Bug Summary

File:build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/lib/Target/X86/X86LowerAMXType.cpp
Warning:line 486, column 20
Called C++ object pointer is null

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name X86LowerAMXType.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I lib/Target/X86 -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/lib/Target/X86 -I include -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm=build-llvm -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm=build-llvm -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -O3 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm=build-llvm -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -ferror-limit 19 -fvisibility=hidden -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-09-04-125545-48738-1 -x c++ /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/lib/Target/X86/X86LowerAMXType.cpp
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
42#include "llvm/ADT/PostOrderIterator.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallSet.h"
45#include "llvm/Analysis/OptimizationRemarkEmitter.h"
46#include "llvm/Analysis/TargetLibraryInfo.h"
47#include "llvm/Analysis/TargetTransformInfo.h"
48#include "llvm/CodeGen/Passes.h"
49#include "llvm/CodeGen/TargetPassConfig.h"
50#include "llvm/CodeGen/ValueTypes.h"
51#include "llvm/IR/DataLayout.h"
52#include "llvm/IR/Function.h"
53#include "llvm/IR/IRBuilder.h"
54#include "llvm/IR/Instructions.h"
55#include "llvm/IR/IntrinsicInst.h"
56#include "llvm/IR/IntrinsicsX86.h"
57#include "llvm/IR/PatternMatch.h"
58#include "llvm/InitializePasses.h"
59#include "llvm/Pass.h"
60#include "llvm/Target/TargetMachine.h"
61#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62#include "llvm/Transforms/Utils/Local.h"
63
64#include <map>
65
66using namespace llvm;
67using namespace PatternMatch;
68
69#define DEBUG_TYPE"lower-amx-type" "lower-amx-type"
70
71static bool isAMXCast(Instruction *II) {
72 return match(II,
73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75}
76
77static bool isAMXIntrinsic(Value *I) {
78 auto *II = dyn_cast<IntrinsicInst>(I);
79 if (!II)
80 return false;
81 if (isAMXCast(II))
82 return false;
83 // Check if return type or parameter is x86_amx. If it is x86_amx
84 // the intrinsic must be x86 amx intrinsics.
85 if (II->getType()->isX86_AMXTy())
86 return true;
87 for (Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
89 return true;
90 }
91
92 return false;
93}
94
95static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
96 Type *Ty) {
97 Function &F = *BB->getParent();
98 Module *M = BB->getModule();
99 const DataLayout &DL = M->getDataLayout();
100
101 LLVMContext &Ctx = Builder.getContext();
102 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
103 unsigned AllocaAS = DL.getAllocaAddrSpace();
104 AllocaInst *AllocaRes =
105 new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
106 AllocaRes->setAlignment(AllocaAlignment);
107 return AllocaRes;
108}
109
110static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
111 for (Instruction &I : F.getEntryBlock())
112 if (!isa<AllocaInst>(&I))
113 return &I;
114 llvm_unreachable("No terminator in the entry block!")::llvm::llvm_unreachable_internal("No terminator in the entry block!"
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 114)
;
115}
116
117static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
118 IRBuilder<> Builder(II);
119 Value *Row = nullptr, *Col = nullptr;
120 switch (II->getIntrinsicID()) {
121 default:
122 llvm_unreachable("Expect amx intrinsics")::llvm::llvm_unreachable_internal("Expect amx intrinsics", "llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 122)
;
123 case Intrinsic::x86_tileloadd64_internal:
124 case Intrinsic::x86_tileloaddt164_internal:
125 case Intrinsic::x86_tilestored64_internal: {
126 Row = II->getArgOperand(0);
127 Col = II->getArgOperand(1);
128 break;
129 }
130 // a * b + c
131 // The shape depends on which operand.
132 case Intrinsic::x86_tdpbssd_internal:
133 case Intrinsic::x86_tdpbsud_internal:
134 case Intrinsic::x86_tdpbusd_internal:
135 case Intrinsic::x86_tdpbuud_internal:
136 case Intrinsic::x86_tdpbf16ps_internal: {
137 switch (OpNo) {
138 case 3:
139 Row = II->getArgOperand(0);
140 Col = II->getArgOperand(1);
141 break;
142 case 4:
143 Row = II->getArgOperand(0);
144 Col = II->getArgOperand(2);
145 break;
146 case 5:
147 if (isa<ConstantInt>(II->getArgOperand(2)))
148 Row = Builder.getInt16(
149 (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
150 else if (isa<Instruction>(II->getArgOperand(2))) {
151 // When it is not a const value and it is not a function argument, we
152 // create Row after the definition of II->getOperand(2) instead of
153 // before II. For example, II is %118, we try to getshape for %117:
154 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
155 // i32> %115).
156 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
157 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
158 // %117).
159 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
160 // definition is after its user(new tileload for %117).
161 // So, the best choice is to create %row right after the definition of
162 // %106.
163 Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
164 Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
165 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
166 } else {
167 // When it is not a const value and it is a function argument, we create
168 // Row at the entry bb.
169 IRBuilder<> NewBuilder(
170 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
171 Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
172 }
173 Col = II->getArgOperand(1);
174 break;
175 }
176 break;
177 }
178 }
179
180 return std::make_pair(Row, Col);
181}
182
183static std::pair<Value *, Value *> getShape(PHINode *Phi) {
184 Use &U = *(Phi->use_begin());
185 unsigned OpNo = U.getOperandNo();
186 User *V = U.getUser();
187 // TODO We don't traverse all users. To make the algorithm simple, here we
188 // just traverse the first user. If we can find shape, then return the shape,
189 // otherwise just return nullptr and the optimization for undef/zero will be
190 // abandoned.
191 while (V) {
192 if (isAMXCast(dyn_cast<Instruction>(V))) {
193 if (V->use_empty())
194 break;
195 Use &U = *(V->use_begin());
196 OpNo = U.getOperandNo();
197 V = U.getUser();
198 } else if (isAMXIntrinsic(V)) {
199 return getShape(cast<IntrinsicInst>(V), OpNo);
200 } else if (isa<PHINode>(V)) {
201 if (V->use_empty())
202 break;
203 Use &U = *(V->use_begin());
204 V = U.getUser();
205 } else {
206 break;
207 }
208 }
209
210 return std::make_pair(nullptr, nullptr);
211}
212
213namespace {
214class X86LowerAMXType {
215 Function &Func;
216
217 // In AMX intrinsics we let Shape = {Row, Col}, but the
218 // RealCol = Col / ElementSize. We may use the RealCol
219 // as a new Row for other new created AMX intrinsics.
220 std::map<Value *, Value *> Col2Row;
221
222public:
223 X86LowerAMXType(Function &F) : Func(F) {}
224 bool visit();
225 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
226 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
227 bool transformBitcast(BitCastInst *Bitcast);
228};
229
230// %src = load <256 x i32>, <256 x i32>* %addr, align 64
231// %2 = bitcast <256 x i32> %src to x86_amx
232// -->
233// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
234// i8* %addr, i64 %stride64)
235void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
236 Value *Row = nullptr, *Col = nullptr;
237 Use &U = *(Bitcast->use_begin());
238 unsigned OpNo = U.getOperandNo();
239 auto *II = cast<IntrinsicInst>(U.getUser());
240 std::tie(Row, Col) = getShape(II, OpNo);
241 IRBuilder<> Builder(Bitcast);
242 // Use the maximun column as stride.
243 Value *Stride = Builder.getInt64(64);
244 Value *I8Ptr =
245 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
246 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
247
248 Value *NewInst =
249 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
250 Bitcast->replaceAllUsesWith(NewInst);
251}
252
253// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
254// %stride);
255// %13 = bitcast x86_amx %src to <256 x i32>
256// store <256 x i32> %13, <256 x i32>* %addr, align 64
257// -->
258// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
259// %stride64, %13)
260void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
261
262 Value *Tile = Bitcast->getOperand(0);
263 auto *II = cast<IntrinsicInst>(Tile);
264 // Tile is output from AMX intrinsic. The first operand of the
265 // intrinsic is row, the second operand of the intrinsic is column.
266 Value *Row = II->getOperand(0);
267 Value *Col = II->getOperand(1);
268 IRBuilder<> Builder(ST);
269 // Use the maximum column as stride. It must be the same with load
270 // stride.
271 Value *Stride = Builder.getInt64(64);
272 Value *I8Ptr =
273 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
274 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
275 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
276 if (Bitcast->hasOneUse())
277 return;
278 // %13 = bitcast x86_amx %src to <256 x i32>
279 // store <256 x i32> %13, <256 x i32>* %addr, align 64
280 // %add = <256 x i32> %13, <256 x i32> %src2
281 // -->
282 // %13 = bitcast x86_amx %src to <256 x i32>
283 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
284 // %stride64, %13)
285 // %14 = load <256 x i32>, %addr
286 // %add = <256 x i32> %14, <256 x i32> %src2
287 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
288 Bitcast->replaceAllUsesWith(Vec);
289}
290
291// transform bitcast to <store, load> instructions.
292bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
293 IRBuilder<> Builder(Bitcast);
294 AllocaInst *AllocaAddr;
295 Value *I8Ptr, *Stride;
296 auto *Src = Bitcast->getOperand(0);
297
298 auto Prepare = [&](Type *MemTy) {
299 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
300 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
301 Stride = Builder.getInt64(64);
302 };
303
304 if (Bitcast->getType()->isX86_AMXTy()) {
305 // %2 = bitcast <256 x i32> %src to x86_amx
306 // -->
307 // %addr = alloca <256 x i32>, align 64
308 // store <256 x i32> %src, <256 x i32>* %addr, align 64
309 // %addr2 = bitcast <256 x i32>* to i8*
310 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
311 // i8* %addr2,
312 // i64 64)
313 Use &U = *(Bitcast->use_begin());
314 unsigned OpNo = U.getOperandNo();
315 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
316 if (!II)
317 return false; // May be bitcast from x86amx to <256 x i32>.
318 Prepare(Bitcast->getOperand(0)->getType());
319 Builder.CreateStore(Src, AllocaAddr);
320 // TODO we can pick an constant operand for the shape.
321 Value *Row = nullptr, *Col = nullptr;
322 std::tie(Row, Col) = getShape(II, OpNo);
323 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
324 Value *NewInst = Builder.CreateIntrinsic(
325 Intrinsic::x86_tileloadd64_internal, None, Args);
326 Bitcast->replaceAllUsesWith(NewInst);
327 } else {
328 // %2 = bitcast x86_amx %src to <256 x i32>
329 // -->
330 // %addr = alloca <256 x i32>, align 64
331 // %addr2 = bitcast <256 x i32>* to i8*
332 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
333 // i8* %addr2, i64 %stride)
334 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
335 auto *II = dyn_cast<IntrinsicInst>(Src);
336 if (!II)
337 return false; // May be bitcast from <256 x i32> to x86amx.
338 Prepare(Bitcast->getType());
339 Value *Row = II->getOperand(0);
340 Value *Col = II->getOperand(1);
341 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
342 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
343 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
344 Bitcast->replaceAllUsesWith(NewInst);
345 }
346
347 return true;
348}
349
350bool X86LowerAMXType::visit() {
351 SmallVector<Instruction *, 8> DeadInsts;
352 Col2Row.clear();
353
354 for (BasicBlock *BB : post_order(&Func)) {
355 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
356 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
357 if (!Bitcast)
358 continue;
359
360 Value *Src = Bitcast->getOperand(0);
361 if (Bitcast->getType()->isX86_AMXTy()) {
362 if (Bitcast->user_empty()) {
363 DeadInsts.push_back(Bitcast);
364 continue;
365 }
366 LoadInst *LD = dyn_cast<LoadInst>(Src);
367 if (!LD) {
368 if (transformBitcast(Bitcast))
369 DeadInsts.push_back(Bitcast);
370 continue;
371 }
372 // If load has mutli-user, duplicate a vector load.
373 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
374 // %2 = bitcast <256 x i32> %src to x86_amx
375 // %add = add <256 x i32> %src, <256 x i32> %src2
376 // -->
377 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
378 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
379 // i8* %addr, i64 %stride64)
380 // %add = add <256 x i32> %src, <256 x i32> %src2
381
382 // If load has one user, the load will be eliminated in DAG ISel.
383 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
384 // %2 = bitcast <256 x i32> %src to x86_amx
385 // -->
386 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
387 // i8* %addr, i64 %stride64)
388 combineLoadBitcast(LD, Bitcast);
389 DeadInsts.push_back(Bitcast);
390 if (LD->hasOneUse())
391 DeadInsts.push_back(LD);
392 } else if (Src->getType()->isX86_AMXTy()) {
393 if (Bitcast->user_empty()) {
394 DeadInsts.push_back(Bitcast);
395 continue;
396 }
397 StoreInst *ST = nullptr;
398 for (Use &U : Bitcast->uses()) {
399 ST = dyn_cast<StoreInst>(U.getUser());
400 if (ST)
401 break;
402 }
403 if (!ST) {
404 if (transformBitcast(Bitcast))
405 DeadInsts.push_back(Bitcast);
406 continue;
407 }
408 // If bitcast (%13) has one use, combine bitcast and store to amx store.
409 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
410 // %stride);
411 // %13 = bitcast x86_amx %src to <256 x i32>
412 // store <256 x i32> %13, <256 x i32>* %addr, align 64
413 // -->
414 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
415 // %stride64, %13)
416 //
417 // If bitcast (%13) has multi-use, transform as below.
418 // %13 = bitcast x86_amx %src to <256 x i32>
419 // store <256 x i32> %13, <256 x i32>* %addr, align 64
420 // %add = <256 x i32> %13, <256 x i32> %src2
421 // -->
422 // %13 = bitcast x86_amx %src to <256 x i32>
423 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
424 // %stride64, %13)
425 // %14 = load <256 x i32>, %addr
426 // %add = <256 x i32> %14, <256 x i32> %src2
427 //
428 combineBitcastStore(Bitcast, ST);
429 // Delete user first.
430 DeadInsts.push_back(ST);
431 DeadInsts.push_back(Bitcast);
432 }
433 }
434 }
435
436 bool C = !DeadInsts.empty();
437
438 for (auto *Inst : DeadInsts)
439 Inst->eraseFromParent();
440
441 return C;
442}
443} // anonymous namespace
444
445static Value *getAllocaPos(BasicBlock *BB) {
446 Module *M = BB->getModule();
447 Function *F = BB->getParent();
448 IRBuilder<> Builder(&F->getEntryBlock().front());
449 const DataLayout &DL = M->getDataLayout();
450 unsigned AllocaAS = DL.getAllocaAddrSpace();
451 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
452 AllocaInst *AllocaRes =
453 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
454 BasicBlock::iterator Iter = AllocaRes->getIterator();
455 ++Iter;
456 Builder.SetInsertPoint(&*Iter);
457 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
458 return I8Ptr;
459}
460
461static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
462 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!")(static_cast <bool> (TileDef->getType()->isX86_AMXTy
() && "Not define tile!") ? void (0) : __assert_fail (
"TileDef->getType()->isX86_AMXTy() && \"Not define tile!\""
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 462, __extension__
__PRETTY_FUNCTION__))
;
463 auto *II = cast<IntrinsicInst>(TileDef);
464 assert(II && "Not tile intrinsic!")(static_cast <bool> (II && "Not tile intrinsic!"
) ? void (0) : __assert_fail ("II && \"Not tile intrinsic!\""
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 464, __extension__
__PRETTY_FUNCTION__))
;
465 Value *Row = II->getOperand(0);
466 Value *Col = II->getOperand(1);
467
468 BasicBlock *BB = TileDef->getParent();
469 BasicBlock::iterator Iter = TileDef->getIterator();
470 IRBuilder<> Builder(BB, ++Iter);
471 Value *Stride = Builder.getInt64(64);
472 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
473
474 Instruction *TileStore =
475 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
476 return TileStore;
477}
478
479static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
480 Value *V = U.get();
481 assert(V->getType()->isX86_AMXTy() && "Not define tile!")(static_cast <bool> (V->getType()->isX86_AMXTy() &&
"Not define tile!") ? void (0) : __assert_fail ("V->getType()->isX86_AMXTy() && \"Not define tile!\""
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 481, __extension__
__PRETTY_FUNCTION__))
;
14
'?' condition is true
482
483 // Get tile shape.
484 IntrinsicInst *II = nullptr;
485 if (IsPHI
14.1
'IsPHI' is true
) {
15
Taking true branch
486 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
16
Assuming 'V' is not a 'CastReturnType'
17
Called C++ object pointer is null
487 II = cast<IntrinsicInst>(PhiOp);
488 } else {
489 II = cast<IntrinsicInst>(V);
490 }
491 Value *Row = II->getOperand(0);
492 Value *Col = II->getOperand(1);
493
494 Instruction *UserI = dyn_cast<Instruction>(U.getUser());
495 IRBuilder<> Builder(UserI);
496 Value *Stride = Builder.getInt64(64);
497 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
498
499 Value *TileLoad =
500 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
501 UserI->replaceUsesOfWith(V, TileLoad);
502}
503
504static bool isIncomingOfPHI(Instruction *I) {
505 for (Use &U : I->uses()) {
506 User *V = U.getUser();
507 if (isa<PHINode>(V))
508 return true;
509 }
510 return false;
511}
512
513// Let all AMX tile data become volatile data, shorten the life range
514// of each tile register before fast register allocation.
515namespace {
516class X86VolatileTileData {
517 Function &F;
518
519public:
520 X86VolatileTileData(Function &Func) : F(Func) {}
521 Value *updatePhiIncomings(BasicBlock *BB,
522 SmallVector<Instruction *, 2> &Incomings);
523 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
524 bool volatileTileData();
525 void volatileTilePHI(PHINode *Inst);
526 void volatileTileNonPHI(Instruction *I);
527};
528
529Value *X86VolatileTileData::updatePhiIncomings(
530 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
531 Value *I8Ptr = getAllocaPos(BB);
532
533 for (auto *I : Incomings) {
534 User *Store = createTileStore(I, I8Ptr);
535
536 // All its uses (except phi) should load from stored mem.
537 for (Use &U : I->uses()) {
538 User *V = U.getUser();
539 if (isa<PHINode>(V) || V == Store)
540 continue;
541 replaceWithTileLoad(U, I8Ptr);
542 }
543 }
544 return I8Ptr;
545}
546
547void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
548 Value *StorePtr) {
549 for (Use &U : PHI->uses())
550 replaceWithTileLoad(U, StorePtr, true);
13
Calling 'replaceWithTileLoad'
551 PHI->eraseFromParent();
552}
553
554// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
555// and their related AMX intrinsics.
556// 1) PHI Def should change to tileload.
557// 2) PHI Incoming Values should tilestored in just after their def.
558// 3) The mem of these tileload and tilestores should be same.
559// e.g.
560// ------------------------------------------------------
561// bb_dom:
562// ...
563// br i1 %bool.cond, label %if.else, label %if.then
564//
565// if.then:
566// def %t0 = ...
567// ...
568// use %t0
569// ...
570// br label %if.end
571//
572// if.else:
573// def %t1 = ...
574// br label %if.end
575//
576// if.end:
577// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
578// ...
579// use %td
580// ------------------------------------------------------
581// -->
582// ------------------------------------------------------
583// bb_entry:
584// %mem = alloca <256 x i32>, align 1024 *
585// ...
586// bb_dom:
587// ...
588// br i1 %bool.cond, label %if.else, label %if.then
589//
590// if.then:
591// def %t0 = ...
592// call void @llvm.x86.tilestored64.internal(mem, %t0) *
593// ...
594// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
595// use %t0` *
596// ...
597// br label %if.end
598//
599// if.else:
600// def %t1 = ...
601// call void @llvm.x86.tilestored64.internal(mem, %t1) *
602// br label %if.end
603//
604// if.end:
605// ...
606// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
607// use %td
608// ------------------------------------------------------
609void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
610 BasicBlock *BB = PHI->getParent();
611 SmallVector<Instruction *, 2> Incomings;
612
613 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
10
Assuming 'I' is equal to 'E'
11
Loop condition is false. Execution continues on line 620
614 Value *Op = PHI->getIncomingValue(I);
615 Instruction *Inst = dyn_cast<Instruction>(Op);
616 assert(Inst && "We shouldn't fold AMX instrution!")(static_cast <bool> (Inst && "We shouldn't fold AMX instrution!"
) ? void (0) : __assert_fail ("Inst && \"We shouldn't fold AMX instrution!\""
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 616, __extension__
__PRETTY_FUNCTION__))
;
617 Incomings.push_back(Inst);
618 }
619
620 Value *StorePtr = updatePhiIncomings(BB, Incomings);
621 replacePhiDefWithLoad(PHI, StorePtr);
12
Calling 'X86VolatileTileData::replacePhiDefWithLoad'
622}
623
624// Store the defined tile and load it before use.
625// All its users are not PHI.
626// e.g.
627// ------------------------------------------------------
628// def %td = ...
629// ...
630// "use %td"
631// ------------------------------------------------------
632// -->
633// ------------------------------------------------------
634// def %td = ...
635// call void @llvm.x86.tilestored64.internal(mem, %td)
636// ...
637// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
638// "use %td2"
639// ------------------------------------------------------
640void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
641 BasicBlock *BB = I->getParent();
642 Value *I8Ptr = getAllocaPos(BB);
643 User *Store = createTileStore(I, I8Ptr);
644
645 // All its uses should load from stored mem.
646 for (Use &U : I->uses()) {
647 User *V = U.getUser();
648 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!")(static_cast <bool> (!isa<PHINode>(V) && "PHI Nodes should be excluded!"
) ? void (0) : __assert_fail ("!isa<PHINode>(V) && \"PHI Nodes should be excluded!\""
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 648, __extension__
__PRETTY_FUNCTION__))
;
649 if (V != Store)
650 replaceWithTileLoad(U, I8Ptr);
651 }
652}
653
654// Volatile Tile Model:
655// 1) All the uses of tile data comes from tileload in time.
656// 2) All the defs of tile data tilestore into mem immediately.
657// For example:
658// --------------------------------------------------------------------------
659// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
660// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
661// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
662// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
663// call void @llvm.x86.tilestored64.internal(... td) area
664// --------------------------------------------------------------------------
665// 3) No terminator, call or other amx instructions in the key amx area.
666bool X86VolatileTileData::volatileTileData() {
667 bool Changed = false;
668 for (BasicBlock &BB : F) {
669 SmallVector<Instruction *, 2> PHIInsts;
670 SmallVector<Instruction *, 8> AMXDefInsts;
671
672 for (Instruction &I : BB) {
673 if (!I.getType()->isX86_AMXTy())
674 continue;
675 if (isa<PHINode>(&I))
676 PHIInsts.push_back(&I);
677 else
678 AMXDefInsts.push_back(&I);
679 }
680
681 // First we "volatile" the non-phi related amx intrinsics.
682 for (Instruction *I : AMXDefInsts) {
6
Assuming '__begin2' is equal to '__end2'
683 if (isIncomingOfPHI(I))
684 continue;
685 volatileTileNonPHI(I);
686 Changed = true;
687 }
688
689 for (Instruction *I : PHIInsts) {
7
Assuming '__begin2' is not equal to '__end2'
690 volatileTilePHI(dyn_cast<PHINode>(I));
8
Assuming 'I' is a 'CastReturnType'
9
Calling 'X86VolatileTileData::volatileTilePHI'
691 Changed = true;
692 }
693 }
694 return Changed;
695}
696
697} // anonymous namespace
698
699namespace {
700
701class X86LowerAMXCast {
702 Function &Func;
703
704public:
705 X86LowerAMXCast(Function &F) : Func(F) {}
706 void combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
707 void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
708 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
709 bool combineAMXcast(TargetLibraryInfo *TLI);
710 bool transformAMXCast(IntrinsicInst *AMXCast);
711 bool transformAllAMXCast();
712 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
713 SmallSetVector<Instruction *, 16> &DeadInst);
714};
715
716static bool DCEInstruction(Instruction *I,
717 SmallSetVector<Instruction *, 16> &WorkList,
718 const TargetLibraryInfo *TLI) {
719 if (isInstructionTriviallyDead(I, TLI)) {
720 salvageDebugInfo(*I);
721 salvageKnowledge(I);
722
723 // Null out all of the instruction's operands to see if any operand becomes
724 // dead as we go.
725 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
726 Value *OpV = I->getOperand(i);
727 I->setOperand(i, nullptr);
728
729 if (!OpV->use_empty() || I == OpV)
730 continue;
731
732 // If the operand is an instruction that became dead as we nulled out the
733 // operand, and if it is 'trivially' dead, delete it in a future loop
734 // iteration.
735 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
736 if (isInstructionTriviallyDead(OpI, TLI)) {
737 WorkList.insert(OpI);
738 }
739 }
740 }
741 I->eraseFromParent();
742 return true;
743 }
744 return false;
745}
746
747/// This function handles following case
748///
749/// A -> B amxcast
750/// PHI
751/// B -> A amxcast
752///
753/// All the related PHI nodes can be replaced by new PHI nodes with type A.
754/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
755bool X86LowerAMXCast::optimizeAMXCastFromPhi(
756 IntrinsicInst *CI, PHINode *PN,
757 SmallSetVector<Instruction *, 16> &DeadInst) {
758 IRBuilder<> Builder(CI);
759 Value *Src = CI->getOperand(0);
760 Type *SrcTy = Src->getType(); // Type B
761 Type *DestTy = CI->getType(); // Type A
762
763 SmallVector<PHINode *, 4> PhiWorklist;
764 SmallSetVector<PHINode *, 4> OldPhiNodes;
765
766 // Find all of the A->B casts and PHI nodes.
767 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
768 // OldPhiNodes is used to track all known PHI nodes, before adding a new
769 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
770 PhiWorklist.push_back(PN);
771 OldPhiNodes.insert(PN);
772 while (!PhiWorklist.empty()) {
773 auto *OldPN = PhiWorklist.pop_back_val();
774 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
775 Value *IncValue = OldPN->getIncomingValue(I);
776 // TODO: currently, We ignore cases where it is a const. In the future, we
777 // might support const.
778 if (isa<Constant>(IncValue)) {
779 auto *IncConst = dyn_cast<Constant>(IncValue);
780 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
781 return false;
782 Value *Row = nullptr, *Col = nullptr;
783 std::tie(Row, Col) = getShape(OldPN);
784 // TODO: If it is not constant the Row and Col must domoniate tilezero
785 // that we are going to create.
786 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
787 return false;
788 // Create tilezero at the end of incoming block.
789 auto *Block = OldPN->getIncomingBlock(I);
790 BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
791 Instruction *NewInst = Builder.CreateIntrinsic(
792 Intrinsic::x86_tilezero_internal, None, {Row, Col});
793 NewInst->moveBefore(&*Iter);
794 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
795 {IncValue->getType()}, {NewInst});
796 NewInst->moveBefore(&*Iter);
797 // Replace InValue with new Value.
798 OldPN->setIncomingValue(I, NewInst);
799 IncValue = NewInst;
800 }
801
802 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
803 if (OldPhiNodes.insert(PNode))
804 PhiWorklist.push_back(PNode);
805 continue;
806 }
807 Instruction *ACI = dyn_cast<Instruction>(IncValue);
808 if (ACI && isAMXCast(ACI)) {
809 // Verify it's a A->B cast.
810 Type *TyA = ACI->getOperand(0)->getType();
811 Type *TyB = ACI->getType();
812 if (TyA != DestTy || TyB != SrcTy)
813 return false;
814 continue;
815 }
816 return false;
817 }
818 }
819
820 // Check that each user of each old PHI node is something that we can
821 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
822 for (auto *OldPN : OldPhiNodes) {
823 for (User *V : OldPN->users()) {
824 Instruction *ACI = dyn_cast<Instruction>(V);
825 if (ACI && isAMXCast(ACI)) {
826 // Verify it's a B->A cast.
827 Type *TyB = ACI->getOperand(0)->getType();
828 Type *TyA = ACI->getType();
829 if (TyA != DestTy || TyB != SrcTy)
830 return false;
831 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
832 // As long as the user is another old PHI node, then even if we don't
833 // rewrite it, the PHI web we're considering won't have any users
834 // outside itself, so it'll be dead.
835 // example:
836 // bb.0:
837 // %0 = amxcast ...
838 // bb.1:
839 // %1 = amxcast ...
840 // bb.2:
841 // %goodphi = phi %0, %1
842 // %3 = amxcast %goodphi
843 // bb.3:
844 // %goodphi2 = phi %0, %goodphi
845 // %4 = amxcast %goodphi2
846 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
847 // outside the phi-web, so the combination stop When
848 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
849 // will be done.
850 if (OldPhiNodes.count(PHI) == 0)
851 return false;
852 } else
853 return false;
854 }
855 }
856
857 // For each old PHI node, create a corresponding new PHI node with a type A.
858 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
859 for (auto *OldPN : OldPhiNodes) {
860 Builder.SetInsertPoint(OldPN);
861 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
862 NewPNodes[OldPN] = NewPN;
863 }
864
865 // Fill in the operands of new PHI nodes.
866 for (auto *OldPN : OldPhiNodes) {
867 PHINode *NewPN = NewPNodes[OldPN];
868 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
869 Value *V = OldPN->getOperand(j);
870 Value *NewV = nullptr;
871 Instruction *ACI = dyn_cast<Instruction>(V);
872 // There should not be a AMXcast from a const.
873 if (ACI && isAMXCast(ACI))
874 NewV = ACI->getOperand(0);
875 else if (auto *PrevPN = dyn_cast<PHINode>(V))
876 NewV = NewPNodes[PrevPN];
877 assert(NewV)(static_cast <bool> (NewV) ? void (0) : __assert_fail (
"NewV", "llvm/lib/Target/X86/X86LowerAMXType.cpp", 877, __extension__
__PRETTY_FUNCTION__))
;
878 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
879 }
880 }
881
882 // Traverse all accumulated PHI nodes and process its users,
883 // which are Stores and BitcCasts. Without this processing
884 // NewPHI nodes could be replicated and could lead to extra
885 // moves generated after DeSSA.
886 // If there is a store with type B, change it to type A.
887
888 // Replace users of BitCast B->A with NewPHI. These will help
889 // later to get rid of a closure formed by OldPHI nodes.
890 for (auto *OldPN : OldPhiNodes) {
891 PHINode *NewPN = NewPNodes[OldPN];
892 for (User *V : make_early_inc_range(OldPN->users())) {
893 Instruction *ACI = dyn_cast<Instruction>(V);
894 if (ACI && isAMXCast(ACI)) {
895 Type *TyB = ACI->getOperand(0)->getType();
896 Type *TyA = ACI->getType();
897 assert(TyA == DestTy && TyB == SrcTy)(static_cast <bool> (TyA == DestTy && TyB == SrcTy
) ? void (0) : __assert_fail ("TyA == DestTy && TyB == SrcTy"
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 897, __extension__
__PRETTY_FUNCTION__))
;
898 (void)TyA;
899 (void)TyB;
900 ACI->replaceAllUsesWith(NewPN);
901 DeadInst.insert(ACI);
902 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
903 // We don't need to push PHINode into DeadInst since they are operands
904 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
905 assert(OldPhiNodes.contains(PHI))(static_cast <bool> (OldPhiNodes.contains(PHI)) ? void (
0) : __assert_fail ("OldPhiNodes.contains(PHI)", "llvm/lib/Target/X86/X86LowerAMXType.cpp"
, 905, __extension__ __PRETTY_FUNCTION__))
;
906 (void)PHI;
907 } else
908 llvm_unreachable("all uses should be handled")::llvm::llvm_unreachable_internal("all uses should be handled"
, "llvm/lib/Target/X86/X86LowerAMXType.cpp", 908)
;
909 }
910 }
911 return true;
912}
913
914// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
915// store <256 x i32> %43, <256 x i32>* %p, align 64
916// -->
917// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
918// i64 64, x86_amx %42)
919void X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
920 Value *Tile = Cast->getOperand(0);
921 // TODO: If it is cast intrinsic or phi node, we can propagate the
922 // shape information through def-use chain.
923 if (!isAMXIntrinsic(Tile))
924 return;
925 auto *II = cast<IntrinsicInst>(Tile);
926 // Tile is output from AMX intrinsic. The first operand of the
927 // intrinsic is row, the second operand of the intrinsic is column.
928 Value *Row = II->getOperand(0);
929 Value *Col = II->getOperand(1);
930 IRBuilder<> Builder(ST);
931 // Use the maximum column as stride. It must be the same with load
932 // stride.
933 Value *Stride = Builder.getInt64(64);
934 Value *I8Ptr =
935 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
936 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
937 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
938}
939
940// %65 = load <256 x i32>, <256 x i32>* %p, align 64
941// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
942// -->
943// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
944// i8* %p, i64 64)
945void X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
946 Value *Row = nullptr, *Col = nullptr;
947 Use &U = *(Cast->use_begin());
948 unsigned OpNo = U.getOperandNo();
949 auto *II = cast<IntrinsicInst>(U.getUser());
950 // TODO: If it is cast intrinsic or phi node, we can propagate the
951 // shape information through def-use chain.
952 if (!isAMXIntrinsic(II))
953 return;
954 std::tie(Row, Col) = getShape(II, OpNo);
955 IRBuilder<> Builder(LD);
956 // Use the maximun column as stride.
957 Value *Stride = Builder.getInt64(64);
958 Value *I8Ptr =
959 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
960 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
961
962 Value *NewInst =
963 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
964 Cast->replaceAllUsesWith(NewInst);
965}
966
967bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
968 bool Change = false;
969 for (auto *Cast : Casts) {
970 auto *II = cast<IntrinsicInst>(Cast);
971 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
972 // store <256 x i32> %43, <256 x i32>* %p, align 64
973 // -->
974 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
975 // i64 64, x86_amx %42)
976 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
977 SmallVector<Instruction *, 2> DeadStores;
978 for (User *U : Cast->users()) {
979 StoreInst *Store = dyn_cast<StoreInst>(U);
980 if (!Store)
981 continue;
982 combineCastStore(cast<IntrinsicInst>(Cast), Store);
983 DeadStores.push_back(Store);
984 Change = true;
985 }
986 for (auto *Store : DeadStores)
987 Store->eraseFromParent();
988 } else { // x86_cast_vector_to_tile
989 SmallVector<Instruction *, 2> DeadLoads;
990 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
991 if (!Load || !Load->hasOneUse())
992 continue;
993 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
994 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
995 // -->
996 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
997 // i8* %p, i64 64)
998 combineLoadCast(cast<IntrinsicInst>(Cast), Load);
999 // Set the operand is null so that load instruction can be erased.
1000 Cast->setOperand(0, nullptr);
1001 Load->eraseFromParent();
1002 }
1003 }
1004 return Change;
1005}
1006
1007bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1008 bool Change = false;
1009 // Collect tile cast instruction.
1010 SmallVector<Instruction *, 8> Vec2TileInsts;
1011 SmallVector<Instruction *, 8> Tile2VecInsts;
1012 SmallVector<Instruction *, 8> PhiCastWorkList;
1013 SmallSetVector<Instruction *, 16> DeadInst;
1014 for (BasicBlock &BB : Func) {
1015 for (Instruction &I : BB) {
1016 Value *Vec;
1017 if (match(&I,
1018 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1019 Vec2TileInsts.push_back(&I);
1020 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1021 m_Value(Vec))))
1022 Tile2VecInsts.push_back(&I);
1023 }
1024 }
1025
1026 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1027 for (auto *Inst : Insts) {
1028 for (User *U : Inst->users()) {
1029 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1030 if (!II || II->getIntrinsicID() != IID)
1031 continue;
1032 // T1 = vec2tile V0
1033 // V2 = tile2vec T1
1034 // V3 = OP V2
1035 // -->
1036 // T1 = vec2tile V0
1037 // V2 = tile2vec T1
1038 // V3 = OP V0
1039 II->replaceAllUsesWith(Inst->getOperand(0));
1040 Change = true;
1041 }
1042 }
1043 };
1044
1045 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1046 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1047
1048 SmallVector<Instruction *, 8> LiveCasts;
1049 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1050 for (auto *Inst : Insts) {
1051 if (Inst->use_empty()) {
1052 Inst->eraseFromParent();
1053 Change = true;
1054 } else {
1055 LiveCasts.push_back(Inst);
1056 }
1057 }
1058 };
1059
1060 EraseInst(Vec2TileInsts);
1061 EraseInst(Tile2VecInsts);
1062 Change |= combineLdSt(LiveCasts);
1063 EraseInst(LiveCasts);
1064
1065 // Handle the A->B->A cast, and there is an intervening PHI node.
1066 for (BasicBlock &BB : Func) {
1067 for (Instruction &I : BB) {
1068 if (isAMXCast(&I)) {
1069 if (isa<PHINode>(I.getOperand(0)))
1070 PhiCastWorkList.push_back(&I);
1071 }
1072 }
1073 }
1074 for (auto *I : PhiCastWorkList) {
1075 // We skip the dead Amxcast.
1076 if (DeadInst.contains(I))
1077 continue;
1078 PHINode *PN = cast<PHINode>(I->getOperand(0));
1079 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1080 DeadInst.insert(PN);
1081 Change = true;
1082 }
1083 }
1084
1085 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1086 // have no uses. We do some DeadCodeElimination for them.
1087 while (!DeadInst.empty()) {
1088 Instruction *I = DeadInst.pop_back_val();
1089 Change |= DCEInstruction(I, DeadInst, TLI);
1090 }
1091 return Change;
1092}
1093
1094// There might be remaining AMXcast after combineAMXcast and they should be
1095// handled elegantly.
1096bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1097 IRBuilder<> Builder(AMXCast);
1098 AllocaInst *AllocaAddr;
1099 Value *I8Ptr, *Stride;
1100 auto *Src = AMXCast->getOperand(0);
1101
1102 auto Prepare = [&](Type *MemTy) {
1103 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1104 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
1105 Stride = Builder.getInt64(64);
1106 };
1107
1108 if (AMXCast->getType()->isX86_AMXTy()) {
1109 // %2 = amxcast <225 x i32> %src to x86_amx
1110 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1111 // i8* %addr3, i64 60, x86_amx %2)
1112 // -->
1113 // %addr = alloca <225 x i32>, align 64
1114 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1115 // %addr2 = bitcast <225 x i32>* %addr to i8*
1116 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1117 // i8* %addr2,
1118 // i64 60)
1119 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1120 // i8* %addr3, i64 60, x86_amx %2)
1121 if (AMXCast->use_empty()) {
1122 AMXCast->eraseFromParent();
1123 return true;
1124 }
1125 Use &U = *(AMXCast->use_begin());
1126 unsigned OpNo = U.getOperandNo();
1127 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1128 if (!II)
1129 return false; // May be bitcast from x86amx to <256 x i32>.
1130 Prepare(AMXCast->getOperand(0)->getType());
1131 Builder.CreateStore(Src, AllocaAddr);
1132 // TODO we can pick an constant operand for the shape.
1133 Value *Row = nullptr, *Col = nullptr;
1134 std::tie(Row, Col) = getShape(II, OpNo);
1135 std::array<Value *, 4> Args = {
1136 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1137 Value *NewInst = Builder.CreateIntrinsic(
1138 Intrinsic::x86_tileloadd64_internal, None, Args);
1139 AMXCast->replaceAllUsesWith(NewInst);
1140 AMXCast->eraseFromParent();
1141 } else {
1142 // %2 = amxcast x86_amx %src to <225 x i32>
1143 // -->
1144 // %addr = alloca <225 x i32>, align 64
1145 // %addr2 = bitcast <225 x i32>* to i8*
1146 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1147 // i8* %addr2, i64 %stride)
1148 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1149 auto *II = dyn_cast<IntrinsicInst>(Src);
1150 if (!II)
1151 return false; // May be bitcast from <256 x i32> to x86amx.
1152 Prepare(AMXCast->getType());
1153 Value *Row = II->getOperand(0);
1154 Value *Col = II->getOperand(1);
1155 std::array<Value *, 5> Args = {
1156 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1157 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
1158 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1159 AMXCast->replaceAllUsesWith(NewInst);
1160 AMXCast->eraseFromParent();
1161 }
1162
1163 return true;
1164}
1165
1166bool X86LowerAMXCast::transformAllAMXCast() {
1167 bool Change = false;
1168 // Collect tile cast instruction.
1169 SmallVector<Instruction *, 8> WorkLists;
1170 for (BasicBlock &BB : Func) {
1171 for (Instruction &I : BB) {
1172 if (isAMXCast(&I))
1173 WorkLists.push_back(&I);
1174 }
1175 }
1176
1177 for (auto *Inst : WorkLists) {
1178 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1179 }
1180
1181 return Change;
1182}
1183
1184} // anonymous namespace
1185
1186namespace {
1187
1188class X86LowerAMXTypeLegacyPass : public FunctionPass {
1189public:
1190 static char ID;
1191
1192 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1193 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1194 }
1195
1196 bool runOnFunction(Function &F) override {
1197 bool C = false;
1198 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1199 TargetLibraryInfo *TLI =
1200 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1201 X86LowerAMXCast LAC(F);
1202 C |= LAC.combineAMXcast(TLI);
1203 // There might be remaining AMXcast after combineAMXcast and they should be
1204 // handled elegantly.
1205 C |= LAC.transformAllAMXCast();
1206
1207 X86LowerAMXType LAT(F);
1208 C |= LAT.visit();
1209
1210 // Prepare for fast register allocation at O0.
1211 // Todo: May better check the volatile model of AMX code, not just
1212 // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1213 if (TM->getOptLevel() == CodeGenOpt::None) {
1
Assuming the condition is true
2
Taking true branch
1214 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1215 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1216 // sure the amx data is volatile, that is nessary for AMX fast
1217 // register allocation.
1218 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
3
Assuming the condition is true
4
Taking true branch
1219 X86VolatileTileData VTD(F);
1220 C = VTD.volatileTileData() || C;
5
Calling 'X86VolatileTileData::volatileTileData'
1221 }
1222 }
1223
1224 return C;
1225 }
1226
1227 void getAnalysisUsage(AnalysisUsage &AU) const override {
1228 AU.setPreservesCFG();
1229 AU.addRequired<TargetPassConfig>();
1230 AU.addRequired<TargetLibraryInfoWrapperPass>();
1231 }
1232};
1233
1234} // anonymous namespace
1235
1236static const char PassName[] = "Lower AMX type for load/store";
1237char X86LowerAMXTypeLegacyPass::ID = 0;
1238INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry
&Registry) {
1239 false)static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry
&Registry) {
1240INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)initializeTargetPassConfigPass(Registry);
1241INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)initializeTargetLibraryInfoWrapperPassPass(Registry);
1242INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,PassInfo *PI = new PassInfo( PassName, "lower-amx-type", &
X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor
<X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass
(*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag
; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag
, initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry
)); }
1243 false)PassInfo *PI = new PassInfo( PassName, "lower-amx-type", &
X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor
<X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass
(*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag
; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag
, initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry
)); }
1244
1245FunctionPass *llvm::createX86LowerAMXTypePass() {
1246 return new X86LowerAMXTypeLegacyPass();
1247}