51using namespace PatternMatch;
53#define DEBUG_TYPE "lower-matrix-intrinsics"
57 cl::desc(
"Enable/disable fusing matrix instructions."));
62 "Tile size for matrix instruction fusion using square-shaped tiles."));
65 cl::desc(
"Generate loop nest for tiling."));
68 cl::desc(
"Force matrix instruction fusion even if not profitable."));
71 cl::desc(
"Allow the use of FMAs if available and profitable. This may "
72 "result in different results, due to less rounding error."));
76 cl::desc(
"Enable/disable matrix shape verification."),
83 cl::desc(
"Sets the default matrix layout"),
85 "Use column-major layout"),
87 "Use row-major layout")));
95 if (
auto *Subprogram = dyn_cast<DISubprogram>(Scope))
103 if (
auto *SV = dyn_cast<ShuffleVectorInst>(V))
104 return SV->isZeroEltSplat();
109template <
typename LTy,
typename RTy>
115template <
typename LTy,
typename RTy>
163 unsigned NumElements,
Type *EltType,
166 assert((!isa<ConstantInt>(Stride) ||
167 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
168 "Stride must be >= the number of elements in the result vector.");
175 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->
isZero())
178 VecStart = Builder.
CreateGEP(EltType, BasePtr, VecStart,
"vec.gep");
190 ShapeInfo(
unsigned NumRows = 0,
unsigned NumColumns = 0)
191 : NumRows(NumRows), NumColumns(NumColumns),
199 return NumRows == other.NumRows && NumColumns == other.NumColumns;
201 bool operator!=(
const ShapeInfo &other) {
return !(*
this == other); }
205 operator bool()
const {
206 assert(NumRows == 0 || NumColumns != 0);
210 unsigned getStride()
const {
216 unsigned getNumVectors()
const {
223 ShapeInfo t()
const {
return ShapeInfo(NumColumns, NumRows); }
227static bool isUniformShape(
Value *V) {
232 switch (
I->getOpcode()) {
233 case Instruction::FAdd:
234 case Instruction::FSub:
235 case Instruction::FMul:
236 case Instruction::FNeg:
237 case Instruction::Add:
238 case Instruction::Mul:
239 case Instruction::Sub:
247static std::optional<ShapeInfo>
253 if (
match(
I, m_Intrinsic<Intrinsic::matrix_multiply>(
255 return ShapeInfo(M, K);
259 return ShapeInfo(
N, M);
261 if (
match(
I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
264 return ShapeInfo(
N, M);
265 if (
match(
I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
267 return ShapeInfo(M,
N);
270 auto OpShape = ShapeMap.
find(MatrixA);
271 if (OpShape != ShapeMap.
end())
272 return OpShape->second;
275 if (isUniformShape(
I)) {
277 for (
auto &
Op :
I->operands()) {
278 auto OpShape = ShapeMap.
find(
Op.get());
279 if (OpShape != ShapeMap.
end())
280 return OpShape->second;
309class LowerMatrixIntrinsics {
322 unsigned NumStores = 0;
324 unsigned NumLoads = 0;
326 unsigned NumComputeOps = 0;
330 unsigned NumExposedTransposes = 0;
333 NumStores +=
RHS.NumStores;
334 NumLoads +=
RHS.NumLoads;
335 NumComputeOps +=
RHS.NumComputeOps;
336 NumExposedTransposes +=
RHS.NumExposedTransposes;
348 bool IsColumnMajor =
true;
355 MatrixTy(
unsigned NumRows,
unsigned NumColumns,
Type *EltTy)
358 unsigned D = isColumnMajor() ? NumColumns : NumRows;
359 for (
unsigned J = 0; J <
D; ++J)
361 EltTy, isColumnMajor() ? NumRows : NumColumns)));
364 Value *getVector(
unsigned i)
const {
return Vectors[i]; }
365 Value *getColumn(
unsigned i)
const {
366 assert(isColumnMajor() &&
"only supported for column-major matrixes");
369 Value *getRow(
unsigned i)
const {
370 assert(!isColumnMajor() &&
"only supported for row-major matrixes");
374 void setVector(
unsigned i,
Value *V) { Vectors[i] =
V; }
376 Type *getElementType()
const {
return getVectorTy()->getElementType(); }
378 unsigned getNumVectors()
const {
380 return getNumColumns();
384 unsigned getNumColumns()
const {
386 return Vectors.
size();
388 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
389 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
392 unsigned getNumRows()
const {
393 if (isColumnMajor()) {
394 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
395 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
397 return Vectors.
size();
402 assert(isColumnMajor() &&
"only supported for column-major matrixes");
403 return getVectorTy();
407 return cast<VectorType>(Vectors[0]->
getType());
412 "columns() only supported for column-major matrixes");
423 return Vectors.
size() == 1 ? Vectors[0]
427 MatrixTy &addNumLoads(
unsigned N) {
428 OpInfo.NumLoads +=
N;
432 void setNumLoads(
unsigned N) { OpInfo.NumLoads =
N; }
434 MatrixTy &addNumStores(
unsigned N) {
435 OpInfo.NumStores +=
N;
439 MatrixTy &addNumExposedTransposes(
unsigned N) {
440 OpInfo.NumExposedTransposes +=
N;
444 MatrixTy &addNumComputeOps(
unsigned N) {
445 OpInfo.NumComputeOps +=
N;
449 unsigned getNumStores()
const {
return OpInfo.NumStores; }
450 unsigned getNumLoads()
const {
return OpInfo.NumLoads; }
451 unsigned getNumComputeOps()
const {
return OpInfo.NumComputeOps; }
453 const OpInfoTy &getOpInfo()
const {
return OpInfo; }
455 bool isColumnMajor()
const {
return IsColumnMajor; }
457 unsigned getStride()
const {
460 return getNumColumns();
468 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(
I);
469 assert(cast<FixedVectorType>(Vec->
getType())->getNumElements() >=
471 "Extracted vector will contain poison values");
506 if (isa<FPMathOperator>(*Inst))
519 unsigned getNumOps(
Type *VT) {
520 assert(isa<VectorType>(VT) &&
"Expected vector type");
522 cast<FixedVectorType>(VT)->getNumElements());
526 bool isMinimal()
const {
532 unsigned getNumOps(
Type *ST,
unsigned N) {
533 return std::ceil((
ST->getPrimitiveSizeInBits() *
N).getFixedValue() /
544 MatrixTy getMatrix(
Value *MatrixVal,
const ShapeInfo &SI,
547 assert(VType &&
"MatrixVal must be a vector type");
549 SI.NumRows *
SI.NumColumns &&
550 "The vector size must match the number of matrix elements");
556 auto Found = Inst2ColumnMatrix.
find(MatrixVal);
557 if (Found != Inst2ColumnMatrix.
end()) {
558 MatrixTy &
M = Found->second;
561 if (
SI.NumRows ==
M.getNumRows() &&
SI.NumColumns ==
M.getNumColumns())
564 MatrixVal =
M.embedInVector(Builder);
569 for (
unsigned MaskStart = 0;
570 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
571 MaskStart +=
SI.getStride()) {
583 bool setShapeInfo(
Value *V, ShapeInfo Shape) {
584 assert(Shape &&
"Shape not set");
585 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
588 auto SIter = ShapeMap.
find(V);
589 if (SIter != ShapeMap.
end()) {
591 SIter->second.NumColumns != Shape.NumColumns)) {
592 errs() <<
"Conflicting shapes (" << SIter->second.NumRows <<
"x"
593 << SIter->second.NumColumns <<
" vs " << Shape.NumRows <<
"x"
594 << Shape.NumColumns <<
") for " << *
V <<
"\n";
596 "Matrix shape verification failed, compilation aborted!");
600 << SIter->second.NumRows <<
" "
601 << SIter->second.NumColumns <<
" for " << *V <<
"\n");
606 LLVM_DEBUG(
dbgs() <<
" " << Shape.NumRows <<
" x " << Shape.NumColumns
607 <<
" for " << *V <<
"\n");
613 bool supportsShapeInfo(
Value *V) {
620 switch (
II->getIntrinsicID()) {
621 case Intrinsic::matrix_multiply:
622 case Intrinsic::matrix_transpose:
623 case Intrinsic::matrix_column_major_load:
624 case Intrinsic::matrix_column_major_store:
629 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
643 while (!WorkList.
empty()) {
647 bool Propagate =
false;
648 if (
auto SI = computeShapeInfoForInst(Inst, ShapeMap))
649 Propagate = setShapeInfo(Inst, *SI);
668 auto pushInstruction = [](
Value *
V,
678 while (!WorkList.
empty()) {
681 size_t BeforeProcessingV = WorkList.
size();
682 if (!isa<Instruction>(V))
690 if (
match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
693 if (setShapeInfo(MatrixA, {
M,
N}))
694 pushInstruction(MatrixA, WorkList);
696 if (setShapeInfo(MatrixB, {
N,
K}))
697 pushInstruction(MatrixB, WorkList);
699 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
702 if (setShapeInfo(MatrixA, {
M,
N}))
703 pushInstruction(MatrixA, WorkList);
704 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
707 if (setShapeInfo(MatrixA, {
M,
N})) {
708 pushInstruction(MatrixA, WorkList);
710 }
else if (isa<LoadInst>(V) ||
711 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
713 }
else if (isa<StoreInst>(V)) {
716 }
else if (isUniformShape(V)) {
718 ShapeInfo Shape = ShapeMap[
V];
719 for (
Use &U : cast<Instruction>(V)->operands()) {
720 if (setShapeInfo(
U.get(), Shape))
721 pushInstruction(
U.get(), WorkList);
727 for (
size_t I = BeforeProcessingV;
I != WorkList.
size();
I++)
729 if (isa<Instruction>(U) && V != U)
730 NewWorkList.
push_back(cast<Instruction>(U));
739 Value *Op0, ShapeInfo Shape0,
Value *Op1, ShapeInfo Shape1,
744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->
getName() +
"_t");
747 setShapeInfo(T0, Shape0.t());
749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->
getName() +
"_t");
750 setShapeInfo(T1, Shape1.t());
751 return Operation(T0, Shape0.t(), T1, Shape1.t());
756 void eraseFromParentAndRemoveFromShapeMap(
Instruction *Inst) {
757 auto Iter = ShapeMap.
find(Inst);
758 if (Iter != ShapeMap.
end())
759 ShapeMap.
erase(Iter);
767 auto *Inst = cast<Instruction>(V);
771 if (
II != BB.
rend() && Inst == &*
II)
773 eraseFromParentAndRemoveFromShapeMap(Inst);
782 auto S = ShapeMap.
find(&Old);
783 if (S != ShapeMap.
end()) {
785 if (supportsShapeInfo(New))
802 if (!
match(&
I, m_Intrinsic<Intrinsic::matrix_transpose>(
808 if (
match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(TATA)))) {
809 updateShapeAndReplaceAllUsesWith(
I, TATA);
810 eraseFromParentAndMove(&
I,
II, BB);
811 eraseFromParentAndMove(TA,
II, BB);
817 updateShapeAndReplaceAllUsesWith(
I, TA);
818 eraseFromParentAndMove(&
I,
II, BB);
824 if (
match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
827 auto NewInst = distributeTransposes(
828 TAMB, {
K,
C}, TAMA, {
R,
K}, Builder,
829 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
832 Shape1.NumColumns,
"mmul");
834 updateShapeAndReplaceAllUsesWith(
I, NewInst);
835 eraseFromParentAndMove(&
I,
II, BB);
836 eraseFromParentAndMove(TA,
II, BB);
849 auto NewInst = distributeTransposes(
850 TAMA, {
R,
C}, TAMB, {
R,
C}, Builder,
851 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
852 bool IsFP =
I.getType()->isFPOrFPVectorTy();
853 auto *
Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1,
"mmul")
854 : LocalBuilder.CreateMul(T0, T1,
"mmul");
856 setShapeInfo(Result, Shape0);
859 updateShapeAndReplaceAllUsesWith(
I, NewInst);
860 eraseFromParentAndMove(&
I,
II, BB);
861 eraseFromParentAndMove(TA,
II, BB);
869 auto NewInst = distributeTransposes(
870 TAMA, {
R,
C}, TAMB, {
R,
C}, Builder,
871 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
872 bool IsFP =
I.getType()->isFPOrFPVectorTy();
873 auto *
Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1,
"madd")
874 : LocalBuilder.CreateAdd(T0, T1,
"madd");
877 setShapeInfo(Result, Shape0);
880 updateShapeAndReplaceAllUsesWith(
I, NewInst);
881 eraseFromParentAndMove(&
I,
II, BB);
882 eraseFromParentAndMove(TA,
II, BB);
893 eraseFromParentAndRemoveFromShapeMap(&
T);
895 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(
A));
896 if (
A !=
B &&
B->use_empty())
897 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(
B));
903 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>(
906 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT))) &&
911 BT, AT,
C->getZExtValue(),
K->getZExtValue(),
R->getZExtValue());
912 setShapeInfo(M, {
C,
R});
915 updateShapeAndReplaceAllUsesWith(
I, NewInst);
916 CleanupBinOp(
I,
A,
B);
922 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
924 match(
B, m_Intrinsic<Intrinsic::matrix_transpose>(
927 auto *
Add = Builder.CreateFAdd(AT,
BT,
"mfadd");
929 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
930 Add,
R->getZExtValue(),
C->getZExtValue(),
"mfadd_t");
931 updateShapeAndReplaceAllUsesWith(
I, NewInst);
932 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
933 computeShapeInfoForInst(&
I, ShapeMap) &&
934 "Shape of new instruction doesn't match original shape.");
935 CleanupBinOp(
I,
A,
B);
936 if (
auto *AddI = dyn_cast<Instruction>(
Add)) {
937 setShapeInfo(AddI, {
R,
C});
939 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
941 "Shape of updated addition doesn't match cached shape.");
947 void optimizeTransposes() {
980 switch (
II->getIntrinsicID()) {
981 case Intrinsic::matrix_multiply:
982 case Intrinsic::matrix_transpose:
983 case Intrinsic::matrix_column_major_load:
984 case Intrinsic::matrix_column_major_store:
993 if (WorkList.
empty())
1004 while (!WorkList.
empty()) {
1005 WorkList = propagateShapeForward(WorkList);
1006 WorkList = propagateShapeBackward(WorkList);
1010 optimizeTransposes();
1012 dbgs() <<
"Dump after matrix transpose optimization:\n";
1017 bool Changed =
false;
1025 for (
auto *BB : RPOT)
1027 if (
match(&
I, m_Intrinsic<Intrinsic::lifetime_end>()))
1028 LifetimeEnds.
push_back(cast<IntrinsicInst>(&
I));
1029 if (ShapeMap.
find(&
I) == ShapeMap.
end())
1031 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>()))
1032 MaybeFusableInsts.
push_back(cast<CallInst>(&
I));
1038 for (
CallInst *CI : MaybeFusableInsts)
1042 for (
CallInst *CI : MaybeFusableInsts)
1044 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1046 Changed = !FusedInsts.
empty();
1050 if (FusedInsts.
count(Inst))
1055 if (
CallInst *CInst = dyn_cast<CallInst>(Inst))
1056 Changed |= VisitCallInst(CInst);
1060 if (
auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1061 Changed |= VisitBinaryOperator(BinOp);
1062 if (
auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1063 Changed |= VisitUnaryOperator(UnOp);
1065 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1067 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1071 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1072 RemarkGen.emitRemarks();
1085 for (
auto *Inst :
reverse(ToRemove)) {
1087 if (
auto *Poisoned = dyn_cast<Instruction>(
U.getUser()))
1088 PoisonedInsts.
insert(Poisoned);
1092 PoisonedInsts.
erase(Inst);
1094 if (!PoisonedInsts.
empty()) {
1096 dbgs() <<
"Poisoned but present instructions:\n";
1097 for (
auto *
I : PoisonedInsts)
1098 dbgs() << *
I <<
"\n";
1106 bool VisitCallInst(
CallInst *Inst) {
1111 case Intrinsic::matrix_multiply:
1112 LowerMultiply(Inst);
1114 case Intrinsic::matrix_transpose:
1115 LowerTranspose(Inst);
1117 case Intrinsic::matrix_column_major_load:
1118 LowerColumnMajorLoad(Inst);
1120 case Intrinsic::matrix_column_major_store:
1121 LowerColumnMajorStore(Inst);
1136 Align InitialAlign =
DL.getValueOrABITypeAlignment(
A, ElementTy);
1138 return InitialAlign;
1140 TypeSize ElementSizeInBits =
DL.getTypeSizeInBits(ElementTy);
1141 if (
auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1143 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1152 bool IsVolatile, ShapeInfo Shape,
IRBuilder<> &Builder) {
1153 auto *VType = cast<VectorType>(Ty);
1154 Type *EltTy = VType->getElementType();
1158 for (
unsigned I = 0, E = Shape.getNumVectors();
I < E; ++
I) {
1161 Stride, Shape.getStride(), EltTy, Builder);
1163 VecTy,
GEP, getAlignForIndex(
I, Stride, EltTy, MAlign),
1164 IsVolatile,
"col.load");
1168 return Result.addNumLoads(getNumOps(
Result.getVectorTy()) *
1176 ShapeInfo ResultShape,
Type *EltTy,
1184 ResultShape.NumColumns);
1186 return loadMatrix(TileTy, TileStart,
Align,
1187 Builder.
getInt64(MatrixShape.getStride()), IsVolatile,
1188 ResultShape, Builder);
1193 bool IsVolatile, ShapeInfo Shape) {
1195 finalizeLowering(Inst,
1204 void LowerColumnMajorLoad(
CallInst *Inst) {
1206 "Intrinsic only supports column-major layout!");
1211 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1216 void storeMatrix(
const MatrixTy &StoreVal,
Value *MatrixPtr,
1217 MaybeAlign MAlign,
bool IsVolatile, ShapeInfo MatrixShape,
1224 StoreVal.getNumColumns());
1226 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1227 Builder.
getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1232 MatrixTy storeMatrix(
Type *Ty, MatrixTy StoreVal,
Value *
Ptr,
1235 auto VType = cast<VectorType>(Ty);
1237 for (
auto Vec :
enumerate(StoreVal.vectors())) {
1242 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1244 getAlignForIndex(Vec.index(), Stride,
1245 VType->getElementType(),
1249 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1250 StoreVal.getNumVectors());
1255 Value *Stride,
bool IsVolatile, ShapeInfo Shape) {
1257 auto StoreVal = getMatrix(
Matrix, Shape, Builder);
1258 finalizeLowering(Inst,
1259 storeMatrix(
Matrix->getType(), StoreVal,
Ptr,
A, Stride,
1260 IsVolatile, Builder),
1267 void LowerColumnMajorStore(
CallInst *Inst) {
1269 "Intrinsic only supports column-major layout!");
1275 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1283 unsigned BlockNumElts =
1284 cast<FixedVectorType>(
Block->getType())->getNumElements();
1285 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1286 assert(NumElts >= BlockNumElts &&
"Too few elements for current block");
1295 for (i = 0; i <
I; i++)
1298 unsigned VecNumElts =
1299 cast<FixedVectorType>(Col->getType())->getNumElements();
1300 for (; i <
I + BlockNumElts; i++)
1301 Mask.push_back(i -
I + VecNumElts);
1303 for (; i < VecNumElts; i++)
1311 unsigned &NumComputeOps) {
1312 NumComputeOps += getNumOps(
A->getType());
1317 if (AllowContraction) {
1323 NumComputeOps += getNumOps(
A->getType());
1328 NumComputeOps += getNumOps(
A->getType());
1340 auto inserted = Inst2ColumnMatrix.
insert(std::make_pair(Inst,
Matrix));
1342 assert(inserted.second &&
"multiple matrix lowering mapping");
1345 Value *Flattened =
nullptr;
1347 if (ShapeMap.
find(
U.getUser()) == ShapeMap.
end()) {
1349 Flattened =
Matrix.embedInVector(Builder);
1358 void lowerDotProduct(
CallInst *MatMul,
1367 if (LShape.NumRows != 1 || RShape.NumColumns != 1)
1380 auto CanBeFlattened = [](
Value *
Op) {
1386 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1387 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1393 auto GetCostForArg = [
this, &CanBeFlattened](
Value *
Op,
unsigned N) {
1394 if (ShapeMap.
find(
Op) == ShapeMap.
end())
1397 if (!isa<Instruction>(
Op))
1403 if (!CanBeFlattened(
Op)) {
1406 for (
unsigned I = 1;
I <
N; ++
I)
1420 return NewCost - OriginalCost;
1423 if (
match(
Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1428 for (
unsigned I = 1;
I <
N; ++
I)
1451 while (!WorkList.
empty()) {
1457 if (OpCost + LHSCost >= LHSCost)
1462 if (
auto *
I = dyn_cast<Instruction>(
Op))
1463 WorkList.
append(
I->op_begin(),
I->op_end());
1467 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1468 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1472 IsIntVec ? std::nullopt : std::optional(FMF)) +
1476 (LShape.NumColumns - 1) +
1478 (LShape.NumColumns);
1479 if ((LHSCost + ReductionCost - SequentialAddCost) >
InstructionCost(0))
1482 FusedInsts.
insert(MatMul);
1484 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1489 if (!CanBeFlattened(
Op))
1493 auto It = ShapeMap.
find(
Op);
1494 if (It != ShapeMap.
end()) {
1495 It->second = It->second.t();
1500 FusedInsts.insert(cast<Instruction>(
Op));
1503 if (
match(
Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1505 auto *NewLoad = Builder.
CreateLoad(
Op->getType(), Arg);
1506 Op->replaceAllUsesWith(NewLoad);
1507 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(
Op));
1509 }
else if (
match(
Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1512 Op->replaceAllUsesWith(Arg);
1517 for (
auto *V : ToFlatten)
1531 ConstantFP::get(cast<VectorType>(
LHS->
getType())->getElementType(),
1541 FusedInsts.insert(MatMul);
1552 void emitMatrixMultiply(MatrixTy &Result,
const MatrixTy &
A,
1555 const unsigned VF = std::max<unsigned>(
1558 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1560 unsigned R =
Result.getNumRows();
1561 unsigned C =
Result.getNumColumns();
1562 unsigned M =
A.getNumColumns();
1564 bool IsFP =
Result.getElementType()->isFloatingPointTy();
1565 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1566 Result.isColumnMajor() ==
A.isColumnMajor() &&
1567 "operands must agree on matrix layout");
1568 unsigned NumComputeOps = 0;
1572 if (
A.isColumnMajor()) {
1576 for (
unsigned J = 0; J <
C; ++J) {
1579 bool isSumZero = isa<ConstantAggregateZero>(
Result.getColumn(J));
1588 for (
unsigned K = 0;
K <
M; ++
K) {
1591 B.getColumn(IsScalarMatrixTransposed ? K : J),
1592 IsScalarMatrixTransposed ? J : K);
1595 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, L,
Splat,
1606 for (
unsigned I = 0;
I <
R; ++
I) {
1608 bool isSumZero = isa<ConstantAggregateZero>(
Result.getRow(
I));
1609 for (
unsigned J = 0; J <
C; J +=
BlockSize) {
1614 Value *Sum =
nullptr;
1615 for (
unsigned K = 0;
K <
M; ++
K) {
1618 A.getVector(IsScalarMatrixTransposed ? K :
I),
1619 IsScalarMatrixTransposed ?
I : K);
1622 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum,
Splat, R,
1630 Result.addNumComputeOps(NumComputeOps);
1643 return Load->getPointerOperand();
1659 nullptr,
"alias_cont");
1665 nullptr,
"no_alias");
1675 const_cast<Value *
>(StoreLoc.
Ptr), IntPtrTy,
"store.begin");
1677 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.
Size.
getValue()),
1678 "store.end",
true,
true);
1680 IntPtrTy,
"load.begin");
1690 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.
Size.
getValue()),
1691 "load.end",
true,
true);
1697 auto *VT = cast<FixedVectorType>(
Load->getType());
1700 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1708 PHI->addIncoming(
Load->getPointerOperand(), Check0);
1709 PHI->addIncoming(
Load->getPointerOperand(), Check1);
1710 PHI->addIncoming(Alloca, Copy);
1721 bool isFusionProfitable(
CallInst *MatMul) {
1728 const unsigned R = LShape.NumRows;
1729 const unsigned C = RShape.NumColumns;
1730 const unsigned M = LShape.NumColumns;
1731 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1733 const unsigned VF = std::max<unsigned>(
1745 if (R <= VF &&
C == 1)
1751 unsigned Op0Regs = (
R + VF - 1) / VF * M;
1752 unsigned Op1Regs = (
M + VF - 1) / VF *
C;
1753 return Op0Regs + Op1Regs >
1757 MatrixTy getZeroMatrix(
Type *EltType,
unsigned R,
unsigned C) {
1760 for (
unsigned I = 0;
I <
C; ++
I)
1765 void createTiledLoops(
CallInst *MatMul,
Value *LPtr, ShapeInfo LShape,
1767 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1770 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns,
TileSize);
1777 BasicBlock *InnerBody = TI.CreateTiledLoops(Start,
End, Builder, DTU, *LI);
1781 MatrixTy TileResult;
1789 TI.RowLoop.Header->getSingleSuccessor());
1790 TileResult.addVector(Phi);
1799 loadMatrix(LPtr, {},
false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1802 loadMatrix(RPtr, {},
false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1804 emitMatrixMultiply(TileResult,
A,
B, Builder,
true,
false,
1808 storeMatrix(TileResult,
Store->getPointerOperand(),
Store->getAlign(),
1809 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1810 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1812 for (
unsigned I = 0;
I < TileResult.getNumVectors();
I++)
1813 ColumnPhis[
I]->addIncoming(TileResult.getVector(
I), TI.KLoop.Latch);
1819 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns /
TileSize);
1821 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1828 "Tiling only supported for column-major matrixes at the moment!");
1829 if (!isFusionProfitable(MatMul))
1835 const unsigned R = LShape.NumRows;
1836 const unsigned C = RShape.NumColumns;
1837 const unsigned M = LShape.NumColumns;
1838 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1840 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1841 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1845 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1848 for (
unsigned J = 0; J <
C; J +=
TileSize)
1850 const unsigned TileR = std::min(R -
I,
unsigned(
TileSize));
1851 const unsigned TileC = std::min(
C - J,
unsigned(
TileSize));
1852 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1855 const unsigned TileM = std::min(M - K,
unsigned(
TileSize));
1859 {TileR, TileM}, EltType, Builder);
1863 {TileM, TileC}, EltType, Builder);
1864 emitMatrixMultiply(Res,
A,
B, Builder,
true,
false,
1867 storeMatrix(Res, CPtr,
Store->getAlign(),
Store->isVolatile(), {R, M},
1874 FusedInsts.
insert(Store);
1875 FusedInsts.
insert(MatMul);
1876 eraseFromParentAndRemoveFromShapeMap(Store);
1877 eraseFromParentAndRemoveFromShapeMap(MatMul);
1879 FusedInsts.
insert(LoadOp0);
1880 eraseFromParentAndRemoveFromShapeMap(LoadOp0);
1882 if (LoadOp1 != LoadOp0 && LoadOp1->
hasNUses(0)) {
1883 FusedInsts.
insert(LoadOp1);
1884 eraseFromParentAndRemoveFromShapeMap(LoadOp1);
1893 LowerMatrixMultiplyFused(
CallInst *MatMul,
1899 assert(AA && LI &&
"Analyses should be available");
1908 :
match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
T)))) {
1910 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1913 const unsigned R = LShape.NumRows;
1914 const unsigned M = LShape.NumColumns;
1915 const unsigned C = RShape.NumColumns;
1922 MA = getMatrix(
A, ShapeInfo(R, M), Builder);
1923 MB = getMatrix(
T, ShapeInfo(
C, M), Builder);
1926 MA = getMatrix(
T, ShapeInfo(R, M), Builder);
1927 MB = getMatrix(
B, ShapeInfo(
C, M), Builder);
1932 MatrixTy
Result(R,
C, EltType);
1934 emitMatrixMultiply(Result, MA, MB, Builder,
false,
true,
1937 FusedInsts.
insert(MatMul);
1939 FusedInsts.
insert(cast<Instruction>(Transpose));
1940 ToRemove.push_back(cast<Instruction>(Transpose));
1943 Inst2ColumnMatrix[Transpose] = MatrixTy(M,
C, EltType);
1945 finalizeLowering(MatMul, Result, Builder);
1954 auto *LoadOp0 = dyn_cast<LoadInst>(
A);
1955 auto *LoadOp1 = dyn_cast<LoadInst>(
B);
1957 if (LoadOp0 && LoadOp1 && Store) {
1963 for (
unsigned I = 0;
I != WorkList.
size(); ++
I) {
1964 Value *Current = WorkList[
I];
1965 auto *CurrI = dyn_cast<Instruction>(Current);
1968 if (isa<PHINode>(CurrI))
1972 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1975 WorkList.
insert(CurrI->op_begin(), CurrI->op_end());
1982 I->moveBefore(MatMul);
1994 bool FusableOpsInSameBlock = LoadOp0->
getParent() == StoreParent &&
1996 for (
unsigned Idx = 0;
Idx != LifetimeEnds.
size();) {
2007 if (FusableOpsInSameBlock &&
End->getParent() != StoreParent)
2021 if (
End->getParent() == StoreParent) {
2022 End->moveAfter(Store);
2033 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2039 void LowerMultiply(
CallInst *MatMul) {
2041 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
2045 const MatrixTy &Lhs = getMatrix(MatMul->
getArgOperand(0), LShape, Builder);
2046 const MatrixTy &Rhs = getMatrix(MatMul->
getArgOperand(1), RShape, Builder);
2047 assert(Lhs.getElementType() == Rhs.getElementType() &&
2048 "Matrix multiply argument element types do not match.");
2050 const unsigned R = LShape.NumRows;
2051 const unsigned C = RShape.NumColumns;
2052 assert(LShape.NumColumns == RShape.NumRows);
2055 MatrixTy
Result(R,
C, EltType);
2056 assert(Lhs.getElementType() ==
Result.getElementType() &&
2057 "Matrix multiply result element type does not match arguments.");
2059 emitMatrixMultiply(Result, Lhs, Rhs, Builder,
false,
false,
2061 finalizeLowering(MatMul, Result, Builder);
2065 void LowerTranspose(
CallInst *Inst) {
2071 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2073 const unsigned NewNumVecs =
2074 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2075 const unsigned NewNumElts =
2076 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2078 for (
unsigned I = 0;
I < NewNumVecs; ++
I) {
2083 for (
auto J :
enumerate(InputMatrix.vectors())) {
2089 Result.addVector(ResultVector);
2097 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2098 .addNumExposedTransposes(1),
2104 auto I = ShapeMap.
find(Inst);
2105 if (
I == ShapeMap.
end())
2116 auto I = ShapeMap.
find(StoredVal);
2117 if (
I == ShapeMap.
end())
2128 auto I = ShapeMap.
find(Inst);
2129 if (
I == ShapeMap.
end())
2136 ShapeInfo &Shape =
I->second;
2139 MatrixTy
A = getMatrix(Lhs, Shape, Builder);
2140 MatrixTy
B = getMatrix(Rhs, Shape, Builder);
2141 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
2142 Result.isColumnMajor() ==
A.isColumnMajor() &&
2143 "operands must agree on matrix layout");
2150 case Instruction::Add:
2152 case Instruction::Mul:
2154 case Instruction::Sub:
2156 case Instruction::FAdd:
2158 case Instruction::FMul:
2160 case Instruction::FSub:
2167 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
2168 Result.addVector(BuildVectorOp(
A.getVector(
I),
B.getVector(
I)));
2170 finalizeLowering(Inst,
2171 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2179 auto I = ShapeMap.
find(Inst);
2180 if (
I == ShapeMap.
end())
2186 ShapeInfo &Shape =
I->second;
2189 MatrixTy
M = getMatrix(
Op, Shape, Builder);
2194 auto BuildVectorOp = [&Builder, Inst](
Value *
Op) {
2196 case Instruction::FNeg:
2203 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
2204 Result.addVector(BuildVectorOp(
M.getVector(
I)));
2206 finalizeLowering(Inst,
2207 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2216 struct ExprLinearizer {
2217 unsigned LengthToBreak = 100;
2220 unsigned LineLength = 0;
2246 : Stream(Str),
DL(
DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2247 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2251 for (
unsigned i = 0; i <
N; i++)
2260 void maybeIndent(
unsigned Indent) {
2261 if (LineLength >= LengthToBreak)
2264 if (LineLength == 0)
2269 LineLength += S.
size();
2273 Value *getUnderlyingObjectThroughLoads(
Value *V) {
2275 return getUnderlyingObjectThroughLoads(
Ptr);
2276 else if (
V->getType()->isPointerTy())
2282 bool isMatrix(
Value *V)
const {
return ExprsInSubprogram.
count(V); }
2287 auto M = Inst2Matrix.
find(V);
2288 if (M == Inst2Matrix.
end())
2291 SS <<
M->second.getNumRows();
2293 SS <<
M->second.getNumColumns();
2302 write(
"<no called fn>");
2305 if (!
Name.starts_with(
"llvm.matrix")) {
2309 auto *
II = cast<IntrinsicInst>(CI);
2316 switch (
II->getIntrinsicID()) {
2317 case Intrinsic::matrix_multiply:
2318 prettyPrintMatrixType(
II->getOperand(0), SS);
2320 prettyPrintMatrixType(
II->getOperand(1), SS);
2321 SS <<
"." << *
II->getType()->getScalarType();
2323 case Intrinsic::matrix_transpose:
2324 prettyPrintMatrixType(
II->getOperand(0), SS);
2325 SS <<
"." << *
II->getType()->getScalarType();
2327 case Intrinsic::matrix_column_major_load:
2328 prettyPrintMatrixType(
II, SS);
2329 SS <<
"." << *
II->getType()->getScalarType();
2331 case Intrinsic::matrix_column_major_store:
2332 prettyPrintMatrixType(
II->getOperand(0), SS);
2333 SS <<
"." << *
II->getOperand(0)->getType()->getScalarType();
2342 unsigned getNumShapeArgs(
CallInst *CI)
const {
2344 switch (
II->getIntrinsicID()) {
2345 case Intrinsic::matrix_multiply:
2347 case Intrinsic::matrix_transpose:
2349 case Intrinsic::matrix_column_major_load:
2350 case Intrinsic::matrix_column_major_store:
2363 V = getUnderlyingObjectThroughLoads(V);
2364 if (
V->getType()->isPointerTy()) {
2365 if (isa<AllocaInst>(V)) {
2366 Stream <<
"stack addr";
2372 if (!
V->getName().empty()) {
2373 Stream <<
" %" <<
V->getName() <<
"";
2374 LineLength +=
V->getName().size() + 2;
2382 if (
auto *CI = dyn_cast<ConstantInt>(V))
2383 TmpStream << CI->getValue();
2384 else if (isa<Constant>(V))
2385 TmpStream <<
"constant";
2388 TmpStream <<
"matrix";
2390 TmpStream <<
"scalar";
2392 Tmp = std::string(
StringRef(Tmp).trim());
2393 LineLength += Tmp.size();
2400 void linearizeExpr(
Value *Expr,
unsigned Indent,
bool ParentReused,
2401 bool ParentShared) {
2402 auto *
I = cast<Instruction>(Expr);
2403 maybeIndent(Indent);
2407 bool ExprShared =
false;
2410 if (!ParentShared) {
2411 auto SI = Shared.find(Expr);
2412 assert(SI != Shared.end() &&
SI->second.count(Leaf));
2417 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2418 write(
"shared with remark at line " + std::to_string(
DL.getLine()) +
2419 " column " + std::to_string(
DL.getCol()) +
" (");
2421 ExprShared =
SI->second.size() > 1;
2424 bool Reused = !ReusedExprs.
insert(Expr).second;
2425 if (Reused && !ParentReused)
2428 if (
auto *CI = dyn_cast<CallInst>(
I)) {
2432 }
else if (isa<BitCastInst>(Expr)) {
2438 Ops.
append(
I->value_op_begin(),
I->value_op_end());
2439 write(std::string(
I->getOpcodeName()));
2442 write(std::string(
"("));
2444 unsigned NumOpsToBreak = 1;
2445 if (
match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2449 if (Ops.size() > NumOpsToBreak)
2452 maybeIndent(Indent + 1);
2454 linearizeExpr(
Op, Indent + 1, Reused, ExprShared);
2457 if (
Op != Ops.back())
2464 const std::string &getResult() {
2482 struct RemarkGenerator {
2490 : Inst2Matrix(Inst2Matrix), ORE(ORE),
Func(
Func),
2491 DL(
Func.getDataLayout()) {}
2499 for (
auto *Expr : ExprsInSubprogram)
2502 return ExprsInSubprogram.count(U);
2511 void collectSharedInfo(
Value *Leaf,
Value *V,
2515 if (!ExprsInSubprogram.
count(V))
2518 Shared[
V].insert(Leaf);
2520 for (
Value *
Op : cast<Instruction>(V)->operand_values())
2521 collectSharedInfo(Leaf,
Op, ExprsInSubprogram, Shared);
2527 std::pair<OpInfoTy, OpInfoTy>
2531 if (!ExprsInSubprogram.
count(Root))
2535 if (!ReusedExprs.
insert(Root).second)
2538 OpInfoTy SharedCount;
2541 auto I = Shared.find(Root);
2542 auto CM = Inst2Matrix.
find(Root);
2543 if (
I->second.size() == 1)
2544 Count = CM->second.getOpInfo();
2546 SharedCount = CM->second.getOpInfo();
2548 for (
Value *
Op : cast<Instruction>(Root)->operand_values()) {
2549 auto C = sumOpInfos(
Op, ReusedExprs, ExprsInSubprogram, Shared);
2551 SharedCount +=
C.second;
2553 return {Count, SharedCount};
2556 void emitRemarks() {
2564 for (
const auto &KV : Inst2Matrix) {
2565 if (
Func.getSubprogram()) {
2566 auto *
I = cast<Instruction>(KV.first);
2569 Subprog2Exprs[
getSubprogram(Context->getScope())].push_back(
2574 Subprog2Exprs[
nullptr].push_back(KV.first);
2577 for (
auto &KV : Subprog2Exprs) {
2580 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2583 for (
Value *Leaf : Leaves)
2584 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2587 for (
auto *L : Leaves) {
2589 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2590 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2600 OpInfoTy Counts, SharedCounts;
2601 std::tie(Counts, SharedCounts) =
2602 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2607 Rem <<
"Lowered with ";
2608 Rem <<
ore::NV(
"NumStores", Counts.NumStores) <<
" stores, "
2609 <<
ore::NV(
"NumLoads", Counts.NumLoads) <<
" loads, "
2610 <<
ore::NV(
"NumComputeOps", Counts.NumComputeOps)
2612 <<
ore::NV(
"NumExposedTransposes", Counts.NumExposedTransposes)
2613 <<
" exposed transposes";
2615 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2616 SharedCounts.NumComputeOps > 0) {
2617 Rem <<
",\nadditionally "
2618 <<
ore::NV(
"NumStores", SharedCounts.NumStores) <<
" stores, "
2619 <<
ore::NV(
"NumLoads", SharedCounts.NumLoads) <<
" loads, "
2620 <<
ore::NV(
"NumFPOps", SharedCounts.NumComputeOps)
2622 <<
" are shared with other expressions";
2625 Rem << (
"\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2636 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2637 Lin.linearizeExpr(L, 0,
false,
false);
2638 return Lin.getResult();
2648 LowerMatrixIntrinsics LMT(
F,
TTI, Minimal ?
nullptr : &AM);
2663 OS, MapClassName2PassName);
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static const Function * getParent(const Value *V)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned getNumElements(Type *Ty)
static unsigned getFastMathFlags(const MachineInstr &I)
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallSet class.
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
static const int BlockSize
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
reverse_iterator rbegin()
InstListType::reverse_iterator reverse_iterator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
This is the shared class of boolean and integer constants.
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Base class for scope-like contexts.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
DILocation * getInlinedAt() const
iterator find(const_arg_type_t< KeyT > Val)
bool erase(const KeyT &Val)
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Analysis pass which computes a DominatorTree.
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Convenience struct for specifying and reasoning about fast-math flags.
void setAllowContract(bool B=true)
bool allowReassoc() const
Flag queries.
bool allowContract() const
Class to represent fixed width SIMD vectors.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memcpy between the specified pointers.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
static InstructionCost getInvalid(CostType Val=0)
void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
A wrapper class for inspecting calls to intrinsic functions.
An instruction for reading from memory.
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Align getAlign() const
Return the alignment of the access that is being performed.
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
iterator find(const KeyT &Key)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)
Return a location representing a particular argument of a call.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
A vector that has set insertion semantics.
size_type size() const
Determine the number of elements in the SetVector.
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
A SetVector that performs no allocations if smaller than a certain size.
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
bool isVolatile() const
Return true if this is a store to a volatile memory location.
StringRef - Represent a constant reference to a string, i.e.
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
constexpr size_t size() const
size - Get the string size.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
UnaryOps getOpcode() const
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
user_iterator user_begin()
bool hasOneUse() const
Return true if there is exactly one use of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
Type * getElementType() const
constexpr ScalarTy getFixedValue() const
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
A raw_ostream that writes to an std::string.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
ElementType
The element type of an SRV or UAV resource.
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
NodeAddr< FuncNode * > Func
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
auto reverse(ContainerTy &&C)
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
A CRTP mix-in to automatically provide informational APIs needed for passes.
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....