56#include "llvm/IR/IntrinsicsX86.h"
67using namespace PatternMatch;
69#define DEBUG_TYPE "lower-amx-type"
73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(
m_Value())) ||
74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
m_Value()));
78 auto *II = dyn_cast<IntrinsicInst>(
I);
85 if (II->getType()->isX86_AMXTy())
87 for (
Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
103 unsigned AllocaAS =
DL.getAllocaAddrSpace();
105 new AllocaInst(Ty, AllocaAS,
"", &
F.getEntryBlock().front());
112 if (!isa<AllocaInst>(&
I))
119 Value *Row =
nullptr, *Col =
nullptr;
123 case Intrinsic::x86_tileloadd64_internal:
124 case Intrinsic::x86_tileloaddt164_internal:
125 case Intrinsic::x86_tilestored64_internal: {
132 case Intrinsic::x86_tdpbssd_internal:
133 case Intrinsic::x86_tdpbsud_internal:
134 case Intrinsic::x86_tdpbusd_internal:
135 case Intrinsic::x86_tdpbuud_internal:
136 case Intrinsic::x86_tdpbf16ps_internal:
137 case Intrinsic::x86_tdpfp16ps_internal: {
150 (cast<ConstantInt>(II->
getOperand(2))->getSExtValue()) / 4);
166 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->
getOperand(2)));
181 return std::make_pair(Row, Col);
186 unsigned OpNo = U.getOperandNo();
187 User *V = U.getUser();
193 if (
isAMXCast(dyn_cast<Instruction>(V))) {
196 Use &U = *(V->use_begin());
197 OpNo = U.getOperandNo();
200 return getShape(cast<IntrinsicInst>(V), OpNo);
201 }
else if (isa<PHINode>(V)) {
204 Use &U = *(V->use_begin());
211 return std::make_pair(
nullptr,
nullptr);
215class X86LowerAMXType {
221 std::map<Value *, Value *> Col2Row;
237 Value *Row =
nullptr, *Col =
nullptr;
239 unsigned OpNo =
U.getOperandNo();
240 auto *II = cast<IntrinsicInst>(
U.getUser());
241 std::tie(Row, Col) =
getShape(II, OpNo);
247 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
249 Value *NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
251 Bitcast->replaceAllUsesWith(NewInst);
264 auto *II = cast<IntrinsicInst>(Tile);
267 Value *Row = II->getOperand(0);
268 Value *Col = II->getOperand(1);
275 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Tile};
276 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
290 Bitcast->replaceAllUsesWith(Vec);
294bool X86LowerAMXType::transformBitcast(
BitCastInst *Bitcast) {
297 Value *I8Ptr, *Stride;
298 auto *Src =
Bitcast->getOperand(0);
300 auto Prepare = [&](
Type *MemTy) {
302 I8Ptr =
Builder.CreateBitCast(AllocaAddr,
Builder.getInt8PtrTy());
306 if (
Bitcast->getType()->isX86_AMXTy()) {
316 unsigned OpNo =
U.getOperandNo();
317 auto *II = dyn_cast<IntrinsicInst>(
U.getUser());
320 Prepare(
Bitcast->getOperand(0)->getType());
321 Builder.CreateStore(Src, AllocaAddr);
323 Value *Row =
nullptr, *Col =
nullptr;
324 std::tie(Row, Col) =
getShape(II, OpNo);
325 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
327 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
328 Bitcast->replaceAllUsesWith(NewInst);
337 auto *II = dyn_cast<IntrinsicInst>(Src);
341 Value *Row = II->getOperand(0);
342 Value *Col = II->getOperand(1);
343 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Src};
344 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
347 Bitcast->replaceAllUsesWith(NewInst);
353bool X86LowerAMXType::visit() {
359 auto *
Bitcast = dyn_cast<BitCastInst>(&Inst);
364 if (
Bitcast->getType()->isX86_AMXTy()) {
371 if (transformBitcast(Bitcast))
391 combineLoadBitcast(LD, Bitcast);
395 }
else if (Src->getType()->isX86_AMXTy()) {
402 ST = dyn_cast<StoreInst>(
U.getUser());
407 if (transformBitcast(Bitcast))
431 combineBitcastStore(Bitcast, ST);
439 bool C = !DeadInsts.
empty();
441 for (
auto *Inst : DeadInsts)
442 Inst->eraseFromParent();
453 unsigned AllocaAS =
DL.getAllocaAddrSpace();
454 Type *V256I32Ty = VectorType::get(
Builder.getInt32Ty(), 256,
false);
456 new AllocaInst(V256I32Ty, AllocaAS,
"", &
F->getEntryBlock().front());
459 Builder.SetInsertPoint(&*Iter);
466 auto *II = cast<IntrinsicInst>(TileDef);
467 assert(II &&
"Not tile intrinsic!");
468 Value *Row = II->getOperand(0);
469 Value *Col = II->getOperand(1);
475 std::array<Value *, 5> Args = {Row, Col,
Ptr, Stride, TileDef};
478 Intrinsic::x86_tilestored64_internal, std::nullopt, Args);
484 assert(V->getType()->isX86_AMXTy() &&
"Not define tile!");
489 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
490 II = cast<IntrinsicInst>(PhiOp);
492 II = cast<IntrinsicInst>(V);
497 Instruction *UserI = dyn_cast<Instruction>(U.getUser());
500 std::array<Value *, 4> Args = {Row, Col,
Ptr, Stride};
502 Value *TileLoad =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
508 for (
Use &U :
I->uses()) {
509 User *V = U.getUser();
519class X86VolatileTileData {
527 bool volatileTileData();
528 void volatileTilePHI(
PHINode *Inst);
532Value *X86VolatileTileData::updatePhiIncomings(
536 for (
auto *
I : Incomings) {
540 for (
Use &U :
I->uses()) {
542 if (isa<PHINode>(V) || V == Store)
552 for (
Use &U :
PHI->uses())
554 PHI->eraseFromParent();
612void X86VolatileTileData::volatileTilePHI(
PHINode *
PHI) {
616 for (
unsigned I = 0,
E =
PHI->getNumIncomingValues();
I !=
E; ++
I) {
619 assert(Inst &&
"We shouldn't fold AMX instrution!");
623 Value *StorePtr = updatePhiIncomings(BB, Incomings);
624 replacePhiDefWithLoad(
PHI, StorePtr);
643void X86VolatileTileData::volatileTileNonPHI(
Instruction *
I) {
649 for (
Use &U :
I->uses()) {
651 assert(!isa<PHINode>(V) &&
"PHI Nodes should be excluded!");
669bool X86VolatileTileData::volatileTileData() {
670 bool Changed =
false;
676 if (!
I.getType()->isX86_AMXTy())
678 if (isa<PHINode>(&
I))
688 volatileTileNonPHI(
I);
693 volatileTilePHI(dyn_cast<PHINode>(
I));
704class X86LowerAMXCast {
706 std::unique_ptr<DominatorTree> DT;
715 bool transformAllAMXCast();
729 for (
unsigned i = 0, e =
I->getNumOperands(); i != e; ++i) {
730 Value *OpV =
I->getOperand(i);
731 I->setOperand(i,
nullptr);
739 if (
Instruction *OpI = dyn_cast<Instruction>(OpV)) {
745 I->eraseFromParent();
759bool X86LowerAMXCast::optimizeAMXCastFromPhi(
764 Type *SrcTy = Src->getType();
776 while (!PhiWorklist.
empty()) {
778 for (
unsigned I = 0;
I < OldPN->getNumOperands(); ++
I) {
779 Value *IncValue = OldPN->getIncomingValue(
I);
782 if (isa<Constant>(IncValue)) {
783 auto *IncConst = dyn_cast<Constant>(IncValue);
784 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
786 Value *Row =
nullptr, *Col =
nullptr;
787 std::tie(Row, Col) =
getShape(OldPN);
790 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
793 auto *
Block = OldPN->getIncomingBlock(
I);
796 Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col});
798 NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
799 {IncValue->
getType()}, {NewInst});
802 OldPN->setIncomingValue(
I, NewInst);
806 if (
auto *PNode = dyn_cast<PHINode>(IncValue)) {
807 if (OldPhiNodes.
insert(PNode))
811 Instruction *ACI = dyn_cast<Instruction>(IncValue);
816 if (TyA != DestTy || TyB != SrcTy)
826 for (
auto *OldPN : OldPhiNodes) {
833 if (TyA != DestTy || TyB != SrcTy)
835 }
else if (
auto *
PHI = dyn_cast<PHINode>(V)) {
854 if (OldPhiNodes.count(
PHI) == 0)
863 for (
auto *OldPN : OldPhiNodes) {
865 PHINode *NewPN =
Builder.CreatePHI(DestTy, OldPN->getNumOperands());
866 NewPNodes[OldPN] = NewPN;
870 for (
auto *OldPN : OldPhiNodes) {
871 PHINode *NewPN = NewPNodes[OldPN];
872 for (
unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
873 Value *
V = OldPN->getOperand(j);
874 Value *NewV =
nullptr;
879 else if (
auto *PrevPN = dyn_cast<PHINode>(V))
880 NewV = NewPNodes[PrevPN];
882 NewPN->
addIncoming(NewV, OldPN->getIncomingBlock(j));
894 for (
auto *OldPN : OldPhiNodes) {
895 PHINode *NewPN = NewPNodes[OldPN];
901 assert(TyA == DestTy && TyB == SrcTy);
906 }
else if (
auto *
PHI = dyn_cast<PHINode>(V)) {
929 auto *II = cast<IntrinsicInst>(Tile);
932 Value *Row = II->getOperand(0);
933 Value *Col = II->getOperand(1);
940 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Tile};
941 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
951 bool EraseLoad =
true;
952 Value *Row =
nullptr, *Col =
nullptr;
954 unsigned OpNo =
U.getOperandNo();
955 auto *II = cast<IntrinsicInst>(
U.getUser());
960 std::tie(Row, Col) =
getShape(II, OpNo);
970 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
974 Builder.SetInsertPoint(&*std::next(
LD->getIterator()));
975 Builder.CreateStore(LD, AllocaAddr);
978 I8Ptr =
Builder.CreateBitCast(AllocaAddr,
Builder.getInt8PtrTy());
983 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
985 Value *NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
994 for (
auto *Cast : Casts) {
995 auto *II = cast<IntrinsicInst>(Cast);
1001 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1007 combineCastStore(cast<IntrinsicInst>(Cast), Store);
1011 for (
auto *Store : DeadStores)
1012 Store->eraseFromParent();
1016 if (!Load || !
Load->hasOneUse())
1023 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1026 Load->eraseFromParent();
1034 bool Change =
false;
1044 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(
m_Value(Vec))))
1046 else if (
match(&
I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1053 for (
auto *Inst : Insts) {
1071 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1072 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1076 for (
auto *Inst : Insts) {
1077 if (Inst->use_empty()) {
1078 Inst->eraseFromParent();
1086 EraseInst(Vec2TileInsts);
1087 EraseInst(Tile2VecInsts);
1088 Change |= combineLdSt(LiveCasts);
1089 EraseInst(LiveCasts);
1095 if (isa<PHINode>(
I.getOperand(0)))
1100 for (
auto *
I : PhiCastWorkList) {
1104 PHINode *PN = cast<PHINode>(
I->getOperand(0));
1105 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(
I), PN, DeadInst)) {
1113 while (!DeadInst.
empty()) {
1122bool X86LowerAMXCast::transformAMXCast(
IntrinsicInst *AMXCast) {
1125 Value *I8Ptr, *Stride;
1128 auto Prepare = [&](
Type *MemTy) {
1130 I8Ptr =
Builder.CreateBitCast(AllocaAddr,
Builder.getInt8PtrTy());
1131 Stride =
Builder.getInt64(64);
1152 unsigned OpNo =
U.getOperandNo();
1153 auto *II = dyn_cast<IntrinsicInst>(
U.getUser());
1157 Builder.CreateStore(Src, AllocaAddr);
1159 Value *Row =
nullptr, *Col =
nullptr;
1160 std::tie(Row, Col) =
getShape(II, OpNo);
1161 std::array<Value *, 4>
Args = {
1164 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
1175 auto *II = dyn_cast<IntrinsicInst>(Src);
1179 Value *Row = II->getOperand(0);
1180 Value *Col = II->getOperand(1);
1181 std::array<Value *, 5>
Args = {
1182 Row, Col, I8Ptr,
Builder.CreateSExt(Col,
Builder.getInt64Ty()), Src};
1183 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
1193bool X86LowerAMXCast::transformAllAMXCast() {
1194 bool Change =
false;
1204 for (
auto *Inst : WorkLists) {
1205 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1227 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
1229 X86LowerAMXCast LAC(
F);
1230 C |= LAC.combineAMXcast(TLI);
1233 C |= LAC.transformAllAMXCast();
1235 X86LowerAMXType LAT(
F);
1246 if (!
F.hasFnAttribute(Attribute::OptimizeNone)) {
1247 X86VolatileTileData VTD(
F);
1248 C = VTD.volatileTileData() ||
C;
1264static const char PassName[] =
"Lower AMX type for load/store";
1265char X86LowerAMXTypeLegacyPass::ID = 0;
1274 return new X86LowerAMXTypeLegacyPass();
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallSet class.
Target-Independent Code Generator Pass Configuration Options pass.
static bool isAMXCast(Instruction *II)
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
static Value * getAllocaPos(BasicBlock *BB)
static bool isIncomingOfPHI(Instruction *I)
static bool isAMXIntrinsic(Value *I)
static const char PassName[]
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
an instruction to allocate memory on the stack
void setAlignment(Align Align)
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const Function * getParent() const
Return the enclosing method, or null if none.
InstListType::iterator iterator
Instruction iterators...
const Module * getModule() const
Return the module owning the function this basic block belongs to, or nullptr if the function does no...
This class represents a no-op cast from one type to another.
Value * getArgOperand(unsigned i) const
A parsed version of the target data layout string in and methods for querying it.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const BasicBlock * getParent() const
const Function * getFunction() const
Return the function this instruction belongs to.
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
A Module instance is used to store all the information related to an LLVM module.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
bool insert(const value_type &X)
Insert a new element into the SetVector.
bool empty() const
Determine if the SetVector is empty or not.
value_type pop_back_val()
A SetVector that performs no allocations if smaller than a certain size.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
static Type * getX86_AMXTy(LLVMContext &C)
bool isX86_AMXTy() const
Return true if this is X86 AMX.
A Use represents the edge between a Value definition and its users.
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
@ Bitcast
Perform the operation on a different, but equivalently sized type.
bool match(Val *V, const Pattern &P)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
This is an optimization pass for GlobalISel generic memory operations.
void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
iterator_range< po_iterator< T > > post_order(const T &G)
bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
auto reverse(ContainerTy &&C)
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...