50using namespace PatternMatch;
52#define DEBUG_TYPE "lower-matrix-intrinsics"
56 cl::desc(
"Enable/disable fusing matrix instructions."));
61 "Tile size for matrix instruction fusion using square-shaped tiles."));
64 cl::desc(
"Generate loop nest for tiling."));
67 cl::desc(
"Force matrix instruction fusion even if not profitable."));
70 cl::desc(
"Allow the use of FMAs if available and profitable. This may "
71 "result in different results, due to less rounding error."));
75 cl::desc(
"Enable/disable matrix shape verification."),
82 cl::desc(
"Sets the default matrix layout"),
84 "Use column-major layout"),
86 "Use row-major layout")));
94 if (
auto *Subprogram = dyn_cast<DISubprogram>(Scope))
103 auto *Inst = cast<Instruction>(V);
105 if (!Inst->use_empty())
107 if (II != BB.
rend() && Inst == &*II)
109 Inst->eraseFromParent();
115 if (
auto *SV = dyn_cast<ShuffleVectorInst>(V))
116 return SV->isZeroEltSplat();
121template <
typename LTy,
typename RTy>
127template <
typename LTy,
typename RTy>
175 unsigned NumElements,
Type *EltType,
178 assert((!isa<ConstantInt>(Stride) ||
179 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
180 "Stride must be >= the number of elements in the result vector.");
187 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->
isZero())
190 VecStart = Builder.
CreateGEP(EltType, BasePtr, VecStart,
"vec.gep");
218class LowerMatrixIntrinsics {
230 unsigned NumStores = 0;
232 unsigned NumLoads = 0;
234 unsigned NumComputeOps = 0;
238 unsigned NumExposedTransposes = 0;
241 NumStores +=
RHS.NumStores;
242 NumLoads +=
RHS.NumLoads;
243 NumComputeOps +=
RHS.NumComputeOps;
244 NumExposedTransposes +=
RHS.NumExposedTransposes;
256 bool IsColumnMajor =
true;
261 : Vectors(Vectors.
begin(), Vectors.
end()),
263 MatrixTy(
unsigned NumRows,
unsigned NumColumns,
Type *EltTy)
266 unsigned D = isColumnMajor() ? NumColumns : NumRows;
267 for (
unsigned J = 0; J <
D; ++J)
269 EltTy, isColumnMajor() ? NumRows : NumColumns)));
272 Value *getVector(
unsigned i)
const {
return Vectors[i]; }
273 Value *getColumn(
unsigned i)
const {
274 assert(isColumnMajor() &&
"only supported for column-major matrixes");
277 Value *getRow(
unsigned i)
const {
278 assert(!isColumnMajor() &&
"only supported for row-major matrixes");
282 void setVector(
unsigned i,
Value *V) { Vectors[i] =
V; }
284 Type *getElementType()
const {
return getVectorTy()->getElementType(); }
286 unsigned getNumVectors()
const {
288 return getNumColumns();
292 unsigned getNumColumns()
const {
294 return Vectors.
size();
296 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
297 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
300 unsigned getNumRows()
const {
301 if (isColumnMajor()) {
302 assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
303 return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
305 return Vectors.
size();
310 assert(isColumnMajor() &&
"only supported for column-major matrixes");
311 return getVectorTy();
315 return cast<VectorType>(Vectors[0]->
getType());
320 "columns() only supported for column-major matrixes");
331 return Vectors.
size() == 1 ? Vectors[0]
335 MatrixTy &addNumLoads(
unsigned N) {
336 OpInfo.NumLoads +=
N;
340 void setNumLoads(
unsigned N) { OpInfo.NumLoads =
N; }
342 MatrixTy &addNumStores(
unsigned N) {
343 OpInfo.NumStores +=
N;
347 MatrixTy &addNumExposedTransposes(
unsigned N) {
348 OpInfo.NumExposedTransposes +=
N;
352 MatrixTy &addNumComputeOps(
unsigned N) {
353 OpInfo.NumComputeOps +=
N;
357 unsigned getNumStores()
const {
return OpInfo.NumStores; }
358 unsigned getNumLoads()
const {
return OpInfo.NumLoads; }
359 unsigned getNumComputeOps()
const {
return OpInfo.NumComputeOps; }
361 const OpInfoTy &getOpInfo()
const {
return OpInfo; }
363 bool isColumnMajor()
const {
return IsColumnMajor; }
365 unsigned getStride()
const {
368 return getNumColumns();
376 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(
I);
377 assert(cast<FixedVectorType>(Vec->
getType())->getNumElements() >=
379 "Extracted vector will contain poison values");
392 ShapeInfo(
unsigned NumRows = 0,
unsigned NumColumns = 0)
393 : NumRows(NumRows), NumColumns(NumColumns),
401 return NumRows == other.NumRows && NumColumns == other.NumColumns;
403 bool operator!=(
const ShapeInfo &other) {
return !(*
this == other); }
407 operator bool()
const {
408 assert(NumRows == 0 || NumColumns != 0);
412 unsigned getStride()
const {
418 unsigned getNumVectors()
const {
425 ShapeInfo t()
const {
return ShapeInfo(NumColumns, NumRows); }
450 if (isa<FPMathOperator>(*Inst))
465 unsigned getNumOps(
Type *VT) {
466 assert(isa<VectorType>(VT) &&
"Expected vector type");
468 cast<FixedVectorType>(VT)->getNumElements());
472 bool isMinimal()
const {
478 unsigned getNumOps(
Type *ST,
unsigned N) {
479 return std::ceil((
ST->getPrimitiveSizeInBits() *
N).getFixedValue() /
490 MatrixTy getMatrix(
Value *MatrixVal,
const ShapeInfo &SI,
493 assert(VType &&
"MatrixVal must be a vector type");
494 assert(cast<FixedVectorType>(VType)->getNumElements() ==
495 SI.NumRows *
SI.NumColumns &&
496 "The vector size must match the number of matrix elements");
502 auto Found = Inst2ColumnMatrix.
find(MatrixVal);
503 if (Found != Inst2ColumnMatrix.
end()) {
504 MatrixTy &
M = Found->second;
507 if (
SI.NumRows ==
M.getNumRows() &&
SI.NumColumns ==
M.getNumColumns())
510 MatrixVal =
M.embedInVector(Builder);
515 for (
unsigned MaskStart = 0;
516 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
517 MaskStart +=
SI.getStride()) {
529 bool setShapeInfo(
Value *V, ShapeInfo Shape) {
530 assert(Shape &&
"Shape not set");
531 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
534 auto SIter = ShapeMap.
find(V);
535 if (SIter != ShapeMap.
end()) {
537 SIter->second.NumColumns != Shape.NumColumns)) {
538 errs() <<
"Conflicting shapes (" << SIter->second.NumRows <<
"x"
539 << SIter->second.NumColumns <<
" vs " << Shape.NumRows <<
"x"
540 << Shape.NumColumns <<
") for " << *
V <<
"\n";
542 "Matrix shape verification failed, compilation aborted!");
546 << SIter->second.NumRows <<
" "
547 << SIter->second.NumColumns <<
" for " << *V <<
"\n");
552 LLVM_DEBUG(
dbgs() <<
" " << Shape.NumRows <<
" x " << Shape.NumColumns
553 <<
" for " << *V <<
"\n");
557 bool isUniformShape(
Value *V) {
562 switch (
I->getOpcode()) {
563 case Instruction::FAdd:
564 case Instruction::FSub:
565 case Instruction::FMul:
566 case Instruction::FNeg:
567 case Instruction::Add:
568 case Instruction::Mul:
569 case Instruction::Sub:
578 bool supportsShapeInfo(
Value *V) {
586 case Intrinsic::matrix_multiply:
587 case Intrinsic::matrix_transpose:
588 case Intrinsic::matrix_column_major_load:
589 case Intrinsic::matrix_column_major_store:
594 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
608 while (!WorkList.
empty()) {
612 bool Propagate =
false;
619 if (
match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
622 Propagate = setShapeInfo(Inst, {
M,
K});
623 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
626 Propagate = setShapeInfo(Inst, {
N,
M});
627 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
630 Propagate = setShapeInfo(Inst, {
N,
M});
631 }
else if (
match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
634 Propagate = setShapeInfo(Inst, {
M,
N});
636 auto OpShape = ShapeMap.
find(MatrixA);
637 if (OpShape != ShapeMap.
end())
638 setShapeInfo(Inst, OpShape->second);
640 }
else if (isUniformShape(Inst)) {
643 auto OpShape = ShapeMap.
find(
Op.get());
644 if (OpShape != ShapeMap.
end()) {
645 Propagate |= setShapeInfo(Inst, OpShape->second);
668 auto pushInstruction = [](
Value *
V,
678 while (!WorkList.
empty()) {
681 size_t BeforeProcessingV = WorkList.
size();
682 if (!isa<Instruction>(V))
690 if (
match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
693 if (setShapeInfo(MatrixA, {
M,
N}))
694 pushInstruction(MatrixA, WorkList);
696 if (setShapeInfo(MatrixB, {
N,
K}))
697 pushInstruction(MatrixB, WorkList);
699 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
702 if (setShapeInfo(MatrixA, {
M,
N}))
703 pushInstruction(MatrixA, WorkList);
704 }
else if (
match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
707 if (setShapeInfo(MatrixA, {
M,
N})) {
708 pushInstruction(MatrixA, WorkList);
710 }
else if (isa<LoadInst>(V) ||
711 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
713 }
else if (isa<StoreInst>(V)) {
716 }
else if (isUniformShape(V)) {
718 ShapeInfo Shape = ShapeMap[
V];
719 for (
Use &U : cast<Instruction>(V)->operands()) {
720 if (setShapeInfo(
U.get(), Shape))
721 pushInstruction(
U.get(), WorkList);
727 for (
size_t I = BeforeProcessingV;
I != WorkList.
size();
I++)
729 if (isa<Instruction>(U) && V != U)
730 NewWorkList.
push_back(cast<Instruction>(U));
739 Value *Op0, ShapeInfo Shape0,
Value *Op1, ShapeInfo Shape1,
744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->
getName() +
"_t");
747 setShapeInfo(T0, Shape0.t());
749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->
getName() +
"_t");
750 setShapeInfo(T1, Shape1.t());
751 return Operation(T0, Shape0.t(), T1, Shape1.t());
758 auto S = ShapeMap.
find(&Old);
759 if (S != ShapeMap.
end()) {
761 if (supportsShapeInfo(New))
778 if (!
match(&
I, m_Intrinsic<Intrinsic::matrix_transpose>(
784 if (
match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(TATA)))) {
785 updateShapeAndReplaceAllUsesWith(
I, TATA);
793 updateShapeAndReplaceAllUsesWith(
I, TA);
800 if (
match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
803 auto NewInst = distributeTransposes(
804 TAMB, {
K,
C}, TAMA, {
R,
K}, Builder,
805 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
808 Shape1.NumColumns,
"mmul");
810 updateShapeAndReplaceAllUsesWith(
I, NewInst);
825 auto NewInst = distributeTransposes(
826 TAMA, {
R,
C}, TAMB, {
R,
C}, Builder,
827 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
828 bool IsFP =
I.getType()->isFPOrFPVectorTy();
829 auto *
Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1,
"mmul")
830 : LocalBuilder.CreateMul(T0, T1,
"mmul");
832 setShapeInfo(Result, Shape0);
835 updateShapeAndReplaceAllUsesWith(
I, NewInst);
845 auto NewInst = distributeTransposes(
846 TAMA, {
R,
C}, TAMB, {
R,
C}, Builder,
847 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
848 bool IsFP =
I.getType()->isFPOrFPVectorTy();
849 auto *
Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1,
"madd")
850 : LocalBuilder.CreateAdd(T0, T1,
"madd");
853 setShapeInfo(Result, Shape0);
856 updateShapeAndReplaceAllUsesWith(
I, NewInst);
871 cast<Instruction>(
A)->eraseFromParent();
872 if (
A !=
B &&
B->use_empty())
873 cast<Instruction>(
B)->eraseFromParent();
879 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>(
882 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT))) &&
887 BT, AT,
C->getZExtValue(),
K->getZExtValue(),
R->getZExtValue());
888 setShapeInfo(M, {
C,
R});
891 updateShapeAndReplaceAllUsesWith(
I, NewInst);
892 CleanupBinOp(
I,
A,
B);
896 match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
898 match(
B, m_Intrinsic<Intrinsic::matrix_transpose>(
901 Value *
Add = cast<Instruction>(Builder.CreateFAdd(AT,
BT,
"mfadd"));
902 setShapeInfo(
Add, {
C,
R});
904 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
905 Add,
C->getZExtValue(),
R->getZExtValue(),
"mfadd_t");
906 updateShapeAndReplaceAllUsesWith(
I, NewInst);
907 CleanupBinOp(
I,
A,
B);
912 void optimizeTransposes() {
916 for (
auto II = BB.
rbegin(); II != BB.
rend();) {
946 case Intrinsic::matrix_multiply:
947 case Intrinsic::matrix_transpose:
948 case Intrinsic::matrix_column_major_load:
949 case Intrinsic::matrix_column_major_store:
958 if (WorkList.
empty())
962 while (!WorkList.
empty()) {
963 WorkList = propagateShapeForward(WorkList);
964 WorkList = propagateShapeBackward(WorkList);
968 optimizeTransposes();
970 dbgs() <<
"Dump after matrix transpose optimization:\n";
975 bool Changed =
false;
982 for (
auto *BB : RPOT)
984 if (ShapeMap.
find(&
I) == ShapeMap.
end())
986 if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>()))
987 MaybeFusableInsts.
push_back(cast<CallInst>(&
I));
993 for (
CallInst *CI : MaybeFusableInsts)
994 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
997 for (
CallInst *CI : MaybeFusableInsts)
998 LowerMatrixMultiplyFused(CI, FusedInsts);
1000 Changed = !FusedInsts.
empty();
1004 if (FusedInsts.
count(Inst))
1009 if (
CallInst *CInst = dyn_cast<CallInst>(Inst))
1010 Changed |= VisitCallInst(CInst);
1014 if (
auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1015 Changed |= VisitBinaryOperator(BinOp);
1016 if (
auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1017 Changed |= VisitUnaryOperator(UnOp);
1019 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1021 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1025 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1026 RemarkGen.emitRemarks();
1039 for (
auto *Inst :
reverse(ToRemove)) {
1041 if (
auto *Poisoned = dyn_cast<Instruction>(
U.getUser()))
1042 PoisonedInsts.
insert(Poisoned);
1046 PoisonedInsts.
erase(Inst);
1048 if (!PoisonedInsts.
empty()) {
1050 dbgs() <<
"Poisoned but present instructions:\n";
1051 for (
auto *
I : PoisonedInsts)
1052 dbgs() << *
I <<
"\n";
1060 bool VisitCallInst(
CallInst *Inst) {
1065 case Intrinsic::matrix_multiply:
1066 LowerMultiply(Inst);
1068 case Intrinsic::matrix_transpose:
1069 LowerTranspose(Inst);
1071 case Intrinsic::matrix_column_major_load:
1072 LowerColumnMajorLoad(Inst);
1074 case Intrinsic::matrix_column_major_store:
1075 LowerColumnMajorStore(Inst);
1090 Align InitialAlign =
DL.getValueOrABITypeAlignment(
A, ElementTy);
1092 return InitialAlign;
1094 TypeSize ElementSizeInBits =
DL.getTypeSizeInBits(ElementTy);
1095 if (
auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1097 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1106 bool IsVolatile, ShapeInfo Shape,
IRBuilder<> &Builder) {
1107 auto *VType = cast<VectorType>(Ty);
1108 Type *EltTy = VType->getElementType();
1112 for (
unsigned I = 0,
E = Shape.getNumVectors();
I <
E; ++
I) {
1115 Stride, Shape.getStride(), EltTy, Builder);
1117 VecTy,
GEP, getAlignForIndex(
I, Stride, EltTy, MAlign),
1118 IsVolatile,
"col.load");
1122 return Result.addNumLoads(getNumOps(
Result.getVectorTy()) *
1130 ShapeInfo ResultShape,
Type *EltTy,
1138 ResultShape.NumColumns);
1140 return loadMatrix(TileTy, TileStart,
Align,
1141 Builder.
getInt64(MatrixShape.getStride()), IsVolatile,
1142 ResultShape, Builder);
1147 bool IsVolatile, ShapeInfo Shape) {
1149 finalizeLowering(Inst,
1158 void LowerColumnMajorLoad(
CallInst *Inst) {
1160 "Intrinsic only supports column-major layout!");
1165 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1170 void storeMatrix(
const MatrixTy &StoreVal,
Value *MatrixPtr,
1171 MaybeAlign MAlign,
bool IsVolatile, ShapeInfo MatrixShape,
1178 StoreVal.getNumColumns());
1180 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1181 Builder.
getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1186 MatrixTy storeMatrix(
Type *Ty, MatrixTy StoreVal,
Value *
Ptr,
1189 auto VType = cast<VectorType>(Ty);
1191 for (
auto Vec :
enumerate(StoreVal.vectors())) {
1196 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1198 getAlignForIndex(Vec.index(), Stride,
1199 VType->getElementType(),
1203 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1204 StoreVal.getNumVectors());
1209 Value *Stride,
bool IsVolatile, ShapeInfo Shape) {
1211 auto StoreVal = getMatrix(
Matrix, Shape, Builder);
1212 finalizeLowering(Inst,
1213 storeMatrix(
Matrix->getType(), StoreVal,
Ptr,
A, Stride,
1214 IsVolatile, Builder),
1221 void LowerColumnMajorStore(
CallInst *Inst) {
1223 "Intrinsic only supports column-major layout!");
1229 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1237 unsigned BlockNumElts =
1238 cast<FixedVectorType>(
Block->getType())->getNumElements();
1239 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1240 assert(NumElts >= BlockNumElts &&
"Too few elements for current block");
1249 for (i = 0; i <
I; i++)
1252 unsigned VecNumElts =
1253 cast<FixedVectorType>(Col->getType())->getNumElements();
1254 for (; i <
I + BlockNumElts; i++)
1255 Mask.push_back(i -
I + VecNumElts);
1257 for (; i < VecNumElts; i++)
1265 unsigned &NumComputeOps) {
1266 NumComputeOps += getNumOps(
A->getType());
1271 if (AllowContraction) {
1275 Func.getParent(), Intrinsic::fmuladd,
A->getType());
1278 NumComputeOps += getNumOps(
A->getType());
1283 NumComputeOps += getNumOps(
A->getType());
1295 auto inserted = Inst2ColumnMatrix.
insert(std::make_pair(Inst,
Matrix));
1297 assert(inserted.second &&
"multiple matrix lowering mapping");
1300 Value *Flattened =
nullptr;
1302 if (ShapeMap.
find(
U.getUser()) == ShapeMap.
end()) {
1304 Flattened =
Matrix.embedInVector(Builder);
1313 void lowerDotProduct(
CallInst *MatMul,
1322 if (LShape.NumRows != 1 || RShape.NumColumns != 1)
1328 Type *ElementType = cast<VectorType>(
LHS->
getType())->getElementType();
1335 auto CanBeFlattened = [
this](
Value *
Op) {
1341 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1342 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1348 auto GetCostForArg = [
this, &CanBeFlattened](
Value *
Op,
unsigned N) {
1349 if (!isa<Instruction>(
Op))
1355 if (!CanBeFlattened(
Op)) {
1358 for (
unsigned I = 1;
I <
N; ++
I)
1372 return NewCost - OriginalCost;
1375 if (
match(
Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1380 for (
unsigned I = 1;
I <
N; ++
I)
1394 auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
1397 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1398 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1402 IsIntVec ? std::nullopt : std::optional(FMF)) +
1406 (LShape.NumColumns - 1) +
1408 (LShape.NumColumns);
1409 if ((LHSCost + ReductionCost - SequentialAddCost) >
InstructionCost(0))
1412 FusedInsts.
insert(MatMul);
1414 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1419 if (!CanBeFlattened(
Op))
1423 ShapeMap[
Op] = ShapeMap[
Op].t();
1427 FusedInsts.insert(cast<Instruction>(
Op));
1430 if (
match(
Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1432 auto *NewLoad = Builder.
CreateLoad(
Op->getType(), Arg);
1433 Op->replaceAllUsesWith(NewLoad);
1434 cast<Instruction>(
Op)->eraseFromParent();
1436 }
else if (
match(
Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1444 LHS = FlattenArg(LHS);
1465 FusedInsts.insert(MatMul);
1476 void emitMatrixMultiply(MatrixTy &Result,
const MatrixTy &
A,
1479 const unsigned VF = std::max<unsigned>(
1482 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1484 unsigned R =
Result.getNumRows();
1485 unsigned C =
Result.getNumColumns();
1486 unsigned M =
A.getNumColumns();
1488 bool IsFP =
Result.getElementType()->isFloatingPointTy();
1489 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1490 Result.isColumnMajor() ==
A.isColumnMajor() &&
1491 "operands must agree on matrix layout");
1492 unsigned NumComputeOps = 0;
1496 if (
A.isColumnMajor()) {
1500 for (
unsigned J = 0; J <
C; ++J) {
1503 bool isSumZero = isa<ConstantAggregateZero>(
Result.getColumn(J));
1512 for (
unsigned K = 0;
K <
M; ++
K) {
1515 B.getColumn(IsScalarMatrixTransposed ? K : J),
1516 IsScalarMatrixTransposed ? J : K);
1519 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, L,
Splat,
1530 for (
unsigned I = 0;
I <
R; ++
I) {
1532 bool isSumZero = isa<ConstantAggregateZero>(
Result.getRow(
I));
1533 for (
unsigned J = 0; J <
C; J +=
BlockSize) {
1538 Value *Sum =
nullptr;
1539 for (
unsigned K = 0;
K <
M; ++
K) {
1542 A.getVector(IsScalarMatrixTransposed ? K :
I),
1543 IsScalarMatrixTransposed ?
I : K);
1546 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum,
Splat, R,
1554 Result.addNumComputeOps(NumComputeOps);
1567 return Load->getPointerOperand();
1583 nullptr,
"alias_cont");
1589 nullptr,
"no_alias");
1599 const_cast<Value *
>(StoreLoc.
Ptr), IntPtrTy,
"store.begin");
1602 "store.end",
true,
true);
1604 IntPtrTy,
"load.begin");
1615 "load.end",
true,
true);
1621 auto *VT = cast<FixedVectorType>(
Load->getType());
1624 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1632 PHI->addIncoming(
Load->getPointerOperand(), Check0);
1633 PHI->addIncoming(
Load->getPointerOperand(), Check1);
1634 PHI->addIncoming(Alloca, Copy);
1645 bool isFusionProfitable(
CallInst *MatMul) {
1652 const unsigned R = LShape.NumRows;
1653 const unsigned C = RShape.NumColumns;
1654 const unsigned M = LShape.NumColumns;
1655 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1657 const unsigned VF = std::max<unsigned>(
1669 if (R <= VF &&
C == 1)
1675 unsigned Op0Regs = (
R + VF - 1) / VF * M;
1676 unsigned Op1Regs = (
M + VF - 1) / VF *
C;
1677 return Op0Regs + Op1Regs >
1681 MatrixTy getZeroMatrix(
Type *EltType,
unsigned R,
unsigned C) {
1684 for (
unsigned I = 0;
I <
C; ++
I)
1689 void createTiledLoops(
CallInst *MatMul,
Value *LPtr, ShapeInfo LShape,
1691 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1694 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns,
TileSize);
1701 BasicBlock *InnerBody = TI.CreateTiledLoops(Start,
End, Builder, DTU, *LI);
1705 MatrixTy TileResult;
1713 TI.RowLoop.Header->getSingleSuccessor());
1714 TileResult.addVector(Phi);
1723 loadMatrix(LPtr, {},
false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1726 loadMatrix(RPtr, {},
false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1728 emitMatrixMultiply(TileResult,
A,
B, Builder,
true,
false,
1729 getFastMathFlags(MatMul));
1732 storeMatrix(TileResult,
Store->getPointerOperand(),
Store->getAlign(),
1733 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1734 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1736 for (
unsigned I = 0;
I < TileResult.getNumVectors();
I++)
1737 ColumnPhis[
I]->addIncoming(TileResult.getVector(
I), TI.KLoop.Latch);
1743 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns /
TileSize);
1745 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1752 "Tiling only supported for column-major matrixes at the moment!");
1753 if (!isFusionProfitable(MatMul))
1759 const unsigned R = LShape.NumRows;
1760 const unsigned C = RShape.NumColumns;
1761 const unsigned M = LShape.NumColumns;
1762 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1764 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1765 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1769 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1772 for (
unsigned J = 0; J <
C; J +=
TileSize)
1774 const unsigned TileR = std::min(R -
I,
unsigned(
TileSize));
1775 const unsigned TileC = std::min(
C - J,
unsigned(
TileSize));
1776 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1779 const unsigned TileM = std::min(M - K,
unsigned(
TileSize));
1783 {TileR, TileM}, EltType, Builder);
1787 {TileM, TileC}, EltType, Builder);
1788 emitMatrixMultiply(Res,
A,
B, Builder,
true,
false,
1789 getFastMathFlags(MatMul));
1791 storeMatrix(Res, CPtr,
Store->getAlign(),
Store->isVolatile(), {R, M},
1798 FusedInsts.
insert(Store);
1799 FusedInsts.
insert(MatMul);
1800 Store->eraseFromParent();
1803 FusedInsts.
insert(LoadOp0);
1806 if (LoadOp1 != LoadOp0 && LoadOp1->
hasNUses(0)) {
1807 FusedInsts.
insert(LoadOp1);
1816 void LowerMatrixMultiplyFused(
CallInst *MatMul,
1821 assert(AA && LI &&
"Analyses should be available");
1830 :
match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
T)))) {
1832 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1835 const unsigned R = LShape.NumRows;
1836 const unsigned M = LShape.NumColumns;
1837 const unsigned C = RShape.NumColumns;
1844 MA = getMatrix(
A, ShapeInfo(R, M), Builder);
1845 MB = getMatrix(
T, ShapeInfo(
C, M), Builder);
1848 MA = getMatrix(
T, ShapeInfo(R, M), Builder);
1849 MB = getMatrix(
B, ShapeInfo(
C, M), Builder);
1854 MatrixTy
Result(R,
C, EltType);
1856 emitMatrixMultiply(Result, MA, MB, Builder,
false,
true,
1857 getFastMathFlags(MatMul));
1859 FusedInsts.
insert(MatMul);
1861 FusedInsts.
insert(cast<Instruction>(Transpose));
1862 ToRemove.push_back(cast<Instruction>(Transpose));
1865 Inst2ColumnMatrix[Transpose] = MatrixTy(M,
C, EltType);
1867 finalizeLowering(MatMul, Result, Builder);
1876 auto *LoadOp0 = dyn_cast<LoadInst>(
A);
1877 auto *LoadOp1 = dyn_cast<LoadInst>(
B);
1879 if (LoadOp0 && LoadOp1 && Store) {
1885 for (
unsigned I = 0;
I != WorkList.
size(); ++
I) {
1886 Value *Current = WorkList[
I];
1887 auto *CurrI = dyn_cast<Instruction>(Current);
1890 if (isa<PHINode>(CurrI))
1894 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1897 WorkList.
insert(CurrI->op_begin(), CurrI->op_end());
1904 I->moveBefore(MatMul);
1906 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1912 void LowerMultiply(
CallInst *MatMul) {
1914 auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1918 const MatrixTy &Lhs = getMatrix(MatMul->
getArgOperand(0), LShape, Builder);
1919 const MatrixTy &Rhs = getMatrix(MatMul->
getArgOperand(1), RShape, Builder);
1920 assert(Lhs.getElementType() == Rhs.getElementType() &&
1921 "Matrix multiply argument element types do not match.");
1923 const unsigned R = LShape.NumRows;
1924 const unsigned C = RShape.NumColumns;
1925 assert(LShape.NumColumns == RShape.NumRows);
1928 MatrixTy
Result(R,
C, EltType);
1929 assert(Lhs.getElementType() ==
Result.getElementType() &&
1930 "Matrix multiply result element type does not match arguments.");
1932 emitMatrixMultiply(Result, Lhs, Rhs, Builder,
false,
false,
1933 getFastMathFlags(MatMul));
1934 finalizeLowering(MatMul, Result, Builder);
1938 void LowerTranspose(
CallInst *Inst) {
1944 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1946 const unsigned NewNumVecs =
1947 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1948 const unsigned NewNumElts =
1949 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1951 for (
unsigned I = 0;
I < NewNumVecs; ++
I) {
1956 for (
auto J :
enumerate(InputMatrix.vectors())) {
1962 Result.addVector(ResultVector);
1970 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1971 .addNumExposedTransposes(1),
1977 auto I = ShapeMap.
find(Inst);
1978 if (
I == ShapeMap.
end())
1989 auto I = ShapeMap.
find(StoredVal);
1990 if (
I == ShapeMap.
end())
2001 auto I = ShapeMap.
find(Inst);
2002 if (
I == ShapeMap.
end())
2009 ShapeInfo &Shape =
I->second;
2012 MatrixTy
A = getMatrix(Lhs, Shape, Builder);
2013 MatrixTy
B = getMatrix(Rhs, Shape, Builder);
2014 assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
2015 Result.isColumnMajor() ==
A.isColumnMajor() &&
2016 "operands must agree on matrix layout");
2023 case Instruction::Add:
2025 case Instruction::Mul:
2027 case Instruction::Sub:
2029 case Instruction::FAdd:
2031 case Instruction::FMul:
2033 case Instruction::FSub:
2040 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
2041 Result.addVector(BuildVectorOp(
A.getVector(
I),
B.getVector(
I)));
2043 finalizeLowering(Inst,
2044 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2052 auto I = ShapeMap.
find(Inst);
2053 if (
I == ShapeMap.
end())
2059 ShapeInfo &Shape =
I->second;
2062 MatrixTy
M = getMatrix(
Op, Shape, Builder);
2067 auto BuildVectorOp = [&Builder, Inst](
Value *
Op) {
2069 case Instruction::FNeg:
2076 for (
unsigned I = 0;
I < Shape.getNumVectors(); ++
I)
2077 Result.addVector(BuildVectorOp(
M.getVector(
I)));
2079 finalizeLowering(Inst,
2080 Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2089 struct ExprLinearizer {
2090 unsigned LengthToBreak = 100;
2093 unsigned LineLength = 0;
2119 : Stream(Str),
DL(
DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2120 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2122 void indent(
unsigned N) {
2124 for (
unsigned i = 0; i <
N; i++)
2133 void maybeIndent(
unsigned Indent) {
2134 if (LineLength >= LengthToBreak)
2137 if (LineLength == 0)
2142 LineLength += S.
size();
2146 Value *getUnderlyingObjectThroughLoads(
Value *V) {
2148 return getUnderlyingObjectThroughLoads(
Ptr);
2149 else if (
V->getType()->isPointerTy())
2155 bool isMatrix(
Value *V)
const {
return ExprsInSubprogram.
count(V); }
2160 auto M = Inst2Matrix.
find(V);
2161 if (M == Inst2Matrix.
end())
2164 SS <<
M->second.getNumRows();
2166 SS <<
M->second.getNumColumns();
2175 write(
"<no called fn>");
2178 if (!
Name.starts_with(
"llvm.matrix")) {
2182 auto *II = cast<IntrinsicInst>(CI);
2190 case Intrinsic::matrix_multiply:
2191 prettyPrintMatrixType(II->
getOperand(0), SS);
2193 prettyPrintMatrixType(II->
getOperand(1), SS);
2196 case Intrinsic::matrix_transpose:
2197 prettyPrintMatrixType(II->
getOperand(0), SS);
2200 case Intrinsic::matrix_column_major_load:
2201 prettyPrintMatrixType(II, SS);
2204 case Intrinsic::matrix_column_major_store:
2205 prettyPrintMatrixType(II->
getOperand(0), SS);
2216 unsigned getNumShapeArgs(
CallInst *CI)
const {
2219 case Intrinsic::matrix_multiply:
2221 case Intrinsic::matrix_transpose:
2223 case Intrinsic::matrix_column_major_load:
2224 case Intrinsic::matrix_column_major_store:
2237 V = getUnderlyingObjectThroughLoads(V);
2238 if (
V->getType()->isPointerTy()) {
2239 if (isa<AllocaInst>(V)) {
2240 Stream <<
"stack addr";
2246 if (!
V->getName().empty()) {
2247 Stream <<
" %" <<
V->getName() <<
"";
2248 LineLength +=
V->getName().size() + 2;
2256 if (
auto *CI = dyn_cast<ConstantInt>(V))
2257 TmpStream << CI->getValue();
2258 else if (isa<Constant>(V))
2259 TmpStream <<
"constant";
2262 TmpStream <<
"matrix";
2264 TmpStream <<
"scalar";
2267 Tmp = std::string(
StringRef(Tmp).trim());
2268 LineLength += Tmp.size();
2275 void linearizeExpr(
Value *Expr,
unsigned Indent,
bool ParentReused,
2276 bool ParentShared) {
2277 auto *
I = cast<Instruction>(Expr);
2278 maybeIndent(Indent);
2282 bool ExprShared =
false;
2285 if (!ParentShared) {
2286 auto SI = Shared.find(Expr);
2287 assert(SI != Shared.end() &&
SI->second.count(Leaf));
2292 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2293 write(
"shared with remark at line " + std::to_string(
DL.getLine()) +
2294 " column " + std::to_string(
DL.getCol()) +
" (");
2296 ExprShared =
SI->second.size() > 1;
2299 bool Reused = !ReusedExprs.
insert(Expr).second;
2300 if (Reused && !ParentReused)
2303 if (
auto *CI = dyn_cast<CallInst>(
I)) {
2307 }
else if (isa<BitCastInst>(Expr)) {
2313 Ops.
append(
I->value_op_begin(),
I->value_op_end());
2314 write(std::string(
I->getOpcodeName()));
2317 write(std::string(
"("));
2319 unsigned NumOpsToBreak = 1;
2320 if (
match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2324 if (Ops.size() > NumOpsToBreak)
2327 maybeIndent(Indent + 1);
2329 linearizeExpr(
Op, Indent + 1, Reused, ExprShared);
2332 if (
Op != Ops.back())
2339 const std::string &getResult() {
2358 struct RemarkGenerator {
2366 : Inst2Matrix(Inst2Matrix), ORE(ORE),
Func(
Func),
2375 for (
auto *Expr : ExprsInSubprogram)
2378 return ExprsInSubprogram.count(U);
2387 void collectSharedInfo(
Value *Leaf,
Value *V,
2391 if (!ExprsInSubprogram.
count(V))
2394 auto I = Shared.insert({
V, {}});
2395 I.first->second.insert(Leaf);
2397 for (
Value *
Op : cast<Instruction>(V)->operand_values())
2398 collectSharedInfo(Leaf,
Op, ExprsInSubprogram, Shared);
2404 std::pair<OpInfoTy, OpInfoTy>
2408 if (!ExprsInSubprogram.
count(Root))
2412 if (!ReusedExprs.
insert(Root).second)
2415 OpInfoTy SharedCount;
2418 auto I = Shared.find(Root);
2419 auto CM = Inst2Matrix.
find(Root);
2420 if (
I->second.size() == 1)
2421 Count = CM->second.getOpInfo();
2423 SharedCount = CM->second.getOpInfo();
2425 for (
Value *
Op : cast<Instruction>(Root)->operand_values()) {
2426 auto C = sumOpInfos(
Op, ReusedExprs, ExprsInSubprogram, Shared);
2428 SharedCount +=
C.second;
2430 return {Count, SharedCount};
2433 void emitRemarks() {
2441 for (
const auto &KV : Inst2Matrix) {
2442 if (
Func.getSubprogram()) {
2443 auto *
I = cast<Instruction>(KV.first);
2448 I.first->second.push_back(KV.first);
2452 auto I = Subprog2Exprs.
insert({
nullptr, {}});
2453 I.first->second.push_back(KV.first);
2456 for (
auto &KV : Subprog2Exprs) {
2459 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2462 for (
Value *Leaf : Leaves)
2463 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2466 for (
auto *L : Leaves) {
2468 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2479 OpInfoTy Counts, SharedCounts;
2480 std::tie(Counts, SharedCounts) =
2481 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2486 Rem <<
"Lowered with ";
2487 Rem <<
ore::NV(
"NumStores", Counts.NumStores) <<
" stores, "
2488 <<
ore::NV(
"NumLoads", Counts.NumLoads) <<
" loads, "
2489 <<
ore::NV(
"NumComputeOps", Counts.NumComputeOps)
2491 <<
ore::NV(
"NumExposedTransposes", Counts.NumExposedTransposes)
2492 <<
" exposed transposes";
2494 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2495 SharedCounts.NumComputeOps > 0) {
2496 Rem <<
",\nadditionally "
2497 <<
ore::NV(
"NumStores", SharedCounts.NumStores) <<
" stores, "
2498 <<
ore::NV(
"NumLoads", SharedCounts.NumLoads) <<
" loads, "
2499 <<
ore::NV(
"NumFPOps", SharedCounts.NumComputeOps)
2501 <<
" are shared with other expressions";
2504 Rem << (
"\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2515 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2516 Lin.linearizeExpr(L, 0,
false,
false);
2517 return Lin.getResult();
2538 LowerMatrixIntrinsics LMT(
F,
TTI, AA, DT, LI, ORE);
2553 OS, MapClassName2PassName);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
ReachingDefAnalysis InstSet & ToRemove
static const Function * getParent(const Value *V)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
This file defines the SmallSet class.
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
static const int BlockSize
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
reverse_iterator rbegin()
InstListType::reverse_iterator reverse_iterator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
static Constant * get(Type *Ty, double V)
This returns a ConstantFP, or a vector containing a splat of a ConstantFP, for the specified value in...
This is the shared class of boolean and integer constants.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Base class for scope-like contexts.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
DILocation * getInlinedAt() const
Analysis pass which computes a DominatorTree.
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Convenience struct for specifying and reasoning about fast-math flags.
void setAllowContract(bool B=true)
bool allowReassoc() const
Flag queries.
bool allowContract() const
Class to represent fixed width SIMD vectors.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args=std::nullopt, const Twine &Name="", MDNode *FPMathTag=nullptr)
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", bool IsInBounds=false)
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memcpy between the specified pointers.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
const BasicBlock * getParent() const
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
An instruction for reading from memory.
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Align getAlign() const
Return the alignment of the access that is being performed.
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
iterator find(const KeyT &Key)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
A vector that has set insertion semantics.
size_type size() const
Determine the number of elements in the SetVector.
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
A SetVector that performs no allocations if smaller than a certain size.
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
bool isVolatile() const
Return true if this is a store to a volatile memory location.
StringRef - Represent a constant reference to a string, i.e.
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
constexpr size_t size() const
size - Get the string size.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
bool isIntegerTy() const
True if this is an instance of IntegerType.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
UnaryOps getOpcode() const
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
iterator find(const KeyT &Val)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
bool erase(const KeyT &Val)
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
user_iterator user_begin()
bool hasOneUse() const
Return true if there is exactly one use of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
Type * getElementType() const
constexpr ScalarTy getFixedValue() const
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
A raw_ostream that writes to an std::string.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
specific_intval< false > m_SpecificInt(APInt V)
Match a specific integer value or vector with all elements equal to the value.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
NodeAddr< FuncNode * > Func
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
const_iterator end(StringRef path)
Get end iterator over path.
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
std::string & operator+=(std::string &buffer, StringRef string)
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, bool ContinueOnCuIndexOverflow)
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
auto reverse(ContainerTy &&C)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ FMulAdd
Sum of float products with llvm.fmuladd(a * b + sum).
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
A CRTP mix-in to automatically provide informational APIs needed for passes.
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....