56#include "llvm/IR/IntrinsicsX86.h"
67using namespace PatternMatch;
69#define DEBUG_TYPE "lower-amx-type"
73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(
m_Value())) ||
74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
m_Value()));
78 auto *II = dyn_cast<IntrinsicInst>(
I);
85 if (II->getType()->isX86_AMXTy())
87 for (
Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
103 unsigned AllocaAS =
DL.getAllocaAddrSpace();
105 new AllocaInst(Ty, AllocaAS,
"", &
F.getEntryBlock().front());
112 if (!isa<AllocaInst>(&
I))
119 Value *Row =
nullptr, *Col =
nullptr;
123 case Intrinsic::x86_tileloadd64_internal:
124 case Intrinsic::x86_tileloaddt164_internal:
125 case Intrinsic::x86_tilestored64_internal: {
132 case Intrinsic::x86_tcmmimfp16ps_internal:
133 case Intrinsic::x86_tcmmrlfp16ps_internal:
134 case Intrinsic::x86_tdpbssd_internal:
135 case Intrinsic::x86_tdpbsud_internal:
136 case Intrinsic::x86_tdpbusd_internal:
137 case Intrinsic::x86_tdpbuud_internal:
138 case Intrinsic::x86_tdpbf16ps_internal:
139 case Intrinsic::x86_tdpfp16ps_internal: {
152 (cast<ConstantInt>(II->
getOperand(2))->getSExtValue()) / 4);
168 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->
getOperand(2)));
183 return std::make_pair(Row, Col);
187 Use &U = *(Phi->use_begin());
188 unsigned OpNo = U.getOperandNo();
189 User *V = U.getUser();
195 if (
isAMXCast(dyn_cast<Instruction>(V))) {
198 Use &U = *(V->use_begin());
199 OpNo = U.getOperandNo();
202 return getShape(cast<IntrinsicInst>(V), OpNo);
203 }
else if (isa<PHINode>(V)) {
206 Use &U = *(V->use_begin());
213 return std::make_pair(
nullptr,
nullptr);
217class X86LowerAMXType {
223 std::map<Value *, Value *> Col2Row;
239 Value *Row =
nullptr, *Col =
nullptr;
241 unsigned OpNo =
U.getOperandNo();
242 auto *II = cast<IntrinsicInst>(
U.getUser());
243 std::tie(Row, Col) =
getShape(II, OpNo);
247 Value *I8Ptr =
LD->getOperand(0);
248 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
250 Value *NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
252 Bitcast->replaceAllUsesWith(NewInst);
265 auto *II = cast<IntrinsicInst>(Tile);
268 Value *Row = II->getOperand(0);
269 Value *Col = II->getOperand(1);
274 Value *I8Ptr =
ST->getOperand(1);
275 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Tile};
276 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
290 Bitcast->replaceAllUsesWith(Vec);
294bool X86LowerAMXType::transformBitcast(
BitCastInst *Bitcast) {
297 Value *I8Ptr, *Stride;
298 auto *Src =
Bitcast->getOperand(0);
300 auto Prepare = [&](
Type *MemTy) {
306 if (
Bitcast->getType()->isX86_AMXTy()) {
316 unsigned OpNo =
U.getOperandNo();
317 auto *II = dyn_cast<IntrinsicInst>(
U.getUser());
320 Prepare(
Bitcast->getOperand(0)->getType());
321 Builder.CreateStore(Src, AllocaAddr);
323 Value *Row =
nullptr, *Col =
nullptr;
324 std::tie(Row, Col) =
getShape(II, OpNo);
325 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
327 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
328 Bitcast->replaceAllUsesWith(NewInst);
337 auto *II = dyn_cast<IntrinsicInst>(Src);
341 Value *Row = II->getOperand(0);
342 Value *Col = II->getOperand(1);
343 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Src};
344 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
347 Bitcast->replaceAllUsesWith(NewInst);
353bool X86LowerAMXType::visit() {
359 auto *
Bitcast = dyn_cast<BitCastInst>(&Inst);
364 if (
Bitcast->getType()->isX86_AMXTy()) {
371 if (transformBitcast(Bitcast))
391 combineLoadBitcast(LD, Bitcast);
395 }
else if (Src->getType()->isX86_AMXTy()) {
402 ST = dyn_cast<StoreInst>(
U.getUser());
407 if (transformBitcast(Bitcast))
431 combineBitcastStore(Bitcast, ST);
439 bool C = !DeadInsts.
empty();
441 for (
auto *Inst : DeadInsts)
442 Inst->eraseFromParent();
453 unsigned AllocaAS =
DL.getAllocaAddrSpace();
454 Type *V256I32Ty = VectorType::get(
Builder.getInt32Ty(), 256,
false);
456 new AllocaInst(V256I32Ty, AllocaAS,
"", &
F->getEntryBlock().front());
459 Builder.SetInsertPoint(&*Iter);
466 auto *II = cast<IntrinsicInst>(TileDef);
467 assert(II &&
"Not tile intrinsic!");
468 Value *Row = II->getOperand(0);
469 Value *Col = II->getOperand(1);
475 std::array<Value *, 5> Args = {Row, Col,
Ptr, Stride, TileDef};
478 Intrinsic::x86_tilestored64_internal, std::nullopt, Args);
484 assert(V->getType()->isX86_AMXTy() &&
"Not define tile!");
489 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
490 II = cast<IntrinsicInst>(PhiOp);
492 II = cast<IntrinsicInst>(V);
497 Instruction *UserI = cast<Instruction>(U.getUser());
500 std::array<Value *, 4> Args = {Row, Col,
Ptr, Stride};
502 Value *TileLoad =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
508 for (
Use &U :
I->uses()) {
509 User *V = U.getUser();
519class X86VolatileTileData {
527 bool volatileTileData();
532Value *X86VolatileTileData::updatePhiIncomings(
536 for (
auto *
I : Incomings) {
540 for (
Use &U :
I->uses()) {
542 if (isa<PHINode>(V) || V == Store)
552 for (
Use &U :
PHI->uses())
554 PHI->eraseFromParent();
612void X86VolatileTileData::volatileTilePHI(
PHINode *
PHI) {
616 for (
unsigned I = 0,
E =
PHI->getNumIncomingValues();
I !=
E; ++
I) {
619 assert(Inst &&
"We shouldn't fold AMX instrution!");
623 Value *StorePtr = updatePhiIncomings(BB, Incomings);
624 replacePhiDefWithLoad(
PHI, StorePtr);
643void X86VolatileTileData::volatileTileNonPHI(
Instruction *
I) {
649 for (
Use &U :
I->uses()) {
651 assert(!isa<PHINode>(V) &&
"PHI Nodes should be excluded!");
669bool X86VolatileTileData::volatileTileData() {
670 bool Changed =
false;
676 if (!
I.getType()->isX86_AMXTy())
678 if (isa<PHINode>(&
I))
688 volatileTileNonPHI(
I);
693 volatileTilePHI(dyn_cast<PHINode>(
I));
704class X86LowerAMXCast {
706 std::unique_ptr<DominatorTree> DT;
715 bool transformAllAMXCast();
729 for (
unsigned i = 0, e =
I->getNumOperands(); i != e; ++i) {
730 Value *OpV =
I->getOperand(i);
731 I->setOperand(i,
nullptr);
739 if (
Instruction *OpI = dyn_cast<Instruction>(OpV)) {
745 I->eraseFromParent();
759bool X86LowerAMXCast::optimizeAMXCastFromPhi(
764 Type *SrcTy = Src->getType();
776 while (!PhiWorklist.
empty()) {
778 for (
unsigned I = 0;
I < OldPN->getNumOperands(); ++
I) {
779 Value *IncValue = OldPN->getIncomingValue(
I);
782 if (isa<Constant>(IncValue)) {
783 auto *IncConst = dyn_cast<Constant>(IncValue);
784 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
786 Value *Row =
nullptr, *Col =
nullptr;
787 std::tie(Row, Col) =
getShape(OldPN);
790 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
793 auto *
Block = OldPN->getIncomingBlock(
I);
796 Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col});
798 NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
799 {IncValue->
getType()}, {NewInst});
802 OldPN->setIncomingValue(
I, NewInst);
806 if (
auto *PNode = dyn_cast<PHINode>(IncValue)) {
807 if (OldPhiNodes.
insert(PNode))
811 Instruction *ACI = dyn_cast<Instruction>(IncValue);
816 if (TyA != DestTy || TyB != SrcTy)
826 for (
auto *OldPN : OldPhiNodes) {
833 if (TyA != DestTy || TyB != SrcTy)
835 }
else if (
auto *
PHI = dyn_cast<PHINode>(V)) {
854 if (OldPhiNodes.count(
PHI) == 0)
863 for (
auto *OldPN : OldPhiNodes) {
865 PHINode *NewPN =
Builder.CreatePHI(DestTy, OldPN->getNumOperands());
866 NewPNodes[OldPN] = NewPN;
870 for (
auto *OldPN : OldPhiNodes) {
871 PHINode *NewPN = NewPNodes[OldPN];
872 for (
unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
873 Value *
V = OldPN->getOperand(j);
874 Value *NewV =
nullptr;
879 else if (
auto *PrevPN = dyn_cast<PHINode>(V))
880 NewV = NewPNodes[PrevPN];
882 NewPN->
addIncoming(NewV, OldPN->getIncomingBlock(j));
894 for (
auto *OldPN : OldPhiNodes) {
895 PHINode *NewPN = NewPNodes[OldPN];
901 assert(TyA == DestTy && TyB == SrcTy);
906 }
else if (
auto *
PHI = dyn_cast<PHINode>(V)) {
929 auto *II = cast<IntrinsicInst>(Tile);
932 Value *Row = II->getOperand(0);
933 Value *Col = II->getOperand(1);
939 std::array<Value *, 5>
Args = {Row, Col, I8Ptr, Stride, Tile};
940 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
951 bool EraseLoad =
true;
952 Value *Row =
nullptr, *Col =
nullptr;
954 unsigned OpNo =
U.getOperandNo();
955 auto *II = cast<IntrinsicInst>(
U.getUser());
960 std::tie(Row, Col) =
getShape(II, OpNo);
970 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
974 Builder.SetInsertPoint(&*std::next(
LD->getIterator()));
975 Builder.CreateStore(LD, AllocaAddr);
978 I8Ptr =
Builder.CreateBitCast(AllocaAddr,
Builder.getInt8PtrTy());
983 std::array<Value *, 4>
Args = {Row, Col, I8Ptr, Stride};
985 Value *NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
994 for (
auto *Cast : Casts) {
995 auto *II = cast<IntrinsicInst>(Cast);
1001 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1007 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1012 for (
auto *Store : DeadStores)
1013 Store->eraseFromParent();
1017 if (!Load || !
Load->hasOneUse())
1024 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1027 Load->eraseFromParent();
1035 bool Change =
false;
1045 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(
m_Value(Vec))))
1047 else if (
match(&
I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1054 for (
auto *Inst : Insts) {
1072 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1073 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1077 for (
auto *Inst : Insts) {
1078 if (Inst->use_empty()) {
1079 Inst->eraseFromParent();
1087 EraseInst(Vec2TileInsts);
1088 EraseInst(Tile2VecInsts);
1089 LLVM_DEBUG(
dbgs() <<
"[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1090 "Vec2Tile and Tile2Vec:\n";
1092 Change |= combineLdSt(LiveCasts);
1093 EraseInst(LiveCasts);
1094 LLVM_DEBUG(
dbgs() <<
"[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1095 "AMXCast and load/store:\n";
1102 if (isa<PHINode>(
I.getOperand(0)))
1107 for (
auto *
I : PhiCastWorkList) {
1111 PHINode *PN = cast<PHINode>(
I->getOperand(0));
1112 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(
I), PN, DeadInst)) {
1120 while (!DeadInst.
empty()) {
1124 LLVM_DEBUG(
dbgs() <<
"[LowerAMXTYpe][combineAMXcast] IR dump after "
1125 "optimizeAMXCastFromPhi:\n";
1132bool X86LowerAMXCast::transformAMXCast(
IntrinsicInst *AMXCast) {
1135 Value *I8Ptr, *Stride;
1138 auto Prepare = [&](
Type *MemTy) {
1140 I8Ptr =
Builder.CreateBitCast(AllocaAddr,
Builder.getInt8PtrTy());
1141 Stride =
Builder.getInt64(64);
1162 unsigned OpNo =
U.getOperandNo();
1163 auto *II = dyn_cast<IntrinsicInst>(
U.getUser());
1167 Builder.CreateStore(Src, AllocaAddr);
1169 Value *Row =
nullptr, *Col =
nullptr;
1170 std::tie(Row, Col) =
getShape(II, OpNo);
1171 std::array<Value *, 4>
Args = {
1174 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
1185 auto *II = dyn_cast<IntrinsicInst>(Src);
1189 Value *Row = II->getOperand(0);
1190 Value *Col = II->getOperand(1);
1191 std::array<Value *, 5>
Args = {
1192 Row, Col, I8Ptr,
Builder.CreateSExt(Col,
Builder.getInt64Ty()), Src};
1193 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
1203bool X86LowerAMXCast::transformAllAMXCast() {
1204 bool Change =
false;
1214 for (
auto *Inst : WorkLists) {
1215 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1237 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
1239 X86LowerAMXCast LAC(
F);
1240 C |= LAC.combineAMXcast(TLI);
1243 C |= LAC.transformAllAMXCast();
1245 X86LowerAMXType LAT(
F);
1251 if (
TM->getOptLevel() == CodeGenOptLevel::None) {
1256 if (!
F.hasFnAttribute(Attribute::OptimizeNone)) {
1257 X86VolatileTileData VTD(
F);
1258 C = VTD.volatileTileData() ||
C;
1274static const char PassName[] =
"Lower AMX type for load/store";
1275char X86LowerAMXTypeLegacyPass::ID = 0;
1284 return new X86LowerAMXTypeLegacyPass();
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallSet class.
Target-Independent Code Generator Pass Configuration Options pass.
static bool isAMXCast(Instruction *II)
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
static Value * getAllocaPos(BasicBlock *BB)
static bool isIncomingOfPHI(Instruction *I)
static bool isAMXIntrinsic(Value *I)
static const char PassName[]
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
an instruction to allocate memory on the stack
void setAlignment(Align Align)
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const Function * getParent() const
Return the enclosing method, or null if none.
InstListType::iterator iterator
Instruction iterators...
const Module * getModule() const
Return the module owning the function this basic block belongs to, or nullptr if the function does no...
This class represents a no-op cast from one type to another.
Value * getArgOperand(unsigned i) const
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const BasicBlock * getParent() const
const Function * getFunction() const
Return the function this instruction belongs to.
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
A Module instance is used to store all the information related to an LLVM module.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
bool empty() const
Determine if the SetVector is empty or not.
bool insert(const value_type &X)
Insert a new element into the SetVector.
value_type pop_back_val()
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
A SetVector that performs no allocations if smaller than a certain size.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
static Type * getX86_AMXTy(LLVMContext &C)
bool isX86_AMXTy() const
Return true if this is X86 AMX.
A Use represents the edge between a Value definition and its users.
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
@ Bitcast
Perform the operation on a different, but equivalently sized type.
bool match(Val *V, const Pattern &P)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
NodeAddr< FuncNode * > Func
This is an optimization pass for GlobalISel generic memory operations.
void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
iterator_range< po_iterator< T > > post_order(const T &G)
bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
auto reverse(ContainerTy &&C)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...