15#include "llvm/IR/IntrinsicsSPIRV.h"
37 if (
MI.getOpcode() != TargetOpcode::G_INTRINSIC ||
44 if (SubInstr->
getOpcode() != TargetOpcode::G_FSUB)
52 Register SubDestReg =
MI.getOperand(2).getReg();
56 Register ResultReg =
MI.getOperand(0).getReg();
59 Builder.buildIntrinsic(Intrinsic::spv_distance, ResultReg)
103 if (DotInstr->
getOpcode() != TargetOpcode::G_INTRINSIC ||
109 !
MRI.getType(DotOperand1).isScalar() ||
110 !
MRI.getType(DotOperand2).isScalar())
119 auto AreNegatedConstantsOrSplats = [&](
Register TrueReg,
Register FalseReg) {
120 std::optional<FPValueAndVReg> TrueVal, FalseVal;
124 APFloat TrueValNegated = TrueVal->Value;
131 std::optional<FPValueAndVReg> MulConstant;
134 if (TrueInstr->
getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
135 FalseInstr->
getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
153 if (!MulConstant || !MulConstant->Value.isExactlyValue(-1.0))
155 }
else if (!AreNegatedConstantsOrSplats(TrueReg, FalseReg))
172 if (DotInstr->
getOpcode() == TargetOpcode::G_FMUL) {
180 Register FalseReg =
MI.getOperand(3).getReg();
182 if (TrueInstr->
getOpcode() == TargetOpcode::G_FNEG ||
183 TrueInstr->
getOpcode() == TargetOpcode::G_FMUL)
187 Register ResultReg =
MI.getOperand(0).getReg();
189 Builder.buildIntrinsic(Intrinsic::spv_faceforward, ResultReg)
192 .addUse(DotOperand2);
196 auto RemoveAllUses = [&](
Register Reg) {
198 for (
auto &
UseMI :
MRI.use_instructions(Reg))
202 for (
auto *MIToErase : UsesToErase)
203 MIToErase->eraseFromParent();
206 RemoveAllUses(CondReg);
209 RemoveAllUses(DotReg);
212 RemoveAllUses(FalseReg);
218 return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
231 if (Rows == 1 || Cols == 1) {
232 Builder.buildCopy(ResReg, InReg);
233 MI.eraseFromParent();
238 for (
uint32_t K = 0; K < Rows * Cols; ++K) {
241 Mask.push_back(
C * Rows + R);
244 Builder.buildShuffleVector(ResReg, InReg, InReg, Mask);
245 MI.eraseFromParent();
249 return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
254SPIRVCombinerHelper::extractColumns(
Register MatrixReg,
uint32_t NumberOfCols,
258 if (NumberOfCols == 1)
263 for (
uint32_t J = 0; J < NumberOfCols; ++J)
265 Builder.buildUnmerge(Cols, MatrixReg);
273SPIRVCombinerHelper::extractRows(
Register MatrixReg, uint32_t NumRows,
283 for (uint32_t
I = 0;
I < NumRows; ++
I)
284 Rows.
push_back(
MRI.createGenericVirtualRegister(VecTy));
285 Builder.buildUnmerge(Rows, MatrixReg);
297 for (uint32_t
I = 0;
I < NumRows; ++
I) {
298 SmallVector<int, 4>
Mask;
299 for (uint32_t k = 0;
k < NumCols; ++
k)
300 Mask.push_back(k * NumRows +
I);
301 Rows.
push_back(
Builder.buildShuffleVector(VecTy, MatrixReg, MatrixReg, Mask)
313 bool IsVectorOp = SpvVecType->
getOpcode() == SPIRV::OpTypeVector;
315 bool IsFloatOp = SpvScalarType->
getOpcode() == SPIRV::OpTypeFloat;
321 Intrinsic::SPVIntrinsics DotIntrinsic =
322 (IsFloatOp ? Intrinsic::spv_fdot : Intrinsic::spv_udot);
323 DotRes =
Builder.buildIntrinsic(DotIntrinsic, {ScalarTy})
329 DotRes =
Builder.buildFMul(VecTy, RowA, ColB).getReg(0);
331 DotRes =
Builder.buildMul(VecTy, RowA, ColB).getReg(0);
342 SmallVector<Register, 16> ResultScalars;
343 for (uint32_t J = 0; J < ColsB.
size(); ++J) {
344 for (uint32_t
I = 0;
I < RowsA.
size(); ++
I) {
346 computeDotProduct(RowsA[
I], ColsB[J], SpvVecType, GR));
349 return ResultScalars;
353SPIRVCombinerHelper::getDotProductVectorType(
Register ResReg, uint32_t K,
356 Type *ScalarResType =
nullptr;
357 for (
auto &
UseMI :
MRI.use_instructions(ResReg)) {
358 if (
UseMI.getOpcode() != TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
377 SPIRV::AccessQualifier::None,
false);
384 uint32_t NumRowsA =
MI.getOperand(4).getImm();
385 uint32_t NumColsA =
MI.getOperand(5).getImm();
386 uint32_t NumColsB =
MI.getOperand(6).getImm();
393 SPIRVTypeInst SpvVecType = getDotProductVectorType(ResReg, NumColsA, GR);
395 extractColumns(BReg, NumColsB, SpvVecType, GR);
397 extractRows(AReg, NumRowsA, NumColsA, SpvVecType, GR);
399 computeDotProducts(RowsA, ColsB, SpvVecType, GR);
401 if (ResultScalars.
size() == 1)
402 Builder.buildCopy(ResReg, ResultScalars[0]);
404 Builder.buildBuildVector(ResReg, ResultScalars);
405 MI.eraseFromParent();
MachineInstrBuilder & UseMI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Declares convenience wrapper classes for interpreting MachineInstr instances as specific generic oper...
Contains matchers for matching SSA Machine Instructions.
Promote Memory to Register
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
@ FCMP_OLT
0 1 0 0 True if ordered and less than
@ FCMP_OGT
0 0 1 0 True if ordered and greater than
@ FCMP_ULT
1 1 0 0 True if unordered or less than
@ FCMP_UGT
1 0 1 0 True if unordered or greater than
MachineRegisterInfo & MRI
MachineDominatorTree * MDT
GISelChangeObserver & Observer
MachineIRBuilder & Builder
ConstantFP - Floating Point Values [float, double].
bool isZero() const
Return true if the value is positive or negative zero.
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
Abstract class that contains various methods for clients to notify about changes.
LLT getElementType() const
Returns the vector's element type. Only valid for vector types.
DominatorTree Class - Concrete subclass of DominatorTreeBase that is used to compute a normal dominat...
Helper class to build MachineInstr.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
unsigned getNumOperands() const
Retuns the total number of operands.
const MachineOperand & getOperand(unsigned i) const
LLVM_ABI MachineInstrBundleIterator< MachineInstr > eraseFromParent()
Unlink 'this' from the containing basic block and delete it.
Register getReg() const
getReg - Returns the register number.
const MachineFunction & getMF() const
LLVM_ABI Register createGenericVirtualRegister(LLT Ty, StringRef Name="")
Create and return a new generic virtual register with low-level type Ty.
Wrapper class representing virtual and physical registers.
void applyMatrixMultiply(MachineInstr &MI) const
bool matchSelectToFaceForward(MachineInstr &MI) const
This match is part of a combine that rewrites select(fcmp(dot(I, Ng), 0), N, -N) to faceforward(N,...
void applyMatrixTranspose(MachineInstr &MI) const
bool matchMatrixTranspose(MachineInstr &MI) const
LLVM_ABI CombinerHelper(GISelChangeObserver &Observer, MachineIRBuilder &B, bool IsPreLegalize, GISelValueTracking *VT=nullptr, MachineDominatorTree *MDT=nullptr, const LegalizerInfo *LI=nullptr)
void applySPIRVFaceForward(MachineInstr &MI) const
SPIRVCombinerHelper(GISelChangeObserver &Observer, MachineIRBuilder &B, bool IsPreLegalize, GISelValueTracking *VT, MachineDominatorTree *MDT, const LegalizerInfo *LI, const SPIRVSubtarget &STI)
bool matchMatrixMultiply(MachineInstr &MI) const
const SPIRVSubtarget & STI
void applySPIRVDistance(MachineInstr &MI) const
bool matchLengthToDistance(MachineInstr &MI) const
This match is part of a combine that rewrites length(X - Y) to distance(X, Y) (f32 (g_intrinsic lengt...
LLT getRegType(SPIRVTypeInst SpvType) const
void invalidateMachineInstr(MachineInstr *MI)
SPIRVTypeInst getScalarOrVectorComponentType(SPIRVTypeInst Type) const
SPIRVTypeInst getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool isVectorTy() const
True if this is an instance of VectorType.
bool isFloatingPointTy() const
Return true if this is one of the floating-point types.
bool isIntegerTy() const
True if this is an instance of IntegerType.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
operand_type_match m_Reg()
operand_type_match m_Pred()
TernaryOp_match< Src0Ty, Src1Ty, Src2Ty, TargetOpcode::G_SELECT > m_GISelect(const Src0Ty &Src0, const Src1Ty &Src1, const Src2Ty &Src2)
bool mi_match(Reg R, const MachineRegisterInfo &MRI, Pattern &&P)
SpecificRegisterMatch m_SpecificReg(Register RequestedReg)
Matches a register only if it is equal to RequestedReg.
UnaryOp_match< SrcTy, TargetOpcode::G_FNEG > m_GFNeg(const SrcTy &Src)
GFCstAndRegMatch m_GFCst(std::optional< FPValueAndVReg > &FPValReg)
GFCstOrSplatGFCstMatch m_GFCstOrSplat(std::optional< FPValueAndVReg > &FPValReg)
BinaryOp_match< LHS, RHS, TargetOpcode::G_FMUL, true > m_GFMul(const LHS &L, const RHS &R)
CompareOp_match< Pred, LHS, RHS, TargetOpcode::G_FCMP > m_GFCmp(const Pred &P, const LHS &L, const RHS &R)
This is an optimization pass for GlobalISel generic memory operations.
void setRegClassType(Register Reg, SPIRVTypeInst SpvType, SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI, const MachineFunction &MF, bool Force)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Type * getMDOperandAsType(const MDNode *N, unsigned I)
bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID)
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.