LLVM  13.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 //
18 //===----------------------------------------------------------------------===//
19 //
20 #include "X86.h"
22 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/CodeGen/Passes.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/IntrinsicsX86.h"
33 #include "llvm/IR/PatternMatch.h"
34 #include "llvm/InitializePasses.h"
35 #include "llvm/Pass.h"
36 
37 using namespace llvm;
38 using namespace PatternMatch;
39 
40 #define DEBUG_TYPE "lower-amx-type"
41 
43  Function &F = *BB->getParent();
44  Module *M = BB->getModule();
45  const DataLayout &DL = M->getDataLayout();
46 
47  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
48  LLVMContext &Ctx = Builder.getContext();
49  auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
50  unsigned AllocaAS = DL.getAllocaAddrSpace();
51  AllocaInst *AllocaRes =
52  new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
53  AllocaRes->setAlignment(AllocaAlignment);
54  return AllocaRes;
55 }
56 
57 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
58  Value *Row = nullptr, *Col = nullptr;
59  switch (II->getIntrinsicID()) {
60  default:
61  llvm_unreachable("Expect amx intrinsics");
62  case Intrinsic::x86_tileloadd64_internal:
63  case Intrinsic::x86_tilestored64_internal: {
64  Row = II->getArgOperand(0);
65  Col = II->getArgOperand(1);
66  break;
67  }
68  // a * b + c
69  // The shape depends on which operand.
70  case Intrinsic::x86_tdpbssd_internal:
71  case Intrinsic::x86_tdpbsud_internal:
72  case Intrinsic::x86_tdpbusd_internal:
73  case Intrinsic::x86_tdpbuud_internal:
74  case Intrinsic::x86_tdpbf16ps_internal: {
75  switch (OpNo) {
76  case 3:
77  Row = II->getArgOperand(0);
78  Col = II->getArgOperand(1);
79  break;
80  case 4:
81  Row = II->getArgOperand(0);
82  Col = II->getArgOperand(2);
83  break;
84  case 5:
85  Row = II->getArgOperand(2);
86  Col = II->getArgOperand(1);
87  break;
88  }
89  break;
90  }
91  }
92 
93  return std::make_pair(Row, Col);
94 }
95 
96 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
97 // %2 = bitcast <256 x i32> %src to x86_amx
98 // -->
99 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
100 // i8* %addr, i64 %stride64)
102  Value *Row = nullptr, *Col = nullptr;
103  Use &U = *(Bitcast->use_begin());
104  unsigned OpNo = U.getOperandNo();
105  auto *II = cast<IntrinsicInst>(U.getUser());
106  std::tie(Row, Col) = getShape(II, OpNo);
108  // Use the maximun column as stride.
109  Value *Stride = Builder.getInt64(64);
110  Value *I8Ptr =
111  Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
112  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
113 
114  Value *NewInst =
115  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
116  Bitcast->replaceAllUsesWith(NewInst);
117 }
118 
119 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
120 // %stride);
121 // %13 = bitcast x86_amx %src to <256 x i32>
122 // store <256 x i32> %13, <256 x i32>* %addr, align 64
123 // -->
124 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
125 // %stride64, %13)
127 
128  Value *Tile = Bitcast->getOperand(0);
129  auto *II = cast<IntrinsicInst>(Tile);
130  // Tile is output from AMX intrinsic. The first operand of the
131  // intrinsic is row, the second operand of the intrinsic is column.
132  Value *Row = II->getOperand(0);
133  Value *Col = II->getOperand(1);
135  // Use the maximum column as stride. It must be the same with load
136  // stride.
137  Value *Stride = Builder.getInt64(64);
138  Value *I8Ptr =
139  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
140  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
141  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
142  if (Bitcast->hasOneUse())
143  return;
144  // %13 = bitcast x86_amx %src to <256 x i32>
145  // store <256 x i32> %13, <256 x i32>* %addr, align 64
146  // %add = <256 x i32> %13, <256 x i32> %src2
147  // -->
148  // %13 = bitcast x86_amx %src to <256 x i32>
149  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
150  // %stride64, %13)
151  // %14 = load <256 x i32>, %addr
152  // %add = <256 x i32> %14, <256 x i32> %src2
153  Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
154  Bitcast->replaceAllUsesWith(Vec);
155 }
156 
157 // transform bitcast to <store, load> instructions.
160  AllocaInst *AllocaAddr;
161  Value *I8Ptr, *Stride;
162  auto *Src = Bitcast->getOperand(0);
163 
164  auto Prepare = [&]() {
165  AllocaAddr = CreateAllocaInst(Builder, Bitcast->getParent());
166  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
167  Stride = Builder.getInt64(64);
168  };
169 
170  if (Bitcast->getType()->isX86_AMXTy()) {
171  // %2 = bitcast <256 x i32> %src to x86_amx
172  // -->
173  // %addr = alloca <256 x i32>, align 64
174  // store <256 x i32> %src, <256 x i32>* %addr, align 64
175  // %addr2 = bitcast <256 x i32>* to i8*
176  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
177  // i8* %addr2,
178  // i64 64)
179  Use &U = *(Bitcast->use_begin());
180  unsigned OpNo = U.getOperandNo();
181  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
182  if (!II)
183  return false; // May be bitcast from x86amx to <256 x i32>.
184  Prepare();
185  Builder.CreateStore(Src, AllocaAddr);
186  // TODO we can pick an constant operand for the shape.
187  Value *Row = nullptr, *Col = nullptr;
188  std::tie(Row, Col) = getShape(II, OpNo);
189  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
190  Value *NewInst = Builder.CreateIntrinsic(
191  Intrinsic::x86_tileloadd64_internal, None, Args);
192  Bitcast->replaceAllUsesWith(NewInst);
193  } else {
194  // %2 = bitcast x86_amx %src to <256 x i32>
195  // -->
196  // %addr = alloca <256 x i32>, align 64
197  // %addr2 = bitcast <256 x i32>* to i8*
198  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
199  // i8* %addr2, i64 %stride)
200  // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
201  auto *II = dyn_cast<IntrinsicInst>(Src);
202  if (!II)
203  return false; // May be bitcast from <256 x i32> to x86amx.
204  Prepare();
205  Value *Row = II->getOperand(0);
206  Value *Col = II->getOperand(1);
207  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
208  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
209  Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
210  Bitcast->replaceAllUsesWith(NewInst);
211  }
212 
213  return true;
214 }
215 
216 namespace {
217 class X86LowerAMXType {
218  Function &Func;
219 
220 public:
221  X86LowerAMXType(Function &F) : Func(F) {}
222  bool visit();
223 };
224 
225 bool X86LowerAMXType::visit() {
227 
228  for (BasicBlock *BB : post_order(&Func)) {
229  for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
230  II != IE;) {
231  Instruction &Inst = *II++;
232  auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
233  if (!Bitcast)
234  continue;
235 
236  Value *Src = Bitcast->getOperand(0);
237  if (Bitcast->getType()->isX86_AMXTy()) {
238  if (Bitcast->user_empty()) {
239  DeadInsts.push_back(Bitcast);
240  continue;
241  }
242  LoadInst *LD = dyn_cast<LoadInst>(Src);
243  if (!LD) {
245  DeadInsts.push_back(Bitcast);
246  continue;
247  }
248  // If load has mutli-user, duplicate a vector load.
249  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
250  // %2 = bitcast <256 x i32> %src to x86_amx
251  // %add = add <256 x i32> %src, <256 x i32> %src2
252  // -->
253  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
254  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
255  // i8* %addr, i64 %stride64)
256  // %add = add <256 x i32> %src, <256 x i32> %src2
257 
258  // If load has one user, the load will be eliminated in DAG ISel.
259  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
260  // %2 = bitcast <256 x i32> %src to x86_amx
261  // -->
262  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
263  // i8* %addr, i64 %stride64)
265  DeadInsts.push_back(Bitcast);
266  if (LD->hasOneUse())
267  DeadInsts.push_back(LD);
268  } else if (Src->getType()->isX86_AMXTy()) {
269  if (Bitcast->user_empty()) {
270  DeadInsts.push_back(Bitcast);
271  continue;
272  }
273  StoreInst *ST = nullptr;
274  for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
275  UI != UE;) {
276  Value *I = (UI++)->getUser();
277  ST = dyn_cast<StoreInst>(I);
278  if (ST)
279  break;
280  }
281  if (!ST) {
283  DeadInsts.push_back(Bitcast);
284  continue;
285  }
286  // If bitcast (%13) has one use, combine bitcast and store to amx store.
287  // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
288  // %stride);
289  // %13 = bitcast x86_amx %src to <256 x i32>
290  // store <256 x i32> %13, <256 x i32>* %addr, align 64
291  // -->
292  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
293  // %stride64, %13)
294  //
295  // If bitcast (%13) has multi-use, transform as below.
296  // %13 = bitcast x86_amx %src to <256 x i32>
297  // store <256 x i32> %13, <256 x i32>* %addr, align 64
298  // %add = <256 x i32> %13, <256 x i32> %src2
299  // -->
300  // %13 = bitcast x86_amx %src to <256 x i32>
301  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
302  // %stride64, %13)
303  // %14 = load <256 x i32>, %addr
304  // %add = <256 x i32> %14, <256 x i32> %src2
305  //
307  // Delete user first.
308  DeadInsts.push_back(ST);
309  DeadInsts.push_back(Bitcast);
310  }
311  }
312  }
313 
314  bool C = !DeadInsts.empty();
315 
316  for (auto *Inst : DeadInsts)
317  Inst->eraseFromParent();
318 
319  return C;
320 }
321 } // anonymous namespace
322 
323 namespace {
324 
325 class X86LowerAMXTypeLegacyPass : public FunctionPass {
326 public:
327  static char ID;
328 
329  X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
331  }
332 
333  bool runOnFunction(Function &F) override {
334  X86LowerAMXType LAT(F);
335  bool C = LAT.visit();
336  return C;
337  }
338 
339  void getAnalysisUsage(AnalysisUsage &AU) const override {
340  AU.setPreservesCFG();
341  }
342 };
343 
344 } // anonymous namespace
345 
346 static const char PassName[] = "Lower AMX type for load/store";
348 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
349  false)
350 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
351  false)
352 
354  return new X86LowerAMXTypeLegacyPass();
355 }
llvm::createX86LowerAMXTypePass
FunctionPass * createX86LowerAMXTypePass()
The pass transform load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128 ...
Definition: X86LowerAMXType.cpp:353
ValueTypes.h
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass
llvm
This class represents lattice values for constants.
Definition: AllocatorList.h:23
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:111
IntrinsicInst.h
llvm::Function
Definition: Function.h:61
Pass.h
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:51
llvm::ARM_MB::LD
@ LD
Definition: ARMBaseInfo.h:72
llvm::BitCastInst
This class represents a no-op cast from one type to another.
Definition: Instructions.h:5136
llvm::SmallVector< Instruction *, 8 >
combineLoadBitcast
static void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast)
Definition: X86LowerAMXType.cpp:101
llvm::IRBuilder<>
OptimizationRemarkEmitter.h
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:46
transformBitcast
static bool transformBitcast(BitCastInst *Bitcast)
Definition: X86LowerAMXType.cpp:158
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:33
INITIALIZE_PASS_END
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
Definition: RegBankSelect.cpp:69
F
#define F(x, y, z)
Definition: MD5.cpp:56
DEBUG_TYPE
#define DEBUG_TYPE
Definition: X86LowerAMXType.cpp:40
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
X86.h
llvm::Type::getX86_AMXTy
static Type * getX86_AMXTy(LLVMContext &C)
Definition: Type.cpp:192
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
combineBitcastStore
static void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST)
Definition: X86LowerAMXType.cpp:126
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
false
Definition: StackSlotColoring.cpp:142
CreateAllocaInst
static AllocaInst * CreateAllocaInst(IRBuilder<> &Builder, BasicBlock *BB)
Definition: X86LowerAMXType.cpp:42
llvm::Instruction
Definition: Instruction.h:45
llvm::Use::getUser
User * getUser() const
Returns the User that contains this Use.
Definition: Use.h:73
PatternMatch.h
llvm::None
const NoneType None
Definition: None.h:23
llvm::ARM_PROC::IE
@ IE
Definition: ARMBaseInfo.h:27
Passes.h
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:303
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:77
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
llvm::LegalizeActions::Bitcast
@ Bitcast
Perform the operation on a different, but equivalently sized type.
Definition: LegalizerInfo.h:72
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
I
#define I(x, y, z)
Definition: MD5.cpp:59
llvm::initializeX86LowerAMXTypeLegacyPassPass
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
IRBuilder.h
llvm::elfabi::ELFSymbolType::Func
@ Func
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
PassName
static const char PassName[]
Definition: X86LowerAMXType.cpp:346
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:643
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:253
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:136
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:246
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:174
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
getShape
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
Definition: X86LowerAMXType.cpp:57
llvm::post_order
iterator_range< po_iterator< T > > post_order(const T &G)
Definition: PostOrderIterator.h:186
Function.h
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:44
llvm::AllocaInst::setAlignment
void setAlignment(Align Align)
Definition: Instructions.h:123
llvm::BasicBlock::reverse_iterator
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:92
Instructions.h
PostOrderIterator.h
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1341
TargetTransformInfo.h
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:298
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
llvm::AMDGPU::HSAMD::Kernel::Key::Args
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Definition: AMDGPUMetadata.h:379
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:61
InitializePasses.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:75
llvm::VectorType::get
static VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
Definition: Type.cpp:627
llvm::Type::isX86_AMXTy
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition: Type.h:187
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:44
SmallSet.h
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:37