diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 01fc1528681b79450137e92d89f3592b518762d7..aa59c67e009a10401b3f9425b44c838c44347c97 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -1339,25 +1339,39 @@ static bool equivalentAddressValues(Value *A, Value *B) { /// Converts store (bitcast (load (bitcast (select ...)))) to /// store (load (select ...)), where select is minmax: /// select ((cmp load V1, load V2), V1, V2). -bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, StoreInst &SI) { +Instruction *removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, + StoreInst &SI) { // bitcast? - Value *StoreAddr; - if (!match(SI.getPointerOperand(), m_BitCast(m_Value(StoreAddr)))) - return false; + if (!match(SI.getPointerOperand(), m_BitCast(m_Value()))) + return nullptr; // load? integer? Value *LoadAddr; if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr))))) - return false; + return nullptr; auto *LI = cast(SI.getValueOperand()); if (!LI->getType()->isIntegerTy()) - return false; + return nullptr; if (!isMinMaxWithLoads(LoadAddr)) - return false; + return nullptr; + if (!all_of(LI->users(), [LI, LoadAddr](User *U) { + auto *SI = dyn_cast(U); + return SI && SI->getPointerOperand() != LI && + peekThroughBitcast(SI->getPointerOperand()) != LoadAddr && + !SI->getPointerOperand()->isSwiftError(); + })) + return nullptr; + + IC.Builder.SetInsertPoint(LI); LoadInst *NewLI = combineLoadToNewType( IC, *LI, LoadAddr->getType()->getPointerElementType()); - combineStoreToNewValue(IC, SI, NewLI); - return true; + // Replace all the stores with stores of the newly loaded value. + for (auto *UI : LI->users()) { + auto *USI = cast(UI); + IC.Builder.SetInsertPoint(USI); + combineStoreToNewValue(IC, *USI, NewLI); + } + return LI; } Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { @@ -1384,8 +1398,12 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (unpackStoreToAggregate(*this, SI)) return eraseInstFromFunction(SI); - if (removeBitcastsFromLoadStoreOnMinMax(*this, SI)) - return eraseInstFromFunction(SI); + if (Instruction *I = removeBitcastsFromLoadStoreOnMinMax(*this, SI)) { + for (auto *UI : I->users()) + eraseInstFromFunction(*cast(UI)); + eraseInstFromFunction(*I); + return nullptr; + } // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) { diff --git a/llvm/test/Transforms/InstCombine/multiple-uses-load-bitcast-select.ll b/llvm/test/Transforms/InstCombine/multiple-uses-load-bitcast-select.ll new file mode 100644 index 0000000000000000000000000000000000000000..28509df6d2faa53cd3e06595499d7b8639bc3f28 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/multiple-uses-load-bitcast-select.ll @@ -0,0 +1,30 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instcombine -S -data-layout="E-m:e-i1:8:16-i8:8:16-i64:64-f128:64-v128:64-a:8:16-n32:64" | FileCheck %s + +define void @PR35618(i64* %st1, double* %st2) { +; CHECK-LABEL: @PR35618( +; CHECK-NEXT: [[Y1:%.*]] = alloca double, align 8 +; CHECK-NEXT: [[Z1:%.*]] = alloca double, align 8 +; CHECK-NEXT: [[LD1:%.*]] = load double, double* [[Y1]], align 8 +; CHECK-NEXT: [[LD2:%.*]] = load double, double* [[Z1]], align 8 +; CHECK-NEXT: [[TMP10:%.*]] = fcmp olt double [[LD1]], [[LD2]] +; CHECK-NEXT: [[TMP121:%.*]] = select i1 [[TMP10]], double [[LD1]], double [[LD2]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i64* [[ST1:%.*]] to double* +; CHECK-NEXT: store double [[TMP121]], double* [[TMP1]], align 8 +; CHECK-NEXT: store double [[TMP121]], double* [[ST2:%.*]], align 8 +; CHECK-NEXT: ret void +; + %y1 = alloca double + %z1 = alloca double + %ld1 = load double, double* %y1 + %ld2 = load double, double* %z1 + %tmp10 = fcmp olt double %ld1, %ld2 + %sel = select i1 %tmp10, double* %y1, double* %z1 + %tmp11 = bitcast double* %sel to i64* + %tmp12 = load i64, i64* %tmp11 + store i64 %tmp12, i64* %st1 + %bc = bitcast double* %st2 to i64* + store i64 %tmp12, i64* %bc + ret void +} +