34 #include "llvm/IR/IntrinsicsX86.h"
44 using namespace PatternMatch;
46 #define DEBUG_TYPE "lower-amx-intrinsics"
50 if (
auto *FVT = dyn_cast<FixedVectorType>(Ty))
51 return FVT->getNumElements() == 256 &&
52 FVT->getElementType()->isIntegerTy(32);
59 cl::desc(
"X86: enable AMX scalarizition."));
62 class X86LowerAMXIntrinsics {
67 :
Func(
F), DTU(DomTU), LI(LoopI) {}
76 template <
bool IsTileLoad>
80 template <Intrinsic::ID IntrID>
81 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
82 IntrID == Intrinsic::x86_tdpbsud_internal ||
83 IntrID == Intrinsic::x86_tdpbusd_internal ||
84 IntrID == Intrinsic::x86_tdpbuud_internal ||
85 IntrID == Intrinsic::x86_tdpbf16ps_internal,
90 template <
bool IsTileLoad>
91 bool lowerTileLoadStore(
Instruction *TileLoadStore);
92 template <Intrinsic::ID IntrID>
93 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
94 IntrID == Intrinsic::x86_tdpbsud_internal ||
95 IntrID == Intrinsic::x86_tdpbusd_internal ||
96 IntrID == Intrinsic::x86_tdpbuud_internal ||
97 IntrID == Intrinsic::x86_tdpbf16ps_internal,
123 B.SetInsertPoint(Latch);
127 IV->addIncoming(Inc, Latch);
132 DTU.applyUpdatesPermissive({
148 template <
bool IsTileLoad>
149 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
152 std::string IntrinName = IsTileLoad ?
"tileload" :
"tilestore";
153 Loop *RowLoop =
nullptr;
154 Loop *ColLoop =
nullptr;
156 RowLoop = LI->AllocateLoop();
157 ColLoop = LI->AllocateLoop();
159 if (
Loop *ParentL = LI->getLoopFor(Start))
160 ParentL->addChildLoop(RowLoop);
162 LI->addTopLevelLoop(RowLoop);
165 BasicBlock *RowBody = createLoop(Start, End, Row,
B.getInt16(1),
166 IntrinName +
".scalarize.rows",
B, RowLoop);
169 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col,
B.getInt16(1),
170 IntrinName +
".scalarize.cols",
B, ColLoop);
177 Type *EltTy =
B.getInt32Ty();
184 Value *CurrentRowZExt =
B.CreateZExt(CurrentRow, Stride->
getType());
185 Value *CurrentColZExt =
B.CreateZExt(CurrentCol, Stride->
getType());
187 B.CreateAdd(
B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
188 unsigned AS = cast<PointerType>(Ptr->
getType())->getAddressSpace();
190 Value *EltPtr =
B.CreateGEP(EltTy, EltBasePtr, Offset);
191 Value *Idx =
B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentCol);
198 PHINode *VecCPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.phi.row");
205 PHINode *VecPhi =
B.CreatePHI(V256I32Ty, 2,
"vec.phi");
214 Value *Elt =
B.CreateLoad(EltTy, EltPtr);
215 Value *ResVec =
B.CreateInsertElement(VecPhi, Elt, Idx);
221 auto *BitCast = cast<BitCastInst>(Tile);
222 Value *Vec = BitCast->getOperand(0);
230 Value *Elt =
B.CreateExtractElement(Vec, Idx);
232 B.CreateStore(Elt, EltPtr);
237 template <Intrinsic::ID IntrID>
238 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
239 IntrID == Intrinsic::x86_tdpbsud_internal ||
240 IntrID == Intrinsic::x86_tdpbusd_internal ||
241 IntrID == Intrinsic::x86_tdpbuud_internal ||
242 IntrID == Intrinsic::x86_tdpbf16ps_internal,
248 std::string IntrinName;
250 case Intrinsic::x86_tdpbssd_internal:
251 IntrinName =
"tiledpbssd";
253 case Intrinsic::x86_tdpbsud_internal:
254 IntrinName =
"tiledpbsud";
256 case Intrinsic::x86_tdpbusd_internal:
257 IntrinName =
"tiledpbusd";
259 case Intrinsic::x86_tdpbuud_internal:
260 IntrinName =
"tiledpbuud";
262 case Intrinsic::x86_tdpbf16ps_internal:
263 IntrinName =
"tiledpbf16ps";
266 Loop *RowLoop =
nullptr;
267 Loop *ColLoop =
nullptr;
268 Loop *InnerLoop =
nullptr;
270 RowLoop = LI->AllocateLoop();
271 ColLoop = LI->AllocateLoop();
272 InnerLoop = LI->AllocateLoop();
275 if (
Loop *ParentL = LI->getLoopFor(Start))
276 ParentL->addChildLoop(RowLoop);
278 LI->addTopLevelLoop(RowLoop);
281 BasicBlock *RowBody = createLoop(Start, End, Row,
B.getInt16(1),
282 IntrinName +
".scalarize.rows",
B, RowLoop);
285 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col,
B.getInt16(1),
286 IntrinName +
".scalarize.cols",
B, ColLoop);
292 createLoop(ColBody, ColLoopLatch, K,
B.getInt16(1),
293 IntrinName +
".scalarize.inner",
B, InnerLoop);
301 Value *CurrentInner = &*InnerLoopHeader->
begin();
304 auto *BitCastAcc = cast<BitCastInst>(Acc);
305 Value *VecC = BitCastAcc->getOperand(0);
310 auto *BitCastLHS = cast<BitCastInst>(
LHS);
311 Value *VecA = BitCastLHS->getOperand(0);
313 auto *BitCastRHS = cast<BitCastInst>(
RHS);
314 Value *VecB = BitCastRHS->getOperand(0);
324 PHINode *VecCPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.c.phi.row");
327 PHINode *VecDPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.d.phi.row");
341 PHINode *VecCPhiColLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.c.phi.col");
342 VecCPhiColLoop->
addIncoming(VecCPhiRowLoop, RowBody);
343 PHINode *VecDPhiColLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.d.phi.col");
344 VecDPhiColLoop->
addIncoming(VecDPhiRowLoop, RowBody);
346 B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentCol);
354 PHINode *VecCPhi =
B.CreatePHI(V256I32Ty, 2,
"vec.c.inner.phi");
359 B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentInner);
361 B.CreateAdd(
B.CreateMul(CurrentInner,
B.getInt16(16)), CurrentCol);
362 Value *NewVecC =
nullptr;
364 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
381 Value *EltC =
B.CreateExtractElement(VecCPhi, IdxC);
382 Value *EltA =
B.CreateExtractElement(VecA, IdxA);
383 Value *SubVecA =
B.CreateBitCast(EltA, V4I8Ty);
384 Value *EltB =
B.CreateExtractElement(VecB, IdxB);
385 Value *SubVecB =
B.CreateBitCast(EltB, V4I8Ty);
386 Value *SEXTSubVecB =
nullptr;
387 Value *SEXTSubVecA =
nullptr;
389 case Intrinsic::x86_tdpbssd_internal:
390 SEXTSubVecB =
B.CreateSExt(SubVecB, V4I32Ty);
391 SEXTSubVecA =
B.CreateSExt(SubVecA, V4I32Ty);
393 case Intrinsic::x86_tdpbsud_internal:
394 SEXTSubVecB =
B.CreateZExt(SubVecB, V4I32Ty);
395 SEXTSubVecA =
B.CreateSExt(SubVecA, V4I32Ty);
397 case Intrinsic::x86_tdpbusd_internal:
398 SEXTSubVecB =
B.CreateSExt(SubVecB, V4I32Ty);
399 SEXTSubVecA =
B.CreateZExt(SubVecA, V4I32Ty);
401 case Intrinsic::x86_tdpbuud_internal:
402 SEXTSubVecB =
B.CreateZExt(SubVecB, V4I32Ty);
403 SEXTSubVecA =
B.CreateZExt(SubVecA, V4I32Ty);
408 Value *SubVecR =
B.CreateAddReduce(
B.CreateMul(SEXTSubVecA, SEXTSubVecB));
409 Value *ResElt =
B.CreateAdd(EltC, SubVecR);
410 NewVecC =
B.CreateInsertElement(VecCPhi, ResElt, IdxC);
436 Value *EltC =
B.CreateExtractElement(VecCPhi, IdxC);
437 Value *EltCF32 =
B.CreateBitCast(EltC,
B.getFloatTy());
438 Value *EltA =
B.CreateExtractElement(VecA, IdxA);
439 Value *SubVecA =
B.CreateBitCast(EltA, V2I16Ty);
440 Value *EltB =
B.CreateExtractElement(VecB, IdxB);
441 Value *SubVecB =
B.CreateBitCast(EltB, V2I16Ty);
443 int ShuffleMask[4] = {2, 0, 3, 1};
445 Value *AV2F32 =
B.CreateBitCast(
446 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
447 Value *BV2F32 =
B.CreateBitCast(
448 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
449 Value *SubVecR =
B.CreateFAddReduce(EltCF32,
B.CreateFMul(AV2F32, BV2F32));
450 Value *ResElt =
B.CreateBitCast(SubVecR,
B.getInt32Ty());
451 NewVecC =
B.CreateInsertElement(VecCPhi, ResElt, IdxC);
459 Value *NewEltC =
B.CreateExtractElement(NewVecC, IdxC);
460 Value *NewVecD =
B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
464 VecCPhiColLoop->
addIncoming(NewVecC, ColLoopLatch);
466 VecDPhiColLoop->
addIncoming(NewVecD, ColLoopLatch);
471 template <Intrinsic::ID IntrID>
472 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
473 IntrID == Intrinsic::x86_tdpbsud_internal ||
474 IntrID == Intrinsic::x86_tdpbusd_internal ||
475 IntrID == Intrinsic::x86_tdpbuud_internal ||
476 IntrID == Intrinsic::x86_tdpbf16ps_internal,
478 X86LowerAMXIntrinsics::lowerTileDP(
Instruction *TileDP) {
484 PreBuilder.SetInsertPoint(TileDP);
488 Value *NDWord = PreBuilder.CreateLShr(
N, PreBuilder.getInt16(2));
489 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
494 Value *ResVec = createTileDPLoops<IntrID>(Start, End,
Builder, M, NDWord,
498 Builder.SetInsertPoint(End->getFirstNonPHI());
506 I->replaceAllUsesWith(ResVec);
507 I->eraseFromParent();
515 template <
bool IsTileLoad>
516 bool X86LowerAMXIntrinsics::lowerTileLoadStore(
Instruction *TileLoadStore) {
517 Value *
M, *
N, *Ptr, *Stride, *Tile;
520 m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
523 match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
529 PreBuilder.SetInsertPoint(TileLoadStore);
530 Value *NDWord = PreBuilder.CreateLShr(
N, PreBuilder.getInt16(2));
531 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
536 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
537 Start, End,
Builder, M, NDWord, Ptr, StrideDWord,
538 IsTileLoad ?
nullptr : Tile);
542 Builder.SetInsertPoint(End->getFirstNonPHI());
550 I->replaceAllUsesWith(ResVec);
551 I->eraseFromParent();
560 bool X86LowerAMXIntrinsics::lowerTileZero(
Instruction *TileZero) {
568 I->replaceAllUsesWith(VecZero);
569 I->eraseFromParent();
576 bool X86LowerAMXIntrinsics::visit() {
581 if (
auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
582 switch (Inst->getIntrinsicID()) {
583 case Intrinsic::x86_tdpbssd_internal:
584 case Intrinsic::x86_tdpbsud_internal:
585 case Intrinsic::x86_tdpbusd_internal:
586 case Intrinsic::x86_tdpbuud_internal:
587 case Intrinsic::x86_tileloadd64_internal:
588 case Intrinsic::x86_tilestored64_internal:
589 case Intrinsic::x86_tilezero_internal:
590 case Intrinsic::x86_tdpbf16ps_internal:
591 WorkList.push_back(Inst);
600 for (
auto *Inst : WorkList) {
601 switch (Inst->getIntrinsicID()) {
602 case Intrinsic::x86_tdpbssd_internal:
603 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) ||
C;
605 case Intrinsic::x86_tdpbsud_internal:
606 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) ||
C;
608 case Intrinsic::x86_tdpbusd_internal:
609 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) ||
C;
611 case Intrinsic::x86_tdpbuud_internal:
612 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) ||
C;
614 case Intrinsic::x86_tdpbf16ps_internal:
615 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) ||
C;
617 case Intrinsic::x86_tileloadd64_internal:
618 C = lowerTileLoadStore<true>(Inst) ||
C;
620 case Intrinsic::x86_tilestored64_internal:
621 C = lowerTileLoadStore<false>(Inst) ||
C;
623 case Intrinsic::x86_tilezero_internal:
624 C = lowerTileZero(Inst) ||
C;
635 class X86LowerAMXIntrinsicsLegacyPass :
public FunctionPass {
648 if (!
F.hasFnAttribute(Attribute::OptimizeNone) &&
652 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
653 auto *DT = DTWP ? &DTWP->getDomTree() :
nullptr;
654 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
655 auto *LI = LIWP ? &LIWP->getLoopInfo() :
nullptr;
658 X86LowerAMXIntrinsics LAT(
F, DTU, LI);
661 StringRef getPassName()
const override {
return "Lower AMX intrinsics"; }
671 static const char PassName[] =
"Lower AMX intrinsics";
680 return new X86LowerAMXIntrinsicsLegacyPass();