File: | build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/llvm/lib/Target/X86/X86LowerAMXType.cpp |
Warning: | line 610, column 20 Called C++ object pointer is null |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | /// \file Pass to transform <256 x i32> load/store | |||
10 | /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only | |||
11 | /// provides simple operation on x86_amx. The basic elementwise operation | |||
12 | /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> | |||
13 | /// and only AMX intrinsics can operate on the type, we need transform | |||
14 | /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can | |||
15 | /// not be combined with load/store, we transform the bitcast to amx load/store | |||
16 | /// and <256 x i32> store/load. | |||
17 | /// | |||
18 | /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S | |||
19 | /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, | |||
20 | /// because that is necessary for AMX fast register allocation. (In Fast | |||
21 | /// registera allocation, register will be allocated before spill/reload, so | |||
22 | /// there is no additional register for amx to identify the step in spill.) | |||
23 | /// The volatileTileData() will handle this case. | |||
24 | /// e.g. | |||
25 | /// ---------------------------------------------------------- | |||
26 | /// | def %td = ... | | |||
27 | /// | ... | | |||
28 | /// | "use %td" | | |||
29 | /// ---------------------------------------------------------- | |||
30 | /// will transfer to --> | |||
31 | /// ---------------------------------------------------------- | |||
32 | /// | def %td = ... | | |||
33 | /// | call void @llvm.x86.tilestored64.internal(mem, %td) | | |||
34 | /// | ... | | |||
35 | /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| | |||
36 | /// | "use %td2" | | |||
37 | /// ---------------------------------------------------------- | |||
38 | // | |||
39 | //===----------------------------------------------------------------------===// | |||
40 | // | |||
41 | #include "X86.h" | |||
42 | #include "llvm/ADT/PostOrderIterator.h" | |||
43 | #include "llvm/ADT/SetVector.h" | |||
44 | #include "llvm/ADT/SmallSet.h" | |||
45 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" | |||
46 | #include "llvm/Analysis/TargetLibraryInfo.h" | |||
47 | #include "llvm/Analysis/TargetTransformInfo.h" | |||
48 | #include "llvm/CodeGen/Passes.h" | |||
49 | #include "llvm/CodeGen/TargetPassConfig.h" | |||
50 | #include "llvm/CodeGen/ValueTypes.h" | |||
51 | #include "llvm/IR/DataLayout.h" | |||
52 | #include "llvm/IR/Function.h" | |||
53 | #include "llvm/IR/IRBuilder.h" | |||
54 | #include "llvm/IR/Instructions.h" | |||
55 | #include "llvm/IR/IntrinsicInst.h" | |||
56 | #include "llvm/IR/IntrinsicsX86.h" | |||
57 | #include "llvm/IR/PatternMatch.h" | |||
58 | #include "llvm/InitializePasses.h" | |||
59 | #include "llvm/Pass.h" | |||
60 | #include "llvm/Target/TargetMachine.h" | |||
61 | #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" | |||
62 | #include "llvm/Transforms/Utils/Local.h" | |||
63 | ||||
64 | #include <map> | |||
65 | ||||
66 | using namespace llvm; | |||
67 | using namespace PatternMatch; | |||
68 | ||||
69 | #define DEBUG_TYPE"lower-amx-type" "lower-amx-type" | |||
70 | ||||
71 | static bool isAMXCast(Instruction *II) { | |||
72 | return match(II, | |||
73 | m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) || | |||
74 | match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value())); | |||
75 | } | |||
76 | ||||
77 | static bool isAMXInstrinsic(User *I) { | |||
78 | auto *II = dyn_cast<IntrinsicInst>(I); | |||
79 | if (!II) | |||
80 | return false; | |||
81 | if (isAMXCast(II)) | |||
82 | return false; | |||
83 | // Check if return type or parameter is x86_amx. If it is x86_amx | |||
84 | // the intrinsic must be x86 amx intrinsics. | |||
85 | if (II->getType()->isX86_AMXTy()) | |||
86 | return true; | |||
87 | for (Value *V : II->args()) { | |||
88 | if (V->getType()->isX86_AMXTy()) | |||
89 | return true; | |||
90 | } | |||
91 | ||||
92 | return false; | |||
93 | } | |||
94 | ||||
95 | static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, | |||
96 | Type *Ty) { | |||
97 | Function &F = *BB->getParent(); | |||
98 | Module *M = BB->getModule(); | |||
99 | const DataLayout &DL = M->getDataLayout(); | |||
100 | ||||
101 | LLVMContext &Ctx = Builder.getContext(); | |||
102 | auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); | |||
103 | unsigned AllocaAS = DL.getAllocaAddrSpace(); | |||
104 | AllocaInst *AllocaRes = | |||
105 | new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front()); | |||
106 | AllocaRes->setAlignment(AllocaAlignment); | |||
107 | return AllocaRes; | |||
108 | } | |||
109 | ||||
110 | static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { | |||
111 | for (Instruction &I : F.getEntryBlock()) | |||
112 | if (!isa<AllocaInst>(&I)) | |||
113 | return &I; | |||
114 | llvm_unreachable("No terminator in the entry block!")::llvm::llvm_unreachable_internal("No terminator in the entry block!" , "llvm/lib/Target/X86/X86LowerAMXType.cpp", 114); | |||
115 | } | |||
116 | ||||
117 | static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { | |||
118 | IRBuilder<> Builder(II); | |||
119 | Value *Row = nullptr, *Col = nullptr; | |||
120 | switch (II->getIntrinsicID()) { | |||
121 | default: | |||
122 | llvm_unreachable("Expect amx intrinsics")::llvm::llvm_unreachable_internal("Expect amx intrinsics", "llvm/lib/Target/X86/X86LowerAMXType.cpp" , 122); | |||
123 | case Intrinsic::x86_tileloadd64_internal: | |||
124 | case Intrinsic::x86_tileloaddt164_internal: | |||
125 | case Intrinsic::x86_tilestored64_internal: { | |||
126 | Row = II->getArgOperand(0); | |||
127 | Col = II->getArgOperand(1); | |||
128 | break; | |||
129 | } | |||
130 | // a * b + c | |||
131 | // The shape depends on which operand. | |||
132 | case Intrinsic::x86_tdpbssd_internal: | |||
133 | case Intrinsic::x86_tdpbsud_internal: | |||
134 | case Intrinsic::x86_tdpbusd_internal: | |||
135 | case Intrinsic::x86_tdpbuud_internal: | |||
136 | case Intrinsic::x86_tdpbf16ps_internal: { | |||
137 | switch (OpNo) { | |||
138 | case 3: | |||
139 | Row = II->getArgOperand(0); | |||
140 | Col = II->getArgOperand(1); | |||
141 | break; | |||
142 | case 4: | |||
143 | Row = II->getArgOperand(0); | |||
144 | Col = II->getArgOperand(2); | |||
145 | break; | |||
146 | case 5: | |||
147 | if (isa<ConstantInt>(II->getArgOperand(2))) | |||
148 | Row = Builder.getInt16( | |||
149 | (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4); | |||
150 | else if (isa<Instruction>(II->getArgOperand(2))) { | |||
151 | // When it is not a const value and it is not a function argument, we | |||
152 | // create Row after the definition of II->getOperand(2) instead of | |||
153 | // before II. For example, II is %118, we try to getshape for %117: | |||
154 | // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x | |||
155 | // i32> %115). | |||
156 | // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 | |||
157 | // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx | |||
158 | // %117). | |||
159 | // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its | |||
160 | // definition is after its user(new tileload for %117). | |||
161 | // So, the best choice is to create %row right after the definition of | |||
162 | // %106. | |||
163 | Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2))); | |||
164 | Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4)); | |||
165 | cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2))); | |||
166 | } else { | |||
167 | // When it is not a const value and it is a function argument, we create | |||
168 | // Row at the entry bb. | |||
169 | IRBuilder<> NewBuilder( | |||
170 | getFirstNonAllocaInTheEntryBlock(*II->getFunction())); | |||
171 | Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4)); | |||
172 | } | |||
173 | Col = II->getArgOperand(1); | |||
174 | break; | |||
175 | } | |||
176 | break; | |||
177 | } | |||
178 | } | |||
179 | ||||
180 | return std::make_pair(Row, Col); | |||
181 | } | |||
182 | ||||
183 | static std::pair<Value *, Value *> getShape(PHINode *Phi) { | |||
184 | Use &U = *(Phi->use_begin()); | |||
185 | unsigned OpNo = U.getOperandNo(); | |||
186 | User *V = U.getUser(); | |||
187 | // TODO We don't traverse all users. To make the algorithm simple, here we | |||
188 | // just traverse the first user. If we can find shape, then return the shape, | |||
189 | // otherwise just return nullptr and the optimization for undef/zero will be | |||
190 | // abandoned. | |||
191 | while (V) { | |||
192 | if (isAMXCast(dyn_cast<Instruction>(V))) { | |||
193 | if (V->use_empty()) | |||
194 | break; | |||
195 | Use &U = *(V->use_begin()); | |||
196 | OpNo = U.getOperandNo(); | |||
197 | V = U.getUser(); | |||
198 | } else if (isAMXInstrinsic(V)) { | |||
199 | return getShape(cast<IntrinsicInst>(V), OpNo); | |||
200 | } else if (isa<PHINode>(V)) { | |||
201 | if (V->use_empty()) | |||
202 | break; | |||
203 | Use &U = *(V->use_begin()); | |||
204 | V = U.getUser(); | |||
205 | } else { | |||
206 | break; | |||
207 | } | |||
208 | } | |||
209 | ||||
210 | return std::make_pair(nullptr, nullptr); | |||
211 | } | |||
212 | ||||
213 | namespace { | |||
214 | class X86LowerAMXType { | |||
215 | Function &Func; | |||
216 | ||||
217 | // In AMX intrinsics we let Shape = {Row, Col}, but the | |||
218 | // RealCol = Col / ElementSize. We may use the RealCol | |||
219 | // as a new Row for other new created AMX intrinsics. | |||
220 | std::map<Value *, Value *> Col2Row; | |||
221 | ||||
222 | public: | |||
223 | X86LowerAMXType(Function &F) : Func(F) {} | |||
224 | bool visit(); | |||
225 | void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); | |||
226 | void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); | |||
227 | bool transformBitcast(BitCastInst *Bitcast); | |||
228 | }; | |||
229 | ||||
230 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
231 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
232 | // --> | |||
233 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
234 | // i8* %addr, i64 %stride64) | |||
235 | void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { | |||
236 | Value *Row = nullptr, *Col = nullptr; | |||
237 | Use &U = *(Bitcast->use_begin()); | |||
238 | unsigned OpNo = U.getOperandNo(); | |||
239 | auto *II = cast<IntrinsicInst>(U.getUser()); | |||
240 | std::tie(Row, Col) = getShape(II, OpNo); | |||
241 | IRBuilder<> Builder(Bitcast); | |||
242 | // Use the maximun column as stride. | |||
243 | Value *Stride = Builder.getInt64(64); | |||
244 | Value *I8Ptr = | |||
245 | Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); | |||
246 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; | |||
247 | ||||
248 | Value *NewInst = | |||
249 | Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); | |||
250 | Bitcast->replaceAllUsesWith(NewInst); | |||
251 | } | |||
252 | ||||
253 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, | |||
254 | // %stride); | |||
255 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
256 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
257 | // --> | |||
258 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
259 | // %stride64, %13) | |||
260 | void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { | |||
261 | ||||
262 | Value *Tile = Bitcast->getOperand(0); | |||
263 | auto *II = cast<IntrinsicInst>(Tile); | |||
264 | // Tile is output from AMX intrinsic. The first operand of the | |||
265 | // intrinsic is row, the second operand of the intrinsic is column. | |||
266 | Value *Row = II->getOperand(0); | |||
267 | Value *Col = II->getOperand(1); | |||
268 | IRBuilder<> Builder(ST); | |||
269 | // Use the maximum column as stride. It must be the same with load | |||
270 | // stride. | |||
271 | Value *Stride = Builder.getInt64(64); | |||
272 | Value *I8Ptr = | |||
273 | Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); | |||
274 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; | |||
275 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
276 | if (Bitcast->hasOneUse()) | |||
277 | return; | |||
278 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
279 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
280 | // %add = <256 x i32> %13, <256 x i32> %src2 | |||
281 | // --> | |||
282 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
283 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
284 | // %stride64, %13) | |||
285 | // %14 = load <256 x i32>, %addr | |||
286 | // %add = <256 x i32> %14, <256 x i32> %src2 | |||
287 | Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); | |||
288 | Bitcast->replaceAllUsesWith(Vec); | |||
289 | } | |||
290 | ||||
291 | // transform bitcast to <store, load> instructions. | |||
292 | bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { | |||
293 | IRBuilder<> Builder(Bitcast); | |||
294 | AllocaInst *AllocaAddr; | |||
295 | Value *I8Ptr, *Stride; | |||
296 | auto *Src = Bitcast->getOperand(0); | |||
297 | ||||
298 | auto Prepare = [&](Type *MemTy) { | |||
299 | AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy); | |||
300 | I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); | |||
301 | Stride = Builder.getInt64(64); | |||
302 | }; | |||
303 | ||||
304 | if (Bitcast->getType()->isX86_AMXTy()) { | |||
305 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
306 | // --> | |||
307 | // %addr = alloca <256 x i32>, align 64 | |||
308 | // store <256 x i32> %src, <256 x i32>* %addr, align 64 | |||
309 | // %addr2 = bitcast <256 x i32>* to i8* | |||
310 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
311 | // i8* %addr2, | |||
312 | // i64 64) | |||
313 | Use &U = *(Bitcast->use_begin()); | |||
314 | unsigned OpNo = U.getOperandNo(); | |||
315 | auto *II = dyn_cast<IntrinsicInst>(U.getUser()); | |||
316 | if (!II) | |||
317 | return false; // May be bitcast from x86amx to <256 x i32>. | |||
318 | Prepare(Bitcast->getOperand(0)->getType()); | |||
319 | Builder.CreateStore(Src, AllocaAddr); | |||
320 | // TODO we can pick an constant operand for the shape. | |||
321 | Value *Row = nullptr, *Col = nullptr; | |||
322 | std::tie(Row, Col) = getShape(II, OpNo); | |||
323 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; | |||
324 | Value *NewInst = Builder.CreateIntrinsic( | |||
325 | Intrinsic::x86_tileloadd64_internal, None, Args); | |||
326 | Bitcast->replaceAllUsesWith(NewInst); | |||
327 | } else { | |||
328 | // %2 = bitcast x86_amx %src to <256 x i32> | |||
329 | // --> | |||
330 | // %addr = alloca <256 x i32>, align 64 | |||
331 | // %addr2 = bitcast <256 x i32>* to i8* | |||
332 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, | |||
333 | // i8* %addr2, i64 %stride) | |||
334 | // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
335 | auto *II = dyn_cast<IntrinsicInst>(Src); | |||
336 | if (!II) | |||
337 | return false; // May be bitcast from <256 x i32> to x86amx. | |||
338 | Prepare(Bitcast->getType()); | |||
339 | Value *Row = II->getOperand(0); | |||
340 | Value *Col = II->getOperand(1); | |||
341 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; | |||
342 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
343 | Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); | |||
344 | Bitcast->replaceAllUsesWith(NewInst); | |||
345 | } | |||
346 | ||||
347 | return true; | |||
348 | } | |||
349 | ||||
350 | bool X86LowerAMXType::visit() { | |||
351 | SmallVector<Instruction *, 8> DeadInsts; | |||
352 | Col2Row.clear(); | |||
353 | ||||
354 | for (BasicBlock *BB : post_order(&Func)) { | |||
355 | for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) { | |||
356 | auto *Bitcast = dyn_cast<BitCastInst>(&Inst); | |||
357 | if (!Bitcast) | |||
358 | continue; | |||
359 | ||||
360 | Value *Src = Bitcast->getOperand(0); | |||
361 | if (Bitcast->getType()->isX86_AMXTy()) { | |||
362 | if (Bitcast->user_empty()) { | |||
363 | DeadInsts.push_back(Bitcast); | |||
364 | continue; | |||
365 | } | |||
366 | LoadInst *LD = dyn_cast<LoadInst>(Src); | |||
367 | if (!LD) { | |||
368 | if (transformBitcast(Bitcast)) | |||
369 | DeadInsts.push_back(Bitcast); | |||
370 | continue; | |||
371 | } | |||
372 | // If load has mutli-user, duplicate a vector load. | |||
373 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
374 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
375 | // %add = add <256 x i32> %src, <256 x i32> %src2 | |||
376 | // --> | |||
377 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
378 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
379 | // i8* %addr, i64 %stride64) | |||
380 | // %add = add <256 x i32> %src, <256 x i32> %src2 | |||
381 | ||||
382 | // If load has one user, the load will be eliminated in DAG ISel. | |||
383 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
384 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
385 | // --> | |||
386 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
387 | // i8* %addr, i64 %stride64) | |||
388 | combineLoadBitcast(LD, Bitcast); | |||
389 | DeadInsts.push_back(Bitcast); | |||
390 | if (LD->hasOneUse()) | |||
391 | DeadInsts.push_back(LD); | |||
392 | } else if (Src->getType()->isX86_AMXTy()) { | |||
393 | if (Bitcast->user_empty()) { | |||
394 | DeadInsts.push_back(Bitcast); | |||
395 | continue; | |||
396 | } | |||
397 | StoreInst *ST = nullptr; | |||
398 | for (Use &U : Bitcast->uses()) { | |||
399 | ST = dyn_cast<StoreInst>(U.getUser()); | |||
400 | if (ST) | |||
401 | break; | |||
402 | } | |||
403 | if (!ST) { | |||
404 | if (transformBitcast(Bitcast)) | |||
405 | DeadInsts.push_back(Bitcast); | |||
406 | continue; | |||
407 | } | |||
408 | // If bitcast (%13) has one use, combine bitcast and store to amx store. | |||
409 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, | |||
410 | // %stride); | |||
411 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
412 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
413 | // --> | |||
414 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
415 | // %stride64, %13) | |||
416 | // | |||
417 | // If bitcast (%13) has multi-use, transform as below. | |||
418 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
419 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
420 | // %add = <256 x i32> %13, <256 x i32> %src2 | |||
421 | // --> | |||
422 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
423 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
424 | // %stride64, %13) | |||
425 | // %14 = load <256 x i32>, %addr | |||
426 | // %add = <256 x i32> %14, <256 x i32> %src2 | |||
427 | // | |||
428 | combineBitcastStore(Bitcast, ST); | |||
429 | // Delete user first. | |||
430 | DeadInsts.push_back(ST); | |||
431 | DeadInsts.push_back(Bitcast); | |||
432 | } | |||
433 | } | |||
434 | } | |||
435 | ||||
436 | bool C = !DeadInsts.empty(); | |||
437 | ||||
438 | for (auto *Inst : DeadInsts) | |||
439 | Inst->eraseFromParent(); | |||
440 | ||||
441 | return C; | |||
442 | } | |||
443 | } // anonymous namespace | |||
444 | ||||
445 | static Value *getAllocaPos(BasicBlock *BB) { | |||
446 | Module *M = BB->getModule(); | |||
447 | Function *F = BB->getParent(); | |||
448 | IRBuilder<> Builder(&F->getEntryBlock().front()); | |||
449 | const DataLayout &DL = M->getDataLayout(); | |||
450 | unsigned AllocaAS = DL.getAllocaAddrSpace(); | |||
451 | Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); | |||
452 | AllocaInst *AllocaRes = | |||
453 | new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front()); | |||
454 | BasicBlock::iterator Iter = AllocaRes->getIterator(); | |||
455 | ++Iter; | |||
456 | Builder.SetInsertPoint(&*Iter); | |||
457 | Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy()); | |||
458 | return I8Ptr; | |||
459 | } | |||
460 | ||||
461 | static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { | |||
462 | assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!")(static_cast <bool> (TileDef->getType()->isX86_AMXTy () && "Not define tile!") ? void (0) : __assert_fail ( "TileDef->getType()->isX86_AMXTy() && \"Not define tile!\"" , "llvm/lib/Target/X86/X86LowerAMXType.cpp", 462, __extension__ __PRETTY_FUNCTION__)); | |||
463 | auto *II = cast<IntrinsicInst>(TileDef); | |||
464 | assert(II && "Not tile intrinsic!")(static_cast <bool> (II && "Not tile intrinsic!" ) ? void (0) : __assert_fail ("II && \"Not tile intrinsic!\"" , "llvm/lib/Target/X86/X86LowerAMXType.cpp", 464, __extension__ __PRETTY_FUNCTION__)); | |||
465 | Value *Row = II->getOperand(0); | |||
466 | Value *Col = II->getOperand(1); | |||
467 | ||||
468 | BasicBlock *BB = TileDef->getParent(); | |||
469 | BasicBlock::iterator Iter = TileDef->getIterator(); | |||
470 | IRBuilder<> Builder(BB, ++Iter); | |||
471 | Value *Stride = Builder.getInt64(64); | |||
472 | std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; | |||
473 | ||||
474 | Instruction *TileStore = | |||
475 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
476 | return TileStore; | |||
477 | } | |||
478 | ||||
479 | static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { | |||
480 | Value *V = U.get(); | |||
481 | assert(V->getType()->isX86_AMXTy() && "Not define tile!")(static_cast <bool> (V->getType()->isX86_AMXTy() && "Not define tile!") ? void (0) : __assert_fail ("V->getType()->isX86_AMXTy() && \"Not define tile!\"" , "llvm/lib/Target/X86/X86LowerAMXType.cpp", 481, __extension__ __PRETTY_FUNCTION__)); | |||
482 | ||||
483 | // Get tile shape. | |||
484 | IntrinsicInst *II = nullptr; | |||
485 | if (IsPHI) { | |||
486 | Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0); | |||
487 | II = cast<IntrinsicInst>(PhiOp); | |||
488 | } else { | |||
489 | II = cast<IntrinsicInst>(V); | |||
490 | } | |||
491 | Value *Row = II->getOperand(0); | |||
492 | Value *Col = II->getOperand(1); | |||
493 | ||||
494 | Instruction *UserI = dyn_cast<Instruction>(U.getUser()); | |||
495 | IRBuilder<> Builder(UserI); | |||
496 | Value *Stride = Builder.getInt64(64); | |||
497 | std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; | |||
498 | ||||
499 | Value *TileLoad = | |||
500 | Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); | |||
501 | UserI->replaceUsesOfWith(V, TileLoad); | |||
502 | } | |||
503 | ||||
504 | static bool isIncomingOfPHI(Instruction *I) { | |||
505 | for (Use &U : I->uses()) { | |||
506 | User *V = U.getUser(); | |||
507 | if (isa<PHINode>(V)) | |||
508 | return true; | |||
509 | } | |||
510 | return false; | |||
511 | } | |||
512 | ||||
513 | // Let all AMX tile data become volatile data, shorten the life range | |||
514 | // of each tile register before fast register allocation. | |||
515 | namespace { | |||
516 | class X86VolatileTileData { | |||
517 | Function &F; | |||
518 | ||||
519 | public: | |||
520 | X86VolatileTileData(Function &Func) : F(Func) {} | |||
521 | Value *updatePhiIncomings(BasicBlock *BB, | |||
522 | SmallVector<Instruction *, 2> &Incomings); | |||
523 | void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); | |||
524 | bool volatileTileData(); | |||
525 | void volatileTilePHI(PHINode *Inst); | |||
526 | void volatileTileNonPHI(Instruction *I); | |||
527 | }; | |||
528 | ||||
529 | Value *X86VolatileTileData::updatePhiIncomings( | |||
530 | BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { | |||
531 | Value *I8Ptr = getAllocaPos(BB); | |||
532 | ||||
533 | for (auto *I : Incomings) { | |||
534 | User *Store = createTileStore(I, I8Ptr); | |||
535 | ||||
536 | // All its uses (except phi) should load from stored mem. | |||
537 | for (Use &U : I->uses()) { | |||
538 | User *V = U.getUser(); | |||
539 | if (isa<PHINode>(V) || V == Store) | |||
540 | continue; | |||
541 | replaceWithTileLoad(U, I8Ptr); | |||
542 | } | |||
543 | } | |||
544 | return I8Ptr; | |||
545 | } | |||
546 | ||||
547 | void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, | |||
548 | Value *StorePtr) { | |||
549 | for (Use &U : PHI->uses()) | |||
550 | replaceWithTileLoad(U, StorePtr, true); | |||
551 | PHI->eraseFromParent(); | |||
552 | } | |||
553 | ||||
554 | // Smilar with volatileTileNonPHI, this function only handle PHI Nodes | |||
555 | // and their related AMX intrinsics. | |||
556 | // 1) PHI Def should change to tileload. | |||
557 | // 2) PHI Incoming Values should tilestored in just after their def. | |||
558 | // 3) The mem of these tileload and tilestores should be same. | |||
559 | // e.g. | |||
560 | // ------------------------------------------------------ | |||
561 | // bb_dom: | |||
562 | // ... | |||
563 | // br i1 %bool.cond, label %if.else, label %if.then | |||
564 | // | |||
565 | // if.then: | |||
566 | // def %t0 = ... | |||
567 | // ... | |||
568 | // use %t0 | |||
569 | // ... | |||
570 | // br label %if.end | |||
571 | // | |||
572 | // if.else: | |||
573 | // def %t1 = ... | |||
574 | // br label %if.end | |||
575 | // | |||
576 | // if.end: | |||
577 | // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] | |||
578 | // ... | |||
579 | // use %td | |||
580 | // ------------------------------------------------------ | |||
581 | // --> | |||
582 | // ------------------------------------------------------ | |||
583 | // bb_entry: | |||
584 | // %mem = alloca <256 x i32>, align 1024 * | |||
585 | // ... | |||
586 | // bb_dom: | |||
587 | // ... | |||
588 | // br i1 %bool.cond, label %if.else, label %if.then | |||
589 | // | |||
590 | // if.then: | |||
591 | // def %t0 = ... | |||
592 | // call void @llvm.x86.tilestored64.internal(mem, %t0) * | |||
593 | // ... | |||
594 | // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* | |||
595 | // use %t0` * | |||
596 | // ... | |||
597 | // br label %if.end | |||
598 | // | |||
599 | // if.else: | |||
600 | // def %t1 = ... | |||
601 | // call void @llvm.x86.tilestored64.internal(mem, %t1) * | |||
602 | // br label %if.end | |||
603 | // | |||
604 | // if.end: | |||
605 | // ... | |||
606 | // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * | |||
607 | // use %td | |||
608 | // ------------------------------------------------------ | |||
609 | void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { | |||
610 | BasicBlock *BB = PHI->getParent(); | |||
| ||||
611 | SmallVector<Instruction *, 2> Incomings; | |||
612 | ||||
613 | for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { | |||
614 | Value *Op = PHI->getIncomingValue(I); | |||
615 | Instruction *Inst = dyn_cast<Instruction>(Op); | |||
616 | assert(Inst && "We shouldn't fold AMX instrution!")(static_cast <bool> (Inst && "We shouldn't fold AMX instrution!" ) ? void (0) : __assert_fail ("Inst && \"We shouldn't fold AMX instrution!\"" , "llvm/lib/Target/X86/X86LowerAMXType.cpp", 616, __extension__ __PRETTY_FUNCTION__)); | |||
617 | Incomings.push_back(Inst); | |||
618 | } | |||
619 | ||||
620 | Value *StorePtr = updatePhiIncomings(BB, Incomings); | |||
621 | replacePhiDefWithLoad(PHI, StorePtr); | |||
622 | } | |||
623 | ||||
624 | // Store the defined tile and load it before use. | |||
625 | // All its users are not PHI. | |||
626 | // e.g. | |||
627 | // ------------------------------------------------------ | |||
628 | // def %td = ... | |||
629 | // ... | |||
630 | // "use %td" | |||
631 | // ------------------------------------------------------ | |||
632 | // --> | |||
633 | // ------------------------------------------------------ | |||
634 | // def %td = ... | |||
635 | // call void @llvm.x86.tilestored64.internal(mem, %td) | |||
636 | // ... | |||
637 | // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) | |||
638 | // "use %td2" | |||
639 | // ------------------------------------------------------ | |||
640 | void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { | |||
641 | BasicBlock *BB = I->getParent(); | |||
642 | Value *I8Ptr = getAllocaPos(BB); | |||
643 | User *Store = createTileStore(I, I8Ptr); | |||
644 | ||||
645 | // All its uses should load from stored mem. | |||
646 | for (Use &U : I->uses()) { | |||
647 | User *V = U.getUser(); | |||
648 | assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!")(static_cast <bool> (!isa<PHINode>(V) && "PHI Nodes should be excluded!" ) ? void (0) : __assert_fail ("!isa<PHINode>(V) && \"PHI Nodes should be excluded!\"" , "llvm/lib/Target/X86/X86LowerAMXType.cpp", 648, __extension__ __PRETTY_FUNCTION__)); | |||
649 | if (V != Store) | |||
650 | replaceWithTileLoad(U, I8Ptr); | |||
651 | } | |||
652 | } | |||
653 | ||||
654 | // Volatile Tile Model: | |||
655 | // 1) All the uses of tile data comes from tileload in time. | |||
656 | // 2) All the defs of tile data tilestore into mem immediately. | |||
657 | // For example: | |||
658 | // -------------------------------------------------------------------------- | |||
659 | // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key | |||
660 | // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) | |||
661 | // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx | |||
662 | // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) | |||
663 | // call void @llvm.x86.tilestored64.internal(... td) area | |||
664 | // -------------------------------------------------------------------------- | |||
665 | // 3) No terminator, call or other amx instructions in the key amx area. | |||
666 | bool X86VolatileTileData::volatileTileData() { | |||
667 | bool Changed = false; | |||
668 | for (BasicBlock &BB : F) { | |||
669 | SmallVector<Instruction *, 2> PHIInsts; | |||
670 | SmallVector<Instruction *, 8> AMXDefInsts; | |||
671 | ||||
672 | for (Instruction &I : BB) { | |||
673 | if (!I.getType()->isX86_AMXTy()) | |||
674 | continue; | |||
675 | if (isa<PHINode>(&I)) | |||
676 | PHIInsts.push_back(&I); | |||
677 | else | |||
678 | AMXDefInsts.push_back(&I); | |||
679 | } | |||
680 | ||||
681 | // First we "volatile" the non-phi related amx intrinsics. | |||
682 | for (Instruction *I : AMXDefInsts) { | |||
683 | if (isIncomingOfPHI(I)) | |||
684 | continue; | |||
685 | volatileTileNonPHI(I); | |||
686 | Changed = true; | |||
687 | } | |||
688 | ||||
689 | for (Instruction *I : PHIInsts) { | |||
690 | volatileTilePHI(dyn_cast<PHINode>(I)); | |||
691 | Changed = true; | |||
692 | } | |||
693 | } | |||
694 | return Changed; | |||
695 | } | |||
696 | ||||
697 | } // anonymous namespace | |||
698 | ||||
699 | namespace { | |||
700 | ||||
701 | class X86LowerAMXCast { | |||
702 | Function &Func; | |||
703 | ||||
704 | public: | |||
705 | X86LowerAMXCast(Function &F) : Func(F) {} | |||
706 | bool combineAMXcast(TargetLibraryInfo *TLI); | |||
707 | bool transformAMXCast(IntrinsicInst *AMXCast); | |||
708 | bool transformAllAMXCast(); | |||
709 | bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, | |||
710 | SmallSetVector<Instruction *, 16> &DeadInst); | |||
711 | }; | |||
712 | ||||
713 | static bool DCEInstruction(Instruction *I, | |||
714 | SmallSetVector<Instruction *, 16> &WorkList, | |||
715 | const TargetLibraryInfo *TLI) { | |||
716 | if (isInstructionTriviallyDead(I, TLI)) { | |||
717 | salvageDebugInfo(*I); | |||
718 | salvageKnowledge(I); | |||
719 | ||||
720 | // Null out all of the instruction's operands to see if any operand becomes | |||
721 | // dead as we go. | |||
722 | for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { | |||
723 | Value *OpV = I->getOperand(i); | |||
724 | I->setOperand(i, nullptr); | |||
725 | ||||
726 | if (!OpV->use_empty() || I == OpV) | |||
727 | continue; | |||
728 | ||||
729 | // If the operand is an instruction that became dead as we nulled out the | |||
730 | // operand, and if it is 'trivially' dead, delete it in a future loop | |||
731 | // iteration. | |||
732 | if (Instruction *OpI = dyn_cast<Instruction>(OpV)) { | |||
733 | if (isInstructionTriviallyDead(OpI, TLI)) { | |||
734 | WorkList.insert(OpI); | |||
735 | } | |||
736 | } | |||
737 | } | |||
738 | I->eraseFromParent(); | |||
739 | return true; | |||
740 | } | |||
741 | return false; | |||
742 | } | |||
743 | ||||
744 | /// This function handles following case | |||
745 | /// | |||
746 | /// A -> B amxcast | |||
747 | /// PHI | |||
748 | /// B -> A amxcast | |||
749 | /// | |||
750 | /// All the related PHI nodes can be replaced by new PHI nodes with type A. | |||
751 | /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. | |||
752 | bool X86LowerAMXCast::optimizeAMXCastFromPhi( | |||
753 | IntrinsicInst *CI, PHINode *PN, | |||
754 | SmallSetVector<Instruction *, 16> &DeadInst) { | |||
755 | IRBuilder<> Builder(CI); | |||
756 | Value *Src = CI->getOperand(0); | |||
757 | Type *SrcTy = Src->getType(); // Type B | |||
758 | Type *DestTy = CI->getType(); // Type A | |||
759 | ||||
760 | SmallVector<PHINode *, 4> PhiWorklist; | |||
761 | SmallSetVector<PHINode *, 4> OldPhiNodes; | |||
762 | ||||
763 | // Find all of the A->B casts and PHI nodes. | |||
764 | // We need to inspect all related PHI nodes, but PHIs can be cyclic, so | |||
765 | // OldPhiNodes is used to track all known PHI nodes, before adding a new | |||
766 | // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. | |||
767 | PhiWorklist.push_back(PN); | |||
768 | OldPhiNodes.insert(PN); | |||
769 | while (!PhiWorklist.empty()) { | |||
770 | auto *OldPN = PhiWorklist.pop_back_val(); | |||
771 | for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { | |||
772 | Value *IncValue = OldPN->getIncomingValue(I); | |||
773 | // TODO: currently, We ignore cases where it is a const. In the future, we | |||
774 | // might support const. | |||
775 | if (isa<Constant>(IncValue)) { | |||
776 | auto *IncConst = dyn_cast<Constant>(IncValue); | |||
777 | if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue()) | |||
778 | return false; | |||
779 | Value *Row = nullptr, *Col = nullptr; | |||
780 | std::tie(Row, Col) = getShape(OldPN); | |||
781 | // TODO: If it is not constant the Row and Col must domoniate tilezero | |||
782 | // that we are going to create. | |||
783 | if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col)) | |||
784 | return false; | |||
785 | // Create tilezero at the end of incoming block. | |||
786 | auto *Block = OldPN->getIncomingBlock(I); | |||
787 | BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); | |||
788 | Instruction *NewInst = Builder.CreateIntrinsic( | |||
789 | Intrinsic::x86_tilezero_internal, None, {Row, Col}); | |||
790 | NewInst->moveBefore(&*Iter); | |||
791 | NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector, | |||
792 | {IncValue->getType()}, {NewInst}); | |||
793 | NewInst->moveBefore(&*Iter); | |||
794 | // Replace InValue with new Value. | |||
795 | OldPN->setIncomingValue(I, NewInst); | |||
796 | IncValue = NewInst; | |||
797 | } | |||
798 | ||||
799 | if (auto *PNode = dyn_cast<PHINode>(IncValue)) { | |||
800 | if (OldPhiNodes.insert(PNode)) | |||
801 | PhiWorklist.push_back(PNode); | |||
802 | continue; | |||
803 | } | |||
804 | Instruction *ACI = dyn_cast<Instruction>(IncValue); | |||
805 | if (ACI && isAMXCast(ACI)) { | |||
806 | // Verify it's a A->B cast. | |||
807 | Type *TyA = ACI->getOperand(0)->getType(); | |||
808 | Type *TyB = ACI->getType(); | |||
809 | if (TyA != DestTy || TyB != SrcTy) | |||
810 | return false; | |||
811 | continue; | |||
812 | } | |||
813 | return false; | |||
814 | } | |||
815 | } | |||
816 | ||||
817 | // Check that each user of each old PHI node is something that we can | |||
818 | // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. | |||
819 | for (auto *OldPN : OldPhiNodes) { | |||
820 | for (User *V : OldPN->users()) { | |||
821 | Instruction *ACI = dyn_cast<Instruction>(V); | |||
822 | if (ACI && isAMXCast(ACI)) { | |||
823 | // Verify it's a B->A cast. | |||
824 | Type *TyB = ACI->getOperand(0)->getType(); | |||
825 | Type *TyA = ACI->getType(); | |||
826 | if (TyA != DestTy || TyB != SrcTy) | |||
827 | return false; | |||
828 | } else if (auto *PHI = dyn_cast<PHINode>(V)) { | |||
829 | // As long as the user is another old PHI node, then even if we don't | |||
830 | // rewrite it, the PHI web we're considering won't have any users | |||
831 | // outside itself, so it'll be dead. | |||
832 | // example: | |||
833 | // bb.0: | |||
834 | // %0 = amxcast ... | |||
835 | // bb.1: | |||
836 | // %1 = amxcast ... | |||
837 | // bb.2: | |||
838 | // %goodphi = phi %0, %1 | |||
839 | // %3 = amxcast %goodphi | |||
840 | // bb.3: | |||
841 | // %goodphi2 = phi %0, %goodphi | |||
842 | // %4 = amxcast %goodphi2 | |||
843 | // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is | |||
844 | // outside the phi-web, so the combination stop When | |||
845 | // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization | |||
846 | // will be done. | |||
847 | if (OldPhiNodes.count(PHI) == 0) | |||
848 | return false; | |||
849 | } else | |||
850 | return false; | |||
851 | } | |||
852 | } | |||
853 | ||||
854 | // For each old PHI node, create a corresponding new PHI node with a type A. | |||
855 | SmallDenseMap<PHINode *, PHINode *> NewPNodes; | |||
856 | for (auto *OldPN : OldPhiNodes) { | |||
857 | Builder.SetInsertPoint(OldPN); | |||
858 | PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); | |||
859 | NewPNodes[OldPN] = NewPN; | |||
860 | } | |||
861 | ||||
862 | // Fill in the operands of new PHI nodes. | |||
863 | for (auto *OldPN : OldPhiNodes) { | |||
864 | PHINode *NewPN = NewPNodes[OldPN]; | |||
865 | for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { | |||
866 | Value *V = OldPN->getOperand(j); | |||
867 | Value *NewV = nullptr; | |||
868 | Instruction *ACI = dyn_cast<Instruction>(V); | |||
869 | // There should not be a AMXcast from a const. | |||
870 | if (ACI && isAMXCast(ACI)) | |||
871 | NewV = ACI->getOperand(0); | |||
872 | else if (auto *PrevPN = dyn_cast<PHINode>(V)) | |||
873 | NewV = NewPNodes[PrevPN]; | |||
874 | assert(NewV)(static_cast <bool> (NewV) ? void (0) : __assert_fail ( "NewV", "llvm/lib/Target/X86/X86LowerAMXType.cpp", 874, __extension__ __PRETTY_FUNCTION__)); | |||
875 | NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); | |||
876 | } | |||
877 | } | |||
878 | ||||
879 | // Traverse all accumulated PHI nodes and process its users, | |||
880 | // which are Stores and BitcCasts. Without this processing | |||
881 | // NewPHI nodes could be replicated and could lead to extra | |||
882 | // moves generated after DeSSA. | |||
883 | // If there is a store with type B, change it to type A. | |||
884 | ||||
885 | // Replace users of BitCast B->A with NewPHI. These will help | |||
886 | // later to get rid of a closure formed by OldPHI nodes. | |||
887 | for (auto *OldPN : OldPhiNodes) { | |||
888 | PHINode *NewPN = NewPNodes[OldPN]; | |||
889 | for (User *V : make_early_inc_range(OldPN->users())) { | |||
890 | Instruction *ACI = dyn_cast<Instruction>(V); | |||
891 | if (ACI && isAMXCast(ACI)) { | |||
892 | Type *TyB = ACI->getOperand(0)->getType(); | |||
893 | Type *TyA = ACI->getType(); | |||
894 | assert(TyA == DestTy && TyB == SrcTy)(static_cast <bool> (TyA == DestTy && TyB == SrcTy ) ? void (0) : __assert_fail ("TyA == DestTy && TyB == SrcTy" , "llvm/lib/Target/X86/X86LowerAMXType.cpp", 894, __extension__ __PRETTY_FUNCTION__)); | |||
895 | (void)TyA; | |||
896 | (void)TyB; | |||
897 | ACI->replaceAllUsesWith(NewPN); | |||
898 | DeadInst.insert(ACI); | |||
899 | } else if (auto *PHI = dyn_cast<PHINode>(V)) { | |||
900 | // We don't need to push PHINode into DeadInst since they are operands | |||
901 | // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. | |||
902 | assert(OldPhiNodes.contains(PHI))(static_cast <bool> (OldPhiNodes.contains(PHI)) ? void ( 0) : __assert_fail ("OldPhiNodes.contains(PHI)", "llvm/lib/Target/X86/X86LowerAMXType.cpp" , 902, __extension__ __PRETTY_FUNCTION__)); | |||
903 | (void)PHI; | |||
904 | } else | |||
905 | llvm_unreachable("all uses should be handled")::llvm::llvm_unreachable_internal("all uses should be handled" , "llvm/lib/Target/X86/X86LowerAMXType.cpp", 905); | |||
906 | } | |||
907 | } | |||
908 | return true; | |||
909 | } | |||
910 | ||||
911 | bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { | |||
912 | bool Change = false; | |||
913 | // Collect tile cast instruction. | |||
914 | SmallVector<Instruction *, 8> Vec2TileInsts; | |||
915 | SmallVector<Instruction *, 8> Tile2VecInsts; | |||
916 | SmallVector<Instruction *, 8> PhiCastWorkList; | |||
917 | SmallSetVector<Instruction *, 16> DeadInst; | |||
918 | for (BasicBlock &BB : Func) { | |||
919 | for (Instruction &I : BB) { | |||
920 | Value *Vec; | |||
921 | if (match(&I, | |||
922 | m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec)))) | |||
923 | Vec2TileInsts.push_back(&I); | |||
924 | else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( | |||
925 | m_Value(Vec)))) | |||
926 | Tile2VecInsts.push_back(&I); | |||
927 | } | |||
928 | } | |||
929 | ||||
930 | auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { | |||
931 | for (auto *Inst : Insts) { | |||
932 | for (User *U : Inst->users()) { | |||
933 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); | |||
934 | if (!II || II->getIntrinsicID() != IID) | |||
935 | continue; | |||
936 | // T1 = vec2tile V0 | |||
937 | // V2 = tile2vec T1 | |||
938 | // V3 = OP V2 | |||
939 | // --> | |||
940 | // T1 = vec2tile V0 | |||
941 | // V2 = tile2vec T1 | |||
942 | // V3 = OP V0 | |||
943 | II->replaceAllUsesWith(Inst->getOperand(0)); | |||
944 | Change = true; | |||
945 | } | |||
946 | } | |||
947 | }; | |||
948 | ||||
949 | Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); | |||
950 | Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); | |||
951 | ||||
952 | auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { | |||
953 | for (auto *Inst : Insts) { | |||
954 | if (Inst->use_empty()) { | |||
955 | Inst->eraseFromParent(); | |||
956 | Change = true; | |||
957 | } | |||
958 | } | |||
959 | }; | |||
960 | ||||
961 | EraseInst(Vec2TileInsts); | |||
962 | EraseInst(Tile2VecInsts); | |||
963 | ||||
964 | // Handle the A->B->A cast, and there is an intervening PHI node. | |||
965 | for (BasicBlock &BB : Func) { | |||
966 | for (Instruction &I : BB) { | |||
967 | if (isAMXCast(&I)) { | |||
968 | if (isa<PHINode>(I.getOperand(0))) | |||
969 | PhiCastWorkList.push_back(&I); | |||
970 | } | |||
971 | } | |||
972 | } | |||
973 | for (auto *I : PhiCastWorkList) { | |||
974 | // We skip the dead Amxcast. | |||
975 | if (DeadInst.contains(I)) | |||
976 | continue; | |||
977 | PHINode *PN = cast<PHINode>(I->getOperand(0)); | |||
978 | if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) { | |||
979 | DeadInst.insert(PN); | |||
980 | Change = true; | |||
981 | } | |||
982 | } | |||
983 | ||||
984 | // Since we create new phi and merge AMXCast, some old phis and AMXCast might | |||
985 | // have no uses. We do some DeadCodeElimination for them. | |||
986 | while (!DeadInst.empty()) { | |||
987 | Instruction *I = DeadInst.pop_back_val(); | |||
988 | Change |= DCEInstruction(I, DeadInst, TLI); | |||
989 | } | |||
990 | return Change; | |||
991 | } | |||
992 | ||||
993 | // There might be remaining AMXcast after combineAMXcast and they should be | |||
994 | // handled elegantly. | |||
995 | bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { | |||
996 | IRBuilder<> Builder(AMXCast); | |||
997 | AllocaInst *AllocaAddr; | |||
998 | Value *I8Ptr, *Stride; | |||
999 | auto *Src = AMXCast->getOperand(0); | |||
1000 | ||||
1001 | auto Prepare = [&](Type *MemTy) { | |||
1002 | AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy); | |||
1003 | I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); | |||
1004 | Stride = Builder.getInt64(64); | |||
1005 | }; | |||
1006 | ||||
1007 | if (AMXCast->getType()->isX86_AMXTy()) { | |||
1008 | // %2 = amxcast <225 x i32> %src to x86_amx | |||
1009 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, | |||
1010 | // i8* %addr3, i64 60, x86_amx %2) | |||
1011 | // --> | |||
1012 | // %addr = alloca <225 x i32>, align 64 | |||
1013 | // store <225 x i32> %src, <225 x i32>* %addr, align 64 | |||
1014 | // %addr2 = bitcast <225 x i32>* %addr to i8* | |||
1015 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, | |||
1016 | // i8* %addr2, | |||
1017 | // i64 60) | |||
1018 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, | |||
1019 | // i8* %addr3, i64 60, x86_amx %2) | |||
1020 | if (AMXCast->use_empty()) { | |||
1021 | AMXCast->eraseFromParent(); | |||
1022 | return true; | |||
1023 | } | |||
1024 | Use &U = *(AMXCast->use_begin()); | |||
1025 | unsigned OpNo = U.getOperandNo(); | |||
1026 | auto *II = dyn_cast<IntrinsicInst>(U.getUser()); | |||
1027 | if (!II) | |||
1028 | return false; // May be bitcast from x86amx to <256 x i32>. | |||
1029 | Prepare(AMXCast->getOperand(0)->getType()); | |||
1030 | Builder.CreateStore(Src, AllocaAddr); | |||
1031 | // TODO we can pick an constant operand for the shape. | |||
1032 | Value *Row = nullptr, *Col = nullptr; | |||
1033 | std::tie(Row, Col) = getShape(II, OpNo); | |||
1034 | std::array<Value *, 4> Args = { | |||
1035 | Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())}; | |||
1036 | Value *NewInst = Builder.CreateIntrinsic( | |||
1037 | Intrinsic::x86_tileloadd64_internal, None, Args); | |||
1038 | AMXCast->replaceAllUsesWith(NewInst); | |||
1039 | AMXCast->eraseFromParent(); | |||
1040 | } else { | |||
1041 | // %2 = amxcast x86_amx %src to <225 x i32> | |||
1042 | // --> | |||
1043 | // %addr = alloca <225 x i32>, align 64 | |||
1044 | // %addr2 = bitcast <225 x i32>* to i8* | |||
1045 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, | |||
1046 | // i8* %addr2, i64 %stride) | |||
1047 | // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 | |||
1048 | auto *II = dyn_cast<IntrinsicInst>(Src); | |||
1049 | if (!II) | |||
1050 | return false; // May be bitcast from <256 x i32> to x86amx. | |||
1051 | Prepare(AMXCast->getType()); | |||
1052 | Value *Row = II->getOperand(0); | |||
1053 | Value *Col = II->getOperand(1); | |||
1054 | std::array<Value *, 5> Args = { | |||
1055 | Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src}; | |||
1056 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
1057 | Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr); | |||
1058 | AMXCast->replaceAllUsesWith(NewInst); | |||
1059 | AMXCast->eraseFromParent(); | |||
1060 | } | |||
1061 | ||||
1062 | return true; | |||
1063 | } | |||
1064 | ||||
1065 | bool X86LowerAMXCast::transformAllAMXCast() { | |||
1066 | bool Change = false; | |||
1067 | // Collect tile cast instruction. | |||
1068 | SmallVector<Instruction *, 8> WorkLists; | |||
1069 | for (BasicBlock &BB : Func) { | |||
1070 | for (Instruction &I : BB) { | |||
1071 | if (isAMXCast(&I)) | |||
1072 | WorkLists.push_back(&I); | |||
1073 | } | |||
1074 | } | |||
1075 | ||||
1076 | for (auto *Inst : WorkLists) { | |||
1077 | Change |= transformAMXCast(cast<IntrinsicInst>(Inst)); | |||
1078 | } | |||
1079 | ||||
1080 | return Change; | |||
1081 | } | |||
1082 | ||||
1083 | } // anonymous namespace | |||
1084 | ||||
1085 | namespace { | |||
1086 | ||||
1087 | class X86LowerAMXTypeLegacyPass : public FunctionPass { | |||
1088 | public: | |||
1089 | static char ID; | |||
1090 | ||||
1091 | X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { | |||
1092 | initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); | |||
1093 | } | |||
1094 | ||||
1095 | bool runOnFunction(Function &F) override { | |||
1096 | bool C = false; | |||
1097 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); | |||
1098 | TargetLibraryInfo *TLI = | |||
1099 | &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); | |||
1100 | X86LowerAMXCast LAC(F); | |||
1101 | C |= LAC.combineAMXcast(TLI); | |||
1102 | // There might be remaining AMXcast after combineAMXcast and they should be | |||
1103 | // handled elegantly. | |||
1104 | C |= LAC.transformAllAMXCast(); | |||
1105 | ||||
1106 | X86LowerAMXType LAT(F); | |||
1107 | C |= LAT.visit(); | |||
1108 | ||||
1109 | // Prepare for fast register allocation at O0. | |||
1110 | // Todo: May better check the volatile model of AMX code, not just | |||
1111 | // by checking Attribute::OptimizeNone and CodeGenOpt::None. | |||
1112 | if (TM->getOptLevel() == CodeGenOpt::None) { | |||
| ||||
1113 | // If Front End not use O0 but the Mid/Back end use O0, (e.g. | |||
1114 | // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make | |||
1115 | // sure the amx data is volatile, that is nessary for AMX fast | |||
1116 | // register allocation. | |||
1117 | if (!F.hasFnAttribute(Attribute::OptimizeNone)) { | |||
1118 | X86VolatileTileData VTD(F); | |||
1119 | C = VTD.volatileTileData() || C; | |||
1120 | } | |||
1121 | } | |||
1122 | ||||
1123 | return C; | |||
1124 | } | |||
1125 | ||||
1126 | void getAnalysisUsage(AnalysisUsage &AU) const override { | |||
1127 | AU.setPreservesCFG(); | |||
1128 | AU.addRequired<TargetPassConfig>(); | |||
1129 | AU.addRequired<TargetLibraryInfoWrapperPass>(); | |||
1130 | } | |||
1131 | }; | |||
1132 | ||||
1133 | } // anonymous namespace | |||
1134 | ||||
1135 | static const char PassName[] = "Lower AMX type for load/store"; | |||
1136 | char X86LowerAMXTypeLegacyPass::ID = 0; | |||
1137 | INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry &Registry) { | |||
1138 | false)static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry &Registry) { | |||
1139 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)initializeTargetPassConfigPass(Registry); | |||
1140 | INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)initializeTargetLibraryInfoWrapperPassPass(Registry); | |||
1141 | INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,PassInfo *PI = new PassInfo( PassName, "lower-amx-type", & X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor <X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass (*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag ; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag , initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry )); } | |||
1142 | false)PassInfo *PI = new PassInfo( PassName, "lower-amx-type", & X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor <X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass (*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag ; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag , initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry )); } | |||
1143 | ||||
1144 | FunctionPass *llvm::createX86LowerAMXTypePass() { | |||
1145 | return new X86LowerAMXTypeLegacyPass(); | |||
1146 | } |