Commit 11996586 authored by Slava Nikolaev's avatar Slava Nikolaev Committed by Volkan Keles
Browse files

LoadStoreVectorizer: support different operand orders in the add sequence match

First we refactor the code which does no wrapping add sequences
match: we need to allow different operand orders for
the key add instructions involved in the match.

Then we use the refactored code trying 4 variants of matching operands.

Originally the code relied on the fact that the matching operands
of the two last add instructions of memory index calculations
had the same LHS argument. But which operand is the same
in the two instructions is actually not essential, so now we allow
that to be any of LHS or RHS of each of the two instructions.
This increases the chances of vectorization to happen.

Reviewed By: volkan

Differential Revision: https://reviews.llvm.org/D103912
parent b73742bc
......@@ -384,6 +384,81 @@ bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB,
return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta, Depth);
}
static bool checkNoWrapFlags(Instruction *I, bool Signed) {
BinaryOperator *BinOpI = cast<BinaryOperator>(I);
return (Signed && BinOpI->hasNoSignedWrap()) ||
(!Signed && BinOpI->hasNoUnsignedWrap());
}
static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA,
unsigned MatchingOpIdxA, Instruction *AddOpB,
unsigned MatchingOpIdxB, bool Signed) {
// If both OpA and OpB is an add with NSW/NUW and with
// one of the operands being the same, we can guarantee that the
// transformation is safe if we can prove that OpA won't overflow when
// IdxDiff added to the other operand of OpA.
// For example:
// %tmp7 = add nsw i32 %tmp2, %v0
// %tmp8 = sext i32 %tmp7 to i64
// ...
// %tmp11 = add nsw i32 %v0, 1
// %tmp12 = add nsw i32 %tmp2, %tmp11
// %tmp13 = sext i32 %tmp12 to i64
//
// Both %tmp7 and %tmp2 has the nsw flag and the first operand
// is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow
// because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the
// nsw flag.
assert(AddOpA->getOpcode() == Instruction::Add &&
AddOpB->getOpcode() == Instruction::Add &&
checkNoWrapFlags(AddOpA, Signed) && checkNoWrapFlags(AddOpB, Signed));
if (AddOpA->getOperand(MatchingOpIdxA) ==
AddOpB->getOperand(MatchingOpIdxB)) {
Value *OtherOperandA = AddOpA->getOperand(MatchingOpIdxA == 1 ? 0 : 1);
Value *OtherOperandB = AddOpB->getOperand(MatchingOpIdxB == 1 ? 0 : 1);
Instruction *OtherInstrA = dyn_cast<Instruction>(OtherOperandA);
Instruction *OtherInstrB = dyn_cast<Instruction>(OtherOperandB);
// Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`.
if (OtherInstrB && OtherInstrB->getOpcode() == Instruction::Add &&
checkNoWrapFlags(OtherInstrB, Signed) &&
isa<ConstantInt>(OtherInstrB->getOperand(1))) {
int64_t CstVal =
cast<ConstantInt>(OtherInstrB->getOperand(1))->getSExtValue();
if (OtherInstrB->getOperand(0) == OtherOperandA &&
IdxDiff.getSExtValue() == CstVal)
return true;
}
// Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`.
if (OtherInstrA && OtherInstrA->getOpcode() == Instruction::Add &&
checkNoWrapFlags(OtherInstrA, Signed) &&
isa<ConstantInt>(OtherInstrA->getOperand(1))) {
int64_t CstVal =
cast<ConstantInt>(OtherInstrA->getOperand(1))->getSExtValue();
if (OtherInstrA->getOperand(0) == OtherOperandB &&
IdxDiff.getSExtValue() == -CstVal)
return true;
}
// Match `x +nsw/nuw (y +nsw/nuw c)` and
// `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`.
if (OtherInstrA && OtherInstrB &&
OtherInstrA->getOpcode() == Instruction::Add &&
OtherInstrB->getOpcode() == Instruction::Add &&
checkNoWrapFlags(OtherInstrA, Signed) &&
checkNoWrapFlags(OtherInstrB, Signed) &&
isa<ConstantInt>(OtherInstrA->getOperand(1)) &&
isa<ConstantInt>(OtherInstrB->getOperand(1))) {
int64_t CstValA =
cast<ConstantInt>(OtherInstrA->getOperand(1))->getSExtValue();
int64_t CstValB =
cast<ConstantInt>(OtherInstrB->getOperand(1))->getSExtValue();
if (OtherInstrA->getOperand(0) == OtherInstrB->getOperand(0) &&
IdxDiff.getSExtValue() == (CstValB - CstValA))
return true;
}
}
return false;
}
bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
APInt PtrDelta,
unsigned Depth) const {
......@@ -438,73 +513,30 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
// Now we need to prove that adding IdxDiff to ValA won't overflow.
bool Safe = false;
auto CheckFlags = [](Instruction *I, bool Signed) {
BinaryOperator *BinOpI = cast<BinaryOperator>(I);
return (Signed && BinOpI->hasNoSignedWrap()) ||
(!Signed && BinOpI->hasNoUnsignedWrap());
};
// First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to
// ValA, we're okay.
if (OpB->getOpcode() == Instruction::Add &&
isa<ConstantInt>(OpB->getOperand(1)) &&
IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue()) &&
CheckFlags(OpB, Signed))
checkNoWrapFlags(OpB, Signed))
Safe = true;
// Second attempt: If both OpA and OpB is an add with NSW/NUW and with
// the same LHS operand, we can guarantee that the transformation is safe
// if we can prove that OpA won't overflow when IdxDiff added to the RHS
// of OpA.
// For example:
// %tmp7 = add nsw i32 %tmp2, %v0
// %tmp8 = sext i32 %tmp7 to i64
// ...
// %tmp11 = add nsw i32 %v0, 1
// %tmp12 = add nsw i32 %tmp2, %tmp11
// %tmp13 = sext i32 %tmp12 to i64
//
// Both %tmp7 and %tmp2 has the nsw flag and the first operand
// is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow
// because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the
// nsw flag.
// Second attempt: check if we have eligible add NSW/NUW instruction
// sequences.
OpA = dyn_cast<Instruction>(ValA);
if (!Safe && OpA && OpA->getOpcode() == Instruction::Add &&
OpB->getOpcode() == Instruction::Add &&
OpA->getOperand(0) == OpB->getOperand(0) && CheckFlags(OpA, Signed) &&
CheckFlags(OpB, Signed)) {
Value *RHSA = OpA->getOperand(1);
Value *RHSB = OpB->getOperand(1);
Instruction *OpRHSA = dyn_cast<Instruction>(RHSA);
Instruction *OpRHSB = dyn_cast<Instruction>(RHSB);
// Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`.
if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add &&
CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSB->getOperand(1))) {
int64_t CstVal = cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue();
if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal)
Safe = true;
}
// Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`.
if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add &&
CheckFlags(OpRHSA, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1))) {
int64_t CstVal = cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue();
if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal)
Safe = true;
}
// Match `x +nsw/nuw (y +nsw/nuw c)` and
// `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`.
if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add &&
OpRHSB->getOpcode() == Instruction::Add && CheckFlags(OpRHSA, Signed) &&
CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1)) &&
isa<ConstantInt>(OpRHSB->getOperand(1))) {
int64_t CstValA =
cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue();
int64_t CstValB =
cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue();
if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) &&
IdxDiff.getSExtValue() == (CstValB - CstValA))
Safe = true;
}
OpB->getOpcode() == Instruction::Add && checkNoWrapFlags(OpA, Signed) &&
checkNoWrapFlags(OpB, Signed)) {
// In the checks below a matching operand in OpA and OpB is
// an operand which is the same in those two instructions.
// Below we account for possible orders of the operands of
// these add instructions.
for (unsigned MatchingOpIdxA : {0, 1})
for (unsigned MatchingOpIdxB : {0, 1})
if (!Safe)
Safe = checkIfSafeAddSequence(IdxDiff, OpA, MatchingOpIdxA, OpB,
MatchingOpIdxB, Signed);
}
unsigned BitWidth = ValA->getType()->getScalarSizeInBits();
......
......@@ -104,6 +104,104 @@ bb:
ret void
}
; Apply different operand orders for the nested add sequences
define void @ld_v4i8_add_nsw_operand_orders(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) {
; CHECK-LABEL: @ld_v4i8_add_nsw_operand_orders(
; CHECK-NEXT: bb:
; CHECK-NEXT: [[TMP:%.*]] = add nsw i32 [[V0:%.*]], -1
; CHECK-NEXT: [[TMP1:%.*]] = add nsw i32 [[V1:%.*]], [[TMP]]
; CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP1]] to i64
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]]
; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>*
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1
; CHECK-NEXT: [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0
; CHECK-NEXT: [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1
; CHECK-NEXT: [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2
; CHECK-NEXT: [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3
; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0
; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1
; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2
; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3
; CHECK-NEXT: store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]]
; CHECK-NEXT: ret void
;
bb:
%tmp = add nsw i32 %v0, -1
%tmp1 = add nsw i32 %v1, %tmp
%tmp2 = sext i32 %tmp1 to i64
%tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2
%tmp4 = load i8, i8* %tmp3, align 1
%tmp5 = add nsw i32 %v0, %v1
%tmp6 = sext i32 %tmp5 to i64
%tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6
%tmp8 = load i8, i8* %tmp7, align 1
%tmp9 = add nsw i32 %v0, 1
%tmp10 = add nsw i32 %tmp9, %v1
%tmp11 = sext i32 %tmp10 to i64
%tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11
%tmp13 = load i8, i8* %tmp12, align 1
%tmp14 = add nsw i32 %v0, 2
%tmp15 = add nsw i32 %v1, %tmp14
%tmp16 = sext i32 %tmp15 to i64
%tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16
%tmp18 = load i8, i8* %tmp17, align 1
%tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0
%tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1
%tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2
%tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3
store <4 x i8> %tmp22, <4 x i8>* %dst
ret void
}
; Apply different operand orders for the nested add sequences
define void @ld_v4i8_add_nuw_operand_orders(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) {
; CHECK-LABEL: @ld_v4i8_add_nuw_operand_orders(
; CHECK-NEXT: bb:
; CHECK-NEXT: [[TMP:%.*]] = add nuw i32 [[V0:%.*]], -1
; CHECK-NEXT: [[TMP1:%.*]] = add nuw i32 [[V1:%.*]], [[TMP]]
; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]]
; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>*
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1
; CHECK-NEXT: [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0
; CHECK-NEXT: [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1
; CHECK-NEXT: [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2
; CHECK-NEXT: [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3
; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0
; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1
; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2
; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3
; CHECK-NEXT: store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]]
; CHECK-NEXT: ret void
;
bb:
%tmp = add nuw i32 %v0, -1
%tmp1 = add nuw i32 %v1, %tmp
%tmp2 = zext i32 %tmp1 to i64
%tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2
%tmp4 = load i8, i8* %tmp3, align 1
%tmp5 = add nuw i32 %v0, %v1
%tmp6 = zext i32 %tmp5 to i64
%tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6
%tmp8 = load i8, i8* %tmp7, align 1
%tmp9 = add nuw i32 %v0, 1
%tmp10 = add nuw i32 %tmp9, %v1
%tmp11 = zext i32 %tmp10 to i64
%tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11
%tmp13 = load i8, i8* %tmp12, align 1
%tmp14 = add nuw i32 %v0, 2
%tmp15 = add nuw i32 %v1, %tmp14
%tmp16 = zext i32 %tmp15 to i64
%tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16
%tmp18 = load i8, i8* %tmp17, align 1
%tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0
%tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1
%tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2
%tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3
store <4 x i8> %tmp22, <4 x i8>* %dst
ret void
}
define void @ld_v4i8_add_known_bits(i32 %ind0, i32 %ind1, i8* %src, <4 x i8>* %dst) {
; CHECK-LABEL: @ld_v4i8_add_known_bits(
; CHECK-NEXT: bb:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment