-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[NVPTX] Eliminate prmt
s that result from BUILD_VECTOR
of LoadV2
#149581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[NVPTX] Eliminate prmt
s that result from BUILD_VECTOR
of LoadV2
#149581
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Justin Fargnoli (justinfargnoli) ChangesFull diff: https://github.com/llvm/llvm-project/pull/149581.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7aa06f9079b09..5f98b1a27617d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5772,7 +5772,8 @@ static SDValue PerformVSELECTCombine(SDNode *N,
}
static SDValue
-PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+PerformBUILD_VECTOROfV2i16Combine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
auto VT = N->getValueType(0);
if (!DCI.isAfterLegalizeDAG() ||
// only process v2*16 types
@@ -5833,6 +5834,80 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getBitcast(VT, PRMT);
}
+static SDValue
+PerformBUILD_VECTOROfTargetLoadCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ // Match: BUILD_VECTOR of v4i8, where first two elements are from a
+ // NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are
+ // zero constants. Replace with: zext the loaded i16 to i32, and return as a
+ // bitcast to v4i8.
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::v4i8)
+ return SDValue();
+ // Check operands: [0]=lo, [1]=hi
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ // Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or
+ // NVPTXISD::LDUV2
+ if (Op0.getNode() != Op1.getNode())
+ return SDValue();
+ if (!(Op0.getOpcode() == NVPTXISD::LoadV2 ||
+ Op0.getOpcode() == NVPTXISD::LDUV2))
+ return SDValue();
+ if (Op0.getValueType() != MVT::i16)
+ return SDValue();
+ if (!(Op0.hasOneUse() && Op1.hasOneUse()))
+ return SDValue();
+
+ // Check operands: [2]= 0 or undef, [3]= 0 or undef
+ SDValue Op2 = N->getOperand(2);
+ SDValue Op3 = N->getOperand(3);
+ if (Op2 != Op3)
+ return SDValue();
+ if (!Op2.isUndef()) {
+ auto *C2 = dyn_cast<ConstantSDNode>(Op2);
+ if (!(C2 && C2->isZero()))
+ return SDValue();
+ }
+
+ // Now, replace with: zext(load i16) -> i32, then bitcast to v4i8
+ auto &DAG = DCI.DAG;
+ // Rebuild the load as i16
+ auto *Load = cast<MemSDNode>(Op0.getNode());
+ SDLoc DL(Load);
+ SDValue LoadI16;
+ if (Load->getOpcode() == NVPTXISD::LoadV2) {
+ LoadI16 = DAG.getLoad(MVT::i16, DL, Load->getChain(), Load->getBasePtr(),
+ Load->getPointerInfo(), Load->getAlign(),
+ Load->getMemOperand()->getFlags());
+ } else {
+ assert(Load->getOpcode() == NVPTXISD::LDUV2);
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ SmallVector<SDValue, 4> Ops;
+ Ops.push_back(Load->getChain());
+ Ops.push_back(DAG.getConstant(Intrinsic::nvvm_ldu_global_i, DL,
+ TLI.getPointerTy(DAG.getDataLayout())));
+ for (unsigned i = 1; i < Load->getNumOperands(); ++i)
+ Ops.push_back(Load->getOperand(i));
+ SDVTList NodeVTList = DAG.getVTList(MVT::i16, MVT::Other);
+ LoadI16 = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, NodeVTList,
+ Ops, MVT::i16, Load->getPointerInfo(),
+ Load->getAlign());
+ }
+ DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 2), LoadI16.getValue(1));
+ SDValue Zext = DAG.getZExtOrTrunc(LoadI16, DL, MVT::i32);
+ return DAG.getBitcast(MVT::v4i8, Zext);
+}
+
+static SDValue
+PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+ if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI))
+ return V;
+ if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI))
+ return V;
+ return SDValue();
+}
+
static SDValue combineADDRSPACECAST(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);
diff --git a/llvm/test/CodeGen/NVPTX/build-vector-combine.ll b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
new file mode 100644
index 0000000000000..019bd3bde8761
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
@@ -0,0 +1,106 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+target datalayout = "e-p:64:64:64-p3:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-f128:128:128-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-a:8:8"
+target triple = "nvptx64-nvidia-cuda"
+
+define void @t1() {
+; CHECK-LABEL: t1(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: mov.b64 %rd1, 0;
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: st.global.v4.b32 [%rd1], {%r1, 0, 0, 0};
+; CHECK-NEXT: ret;
+entry:
+ %0 = load <2 x i8>, ptr addrspace(1) null, align 4
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ %2 = bitcast <4 x i8> %1 to i32
+ %3 = insertelement <4 x i32> zeroinitializer, i32 %2, i64 0
+ store <4 x i32> %3, ptr addrspace(1) null, align 16
+ ret void
+}
+
+define void @t2() {
+; CHECK-LABEL: t2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: mov.b64 %rd1, 0;
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: st.local.b32 [%rd1], %r1;
+; CHECK-NEXT: ret;
+entry:
+ %0 = load <2 x i8>, ptr addrspace(1) null, align 8
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ store <4 x i8> %1, ptr addrspace(5) null, align 8
+ ret void
+}
+
+declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)
+
+define void @ldg(ptr addrspace(1) %ptr) {
+; CHECK-LABEL: ldg(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b64 %rd1, [ldg_param_0];
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: mov.b64 %rd2, 0;
+; CHECK-NEXT: st.local.b32 [%rd2], %r1;
+; CHECK-NEXT: ret;
+entry:
+ %0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ store <4 x i8> %1, ptr addrspace(5) null, align 8
+ ret void
+}
+
+declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)
+
+define void @ldu(ptr addrspace(1) %ptr) {
+; CHECK-LABEL: ldu(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b64 %rd1, [ldu_param_0];
+; CHECK-NEXT: ldu.global.b16 %rs1, [%rd1];
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
+; CHECK-NEXT: mov.b64 %rd2, 0;
+; CHECK-NEXT: st.local.b32 [%rd2], %r1;
+; CHECK-NEXT: ret;
+entry:
+ %0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ store <4 x i8> %1, ptr addrspace(5) null, align 8
+ ret void
+}
+
+define void @t3() {
+; CHECK-LABEL: t3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: mov.b64 %rd1, 0;
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: st.global.v2.b32 [%rd1], {%r1, 0};
+; CHECK-NEXT: ret;
+ %1 = load <2 x i8>, ptr addrspace(1) null, align 2
+ %insval2 = bitcast <2 x i8> %1 to i16
+ %2 = insertelement <4 x i16> zeroinitializer, i16 %insval2, i32 0
+ store <4 x i16> %2, ptr addrspace(1) null, align 8
+ ret void
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces an optimization to eliminate unnecessary prmt
(permute) instructions in NVPTX code generation by handling BUILD_VECTOR
operations containing LoadV2
instructions more efficiently.
- Adds a new optimization pass that recognizes when
BUILD_VECTOR
operations are constructed from vector loads - Refactors existing
BUILD_VECTOR
combine logic to handle multiple optimization patterns - Replaces inefficient permute operations with direct zero-extension and bitcasting for specific load patterns
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | Implements the core optimization logic by adding PerformBUILD_VECTOROfTargetLoadCombine and refactoring existing combine functions |
llvm/test/CodeGen/NVPTX/build-vector-combine.ll | Adds comprehensive test cases covering various load types (regular, ldg, ldu) and vector patterns to verify the optimization |
Comments suppressed due to low confidence (2)
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:5904
- [nitpick] The variable name 'V' is not descriptive. Consider using a more meaningful name like 'Result' or 'CombinedValue' to improve code readability.
if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI))
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:5906
- [nitpick] The variable name 'V' is not descriptive. Consider using a more meaningful name like 'Result' or 'CombinedValue' to improve code readability.
if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI))
// Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or | ||
// NVPTXISD::LDUV2 | ||
if (Op0.getNode() != Op1.getNode()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition checks if Op0 and Op1 come from the same node, but the comment and function logic suggest they should be consecutive elements from a LoadV2. Consider adding a comment explaining why the same node check is sufficient, or verify that Op0 and Op1 are specifically the low and high bytes of the same load.
// Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or | |
// NVPTXISD::LDUV2 | |
if (Op0.getNode() != Op1.getNode()) | |
// Check that Op0 and Op1 are consecutive elements (low and high bytes) | |
// from the same NVPTXISD::LoadV2 or NVPTXISD::LDUV2 | |
if (Op0.getNode() != Op1.getNode() || Op0.getResNo() != 0 || Op1.getResNo() != 1) |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: we need to check that they're different values from the same node.
@AlexMaclean suggested extending this to |
// Match: BUILD_VECTOR of v4i8, where first two elements are from a | ||
// NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are | ||
// zero constants. Replace with: zext the loaded i16 to i32, and return as a | ||
// bitcast to v4i8. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears to be an oddly specific pattern with two specific elements loaded, and two specific constants. Is that a particularly common pattern? Where does it come from, if so?
// Match: BUILD_VECTOR of v4i8, where first two elements are from a | ||
// NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are | ||
// zero constants. Replace with: zext the loaded i16 to i32, and return as a | ||
// bitcast to v4i8. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a very specific and uncommon occurrence and the test cases look a bit contrived. Is this something we're going to see in real programs?
No description provided.