56 #include "llvm/IR/IntrinsicsX86.h"
67 using namespace PatternMatch;
69 #define DEBUG_TYPE "lower-amx-type"
73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(
m_Value())) ||
74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
m_Value()));
78 auto *II = dyn_cast<IntrinsicInst>(
I);
85 if (II->getType()->isX86_AMXTy())
87 for (
Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
103 unsigned AllocaAS =
DL.getAllocaAddrSpace();
105 new AllocaInst(Ty, AllocaAS,
"", &
F.getEntryBlock().front());
112 if (!isa<AllocaInst>(&
I))
119 Value *Row =
nullptr, *Col =
nullptr;
123 case Intrinsic::x86_tileloadd64_internal:
124 case Intrinsic::x86_tileloaddt164_internal:
125 case Intrinsic::x86_tilestored64_internal: {
132 case Intrinsic::x86_tdpbssd_internal:
133 case Intrinsic::x86_tdpbsud_internal:
134 case Intrinsic::x86_tdpbusd_internal:
135 case Intrinsic::x86_tdpbuud_internal:
136 case Intrinsic::x86_tdpbf16ps_internal: {
149 (cast<ConstantInt>(II->
getOperand(2))->getSExtValue()) / 4);
165 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->
getOperand(2)));
180 return std::make_pair(Row, Col);
192 if (
isAMXCast(dyn_cast<Instruction>(V))) {
199 return getShape(cast<IntrinsicInst>(V), OpNo);
200 }
else if (isa<PHINode>(V)) {
210 return std::make_pair(
nullptr,
nullptr);
214 class X86LowerAMXType {
220 std::map<Value *, Value *> Col2Row;
236 Value *Row =
nullptr, *Col =
nullptr;
239 auto *II = cast<IntrinsicInst>(U.
getUser());
240 std::tie(Row, Col) =
getShape(II, OpNo);
246 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
249 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None,
Args);
250 Bitcast->replaceAllUsesWith(NewInst);
263 auto *II = cast<IntrinsicInst>(Tile);
266 Value *Row = II->getOperand(0);
267 Value *Col = II->getOperand(1);
274 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Tile};
275 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None,
Args);
288 Bitcast->replaceAllUsesWith(Vec);
295 Value *I8Ptr, *Stride;
296 auto *Src =
Bitcast->getOperand(0);
298 auto Prepare = [&](
Type *MemTy) {
300 I8Ptr =
Builder.CreateBitCast(AllocaAddr,
Builder.getInt8PtrTy());
304 if (
Bitcast->getType()->isX86_AMXTy()) {
315 auto *II = dyn_cast<IntrinsicInst>(U.
getUser());
318 Prepare(
Bitcast->getOperand(0)->getType());
319 Builder.CreateStore(Src, AllocaAddr);
321 Value *Row =
nullptr, *Col =
nullptr;
322 std::tie(Row, Col) =
getShape(II, OpNo);
323 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
325 Intrinsic::x86_tileloadd64_internal, None,
Args);
326 Bitcast->replaceAllUsesWith(NewInst);
335 auto *II = dyn_cast<IntrinsicInst>(Src);
339 Value *Row = II->getOperand(0);
340 Value *Col = II->getOperand(1);
341 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Src};
342 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None,
Args);
344 Bitcast->replaceAllUsesWith(NewInst);
350 bool X86LowerAMXType::visit() {
356 auto *
Bitcast = dyn_cast<BitCastInst>(&Inst);
361 if (
Bitcast->getType()->isX86_AMXTy()) {
391 DeadInsts.push_back(
LD);
392 }
else if (Src->getType()->isX86_AMXTy()) {
430 DeadInsts.push_back(
ST);
436 bool C = !DeadInsts.empty();
438 for (
auto *Inst : DeadInsts)
439 Inst->eraseFromParent();
450 unsigned AllocaAS =
DL.getAllocaAddrSpace();
453 new AllocaInst(V256I32Ty, AllocaAS,
"", &
F->getEntryBlock().front());
456 Builder.SetInsertPoint(&*Iter);
463 auto *II = cast<IntrinsicInst>(TileDef);
464 assert(II &&
"Not tile intrinsic!");
465 Value *Row = II->getOperand(0);
466 Value *Col = II->getOperand(1);
472 std::array<Value *, 5>
Args = {Row, Col, Ptr, Stride, TileDef};
475 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal,
None,
Args);
486 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
487 II = cast<IntrinsicInst>(PhiOp);
489 II = cast<IntrinsicInst>(V);
497 std::array<Value *, 4>
Args = {Row, Col, Ptr, Stride};
500 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
None,
Args);
505 for (
Use &U :
I->uses()) {
516 class X86VolatileTileData {
524 bool volatileTileData();
525 void volatileTilePHI(
PHINode *Inst);
529 Value *X86VolatileTileData::updatePhiIncomings(
533 for (
auto *
I : Incomings) {
537 for (
Use &U :
I->uses()) {
539 if (isa<PHINode>(V) || V ==
Store)
547 void X86VolatileTileData::replacePhiDefWithLoad(
Instruction *PHI,
609 void X86VolatileTileData::volatileTilePHI(
PHINode *PHI) {
616 assert(Inst &&
"We shouldn't fold AMX instrution!");
617 Incomings.push_back(Inst);
620 Value *StorePtr = updatePhiIncomings(
BB, Incomings);
621 replacePhiDefWithLoad(PHI, StorePtr);
640 void X86VolatileTileData::volatileTileNonPHI(
Instruction *
I) {
646 for (
Use &U :
I->uses()) {
648 assert(!isa<PHINode>(V) &&
"PHI Nodes should be excluded!");
666 bool X86VolatileTileData::volatileTileData() {
667 bool Changed =
false;
673 if (!
I.getType()->isX86_AMXTy())
675 if (isa<PHINode>(&
I))
676 PHIInsts.push_back(&
I);
678 AMXDefInsts.push_back(&
I);
685 volatileTileNonPHI(
I);
690 volatileTilePHI(dyn_cast<PHINode>(
I));
701 class X86LowerAMXCast {
711 bool transformAllAMXCast();
725 for (
unsigned i = 0,
e =
I->getNumOperands();
i !=
e; ++
i) {
727 I->setOperand(
i,
nullptr);
735 if (
Instruction *OpI = dyn_cast<Instruction>(OpV)) {
741 I->eraseFromParent();
755 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
760 Type *SrcTy = Src->getType();
770 PhiWorklist.push_back(PN);
772 while (!PhiWorklist.empty()) {
774 for (
unsigned I = 0;
I < OldPN->getNumOperands(); ++
I) {
775 Value *IncValue = OldPN->getIncomingValue(
I);
778 if (isa<Constant>(IncValue)) {
779 auto *IncConst = dyn_cast<Constant>(IncValue);
780 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
782 Value *Row =
nullptr, *Col =
nullptr;
783 std::tie(Row, Col) =
getShape(OldPN);
786 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
789 auto *
Block = OldPN->getIncomingBlock(
I);
792 Intrinsic::x86_tilezero_internal, None, {Row, Col});
794 NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
795 {IncValue->
getType()}, {NewInst});
798 OldPN->setIncomingValue(
I, NewInst);
802 if (
auto *PNode = dyn_cast<PHINode>(IncValue)) {
803 if (OldPhiNodes.
insert(PNode))
804 PhiWorklist.push_back(PNode);
807 Instruction *ACI = dyn_cast<Instruction>(IncValue);
812 if (TyA != DestTy || TyB != SrcTy)
822 for (
auto *OldPN : OldPhiNodes) {
829 if (TyA != DestTy || TyB != SrcTy)
831 }
else if (
auto *PHI = dyn_cast<PHINode>(V)) {
850 if (OldPhiNodes.count(PHI) == 0)
859 for (
auto *OldPN : OldPhiNodes) {
861 PHINode *NewPN =
Builder.CreatePHI(DestTy, OldPN->getNumOperands());
862 NewPNodes[OldPN] = NewPN;
866 for (
auto *OldPN : OldPhiNodes) {
867 PHINode *NewPN = NewPNodes[OldPN];
868 for (
unsigned j = 0,
e = OldPN->getNumOperands();
j !=
e; ++
j) {
870 Value *NewV =
nullptr;
875 else if (
auto *PrevPN = dyn_cast<PHINode>(V))
876 NewV = NewPNodes[PrevPN];
890 for (
auto *OldPN : OldPhiNodes) {
891 PHINode *NewPN = NewPNodes[OldPN];
897 assert(TyA == DestTy && TyB == SrcTy);
902 }
else if (
auto *PHI = dyn_cast<PHINode>(V)) {
905 assert(OldPhiNodes.contains(PHI));
925 auto *II = cast<IntrinsicInst>(Tile);
928 Value *Row = II->getOperand(0);
929 Value *Col = II->getOperand(1);
936 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Tile};
937 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None,
Args);
946 Value *Row =
nullptr, *Col =
nullptr;
949 auto *II = cast<IntrinsicInst>(U.
getUser());
954 std::tie(Row, Col) =
getShape(II, OpNo);
960 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
963 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None,
Args);
969 for (
auto *Cast : Casts) {
970 auto *II = cast<IntrinsicInst>(Cast);
976 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
982 combineCastStore(cast<IntrinsicInst>(Cast),
Store);
983 DeadStores.push_back(
Store);
986 for (
auto *
Store : DeadStores)
987 Store->eraseFromParent();
998 combineLoadCast(cast<IntrinsicInst>(Cast),
Load);
1001 Load->eraseFromParent();
1008 bool Change =
false;
1018 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(
m_Value(Vec))))
1019 Vec2TileInsts.push_back(&
I);
1020 else if (
match(&
I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1022 Tile2VecInsts.push_back(&
I);
1027 for (
auto *Inst : Insts) {
1045 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1046 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1050 for (
auto *Inst : Insts) {
1051 if (Inst->use_empty()) {
1052 Inst->eraseFromParent();
1055 LiveCasts.push_back(Inst);
1060 EraseInst(Vec2TileInsts);
1061 EraseInst(Tile2VecInsts);
1062 Change |= combineLdSt(LiveCasts);
1063 EraseInst(LiveCasts);
1069 if (isa<PHINode>(
I.getOperand(0)))
1070 PhiCastWorkList.push_back(&
I);
1074 for (
auto *
I : PhiCastWorkList) {
1078 PHINode *PN = cast<PHINode>(
I->getOperand(0));
1079 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(
I), PN, DeadInst)) {
1087 while (!DeadInst.
empty()) {
1096 bool X86LowerAMXCast::transformAMXCast(
IntrinsicInst *AMXCast) {
1099 Value *I8Ptr, *Stride;
1102 auto Prepare = [&](
Type *MemTy) {
1104 I8Ptr =
Builder.CreateBitCast(AllocaAddr,
Builder.getInt8PtrTy());
1105 Stride =
Builder.getInt64(64);
1127 auto *II = dyn_cast<IntrinsicInst>(U.
getUser());
1131 Builder.CreateStore(Src, AllocaAddr);
1133 Value *Row =
nullptr, *Col =
nullptr;
1134 std::tie(Row, Col) =
getShape(II, OpNo);
1135 std::array<Value *, 4>
Args = {
1138 Intrinsic::x86_tileloadd64_internal, None,
Args);
1149 auto *II = dyn_cast<IntrinsicInst>(Src);
1153 Value *Row = II->getOperand(0);
1154 Value *Col = II->getOperand(1);
1155 std::array<Value *, 5>
Args = {
1156 Row, Col, I8Ptr,
Builder.CreateSExt(Col,
Builder.getInt64Ty()), Src};
1157 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None,
Args);
1166 bool X86LowerAMXCast::transformAllAMXCast() {
1167 bool Change =
false;
1173 WorkLists.push_back(&
I);
1177 for (
auto *Inst : WorkLists) {
1178 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1188 class X86LowerAMXTypeLegacyPass :
public FunctionPass {
1200 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
1201 X86LowerAMXCast LAC(
F);
1202 C |= LAC.combineAMXcast(TLI);
1205 C |= LAC.transformAllAMXCast();
1207 X86LowerAMXType LAT(
F);
1218 if (!
F.hasFnAttribute(Attribute::OptimizeNone)) {
1219 X86VolatileTileData VTD(
F);
1220 C = VTD.volatileTileData() ||
C;
1236 static const char PassName[] =
"Lower AMX type for load/store";
1246 return new X86LowerAMXTypeLegacyPass();