Skip to content

Commit 86203b6

Browse files
authored
[NVPTX] Use PRMT more widely, and improve folding around this instruction (#148261)
Replace uses of BFE with PRMT when lowering v4i8 vectors. This will generally lead to equivalent or better SASS and reduces the number of target specific operations we need to represent. (https://cuda.godbolt.org/z/M75W6f8xd) Also implement KnownBits tracking for PRMT allowing elimination of redundant AND instructions when lowering various i8 operations.
1 parent c384ec4 commit 86203b6

File tree

10 files changed

+653
-655
lines changed

10 files changed

+653
-655
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
#include "llvm/Support/CodeGen.h"
5858
#include "llvm/Support/CommandLine.h"
5959
#include "llvm/Support/ErrorHandling.h"
60+
#include "llvm/Support/KnownBits.h"
6061
#include "llvm/Support/NVPTXAddrSpace.h"
6162
#include "llvm/Support/raw_ostream.h"
6263
#include "llvm/Target/TargetMachine.h"
@@ -1087,7 +1088,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10871088
MAKE_CASE(NVPTXISD::StoreV8)
10881089
MAKE_CASE(NVPTXISD::FSHL_CLAMP)
10891090
MAKE_CASE(NVPTXISD::FSHR_CLAMP)
1090-
MAKE_CASE(NVPTXISD::BFE)
10911091
MAKE_CASE(NVPTXISD::BFI)
10921092
MAKE_CASE(NVPTXISD::PRMT)
10931093
MAKE_CASE(NVPTXISD::FCOPYSIGN)
@@ -2173,14 +2173,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
21732173
EVT VectorVT = Vector.getValueType();
21742174

21752175
if (VectorVT == MVT::v4i8) {
2176-
SDValue BFE =
2177-
DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
2178-
{Vector,
2179-
DAG.getNode(ISD::MUL, DL, MVT::i32,
2180-
DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2181-
DAG.getConstant(8, DL, MVT::i32)),
2182-
DAG.getConstant(8, DL, MVT::i32)});
2183-
return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
2176+
SDValue Selector = DAG.getNode(ISD::OR, DL, MVT::i32,
2177+
DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2178+
DAG.getConstant(0x7770, DL, MVT::i32));
2179+
SDValue PRMT = DAG.getNode(
2180+
NVPTXISD::PRMT, DL, MVT::i32,
2181+
{DAG.getBitcast(MVT::i32, Vector), DAG.getConstant(0, DL, MVT::i32),
2182+
Selector, DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
2183+
return DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
21842184
}
21852185

21862186
// Constant index will be matched by tablegen.
@@ -5271,31 +5271,6 @@ static SDValue PerformANDCombine(SDNode *N,
52715271

52725272
SDValue AExt;
52735273

5274-
// Convert BFE-> truncate i16 -> and 255
5275-
// To just BFE-> truncate i16, as the value already has all the bits in the
5276-
// right places.
5277-
if (Val.getOpcode() == ISD::TRUNCATE) {
5278-
SDValue BFE = Val.getOperand(0);
5279-
if (BFE.getOpcode() != NVPTXISD::BFE)
5280-
return SDValue();
5281-
5282-
ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(BFE.getOperand(0));
5283-
if (!BFEBits)
5284-
return SDValue();
5285-
uint64_t BFEBitsVal = BFEBits->getZExtValue();
5286-
5287-
ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
5288-
if (!MaskCnst) {
5289-
// Not an AND with a constant
5290-
return SDValue();
5291-
}
5292-
uint64_t MaskVal = MaskCnst->getZExtValue();
5293-
5294-
if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1)
5295-
return SDValue();
5296-
// If we get here, the AND is unnecessary. Just replace it with the trunc
5297-
DCI.CombineTo(N, Val, false);
5298-
}
52995274
// Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
53005275
if (Val.getOpcode() == ISD::ANY_EXTEND) {
53015276
AExt = Val;
@@ -6402,3 +6377,45 @@ MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal(
64026377
const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
64036378
return getDataSection();
64046379
}
6380+
6381+
static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
6382+
const SelectionDAG &DAG, unsigned Depth) {
6383+
SDValue A = Op.getOperand(0);
6384+
SDValue B = Op.getOperand(1);
6385+
ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2));
6386+
unsigned Mode = Op.getConstantOperandVal(3);
6387+
6388+
if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
6389+
return;
6390+
6391+
KnownBits AKnown = DAG.computeKnownBits(A, Depth);
6392+
KnownBits BKnown = DAG.computeKnownBits(B, Depth);
6393+
6394+
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6395+
KnownBits BitField = BKnown.concat(AKnown);
6396+
6397+
APInt SelectorVal = Selector->getAPIntValue();
6398+
for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) {
6399+
APInt Sel = SelectorVal.extractBits(4, I * 4);
6400+
unsigned Idx = Sel.getLoBits(3).getZExtValue();
6401+
unsigned Sign = Sel.getHiBits(1).getZExtValue();
6402+
KnownBits Byte = BitField.extractBits(8, Idx * 8);
6403+
if (Sign)
6404+
Byte = KnownBits::ashr(Byte, 8);
6405+
Known.insertBits(Byte, I * 8);
6406+
}
6407+
}
6408+
6409+
void NVPTXTargetLowering::computeKnownBitsForTargetNode(
6410+
const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
6411+
const SelectionDAG &DAG, unsigned Depth) const {
6412+
Known.resetAll();
6413+
6414+
switch (Op.getOpcode()) {
6415+
case NVPTXISD::PRMT:
6416+
computeKnownBitsForPRMT(Op, Known, DAG, Depth);
6417+
break;
6418+
default:
6419+
break;
6420+
}
6421+
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ enum NodeType : unsigned {
5050
MUL_WIDE_UNSIGNED,
5151
SETP_F16X2,
5252
SETP_BF16X2,
53-
BFE,
5453
BFI,
5554
PRMT,
5655

@@ -272,6 +271,11 @@ class NVPTXTargetLowering : public TargetLowering {
272271
unsigned getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
273272
EVT ToVT) const override;
274273

274+
void computeKnownBitsForTargetNode(const SDValue Op, KnownBits &Known,
275+
const APInt &DemandedElts,
276+
const SelectionDAG &DAG,
277+
unsigned Depth = 0) const override;
278+
275279
private:
276280
const NVPTXSubtarget &STI; // cache the subtarget here
277281
mutable unsigned GlobalUniqueCallSite;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,11 +1372,6 @@ def BREV64 :
13721372
// restriction in PTX?
13731373
//
13741374
// dest and src may be int32 or int64, but start and end are always int32.
1375-
def SDTBFE :
1376-
SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>,
1377-
SDTCisVT<2, i32>, SDTCisVT<3, i32>]>;
1378-
def bfe : SDNode<"NVPTXISD::BFE", SDTBFE>;
1379-
13801375
def SDTBFI :
13811376
SDTypeProfile<1, 4, [SDTCisInt<0>, SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>,
13821377
SDTCisVT<3, i32>, SDTCisVT<4, i32>]>;
@@ -1387,22 +1382,13 @@ def SDTPRMT :
13871382
SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>]>;
13881383
def prmt : SDNode<"NVPTXISD::PRMT", SDTPRMT>;
13891384

1390-
multiclass BFE<string Instr, ValueType T, RegisterClass RC> {
1385+
multiclass BFE<string Instr, RegisterClass RC> {
13911386
def rrr
1392-
: BasicNVPTXInst<(outs RC:$d),
1393-
(ins RC:$a, B32:$b, B32:$c),
1394-
Instr,
1395-
[(set T:$d, (bfe T:$a, i32:$b, i32:$c))]>;
1387+
: BasicNVPTXInst<(outs RC:$d), (ins RC:$a, B32:$b, B32:$c), Instr>;
13961388
def rri
1397-
: BasicNVPTXInst<(outs RC:$d),
1398-
(ins RC:$a, B32:$b, i32imm:$c),
1399-
Instr,
1400-
[(set T:$d, (bfe T:$a, i32:$b, imm:$c))]>;
1389+
: BasicNVPTXInst<(outs RC:$d), (ins RC:$a, B32:$b, i32imm:$c), Instr>;
14011390
def rii
1402-
: BasicNVPTXInst<(outs RC:$d),
1403-
(ins RC:$a, i32imm:$b, i32imm:$c),
1404-
Instr,
1405-
[(set T:$d, (bfe T:$a, imm:$b, imm:$c))]>;
1391+
: BasicNVPTXInst<(outs RC:$d), (ins RC:$a, i32imm:$b, i32imm:$c), Instr>;
14061392
}
14071393

14081394
multiclass BFI<string Instr, ValueType T, RegisterClass RC, Operand ImmCls> {
@@ -1447,10 +1433,10 @@ let hasSideEffects = false in {
14471433
// the same patterns, so the first one wins. Having unsigned byte extraction
14481434
// has the benefit of always having zero in unused bits, which makes some
14491435
// optimizations easier (e.g. no need to mask them).
1450-
defm BFE_U32 : BFE<"bfe.u32", i32, B32>;
1451-
defm BFE_S32 : BFE<"bfe.s32", i32, B32>;
1452-
defm BFE_U64 : BFE<"bfe.u64", i64, B64>;
1453-
defm BFE_S64 : BFE<"bfe.s64", i64, B64>;
1436+
defm BFE_U32 : BFE<"bfe.u32", B32>;
1437+
defm BFE_S32 : BFE<"bfe.s32", B32>;
1438+
defm BFE_U64 : BFE<"bfe.u64", B64>;
1439+
defm BFE_S64 : BFE<"bfe.s64", B64>;
14541440

14551441
defm BFI_B32 : BFI<"bfi.b32", i32, B32, i32imm>;
14561442
defm BFI_B64 : BFI<"bfi.b64", i64, B64, i64imm>;
@@ -1487,19 +1473,26 @@ def : Pat<(fshr i32:$hi, i32:$lo, (shl i32:$amt, (i32 3))),
14871473
(PRMT_B32rrr $lo, $hi, $amt, PrmtF4E)>;
14881474

14891475

1476+
def byte_extract_prmt : ImmLeaf<i32, [{
1477+
return (Imm == 0x7770) || (Imm == 0x7771) || (Imm == 0x7772) || (Imm == 0x7773);
1478+
}]>;
1479+
1480+
def to_sign_extend_selector : SDNodeXForm<imm, [{
1481+
const APInt &V = N->getAPIntValue();
1482+
const APInt B = V.trunc(4);
1483+
const APInt BSext = B | 8;
1484+
const APInt R = BSext.concat(BSext).concat(BSext).concat(B).zext(32);
1485+
return CurDAG->getTargetConstant(R, SDLoc(N), MVT::i32);
1486+
}]>;
1487+
1488+
14901489
// byte extraction + signed/unsigned extension to i32.
1491-
def : Pat<(i32 (sext_inreg (bfe i32:$s, i32:$o, 8), i8)),
1492-
(BFE_S32rri $s, $o, 8)>;
1493-
def : Pat<(i32 (sext_inreg (bfe i32:$s, imm:$o, 8), i8)),
1494-
(BFE_S32rii $s, imm:$o, 8)>;
1495-
def : Pat<(i32 (and (bfe i32:$s, i32:$o, 8), 255)),
1496-
(BFE_U32rri $s, $o, 8)>;
1497-
def : Pat<(i32 (and (bfe i32:$s, imm:$o, 8), 255)),
1498-
(BFE_U32rii $s, imm:$o, 8)>;
1490+
def : Pat<(i32 (sext_inreg (prmt i32:$s, 0, byte_extract_prmt:$sel, PrmtNONE), i8)),
1491+
(PRMT_B32rii $s, 0, (to_sign_extend_selector $sel), PrmtNONE)>;
14991492

15001493
// byte extraction + signed extension to i16
1501-
def : Pat<(i16 (sext_inreg (trunc (bfe i32:$s, imm:$o, 8)), i8)),
1502-
(CVT_s8_s32 (BFE_S32rii $s, imm:$o, 8), CvtNONE)>;
1494+
def : Pat<(i16 (sext_inreg (trunc (prmt i32:$s, 0, byte_extract_prmt:$sel, PrmtNONE)), i8)),
1495+
(CVT_u16_u32 (PRMT_B32rii $s, 0, (to_sign_extend_selector $sel), PrmtNONE), CvtNONE)>;
15031496

15041497

15051498
// Byte extraction via shift/trunc/sext
@@ -1709,28 +1702,36 @@ def cond_not_signed : PatLeaf<(cond), [{
17091702
return !isSignedIntSetCC(N->get());
17101703
}]>;
17111704

1712-
// comparisons of i8 extracted with BFE as i32
1713-
// It's faster to do comparison directly on i32 extracted by BFE,
1705+
// comparisons of i8 extracted with PRMT as i32
1706+
// It's faster to do comparison directly on i32 extracted by PRMT,
17141707
// instead of the long conversion and sign extending.
1715-
def: Pat<(setcc (i16 (sext_inreg (i16 (trunc (bfe B32:$a, B32:$oa, 8))), i8)),
1716-
(i16 (sext_inreg (i16 (trunc (bfe B32:$b, B32:$ob, 8))), i8)),
1708+
def: Pat<(setcc (i16 (sext_inreg (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), i8)),
1709+
(i16 (sext_inreg (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), i8)),
17171710
cond_signed:$cc),
1718-
(SETP_i32rr (BFE_S32rri $a, $oa, 8), (BFE_S32rri $b, $ob, 8), (cond2cc $cc))>;
1711+
(SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
1712+
(PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
1713+
(cond2cc $cc))>;
17191714

1720-
def: Pat<(setcc (i16 (sext_inreg (trunc (bfe B32:$a, imm:$oa, 8)), i8)),
1721-
(i16 (sext_inreg (trunc (bfe B32:$b, imm:$ob, 8)), i8)),
1715+
def: Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)),
1716+
(i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)),
17221717
cond_signed:$cc),
1723-
(SETP_i32rr (BFE_S32rii $a, imm:$oa, 8), (BFE_S32rii $b, imm:$ob, 8), (cond2cc $cc))>;
1718+
(SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
1719+
(PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
1720+
(cond2cc $cc))>;
17241721

1725-
def: Pat<(setcc (i16 (and (trunc (bfe B32:$a, B32:$oa, 8)), 255)),
1726-
(i16 (and (trunc (bfe B32:$b, B32:$ob, 8)), 255)),
1722+
def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
1723+
(i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
17271724
cond_signed:$cc),
1728-
(SETP_i32rr (BFE_U32rri $a, $oa, 8), (BFE_U32rri $b, $ob, 8), (cond2cc $cc))>;
1725+
(SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
1726+
(PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
1727+
(cond2cc $cc))>;
17291728

1730-
def: Pat<(setcc (i16 (and (trunc (bfe B32:$a, imm:$oa, 8)), 255)),
1731-
(i16 (and (trunc (bfe B32:$b, imm:$ob, 8)), 255)),
1729+
def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
1730+
(i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
17321731
cond_not_signed:$cc),
1733-
(SETP_i32rr (BFE_U32rii $a, imm:$oa, 8), (BFE_U32rii $b, imm:$ob, 8), (cond2cc $cc))>;
1732+
(SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
1733+
(PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
1734+
(cond2cc $cc))>;
17341735

17351736
def SDTDeclareArrayParam :
17361737
SDTypeProfile<0, 3, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>]>;

0 commit comments

Comments
 (0)