34#include "llvm/IR/IntrinsicsX86.h"
44using namespace PatternMatch;
46#define DEBUG_TYPE "lower-amx-intrinsics"
50 if (
auto *FVT = dyn_cast<FixedVectorType>(Ty))
51 return FVT->getNumElements() == 256 &&
52 FVT->getElementType()->isIntegerTy(32);
59 cl::desc(
"X86: enable AMX scalarizition."));
62class X86LowerAMXIntrinsics {
67 : Func(
F), DTU(DomTU), LI(LoopI) {}
76 template <
bool IsTileLoad>
80 template <Intrinsic::ID IntrID>
81 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
82 IntrID == Intrinsic::x86_tdpbsud_internal ||
83 IntrID == Intrinsic::x86_tdpbusd_internal ||
84 IntrID == Intrinsic::x86_tdpbuud_internal ||
85 IntrID == Intrinsic::x86_tdpbf16ps_internal,
90 template <
bool IsTileLoad>
91 bool lowerTileLoadStore(
Instruction *TileLoadStore);
92 template <Intrinsic::ID IntrID>
93 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
94 IntrID == Intrinsic::x86_tdpbsud_internal ||
95 IntrID == Intrinsic::x86_tdpbusd_internal ||
96 IntrID == Intrinsic::x86_tdpbuud_internal ||
97 IntrID == Intrinsic::x86_tdpbf16ps_internal,
123 B.SetInsertPoint(Latch);
127 IV->addIncoming(Inc, Latch);
132 DTU.applyUpdatesPermissive({
133 {DominatorTree::Delete, Preheader, Tmp},
134 {DominatorTree::Insert, Header, Body},
135 {DominatorTree::Insert, Body, Latch},
136 {DominatorTree::Insert, Latch, Header},
137 {DominatorTree::Insert, Latch, Exit},
138 {DominatorTree::Insert, Preheader, Header},
141 L->addBasicBlockToLoop(Header, *LI);
142 L->addBasicBlockToLoop(Body, *LI);
143 L->addBasicBlockToLoop(Latch, *LI);
148template <
bool IsTileLoad>
149Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
152 std::string IntrinName = IsTileLoad ?
"tileload" :
"tilestore";
153 Loop *RowLoop =
nullptr;
154 Loop *ColLoop =
nullptr;
156 RowLoop = LI->AllocateLoop();
157 ColLoop = LI->AllocateLoop();
159 if (
Loop *ParentL = LI->getLoopFor(Start))
160 ParentL->addChildLoop(RowLoop);
162 LI->addTopLevelLoop(RowLoop);
166 IntrinName +
".scalarize.rows",
B, RowLoop);
169 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col,
B.getInt16(1),
170 IntrinName +
".scalarize.cols",
B, ColLoop);
177 Type *EltTy =
B.getInt32Ty();
184 Value *CurrentRowZExt =
B.CreateZExt(CurrentRow, Stride->
getType());
185 Value *CurrentColZExt =
B.CreateZExt(CurrentCol, Stride->
getType());
187 B.CreateAdd(
B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
188 unsigned AS = cast<PointerType>(
Ptr->getType())->getAddressSpace();
189 Value *EltBasePtr =
B.CreatePointerCast(
Ptr, PointerType::get(EltTy, AS));
191 Value *
Idx =
B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentCol);
198 PHINode *VecCPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.phi.row");
205 PHINode *VecPhi =
B.CreatePHI(V256I32Ty, 2,
"vec.phi");
214 Value *Elt =
B.CreateLoad(EltTy, EltPtr);
215 Value *ResVec =
B.CreateInsertElement(VecPhi, Elt,
Idx);
221 auto *BitCast = cast<BitCastInst>(Tile);
222 Value *Vec = BitCast->getOperand(0);
230 Value *Elt =
B.CreateExtractElement(Vec,
Idx);
232 B.CreateStore(Elt, EltPtr);
237template <Intrinsic::ID IntrID>
238std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
239 IntrID == Intrinsic::x86_tdpbsud_internal ||
240 IntrID == Intrinsic::x86_tdpbusd_internal ||
241 IntrID == Intrinsic::x86_tdpbuud_internal ||
242 IntrID == Intrinsic::x86_tdpbf16ps_internal,
248 std::string IntrinName;
250 case Intrinsic::x86_tdpbssd_internal:
251 IntrinName =
"tiledpbssd";
253 case Intrinsic::x86_tdpbsud_internal:
254 IntrinName =
"tiledpbsud";
256 case Intrinsic::x86_tdpbusd_internal:
257 IntrinName =
"tiledpbusd";
259 case Intrinsic::x86_tdpbuud_internal:
260 IntrinName =
"tiledpbuud";
262 case Intrinsic::x86_tdpbf16ps_internal:
263 IntrinName =
"tiledpbf16ps";
266 Loop *RowLoop =
nullptr;
267 Loop *ColLoop =
nullptr;
268 Loop *InnerLoop =
nullptr;
270 RowLoop = LI->AllocateLoop();
271 ColLoop = LI->AllocateLoop();
272 InnerLoop = LI->AllocateLoop();
275 if (
Loop *ParentL = LI->getLoopFor(Start))
276 ParentL->addChildLoop(RowLoop);
278 LI->addTopLevelLoop(RowLoop);
282 IntrinName +
".scalarize.rows",
B, RowLoop);
285 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col,
B.getInt16(1),
286 IntrinName +
".scalarize.cols",
B, ColLoop);
292 createLoop(ColBody, ColLoopLatch, K,
B.getInt16(1),
293 IntrinName +
".scalarize.inner",
B, InnerLoop);
301 Value *CurrentInner = &*InnerLoopHeader->
begin();
304 auto *BitCastAcc = cast<BitCastInst>(Acc);
305 Value *VecC = BitCastAcc->getOperand(0);
310 auto *BitCastLHS = cast<BitCastInst>(LHS);
311 Value *VecA = BitCastLHS->getOperand(0);
313 auto *BitCastRHS = cast<BitCastInst>(RHS);
314 Value *VecB = BitCastRHS->getOperand(0);
324 PHINode *VecCPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.c.phi.row");
327 PHINode *VecDPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.d.phi.row");
341 PHINode *VecCPhiColLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.c.phi.col");
342 VecCPhiColLoop->
addIncoming(VecCPhiRowLoop, RowBody);
343 PHINode *VecDPhiColLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.d.phi.col");
344 VecDPhiColLoop->
addIncoming(VecDPhiRowLoop, RowBody);
346 B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentCol);
354 PHINode *VecCPhi =
B.CreatePHI(V256I32Ty, 2,
"vec.c.inner.phi");
359 B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentInner);
361 B.CreateAdd(
B.CreateMul(CurrentInner,
B.getInt16(16)), CurrentCol);
362 Value *NewVecC =
nullptr;
364 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
381 Value *EltC =
B.CreateExtractElement(VecCPhi, IdxC);
382 Value *EltA =
B.CreateExtractElement(VecA, IdxA);
383 Value *SubVecA =
B.CreateBitCast(EltA, V4I8Ty);
384 Value *EltB =
B.CreateExtractElement(VecB, IdxB);
385 Value *SubVecB =
B.CreateBitCast(EltB, V4I8Ty);
386 Value *SEXTSubVecB =
nullptr;
387 Value *SEXTSubVecA =
nullptr;
389 case Intrinsic::x86_tdpbssd_internal:
390 SEXTSubVecB =
B.CreateSExt(SubVecB, V4I32Ty);
391 SEXTSubVecA =
B.CreateSExt(SubVecA, V4I32Ty);
393 case Intrinsic::x86_tdpbsud_internal:
394 SEXTSubVecB =
B.CreateZExt(SubVecB, V4I32Ty);
395 SEXTSubVecA =
B.CreateSExt(SubVecA, V4I32Ty);
397 case Intrinsic::x86_tdpbusd_internal:
398 SEXTSubVecB =
B.CreateSExt(SubVecB, V4I32Ty);
399 SEXTSubVecA =
B.CreateZExt(SubVecA, V4I32Ty);
401 case Intrinsic::x86_tdpbuud_internal:
402 SEXTSubVecB =
B.CreateZExt(SubVecB, V4I32Ty);
403 SEXTSubVecA =
B.CreateZExt(SubVecA, V4I32Ty);
408 Value *SubVecR =
B.CreateAddReduce(
B.CreateMul(SEXTSubVecA, SEXTSubVecB));
409 Value *ResElt =
B.CreateAdd(EltC, SubVecR);
410 NewVecC =
B.CreateInsertElement(VecCPhi, ResElt, IdxC);
436 Value *EltC =
B.CreateExtractElement(VecCPhi, IdxC);
437 Value *EltCF32 =
B.CreateBitCast(EltC,
B.getFloatTy());
438 Value *EltA =
B.CreateExtractElement(VecA, IdxA);
439 Value *SubVecA =
B.CreateBitCast(EltA, V2I16Ty);
440 Value *EltB =
B.CreateExtractElement(VecB, IdxB);
441 Value *SubVecB =
B.CreateBitCast(EltB, V2I16Ty);
443 int ShuffleMask[4] = {2, 0, 3, 1};
444 auto ShuffleArray =
ArrayRef(ShuffleMask);
445 Value *AV2F32 =
B.CreateBitCast(
446 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
447 Value *BV2F32 =
B.CreateBitCast(
448 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
449 Value *SubVecR =
B.CreateFAddReduce(EltCF32,
B.CreateFMul(AV2F32, BV2F32));
450 Value *ResElt =
B.CreateBitCast(SubVecR,
B.getInt32Ty());
451 NewVecC =
B.CreateInsertElement(VecCPhi, ResElt, IdxC);
459 Value *NewEltC =
B.CreateExtractElement(NewVecC, IdxC);
460 Value *NewVecD =
B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
464 VecCPhiColLoop->
addIncoming(NewVecC, ColLoopLatch);
466 VecDPhiColLoop->
addIncoming(NewVecD, ColLoopLatch);
471template <Intrinsic::ID IntrID>
472std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
473 IntrID == Intrinsic::x86_tdpbsud_internal ||
474 IntrID == Intrinsic::x86_tdpbusd_internal ||
475 IntrID == Intrinsic::x86_tdpbuud_internal ||
476 IntrID == Intrinsic::x86_tdpbf16ps_internal,
478X86LowerAMXIntrinsics::lowerTileDP(
Instruction *TileDP) {
484 PreBuilder.SetInsertPoint(TileDP);
488 Value *NDWord = PreBuilder.CreateLShr(
N, PreBuilder.getInt16(2));
489 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
494 Value *ResVec = createTileDPLoops<IntrID>(Start,
End, Builder, M, NDWord,
498 Builder.SetInsertPoint(
End->getFirstNonPHI());
506 I->replaceAllUsesWith(ResVec);
507 I->eraseFromParent();
515template <
bool IsTileLoad>
516bool X86LowerAMXIntrinsics::lowerTileLoadStore(
Instruction *TileLoadStore) {
520 m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
523 match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
529 PreBuilder.SetInsertPoint(TileLoadStore);
530 Value *NDWord = PreBuilder.CreateLShr(
N, PreBuilder.getInt16(2));
531 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
536 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
537 Start,
End, Builder, M, NDWord,
Ptr, StrideDWord,
538 IsTileLoad ?
nullptr : Tile);
542 Builder.SetInsertPoint(
End->getFirstNonPHI());
550 I->replaceAllUsesWith(ResVec);
551 I->eraseFromParent();
560bool X86LowerAMXIntrinsics::lowerTileZero(
Instruction *TileZero) {
568 I->replaceAllUsesWith(VecZero);
569 I->eraseFromParent();
576bool X86LowerAMXIntrinsics::visit() {
581 if (
auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
582 switch (Inst->getIntrinsicID()) {
583 case Intrinsic::x86_tdpbssd_internal:
584 case Intrinsic::x86_tdpbsud_internal:
585 case Intrinsic::x86_tdpbusd_internal:
586 case Intrinsic::x86_tdpbuud_internal:
587 case Intrinsic::x86_tileloadd64_internal:
588 case Intrinsic::x86_tilestored64_internal:
589 case Intrinsic::x86_tilezero_internal:
590 case Intrinsic::x86_tdpbf16ps_internal:
600 for (
auto *Inst : WorkList) {
601 switch (Inst->getIntrinsicID()) {
602 case Intrinsic::x86_tdpbssd_internal:
603 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) ||
C;
605 case Intrinsic::x86_tdpbsud_internal:
606 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) ||
C;
608 case Intrinsic::x86_tdpbusd_internal:
609 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) ||
C;
611 case Intrinsic::x86_tdpbuud_internal:
612 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) ||
C;
614 case Intrinsic::x86_tdpbf16ps_internal:
615 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) ||
C;
617 case Intrinsic::x86_tileloadd64_internal:
618 C = lowerTileLoadStore<true>(Inst) ||
C;
620 case Intrinsic::x86_tilestored64_internal:
621 C = lowerTileLoadStore<false>(Inst) ||
C;
623 case Intrinsic::x86_tilezero_internal:
624 C = lowerTileZero(Inst) ||
C;
635class X86LowerAMXIntrinsicsLegacyPass :
public FunctionPass {
648 if (!
F.hasFnAttribute(Attribute::OptimizeNone) &&
652 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
653 auto *DT = DTWP ? &DTWP->getDomTree() :
nullptr;
654 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
655 auto *LI = LIWP ? &LIWP->getLoopInfo() :
nullptr;
658 X86LowerAMXIntrinsics LAT(
F, DTU, LI);
671static const char PassName[] =
"Lower AMX intrinsics";
672char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
680 return new X86LowerAMXIntrinsicsLegacyPass();
SmallVector< MachineOperand, 4 > Cond
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
This file defines the DenseSet and SmallDenseSet classes.
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Target-Independent Code Generator Pass Configuration Options pass.
static cl::opt< bool > X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden, cl::desc("X86: enable AMX scalarizition."))
static bool isV256I32Ty(Type *Ty)
static const char PassName[]
static const uint32_t IV[8]
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
const Function * getParent() const
Return the enclosing method, or null if none.
InstListType::iterator iterator
Instruction iterators...
LLVMContext & getContext() const
Get the context in which this basic block lives.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Conditional or Unconditional Branch instruction.
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
void setSuccessor(unsigned idx, BasicBlock *NewSucc)
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Legacy analysis pass which computes a DominatorTree.
Class to represent fixed width SIMD vectors.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const BasicBlock * getParent() const
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
This is an important class for using LLVM in a threaded context.
void addChildLoop(LoopT *NewChild)
Add the specified loop to be a child of this loop.
The legacy pass manager's analysis pass to compute loop information.
Represents a single loop in the control flow graph.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
static Type * getX86_AMXTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
A Use represents the edge between a Value definition and its users.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< use_iterator > uses()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
CastClass_match< OpTy, Instruction::BitCast > m_BitCast(const OpTy &Op)
Matches BitCast.
bool match(Val *V, const Pattern &P)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
FunctionPass * createX86LowerAMXIntrinsicsPass()
The pass transforms amx intrinsics to scalar operation if the function has optnone attribute or it is...
void initializeX86LowerAMXIntrinsicsLegacyPassPass(PassRegistry &)
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
iterator_range< df_iterator< T > > depth_first(const T &G)