Bug Summary

File:llvm/lib/Target/X86/X86LowerAMXType.cpp
Warning:line 581, column 20
Called C++ object pointer is null

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -disable-llvm-verifier -discard-value-names -main-file-name X86LowerAMXType.cpp -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -fno-rounding-math -mconstructor-aliases -munwind-tables -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/build-llvm/lib/Target/X86 -resource-dir /usr/lib/llvm-14/lib/clang/14.0.0 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I /build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/build-llvm/lib/Target/X86 -I /build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86 -I /build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/build-llvm/include -I /build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/include -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-14/lib/clang/14.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -O2 -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/build-llvm/lib/Target/X86 -fdebug-prefix-map=/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0=. -ferror-limit 19 -fvisibility hidden -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2021-08-28-193554-24367-1 -x c++ /build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
42#include "llvm/ADT/PostOrderIterator.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallSet.h"
45#include "llvm/Analysis/OptimizationRemarkEmitter.h"
46#include "llvm/Analysis/TargetLibraryInfo.h"
47#include "llvm/Analysis/TargetTransformInfo.h"
48#include "llvm/CodeGen/Passes.h"
49#include "llvm/CodeGen/TargetPassConfig.h"
50#include "llvm/CodeGen/ValueTypes.h"
51#include "llvm/IR/DataLayout.h"
52#include "llvm/IR/Function.h"
53#include "llvm/IR/IRBuilder.h"
54#include "llvm/IR/Instructions.h"
55#include "llvm/IR/IntrinsicInst.h"
56#include "llvm/IR/IntrinsicsX86.h"
57#include "llvm/IR/PatternMatch.h"
58#include "llvm/InitializePasses.h"
59#include "llvm/Pass.h"
60#include "llvm/Target/TargetMachine.h"
61#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62#include "llvm/Transforms/Utils/Local.h"
63
64using namespace llvm;
65using namespace PatternMatch;
66
67#define DEBUG_TYPE"lower-amx-type" "lower-amx-type"
68
69static bool isAMXCast(Instruction *II) {
70 return match(II,
71 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
72 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
73}
74
75static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
76 Type *Ty) {
77 Function &F = *BB->getParent();
78 Module *M = BB->getModule();
79 const DataLayout &DL = M->getDataLayout();
80
81 LLVMContext &Ctx = Builder.getContext();
82 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
83 unsigned AllocaAS = DL.getAllocaAddrSpace();
84 AllocaInst *AllocaRes =
85 new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
86 AllocaRes->setAlignment(AllocaAlignment);
87 return AllocaRes;
88}
89
90static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
91 for (Instruction &I : F.getEntryBlock())
92 if (!isa<AllocaInst>(&I))
93 return &I;
94 llvm_unreachable("No terminator in the entry block!")::llvm::llvm_unreachable_internal("No terminator in the entry block!"
, "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 94)
;
95}
96
97static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
98 IRBuilder<> Builder(II);
99 Value *Row = nullptr, *Col = nullptr;
100 switch (II->getIntrinsicID()) {
101 default:
102 llvm_unreachable("Expect amx intrinsics")::llvm::llvm_unreachable_internal("Expect amx intrinsics", "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 102)
;
103 case Intrinsic::x86_tileloadd64_internal:
104 case Intrinsic::x86_tileloaddt164_internal:
105 case Intrinsic::x86_tilestored64_internal: {
106 Row = II->getArgOperand(0);
107 Col = II->getArgOperand(1);
108 break;
109 }
110 // a * b + c
111 // The shape depends on which operand.
112 case Intrinsic::x86_tdpbssd_internal:
113 case Intrinsic::x86_tdpbsud_internal:
114 case Intrinsic::x86_tdpbusd_internal:
115 case Intrinsic::x86_tdpbuud_internal:
116 case Intrinsic::x86_tdpbf16ps_internal: {
117 switch (OpNo) {
118 case 3:
119 Row = II->getArgOperand(0);
120 Col = II->getArgOperand(1);
121 break;
122 case 4:
123 Row = II->getArgOperand(0);
124 Col = II->getArgOperand(2);
125 break;
126 case 5:
127 if (isa<ConstantInt>(II->getArgOperand(2)))
128 Row = Builder.getInt16(
129 (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
130 else if (isa<Instruction>(II->getArgOperand(2))) {
131 // When it is not a const value and it is not a function argument, we
132 // create Row after the definition of II->getOperand(2) instead of
133 // before II. For example, II is %118, we try to getshape for %117:
134 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
135 // i32> %115).
136 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
137 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
138 // %117).
139 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
140 // definition is after its user(new tileload for %117).
141 // So, the best choice is to create %row right after the definition of
142 // %106.
143 Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
144 Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
145 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
146 } else {
147 // When it is not a const value and it is a function argument, we create
148 // Row at the entry bb.
149 IRBuilder<> NewBuilder(
150 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
151 Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
152 }
153 Col = II->getArgOperand(1);
154 break;
155 }
156 break;
157 }
158 }
159
160 return std::make_pair(Row, Col);
161}
162
163namespace {
164class X86LowerAMXType {
165 Function &Func;
166
167 // In AMX intrinsics we let Shape = {Row, Col}, but the
168 // RealCol = Col / ElementSize. We may use the RealCol
169 // as a new Row for other new created AMX intrinsics.
170 std::map<Value *, Value *> Col2Row;
171
172public:
173 X86LowerAMXType(Function &F) : Func(F) {}
174 bool visit();
175 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
176 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
177 bool transformBitcast(BitCastInst *Bitcast);
178 Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
179};
180
181Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
182 unsigned Granularity) {
183 if (Col2Row.count(V))
184 return Col2Row[V];
185 IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
186 if (auto *I = dyn_cast<Instruction>(V)) {
187 BasicBlock::iterator Iter = I->getIterator();
188 ++Iter;
189 Builder.SetInsertPoint(&*Iter);
190 }
191 ConstantInt *Gran = Builder.getInt16(Granularity);
192 Value *RealRow = Builder.CreateUDiv(V, Gran);
193 Col2Row[V] = RealRow;
194 return RealRow;
195}
196
197// %src = load <256 x i32>, <256 x i32>* %addr, align 64
198// %2 = bitcast <256 x i32> %src to x86_amx
199// -->
200// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
201// i8* %addr, i64 %stride64)
202void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
203 Value *Row = nullptr, *Col = nullptr;
204 Use &U = *(Bitcast->use_begin());
205 unsigned OpNo = U.getOperandNo();
206 auto *II = cast<IntrinsicInst>(U.getUser());
207 std::tie(Row, Col) = getShape(II, OpNo);
208 IRBuilder<> Builder(Bitcast);
209 // Use the maximun column as stride.
210 Value *Stride = Builder.getInt64(64);
211 Value *I8Ptr =
212 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
213 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
214
215 Value *NewInst =
216 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
217 Bitcast->replaceAllUsesWith(NewInst);
218}
219
220// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
221// %stride);
222// %13 = bitcast x86_amx %src to <256 x i32>
223// store <256 x i32> %13, <256 x i32>* %addr, align 64
224// -->
225// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
226// %stride64, %13)
227void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
228
229 Value *Tile = Bitcast->getOperand(0);
230 auto *II = cast<IntrinsicInst>(Tile);
231 // Tile is output from AMX intrinsic. The first operand of the
232 // intrinsic is row, the second operand of the intrinsic is column.
233 Value *Row = II->getOperand(0);
234 Value *Col = II->getOperand(1);
235 IRBuilder<> Builder(ST);
236 // Use the maximum column as stride. It must be the same with load
237 // stride.
238 Value *Stride = Builder.getInt64(64);
239 Value *I8Ptr =
240 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
241 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
242 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
243 if (Bitcast->hasOneUse())
244 return;
245 // %13 = bitcast x86_amx %src to <256 x i32>
246 // store <256 x i32> %13, <256 x i32>* %addr, align 64
247 // %add = <256 x i32> %13, <256 x i32> %src2
248 // -->
249 // %13 = bitcast x86_amx %src to <256 x i32>
250 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
251 // %stride64, %13)
252 // %14 = load <256 x i32>, %addr
253 // %add = <256 x i32> %14, <256 x i32> %src2
254 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
255 Bitcast->replaceAllUsesWith(Vec);
256}
257
258// transform bitcast to <store, load> instructions.
259bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
260 IRBuilder<> Builder(Bitcast);
261 AllocaInst *AllocaAddr;
262 Value *I8Ptr, *Stride;
263 auto *Src = Bitcast->getOperand(0);
264
265 auto Prepare = [&](Type *MemTy) {
266 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
267 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
268 Stride = Builder.getInt64(64);
269 };
270
271 if (Bitcast->getType()->isX86_AMXTy()) {
272 // %2 = bitcast <256 x i32> %src to x86_amx
273 // -->
274 // %addr = alloca <256 x i32>, align 64
275 // store <256 x i32> %src, <256 x i32>* %addr, align 64
276 // %addr2 = bitcast <256 x i32>* to i8*
277 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
278 // i8* %addr2,
279 // i64 64)
280 Use &U = *(Bitcast->use_begin());
281 unsigned OpNo = U.getOperandNo();
282 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
283 if (!II)
284 return false; // May be bitcast from x86amx to <256 x i32>.
285 Prepare(Bitcast->getOperand(0)->getType());
286 Builder.CreateStore(Src, AllocaAddr);
287 // TODO we can pick an constant operand for the shape.
288 Value *Row = nullptr, *Col = nullptr;
289 std::tie(Row, Col) = getShape(II, OpNo);
290 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
291 Value *NewInst = Builder.CreateIntrinsic(
292 Intrinsic::x86_tileloadd64_internal, None, Args);
293 Bitcast->replaceAllUsesWith(NewInst);
294 } else {
295 // %2 = bitcast x86_amx %src to <256 x i32>
296 // -->
297 // %addr = alloca <256 x i32>, align 64
298 // %addr2 = bitcast <256 x i32>* to i8*
299 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
300 // i8* %addr2, i64 %stride)
301 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
302 auto *II = dyn_cast<IntrinsicInst>(Src);
303 if (!II)
304 return false; // May be bitcast from <256 x i32> to x86amx.
305 Prepare(Bitcast->getType());
306 Value *Row = II->getOperand(0);
307 Value *Col = II->getOperand(1);
308 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
309 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
310 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
311 Bitcast->replaceAllUsesWith(NewInst);
312 }
313
314 return true;
315}
316
317bool X86LowerAMXType::visit() {
318 SmallVector<Instruction *, 8> DeadInsts;
319 Col2Row.clear();
320
321 for (BasicBlock *BB : post_order(&Func)) {
322 for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
323 II != IE;) {
324 Instruction &Inst = *II++;
325 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
326 if (!Bitcast)
327 continue;
328
329 Value *Src = Bitcast->getOperand(0);
330 if (Bitcast->getType()->isX86_AMXTy()) {
331 if (Bitcast->user_empty()) {
332 DeadInsts.push_back(Bitcast);
333 continue;
334 }
335 LoadInst *LD = dyn_cast<LoadInst>(Src);
336 if (!LD) {
337 if (transformBitcast(Bitcast))
338 DeadInsts.push_back(Bitcast);
339 continue;
340 }
341 // If load has mutli-user, duplicate a vector load.
342 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
343 // %2 = bitcast <256 x i32> %src to x86_amx
344 // %add = add <256 x i32> %src, <256 x i32> %src2
345 // -->
346 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
347 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
348 // i8* %addr, i64 %stride64)
349 // %add = add <256 x i32> %src, <256 x i32> %src2
350
351 // If load has one user, the load will be eliminated in DAG ISel.
352 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
353 // %2 = bitcast <256 x i32> %src to x86_amx
354 // -->
355 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
356 // i8* %addr, i64 %stride64)
357 combineLoadBitcast(LD, Bitcast);
358 DeadInsts.push_back(Bitcast);
359 if (LD->hasOneUse())
360 DeadInsts.push_back(LD);
361 } else if (Src->getType()->isX86_AMXTy()) {
362 if (Bitcast->user_empty()) {
363 DeadInsts.push_back(Bitcast);
364 continue;
365 }
366 StoreInst *ST = nullptr;
367 for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
368 UI != UE;) {
369 Value *I = (UI++)->getUser();
370 ST = dyn_cast<StoreInst>(I);
371 if (ST)
372 break;
373 }
374 if (!ST) {
375 if (transformBitcast(Bitcast))
376 DeadInsts.push_back(Bitcast);
377 continue;
378 }
379 // If bitcast (%13) has one use, combine bitcast and store to amx store.
380 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
381 // %stride);
382 // %13 = bitcast x86_amx %src to <256 x i32>
383 // store <256 x i32> %13, <256 x i32>* %addr, align 64
384 // -->
385 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
386 // %stride64, %13)
387 //
388 // If bitcast (%13) has multi-use, transform as below.
389 // %13 = bitcast x86_amx %src to <256 x i32>
390 // store <256 x i32> %13, <256 x i32>* %addr, align 64
391 // %add = <256 x i32> %13, <256 x i32> %src2
392 // -->
393 // %13 = bitcast x86_amx %src to <256 x i32>
394 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
395 // %stride64, %13)
396 // %14 = load <256 x i32>, %addr
397 // %add = <256 x i32> %14, <256 x i32> %src2
398 //
399 combineBitcastStore(Bitcast, ST);
400 // Delete user first.
401 DeadInsts.push_back(ST);
402 DeadInsts.push_back(Bitcast);
403 }
404 }
405 }
406
407 bool C = !DeadInsts.empty();
408
409 for (auto *Inst : DeadInsts)
410 Inst->eraseFromParent();
411
412 return C;
413}
414} // anonymous namespace
415
416static Value *getAllocaPos(BasicBlock *BB) {
417 Module *M = BB->getModule();
418 Function *F = BB->getParent();
419 IRBuilder<> Builder(&F->getEntryBlock().front());
420 const DataLayout &DL = M->getDataLayout();
421 unsigned AllocaAS = DL.getAllocaAddrSpace();
422 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
423 AllocaInst *AllocaRes =
424 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
425 BasicBlock::iterator Iter = AllocaRes->getIterator();
426 ++Iter;
427 Builder.SetInsertPoint(&*Iter);
428 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
429 return I8Ptr;
430}
431
432static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
433 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!")(static_cast <bool> (TileDef->getType()->isX86_AMXTy
() && "Not define tile!") ? void (0) : __assert_fail (
"TileDef->getType()->isX86_AMXTy() && \"Not define tile!\""
, "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 433, __extension__ __PRETTY_FUNCTION__))
;
434 auto *II = cast<IntrinsicInst>(TileDef);
435 assert(II && "Not tile intrinsic!")(static_cast <bool> (II && "Not tile intrinsic!"
) ? void (0) : __assert_fail ("II && \"Not tile intrinsic!\""
, "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 435, __extension__ __PRETTY_FUNCTION__))
;
436 Value *Row = II->getOperand(0);
437 Value *Col = II->getOperand(1);
438
439 BasicBlock *BB = TileDef->getParent();
440 BasicBlock::iterator Iter = TileDef->getIterator();
441 IRBuilder<> Builder(BB, ++Iter);
442 Value *Stride = Builder.getInt64(64);
443 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
444
445 Instruction *TileStore =
446 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
447 return TileStore;
448}
449
450static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
451 Value *V = U.get();
452 assert(V->getType()->isX86_AMXTy() && "Not define tile!")(static_cast <bool> (V->getType()->isX86_AMXTy() &&
"Not define tile!") ? void (0) : __assert_fail ("V->getType()->isX86_AMXTy() && \"Not define tile!\""
, "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 452, __extension__ __PRETTY_FUNCTION__))
;
453
454 // Get tile shape.
455 IntrinsicInst *II = nullptr;
456 if (IsPHI) {
457 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
458 II = cast<IntrinsicInst>(PhiOp);
459 } else {
460 II = cast<IntrinsicInst>(V);
461 }
462 Value *Row = II->getOperand(0);
463 Value *Col = II->getOperand(1);
464
465 Instruction *UserI = dyn_cast<Instruction>(U.getUser());
466 IRBuilder<> Builder(UserI);
467 Value *Stride = Builder.getInt64(64);
468 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
469
470 Value *TileLoad =
471 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
472 UserI->replaceUsesOfWith(V, TileLoad);
473}
474
475static bool isIncomingOfPHI(Instruction *I) {
476 for (Use &U : I->uses()) {
477 User *V = U.getUser();
478 if (isa<PHINode>(V))
479 return true;
480 }
481 return false;
482}
483
484// Let all AMX tile data become volatile data, shorten the life range
485// of each tile register before fast register allocation.
486namespace {
487class X86VolatileTileData {
488 Function &F;
489
490public:
491 X86VolatileTileData(Function &Func) : F(Func) {}
492 Value *updatePhiIncomings(BasicBlock *BB,
493 SmallVector<Instruction *, 2> &Incomings);
494 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
495 bool volatileTileData();
496 void volatileTilePHI(PHINode *Inst);
497 void volatileTileNonPHI(Instruction *I);
498};
499
500Value *X86VolatileTileData::updatePhiIncomings(
501 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
502 Value *I8Ptr = getAllocaPos(BB);
503
504 for (auto *I : Incomings) {
505 User *Store = createTileStore(I, I8Ptr);
506
507 // All its uses (except phi) should load from stored mem.
508 for (Use &U : I->uses()) {
509 User *V = U.getUser();
510 if (isa<PHINode>(V) || V == Store)
511 continue;
512 replaceWithTileLoad(U, I8Ptr);
513 }
514 }
515 return I8Ptr;
516}
517
518void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
519 Value *StorePtr) {
520 for (Use &U : PHI->uses())
521 replaceWithTileLoad(U, StorePtr, true);
522 PHI->eraseFromParent();
523}
524
525// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
526// and their related AMX intrinsics.
527// 1) PHI Def should change to tileload.
528// 2) PHI Incoming Values should tilestored in just after their def.
529// 3) The mem of these tileload and tilestores should be same.
530// e.g.
531// ------------------------------------------------------
532// bb_dom:
533// ...
534// br i1 %bool.cond, label %if.else, label %if.then
535//
536// if.then:
537// def %t0 = ...
538// ...
539// use %t0
540// ...
541// br label %if.end
542//
543// if.else:
544// def %t1 = ...
545// br label %if.end
546//
547// if.end:
548// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
549// ...
550// use %td
551// ------------------------------------------------------
552// -->
553// ------------------------------------------------------
554// bb_entry:
555// %mem = alloca <256 x i32>, align 1024 *
556// ...
557// bb_dom:
558// ...
559// br i1 %bool.cond, label %if.else, label %if.then
560//
561// if.then:
562// def %t0 = ...
563// call void @llvm.x86.tilestored64.internal(mem, %t0) *
564// ...
565// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
566// use %t0` *
567// ...
568// br label %if.end
569//
570// if.else:
571// def %t1 = ...
572// call void @llvm.x86.tilestored64.internal(mem, %t1) *
573// br label %if.end
574//
575// if.end:
576// ...
577// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
578// use %td
579// ------------------------------------------------------
580void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
581 BasicBlock *BB = PHI->getParent();
11
Called C++ object pointer is null
582 SmallVector<Instruction *, 2> Incomings;
583
584 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
585 Value *Op = PHI->getIncomingValue(I);
586 Instruction *Inst = dyn_cast<Instruction>(Op);
587 assert(Inst && "We shouldn't fold AMX instrution!")(static_cast <bool> (Inst && "We shouldn't fold AMX instrution!"
) ? void (0) : __assert_fail ("Inst && \"We shouldn't fold AMX instrution!\""
, "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 587, __extension__ __PRETTY_FUNCTION__))
;
588 Incomings.push_back(Inst);
589 }
590
591 Value *StorePtr = updatePhiIncomings(BB, Incomings);
592 replacePhiDefWithLoad(PHI, StorePtr);
593}
594
595// Store the defined tile and load it before use.
596// All its users are not PHI.
597// e.g.
598// ------------------------------------------------------
599// def %td = ...
600// ...
601// "use %td"
602// ------------------------------------------------------
603// -->
604// ------------------------------------------------------
605// def %td = ...
606// call void @llvm.x86.tilestored64.internal(mem, %td)
607// ...
608// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
609// "use %td2"
610// ------------------------------------------------------
611void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
612 BasicBlock *BB = I->getParent();
613 Value *I8Ptr = getAllocaPos(BB);
614 User *Store = createTileStore(I, I8Ptr);
615
616 // All its uses should load from stored mem.
617 for (Use &U : I->uses()) {
618 User *V = U.getUser();
619 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!")(static_cast <bool> (!isa<PHINode>(V) && "PHI Nodes should be excluded!"
) ? void (0) : __assert_fail ("!isa<PHINode>(V) && \"PHI Nodes should be excluded!\""
, "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 619, __extension__ __PRETTY_FUNCTION__))
;
620 if (V != Store)
621 replaceWithTileLoad(U, I8Ptr);
622 }
623}
624
625// Volatile Tile Model:
626// 1) All the uses of tile data comes from tileload in time.
627// 2) All the defs of tile data tilestore into mem immediately.
628// For example:
629// --------------------------------------------------------------------------
630// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
631// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
632// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
633// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
634// call void @llvm.x86.tilestored64.internal(... td) area
635// --------------------------------------------------------------------------
636// 3) No terminator, call or other amx instructions in the key amx area.
637bool X86VolatileTileData::volatileTileData() {
638 bool Changed = false;
639 for (BasicBlock &BB : F) {
640 SmallVector<Instruction *, 2> PHIInsts;
641 SmallVector<Instruction *, 8> AMXDefInsts;
642
643 for (Instruction &I : BB) {
644 if (!I.getType()->isX86_AMXTy())
645 continue;
646 if (isa<PHINode>(&I))
647 PHIInsts.push_back(&I);
648 else
649 AMXDefInsts.push_back(&I);
650 }
651
652 // First we "volatile" the non-phi related amx intrinsics.
653 for (Instruction *I : AMXDefInsts) {
6
Assuming '__begin2' is equal to '__end2'
654 if (isIncomingOfPHI(I))
655 continue;
656 volatileTileNonPHI(I);
657 Changed = true;
658 }
659
660 for (Instruction *I : PHIInsts) {
7
Assuming '__begin2' is not equal to '__end2'
661 volatileTilePHI(dyn_cast<PHINode>(I));
8
Assuming 'I' is not a 'PHINode'
9
Passing null pointer value via 1st parameter 'PHI'
10
Calling 'X86VolatileTileData::volatileTilePHI'
662 Changed = true;
663 }
664 }
665 return Changed;
666}
667
668} // anonymous namespace
669
670namespace {
671
672class X86LowerAMXCast {
673 Function &Func;
674
675public:
676 X86LowerAMXCast(Function &F) : Func(F) {}
677 bool combineAMXcast(TargetLibraryInfo *TLI);
678 bool transformAMXCast(IntrinsicInst *AMXCast);
679 bool transformAllAMXCast();
680 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
681 SmallSetVector<Instruction *, 16> &DeadInst);
682};
683
684static bool DCEInstruction(Instruction *I,
685 SmallSetVector<Instruction *, 16> &WorkList,
686 const TargetLibraryInfo *TLI) {
687 if (isInstructionTriviallyDead(I, TLI)) {
688 salvageDebugInfo(*I);
689 salvageKnowledge(I);
690
691 // Null out all of the instruction's operands to see if any operand becomes
692 // dead as we go.
693 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
694 Value *OpV = I->getOperand(i);
695 I->setOperand(i, nullptr);
696
697 if (!OpV->use_empty() || I == OpV)
698 continue;
699
700 // If the operand is an instruction that became dead as we nulled out the
701 // operand, and if it is 'trivially' dead, delete it in a future loop
702 // iteration.
703 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
704 if (isInstructionTriviallyDead(OpI, TLI)) {
705 WorkList.insert(OpI);
706 }
707 }
708 }
709 I->eraseFromParent();
710 return true;
711 }
712 return false;
713}
714
715/// This function handles following case
716///
717/// A -> B amxcast
718/// PHI
719/// B -> A amxcast
720///
721/// All the related PHI nodes can be replaced by new PHI nodes with type A.
722/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
723bool X86LowerAMXCast::optimizeAMXCastFromPhi(
724 IntrinsicInst *CI, PHINode *PN,
725 SmallSetVector<Instruction *, 16> &DeadInst) {
726 IRBuilder<> Builder(CI);
727 Value *Src = CI->getOperand(0);
728 Type *SrcTy = Src->getType(); // Type B
729 Type *DestTy = CI->getType(); // Type A
730
731 SmallVector<PHINode *, 4> PhiWorklist;
732 SmallSetVector<PHINode *, 4> OldPhiNodes;
733
734 // Find all of the A->B casts and PHI nodes.
735 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
736 // OldPhiNodes is used to track all known PHI nodes, before adding a new
737 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
738 PhiWorklist.push_back(PN);
739 OldPhiNodes.insert(PN);
740 while (!PhiWorklist.empty()) {
741 auto *OldPN = PhiWorklist.pop_back_val();
742 for (Value *IncValue : OldPN->incoming_values()) {
743 // TODO: currently, We ignore cases where it is a const. In the future, we
744 // might support const.
745 if (isa<Constant>(IncValue))
746 return false;
747
748 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
749 if (OldPhiNodes.insert(PNode))
750 PhiWorklist.push_back(PNode);
751 continue;
752 }
753 Instruction *ACI = dyn_cast<Instruction>(IncValue);
754 if (ACI && isAMXCast(ACI)) {
755 // Verify it's a A->B cast.
756 Type *TyA = ACI->getOperand(0)->getType();
757 Type *TyB = ACI->getType();
758 if (TyA != DestTy || TyB != SrcTy)
759 return false;
760 continue;
761 }
762 return false;
763 }
764 }
765
766 // Check that each user of each old PHI node is something that we can
767 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
768 for (auto *OldPN : OldPhiNodes) {
769 for (User *V : OldPN->users()) {
770 Instruction *ACI = dyn_cast<Instruction>(V);
771 if (ACI && isAMXCast(ACI)) {
772 // Verify it's a B->A cast.
773 Type *TyB = ACI->getOperand(0)->getType();
774 Type *TyA = ACI->getType();
775 if (TyA != DestTy || TyB != SrcTy)
776 return false;
777 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
778 // As long as the user is another old PHI node, then even if we don't
779 // rewrite it, the PHI web we're considering won't have any users
780 // outside itself, so it'll be dead.
781 // example:
782 // bb.0:
783 // %0 = amxcast ...
784 // bb.1:
785 // %1 = amxcast ...
786 // bb.2:
787 // %goodphi = phi %0, %1
788 // %3 = amxcast %goodphi
789 // bb.3:
790 // %goodphi2 = phi %0, %goodphi
791 // %4 = amxcast %goodphi2
792 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
793 // outside the phi-web, so the combination stop When
794 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
795 // will be done.
796 if (OldPhiNodes.count(PHI) == 0)
797 return false;
798 } else
799 return false;
800 }
801 }
802
803 // For each old PHI node, create a corresponding new PHI node with a type A.
804 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
805 for (auto *OldPN : OldPhiNodes) {
806 Builder.SetInsertPoint(OldPN);
807 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
808 NewPNodes[OldPN] = NewPN;
809 }
810
811 // Fill in the operands of new PHI nodes.
812 for (auto *OldPN : OldPhiNodes) {
813 PHINode *NewPN = NewPNodes[OldPN];
814 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
815 Value *V = OldPN->getOperand(j);
816 Value *NewV = nullptr;
817 Instruction *ACI = dyn_cast<Instruction>(V);
818 // There should not be a AMXcast from a const.
819 if (ACI && isAMXCast(ACI))
820 NewV = ACI->getOperand(0);
821 else if (auto *PrevPN = dyn_cast<PHINode>(V))
822 NewV = NewPNodes[PrevPN];
823 assert(NewV)(static_cast <bool> (NewV) ? void (0) : __assert_fail (
"NewV", "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 823, __extension__ __PRETTY_FUNCTION__))
;
824 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
825 }
826 }
827
828 // Traverse all accumulated PHI nodes and process its users,
829 // which are Stores and BitcCasts. Without this processing
830 // NewPHI nodes could be replicated and could lead to extra
831 // moves generated after DeSSA.
832 // If there is a store with type B, change it to type A.
833
834 // Replace users of BitCast B->A with NewPHI. These will help
835 // later to get rid of a closure formed by OldPHI nodes.
836 for (auto *OldPN : OldPhiNodes) {
837 PHINode *NewPN = NewPNodes[OldPN];
838 for (User *V : make_early_inc_range(OldPN->users())) {
839 Instruction *ACI = dyn_cast<Instruction>(V);
840 if (ACI && isAMXCast(ACI)) {
841 Type *TyB = ACI->getOperand(0)->getType();
842 Type *TyA = ACI->getType();
843 assert(TyA == DestTy && TyB == SrcTy)(static_cast <bool> (TyA == DestTy && TyB == SrcTy
) ? void (0) : __assert_fail ("TyA == DestTy && TyB == SrcTy"
, "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 843, __extension__ __PRETTY_FUNCTION__))
;
844 (void)TyA;
845 (void)TyB;
846 ACI->replaceAllUsesWith(NewPN);
847 DeadInst.insert(ACI);
848 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
849 // We don't need to push PHINode into DeadInst since they are operands
850 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
851 assert(OldPhiNodes.contains(PHI))(static_cast <bool> (OldPhiNodes.contains(PHI)) ? void (
0) : __assert_fail ("OldPhiNodes.contains(PHI)", "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 851, __extension__ __PRETTY_FUNCTION__))
;
852 (void)PHI;
853 } else
854 llvm_unreachable("all uses should be handled")::llvm::llvm_unreachable_internal("all uses should be handled"
, "/build/llvm-toolchain-snapshot-14~++20210828111110+16086d47c0d0/llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 854)
;
855 }
856 }
857 return true;
858}
859
860bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
861 bool Change = false;
862 // Collect tile cast instruction.
863 SmallVector<Instruction *, 8> Vec2TileInsts;
864 SmallVector<Instruction *, 8> Tile2VecInsts;
865 SmallVector<Instruction *, 8> PhiCastWorkList;
866 SmallSetVector<Instruction *, 16> DeadInst;
867 for (BasicBlock &BB : Func) {
868 for (Instruction &I : BB) {
869 Value *Vec;
870 if (match(&I,
871 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
872 Vec2TileInsts.push_back(&I);
873 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
874 m_Value(Vec))))
875 Tile2VecInsts.push_back(&I);
876 }
877 }
878
879 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
880 for (auto *Inst : Insts) {
881 for (User *U : Inst->users()) {
882 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
883 if (!II || II->getIntrinsicID() != IID)
884 continue;
885 // T1 = vec2tile V0
886 // V2 = tile2vec T1
887 // V3 = OP V2
888 // -->
889 // T1 = vec2tile V0
890 // V2 = tile2vec T1
891 // V3 = OP V0
892 II->replaceAllUsesWith(Inst->getOperand(0));
893 Change = true;
894 }
895 }
896 };
897
898 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
899 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
900
901 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
902 for (auto *Inst : Insts) {
903 if (Inst->use_empty()) {
904 Inst->eraseFromParent();
905 Change = true;
906 }
907 }
908 };
909
910 EraseInst(Vec2TileInsts);
911 EraseInst(Tile2VecInsts);
912
913 // Handle the A->B->A cast, and there is an intervening PHI node.
914 for (BasicBlock &BB : Func) {
915 for (Instruction &I : BB) {
916 if (isAMXCast(&I)) {
917 if (isa<PHINode>(I.getOperand(0)))
918 PhiCastWorkList.push_back(&I);
919 }
920 }
921 }
922 for (auto *I : PhiCastWorkList) {
923 // We skip the dead Amxcast.
924 if (DeadInst.contains(I))
925 continue;
926 PHINode *PN = cast<PHINode>(I->getOperand(0));
927 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
928 DeadInst.insert(PN);
929 Change = true;
930 }
931 }
932
933 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
934 // have no uses. We do some DeadCodeElimination for them.
935 while (!DeadInst.empty()) {
936 Instruction *I = DeadInst.pop_back_val();
937 Change |= DCEInstruction(I, DeadInst, TLI);
938 }
939 return Change;
940}
941
942// There might be remaining AMXcast after combineAMXcast and they should be
943// handled elegantly.
944bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
945 IRBuilder<> Builder(AMXCast);
946 AllocaInst *AllocaAddr;
947 Value *I8Ptr, *Stride;
948 auto *Src = AMXCast->getOperand(0);
949
950 auto Prepare = [&](Type *MemTy) {
951 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
952 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
953 Stride = Builder.getInt64(64);
954 };
955
956 if (AMXCast->getType()->isX86_AMXTy()) {
957 // %2 = amxcast <225 x i32> %src to x86_amx
958 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
959 // i8* %addr3, i64 60, x86_amx %2)
960 // -->
961 // %addr = alloca <225 x i32>, align 64
962 // store <225 x i32> %src, <225 x i32>* %addr, align 64
963 // %addr2 = bitcast <225 x i32>* %addr to i8*
964 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
965 // i8* %addr2,
966 // i64 60)
967 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
968 // i8* %addr3, i64 60, x86_amx %2)
969 Use &U = *(AMXCast->use_begin());
970 unsigned OpNo = U.getOperandNo();
971 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
972 if (!II)
973 return false; // May be bitcast from x86amx to <256 x i32>.
974 Prepare(AMXCast->getOperand(0)->getType());
975 Builder.CreateStore(Src, AllocaAddr);
976 // TODO we can pick an constant operand for the shape.
977 Value *Row = nullptr, *Col = nullptr;
978 std::tie(Row, Col) = getShape(II, OpNo);
979 std::array<Value *, 4> Args = {
980 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
981 Value *NewInst = Builder.CreateIntrinsic(
982 Intrinsic::x86_tileloadd64_internal, None, Args);
983 AMXCast->replaceAllUsesWith(NewInst);
984 AMXCast->eraseFromParent();
985 } else {
986 // %2 = amxcast x86_amx %src to <225 x i32>
987 // -->
988 // %addr = alloca <225 x i32>, align 64
989 // %addr2 = bitcast <225 x i32>* to i8*
990 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
991 // i8* %addr2, i64 %stride)
992 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
993 auto *II = dyn_cast<IntrinsicInst>(Src);
994 if (!II)
995 return false; // May be bitcast from <256 x i32> to x86amx.
996 Prepare(AMXCast->getType());
997 Value *Row = II->getOperand(0);
998 Value *Col = II->getOperand(1);
999 std::array<Value *, 5> Args = {
1000 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1001 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
1002 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1003 AMXCast->replaceAllUsesWith(NewInst);
1004 AMXCast->eraseFromParent();
1005 }
1006
1007 return true;
1008}
1009
1010bool X86LowerAMXCast::transformAllAMXCast() {
1011 bool Change = false;
1012 // Collect tile cast instruction.
1013 SmallVector<Instruction *, 8> WorkLists;
1014 for (BasicBlock &BB : Func) {
1015 for (Instruction &I : BB) {
1016 if (isAMXCast(&I))
1017 WorkLists.push_back(&I);
1018 }
1019 }
1020
1021 for (auto *Inst : WorkLists) {
1022 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1023 }
1024
1025 return Change;
1026}
1027
1028} // anonymous namespace
1029
1030namespace {
1031
1032class X86LowerAMXTypeLegacyPass : public FunctionPass {
1033public:
1034 static char ID;
1035
1036 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1037 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1038 }
1039
1040 bool runOnFunction(Function &F) override {
1041 bool C = false;
1042 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1043 TargetLibraryInfo *TLI =
1044 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1045 X86LowerAMXCast LAC(F);
1046 C |= LAC.combineAMXcast(TLI);
1047 // There might be remaining AMXcast after combineAMXcast and they should be
1048 // handled elegantly.
1049 C |= LAC.transformAllAMXCast();
1050
1051 X86LowerAMXType LAT(F);
1052 C |= LAT.visit();
1053
1054 // Prepare for fast register allocation at O0.
1055 // Todo: May better check the volatile model of AMX code, not just
1056 // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1057 if (TM->getOptLevel() == CodeGenOpt::None) {
1
Assuming the condition is true
2
Taking true branch
1058 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1059 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1060 // sure the amx data is volatile, that is nessary for AMX fast
1061 // register allocation.
1062 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
3
Assuming the condition is true
4
Taking true branch
1063 X86VolatileTileData VTD(F);
1064 C = VTD.volatileTileData() || C;
5
Calling 'X86VolatileTileData::volatileTileData'
1065 }
1066 }
1067
1068 return C;
1069 }
1070
1071 void getAnalysisUsage(AnalysisUsage &AU) const override {
1072 AU.setPreservesCFG();
1073 AU.addRequired<TargetPassConfig>();
1074 AU.addRequired<TargetLibraryInfoWrapperPass>();
1075 }
1076};
1077
1078} // anonymous namespace
1079
1080static const char PassName[] = "Lower AMX type for load/store";
1081char X86LowerAMXTypeLegacyPass::ID = 0;
1082INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry
&Registry) {
1083 false)static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry
&Registry) {
1084INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)initializeTargetPassConfigPass(Registry);
1085INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)initializeTargetLibraryInfoWrapperPassPass(Registry);
1086INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,PassInfo *PI = new PassInfo( PassName, "lower-amx-type", &
X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor
<X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass
(*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag
; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag
, initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry
)); }
1087 false)PassInfo *PI = new PassInfo( PassName, "lower-amx-type", &
X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor
<X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass
(*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag
; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag
, initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry
)); }
1088
1089FunctionPass *llvm::createX86LowerAMXTypePass() {
1090 return new X86LowerAMXTypeLegacyPass();
1091}