-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[IA] Use a single callback for lowerDeinterleaveIntrinsic [nfc] #148978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IA] Use a single callback for lowerDeinterleaveIntrinsic [nfc] #148978
Conversation
This essentially merges the handling for VPLoad - currently in lowerInterleavedVPLoad which is shared between shuffle and intrinsic based interleaves - into the existing dedicated routine. My plan is that if we like this factoring is that I'll do the same for the intrinsic store paths, and then remove the excess generality from the shuffle paths since we don't need to support both modes.
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-backend-risc-v Author: Philip Reames (preames) ChangesThis essentially merges the handling for VPLoad - currently in lowerInterleavedVPLoad which is shared between shuffle and intrinsic based interleaves - into the existing dedicated routine. My plan is that if we like this factoring is that I'll do the same for the intrinsic store paths, and then remove the excess generality from the shuffle paths since we don't need to support both modes in the shared VPLoad/Store callbacks. We can probably even fold the VP versions into the non-VP shuffle variants in the analogous way. Full diff: https://github.com/llvm/llvm-project/pull/148978.diff 6 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a248eb7444b20..72594c7f9783c 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3249,10 +3249,11 @@ class LLVM_ABI TargetLoweringBase {
/// Return true on success. Currently only supports
/// llvm.vector.deinterleave{2,3,5,7}
///
- /// \p LI is the accompanying load instruction.
+ /// \p Load is the accompanying load instruction. Can be either a plain load
+ /// instruction or a vp.load intrinsic.
/// \p DeinterleaveValues contains the deinterleaved values.
virtual bool
- lowerDeinterleaveIntrinsicToLoad(LoadInst *LI,
+ lowerDeinterleaveIntrinsicToLoad(Instruction *Load, Value *Mask,
ArrayRef<Value *> DeinterleaveValues) const {
return false;
}
diff --git a/llvm/lib/CodeGen/InterleavedAccessPass.cpp b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
index 7259834975cf4..95599837e1bfc 100644
--- a/llvm/lib/CodeGen/InterleavedAccessPass.cpp
+++ b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
@@ -634,24 +634,18 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
if (!LastFactor)
return false;
+ Value *Mask = nullptr;
if (auto *VPLoad = dyn_cast<VPIntrinsic>(LoadedVal)) {
if (VPLoad->getIntrinsicID() != Intrinsic::vp_load)
return false;
// Check mask operand. Handle both all-true/false and interleaved mask.
Value *WideMask = VPLoad->getOperand(1);
- Value *Mask =
- getMask(WideMask, Factor, cast<VectorType>(LastFactor->getType()));
+ Mask = getMask(WideMask, Factor, cast<VectorType>(LastFactor->getType()));
if (!Mask)
return false;
LLVM_DEBUG(dbgs() << "IA: Found a vp.load with deinterleave intrinsic "
<< *DI << " and factor = " << Factor << "\n");
-
- // Since lowerInterleaveLoad expects Shuffles and LoadInst, use special
- // TLI function to emit target-specific interleaved instruction.
- if (!TLI->lowerInterleavedVPLoad(VPLoad, Mask, DeinterleaveValues))
- return false;
-
} else {
auto *LI = cast<LoadInst>(LoadedVal);
if (!LI->isSimple())
@@ -659,12 +653,13 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
LLVM_DEBUG(dbgs() << "IA: Found a load with deinterleave intrinsic " << *DI
<< " and factor = " << Factor << "\n");
-
- // Try and match this with target specific intrinsics.
- if (!TLI->lowerDeinterleaveIntrinsicToLoad(LI, DeinterleaveValues))
- return false;
}
+ // Try and match this with target specific intrinsics.
+ if (!TLI->lowerDeinterleaveIntrinsicToLoad(cast<Instruction>(LoadedVal), Mask,
+ DeinterleaveValues))
+ return false;
+
for (Value *V : DeinterleaveValues)
if (V)
DeadInsts.insert(cast<Instruction>(V));
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index bde4ba993f69e..235df9022c6fb 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17476,12 +17476,17 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
}
bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
- LoadInst *LI, ArrayRef<Value *> DeinterleavedValues) const {
+ Instruction *Load, Value *Mask,
+ ArrayRef<Value *> DeinterleavedValues) const {
unsigned Factor = DeinterleavedValues.size();
if (Factor != 2 && Factor != 4) {
LLVM_DEBUG(dbgs() << "Matching ld2 and ld4 patterns failed\n");
return false;
}
+ auto *LI = dyn_cast<LoadInst>(Load);
+ if (!LI)
+ return false;
+ assert(!Mask && "Unexpected mask on a load\n");
Value *FirstActive = *llvm::find_if(DeinterleavedValues,
[](Value *V) { return V != nullptr; });
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 65fe08e92c235..6afb3c330d25b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -219,7 +219,8 @@ class AArch64TargetLowering : public TargetLowering {
unsigned Factor) const override;
bool lowerDeinterleaveIntrinsicToLoad(
- LoadInst *LI, ArrayRef<Value *> DeinterleaveValues) const override;
+ Instruction *Load, Value *Mask,
+ ArrayRef<Value *> DeinterleaveValues) const override;
bool lowerInterleaveIntrinsicToStore(
StoreInst *SI, ArrayRef<Value *> InterleaveValues) const override;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 00e969056df7d..41bbf6b9dcf2e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -438,7 +438,8 @@ class RISCVTargetLowering : public TargetLowering {
unsigned Factor) const override;
bool lowerDeinterleaveIntrinsicToLoad(
- LoadInst *LI, ArrayRef<Value *> DeinterleaveValues) const override;
+ Instruction *Load, Value *Mask,
+ ArrayRef<Value *> DeinterleaveValues) const override;
bool lowerInterleaveIntrinsicToStore(
StoreInst *SI, ArrayRef<Value *> InterleaveValues) const override;
diff --git a/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp b/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp
index 440857f831fa6..ba5bb4cc44bc8 100644
--- a/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp
@@ -234,53 +234,100 @@ bool RISCVTargetLowering::lowerInterleavedStore(StoreInst *SI,
return true;
}
+static bool isMultipleOfN(const Value *V, const DataLayout &DL, unsigned N) {
+ assert(N);
+ if (N == 1)
+ return true;
+
+ using namespace PatternMatch;
+ // Right now we're only recognizing the simplest pattern.
+ uint64_t C;
+ if (match(V, m_CombineOr(m_ConstantInt(C),
+ m_c_Mul(m_Value(), m_ConstantInt(C)))) &&
+ C && C % N == 0)
+ return true;
+
+ if (isPowerOf2_32(N)) {
+ KnownBits KB = llvm::computeKnownBits(V, DL);
+ return KB.countMinTrailingZeros() >= Log2_32(N);
+ }
+
+ return false;
+}
+
bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad(
- LoadInst *LI, ArrayRef<Value *> DeinterleaveValues) const {
+ Instruction *Load, Value *Mask,
+ ArrayRef<Value *> DeinterleaveValues) const {
const unsigned Factor = DeinterleaveValues.size();
if (Factor > 8)
return false;
- assert(LI->isSimple());
- IRBuilder<> Builder(LI);
+ IRBuilder<> Builder(Load);
Value *FirstActive =
*llvm::find_if(DeinterleaveValues, [](Value *V) { return V != nullptr; });
VectorType *ResVTy = cast<VectorType>(FirstActive->getType());
- const DataLayout &DL = LI->getDataLayout();
+ const DataLayout &DL = Load->getDataLayout();
+ auto *XLenTy = Type::getIntNTy(Load->getContext(), Subtarget.getXLen());
- if (!isLegalInterleavedAccessType(ResVTy, Factor, LI->getAlign(),
- LI->getPointerAddressSpace(), DL))
+ Value *Ptr, *VL;
+ Align Alignment;
+ if (auto *LI = dyn_cast<LoadInst>(Load)) {
+ assert(LI->isSimple());
+ Ptr = LI->getPointerOperand();
+ Alignment = LI->getAlign();
+ assert(!Mask && "Unexpected mask on a load\n");
+ Mask = Builder.getAllOnesMask(ResVTy->getElementCount());
+ VL = isa<FixedVectorType>(ResVTy)
+ ? Builder.CreateElementCount(XLenTy, ResVTy->getElementCount())
+ : Constant::getAllOnesValue(XLenTy);
+ } else {
+ auto *VPLoad = cast<VPIntrinsic>(Load);
+ assert(VPLoad->getIntrinsicID() == Intrinsic::vp_load &&
+ "Unexpected intrinsic");
+ Ptr = VPLoad->getArgOperand(0);
+ Alignment = VPLoad->getParamAlign(0).value_or(
+ DL.getABITypeAlign(ResVTy->getElementType()));
+
+ assert(Mask && "vp.load needs a mask!");
+
+ Value *WideEVL = VPLoad->getVectorLengthParam();
+ // Conservatively check if EVL is a multiple of factor, otherwise some
+ // (trailing) elements might be lost after the transformation.
+ if (!isMultipleOfN(WideEVL, Load->getDataLayout(), Factor))
+ return false;
+
+ VL = Builder.CreateZExt(
+ Builder.CreateUDiv(WideEVL,
+ ConstantInt::get(WideEVL->getType(), Factor)),
+ XLenTy);
+ }
+
+ Type *PtrTy = Ptr->getType();
+ unsigned AS = PtrTy->getPointerAddressSpace();
+ if (!isLegalInterleavedAccessType(ResVTy, Factor, Alignment, AS, DL))
return false;
Value *Return;
- Type *PtrTy = LI->getPointerOperandType();
- Type *XLenTy = Type::getIntNTy(LI->getContext(), Subtarget.getXLen());
-
if (isa<FixedVectorType>(ResVTy)) {
- Value *VL = Builder.CreateElementCount(XLenTy, ResVTy->getElementCount());
- Value *Mask = Builder.getAllOnesMask(ResVTy->getElementCount());
Return = Builder.CreateIntrinsic(FixedVlsegIntrIds[Factor - 2],
- {ResVTy, PtrTy, XLenTy},
- {LI->getPointerOperand(), Mask, VL});
+ {ResVTy, PtrTy, XLenTy}, {Ptr, Mask, VL});
} else {
unsigned SEW = DL.getTypeSizeInBits(ResVTy->getElementType());
unsigned NumElts = ResVTy->getElementCount().getKnownMinValue();
Type *VecTupTy = TargetExtType::get(
- LI->getContext(), "riscv.vector.tuple",
- ScalableVectorType::get(Type::getInt8Ty(LI->getContext()),
+ Load->getContext(), "riscv.vector.tuple",
+ ScalableVectorType::get(Type::getInt8Ty(Load->getContext()),
NumElts * SEW / 8),
Factor);
- Value *VL = Constant::getAllOnesValue(XLenTy);
- Value *Mask = Builder.getAllOnesMask(ResVTy->getElementCount());
-
Function *VlsegNFunc = Intrinsic::getOrInsertDeclaration(
- LI->getModule(), ScalableVlsegIntrIds[Factor - 2],
+ Load->getModule(), ScalableVlsegIntrIds[Factor - 2],
{VecTupTy, PtrTy, Mask->getType(), VL->getType()});
Value *Operands[] = {
PoisonValue::get(VecTupTy),
- LI->getPointerOperand(),
+ Ptr,
Mask,
VL,
ConstantInt::get(XLenTy,
@@ -290,7 +337,7 @@ bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad(
CallInst *Vlseg = Builder.CreateCall(VlsegNFunc, Operands);
SmallVector<Type *, 2> AggrTypes{Factor, ResVTy};
- Return = PoisonValue::get(StructType::get(LI->getContext(), AggrTypes));
+ Return = PoisonValue::get(StructType::get(Load->getContext(), AggrTypes));
for (unsigned i = 0; i < Factor; ++i) {
Value *VecExtract = Builder.CreateIntrinsic(
Intrinsic::riscv_tuple_extract, {ResVTy, VecTupTy},
@@ -370,27 +417,6 @@ bool RISCVTargetLowering::lowerInterleaveIntrinsicToStore(
return true;
}
-static bool isMultipleOfN(const Value *V, const DataLayout &DL, unsigned N) {
- assert(N);
- if (N == 1)
- return true;
-
- using namespace PatternMatch;
- // Right now we're only recognizing the simplest pattern.
- uint64_t C;
- if (match(V, m_CombineOr(m_ConstantInt(C),
- m_c_Mul(m_Value(), m_ConstantInt(C)))) &&
- C && C % N == 0)
- return true;
-
- if (isPowerOf2_32(N)) {
- KnownBits KB = llvm::computeKnownBits(V, DL);
- return KB.countMinTrailingZeros() >= Log2_32(N);
- }
-
- return false;
-}
-
/// Lower an interleaved vp.load into a vlsegN intrinsic.
///
/// E.g. Lower an interleaved vp.load (Factor = 2):
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I agree splitting the TLI callbacks by intrinsic / shufflevector rather than load / vp.load yields a cleaner interface in the long run.
// Right now we're only recognizing the simplest pattern. | ||
uint64_t C; | ||
if (match(V, m_CombineOr(m_ConstantInt(C), | ||
m_c_Mul(m_Value(), m_ConstantInt(C)))) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this patch since this was just moved from another location. Do we need to use m_c_Mul instead of m_MuL here? Constants should have been canonicalized to the right hand side by this point.
|
||
VL = Builder.CreateZExt( | ||
Builder.CreateUDiv(WideEVL, | ||
ConstantInt::get(WideEVL->getType(), Factor)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be an exact udiv?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It probably should be, but doing this revealed a deeper problem in the EVL recognition. We're not checking for overflow in the multiply check, and I believe that to be unsound. I am going to return to this, but want to get a few things off my queue first.
This essentially merges the handling for VPLoad - currently in lowerInterleavedVPLoad which is shared between shuffle and intrinsic based interleaves - into the existing dedicated routine.
My plan is that if we like this factoring is that I'll do the same for the intrinsic store paths, and then remove the excess generality from the shuffle paths since we don't need to support both modes in the shared VPLoad/Store callbacks. We can probably even fold the VP versions into the non-VP shuffle variants in the analogous way.