LLVM 19.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallSet.h"
48#include "llvm/CodeGen/Passes.h"
51#include "llvm/IR/DataLayout.h"
52#include "llvm/IR/Function.h"
53#include "llvm/IR/IRBuilder.h"
56#include "llvm/IR/IntrinsicsX86.h"
59#include "llvm/Pass.h"
63
64#include <map>
65
66using namespace llvm;
67using namespace PatternMatch;
68
69#define DEBUG_TYPE "lower-amx-type"
70
71static bool isAMXCast(Instruction *II) {
72 return match(II,
73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75}
76
77static bool isAMXIntrinsic(Value *I) {
78 auto *II = dyn_cast<IntrinsicInst>(I);
79 if (!II)
80 return false;
81 if (isAMXCast(II))
82 return false;
83 // Check if return type or parameter is x86_amx. If it is x86_amx
84 // the intrinsic must be x86 amx intrinsics.
85 if (II->getType()->isX86_AMXTy())
86 return true;
87 for (Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
89 return true;
90 }
91
92 return false;
93}
94
95static bool containsAMXCode(Function &F) {
96 for (BasicBlock &BB : F)
97 for (Instruction &I : BB)
98 if (I.getType()->isX86_AMXTy())
99 return true;
100 return false;
101}
102
104 Type *Ty) {
105 Function &F = *BB->getParent();
106 Module *M = BB->getModule();
107 const DataLayout &DL = M->getDataLayout();
108
109 LLVMContext &Ctx = Builder.getContext();
110 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
111 unsigned AllocaAS = DL.getAllocaAddrSpace();
112 AllocaInst *AllocaRes =
113 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
114 AllocaRes->setAlignment(AllocaAlignment);
115 return AllocaRes;
116}
117
119 for (Instruction &I : F.getEntryBlock())
120 if (!isa<AllocaInst>(&I))
121 return &I;
122 llvm_unreachable("No terminator in the entry block!");
123}
124
125static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
126 IRBuilder<> Builder(II);
127 Value *Row = nullptr, *Col = nullptr;
128 switch (II->getIntrinsicID()) {
129 default:
130 llvm_unreachable("Expect amx intrinsics");
131 case Intrinsic::x86_tileloadd64_internal:
132 case Intrinsic::x86_tileloaddt164_internal:
133 case Intrinsic::x86_tilestored64_internal: {
134 Row = II->getArgOperand(0);
135 Col = II->getArgOperand(1);
136 break;
137 }
138 // a * b + c
139 // The shape depends on which operand.
140 case Intrinsic::x86_tcmmimfp16ps_internal:
141 case Intrinsic::x86_tcmmrlfp16ps_internal:
142 case Intrinsic::x86_tdpbssd_internal:
143 case Intrinsic::x86_tdpbsud_internal:
144 case Intrinsic::x86_tdpbusd_internal:
145 case Intrinsic::x86_tdpbuud_internal:
146 case Intrinsic::x86_tdpbf16ps_internal:
147 case Intrinsic::x86_tdpfp16ps_internal: {
148 switch (OpNo) {
149 case 3:
150 Row = II->getArgOperand(0);
151 Col = II->getArgOperand(1);
152 break;
153 case 4:
154 Row = II->getArgOperand(0);
155 Col = II->getArgOperand(2);
156 break;
157 case 5:
158 if (isa<ConstantInt>(II->getArgOperand(2)))
159 Row = Builder.getInt16(
160 (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
161 else if (isa<Instruction>(II->getArgOperand(2))) {
162 // When it is not a const value and it is not a function argument, we
163 // create Row after the definition of II->getOperand(2) instead of
164 // before II. For example, II is %118, we try to getshape for %117:
165 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
166 // i32> %115).
167 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
168 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
169 // %117).
170 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
171 // definition is after its user(new tileload for %117).
172 // So, the best choice is to create %row right after the definition of
173 // %106.
174 Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
175 Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
176 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
177 } else {
178 // When it is not a const value and it is a function argument, we create
179 // Row at the entry bb.
180 IRBuilder<> NewBuilder(
181 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
182 Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
183 }
184 Col = II->getArgOperand(1);
185 break;
186 }
187 break;
188 }
189 }
190
191 return std::make_pair(Row, Col);
192}
193
194static std::pair<Value *, Value *> getShape(PHINode *Phi) {
195 Use &U = *(Phi->use_begin());
196 unsigned OpNo = U.getOperandNo();
197 User *V = U.getUser();
198 // TODO We don't traverse all users. To make the algorithm simple, here we
199 // just traverse the first user. If we can find shape, then return the shape,
200 // otherwise just return nullptr and the optimization for undef/zero will be
201 // abandoned.
202 while (V) {
203 if (isAMXCast(dyn_cast<Instruction>(V))) {
204 if (V->use_empty())
205 break;
206 Use &U = *(V->use_begin());
207 OpNo = U.getOperandNo();
208 V = U.getUser();
209 } else if (isAMXIntrinsic(V)) {
210 return getShape(cast<IntrinsicInst>(V), OpNo);
211 } else if (isa<PHINode>(V)) {
212 if (V->use_empty())
213 break;
214 Use &U = *(V->use_begin());
215 V = U.getUser();
216 } else {
217 break;
218 }
219 }
220
221 return std::make_pair(nullptr, nullptr);
222}
223
224namespace {
225class X86LowerAMXType {
226 Function &Func;
227
228 // In AMX intrinsics we let Shape = {Row, Col}, but the
229 // RealCol = Col / ElementSize. We may use the RealCol
230 // as a new Row for other new created AMX intrinsics.
231 std::map<Value *, Value *> Col2Row;
232
233public:
234 X86LowerAMXType(Function &F) : Func(F) {}
235 bool visit();
236 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
237 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
238 bool transformBitcast(BitCastInst *Bitcast);
239};
240
241// %src = load <256 x i32>, <256 x i32>* %addr, align 64
242// %2 = bitcast <256 x i32> %src to x86_amx
243// -->
244// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
245// i8* %addr, i64 %stride64)
246void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
247 Value *Row = nullptr, *Col = nullptr;
248 Use &U = *(Bitcast->use_begin());
249 unsigned OpNo = U.getOperandNo();
250 auto *II = cast<IntrinsicInst>(U.getUser());
251 std::tie(Row, Col) = getShape(II, OpNo);
252 IRBuilder<> Builder(Bitcast);
253 // Use the maximun column as stride.
254 Value *Stride = Builder.getInt64(64);
255 Value *I8Ptr = LD->getOperand(0);
256 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
257
258 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
259 std::nullopt, Args);
260 Bitcast->replaceAllUsesWith(NewInst);
261}
262
263// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
264// %stride);
265// %13 = bitcast x86_amx %src to <256 x i32>
266// store <256 x i32> %13, <256 x i32>* %addr, align 64
267// -->
268// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
269// %stride64, %13)
270void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
271
272 Value *Tile = Bitcast->getOperand(0);
273 auto *II = cast<IntrinsicInst>(Tile);
274 // Tile is output from AMX intrinsic. The first operand of the
275 // intrinsic is row, the second operand of the intrinsic is column.
276 Value *Row = II->getOperand(0);
277 Value *Col = II->getOperand(1);
278 IRBuilder<> Builder(ST);
279 // Use the maximum column as stride. It must be the same with load
280 // stride.
281 Value *Stride = Builder.getInt64(64);
282 Value *I8Ptr = ST->getOperand(1);
283 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
284 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
285 Args);
286 if (Bitcast->hasOneUse())
287 return;
288 // %13 = bitcast x86_amx %src to <256 x i32>
289 // store <256 x i32> %13, <256 x i32>* %addr, align 64
290 // %add = <256 x i32> %13, <256 x i32> %src2
291 // -->
292 // %13 = bitcast x86_amx %src to <256 x i32>
293 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
294 // %stride64, %13)
295 // %14 = load <256 x i32>, %addr
296 // %add = <256 x i32> %14, <256 x i32> %src2
297 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
298 Bitcast->replaceAllUsesWith(Vec);
299}
300
301// transform bitcast to <store, load> instructions.
302bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
303 IRBuilder<> Builder(Bitcast);
304 AllocaInst *AllocaAddr;
305 Value *I8Ptr, *Stride;
306 auto *Src = Bitcast->getOperand(0);
307
308 auto Prepare = [&](Type *MemTy) {
309 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
310 I8Ptr = AllocaAddr;
311 Stride = Builder.getInt64(64);
312 };
313
314 if (Bitcast->getType()->isX86_AMXTy()) {
315 // %2 = bitcast <256 x i32> %src to x86_amx
316 // -->
317 // %addr = alloca <256 x i32>, align 64
318 // store <256 x i32> %src, <256 x i32>* %addr, align 64
319 // %addr2 = bitcast <256 x i32>* to i8*
320 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
321 // i8* %addr2,
322 // i64 64)
323 Use &U = *(Bitcast->use_begin());
324 unsigned OpNo = U.getOperandNo();
325 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
326 if (!II)
327 return false; // May be bitcast from x86amx to <256 x i32>.
328 Prepare(Bitcast->getOperand(0)->getType());
329 Builder.CreateStore(Src, AllocaAddr);
330 // TODO we can pick an constant operand for the shape.
331 Value *Row = nullptr, *Col = nullptr;
332 std::tie(Row, Col) = getShape(II, OpNo);
333 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
334 Value *NewInst = Builder.CreateIntrinsic(
335 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
336 Bitcast->replaceAllUsesWith(NewInst);
337 } else {
338 // %2 = bitcast x86_amx %src to <256 x i32>
339 // -->
340 // %addr = alloca <256 x i32>, align 64
341 // %addr2 = bitcast <256 x i32>* to i8*
342 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
343 // i8* %addr2, i64 %stride)
344 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
345 auto *II = dyn_cast<IntrinsicInst>(Src);
346 if (!II)
347 return false; // May be bitcast from <256 x i32> to x86amx.
348 Prepare(Bitcast->getType());
349 Value *Row = II->getOperand(0);
350 Value *Col = II->getOperand(1);
351 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
352 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
353 Args);
354 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
355 Bitcast->replaceAllUsesWith(NewInst);
356 }
357
358 return true;
359}
360
361bool X86LowerAMXType::visit() {
363 Col2Row.clear();
364
365 for (BasicBlock *BB : post_order(&Func)) {
367 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
368 if (!Bitcast)
369 continue;
370
371 Value *Src = Bitcast->getOperand(0);
372 if (Bitcast->getType()->isX86_AMXTy()) {
373 if (Bitcast->user_empty()) {
374 DeadInsts.push_back(Bitcast);
375 continue;
376 }
377 LoadInst *LD = dyn_cast<LoadInst>(Src);
378 if (!LD) {
379 if (transformBitcast(Bitcast))
380 DeadInsts.push_back(Bitcast);
381 continue;
382 }
383 // If load has mutli-user, duplicate a vector load.
384 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
385 // %2 = bitcast <256 x i32> %src to x86_amx
386 // %add = add <256 x i32> %src, <256 x i32> %src2
387 // -->
388 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
389 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
390 // i8* %addr, i64 %stride64)
391 // %add = add <256 x i32> %src, <256 x i32> %src2
392
393 // If load has one user, the load will be eliminated in DAG ISel.
394 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
395 // %2 = bitcast <256 x i32> %src to x86_amx
396 // -->
397 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
398 // i8* %addr, i64 %stride64)
399 combineLoadBitcast(LD, Bitcast);
400 DeadInsts.push_back(Bitcast);
401 if (LD->hasOneUse())
402 DeadInsts.push_back(LD);
403 } else if (Src->getType()->isX86_AMXTy()) {
404 if (Bitcast->user_empty()) {
405 DeadInsts.push_back(Bitcast);
406 continue;
407 }
408 StoreInst *ST = nullptr;
409 for (Use &U : Bitcast->uses()) {
410 ST = dyn_cast<StoreInst>(U.getUser());
411 if (ST)
412 break;
413 }
414 if (!ST) {
415 if (transformBitcast(Bitcast))
416 DeadInsts.push_back(Bitcast);
417 continue;
418 }
419 // If bitcast (%13) has one use, combine bitcast and store to amx store.
420 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
421 // %stride);
422 // %13 = bitcast x86_amx %src to <256 x i32>
423 // store <256 x i32> %13, <256 x i32>* %addr, align 64
424 // -->
425 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
426 // %stride64, %13)
427 //
428 // If bitcast (%13) has multi-use, transform as below.
429 // %13 = bitcast x86_amx %src to <256 x i32>
430 // store <256 x i32> %13, <256 x i32>* %addr, align 64
431 // %add = <256 x i32> %13, <256 x i32> %src2
432 // -->
433 // %13 = bitcast x86_amx %src to <256 x i32>
434 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
435 // %stride64, %13)
436 // %14 = load <256 x i32>, %addr
437 // %add = <256 x i32> %14, <256 x i32> %src2
438 //
439 combineBitcastStore(Bitcast, ST);
440 // Delete user first.
441 DeadInsts.push_back(ST);
442 DeadInsts.push_back(Bitcast);
443 }
444 }
445 }
446
447 bool C = !DeadInsts.empty();
448
449 for (auto *Inst : DeadInsts)
450 Inst->eraseFromParent();
451
452 return C;
453}
454} // anonymous namespace
455
457 Module *M = BB->getModule();
458 Function *F = BB->getParent();
459 IRBuilder<> Builder(&F->getEntryBlock().front());
460 const DataLayout &DL = M->getDataLayout();
461 unsigned AllocaAS = DL.getAllocaAddrSpace();
462 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
463 AllocaInst *AllocaRes =
464 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
465 BasicBlock::iterator Iter = AllocaRes->getIterator();
466 ++Iter;
467 Builder.SetInsertPoint(&*Iter);
468 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
469 return I8Ptr;
470}
471
473 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
474 auto *II = cast<IntrinsicInst>(TileDef);
475 assert(II && "Not tile intrinsic!");
476 Value *Row = II->getOperand(0);
477 Value *Col = II->getOperand(1);
478
479 BasicBlock *BB = TileDef->getParent();
480 BasicBlock::iterator Iter = TileDef->getIterator();
481 IRBuilder<> Builder(BB, ++Iter);
482 Value *Stride = Builder.getInt64(64);
483 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
484
485 Instruction *TileStore = Builder.CreateIntrinsic(
486 Intrinsic::x86_tilestored64_internal, std::nullopt, Args);
487 return TileStore;
488}
489
490static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
491 Value *V = U.get();
492 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
493
494 // Get tile shape.
495 IntrinsicInst *II = nullptr;
496 if (IsPHI) {
497 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
498 II = cast<IntrinsicInst>(PhiOp);
499 } else {
500 II = cast<IntrinsicInst>(V);
501 }
502 Value *Row = II->getOperand(0);
503 Value *Col = II->getOperand(1);
504
505 Instruction *UserI = cast<Instruction>(U.getUser());
506 IRBuilder<> Builder(UserI);
507 Value *Stride = Builder.getInt64(64);
508 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
509
510 Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
511 std::nullopt, Args);
512 UserI->replaceUsesOfWith(V, TileLoad);
513}
514
516 for (Use &U : I->uses()) {
517 User *V = U.getUser();
518 if (isa<PHINode>(V))
519 return true;
520 }
521 return false;
522}
523
524// Let all AMX tile data become volatile data, shorten the life range
525// of each tile register before fast register allocation.
526namespace {
527class X86VolatileTileData {
528 Function &F;
529
530public:
531 X86VolatileTileData(Function &Func) : F(Func) {}
532 Value *updatePhiIncomings(BasicBlock *BB,
534 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
535 bool volatileTileData();
536 void volatileTilePHI(PHINode *PHI);
537 void volatileTileNonPHI(Instruction *I);
538};
539
540Value *X86VolatileTileData::updatePhiIncomings(
542 Value *I8Ptr = getAllocaPos(BB);
543
544 for (auto *I : Incomings) {
545 User *Store = createTileStore(I, I8Ptr);
546
547 // All its uses (except phi) should load from stored mem.
548 for (Use &U : I->uses()) {
549 User *V = U.getUser();
550 if (isa<PHINode>(V) || V == Store)
551 continue;
552 replaceWithTileLoad(U, I8Ptr);
553 }
554 }
555 return I8Ptr;
556}
557
558void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
559 Value *StorePtr) {
560 for (Use &U : PHI->uses())
561 replaceWithTileLoad(U, StorePtr, true);
562 PHI->eraseFromParent();
563}
564
565// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
566// and their related AMX intrinsics.
567// 1) PHI Def should change to tileload.
568// 2) PHI Incoming Values should tilestored in just after their def.
569// 3) The mem of these tileload and tilestores should be same.
570// e.g.
571// ------------------------------------------------------
572// bb_dom:
573// ...
574// br i1 %bool.cond, label %if.else, label %if.then
575//
576// if.then:
577// def %t0 = ...
578// ...
579// use %t0
580// ...
581// br label %if.end
582//
583// if.else:
584// def %t1 = ...
585// br label %if.end
586//
587// if.end:
588// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
589// ...
590// use %td
591// ------------------------------------------------------
592// -->
593// ------------------------------------------------------
594// bb_entry:
595// %mem = alloca <256 x i32>, align 1024 *
596// ...
597// bb_dom:
598// ...
599// br i1 %bool.cond, label %if.else, label %if.then
600//
601// if.then:
602// def %t0 = ...
603// call void @llvm.x86.tilestored64.internal(mem, %t0) *
604// ...
605// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
606// use %t0` *
607// ...
608// br label %if.end
609//
610// if.else:
611// def %t1 = ...
612// call void @llvm.x86.tilestored64.internal(mem, %t1) *
613// br label %if.end
614//
615// if.end:
616// ...
617// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
618// use %td
619// ------------------------------------------------------
620void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
621 BasicBlock *BB = PHI->getParent();
623
624 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
625 Value *Op = PHI->getIncomingValue(I);
626 Instruction *Inst = dyn_cast<Instruction>(Op);
627 assert(Inst && "We shouldn't fold AMX instrution!");
628 Incomings.push_back(Inst);
629 }
630
631 Value *StorePtr = updatePhiIncomings(BB, Incomings);
632 replacePhiDefWithLoad(PHI, StorePtr);
633}
634
635// Store the defined tile and load it before use.
636// All its users are not PHI.
637// e.g.
638// ------------------------------------------------------
639// def %td = ...
640// ...
641// "use %td"
642// ------------------------------------------------------
643// -->
644// ------------------------------------------------------
645// def %td = ...
646// call void @llvm.x86.tilestored64.internal(mem, %td)
647// ...
648// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
649// "use %td2"
650// ------------------------------------------------------
651void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
652 BasicBlock *BB = I->getParent();
653 Value *I8Ptr = getAllocaPos(BB);
654 User *Store = createTileStore(I, I8Ptr);
655
656 // All its uses should load from stored mem.
657 for (Use &U : I->uses()) {
658 User *V = U.getUser();
659 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
660 if (V != Store)
661 replaceWithTileLoad(U, I8Ptr);
662 }
663}
664
665// Volatile Tile Model:
666// 1) All the uses of tile data comes from tileload in time.
667// 2) All the defs of tile data tilestore into mem immediately.
668// For example:
669// --------------------------------------------------------------------------
670// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
671// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
672// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
673// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
674// call void @llvm.x86.tilestored64.internal(... td) area
675// --------------------------------------------------------------------------
676// 3) No terminator, call or other amx instructions in the key amx area.
677bool X86VolatileTileData::volatileTileData() {
678 bool Changed = false;
679 for (BasicBlock &BB : F) {
682
683 for (Instruction &I : BB) {
684 if (!I.getType()->isX86_AMXTy())
685 continue;
686 if (isa<PHINode>(&I))
687 PHIInsts.push_back(&I);
688 else
689 AMXDefInsts.push_back(&I);
690 }
691
692 // First we "volatile" the non-phi related amx intrinsics.
693 for (Instruction *I : AMXDefInsts) {
694 if (isIncomingOfPHI(I))
695 continue;
696 volatileTileNonPHI(I);
697 Changed = true;
698 }
699
700 for (Instruction *I : PHIInsts) {
701 volatileTilePHI(dyn_cast<PHINode>(I));
702 Changed = true;
703 }
704 }
705 return Changed;
706}
707
708} // anonymous namespace
709
710namespace {
711
712class X86LowerAMXCast {
713 Function &Func;
714 std::unique_ptr<DominatorTree> DT;
715
716public:
717 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
718 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
719 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
720 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
721 bool combineAMXcast(TargetLibraryInfo *TLI);
722 bool transformAMXCast(IntrinsicInst *AMXCast);
723 bool transformAllAMXCast();
724 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
726};
727
728static bool DCEInstruction(Instruction *I,
730 const TargetLibraryInfo *TLI) {
731 if (isInstructionTriviallyDead(I, TLI)) {
734
735 // Null out all of the instruction's operands to see if any operand becomes
736 // dead as we go.
737 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
738 Value *OpV = I->getOperand(i);
739 I->setOperand(i, nullptr);
740
741 if (!OpV->use_empty() || I == OpV)
742 continue;
743
744 // If the operand is an instruction that became dead as we nulled out the
745 // operand, and if it is 'trivially' dead, delete it in a future loop
746 // iteration.
747 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
748 if (isInstructionTriviallyDead(OpI, TLI)) {
749 WorkList.insert(OpI);
750 }
751 }
752 }
753 I->eraseFromParent();
754 return true;
755 }
756 return false;
757}
758
759/// This function handles following case
760///
761/// A -> B amxcast
762/// PHI
763/// B -> A amxcast
764///
765/// All the related PHI nodes can be replaced by new PHI nodes with type A.
766/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
767bool X86LowerAMXCast::optimizeAMXCastFromPhi(
768 IntrinsicInst *CI, PHINode *PN,
770 IRBuilder<> Builder(CI);
771 Value *Src = CI->getOperand(0);
772 Type *SrcTy = Src->getType(); // Type B
773 Type *DestTy = CI->getType(); // Type A
774
775 SmallVector<PHINode *, 4> PhiWorklist;
777
778 // Find all of the A->B casts and PHI nodes.
779 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
780 // OldPhiNodes is used to track all known PHI nodes, before adding a new
781 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
782 PhiWorklist.push_back(PN);
783 OldPhiNodes.insert(PN);
784 while (!PhiWorklist.empty()) {
785 auto *OldPN = PhiWorklist.pop_back_val();
786 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
787 Value *IncValue = OldPN->getIncomingValue(I);
788 // TODO: currently, We ignore cases where it is a const. In the future, we
789 // might support const.
790 if (isa<Constant>(IncValue)) {
791 auto *IncConst = dyn_cast<Constant>(IncValue);
792 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
793 return false;
794 Value *Row = nullptr, *Col = nullptr;
795 std::tie(Row, Col) = getShape(OldPN);
796 // TODO: If it is not constant the Row and Col must domoniate tilezero
797 // that we are going to create.
798 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
799 return false;
800 // Create tilezero at the end of incoming block.
801 auto *Block = OldPN->getIncomingBlock(I);
802 BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
803 Instruction *NewInst = Builder.CreateIntrinsic(
804 Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col});
805 NewInst->moveBefore(&*Iter);
806 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
807 {IncValue->getType()}, {NewInst});
808 NewInst->moveBefore(&*Iter);
809 // Replace InValue with new Value.
810 OldPN->setIncomingValue(I, NewInst);
811 IncValue = NewInst;
812 }
813
814 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
815 if (OldPhiNodes.insert(PNode))
816 PhiWorklist.push_back(PNode);
817 continue;
818 }
819 Instruction *ACI = dyn_cast<Instruction>(IncValue);
820 if (ACI && isAMXCast(ACI)) {
821 // Verify it's a A->B cast.
822 Type *TyA = ACI->getOperand(0)->getType();
823 Type *TyB = ACI->getType();
824 if (TyA != DestTy || TyB != SrcTy)
825 return false;
826 continue;
827 }
828 return false;
829 }
830 }
831
832 // Check that each user of each old PHI node is something that we can
833 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
834 for (auto *OldPN : OldPhiNodes) {
835 for (User *V : OldPN->users()) {
836 Instruction *ACI = dyn_cast<Instruction>(V);
837 if (ACI && isAMXCast(ACI)) {
838 // Verify it's a B->A cast.
839 Type *TyB = ACI->getOperand(0)->getType();
840 Type *TyA = ACI->getType();
841 if (TyA != DestTy || TyB != SrcTy)
842 return false;
843 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
844 // As long as the user is another old PHI node, then even if we don't
845 // rewrite it, the PHI web we're considering won't have any users
846 // outside itself, so it'll be dead.
847 // example:
848 // bb.0:
849 // %0 = amxcast ...
850 // bb.1:
851 // %1 = amxcast ...
852 // bb.2:
853 // %goodphi = phi %0, %1
854 // %3 = amxcast %goodphi
855 // bb.3:
856 // %goodphi2 = phi %0, %goodphi
857 // %4 = amxcast %goodphi2
858 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
859 // outside the phi-web, so the combination stop When
860 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
861 // will be done.
862 if (OldPhiNodes.count(PHI) == 0)
863 return false;
864 } else
865 return false;
866 }
867 }
868
869 // For each old PHI node, create a corresponding new PHI node with a type A.
871 for (auto *OldPN : OldPhiNodes) {
872 Builder.SetInsertPoint(OldPN);
873 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
874 NewPNodes[OldPN] = NewPN;
875 }
876
877 // Fill in the operands of new PHI nodes.
878 for (auto *OldPN : OldPhiNodes) {
879 PHINode *NewPN = NewPNodes[OldPN];
880 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
881 Value *V = OldPN->getOperand(j);
882 Value *NewV = nullptr;
883 Instruction *ACI = dyn_cast<Instruction>(V);
884 // There should not be a AMXcast from a const.
885 if (ACI && isAMXCast(ACI))
886 NewV = ACI->getOperand(0);
887 else if (auto *PrevPN = dyn_cast<PHINode>(V))
888 NewV = NewPNodes[PrevPN];
889 assert(NewV);
890 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
891 }
892 }
893
894 // Traverse all accumulated PHI nodes and process its users,
895 // which are Stores and BitcCasts. Without this processing
896 // NewPHI nodes could be replicated and could lead to extra
897 // moves generated after DeSSA.
898 // If there is a store with type B, change it to type A.
899
900 // Replace users of BitCast B->A with NewPHI. These will help
901 // later to get rid of a closure formed by OldPHI nodes.
902 for (auto *OldPN : OldPhiNodes) {
903 PHINode *NewPN = NewPNodes[OldPN];
904 for (User *V : make_early_inc_range(OldPN->users())) {
905 Instruction *ACI = dyn_cast<Instruction>(V);
906 if (ACI && isAMXCast(ACI)) {
907 Type *TyB = ACI->getOperand(0)->getType();
908 Type *TyA = ACI->getType();
909 assert(TyA == DestTy && TyB == SrcTy);
910 (void)TyA;
911 (void)TyB;
912 ACI->replaceAllUsesWith(NewPN);
913 DeadInst.insert(ACI);
914 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
915 // We don't need to push PHINode into DeadInst since they are operands
916 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
917 assert(OldPhiNodes.contains(PHI));
918 (void)PHI;
919 } else
920 llvm_unreachable("all uses should be handled");
921 }
922 }
923 return true;
924}
925
926// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
927// store <256 x i32> %43, <256 x i32>* %p, align 64
928// -->
929// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
930// i64 64, x86_amx %42)
931bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
932 Value *Tile = Cast->getOperand(0);
933 // TODO: If it is cast intrinsic or phi node, we can propagate the
934 // shape information through def-use chain.
935 if (!isAMXIntrinsic(Tile))
936 return false;
937 auto *II = cast<IntrinsicInst>(Tile);
938 // Tile is output from AMX intrinsic. The first operand of the
939 // intrinsic is row, the second operand of the intrinsic is column.
940 Value *Row = II->getOperand(0);
941 Value *Col = II->getOperand(1);
942 IRBuilder<> Builder(ST);
943 // Stride should be equal to col(measured by bytes)
944 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
945 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
946 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
947 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
948 Args);
949 return true;
950}
951
952// %65 = load <256 x i32>, <256 x i32>* %p, align 64
953// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
954// -->
955// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
956// i8* %p, i64 64)
957bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
958 bool EraseLoad = true;
959 Value *Row = nullptr, *Col = nullptr;
960 Use &U = *(Cast->use_begin());
961 unsigned OpNo = U.getOperandNo();
962 auto *II = cast<IntrinsicInst>(U.getUser());
963 // TODO: If it is cast intrinsic or phi node, we can propagate the
964 // shape information through def-use chain.
965 if (!isAMXIntrinsic(II))
966 return false;
967 std::tie(Row, Col) = getShape(II, OpNo);
968 IRBuilder<> Builder(LD);
969 // Stride should be equal to col(measured by bytes)
970 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
971 Value *I8Ptr;
972
973 // To save compiling time, we create doninator tree when it is really
974 // needed.
975 if (!DT)
976 DT.reset(new DominatorTree(Func));
977 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
978 // store the value to stack and reload it from stack before cast.
979 auto *AllocaAddr =
980 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
981 Builder.SetInsertPoint(&*std::next(LD->getIterator()));
982 Builder.CreateStore(LD, AllocaAddr);
983
984 Builder.SetInsertPoint(Cast);
985 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
986 EraseLoad = false;
987 } else {
988 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
989 }
990 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
991
992 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
993 std::nullopt, Args);
994 Cast->replaceAllUsesWith(NewInst);
995
996 return EraseLoad;
997}
998
999bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
1000 bool Change = false;
1001 for (auto *Cast : Casts) {
1002 auto *II = cast<IntrinsicInst>(Cast);
1003 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
1004 // store <256 x i32> %43, <256 x i32>* %p, align 64
1005 // -->
1006 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1007 // i64 64, x86_amx %42)
1008 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1010 for (User *U : Cast->users()) {
1011 StoreInst *Store = dyn_cast<StoreInst>(U);
1012 if (!Store)
1013 continue;
1014 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1015 DeadStores.push_back(Store);
1016 Change = true;
1017 }
1018 }
1019 for (auto *Store : DeadStores)
1020 Store->eraseFromParent();
1021 } else { // x86_cast_vector_to_tile
1023 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1024 if (!Load || !Load->hasOneUse())
1025 continue;
1026 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1027 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1028 // -->
1029 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1030 // i8* %p, i64 64)
1031 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1032 // Set the operand is null so that load instruction can be erased.
1033 Cast->setOperand(0, nullptr);
1034 Load->eraseFromParent();
1035 }
1036 }
1037 }
1038 return Change;
1039}
1040
1041bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1042 bool Change = false;
1043 // Collect tile cast instruction.
1044 SmallVector<Instruction *, 8> Vec2TileInsts;
1045 SmallVector<Instruction *, 8> Tile2VecInsts;
1046 SmallVector<Instruction *, 8> PhiCastWorkList;
1048 for (BasicBlock &BB : Func) {
1049 for (Instruction &I : BB) {
1050 Value *Vec;
1051 if (match(&I,
1052 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1053 Vec2TileInsts.push_back(&I);
1054 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1055 m_Value(Vec))))
1056 Tile2VecInsts.push_back(&I);
1057 }
1058 }
1059
1060 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1061 for (auto *Inst : Insts) {
1062 for (User *U : Inst->users()) {
1063 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1064 if (!II || II->getIntrinsicID() != IID)
1065 continue;
1066 // T1 = vec2tile V0
1067 // V2 = tile2vec T1
1068 // V3 = OP V2
1069 // -->
1070 // T1 = vec2tile V0
1071 // V2 = tile2vec T1
1072 // V3 = OP V0
1073 II->replaceAllUsesWith(Inst->getOperand(0));
1074 Change = true;
1075 }
1076 }
1077 };
1078
1079 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1080 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1081
1083 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1084 for (auto *Inst : Insts) {
1085 if (Inst->use_empty()) {
1086 Inst->eraseFromParent();
1087 Change = true;
1088 } else {
1089 LiveCasts.push_back(Inst);
1090 }
1091 }
1092 };
1093
1094 EraseInst(Vec2TileInsts);
1095 EraseInst(Tile2VecInsts);
1096 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1097 "Vec2Tile and Tile2Vec:\n";
1098 Func.dump());
1099 Change |= combineLdSt(LiveCasts);
1100 EraseInst(LiveCasts);
1101 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1102 "AMXCast and load/store:\n";
1103 Func.dump());
1104
1105 // Handle the A->B->A cast, and there is an intervening PHI node.
1106 for (BasicBlock &BB : Func) {
1107 for (Instruction &I : BB) {
1108 if (isAMXCast(&I)) {
1109 if (isa<PHINode>(I.getOperand(0)))
1110 PhiCastWorkList.push_back(&I);
1111 }
1112 }
1113 }
1114 for (auto *I : PhiCastWorkList) {
1115 // We skip the dead Amxcast.
1116 if (DeadInst.contains(I))
1117 continue;
1118 PHINode *PN = cast<PHINode>(I->getOperand(0));
1119 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1120 DeadInst.insert(PN);
1121 Change = true;
1122 }
1123 }
1124
1125 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1126 // have no uses. We do some DeadCodeElimination for them.
1127 while (!DeadInst.empty()) {
1128 Instruction *I = DeadInst.pop_back_val();
1129 Change |= DCEInstruction(I, DeadInst, TLI);
1130 }
1131 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1132 "optimizeAMXCastFromPhi:\n";
1133 Func.dump());
1134 return Change;
1135}
1136
1137// There might be remaining AMXcast after combineAMXcast and they should be
1138// handled elegantly.
1139bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1140 IRBuilder<> Builder(AMXCast);
1141 AllocaInst *AllocaAddr;
1142 Value *I8Ptr, *Stride;
1143 auto *Src = AMXCast->getOperand(0);
1144
1145 auto Prepare = [&](Type *MemTy) {
1146 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1147 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1148 Stride = Builder.getInt64(64);
1149 };
1150
1151 if (AMXCast->getType()->isX86_AMXTy()) {
1152 // %2 = amxcast <225 x i32> %src to x86_amx
1153 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1154 // i8* %addr3, i64 60, x86_amx %2)
1155 // -->
1156 // %addr = alloca <225 x i32>, align 64
1157 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1158 // %addr2 = bitcast <225 x i32>* %addr to i8*
1159 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1160 // i8* %addr2,
1161 // i64 60)
1162 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1163 // i8* %addr3, i64 60, x86_amx %2)
1164 if (AMXCast->use_empty()) {
1165 AMXCast->eraseFromParent();
1166 return true;
1167 }
1168 Use &U = *(AMXCast->use_begin());
1169 unsigned OpNo = U.getOperandNo();
1170 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1171 if (!II)
1172 return false; // May be bitcast from x86amx to <256 x i32>.
1173 Prepare(AMXCast->getOperand(0)->getType());
1174 Builder.CreateStore(Src, AllocaAddr);
1175 // TODO we can pick an constant operand for the shape.
1176 Value *Row = nullptr, *Col = nullptr;
1177 std::tie(Row, Col) = getShape(II, OpNo);
1178 std::array<Value *, 4> Args = {
1179 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1180 Value *NewInst = Builder.CreateIntrinsic(
1181 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
1182 AMXCast->replaceAllUsesWith(NewInst);
1183 AMXCast->eraseFromParent();
1184 } else {
1185 // %2 = amxcast x86_amx %src to <225 x i32>
1186 // -->
1187 // %addr = alloca <225 x i32>, align 64
1188 // %addr2 = bitcast <225 x i32>* to i8*
1189 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1190 // i8* %addr2, i64 %stride)
1191 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1192 auto *II = dyn_cast<IntrinsicInst>(Src);
1193 if (!II)
1194 return false; // May be bitcast from <256 x i32> to x86amx.
1195 Prepare(AMXCast->getType());
1196 Value *Row = II->getOperand(0);
1197 Value *Col = II->getOperand(1);
1198 std::array<Value *, 5> Args = {
1199 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1200 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
1201 Args);
1202 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1203 AMXCast->replaceAllUsesWith(NewInst);
1204 AMXCast->eraseFromParent();
1205 }
1206
1207 return true;
1208}
1209
1210bool X86LowerAMXCast::transformAllAMXCast() {
1211 bool Change = false;
1212 // Collect tile cast instruction.
1214 for (BasicBlock &BB : Func) {
1215 for (Instruction &I : BB) {
1216 if (isAMXCast(&I))
1217 WorkLists.push_back(&I);
1218 }
1219 }
1220
1221 for (auto *Inst : WorkLists) {
1222 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1223 }
1224
1225 return Change;
1226}
1227
1228} // anonymous namespace
1229
1230namespace {
1231
1232class X86LowerAMXTypeLegacyPass : public FunctionPass {
1233public:
1234 static char ID;
1235
1236 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1238 }
1239
1240 bool runOnFunction(Function &F) override {
1241 // Performance optimization: most code doesn't use AMX, so return early if
1242 // there are no instructions that produce AMX values. This is sufficient, as
1243 // AMX arguments and constants are not allowed -- so any producer of an AMX
1244 // value must be an instruction.
1245 // TODO: find a cheaper way for this, without looking at all instructions.
1246 if (!containsAMXCode(F))
1247 return false;
1248
1249 bool C = false;
1250 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1251 TargetLibraryInfo *TLI =
1252 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1253
1254 X86LowerAMXCast LAC(F);
1255 C |= LAC.combineAMXcast(TLI);
1256 // There might be remaining AMXcast after combineAMXcast and they should be
1257 // handled elegantly.
1258 C |= LAC.transformAllAMXCast();
1259
1260 X86LowerAMXType LAT(F);
1261 C |= LAT.visit();
1262
1263 // Prepare for fast register allocation at O0.
1264 // Todo: May better check the volatile model of AMX code, not just
1265 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1266 if (TM->getOptLevel() == CodeGenOptLevel::None) {
1267 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1268 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1269 // sure the amx data is volatile, that is nessary for AMX fast
1270 // register allocation.
1271 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1272 X86VolatileTileData VTD(F);
1273 C = VTD.volatileTileData() || C;
1274 }
1275 }
1276
1277 return C;
1278 }
1279
1280 void getAnalysisUsage(AnalysisUsage &AU) const override {
1281 AU.setPreservesCFG();
1284 }
1285};
1286
1287} // anonymous namespace
1288
1289static const char PassName[] = "Lower AMX type for load/store";
1290char X86LowerAMXTypeLegacyPass::ID = 0;
1291INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1292 false)
1295INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1296 false)
1297
1299 return new X86LowerAMXTypeLegacyPass();
1300}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Rewrite undef for PHI
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition: DCE.cpp:55
#define LLVM_DEBUG(X)
Definition: Debug.h:101
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
uint64_t IntrinsicInst * II
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallSet class.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
static bool isAMXCast(Instruction *II)
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
static Value * getAllocaPos(BasicBlock *BB)
static bool containsAMXCode(Function &F)
static bool isIncomingOfPHI(Instruction *I)
static bool isAMXIntrinsic(Value *I)
#define DEBUG_TYPE
static const char PassName[]
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
an instruction to allocate memory on the stack
Definition: Instructions.h:60
void setAlignment(Align Align)
Definition: Instructions.h:125
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:269
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:209
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:167
const Module * getModule() const
Return the module owning the function this basic block belongs to, or nullptr if the function does no...
Definition: BasicBlock.cpp:290
This class represents a no-op cast from one type to another.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:932
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Definition: IRBuilder.h:524
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition: IRBuilder.h:1376
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition: IRBuilder.h:489
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2125
LLVMContext & getContext() const
Definition: IRBuilder.h:174
PointerType * getPtrTy(unsigned AddrSpace=0)
Fetch the type representing a pointer.
Definition: IRBuilder.h:567
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition: IRBuilder.h:479
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:178
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2664
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:92
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:48
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
An instruction for reading from memory.
Definition: Instructions.h:173
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
bool empty() const
Determine if the SetVector is empty or not.
Definition: SetVector.h:93
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
value_type pop_back_val()
Definition: SetVector.h:285
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition: SetVector.h:254
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
bool empty() const
Definition: SmallVector.h:94
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:289
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static Type * getX86_AMXTy(LLVMContext &C)
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition: Type.h:204
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
use_iterator use_begin()
Definition: Value.h:360
bool use_empty() const
Definition: Value.h:344
const ParentTy * getParent() const
Definition: ilist_node.h:32
self_iterator getIterator()
Definition: ilist_node.h:132
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
@ Bitcast
Perform the operation on a different, but equivalently sized type.
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
NodeAddr< FuncNode * > Func
Definition: RDFGraph.h:393
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
Definition: Utils.cpp:1665
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:656
iterator_range< po_iterator< T > > post_order(const T &G)
bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition: Local.cpp:400
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:419
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...