Skip to content
ScalarReplAggregates.cpp 75.4 KiB
Newer Older
//===- ScalarReplAggregates.cpp - Scalar Replacement of Aggregates --------===//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//===----------------------------------------------------------------------===//
//
// This transformation implements the well known scalar replacement of
// aggregates transformation.  This xform breaks up alloca instructions of
// aggregate type (structure or array) into individual alloca instructions for
// each member (if possible).  Then, if possible, it transforms the individual
// alloca instructions into nice clean scalar SSA form.
//
// This combines a simple SRoA algorithm with the Mem2Reg algorithm because
// often interact, especially for C++ programs.  As such, iterating between
// SRoA, then Mem2Reg until we run out of things to promote works well.
//
//===----------------------------------------------------------------------===//

#include "llvm/Constants.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Instructions.h"
#include "llvm/LLVMContext.h"
#include "llvm/Analysis/Dominators.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/GetElementPtrTypeIterator.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/ADT/SmallVector.h"
Reid Spencer's avatar
Reid Spencer committed
#include "llvm/ADT/Statistic.h"
using namespace llvm;
STATISTIC(NumReplaced,  "Number of allocas broken up");
STATISTIC(NumPromoted,  "Number of allocas promoted");
STATISTIC(NumConverted, "Number of aggregates converted to scalar");
STATISTIC(NumGlobals,   "Number of allocas copied from constant global");
  struct SROA : public FunctionPass {
Nick Lewycky's avatar
Nick Lewycky committed
    static char ID; // Pass identification, replacement for typeid
    explicit SROA(signed T = -1) : FunctionPass(&ID) {
Chris Lattner's avatar
Chris Lattner committed
        SRThreshold = 128;
    bool performScalarRepl(Function &F);
    bool performPromotion(Function &F);

    // getAnalysisUsage - This pass does not require any passes, but we know it
    // will not alter the CFG, so say so.
    virtual void getAnalysisUsage(AnalysisUsage &AU) const {
      AU.addRequired<DominatorTree>();
      AU.addRequired<DominanceFrontier>();
      AU.setPreservesCFG();
    }

    /// AllocaInfo - When analyzing uses of an alloca instruction, this captures
    /// information about the uses.  All these fields are initialized to false
    /// and set to true when something is learned.
    struct AllocaInfo {
      /// isUnsafe - This is set to true if the alloca cannot be SROA'd.
      bool isUnsafe : 1;
      
      /// needsCleanup - This is set to true if there is some use of the alloca
      /// that requires cleanup.
      bool needsCleanup : 1;
      
      /// isMemCpySrc - This is true if this aggregate is memcpy'd from.
      bool isMemCpySrc : 1;

Zhou Sheng's avatar
Zhou Sheng committed
      /// isMemCpyDst - This is true if this aggregate is memcpy'd into.
        : isUnsafe(false), needsCleanup(false), 
          isMemCpySrc(false), isMemCpyDst(false) {}
    };
    
    void MarkUnsafe(AllocaInfo &I) { I.isUnsafe = true; }

    int isSafeAllocaToScalarRepl(AllocaInst *AI);
    void isSafeUseOfAllocation(Instruction *User, AllocaInst *AI,
                               AllocaInfo &Info);
    void isSafeElementUse(Value *Ptr, bool isFirstElt, AllocaInst *AI,
                          AllocaInfo &Info);
    void isSafeMemIntrinsicOnAllocation(MemIntrinsic *MI, AllocaInst *AI,
                                        unsigned OpNo, AllocaInfo &Info);
    void isSafeUseOfBitCastedAllocation(BitCastInst *User, AllocaInst *AI,
                                        AllocaInfo &Info);
    void DoScalarReplacement(AllocaInst *AI, 
                             std::vector<AllocaInst*> &WorkList);
    void CleanupGEP(GetElementPtrInst *GEP);
    void CleanupAllocaUsers(AllocaInst *AI);
    AllocaInst *AddNewAlloca(Function &F, const Type *Ty, AllocaInst *Base);
    void RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocaInst *AI,
                                    SmallVector<AllocaInst*, 32> &NewElts);
    
    void RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *BCInst,
                                      SmallVector<AllocaInst*, 32> &NewElts);
    void RewriteStoreUserOfWholeAlloca(StoreInst *SI, AllocaInst *AI,
                                       SmallVector<AllocaInst*, 32> &NewElts);
    void RewriteLoadUserOfWholeAlloca(LoadInst *LI, AllocaInst *AI,
                                      SmallVector<AllocaInst*, 32> &NewElts);
    bool CanConvertToScalar(Value *V, bool &IsNotTrivial, const Type *&VecTy,
                            bool &SawVec, uint64_t Offset, unsigned AllocaSize);
    void ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, uint64_t Offset);
    Value *ConvertScalar_ExtractValue(Value *NV, const Type *ToType,
                                     uint64_t Offset, IRBuilder<> &Builder);
    Value *ConvertScalar_InsertValue(Value *StoredVal, Value *ExistingVal,
                                     uint64_t Offset, IRBuilder<> &Builder);
    static Instruction *isOnlyCopiedFromConstantGlobal(AllocaInst *AI);
char SROA::ID = 0;
static RegisterPass<SROA> X("scalarrepl", "Scalar Replacement of Aggregates");

// Public interface to the ScalarReplAggregates pass
FunctionPass *llvm::createScalarReplAggregatesPass(signed int Threshold) { 
  return new SROA(Threshold);
}
  TD = getAnalysisIfAvailable<TargetData>();

  bool Changed = performPromotion(F);

  // FIXME: ScalarRepl currently depends on TargetData more than it
  // theoretically needs to. It should be refactored in order to support
  // target-independent IR. Until this is done, just skip the actual
  // scalar-replacement portion of this pass.
  if (!TD) return Changed;

  while (1) {
    bool LocalChange = performScalarRepl(F);
    if (!LocalChange) break;   // No need to repromote if no scalarrepl
    Changed = true;
    LocalChange = performPromotion(F);
    if (!LocalChange) break;   // No need to re-scalarrepl if no promotion
  }

  return Changed;
}


bool SROA::performPromotion(Function &F) {
  std::vector<AllocaInst*> Allocas;
  DominatorTree         &DT = getAnalysis<DominatorTree>();
  DominanceFrontier &DF = getAnalysis<DominanceFrontier>();
  BasicBlock &BB = F.getEntryBlock();  // Get the entry node for the function
  bool Changed = false;
  while (1) {
    Allocas.clear();

    // Find allocas that are safe to promote, by looking at all instructions in
    // the entry node
    for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
      if (AllocaInst *AI = dyn_cast<AllocaInst>(I))       // Is it an alloca?
        if (isAllocaPromotable(AI))
          Allocas.push_back(AI);

    if (Allocas.empty()) break;

    PromoteMemToReg(Allocas, DT, DF);
    NumPromoted += Allocas.size();
    Changed = true;
  }

  return Changed;
}

/// getNumSAElements - Return the number of elements in the specific struct or
/// array.
static uint64_t getNumSAElements(const Type *T) {
  if (const StructType *ST = dyn_cast<StructType>(T))
    return ST->getNumElements();
  return cast<ArrayType>(T)->getNumElements();
}

// performScalarRepl - This algorithm is a simple worklist driven algorithm,
// which runs on all of the malloc/alloca instructions in the function, removing
// them if they are only used by getelementptr instructions.
//
bool SROA::performScalarRepl(Function &F) {

  // Scan the entry basic block, adding any alloca's and mallocs to the worklist
  BasicBlock &BB = F.getEntryBlock();
  for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I)
    if (AllocaInst *A = dyn_cast<AllocaInst>(I))
      WorkList.push_back(A);

  // Process the worklist
  bool Changed = false;
  while (!WorkList.empty()) {
    // Handle dead allocas trivially.  These can be formed by SROA'ing arrays
    // with unused elements.
    if (AI->use_empty()) {
      AI->eraseFromParent();
      continue;
    }

    // If this alloca is impossible for us to promote, reject it early.
    if (AI->isArrayAllocation() || !AI->getAllocatedType()->isSized())
      continue;
    
    // Check to see if this allocation is only modified by a memcpy/memmove from
    // a constant global.  If this is the case, we can change all users to use
    // the constant global instead.  This is commonly produced by the CFE by
    // constructs like "void foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A'
    // is only subsequently read.
    if (Instruction *TheCopy = isOnlyCopiedFromConstantGlobal(AI)) {
      DEBUG(errs() << "Found alloca equal to global: " << *AI << '\n');
      DEBUG(errs() << "  memcpy = " << *TheCopy << '\n');
      Constant *TheSrc = cast<Constant>(TheCopy->getOperand(2));
      AI->replaceAllUsesWith(ConstantExpr::getBitCast(TheSrc, AI->getType()));
      TheCopy->eraseFromParent();  // Don't mutate the global.
      AI->eraseFromParent();
      ++NumGlobals;
      Changed = true;
      continue;
    }
    // Check to see if we can perform the core SROA transformation.  We cannot
    // transform the allocation instruction if it is an array allocation
    // (allocations OF arrays are ok though), and an allocation of a scalar
    // value cannot be decomposed at all.
    uint64_t AllocaSize = TD->getTypeAllocSize(AI->getAllocatedType());
    // Do not promote [0 x %struct].
    if (AllocaSize == 0) continue;

    // Do not promote any struct whose size is too big.
    if (AllocaSize > SRThreshold) continue;
    if ((isa<StructType>(AI->getAllocatedType()) ||
         isa<ArrayType>(AI->getAllocatedType())) &&
        // Do not promote any struct into more than "32" separate vars.
        getNumSAElements(AI->getAllocatedType()) <= SRThreshold/4) {
      // Check that all of the users of the allocation are capable of being
      // transformed.
      switch (isSafeAllocaToScalarRepl(AI)) {
      default: llvm_unreachable("Unexpected value!");
      case 0:  // Not safe to scalar replace.
        break;
      case 1:  // Safe, but requires cleanup/canonicalizations first
        // FALL THROUGH.
      case 3:  // Safe to scalar replace.
        DoScalarReplacement(AI, WorkList);
        Changed = true;
        continue;
      }

    // If we can turn this aggregate value (potentially with casts) into a
    // simple scalar value that can be mem2reg'd into a register value.
    // IsNotTrivial tracks whether this is something that mem2reg could have
    // promoted itself.  If so, we don't want to transform it needlessly.  Note
    // that we can't just check based on the type: the alloca may be of an i32
    // but that has pointer arithmetic to set byte 3 of it or something.
    bool HadAVector = false;
    if (CanConvertToScalar(AI, IsNotTrivial, VectorTy, HadAVector, 
                           0, unsigned(AllocaSize)) && IsNotTrivial) {
      // If we were able to find a vector type that can handle this with
      // insert/extract elements, and if there was at least one use that had
      // a vector type, promote this to a vector.  We don't want to promote
      // random stuff that doesn't use vectors (e.g. <9 x double>) because then
      // we just get a lot of insert/extracts.  If at least one vector is
      // involved, then we probably really do have a union of vector/array.
      if (VectorTy && isa<VectorType>(VectorTy) && HadAVector) {
        DEBUG(errs() << "CONVERT TO VECTOR: " << *AI << "\n  TYPE = "
        // Create and insert the vector alloca.
        NewAI = new AllocaInst(VectorTy, 0, "",  AI->getParent()->begin());
        ConvertUsesToScalar(AI, NewAI, 0);
      } else {
        DEBUG(errs() << "CONVERT TO SCALAR INTEGER: " << *AI << "\n");
        
        // Create and insert the integer alloca.
        const Type *NewTy = IntegerType::get(AI->getContext(), AllocaSize*8);
        NewAI = new AllocaInst(NewTy, 0, "", AI->getParent()->begin());
        ConvertUsesToScalar(AI, NewAI, 0);
      NewAI->takeName(AI);
      AI->eraseFromParent();
      ++NumConverted;
      Changed = true;
      continue;
    }
    // Otherwise, couldn't process this alloca.
/// DoScalarReplacement - This alloca satisfied the isSafeAllocaToScalarRepl
/// predicate, do SROA now.
void SROA::DoScalarReplacement(AllocaInst *AI, 
                               std::vector<AllocaInst*> &WorkList) {
Chris Lattner's avatar
Chris Lattner committed
  DEBUG(errs() << "Found inst to SROA: " << *AI << '\n');
  SmallVector<AllocaInst*, 32> ElementAllocas;
  if (const StructType *ST = dyn_cast<StructType>(AI->getAllocatedType())) {
    ElementAllocas.reserve(ST->getNumContainedTypes());
    for (unsigned i = 0, e = ST->getNumContainedTypes(); i != e; ++i) {
      AllocaInst *NA = new AllocaInst(ST->getContainedType(i), 0, 
                                      AI->getName() + "." + Twine(i), AI);
      ElementAllocas.push_back(NA);
      WorkList.push_back(NA);  // Add to worklist for recursive processing
    }
  } else {
    const ArrayType *AT = cast<ArrayType>(AI->getAllocatedType());
    ElementAllocas.reserve(AT->getNumElements());
    const Type *ElTy = AT->getElementType();
    for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) {
      AllocaInst *NA = new AllocaInst(ElTy, 0, AI->getAlignment(),
                                      AI->getName() + "." + Twine(i), AI);
      ElementAllocas.push_back(NA);
      WorkList.push_back(NA);  // Add to worklist for recursive processing
  // Now that we have created the alloca instructions that we want to use,
  // expand the getelementptr instructions to use them.
  while (!AI->use_empty()) {
    Instruction *User = cast<Instruction>(AI->use_back());
    if (BitCastInst *BCInst = dyn_cast<BitCastInst>(User)) {
      RewriteBitCastUserOfAlloca(BCInst, AI, ElementAllocas);
      BCInst->eraseFromParent();
      continue;
    }
    
    // Replace:
    //   %res = load { i32, i32 }* %alloc
    // with:
    //   %load.0 = load i32* %alloc.0
    //   %insert.0 insertvalue { i32, i32 } zeroinitializer, i32 %load.0, 0 
    //   %load.1 = load i32* %alloc.1
    //   %insert = insertvalue { i32, i32 } %insert.0, i32 %load.1, 1 
    // (Also works for arrays instead of structs)
    if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
      Value *Insert = UndefValue::get(LI->getType());
      for (unsigned i = 0, e = ElementAllocas.size(); i != e; ++i) {
        Value *Load = new LoadInst(ElementAllocas[i], "load", LI);
        Insert = InsertValueInst::Create(Insert, Load, i, "insert", LI);
      }
      LI->replaceAllUsesWith(Insert);
      LI->eraseFromParent();
      continue;
    }
    // Replace:
    //   store { i32, i32 } %val, { i32, i32 }* %alloc
    // with:
    //   %val.0 = extractvalue { i32, i32 } %val, 0 
    //   store i32 %val.0, i32* %alloc.0
    //   %val.1 = extractvalue { i32, i32 } %val, 1 
    //   store i32 %val.1, i32* %alloc.1
    // (Also works for arrays instead of structs)
    if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
      Value *Val = SI->getOperand(0);
      for (unsigned i = 0, e = ElementAllocas.size(); i != e; ++i) {
        Value *Extract = ExtractValueInst::Create(Val, i, Val->getName(), SI);
        new StoreInst(Extract, ElementAllocas[i], SI);
      }
      SI->eraseFromParent();
      continue;
    }
    
    GetElementPtrInst *GEPI = cast<GetElementPtrInst>(User);
    // We now know that the GEP is of the form: GEP <ptr>, 0, <cst>
    unsigned Idx =
       (unsigned)cast<ConstantInt>(GEPI->getOperand(2))->getZExtValue();

    assert(Idx < ElementAllocas.size() && "Index out of range?");
    AllocaInst *AllocaToUse = ElementAllocas[Idx];

    Value *RepValue;
    if (GEPI->getNumOperands() == 3) {
      // Do not insert a new getelementptr instruction with zero indices, only
      // to have it optimized out later.
      RepValue = AllocaToUse;
    } else {
      // We are indexing deeply into the structure, so we still need a
      // getelement ptr instruction to finish the indexing.  This may be
      // expanded itself once the worklist is rerun.
      //
      SmallVector<Value*, 8> NewArgs;
      NewArgs.push_back(Constant::getNullValue(
                                           Type::getInt32Ty(AI->getContext())));
      NewArgs.append(GEPI->op_begin()+3, GEPI->op_end());
      RepValue = GetElementPtrInst::Create(AllocaToUse, NewArgs.begin(),
                                           NewArgs.end(), "", GEPI);
      RepValue->takeName(GEPI);
    }
    
    // If this GEP is to the start of the aggregate, check for memcpys.
    if (Idx == 0 && GEPI->hasAllZeroIndices())
      RewriteBitCastUserOfAlloca(GEPI, AI, ElementAllocas);

    // Move all of the users over to the new GEP.
    GEPI->replaceAllUsesWith(RepValue);
    // Delete the old GEP
    GEPI->eraseFromParent();
  }
  // Finally, delete the Alloca instruction
  AI->eraseFromParent();
/// isSafeElementUse - Check to see if this use is an allowed use for a
/// getelementptr instruction of an array aggregate allocation.  isFirstElt
/// indicates whether Ptr is known to the start of the aggregate.
void SROA::isSafeElementUse(Value *Ptr, bool isFirstElt, AllocaInst *AI,
                            AllocaInfo &Info) {
  for (Value::use_iterator I = Ptr->use_begin(), E = Ptr->use_end();
       I != E; ++I) {
    Instruction *User = cast<Instruction>(*I);
    switch (User->getOpcode()) {
    case Instruction::Load:  break;
    case Instruction::Store:
      // Store is ok if storing INTO the pointer, not storing the pointer
      if (User->getOperand(0) == Ptr) return MarkUnsafe(Info);
      break;
    case Instruction::GetElementPtr: {
      GetElementPtrInst *GEP = cast<GetElementPtrInst>(User);
      bool AreAllZeroIndices = isFirstElt;
      if (GEP->getNumOperands() > 1 &&
          (!isa<ConstantInt>(GEP->getOperand(1)) ||
           !cast<ConstantInt>(GEP->getOperand(1))->isZero()))
        // Using pointer arithmetic to navigate the array.
        return MarkUnsafe(Info);
      
      // Verify that any array subscripts are in range.
      for (gep_type_iterator GEPIt = gep_type_begin(GEP),
           E = gep_type_end(GEP); GEPIt != E; ++GEPIt) {
        // Ignore struct elements, no extra checking needed for these.
        if (isa<StructType>(*GEPIt))
          continue;
        // This GEP indexes an array.  Verify that this is an in-range
        // constant integer. Specifically, consider A[0][i]. We cannot know that
        // the user isn't doing invalid things like allowing i to index an
        // out-of-range subscript that accesses A[1].  Because of this, we have
        // to reject SROA of any accesses into structs where any of the
        // components are variables. 
        ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
        if (!IdxVal) return MarkUnsafe(Info);
        
        // Are all indices still zero?
        AreAllZeroIndices &= IdxVal->isZero();
        
        if (const ArrayType *AT = dyn_cast<ArrayType>(*GEPIt)) {
          if (IdxVal->getZExtValue() >= AT->getNumElements())
            return MarkUnsafe(Info);
        } else if (const VectorType *VT = dyn_cast<VectorType>(*GEPIt)) {
          if (IdxVal->getZExtValue() >= VT->getNumElements())
            return MarkUnsafe(Info);
        }
      
      isSafeElementUse(GEP, AreAllZeroIndices, AI, Info);
      if (Info.isUnsafe) return;
      break;
    }
    case Instruction::BitCast:
      if (isFirstElt) {
        isSafeUseOfBitCastedAllocation(cast<BitCastInst>(User), AI, Info);
        if (Info.isUnsafe) return;
        break;
      }
      DEBUG(errs() << "  Transformation preventing inst: " << *User << '\n');
      return MarkUnsafe(Info);
    case Instruction::Call:
      if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) {
        if (isFirstElt) {
          isSafeMemIntrinsicOnAllocation(MI, AI, I.getOperandNo(), Info);
          if (Info.isUnsafe) return;
          break;
        }
      }
      DEBUG(errs() << "  Transformation preventing inst: " << *User << '\n');
      return MarkUnsafe(Info);
    default:
      DEBUG(errs() << "  Transformation preventing inst: " << *User << '\n');
      return MarkUnsafe(Info);
    }
  return;  // All users look ok :)
/// AllUsersAreLoads - Return true if all users of this value are loads.
static bool AllUsersAreLoads(Value *Ptr) {
  for (Value::use_iterator I = Ptr->use_begin(), E = Ptr->use_end();
       I != E; ++I)
    if (cast<Instruction>(*I)->getOpcode() != Instruction::Load)
      return false;
  return true;
/// isSafeUseOfAllocation - Check if this user is an allowed use for an
/// aggregate allocation.
void SROA::isSafeUseOfAllocation(Instruction *User, AllocaInst *AI,
                                 AllocaInfo &Info) {
  if (BitCastInst *C = dyn_cast<BitCastInst>(User))
    return isSafeUseOfBitCastedAllocation(C, AI, Info);

  if (LoadInst *LI = dyn_cast<LoadInst>(User))
    if (!LI->isVolatile())
      return;// Loads (returning a first class aggregrate) are always rewritable

  if (StoreInst *SI = dyn_cast<StoreInst>(User))
    if (!SI->isVolatile() && SI->getOperand(0) != AI)
      return;// Store is ok if storing INTO the pointer, not storing the pointer
 
  GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(User);
  if (GEPI == 0)
    return MarkUnsafe(Info);
  gep_type_iterator I = gep_type_begin(GEPI), E = gep_type_end(GEPI);
  // The GEP is not safe to transform if not of the form "GEP <ptr>, 0, <cst>".
  if (I == E ||
      I.getOperand() != Constant::getNullValue(I.getOperand()->getType())) {
  ++I;
  if (I == E) return MarkUnsafe(Info);  // ran out of GEP indices??

  bool IsAllZeroIndices = true;
  
  // If the first index is a non-constant index into an array, see if we can
  // handle it as a special case.
  if (const ArrayType *AT = dyn_cast<ArrayType>(*I)) {
    if (!isa<ConstantInt>(I.getOperand())) {
      IsAllZeroIndices = 0;
      uint64_t NumElements = AT->getNumElements();
      
      // If this is an array index and the index is not constant, we cannot
      // promote... that is unless the array has exactly one or two elements in
      // it, in which case we CAN promote it, but we have to canonicalize this
      // out if this is the only problem.
      if ((NumElements == 1 || NumElements == 2) &&
          AllUsersAreLoads(GEPI)) {
        return;  // Canonicalization required!
      return MarkUnsafe(Info);
  // Walk through the GEP type indices, checking the types that this indexes
  // into.
    // Ignore struct elements, no extra checking needed for these.
    if (isa<StructType>(*I))
    
    ConstantInt *IdxVal = dyn_cast<ConstantInt>(I.getOperand());
    if (!IdxVal) return MarkUnsafe(Info);
    // Are all indices still zero?
    IsAllZeroIndices &= IdxVal->isZero();
    
    if (const ArrayType *AT = dyn_cast<ArrayType>(*I)) {
      // This GEP indexes an array.  Verify that this is an in-range constant
      // integer. Specifically, consider A[0][i]. We cannot know that the user
      // isn't doing invalid things like allowing i to index an out-of-range
      // subscript that accesses A[1].  Because of this, we have to reject SROA
Bob Wilson's avatar
Bob Wilson committed
      // of any accesses into structs where any of the components are variables.
      if (IdxVal->getZExtValue() >= AT->getNumElements())
        return MarkUnsafe(Info);
    } else if (const VectorType *VT = dyn_cast<VectorType>(*I)) {
      if (IdxVal->getZExtValue() >= VT->getNumElements())
        return MarkUnsafe(Info);
  
  // If there are any non-simple uses of this getelementptr, make sure to reject
  // them.
  return isSafeElementUse(GEPI, IsAllZeroIndices, AI, Info);
/// isSafeMemIntrinsicOnAllocation - Check if the specified memory
/// intrinsic can be promoted by SROA.  At this point, we know that the operand
/// of the memintrinsic is a pointer to the beginning of the allocation.
void SROA::isSafeMemIntrinsicOnAllocation(MemIntrinsic *MI, AllocaInst *AI,
                                          unsigned OpNo, AllocaInfo &Info) {
  // If not constant length, give up.
  ConstantInt *Length = dyn_cast<ConstantInt>(MI->getLength());
  if (!Length) return MarkUnsafe(Info);
  
  // If not the whole aggregate, give up.
  if (Length->getZExtValue() !=
      TD->getTypeAllocSize(AI->getType()->getElementType()))
    return MarkUnsafe(Info);
  
  // We only know about memcpy/memset/memmove.
  if (!isa<MemIntrinsic>(MI))
    return MarkUnsafe(Info);
  
  // Otherwise, we can transform it.  Determine whether this is a memcpy/set
  // into or out of the aggregate.
  if (OpNo == 1)
    Info.isMemCpyDst = true;
  else {
    assert(OpNo == 2);
    Info.isMemCpySrc = true;
/// isSafeUseOfBitCastedAllocation - Check if all users of this bitcast
/// from an alloca are safe for SROA of that alloca.
void SROA::isSafeUseOfBitCastedAllocation(BitCastInst *BC, AllocaInst *AI,
                                          AllocaInfo &Info) {
  for (Value::use_iterator UI = BC->use_begin(), E = BC->use_end();
       UI != E; ++UI) {
    if (BitCastInst *BCU = dyn_cast<BitCastInst>(UI)) {
      isSafeUseOfBitCastedAllocation(BCU, AI, Info);
    } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(UI)) {
      isSafeMemIntrinsicOnAllocation(MI, AI, UI.getOperandNo(), Info);
    } else if (StoreInst *SI = dyn_cast<StoreInst>(UI)) {
      if (SI->isVolatile())
        return MarkUnsafe(Info);
      
      // If storing the entire alloca in one chunk through a bitcasted pointer
      // to integer, we can transform it.  This happens (for example) when you
      // cast a {i32,i32}* to i64* and store through it.  This is similar to the
      // memcpy case and occurs in various "byval" cases and emulated memcpys.
      if (isa<IntegerType>(SI->getOperand(0)->getType()) &&
          TD->getTypeAllocSize(SI->getOperand(0)->getType()) ==
          TD->getTypeAllocSize(AI->getType()->getElementType())) {
        Info.isMemCpyDst = true;
        continue;
      }
      return MarkUnsafe(Info);
    } else if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
      if (LI->isVolatile())
        return MarkUnsafe(Info);
      // If loading the entire alloca in one chunk through a bitcasted pointer
      // to integer, we can transform it.  This happens (for example) when you
      // cast a {i32,i32}* to i64* and load through it.  This is similar to the
      // memcpy case and occurs in various "byval" cases and emulated memcpys.
      if (isa<IntegerType>(LI->getType()) &&
          TD->getTypeAllocSize(LI->getType()) ==
          TD->getTypeAllocSize(AI->getType()->getElementType())) {
        Info.isMemCpySrc = true;
        continue;
      return MarkUnsafe(Info);
    } else if (isa<DbgInfoIntrinsic>(UI)) {
      // If one user is DbgInfoIntrinsic then check if all users are
      // DbgInfoIntrinsics.
      if (OnlyUsedByDbgInfoIntrinsics(BC)) {
        Info.needsCleanup = true;
        return;
    else {
      return MarkUnsafe(Info);
    }
    if (Info.isUnsafe) return;
/// RewriteBitCastUserOfAlloca - BCInst (transitively) bitcasts AI, or indexes
/// to its first element.  Transform users of the cast to use the new values
/// instead.
void SROA::RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocaInst *AI,
                                      SmallVector<AllocaInst*, 32> &NewElts) {
  Value::use_iterator UI = BCInst->use_begin(), UE = BCInst->use_end();
  while (UI != UE) {
    Instruction *User = cast<Instruction>(*UI++);
    if (BitCastInst *BCU = dyn_cast<BitCastInst>(User)) {
      RewriteBitCastUserOfAlloca(BCU, AI, NewElts);
      if (BCU->use_empty()) BCU->eraseFromParent();
      continue;
    }
    if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) {
      // This must be memcpy/memmove/memset of the entire aggregate.
      // Split into one per element.
      RewriteMemIntrinUserOfAlloca(MI, BCInst, AI, NewElts);
      continue;
    }
      
    if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
      // If this is a store of the entire alloca from an integer, rewrite it.
      RewriteStoreUserOfWholeAlloca(SI, AI, NewElts);
      continue;
    }
    if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
      // If this is a load of the entire alloca to an integer, rewrite it.
      RewriteLoadUserOfWholeAlloca(LI, AI, NewElts);
      continue;
    }
    
    // Otherwise it must be some other user of a gep of the first pointer.  Just
    // leave these alone.
    continue;
}

/// RewriteMemIntrinUserOfAlloca - MI is a memcpy/memset/memmove from or to AI.
/// Rewrite it to copy or set the elements of the scalarized memory.
void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *BCInst,
                                        SmallVector<AllocaInst*, 32> &NewElts) {
  // If this is a memcpy/memmove, construct the other pointer as the
Chris Lattner's avatar
Chris Lattner committed
  // appropriate type.  The "Other" pointer is the pointer that goes to memory
  // that doesn't have anything to do with the alloca that we are promoting. For
  // memset, this Value* stays null.
  Value *OtherPtr = 0;
  LLVMContext &Context = MI->getContext();
  unsigned MemAlignment = MI->getAlignment();
  if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { // memmove/memcopy
    if (BCInst == MTI->getRawDest())
      OtherPtr = MTI->getRawSource();
      assert(BCInst == MTI->getRawSource());
      OtherPtr = MTI->getRawDest();
  // Keep track of the other intrinsic argument, so it can be removed if it
  // is dead when the intrinsic is replaced.
  Value *PossiblyDead = OtherPtr;
  
  // If there is an other pointer, we want to convert it to the same pointer
  // type as AI has, so we can GEP through it safely.
  if (OtherPtr) {
    // It is likely that OtherPtr is a bitcast, if so, remove it.
    if (BitCastInst *BC = dyn_cast<BitCastInst>(OtherPtr))
      OtherPtr = BC->getOperand(0);
    // All zero GEPs are effectively bitcasts.
    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(OtherPtr))
      if (GEP->hasAllZeroIndices())
        OtherPtr = GEP->getOperand(0);
    if (ConstantExpr *BCE = dyn_cast<ConstantExpr>(OtherPtr))
      if (BCE->getOpcode() == Instruction::BitCast)
        OtherPtr = BCE->getOperand(0);
    
    // If the pointer is not the right type, insert a bitcast to the right
    // type.
    if (OtherPtr->getType() != AI->getType())
      OtherPtr = new BitCastInst(OtherPtr, AI->getType(), OtherPtr->getName(),
                                 MI);
  }
  
  // Process each element of the aggregate.
  Value *TheFn = MI->getOperand(0);
  const Type *BytePtrTy = MI->getRawDest()->getType();
  bool SROADest = MI->getRawDest() == BCInst;
  Constant *Zero = Constant::getNullValue(Type::getInt32Ty(MI->getContext()));
  for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
    // If this is a memcpy/memmove, emit a GEP of the other element address.
    Value *OtherElt = 0;
      Value *Idx[2] = { Zero,
                      ConstantInt::get(Type::getInt32Ty(MI->getContext()), i) };
      OtherElt = GetElementPtrInst::Create(OtherPtr, Idx, Idx + 2,
                                           OtherPtr->getNameStr()+"."+Twine(i),
      uint64_t EltOffset;
      const PointerType *OtherPtrTy = cast<PointerType>(OtherPtr->getType());
      if (const StructType *ST =
            dyn_cast<StructType>(OtherPtrTy->getElementType())) {
        EltOffset = TD->getStructLayout(ST)->getElementOffset(i);
      } else {
        const Type *EltTy =
          cast<SequentialType>(OtherPtr->getType())->getElementType();
        EltOffset = TD->getTypeAllocSize(EltTy)*i;
      }
      
      // The alignment of the other pointer is the guaranteed alignment of the
      // element, which is affected by both the known alignment of the whole
      // mem intrinsic and the alignment of the element.  If the alignment of
      // the memcpy (f.e.) is 32 but the element is at a 4-byte offset, then the
      // known alignment is just 4 bytes.
      OtherEltAlign = (unsigned)MinAlign(OtherEltAlign, EltOffset);
    }
    
    Value *EltPtr = NewElts[i];
    const Type *EltTy = cast<PointerType>(EltPtr->getType())->getElementType();
    
    // If we got down to a scalar, insert a load or store as appropriate.
    if (EltTy->isSingleValueType()) {
      if (isa<MemTransferInst>(MI)) {
        if (SROADest) {
          // From Other to Alloca.
          Value *Elt = new LoadInst(OtherElt, "tmp", false, OtherEltAlign, MI);
          new StoreInst(Elt, EltPtr, MI);
        } else {
          // From Alloca to Other.
          Value *Elt = new LoadInst(EltPtr, "tmp", MI);
          new StoreInst(Elt, OtherElt, false, OtherEltAlign, MI);
        }
      assert(isa<MemSetInst>(MI));
      // If the stored element is zero (common case), just store a null
      // constant.
      Constant *StoreVal;
      if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getOperand(2))) {
        if (CI->isZero()) {
          StoreVal = Constant::getNullValue(EltTy);  // 0.0, null, 0, <0,0>
          // If EltTy is a vector type, get the element type.
Dan Gohman's avatar
Dan Gohman committed
          const Type *ValTy = EltTy->getScalarType();

          // Construct an integer with the right value.
          unsigned EltSize = TD->getTypeSizeInBits(ValTy);
          APInt OneVal(EltSize, CI->getZExtValue());
          APInt TotalVal(OneVal);
          // Set each byte.
          for (unsigned i = 0; 8*i < EltSize; ++i) {
            TotalVal = TotalVal.shl(8);
            TotalVal |= OneVal;
          }
          
          // Convert the integer value to the appropriate type.
          StoreVal = ConstantInt::get(Context, TotalVal);
          if (isa<PointerType>(ValTy))
            StoreVal = ConstantExpr::getIntToPtr(StoreVal, ValTy);
          else if (ValTy->isFloatingPoint())
            StoreVal = ConstantExpr::getBitCast(StoreVal, ValTy);
          assert(StoreVal->getType() == ValTy && "Type mismatch!");
          
          // If the requested value was a vector constant, create it.
          if (EltTy != ValTy) {
            unsigned NumElts = cast<VectorType>(ValTy)->getNumElements();
            SmallVector<Constant*, 16> Elts(NumElts, StoreVal);
            StoreVal = ConstantVector::get(&Elts[0], NumElts);
        new StoreInst(StoreVal, EltPtr, MI);
        continue;
      // Otherwise, if we're storing a byte variable, use a memset call for
      // this element.
    }
    // Cast the element pointer to BytePtrTy.
    if (EltPtr->getType() != BytePtrTy)
      EltPtr = new BitCastInst(EltPtr, BytePtrTy, EltPtr->getNameStr(), MI);
    // Cast the other pointer (if we have one) to BytePtrTy. 
    if (OtherElt && OtherElt->getType() != BytePtrTy)
      OtherElt = new BitCastInst(OtherElt, BytePtrTy,OtherElt->getNameStr(),
                                 MI);
    
    unsigned EltSize = TD->getTypeAllocSize(EltTy);
    
    // Finally, insert the meminst for this element.
    if (isa<MemTransferInst>(MI)) {
      Value *Ops[] = {
        SROADest ? EltPtr : OtherElt,  // Dest ptr
        SROADest ? OtherElt : EltPtr,  // Src ptr
        ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size
        // Align
        ConstantInt::get(Type::getInt32Ty(MI->getContext()), OtherEltAlign)
      };
      CallInst::Create(TheFn, Ops, Ops + 4, "", MI);
    } else {
      assert(isa<MemSetInst>(MI));
      Value *Ops[] = {
        EltPtr, MI->getOperand(2),  // Dest, Value,
        ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size
        Zero  // Align
      };
      CallInst::Create(TheFn, Ops, Ops + 4, "", MI);
  MI->eraseFromParent();
  if (PossiblyDead)
    RecursivelyDeleteTriviallyDeadInstructions(PossiblyDead);
Bob Wilson's avatar
Bob Wilson committed
/// RewriteStoreUserOfWholeAlloca - We found a store of an integer that
/// overwrites the entire allocation.  Extract out the pieces of the stored
/// integer and store them individually.
void SROA::RewriteStoreUserOfWholeAlloca(StoreInst *SI, AllocaInst *AI,
                                         SmallVector<AllocaInst*, 32> &NewElts){
  // Extract each element out of the integer according to its structure offset
  // and store the element value to the individual alloca.
  Value *SrcVal = SI->getOperand(0);
  const Type *AllocaEltTy = AI->getType()->getElementType();
  uint64_t AllocaSizeBits = TD->getTypeAllocSizeInBits(AllocaEltTy);
  // If this isn't a store of an integer to the whole alloca, it may be a store
  // to the first element.  Just ignore the store in this case and normal SROA
  // will handle it.
  if (!isa<IntegerType>(SrcVal->getType()) ||
      TD->getTypeAllocSizeInBits(SrcVal->getType()) != AllocaSizeBits)
    return;
  // Handle tail padding by extending the operand
  if (TD->getTypeSizeInBits(SrcVal->getType()) != AllocaSizeBits)
    SrcVal = new ZExtInst(SrcVal,
                          IntegerType::get(SI->getContext(), AllocaSizeBits), 
                          "", SI);
  DEBUG(errs() << "PROMOTING STORE TO WHOLE ALLOCA: " << *AI << '\n' << *SI
               << '\n');

  // There are two forms here: AI could be an array or struct.  Both cases
  // have different ways to compute the element offset.
  if (const StructType *EltSTy = dyn_cast<StructType>(AllocaEltTy)) {
    const StructLayout *Layout = TD->getStructLayout(EltSTy);
    
    for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
      // Get the number of bits to shift SrcVal to get the value.
      const Type *FieldTy = EltSTy->getElementType(i);
      uint64_t Shift = Layout->getElementOffsetInBits(i);
      
      if (TD->isBigEndian())
        Shift = AllocaSizeBits-Shift-TD->getTypeAllocSizeInBits(FieldTy);
        Value *ShiftVal = ConstantInt::get(EltVal->getType(), Shift);
        EltVal = BinaryOperator::CreateLShr(EltVal, ShiftVal,
                                            "sroa.store.elt", SI);
      }
      
      // Truncate down to an integer of the right size.
      uint64_t FieldSizeBits = TD->getTypeSizeInBits(FieldTy);
Chris Lattner's avatar
Chris Lattner committed
      
      // Ignore zero sized fields like {}, they obviously contain no data.
      if (FieldSizeBits == 0) continue;
      
      if (FieldSizeBits != AllocaSizeBits)
        EltVal = new TruncInst(EltVal,
                             IntegerType::get(SI->getContext(), FieldSizeBits),
                              "", SI);
      Value *DestField = NewElts[i];
      if (EltVal->getType() == FieldTy) {
        // Storing to an integer field of this size, just do it.
      } else if (FieldTy->isFloatingPoint() || isa<VectorType>(FieldTy)) {
        // Bitcast to the right element type (for fp/vector values).
        EltVal = new BitCastInst(EltVal, FieldTy, "", SI);
      } else {
        // Otherwise, bitcast the dest pointer (for aggregates).
        DestField = new BitCastInst(DestField,
                              PointerType::getUnqual(EltVal->getType()),
                                    "", SI);
      }
      new StoreInst(EltVal, DestField, SI);