52using namespace PatternMatch;
54#define DEBUG_TYPE "lower-matrix-intrinsics"
58 cl::desc(
"Enable/disable fusing matrix instructions."));
63 "Tile size for matrix instruction fusion using square-shaped tiles."));
66 cl::desc(
"Generate loop nest for tiling."));
69 cl::desc(
"Force matrix instruction fusion even if not profitable."));
72 cl::desc(
"Allow the use of FMAs if available and profitable. This may "
73 "result in different results, due to less rounding error."));
77 cl::desc(
"Enable/disable matrix shape verification."),
84 cl::desc(
"Sets the default matrix layout"),
86 "Use column-major layout"),
88 "Use row-major layout")));
96 if (
auto *Subprogram = dyn_cast<DISubprogram>(Scope))
105 auto *Inst = cast<Instruction>(V);
107 if (!Inst->use_empty())
109 if (II != BB.
rend() && Inst == &*II)
111 Inst->eraseFromParent();
117 if (
auto *SV = dyn_cast<ShuffleVectorInst>(V))
118 return SV->isZeroEltSplat();
123template <
typename LTy,
typename RTy>
129template <
typename LTy,
typename RTy>
177 unsigned NumElements,
Type *EltType,
180 assert((!isa<ConstantInt>(Stride) ||
181 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
182 "Stride must be >= the number of elements in the result vector.");
183 unsigned AS = cast<PointerType>(
BasePtr->getType())->getAddressSpace();
186 Value *VecStart =
Builder.CreateMul(VecIdx, Stride,
"vec.start");
190 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->
isZero())
193 VecStart =
Builder.CreateGEP(EltType, BasePtr, VecStart,
"vec.gep");
198 Type *VecPtrType = PointerType::get(VecType, AS);
199 return Builder.CreatePointerCast(VecStart, VecPtrType,
"vec.cast");
225class LowerMatrixIntrinsics {
237 unsigned NumStores = 0;
239 unsigned NumLoads = 0;
241 unsigned NumComputeOps = 0;
245 unsigned NumExposedTransposes = 0;
248 NumStores +=
RHS.NumStores;
249 NumLoads +=
RHS.NumLoads;
250 NumComputeOps +=
RHS.NumComputeOps;
251 NumExposedTransposes +=
RHS.NumExposedTransposes;
263 bool IsColumnMajor =
true;
268 : Vectors(Vectors.
begin(), Vectors.
end()),
270 MatrixTy(
unsigned NumRows,
unsigned NumColumns,
Type *EltTy)
273 unsigned D = isColumnMajor() ? NumColumns : NumRows;
274 for (
unsigned J = 0; J <
D; ++J)
276 EltTy, isColumnMajor() ? NumRows : NumColumns)));
279 Value *getVector(
unsigned i)
const {
return Vectors[i]; }
280 Value *getColumn(
unsigned i)
const {
281 assert(isColumnMajor() &&
"only supported for column-major matrixes");
284 Value *getRow(
unsigned i)
const {
285 assert(!isColumnMajor() &&
"only supported for row-major matrixes");
289 void setVector(
unsigned i,
Value *V) { Vectors[i] =
V; }
291 Type *getElementType()
const {
return getVectorTy()->getElementType(); }
293 unsigned getNumVectors()
const {
295 return getNumColumns();
299 unsigned getNumColumns()
const {
301 return Vectors.
size();
303 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
304 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
307 unsigned getNumRows()
const {
308 if (isColumnMajor()) {
309 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
310 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
312 return Vectors.
size();
317 assert(isColumnMajor() &&
"only supported for column-major matrixes");
318 return getVectorTy();
322 return cast<VectorType>(Vectors[0]->
getType());
327 "columns() only supported for column-major matrixes");
338 return Vectors.
size() == 1 ? Vectors[0]
342 MatrixTy &addNumLoads(
unsigned N) {
343 OpInfo.NumLoads +=
N;
347 void setNumLoads(
unsigned N) { OpInfo.NumLoads =
N; }
349 MatrixTy &addNumStores(
unsigned N) {
350 OpInfo.NumStores +=
N;
354 MatrixTy &addNumExposedTransposes(
unsigned N) {
355 OpInfo.NumExposedTransposes +=
N;
359 MatrixTy &addNumComputeOps(
unsigned N) {
360 OpInfo.NumComputeOps +=
N;
364 unsigned getNumStores()
const {
return OpInfo.NumStores; }
365 unsigned getNumLoads()
const {
return OpInfo.NumLoads; }
366 unsigned getNumComputeOps()
const {
return OpInfo.NumComputeOps; }
368 const OpInfoTy &getOpInfo()
const {
return OpInfo; }
370 bool isColumnMajor()
const {
return IsColumnMajor; }
372 unsigned getStride()
const {
375 return getNumColumns();
383 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(
I);
384 assert(cast<FixedVectorType>(Vec->
getType())->getNumElements() >=
386 "Extracted vector will contain poison values");
387 return Builder.CreateShuffleVector(
399 ShapeInfo(
unsigned NumRows = 0,
unsigned NumColumns = 0)
400 : NumRows(NumRows), NumColumns(NumColumns),
408 return NumRows == other.NumRows && NumColumns == other.NumColumns;
410 bool operator!=(
const ShapeInfo &other) {
return !(*
this == other); }
414 operator bool()
const {
415 assert(NumRows == 0 || NumColumns != 0);
419 unsigned getStride()
const {
425 unsigned getNumVectors()
const {
432 ShapeInfo t()
const {
return ShapeInfo(NumColumns, NumRows); }
457 if (isa<FPMathOperator>(*Inst))
472 unsigned getNumOps(
Type *VT) {
473 assert(isa<VectorType>(VT) &&
"Expected vector type");
475 cast<FixedVectorType>(VT)->getNumElements());
479 bool isMinimal()
const {
485 unsigned getNumOps(
Type *ST,
unsigned N) {
486 return std::ceil((
ST->getPrimitiveSizeInBits() *
N).getFixedValue() /
497 MatrixTy getMatrix(
Value *MatrixVal,
const ShapeInfo &SI,
500 assert(VType &&
"MatrixVal must be a vector type");
501 assert(cast<FixedVectorType>(VType)->getNumElements() ==
502 SI.NumRows *
SI.NumColumns &&
503 "The vector size must match the number of matrix elements");
509 auto Found = Inst2ColumnMatrix.
find(MatrixVal);
510 if (Found != Inst2ColumnMatrix.
end()) {
511 MatrixTy &
M = Found->second;
514 if (
SI.NumRows ==
M.getNumRows() &&
SI.NumColumns ==
M.getNumColumns())
517 MatrixVal =
M.embedInVector(Builder);
522 for (
unsigned MaskStart = 0;
523 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
524 MaskStart +=
SI.getStride()) {
536 bool setShapeInfo(
Value *V, ShapeInfo Shape) {
537 assert(Shape &&
"Shape not set");
538 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
541 auto SIter = ShapeMap.
find(V);
542 if (SIter != ShapeMap.
end()) {
544 SIter->second.NumColumns != Shape.NumColumns)) {
545 errs() <<
"Conflicting shapes (" << SIter->second.NumRows <<
"x"
546 << SIter->second.NumColumns <<
" vs " << Shape.NumRows <<
"x"
547 << Shape.NumColumns <<
") for " << *
V <<
"\n";
549 "Matrix shape verification failed, compilation aborted!");
553 << SIter->second.NumRows <<
" "
554 << SIter->second.NumColumns <<
" for " << *V <<
"\n");
559 LLVM_DEBUG(
dbgs() <<
" " << Shape.NumRows <<
" x " << Shape.NumColumns
560 <<
" for " << *V <<
"\n");
564 bool isUniformShape(
Value *V) {
569 switch (
I->getOpcode()) {
570 case Instruction::FAdd:
571 case Instruction::FSub:
572 case Instruction::FMul:
573 case Instruction::FNeg:
574 case Instruction::Add:
575 case Instruction::Mul:
576 case Instruction::Sub:
585 bool supportsShapeInfo(
Value *V) {
593 case Intrinsic::matrix_multiply:
594 case Intrinsic::matrix_transpose:
595 case Intrinsic::matrix_column_major_load:
596 case Intrinsic::matrix_column_major_store:
601 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
615 while (!WorkList.
empty()) {
619 bool Propagate =
false;
626 if (
match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
629 Propagate = setShapeInfo(Inst, {
M,
K});
630 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
633 Propagate = setShapeInfo(Inst, {
N,
M});
634 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
637 Propagate = setShapeInfo(Inst, {
N,
M});
638 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
641 Propagate = setShapeInfo(Inst, {
M,
N});
643 auto OpShape = ShapeMap.
find(MatrixA);
644 if (OpShape != ShapeMap.
end())
645 setShapeInfo(Inst, OpShape->second);
647 }
else if (isUniformShape(Inst)) {
650 auto OpShape = ShapeMap.
find(
Op.get());
651 if (OpShape != ShapeMap.
end()) {
652 Propagate |= setShapeInfo(Inst, OpShape->second);
675 auto pushInstruction = [](
Value *
V,
685 while (!WorkList.
empty()) {
688 size_t BeforeProcessingV = WorkList.
size();
689 if (!isa<Instruction>(V))
697 if (
match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
700 if (setShapeInfo(MatrixA, {
M,
N}))
701 pushInstruction(MatrixA, WorkList);
703 if (setShapeInfo(MatrixB, {
N,
K}))
704 pushInstruction(MatrixB, WorkList);
706 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
709 if (setShapeInfo(MatrixA, {
M,
N}))
710 pushInstruction(MatrixA, WorkList);
711 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
714 if (setShapeInfo(MatrixA, {
M,
N})) {
715 pushInstruction(MatrixA, WorkList);
717 }
else if (isa<LoadInst>(V) ||
718 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
720 }
else if (isa<StoreInst>(V)) {
723 }
else if (isUniformShape(V)) {
725 ShapeInfo Shape = ShapeMap[
V];
726 for (
Use &U : cast<Instruction>(V)->operands()) {
727 if (setShapeInfo(
U.get(), Shape))
728 pushInstruction(
U.get(), WorkList);
734 for (
size_t I = BeforeProcessingV;
I != WorkList.
size();
I++)
736 if (isa<Instruction>(U) && V != U)
737 NewWorkList.
push_back(cast<Instruction>(U));
746 Value *Op0, ShapeInfo Shape0,
Value *Op1, ShapeInfo Shape1,
751 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->
getName() +
"_t");
754 setShapeInfo(T0, Shape0.t());
756 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->
getName() +
"_t");
757 setShapeInfo(T1, Shape1.t());
758 return Operation(T0, Shape0.t(), T1, Shape1.t());
765 auto S = ShapeMap.
find(&Old);
766 if (S != ShapeMap.
end()) {
768 if (supportsShapeInfo(New))
785 if (!
match(&
I, m_Intrinsic<Intrinsic::matrix_transpose>(
791 if (
match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(TATA)))) {
792 updateShapeAndReplaceAllUsesWith(
I, TATA);
800 updateShapeAndReplaceAllUsesWith(
I, TA);
807 if (
match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
810 auto NewInst = distributeTransposes(
812 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
813 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
815 Shape1.NumColumns,
"mmul");
817 updateShapeAndReplaceAllUsesWith(
I, NewInst);
832 auto NewInst = distributeTransposes(
834 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
835 bool IsFP =
I.getType()->isFPOrFPVectorTy();
836 auto *
Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1,
"mmul")
837 : LocalBuilder.CreateMul(T0, T1,
"mmul");
839 setShapeInfo(Result, Shape0);
842 updateShapeAndReplaceAllUsesWith(
I, NewInst);
852 auto NewInst = distributeTransposes(
854 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
855 bool IsFP =
I.getType()->isFPOrFPVectorTy();
856 auto *
Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1,
"madd")
857 : LocalBuilder.CreateAdd(T0, T1,
"madd");
860 setShapeInfo(Result, Shape0);
863 updateShapeAndReplaceAllUsesWith(
I, NewInst);
878 cast<Instruction>(
A)->eraseFromParent();
879 if (
A !=
B &&
B->use_empty())
880 cast<Instruction>(
B)->eraseFromParent();
886 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>(
889 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT))) &&
894 BT, AT,
C->getZExtValue(),
K->getZExtValue(),
R->getZExtValue());
895 setShapeInfo(M, {
C,
R});
898 updateShapeAndReplaceAllUsesWith(
I, NewInst);
899 CleanupBinOp(
I,
A,
B);
903 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
905 match(
B, m_Intrinsic<Intrinsic::matrix_transpose>(
909 setShapeInfo(
Add, {
C,
R});
911 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
912 Add,
C->getZExtValue(),
R->getZExtValue(),
"mfadd_t");
913 updateShapeAndReplaceAllUsesWith(
I, NewInst);
914 CleanupBinOp(
I,
A,
B);
919 void optimizeTransposes() {
923 for (
auto II = BB.
rbegin(); II != BB.
rend();) {
953 case Intrinsic::matrix_multiply:
954 case Intrinsic::matrix_transpose:
955 case Intrinsic::matrix_column_major_load:
956 case Intrinsic::matrix_column_major_store:
965 if (WorkList.
empty())
969 while (!WorkList.
empty()) {
970 WorkList = propagateShapeForward(WorkList);
971 WorkList = propagateShapeBackward(WorkList);
975 optimizeTransposes();
977 dbgs() <<
"Dump after matrix transpose optimization:\n";
982 bool Changed =
false;
989 for (
auto *BB : RPOT)
991 if (ShapeMap.
find(&
I) == ShapeMap.
end())
993 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>()))
994 MaybeFusableInsts.
push_back(cast<CallInst>(&
I));
1000 for (
CallInst *CI : MaybeFusableInsts)
1001 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1004 for (
CallInst *CI : MaybeFusableInsts)
1005 LowerMatrixMultiplyFused(CI, FusedInsts);
1007 Changed = !FusedInsts.
empty();
1011 if (FusedInsts.
count(Inst))
1016 if (
CallInst *CInst = dyn_cast<CallInst>(Inst))
1017 Changed |= VisitCallInst(CInst);
1021 if (
auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1022 Changed |= VisitBinaryOperator(BinOp);
1023 if (
auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1024 Changed |= VisitUnaryOperator(UnOp);
1026 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1028 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1032 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1033 RemarkGen.emitRemarks();
1046 for (
auto *Inst :
reverse(ToRemove)) {
1048 if (
auto *Poisoned = dyn_cast<Instruction>(
U.getUser()))
1049 PoisonedInsts.
insert(Poisoned);
1053 PoisonedInsts.
erase(Inst);
1055 if (!PoisonedInsts.
empty()) {
1057 dbgs() <<
"Poisoned but present instructions:\n";
1058 for (
auto *
I : PoisonedInsts)
1059 dbgs() << *
I <<
"\n";
1068 unsigned AS = cast<PointerType>(
BasePtr->getType())->getAddressSpace();
1069 Type *EltPtrType = PointerType::get(EltType, AS);
1070 return Builder.CreatePointerCast(BasePtr, EltPtrType);
1074 bool VisitCallInst(
CallInst *Inst) {
1079 case Intrinsic::matrix_multiply:
1080 LowerMultiply(Inst);
1082 case Intrinsic::matrix_transpose:
1083 LowerTranspose(Inst);
1085 case Intrinsic::matrix_column_major_load:
1086 LowerColumnMajorLoad(Inst);
1088 case Intrinsic::matrix_column_major_store:
1089 LowerColumnMajorStore(Inst);
1104 Align InitialAlign =
DL.getValueOrABITypeAlignment(
A, ElementTy);
1106 return InitialAlign;
1108 TypeSize ElementSizeInBits =
DL.getTypeSizeInBits(ElementTy);
1109 if (
auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1111 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1120 bool IsVolatile, ShapeInfo Shape,
IRBuilder<> &Builder) {
1121 auto *VType = cast<VectorType>(Ty);
1122 Type *EltTy = VType->getElementType();
1124 Value *EltPtr = createElementPtr(
Ptr, EltTy, Builder);
1126 for (
unsigned I = 0,
E = Shape.getNumVectors();
I <
E; ++
I) {
1129 Stride, Shape.getStride(), EltTy, Builder);
1131 VecTy,
GEP, getAlignForIndex(
I, Stride, EltTy, MAlign),
1132 IsVolatile,
"col.load");
1136 return Result.addNumLoads(getNumOps(
Result.getVectorTy()) *
1144 ShapeInfo ResultShape,
Type *EltTy,
1150 unsigned AS = cast<PointerType>(MatrixPtr->
getType())->getAddressSpace();
1152 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1155 ResultShape.NumColumns);
1156 Type *TilePtrTy = PointerType::get(TileTy, AS);
1158 Builder.CreatePointerCast(TileStart, TilePtrTy,
"col.cast");
1160 return loadMatrix(TileTy, TilePtr,
Align,
1161 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1162 ResultShape, Builder);
1167 bool IsVolatile, ShapeInfo Shape) {
1169 finalizeLowering(Inst,
1178 void LowerColumnMajorLoad(
CallInst *Inst) {
1180 "Intrinsic only supports column-major layout!");
1185 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1190 void storeMatrix(
const MatrixTy &StoreVal,
Value *MatrixPtr,
1191 MaybeAlign MAlign,
bool IsVolatile, ShapeInfo MatrixShape,
1196 unsigned AS = cast<PointerType>(MatrixPtr->
getType())->getAddressSpace();
1198 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1201 StoreVal.getNumColumns());
1202 Type *TilePtrTy = PointerType::get(TileTy, AS);
1204 Builder.CreatePointerCast(TileStart, TilePtrTy,
"col.cast");
1206 storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1207 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1212 MatrixTy storeMatrix(
Type *Ty, MatrixTy StoreVal,
Value *
Ptr,
1215 auto VType = cast<VectorType>(Ty);
1216 Value *EltPtr = createElementPtr(
Ptr, VType->getElementType(), Builder);
1217 for (
auto Vec :
enumerate(StoreVal.vectors())) {
1222 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1224 getAlignForIndex(Vec.index(), Stride,
1225 VType->getElementType(),
1229 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1230 StoreVal.getNumVectors());
1235 Value *Stride,
bool IsVolatile, ShapeInfo Shape) {
1237 auto StoreVal = getMatrix(
Matrix, Shape, Builder);
1238 finalizeLowering(Inst,
1239 storeMatrix(
Matrix->getType(), StoreVal,
Ptr,
A, Stride,
1240 IsVolatile, Builder),
1247 void LowerColumnMajorStore(
CallInst *Inst) {
1249 "Intrinsic only supports column-major layout!");
1255 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1263 unsigned BlockNumElts =
1264 cast<FixedVectorType>(
Block->getType())->getNumElements();
1265 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1266 assert(NumElts >= BlockNumElts &&
"Too few elements for current block");
1275 for (i = 0; i <
I; i++)
1278 unsigned VecNumElts =
1279 cast<FixedVectorType>(Col->getType())->getNumElements();
1280 for (; i <
I + BlockNumElts; i++)
1281 Mask.push_back(i -
I + VecNumElts);
1283 for (; i < VecNumElts; i++)
1286 return Builder.CreateShuffleVector(Col, Block, Mask);
1291 unsigned &NumComputeOps) {
1292 NumComputeOps += getNumOps(
A->getType());
1297 if (AllowContraction) {
1301 Func.getParent(), Intrinsic::fmuladd,
A->getType());
1304 NumComputeOps += getNumOps(
A->getType());
1309 NumComputeOps += getNumOps(
A->getType());
1321 auto inserted = Inst2ColumnMatrix.
insert(std::make_pair(Inst,
Matrix));
1323 assert(inserted.second &&
"multiple matrix lowering mapping");
1326 Value *Flattened =
nullptr;
1328 if (ShapeMap.
find(
U.getUser()) == ShapeMap.
end()) {
1330 Flattened =
Matrix.embedInVector(Builder);
1339 void lowerDotProduct(
CallInst *MatMul,
1348 if (LShape.NumRows != 1 || RShape.NumColumns != 1)
1354 Type *ElementType = cast<VectorType>(
LHS->
getType())->getElementType();
1361 auto CanBeFlattened = [
this](
Value *
Op) {
1367 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1368 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1374 auto GetCostForArg = [
this, &CanBeFlattened](
Value *
Op,
unsigned N) {
1375 if (!isa<Instruction>(Op))
1381 if (!CanBeFlattened(Op)) {
1384 for (
unsigned I = 1;
I <
N; ++
I)
1397 cast<Instruction>(Op)->
getOpcode(), VecTy);
1398 return NewCost - OriginalCost;
1401 if (
match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1406 for (
unsigned I = 1;
I <
N; ++
I)
1420 auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
1423 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1424 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1428 IsIntVec ? std::nullopt : std::optional(FMF)) +
1432 (LShape.NumColumns - 1) +
1434 (LShape.NumColumns);
1435 if ((LHSCost + ReductionCost - SequentialAddCost) >
InstructionCost(0))
1438 FusedInsts.
insert(MatMul);
1440 auto FlattenArg = [&
Builder, &FusedInsts, &CanBeFlattened,
1445 if (!CanBeFlattened(Op))
1449 ShapeMap[
Op] = ShapeMap[
Op].t();
1453 FusedInsts.insert(cast<Instruction>(Op));
1456 if (
match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1458 auto *NewLoad =
Builder.CreateLoad(
Op->getType(),
Arg);
1459 Op->replaceAllUsesWith(NewLoad);
1460 cast<Instruction>(Op)->eraseFromParent();
1462 }
else if (
match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1464 ToRemove.push_back(cast<Instruction>(Op));
1470 LHS = FlattenArg(LHS);
1474 IsIntVec ?
Builder.CreateMul(LHS, RHS) :
Builder.CreateFMul(LHS, RHS);
1484 cast<Instruction>(Result)->setFastMathFlags(FMF);
1491 FusedInsts.insert(MatMul);
1502 void emitMatrixMultiply(MatrixTy &Result,
const MatrixTy &
A,
1505 const unsigned VF = std::max<unsigned>(
1508 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1510 unsigned R =
Result.getNumRows();
1511 unsigned C =
Result.getNumColumns();
1512 unsigned M =
A.getNumColumns();
1514 bool IsFP =
Result.getElementType()->isFloatingPointTy();
1515 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1516 Result.isColumnMajor() ==
A.isColumnMajor() &&
1517 "operands must agree on matrix layout");
1518 unsigned NumComputeOps = 0;
1520 Builder.setFastMathFlags(FMF);
1522 if (
A.isColumnMajor()) {
1526 for (
unsigned J = 0; J <
C; ++J) {
1529 bool isSumZero = isa<ConstantAggregateZero>(
Result.getColumn(J));
1538 for (
unsigned K = 0;
K <
M; ++
K) {
1541 B.getColumn(IsScalarMatrixTransposed ? K : J),
1542 IsScalarMatrixTransposed ? J : K);
1545 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, L, Splat,
1556 for (
unsigned I = 0;
I <
R; ++
I) {
1558 bool isSumZero = isa<ConstantAggregateZero>(
Result.getRow(
I));
1559 for (
unsigned J = 0; J <
C; J +=
BlockSize) {
1564 Value *Sum =
nullptr;
1565 for (
unsigned K = 0;
K <
M; ++
K) {
1568 A.getVector(IsScalarMatrixTransposed ? K :
I),
1569 IsScalarMatrixTransposed ?
I : K);
1572 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, Splat, R,
1580 Result.addNumComputeOps(NumComputeOps);
1593 return Load->getPointerOperand();
1609 nullptr,
"alias_cont");
1615 nullptr,
"no_alias");
1622 Builder.SetInsertPoint(Check0);
1623 Type *IntPtrTy =
Builder.getIntPtrTy(
Load->getModule()->getDataLayout());
1625 const_cast<Value *
>(StoreLoc.
Ptr), IntPtrTy,
"store.begin");
1628 "store.end",
true,
true);
1630 IntPtrTy,
"load.begin");
1631 Builder.CreateCondBr(
Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1641 "load.end",
true,
true);
1642 Builder.CreateCondBr(
Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1647 auto *VT = cast<FixedVectorType>(
Load->getType());
1650 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1652 Builder.CreateAlloca(ArrayTy,
Load->getPointerAddressSpace());
1659 PHI->addIncoming(
Load->getPointerOperand(), Check0);
1660 PHI->addIncoming(
Load->getPointerOperand(), Check1);
1661 PHI->addIncoming(BC, Copy);
1672 bool isFusionProfitable(
CallInst *MatMul) {
1679 const unsigned R = LShape.NumRows;
1680 const unsigned C = RShape.NumColumns;
1681 const unsigned M = LShape.NumColumns;
1682 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1684 const unsigned VF = std::max<unsigned>(
1696 if (R <= VF &&
C == 1)
1702 unsigned Op0Regs = (
R + VF - 1) / VF * M;
1703 unsigned Op1Regs = (
M + VF - 1) / VF *
C;
1704 return Op0Regs + Op1Regs >
1708 MatrixTy getZeroMatrix(
Type *EltType,
unsigned R,
unsigned C) {
1711 for (
unsigned I = 0;
I <
C; ++
I)
1716 void createTiledLoops(
CallInst *MatMul,
Value *LPtr, ShapeInfo LShape,
1718 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1721 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns,
TileSize);
1728 BasicBlock *InnerBody = TI.CreateTiledLoops(Start,
End, Builder, DTU, *LI);
1732 MatrixTy TileResult;
1734 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1738 auto *Phi =
Builder.CreatePHI(TileVecTy, 2,
"result.vec." +
Twine(
I));
1740 TI.RowLoop.Header->getSingleSuccessor());
1741 TileResult.addVector(Phi);
1750 loadMatrix(LPtr, {},
false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1753 loadMatrix(RPtr, {},
false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1755 emitMatrixMultiply(TileResult,
A,
B, Builder,
true,
false,
1756 getFastMathFlags(MatMul));
1758 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1759 storeMatrix(TileResult,
Store->getPointerOperand(),
Store->getAlign(),
1760 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1761 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1763 for (
unsigned I = 0;
I < TileResult.getNumVectors();
I++)
1764 ColumnPhis[
I]->addIncoming(TileResult.getVector(
I), TI.KLoop.Latch);
1770 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns /
TileSize);
1772 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1779 "Tiling only supported for column-major matrixes at the moment!");
1780 if (!isFusionProfitable(MatMul))
1786 const unsigned R = LShape.NumRows;
1787 const unsigned C = RShape.NumColumns;
1788 const unsigned M = LShape.NumColumns;
1789 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1791 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1792 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1796 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1799 for (
unsigned J = 0; J <
C; J +=
TileSize)
1801 const unsigned TileR = std::min(R -
I,
unsigned(
TileSize));
1802 const unsigned TileC = std::min(
C - J,
unsigned(
TileSize));
1803 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1806 const unsigned TileM = std::min(M - K,
unsigned(
TileSize));
1810 {TileR, TileM}, EltType, Builder);
1814 {TileM, TileC}, EltType, Builder);
1815 emitMatrixMultiply(Res,
A,
B, Builder,
true,
false,
1816 getFastMathFlags(MatMul));
1818 storeMatrix(Res, CPtr,
Store->getAlign(),
Store->isVolatile(), {R, M},
1825 FusedInsts.
insert(Store);
1826 FusedInsts.
insert(MatMul);
1827 Store->eraseFromParent();
1830 FusedInsts.
insert(LoadOp0);
1833 if (LoadOp1 != LoadOp0 && LoadOp1->
hasNUses(0)) {
1834 FusedInsts.
insert(LoadOp1);
1843 void LowerMatrixMultiplyFused(
CallInst *MatMul,
1848 assert(AA && LI &&
"Analyses should be available");
1857 :
match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
T)))) {
1859 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1862 const unsigned R = LShape.NumRows;
1863 const unsigned M = LShape.NumColumns;
1864 const unsigned C = RShape.NumColumns;
1871 MA = getMatrix(
A, ShapeInfo(R, M), Builder);
1872 MB = getMatrix(
T, ShapeInfo(
C, M), Builder);
1875 MA = getMatrix(
T, ShapeInfo(R, M), Builder);
1876 MB = getMatrix(
B, ShapeInfo(
C, M), Builder);
1881 MatrixTy
Result(R,
C, EltType);
1883 emitMatrixMultiply(Result, MA, MB, Builder,
false,
true,
1884 getFastMathFlags(MatMul));
1886 FusedInsts.
insert(MatMul);
1888 FusedInsts.
insert(cast<Instruction>(Transpose));
1889 ToRemove.push_back(cast<Instruction>(Transpose));
1892 Inst2ColumnMatrix[Transpose] = MatrixTy(M,
C, EltType);
1894 finalizeLowering(MatMul, Result, Builder);
1903 auto *LoadOp0 = dyn_cast<LoadInst>(
A);
1904 auto *LoadOp1 = dyn_cast<LoadInst>(
B);
1906 if (LoadOp0 && LoadOp1 && Store) {
1912 for (
unsigned I = 0;
I != WorkList.
size(); ++
I) {
1913 Value *Current = WorkList[
I];
1914 auto *CurrI = dyn_cast<Instruction>(Current);
1917 if (isa<PHINode>(CurrI))
1921 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1924 WorkList.
insert(CurrI->op_begin(), CurrI->op_end());
1931 I->moveBefore(MatMul);
1933 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1939 void LowerMultiply(
CallInst *MatMul) {
1941 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1945 const MatrixTy &Lhs = getMatrix(MatMul->
getArgOperand(0), LShape, Builder);
1946 const MatrixTy &Rhs = getMatrix(MatMul->
getArgOperand(1), RShape, Builder);
1947 assert(Lhs.getElementType() == Rhs.getElementType() &&
1948 "Matrix multiply argument element types do not match.");
1950 const unsigned R = LShape.NumRows;
1951 const unsigned C = RShape.NumColumns;
1952 assert(LShape.NumColumns == RShape.NumRows);
1955 MatrixTy
Result(R,
C, EltType);
1956 assert(Lhs.getElementType() ==
Result.getElementType() &&
1957 "Matrix multiply result element type does not match arguments.");
1959 emitMatrixMultiply(Result, Lhs, Rhs, Builder,
false,
false,
1960 getFastMathFlags(MatMul));
1961 finalizeLowering(MatMul, Result, Builder);
1965 void LowerTranspose(
CallInst *Inst) {
1971 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1973 const unsigned NewNumVecs =
1974 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1975 const unsigned NewNumElts =
1976 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1978 for (
unsigned I = 0;
I < NewNumVecs; ++
I) {
1983 for (
auto J :
enumerate(InputMatrix.vectors())) {
1987 Builder.CreateInsertElement(ResultVector, Elt, J.index());
1989 Result.addVector(ResultVector);
1997 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1998 .addNumExposedTransposes(1),
2004 auto I = ShapeMap.
find(Inst);
2005 if (
I == ShapeMap.
end())
2016 auto I = ShapeMap.
find(StoredVal);
2017 if (
I == ShapeMap.
end())
2028 auto I = ShapeMap.
find(Inst);
2029 if (
I == ShapeMap.
end())
2036 ShapeInfo &Shape =
I->second;
2039 MatrixTy
A = getMatrix(Lhs, Shape, Builder);
2040 MatrixTy
B = getMatrix(Rhs, Shape, Builder);
2041 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
2042 Result.isColumnMajor() ==
A.isColumnMajor() &&
2043 "operands must agree on matrix layout");
2045 Builder.setFastMathFlags(getFastMathFlags(Inst));
2050 case Instruction::Add:
2051 return Builder.CreateAdd(LHS, RHS);
2052 case Instruction::Mul:
2053 return Builder.CreateMul(LHS, RHS);
2054 case Instruction::Sub:
2055 return Builder.CreateSub(LHS, RHS);
2056 case Instruction::FAdd:
2057 return Builder.CreateFAdd(LHS, RHS);
2058 case Instruction::FMul:
2059 return Builder.CreateFMul(LHS, RHS);
2060 case Instruction::FSub:
2061 return Builder.CreateFSub(LHS, RHS);
2067 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
2068 Result.addVector(BuildVectorOp(
A.getVector(
I),
B.getVector(
I)));
2070 finalizeLowering(Inst,
2071 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2079 auto I = ShapeMap.
find(Inst);
2080 if (
I == ShapeMap.
end())
2086 ShapeInfo &Shape =
I->second;
2089 MatrixTy
M = getMatrix(Op, Shape, Builder);
2091 Builder.setFastMathFlags(getFastMathFlags(Inst));
2096 case Instruction::FNeg:
2097 return Builder.CreateFNeg(Op);
2103 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
2104 Result.addVector(BuildVectorOp(
M.getVector(
I)));
2106 finalizeLowering(Inst,
2107 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2116 struct ExprLinearizer {
2117 unsigned LengthToBreak = 100;
2120 unsigned LineLength = 0;
2146 : Stream(Str),
DL(
DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2147 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2149 void indent(
unsigned N) {
2151 for (
unsigned i = 0; i <
N; i++)
2160 void maybeIndent(
unsigned Indent) {
2161 if (LineLength >= LengthToBreak)
2164 if (LineLength == 0)
2169 LineLength += S.
size();
2173 Value *getUnderlyingObjectThroughLoads(
Value *V) {
2175 return getUnderlyingObjectThroughLoads(
Ptr);
2176 else if (
V->getType()->isPointerTy())
2182 bool isMatrix(
Value *V)
const {
return ExprsInSubprogram.
count(V); }
2187 auto M = Inst2Matrix.
find(V);
2188 if (M == Inst2Matrix.
end())
2191 SS <<
M->second.getNumRows();
2193 SS <<
M->second.getNumColumns();
2202 write(
"<no called fn>");
2205 if (!
Name.startswith(
"llvm.matrix")) {
2209 auto *II = cast<IntrinsicInst>(CI);
2217 case Intrinsic::matrix_multiply:
2218 prettyPrintMatrixType(II->
getOperand(0), SS);
2220 prettyPrintMatrixType(II->
getOperand(1), SS);
2223 case Intrinsic::matrix_transpose:
2224 prettyPrintMatrixType(II->
getOperand(0), SS);
2227 case Intrinsic::matrix_column_major_load:
2228 prettyPrintMatrixType(II, SS);
2231 case Intrinsic::matrix_column_major_store:
2232 prettyPrintMatrixType(II->
getOperand(0), SS);
2243 unsigned getNumShapeArgs(
CallInst *CI)
const {
2246 case Intrinsic::matrix_multiply:
2248 case Intrinsic::matrix_transpose:
2250 case Intrinsic::matrix_column_major_load:
2251 case Intrinsic::matrix_column_major_store:
2264 V = getUnderlyingObjectThroughLoads(V);
2265 if (
V->getType()->isPointerTy()) {
2266 if (isa<AllocaInst>(V)) {
2267 Stream <<
"stack addr";
2273 if (!
V->getName().empty()) {
2274 Stream <<
" %" <<
V->getName() <<
"";
2275 LineLength +=
V->getName().size() + 2;
2283 if (
auto *CI = dyn_cast<ConstantInt>(V))
2284 TmpStream << CI->getValue();
2285 else if (isa<Constant>(V))
2286 TmpStream <<
"constant";
2289 TmpStream <<
"matrix";
2291 TmpStream <<
"scalar";
2294 Tmp = std::string(
StringRef(Tmp).trim());
2295 LineLength += Tmp.size();
2302 void linearizeExpr(
Value *Expr,
unsigned Indent,
bool ParentReused,
2303 bool ParentShared) {
2304 auto *
I = cast<Instruction>(Expr);
2305 maybeIndent(Indent);
2309 bool ExprShared =
false;
2312 if (!ParentShared) {
2313 auto SI = Shared.find(Expr);
2314 assert(SI != Shared.end() &&
SI->second.count(Leaf));
2319 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2320 write(
"shared with remark at line " + std::to_string(
DL.getLine()) +
2321 " column " + std::to_string(
DL.getCol()) +
" (");
2323 ExprShared =
SI->second.size() > 1;
2326 bool Reused = !ReusedExprs.
insert(Expr).second;
2327 if (Reused && !ParentReused)
2330 if (
auto *CI = dyn_cast<CallInst>(
I)) {
2334 }
else if (isa<BitCastInst>(Expr)) {
2340 Ops.
append(
I->value_op_begin(),
I->value_op_end());
2341 write(std::string(
I->getOpcodeName()));
2344 write(std::string(
"("));
2346 unsigned NumOpsToBreak = 1;
2347 if (
match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2350 for (
Value *Op : Ops) {
2351 if (Ops.size() > NumOpsToBreak)
2354 maybeIndent(Indent + 1);
2356 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2359 if (Op != Ops.back())
2366 const std::string &getResult() {
2385 struct RemarkGenerator {
2393 : Inst2Matrix(Inst2Matrix), ORE(ORE),
Func(
Func),
2402 for (
auto *Expr : ExprsInSubprogram)
2405 return ExprsInSubprogram.count(U);
2414 void collectSharedInfo(
Value *Leaf,
Value *V,
2418 if (!ExprsInSubprogram.
count(V))
2421 auto I = Shared.insert({
V, {}});
2422 I.first->second.insert(Leaf);
2424 for (
Value *Op : cast<Instruction>(V)->operand_values())
2425 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2431 std::pair<OpInfoTy, OpInfoTy>
2435 if (!ExprsInSubprogram.
count(Root))
2439 if (!ReusedExprs.
insert(Root).second)
2442 OpInfoTy SharedCount;
2445 auto I = Shared.find(Root);
2446 auto CM = Inst2Matrix.
find(Root);
2447 if (
I->second.size() == 1)
2448 Count = CM->second.getOpInfo();
2450 SharedCount = CM->second.getOpInfo();
2452 for (
Value *Op : cast<Instruction>(Root)->operand_values()) {
2453 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2455 SharedCount +=
C.second;
2457 return {Count, SharedCount};
2460 void emitRemarks() {
2468 for (
const auto &KV : Inst2Matrix) {
2469 if (
Func.getSubprogram()) {
2470 auto *
I = cast<Instruction>(KV.first);
2475 I.first->second.push_back(KV.first);
2479 auto I = Subprog2Exprs.
insert({
nullptr, {}});
2480 I.first->second.push_back(KV.first);
2483 for (
auto &KV : Subprog2Exprs) {
2486 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2489 for (
Value *Leaf : Leaves)
2490 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2493 for (
auto *L : Leaves) {
2495 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2506 OpInfoTy Counts, SharedCounts;
2507 std::tie(Counts, SharedCounts) =
2508 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2513 Rem <<
"Lowered with ";
2514 Rem <<
ore::NV(
"NumStores", Counts.NumStores) <<
" stores, "
2515 <<
ore::NV(
"NumLoads", Counts.NumLoads) <<
" loads, "
2516 <<
ore::NV(
"NumComputeOps", Counts.NumComputeOps)
2518 <<
ore::NV(
"NumExposedTransposes", Counts.NumExposedTransposes)
2519 <<
" exposed transposes";
2521 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2522 SharedCounts.NumComputeOps > 0) {
2523 Rem <<
",\nadditionally "
2524 <<
ore::NV(
"NumStores", SharedCounts.NumStores) <<
" stores, "
2525 <<
ore::NV(
"NumLoads", SharedCounts.NumLoads) <<
" loads, "
2526 <<
ore::NV(
"NumFPOps", SharedCounts.NumComputeOps)
2528 <<
" are shared with other expressions";
2531 Rem << (
"\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2542 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2543 Lin.linearizeExpr(L, 0,
false,
false);
2544 return Lin.getResult();
2565 LowerMatrixIntrinsics LMT(
F,
TTI, AA, DT, LI, ORE);
2580 OS, MapClassName2PassName);
2589class LowerMatrixIntrinsicsLegacyPass :
public FunctionPass {
2599 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
2600 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2601 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2602 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2603 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2604 LowerMatrixIntrinsics LMT(
F,
TTI, &AA, &DT, &LI, &ORE);
2605 bool C = LMT.Visit();
2621static const char pass_name[] =
"Lower the matrix intrinsics";
2622char LowerMatrixIntrinsicsLegacyPass::ID = 0;
2633 return new LowerMatrixIntrinsicsLegacyPass();
2642class LowerMatrixIntrinsicsMinimalLegacyPass :
public FunctionPass {
2652 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
2653 LowerMatrixIntrinsics LMT(
F,
TTI,
nullptr,
nullptr,
nullptr,
nullptr);
2654 bool C = LMT.Visit();
2666char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
2675 return new LowerMatrixIntrinsicsMinimalLegacyPass();
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
amdgpu Simplify well known AMD library false FunctionCallee Value * Arg
ReachingDefAnalysis InstSet & ToRemove
static const Function * getParent(const Value *V)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runOnFunction(Function &F, bool PostInlining)
expand Expand reduction intrinsics
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
lower matrix intrinsics minimal
static const char pass_name[]
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
static const char pass_name_minimal[]
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
static const int BlockSize
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
A manager for alias analyses.
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
reverse_iterator rbegin()
InstListType::reverse_iterator reverse_iterator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
static Constant * get(Type *Ty, double V)
This returns a ConstantFP, or a vector containing a splat of a ConstantFP, for the specified value in...
This is the shared class of boolean and integer constants.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Base class for scope-like contexts.
A parsed version of the target data layout string in and methods for querying it.
DILocation * getInlinedAt() const
Analysis pass which computes a DominatorTree.
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Convenience struct for specifying and reasoning about fast-math flags.
void setAllowContract(bool B=true)
bool allowReassoc() const
Flag queries.
bool allowContract() const
Class to represent fixed width SIMD vectors.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
FunctionPass class - This class is used to implement most global optimizations.
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const BasicBlock * getParent() const
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
An instruction for reading from memory.
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Align getAlign() const
Return the alignment of the access that is being performed.
uint64_t getValue() const
Analysis pass that exposes the LoopInfo for a function.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
iterator find(const KeyT &Key)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
A vector that has set insertion semantics.
size_type size() const
Determine the number of elements in the SetVector.
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
A SetVector that performs no allocations if smaller than a certain size.
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
bool isVolatile() const
Return true if this is a store to a volatile memory location.
StringRef - Represent a constant reference to a string, i.e.
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
constexpr size_t size() const
size - Get the string size.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
bool isIntegerTy() const
True if this is an instance of IntegerType.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
UnaryOps getOpcode() const
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
iterator find(const KeyT &Val)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
bool erase(const KeyT &Val)
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
user_iterator user_begin()
bool hasOneUse() const
Return true if there is exactly one use of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
Type * getElementType() const
constexpr ScalarTy getFixedValue() const
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
A raw_ostream that writes to an std::string.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
specific_intval< false > m_SpecificInt(APInt V)
Match a specific integer value or vector with all elements equal to the value.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
DiagnosticInfoOptimizationBase::Argument NV
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
const_iterator end(StringRef path)
Get end iterator over path.
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
std::string & operator+=(std::string &buffer, StringRef string)
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, bool ContinueOnCuIndexOverflow)
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &)
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
auto reverse(ContainerTy &&C)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Pass * createLowerMatrixIntrinsicsPass()
void initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(PassRegistry &)
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ FMulAdd
Fused multiply-add of floats (a * b + c).
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Pass * createLowerMatrixIntrinsicsMinimalPass()
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
A CRTP mix-in to automatically provide informational APIs needed for passes.
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....