31#include "llvm/IR/IntrinsicsX86.h"
41using namespace PatternMatch;
43#define DEBUG_TYPE "lower-amx-intrinsics"
47 if (
auto *FVT = dyn_cast<FixedVectorType>(Ty))
48 return FVT->getNumElements() == 256 &&
49 FVT->getElementType()->isIntegerTy(32);
56 cl::desc(
"X86: enable AMX scalarizition."));
59class X86LowerAMXIntrinsics {
64 : Func(
F), DTU(DomTU), LI(LoopI) {}
73 template <
bool IsTileLoad>
77 template <Intrinsic::ID IntrID>
78 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
79 IntrID == Intrinsic::x86_tdpbsud_internal ||
80 IntrID == Intrinsic::x86_tdpbusd_internal ||
81 IntrID == Intrinsic::x86_tdpbuud_internal ||
82 IntrID == Intrinsic::x86_tdpbf16ps_internal,
87 template <
bool IsTileLoad>
88 bool lowerTileLoadStore(
Instruction *TileLoadStore);
89 template <Intrinsic::ID IntrID>
90 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
91 IntrID == Intrinsic::x86_tdpbsud_internal ||
92 IntrID == Intrinsic::x86_tdpbusd_internal ||
93 IntrID == Intrinsic::x86_tdpbuud_internal ||
94 IntrID == Intrinsic::x86_tdpbf16ps_internal,
118 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
120 B.SetInsertPoint(Latch);
124 IV->addIncoming(Inc, Latch);
129 DTU.applyUpdatesPermissive({
130 {DominatorTree::Delete, Preheader, Tmp},
131 {DominatorTree::Insert, Header, Body},
132 {DominatorTree::Insert, Body, Latch},
133 {DominatorTree::Insert, Latch, Header},
134 {DominatorTree::Insert, Latch,
Exit},
135 {DominatorTree::Insert, Preheader, Header},
138 L->addBasicBlockToLoop(Header, *LI);
139 L->addBasicBlockToLoop(Body, *LI);
140 L->addBasicBlockToLoop(Latch, *LI);
145template <
bool IsTileLoad>
146Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
149 std::string IntrinName = IsTileLoad ?
"tileload" :
"tilestore";
150 Loop *RowLoop =
nullptr;
151 Loop *ColLoop =
nullptr;
153 RowLoop = LI->AllocateLoop();
154 ColLoop = LI->AllocateLoop();
156 if (
Loop *ParentL = LI->getLoopFor(Start))
157 ParentL->addChildLoop(RowLoop);
159 LI->addTopLevelLoop(RowLoop);
163 IntrinName +
".scalarize.rows",
B, RowLoop);
166 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col,
B.getInt16(1),
167 IntrinName +
".scalarize.cols",
B, ColLoop);
174 Type *EltTy =
B.getInt32Ty();
181 Value *CurrentRowZExt =
B.CreateZExt(CurrentRow, Stride->
getType());
182 Value *CurrentColZExt =
B.CreateZExt(CurrentCol, Stride->
getType());
184 B.CreateAdd(
B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
186 Value *
Idx =
B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentCol);
193 PHINode *VecCPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.phi.row");
200 PHINode *VecPhi =
B.CreatePHI(V256I32Ty, 2,
"vec.phi");
209 Value *Elt =
B.CreateLoad(EltTy, EltPtr);
210 Value *ResVec =
B.CreateInsertElement(VecPhi, Elt,
Idx);
216 auto *BitCast = cast<BitCastInst>(Tile);
217 Value *Vec = BitCast->getOperand(0);
225 Value *Elt =
B.CreateExtractElement(Vec,
Idx);
227 B.CreateStore(Elt, EltPtr);
232template <Intrinsic::ID IntrID>
233std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
234 IntrID == Intrinsic::x86_tdpbsud_internal ||
235 IntrID == Intrinsic::x86_tdpbusd_internal ||
236 IntrID == Intrinsic::x86_tdpbuud_internal ||
237 IntrID == Intrinsic::x86_tdpbf16ps_internal,
243 std::string IntrinName;
245 case Intrinsic::x86_tdpbssd_internal:
246 IntrinName =
"tiledpbssd";
248 case Intrinsic::x86_tdpbsud_internal:
249 IntrinName =
"tiledpbsud";
251 case Intrinsic::x86_tdpbusd_internal:
252 IntrinName =
"tiledpbusd";
254 case Intrinsic::x86_tdpbuud_internal:
255 IntrinName =
"tiledpbuud";
257 case Intrinsic::x86_tdpbf16ps_internal:
258 IntrinName =
"tiledpbf16ps";
261 Loop *RowLoop =
nullptr;
262 Loop *ColLoop =
nullptr;
263 Loop *InnerLoop =
nullptr;
265 RowLoop = LI->AllocateLoop();
266 ColLoop = LI->AllocateLoop();
267 InnerLoop = LI->AllocateLoop();
270 if (
Loop *ParentL = LI->getLoopFor(Start))
271 ParentL->addChildLoop(RowLoop);
273 LI->addTopLevelLoop(RowLoop);
277 IntrinName +
".scalarize.rows",
B, RowLoop);
280 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col,
B.getInt16(1),
281 IntrinName +
".scalarize.cols",
B, ColLoop);
287 createLoop(ColBody, ColLoopLatch, K,
B.getInt16(1),
288 IntrinName +
".scalarize.inner",
B, InnerLoop);
296 Value *CurrentInner = &*InnerLoopHeader->
begin();
299 auto *BitCastAcc = cast<BitCastInst>(Acc);
300 Value *VecC = BitCastAcc->getOperand(0);
305 auto *BitCastLHS = cast<BitCastInst>(LHS);
306 Value *VecA = BitCastLHS->getOperand(0);
308 auto *BitCastRHS = cast<BitCastInst>(RHS);
309 Value *VecB = BitCastRHS->getOperand(0);
319 PHINode *VecCPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.c.phi.row");
322 PHINode *VecDPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.d.phi.row");
336 PHINode *VecCPhiColLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.c.phi.col");
337 VecCPhiColLoop->
addIncoming(VecCPhiRowLoop, RowBody);
338 PHINode *VecDPhiColLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.d.phi.col");
339 VecDPhiColLoop->
addIncoming(VecDPhiRowLoop, RowBody);
341 B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentCol);
349 PHINode *VecCPhi =
B.CreatePHI(V256I32Ty, 2,
"vec.c.inner.phi");
354 B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentInner);
356 B.CreateAdd(
B.CreateMul(CurrentInner,
B.getInt16(16)), CurrentCol);
357 Value *NewVecC =
nullptr;
359 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
376 Value *EltC =
B.CreateExtractElement(VecCPhi, IdxC);
377 Value *EltA =
B.CreateExtractElement(VecA, IdxA);
378 Value *SubVecA =
B.CreateBitCast(EltA, V4I8Ty);
379 Value *EltB =
B.CreateExtractElement(VecB, IdxB);
380 Value *SubVecB =
B.CreateBitCast(EltB, V4I8Ty);
381 Value *SEXTSubVecB =
nullptr;
382 Value *SEXTSubVecA =
nullptr;
384 case Intrinsic::x86_tdpbssd_internal:
385 SEXTSubVecB =
B.CreateSExt(SubVecB, V4I32Ty);
386 SEXTSubVecA =
B.CreateSExt(SubVecA, V4I32Ty);
388 case Intrinsic::x86_tdpbsud_internal:
389 SEXTSubVecB =
B.CreateZExt(SubVecB, V4I32Ty);
390 SEXTSubVecA =
B.CreateSExt(SubVecA, V4I32Ty);
392 case Intrinsic::x86_tdpbusd_internal:
393 SEXTSubVecB =
B.CreateSExt(SubVecB, V4I32Ty);
394 SEXTSubVecA =
B.CreateZExt(SubVecA, V4I32Ty);
396 case Intrinsic::x86_tdpbuud_internal:
397 SEXTSubVecB =
B.CreateZExt(SubVecB, V4I32Ty);
398 SEXTSubVecA =
B.CreateZExt(SubVecA, V4I32Ty);
403 Value *SubVecR =
B.CreateAddReduce(
B.CreateMul(SEXTSubVecA, SEXTSubVecB));
404 Value *ResElt =
B.CreateAdd(EltC, SubVecR);
405 NewVecC =
B.CreateInsertElement(VecCPhi, ResElt, IdxC);
431 Value *EltC =
B.CreateExtractElement(VecCPhi, IdxC);
432 Value *EltCF32 =
B.CreateBitCast(EltC,
B.getFloatTy());
433 Value *EltA =
B.CreateExtractElement(VecA, IdxA);
434 Value *SubVecA =
B.CreateBitCast(EltA, V2I16Ty);
435 Value *EltB =
B.CreateExtractElement(VecB, IdxB);
436 Value *SubVecB =
B.CreateBitCast(EltB, V2I16Ty);
438 int ShuffleMask[4] = {2, 0, 3, 1};
439 auto ShuffleArray =
ArrayRef(ShuffleMask);
440 Value *AV2F32 =
B.CreateBitCast(
441 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
442 Value *BV2F32 =
B.CreateBitCast(
443 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
444 Value *SubVecR =
B.CreateFAddReduce(EltCF32,
B.CreateFMul(AV2F32, BV2F32));
445 Value *ResElt =
B.CreateBitCast(SubVecR,
B.getInt32Ty());
446 NewVecC =
B.CreateInsertElement(VecCPhi, ResElt, IdxC);
454 Value *NewEltC =
B.CreateExtractElement(NewVecC, IdxC);
455 Value *NewVecD =
B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
459 VecCPhiColLoop->
addIncoming(NewVecC, ColLoopLatch);
461 VecDPhiColLoop->
addIncoming(NewVecD, ColLoopLatch);
466template <Intrinsic::ID IntrID>
467std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
468 IntrID == Intrinsic::x86_tdpbsud_internal ||
469 IntrID == Intrinsic::x86_tdpbusd_internal ||
470 IntrID == Intrinsic::x86_tdpbuud_internal ||
471 IntrID == Intrinsic::x86_tdpbf16ps_internal,
473X86LowerAMXIntrinsics::lowerTileDP(
Instruction *TileDP) {
479 PreBuilder.SetInsertPoint(TileDP);
483 Value *NDWord = PreBuilder.CreateLShr(
N, PreBuilder.getInt16(2));
484 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
489 Value *ResVec = createTileDPLoops<IntrID>(Start,
End, Builder, M, NDWord,
493 Builder.SetInsertPoint(
End,
End->getFirstNonPHIIt());
501 I->replaceAllUsesWith(ResVec);
502 I->eraseFromParent();
510template <
bool IsTileLoad>
511bool X86LowerAMXIntrinsics::lowerTileLoadStore(
Instruction *TileLoadStore) {
515 m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
518 match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
524 PreBuilder.SetInsertPoint(TileLoadStore);
525 Value *NDWord = PreBuilder.CreateLShr(
N, PreBuilder.getInt16(2));
526 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
531 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
532 Start,
End, Builder, M, NDWord,
Ptr, StrideDWord,
533 IsTileLoad ?
nullptr : Tile);
537 Builder.SetInsertPoint(
End,
End->getFirstNonPHIIt());
545 I->replaceAllUsesWith(ResVec);
546 I->eraseFromParent();
555bool X86LowerAMXIntrinsics::lowerTileZero(
Instruction *TileZero) {
563 I->replaceAllUsesWith(VecZero);
564 I->eraseFromParent();
571bool X86LowerAMXIntrinsics::visit() {
576 if (
auto *Inst = dyn_cast<IntrinsicInst>(&*
II++)) {
577 switch (Inst->getIntrinsicID()) {
578 case Intrinsic::x86_tdpbssd_internal:
579 case Intrinsic::x86_tdpbsud_internal:
580 case Intrinsic::x86_tdpbusd_internal:
581 case Intrinsic::x86_tdpbuud_internal:
582 case Intrinsic::x86_tileloadd64_internal:
583 case Intrinsic::x86_tilestored64_internal:
584 case Intrinsic::x86_tilezero_internal:
585 case Intrinsic::x86_tdpbf16ps_internal:
595 for (
auto *Inst : WorkList) {
596 switch (Inst->getIntrinsicID()) {
597 case Intrinsic::x86_tdpbssd_internal:
598 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) ||
C;
600 case Intrinsic::x86_tdpbsud_internal:
601 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) ||
C;
603 case Intrinsic::x86_tdpbusd_internal:
604 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) ||
C;
606 case Intrinsic::x86_tdpbuud_internal:
607 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) ||
C;
609 case Intrinsic::x86_tdpbf16ps_internal:
610 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) ||
C;
612 case Intrinsic::x86_tileloadd64_internal:
613 C = lowerTileLoadStore<true>(Inst) ||
C;
615 case Intrinsic::x86_tilestored64_internal:
616 C = lowerTileLoadStore<false>(Inst) ||
C;
618 case Intrinsic::x86_tilezero_internal:
619 C = lowerTileZero(Inst) ||
C;
630class X86LowerAMXIntrinsicsLegacyPass :
public FunctionPass {
643 if (!
F.hasFnAttribute(Attribute::OptimizeNone) &&
647 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
648 auto *DT = DTWP ? &DTWP->getDomTree() :
nullptr;
649 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
650 auto *LI = LIWP ? &LIWP->getLoopInfo() :
nullptr;
653 X86LowerAMXIntrinsics LAT(
F, DTU, LI);
666static const char PassName[] =
"Lower AMX intrinsics";
667char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
675 return new X86LowerAMXIntrinsicsLegacyPass();
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
const SmallVectorImpl< MachineOperand > & Cond
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
Target-Independent Code Generator Pass Configuration Options pass.
static cl::opt< bool > X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden, cl::desc("X86: enable AMX scalarizition."))
static bool isV256I32Ty(Type *Ty)
static const char PassName[]
static const uint32_t IV[8]
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
const Function * getParent() const
Return the enclosing method, or null if none.
InstListType::iterator iterator
Instruction iterators...
LLVMContext & getContext() const
Get the context in which this basic block lives.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Conditional or Unconditional Branch instruction.
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
void setSuccessor(unsigned idx, BasicBlock *NewSucc)
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Legacy analysis pass which computes a DominatorTree.
Class to represent fixed width SIMD vectors.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
This is an important class for using LLVM in a threaded context.
void addChildLoop(LoopT *NewChild)
Add the specified loop to be a child of this loop.
The legacy pass manager's analysis pass to compute loop information.
Represents a single loop in the control flow graph.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
static Type * getX86_AMXTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
A Use represents the edge between a Value definition and its users.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< use_iterator > uses()
const ParentTy * getParent() const
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
bool match(Val *V, const Pattern &P)
CastOperator_match< OpTy, Instruction::BitCast > m_BitCast(const OpTy &Op)
Matches BitCast.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
FunctionPass * createX86LowerAMXIntrinsicsPass()
The pass transforms amx intrinsics to scalar operation if the function has optnone attribute or it is...
void initializeX86LowerAMXIntrinsicsLegacyPassPass(PassRegistry &)
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
iterator_range< df_iterator< T > > depth_first(const T &G)