Skip to content
VectorPredication.cpp 7.7 KiB
Newer Older
#include "llvm/Transforms/Vectorize/VectorPredication.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/DDG.h"
#include "llvm/Analysis/DDGPrinter.h"
#include "llvm/Analysis/DependenceAnalysis.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/VectorBuilder.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/GraphWriter.h"
#include "llvm/Transforms/Utils/Local.h"

#define DEBUG_TYPE "vector-predication"
STATISTIC(Transforms, "Number of full-length -> evl vector transformation.");

using namespace llvm;

namespace {
cl::opt<bool> PrintDDG("print-ddg", cl::Hidden, cl::init(false),
                       cl::desc("Create dot file for the DDG"));
} // namespace
// Register all memory writing VP instructions found in the function.
void VectorPredicationPass::analyseFunction(Function &F) {
  for (BasicBlock &BB : F) {
    for (Instruction &I : BB) {
      if (auto *CI = dyn_cast<CallInst>(&I)) {
        Intrinsic::ID ID = CI->getIntrinsicID();
        if (ID == Intrinsic::vp_store || ID == Intrinsic::vp_scatter) {
          MemoryWritingVPInstructions.push_back(&I);
static void findCandidateVectorOperation(Value *Op, Value *Mask, Value *EVL,
                                         InstToMaskEVLMap &VecOpsToTransform) {
  auto *OpInst = dyn_cast<Instruction>(Op);
  if (!OpInst)
    return;

  Intrinsic::ID VPID = VPIntrinsic::getForOpcode(OpInst->getOpcode());
  if (VPID == Intrinsic::not_intrinsic)
    return;

  // If the instruction is already present in the map, it means it was already
  // visited starting from a previous memory writing vp operation.
  if (!VecOpsToTransform
           .insert(std::make_pair(OpInst, std::make_pair(Mask, EVL)))
           .second) {
    // We need to check if new mask and evl values differ from the old ones:
    // - if they are the same, then there is nothing to do;
    // - if only the mask differ, we use an allones mask;
    // - otherwise, we remove the instruction from the map (i.e., no
    //   transformation should happen)
    auto It = VecOpsToTransform.find(OpInst);
    assert(It != VecOpsToTransform.end());
    Value *OldMask, *OldEVL;
    std::tie(OldMask, OldEVL) = It->second;

    if (Mask == OldMask && EVL == OldEVL)
      return;

    VecOpsToTransform.erase(OpInst);
    if (EVL == OldEVL) {
      VecOpsToTransform.insert(
          std::make_pair(OpInst, std::make_pair(nullptr, EVL)));
    }
  }

  // Recursively visit OpInst operands.
  switch (VPID) {
  default:
    for (auto *OpVal : OpInst->operand_values())
      findCandidateVectorOperation(OpVal, Mask, EVL, VecOpsToTransform);
    break;
  case Intrinsic::vp_select: {
    Value *Cond = OpInst->getOperand(0);
    if (Cond->getType()->isVectorTy())
      findCandidateVectorOperation(Cond, nullptr, EVL, VecOpsToTransform);

    // TODO: if the condition argument is a vector, we could backpropagate it
    // as mask for the true branch and its negation as mask for the false one.
    // WARNING: when creating the negation of the condition, we must ensure it
    // dominates all uses.
    findCandidateVectorOperation(OpInst->getOperand(1), nullptr, EVL,
                                 VecOpsToTransform);
    findCandidateVectorOperation(OpInst->getOperand(2), nullptr, EVL,
                                 VecOpsToTransform);
    break;
  }
  }
}

// For each VP memory writing operation, go back to the stored vector defining
// instruction and verify it is a vector operation. Add it to the list of
// instructions to be transformed into vector predicated ones, then recursively
// repeat the process for its vector arguments.
void VectorPredicationPass::findCandidateVectorOperations() {
  if (MemoryWritingVPInstructions.empty())
  for (Instruction *I : MemoryWritingVPInstructions) {
    auto *VPI = cast<VPIntrinsic>(I);
    Value *StoredOperand = VPI->getMemoryDataParam();
    Value *MaskOperand = VPI->getMaskParam();
    Value *EVLOperand = VPI->getVectorLengthParam();
    // First, visit the mask operand (assigning an allones mask to this branch)
    // and only then visit the stored operand.
    findCandidateVectorOperation(MaskOperand, nullptr, EVLOperand,
                                 VecOpsToTransform);
    findCandidateVectorOperation(StoredOperand, MaskOperand, EVLOperand,
                                 VecOpsToTransform);
  }
}

// Transform candidates to vector predicated instructions.
void VectorPredicationPass::transformCandidateVectorOperations() {
  if (VecOpsToTransform.empty())
  for (auto [I, P] : VecOpsToTransform) {
    Value *Mask, *EVL;
    std::tie(Mask, EVL) = P;

    IRBuilder<> Builder(I);
    unsigned int Opcode = I->getOpcode();
    Type *RetTy = I->getType();
    SmallVector<Value *> Operands(I->value_op_begin(), I->value_op_end());
    switch (Opcode) {
    case Instruction::FCmp:
    case Instruction::ICmp: {
      Operands.clear();
      auto *CmpI = cast<CmpInst>(I);
      Value *PredOp = MetadataAsValue::get(
          Builder.getContext(),
          MDString::get(Builder.getContext(),
                        CmpInst::getPredicateName(CmpI->getPredicate())));
      Operands = {CmpI->getOperand(0), CmpI->getOperand(1), PredOp};
      break;
    }
    case Instruction::Select: {
      if (!I->getOperand(0)->getType()->isVectorTy()) {
        Operands.clear();
        Value *Op1 = I->getOperand(1);
        Value *Op2 = I->getOperand(2);
        Value *Cond = Builder.CreateVectorSplat(
            cast<VectorType>(Op1->getType())->getElementCount(),
            I->getOperand(0), "select.cond.splat");
        Operands = {Cond, Op1, Op2};
      }
      break;
    }
    default:
      break;
    }

    if (!Mask)
      // nullptr means unmasked operation, hence we use an all-ones mask.
      Mask = ConstantInt::getTrue(RetTy->getWithNewType(Builder.getInt1Ty()));

    VectorBuilder VecBuilder(Builder);
    VecBuilder.setMask(Mask).setEVL(EVL);
    Value *NewVPOp =
        VecBuilder.createVectorInstruction(Opcode, RetTy, Operands, "vp.op");

    Transforms++; // Stats
    OldInstructionsToRemove.insert(std::make_pair(I, NewVPOp));
  }
}

// Remove old instructions, if possible.
void VectorPredicationPass::removeOldInstructions() {
  for (auto [I, NewVPOp] : OldInstructionsToRemove) {
    I->replaceAllUsesWith(NewVPOp);
    if (isInstructionTriviallyDead(I))
      I->eraseFromParent();
  }

  OldInstructionsToRemove.clear();
}

PreservedAnalyses VectorPredicationPass::run(Function &F,
                                             FunctionAnalysisManager &AM) {
  assert(OldInstructionsToRemove.empty() &&
         "Map should be cleared at the end of each run of the pass.");

  analyseFunction(F);
  findCandidateVectorOperations();
  // TODO: before transformation, create DDG and use it to check if adding a new
  // edge creates a cycle.
  transformCandidateVectorOperations();
  removeOldInstructions();
  // Retrieve DependenceInfo for the function.
  TargetLibraryInfoImpl TLII;
  TargetLibraryInfo TLI(TLII);
  AAResults AA(TLI);
  AssumptionCache AC(F);
  DominatorTree DT(F);
  LoopInfo LI(DT);
  ScalarEvolution SE(F, TLI, AC, DT, LI);
  DependenceInfo DI(&F, &AA, &SE, &LI);

  DataDependenceGraph DDG(F, DI);
  if (PrintDDG) {
    std::string Filename = ("ddg." + F.getName() + ".dot").str();
    std::error_code EC;
    raw_fd_ostream File(Filename, EC, sys::fs::OF_Text);
    if (!EC)
      WriteGraph(File, &DDG);
    else
      errs() << "Error opening file for writing!\n";
  }

  // TODO: think about which analysis are preserved.
  return PreservedAnalyses::none();
}