Bug Summary

File:llvm/lib/Target/X86/X86LowerAMXType.cpp
Warning:line 560, column 20
Called C++ object pointer is null

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name X86LowerAMXType.cpp -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/build-llvm -resource-dir /usr/lib/llvm-14/lib/clang/14.0.0 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I lib/Target/X86 -I /build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/llvm/lib/Target/X86 -I include -I /build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/llvm/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-14/lib/clang/14.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/build-llvm=build-llvm -fmacro-prefix-map=/build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/build-llvm=build-llvm -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/= -O3 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/build-llvm -fdebug-prefix-map=/build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/build-llvm=build-llvm -fdebug-prefix-map=/build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/= -ferror-limit 19 -fvisibility hidden -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-01-19-001817-16337-1 -x c++ /build/llvm-toolchain-snapshot-14~++20220118101002+ec47dba1c8a2/llvm/lib/Target/X86/X86LowerAMXType.cpp
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
42#include "llvm/ADT/PostOrderIterator.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallSet.h"
45#include "llvm/Analysis/OptimizationRemarkEmitter.h"
46#include "llvm/Analysis/TargetLibraryInfo.h"
47#include "llvm/Analysis/TargetTransformInfo.h"
48#include "llvm/CodeGen/Passes.h"
49#include "llvm/CodeGen/TargetPassConfig.h"
50#include "llvm/CodeGen/ValueTypes.h"
51#include "llvm/IR/DataLayout.h"
52#include "llvm/IR/Function.h"
53#include "llvm/IR/IRBuilder.h"
54#include "llvm/IR/Instructions.h"
55#include "llvm/IR/IntrinsicInst.h"
56#include "llvm/IR/IntrinsicsX86.h"
57#include "llvm/IR/PatternMatch.h"
58#include "llvm/InitializePasses.h"
59#include "llvm/Pass.h"
60#include "llvm/Target/TargetMachine.h"
61#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62#include "llvm/Transforms/Utils/Local.h"
63
64using namespace llvm;
65using namespace PatternMatch;
66
67#define DEBUG_TYPE"lower-amx-type" "lower-amx-type"
68
69static bool isAMXCast(Instruction *II) {
70 return match(II,
71 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
72 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
73}
74
75static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
76 Type *Ty) {
77 Function &F = *BB->getParent();
78 Module *M = BB->getModule();
79 const DataLayout &DL = M->getDataLayout();
80
81 LLVMContext &Ctx = Builder.getContext();
82 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
83 unsigned AllocaAS = DL.getAllocaAddrSpace();
84 AllocaInst *AllocaRes =
85 new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
86 AllocaRes->setAlignment(AllocaAlignment);
87 return AllocaRes;
88}
89
90static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
91 for (Instruction &I : F.getEntryBlock())
92 if (!isa<AllocaInst>(&I))
93 return &I;
94 llvm_unreachable("No terminator in the entry block!")::llvm::llvm_unreachable_internal("No terminator in the entry block!"
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 94)
;
95}
96
97static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
98 IRBuilder<> Builder(II);
99 Value *Row = nullptr, *Col = nullptr;
100 switch (II->getIntrinsicID()) {
101 default:
102 llvm_unreachable("Expect amx intrinsics")::llvm::llvm_unreachable_internal("Expect amx intrinsics", "llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 102)
;
103 case Intrinsic::x86_tileloadd64_internal:
104 case Intrinsic::x86_tileloaddt164_internal:
105 case Intrinsic::x86_tilestored64_internal: {
106 Row = II->getArgOperand(0);
107 Col = II->getArgOperand(1);
108 break;
109 }
110 // a * b + c
111 // The shape depends on which operand.
112 case Intrinsic::x86_tdpbssd_internal:
113 case Intrinsic::x86_tdpbsud_internal:
114 case Intrinsic::x86_tdpbusd_internal:
115 case Intrinsic::x86_tdpbuud_internal:
116 case Intrinsic::x86_tdpbf16ps_internal: {
117 switch (OpNo) {
118 case 3:
119 Row = II->getArgOperand(0);
120 Col = II->getArgOperand(1);
121 break;
122 case 4:
123 Row = II->getArgOperand(0);
124 Col = II->getArgOperand(2);
125 break;
126 case 5:
127 if (isa<ConstantInt>(II->getArgOperand(2)))
128 Row = Builder.getInt16(
129 (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
130 else if (isa<Instruction>(II->getArgOperand(2))) {
131 // When it is not a const value and it is not a function argument, we
132 // create Row after the definition of II->getOperand(2) instead of
133 // before II. For example, II is %118, we try to getshape for %117:
134 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
135 // i32> %115).
136 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
137 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
138 // %117).
139 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
140 // definition is after its user(new tileload for %117).
141 // So, the best choice is to create %row right after the definition of
142 // %106.
143 Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
144 Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
145 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
146 } else {
147 // When it is not a const value and it is a function argument, we create
148 // Row at the entry bb.
149 IRBuilder<> NewBuilder(
150 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
151 Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
152 }
153 Col = II->getArgOperand(1);
154 break;
155 }
156 break;
157 }
158 }
159
160 return std::make_pair(Row, Col);
161}
162
163namespace {
164class X86LowerAMXType {
165 Function &Func;
166
167 // In AMX intrinsics we let Shape = {Row, Col}, but the
168 // RealCol = Col / ElementSize. We may use the RealCol
169 // as a new Row for other new created AMX intrinsics.
170 std::map<Value *, Value *> Col2Row;
171
172public:
173 X86LowerAMXType(Function &F) : Func(F) {}
174 bool visit();
175 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
176 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
177 bool transformBitcast(BitCastInst *Bitcast);
178};
179
180// %src = load <256 x i32>, <256 x i32>* %addr, align 64
181// %2 = bitcast <256 x i32> %src to x86_amx
182// -->
183// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
184// i8* %addr, i64 %stride64)
185void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
186 Value *Row = nullptr, *Col = nullptr;
187 Use &U = *(Bitcast->use_begin());
188 unsigned OpNo = U.getOperandNo();
189 auto *II = cast<IntrinsicInst>(U.getUser());
190 std::tie(Row, Col) = getShape(II, OpNo);
191 IRBuilder<> Builder(Bitcast);
192 // Use the maximun column as stride.
193 Value *Stride = Builder.getInt64(64);
194 Value *I8Ptr =
195 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
196 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
197
198 Value *NewInst =
199 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
200 Bitcast->replaceAllUsesWith(NewInst);
201}
202
203// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
204// %stride);
205// %13 = bitcast x86_amx %src to <256 x i32>
206// store <256 x i32> %13, <256 x i32>* %addr, align 64
207// -->
208// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
209// %stride64, %13)
210void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
211
212 Value *Tile = Bitcast->getOperand(0);
213 auto *II = cast<IntrinsicInst>(Tile);
214 // Tile is output from AMX intrinsic. The first operand of the
215 // intrinsic is row, the second operand of the intrinsic is column.
216 Value *Row = II->getOperand(0);
217 Value *Col = II->getOperand(1);
218 IRBuilder<> Builder(ST);
219 // Use the maximum column as stride. It must be the same with load
220 // stride.
221 Value *Stride = Builder.getInt64(64);
222 Value *I8Ptr =
223 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
224 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
225 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
226 if (Bitcast->hasOneUse())
227 return;
228 // %13 = bitcast x86_amx %src to <256 x i32>
229 // store <256 x i32> %13, <256 x i32>* %addr, align 64
230 // %add = <256 x i32> %13, <256 x i32> %src2
231 // -->
232 // %13 = bitcast x86_amx %src to <256 x i32>
233 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
234 // %stride64, %13)
235 // %14 = load <256 x i32>, %addr
236 // %add = <256 x i32> %14, <256 x i32> %src2
237 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
238 Bitcast->replaceAllUsesWith(Vec);
239}
240
241// transform bitcast to <store, load> instructions.
242bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
243 IRBuilder<> Builder(Bitcast);
244 AllocaInst *AllocaAddr;
245 Value *I8Ptr, *Stride;
246 auto *Src = Bitcast->getOperand(0);
247
248 auto Prepare = [&](Type *MemTy) {
249 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
250 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
251 Stride = Builder.getInt64(64);
252 };
253
254 if (Bitcast->getType()->isX86_AMXTy()) {
255 // %2 = bitcast <256 x i32> %src to x86_amx
256 // -->
257 // %addr = alloca <256 x i32>, align 64
258 // store <256 x i32> %src, <256 x i32>* %addr, align 64
259 // %addr2 = bitcast <256 x i32>* to i8*
260 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
261 // i8* %addr2,
262 // i64 64)
263 Use &U = *(Bitcast->use_begin());
264 unsigned OpNo = U.getOperandNo();
265 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
266 if (!II)
267 return false; // May be bitcast from x86amx to <256 x i32>.
268 Prepare(Bitcast->getOperand(0)->getType());
269 Builder.CreateStore(Src, AllocaAddr);
270 // TODO we can pick an constant operand for the shape.
271 Value *Row = nullptr, *Col = nullptr;
272 std::tie(Row, Col) = getShape(II, OpNo);
273 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
274 Value *NewInst = Builder.CreateIntrinsic(
275 Intrinsic::x86_tileloadd64_internal, None, Args);
276 Bitcast->replaceAllUsesWith(NewInst);
277 } else {
278 // %2 = bitcast x86_amx %src to <256 x i32>
279 // -->
280 // %addr = alloca <256 x i32>, align 64
281 // %addr2 = bitcast <256 x i32>* to i8*
282 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
283 // i8* %addr2, i64 %stride)
284 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
285 auto *II = dyn_cast<IntrinsicInst>(Src);
286 if (!II)
287 return false; // May be bitcast from <256 x i32> to x86amx.
288 Prepare(Bitcast->getType());
289 Value *Row = II->getOperand(0);
290 Value *Col = II->getOperand(1);
291 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
292 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
293 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
294 Bitcast->replaceAllUsesWith(NewInst);
295 }
296
297 return true;
298}
299
300bool X86LowerAMXType::visit() {
301 SmallVector<Instruction *, 8> DeadInsts;
302 Col2Row.clear();
303
304 for (BasicBlock *BB : post_order(&Func)) {
305 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
306 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
307 if (!Bitcast)
308 continue;
309
310 Value *Src = Bitcast->getOperand(0);
311 if (Bitcast->getType()->isX86_AMXTy()) {
312 if (Bitcast->user_empty()) {
313 DeadInsts.push_back(Bitcast);
314 continue;
315 }
316 LoadInst *LD = dyn_cast<LoadInst>(Src);
317 if (!LD) {
318 if (transformBitcast(Bitcast))
319 DeadInsts.push_back(Bitcast);
320 continue;
321 }
322 // If load has mutli-user, duplicate a vector load.
323 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
324 // %2 = bitcast <256 x i32> %src to x86_amx
325 // %add = add <256 x i32> %src, <256 x i32> %src2
326 // -->
327 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
328 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
329 // i8* %addr, i64 %stride64)
330 // %add = add <256 x i32> %src, <256 x i32> %src2
331
332 // If load has one user, the load will be eliminated in DAG ISel.
333 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
334 // %2 = bitcast <256 x i32> %src to x86_amx
335 // -->
336 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
337 // i8* %addr, i64 %stride64)
338 combineLoadBitcast(LD, Bitcast);
339 DeadInsts.push_back(Bitcast);
340 if (LD->hasOneUse())
341 DeadInsts.push_back(LD);
342 } else if (Src->getType()->isX86_AMXTy()) {
343 if (Bitcast->user_empty()) {
344 DeadInsts.push_back(Bitcast);
345 continue;
346 }
347 StoreInst *ST = nullptr;
348 for (Use &U : Bitcast->uses()) {
349 ST = dyn_cast<StoreInst>(U.getUser());
350 if (ST)
351 break;
352 }
353 if (!ST) {
354 if (transformBitcast(Bitcast))
355 DeadInsts.push_back(Bitcast);
356 continue;
357 }
358 // If bitcast (%13) has one use, combine bitcast and store to amx store.
359 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
360 // %stride);
361 // %13 = bitcast x86_amx %src to <256 x i32>
362 // store <256 x i32> %13, <256 x i32>* %addr, align 64
363 // -->
364 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
365 // %stride64, %13)
366 //
367 // If bitcast (%13) has multi-use, transform as below.
368 // %13 = bitcast x86_amx %src to <256 x i32>
369 // store <256 x i32> %13, <256 x i32>* %addr, align 64
370 // %add = <256 x i32> %13, <256 x i32> %src2
371 // -->
372 // %13 = bitcast x86_amx %src to <256 x i32>
373 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
374 // %stride64, %13)
375 // %14 = load <256 x i32>, %addr
376 // %add = <256 x i32> %14, <256 x i32> %src2
377 //
378 combineBitcastStore(Bitcast, ST);
379 // Delete user first.
380 DeadInsts.push_back(ST);
381 DeadInsts.push_back(Bitcast);
382 }
383 }
384 }
385
386 bool C = !DeadInsts.empty();
387
388 for (auto *Inst : DeadInsts)
389 Inst->eraseFromParent();
390
391 return C;
392}
393} // anonymous namespace
394
395static Value *getAllocaPos(BasicBlock *BB) {
396 Module *M = BB->getModule();
397 Function *F = BB->getParent();
398 IRBuilder<> Builder(&F->getEntryBlock().front());
399 const DataLayout &DL = M->getDataLayout();
400 unsigned AllocaAS = DL.getAllocaAddrSpace();
401 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
402 AllocaInst *AllocaRes =
403 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
404 BasicBlock::iterator Iter = AllocaRes->getIterator();
405 ++Iter;
406 Builder.SetInsertPoint(&*Iter);
407 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
408 return I8Ptr;
409}
410
411static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
412 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!")(static_cast <bool> (TileDef->getType()->isX86_AMXTy
() && "Not define tile!") ? void (0) : __assert_fail (
"TileDef->getType()->isX86_AMXTy() && \"Not define tile!\""
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 412, __extension__
__PRETTY_FUNCTION__))
;
413 auto *II = cast<IntrinsicInst>(TileDef);
414 assert(II && "Not tile intrinsic!")(static_cast <bool> (II && "Not tile intrinsic!"
) ? void (0) : __assert_fail ("II && \"Not tile intrinsic!\""
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 414, __extension__
__PRETTY_FUNCTION__))
;
415 Value *Row = II->getOperand(0);
416 Value *Col = II->getOperand(1);
417
418 BasicBlock *BB = TileDef->getParent();
419 BasicBlock::iterator Iter = TileDef->getIterator();
420 IRBuilder<> Builder(BB, ++Iter);
421 Value *Stride = Builder.getInt64(64);
422 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
423
424 Instruction *TileStore =
425 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
426 return TileStore;
427}
428
429static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
430 Value *V = U.get();
431 assert(V->getType()->isX86_AMXTy() && "Not define tile!")(static_cast <bool> (V->getType()->isX86_AMXTy() &&
"Not define tile!") ? void (0) : __assert_fail ("V->getType()->isX86_AMXTy() && \"Not define tile!\""
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 431, __extension__
__PRETTY_FUNCTION__))
;
432
433 // Get tile shape.
434 IntrinsicInst *II = nullptr;
435 if (IsPHI) {
436 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
437 II = cast<IntrinsicInst>(PhiOp);
438 } else {
439 II = cast<IntrinsicInst>(V);
440 }
441 Value *Row = II->getOperand(0);
442 Value *Col = II->getOperand(1);
443
444 Instruction *UserI = dyn_cast<Instruction>(U.getUser());
445 IRBuilder<> Builder(UserI);
446 Value *Stride = Builder.getInt64(64);
447 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
448
449 Value *TileLoad =
450 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
451 UserI->replaceUsesOfWith(V, TileLoad);
452}
453
454static bool isIncomingOfPHI(Instruction *I) {
455 for (Use &U : I->uses()) {
456 User *V = U.getUser();
457 if (isa<PHINode>(V))
458 return true;
459 }
460 return false;
461}
462
463// Let all AMX tile data become volatile data, shorten the life range
464// of each tile register before fast register allocation.
465namespace {
466class X86VolatileTileData {
467 Function &F;
468
469public:
470 X86VolatileTileData(Function &Func) : F(Func) {}
471 Value *updatePhiIncomings(BasicBlock *BB,
472 SmallVector<Instruction *, 2> &Incomings);
473 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
474 bool volatileTileData();
475 void volatileTilePHI(PHINode *Inst);
476 void volatileTileNonPHI(Instruction *I);
477};
478
479Value *X86VolatileTileData::updatePhiIncomings(
480 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
481 Value *I8Ptr = getAllocaPos(BB);
482
483 for (auto *I : Incomings) {
484 User *Store = createTileStore(I, I8Ptr);
485
486 // All its uses (except phi) should load from stored mem.
487 for (Use &U : I->uses()) {
488 User *V = U.getUser();
489 if (isa<PHINode>(V) || V == Store)
490 continue;
491 replaceWithTileLoad(U, I8Ptr);
492 }
493 }
494 return I8Ptr;
495}
496
497void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
498 Value *StorePtr) {
499 for (Use &U : PHI->uses())
500 replaceWithTileLoad(U, StorePtr, true);
501 PHI->eraseFromParent();
502}
503
504// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
505// and their related AMX intrinsics.
506// 1) PHI Def should change to tileload.
507// 2) PHI Incoming Values should tilestored in just after their def.
508// 3) The mem of these tileload and tilestores should be same.
509// e.g.
510// ------------------------------------------------------
511// bb_dom:
512// ...
513// br i1 %bool.cond, label %if.else, label %if.then
514//
515// if.then:
516// def %t0 = ...
517// ...
518// use %t0
519// ...
520// br label %if.end
521//
522// if.else:
523// def %t1 = ...
524// br label %if.end
525//
526// if.end:
527// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
528// ...
529// use %td
530// ------------------------------------------------------
531// -->
532// ------------------------------------------------------
533// bb_entry:
534// %mem = alloca <256 x i32>, align 1024 *
535// ...
536// bb_dom:
537// ...
538// br i1 %bool.cond, label %if.else, label %if.then
539//
540// if.then:
541// def %t0 = ...
542// call void @llvm.x86.tilestored64.internal(mem, %t0) *
543// ...
544// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
545// use %t0` *
546// ...
547// br label %if.end
548//
549// if.else:
550// def %t1 = ...
551// call void @llvm.x86.tilestored64.internal(mem, %t1) *
552// br label %if.end
553//
554// if.end:
555// ...
556// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
557// use %td
558// ------------------------------------------------------
559void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
560 BasicBlock *BB = PHI->getParent();
11
Called C++ object pointer is null
561 SmallVector<Instruction *, 2> Incomings;
562
563 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
564 Value *Op = PHI->getIncomingValue(I);
565 Instruction *Inst = dyn_cast<Instruction>(Op);
566 assert(Inst && "We shouldn't fold AMX instrution!")(static_cast <bool> (Inst && "We shouldn't fold AMX instrution!"
) ? void (0) : __assert_fail ("Inst && \"We shouldn't fold AMX instrution!\""
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 566, __extension__
__PRETTY_FUNCTION__))
;
567 Incomings.push_back(Inst);
568 }
569
570 Value *StorePtr = updatePhiIncomings(BB, Incomings);
571 replacePhiDefWithLoad(PHI, StorePtr);
572}
573
574// Store the defined tile and load it before use.
575// All its users are not PHI.
576// e.g.
577// ------------------------------------------------------
578// def %td = ...
579// ...
580// "use %td"
581// ------------------------------------------------------
582// -->
583// ------------------------------------------------------
584// def %td = ...
585// call void @llvm.x86.tilestored64.internal(mem, %td)
586// ...
587// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
588// "use %td2"
589// ------------------------------------------------------
590void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
591 BasicBlock *BB = I->getParent();
592 Value *I8Ptr = getAllocaPos(BB);
593 User *Store = createTileStore(I, I8Ptr);
594
595 // All its uses should load from stored mem.
596 for (Use &U : I->uses()) {
597 User *V = U.getUser();
598 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!")(static_cast <bool> (!isa<PHINode>(V) && "PHI Nodes should be excluded!"
) ? void (0) : __assert_fail ("!isa<PHINode>(V) && \"PHI Nodes should be excluded!\""
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 598, __extension__
__PRETTY_FUNCTION__))
;
599 if (V != Store)
600 replaceWithTileLoad(U, I8Ptr);
601 }
602}
603
604// Volatile Tile Model:
605// 1) All the uses of tile data comes from tileload in time.
606// 2) All the defs of tile data tilestore into mem immediately.
607// For example:
608// --------------------------------------------------------------------------
609// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
610// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
611// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
612// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
613// call void @llvm.x86.tilestored64.internal(... td) area
614// --------------------------------------------------------------------------
615// 3) No terminator, call or other amx instructions in the key amx area.
616bool X86VolatileTileData::volatileTileData() {
617 bool Changed = false;
618 for (BasicBlock &BB : F) {
619 SmallVector<Instruction *, 2> PHIInsts;
620 SmallVector<Instruction *, 8> AMXDefInsts;
621
622 for (Instruction &I : BB) {
623 if (!I.getType()->isX86_AMXTy())
624 continue;
625 if (isa<PHINode>(&I))
626 PHIInsts.push_back(&I);
627 else
628 AMXDefInsts.push_back(&I);
629 }
630
631 // First we "volatile" the non-phi related amx intrinsics.
632 for (Instruction *I : AMXDefInsts) {
6
Assuming '__begin2' is equal to '__end2'
633 if (isIncomingOfPHI(I))
634 continue;
635 volatileTileNonPHI(I);
636 Changed = true;
637 }
638
639 for (Instruction *I : PHIInsts) {
7
Assuming '__begin2' is not equal to '__end2'
640 volatileTilePHI(dyn_cast<PHINode>(I));
8
Assuming 'I' is not a 'PHINode'
9
Passing null pointer value via 1st parameter 'PHI'
10
Calling 'X86VolatileTileData::volatileTilePHI'
641 Changed = true;
642 }
643 }
644 return Changed;
645}
646
647} // anonymous namespace
648
649namespace {
650
651class X86LowerAMXCast {
652 Function &Func;
653
654public:
655 X86LowerAMXCast(Function &F) : Func(F) {}
656 bool combineAMXcast(TargetLibraryInfo *TLI);
657 bool transformAMXCast(IntrinsicInst *AMXCast);
658 bool transformAllAMXCast();
659 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
660 SmallSetVector<Instruction *, 16> &DeadInst);
661};
662
663static bool DCEInstruction(Instruction *I,
664 SmallSetVector<Instruction *, 16> &WorkList,
665 const TargetLibraryInfo *TLI) {
666 if (isInstructionTriviallyDead(I, TLI)) {
667 salvageDebugInfo(*I);
668 salvageKnowledge(I);
669
670 // Null out all of the instruction's operands to see if any operand becomes
671 // dead as we go.
672 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
673 Value *OpV = I->getOperand(i);
674 I->setOperand(i, nullptr);
675
676 if (!OpV->use_empty() || I == OpV)
677 continue;
678
679 // If the operand is an instruction that became dead as we nulled out the
680 // operand, and if it is 'trivially' dead, delete it in a future loop
681 // iteration.
682 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
683 if (isInstructionTriviallyDead(OpI, TLI)) {
684 WorkList.insert(OpI);
685 }
686 }
687 }
688 I->eraseFromParent();
689 return true;
690 }
691 return false;
692}
693
694/// This function handles following case
695///
696/// A -> B amxcast
697/// PHI
698/// B -> A amxcast
699///
700/// All the related PHI nodes can be replaced by new PHI nodes with type A.
701/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
702bool X86LowerAMXCast::optimizeAMXCastFromPhi(
703 IntrinsicInst *CI, PHINode *PN,
704 SmallSetVector<Instruction *, 16> &DeadInst) {
705 IRBuilder<> Builder(CI);
706 Value *Src = CI->getOperand(0);
707 Type *SrcTy = Src->getType(); // Type B
708 Type *DestTy = CI->getType(); // Type A
709
710 SmallVector<PHINode *, 4> PhiWorklist;
711 SmallSetVector<PHINode *, 4> OldPhiNodes;
712
713 // Find all of the A->B casts and PHI nodes.
714 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
715 // OldPhiNodes is used to track all known PHI nodes, before adding a new
716 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
717 PhiWorklist.push_back(PN);
718 OldPhiNodes.insert(PN);
719 while (!PhiWorklist.empty()) {
720 auto *OldPN = PhiWorklist.pop_back_val();
721 for (Value *IncValue : OldPN->incoming_values()) {
722 // TODO: currently, We ignore cases where it is a const. In the future, we
723 // might support const.
724 if (isa<Constant>(IncValue))
725 return false;
726
727 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
728 if (OldPhiNodes.insert(PNode))
729 PhiWorklist.push_back(PNode);
730 continue;
731 }
732 Instruction *ACI = dyn_cast<Instruction>(IncValue);
733 if (ACI && isAMXCast(ACI)) {
734 // Verify it's a A->B cast.
735 Type *TyA = ACI->getOperand(0)->getType();
736 Type *TyB = ACI->getType();
737 if (TyA != DestTy || TyB != SrcTy)
738 return false;
739 continue;
740 }
741 return false;
742 }
743 }
744
745 // Check that each user of each old PHI node is something that we can
746 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
747 for (auto *OldPN : OldPhiNodes) {
748 for (User *V : OldPN->users()) {
749 Instruction *ACI = dyn_cast<Instruction>(V);
750 if (ACI && isAMXCast(ACI)) {
751 // Verify it's a B->A cast.
752 Type *TyB = ACI->getOperand(0)->getType();
753 Type *TyA = ACI->getType();
754 if (TyA != DestTy || TyB != SrcTy)
755 return false;
756 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
757 // As long as the user is another old PHI node, then even if we don't
758 // rewrite it, the PHI web we're considering won't have any users
759 // outside itself, so it'll be dead.
760 // example:
761 // bb.0:
762 // %0 = amxcast ...
763 // bb.1:
764 // %1 = amxcast ...
765 // bb.2:
766 // %goodphi = phi %0, %1
767 // %3 = amxcast %goodphi
768 // bb.3:
769 // %goodphi2 = phi %0, %goodphi
770 // %4 = amxcast %goodphi2
771 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
772 // outside the phi-web, so the combination stop When
773 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
774 // will be done.
775 if (OldPhiNodes.count(PHI) == 0)
776 return false;
777 } else
778 return false;
779 }
780 }
781
782 // For each old PHI node, create a corresponding new PHI node with a type A.
783 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
784 for (auto *OldPN : OldPhiNodes) {
785 Builder.SetInsertPoint(OldPN);
786 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
787 NewPNodes[OldPN] = NewPN;
788 }
789
790 // Fill in the operands of new PHI nodes.
791 for (auto *OldPN : OldPhiNodes) {
792 PHINode *NewPN = NewPNodes[OldPN];
793 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
794 Value *V = OldPN->getOperand(j);
795 Value *NewV = nullptr;
796 Instruction *ACI = dyn_cast<Instruction>(V);
797 // There should not be a AMXcast from a const.
798 if (ACI && isAMXCast(ACI))
799 NewV = ACI->getOperand(0);
800 else if (auto *PrevPN = dyn_cast<PHINode>(V))
801 NewV = NewPNodes[PrevPN];
802 assert(NewV)(static_cast <bool> (NewV) ? void (0) : __assert_fail (
"NewV", "llvm/lib/Target/X86/X86LowerAMXType.cpp", 802, __extension__
__PRETTY_FUNCTION__))
;
803 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
804 }
805 }
806
807 // Traverse all accumulated PHI nodes and process its users,
808 // which are Stores and BitcCasts. Without this processing
809 // NewPHI nodes could be replicated and could lead to extra
810 // moves generated after DeSSA.
811 // If there is a store with type B, change it to type A.
812
813 // Replace users of BitCast B->A with NewPHI. These will help
814 // later to get rid of a closure formed by OldPHI nodes.
815 for (auto *OldPN : OldPhiNodes) {
816 PHINode *NewPN = NewPNodes[OldPN];
817 for (User *V : make_early_inc_range(OldPN->users())) {
818 Instruction *ACI = dyn_cast<Instruction>(V);
819 if (ACI && isAMXCast(ACI)) {
820 Type *TyB = ACI->getOperand(0)->getType();
821 Type *TyA = ACI->getType();
822 assert(TyA == DestTy && TyB == SrcTy)(static_cast <bool> (TyA == DestTy && TyB == SrcTy
) ? void (0) : __assert_fail ("TyA == DestTy && TyB == SrcTy"
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 822, __extension__
__PRETTY_FUNCTION__))
;
823 (void)TyA;
824 (void)TyB;
825 ACI->replaceAllUsesWith(NewPN);
826 DeadInst.insert(ACI);
827 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
828 // We don't need to push PHINode into DeadInst since they are operands
829 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
830 assert(OldPhiNodes.contains(PHI))(static_cast <bool> (OldPhiNodes.contains(PHI)) ? void (
0) : __assert_fail ("OldPhiNodes.contains(PHI)", "llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 830, __extension__ __PRETTY_FUNCTION__))
;
831 (void)PHI;
832 } else
833 llvm_unreachable("all uses should be handled")::llvm::llvm_unreachable_internal("all uses should be handled"
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 833)
;
834 }
835 }
836 return true;
837}
838
839bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
840 bool Change = false;
841 // Collect tile cast instruction.
842 SmallVector<Instruction *, 8> Vec2TileInsts;
843 SmallVector<Instruction *, 8> Tile2VecInsts;
844 SmallVector<Instruction *, 8> PhiCastWorkList;
845 SmallSetVector<Instruction *, 16> DeadInst;
846 for (BasicBlock &BB : Func) {
847 for (Instruction &I : BB) {
848 Value *Vec;
849 if (match(&I,
850 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
851 Vec2TileInsts.push_back(&I);
852 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
853 m_Value(Vec))))
854 Tile2VecInsts.push_back(&I);
855 }
856 }
857
858 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
859 for (auto *Inst : Insts) {
860 for (User *U : Inst->users()) {
861 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
862 if (!II || II->getIntrinsicID() != IID)
863 continue;
864 // T1 = vec2tile V0
865 // V2 = tile2vec T1
866 // V3 = OP V2
867 // -->
868 // T1 = vec2tile V0
869 // V2 = tile2vec T1
870 // V3 = OP V0
871 II->replaceAllUsesWith(Inst->getOperand(0));
872 Change = true;
873 }
874 }
875 };
876
877 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
878 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
879
880 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
881 for (auto *Inst : Insts) {
882 if (Inst->use_empty()) {
883 Inst->eraseFromParent();
884 Change = true;
885 }
886 }
887 };
888
889 EraseInst(Vec2TileInsts);
890 EraseInst(Tile2VecInsts);
891
892 // Handle the A->B->A cast, and there is an intervening PHI node.
893 for (BasicBlock &BB : Func) {
894 for (Instruction &I : BB) {
895 if (isAMXCast(&I)) {
896 if (isa<PHINode>(I.getOperand(0)))
897 PhiCastWorkList.push_back(&I);
898 }
899 }
900 }
901 for (auto *I : PhiCastWorkList) {
902 // We skip the dead Amxcast.
903 if (DeadInst.contains(I))
904 continue;
905 PHINode *PN = cast<PHINode>(I->getOperand(0));
906 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
907 DeadInst.insert(PN);
908 Change = true;
909 }
910 }
911
912 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
913 // have no uses. We do some DeadCodeElimination for them.
914 while (!DeadInst.empty()) {
915 Instruction *I = DeadInst.pop_back_val();
916 Change |= DCEInstruction(I, DeadInst, TLI);
917 }
918 return Change;
919}
920
921// There might be remaining AMXcast after combineAMXcast and they should be
922// handled elegantly.
923bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
924 IRBuilder<> Builder(AMXCast);
925 AllocaInst *AllocaAddr;
926 Value *I8Ptr, *Stride;
927 auto *Src = AMXCast->getOperand(0);
928
929 auto Prepare = [&](Type *MemTy) {
930 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
931 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
932 Stride = Builder.getInt64(64);
933 };
934
935 if (AMXCast->getType()->isX86_AMXTy()) {
936 // %2 = amxcast <225 x i32> %src to x86_amx
937 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
938 // i8* %addr3, i64 60, x86_amx %2)
939 // -->
940 // %addr = alloca <225 x i32>, align 64
941 // store <225 x i32> %src, <225 x i32>* %addr, align 64
942 // %addr2 = bitcast <225 x i32>* %addr to i8*
943 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
944 // i8* %addr2,
945 // i64 60)
946 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
947 // i8* %addr3, i64 60, x86_amx %2)
948 Use &U = *(AMXCast->use_begin());
949 unsigned OpNo = U.getOperandNo();
950 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
951 if (!II)
952 return false; // May be bitcast from x86amx to <256 x i32>.
953 Prepare(AMXCast->getOperand(0)->getType());
954 Builder.CreateStore(Src, AllocaAddr);
955 // TODO we can pick an constant operand for the shape.
956 Value *Row = nullptr, *Col = nullptr;
957 std::tie(Row, Col) = getShape(II, OpNo);
958 std::array<Value *, 4> Args = {
959 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
960 Value *NewInst = Builder.CreateIntrinsic(
961 Intrinsic::x86_tileloadd64_internal, None, Args);
962 AMXCast->replaceAllUsesWith(NewInst);
963 AMXCast->eraseFromParent();
964 } else {
965 // %2 = amxcast x86_amx %src to <225 x i32>
966 // -->
967 // %addr = alloca <225 x i32>, align 64
968 // %addr2 = bitcast <225 x i32>* to i8*
969 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
970 // i8* %addr2, i64 %stride)
971 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
972 auto *II = dyn_cast<IntrinsicInst>(Src);
973 if (!II)
974 return false; // May be bitcast from <256 x i32> to x86amx.
975 Prepare(AMXCast->getType());
976 Value *Row = II->getOperand(0);
977 Value *Col = II->getOperand(1);
978 std::array<Value *, 5> Args = {
979 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
980 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
981 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
982 AMXCast->replaceAllUsesWith(NewInst);
983 AMXCast->eraseFromParent();
984 }
985
986 return true;
987}
988
989bool X86LowerAMXCast::transformAllAMXCast() {
990 bool Change = false;
991 // Collect tile cast instruction.
992 SmallVector<Instruction *, 8> WorkLists;
993 for (BasicBlock &BB : Func) {
994 for (Instruction &I : BB) {
995 if (isAMXCast(&I))
996 WorkLists.push_back(&I);
997 }
998 }
999
1000 for (auto *Inst : WorkLists) {
1001 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1002 }
1003
1004 return Change;
1005}
1006
1007} // anonymous namespace
1008
1009namespace {
1010
1011class X86LowerAMXTypeLegacyPass : public FunctionPass {
1012public:
1013 static char ID;
1014
1015 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1016 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1017 }
1018
1019 bool runOnFunction(Function &F) override {
1020 bool C = false;
1021 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1022 TargetLibraryInfo *TLI =
1023 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1024 X86LowerAMXCast LAC(F);
1025 C |= LAC.combineAMXcast(TLI);
1026 // There might be remaining AMXcast after combineAMXcast and they should be
1027 // handled elegantly.
1028 C |= LAC.transformAllAMXCast();
1029
1030 X86LowerAMXType LAT(F);
1031 C |= LAT.visit();
1032
1033 // Prepare for fast register allocation at O0.
1034 // Todo: May better check the volatile model of AMX code, not just
1035 // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1036 if (TM->getOptLevel() == CodeGenOpt::None) {
1
Assuming the condition is true
2
Taking true branch
1037 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1038 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1039 // sure the amx data is volatile, that is nessary for AMX fast
1040 // register allocation.
1041 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
3
Assuming the condition is true
4
Taking true branch
1042 X86VolatileTileData VTD(F);
1043 C = VTD.volatileTileData() || C;
5
Calling 'X86VolatileTileData::volatileTileData'
1044 }
1045 }
1046
1047 return C;
1048 }
1049
1050 void getAnalysisUsage(AnalysisUsage &AU) const override {
1051 AU.setPreservesCFG();
1052 AU.addRequired<TargetPassConfig>();
1053 AU.addRequired<TargetLibraryInfoWrapperPass>();
1054 }
1055};
1056
1057} // anonymous namespace
1058
1059static const char PassName[] = "Lower AMX type for load/store";
1060char X86LowerAMXTypeLegacyPass::ID = 0;
1061INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry
&Registry) {
1062 false)static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry
&Registry) {
1063INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)initializeTargetPassConfigPass(Registry);
1064INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)initializeTargetLibraryInfoWrapperPassPass(Registry);
1065INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,PassInfo *PI = new PassInfo( PassName, "lower-amx-type", &
X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor
<X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass
(*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag
; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag
, initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry
)); }
1066 false)PassInfo *PI = new PassInfo( PassName, "lower-amx-type", &
X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor
<X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass
(*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag
; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag
, initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry
)); }
1067
1068FunctionPass *llvm::createX86LowerAMXTypePass() {
1069 return new X86LowerAMXTypeLegacyPass();
1070}