LLVM 23.0.0git
SPIRVCombinerHelper.cpp
Go to the documentation of this file.
1//===-- SPIRVCombinerHelper.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "SPIRVGlobalRegistry.h"
11#include "SPIRVUtils.h"
15#include "llvm/IR/IntrinsicsSPIRV.h"
16#include "llvm/IR/LLVMContext.h" // Explicitly include for LLVMContext
18
19using namespace llvm;
20using namespace MIPatternMatch;
21
27
28/// This match is part of a combine that
29/// rewrites length(X - Y) to distance(X, Y)
30/// (f32 (g_intrinsic length
31/// (g_fsub (vXf32 X) (vXf32 Y))))
32/// ->
33/// (f32 (g_intrinsic distance
34/// (vXf32 X) (vXf32 Y)))
35///
37 if (MI.getOpcode() != TargetOpcode::G_INTRINSIC ||
38 cast<GIntrinsic>(MI).getIntrinsicID() != Intrinsic::spv_length)
39 return false;
40
41 // First operand of MI is `G_INTRINSIC` so start at operand 2.
42 Register SubReg = MI.getOperand(2).getReg();
43 MachineInstr *SubInstr = MRI.getVRegDef(SubReg);
44 if (SubInstr->getOpcode() != TargetOpcode::G_FSUB)
45 return false;
46
47 return true;
48}
49
51 // Extract the operands for X and Y from the match criteria.
52 Register SubDestReg = MI.getOperand(2).getReg();
53 MachineInstr *SubInstr = MRI.getVRegDef(SubDestReg);
54 Register SubOperand1 = SubInstr->getOperand(1).getReg();
55 Register SubOperand2 = SubInstr->getOperand(2).getReg();
56 Register ResultReg = MI.getOperand(0).getReg();
57
58 Builder.setInstrAndDebugLoc(MI);
59 Builder.buildIntrinsic(Intrinsic::spv_distance, ResultReg)
60 .addUse(SubOperand1)
61 .addUse(SubOperand2);
62
63 MI.eraseFromParent();
64}
65
66/// This match is part of a combine that
67/// rewrites select(fcmp(dot(I, Ng), 0), N, -N) to faceforward(N, I, Ng)
68/// (vXf32 (g_select
69/// (g_fcmp
70/// (g_intrinsic dot(vXf32 I) (vXf32 Ng)
71/// 0)
72/// (vXf32 N)
73/// (vXf32 g_fneg (vXf32 N))))
74/// ->
75/// (vXf32 (g_intrinsic faceforward
76/// (vXf32 N) (vXf32 I) (vXf32 Ng)))
77///
78/// This only works for Vulkan shader targets.
79///
81 if (!STI.isShader())
82 return false;
83
84 // Match overall select pattern.
85 Register CondReg, TrueReg, FalseReg;
86 if (!mi_match(MI.getOperand(0).getReg(), MRI,
87 m_GISelect(m_Reg(CondReg), m_Reg(TrueReg), m_Reg(FalseReg))))
88 return false;
89
90 // Match the FCMP condition.
91 Register DotReg, CondZeroReg;
93 if (!mi_match(CondReg, MRI,
94 m_GFCmp(m_Pred(Pred), m_Reg(DotReg), m_Reg(CondZeroReg))))
95 return false;
96 if (Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT)
97 std::swap(DotReg, CondZeroReg);
98 else if (!(Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULT))
99 return false;
100
101 // Check if FCMP is a comparison between a dot product and 0.
102 MachineInstr *DotInstr = MRI.getVRegDef(DotReg);
103 if (DotInstr->getOpcode() != TargetOpcode::G_INTRINSIC ||
104 cast<GIntrinsic>(DotInstr)->getIntrinsicID() != Intrinsic::spv_fdot) {
105 Register DotOperand1, DotOperand2;
106 // Check for scalar dot product.
107 if (!mi_match(DotReg, MRI,
108 m_GFMul(m_Reg(DotOperand1), m_Reg(DotOperand2))) ||
109 !MRI.getType(DotOperand1).isScalar() ||
110 !MRI.getType(DotOperand2).isScalar())
111 return false;
112 }
113
114 const ConstantFP *ZeroVal;
115 if (!mi_match(CondZeroReg, MRI, m_GFCst(ZeroVal)) || !ZeroVal->isZero())
116 return false;
117
118 // Check if select's false operand is the negation of the true operand.
119 auto AreNegatedConstantsOrSplats = [&](Register TrueReg, Register FalseReg) {
120 std::optional<FPValueAndVReg> TrueVal, FalseVal;
121 if (!mi_match(TrueReg, MRI, m_GFCstOrSplat(TrueVal)) ||
122 !mi_match(FalseReg, MRI, m_GFCstOrSplat(FalseVal)))
123 return false;
124 APFloat TrueValNegated = TrueVal->Value;
125 TrueValNegated.changeSign();
126 return FalseVal->Value.compare(TrueValNegated) == APFloat::cmpEqual;
127 };
128
129 if (!mi_match(TrueReg, MRI, m_GFNeg(m_SpecificReg(FalseReg))) &&
130 !mi_match(FalseReg, MRI, m_GFNeg(m_SpecificReg(TrueReg)))) {
131 std::optional<FPValueAndVReg> MulConstant;
132 MachineInstr *TrueInstr = MRI.getVRegDef(TrueReg);
133 MachineInstr *FalseInstr = MRI.getVRegDef(FalseReg);
134 if (TrueInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
135 FalseInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
136 TrueInstr->getNumOperands() == FalseInstr->getNumOperands()) {
137 for (unsigned I = 1; I < TrueInstr->getNumOperands(); ++I)
138 if (!AreNegatedConstantsOrSplats(TrueInstr->getOperand(I).getReg(),
139 FalseInstr->getOperand(I).getReg()))
140 return false;
141 } else if (mi_match(TrueReg, MRI,
142 m_GFMul(m_SpecificReg(FalseReg),
143 m_GFCstOrSplat(MulConstant))) ||
144 mi_match(FalseReg, MRI,
145 m_GFMul(m_SpecificReg(TrueReg),
146 m_GFCstOrSplat(MulConstant))) ||
147 mi_match(TrueReg, MRI,
148 m_GFMul(m_GFCstOrSplat(MulConstant),
149 m_SpecificReg(FalseReg))) ||
150 mi_match(FalseReg, MRI,
151 m_GFMul(m_GFCstOrSplat(MulConstant),
152 m_SpecificReg(TrueReg)))) {
153 if (!MulConstant || !MulConstant->Value.isExactlyValue(-1.0))
154 return false;
155 } else if (!AreNegatedConstantsOrSplats(TrueReg, FalseReg))
156 return false;
157 }
158
159 return true;
160}
161
163 // Extract the operands for N, I, and Ng from the match criteria.
164 Register CondReg = MI.getOperand(1).getReg();
165 MachineInstr *CondInstr = MRI.getVRegDef(CondReg);
166 Register DotReg = CondInstr->getOperand(2).getReg();
167 CmpInst::Predicate Pred = cast<GFCmp>(CondInstr)->getCond();
168 if (Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT)
169 DotReg = CondInstr->getOperand(3).getReg();
170 MachineInstr *DotInstr = MRI.getVRegDef(DotReg);
171 Register DotOperand1, DotOperand2;
172 if (DotInstr->getOpcode() == TargetOpcode::G_FMUL) {
173 DotOperand1 = DotInstr->getOperand(1).getReg();
174 DotOperand2 = DotInstr->getOperand(2).getReg();
175 } else {
176 DotOperand1 = DotInstr->getOperand(2).getReg();
177 DotOperand2 = DotInstr->getOperand(3).getReg();
178 }
179 Register TrueReg = MI.getOperand(2).getReg();
180 Register FalseReg = MI.getOperand(3).getReg();
181 MachineInstr *TrueInstr = MRI.getVRegDef(TrueReg);
182 if (TrueInstr->getOpcode() == TargetOpcode::G_FNEG ||
183 TrueInstr->getOpcode() == TargetOpcode::G_FMUL)
184 std::swap(TrueReg, FalseReg);
185 MachineInstr *FalseInstr = MRI.getVRegDef(FalseReg);
186
187 Register ResultReg = MI.getOperand(0).getReg();
188 Builder.setInstrAndDebugLoc(MI);
189 Builder.buildIntrinsic(Intrinsic::spv_faceforward, ResultReg)
190 .addUse(TrueReg) // N
191 .addUse(DotOperand1) // I
192 .addUse(DotOperand2); // Ng
193
195 MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
196 auto RemoveAllUses = [&](Register Reg) {
198 for (auto &UseMI : MRI.use_instructions(Reg))
199 UsesToErase.push_back(&UseMI);
200
201 // calling eraseFromParent to early invalidates the iterator.
202 for (auto *MIToErase : UsesToErase)
203 MIToErase->eraseFromParent();
204 };
205
206 RemoveAllUses(CondReg); // remove all uses of FCMP Result
207 GR->invalidateMachineInstr(CondInstr);
208 CondInstr->eraseFromParent(); // remove FCMP instruction
209 RemoveAllUses(DotReg); // remove all uses of spv_fdot/G_FMUL Result
210 GR->invalidateMachineInstr(DotInstr);
211 DotInstr->eraseFromParent(); // remove spv_fdot/G_FMUL instruction
212 RemoveAllUses(FalseReg);
213 GR->invalidateMachineInstr(FalseInstr);
214 FalseInstr->eraseFromParent();
215}
216
218 return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
219 cast<GIntrinsic>(MI).getIntrinsicID() == Intrinsic::matrix_transpose;
220}
221
223 Register ResReg = MI.getOperand(0).getReg();
224 Register InReg = MI.getOperand(2).getReg();
225 uint32_t Rows = MI.getOperand(3).getImm();
226 uint32_t Cols = MI.getOperand(4).getImm();
227
228 Builder.setInstrAndDebugLoc(MI);
229
230 // A 1xN or Nx1 transpose is a pure reshape.
231 if (Rows == 1 || Cols == 1) {
232 Builder.buildCopy(ResReg, InReg);
233 MI.eraseFromParent();
234 return;
235 }
236
238 for (uint32_t K = 0; K < Rows * Cols; ++K) {
239 uint32_t R = K / Cols;
240 uint32_t C = K % Cols;
241 Mask.push_back(C * Rows + R);
242 }
243
244 Builder.buildShuffleVector(ResReg, InReg, InReg, Mask);
245 MI.eraseFromParent();
246}
247
249 return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
250 cast<GIntrinsic>(MI).getIntrinsicID() == Intrinsic::matrix_multiply;
251}
252
254SPIRVCombinerHelper::extractColumns(Register MatrixReg, uint32_t NumberOfCols,
255 SPIRVTypeInst SpvColType,
256 SPIRVGlobalRegistry *GR) const {
257 // If the matrix is a single colunm, return that single column.
258 if (NumberOfCols == 1)
259 return {MatrixReg};
260
262 LLT ColTy = GR->getRegType(SpvColType);
263 for (uint32_t J = 0; J < NumberOfCols; ++J)
265 Builder.buildUnmerge(Cols, MatrixReg);
266 for (Register R : Cols) {
267 setRegClassType(R, SpvColType, GR, &MRI, Builder.getMF());
268 }
269 return Cols;
270}
271
273SPIRVCombinerHelper::extractRows(Register MatrixReg, uint32_t NumRows,
274 uint32_t NumCols, SPIRVTypeInst SpvRowType,
275 SPIRVGlobalRegistry *GR) const {
277 LLT VecTy = GR->getRegType(SpvRowType);
278
279 // If there is only one column, then each row is a scalar that needs
280 // to be extracted.
281 if (NumCols == 1) {
282 assert(SpvRowType->getOpcode() != SPIRV::OpTypeVector);
283 for (uint32_t I = 0; I < NumRows; ++I)
284 Rows.push_back(MRI.createGenericVirtualRegister(VecTy));
285 Builder.buildUnmerge(Rows, MatrixReg);
286 for (Register R : Rows) {
287 setRegClassType(R, SpvRowType, GR, &MRI, Builder.getMF());
288 }
289 return Rows;
290 }
291
292 // If the matrix is a single row return that row.
293 if (NumRows == 1) {
294 return {MatrixReg};
295 }
296
297 for (uint32_t I = 0; I < NumRows; ++I) {
298 SmallVector<int, 4> Mask;
299 for (uint32_t k = 0; k < NumCols; ++k)
300 Mask.push_back(k * NumRows + I);
301 Rows.push_back(Builder.buildShuffleVector(VecTy, MatrixReg, MatrixReg, Mask)
302 .getReg(0));
303 }
304 for (Register R : Rows) {
305 setRegClassType(R, SpvRowType, GR, &MRI, Builder.getMF());
306 }
307 return Rows;
308}
309
310Register SPIRVCombinerHelper::computeDotProduct(Register RowA, Register ColB,
311 SPIRVTypeInst SpvVecType,
312 SPIRVGlobalRegistry *GR) const {
313 bool IsVectorOp = SpvVecType->getOpcode() == SPIRV::OpTypeVector;
314 SPIRVTypeInst SpvScalarType = GR->getScalarOrVectorComponentType(SpvVecType);
315 bool IsFloatOp = SpvScalarType->getOpcode() == SPIRV::OpTypeFloat;
316 LLT VecTy = GR->getRegType(SpvVecType);
317
318 Register DotRes;
319 if (IsVectorOp) {
320 LLT ScalarTy = VecTy.getElementType();
321 Intrinsic::SPVIntrinsics DotIntrinsic =
322 (IsFloatOp ? Intrinsic::spv_fdot : Intrinsic::spv_udot);
323 DotRes = Builder.buildIntrinsic(DotIntrinsic, {ScalarTy})
324 .addUse(RowA)
325 .addUse(ColB)
326 .getReg(0);
327 } else {
328 if (IsFloatOp)
329 DotRes = Builder.buildFMul(VecTy, RowA, ColB).getReg(0);
330 else
331 DotRes = Builder.buildMul(VecTy, RowA, ColB).getReg(0);
332 }
333 setRegClassType(DotRes, SpvScalarType, GR, &MRI, Builder.getMF());
334 return DotRes;
335}
336
338SPIRVCombinerHelper::computeDotProducts(const SmallVector<Register, 4> &RowsA,
339 const SmallVector<Register, 4> &ColsB,
340 SPIRVTypeInst SpvVecType,
341 SPIRVGlobalRegistry *GR) const {
342 SmallVector<Register, 16> ResultScalars;
343 for (uint32_t J = 0; J < ColsB.size(); ++J) {
344 for (uint32_t I = 0; I < RowsA.size(); ++I) {
345 ResultScalars.push_back(
346 computeDotProduct(RowsA[I], ColsB[J], SpvVecType, GR));
347 }
348 }
349 return ResultScalars;
350}
351
353SPIRVCombinerHelper::getDotProductVectorType(Register ResReg, uint32_t K,
354 SPIRVGlobalRegistry *GR) const {
355 // Loop over all non debug uses of ResReg
356 Type *ScalarResType = nullptr;
357 for (auto &UseMI : MRI.use_instructions(ResReg)) {
358 if (UseMI.getOpcode() != TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
359 continue;
360
361 if (!isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type))
362 continue;
363
364 Type *Ty = getMDOperandAsType(UseMI.getOperand(2).getMetadata(), 0);
365 if (Ty->isVectorTy())
366 ScalarResType = cast<VectorType>(Ty)->getElementType();
367 else
368 ScalarResType = Ty;
369 assert(ScalarResType->isIntegerTy() || ScalarResType->isFloatingPointTy());
370 break;
371 }
372 if (!ScalarResType)
373 llvm_unreachable("Could not determine scalar result type");
374 Type *VecType =
375 (K > 1 ? FixedVectorType::get(ScalarResType, K) : ScalarResType);
376 return GR->getOrCreateSPIRVType(VecType, Builder,
377 SPIRV::AccessQualifier::None, false);
378}
379
381 Register ResReg = MI.getOperand(0).getReg();
382 Register AReg = MI.getOperand(2).getReg();
383 Register BReg = MI.getOperand(3).getReg();
384 uint32_t NumRowsA = MI.getOperand(4).getImm();
385 uint32_t NumColsA = MI.getOperand(5).getImm();
386 uint32_t NumColsB = MI.getOperand(6).getImm();
387
388 Builder.setInstrAndDebugLoc(MI);
389
391 MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
392
393 SPIRVTypeInst SpvVecType = getDotProductVectorType(ResReg, NumColsA, GR);
395 extractColumns(BReg, NumColsB, SpvVecType, GR);
397 extractRows(AReg, NumRowsA, NumColsA, SpvVecType, GR);
398 SmallVector<Register, 16> ResultScalars =
399 computeDotProducts(RowsA, ColsB, SpvVecType, GR);
400
401 if (ResultScalars.size() == 1)
402 Builder.buildCopy(ResReg, ResultScalars[0]);
403 else
404 Builder.buildBuildVector(ResReg, ResultScalars);
405 MI.eraseFromParent();
406}
MachineInstrBuilder & UseMI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Declares convenience wrapper classes for interpreting MachineInstr instances as specific generic oper...
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition MD5.cpp:57
Contains matchers for matching SSA Machine Instructions.
Promote Memory to Register
Definition Mem2Reg.cpp:110
void changeSign()
Definition APFloat.h:1356
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:740
@ FCMP_OLT
0 1 0 0 True if ordered and less than
Definition InstrTypes.h:746
@ FCMP_OGT
0 0 1 0 True if ordered and greater than
Definition InstrTypes.h:744
@ FCMP_ULT
1 1 0 0 True if unordered or less than
Definition InstrTypes.h:754
@ FCMP_UGT
1 0 1 0 True if unordered or greater than
Definition InstrTypes.h:752
MachineRegisterInfo & MRI
const LegalizerInfo * LI
MachineDominatorTree * MDT
GISelValueTracking * VT
GISelChangeObserver & Observer
MachineIRBuilder & Builder
ConstantFP - Floating Point Values [float, double].
Definition Constants.h:420
bool isZero() const
Return true if the value is positive or negative zero.
Definition Constants.h:467
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition Type.cpp:869
Abstract class that contains various methods for clients to notify about changes.
LLT getElementType() const
Returns the vector's element type. Only valid for vector types.
DominatorTree Class - Concrete subclass of DominatorTreeBase that is used to compute a normal dominat...
Helper class to build MachineInstr.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
unsigned getNumOperands() const
Retuns the total number of operands.
const MachineOperand & getOperand(unsigned i) const
LLVM_ABI MachineInstrBundleIterator< MachineInstr > eraseFromParent()
Unlink 'this' from the containing basic block and delete it.
Register getReg() const
getReg - Returns the register number.
const MachineFunction & getMF() const
LLVM_ABI Register createGenericVirtualRegister(LLT Ty, StringRef Name="")
Create and return a new generic virtual register with low-level type Ty.
Wrapper class representing virtual and physical registers.
Definition Register.h:20
void applyMatrixMultiply(MachineInstr &MI) const
bool matchSelectToFaceForward(MachineInstr &MI) const
This match is part of a combine that rewrites select(fcmp(dot(I, Ng), 0), N, -N) to faceforward(N,...
void applyMatrixTranspose(MachineInstr &MI) const
bool matchMatrixTranspose(MachineInstr &MI) const
LLVM_ABI CombinerHelper(GISelChangeObserver &Observer, MachineIRBuilder &B, bool IsPreLegalize, GISelValueTracking *VT=nullptr, MachineDominatorTree *MDT=nullptr, const LegalizerInfo *LI=nullptr)
void applySPIRVFaceForward(MachineInstr &MI) const
SPIRVCombinerHelper(GISelChangeObserver &Observer, MachineIRBuilder &B, bool IsPreLegalize, GISelValueTracking *VT, MachineDominatorTree *MDT, const LegalizerInfo *LI, const SPIRVSubtarget &STI)
bool matchMatrixMultiply(MachineInstr &MI) const
const SPIRVSubtarget & STI
void applySPIRVDistance(MachineInstr &MI) const
bool matchLengthToDistance(MachineInstr &MI) const
This match is part of a combine that rewrites length(X - Y) to distance(X, Y) (f32 (g_intrinsic lengt...
LLT getRegType(SPIRVTypeInst SpvType) const
void invalidateMachineInstr(MachineInstr *MI)
SPIRVTypeInst getScalarOrVectorComponentType(SPIRVTypeInst Type) const
SPIRVTypeInst getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:288
bool isFloatingPointTy() const
Return true if this is one of the floating-point types.
Definition Type.h:186
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
operand_type_match m_Reg()
operand_type_match m_Pred()
TernaryOp_match< Src0Ty, Src1Ty, Src2Ty, TargetOpcode::G_SELECT > m_GISelect(const Src0Ty &Src0, const Src1Ty &Src1, const Src2Ty &Src2)
bool mi_match(Reg R, const MachineRegisterInfo &MRI, Pattern &&P)
SpecificRegisterMatch m_SpecificReg(Register RequestedReg)
Matches a register only if it is equal to RequestedReg.
UnaryOp_match< SrcTy, TargetOpcode::G_FNEG > m_GFNeg(const SrcTy &Src)
GFCstAndRegMatch m_GFCst(std::optional< FPValueAndVReg > &FPValReg)
GFCstOrSplatGFCstMatch m_GFCstOrSplat(std::optional< FPValueAndVReg > &FPValReg)
BinaryOp_match< LHS, RHS, TargetOpcode::G_FMUL, true > m_GFMul(const LHS &L, const RHS &R)
CompareOp_match< Pred, LHS, RHS, TargetOpcode::G_FCMP > m_GFCmp(const Pred &P, const LHS &L, const RHS &R)
This is an optimization pass for GlobalISel generic memory operations.
void setRegClassType(Register Reg, SPIRVTypeInst SpvType, SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI, const MachineFunction &MF, bool Force)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
Type * getMDOperandAsType(const MDNode *N, unsigned I)
bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID)
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:862