File: | llvm/lib/Target/X86/X86LowerAMXType.cpp |
Warning: | line 581, column 20 Called C++ object pointer is null |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | /// \file Pass to transform <256 x i32> load/store | |||
10 | /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only | |||
11 | /// provides simple operation on x86_amx. The basic elementwise operation | |||
12 | /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> | |||
13 | /// and only AMX intrinsics can operate on the type, we need transform | |||
14 | /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can | |||
15 | /// not be combined with load/store, we transform the bitcast to amx load/store | |||
16 | /// and <256 x i32> store/load. | |||
17 | /// | |||
18 | /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S | |||
19 | /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, | |||
20 | /// because that is necessary for AMX fast register allocation. (In Fast | |||
21 | /// registera allocation, register will be allocated before spill/reload, so | |||
22 | /// there is no additional register for amx to identify the step in spill.) | |||
23 | /// The volatileTileData() will handle this case. | |||
24 | /// e.g. | |||
25 | /// ---------------------------------------------------------- | |||
26 | /// | def %td = ... | | |||
27 | /// | ... | | |||
28 | /// | "use %td" | | |||
29 | /// ---------------------------------------------------------- | |||
30 | /// will transfer to --> | |||
31 | /// ---------------------------------------------------------- | |||
32 | /// | def %td = ... | | |||
33 | /// | call void @llvm.x86.tilestored64.internal(mem, %td) | | |||
34 | /// | ... | | |||
35 | /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| | |||
36 | /// | "use %td2" | | |||
37 | /// ---------------------------------------------------------- | |||
38 | // | |||
39 | //===----------------------------------------------------------------------===// | |||
40 | // | |||
41 | #include "X86.h" | |||
42 | #include "llvm/ADT/PostOrderIterator.h" | |||
43 | #include "llvm/ADT/SetVector.h" | |||
44 | #include "llvm/ADT/SmallSet.h" | |||
45 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" | |||
46 | #include "llvm/Analysis/TargetLibraryInfo.h" | |||
47 | #include "llvm/Analysis/TargetTransformInfo.h" | |||
48 | #include "llvm/CodeGen/Passes.h" | |||
49 | #include "llvm/CodeGen/TargetPassConfig.h" | |||
50 | #include "llvm/CodeGen/ValueTypes.h" | |||
51 | #include "llvm/IR/DataLayout.h" | |||
52 | #include "llvm/IR/Function.h" | |||
53 | #include "llvm/IR/IRBuilder.h" | |||
54 | #include "llvm/IR/Instructions.h" | |||
55 | #include "llvm/IR/IntrinsicInst.h" | |||
56 | #include "llvm/IR/IntrinsicsX86.h" | |||
57 | #include "llvm/IR/PatternMatch.h" | |||
58 | #include "llvm/InitializePasses.h" | |||
59 | #include "llvm/Pass.h" | |||
60 | #include "llvm/Target/TargetMachine.h" | |||
61 | #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" | |||
62 | #include "llvm/Transforms/Utils/Local.h" | |||
63 | ||||
64 | using namespace llvm; | |||
65 | using namespace PatternMatch; | |||
66 | ||||
67 | #define DEBUG_TYPE"lower-amx-type" "lower-amx-type" | |||
68 | ||||
69 | static bool isAMXCast(Instruction *II) { | |||
70 | return match(II, | |||
71 | m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) || | |||
72 | match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value())); | |||
73 | } | |||
74 | ||||
75 | static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, | |||
76 | Type *Ty) { | |||
77 | Function &F = *BB->getParent(); | |||
78 | Module *M = BB->getModule(); | |||
79 | const DataLayout &DL = M->getDataLayout(); | |||
80 | ||||
81 | LLVMContext &Ctx = Builder.getContext(); | |||
82 | auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); | |||
83 | unsigned AllocaAS = DL.getAllocaAddrSpace(); | |||
84 | AllocaInst *AllocaRes = | |||
85 | new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front()); | |||
86 | AllocaRes->setAlignment(AllocaAlignment); | |||
87 | return AllocaRes; | |||
88 | } | |||
89 | ||||
90 | static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { | |||
91 | for (Instruction &I : F.getEntryBlock()) | |||
92 | if (!isa<AllocaInst>(&I)) | |||
93 | return &I; | |||
94 | llvm_unreachable("No terminator in the entry block!")::llvm::llvm_unreachable_internal("No terminator in the entry block!" , "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 94); | |||
95 | } | |||
96 | ||||
97 | static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { | |||
98 | IRBuilder<> Builder(II); | |||
99 | Value *Row = nullptr, *Col = nullptr; | |||
100 | switch (II->getIntrinsicID()) { | |||
101 | default: | |||
102 | llvm_unreachable("Expect amx intrinsics")::llvm::llvm_unreachable_internal("Expect amx intrinsics", "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 102); | |||
103 | case Intrinsic::x86_tileloadd64_internal: | |||
104 | case Intrinsic::x86_tileloaddt164_internal: | |||
105 | case Intrinsic::x86_tilestored64_internal: { | |||
106 | Row = II->getArgOperand(0); | |||
107 | Col = II->getArgOperand(1); | |||
108 | break; | |||
109 | } | |||
110 | // a * b + c | |||
111 | // The shape depends on which operand. | |||
112 | case Intrinsic::x86_tdpbssd_internal: | |||
113 | case Intrinsic::x86_tdpbsud_internal: | |||
114 | case Intrinsic::x86_tdpbusd_internal: | |||
115 | case Intrinsic::x86_tdpbuud_internal: | |||
116 | case Intrinsic::x86_tdpbf16ps_internal: { | |||
117 | switch (OpNo) { | |||
118 | case 3: | |||
119 | Row = II->getArgOperand(0); | |||
120 | Col = II->getArgOperand(1); | |||
121 | break; | |||
122 | case 4: | |||
123 | Row = II->getArgOperand(0); | |||
124 | Col = II->getArgOperand(2); | |||
125 | break; | |||
126 | case 5: | |||
127 | if (isa<ConstantInt>(II->getArgOperand(2))) | |||
128 | Row = Builder.getInt16( | |||
129 | (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4); | |||
130 | else if (isa<Instruction>(II->getArgOperand(2))) { | |||
131 | // When it is not a const value and it is not a function argument, we | |||
132 | // create Row after the definition of II->getOperand(2) instead of | |||
133 | // before II. For example, II is %118, we try to getshape for %117: | |||
134 | // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x | |||
135 | // i32> %115). | |||
136 | // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 | |||
137 | // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx | |||
138 | // %117). | |||
139 | // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its | |||
140 | // definition is after its user(new tileload for %117). | |||
141 | // So, the best choice is to create %row right after the definition of | |||
142 | // %106. | |||
143 | Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2))); | |||
144 | Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4)); | |||
145 | cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2))); | |||
146 | } else { | |||
147 | // When it is not a const value and it is a function argument, we create | |||
148 | // Row at the entry bb. | |||
149 | IRBuilder<> NewBuilder( | |||
150 | getFirstNonAllocaInTheEntryBlock(*II->getFunction())); | |||
151 | Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4)); | |||
152 | } | |||
153 | Col = II->getArgOperand(1); | |||
154 | break; | |||
155 | } | |||
156 | break; | |||
157 | } | |||
158 | } | |||
159 | ||||
160 | return std::make_pair(Row, Col); | |||
161 | } | |||
162 | ||||
163 | namespace { | |||
164 | class X86LowerAMXType { | |||
165 | Function &Func; | |||
166 | ||||
167 | // In AMX intrinsics we let Shape = {Row, Col}, but the | |||
168 | // RealCol = Col / ElementSize. We may use the RealCol | |||
169 | // as a new Row for other new created AMX intrinsics. | |||
170 | std::map<Value *, Value *> Col2Row; | |||
171 | ||||
172 | public: | |||
173 | X86LowerAMXType(Function &F) : Func(F) {} | |||
174 | bool visit(); | |||
175 | void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); | |||
176 | void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); | |||
177 | bool transformBitcast(BitCastInst *Bitcast); | |||
178 | Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); | |||
179 | }; | |||
180 | ||||
181 | Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V, | |||
182 | unsigned Granularity) { | |||
183 | if (Col2Row.count(V)) | |||
184 | return Col2Row[V]; | |||
185 | IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt()); | |||
186 | if (auto *I = dyn_cast<Instruction>(V)) { | |||
187 | BasicBlock::iterator Iter = I->getIterator(); | |||
188 | ++Iter; | |||
189 | Builder.SetInsertPoint(&*Iter); | |||
190 | } | |||
191 | ConstantInt *Gran = Builder.getInt16(Granularity); | |||
192 | Value *RealRow = Builder.CreateUDiv(V, Gran); | |||
193 | Col2Row[V] = RealRow; | |||
194 | return RealRow; | |||
195 | } | |||
196 | ||||
197 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
198 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
199 | // --> | |||
200 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
201 | // i8* %addr, i64 %stride64) | |||
202 | void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { | |||
203 | Value *Row = nullptr, *Col = nullptr; | |||
204 | Use &U = *(Bitcast->use_begin()); | |||
205 | unsigned OpNo = U.getOperandNo(); | |||
206 | auto *II = cast<IntrinsicInst>(U.getUser()); | |||
207 | std::tie(Row, Col) = getShape(II, OpNo); | |||
208 | IRBuilder<> Builder(Bitcast); | |||
209 | // Use the maximun column as stride. | |||
210 | Value *Stride = Builder.getInt64(64); | |||
211 | Value *I8Ptr = | |||
212 | Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); | |||
213 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; | |||
214 | ||||
215 | Value *NewInst = | |||
216 | Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); | |||
217 | Bitcast->replaceAllUsesWith(NewInst); | |||
218 | } | |||
219 | ||||
220 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, | |||
221 | // %stride); | |||
222 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
223 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
224 | // --> | |||
225 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
226 | // %stride64, %13) | |||
227 | void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { | |||
228 | ||||
229 | Value *Tile = Bitcast->getOperand(0); | |||
230 | auto *II = cast<IntrinsicInst>(Tile); | |||
231 | // Tile is output from AMX intrinsic. The first operand of the | |||
232 | // intrinsic is row, the second operand of the intrinsic is column. | |||
233 | Value *Row = II->getOperand(0); | |||
234 | Value *Col = II->getOperand(1); | |||
235 | IRBuilder<> Builder(ST); | |||
236 | // Use the maximum column as stride. It must be the same with load | |||
237 | // stride. | |||
238 | Value *Stride = Builder.getInt64(64); | |||
239 | Value *I8Ptr = | |||
240 | Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); | |||
241 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; | |||
242 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
243 | if (Bitcast->hasOneUse()) | |||
244 | return; | |||
245 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
246 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
247 | // %add = <256 x i32> %13, <256 x i32> %src2 | |||
248 | // --> | |||
249 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
250 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
251 | // %stride64, %13) | |||
252 | // %14 = load <256 x i32>, %addr | |||
253 | // %add = <256 x i32> %14, <256 x i32> %src2 | |||
254 | Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); | |||
255 | Bitcast->replaceAllUsesWith(Vec); | |||
256 | } | |||
257 | ||||
258 | // transform bitcast to <store, load> instructions. | |||
259 | bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { | |||
260 | IRBuilder<> Builder(Bitcast); | |||
261 | AllocaInst *AllocaAddr; | |||
262 | Value *I8Ptr, *Stride; | |||
263 | auto *Src = Bitcast->getOperand(0); | |||
264 | ||||
265 | auto Prepare = [&](Type *MemTy) { | |||
266 | AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy); | |||
267 | I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); | |||
268 | Stride = Builder.getInt64(64); | |||
269 | }; | |||
270 | ||||
271 | if (Bitcast->getType()->isX86_AMXTy()) { | |||
272 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
273 | // --> | |||
274 | // %addr = alloca <256 x i32>, align 64 | |||
275 | // store <256 x i32> %src, <256 x i32>* %addr, align 64 | |||
276 | // %addr2 = bitcast <256 x i32>* to i8* | |||
277 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
278 | // i8* %addr2, | |||
279 | // i64 64) | |||
280 | Use &U = *(Bitcast->use_begin()); | |||
281 | unsigned OpNo = U.getOperandNo(); | |||
282 | auto *II = dyn_cast<IntrinsicInst>(U.getUser()); | |||
283 | if (!II) | |||
284 | return false; // May be bitcast from x86amx to <256 x i32>. | |||
285 | Prepare(Bitcast->getOperand(0)->getType()); | |||
286 | Builder.CreateStore(Src, AllocaAddr); | |||
287 | // TODO we can pick an constant operand for the shape. | |||
288 | Value *Row = nullptr, *Col = nullptr; | |||
289 | std::tie(Row, Col) = getShape(II, OpNo); | |||
290 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; | |||
291 | Value *NewInst = Builder.CreateIntrinsic( | |||
292 | Intrinsic::x86_tileloadd64_internal, None, Args); | |||
293 | Bitcast->replaceAllUsesWith(NewInst); | |||
294 | } else { | |||
295 | // %2 = bitcast x86_amx %src to <256 x i32> | |||
296 | // --> | |||
297 | // %addr = alloca <256 x i32>, align 64 | |||
298 | // %addr2 = bitcast <256 x i32>* to i8* | |||
299 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, | |||
300 | // i8* %addr2, i64 %stride) | |||
301 | // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
302 | auto *II = dyn_cast<IntrinsicInst>(Src); | |||
303 | if (!II) | |||
304 | return false; // May be bitcast from <256 x i32> to x86amx. | |||
305 | Prepare(Bitcast->getType()); | |||
306 | Value *Row = II->getOperand(0); | |||
307 | Value *Col = II->getOperand(1); | |||
308 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; | |||
309 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
310 | Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); | |||
311 | Bitcast->replaceAllUsesWith(NewInst); | |||
312 | } | |||
313 | ||||
314 | return true; | |||
315 | } | |||
316 | ||||
317 | bool X86LowerAMXType::visit() { | |||
318 | SmallVector<Instruction *, 8> DeadInsts; | |||
319 | Col2Row.clear(); | |||
320 | ||||
321 | for (BasicBlock *BB : post_order(&Func)) { | |||
322 | for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend(); | |||
323 | II != IE;) { | |||
324 | Instruction &Inst = *II++; | |||
325 | auto *Bitcast = dyn_cast<BitCastInst>(&Inst); | |||
326 | if (!Bitcast) | |||
327 | continue; | |||
328 | ||||
329 | Value *Src = Bitcast->getOperand(0); | |||
330 | if (Bitcast->getType()->isX86_AMXTy()) { | |||
331 | if (Bitcast->user_empty()) { | |||
332 | DeadInsts.push_back(Bitcast); | |||
333 | continue; | |||
334 | } | |||
335 | LoadInst *LD = dyn_cast<LoadInst>(Src); | |||
336 | if (!LD) { | |||
337 | if (transformBitcast(Bitcast)) | |||
338 | DeadInsts.push_back(Bitcast); | |||
339 | continue; | |||
340 | } | |||
341 | // If load has mutli-user, duplicate a vector load. | |||
342 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
343 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
344 | // %add = add <256 x i32> %src, <256 x i32> %src2 | |||
345 | // --> | |||
346 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
347 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
348 | // i8* %addr, i64 %stride64) | |||
349 | // %add = add <256 x i32> %src, <256 x i32> %src2 | |||
350 | ||||
351 | // If load has one user, the load will be eliminated in DAG ISel. | |||
352 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
353 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
354 | // --> | |||
355 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
356 | // i8* %addr, i64 %stride64) | |||
357 | combineLoadBitcast(LD, Bitcast); | |||
358 | DeadInsts.push_back(Bitcast); | |||
359 | if (LD->hasOneUse()) | |||
360 | DeadInsts.push_back(LD); | |||
361 | } else if (Src->getType()->isX86_AMXTy()) { | |||
362 | if (Bitcast->user_empty()) { | |||
363 | DeadInsts.push_back(Bitcast); | |||
364 | continue; | |||
365 | } | |||
366 | StoreInst *ST = nullptr; | |||
367 | for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end(); | |||
368 | UI != UE;) { | |||
369 | Value *I = (UI++)->getUser(); | |||
370 | ST = dyn_cast<StoreInst>(I); | |||
371 | if (ST) | |||
372 | break; | |||
373 | } | |||
374 | if (!ST) { | |||
375 | if (transformBitcast(Bitcast)) | |||
376 | DeadInsts.push_back(Bitcast); | |||
377 | continue; | |||
378 | } | |||
379 | // If bitcast (%13) has one use, combine bitcast and store to amx store. | |||
380 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, | |||
381 | // %stride); | |||
382 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
383 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
384 | // --> | |||
385 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
386 | // %stride64, %13) | |||
387 | // | |||
388 | // If bitcast (%13) has multi-use, transform as below. | |||
389 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
390 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
391 | // %add = <256 x i32> %13, <256 x i32> %src2 | |||
392 | // --> | |||
393 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
394 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
395 | // %stride64, %13) | |||
396 | // %14 = load <256 x i32>, %addr | |||
397 | // %add = <256 x i32> %14, <256 x i32> %src2 | |||
398 | // | |||
399 | combineBitcastStore(Bitcast, ST); | |||
400 | // Delete user first. | |||
401 | DeadInsts.push_back(ST); | |||
402 | DeadInsts.push_back(Bitcast); | |||
403 | } | |||
404 | } | |||
405 | } | |||
406 | ||||
407 | bool C = !DeadInsts.empty(); | |||
408 | ||||
409 | for (auto *Inst : DeadInsts) | |||
410 | Inst->eraseFromParent(); | |||
411 | ||||
412 | return C; | |||
413 | } | |||
414 | } // anonymous namespace | |||
415 | ||||
416 | static Value *getAllocaPos(BasicBlock *BB) { | |||
417 | Module *M = BB->getModule(); | |||
418 | Function *F = BB->getParent(); | |||
419 | IRBuilder<> Builder(&F->getEntryBlock().front()); | |||
420 | const DataLayout &DL = M->getDataLayout(); | |||
421 | unsigned AllocaAS = DL.getAllocaAddrSpace(); | |||
422 | Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); | |||
423 | AllocaInst *AllocaRes = | |||
424 | new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front()); | |||
425 | BasicBlock::iterator Iter = AllocaRes->getIterator(); | |||
426 | ++Iter; | |||
427 | Builder.SetInsertPoint(&*Iter); | |||
428 | Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy()); | |||
429 | return I8Ptr; | |||
430 | } | |||
431 | ||||
432 | static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { | |||
433 | assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!")(static_cast <bool> (TileDef->getType()->isX86_AMXTy () && "Not define tile!") ? void (0) : __assert_fail ( "TileDef->getType()->isX86_AMXTy() && \"Not define tile!\"" , "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 433, __extension__ __PRETTY_FUNCTION__)); | |||
434 | auto *II = cast<IntrinsicInst>(TileDef); | |||
435 | assert(II && "Not tile intrinsic!")(static_cast <bool> (II && "Not tile intrinsic!" ) ? void (0) : __assert_fail ("II && \"Not tile intrinsic!\"" , "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 435, __extension__ __PRETTY_FUNCTION__)); | |||
436 | Value *Row = II->getOperand(0); | |||
437 | Value *Col = II->getOperand(1); | |||
438 | ||||
439 | BasicBlock *BB = TileDef->getParent(); | |||
440 | BasicBlock::iterator Iter = TileDef->getIterator(); | |||
441 | IRBuilder<> Builder(BB, ++Iter); | |||
442 | Value *Stride = Builder.getInt64(64); | |||
443 | std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; | |||
444 | ||||
445 | Instruction *TileStore = | |||
446 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
447 | return TileStore; | |||
448 | } | |||
449 | ||||
450 | static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { | |||
451 | Value *V = U.get(); | |||
452 | assert(V->getType()->isX86_AMXTy() && "Not define tile!")(static_cast <bool> (V->getType()->isX86_AMXTy() && "Not define tile!") ? void (0) : __assert_fail ("V->getType()->isX86_AMXTy() && \"Not define tile!\"" , "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 452, __extension__ __PRETTY_FUNCTION__)); | |||
453 | ||||
454 | // Get tile shape. | |||
455 | IntrinsicInst *II = nullptr; | |||
456 | if (IsPHI) { | |||
457 | Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0); | |||
458 | II = cast<IntrinsicInst>(PhiOp); | |||
459 | } else { | |||
460 | II = cast<IntrinsicInst>(V); | |||
461 | } | |||
462 | Value *Row = II->getOperand(0); | |||
463 | Value *Col = II->getOperand(1); | |||
464 | ||||
465 | Instruction *UserI = dyn_cast<Instruction>(U.getUser()); | |||
466 | IRBuilder<> Builder(UserI); | |||
467 | Value *Stride = Builder.getInt64(64); | |||
468 | std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; | |||
469 | ||||
470 | Value *TileLoad = | |||
471 | Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); | |||
472 | UserI->replaceUsesOfWith(V, TileLoad); | |||
473 | } | |||
474 | ||||
475 | static bool isIncomingOfPHI(Instruction *I) { | |||
476 | for (Use &U : I->uses()) { | |||
477 | User *V = U.getUser(); | |||
478 | if (isa<PHINode>(V)) | |||
479 | return true; | |||
480 | } | |||
481 | return false; | |||
482 | } | |||
483 | ||||
484 | // Let all AMX tile data become volatile data, shorten the life range | |||
485 | // of each tile register before fast register allocation. | |||
486 | namespace { | |||
487 | class X86VolatileTileData { | |||
488 | Function &F; | |||
489 | ||||
490 | public: | |||
491 | X86VolatileTileData(Function &Func) : F(Func) {} | |||
492 | Value *updatePhiIncomings(BasicBlock *BB, | |||
493 | SmallVector<Instruction *, 2> &Incomings); | |||
494 | void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); | |||
495 | bool volatileTileData(); | |||
496 | void volatileTilePHI(PHINode *Inst); | |||
497 | void volatileTileNonPHI(Instruction *I); | |||
498 | }; | |||
499 | ||||
500 | Value *X86VolatileTileData::updatePhiIncomings( | |||
501 | BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { | |||
502 | Value *I8Ptr = getAllocaPos(BB); | |||
503 | ||||
504 | for (auto *I : Incomings) { | |||
505 | User *Store = createTileStore(I, I8Ptr); | |||
506 | ||||
507 | // All its uses (except phi) should load from stored mem. | |||
508 | for (Use &U : I->uses()) { | |||
509 | User *V = U.getUser(); | |||
510 | if (isa<PHINode>(V) || V == Store) | |||
511 | continue; | |||
512 | replaceWithTileLoad(U, I8Ptr); | |||
513 | } | |||
514 | } | |||
515 | return I8Ptr; | |||
516 | } | |||
517 | ||||
518 | void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, | |||
519 | Value *StorePtr) { | |||
520 | for (Use &U : PHI->uses()) | |||
521 | replaceWithTileLoad(U, StorePtr, true); | |||
522 | PHI->eraseFromParent(); | |||
523 | } | |||
524 | ||||
525 | // Smilar with volatileTileNonPHI, this function only handle PHI Nodes | |||
526 | // and their related AMX intrinsics. | |||
527 | // 1) PHI Def should change to tileload. | |||
528 | // 2) PHI Incoming Values should tilestored in just after their def. | |||
529 | // 3) The mem of these tileload and tilestores should be same. | |||
530 | // e.g. | |||
531 | // ------------------------------------------------------ | |||
532 | // bb_dom: | |||
533 | // ... | |||
534 | // br i1 %bool.cond, label %if.else, label %if.then | |||
535 | // | |||
536 | // if.then: | |||
537 | // def %t0 = ... | |||
538 | // ... | |||
539 | // use %t0 | |||
540 | // ... | |||
541 | // br label %if.end | |||
542 | // | |||
543 | // if.else: | |||
544 | // def %t1 = ... | |||
545 | // br label %if.end | |||
546 | // | |||
547 | // if.end: | |||
548 | // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] | |||
549 | // ... | |||
550 | // use %td | |||
551 | // ------------------------------------------------------ | |||
552 | // --> | |||
553 | // ------------------------------------------------------ | |||
554 | // bb_entry: | |||
555 | // %mem = alloca <256 x i32>, align 1024 * | |||
556 | // ... | |||
557 | // bb_dom: | |||
558 | // ... | |||
559 | // br i1 %bool.cond, label %if.else, label %if.then | |||
560 | // | |||
561 | // if.then: | |||
562 | // def %t0 = ... | |||
563 | // call void @llvm.x86.tilestored64.internal(mem, %t0) * | |||
564 | // ... | |||
565 | // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* | |||
566 | // use %t0` * | |||
567 | // ... | |||
568 | // br label %if.end | |||
569 | // | |||
570 | // if.else: | |||
571 | // def %t1 = ... | |||
572 | // call void @llvm.x86.tilestored64.internal(mem, %t1) * | |||
573 | // br label %if.end | |||
574 | // | |||
575 | // if.end: | |||
576 | // ... | |||
577 | // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * | |||
578 | // use %td | |||
579 | // ------------------------------------------------------ | |||
580 | void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { | |||
581 | BasicBlock *BB = PHI->getParent(); | |||
| ||||
582 | SmallVector<Instruction *, 2> Incomings; | |||
583 | ||||
584 | for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { | |||
585 | Value *Op = PHI->getIncomingValue(I); | |||
586 | Instruction *Inst = dyn_cast<Instruction>(Op); | |||
587 | assert(Inst && "We shouldn't fold AMX instrution!")(static_cast <bool> (Inst && "We shouldn't fold AMX instrution!" ) ? void (0) : __assert_fail ("Inst && \"We shouldn't fold AMX instrution!\"" , "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 587, __extension__ __PRETTY_FUNCTION__)); | |||
588 | Incomings.push_back(Inst); | |||
589 | } | |||
590 | ||||
591 | Value *StorePtr = updatePhiIncomings(BB, Incomings); | |||
592 | replacePhiDefWithLoad(PHI, StorePtr); | |||
593 | } | |||
594 | ||||
595 | // Store the defined tile and load it before use. | |||
596 | // All its users are not PHI. | |||
597 | // e.g. | |||
598 | // ------------------------------------------------------ | |||
599 | // def %td = ... | |||
600 | // ... | |||
601 | // "use %td" | |||
602 | // ------------------------------------------------------ | |||
603 | // --> | |||
604 | // ------------------------------------------------------ | |||
605 | // def %td = ... | |||
606 | // call void @llvm.x86.tilestored64.internal(mem, %td) | |||
607 | // ... | |||
608 | // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) | |||
609 | // "use %td2" | |||
610 | // ------------------------------------------------------ | |||
611 | void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { | |||
612 | BasicBlock *BB = I->getParent(); | |||
613 | Value *I8Ptr = getAllocaPos(BB); | |||
614 | User *Store = createTileStore(I, I8Ptr); | |||
615 | ||||
616 | // All its uses should load from stored mem. | |||
617 | for (Use &U : I->uses()) { | |||
618 | User *V = U.getUser(); | |||
619 | assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!")(static_cast <bool> (!isa<PHINode>(V) && "PHI Nodes should be excluded!" ) ? void (0) : __assert_fail ("!isa<PHINode>(V) && \"PHI Nodes should be excluded!\"" , "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 619, __extension__ __PRETTY_FUNCTION__)); | |||
620 | if (V != Store) | |||
621 | replaceWithTileLoad(U, I8Ptr); | |||
622 | } | |||
623 | } | |||
624 | ||||
625 | // Volatile Tile Model: | |||
626 | // 1) All the uses of tile data comes from tileload in time. | |||
627 | // 2) All the defs of tile data tilestore into mem immediately. | |||
628 | // For example: | |||
629 | // -------------------------------------------------------------------------- | |||
630 | // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key | |||
631 | // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) | |||
632 | // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx | |||
633 | // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) | |||
634 | // call void @llvm.x86.tilestored64.internal(... td) area | |||
635 | // -------------------------------------------------------------------------- | |||
636 | // 3) No terminator, call or other amx instructions in the key amx area. | |||
637 | bool X86VolatileTileData::volatileTileData() { | |||
638 | bool Changed = false; | |||
639 | for (BasicBlock &BB : F) { | |||
640 | SmallVector<Instruction *, 2> PHIInsts; | |||
641 | SmallVector<Instruction *, 8> AMXDefInsts; | |||
642 | ||||
643 | for (Instruction &I : BB) { | |||
644 | if (!I.getType()->isX86_AMXTy()) | |||
645 | continue; | |||
646 | if (isa<PHINode>(&I)) | |||
647 | PHIInsts.push_back(&I); | |||
648 | else | |||
649 | AMXDefInsts.push_back(&I); | |||
650 | } | |||
651 | ||||
652 | // First we "volatile" the non-phi related amx intrinsics. | |||
653 | for (Instruction *I : AMXDefInsts) { | |||
654 | if (isIncomingOfPHI(I)) | |||
655 | continue; | |||
656 | volatileTileNonPHI(I); | |||
657 | Changed = true; | |||
658 | } | |||
659 | ||||
660 | for (Instruction *I : PHIInsts) { | |||
661 | volatileTilePHI(dyn_cast<PHINode>(I)); | |||
662 | Changed = true; | |||
663 | } | |||
664 | } | |||
665 | return Changed; | |||
666 | } | |||
667 | ||||
668 | } // anonymous namespace | |||
669 | ||||
670 | namespace { | |||
671 | ||||
672 | class X86LowerAMXCast { | |||
673 | Function &Func; | |||
674 | ||||
675 | public: | |||
676 | X86LowerAMXCast(Function &F) : Func(F) {} | |||
677 | bool combineAMXcast(TargetLibraryInfo *TLI); | |||
678 | bool transformAMXCast(IntrinsicInst *AMXCast); | |||
679 | bool transformAllAMXCast(); | |||
680 | bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, | |||
681 | SmallSetVector<Instruction *, 16> &DeadInst); | |||
682 | }; | |||
683 | ||||
684 | static bool DCEInstruction(Instruction *I, | |||
685 | SmallSetVector<Instruction *, 16> &WorkList, | |||
686 | const TargetLibraryInfo *TLI) { | |||
687 | if (isInstructionTriviallyDead(I, TLI)) { | |||
688 | salvageDebugInfo(*I); | |||
689 | salvageKnowledge(I); | |||
690 | ||||
691 | // Null out all of the instruction's operands to see if any operand becomes | |||
692 | // dead as we go. | |||
693 | for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { | |||
694 | Value *OpV = I->getOperand(i); | |||
695 | I->setOperand(i, nullptr); | |||
696 | ||||
697 | if (!OpV->use_empty() || I == OpV) | |||
698 | continue; | |||
699 | ||||
700 | // If the operand is an instruction that became dead as we nulled out the | |||
701 | // operand, and if it is 'trivially' dead, delete it in a future loop | |||
702 | // iteration. | |||
703 | if (Instruction *OpI = dyn_cast<Instruction>(OpV)) { | |||
704 | if (isInstructionTriviallyDead(OpI, TLI)) { | |||
705 | WorkList.insert(OpI); | |||
706 | } | |||
707 | } | |||
708 | } | |||
709 | I->eraseFromParent(); | |||
710 | return true; | |||
711 | } | |||
712 | return false; | |||
713 | } | |||
714 | ||||
715 | /// This function handles following case | |||
716 | /// | |||
717 | /// A -> B amxcast | |||
718 | /// PHI | |||
719 | /// B -> A amxcast | |||
720 | /// | |||
721 | /// All the related PHI nodes can be replaced by new PHI nodes with type A. | |||
722 | /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. | |||
723 | bool X86LowerAMXCast::optimizeAMXCastFromPhi( | |||
724 | IntrinsicInst *CI, PHINode *PN, | |||
725 | SmallSetVector<Instruction *, 16> &DeadInst) { | |||
726 | IRBuilder<> Builder(CI); | |||
727 | Value *Src = CI->getOperand(0); | |||
728 | Type *SrcTy = Src->getType(); // Type B | |||
729 | Type *DestTy = CI->getType(); // Type A | |||
730 | ||||
731 | SmallVector<PHINode *, 4> PhiWorklist; | |||
732 | SmallSetVector<PHINode *, 4> OldPhiNodes; | |||
733 | ||||
734 | // Find all of the A->B casts and PHI nodes. | |||
735 | // We need to inspect all related PHI nodes, but PHIs can be cyclic, so | |||
736 | // OldPhiNodes is used to track all known PHI nodes, before adding a new | |||
737 | // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. | |||
738 | PhiWorklist.push_back(PN); | |||
739 | OldPhiNodes.insert(PN); | |||
740 | while (!PhiWorklist.empty()) { | |||
741 | auto *OldPN = PhiWorklist.pop_back_val(); | |||
742 | for (Value *IncValue : OldPN->incoming_values()) { | |||
743 | // TODO: currently, We ignore cases where it is a const. In the future, we | |||
744 | // might support const. | |||
745 | if (isa<Constant>(IncValue)) | |||
746 | return false; | |||
747 | ||||
748 | if (auto *PNode = dyn_cast<PHINode>(IncValue)) { | |||
749 | if (OldPhiNodes.insert(PNode)) | |||
750 | PhiWorklist.push_back(PNode); | |||
751 | continue; | |||
752 | } | |||
753 | Instruction *ACI = dyn_cast<Instruction>(IncValue); | |||
754 | if (ACI && isAMXCast(ACI)) { | |||
755 | // Verify it's a A->B cast. | |||
756 | Type *TyA = ACI->getOperand(0)->getType(); | |||
757 | Type *TyB = ACI->getType(); | |||
758 | if (TyA != DestTy || TyB != SrcTy) | |||
759 | return false; | |||
760 | continue; | |||
761 | } | |||
762 | return false; | |||
763 | } | |||
764 | } | |||
765 | ||||
766 | // Check that each user of each old PHI node is something that we can | |||
767 | // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. | |||
768 | for (auto *OldPN : OldPhiNodes) { | |||
769 | for (User *V : OldPN->users()) { | |||
770 | Instruction *ACI = dyn_cast<Instruction>(V); | |||
771 | if (ACI && isAMXCast(ACI)) { | |||
772 | // Verify it's a B->A cast. | |||
773 | Type *TyB = ACI->getOperand(0)->getType(); | |||
774 | Type *TyA = ACI->getType(); | |||
775 | if (TyA != DestTy || TyB != SrcTy) | |||
776 | return false; | |||
777 | } else if (auto *PHI = dyn_cast<PHINode>(V)) { | |||
778 | // As long as the user is another old PHI node, then even if we don't | |||
779 | // rewrite it, the PHI web we're considering won't have any users | |||
780 | // outside itself, so it'll be dead. | |||
781 | // example: | |||
782 | // bb.0: | |||
783 | // %0 = amxcast ... | |||
784 | // bb.1: | |||
785 | // %1 = amxcast ... | |||
786 | // bb.2: | |||
787 | // %goodphi = phi %0, %1 | |||
788 | // %3 = amxcast %goodphi | |||
789 | // bb.3: | |||
790 | // %goodphi2 = phi %0, %goodphi | |||
791 | // %4 = amxcast %goodphi2 | |||
792 | // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is | |||
793 | // outside the phi-web, so the combination stop When | |||
794 | // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization | |||
795 | // will be done. | |||
796 | if (OldPhiNodes.count(PHI) == 0) | |||
797 | return false; | |||
798 | } else | |||
799 | return false; | |||
800 | } | |||
801 | } | |||
802 | ||||
803 | // For each old PHI node, create a corresponding new PHI node with a type A. | |||
804 | SmallDenseMap<PHINode *, PHINode *> NewPNodes; | |||
805 | for (auto *OldPN : OldPhiNodes) { | |||
806 | Builder.SetInsertPoint(OldPN); | |||
807 | PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); | |||
808 | NewPNodes[OldPN] = NewPN; | |||
809 | } | |||
810 | ||||
811 | // Fill in the operands of new PHI nodes. | |||
812 | for (auto *OldPN : OldPhiNodes) { | |||
813 | PHINode *NewPN = NewPNodes[OldPN]; | |||
814 | for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { | |||
815 | Value *V = OldPN->getOperand(j); | |||
816 | Value *NewV = nullptr; | |||
817 | Instruction *ACI = dyn_cast<Instruction>(V); | |||
818 | // There should not be a AMXcast from a const. | |||
819 | if (ACI && isAMXCast(ACI)) | |||
820 | NewV = ACI->getOperand(0); | |||
821 | else if (auto *PrevPN = dyn_cast<PHINode>(V)) | |||
822 | NewV = NewPNodes[PrevPN]; | |||
823 | assert(NewV)(static_cast <bool> (NewV) ? void (0) : __assert_fail ( "NewV", "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 823, __extension__ __PRETTY_FUNCTION__)); | |||
824 | NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); | |||
825 | } | |||
826 | } | |||
827 | ||||
828 | // Traverse all accumulated PHI nodes and process its users, | |||
829 | // which are Stores and BitcCasts. Without this processing | |||
830 | // NewPHI nodes could be replicated and could lead to extra | |||
831 | // moves generated after DeSSA. | |||
832 | // If there is a store with type B, change it to type A. | |||
833 | ||||
834 | // Replace users of BitCast B->A with NewPHI. These will help | |||
835 | // later to get rid of a closure formed by OldPHI nodes. | |||
836 | for (auto *OldPN : OldPhiNodes) { | |||
837 | PHINode *NewPN = NewPNodes[OldPN]; | |||
838 | for (User *V : make_early_inc_range(OldPN->users())) { | |||
839 | Instruction *ACI = dyn_cast<Instruction>(V); | |||
840 | if (ACI && isAMXCast(ACI)) { | |||
841 | Type *TyB = ACI->getOperand(0)->getType(); | |||
842 | Type *TyA = ACI->getType(); | |||
843 | assert(TyA == DestTy && TyB == SrcTy)(static_cast <bool> (TyA == DestTy && TyB == SrcTy ) ? void (0) : __assert_fail ("TyA == DestTy && TyB == SrcTy" , "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 843, __extension__ __PRETTY_FUNCTION__)); | |||
844 | (void)TyA; | |||
845 | (void)TyB; | |||
846 | ACI->replaceAllUsesWith(NewPN); | |||
847 | DeadInst.insert(ACI); | |||
848 | } else if (auto *PHI = dyn_cast<PHINode>(V)) { | |||
849 | // We don't need to push PHINode into DeadInst since they are operands | |||
850 | // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. | |||
851 | assert(OldPhiNodes.contains(PHI))(static_cast <bool> (OldPhiNodes.contains(PHI)) ? void ( 0) : __assert_fail ("OldPhiNodes.contains(PHI)", "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 851, __extension__ __PRETTY_FUNCTION__)); | |||
852 | (void)PHI; | |||
853 | } else | |||
854 | llvm_unreachable("all uses should be handled")::llvm::llvm_unreachable_internal("all uses should be handled" , "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp" , 854); | |||
855 | } | |||
856 | } | |||
857 | return true; | |||
858 | } | |||
859 | ||||
860 | bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { | |||
861 | bool Change = false; | |||
862 | // Collect tile cast instruction. | |||
863 | SmallVector<Instruction *, 8> Vec2TileInsts; | |||
864 | SmallVector<Instruction *, 8> Tile2VecInsts; | |||
865 | SmallVector<Instruction *, 8> PhiCastWorkList; | |||
866 | SmallSetVector<Instruction *, 16> DeadInst; | |||
867 | for (BasicBlock &BB : Func) { | |||
868 | for (Instruction &I : BB) { | |||
869 | Value *Vec; | |||
870 | if (match(&I, | |||
871 | m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec)))) | |||
872 | Vec2TileInsts.push_back(&I); | |||
873 | else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( | |||
874 | m_Value(Vec)))) | |||
875 | Tile2VecInsts.push_back(&I); | |||
876 | } | |||
877 | } | |||
878 | ||||
879 | auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { | |||
880 | for (auto *Inst : Insts) { | |||
881 | for (User *U : Inst->users()) { | |||
882 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); | |||
883 | if (!II || II->getIntrinsicID() != IID) | |||
884 | continue; | |||
885 | // T1 = vec2tile V0 | |||
886 | // V2 = tile2vec T1 | |||
887 | // V3 = OP V2 | |||
888 | // --> | |||
889 | // T1 = vec2tile V0 | |||
890 | // V2 = tile2vec T1 | |||
891 | // V3 = OP V0 | |||
892 | II->replaceAllUsesWith(Inst->getOperand(0)); | |||
893 | Change = true; | |||
894 | } | |||
895 | } | |||
896 | }; | |||
897 | ||||
898 | Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); | |||
899 | Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); | |||
900 | ||||
901 | auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { | |||
902 | for (auto *Inst : Insts) { | |||
903 | if (Inst->use_empty()) { | |||
904 | Inst->eraseFromParent(); | |||
905 | Change = true; | |||
906 | } | |||
907 | } | |||
908 | }; | |||
909 | ||||
910 | EraseInst(Vec2TileInsts); | |||
911 | EraseInst(Tile2VecInsts); | |||
912 | ||||
913 | // Handle the A->B->A cast, and there is an intervening PHI node. | |||
914 | for (BasicBlock &BB : Func) { | |||
915 | for (Instruction &I : BB) { | |||
916 | if (isAMXCast(&I)) { | |||
917 | if (isa<PHINode>(I.getOperand(0))) | |||
918 | PhiCastWorkList.push_back(&I); | |||
919 | } | |||
920 | } | |||
921 | } | |||
922 | for (auto *I : PhiCastWorkList) { | |||
923 | // We skip the dead Amxcast. | |||
924 | if (DeadInst.contains(I)) | |||
925 | continue; | |||
926 | PHINode *PN = cast<PHINode>(I->getOperand(0)); | |||
927 | if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) { | |||
928 | DeadInst.insert(PN); | |||
929 | Change = true; | |||
930 | } | |||
931 | } | |||
932 | ||||
933 | // Since we create new phi and merge AMXCast, some old phis and AMXCast might | |||
934 | // have no uses. We do some DeadCodeElimination for them. | |||
935 | while (!DeadInst.empty()) { | |||
936 | Instruction *I = DeadInst.pop_back_val(); | |||
937 | Change |= DCEInstruction(I, DeadInst, TLI); | |||
938 | } | |||
939 | return Change; | |||
940 | } | |||
941 | ||||
942 | // There might be remaining AMXcast after combineAMXcast and they should be | |||
943 | // handled elegantly. | |||
944 | bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { | |||
945 | IRBuilder<> Builder(AMXCast); | |||
946 | AllocaInst *AllocaAddr; | |||
947 | Value *I8Ptr, *Stride; | |||
948 | auto *Src = AMXCast->getOperand(0); | |||
949 | ||||
950 | auto Prepare = [&](Type *MemTy) { | |||
951 | AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy); | |||
952 | I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); | |||
953 | Stride = Builder.getInt64(64); | |||
954 | }; | |||
955 | ||||
956 | if (AMXCast->getType()->isX86_AMXTy()) { | |||
957 | // %2 = amxcast <225 x i32> %src to x86_amx | |||
958 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, | |||
959 | // i8* %addr3, i64 60, x86_amx %2) | |||
960 | // --> | |||
961 | // %addr = alloca <225 x i32>, align 64 | |||
962 | // store <225 x i32> %src, <225 x i32>* %addr, align 64 | |||
963 | // %addr2 = bitcast <225 x i32>* %addr to i8* | |||
964 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, | |||
965 | // i8* %addr2, | |||
966 | // i64 60) | |||
967 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, | |||
968 | // i8* %addr3, i64 60, x86_amx %2) | |||
969 | Use &U = *(AMXCast->use_begin()); | |||
970 | unsigned OpNo = U.getOperandNo(); | |||
971 | auto *II = dyn_cast<IntrinsicInst>(U.getUser()); | |||
972 | if (!II) | |||
973 | return false; // May be bitcast from x86amx to <256 x i32>. | |||
974 | Prepare(AMXCast->getOperand(0)->getType()); | |||
975 | Builder.CreateStore(Src, AllocaAddr); | |||
976 | // TODO we can pick an constant operand for the shape. | |||
977 | Value *Row = nullptr, *Col = nullptr; | |||
978 | std::tie(Row, Col) = getShape(II, OpNo); | |||
979 | std::array<Value *, 4> Args = { | |||
980 | Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())}; | |||
981 | Value *NewInst = Builder.CreateIntrinsic( | |||
982 | Intrinsic::x86_tileloadd64_internal, None, Args); | |||
983 | AMXCast->replaceAllUsesWith(NewInst); | |||
984 | AMXCast->eraseFromParent(); | |||
985 | } else { | |||
986 | // %2 = amxcast x86_amx %src to <225 x i32> | |||
987 | // --> | |||
988 | // %addr = alloca <225 x i32>, align 64 | |||
989 | // %addr2 = bitcast <225 x i32>* to i8* | |||
990 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, | |||
991 | // i8* %addr2, i64 %stride) | |||
992 | // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 | |||
993 | auto *II = dyn_cast<IntrinsicInst>(Src); | |||
994 | if (!II) | |||
995 | return false; // May be bitcast from <256 x i32> to x86amx. | |||
996 | Prepare(AMXCast->getType()); | |||
997 | Value *Row = II->getOperand(0); | |||
998 | Value *Col = II->getOperand(1); | |||
999 | std::array<Value *, 5> Args = { | |||
1000 | Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src}; | |||
1001 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
1002 | Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr); | |||
1003 | AMXCast->replaceAllUsesWith(NewInst); | |||
1004 | AMXCast->eraseFromParent(); | |||
1005 | } | |||
1006 | ||||
1007 | return true; | |||
1008 | } | |||
1009 | ||||
1010 | bool X86LowerAMXCast::transformAllAMXCast() { | |||
1011 | bool Change = false; | |||
1012 | // Collect tile cast instruction. | |||
1013 | SmallVector<Instruction *, 8> WorkLists; | |||
1014 | for (BasicBlock &BB : Func) { | |||
1015 | for (Instruction &I : BB) { | |||
1016 | if (isAMXCast(&I)) | |||
1017 | WorkLists.push_back(&I); | |||
1018 | } | |||
1019 | } | |||
1020 | ||||
1021 | for (auto *Inst : WorkLists) { | |||
1022 | Change |= transformAMXCast(cast<IntrinsicInst>(Inst)); | |||
1023 | } | |||
1024 | ||||
1025 | return Change; | |||
1026 | } | |||
1027 | ||||
1028 | } // anonymous namespace | |||
1029 | ||||
1030 | namespace { | |||
1031 | ||||
1032 | class X86LowerAMXTypeLegacyPass : public FunctionPass { | |||
1033 | public: | |||
1034 | static char ID; | |||
1035 | ||||
1036 | X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { | |||
1037 | initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); | |||
1038 | } | |||
1039 | ||||
1040 | bool runOnFunction(Function &F) override { | |||
1041 | bool C = false; | |||
1042 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); | |||
1043 | TargetLibraryInfo *TLI = | |||
1044 | &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); | |||
1045 | X86LowerAMXCast LAC(F); | |||
1046 | C |= LAC.combineAMXcast(TLI); | |||
1047 | // There might be remaining AMXcast after combineAMXcast and they should be | |||
1048 | // handled elegantly. | |||
1049 | C |= LAC.transformAllAMXCast(); | |||
1050 | ||||
1051 | X86LowerAMXType LAT(F); | |||
1052 | C |= LAT.visit(); | |||
1053 | ||||
1054 | // Prepare for fast register allocation at O0. | |||
1055 | // Todo: May better check the volatile model of AMX code, not just | |||
1056 | // by checking Attribute::OptimizeNone and CodeGenOpt::None. | |||
1057 | if (TM->getOptLevel() == CodeGenOpt::None) { | |||
| ||||
1058 | // If Front End not use O0 but the Mid/Back end use O0, (e.g. | |||
1059 | // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make | |||
1060 | // sure the amx data is volatile, that is nessary for AMX fast | |||
1061 | // register allocation. | |||
1062 | if (!F.hasFnAttribute(Attribute::OptimizeNone)) { | |||
1063 | X86VolatileTileData VTD(F); | |||
1064 | C = VTD.volatileTileData() || C; | |||
1065 | } | |||
1066 | } | |||
1067 | ||||
1068 | return C; | |||
1069 | } | |||
1070 | ||||
1071 | void getAnalysisUsage(AnalysisUsage &AU) const override { | |||
1072 | AU.setPreservesCFG(); | |||
1073 | AU.addRequired<TargetPassConfig>(); | |||
1074 | AU.addRequired<TargetLibraryInfoWrapperPass>(); | |||
1075 | } | |||
1076 | }; | |||
1077 | ||||
1078 | } // anonymous namespace | |||
1079 | ||||
1080 | static const char PassName[] = "Lower AMX type for load/store"; | |||
1081 | char X86LowerAMXTypeLegacyPass::ID = 0; | |||
1082 | INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry &Registry) { | |||
1083 | false)static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry &Registry) { | |||
1084 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)initializeTargetPassConfigPass(Registry); | |||
1085 | INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)initializeTargetLibraryInfoWrapperPassPass(Registry); | |||
1086 | INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,PassInfo *PI = new PassInfo( PassName, "lower-amx-type", & X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor <X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass (*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag ; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag , initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry )); } | |||
1087 | false)PassInfo *PI = new PassInfo( PassName, "lower-amx-type", & X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor <X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass (*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag ; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag , initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry )); } | |||
1088 | ||||
1089 | FunctionPass *llvm::createX86LowerAMXTypePass() { | |||
1090 | return new X86LowerAMXTypeLegacyPass(); | |||
1091 | } |