51using namespace PatternMatch;
53#define DEBUG_TYPE "lower-matrix-intrinsics"
57 cl::desc(
"Enable/disable fusing matrix instructions."));
62 "Tile size for matrix instruction fusion using square-shaped tiles."));
65 cl::desc(
"Generate loop nest for tiling."));
68 cl::desc(
"Force matrix instruction fusion even if not profitable."));
71 cl::desc(
"Allow the use of FMAs if available and profitable. This may "
72 "result in different results, due to less rounding error."));
76 cl::desc(
"Enable/disable matrix shape verification."),
83 cl::desc(
"Sets the default matrix layout"),
85 "Use column-major layout"),
87 "Use row-major layout")));
95 if (
auto *Subprogram = dyn_cast<DISubprogram>(Scope))
104 auto *Inst = cast<Instruction>(V);
106 if (!Inst->use_empty())
108 if (
II != BB.
rend() && Inst == &*
II)
110 Inst->eraseFromParent();
116 if (
auto *SV = dyn_cast<ShuffleVectorInst>(V))
117 return SV->isZeroEltSplat();
122template <
typename LTy,
typename RTy>
128template <
typename LTy,
typename RTy>
176 unsigned NumElements,
Type *EltType,
179 assert((!isa<ConstantInt>(Stride) ||
180 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
181 "Stride must be >= the number of elements in the result vector.");
188 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->
isZero())
191 VecStart = Builder.
CreateGEP(EltType, BasePtr, VecStart,
"vec.gep");
203 ShapeInfo(
unsigned NumRows = 0,
unsigned NumColumns = 0)
204 : NumRows(NumRows), NumColumns(NumColumns),
212 return NumRows == other.NumRows && NumColumns == other.NumColumns;
214 bool operator!=(
const ShapeInfo &other) {
return !(*
this == other); }
218 operator bool()
const {
219 assert(NumRows == 0 || NumColumns != 0);
223 unsigned getStride()
const {
229 unsigned getNumVectors()
const {
236 ShapeInfo t()
const {
return ShapeInfo(NumColumns, NumRows); }
240static bool isUniformShape(
Value *V) {
245 switch (
I->getOpcode()) {
246 case Instruction::FAdd:
247 case Instruction::FSub:
248 case Instruction::FMul:
249 case Instruction::FNeg:
250 case Instruction::Add:
251 case Instruction::Mul:
252 case Instruction::Sub:
260static std::optional<ShapeInfo>
266 if (
match(
I, m_Intrinsic<Intrinsic::matrix_multiply>(
268 return ShapeInfo(M, K);
272 return ShapeInfo(
N, M);
274 if (
match(
I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
277 return ShapeInfo(
N, M);
278 if (
match(
I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
280 return ShapeInfo(M,
N);
283 auto OpShape = ShapeMap.
find(MatrixA);
284 if (OpShape != ShapeMap.
end())
288 if (isUniformShape(
I)) {
290 for (
auto &
Op :
I->operands()) {
291 auto OpShape = ShapeMap.
find(
Op.get());
292 if (OpShape != ShapeMap.
end())
322class LowerMatrixIntrinsics {
334 unsigned NumStores = 0;
336 unsigned NumLoads = 0;
338 unsigned NumComputeOps = 0;
342 unsigned NumExposedTransposes = 0;
345 NumStores +=
RHS.NumStores;
346 NumLoads +=
RHS.NumLoads;
347 NumComputeOps +=
RHS.NumComputeOps;
348 NumExposedTransposes +=
RHS.NumExposedTransposes;
360 bool IsColumnMajor =
true;
367 MatrixTy(
unsigned NumRows,
unsigned NumColumns,
Type *EltTy)
370 unsigned D = isColumnMajor() ? NumColumns : NumRows;
371 for (
unsigned J = 0; J <
D; ++J)
373 EltTy, isColumnMajor() ? NumRows : NumColumns)));
376 Value *getVector(
unsigned i)
const {
return Vectors[i]; }
377 Value *getColumn(
unsigned i)
const {
378 assert(isColumnMajor() &&
"only supported for column-major matrixes");
381 Value *getRow(
unsigned i)
const {
382 assert(!isColumnMajor() &&
"only supported for row-major matrixes");
386 void setVector(
unsigned i,
Value *V) { Vectors[i] =
V; }
388 Type *getElementType()
const {
return getVectorTy()->getElementType(); }
390 unsigned getNumVectors()
const {
392 return getNumColumns();
396 unsigned getNumColumns()
const {
398 return Vectors.
size();
400 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
401 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
404 unsigned getNumRows()
const {
405 if (isColumnMajor()) {
406 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
407 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
409 return Vectors.
size();
414 assert(isColumnMajor() &&
"only supported for column-major matrixes");
415 return getVectorTy();
419 return cast<VectorType>(Vectors[0]->
getType());
424 "columns() only supported for column-major matrixes");
435 return Vectors.
size() == 1 ? Vectors[0]
439 MatrixTy &addNumLoads(
unsigned N) {
440 OpInfo.NumLoads +=
N;
444 void setNumLoads(
unsigned N) { OpInfo.NumLoads =
N; }
446 MatrixTy &addNumStores(
unsigned N) {
447 OpInfo.NumStores +=
N;
451 MatrixTy &addNumExposedTransposes(
unsigned N) {
452 OpInfo.NumExposedTransposes +=
N;
456 MatrixTy &addNumComputeOps(
unsigned N) {
457 OpInfo.NumComputeOps +=
N;
461 unsigned getNumStores()
const {
return OpInfo.NumStores; }
462 unsigned getNumLoads()
const {
return OpInfo.NumLoads; }
463 unsigned getNumComputeOps()
const {
return OpInfo.NumComputeOps; }
465 const OpInfoTy &getOpInfo()
const {
return OpInfo; }
467 bool isColumnMajor()
const {
return IsColumnMajor; }
469 unsigned getStride()
const {
472 return getNumColumns();
480 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(
I);
481 assert(cast<FixedVectorType>(Vec->
getType())->getNumElements() >=
483 "Extracted vector will contain poison values");
512 if (isa<FPMathOperator>(*Inst))
527 unsigned getNumOps(
Type *VT) {
528 assert(isa<VectorType>(VT) &&
"Expected vector type");
530 cast<FixedVectorType>(VT)->getNumElements());
534 bool isMinimal()
const {
540 unsigned getNumOps(
Type *ST,
unsigned N) {
541 return std::ceil((
ST->getPrimitiveSizeInBits() *
N).getFixedValue() /
552 MatrixTy getMatrix(
Value *MatrixVal,
const ShapeInfo &SI,
555 assert(VType &&
"MatrixVal must be a vector type");
557 SI.NumRows *
SI.NumColumns &&
558 "The vector size must match the number of matrix elements");
564 auto Found = Inst2ColumnMatrix.
find(MatrixVal);
565 if (Found != Inst2ColumnMatrix.
end()) {
566 MatrixTy &
M = Found->second;
569 if (
SI.NumRows ==
M.getNumRows() &&
SI.NumColumns ==
M.getNumColumns())
572 MatrixVal =
M.embedInVector(Builder);
577 for (
unsigned MaskStart = 0;
578 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
579 MaskStart +=
SI.getStride()) {
591 bool setShapeInfo(
Value *V, ShapeInfo Shape) {
592 assert(Shape &&
"Shape not set");
593 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
596 auto SIter = ShapeMap.
find(V);
597 if (SIter != ShapeMap.
end()) {
599 SIter->second.NumColumns != Shape.NumColumns)) {
600 errs() <<
"Conflicting shapes (" << SIter->second.NumRows <<
"x"
601 << SIter->second.NumColumns <<
" vs " << Shape.NumRows <<
"x"
602 << Shape.NumColumns <<
") for " << *
V <<
"\n";
604 "Matrix shape verification failed, compilation aborted!");
608 << SIter->second.NumRows <<
" "
609 << SIter->second.NumColumns <<
" for " << *V <<
"\n");
614 LLVM_DEBUG(
dbgs() <<
" " << Shape.NumRows <<
" x " << Shape.NumColumns
615 <<
" for " << *V <<
"\n");
621 bool supportsShapeInfo(
Value *V) {
628 switch (
II->getIntrinsicID()) {
629 case Intrinsic::matrix_multiply:
630 case Intrinsic::matrix_transpose:
631 case Intrinsic::matrix_column_major_load:
632 case Intrinsic::matrix_column_major_store:
637 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
651 while (!WorkList.
empty()) {
655 bool Propagate =
false;
656 if (
auto SI = computeShapeInfoForInst(Inst, ShapeMap))
657 Propagate = setShapeInfo(Inst, *SI);
676 auto pushInstruction = [](
Value *
V,
686 while (!WorkList.
empty()) {
689 size_t BeforeProcessingV = WorkList.
size();
690 if (!isa<Instruction>(V))
698 if (
match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
701 if (setShapeInfo(MatrixA, {
M,
N}))
702 pushInstruction(MatrixA, WorkList);
704 if (setShapeInfo(MatrixB, {
N,
K}))
705 pushInstruction(MatrixB, WorkList);
707 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
710 if (setShapeInfo(MatrixA, {
M,
N}))
711 pushInstruction(MatrixA, WorkList);
712 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
715 if (setShapeInfo(MatrixA, {
M,
N})) {
716 pushInstruction(MatrixA, WorkList);
718 }
else if (isa<LoadInst>(V) ||
719 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
721 }
else if (isa<StoreInst>(V)) {
724 }
else if (isUniformShape(V)) {
726 ShapeInfo Shape = ShapeMap[
V];
727 for (
Use &U : cast<Instruction>(V)->operands()) {
728 if (setShapeInfo(
U.get(), Shape))
729 pushInstruction(
U.get(), WorkList);
735 for (
size_t I = BeforeProcessingV;
I != WorkList.
size();
I++)
737 if (isa<Instruction>(U) && V != U)
738 NewWorkList.
push_back(cast<Instruction>(U));
747 Value *Op0, ShapeInfo Shape0,
Value *Op1, ShapeInfo Shape1,
752 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->
getName() +
"_t");
755 setShapeInfo(T0, Shape0.t());
757 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->
getName() +
"_t");
758 setShapeInfo(T1, Shape1.t());
759 return Operation(T0, Shape0.t(), T1, Shape1.t());
766 auto S = ShapeMap.
find(&Old);
767 if (S != ShapeMap.
end()) {
769 if (supportsShapeInfo(New))
786 if (!
match(&
I, m_Intrinsic<Intrinsic::matrix_transpose>(
792 if (
match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(TATA)))) {
793 updateShapeAndReplaceAllUsesWith(
I, TATA);
801 updateShapeAndReplaceAllUsesWith(
I, TA);
808 if (
match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
811 auto NewInst = distributeTransposes(
812 TAMB, {
K,
C}, TAMA, {
R,
K}, Builder,
813 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
816 Shape1.NumColumns,
"mmul");
818 updateShapeAndReplaceAllUsesWith(
I, NewInst);
833 auto NewInst = distributeTransposes(
834 TAMA, {
R,
C}, TAMB, {
R,
C}, Builder,
835 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
836 bool IsFP =
I.getType()->isFPOrFPVectorTy();
837 auto *
Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1,
"mmul")
838 : LocalBuilder.CreateMul(T0, T1,
"mmul");
840 setShapeInfo(Result, Shape0);
843 updateShapeAndReplaceAllUsesWith(
I, NewInst);
853 auto NewInst = distributeTransposes(
854 TAMA, {
R,
C}, TAMB, {
R,
C}, Builder,
855 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
856 bool IsFP =
I.getType()->isFPOrFPVectorTy();
857 auto *
Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1,
"madd")
858 : LocalBuilder.CreateAdd(T0, T1,
"madd");
861 setShapeInfo(Result, Shape0);
864 updateShapeAndReplaceAllUsesWith(
I, NewInst);
879 cast<Instruction>(
A)->eraseFromParent();
880 if (
A !=
B &&
B->use_empty())
881 cast<Instruction>(
B)->eraseFromParent();
887 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>(
890 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT))) &&
895 BT, AT,
C->getZExtValue(),
K->getZExtValue(),
R->getZExtValue());
896 setShapeInfo(M, {
C,
R});
899 updateShapeAndReplaceAllUsesWith(
I, NewInst);
900 CleanupBinOp(
I,
A,
B);
906 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
908 match(
B, m_Intrinsic<Intrinsic::matrix_transpose>(
911 auto *
Add = cast<Instruction>(Builder.CreateFAdd(AT,
BT,
"mfadd"));
912 setShapeInfo(
Add, {
R,
C});
914 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
915 Add,
R->getZExtValue(),
C->getZExtValue(),
"mfadd_t");
916 updateShapeAndReplaceAllUsesWith(
I, NewInst);
917 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
918 computeShapeInfoForInst(&
I, ShapeMap) &&
919 "Shape of new instruction doesn't match original shape.");
920 CleanupBinOp(
I,
A,
B);
921 assert(computeShapeInfoForInst(
Add, ShapeMap).value_or(ShapeMap[
Add]) ==
923 "Shape of updated addition doesn't match cached shape.");
928 void optimizeTransposes() {
961 switch (
II->getIntrinsicID()) {
962 case Intrinsic::matrix_multiply:
963 case Intrinsic::matrix_transpose:
964 case Intrinsic::matrix_column_major_load:
965 case Intrinsic::matrix_column_major_store:
974 if (WorkList.
empty())
978 while (!WorkList.
empty()) {
979 WorkList = propagateShapeForward(WorkList);
980 WorkList = propagateShapeBackward(WorkList);
984 optimizeTransposes();
986 dbgs() <<
"Dump after matrix transpose optimization:\n";
991 bool Changed =
false;
999 for (
auto *BB : RPOT)
1001 if (
match(&
I, m_Intrinsic<Intrinsic::lifetime_end>()))
1002 LifetimeEnds.
push_back(cast<IntrinsicInst>(&
I));
1003 if (ShapeMap.
find(&
I) == ShapeMap.
end())
1005 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>()))
1006 MaybeFusableInsts.
push_back(cast<CallInst>(&
I));
1012 for (
CallInst *CI : MaybeFusableInsts)
1016 for (
CallInst *CI : MaybeFusableInsts)
1017 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1019 Changed = !FusedInsts.
empty();
1023 if (FusedInsts.
count(Inst))
1028 if (
CallInst *CInst = dyn_cast<CallInst>(Inst))
1029 Changed |= VisitCallInst(CInst);
1033 if (
auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1034 Changed |= VisitBinaryOperator(BinOp);
1035 if (
auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1036 Changed |= VisitUnaryOperator(UnOp);
1038 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1040 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1044 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1045 RemarkGen.emitRemarks();
1058 for (
auto *Inst :
reverse(ToRemove)) {
1060 if (
auto *Poisoned = dyn_cast<Instruction>(
U.getUser()))
1061 PoisonedInsts.
insert(Poisoned);
1065 PoisonedInsts.
erase(Inst);
1067 if (!PoisonedInsts.
empty()) {
1069 dbgs() <<
"Poisoned but present instructions:\n";
1070 for (
auto *
I : PoisonedInsts)
1071 dbgs() << *
I <<
"\n";
1079 bool VisitCallInst(
CallInst *Inst) {
1084 case Intrinsic::matrix_multiply:
1085 LowerMultiply(Inst);
1087 case Intrinsic::matrix_transpose:
1088 LowerTranspose(Inst);
1090 case Intrinsic::matrix_column_major_load:
1091 LowerColumnMajorLoad(Inst);
1093 case Intrinsic::matrix_column_major_store:
1094 LowerColumnMajorStore(Inst);
1109 Align InitialAlign =
DL.getValueOrABITypeAlignment(
A, ElementTy);
1111 return InitialAlign;
1113 TypeSize ElementSizeInBits =
DL.getTypeSizeInBits(ElementTy);
1114 if (
auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1116 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1125 bool IsVolatile, ShapeInfo Shape,
IRBuilder<> &Builder) {
1126 auto *VType = cast<VectorType>(Ty);
1127 Type *EltTy = VType->getElementType();
1131 for (
unsigned I = 0, E = Shape.getNumVectors();
I < E; ++
I) {
1134 Stride, Shape.getStride(), EltTy, Builder);
1136 VecTy,
GEP, getAlignForIndex(
I, Stride, EltTy, MAlign),
1137 IsVolatile,
"col.load");
1141 return Result.addNumLoads(getNumOps(
Result.getVectorTy()) *
1149 ShapeInfo ResultShape,
Type *EltTy,
1157 ResultShape.NumColumns);
1159 return loadMatrix(TileTy, TileStart,
Align,
1160 Builder.
getInt64(MatrixShape.getStride()), IsVolatile,
1161 ResultShape, Builder);
1166 bool IsVolatile, ShapeInfo Shape) {
1168 finalizeLowering(Inst,
1177 void LowerColumnMajorLoad(
CallInst *Inst) {
1179 "Intrinsic only supports column-major layout!");
1184 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1189 void storeMatrix(
const MatrixTy &StoreVal,
Value *MatrixPtr,
1190 MaybeAlign MAlign,
bool IsVolatile, ShapeInfo MatrixShape,
1197 StoreVal.getNumColumns());
1199 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1200 Builder.
getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1205 MatrixTy storeMatrix(
Type *Ty, MatrixTy StoreVal,
Value *
Ptr,
1208 auto VType = cast<VectorType>(Ty);
1210 for (
auto Vec :
enumerate(StoreVal.vectors())) {
1215 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1217 getAlignForIndex(Vec.index(), Stride,
1218 VType->getElementType(),
1222 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1223 StoreVal.getNumVectors());
1228 Value *Stride,
bool IsVolatile, ShapeInfo Shape) {
1230 auto StoreVal = getMatrix(
Matrix, Shape, Builder);
1231 finalizeLowering(Inst,
1232 storeMatrix(
Matrix->getType(), StoreVal,
Ptr,
A, Stride,
1233 IsVolatile, Builder),
1240 void LowerColumnMajorStore(
CallInst *Inst) {
1242 "Intrinsic only supports column-major layout!");
1248 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1256 unsigned BlockNumElts =
1257 cast<FixedVectorType>(
Block->getType())->getNumElements();
1258 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1259 assert(NumElts >= BlockNumElts &&
"Too few elements for current block");
1268 for (i = 0; i <
I; i++)
1271 unsigned VecNumElts =
1272 cast<FixedVectorType>(Col->getType())->getNumElements();
1273 for (; i <
I + BlockNumElts; i++)
1274 Mask.push_back(i -
I + VecNumElts);
1276 for (; i < VecNumElts; i++)
1284 unsigned &NumComputeOps) {
1285 NumComputeOps += getNumOps(
A->getType());
1290 if (AllowContraction) {
1294 Func.getParent(), Intrinsic::fmuladd,
A->getType());
1297 NumComputeOps += getNumOps(
A->getType());
1302 NumComputeOps += getNumOps(
A->getType());
1314 auto inserted = Inst2ColumnMatrix.
insert(std::make_pair(Inst,
Matrix));
1316 assert(inserted.second &&
"multiple matrix lowering mapping");
1319 Value *Flattened =
nullptr;
1321 if (ShapeMap.
find(
U.getUser()) == ShapeMap.
end()) {
1323 Flattened =
Matrix.embedInVector(Builder);
1332 void lowerDotProduct(
CallInst *MatMul,
1341 if (LShape.NumRows != 1 || RShape.NumColumns != 1)
1354 auto CanBeFlattened = [](
Value *
Op) {
1360 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1361 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1367 auto GetCostForArg = [
this, &CanBeFlattened](
Value *
Op,
unsigned N) {
1368 if (ShapeMap.
find(
Op) == ShapeMap.
end())
1371 if (!isa<Instruction>(
Op))
1377 if (!CanBeFlattened(
Op)) {
1380 for (
unsigned I = 1;
I <
N; ++
I)
1394 return NewCost - OriginalCost;
1397 if (
match(
Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1402 for (
unsigned I = 1;
I <
N; ++
I)
1425 while (!WorkList.
empty()) {
1431 if (OpCost + LHSCost >= LHSCost)
1436 if (
auto *
I = dyn_cast<Instruction>(
Op))
1437 WorkList.
append(
I->op_begin(),
I->op_end());
1441 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1442 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1446 IsIntVec ? std::nullopt : std::optional(FMF)) +
1450 (LShape.NumColumns - 1) +
1452 (LShape.NumColumns);
1453 if ((LHSCost + ReductionCost - SequentialAddCost) >
InstructionCost(0))
1456 FusedInsts.
insert(MatMul);
1458 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1463 if (!CanBeFlattened(
Op))
1467 ShapeMap[
Op] = ShapeMap[
Op].t();
1471 FusedInsts.
insert(cast<Instruction>(
Op));
1474 if (
match(
Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1476 auto *NewLoad = Builder.
CreateLoad(
Op->getType(), Arg);
1477 Op->replaceAllUsesWith(NewLoad);
1478 cast<Instruction>(
Op)->eraseFromParent();
1480 }
else if (
match(
Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1483 Op->replaceAllUsesWith(Arg);
1488 for (
auto *V : ToFlatten)
1502 ConstantFP::get(cast<VectorType>(
LHS->
getType())->getElementType(),
1512 FusedInsts.insert(MatMul);
1523 void emitMatrixMultiply(MatrixTy &Result,
const MatrixTy &
A,
1526 const unsigned VF = std::max<unsigned>(
1529 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1531 unsigned R =
Result.getNumRows();
1532 unsigned C =
Result.getNumColumns();
1533 unsigned M =
A.getNumColumns();
1535 bool IsFP =
Result.getElementType()->isFloatingPointTy();
1536 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1537 Result.isColumnMajor() ==
A.isColumnMajor() &&
1538 "operands must agree on matrix layout");
1539 unsigned NumComputeOps = 0;
1543 if (
A.isColumnMajor()) {
1547 for (
unsigned J = 0; J <
C; ++J) {
1550 bool isSumZero = isa<ConstantAggregateZero>(
Result.getColumn(J));
1559 for (
unsigned K = 0;
K <
M; ++
K) {
1562 B.getColumn(IsScalarMatrixTransposed ? K : J),
1563 IsScalarMatrixTransposed ? J : K);
1566 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, L,
Splat,
1577 for (
unsigned I = 0;
I <
R; ++
I) {
1579 bool isSumZero = isa<ConstantAggregateZero>(
Result.getRow(
I));
1580 for (
unsigned J = 0; J <
C; J +=
BlockSize) {
1585 Value *Sum =
nullptr;
1586 for (
unsigned K = 0;
K <
M; ++
K) {
1589 A.getVector(IsScalarMatrixTransposed ? K :
I),
1590 IsScalarMatrixTransposed ?
I : K);
1593 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum,
Splat, R,
1601 Result.addNumComputeOps(NumComputeOps);
1614 return Load->getPointerOperand();
1630 nullptr,
"alias_cont");
1636 nullptr,
"no_alias");
1646 const_cast<Value *
>(StoreLoc.
Ptr), IntPtrTy,
"store.begin");
1648 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.
Size.
getValue()),
1649 "store.end",
true,
true);
1651 IntPtrTy,
"load.begin");
1661 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.
Size.
getValue()),
1662 "load.end",
true,
true);
1668 auto *VT = cast<FixedVectorType>(
Load->getType());
1671 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1679 PHI->addIncoming(
Load->getPointerOperand(), Check0);
1680 PHI->addIncoming(
Load->getPointerOperand(), Check1);
1681 PHI->addIncoming(Alloca, Copy);
1692 bool isFusionProfitable(
CallInst *MatMul) {
1699 const unsigned R = LShape.NumRows;
1700 const unsigned C = RShape.NumColumns;
1701 const unsigned M = LShape.NumColumns;
1702 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1704 const unsigned VF = std::max<unsigned>(
1716 if (R <= VF &&
C == 1)
1722 unsigned Op0Regs = (
R + VF - 1) / VF * M;
1723 unsigned Op1Regs = (
M + VF - 1) / VF *
C;
1724 return Op0Regs + Op1Regs >
1728 MatrixTy getZeroMatrix(
Type *EltType,
unsigned R,
unsigned C) {
1731 for (
unsigned I = 0;
I <
C; ++
I)
1736 void createTiledLoops(
CallInst *MatMul,
Value *LPtr, ShapeInfo LShape,
1738 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1741 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns,
TileSize);
1748 BasicBlock *InnerBody = TI.CreateTiledLoops(Start,
End, Builder, DTU, *LI);
1752 MatrixTy TileResult;
1760 TI.RowLoop.Header->getSingleSuccessor());
1761 TileResult.addVector(Phi);
1770 loadMatrix(LPtr, {},
false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1773 loadMatrix(RPtr, {},
false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1775 emitMatrixMultiply(TileResult,
A,
B, Builder,
true,
false,
1779 storeMatrix(TileResult,
Store->getPointerOperand(),
Store->getAlign(),
1780 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1781 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1783 for (
unsigned I = 0;
I < TileResult.getNumVectors();
I++)
1784 ColumnPhis[
I]->addIncoming(TileResult.getVector(
I), TI.KLoop.Latch);
1790 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns /
TileSize);
1792 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1799 "Tiling only supported for column-major matrixes at the moment!");
1800 if (!isFusionProfitable(MatMul))
1806 const unsigned R = LShape.NumRows;
1807 const unsigned C = RShape.NumColumns;
1808 const unsigned M = LShape.NumColumns;
1809 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1811 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1812 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1816 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1819 for (
unsigned J = 0; J <
C; J +=
TileSize)
1821 const unsigned TileR = std::min(R -
I,
unsigned(
TileSize));
1822 const unsigned TileC = std::min(
C - J,
unsigned(
TileSize));
1823 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1826 const unsigned TileM = std::min(M - K,
unsigned(
TileSize));
1830 {TileR, TileM}, EltType, Builder);
1834 {TileM, TileC}, EltType, Builder);
1835 emitMatrixMultiply(Res,
A,
B, Builder,
true,
false,
1838 storeMatrix(Res, CPtr,
Store->getAlign(),
Store->isVolatile(), {R, M},
1845 FusedInsts.
insert(Store);
1846 FusedInsts.
insert(MatMul);
1847 Store->eraseFromParent();
1850 FusedInsts.
insert(LoadOp0);
1853 if (LoadOp1 != LoadOp0 && LoadOp1->
hasNUses(0)) {
1854 FusedInsts.
insert(LoadOp1);
1864 LowerMatrixMultiplyFused(
CallInst *MatMul,
1870 assert(AA && LI &&
"Analyses should be available");
1879 :
match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
T)))) {
1881 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1884 const unsigned R = LShape.NumRows;
1885 const unsigned M = LShape.NumColumns;
1886 const unsigned C = RShape.NumColumns;
1893 MA = getMatrix(
A, ShapeInfo(R, M), Builder);
1894 MB = getMatrix(
T, ShapeInfo(
C, M), Builder);
1897 MA = getMatrix(
T, ShapeInfo(R, M), Builder);
1898 MB = getMatrix(
B, ShapeInfo(
C, M), Builder);
1903 MatrixTy
Result(R,
C, EltType);
1905 emitMatrixMultiply(Result, MA, MB, Builder,
false,
true,
1908 FusedInsts.
insert(MatMul);
1910 FusedInsts.
insert(cast<Instruction>(Transpose));
1911 ToRemove.push_back(cast<Instruction>(Transpose));
1914 Inst2ColumnMatrix[Transpose] = MatrixTy(M,
C, EltType);
1916 finalizeLowering(MatMul, Result, Builder);
1925 auto *LoadOp0 = dyn_cast<LoadInst>(
A);
1926 auto *LoadOp1 = dyn_cast<LoadInst>(
B);
1928 if (LoadOp0 && LoadOp1 && Store) {
1934 for (
unsigned I = 0;
I != WorkList.
size(); ++
I) {
1935 Value *Current = WorkList[
I];
1936 auto *CurrI = dyn_cast<Instruction>(Current);
1939 if (isa<PHINode>(CurrI))
1943 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1946 WorkList.
insert(CurrI->op_begin(), CurrI->op_end());
1953 I->moveBefore(MatMul);
1965 bool FusableOpsInSameBlock = LoadOp0->
getParent() == StoreParent &&
1967 for (
unsigned Idx = 0;
Idx != LifetimeEnds.
size();) {
1978 if (FusableOpsInSameBlock &&
End->getParent() != StoreParent)
1992 if (
End->getParent() == StoreParent) {
1993 End->moveAfter(Store);
2004 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2010 void LowerMultiply(
CallInst *MatMul) {
2012 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
2016 const MatrixTy &Lhs = getMatrix(MatMul->
getArgOperand(0), LShape, Builder);
2017 const MatrixTy &Rhs = getMatrix(MatMul->
getArgOperand(1), RShape, Builder);
2018 assert(Lhs.getElementType() == Rhs.getElementType() &&
2019 "Matrix multiply argument element types do not match.");
2021 const unsigned R = LShape.NumRows;
2022 const unsigned C = RShape.NumColumns;
2023 assert(LShape.NumColumns == RShape.NumRows);
2026 MatrixTy
Result(R,
C, EltType);
2027 assert(Lhs.getElementType() ==
Result.getElementType() &&
2028 "Matrix multiply result element type does not match arguments.");
2030 emitMatrixMultiply(Result, Lhs, Rhs, Builder,
false,
false,
2032 finalizeLowering(MatMul, Result, Builder);
2036 void LowerTranspose(
CallInst *Inst) {
2042 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2044 const unsigned NewNumVecs =
2045 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2046 const unsigned NewNumElts =
2047 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2049 for (
unsigned I = 0;
I < NewNumVecs; ++
I) {
2054 for (
auto J :
enumerate(InputMatrix.vectors())) {
2060 Result.addVector(ResultVector);
2068 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2069 .addNumExposedTransposes(1),
2075 auto I = ShapeMap.
find(Inst);
2076 if (
I == ShapeMap.
end())
2087 auto I = ShapeMap.
find(StoredVal);
2088 if (
I == ShapeMap.
end())
2099 auto I = ShapeMap.
find(Inst);
2100 if (
I == ShapeMap.
end())
2107 ShapeInfo &Shape =
I->second;
2110 MatrixTy
A = getMatrix(Lhs, Shape, Builder);
2111 MatrixTy
B = getMatrix(Rhs, Shape, Builder);
2112 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
2113 Result.isColumnMajor() ==
A.isColumnMajor() &&
2114 "operands must agree on matrix layout");
2121 case Instruction::Add:
2123 case Instruction::Mul:
2125 case Instruction::Sub:
2127 case Instruction::FAdd:
2129 case Instruction::FMul:
2131 case Instruction::FSub:
2138 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
2139 Result.addVector(BuildVectorOp(
A.getVector(
I),
B.getVector(
I)));
2141 finalizeLowering(Inst,
2142 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2150 auto I = ShapeMap.
find(Inst);
2151 if (
I == ShapeMap.
end())
2157 ShapeInfo &Shape =
I->second;
2160 MatrixTy
M = getMatrix(
Op, Shape, Builder);
2165 auto BuildVectorOp = [&Builder, Inst](
Value *
Op) {
2167 case Instruction::FNeg:
2174 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
2175 Result.addVector(BuildVectorOp(
M.getVector(
I)));
2177 finalizeLowering(Inst,
2178 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2187 struct ExprLinearizer {
2188 unsigned LengthToBreak = 100;
2191 unsigned LineLength = 0;
2217 : Stream(Str),
DL(
DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2218 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2220 void indent(
unsigned N) {
2222 for (
unsigned i = 0; i <
N; i++)
2231 void maybeIndent(
unsigned Indent) {
2232 if (LineLength >= LengthToBreak)
2235 if (LineLength == 0)
2240 LineLength += S.
size();
2244 Value *getUnderlyingObjectThroughLoads(
Value *V) {
2246 return getUnderlyingObjectThroughLoads(
Ptr);
2247 else if (
V->getType()->isPointerTy())
2253 bool isMatrix(
Value *V)
const {
return ExprsInSubprogram.
count(V); }
2258 auto M = Inst2Matrix.
find(V);
2259 if (M == Inst2Matrix.
end())
2262 SS <<
M->second.getNumRows();
2264 SS <<
M->second.getNumColumns();
2273 write(
"<no called fn>");
2276 if (!
Name.starts_with(
"llvm.matrix")) {
2280 auto *
II = cast<IntrinsicInst>(CI);
2287 switch (
II->getIntrinsicID()) {
2288 case Intrinsic::matrix_multiply:
2289 prettyPrintMatrixType(
II->getOperand(0), SS);
2291 prettyPrintMatrixType(
II->getOperand(1), SS);
2292 SS <<
"." << *
II->getType()->getScalarType();
2294 case Intrinsic::matrix_transpose:
2295 prettyPrintMatrixType(
II->getOperand(0), SS);
2296 SS <<
"." << *
II->getType()->getScalarType();
2298 case Intrinsic::matrix_column_major_load:
2299 prettyPrintMatrixType(
II, SS);
2300 SS <<
"." << *
II->getType()->getScalarType();
2302 case Intrinsic::matrix_column_major_store:
2303 prettyPrintMatrixType(
II->getOperand(0), SS);
2304 SS <<
"." << *
II->getOperand(0)->getType()->getScalarType();
2314 unsigned getNumShapeArgs(
CallInst *CI)
const {
2316 switch (
II->getIntrinsicID()) {
2317 case Intrinsic::matrix_multiply:
2319 case Intrinsic::matrix_transpose:
2321 case Intrinsic::matrix_column_major_load:
2322 case Intrinsic::matrix_column_major_store:
2335 V = getUnderlyingObjectThroughLoads(V);
2336 if (
V->getType()->isPointerTy()) {
2337 if (isa<AllocaInst>(V)) {
2338 Stream <<
"stack addr";
2344 if (!
V->getName().empty()) {
2345 Stream <<
" %" <<
V->getName() <<
"";
2346 LineLength +=
V->getName().size() + 2;
2354 if (
auto *CI = dyn_cast<ConstantInt>(V))
2355 TmpStream << CI->getValue();
2356 else if (isa<Constant>(V))
2357 TmpStream <<
"constant";
2360 TmpStream <<
"matrix";
2362 TmpStream <<
"scalar";
2365 Tmp = std::string(
StringRef(Tmp).trim());
2366 LineLength += Tmp.size();
2373 void linearizeExpr(
Value *Expr,
unsigned Indent,
bool ParentReused,
2374 bool ParentShared) {
2375 auto *
I = cast<Instruction>(Expr);
2376 maybeIndent(Indent);
2380 bool ExprShared =
false;
2383 if (!ParentShared) {
2384 auto SI = Shared.find(Expr);
2385 assert(SI != Shared.end() &&
SI->second.count(Leaf));
2390 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2391 write(
"shared with remark at line " + std::to_string(
DL.getLine()) +
2392 " column " + std::to_string(
DL.getCol()) +
" (");
2394 ExprShared =
SI->second.size() > 1;
2397 bool Reused = !ReusedExprs.
insert(Expr).second;
2398 if (Reused && !ParentReused)
2401 if (
auto *CI = dyn_cast<CallInst>(
I)) {
2405 }
else if (isa<BitCastInst>(Expr)) {
2411 Ops.
append(
I->value_op_begin(),
I->value_op_end());
2412 write(std::string(
I->getOpcodeName()));
2415 write(std::string(
"("));
2417 unsigned NumOpsToBreak = 1;
2418 if (
match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2422 if (Ops.size() > NumOpsToBreak)
2425 maybeIndent(Indent + 1);
2427 linearizeExpr(
Op, Indent + 1, Reused, ExprShared);
2430 if (
Op != Ops.back())
2437 const std::string &getResult() {
2456 struct RemarkGenerator {
2464 : Inst2Matrix(Inst2Matrix), ORE(ORE),
Func(
Func),
2465 DL(
Func.getDataLayout()) {}
2473 for (
auto *Expr : ExprsInSubprogram)
2476 return ExprsInSubprogram.count(U);
2485 void collectSharedInfo(
Value *Leaf,
Value *V,
2489 if (!ExprsInSubprogram.
count(V))
2492 auto I = Shared.insert({
V, {}});
2493 I.first->second.insert(Leaf);
2495 for (
Value *
Op : cast<Instruction>(V)->operand_values())
2496 collectSharedInfo(Leaf,
Op, ExprsInSubprogram, Shared);
2502 std::pair<OpInfoTy, OpInfoTy>
2506 if (!ExprsInSubprogram.
count(Root))
2510 if (!ReusedExprs.
insert(Root).second)
2513 OpInfoTy SharedCount;
2516 auto I = Shared.find(Root);
2517 auto CM = Inst2Matrix.
find(Root);
2518 if (
I->second.size() == 1)
2519 Count = CM->second.getOpInfo();
2521 SharedCount = CM->second.getOpInfo();
2523 for (
Value *
Op : cast<Instruction>(Root)->operand_values()) {
2524 auto C = sumOpInfos(
Op, ReusedExprs, ExprsInSubprogram, Shared);
2526 SharedCount +=
C.second;
2528 return {Count, SharedCount};
2531 void emitRemarks() {
2539 for (
const auto &KV : Inst2Matrix) {
2540 if (
Func.getSubprogram()) {
2541 auto *
I = cast<Instruction>(KV.first);
2546 I.first->second.push_back(KV.first);
2550 auto I = Subprog2Exprs.
insert({
nullptr, {}});
2551 I.first->second.push_back(KV.first);
2554 for (
auto &KV : Subprog2Exprs) {
2557 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2560 for (
Value *Leaf : Leaves)
2561 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2564 for (
auto *L : Leaves) {
2566 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2567 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2577 OpInfoTy Counts, SharedCounts;
2578 std::tie(Counts, SharedCounts) =
2579 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2584 Rem <<
"Lowered with ";
2585 Rem <<
ore::NV(
"NumStores", Counts.NumStores) <<
" stores, "
2586 <<
ore::NV(
"NumLoads", Counts.NumLoads) <<
" loads, "
2587 <<
ore::NV(
"NumComputeOps", Counts.NumComputeOps)
2589 <<
ore::NV(
"NumExposedTransposes", Counts.NumExposedTransposes)
2590 <<
" exposed transposes";
2592 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2593 SharedCounts.NumComputeOps > 0) {
2594 Rem <<
",\nadditionally "
2595 <<
ore::NV(
"NumStores", SharedCounts.NumStores) <<
" stores, "
2596 <<
ore::NV(
"NumLoads", SharedCounts.NumLoads) <<
" loads, "
2597 <<
ore::NV(
"NumFPOps", SharedCounts.NumComputeOps)
2599 <<
" are shared with other expressions";
2602 Rem << (
"\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2613 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2614 Lin.linearizeExpr(L, 0,
false,
false);
2615 return Lin.getResult();
2636 LowerMatrixIntrinsics LMT(
F,
TTI, AA, DT, LI, ORE);
2651 OS, MapClassName2PassName);
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static const Function * getParent(const Value *V)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned getNumElements(Type *Ty)
static unsigned getFastMathFlags(const MachineInstr &I)
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallSet class.
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
static const int BlockSize
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
reverse_iterator rbegin()
InstListType::reverse_iterator reverse_iterator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
This is the shared class of boolean and integer constants.
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Base class for scope-like contexts.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
DILocation * getInlinedAt() const
Analysis pass which computes a DominatorTree.
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Convenience struct for specifying and reasoning about fast-math flags.
void setAllowContract(bool B=true)
bool allowReassoc() const
Flag queries.
bool allowContract() const
Class to represent fixed width SIMD vectors.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args=std::nullopt, const Twine &Name="", MDNode *FPMathTag=nullptr)
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memcpy between the specified pointers.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
static InstructionCost getInvalid(CostType Val=0)
void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
A wrapper class for inspecting calls to intrinsic functions.
An instruction for reading from memory.
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Align getAlign() const
Return the alignment of the access that is being performed.
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
iterator find(const KeyT &Key)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)
Return a location representing a particular argument of a call.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
A vector that has set insertion semantics.
size_type size() const
Determine the number of elements in the SetVector.
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
A SetVector that performs no allocations if smaller than a certain size.
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
bool isVolatile() const
Return true if this is a store to a volatile memory location.
StringRef - Represent a constant reference to a string, i.e.
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
constexpr size_t size() const
size - Get the string size.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
UnaryOps getOpcode() const
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
iterator find(const KeyT &Val)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
bool erase(const KeyT &Val)
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
user_iterator user_begin()
bool hasOneUse() const
Return true if there is exactly one use of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
Type * getElementType() const
constexpr ScalarTy getFixedValue() const
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
A raw_ostream that writes to an std::string.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
ElementType
The element type of an SRV or UAV resource.
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
NodeAddr< FuncNode * > Func
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
auto reverse(ContainerTy &&C)
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ FMulAdd
Sum of float products with llvm.fmuladd(a * b + sum).
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
A CRTP mix-in to automatically provide informational APIs needed for passes.
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....