LLVM 22.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
43#include "llvm/ADT/SetVector.h"
46#include "llvm/CodeGen/Passes.h"
49#include "llvm/IR/Analysis.h"
50#include "llvm/IR/DataLayout.h"
51#include "llvm/IR/Function.h"
52#include "llvm/IR/IRBuilder.h"
55#include "llvm/IR/IntrinsicsX86.h"
56#include "llvm/IR/PassManager.h"
59#include "llvm/Pass.h"
63
64#include <map>
65
66using namespace llvm;
67using namespace PatternMatch;
68
69#define DEBUG_TYPE "x86-lower-amx-type"
70
76
77static bool isAMXIntrinsic(Value *I) {
79 if (!II)
80 return false;
81 if (isAMXCast(II))
82 return false;
83 // Check if return type or parameter is x86_amx. If it is x86_amx
84 // the intrinsic must be x86 amx intrinsics.
85 if (II->getType()->isX86_AMXTy())
86 return true;
87 for (Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
89 return true;
90 }
91
92 return false;
93}
94
95static bool containsAMXCode(Function &F) {
96 for (BasicBlock &BB : F)
97 for (Instruction &I : BB)
98 if (I.getType()->isX86_AMXTy())
99 return true;
100 return false;
101}
102
104 Type *Ty) {
105 Function &F = *BB->getParent();
106 const DataLayout &DL = F.getDataLayout();
107
108 LLVMContext &Ctx = Builder.getContext();
109 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
110 unsigned AllocaAS = DL.getAllocaAddrSpace();
111 AllocaInst *AllocaRes =
112 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
113 AllocaRes->setAlignment(AllocaAlignment);
114 return AllocaRes;
115}
116
118 for (Instruction &I : F.getEntryBlock())
119 if (!isa<AllocaInst>(&I))
120 return &I;
121 llvm_unreachable("No terminator in the entry block!");
122}
123
124static Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity) {
125 IRBuilder<> Builder(II);
126 Value *RealRow = nullptr;
127 if (isa<ConstantInt>(V))
128 RealRow =
129 Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) / Granularity);
130 else if (isa<Instruction>(V)) {
131 // When it is not a const value and it is not a function argument, we
132 // create Row after the definition of V instead of
133 // before II. For example, II is %118, we try to getshape for %117:
134 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
135 // i32> %115).
136 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
137 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
138 // %117).
139 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
140 // definition is after its user(new tileload for %117).
141 // So, the best choice is to create %row right after the definition of
142 // %106.
143 Builder.SetInsertPoint(cast<Instruction>(V));
144 RealRow = Builder.CreateUDiv(V, Builder.getInt16(4));
145 cast<Instruction>(RealRow)->moveAfter(cast<Instruction>(V));
146 } else {
147 // When it is not a const value and it is a function argument, we create
148 // Row at the entry bb.
149 IRBuilder<> NewBuilder(
150 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
151 RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));
152 }
153 return RealRow;
154}
155
156// TODO: Refine the row and col-in-bytes of tile to row and col of matrix.
157std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
158 IRBuilder<> Builder(II);
159 Value *Row = nullptr, *Col = nullptr;
160 switch (II->getIntrinsicID()) {
161 default:
162 llvm_unreachable("Expect amx intrinsics");
163 case Intrinsic::x86_tileloadd64_internal:
164 case Intrinsic::x86_tileloaddt164_internal:
165 case Intrinsic::x86_tilestored64_internal:
166 case Intrinsic::x86_t2rpntlvwz0rs_internal:
167 case Intrinsic::x86_t2rpntlvwz0rst1_internal:
168 case Intrinsic::x86_t2rpntlvwz1rs_internal:
169 case Intrinsic::x86_t2rpntlvwz1rst1_internal:
170 case Intrinsic::x86_tileloaddrs64_internal:
171 case Intrinsic::x86_tileloaddrst164_internal: {
172 Row = II->getArgOperand(0);
173 Col = II->getArgOperand(1);
174 break;
175 }
176 // a * b + c
177 // The shape depends on which operand.
178 case Intrinsic::x86_tcmmimfp16ps_internal:
179 case Intrinsic::x86_tcmmrlfp16ps_internal:
180 case Intrinsic::x86_tdpbssd_internal:
181 case Intrinsic::x86_tdpbsud_internal:
182 case Intrinsic::x86_tdpbusd_internal:
183 case Intrinsic::x86_tdpbuud_internal:
184 case Intrinsic::x86_tdpbf16ps_internal:
185 case Intrinsic::x86_tdpfp16ps_internal:
186 case Intrinsic::x86_tmmultf32ps_internal:
187 case Intrinsic::x86_tdpbf8ps_internal:
188 case Intrinsic::x86_tdpbhf8ps_internal:
189 case Intrinsic::x86_tdphbf8ps_internal:
190 case Intrinsic::x86_tdphf8ps_internal: {
191 switch (OpNo) {
192 case 3:
193 Row = II->getArgOperand(0);
194 Col = II->getArgOperand(1);
195 break;
196 case 4:
197 Row = II->getArgOperand(0);
198 Col = II->getArgOperand(2);
199 break;
200 case 5:
201 Row = getRowFromCol(II, II->getArgOperand(2), 4);
202 Col = II->getArgOperand(1);
203 break;
204 }
205 break;
206 }
207 case Intrinsic::x86_tcvtrowd2ps_internal:
208 case Intrinsic::x86_tcvtrowps2bf16h_internal:
209 case Intrinsic::x86_tcvtrowps2bf16l_internal:
210 case Intrinsic::x86_tcvtrowps2phh_internal:
211 case Intrinsic::x86_tcvtrowps2phl_internal:
212 case Intrinsic::x86_tilemovrow_internal: {
213 assert(OpNo == 2 && "Illegal Operand Number.");
214 Row = II->getArgOperand(0);
215 Col = II->getArgOperand(1);
216 break;
217 }
218 }
219
220 return std::make_pair(Row, Col);
221}
222
223static std::pair<Value *, Value *> getShape(PHINode *Phi) {
224 Use &U = *(Phi->use_begin());
225 unsigned OpNo = U.getOperandNo();
226 User *V = U.getUser();
227 // TODO We don't traverse all users. To make the algorithm simple, here we
228 // just traverse the first user. If we can find shape, then return the shape,
229 // otherwise just return nullptr and the optimization for undef/zero will be
230 // abandoned.
231 while (V) {
233 if (V->use_empty())
234 break;
235 Use &U = *(V->use_begin());
236 OpNo = U.getOperandNo();
237 V = U.getUser();
238 } else if (isAMXIntrinsic(V)) {
239 return getShape(cast<IntrinsicInst>(V), OpNo);
240 } else if (isa<PHINode>(V)) {
241 if (V->use_empty())
242 break;
243 Use &U = *(V->use_begin());
244 V = U.getUser();
245 } else {
246 break;
247 }
248 }
249
250 return std::make_pair(nullptr, nullptr);
251}
252
253namespace {
254class X86LowerAMXType {
255 Function &Func;
256
257 // In AMX intrinsics we let Shape = {Row, Col}, but the
258 // RealCol = Col / ElementSize. We may use the RealCol
259 // as a new Row for other new created AMX intrinsics.
260 std::map<Value *, Value *> Col2Row;
261
262public:
263 X86LowerAMXType(Function &F) : Func(F) {}
264 bool visit();
265 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
266 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
267 bool transformBitcast(BitCastInst *Bitcast);
268};
269
270// %src = load <256 x i32>, <256 x i32>* %addr, align 64
271// %2 = bitcast <256 x i32> %src to x86_amx
272// -->
273// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
274// i8* %addr, i64 %stride64)
275void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
276 Value *Row = nullptr, *Col = nullptr;
277 Use &U = *(Bitcast->use_begin());
278 unsigned OpNo = U.getOperandNo();
279 auto *II = cast<IntrinsicInst>(U.getUser());
280 std::tie(Row, Col) = getShape(II, OpNo);
281 IRBuilder<> Builder(Bitcast);
282 // Use the maximun column as stride.
283 Value *Stride = Builder.getInt64(64);
284 Value *I8Ptr = LD->getOperand(0);
285 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
286
287 Value *NewInst =
288 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
289 Bitcast->replaceAllUsesWith(NewInst);
290}
291
292// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
293// %stride);
294// %13 = bitcast x86_amx %src to <256 x i32>
295// store <256 x i32> %13, <256 x i32>* %addr, align 64
296// -->
297// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
298// %stride64, %13)
299void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
300
301 Value *Tile = Bitcast->getOperand(0);
302 auto *II = cast<IntrinsicInst>(Tile);
303 // Tile is output from AMX intrinsic. The first operand of the
304 // intrinsic is row, the second operand of the intrinsic is column.
305 Value *Row = II->getOperand(0);
306 Value *Col = II->getOperand(1);
307 IRBuilder<> Builder(ST);
308 // Use the maximum column as stride. It must be the same with load
309 // stride.
310 Value *Stride = Builder.getInt64(64);
311 Value *I8Ptr = ST->getOperand(1);
312 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
313 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
314 if (Bitcast->hasOneUse())
315 return;
316 // %13 = bitcast x86_amx %src to <256 x i32>
317 // store <256 x i32> %13, <256 x i32>* %addr, align 64
318 // %add = <256 x i32> %13, <256 x i32> %src2
319 // -->
320 // %13 = bitcast x86_amx %src to <256 x i32>
321 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
322 // %stride64, %13)
323 // %14 = load <256 x i32>, %addr
324 // %add = <256 x i32> %14, <256 x i32> %src2
325 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
326 Bitcast->replaceAllUsesWith(Vec);
327}
328
329// transform bitcast to <store, load> instructions.
330bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
331 IRBuilder<> Builder(Bitcast);
332 AllocaInst *AllocaAddr;
333 Value *I8Ptr, *Stride;
334 auto *Src = Bitcast->getOperand(0);
335
336 auto Prepare = [&](Type *MemTy) {
337 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
338 I8Ptr = AllocaAddr;
339 Stride = Builder.getInt64(64);
340 };
341
342 if (Bitcast->getType()->isX86_AMXTy()) {
343 // %2 = bitcast <256 x i32> %src to x86_amx
344 // -->
345 // %addr = alloca <256 x i32>, align 64
346 // store <256 x i32> %src, <256 x i32>* %addr, align 64
347 // %addr2 = bitcast <256 x i32>* to i8*
348 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
349 // i8* %addr2,
350 // i64 64)
351 Use &U = *(Bitcast->use_begin());
352 unsigned OpNo = U.getOperandNo();
353 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
354 if (!II)
355 return false; // May be bitcast from x86amx to <256 x i32>.
356 Prepare(Bitcast->getOperand(0)->getType());
357 Builder.CreateStore(Src, AllocaAddr);
358 // TODO we can pick an constant operand for the shape.
359 Value *Row = nullptr, *Col = nullptr;
360 std::tie(Row, Col) = getShape(II, OpNo);
361 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
362 Value *NewInst =
363 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
364 Bitcast->replaceAllUsesWith(NewInst);
365 } else {
366 // %2 = bitcast x86_amx %src to <256 x i32>
367 // -->
368 // %addr = alloca <256 x i32>, align 64
369 // %addr2 = bitcast <256 x i32>* to i8*
370 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
371 // i8* %addr2, i64 %stride)
372 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
373 auto *II = dyn_cast<IntrinsicInst>(Src);
374 if (!II)
375 return false; // May be bitcast from <256 x i32> to x86amx.
376 Prepare(Bitcast->getType());
377 Value *Row = II->getOperand(0);
378 Value *Col = II->getOperand(1);
379 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
380 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
381 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
382 Bitcast->replaceAllUsesWith(NewInst);
383 }
384
385 return true;
386}
387
388bool X86LowerAMXType::visit() {
389 SmallVector<Instruction *, 8> DeadInsts;
390 Col2Row.clear();
391
392 for (BasicBlock *BB : post_order(&Func)) {
393 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
394 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
395 if (!Bitcast)
396 continue;
397
398 Value *Src = Bitcast->getOperand(0);
399 if (Bitcast->getType()->isX86_AMXTy()) {
400 if (Bitcast->user_empty()) {
401 DeadInsts.push_back(Bitcast);
402 continue;
403 }
404 LoadInst *LD = dyn_cast<LoadInst>(Src);
405 if (!LD) {
406 if (transformBitcast(Bitcast))
407 DeadInsts.push_back(Bitcast);
408 continue;
409 }
410 // If load has multi-user, duplicate a vector load.
411 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
412 // %2 = bitcast <256 x i32> %src to x86_amx
413 // %add = add <256 x i32> %src, <256 x i32> %src2
414 // -->
415 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
416 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
417 // i8* %addr, i64 %stride64)
418 // %add = add <256 x i32> %src, <256 x i32> %src2
419
420 // If load has one user, the load will be eliminated in DAG ISel.
421 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
422 // %2 = bitcast <256 x i32> %src to x86_amx
423 // -->
424 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
425 // i8* %addr, i64 %stride64)
426 combineLoadBitcast(LD, Bitcast);
427 DeadInsts.push_back(Bitcast);
428 if (LD->hasOneUse())
429 DeadInsts.push_back(LD);
430 } else if (Src->getType()->isX86_AMXTy()) {
431 if (Bitcast->user_empty()) {
432 DeadInsts.push_back(Bitcast);
433 continue;
434 }
435 StoreInst *ST = nullptr;
436 for (Use &U : Bitcast->uses()) {
437 ST = dyn_cast<StoreInst>(U.getUser());
438 if (ST)
439 break;
440 }
441 if (!ST) {
442 if (transformBitcast(Bitcast))
443 DeadInsts.push_back(Bitcast);
444 continue;
445 }
446 // If bitcast (%13) has one use, combine bitcast and store to amx store.
447 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
448 // %stride);
449 // %13 = bitcast x86_amx %src to <256 x i32>
450 // store <256 x i32> %13, <256 x i32>* %addr, align 64
451 // -->
452 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
453 // %stride64, %13)
454 //
455 // If bitcast (%13) has multi-use, transform as below.
456 // %13 = bitcast x86_amx %src to <256 x i32>
457 // store <256 x i32> %13, <256 x i32>* %addr, align 64
458 // %add = <256 x i32> %13, <256 x i32> %src2
459 // -->
460 // %13 = bitcast x86_amx %src to <256 x i32>
461 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
462 // %stride64, %13)
463 // %14 = load <256 x i32>, %addr
464 // %add = <256 x i32> %14, <256 x i32> %src2
465 //
466 combineBitcastStore(Bitcast, ST);
467 // Delete user first.
468 DeadInsts.push_back(ST);
469 DeadInsts.push_back(Bitcast);
470 }
471 }
472 }
473
474 bool C = !DeadInsts.empty();
475
476 for (auto *Inst : DeadInsts)
477 Inst->eraseFromParent();
478
479 return C;
480}
481} // anonymous namespace
482
484 Function *F = BB->getParent();
485 IRBuilder<> Builder(&F->getEntryBlock().front());
486 const DataLayout &DL = F->getDataLayout();
487 unsigned AllocaAS = DL.getAllocaAddrSpace();
488 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
489 AllocaInst *AllocaRes =
490 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
491 BasicBlock::iterator Iter = AllocaRes->getIterator();
492 ++Iter;
493 Builder.SetInsertPoint(&*Iter);
494 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
495 return I8Ptr;
496}
497
499 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
500 auto *II = cast<IntrinsicInst>(TileDef);
501
502 assert(II && "Not tile intrinsic!");
503 Value *Row = II->getOperand(0);
504 Value *Col = II->getOperand(1);
505
506 BasicBlock *BB = TileDef->getParent();
507 BasicBlock::iterator Iter = TileDef->getIterator();
508 IRBuilder<> Builder(BB, ++Iter);
509 Value *Stride = Builder.getInt64(64);
510 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
511
512 Instruction *TileStore =
513 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
514 return TileStore;
515}
516
517static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
518 Value *V = U.get();
519 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
520
521 // Get tile shape.
522 IntrinsicInst *II = nullptr;
523 if (IsPHI) {
524 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
525 II = cast<IntrinsicInst>(PhiOp);
526 } else {
528 }
529 Value *Row = II->getOperand(0);
530 Value *Col = II->getOperand(1);
531
532 Instruction *UserI = cast<Instruction>(U.getUser());
533 IRBuilder<> Builder(UserI);
534 Value *Stride = Builder.getInt64(64);
535 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
536
537 Value *TileLoad =
538 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
539 UserI->replaceUsesOfWith(V, TileLoad);
540}
541
543 for (Use &U : I->uses()) {
544 User *V = U.getUser();
545 if (isa<PHINode>(V))
546 return true;
547 }
548 return false;
549}
550
551// Let all AMX tile data become volatile data, shorten the life range
552// of each tile register before fast register allocation.
553namespace {
554class X86VolatileTileData {
555 Function &F;
556
557public:
558 X86VolatileTileData(Function &Func) : F(Func) {}
559 Value *updatePhiIncomings(BasicBlock *BB,
560 SmallVector<Instruction *, 2> &Incomings);
561 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
562 bool volatileTileData();
563 void volatileTilePHI(PHINode *PHI);
564 void volatileTileNonPHI(Instruction *I);
565};
566
567Value *X86VolatileTileData::updatePhiIncomings(
568 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
569 Value *I8Ptr = getAllocaPos(BB);
570
571 for (auto *I : Incomings) {
572 User *Store = createTileStore(I, I8Ptr);
573
574 // All its uses (except phi) should load from stored mem.
575 for (Use &U : I->uses()) {
576 User *V = U.getUser();
577 if (isa<PHINode>(V) || V == Store)
578 continue;
579 replaceWithTileLoad(U, I8Ptr);
580 }
581 }
582 return I8Ptr;
583}
584
585void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
586 Value *StorePtr) {
587 for (Use &U : PHI->uses())
588 replaceWithTileLoad(U, StorePtr, true);
589 PHI->eraseFromParent();
590}
591
592// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
593// and their related AMX intrinsics.
594// 1) PHI Def should change to tileload.
595// 2) PHI Incoming Values should tilestored in just after their def.
596// 3) The mem of these tileload and tilestores should be same.
597// e.g.
598// ------------------------------------------------------
599// bb_dom:
600// ...
601// br i1 %bool.cond, label %if.else, label %if.then
602//
603// if.then:
604// def %t0 = ...
605// ...
606// use %t0
607// ...
608// br label %if.end
609//
610// if.else:
611// def %t1 = ...
612// br label %if.end
613//
614// if.end:
615// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
616// ...
617// use %td
618// ------------------------------------------------------
619// -->
620// ------------------------------------------------------
621// bb_entry:
622// %mem = alloca <256 x i32>, align 1024 *
623// ...
624// bb_dom:
625// ...
626// br i1 %bool.cond, label %if.else, label %if.then
627//
628// if.then:
629// def %t0 = ...
630// call void @llvm.x86.tilestored64.internal(mem, %t0) *
631// ...
632// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
633// use %t0` *
634// ...
635// br label %if.end
636//
637// if.else:
638// def %t1 = ...
639// call void @llvm.x86.tilestored64.internal(mem, %t1) *
640// br label %if.end
641//
642// if.end:
643// ...
644// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
645// use %td
646// ------------------------------------------------------
647void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
648 BasicBlock *BB = PHI->getParent();
649 SmallVector<Instruction *, 2> Incomings;
650
651 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
652 Value *Op = PHI->getIncomingValue(I);
654 assert(Inst && "We shouldn't fold AMX instrution!");
655 Incomings.push_back(Inst);
656 }
657
658 Value *StorePtr = updatePhiIncomings(BB, Incomings);
659 replacePhiDefWithLoad(PHI, StorePtr);
660}
661
662// Store the defined tile and load it before use.
663// All its users are not PHI.
664// e.g.
665// ------------------------------------------------------
666// def %td = ...
667// ...
668// "use %td"
669// ------------------------------------------------------
670// -->
671// ------------------------------------------------------
672// def %td = ...
673// call void @llvm.x86.tilestored64.internal(mem, %td)
674// ...
675// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
676// "use %td2"
677// ------------------------------------------------------
678void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
679 BasicBlock *BB = I->getParent();
680 Value *I8Ptr = getAllocaPos(BB);
681 User *Store = createTileStore(I, I8Ptr);
682
683 // All its uses should load from stored mem.
684 for (Use &U : I->uses()) {
685 User *V = U.getUser();
686 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
687 if (V != Store)
688 replaceWithTileLoad(U, I8Ptr);
689 }
690}
691
692// Volatile Tile Model:
693// 1) All the uses of tile data comes from tileload in time.
694// 2) All the defs of tile data tilestore into mem immediately.
695// For example:
696// --------------------------------------------------------------------------
697// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
698// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
699// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
700// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
701// call void @llvm.x86.tilestored64.internal(... td) area
702// --------------------------------------------------------------------------
703// 3) No terminator, call or other amx instructions in the key amx area.
704bool X86VolatileTileData::volatileTileData() {
705 bool Changed = false;
706 for (BasicBlock &BB : F) {
707 SmallVector<Instruction *, 2> PHIInsts;
708 SmallVector<Instruction *, 8> AMXDefInsts;
709
710 for (Instruction &I : BB) {
711 if (!I.getType()->isX86_AMXTy())
712 continue;
713 if (isa<PHINode>(&I))
714 PHIInsts.push_back(&I);
715 else
716 AMXDefInsts.push_back(&I);
717 }
718
719 // First we "volatile" the non-phi related amx intrinsics.
720 for (Instruction *I : AMXDefInsts) {
721 if (isIncomingOfPHI(I))
722 continue;
723 volatileTileNonPHI(I);
724 Changed = true;
725 }
726
727 for (Instruction *I : PHIInsts) {
728 volatileTilePHI(dyn_cast<PHINode>(I));
729 Changed = true;
730 }
731 }
732 return Changed;
733}
734
735} // anonymous namespace
736
737namespace {
738
739class X86LowerAMXCast {
740 Function &Func;
741 std::unique_ptr<DominatorTree> DT;
742
743public:
744 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
745 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
746 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
747 bool combineTilezero(IntrinsicInst *Cast);
748 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
749 bool combineAMXcast(TargetLibraryInfo *TLI);
750 bool transformAMXCast(IntrinsicInst *AMXCast);
751 bool transformAllAMXCast();
752 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
753 SmallSetVector<Instruction *, 16> &DeadInst);
754};
755
756static bool DCEInstruction(Instruction *I,
757 SmallSetVector<Instruction *, 16> &WorkList,
758 const TargetLibraryInfo *TLI) {
759 if (isInstructionTriviallyDead(I, TLI)) {
762
763 // Null out all of the instruction's operands to see if any operand becomes
764 // dead as we go.
765 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
766 Value *OpV = I->getOperand(i);
767 I->setOperand(i, nullptr);
768
769 if (!OpV->use_empty() || I == OpV)
770 continue;
771
772 // If the operand is an instruction that became dead as we nulled out the
773 // operand, and if it is 'trivially' dead, delete it in a future loop
774 // iteration.
775 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
776 if (isInstructionTriviallyDead(OpI, TLI)) {
777 WorkList.insert(OpI);
778 }
779 }
780 }
781 I->eraseFromParent();
782 return true;
783 }
784 return false;
785}
786
787/// This function handles following case
788///
789/// A -> B amxcast
790/// PHI
791/// B -> A amxcast
792///
793/// All the related PHI nodes can be replaced by new PHI nodes with type A.
794/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
795bool X86LowerAMXCast::optimizeAMXCastFromPhi(
796 IntrinsicInst *CI, PHINode *PN,
797 SmallSetVector<Instruction *, 16> &DeadInst) {
798 IRBuilder<> Builder(CI);
799 Value *Src = CI->getOperand(0);
800 Type *SrcTy = Src->getType(); // Type B
801 Type *DestTy = CI->getType(); // Type A
802
803 SmallVector<PHINode *, 4> PhiWorklist;
804 SmallSetVector<PHINode *, 4> OldPhiNodes;
805
806 // Find all of the A->B casts and PHI nodes.
807 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
808 // OldPhiNodes is used to track all known PHI nodes, before adding a new
809 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
810 PhiWorklist.push_back(PN);
811 OldPhiNodes.insert(PN);
812 while (!PhiWorklist.empty()) {
813 auto *OldPN = PhiWorklist.pop_back_val();
814 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
815 Value *IncValue = OldPN->getIncomingValue(I);
816 // TODO: currently, We ignore cases where it is a const. In the future, we
817 // might support const.
818 if (isa<Constant>(IncValue)) {
819 auto *IncConst = dyn_cast<Constant>(IncValue);
820 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
821 return false;
822 Value *Row = nullptr, *Col = nullptr;
823 std::tie(Row, Col) = getShape(OldPN);
824 // TODO: If it is not constant the Row and Col must domoniate tilezero
825 // that we are going to create.
826 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
827 return false;
828 // Create tilezero at the end of incoming block.
829 auto *Block = OldPN->getIncomingBlock(I);
830 BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
831 Instruction *NewInst = Builder.CreateIntrinsic(
832 Intrinsic::x86_tilezero_internal, {}, {Row, Col});
833 NewInst->moveBefore(Iter);
834 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
835 {IncValue->getType()}, {NewInst});
836 NewInst->moveBefore(Iter);
837 // Replace InValue with new Value.
838 OldPN->setIncomingValue(I, NewInst);
839 IncValue = NewInst;
840 }
841
842 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
843 if (OldPhiNodes.insert(PNode))
844 PhiWorklist.push_back(PNode);
845 continue;
846 }
847 Instruction *ACI = dyn_cast<Instruction>(IncValue);
848 if (ACI && isAMXCast(ACI)) {
849 // Verify it's a A->B cast.
850 Type *TyA = ACI->getOperand(0)->getType();
851 Type *TyB = ACI->getType();
852 if (TyA != DestTy || TyB != SrcTy)
853 return false;
854 continue;
855 }
856 return false;
857 }
858 }
859
860 // Check that each user of each old PHI node is something that we can
861 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
862 for (auto *OldPN : OldPhiNodes) {
863 for (User *V : OldPN->users()) {
865 if (ACI && isAMXCast(ACI)) {
866 // Verify it's a B->A cast.
867 Type *TyB = ACI->getOperand(0)->getType();
868 Type *TyA = ACI->getType();
869 if (TyA != DestTy || TyB != SrcTy)
870 return false;
871 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
872 // As long as the user is another old PHI node, then even if we don't
873 // rewrite it, the PHI web we're considering won't have any users
874 // outside itself, so it'll be dead.
875 // example:
876 // bb.0:
877 // %0 = amxcast ...
878 // bb.1:
879 // %1 = amxcast ...
880 // bb.2:
881 // %goodphi = phi %0, %1
882 // %3 = amxcast %goodphi
883 // bb.3:
884 // %goodphi2 = phi %0, %goodphi
885 // %4 = amxcast %goodphi2
886 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
887 // outside the phi-web, so the combination stop When
888 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
889 // will be done.
890 if (OldPhiNodes.count(PHI) == 0)
891 return false;
892 } else
893 return false;
894 }
895 }
896
897 // For each old PHI node, create a corresponding new PHI node with a type A.
898 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
899 for (auto *OldPN : OldPhiNodes) {
900 Builder.SetInsertPoint(OldPN);
901 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
902 NewPNodes[OldPN] = NewPN;
903 }
904
905 // Fill in the operands of new PHI nodes.
906 for (auto *OldPN : OldPhiNodes) {
907 PHINode *NewPN = NewPNodes[OldPN];
908 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
909 Value *V = OldPN->getOperand(j);
910 Value *NewV = nullptr;
912 // There should not be a AMXcast from a const.
913 if (ACI && isAMXCast(ACI))
914 NewV = ACI->getOperand(0);
915 else if (auto *PrevPN = dyn_cast<PHINode>(V))
916 NewV = NewPNodes[PrevPN];
917 assert(NewV);
918 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
919 }
920 }
921
922 // Traverse all accumulated PHI nodes and process its users,
923 // which are Stores and BitcCasts. Without this processing
924 // NewPHI nodes could be replicated and could lead to extra
925 // moves generated after DeSSA.
926 // If there is a store with type B, change it to type A.
927
928 // Replace users of BitCast B->A with NewPHI. These will help
929 // later to get rid of a closure formed by OldPHI nodes.
930 for (auto *OldPN : OldPhiNodes) {
931 PHINode *NewPN = NewPNodes[OldPN];
932 for (User *V : make_early_inc_range(OldPN->users())) {
934 if (ACI && isAMXCast(ACI)) {
935 Type *TyB = ACI->getOperand(0)->getType();
936 Type *TyA = ACI->getType();
937 assert(TyA == DestTy && TyB == SrcTy);
938 (void)TyA;
939 (void)TyB;
940 ACI->replaceAllUsesWith(NewPN);
941 DeadInst.insert(ACI);
942 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
943 // We don't need to push PHINode into DeadInst since they are operands
944 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
945 assert(OldPhiNodes.contains(PHI));
946 (void)PHI;
947 } else
948 llvm_unreachable("all uses should be handled");
949 }
950 }
951 return true;
952}
953
954// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
955// store <256 x i32> %43, <256 x i32>* %p, align 64
956// -->
957// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
958// i64 64, x86_amx %42)
959bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
960 Value *Tile = Cast->getOperand(0);
961
962 assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!");
963
964 // TODO: Specially handle the multi-use case.
965 if (!Tile->hasOneUse())
966 return false;
967
968 auto *II = cast<IntrinsicInst>(Tile);
969 // Tile is output from AMX intrinsic. The first operand of the
970 // intrinsic is row, the second operand of the intrinsic is column.
971 Value *Row = II->getOperand(0);
972 Value *Col = II->getOperand(1);
973
974 IRBuilder<> Builder(ST);
975
976 // Stride should be equal to col(measured by bytes)
977 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
978 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
979 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
980 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
981 return true;
982}
983
984// %65 = load <256 x i32>, <256 x i32>* %p, align 64
985// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
986// -->
987// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
988// i8* %p, i64 64)
989bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
990 bool EraseLoad = true;
991 Value *Row = nullptr, *Col = nullptr;
992 Use &U = *(Cast->use_begin());
993 unsigned OpNo = U.getOperandNo();
994 auto *II = cast<IntrinsicInst>(U.getUser());
995 // TODO: If it is cast intrinsic or phi node, we can propagate the
996 // shape information through def-use chain.
997 if (!isAMXIntrinsic(II))
998 return false;
999 std::tie(Row, Col) = getShape(II, OpNo);
1000 IRBuilder<> Builder(LD);
1001 // Stride should be equal to col(measured by bytes)
1002 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
1003 Value *I8Ptr;
1004
1005 // To save compiling time, we create doninator tree when it is really
1006 // needed.
1007 if (!DT)
1008 DT.reset(new DominatorTree(Func));
1009 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
1010 // store the value to stack and reload it from stack before cast.
1011 auto *AllocaAddr =
1012 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
1013 Builder.SetInsertPoint(&*std::next(LD->getIterator()));
1014 Builder.CreateStore(LD, AllocaAddr);
1015
1016 Builder.SetInsertPoint(Cast);
1017 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1018 EraseLoad = false;
1019 } else {
1020 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
1021 }
1022 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
1023
1024 Value *NewInst =
1025 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
1026 Cast->replaceAllUsesWith(NewInst);
1027
1028 return EraseLoad;
1029}
1030
1031// %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1032// -->
1033// %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1034bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {
1035 Value *Row = nullptr, *Col = nullptr;
1036 Use &U = *(Cast->use_begin());
1037 unsigned OpNo = U.getOperandNo();
1038 auto *II = cast<IntrinsicInst>(U.getUser());
1039 if (!isAMXIntrinsic(II))
1040 return false;
1041
1042 std::tie(Row, Col) = getShape(II, OpNo);
1043
1044 IRBuilder<> Builder(Cast);
1045 Value *NewInst =
1046 Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal, {}, {Row, Col});
1047 Cast->replaceAllUsesWith(NewInst);
1048 return true;
1049}
1050
1051bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
1052 bool Change = false;
1053 for (auto *Cast : Casts) {
1054 auto *II = cast<IntrinsicInst>(Cast);
1055 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
1056 // store <256 x i32> %43, <256 x i32>* %p, align 64
1057 // -->
1058 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1059 // i64 64, x86_amx %42)
1060 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1061 SmallVector<Instruction *, 2> DeadStores;
1062 for (User *U : Cast->users()) {
1063 StoreInst *Store = dyn_cast<StoreInst>(U);
1064 if (!Store)
1065 continue;
1066 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1067 DeadStores.push_back(Store);
1068 Change = true;
1069 }
1070 }
1071 for (auto *Store : DeadStores)
1072 Store->eraseFromParent();
1073 } else { // x86_cast_vector_to_tile
1074 // %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1075 // -->
1076 // %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1077 if (isa<ConstantAggregateZero>(Cast->getOperand(0))) {
1078 Change |= combineTilezero(cast<IntrinsicInst>(Cast));
1079 continue;
1080 }
1081
1082 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1083 if (!Load || !Load->hasOneUse())
1084 continue;
1085 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1086 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1087 // -->
1088 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1089 // i8* %p, i64 64)
1090 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1091 // Set the operand is null so that load instruction can be erased.
1092 Cast->setOperand(0, nullptr);
1093 Load->eraseFromParent();
1094 Change = true;
1095 }
1096 }
1097 }
1098 return Change;
1099}
1100
1101bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1102 bool Change = false;
1103 // Collect tile cast instruction.
1104 SmallVector<Instruction *, 8> Vec2TileInsts;
1105 SmallVector<Instruction *, 8> Tile2VecInsts;
1106 SmallVector<Instruction *, 8> PhiCastWorkList;
1107 SmallSetVector<Instruction *, 16> DeadInst;
1108 for (BasicBlock &BB : Func) {
1109 for (Instruction &I : BB) {
1110 Value *Vec;
1111 if (match(&I,
1113 Vec2TileInsts.push_back(&I);
1115 m_Value(Vec))))
1116 Tile2VecInsts.push_back(&I);
1117 }
1118 }
1119
1120 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1121 for (auto *Inst : Insts) {
1122 for (User *U : Inst->users()) {
1123 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1124 if (!II || II->getIntrinsicID() != IID)
1125 continue;
1126 // T1 = vec2tile V0
1127 // V2 = tile2vec T1
1128 // V3 = OP V2
1129 // -->
1130 // T1 = vec2tile V0
1131 // V2 = tile2vec T1
1132 // V3 = OP V0
1133 II->replaceAllUsesWith(Inst->getOperand(0));
1134 Change = true;
1135 }
1136 }
1137 };
1138
1139 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1140 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1141
1142 SmallVector<Instruction *, 8> LiveCasts;
1143 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1144 for (auto *Inst : Insts) {
1145 if (Inst->use_empty()) {
1146 Inst->eraseFromParent();
1147 Change = true;
1148 } else {
1149 LiveCasts.push_back(Inst);
1150 }
1151 }
1152 };
1153
1154 EraseInst(Vec2TileInsts);
1155 EraseInst(Tile2VecInsts);
1156 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1157 "Vec2Tile and Tile2Vec:\n";
1158 Func.dump());
1159 Change |= combineLdSt(LiveCasts);
1160 EraseInst(LiveCasts);
1161 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1162 "AMXCast and load/store:\n";
1163 Func.dump());
1164
1165 // Handle the A->B->A cast, and there is an intervening PHI node.
1166 for (BasicBlock &BB : Func) {
1167 for (Instruction &I : BB) {
1168 if (isAMXCast(&I)) {
1169 if (isa<PHINode>(I.getOperand(0)))
1170 PhiCastWorkList.push_back(&I);
1171 }
1172 }
1173 }
1174 for (auto *I : PhiCastWorkList) {
1175 // We skip the dead Amxcast.
1176 if (DeadInst.contains(I))
1177 continue;
1178 PHINode *PN = cast<PHINode>(I->getOperand(0));
1179 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1180 DeadInst.insert(PN);
1181 Change = true;
1182 }
1183 }
1184
1185 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1186 // have no uses. We do some DeadCodeElimination for them.
1187 while (!DeadInst.empty()) {
1188 Instruction *I = DeadInst.pop_back_val();
1189 Change |= DCEInstruction(I, DeadInst, TLI);
1190 }
1191 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1192 "optimizeAMXCastFromPhi:\n";
1193 Func.dump());
1194 return Change;
1195}
1196
1197// There might be remaining AMXcast after combineAMXcast and they should be
1198// handled elegantly.
1199bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1200 IRBuilder<> Builder(AMXCast);
1201 AllocaInst *AllocaAddr;
1202 Value *I8Ptr, *Stride;
1203 auto *Src = AMXCast->getOperand(0);
1204
1205 auto Prepare = [&](Type *MemTy) {
1206 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1207 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1208 Stride = Builder.getInt64(64);
1209 };
1210
1211 if (AMXCast->getType()->isX86_AMXTy()) {
1212 // %2 = amxcast <225 x i32> %src to x86_amx
1213 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1214 // i8* %addr3, i64 60, x86_amx %2)
1215 // -->
1216 // %addr = alloca <225 x i32>, align 64
1217 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1218 // %addr2 = bitcast <225 x i32>* %addr to i8*
1219 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1220 // i8* %addr2,
1221 // i64 60)
1222 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1223 // i8* %addr3, i64 60, x86_amx %2)
1224 if (AMXCast->use_empty()) {
1225 AMXCast->eraseFromParent();
1226 return true;
1227 }
1228 Use &U = *(AMXCast->use_begin());
1229 unsigned OpNo = U.getOperandNo();
1230 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1231 if (!II)
1232 return false; // May be bitcast from x86amx to <256 x i32>.
1233 Prepare(AMXCast->getOperand(0)->getType());
1234 Builder.CreateStore(Src, AllocaAddr);
1235 // TODO we can pick an constant operand for the shape.
1236 Value *Row = nullptr, *Col = nullptr;
1237 std::tie(Row, Col) = getShape(II, OpNo);
1238 std::array<Value *, 4> Args = {
1239 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1240 Value *NewInst =
1241 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
1242 AMXCast->replaceAllUsesWith(NewInst);
1243 AMXCast->eraseFromParent();
1244 } else {
1245 // %2 = amxcast x86_amx %src to <225 x i32>
1246 // -->
1247 // %addr = alloca <225 x i32>, align 64
1248 // %addr2 = bitcast <225 x i32>* to i8*
1249 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1250 // i8* %addr2, i64 %stride)
1251 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1252 auto *II = dyn_cast<IntrinsicInst>(Src);
1253 if (!II)
1254 return false; // May be bitcast from <256 x i32> to x86amx.
1255 Prepare(AMXCast->getType());
1256 Value *Row = II->getOperand(0);
1257 Value *Col = II->getOperand(1);
1258 std::array<Value *, 5> Args = {
1259 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1260 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
1261 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1262 AMXCast->replaceAllUsesWith(NewInst);
1263 AMXCast->eraseFromParent();
1264 }
1265
1266 return true;
1267}
1268
1269bool X86LowerAMXCast::transformAllAMXCast() {
1270 bool Change = false;
1271 // Collect tile cast instruction.
1272 SmallVector<Instruction *, 8> WorkLists;
1273 for (BasicBlock &BB : Func) {
1274 for (Instruction &I : BB) {
1275 if (isAMXCast(&I))
1276 WorkLists.push_back(&I);
1277 }
1278 }
1279
1280 for (auto *Inst : WorkLists) {
1281 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1282 }
1283
1284 return Change;
1285}
1286
1287bool lowerAmxType(Function &F, const TargetMachine *TM,
1288 TargetLibraryInfo *TLI) {
1289 // Performance optimization: most code doesn't use AMX, so return early if
1290 // there are no instructions that produce AMX values. This is sufficient, as
1291 // AMX arguments and constants are not allowed -- so any producer of an AMX
1292 // value must be an instruction.
1293 // TODO: find a cheaper way for this, without looking at all instructions.
1294 if (!containsAMXCode(F))
1295 return false;
1296
1297 bool C = false;
1298 X86LowerAMXCast LAC(F);
1299 C |= LAC.combineAMXcast(TLI);
1300 // There might be remaining AMXcast after combineAMXcast and they should be
1301 // handled elegantly.
1302 C |= LAC.transformAllAMXCast();
1303
1304 X86LowerAMXType LAT(F);
1305 C |= LAT.visit();
1306
1307 // Prepare for fast register allocation at O0.
1308 // Todo: May better check the volatile model of AMX code, not just
1309 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1310 if (TM->getOptLevel() == CodeGenOptLevel::None) {
1311 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1312 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1313 // sure the amx data is volatile, that is necessary for AMX fast
1314 // register allocation.
1315 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1316 X86VolatileTileData VTD(F);
1317 C = VTD.volatileTileData() || C;
1318 }
1319 }
1320
1321 return C;
1322}
1323
1324} // anonymous namespace
1325
1328 TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
1329 bool Changed = lowerAmxType(F, TM, &TLI);
1330 if (!Changed)
1331 return PreservedAnalyses::all();
1332
1335 return PA;
1336}
1337
1338namespace {
1339
1340class X86LowerAMXTypeLegacyPass : public FunctionPass {
1341public:
1342 static char ID;
1343
1344 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {}
1345
1346 bool runOnFunction(Function &F) override {
1347 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1348 TargetLibraryInfo *TLI =
1349 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1350 return lowerAmxType(F, TM, TLI);
1351 }
1352
1353 void getAnalysisUsage(AnalysisUsage &AU) const override {
1354 AU.setPreservesCFG();
1355 AU.addRequired<TargetPassConfig>();
1356 AU.addRequired<TargetLibraryInfoWrapperPass>();
1357 }
1358};
1359
1360} // anonymous namespace
1361
1362static const char PassName[] = "Lower AMX type for load/store";
1363char X86LowerAMXTypeLegacyPass::ID = 0;
1364INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1365 false)
1368INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1369 false)
1370
1372 return new X86LowerAMXTypeLegacyPass();
1373}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Rewrite undef for PHI
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition DCE.cpp:55
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
uint64_t IntrinsicInst * II
FunctionAnalysisManager FAM
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file implements a set that has insertion order iteration characteristics.
#define LLVM_DEBUG(...)
Definition Debug.h:114
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg)
static const char PassName[]
static bool isAMXCast(Instruction *II)
static Value * getRowFromCol(Instruction *II, Value *V, unsigned Granularity)
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
static Value * getAllocaPos(BasicBlock *BB)
static bool containsAMXCode(Function &F)
std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
static bool isIncomingOfPHI(Instruction *I)
static bool isAMXIntrinsic(Value *I)
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
an instruction to allocate memory on the stack
void setAlignment(Align Align)
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
LLVM Basic Block Representation.
Definition BasicBlock.h:62
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
This class represents a no-op cast from one type to another.
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition IRBuilder.h:1454
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition IRBuilder.h:517
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2788
LLVM_ABI void moveBefore(InstListType::iterator InsertPos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
bool empty() const
Determine if the SetVector is empty or not.
Definition SetVector.h:98
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:149
value_type pop_back_val()
Definition SetVector.h:277
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition SetVector.h:250
void push_back(const T &Elt)
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
CodeGenOptLevel getOptLevel() const
Returns the optimization level: None, Less, Default, or Aggressive.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI Type * getX86_AMXTy(LLVMContext &C)
Definition Type.cpp:291
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition Type.h:200
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
LLVM_ABI unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition Use.cpp:35
User * getUser() const
Returns the User that contains this Use.
Definition Use.h:61
void setOperand(unsigned i, Value *Val)
Definition User.h:237
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition User.cpp:24
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
iterator_range< user_iterator > users()
Definition Value.h:426
use_iterator use_begin()
Definition Value.h:364
bool use_empty() const
Definition Value.h:346
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM)
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
Changed
Pass manager infrastructure for declaring and invalidating analyses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
@ Bitcast
Perform the operation on a different, but equivalently sized type.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
@ User
could "use" a pointer
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
NodeAddr< FuncNode * > Func
Definition RDFGraph.h:393
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
Definition Utils.cpp:1724
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:632
iterator_range< po_iterator< T > > post_order(const T &G)
LLVM_ABI bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition Local.cpp:402
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
LLVM_ABI bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
DWARFExpression::Operation Op
FunctionPass * createX86LowerAMXTypeLegacyPass()
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.