LLVM 19.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallSet.h"
48#include "llvm/CodeGen/Passes.h"
51#include "llvm/IR/DataLayout.h"
52#include "llvm/IR/Function.h"
53#include "llvm/IR/IRBuilder.h"
56#include "llvm/IR/IntrinsicsX86.h"
59#include "llvm/Pass.h"
63
64#include <map>
65
66using namespace llvm;
67using namespace PatternMatch;
68
69#define DEBUG_TYPE "lower-amx-type"
70
71static bool isAMXCast(Instruction *II) {
72 return match(II,
73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75}
76
77static bool isAMXIntrinsic(Value *I) {
78 auto *II = dyn_cast<IntrinsicInst>(I);
79 if (!II)
80 return false;
81 if (isAMXCast(II))
82 return false;
83 // Check if return type or parameter is x86_amx. If it is x86_amx
84 // the intrinsic must be x86 amx intrinsics.
85 if (II->getType()->isX86_AMXTy())
86 return true;
87 for (Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
89 return true;
90 }
91
92 return false;
93}
94
96 Type *Ty) {
97 Function &F = *BB->getParent();
98 Module *M = BB->getModule();
99 const DataLayout &DL = M->getDataLayout();
100
101 LLVMContext &Ctx = Builder.getContext();
102 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
103 unsigned AllocaAS = DL.getAllocaAddrSpace();
104 AllocaInst *AllocaRes =
105 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
106 AllocaRes->setAlignment(AllocaAlignment);
107 return AllocaRes;
108}
109
111 for (Instruction &I : F.getEntryBlock())
112 if (!isa<AllocaInst>(&I))
113 return &I;
114 llvm_unreachable("No terminator in the entry block!");
115}
116
117static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
118 IRBuilder<> Builder(II);
119 Value *Row = nullptr, *Col = nullptr;
120 switch (II->getIntrinsicID()) {
121 default:
122 llvm_unreachable("Expect amx intrinsics");
123 case Intrinsic::x86_tileloadd64_internal:
124 case Intrinsic::x86_tileloaddt164_internal:
125 case Intrinsic::x86_tilestored64_internal: {
126 Row = II->getArgOperand(0);
127 Col = II->getArgOperand(1);
128 break;
129 }
130 // a * b + c
131 // The shape depends on which operand.
132 case Intrinsic::x86_tcmmimfp16ps_internal:
133 case Intrinsic::x86_tcmmrlfp16ps_internal:
134 case Intrinsic::x86_tdpbssd_internal:
135 case Intrinsic::x86_tdpbsud_internal:
136 case Intrinsic::x86_tdpbusd_internal:
137 case Intrinsic::x86_tdpbuud_internal:
138 case Intrinsic::x86_tdpbf16ps_internal:
139 case Intrinsic::x86_tdpfp16ps_internal: {
140 switch (OpNo) {
141 case 3:
142 Row = II->getArgOperand(0);
143 Col = II->getArgOperand(1);
144 break;
145 case 4:
146 Row = II->getArgOperand(0);
147 Col = II->getArgOperand(2);
148 break;
149 case 5:
150 if (isa<ConstantInt>(II->getArgOperand(2)))
151 Row = Builder.getInt16(
152 (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
153 else if (isa<Instruction>(II->getArgOperand(2))) {
154 // When it is not a const value and it is not a function argument, we
155 // create Row after the definition of II->getOperand(2) instead of
156 // before II. For example, II is %118, we try to getshape for %117:
157 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
158 // i32> %115).
159 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
160 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
161 // %117).
162 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
163 // definition is after its user(new tileload for %117).
164 // So, the best choice is to create %row right after the definition of
165 // %106.
166 Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
167 Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
168 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
169 } else {
170 // When it is not a const value and it is a function argument, we create
171 // Row at the entry bb.
172 IRBuilder<> NewBuilder(
174 Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
175 }
176 Col = II->getArgOperand(1);
177 break;
178 }
179 break;
180 }
181 }
182
183 return std::make_pair(Row, Col);
184}
185
186static std::pair<Value *, Value *> getShape(PHINode *Phi) {
187 Use &U = *(Phi->use_begin());
188 unsigned OpNo = U.getOperandNo();
189 User *V = U.getUser();
190 // TODO We don't traverse all users. To make the algorithm simple, here we
191 // just traverse the first user. If we can find shape, then return the shape,
192 // otherwise just return nullptr and the optimization for undef/zero will be
193 // abandoned.
194 while (V) {
195 if (isAMXCast(dyn_cast<Instruction>(V))) {
196 if (V->use_empty())
197 break;
198 Use &U = *(V->use_begin());
199 OpNo = U.getOperandNo();
200 V = U.getUser();
201 } else if (isAMXIntrinsic(V)) {
202 return getShape(cast<IntrinsicInst>(V), OpNo);
203 } else if (isa<PHINode>(V)) {
204 if (V->use_empty())
205 break;
206 Use &U = *(V->use_begin());
207 V = U.getUser();
208 } else {
209 break;
210 }
211 }
212
213 return std::make_pair(nullptr, nullptr);
214}
215
216namespace {
217class X86LowerAMXType {
218 Function &Func;
219
220 // In AMX intrinsics we let Shape = {Row, Col}, but the
221 // RealCol = Col / ElementSize. We may use the RealCol
222 // as a new Row for other new created AMX intrinsics.
223 std::map<Value *, Value *> Col2Row;
224
225public:
226 X86LowerAMXType(Function &F) : Func(F) {}
227 bool visit();
228 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
229 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
230 bool transformBitcast(BitCastInst *Bitcast);
231};
232
233// %src = load <256 x i32>, <256 x i32>* %addr, align 64
234// %2 = bitcast <256 x i32> %src to x86_amx
235// -->
236// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
237// i8* %addr, i64 %stride64)
238void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
239 Value *Row = nullptr, *Col = nullptr;
240 Use &U = *(Bitcast->use_begin());
241 unsigned OpNo = U.getOperandNo();
242 auto *II = cast<IntrinsicInst>(U.getUser());
243 std::tie(Row, Col) = getShape(II, OpNo);
244 IRBuilder<> Builder(Bitcast);
245 // Use the maximun column as stride.
246 Value *Stride = Builder.getInt64(64);
247 Value *I8Ptr = LD->getOperand(0);
248 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
249
250 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
251 std::nullopt, Args);
252 Bitcast->replaceAllUsesWith(NewInst);
253}
254
255// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
256// %stride);
257// %13 = bitcast x86_amx %src to <256 x i32>
258// store <256 x i32> %13, <256 x i32>* %addr, align 64
259// -->
260// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
261// %stride64, %13)
262void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
263
264 Value *Tile = Bitcast->getOperand(0);
265 auto *II = cast<IntrinsicInst>(Tile);
266 // Tile is output from AMX intrinsic. The first operand of the
267 // intrinsic is row, the second operand of the intrinsic is column.
268 Value *Row = II->getOperand(0);
269 Value *Col = II->getOperand(1);
270 IRBuilder<> Builder(ST);
271 // Use the maximum column as stride. It must be the same with load
272 // stride.
273 Value *Stride = Builder.getInt64(64);
274 Value *I8Ptr = ST->getOperand(1);
275 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
276 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
277 Args);
278 if (Bitcast->hasOneUse())
279 return;
280 // %13 = bitcast x86_amx %src to <256 x i32>
281 // store <256 x i32> %13, <256 x i32>* %addr, align 64
282 // %add = <256 x i32> %13, <256 x i32> %src2
283 // -->
284 // %13 = bitcast x86_amx %src to <256 x i32>
285 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
286 // %stride64, %13)
287 // %14 = load <256 x i32>, %addr
288 // %add = <256 x i32> %14, <256 x i32> %src2
289 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
290 Bitcast->replaceAllUsesWith(Vec);
291}
292
293// transform bitcast to <store, load> instructions.
294bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
295 IRBuilder<> Builder(Bitcast);
296 AllocaInst *AllocaAddr;
297 Value *I8Ptr, *Stride;
298 auto *Src = Bitcast->getOperand(0);
299
300 auto Prepare = [&](Type *MemTy) {
301 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
302 I8Ptr = AllocaAddr;
303 Stride = Builder.getInt64(64);
304 };
305
306 if (Bitcast->getType()->isX86_AMXTy()) {
307 // %2 = bitcast <256 x i32> %src to x86_amx
308 // -->
309 // %addr = alloca <256 x i32>, align 64
310 // store <256 x i32> %src, <256 x i32>* %addr, align 64
311 // %addr2 = bitcast <256 x i32>* to i8*
312 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
313 // i8* %addr2,
314 // i64 64)
315 Use &U = *(Bitcast->use_begin());
316 unsigned OpNo = U.getOperandNo();
317 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
318 if (!II)
319 return false; // May be bitcast from x86amx to <256 x i32>.
320 Prepare(Bitcast->getOperand(0)->getType());
321 Builder.CreateStore(Src, AllocaAddr);
322 // TODO we can pick an constant operand for the shape.
323 Value *Row = nullptr, *Col = nullptr;
324 std::tie(Row, Col) = getShape(II, OpNo);
325 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
326 Value *NewInst = Builder.CreateIntrinsic(
327 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
328 Bitcast->replaceAllUsesWith(NewInst);
329 } else {
330 // %2 = bitcast x86_amx %src to <256 x i32>
331 // -->
332 // %addr = alloca <256 x i32>, align 64
333 // %addr2 = bitcast <256 x i32>* to i8*
334 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
335 // i8* %addr2, i64 %stride)
336 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
337 auto *II = dyn_cast<IntrinsicInst>(Src);
338 if (!II)
339 return false; // May be bitcast from <256 x i32> to x86amx.
340 Prepare(Bitcast->getType());
341 Value *Row = II->getOperand(0);
342 Value *Col = II->getOperand(1);
343 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
344 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
345 Args);
346 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
347 Bitcast->replaceAllUsesWith(NewInst);
348 }
349
350 return true;
351}
352
353bool X86LowerAMXType::visit() {
355 Col2Row.clear();
356
357 for (BasicBlock *BB : post_order(&Func)) {
359 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
360 if (!Bitcast)
361 continue;
362
363 Value *Src = Bitcast->getOperand(0);
364 if (Bitcast->getType()->isX86_AMXTy()) {
365 if (Bitcast->user_empty()) {
366 DeadInsts.push_back(Bitcast);
367 continue;
368 }
369 LoadInst *LD = dyn_cast<LoadInst>(Src);
370 if (!LD) {
371 if (transformBitcast(Bitcast))
372 DeadInsts.push_back(Bitcast);
373 continue;
374 }
375 // If load has mutli-user, duplicate a vector load.
376 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
377 // %2 = bitcast <256 x i32> %src to x86_amx
378 // %add = add <256 x i32> %src, <256 x i32> %src2
379 // -->
380 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
381 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
382 // i8* %addr, i64 %stride64)
383 // %add = add <256 x i32> %src, <256 x i32> %src2
384
385 // If load has one user, the load will be eliminated in DAG ISel.
386 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
387 // %2 = bitcast <256 x i32> %src to x86_amx
388 // -->
389 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
390 // i8* %addr, i64 %stride64)
391 combineLoadBitcast(LD, Bitcast);
392 DeadInsts.push_back(Bitcast);
393 if (LD->hasOneUse())
394 DeadInsts.push_back(LD);
395 } else if (Src->getType()->isX86_AMXTy()) {
396 if (Bitcast->user_empty()) {
397 DeadInsts.push_back(Bitcast);
398 continue;
399 }
400 StoreInst *ST = nullptr;
401 for (Use &U : Bitcast->uses()) {
402 ST = dyn_cast<StoreInst>(U.getUser());
403 if (ST)
404 break;
405 }
406 if (!ST) {
407 if (transformBitcast(Bitcast))
408 DeadInsts.push_back(Bitcast);
409 continue;
410 }
411 // If bitcast (%13) has one use, combine bitcast and store to amx store.
412 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
413 // %stride);
414 // %13 = bitcast x86_amx %src to <256 x i32>
415 // store <256 x i32> %13, <256 x i32>* %addr, align 64
416 // -->
417 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
418 // %stride64, %13)
419 //
420 // If bitcast (%13) has multi-use, transform as below.
421 // %13 = bitcast x86_amx %src to <256 x i32>
422 // store <256 x i32> %13, <256 x i32>* %addr, align 64
423 // %add = <256 x i32> %13, <256 x i32> %src2
424 // -->
425 // %13 = bitcast x86_amx %src to <256 x i32>
426 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
427 // %stride64, %13)
428 // %14 = load <256 x i32>, %addr
429 // %add = <256 x i32> %14, <256 x i32> %src2
430 //
431 combineBitcastStore(Bitcast, ST);
432 // Delete user first.
433 DeadInsts.push_back(ST);
434 DeadInsts.push_back(Bitcast);
435 }
436 }
437 }
438
439 bool C = !DeadInsts.empty();
440
441 for (auto *Inst : DeadInsts)
442 Inst->eraseFromParent();
443
444 return C;
445}
446} // anonymous namespace
447
449 Module *M = BB->getModule();
450 Function *F = BB->getParent();
451 IRBuilder<> Builder(&F->getEntryBlock().front());
452 const DataLayout &DL = M->getDataLayout();
453 unsigned AllocaAS = DL.getAllocaAddrSpace();
454 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
455 AllocaInst *AllocaRes =
456 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
457 BasicBlock::iterator Iter = AllocaRes->getIterator();
458 ++Iter;
459 Builder.SetInsertPoint(&*Iter);
460 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
461 return I8Ptr;
462}
463
465 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
466 auto *II = cast<IntrinsicInst>(TileDef);
467 assert(II && "Not tile intrinsic!");
468 Value *Row = II->getOperand(0);
469 Value *Col = II->getOperand(1);
470
471 BasicBlock *BB = TileDef->getParent();
472 BasicBlock::iterator Iter = TileDef->getIterator();
473 IRBuilder<> Builder(BB, ++Iter);
474 Value *Stride = Builder.getInt64(64);
475 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
476
477 Instruction *TileStore = Builder.CreateIntrinsic(
478 Intrinsic::x86_tilestored64_internal, std::nullopt, Args);
479 return TileStore;
480}
481
482static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
483 Value *V = U.get();
484 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
485
486 // Get tile shape.
487 IntrinsicInst *II = nullptr;
488 if (IsPHI) {
489 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
490 II = cast<IntrinsicInst>(PhiOp);
491 } else {
492 II = cast<IntrinsicInst>(V);
493 }
494 Value *Row = II->getOperand(0);
495 Value *Col = II->getOperand(1);
496
497 Instruction *UserI = cast<Instruction>(U.getUser());
498 IRBuilder<> Builder(UserI);
499 Value *Stride = Builder.getInt64(64);
500 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
501
502 Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
503 std::nullopt, Args);
504 UserI->replaceUsesOfWith(V, TileLoad);
505}
506
508 for (Use &U : I->uses()) {
509 User *V = U.getUser();
510 if (isa<PHINode>(V))
511 return true;
512 }
513 return false;
514}
515
516// Let all AMX tile data become volatile data, shorten the life range
517// of each tile register before fast register allocation.
518namespace {
519class X86VolatileTileData {
520 Function &F;
521
522public:
523 X86VolatileTileData(Function &Func) : F(Func) {}
524 Value *updatePhiIncomings(BasicBlock *BB,
526 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
527 bool volatileTileData();
528 void volatileTilePHI(PHINode *PHI);
529 void volatileTileNonPHI(Instruction *I);
530};
531
532Value *X86VolatileTileData::updatePhiIncomings(
534 Value *I8Ptr = getAllocaPos(BB);
535
536 for (auto *I : Incomings) {
537 User *Store = createTileStore(I, I8Ptr);
538
539 // All its uses (except phi) should load from stored mem.
540 for (Use &U : I->uses()) {
541 User *V = U.getUser();
542 if (isa<PHINode>(V) || V == Store)
543 continue;
544 replaceWithTileLoad(U, I8Ptr);
545 }
546 }
547 return I8Ptr;
548}
549
550void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
551 Value *StorePtr) {
552 for (Use &U : PHI->uses())
553 replaceWithTileLoad(U, StorePtr, true);
554 PHI->eraseFromParent();
555}
556
557// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
558// and their related AMX intrinsics.
559// 1) PHI Def should change to tileload.
560// 2) PHI Incoming Values should tilestored in just after their def.
561// 3) The mem of these tileload and tilestores should be same.
562// e.g.
563// ------------------------------------------------------
564// bb_dom:
565// ...
566// br i1 %bool.cond, label %if.else, label %if.then
567//
568// if.then:
569// def %t0 = ...
570// ...
571// use %t0
572// ...
573// br label %if.end
574//
575// if.else:
576// def %t1 = ...
577// br label %if.end
578//
579// if.end:
580// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
581// ...
582// use %td
583// ------------------------------------------------------
584// -->
585// ------------------------------------------------------
586// bb_entry:
587// %mem = alloca <256 x i32>, align 1024 *
588// ...
589// bb_dom:
590// ...
591// br i1 %bool.cond, label %if.else, label %if.then
592//
593// if.then:
594// def %t0 = ...
595// call void @llvm.x86.tilestored64.internal(mem, %t0) *
596// ...
597// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
598// use %t0` *
599// ...
600// br label %if.end
601//
602// if.else:
603// def %t1 = ...
604// call void @llvm.x86.tilestored64.internal(mem, %t1) *
605// br label %if.end
606//
607// if.end:
608// ...
609// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
610// use %td
611// ------------------------------------------------------
612void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
613 BasicBlock *BB = PHI->getParent();
615
616 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
617 Value *Op = PHI->getIncomingValue(I);
618 Instruction *Inst = dyn_cast<Instruction>(Op);
619 assert(Inst && "We shouldn't fold AMX instrution!");
620 Incomings.push_back(Inst);
621 }
622
623 Value *StorePtr = updatePhiIncomings(BB, Incomings);
624 replacePhiDefWithLoad(PHI, StorePtr);
625}
626
627// Store the defined tile and load it before use.
628// All its users are not PHI.
629// e.g.
630// ------------------------------------------------------
631// def %td = ...
632// ...
633// "use %td"
634// ------------------------------------------------------
635// -->
636// ------------------------------------------------------
637// def %td = ...
638// call void @llvm.x86.tilestored64.internal(mem, %td)
639// ...
640// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
641// "use %td2"
642// ------------------------------------------------------
643void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
644 BasicBlock *BB = I->getParent();
645 Value *I8Ptr = getAllocaPos(BB);
646 User *Store = createTileStore(I, I8Ptr);
647
648 // All its uses should load from stored mem.
649 for (Use &U : I->uses()) {
650 User *V = U.getUser();
651 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
652 if (V != Store)
653 replaceWithTileLoad(U, I8Ptr);
654 }
655}
656
657// Volatile Tile Model:
658// 1) All the uses of tile data comes from tileload in time.
659// 2) All the defs of tile data tilestore into mem immediately.
660// For example:
661// --------------------------------------------------------------------------
662// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
663// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
664// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
665// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
666// call void @llvm.x86.tilestored64.internal(... td) area
667// --------------------------------------------------------------------------
668// 3) No terminator, call or other amx instructions in the key amx area.
669bool X86VolatileTileData::volatileTileData() {
670 bool Changed = false;
671 for (BasicBlock &BB : F) {
674
675 for (Instruction &I : BB) {
676 if (!I.getType()->isX86_AMXTy())
677 continue;
678 if (isa<PHINode>(&I))
679 PHIInsts.push_back(&I);
680 else
681 AMXDefInsts.push_back(&I);
682 }
683
684 // First we "volatile" the non-phi related amx intrinsics.
685 for (Instruction *I : AMXDefInsts) {
686 if (isIncomingOfPHI(I))
687 continue;
688 volatileTileNonPHI(I);
689 Changed = true;
690 }
691
692 for (Instruction *I : PHIInsts) {
693 volatileTilePHI(dyn_cast<PHINode>(I));
694 Changed = true;
695 }
696 }
697 return Changed;
698}
699
700} // anonymous namespace
701
702namespace {
703
704class X86LowerAMXCast {
705 Function &Func;
706 std::unique_ptr<DominatorTree> DT;
707
708public:
709 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
710 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
711 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
712 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
713 bool combineAMXcast(TargetLibraryInfo *TLI);
714 bool transformAMXCast(IntrinsicInst *AMXCast);
715 bool transformAllAMXCast();
716 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
718};
719
720static bool DCEInstruction(Instruction *I,
722 const TargetLibraryInfo *TLI) {
723 if (isInstructionTriviallyDead(I, TLI)) {
726
727 // Null out all of the instruction's operands to see if any operand becomes
728 // dead as we go.
729 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
730 Value *OpV = I->getOperand(i);
731 I->setOperand(i, nullptr);
732
733 if (!OpV->use_empty() || I == OpV)
734 continue;
735
736 // If the operand is an instruction that became dead as we nulled out the
737 // operand, and if it is 'trivially' dead, delete it in a future loop
738 // iteration.
739 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
740 if (isInstructionTriviallyDead(OpI, TLI)) {
741 WorkList.insert(OpI);
742 }
743 }
744 }
745 I->eraseFromParent();
746 return true;
747 }
748 return false;
749}
750
751/// This function handles following case
752///
753/// A -> B amxcast
754/// PHI
755/// B -> A amxcast
756///
757/// All the related PHI nodes can be replaced by new PHI nodes with type A.
758/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
759bool X86LowerAMXCast::optimizeAMXCastFromPhi(
760 IntrinsicInst *CI, PHINode *PN,
762 IRBuilder<> Builder(CI);
763 Value *Src = CI->getOperand(0);
764 Type *SrcTy = Src->getType(); // Type B
765 Type *DestTy = CI->getType(); // Type A
766
767 SmallVector<PHINode *, 4> PhiWorklist;
769
770 // Find all of the A->B casts and PHI nodes.
771 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
772 // OldPhiNodes is used to track all known PHI nodes, before adding a new
773 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
774 PhiWorklist.push_back(PN);
775 OldPhiNodes.insert(PN);
776 while (!PhiWorklist.empty()) {
777 auto *OldPN = PhiWorklist.pop_back_val();
778 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
779 Value *IncValue = OldPN->getIncomingValue(I);
780 // TODO: currently, We ignore cases where it is a const. In the future, we
781 // might support const.
782 if (isa<Constant>(IncValue)) {
783 auto *IncConst = dyn_cast<Constant>(IncValue);
784 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
785 return false;
786 Value *Row = nullptr, *Col = nullptr;
787 std::tie(Row, Col) = getShape(OldPN);
788 // TODO: If it is not constant the Row and Col must domoniate tilezero
789 // that we are going to create.
790 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
791 return false;
792 // Create tilezero at the end of incoming block.
793 auto *Block = OldPN->getIncomingBlock(I);
794 BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
795 Instruction *NewInst = Builder.CreateIntrinsic(
796 Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col});
797 NewInst->moveBefore(&*Iter);
798 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
799 {IncValue->getType()}, {NewInst});
800 NewInst->moveBefore(&*Iter);
801 // Replace InValue with new Value.
802 OldPN->setIncomingValue(I, NewInst);
803 IncValue = NewInst;
804 }
805
806 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
807 if (OldPhiNodes.insert(PNode))
808 PhiWorklist.push_back(PNode);
809 continue;
810 }
811 Instruction *ACI = dyn_cast<Instruction>(IncValue);
812 if (ACI && isAMXCast(ACI)) {
813 // Verify it's a A->B cast.
814 Type *TyA = ACI->getOperand(0)->getType();
815 Type *TyB = ACI->getType();
816 if (TyA != DestTy || TyB != SrcTy)
817 return false;
818 continue;
819 }
820 return false;
821 }
822 }
823
824 // Check that each user of each old PHI node is something that we can
825 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
826 for (auto *OldPN : OldPhiNodes) {
827 for (User *V : OldPN->users()) {
828 Instruction *ACI = dyn_cast<Instruction>(V);
829 if (ACI && isAMXCast(ACI)) {
830 // Verify it's a B->A cast.
831 Type *TyB = ACI->getOperand(0)->getType();
832 Type *TyA = ACI->getType();
833 if (TyA != DestTy || TyB != SrcTy)
834 return false;
835 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
836 // As long as the user is another old PHI node, then even if we don't
837 // rewrite it, the PHI web we're considering won't have any users
838 // outside itself, so it'll be dead.
839 // example:
840 // bb.0:
841 // %0 = amxcast ...
842 // bb.1:
843 // %1 = amxcast ...
844 // bb.2:
845 // %goodphi = phi %0, %1
846 // %3 = amxcast %goodphi
847 // bb.3:
848 // %goodphi2 = phi %0, %goodphi
849 // %4 = amxcast %goodphi2
850 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
851 // outside the phi-web, so the combination stop When
852 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
853 // will be done.
854 if (OldPhiNodes.count(PHI) == 0)
855 return false;
856 } else
857 return false;
858 }
859 }
860
861 // For each old PHI node, create a corresponding new PHI node with a type A.
863 for (auto *OldPN : OldPhiNodes) {
864 Builder.SetInsertPoint(OldPN);
865 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
866 NewPNodes[OldPN] = NewPN;
867 }
868
869 // Fill in the operands of new PHI nodes.
870 for (auto *OldPN : OldPhiNodes) {
871 PHINode *NewPN = NewPNodes[OldPN];
872 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
873 Value *V = OldPN->getOperand(j);
874 Value *NewV = nullptr;
875 Instruction *ACI = dyn_cast<Instruction>(V);
876 // There should not be a AMXcast from a const.
877 if (ACI && isAMXCast(ACI))
878 NewV = ACI->getOperand(0);
879 else if (auto *PrevPN = dyn_cast<PHINode>(V))
880 NewV = NewPNodes[PrevPN];
881 assert(NewV);
882 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
883 }
884 }
885
886 // Traverse all accumulated PHI nodes and process its users,
887 // which are Stores and BitcCasts. Without this processing
888 // NewPHI nodes could be replicated and could lead to extra
889 // moves generated after DeSSA.
890 // If there is a store with type B, change it to type A.
891
892 // Replace users of BitCast B->A with NewPHI. These will help
893 // later to get rid of a closure formed by OldPHI nodes.
894 for (auto *OldPN : OldPhiNodes) {
895 PHINode *NewPN = NewPNodes[OldPN];
896 for (User *V : make_early_inc_range(OldPN->users())) {
897 Instruction *ACI = dyn_cast<Instruction>(V);
898 if (ACI && isAMXCast(ACI)) {
899 Type *TyB = ACI->getOperand(0)->getType();
900 Type *TyA = ACI->getType();
901 assert(TyA == DestTy && TyB == SrcTy);
902 (void)TyA;
903 (void)TyB;
904 ACI->replaceAllUsesWith(NewPN);
905 DeadInst.insert(ACI);
906 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
907 // We don't need to push PHINode into DeadInst since they are operands
908 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
909 assert(OldPhiNodes.contains(PHI));
910 (void)PHI;
911 } else
912 llvm_unreachable("all uses should be handled");
913 }
914 }
915 return true;
916}
917
918// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
919// store <256 x i32> %43, <256 x i32>* %p, align 64
920// -->
921// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
922// i64 64, x86_amx %42)
923bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
924 Value *Tile = Cast->getOperand(0);
925 // TODO: If it is cast intrinsic or phi node, we can propagate the
926 // shape information through def-use chain.
927 if (!isAMXIntrinsic(Tile))
928 return false;
929 auto *II = cast<IntrinsicInst>(Tile);
930 // Tile is output from AMX intrinsic. The first operand of the
931 // intrinsic is row, the second operand of the intrinsic is column.
932 Value *Row = II->getOperand(0);
933 Value *Col = II->getOperand(1);
934 IRBuilder<> Builder(ST);
935 // Stride should be equal to col(measured by bytes)
936 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
937 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
938 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
939 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
940 Args);
941 return true;
942}
943
944// %65 = load <256 x i32>, <256 x i32>* %p, align 64
945// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
946// -->
947// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
948// i8* %p, i64 64)
949bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
950 bool EraseLoad = true;
951 Value *Row = nullptr, *Col = nullptr;
952 Use &U = *(Cast->use_begin());
953 unsigned OpNo = U.getOperandNo();
954 auto *II = cast<IntrinsicInst>(U.getUser());
955 // TODO: If it is cast intrinsic or phi node, we can propagate the
956 // shape information through def-use chain.
957 if (!isAMXIntrinsic(II))
958 return false;
959 std::tie(Row, Col) = getShape(II, OpNo);
960 IRBuilder<> Builder(LD);
961 // Stride should be equal to col(measured by bytes)
962 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
963 Value *I8Ptr;
964
965 // To save compiling time, we create doninator tree when it is really
966 // needed.
967 if (!DT)
968 DT.reset(new DominatorTree(Func));
969 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
970 // store the value to stack and reload it from stack before cast.
971 auto *AllocaAddr =
972 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
973 Builder.SetInsertPoint(&*std::next(LD->getIterator()));
974 Builder.CreateStore(LD, AllocaAddr);
975
976 Builder.SetInsertPoint(Cast);
977 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
978 EraseLoad = false;
979 } else {
980 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
981 }
982 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
983
984 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
985 std::nullopt, Args);
986 Cast->replaceAllUsesWith(NewInst);
987
988 return EraseLoad;
989}
990
991bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
992 bool Change = false;
993 for (auto *Cast : Casts) {
994 auto *II = cast<IntrinsicInst>(Cast);
995 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
996 // store <256 x i32> %43, <256 x i32>* %p, align 64
997 // -->
998 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
999 // i64 64, x86_amx %42)
1000 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1002 for (User *U : Cast->users()) {
1003 StoreInst *Store = dyn_cast<StoreInst>(U);
1004 if (!Store)
1005 continue;
1006 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1007 DeadStores.push_back(Store);
1008 Change = true;
1009 }
1010 }
1011 for (auto *Store : DeadStores)
1012 Store->eraseFromParent();
1013 } else { // x86_cast_vector_to_tile
1015 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1016 if (!Load || !Load->hasOneUse())
1017 continue;
1018 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1019 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1020 // -->
1021 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1022 // i8* %p, i64 64)
1023 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1024 // Set the operand is null so that load instruction can be erased.
1025 Cast->setOperand(0, nullptr);
1026 Load->eraseFromParent();
1027 }
1028 }
1029 }
1030 return Change;
1031}
1032
1033bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1034 bool Change = false;
1035 // Collect tile cast instruction.
1036 SmallVector<Instruction *, 8> Vec2TileInsts;
1037 SmallVector<Instruction *, 8> Tile2VecInsts;
1038 SmallVector<Instruction *, 8> PhiCastWorkList;
1040 for (BasicBlock &BB : Func) {
1041 for (Instruction &I : BB) {
1042 Value *Vec;
1043 if (match(&I,
1044 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1045 Vec2TileInsts.push_back(&I);
1046 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1047 m_Value(Vec))))
1048 Tile2VecInsts.push_back(&I);
1049 }
1050 }
1051
1052 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1053 for (auto *Inst : Insts) {
1054 for (User *U : Inst->users()) {
1055 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1056 if (!II || II->getIntrinsicID() != IID)
1057 continue;
1058 // T1 = vec2tile V0
1059 // V2 = tile2vec T1
1060 // V3 = OP V2
1061 // -->
1062 // T1 = vec2tile V0
1063 // V2 = tile2vec T1
1064 // V3 = OP V0
1065 II->replaceAllUsesWith(Inst->getOperand(0));
1066 Change = true;
1067 }
1068 }
1069 };
1070
1071 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1072 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1073
1075 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1076 for (auto *Inst : Insts) {
1077 if (Inst->use_empty()) {
1078 Inst->eraseFromParent();
1079 Change = true;
1080 } else {
1081 LiveCasts.push_back(Inst);
1082 }
1083 }
1084 };
1085
1086 EraseInst(Vec2TileInsts);
1087 EraseInst(Tile2VecInsts);
1088 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1089 "Vec2Tile and Tile2Vec:\n";
1090 Func.dump());
1091 Change |= combineLdSt(LiveCasts);
1092 EraseInst(LiveCasts);
1093 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1094 "AMXCast and load/store:\n";
1095 Func.dump());
1096
1097 // Handle the A->B->A cast, and there is an intervening PHI node.
1098 for (BasicBlock &BB : Func) {
1099 for (Instruction &I : BB) {
1100 if (isAMXCast(&I)) {
1101 if (isa<PHINode>(I.getOperand(0)))
1102 PhiCastWorkList.push_back(&I);
1103 }
1104 }
1105 }
1106 for (auto *I : PhiCastWorkList) {
1107 // We skip the dead Amxcast.
1108 if (DeadInst.contains(I))
1109 continue;
1110 PHINode *PN = cast<PHINode>(I->getOperand(0));
1111 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1112 DeadInst.insert(PN);
1113 Change = true;
1114 }
1115 }
1116
1117 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1118 // have no uses. We do some DeadCodeElimination for them.
1119 while (!DeadInst.empty()) {
1120 Instruction *I = DeadInst.pop_back_val();
1121 Change |= DCEInstruction(I, DeadInst, TLI);
1122 }
1123 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1124 "optimizeAMXCastFromPhi:\n";
1125 Func.dump());
1126 return Change;
1127}
1128
1129// There might be remaining AMXcast after combineAMXcast and they should be
1130// handled elegantly.
1131bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1132 IRBuilder<> Builder(AMXCast);
1133 AllocaInst *AllocaAddr;
1134 Value *I8Ptr, *Stride;
1135 auto *Src = AMXCast->getOperand(0);
1136
1137 auto Prepare = [&](Type *MemTy) {
1138 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1139 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1140 Stride = Builder.getInt64(64);
1141 };
1142
1143 if (AMXCast->getType()->isX86_AMXTy()) {
1144 // %2 = amxcast <225 x i32> %src to x86_amx
1145 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1146 // i8* %addr3, i64 60, x86_amx %2)
1147 // -->
1148 // %addr = alloca <225 x i32>, align 64
1149 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1150 // %addr2 = bitcast <225 x i32>* %addr to i8*
1151 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1152 // i8* %addr2,
1153 // i64 60)
1154 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1155 // i8* %addr3, i64 60, x86_amx %2)
1156 if (AMXCast->use_empty()) {
1157 AMXCast->eraseFromParent();
1158 return true;
1159 }
1160 Use &U = *(AMXCast->use_begin());
1161 unsigned OpNo = U.getOperandNo();
1162 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1163 if (!II)
1164 return false; // May be bitcast from x86amx to <256 x i32>.
1165 Prepare(AMXCast->getOperand(0)->getType());
1166 Builder.CreateStore(Src, AllocaAddr);
1167 // TODO we can pick an constant operand for the shape.
1168 Value *Row = nullptr, *Col = nullptr;
1169 std::tie(Row, Col) = getShape(II, OpNo);
1170 std::array<Value *, 4> Args = {
1171 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1172 Value *NewInst = Builder.CreateIntrinsic(
1173 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
1174 AMXCast->replaceAllUsesWith(NewInst);
1175 AMXCast->eraseFromParent();
1176 } else {
1177 // %2 = amxcast x86_amx %src to <225 x i32>
1178 // -->
1179 // %addr = alloca <225 x i32>, align 64
1180 // %addr2 = bitcast <225 x i32>* to i8*
1181 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1182 // i8* %addr2, i64 %stride)
1183 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1184 auto *II = dyn_cast<IntrinsicInst>(Src);
1185 if (!II)
1186 return false; // May be bitcast from <256 x i32> to x86amx.
1187 Prepare(AMXCast->getType());
1188 Value *Row = II->getOperand(0);
1189 Value *Col = II->getOperand(1);
1190 std::array<Value *, 5> Args = {
1191 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1192 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
1193 Args);
1194 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1195 AMXCast->replaceAllUsesWith(NewInst);
1196 AMXCast->eraseFromParent();
1197 }
1198
1199 return true;
1200}
1201
1202bool X86LowerAMXCast::transformAllAMXCast() {
1203 bool Change = false;
1204 // Collect tile cast instruction.
1206 for (BasicBlock &BB : Func) {
1207 for (Instruction &I : BB) {
1208 if (isAMXCast(&I))
1209 WorkLists.push_back(&I);
1210 }
1211 }
1212
1213 for (auto *Inst : WorkLists) {
1214 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1215 }
1216
1217 return Change;
1218}
1219
1220} // anonymous namespace
1221
1222namespace {
1223
1224class X86LowerAMXTypeLegacyPass : public FunctionPass {
1225public:
1226 static char ID;
1227
1228 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1230 }
1231
1232 bool runOnFunction(Function &F) override {
1233 bool C = false;
1234 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1235 TargetLibraryInfo *TLI =
1236 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1237
1238 X86LowerAMXCast LAC(F);
1239 C |= LAC.combineAMXcast(TLI);
1240 // There might be remaining AMXcast after combineAMXcast and they should be
1241 // handled elegantly.
1242 C |= LAC.transformAllAMXCast();
1243
1244 X86LowerAMXType LAT(F);
1245 C |= LAT.visit();
1246
1247 // Prepare for fast register allocation at O0.
1248 // Todo: May better check the volatile model of AMX code, not just
1249 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1250 if (TM->getOptLevel() == CodeGenOptLevel::None) {
1251 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1252 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1253 // sure the amx data is volatile, that is nessary for AMX fast
1254 // register allocation.
1255 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1256 X86VolatileTileData VTD(F);
1257 C = VTD.volatileTileData() || C;
1258 }
1259 }
1260
1261 return C;
1262 }
1263
1264 void getAnalysisUsage(AnalysisUsage &AU) const override {
1265 AU.setPreservesCFG();
1268 }
1269};
1270
1271} // anonymous namespace
1272
1273static const char PassName[] = "Lower AMX type for load/store";
1274char X86LowerAMXTypeLegacyPass::ID = 0;
1275INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1276 false)
1279INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1280 false)
1281
1283 return new X86LowerAMXTypeLegacyPass();
1284}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Rewrite undef for PHI
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition: DCE.cpp:55
#define LLVM_DEBUG(X)
Definition: Debug.h:101
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallSet class.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
static bool isAMXCast(Instruction *II)
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
static Value * getAllocaPos(BasicBlock *BB)
static bool isIncomingOfPHI(Instruction *I)
static bool isAMXIntrinsic(Value *I)
#define DEBUG_TYPE
static const char PassName[]
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
an instruction to allocate memory on the stack
Definition: Instructions.h:59
void setAlignment(Align Align)
Definition: Instructions.h:136
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:269
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:206
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:165
const Module * getModule() const
Return the module owning the function this basic block belongs to, or nullptr if the function does no...
Definition: BasicBlock.cpp:289
This class represents a no-op cast from one type to another.
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1687
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:932
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Definition: IRBuilder.h:526
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition: IRBuilder.h:1378
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition: IRBuilder.h:491
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2127
LLVMContext & getContext() const
Definition: IRBuilder.h:176
PointerType * getPtrTy(unsigned AddrSpace=0)
Fetch the type representing a pointer.
Definition: IRBuilder.h:569
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition: IRBuilder.h:481
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:180
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2666
const BasicBlock * getParent() const
Definition: Instruction.h:152
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:86
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:54
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
An instruction for reading from memory.
Definition: Instructions.h:184
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
bool empty() const
Determine if the SetVector is empty or not.
Definition: SetVector.h:93
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
value_type pop_back_val()
Definition: SetVector.h:285
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition: SetVector.h:254
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
bool empty() const
Definition: SmallVector.h:94
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:317
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:76
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static Type * getX86_AMXTy(LLVMContext &C)
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition: Type.h:204
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
use_iterator use_begin()
Definition: Value.h:360
bool use_empty() const
Definition: Value.h:344
self_iterator getIterator()
Definition: ilist_node.h:109
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
@ Bitcast
Perform the operation on a different, but equivalently sized type.
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
NodeAddr< FuncNode * > Func
Definition: RDFGraph.h:393
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
Definition: Utils.cpp:1650
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:656
iterator_range< po_iterator< T > > post_order(const T &G)
bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition: Local.cpp:399
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:419
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...