50 using namespace PatternMatch;
52 #define DEBUG_TYPE "lower-matrix-intrinsics"
56 cl::desc(
"Enable/disable fusing matrix instructions."));
61 "Tile size for matrix instruction fusion using square-shaped tiles."));
64 cl::desc(
"Generate loop nest for tiling."));
67 cl::desc(
"Force matrix instruction fusion even if not profitable."));
70 cl::desc(
"Allow the use of FMAs if available and profitable. This may "
71 "result in different results, due to less rounding error."));
77 cl::desc(
"Sets the default matrix layout"),
79 "Use column-major layout"),
81 "Use row-major layout")));
86 if (
auto *Subprogram = dyn_cast<DISubprogram>(
Scope))
134 unsigned NumElements,
Type *EltType,
137 assert((!isa<ConstantInt>(Stride) ||
138 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
139 "Stride must be >= the number of elements in the result vector.");
140 unsigned AS = cast<PointerType>(
BasePtr->getType())->getAddressSpace();
143 Value *VecStart =
Builder.CreateMul(VecIdx, Stride,
"vec.start");
147 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->
isZero())
150 VecStart =
Builder.CreateGEP(EltType, BasePtr, VecStart,
"vec.gep");
156 return Builder.CreatePointerCast(VecStart, VecPtrType,
"vec.cast");
182 class LowerMatrixIntrinsics {
194 unsigned NumStores = 0;
196 unsigned NumLoads = 0;
198 unsigned NumComputeOps = 0;
202 unsigned NumExposedTransposes = 0;
205 NumStores +=
RHS.NumStores;
206 NumLoads +=
RHS.NumLoads;
207 NumComputeOps +=
RHS.NumComputeOps;
208 NumExposedTransposes +=
RHS.NumExposedTransposes;
220 bool IsColumnMajor =
true;
225 : Vectors(Vectors.
begin(), Vectors.
end()),
227 MatrixTy(
unsigned NumRows,
unsigned NumColumns,
Type *EltTy)
230 unsigned D = isColumnMajor() ? NumColumns : NumRows;
231 for (
unsigned J = 0; J <
D; ++J)
233 EltTy, isColumnMajor() ? NumRows : NumColumns)));
236 Value *getVector(
unsigned i)
const {
return Vectors[
i]; }
237 Value *getColumn(
unsigned i)
const {
238 assert(isColumnMajor() &&
"only supported for column-major matrixes");
241 Value *getRow(
unsigned i)
const {
242 assert(!isColumnMajor() &&
"only supported for row-major matrixes");
246 void setVector(
unsigned i,
Value *V) { Vectors[
i] = V; }
248 Type *getElementType()
const {
return getVectorTy()->getElementType(); }
250 unsigned getNumVectors()
const {
252 return getNumColumns();
256 unsigned getNumColumns()
const {
258 return Vectors.size();
260 assert(Vectors.size() > 0 &&
"Cannot call getNumRows without columns");
261 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
264 unsigned getNumRows()
const {
265 if (isColumnMajor()) {
266 assert(Vectors.size() > 0 &&
"Cannot call getNumRows without columns");
267 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
269 return Vectors.size();
272 void addVector(
Value *V) { Vectors.push_back(V); }
274 assert(isColumnMajor() &&
"only supported for column-major matrixes");
275 return getVectorTy();
279 return cast<VectorType>(Vectors[0]->
getType());
284 "columns() only supported for column-major matrixes");
285 return make_range(Vectors.begin(), Vectors.end());
289 return make_range(Vectors.begin(), Vectors.end());
295 return Vectors.size() == 1 ? Vectors[0]
299 MatrixTy &addNumLoads(
unsigned N) {
300 OpInfo.NumLoads +=
N;
304 void setNumLoads(
unsigned N) { OpInfo.NumLoads =
N; }
306 MatrixTy &addNumStores(
unsigned N) {
307 OpInfo.NumStores +=
N;
311 MatrixTy &addNumExposedTransposes(
unsigned N) {
312 OpInfo.NumExposedTransposes +=
N;
316 MatrixTy &addNumComputeOps(
unsigned N) {
317 OpInfo.NumComputeOps +=
N;
321 unsigned getNumStores()
const {
return OpInfo.NumStores; }
322 unsigned getNumLoads()
const {
return OpInfo.NumLoads; }
323 unsigned getNumComputeOps()
const {
return OpInfo.NumComputeOps; }
325 const OpInfoTy &getOpInfo()
const {
return OpInfo; }
327 bool isColumnMajor()
const {
return IsColumnMajor; }
329 unsigned getStride()
const {
332 return getNumColumns();
340 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(
I);
341 return Builder.CreateShuffleVector(
353 ShapeInfo(
unsigned NumRows = 0,
unsigned NumColumns = 0)
354 : NumRows(NumRows), NumColumns(NumColumns),
362 return NumRows == other.NumRows && NumColumns == other.NumColumns;
364 bool operator!=(
const ShapeInfo &other) {
return !(*
this == other); }
368 operator bool()
const {
369 assert(NumRows == 0 || NumColumns != 0);
373 unsigned getStride()
const {
379 unsigned getNumVectors()
const {
408 if (isa<FPMathOperator>(*Inst))
423 unsigned getNumOps(
Type *VT) {
424 assert(isa<VectorType>(VT) &&
"Expected vector type");
426 cast<FixedVectorType>(VT)->getNumElements());
430 bool isMinimal()
const {
436 unsigned getNumOps(
Type *
ST,
unsigned N) {
437 return std::ceil((
ST->getPrimitiveSizeInBits() *
N).getFixedSize() /
448 MatrixTy getMatrix(
Value *MatrixVal,
const ShapeInfo &
SI,
451 assert(VType &&
"MatrixVal must be a vector type");
452 assert(cast<FixedVectorType>(VType)->getNumElements() ==
453 SI.NumRows *
SI.NumColumns &&
454 "The vector size must match the number of matrix elements");
460 auto Found = Inst2ColumnMatrix.
find(MatrixVal);
461 if (Found != Inst2ColumnMatrix.
end()) {
462 MatrixTy &
M = Found->second;
465 if (
SI.NumRows ==
M.getNumRows() &&
SI.NumColumns ==
M.getNumColumns())
468 MatrixVal =
M.embedInVector(
Builder);
473 for (
unsigned MaskStart = 0;
474 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
475 MaskStart +=
SI.getStride()) {
479 SplitVecs.push_back(V);
487 bool setShapeInfo(
Value *V, ShapeInfo Shape) {
488 assert(Shape &&
"Shape not set");
489 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
492 auto SIter = ShapeMap.
find(V);
493 if (SIter != ShapeMap.
end()) {
495 << SIter->second.NumRows <<
" "
496 << SIter->second.NumColumns <<
" for " << *V <<
"\n");
500 ShapeMap.
insert({V, Shape});
501 LLVM_DEBUG(
dbgs() <<
" " << Shape.NumRows <<
" x " << Shape.NumColumns
502 <<
" for " << *V <<
"\n");
506 bool isUniformShape(
Value *V) {
511 switch (
I->getOpcode()) {
512 case Instruction::FAdd:
513 case Instruction::FSub:
514 case Instruction::FMul:
515 case Instruction::FNeg:
518 case Instruction::Sub:
527 bool supportsShapeInfo(
Value *V) {
535 case Intrinsic::matrix_multiply:
536 case Intrinsic::matrix_transpose:
537 case Intrinsic::matrix_column_major_load:
538 case Intrinsic::matrix_column_major_store:
543 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
557 while (!WorkList.empty()) {
561 bool Propagate =
false;
568 if (
match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
571 Propagate = setShapeInfo(Inst, {
M, K});
572 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
575 Propagate = setShapeInfo(Inst, {
N,
M});
576 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
579 Propagate = setShapeInfo(Inst, {
N,
M});
580 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
583 Propagate = setShapeInfo(Inst, {
M,
N});
585 auto OpShape = ShapeMap.
find(MatrixA);
586 if (OpShape != ShapeMap.
end())
587 setShapeInfo(Inst, OpShape->second);
589 }
else if (isUniformShape(Inst)) {
592 auto OpShape = ShapeMap.
find(
Op.get());
593 if (OpShape != ShapeMap.
end()) {
594 Propagate |= setShapeInfo(Inst, OpShape->second);
601 NewWorkList.push_back(Inst);
604 WorkList.push_back(cast<Instruction>(
User));
617 auto pushInstruction = [](
Value *V,
621 WorkList.push_back(
I);
627 while (!WorkList.empty()) {
630 size_t BeforeProcessingV = WorkList.size();
631 if (!isa<Instruction>(V))
639 if (
match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
642 if (setShapeInfo(MatrixA, {
M,
N}))
643 pushInstruction(MatrixA, WorkList);
645 if (setShapeInfo(MatrixB, {
N, K}))
646 pushInstruction(MatrixB, WorkList);
648 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
651 if (setShapeInfo(MatrixA, {
M,
N}))
652 pushInstruction(MatrixA, WorkList);
653 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
656 if (setShapeInfo(MatrixA, {
M,
N})) {
657 pushInstruction(MatrixA, WorkList);
659 }
else if (isa<LoadInst>(V) ||
660 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
662 }
else if (isa<StoreInst>(V)) {
665 }
else if (isUniformShape(V)) {
667 ShapeInfo Shape = ShapeMap[V];
668 for (
Use &U : cast<Instruction>(V)->operands()) {
669 if (setShapeInfo(U.get(), Shape))
670 pushInstruction(U.get(), WorkList);
676 for (
size_t I = BeforeProcessingV;
I != WorkList.size();
I++)
678 if (isa<Instruction>(U) && V != U)
679 NewWorkList.push_back(cast<Instruction>(U));
685 void optimizeTransposes() {
690 auto S = ShapeMap.
find(&Old);
691 if (
S != ShapeMap.
end()) {
693 if (supportsShapeInfo(New))
694 ShapeMap.
insert({New,
S->second});
702 for (
auto II =
BB.rbegin(); II !=
BB.rend();) {
707 auto EraseFromParent = [&II, &
BB](
Value *V) {
708 auto *Inst = cast<Instruction>(V);
710 if (II !=
BB.rend() && Inst == &*II) {
725 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
TA)))) {
730 m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(TATA)))) {
731 ReplaceAllUsesWith(
I, TATA);
738 else if (
match(
TA, m_Intrinsic<Intrinsic::matrix_multiply>(
746 setShapeInfo(T0, {
C, K});
750 setShapeInfo(
T1, {K,
R});
751 NewInst =
Builder.CreateMatrixMultiply(T0,
T1,
C->getZExtValue(),
753 R->getZExtValue(),
"mmul");
754 ReplaceAllUsesWith(
I, NewInst);
773 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>(
776 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT))) &&
782 setShapeInfo(M, {
C,
R});
784 M,
C->getZExtValue(),
R->getZExtValue());
785 ReplaceAllUsesWith(
I, NewInst);
789 cast<Instruction>(A)->eraseFromParent();
790 if (A !=
B &&
B->use_empty())
791 cast<Instruction>(
B)->eraseFromParent();
809 case Intrinsic::matrix_multiply:
810 case Intrinsic::matrix_transpose:
811 case Intrinsic::matrix_column_major_load:
812 case Intrinsic::matrix_column_major_store:
813 WorkList.push_back(&Inst);
821 if (WorkList.empty())
825 while (!WorkList.empty()) {
826 WorkList = propagateShapeForward(WorkList);
827 WorkList = propagateShapeBackward(WorkList);
831 optimizeTransposes();
833 dbgs() <<
"Dump after matrix transpose optimization:\n";
838 bool Changed =
false;
845 for (
auto *
BB : RPOT)
847 if (ShapeMap.
find(&
I) == ShapeMap.
end())
849 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>()))
850 MaybeFusableInsts.push_back(cast<CallInst>(&
I));
851 MatrixInsts.push_back(&
I);
856 for (
CallInst *CI : MaybeFusableInsts)
857 LowerMatrixMultiplyFused(CI, FusedInsts);
858 Changed = !FusedInsts.
empty();
862 if (FusedInsts.
count(Inst))
867 if (
CallInst *CInst = dyn_cast<CallInst>(Inst))
868 Changed |= VisitCallInst(CInst);
872 if (
auto *BinOp = dyn_cast<BinaryOperator>(Inst))
873 Changed |= VisitBinaryOperator(BinOp);
874 if (
auto *UnOp = dyn_cast<UnaryOperator>(Inst))
875 Changed |= VisitUnaryOperator(UnOp);
877 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1,
Builder);
879 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2,
Builder);
883 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
884 RemarkGen.emitRemarks();
899 if (
auto *Undefed = dyn_cast<Instruction>(U.getUser()))
900 UndefedInsts.
insert(Undefed);
904 UndefedInsts.
erase(Inst);
906 if (!UndefedInsts.
empty()) {
908 dbgs() <<
"Undefed but present instructions:\n";
909 for (
auto *
I : UndefedInsts)
910 dbgs() << *
I <<
"\n";
919 unsigned AS = cast<PointerType>(
BasePtr->getType())->getAddressSpace();
921 return Builder.CreatePointerCast(BasePtr, EltPtrType);
925 bool VisitCallInst(
CallInst *Inst) {
930 case Intrinsic::matrix_multiply:
933 case Intrinsic::matrix_transpose:
934 LowerTranspose(Inst);
936 case Intrinsic::matrix_column_major_load:
937 LowerColumnMajorLoad(Inst);
939 case Intrinsic::matrix_column_major_store:
940 LowerColumnMajorStore(Inst);
953 Align getAlignForIndex(
unsigned Idx,
Value *Stride,
Type *ElementTy,
955 Align InitialAlign =
DL.getValueOrABITypeAlignment(A, ElementTy);
959 TypeSize ElementSizeInBits =
DL.getTypeSizeInBits(ElementTy);
960 if (
auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
962 ConstStride->getZExtValue() * ElementSizeInBits / 8;
972 auto *VType = cast<VectorType>(Ty);
977 for (
unsigned I = 0,
E = Shape.getNumVectors();
I <
E; ++
I) {
980 Stride, Shape.getStride(), EltTy,
Builder);
982 VecTy,
GEP, getAlignForIndex(
I, Stride, EltTy, MAlign),
987 return Result.addNumLoads(getNumOps(
Result.getVectorTy()) *
995 ShapeInfo ResultShape,
Type *EltTy,
1001 unsigned AS = cast<PointerType>(MatrixPtr->
getType())->getAddressSpace();
1004 Value *TileStart =
Builder.CreateGEP(EltTy, EltPtr, Offset);
1006 ResultShape.NumColumns);
1009 Builder.CreatePointerCast(TileStart, TilePtrTy,
"col.cast");
1011 return loadMatrix(TileTy, TilePtr,
Align,
1020 finalizeLowering(Inst,
1029 void LowerColumnMajorLoad(
CallInst *Inst) {
1031 "Intrinsic only supports column-major layout!");
1036 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1041 void storeMatrix(
const MatrixTy &StoreVal,
Value *MatrixPtr,
1047 unsigned AS = cast<PointerType>(MatrixPtr->
getType())->getAddressSpace();
1050 Value *TileStart =
Builder.CreateGEP(EltTy, EltPtr, Offset);
1052 StoreVal.getNumColumns());
1055 Builder.CreatePointerCast(TileStart, TilePtrTy,
"col.cast");
1057 storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1063 MatrixTy storeMatrix(
Type *Ty, MatrixTy StoreVal,
Value *Ptr,
1066 auto VType = cast<VectorType>(Ty);
1068 for (
auto Vec :
enumerate(StoreVal.vectors())) {
1075 getAlignForIndex(Vec.index(), Stride,
1080 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1081 StoreVal.getNumVectors());
1088 auto StoreVal = getMatrix(Matrix, Shape,
Builder);
1089 finalizeLowering(Inst,
1090 storeMatrix(
Matrix->getType(), StoreVal, Ptr, A, Stride,
1098 void LowerColumnMajorStore(
CallInst *Inst) {
1100 "Intrinsic only supports column-major layout!");
1106 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1114 unsigned BlockNumElts =
1115 cast<FixedVectorType>(
Block->getType())->getNumElements();
1116 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1117 assert(NumElts >= BlockNumElts &&
"Too few elements for current block");
1126 for (
i = 0;
i <
I;
i++)
1129 unsigned VecNumElts =
1130 cast<FixedVectorType>(Col->getType())->getNumElements();
1131 for (;
i <
I + BlockNumElts;
i++)
1132 Mask.push_back(
i -
I + VecNumElts);
1134 for (;
i < VecNumElts;
i++)
1137 return Builder.CreateShuffleVector(Col, Block,
Mask);
1142 unsigned &NumComputeOps) {
1143 NumComputeOps += getNumOps(
A->getType());
1148 if (AllowContraction) {
1152 Func.getParent(), Intrinsic::fmuladd,
A->getType());
1153 return Builder.CreateCall(FMulAdd, {
A,
B, Sum});
1155 NumComputeOps += getNumOps(
A->getType());
1157 return Builder.CreateFAdd(Sum, Mul);
1160 NumComputeOps += getNumOps(
A->getType());
1162 return Builder.CreateAdd(Sum, Mul);
1170 void finalizeLowering(
Instruction *Inst, MatrixTy Matrix,
1172 auto inserted = Inst2ColumnMatrix.
insert(std::make_pair(Inst, Matrix));
1174 assert(inserted.second &&
"multiple matrix lowering mapping");
1177 Value *Flattened =
nullptr;
1179 if (ShapeMap.
find(U.getUser()) == ShapeMap.
end()) {
1194 void emitMatrixMultiply(MatrixTy &Result,
const MatrixTy &A,
1197 const unsigned VF = std::max<unsigned>(
1200 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
1202 unsigned R =
Result.getNumRows();
1203 unsigned C =
Result.getNumColumns();
1204 unsigned M =
A.getNumColumns();
1206 bool IsFP =
Result.getElementType()->isFloatingPointTy();
1207 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1208 Result.isColumnMajor() ==
A.isColumnMajor() &&
1209 "operands must agree on matrix layout");
1210 unsigned NumComputeOps = 0;
1212 Builder.setFastMathFlags(FMF);
1214 if (
A.isColumnMajor()) {
1218 for (
unsigned J = 0; J <
C; ++J) {
1221 bool isSumZero = isa<ConstantAggregateZero>(
Result.getColumn(J));
1230 for (
unsigned K = 0; K <
M; ++K) {
1233 B.getColumn(IsScalarMatrixTransposed ? K : J),
1234 IsScalarMatrixTransposed ? J : K);
1237 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, L, Splat,
1248 for (
unsigned I = 0;
I <
R; ++
I) {
1250 bool isSumZero = isa<ConstantAggregateZero>(
Result.getRow(
I));
1251 for (
unsigned J = 0; J <
C; J +=
BlockSize) {
1256 Value *Sum =
nullptr;
1257 for (
unsigned K = 0; K <
M; ++K) {
1260 A.getVector(IsScalarMatrixTransposed ? K :
I),
1261 IsScalarMatrixTransposed ?
I : K);
1264 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, Splat, R,
1272 Result.addNumComputeOps(NumComputeOps);
1284 if (
AA->isNoAlias(LoadLoc, StoreLoc))
1285 return Load->getPointerOperand();
1297 DTUpdates.push_back({DT->
Delete, Check0, Succ});
1301 nullptr,
"alias_cont");
1307 nullptr,
"no_alias");
1314 Builder.SetInsertPoint(Check0);
1315 Type *IntPtrTy =
Builder.getIntPtrTy(
Load->getModule()->getDataLayout());
1317 const_cast<Value *
>(StoreLoc.
Ptr), IntPtrTy,
"store.begin");
1320 "store.end",
true,
true);
1322 IntPtrTy,
"load.begin");
1323 Builder.CreateCondBr(
Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1333 "load.end",
true,
true);
1334 Builder.CreateCondBr(
Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1339 auto *VT = cast<FixedVectorType>(
Load->getType());
1342 auto *ArrayTy =
ArrayType::get(VT->getElementType(), VT->getNumElements());
1344 Builder.CreateAlloca(ArrayTy,
Load->getPointerAddressSpace());
1356 DTUpdates.push_back({DT->
Insert, Check0, Check1});
1357 DTUpdates.push_back({DT->
Insert, Check0, Fusion});
1358 DTUpdates.push_back({DT->
Insert, Check1,
Copy});
1359 DTUpdates.push_back({DT->
Insert, Check1, Fusion});
1364 bool isFusionProfitable(
CallInst *MatMul) {
1371 const unsigned R = LShape.NumRows;
1372 const unsigned C = RShape.NumColumns;
1373 const unsigned M = LShape.NumColumns;
1374 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1376 const unsigned VF = std::max<unsigned>(
1388 if (R <= VF &&
C == 1)
1394 unsigned Op0Regs = (
R + VF - 1) / VF * M;
1395 unsigned Op1Regs = (
M + VF - 1) / VF *
C;
1396 return Op0Regs + Op1Regs >
1400 MatrixTy getZeroMatrix(
Type *EltType,
unsigned R,
unsigned C) {
1403 for (
unsigned I = 0;
I <
C; ++
I)
1408 void createTiledLoops(
CallInst *MatMul,
Value *LPtr, ShapeInfo LShape,
1410 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1413 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns,
TileSize);
1424 MatrixTy TileResult;
1426 Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
1430 auto *Phi =
Builder.CreatePHI(TileVecTy, 2,
"result.vec." +
Twine(
I));
1432 TI.RowLoopHeader->getSingleSuccessor());
1433 TileResult.addVector(Phi);
1434 ColumnPhis.push_back(Phi);
1441 MatrixTy
A = loadMatrix(LPtr, {},
false, LShape, TI.CurrentRow, TI.CurrentK,
1443 MatrixTy
B = loadMatrix(RPtr, {},
false, RShape, TI.CurrentK, TI.CurrentCol,
1445 emitMatrixMultiply(TileResult, A,
B,
Builder,
true,
false,
1446 getFastMathFlags(MatMul));
1448 Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
1449 storeMatrix(TileResult,
Store->getPointerOperand(),
Store->getAlign(),
1450 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1451 TI.CurrentRow, TI.CurrentCol, EltType,
Builder);
1453 for (
unsigned I = 0;
I < TileResult.getNumVectors();
I++)
1454 ColumnPhis[
I]->addIncoming(TileResult.getVector(
I), TI.InnerLoopLatch);
1460 unsigned InnerLoopUnrollCount =
std::min(10u, LShape.NumColumns /
TileSize);
1462 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1469 "Tiling only supported for column-major matrixes at the moment!");
1470 if (!isFusionProfitable(MatMul))
1476 const unsigned R = LShape.NumRows;
1477 const unsigned C = RShape.NumColumns;
1478 const unsigned M = LShape.NumColumns;
1479 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1481 Value *APtr = getNonAliasingPointer(LoadOp0,
Store, MatMul);
1482 Value *BPtr = getNonAliasingPointer(LoadOp1,
Store, MatMul);
1486 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape,
Store);
1489 for (
unsigned J = 0; J <
C; J +=
TileSize)
1493 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1495 for (
unsigned K = 0; K <
M; K +=
TileSize) {
1500 {TileR, TileM}, EltType,
Builder);
1504 {TileM, TileC}, EltType,
Builder);
1505 emitMatrixMultiply(Res, A,
B,
Builder,
true,
false,
1506 getFastMathFlags(MatMul));
1508 storeMatrix(Res, CPtr,
Store->getAlign(),
Store->isVolatile(), {R, M},
1516 FusedInsts.
insert(MatMul);
1517 Store->eraseFromParent();
1520 FusedInsts.
insert(LoadOp0);
1523 if (LoadOp1 != LoadOp0 && LoadOp1->
hasNUses(0)) {
1524 FusedInsts.
insert(LoadOp1);
1533 void LowerMatrixMultiplyFused(
CallInst *MatMul,
1538 assert(
AA && LI &&
"Analyses should be available");
1547 :
match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
T)))) {
1549 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1552 const unsigned R = LShape.NumRows;
1553 const unsigned M = LShape.NumColumns;
1554 const unsigned C = RShape.NumColumns;
1561 MA = getMatrix(A, ShapeInfo(R, M),
Builder);
1562 MB = getMatrix(
T, ShapeInfo(
C, M),
Builder);
1565 MA = getMatrix(
T, ShapeInfo(R, M),
Builder);
1566 MB = getMatrix(
B, ShapeInfo(
C, M),
Builder);
1571 MatrixTy
Result(R,
C, EltType);
1573 emitMatrixMultiply(Result, MA, MB,
Builder,
false,
true,
1574 getFastMathFlags(MatMul));
1576 FusedInsts.
insert(MatMul);
1578 FusedInsts.
insert(cast<Instruction>(Transpose));
1579 ToRemove.push_back(cast<Instruction>(Transpose));
1582 Inst2ColumnMatrix[Transpose] = MatrixTy(M,
C, EltType);
1584 finalizeLowering(MatMul, Result,
Builder);
1593 auto *LoadOp0 = dyn_cast<LoadInst>(A);
1594 auto *LoadOp1 = dyn_cast<LoadInst>(
B);
1596 if (LoadOp0 && LoadOp1 &&
Store) {
1602 for (
unsigned I = 0;
I != WorkList.
size(); ++
I) {
1603 Value *Current = WorkList[
I];
1604 auto *CurrI = dyn_cast<Instruction>(Current);
1607 if (isa<PHINode>(CurrI))
1611 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1613 ToHoist.push_back(CurrI);
1614 WorkList.
insert(CurrI->op_begin(), CurrI->op_end());
1621 I->moveBefore(MatMul);
1623 emitSIMDTiling(MatMul, LoadOp0, LoadOp1,
Store, FusedInsts);
1629 void LowerMultiply(
CallInst *MatMul) {
1631 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1637 assert(Lhs.getElementType() == Rhs.getElementType() &&
1638 "Matrix multiply argument element types do not match.");
1640 const unsigned R = LShape.NumRows;
1641 const unsigned C = RShape.NumColumns;
1642 assert(LShape.NumColumns == RShape.NumRows);
1645 MatrixTy
Result(R,
C, EltType);
1646 assert(Lhs.getElementType() ==
Result.getElementType() &&
1647 "Matrix multiply result element type does not match arguments.");
1649 emitMatrixMultiply(Result, Lhs, Rhs,
Builder,
false,
false,
1650 getFastMathFlags(MatMul));
1651 finalizeLowering(MatMul, Result,
Builder);
1655 void LowerTranspose(
CallInst *Inst) {
1661 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape,
Builder);
1663 const unsigned NewNumVecs =
1664 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1665 const unsigned NewNumElts =
1666 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1668 for (
unsigned I = 0;
I < NewNumVecs; ++
I) {
1673 for (
auto J :
enumerate(InputMatrix.vectors())) {
1677 Builder.CreateInsertElement(ResultVector, Elt, J.index());
1679 Result.addVector(ResultVector);
1687 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1688 .addNumExposedTransposes(1),
1694 auto I = ShapeMap.
find(Inst);
1695 if (
I == ShapeMap.
end())
1706 auto I = ShapeMap.
find(StoredVal);
1707 if (
I == ShapeMap.
end())
1718 auto I = ShapeMap.
find(Inst);
1719 if (
I == ShapeMap.
end())
1726 ShapeInfo &Shape =
I->second;
1729 MatrixTy
A = getMatrix(Lhs, Shape,
Builder);
1730 MatrixTy
B = getMatrix(Rhs, Shape,
Builder);
1731 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1732 Result.isColumnMajor() ==
A.isColumnMajor() &&
1733 "operands must agree on matrix layout");
1735 Builder.setFastMathFlags(getFastMathFlags(Inst));
1744 case Instruction::Sub:
1746 case Instruction::FAdd:
1748 case Instruction::FMul:
1750 case Instruction::FSub:
1757 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
1758 Result.addVector(BuildVectorOp(
A.getVector(
I),
B.getVector(
I)));
1760 finalizeLowering(Inst,
1761 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
1769 auto I = ShapeMap.
find(Inst);
1770 if (
I == ShapeMap.
end())
1776 ShapeInfo &Shape =
I->second;
1781 Builder.setFastMathFlags(getFastMathFlags(Inst));
1786 case Instruction::FNeg:
1793 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
1794 Result.addVector(BuildVectorOp(
M.getVector(
I)));
1796 finalizeLowering(Inst,
1797 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
1806 struct ExprLinearizer {
1807 unsigned LengthToBreak = 100;
1810 unsigned LineLength = 0;
1837 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
1839 void indent(
unsigned N) {
1841 for (
unsigned i = 0;
i <
N;
i++)
1850 void maybeIndent(
unsigned Indent) {
1851 if (LineLength >= LengthToBreak)
1854 if (LineLength == 0)
1859 LineLength +=
S.size();
1863 Value *getUnderlyingObjectThroughLoads(
Value *V) {
1865 return getUnderlyingObjectThroughLoads(Ptr);
1872 bool isMatrix(
Value *V)
const {
return ExprsInSubprogram.
count(V); }
1877 auto M = Inst2Matrix.
find(V);
1878 if (M == Inst2Matrix.
end())
1881 SS <<
M->second.getNumRows();
1883 SS <<
M->second.getNumColumns();
1892 write(
"<no called fn>");
1895 if (!
Name.startswith(
"llvm.matrix")) {
1899 auto *II = cast<IntrinsicInst>(CI);
1907 case Intrinsic::matrix_multiply:
1913 case Intrinsic::matrix_transpose:
1917 case Intrinsic::matrix_column_major_load:
1918 prettyPrintMatrixType(II,
SS);
1921 case Intrinsic::matrix_column_major_store:
1933 unsigned getNumShapeArgs(
CallInst *CI)
const {
1936 case Intrinsic::matrix_multiply:
1938 case Intrinsic::matrix_transpose:
1940 case Intrinsic::matrix_column_major_load:
1941 case Intrinsic::matrix_column_major_store:
1954 V = getUnderlyingObjectThroughLoads(V);
1956 if (isa<AllocaInst>(V)) {
1973 if (
auto *CI = dyn_cast<ConstantInt>(V))
1974 TmpStream << CI->getValue();
1975 else if (isa<Constant>(V))
1976 TmpStream <<
"constant";
1979 TmpStream <<
"matrix";
1981 TmpStream <<
"scalar";
1984 Tmp = std::string(
StringRef(Tmp).trim());
1985 LineLength += Tmp.size();
1992 void linearizeExpr(
Value *Expr,
unsigned Indent,
bool ParentReused,
1993 bool ParentShared) {
1994 auto *
I = cast<Instruction>(Expr);
1995 maybeIndent(Indent);
1999 bool ExprShared =
false;
2002 if (!ParentShared) {
2009 DebugLoc DL = cast<Instruction>(
S)->getDebugLoc();
2013 ExprShared =
SI->second.size() > 1;
2016 bool Reused = !ReusedExprs.
insert(Expr).second;
2017 if (Reused && !ParentReused)
2020 if (
auto *CI = dyn_cast<CallInst>(
I)) {
2024 }
else if (isa<BitCastInst>(Expr)) {
2030 Ops.
append(
I->value_op_begin(),
I->value_op_end());
2031 write(std::string(
I->getOpcodeName()));
2034 write(std::string(
"("));
2036 unsigned NumOpsToBreak = 1;
2037 if (
match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2041 if (Ops.size() > NumOpsToBreak)
2044 maybeIndent(Indent + 1);
2046 linearizeExpr(
Op, Indent + 1, Reused, ExprShared);
2049 if (
Op != Ops.back())
2056 const std::string &getResult() {
2075 struct RemarkGenerator {
2083 : Inst2Matrix(Inst2Matrix), ORE(ORE),
Func(
Func),
2092 for (
auto *Expr : ExprsInSubprogram)
2095 return ExprsInSubprogram.count(U);
2097 Leaves.push_back(Expr);
2104 void collectSharedInfo(
Value *Leaf,
Value *V,
2108 if (!ExprsInSubprogram.
count(V))
2111 auto I =
Shared.insert({V, {}});
2112 I.first->second.insert(Leaf);
2114 for (
Value *
Op : cast<Instruction>(V)->operand_values())
2115 collectSharedInfo(Leaf,
Op, ExprsInSubprogram, Shared);
2121 std::pair<OpInfoTy, OpInfoTy>
2125 if (!ExprsInSubprogram.
count(Root))
2129 if (!ReusedExprs.
insert(Root).second)
2132 OpInfoTy SharedCount;
2136 auto CM = Inst2Matrix.
find(Root);
2137 if (
I->second.size() == 1)
2138 Count = CM->second.getOpInfo();
2140 SharedCount = CM->second.getOpInfo();
2142 for (
Value *
Op : cast<Instruction>(Root)->operand_values()) {
2143 auto C = sumOpInfos(
Op, ReusedExprs, ExprsInSubprogram, Shared);
2145 SharedCount +=
C.second;
2147 return {Count, SharedCount};
2150 void emitRemarks() {
2158 for (
auto &KV : Inst2Matrix) {
2159 if (
Func.getSubprogram()) {
2160 auto *
I = cast<Instruction>(KV.first);
2165 I.first->second.push_back(KV.first);
2169 auto I = Subprog2Exprs.
insert({
nullptr, {}});
2170 I.first->second.push_back(KV.first);
2173 for (
auto &KV : Subprog2Exprs) {
2176 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2179 for (
Value *Leaf : Leaves)
2180 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2183 for (
auto *L : Leaves) {
2185 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2196 OpInfoTy Counts, SharedCounts;
2197 std::tie(Counts, SharedCounts) =
2198 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2203 Rem <<
"Lowered with ";
2204 Rem <<
ore::NV(
"NumStores", Counts.NumStores) <<
" stores, "
2205 <<
ore::NV(
"NumLoads", Counts.NumLoads) <<
" loads, "
2206 <<
ore::NV(
"NumComputeOps", Counts.NumComputeOps)
2208 <<
ore::NV(
"NumExposedTransposes", Counts.NumExposedTransposes)
2209 <<
" exposed transposes";
2211 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2212 SharedCounts.NumComputeOps > 0) {
2213 Rem <<
",\nadditionally "
2214 <<
ore::NV(
"NumStores", SharedCounts.NumStores) <<
" stores, "
2215 <<
ore::NV(
"NumLoads", SharedCounts.NumLoads) <<
" loads, "
2216 <<
ore::NV(
"NumFPOps", SharedCounts.NumComputeOps)
2218 <<
" are shared with other expressions";
2221 Rem << (
"\n" + linearize(L, Shared, ExprsInSubprogram,
DL));
2232 ExprLinearizer Lin(
DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2233 Lin.linearizeExpr(L, 0,
false,
false);
2234 return Lin.getResult();
2255 LowerMatrixIntrinsics LMT(
F,
TTI,
AA, DT, LI, ORE);
2270 OS, MapClassName2PassName);
2279 class LowerMatrixIntrinsicsLegacyPass :
public FunctionPass {
2289 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
2290 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2291 auto &
AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2292 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2293 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2294 LowerMatrixIntrinsics LMT(
F,
TTI, &
AA, &DT, &LI, &ORE);
2295 bool C = LMT.Visit();
2311 static const char pass_name[] =
"Lower the matrix intrinsics";
2323 return new LowerMatrixIntrinsicsLegacyPass();
2332 class LowerMatrixIntrinsicsMinimalLegacyPass :
public FunctionPass {
2342 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
2343 LowerMatrixIntrinsics LMT(
F,
TTI,
nullptr,
nullptr,
nullptr,
nullptr);
2344 bool C = LMT.Visit();
2365 return new LowerMatrixIntrinsicsMinimalLegacyPass();