Skip to content
LoopStrengthReduce.cpp 102 KiB
Newer Older
//===- LoopStrengthReduce.cpp - Strength Reduce IVs in Loops --------------===//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//===----------------------------------------------------------------------===//
//
// This transformation analyzes and transforms the induction variables (and
// computations derived from them) into forms suitable for efficient execution
// on the target.
//
// This pass performs a strength reduction on array references inside loops that
// have as one or more of their components the loop induction variable, it
// rewrites expressions to take advantage of scaled-index addressing modes
// available on the target, and it performs a variety of other optimizations
// related to loop induction variables.
// Terminology note: this code has a lot of handling for "post-increment" or
// "post-inc" users. This is not talking about post-increment addressing modes;
// it is instead talking about code like this:
//
//   %i = phi [ 0, %entry ], [ %i.next, %latch ]
//   ...
//   %i.next = add %i, 1
//   %c = icmp eq %i.next, %n
//
// The SCEV for %i is {0,+,1}<%L>. The SCEV for %i.next is {1,+,1}<%L>, however
// it's useful to think about these as the same register, with some uses using
// the value of the register before the add and some using // it after. In this
// example, the icmp is a post-increment user, since it uses %i.next, which is
// the value of the induction variable after the increment. The other common
// case of post-increment users is users outside the loop.
//
// TODO: More sophistication in the way Formulae are generated.
//
// TODO: Handle multiple loops at a time.
//
// TODO: test/CodeGen/X86/full-lsr.ll should get full lsr. The problem is
//       that {0,+,1}<%bb> is getting picked first because all 7 uses can
//       use it, and while it's a pretty good solution, it means that LSR
//       doesn't look further to find an even better solution.
//
// TODO: Should TargetLowering::AddrMode::BaseGV be changed to a ConstantExpr
//       instead of a GlobalValue?
//
// TODO: When truncation is free, truncate ICmp users' operands to make it a
//       smaller encoding (on x86 at least).
//
// TODO: When a negated register is used by an add (such as in a list of
//       multiple base registers, or as the increment expression in an addrec),
//       we may not actually need both reg and (-1 * reg) in registers; the
//       negation can be implemented by using a sub instead of an add. The
//       lack of support for taking this into consideration when making
//       register pressure decisions is partly worked around by the "Special"
//       use kind.
//
//===----------------------------------------------------------------------===//

Chris Lattner's avatar
Chris Lattner committed
#define DEBUG_TYPE "loop-reduce"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Constants.h"
#include "llvm/Instructions.h"
#include "llvm/IntrinsicInst.h"
#include "llvm/Analysis/IVUsers.h"
#include "llvm/Analysis/Dominators.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/ScalarEvolutionExpander.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/ValueHandle.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetLowering.h"
Jeff Cohen's avatar
Jeff Cohen committed
#include <algorithm>
// Constant strides come first which in turns are sorted by their absolute
// values. If absolute values are the same, then positive strides comes first.
// e.g.
// 4, -1, X, 1, 2 ==> 1, -1, 2, 4, X
struct StrideCompare {
  const ScalarEvolution &SE;
  explicit StrideCompare(const ScalarEvolution &se) : SE(se) {}

  bool operator()(const SCEV *const &LHS, const SCEV *const &RHS) const {
    const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS);
    const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
    if (LHSC && RHSC) {
      unsigned BitWidth = std::max(SE.getTypeSizeInBits(LHS->getType()),
                                   SE.getTypeSizeInBits(RHS->getType()));
      APInt  LV = LHSC->getValue()->getValue();
      APInt  RV = RHSC->getValue()->getValue();
      LV.sextOrTrunc(BitWidth);
      RV.sextOrTrunc(BitWidth);
      APInt ALV = LV.abs();
      APInt ARV = RV.abs();
      if (ALV == ARV) {
        if (LV != RV)
          return LV.sgt(RV);
      } else {
        return ALV.ult(ARV);
      }
      // If it's the same value but different type, sort by bit width so
      // that we emit larger induction variables before smaller
      // ones, letting the smaller be re-written in terms of larger ones.
      return SE.getTypeSizeInBits(RHS->getType()) <
             SE.getTypeSizeInBits(LHS->getType());
    }
    return LHSC && !RHSC;
  }
};
/// RegSortData - This class holds data which is used to order reuse
/// candidates.
class RegSortData {
public:
  /// Bits - This represents the set of LSRUses (by index) which reference a
  /// particular register.
  SmallBitVector Bits;
  /// MaxComplexity - This represents the greatest complexity (see the comments
  /// on Formula::getComplexity) seen with a particular register.
  uint32_t MaxComplexity;
  /// Index - This holds an arbitrary value used as a last-resort tie breaker
  /// to ensure deterministic behavior.
  unsigned Index;
  void print(raw_ostream &OS) const;
  void dump() const;
};
void RegSortData::print(raw_ostream &OS) const {
  OS << "[NumUses=" << Bits.count()
     << ", MaxComplexity=";
  OS.write_hex(MaxComplexity);
  OS << ", Index=" << Index << ']';
void RegSortData::dump() const {
  print(errs()); errs() << '\n';
}
/// RegCount - This is a helper class to sort a given set of registers
/// according to associated RegSortData values.
class RegCount {
public:
  const SCEV *Reg;
  RegSortData Sort;

  RegCount(const SCEV *R, const RegSortData &RSD)
    : Reg(R), Sort(RSD) {}

  // Sort by count. Returning true means the register is preferred.
  bool operator<(const RegCount &Other) const {
    // Sort by the number of unique uses of this register.
    unsigned A = Sort.Bits.count();
    unsigned B = Other.Sort.Bits.count();
    if (A != B) return A > B;

    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Reg)) {
      const SCEVAddRecExpr *BR = dyn_cast<SCEVAddRecExpr>(Other.Reg);
      // AddRecs have higher priority than other things.
      if (!BR) return true;

      // Prefer affine values.
      if (AR->isAffine() != BR->isAffine())
        return AR->isAffine();

      const Loop *AL = AR->getLoop();
      const Loop *BL = BR->getLoop();
      if (AL != BL) {
        unsigned ADepth = AL->getLoopDepth();
        unsigned BDepth = BL->getLoopDepth();
        // Prefer a less deeply nested addrec.
        if (ADepth != BDepth)
          return ADepth < BDepth;

        // Different loops at the same depth; do something arbitrary.
        BasicBlock *AH = AL->getHeader();
        BasicBlock *BH = BL->getHeader();
        for (Function::iterator I = AH, E = AH->getParent()->end(); I != E; ++I)
          if (&*I == BH) return true;
        return false;
      // Sort addrecs by stride.
      const SCEV *AStep = AR->getOperand(1);
      const SCEV *BStep = BR->getOperand(1);
      if (AStep != BStep) {
        if (const SCEVConstant *AC = dyn_cast<SCEVConstant>(AStep)) {
          const SCEVConstant *BC = dyn_cast<SCEVConstant>(BStep);
          if (!BC) return true;
          // Arbitrarily prefer wider registers.
          if (AC->getValue()->getValue().getBitWidth() !=
              BC->getValue()->getValue().getBitWidth())
            return AC->getValue()->getValue().getBitWidth() >
                   BC->getValue()->getValue().getBitWidth();
          // Ignore the sign bit, assuming that striding by a negative value
          // is just as easy as by a positive value.
          // Prefer the addrec with the lesser absolute value stride, as it
          // will allow uses to have simpler addressing modes.
          return AC->getValue()->getValue().abs()
            .ult(BC->getValue()->getValue().abs());
        }
      }
      // Then sort by the register which will permit the simplest uses.
      // This is a heuristic; currently we only track the most complex use as a
      // representative.
      if (Sort.MaxComplexity != Other.Sort.MaxComplexity)
        return Sort.MaxComplexity < Other.Sort.MaxComplexity;

      // Then sort them by their start values.
      const SCEV *AStart = AR->getStart();
      const SCEV *BStart = BR->getStart();
      if (AStart != BStart) {
        if (const SCEVConstant *AC = dyn_cast<SCEVConstant>(AStart)) {
          const SCEVConstant *BC = dyn_cast<SCEVConstant>(BStart);
          if (!BC) return true;
          // Arbitrarily prefer wider registers.
          if (AC->getValue()->getValue().getBitWidth() !=
              BC->getValue()->getValue().getBitWidth())
            return AC->getValue()->getValue().getBitWidth() >
                   BC->getValue()->getValue().getBitWidth();
          // Prefer positive over negative if the absolute values are the same.
          if (AC->getValue()->getValue().abs() ==
              BC->getValue()->getValue().abs())
            return AC->getValue()->getValue().isStrictlyPositive();
          // Prefer the addrec with the lesser absolute value start.
          return AC->getValue()->getValue().abs()
            .ult(BC->getValue()->getValue().abs());
        }
      }
    } else {
      // AddRecs have higher priority than other things.
      if (isa<SCEVAddRecExpr>(Other.Reg)) return false;
      // Sort by the register which will permit the simplest uses.
      // This is a heuristic; currently we only track the most complex use as a
      // representative.
      if (Sort.MaxComplexity != Other.Sort.MaxComplexity)
        return Sort.MaxComplexity < Other.Sort.MaxComplexity;

    // Tie-breaker: the arbitrary index, to ensure a reliable ordering.
    return Sort.Index < Other.Sort.Index;
  void print(raw_ostream &OS) const;
  void dump() const;
};

void RegCount::print(raw_ostream &OS) const {
  OS << *Reg << ':';
  Sort.print(OS);
void RegCount::dump() const {
  print(errs()); errs() << '\n';
}
/// Formula - This class holds information that describes a formula for
/// satisfying a use. It may include broken-out immediates and scaled registers.
struct Formula {
  /// AM - This is used to represent complex addressing, as well as other kinds
  /// of interesting uses.
  TargetLowering::AddrMode AM;
  /// BaseRegs - The list of "base" registers for this use. When this is
  /// non-empty, AM.HasBaseReg should be set to true.
  SmallVector<const SCEV *, 2> BaseRegs;
  /// ScaledReg - The 'scaled' register for this use. This should be non-null
  /// when AM.Scale is not zero.
  const SCEV *ScaledReg;
  unsigned getNumRegs() const;
  uint32_t getComplexity() const;
  const Type *getType() const;
  void InitialMatch(const SCEV *S, Loop *L,
                    ScalarEvolution &SE, DominatorTree &DT);
  /// referencesReg - Test if this formula references the given register.
  bool referencesReg(const SCEV *S) const {
    return S == ScaledReg ||
           std::find(BaseRegs.begin(), BaseRegs.end(), S) != BaseRegs.end();
  bool operator==(const Formula &Other) const {
    return BaseRegs == Other.BaseRegs &&
           ScaledReg == Other.ScaledReg &&
           AM.Scale == Other.AM.Scale &&
           AM.BaseOffs == Other.AM.BaseOffs &&
           AM.BaseGV == Other.AM.BaseGV;
  }
  // This sorts at least partially based on host pointer values which are
  // not deterministic, so it is only usable for uniqification.
  bool operator<(const Formula &Other) const {
    if (BaseRegs != Other.BaseRegs)
      return BaseRegs < Other.BaseRegs;
    if (ScaledReg != Other.ScaledReg)
      return ScaledReg < Other.ScaledReg;
    if (AM.Scale != Other.AM.Scale)
      return AM.Scale < Other.AM.Scale;
    if (AM.BaseOffs != Other.AM.BaseOffs)
      return AM.BaseOffs < Other.AM.BaseOffs;
    if (AM.BaseGV != Other.AM.BaseGV)
      return AM.BaseGV < Other.AM.BaseGV;
    return false;
  void print(raw_ostream &OS) const;
  void dump() const;
};
/// getNumRegs - Return the total number of register operands used by this
Dan Gohman's avatar
Dan Gohman committed
/// formula. This does not include register uses implied by non-constant
/// addrec strides.
unsigned Formula::getNumRegs() const {
  return !!ScaledReg + BaseRegs.size();
}
/// getComplexity - Return an oversimplified value indicating the complexity
/// of this formula. This is used as a tie-breaker in choosing register
/// preferences.
uint32_t Formula::getComplexity() const {
  // Encode the information in a uint32_t so that comparing with operator<
  // will be interesting.
  return
    // Most significant, the number of registers. This saturates because we
    // need the bits, and because beyond a few hundred it doesn't really matter.
    (std::min(getNumRegs(), (1u<<15)-1) << 17) |
    // Having multiple base regs is worse than having a base reg and a scale.
    ((BaseRegs.size() > 1) << 16) |
    // Scale absolute value.
    ((AM.Scale != 0 ? (Log2_64(abs64(AM.Scale)) + 1) : 0u) << 9) |
    // Scale sign, which is less significant than the absolute value.
    ((AM.Scale < 0) << 8) |
    // Offset absolute value.
    ((AM.BaseOffs != 0 ? (Log2_64(abs64(AM.BaseOffs)) + 1) : 0u) << 1) |
    // If a GV is present, treat it like a maximal offset.
    ((AM.BaseGV ? ((1u<<7)-1) : 0) << 1) |
    // Offset sign, which is less significant than the absolute offset.
    ((AM.BaseOffs < 0) << 0);
}
/// getType - Return the type of this formula, if it has one, or null
/// otherwise. This type is meaningless except for the bit size.
const Type *Formula::getType() const {
  return !BaseRegs.empty() ? BaseRegs.front()->getType() :
         ScaledReg ? ScaledReg->getType() :
         AM.BaseGV ? AM.BaseGV->getType() :
         0;
/// ComplexitySorter - A predicate which orders Formulae by the number of
/// registers they contain.
struct ComplexitySorter {
  bool operator()(const Formula &LHS, const Formula &RHS) const {
    unsigned L = LHS.getNumRegs();
    unsigned R = RHS.getNumRegs();
    if (L != R) return L < R;
    return LHS.getComplexity() < RHS.getComplexity();
  }
};
/// DoInitialMatch - Recurrsion helper for InitialMatch.
static void DoInitialMatch(const SCEV *S, Loop *L,
                           SmallVectorImpl<const SCEV *> &Good,
                           SmallVectorImpl<const SCEV *> &Bad,
                           ScalarEvolution &SE, DominatorTree &DT) {
  // Collect expressions which properly dominate the loop header.
  if (S->properlyDominates(L->getHeader(), &DT)) {
    Good.push_back(S);
  // Look at add operands.
  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
    for (SCEVAddExpr::op_iterator I = Add->op_begin(), E = Add->op_end();
         I != E; ++I)
      DoInitialMatch(*I, L, Good, Bad, SE, DT);
  // Look at addrec operands.
  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
    if (!AR->getStart()->isZero()) {
      DoInitialMatch(AR->getStart(), L, Good, Bad, SE, DT);
      DoInitialMatch(SE.getAddRecExpr(SE.getIntegerSCEV(0, AR->getType()),
                                      AR->getStepRecurrence(SE),
                                      AR->getLoop()),
                     L, Good, Bad, SE, DT);
      return;

  // Handle a multiplication by -1 (negation) if it didn't fold.
  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S))
    if (Mul->getOperand(0)->isAllOnesValue()) {
      SmallVector<const SCEV *, 4> Ops(Mul->op_begin()+1, Mul->op_end());
      const SCEV *NewMul = SE.getMulExpr(Ops);

      SmallVector<const SCEV *, 4> MyGood;
      SmallVector<const SCEV *, 4> MyBad;
      DoInitialMatch(NewMul, L, MyGood, MyBad, SE, DT);
      const SCEV *NegOne = SE.getSCEV(ConstantInt::getAllOnesValue(
        SE.getEffectiveSCEVType(NewMul->getType())));
      for (SmallVectorImpl<const SCEV *>::const_iterator I = MyGood.begin(),
           E = MyGood.end(); I != E; ++I)
        Good.push_back(SE.getMulExpr(NegOne, *I));
      for (SmallVectorImpl<const SCEV *>::const_iterator I = MyBad.begin(),
           E = MyBad.end(); I != E; ++I)
        Bad.push_back(SE.getMulExpr(NegOne, *I));
      return;
    }

  // Ok, we can't do anything interesting. Just stuff the whole thing into a
  // register and hope for the best.
  Bad.push_back(S);
/// InitialMatch - Incorporate loop-variant parts of S into this Formula,
/// attempting to keep all loop-invariant and loop-computable values in a
/// single base register.
void Formula::InitialMatch(const SCEV *S, Loop *L,
                           ScalarEvolution &SE, DominatorTree &DT) {
  SmallVector<const SCEV *, 4> Good;
  SmallVector<const SCEV *, 4> Bad;
  DoInitialMatch(S, L, Good, Bad, SE, DT);
  if (!Good.empty()) {
    BaseRegs.push_back(SE.getAddExpr(Good));
    AM.HasBaseReg = true;
  if (!Bad.empty()) {
    BaseRegs.push_back(SE.getAddExpr(Bad));
    AM.HasBaseReg = true;
  }
}
void Formula::print(raw_ostream &OS) const {
  bool First = true;
  if (AM.BaseGV) {
    if (!First) OS << " + "; else First = false;
    WriteAsOperand(OS, AM.BaseGV, /*PrintType=*/false);
  if (AM.BaseOffs != 0) {
    if (!First) OS << " + "; else First = false;
    OS << AM.BaseOffs;
  }
  for (SmallVectorImpl<const SCEV *>::const_iterator I = BaseRegs.begin(),
       E = BaseRegs.end(); I != E; ++I) {
    if (!First) OS << " + "; else First = false;
    OS << "reg(";
    OS << **I;
    OS << ")";
  }
  if (AM.Scale != 0) {
    if (!First) OS << " + "; else First = false;
    OS << AM.Scale << "*reg(";
    if (ScaledReg)
      OS << *ScaledReg;
    else
      OS << "<unknown>";
    OS << ")";
  }
}
void Formula::dump() const {
  print(errs()); errs() << '\n';
}

/// getSDiv - Return an expression for LHS /s RHS, if it can be determined,
/// or null otherwise. If IgnoreSignificantBits is true, expressions like
/// (X * Y) /s Y are simplified to Y, ignoring that the multiplication may
/// overflow, which is useful when the result will be used in a context where
/// the most significant bits are ignored.
static const SCEV *getSDiv(const SCEV *LHS, const SCEV *RHS,
                           ScalarEvolution &SE,
                           bool IgnoreSignificantBits = false) {
  // Handle the trivial case, which works for any SCEV type.
  if (LHS == RHS)
    return SE.getIntegerSCEV(1, LHS->getType());

  // Handle x /s -1 as x * -1, to give ScalarEvolution a chance to do some
  // folding.
  if (RHS->isAllOnesValue())
    return SE.getMulExpr(LHS, RHS);

  // Check for a division of a constant by a constant.
  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(LHS)) {
    const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS);
    if (!RC)
      return 0;
    if (C->getValue()->getValue().srem(RC->getValue()->getValue()) != 0)
      return 0;
    return SE.getConstant(C->getValue()->getValue()
               .sdiv(RC->getValue()->getValue()));
  // Distribute the sdiv over addrec operands.
  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS)) {
    const SCEV *Start = getSDiv(AR->getStart(), RHS, SE,
                                IgnoreSignificantBits);
    if (!Start) return 0;
    const SCEV *Step = getSDiv(AR->getStepRecurrence(SE), RHS, SE,
                               IgnoreSignificantBits);
    if (!Step) return 0;
    return SE.getAddRecExpr(Start, Step, AR->getLoop());
  // Distribute the sdiv over add operands.
  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(LHS)) {
    SmallVector<const SCEV *, 8> Ops;
    for (SCEVAddExpr::op_iterator I = Add->op_begin(), E = Add->op_end();
         I != E; ++I) {
      const SCEV *Op = getSDiv(*I, RHS, SE,
                               IgnoreSignificantBits);
      if (!Op) return 0;
      Ops.push_back(Op);
  // Check for a multiply operand that we can pull RHS out of.
  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS))
    if (IgnoreSignificantBits || Mul->hasNoSignedWrap()) {
      SmallVector<const SCEV *, 4> Ops;
      bool Found = false;
      for (SCEVMulExpr::op_iterator I = Mul->op_begin(), E = Mul->op_end();
           I != E; ++I) {
        if (!Found)
          if (const SCEV *Q = getSDiv(*I, RHS, SE, IgnoreSignificantBits)) {
            Ops.push_back(Q);
            Found = true;
            continue;
          }
        Ops.push_back(*I);
      }
      return Found ? SE.getMulExpr(Ops) : 0;
/// LSRUse - This class holds the state that LSR keeps for each use in
/// IVUsers, as well as uses invented by LSR itself. It includes information
/// about what kinds of things can be folded into the user, information
/// about the user itself, and information about how the use may be satisfied.
/// TODO: Represent multiple users of the same expression in common?
class LSRUse {
  SmallSet<Formula, 8> FormulaeUniquifier;

public:
  /// KindType - An enum for a kind of use, indicating what types of
  /// scaled and immediate operands it might support.
  enum KindType {
    Basic,   ///< A normal use, with no folding.
    Special, ///< A special case of basic, allowing -1 scales.
    Address, ///< An address use; folding according to TargetLowering
    ICmpZero ///< An equality icmp with both operands folded into one.
    // TODO: Add a generic icmp too?
  };
  KindType Kind;
  const Type *AccessTy;
  Instruction *UserInst;
  Value *OperandValToReplace;
  /// PostIncLoop - If this user is to use the post-incremented value of an
  /// induction variable, this variable is non-null and holds the loop
  /// associated with the induction variable.
  const Loop *PostIncLoop;
  /// Formulae - A list of ways to build a value that can satisfy this user.
  /// After the list is populated, one of these is selected heuristically and
  /// used to formulate a replacement for OperandValToReplace in UserInst.
  SmallVector<Formula, 12> Formulae;
  LSRUse() : Kind(Basic), AccessTy(0),
             UserInst(0), OperandValToReplace(0), PostIncLoop(0) {}
  void InsertInitialFormula(const SCEV *S, Loop *L,
                            ScalarEvolution &SE, DominatorTree &DT);
  void InsertSupplementalFormula(const SCEV *S);
  bool InsertFormula(const Formula &F);
  void Rewrite(Loop *L, Instruction *IVIncInsertPos,
               SCEVExpander &Rewriter,
               SmallVectorImpl<WeakVH> &DeadInsts,
               ScalarEvolution &SE, DominatorTree &DT,
               Pass *P) const;
  void print(raw_ostream &OS) const;
  void dump() const;
  Value *Expand(BasicBlock::iterator IP, Loop *L, Instruction *IVIncInsertPos,
                SCEVExpander &Rewriter,
                SmallVectorImpl<WeakVH> &DeadInsts,
                ScalarEvolution &SE, DominatorTree &DT) const;
};
/// ExtractImmediate - If S involves the addition of a constant integer value,
/// return that integer value, and mutate S to point to a new SCEV with that
/// value excluded.
static int64_t ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
    if (C->getValue()->getValue().getMinSignedBits() <= 64) {
      S = SE.getIntegerSCEV(0, C->getType());
      return C->getValue()->getSExtValue();
  } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
    SmallVector<const SCEV *, 8> NewOps(Add->op_begin(), Add->op_end());
    int64_t Result = ExtractImmediate(NewOps.front(), SE);
    S = SE.getAddExpr(NewOps);
    return Result;
  } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
    SmallVector<const SCEV *, 8> NewOps(AR->op_begin(), AR->op_end());
    int64_t Result = ExtractImmediate(NewOps.front(), SE);
    S = SE.getAddRecExpr(NewOps, AR->getLoop());
    return Result;
  return 0;
}

/// ExtractSymbol - If S involves the addition of a GlobalValue address,
/// return that symbol, and mutate S to point to a new SCEV with that
/// value excluded.
static GlobalValue *ExtractSymbol(const SCEV *&S, ScalarEvolution &SE) {
  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
    if (GlobalValue *GV = dyn_cast<GlobalValue>(U->getValue())) {
      S = SE.getIntegerSCEV(0, GV->getType());
      return GV;
    }
  } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
    SmallVector<const SCEV *, 8> NewOps(Add->op_begin(), Add->op_end());
    GlobalValue *Result = ExtractSymbol(NewOps.back(), SE);
    S = SE.getAddExpr(NewOps);
    return Result;
  } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
    SmallVector<const SCEV *, 8> NewOps(AR->op_begin(), AR->op_end());
    GlobalValue *Result = ExtractSymbol(NewOps.front(), SE);
    S = SE.getAddRecExpr(NewOps, AR->getLoop());
    return Result;
/// isLegalUse - Test whether the use described by AM is "legal", meaning
/// it can be completely folded into the user instruction at isel time.
/// This includes address-mode folding and special icmp tricks.
static bool isLegalUse(const TargetLowering::AddrMode &AM,
                       LSRUse::KindType Kind, const Type *AccessTy,
                       const TargetLowering *TLI) {
  switch (Kind) {
  case LSRUse::Address:
    // If we have low-level target information, ask the target if it can
    // completely fold this address.
    if (TLI) return TLI->isLegalAddressingMode(AM, AccessTy);

    // Otherwise, just guess that reg+reg addressing is legal.
    return !AM.BaseGV && AM.BaseOffs == 0 && AM.Scale <= 1;

  case LSRUse::ICmpZero:
    // There's not even a target hook for querying whether it would be legal
    // to fold a GV into an ICmp.
    if (AM.BaseGV)
      return false;
    // ICmp only has two operands; don't allow more than two non-trivial parts.
    if (AM.Scale != 0 && AM.HasBaseReg && AM.BaseOffs != 0)
      return false;
    // ICmp only supports no scale or a -1 scale, as we can "fold" a -1 scale
    // by putting the scaled register in the other operand of the icmp.
    if (AM.Scale != 0 && AM.Scale != -1)
      return false;
    // If we have low-level target information, ask the target if it can
    // fold an integer immediate on an icmp.
    if (AM.BaseOffs != 0) {
      if (TLI) return TLI->isLegalICmpImmediate(-AM.BaseOffs);
      return false;
    return true;

  case LSRUse::Basic:
    // Only handle single-register values.
    return !AM.BaseGV && AM.Scale == 0 && AM.BaseOffs == 0;
  case LSRUse::Special:
    // Only handle -1 scales, or no scale.
    return AM.Scale == 0 || AM.Scale == -1;
  }
static bool isAlwaysFoldable(const SCEV *S,
                             bool HasBaseReg,
                             LSRUse::KindType Kind, const Type *AccessTy,
                             const TargetLowering *TLI,
                             ScalarEvolution &SE) {
  // Fast-path: zero is always foldable.
  if (S->isZero()) return true;

  // Conservatively, create an address with an immediate and a
  // base and a scale.
  TargetLowering::AddrMode AM;
  AM.BaseOffs = ExtractImmediate(S, SE);
  AM.BaseGV = ExtractSymbol(S, SE);
  AM.HasBaseReg = HasBaseReg;
  AM.Scale = Kind == LSRUse::ICmpZero ? -1 : 1;

  // If there's anything else involved, it's not foldable.
  if (!S->isZero()) return false;

  return isLegalUse(AM, Kind, AccessTy, TLI);
}
/// InsertFormula - If the given formula has not yet been inserted, add it
/// to the list, and return true. Return false otherwise.
bool LSRUse::InsertFormula(const Formula &F) {
  Formula Copy = F;
  // Sort the base regs, to avoid adding the same solution twice with
  // the base regs in different orders. This uses host pointer values, but
  // it doesn't matter since it's only used for uniquifying.
  std::sort(Copy.BaseRegs.begin(), Copy.BaseRegs.end());
  DEBUG(for (SmallVectorImpl<const SCEV *>::const_iterator I =
             F.BaseRegs.begin(), E = F.BaseRegs.end(); I != E; ++I)
          assert(!(*I)->isZero() && "Zero allocated in a base register!");
        assert((!F.ScaledReg || !F.ScaledReg->isZero()) &&
               "Zero allocated in a scaled register!"));
  if (FormulaeUniquifier.insert(Copy)) {
    Formulae.push_back(F);
void
LSRUse::InsertInitialFormula(const SCEV *S, Loop *L,
                             ScalarEvolution &SE, DominatorTree &DT) {
  Formula F;
  F.InitialMatch(S, L, SE, DT);
  bool Inserted = InsertFormula(F);
  assert(Inserted && "Initial formula already exists!"); (void)Inserted;
LSRUse::InsertSupplementalFormula(const SCEV *S) {
  Formula F;
  F.BaseRegs.push_back(S);
  F.AM.HasBaseReg = true;
  bool Inserted = InsertFormula(F);
  assert(Inserted && "Supplemental formula already exists!"); (void)Inserted;
/// getImmediateDominator - A handy utility for the specific DominatorTree
/// query that we need here.
static BasicBlock *getImmediateDominator(BasicBlock *BB, DominatorTree &DT) {
  DomTreeNode *Node = DT.getNode(BB);
  if (!Node) return 0;
  Node = Node->getIDom();
  if (!Node) return 0;
  return Node->getBlock();
Value *LSRUse::Expand(BasicBlock::iterator IP,
                      Loop *L, Instruction *IVIncInsertPos,
                      SCEVExpander &Rewriter,
                      SmallVectorImpl<WeakVH> &DeadInsts,
                      ScalarEvolution &SE, DominatorTree &DT) const {
  // Then, collect some instructions which we will remain dominated by when
  // expanding the replacement. These must be dominated by any operands that
  // will be required in the expansion.
  SmallVector<Instruction *, 4> Inputs;
  if (Instruction *I = dyn_cast<Instruction>(OperandValToReplace))
    Inputs.push_back(I);
  if (Kind == ICmpZero)
    if (Instruction *I =
          dyn_cast<Instruction>(cast<ICmpInst>(UserInst)->getOperand(1)))
      Inputs.push_back(I);
  if (PostIncLoop && !L->contains(UserInst))
    Inputs.push_back(L->getLoopLatch()->getTerminator());

  // Then, climb up the immediate dominator tree as far as we can go while
  // still being dominated by the input positions.
  for (;;) {
    bool AllDominate = true;
    Instruction *BetterPos = 0;
    BasicBlock *IDom = getImmediateDominator(IP->getParent(), DT);
    if (!IDom) break;
    Instruction *Tentative = IDom->getTerminator();
    for (SmallVectorImpl<Instruction *>::const_iterator I = Inputs.begin(),
         E = Inputs.end(); I != E; ++I) {
      Instruction *Inst = *I;
      if (Inst == Tentative || !DT.dominates(Inst, Tentative)) {
        AllDominate = false;
        break;
      if (IDom == Inst->getParent() &&
          (!BetterPos || DT.dominates(BetterPos, Inst)))
        BetterPos = next(BasicBlock::iterator(Inst));
    if (!AllDominate)
      break;
    if (BetterPos)
      IP = BetterPos;
    else
      IP = Tentative;
  }
  while (isa<PHINode>(IP)) ++IP;

  // The first formula in the list is the winner.
  const Formula &F = Formulae.front();

  // Inform the Rewriter if we have a post-increment use, so that it can
  // perform an advantageous expansion.
  Rewriter.setPostInc(PostIncLoop);

  // This is the type that the user actually needs.
  const Type *OpTy = OperandValToReplace->getType();
  // This will be the type that we'll initially expand to.
  const Type *Ty = F.getType();
  if (!Ty)
    // No type known; just expand directly to the ultimate type.
    Ty = OpTy;
  else if (SE.getEffectiveSCEVType(Ty) == SE.getEffectiveSCEVType(OpTy))
    // Expand directly to the ultimate type if it's the right size.
    Ty = OpTy;
  // This is the type to do integer arithmetic in.
  const Type *IntTy = SE.getEffectiveSCEVType(Ty);

  // Build up a list of operands to add together to form the full base.
  SmallVector<const SCEV *, 8> Ops;

  // Expand the BaseRegs portion.
  for (SmallVectorImpl<const SCEV *>::const_iterator I = F.BaseRegs.begin(),
       E = F.BaseRegs.end(); I != E; ++I) {
    const SCEV *Reg = *I;
    assert(!Reg->isZero() && "Zero allocated in a base register!");

    // If we're expanding for a post-inc user for the add-rec's loop, make the
    // post-inc adjustment.
    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Reg))
      if (AR->getLoop() == PostIncLoop) {
        Reg = SE.getAddExpr(Reg, AR->getStepRecurrence(SE));
        // If the user is inside the loop, insert the code after the increment
        // so that it is dominated by its operand.
        if (L->contains(UserInst))
          IP = IVIncInsertPos;
      }

    Ops.push_back(SE.getUnknown(Rewriter.expandCodeFor(Reg, 0, IP)));
  // Expand the ScaledReg portion.
  Value *ICmpScaledV = 0;
  if (F.AM.Scale != 0) {
    const SCEV *ScaledS = F.ScaledReg;

    // If we're expanding for a post-inc user for the add-rec's loop, make the
    // post-inc adjustment.
    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ScaledS))
      if (AR->getLoop() == PostIncLoop)
        ScaledS = SE.getAddExpr(ScaledS, AR->getStepRecurrence(SE));

    if (Kind == ICmpZero) {
      // An interesting way of "folding" with an icmp is to use a negated
      // scale, which we'll implement by inserting it into the other operand
      // of the icmp.
      assert(F.AM.Scale == -1 &&
             "The only scale supported by ICmpZero uses is -1!");
      ICmpScaledV = Rewriter.expandCodeFor(ScaledS, 0, IP);
    } else {
      // Otherwise just expand the scaled register and an explicit scale,
      // which is expected to be matched as part of the address.
      ScaledS = SE.getUnknown(Rewriter.expandCodeFor(ScaledS, 0, IP));
      const Type *ScaledTy = SE.getEffectiveSCEVType(ScaledS->getType());
      ScaledS = SE.getMulExpr(ScaledS,
                              SE.getSCEV(ConstantInt::get(ScaledTy,
                                                          F.AM.Scale)));
      Ops.push_back(ScaledS);
Chris Lattner's avatar
 
Chris Lattner committed
  }

  // Expand the immediate portions.
  if (F.AM.BaseGV)
    Ops.push_back(SE.getSCEV(F.AM.BaseGV));
  if (F.AM.BaseOffs != 0) {
    if (Kind == ICmpZero) {
      // The other interesting way of "folding" with an ICmpZero is to use a
      // negated immediate.
      if (!ICmpScaledV)
        ICmpScaledV = ConstantInt::get(IntTy, -F.AM.BaseOffs);
      else {
        Ops.push_back(SE.getUnknown(ICmpScaledV));
        ICmpScaledV = ConstantInt::get(IntTy, F.AM.BaseOffs);
    } else {
      // Just add the immediate values. These again are expected to be matched
      // as part of the address.
      Ops.push_back(SE.getSCEV(ConstantInt::get(IntTy, F.AM.BaseOffs)));
Chris Lattner's avatar
 
Chris Lattner committed
    }
Chris Lattner's avatar
 
Chris Lattner committed

  // Emit instructions summing all the operands.
  const SCEV *FullS = Ops.empty() ?
                      SE.getIntegerSCEV(0, IntTy) :
                      SE.getAddExpr(Ops);
  Value *FullV = Rewriter.expandCodeFor(FullS, Ty, IP);

  // We're done expanding now, so reset the rewriter.
  Rewriter.setPostInc(0);

  // An ICmpZero Formula represents an ICmp which we're handling as a
  // comparison against zero. Now that we've expanded an expression for that
  // form, update the ICmp's other operand.
  if (Kind == ICmpZero) {
    ICmpInst *CI = cast<ICmpInst>(UserInst);
    DeadInsts.push_back(CI->getOperand(1));
    assert(!F.AM.BaseGV && "ICmp does not support folding a global value and "
                           "a scale at the same time!");
    if (F.AM.Scale == -1) {
      if (ICmpScaledV->getType() != OpTy) {
        Instruction *Cast =
          CastInst::Create(CastInst::getCastOpcode(ICmpScaledV, false,
                                                   OpTy, false),
                           ICmpScaledV, OpTy, "tmp", CI);
        ICmpScaledV = Cast;
      CI->setOperand(1, ICmpScaledV);
    } else {
      assert(F.AM.Scale == 0 &&
             "ICmp does not support folding a global value and "
             "a scale at the same time!");
      Constant *C = ConstantInt::getSigned(SE.getEffectiveSCEVType(OpTy),
                                           -(uint64_t)F.AM.BaseOffs);
      if (C->getType() != OpTy)
        C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
                                                          OpTy, false),
                                  C, OpTy);

      CI->setOperand(1, C);
    }
  }
/// Rewrite - Emit instructions for the leading candidate expression for this
/// LSRUse (this is called "expanding"), and update the UserInst to reference
/// the newly expanded value.
void LSRUse::Rewrite(Loop *L, Instruction *IVIncInsertPos,
                     SCEVExpander &Rewriter,
                     SmallVectorImpl<WeakVH> &DeadInsts,