Skip to content

[NVPTX] Add PRMT constant folding and cleanup usage of PRMT node #148906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 17, 2025

Conversation

AlexMaclean
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Patch is 130.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148906.diff

5 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+138-23)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+19-4)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (-18)
  • (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+1758-872)
  • (added) llvm/test/CodeGen/NVPTX/prmt-const-folding.ll (+171)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 14f05250ad6b8..e8f3b322ed90e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1048,9 +1048,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
                       MVT::v32i32, MVT::v64i32, MVT::v128i32},
                      Custom);
 
-  setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
-  // Enable custom lowering for the i128 bit operand with clusterlaunchcontrol
-  setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i128, Custom);
+  // Enable custom lowering for the following:
+  //   * MVT::i128 - clusterlaunchcontrol
+  //   * MVT::i32 - prmt
+  //   * MVT::Other - internal.addrspace.wrap
+  setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},
+                     Custom);
 }
 
 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -2060,6 +2063,13 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
+static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
+                       SelectionDAG &DAG,
+                       unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
+  return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32,
+                     {A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)});
+}
+
 SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
   // Handle bitcasting from v2i8 without hitting the default promotion
   // strategy which goes through stack memory.
@@ -2111,15 +2121,13 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
         L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
         R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
       }
-      return DAG.getNode(
-          NVPTXISD::PRMT, DL, MVT::v4i8,
-          {L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),
-           DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
+      return getPRMT(L, R, DAG.getConstant(SelectionValue, DL, MVT::i32), DL,
+                     DAG);
     };
     auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
     auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
     auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
-    return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);
+    return DAG.getBitcast(VT, PRMT3210);
   }
 
   // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
@@ -2176,11 +2184,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
     SDValue Selector = DAG.getNode(ISD::OR, DL, MVT::i32,
                                    DAG.getZExtOrTrunc(Index, DL, MVT::i32),
                                    DAG.getConstant(0x7770, DL, MVT::i32));
-    SDValue PRMT = DAG.getNode(
-        NVPTXISD::PRMT, DL, MVT::i32,
-        {DAG.getBitcast(MVT::i32, Vector), DAG.getConstant(0, DL, MVT::i32),
-         Selector, DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
-    return DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
+    SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, Vector),
+                           DAG.getConstant(0, DL, MVT::i32), Selector, DL, DAG);
+    SDValue Ext = DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
+    SDNodeFlags Flags;
+    Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
+    Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
+    Ext->setFlags(Flags);
+    return Ext;
   }
 
   // Constant index will be matched by tablegen.
@@ -2242,9 +2253,10 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
   }
 
   SDLoc DL(Op);
-  return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
-                     DAG.getConstant(Selector, DL, MVT::i32),
-                     DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
+  SDValue PRMT =
+      getPRMT(DAG.getBitcast(MVT::i32, V1), DAG.getBitcast(MVT::i32, V2),
+              DAG.getConstant(Selector, DL, MVT::i32), DL, DAG);
+  return DAG.getBitcast(Op.getValueType(), PRMT);
 }
 /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
@@ -2729,10 +2741,46 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
                      {TryCancelResponse0, TryCancelResponse1});
 }
 
+static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
+  const unsigned Mode = [&]() {
+    switch (Op->getConstantOperandVal(0)) {
+    case Intrinsic::nvvm_prmt:
+      return NVPTX::PTXPrmtMode::NONE;
+    case Intrinsic::nvvm_prmt_b4e:
+      return NVPTX::PTXPrmtMode::B4E;
+    case Intrinsic::nvvm_prmt_ecl:
+      return NVPTX::PTXPrmtMode::ECL;
+    case Intrinsic::nvvm_prmt_ecr:
+      return NVPTX::PTXPrmtMode::ECR;
+    case Intrinsic::nvvm_prmt_f4e:
+      return NVPTX::PTXPrmtMode::F4E;
+    case Intrinsic::nvvm_prmt_rc16:
+      return NVPTX::PTXPrmtMode::RC16;
+    case Intrinsic::nvvm_prmt_rc8:
+      return NVPTX::PTXPrmtMode::RC8;
+    default:
+      llvm_unreachable("unsupported/unhandled intrinsic");
+    }
+  }();
+  SDLoc DL(Op);
+  SDValue A = Op->getOperand(1);
+  SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(2)
+                                       : DAG.getConstant(0, DL, MVT::i32);
+  SDValue Selector = (Op->op_end() - 1)->get();
+  return getPRMT(A, B, Selector, DL, DAG, Mode);
+}
 static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
   switch (Op->getConstantOperandVal(0)) {
   default:
     return Op;
+  case Intrinsic::nvvm_prmt:
+  case Intrinsic::nvvm_prmt_b4e:
+  case Intrinsic::nvvm_prmt_ecl:
+  case Intrinsic::nvvm_prmt_ecr:
+  case Intrinsic::nvvm_prmt_f4e:
+  case Intrinsic::nvvm_prmt_rc16:
+  case Intrinsic::nvvm_prmt_rc8:
+    return lowerPrmtIntrinsic(Op, DAG);
   case Intrinsic::nvvm_internal_addrspace_wrap:
     return Op.getOperand(1);
   case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
@@ -5775,11 +5823,10 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   SDLoc DL(N);
   auto &DAG = DCI.DAG;
 
-  auto PRMT = DAG.getNode(
-      NVPTXISD::PRMT, DL, MVT::v4i8,
-      {Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32),
-       DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
-  return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
+  auto PRMT = getPRMT(
+      DAG.getBitcast(MVT::i32, Op0), DAG.getBitcast(MVT::i32, Op1),
+      DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32), DL, DAG);
+  return DAG.getBitcast(VT, PRMT);
 }
 
 static SDValue combineADDRSPACECAST(SDNode *N,
@@ -5797,6 +5844,72 @@ static SDValue combineADDRSPACECAST(SDNode *N,
   return SDValue();
 }
 
+static APInt getPRMTSelector(APInt Selector, unsigned Mode) {
+  if (Mode == NVPTX::PTXPrmtMode::NONE)
+    return Selector;
+
+  unsigned V = Selector.trunc(2).getZExtValue();
+
+  const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
+                              unsigned S3) {
+    return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
+  };
+
+  switch (Mode) {
+  case NVPTX::PTXPrmtMode::F4E:
+    return GetSelector(V, V + 1, V + 2, V + 3);
+  case NVPTX::PTXPrmtMode::B4E:
+    return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
+  case NVPTX::PTXPrmtMode::RC8:
+    return GetSelector(V, V, V, V);
+  case NVPTX::PTXPrmtMode::ECL:
+    return GetSelector(V, std::max(V, 1U), std::max(V, 2U), 3U);
+  case NVPTX::PTXPrmtMode::ECR:
+    return GetSelector(0, std::min(V, 1U), std::min(V, 2U), V);
+  case NVPTX::PTXPrmtMode::RC16: {
+    unsigned V1 = (V & 1) << 1;
+    return GetSelector(V1, V1 + 1, V1, V1 + 1);
+  }
+  default:
+    llvm_unreachable("Invalid PRMT mode");
+  }
+}
+
+static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
+  // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
+  APInt BitField = B.concat(A);
+  APInt SelectorVal = getPRMTSelector(Selector, Mode);
+  APInt Result(32, 0);
+  for (unsigned I : llvm::seq(4U)) {
+    APInt Sel = SelectorVal.extractBits(4, I * 4);
+    unsigned Idx = Sel.getLoBits(3).getZExtValue();
+    unsigned Sign = Sel.getHiBits(1).getZExtValue();
+    APInt Byte = BitField.extractBits(8, Idx * 8);
+    if (Sign)
+      Byte = Byte.ashr(8);
+    Result.insertBits(Byte, I * 8);
+  }
+  return Result;
+}
+
+static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                           CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  // Constant fold PRMT
+  if (isa<ConstantSDNode>(N->getOperand(0)) &&
+      isa<ConstantSDNode>(N->getOperand(1)) &&
+      isa<ConstantSDNode>(N->getOperand(2)))
+    return DCI.DAG.getConstant(computePRMT(N->getConstantOperandAPInt(0),
+                                           N->getConstantOperandAPInt(1),
+                                           N->getConstantOperandAPInt(2),
+                                           N->getConstantOperandVal(3)),
+                               SDLoc(N), N->getValueType(0));
+
+  return SDValue();
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5838,6 +5951,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
       return PerformBUILD_VECTORCombine(N, DCI);
     case ISD::ADDRSPACECAST:
       return combineADDRSPACECAST(N, DCI);
+    case NVPTXISD::PRMT:
+      return combinePRMT(N, DCI, OptLevel);
   }
   return SDValue();
 }
@@ -6385,7 +6500,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
   ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2));
   unsigned Mode = Op.getConstantOperandVal(3);
 
-  if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
+  if (!Selector)
     return;
 
   KnownBits AKnown = DAG.computeKnownBits(A, Depth);
@@ -6394,7 +6509,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
   // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
   KnownBits BitField = BKnown.concat(AKnown);
 
-  APInt SelectorVal = Selector->getAPIntValue();
+  APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
   for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) {
     APInt Sel = SelectorVal.extractBits(4, I * 4);
     unsigned Idx = Sel.getLoBits(3).getZExtValue();
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index ecae03e77aa83..6741ccbb43abc 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1453,18 +1453,33 @@ let hasSideEffects = false in {
                 (ins PrmtMode:$mode),
                 "prmt.b32$mode",
                 [(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>;
+  def PRMT_B32rir
+  : BasicFlagsNVPTXInst<(outs B32:$d),
+              (ins B32:$a, i32imm:$b, B32:$c),
+              (ins PrmtMode:$mode),
+              "prmt.b32$mode",
+              [(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
   def PRMT_B32rii
     : BasicFlagsNVPTXInst<(outs B32:$d),
                 (ins B32:$a, i32imm:$b, Hexu32imm:$c),
                 (ins PrmtMode:$mode),
                 "prmt.b32$mode",
                 [(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>;
-  def PRMT_B32rir
+  def PRMT_B32irr
     : BasicFlagsNVPTXInst<(outs B32:$d),
-                (ins B32:$a, i32imm:$b, B32:$c),
-                (ins PrmtMode:$mode),
+                (ins i32imm:$a, B32:$b, B32:$c), (ins PrmtMode:$mode),
+                "prmt.b32$mode",
+                [(set i32:$d, (prmt imm:$a, i32:$b, i32:$c, imm:$mode))]>;
+  def PRMT_B32iri
+    : BasicFlagsNVPTXInst<(outs B32:$d),
+                (ins i32imm:$a, B32:$b, Hexu32imm:$c), (ins PrmtMode:$mode),
+                "prmt.b32$mode",
+                [(set i32:$d, (prmt imm:$a, i32:$b, imm:$c, imm:$mode))]>;
+  def PRMT_B32iir
+    : BasicFlagsNVPTXInst<(outs B32:$d),
+                (ins i32imm:$a, i32imm:$b, B32:$c), (ins PrmtMode:$mode),
                 "prmt.b32$mode",
-                [(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
+                [(set i32:$d, (prmt imm:$a, imm:$b, i32:$c, imm:$mode))]>;
 
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 93827be5c2811..bdddf3f56cb13 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1007,24 +1007,6 @@ class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass,
 // MISC
 //
 
-class PRMT3Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
-    : Pat<(prmt_intrinsic i32:$a, i32:$b, i32:$c),
-          (PRMT_B32rrr $a, $b, $c, prmt_mode)>;
-
-class PRMT2Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
-    : Pat<(prmt_intrinsic i32:$a, i32:$c),
-          (PRMT_B32rir $a, (i32 0), $c, prmt_mode)>;
-
-def : PRMT3Pat<int_nvvm_prmt,      PrmtNONE>;
-def : PRMT3Pat<int_nvvm_prmt_f4e,  PrmtF4E>;
-def : PRMT3Pat<int_nvvm_prmt_b4e,  PrmtB4E>;
-
-def : PRMT2Pat<int_nvvm_prmt_rc8,  PrmtRC8>;
-def : PRMT2Pat<int_nvvm_prmt_ecl,  PrmtECL>;
-def : PRMT2Pat<int_nvvm_prmt_ecr,  PrmtECR>;
-def : PRMT2Pat<int_nvvm_prmt_rc16, PrmtRC16>;
-
-
 def INT_NVVM_NANOSLEEP_I : BasicNVPTXInst<(outs), (ins i32imm:$i), "nanosleep.u32",
                              [(int_nvvm_nanosleep imm:$i)]>,
         Requires<[hasPTX<63>, hasSM<70>]>;
diff --git a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
index 410c0019c7222..cbc9f700b1f01 100644
--- a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
@@ -1,14 +1,19 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
 ; ## Support i16x2 instructions
-; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 \
-; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
-; RUN: | FileCheck -allow-deprecated-dag-overlap %s
-; RUN: %if ptxas %{                                                           \
-; RUN:   llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 \
-; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
-; RUN:   | %ptxas-verify -arch=sm_90                                          \
+; RUN: llc < %s -mcpu=sm_90 -mattr=+ptx80 -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
+; RUN: | FileCheck %s --check-prefixes=CHECK,O0
+; RUN: llc < %s -mcpu=sm_90 -mattr=+ptx80 -verify-machineinstrs \
+; RUN: | FileCheck %s --check-prefixes=CHECK,O3
+; RUN: %if ptxas %{                                                            \
+; RUN:   llc < %s -mcpu=sm_90 -mattr=+ptx80 -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
+; RUN:   | %ptxas-verify -arch=sm_90                                           \
+; RUN: %}
+; RUN: %if ptxas %{                                                            \
+; RUN:   llc < %s -mcpu=sm_90 -mattr=+ptx80 -verify-machineinstrs              \
+; RUN:   | %ptxas-verify -arch=sm_90                                           \
 ; RUN: %}
 
+target triple = "nvptx64-nvidia-cuda"
 target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
 
 define <4 x i8> @test_ret_const() #0 {
@@ -79,61 +84,111 @@ define i8 @test_extract_3(<4 x i8> %a) #0 {
 }
 
 define i8 @test_extract_i(<4 x i8> %a, i64 %idx) #0 {
-; CHECK-LABEL: test_extract_i(
-; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<5>;
-; CHECK-NEXT:    .reg .b64 %rd<2>;
-; CHECK-EMPTY:
-; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd1, [test_extract_i_param_1];
-; CHECK-NEXT:    ld.param.b32 %r1, [test_extract_i_param_0];
-; CHECK-NEXT:    cvt.u32.u64 %r2, %rd1;
-; CHECK-NEXT:    or.b32 %r3, %r2, 30576;
-; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, %r3;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
-; CHECK-NEXT:    ret;
+; O0-LABEL: test_extract_i(
+; O0:       {
+; O0-NEXT:    .reg .b32 %r<5>;
+; O0-NEXT:    .reg .b64 %rd<2>;
+; O0-EMPTY:
+; O0-NEXT:  // %bb.0:
+; O0-NEXT:    ld.param.b64 %rd1, [test_extract_i_param_1];
+; O0-NEXT:    ld.param.b32 %r1, [test_extract_i_param_0];
+; O0-NEXT:    cvt.u32.u64 %r2, %rd1;
+; O0-NEXT:    or.b32 %r3, %r2, 30576;
+; O0-NEXT:    prmt.b32 %r4, %r1, 0, %r3;
+; O0-NEXT:    st.param.b32 [func_retval0], %r4;
+; O0-NEXT:    ret;
+;
+; O3-LABEL: test_extract_i(
+; O3:       {
+; O3-NEXT:    .reg .b32 %r<5>;
+; O3-EMPTY:
+; O3-NEXT:  // %bb.0:
+; O3-NEXT:    ld.param.b32 %r1, [test_extract_i_param_0];
+; O3-NEXT:    ld.param.b32 %r2, [test_extract_i_param_1];
+; O3-NEXT:    or.b32 %r3, %r2, 30576;
+; O3-NEXT:    prmt.b32 %r4, %r1, 0, %r3;
+; O3-NEXT:    st.param.b32 [func_retval0], %r4;
+; O3-NEXT:    ret;
   %e = extractelement <4 x i8> %a, i64 %idx
   ret i8 %e
 }
 
 define <4 x i8> @test_add(<4 x i8> %a, <4 x i8> %b) #0 {
-; CHECK-LABEL: test_add(
-; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<13>;
-; CHECK-NEXT:    .reg .b32 %r<18>;
-; CHECK-EMPTY:
-; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b32 %r2, [test_add_param_1];
-; CHECK-NEXT:    ld.param.b32 %r1, [test_add_param_0];
-; CHECK-NEXT:    prmt.b32 %r3, %r2, 0, 0x7773U;
-; CHECK-NEXT:    cvt.u16.u32 %rs1, %r3;
-; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, 0x7773U;
-; CHECK-NEXT:    cvt.u16.u32 %rs2, %r4;
-; CHECK-NEXT:    add.s16 %rs3, %rs2, %rs1;
-; CHECK-NEXT:    cvt.u32.u16 %r5, %rs3;
-; CHECK-NEXT:    prmt.b32 %r6, %r2, 0, 0x7772U;
-; CHECK-NEXT:    cvt.u16.u32 %rs4, %r6;
-; CHECK-NEXT:    prmt.b32 %r7, %r1, 0, 0x7772U;
-; CHECK-NEXT:    cvt.u16.u32 %rs5, %r7;
-; CHECK-NEXT:    add.s16 %rs6, %rs5, %rs4;
-; CHECK-NEXT:    cvt.u32.u16 %r8, %rs6;
-; CHECK-NEXT:    prmt.b32 %r9, %r8, %r5, 0x3340U;
-; CHECK-NEXT:    prmt.b32 %r10, %r2, 0, 0x7771U;
-; CHECK-NEXT:    cvt.u16.u32 %rs7, %r10;
-; CHECK-NEXT:    prmt.b32 %r11, %r1, 0, 0x7771U;
-; CHECK-NEXT:    cvt.u16.u32 %rs8, %r11;
-; CHECK-NEXT:    add.s16 %rs9, %rs8, %rs7;
-; CHECK-NEXT:    cvt.u32.u16 %r12, %rs9;
-; CHECK-NEXT:    prmt.b32 %r13, %r2, 0, 0x7770U;
-; CHECK-NEXT:    cvt.u16.u32 %rs10, %r13;
-; CHECK-NEXT:    prmt.b32 %r14, %r1, 0, 0x7770U;
-; CHECK-NEXT:    cvt.u16.u32 %rs11, %r14;
-; CHECK-NEXT:    add.s16 %rs12, %rs11, %rs10;
-; CHECK-NEXT:    cvt.u32.u16 %r15, %rs12;
-; CHECK-NEXT:    prmt.b32 %r16, %r15, %r12, 0x3340U;
-; CHECK-NEXT:    prmt.b32 %r17, %r16, %r9, 0x5410U;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r17;
-; CHECK-NEXT:    ret;
+; O0-LABEL: test_add(
+; O0:       {
+; O0-NEXT:    .reg .b16 %rs<13>;
+; O0-NEXT:    .reg .b32 %r<18>;
+; O0-EMPTY:
+; O0-NEXT:  // %bb.0:
+; O0-NEXT:    ld.param.b32 %r2, [test_add_param_1];
+; O0-NEXT:    ld.param.b32 %r1, [test_add_param_0];
+; O0-NEXT:    prmt.b32 %r3, %r2, 0, 0x7773U;
+; O0-NEXT:    cvt.u16.u32 %rs1, %r3;
+; O0-NEXT:    prmt.b32 %r4, %r1, 0, 0x7773U;
+; O0-NEXT:    cvt.u16.u32 %rs2, %r4;
+; O0-NEXT:    add.s16 %rs3, %rs2, %rs1;
+; O0-NEXT:    cvt.u32.u16 %r5, %rs3;
+; O0-NEXT:    prmt.b32 %r6, %r2, 0, 0x7772U;
+; O0-NEXT:    cvt.u16.u32 %rs4, %r6;
+; O0-NEXT:    prmt.b32 %r7, %r1, 0, 0x7772U;
+; O0-NEXT:    cvt.u16.u32 %rs5, %r7;
+; O0-NEXT:    add.s16 %rs6, %rs5, %rs4;
+; O0-NEXT:    cvt.u32.u16 %r8, %rs6;
+; O0-NEXT:    prmt.b32 %r9, %r8, %r5, 0x3340U;
+; O0-NEXT:    prmt.b32 %r10, %r2, 0, 0x7771U;
+; O0-NEXT:    cvt.u16.u32 %rs7, %r10;
+; O0-NEXT:    prmt.b32 %r11, %r1, 0, 0x7771U;
+; O0-NEXT:    cvt.u16.u32 %rs8, %r11;
+; O0-NEXT:    add.s16 %rs9, %rs8, %rs7;
+; O0-NEXT:    cvt.u32.u16 %r12, %rs9;
+; O0-NEXT:    prmt.b32 %r13, %r2, 0, 0x7770U;
+; O0-NEXT:    cvt.u16.u32 %rs10, %r13;
+; O0-NEXT:    prmt.b32 %r14, %r1, 0, 0x7770U;
+; O0-NEXT:    cvt.u16.u32 %rs11, %r14;
+; O0-NEXT:    add.s16 %rs12, %rs11, %rs10;
+; O0-NEXT:    cvt.u32.u16 %r15, %rs12;
+; O0-NEXT:    prmt.b32 %r16, %r15, %r12, 0x3340U;
+; O0-NEXT:    prmt.b32 %r17, %r16, %r9, 0x5410U;
+; O0-NEXT:    st.param.b32 [func_retval0], %r17;
+; O0-NEXT:    ret;
+;
+; O3-LABEL: test_add(
+; O3:       {
+; O3-NEXT:    .reg .b16 %rs<13>;
+; O3-NEXT:    .reg .b32 %r<18>;
+; O3-EMPTY:
+; O3-NEXT:  // %bb.0:
+; O3-NEXT:    ld.param.b32 %r1, [test_add_param_0];
+; O3-NEXT:    ld.param.b32 %r2, [test_add_param_1];
+; O3-NEXT:    prmt.b32 %r3, %r2, 0, 0x7773U;
+; O3...
[truncated]

Copy link

github-actions bot commented Jul 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/prmt-2 branch from 931aecd to 43ba2f9 Compare July 15, 2025 17:44
@AlexMaclean AlexMaclean requested a review from kalxr July 16, 2025 16:10
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a few nits.

Comment on lines +2066 to +2068
static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
SelectionDAG &DAG,
unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add another overload with Selector provided as an integer. That seems to be a common pattern that forces us to sprinkle DAG.getConstant(X, DL, MVT::i32) in numerous places.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

@@ -5797,47 +5844,116 @@ static SDValue combineADDRSPACECAST(SDNode *N,
return SDValue();
}

static APInt getPRMTSelector(APInt Selector, unsigned Mode) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: a pointer to the selector encoding docs would be useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Comment on lines 5955 to 5956
case NVPTXISD::PRMT:
return combinePRMT(N, DCI, OptLevel);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Perhaps it's a good opportunity to sort the case values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/prmt-2 branch from 43ba2f9 to 6ccb74c Compare July 17, 2025 16:50
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/prmt-2 branch from 6ccb74c to 81a2e32 Compare July 17, 2025 16:58
@AlexMaclean AlexMaclean merged commit f480e1b into llvm:main Jul 17, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants