Skip to content

[NVPTX] Eliminate prmts that result from BUILD_VECTOR of LoadV2 #149581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

justinfargnoli
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Justin Fargnoli (justinfargnoli)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/149581.diff

2 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+76-1)
  • (added) llvm/test/CodeGen/NVPTX/build-vector-combine.ll (+106)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7aa06f9079b09..5f98b1a27617d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5772,7 +5772,8 @@ static SDValue PerformVSELECTCombine(SDNode *N,
 }
 
 static SDValue
-PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+PerformBUILD_VECTOROfV2i16Combine(SDNode *N,
+                                  TargetLowering::DAGCombinerInfo &DCI) {
   auto VT = N->getValueType(0);
   if (!DCI.isAfterLegalizeDAG() ||
       // only process v2*16 types
@@ -5833,6 +5834,80 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   return DAG.getBitcast(VT, PRMT);
 }
 
+static SDValue
+PerformBUILD_VECTOROfTargetLoadCombine(SDNode *N,
+                                       TargetLowering::DAGCombinerInfo &DCI) {
+  // Match: BUILD_VECTOR of v4i8, where first two elements are from a
+  // NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are
+  // zero constants. Replace with: zext the loaded i16 to i32, and return as a
+  // bitcast to v4i8.
+  EVT VT = N->getValueType(0);
+  if (VT != MVT::v4i8)
+    return SDValue();
+  // Check operands: [0]=lo, [1]=hi
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  // Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or
+  // NVPTXISD::LDUV2
+  if (Op0.getNode() != Op1.getNode())
+    return SDValue();
+  if (!(Op0.getOpcode() == NVPTXISD::LoadV2 ||
+        Op0.getOpcode() == NVPTXISD::LDUV2))
+    return SDValue();
+  if (Op0.getValueType() != MVT::i16)
+    return SDValue();
+  if (!(Op0.hasOneUse() && Op1.hasOneUse()))
+    return SDValue();
+
+  // Check operands: [2]= 0 or undef, [3]= 0 or undef
+  SDValue Op2 = N->getOperand(2);
+  SDValue Op3 = N->getOperand(3);
+  if (Op2 != Op3)
+    return SDValue();
+  if (!Op2.isUndef()) {
+    auto *C2 = dyn_cast<ConstantSDNode>(Op2);
+    if (!(C2 && C2->isZero()))
+      return SDValue();
+  }
+
+  // Now, replace with: zext(load i16) -> i32, then bitcast to v4i8
+  auto &DAG = DCI.DAG;
+  // Rebuild the load as i16
+  auto *Load = cast<MemSDNode>(Op0.getNode());
+  SDLoc DL(Load);
+  SDValue LoadI16;
+  if (Load->getOpcode() == NVPTXISD::LoadV2) {
+    LoadI16 = DAG.getLoad(MVT::i16, DL, Load->getChain(), Load->getBasePtr(),
+                          Load->getPointerInfo(), Load->getAlign(),
+                          Load->getMemOperand()->getFlags());
+  } else {
+    assert(Load->getOpcode() == NVPTXISD::LDUV2);
+    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+    SmallVector<SDValue, 4> Ops;
+    Ops.push_back(Load->getChain());
+    Ops.push_back(DAG.getConstant(Intrinsic::nvvm_ldu_global_i, DL,
+                                  TLI.getPointerTy(DAG.getDataLayout())));
+    for (unsigned i = 1; i < Load->getNumOperands(); ++i)
+      Ops.push_back(Load->getOperand(i));
+    SDVTList NodeVTList = DAG.getVTList(MVT::i16, MVT::Other);
+    LoadI16 = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, NodeVTList,
+                                      Ops, MVT::i16, Load->getPointerInfo(),
+                                      Load->getAlign());
+  }
+  DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 2), LoadI16.getValue(1));
+  SDValue Zext = DAG.getZExtOrTrunc(LoadI16, DL, MVT::i32);
+  return DAG.getBitcast(MVT::v4i8, Zext);
+}
+
+static SDValue
+PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+  if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI))
+    return V;
+  if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI))
+    return V;
+  return SDValue();
+}
+
 static SDValue combineADDRSPACECAST(SDNode *N,
                                     TargetLowering::DAGCombinerInfo &DCI) {
   auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);
diff --git a/llvm/test/CodeGen/NVPTX/build-vector-combine.ll b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
new file mode 100644
index 0000000000000..019bd3bde8761
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
@@ -0,0 +1,106 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+target datalayout = "e-p:64:64:64-p3:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-f128:128:128-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-a:8:8"
+target triple = "nvptx64-nvidia-cuda"
+
+define void @t1() {
+; CHECK-LABEL: t1(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    mov.b64 %rd1, 0;
+; CHECK-NEXT:    ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT:    st.global.v4.b32 [%rd1], {%r1, 0, 0, 0};
+; CHECK-NEXT:    ret;
+entry:
+  %0 = load <2 x i8>, ptr addrspace(1) null, align 4
+  %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %2 = bitcast <4 x i8> %1 to i32
+  %3 = insertelement <4 x i32> zeroinitializer, i32 %2, i64 0
+  store <4 x i32> %3, ptr addrspace(1) null, align 16
+  ret void
+}
+
+define void @t2() {
+; CHECK-LABEL: t2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    mov.b64 %rd1, 0;
+; CHECK-NEXT:    ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT:    st.local.b32 [%rd1], %r1;
+; CHECK-NEXT:    ret;
+entry:
+  %0 = load <2 x i8>, ptr addrspace(1) null, align 8
+  %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  store <4 x i8> %1, ptr addrspace(5) null, align 8
+  ret void
+}
+
+declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)
+
+define void @ldg(ptr addrspace(1) %ptr) {
+; CHECK-LABEL: ldg(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b64 %rd1, [ldg_param_0];
+; CHECK-NEXT:    ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT:    mov.b64 %rd2, 0;
+; CHECK-NEXT:    st.local.b32 [%rd2], %r1;
+; CHECK-NEXT:    ret;
+entry:
+  %0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
+  %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  store <4 x i8> %1, ptr addrspace(5) null, align 8
+  ret void
+}
+
+declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)
+
+define void @ldu(ptr addrspace(1) %ptr) {
+; CHECK-LABEL: ldu(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b64 %rd1, [ldu_param_0];
+; CHECK-NEXT:    ldu.global.b16 %rs1, [%rd1];
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs1;
+; CHECK-NEXT:    mov.b64 %rd2, 0;
+; CHECK-NEXT:    st.local.b32 [%rd2], %r1;
+; CHECK-NEXT:    ret;
+entry:
+  %0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
+  %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  store <4 x i8> %1, ptr addrspace(5) null, align 8
+  ret void
+}
+
+define void @t3() {
+; CHECK-LABEL: t3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.b64 %rd1, 0;
+; CHECK-NEXT:    ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT:    st.global.v2.b32 [%rd1], {%r1, 0};
+; CHECK-NEXT:    ret;
+  %1 = load <2 x i8>, ptr addrspace(1) null, align 2
+  %insval2 = bitcast <2 x i8> %1 to i16
+  %2 = insertelement <4 x i16> zeroinitializer, i16 %insval2, i32 0
+  store <4 x i16> %2, ptr addrspace(1) null, align 8
+  ret void
+}

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces an optimization to eliminate unnecessary prmt (permute) instructions in NVPTX code generation by handling BUILD_VECTOR operations containing LoadV2 instructions more efficiently.

  • Adds a new optimization pass that recognizes when BUILD_VECTOR operations are constructed from vector loads
  • Refactors existing BUILD_VECTOR combine logic to handle multiple optimization patterns
  • Replaces inefficient permute operations with direct zero-extension and bitcasting for specific load patterns

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Implements the core optimization logic by adding PerformBUILD_VECTOROfTargetLoadCombine and refactoring existing combine functions
llvm/test/CodeGen/NVPTX/build-vector-combine.ll Adds comprehensive test cases covering various load types (regular, ldg, ldu) and vector patterns to verify the optimization
Comments suppressed due to low confidence (2)

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:5904

  • [nitpick] The variable name 'V' is not descriptive. Consider using a more meaningful name like 'Result' or 'CombinedValue' to improve code readability.
  if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI))

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:5906

  • [nitpick] The variable name 'V' is not descriptive. Consider using a more meaningful name like 'Result' or 'CombinedValue' to improve code readability.
  if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI))

Comment on lines +5850 to +5852
// Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or
// NVPTXISD::LDUV2
if (Op0.getNode() != Op1.getNode())
Copy link
Preview

Copilot AI Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition checks if Op0 and Op1 come from the same node, but the comment and function logic suggest they should be consecutive elements from a LoadV2. Consider adding a comment explaining why the same node check is sufficient, or verify that Op0 and Op1 are specifically the low and high bytes of the same load.

Suggested change
// Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or
// NVPTXISD::LDUV2
if (Op0.getNode() != Op1.getNode())
// Check that Op0 and Op1 are consecutive elements (low and high bytes)
// from the same NVPTXISD::LoadV2 or NVPTXISD::LDUV2
if (Op0.getNode() != Op1.getNode() || Op0.getResNo() != 0 || Op1.getResNo() != 1)

Copilot uses AI. Check for mistakes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: we need to check that they're different values from the same node.

@justinfargnoli
Copy link
Contributor Author

@AlexMaclean suggested extending this to NVPTXISD::LoadV*. I'd love to do this, but I wasn't able to create a test case that would be impacted. If anyone's able to make one, please let me know!

Comment on lines +5840 to +5843
// Match: BUILD_VECTOR of v4i8, where first two elements are from a
// NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are
// zero constants. Replace with: zext the loaded i16 to i32, and return as a
// bitcast to v4i8.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to be an oddly specific pattern with two specific elements loaded, and two specific constants. Is that a particularly common pattern? Where does it come from, if so?

Comment on lines +5840 to +5843
// Match: BUILD_VECTOR of v4i8, where first two elements are from a
// NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are
// zero constants. Replace with: zext the loaded i16 to i32, and return as a
// bitcast to v4i8.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a very specific and uncommon occurrence and the test cases look a bit contrived. Is this something we're going to see in real programs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants