LLVM 22.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
43#include "llvm/ADT/SetVector.h"
46#include "llvm/CodeGen/Passes.h"
49#include "llvm/IR/Analysis.h"
50#include "llvm/IR/DataLayout.h"
51#include "llvm/IR/Function.h"
52#include "llvm/IR/IRBuilder.h"
55#include "llvm/IR/IntrinsicsX86.h"
56#include "llvm/IR/PassManager.h"
59#include "llvm/Pass.h"
63
64#include <map>
65
66using namespace llvm;
67using namespace PatternMatch;
68
69#define DEBUG_TYPE "x86-lower-amx-type"
70
76
77static bool isAMXIntrinsic(Value *I) {
79 if (!II)
80 return false;
81 if (isAMXCast(II))
82 return false;
83 // Check if return type or parameter is x86_amx. If it is x86_amx
84 // the intrinsic must be x86 amx intrinsics.
85 if (II->getType()->isX86_AMXTy())
86 return true;
87 for (Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
89 return true;
90 }
91
92 return false;
93}
94
95static bool containsAMXCode(Function &F) {
96 for (BasicBlock &BB : F)
97 for (Instruction &I : BB)
98 if (I.getType()->isX86_AMXTy())
99 return true;
100 return false;
101}
102
104 Type *Ty) {
105 Function &F = *BB->getParent();
106 const DataLayout &DL = F.getDataLayout();
107
108 LLVMContext &Ctx = Builder.getContext();
109 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
110 unsigned AllocaAS = DL.getAllocaAddrSpace();
111 AllocaInst *AllocaRes =
112 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
113 AllocaRes->setAlignment(AllocaAlignment);
114 return AllocaRes;
115}
116
118 for (Instruction &I : F.getEntryBlock())
119 if (!isa<AllocaInst>(&I))
120 return &I;
121 llvm_unreachable("No terminator in the entry block!");
122}
123
124static Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity) {
125 IRBuilder<> Builder(II);
126 Value *RealRow = nullptr;
127 if (isa<ConstantInt>(V))
128 RealRow =
129 Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) / Granularity);
130 else if (isa<Instruction>(V)) {
131 // When it is not a const value and it is not a function argument, we
132 // create Row after the definition of V instead of
133 // before II. For example, II is %118, we try to getshape for %117:
134 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
135 // i32> %115).
136 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
137 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
138 // %117).
139 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
140 // definition is after its user(new tileload for %117).
141 // So, the best choice is to create %row right after the definition of
142 // %106.
143 Builder.SetInsertPoint(cast<Instruction>(V));
144 RealRow = Builder.CreateUDiv(V, Builder.getInt16(4));
145 cast<Instruction>(RealRow)->moveAfter(cast<Instruction>(V));
146 } else {
147 // When it is not a const value and it is a function argument, we create
148 // Row at the entry bb.
149 IRBuilder<> NewBuilder(
150 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
151 RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));
152 }
153 return RealRow;
154}
155
156// TODO: Refine the row and col-in-bytes of tile to row and col of matrix.
157std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
158 IRBuilder<> Builder(II);
159 Value *Row = nullptr, *Col = nullptr;
160 switch (II->getIntrinsicID()) {
161 default:
162 llvm_unreachable("Expect amx intrinsics");
163 case Intrinsic::x86_tileloadd64_internal:
164 case Intrinsic::x86_tileloaddt164_internal:
165 case Intrinsic::x86_tilestored64_internal:
166 case Intrinsic::x86_t2rpntlvwz0rs_internal:
167 case Intrinsic::x86_t2rpntlvwz0rst1_internal:
168 case Intrinsic::x86_t2rpntlvwz1rs_internal:
169 case Intrinsic::x86_t2rpntlvwz1rst1_internal:
170 case Intrinsic::x86_tileloaddrs64_internal:
171 case Intrinsic::x86_tileloaddrst164_internal: {
172 Row = II->getArgOperand(0);
173 Col = II->getArgOperand(1);
174 break;
175 }
176 // a * b + c
177 // The shape depends on which operand.
178 case Intrinsic::x86_tcmmimfp16ps_internal:
179 case Intrinsic::x86_tcmmrlfp16ps_internal:
180 case Intrinsic::x86_tdpbssd_internal:
181 case Intrinsic::x86_tdpbsud_internal:
182 case Intrinsic::x86_tdpbusd_internal:
183 case Intrinsic::x86_tdpbuud_internal:
184 case Intrinsic::x86_tdpbf16ps_internal:
185 case Intrinsic::x86_tdpfp16ps_internal:
186 case Intrinsic::x86_tmmultf32ps_internal:
187 case Intrinsic::x86_tdpbf8ps_internal:
188 case Intrinsic::x86_tdpbhf8ps_internal:
189 case Intrinsic::x86_tdphbf8ps_internal:
190 case Intrinsic::x86_tdphf8ps_internal: {
191 switch (OpNo) {
192 case 3:
193 Row = II->getArgOperand(0);
194 Col = II->getArgOperand(1);
195 break;
196 case 4:
197 Row = II->getArgOperand(0);
198 Col = II->getArgOperand(2);
199 break;
200 case 5:
201 Row = getRowFromCol(II, II->getArgOperand(2), 4);
202 Col = II->getArgOperand(1);
203 break;
204 }
205 break;
206 }
207 case Intrinsic::x86_tcvtrowd2ps_internal:
208 case Intrinsic::x86_tcvtrowps2bf16h_internal:
209 case Intrinsic::x86_tcvtrowps2bf16l_internal:
210 case Intrinsic::x86_tcvtrowps2phh_internal:
211 case Intrinsic::x86_tcvtrowps2phl_internal:
212 case Intrinsic::x86_tilemovrow_internal: {
213 assert(OpNo == 2 && "Illegal Operand Number.");
214 Row = II->getArgOperand(0);
215 Col = II->getArgOperand(1);
216 break;
217 }
218 }
219
220 return std::make_pair(Row, Col);
221}
222
223static std::pair<Value *, Value *> getShape(PHINode *Phi) {
224 Use &U = *(Phi->use_begin());
225 unsigned OpNo = U.getOperandNo();
226 User *V = U.getUser();
227 // TODO We don't traverse all users. To make the algorithm simple, here we
228 // just traverse the first user. If we can find shape, then return the shape,
229 // otherwise just return nullptr and the optimization for undef/zero will be
230 // abandoned.
231 while (V) {
233 if (V->use_empty())
234 break;
235 Use &U = *(V->use_begin());
236 OpNo = U.getOperandNo();
237 V = U.getUser();
238 } else if (isAMXIntrinsic(V)) {
239 return getShape(cast<IntrinsicInst>(V), OpNo);
240 } else if (isa<PHINode>(V)) {
241 if (V->use_empty())
242 break;
243 Use &U = *(V->use_begin());
244 V = U.getUser();
245 } else {
246 break;
247 }
248 }
249
250 return std::make_pair(nullptr, nullptr);
251}
252
253namespace {
254class X86LowerAMXType {
255 Function &Func;
256
257 // In AMX intrinsics we let Shape = {Row, Col}, but the
258 // RealCol = Col / ElementSize. We may use the RealCol
259 // as a new Row for other new created AMX intrinsics.
260 std::map<Value *, Value *> Col2Row;
261
262public:
263 X86LowerAMXType(Function &F) : Func(F) {}
264 bool visit();
265 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
266 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
267 bool transformBitcast(BitCastInst *Bitcast);
268};
269
270// %src = load <256 x i32>, <256 x i32>* %addr, align 64
271// %2 = bitcast <256 x i32> %src to x86_amx
272// -->
273// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
274// i8* %addr, i64 %stride64)
275void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
276 Value *Row = nullptr, *Col = nullptr;
277 Use &U = *(Bitcast->use_begin());
278 unsigned OpNo = U.getOperandNo();
279 auto *II = cast<IntrinsicInst>(U.getUser());
280 std::tie(Row, Col) = getShape(II, OpNo);
281 IRBuilder<> Builder(Bitcast);
282 // Use the maximun column as stride.
283 Value *Stride = Builder.getInt64(64);
284 Value *I8Ptr = LD->getOperand(0);
285 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
286
287 Value *NewInst =
288 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
289 Bitcast->replaceAllUsesWith(NewInst);
290}
291
292// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
293// %stride);
294// %13 = bitcast x86_amx %src to <256 x i32>
295// store <256 x i32> %13, <256 x i32>* %addr, align 64
296// -->
297// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
298// %stride64, %13)
299void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
300
301 Value *Tile = Bitcast->getOperand(0);
302 auto *II = cast<IntrinsicInst>(Tile);
303 // Tile is output from AMX intrinsic. The first operand of the
304 // intrinsic is row, the second operand of the intrinsic is column.
305 Value *Row = II->getOperand(0);
306 Value *Col = II->getOperand(1);
307 IRBuilder<> Builder(ST);
308 // Use the maximum column as stride. It must be the same with load
309 // stride.
310 Value *Stride = Builder.getInt64(64);
311 Value *I8Ptr = ST->getOperand(1);
312 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
313 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
314 if (Bitcast->hasOneUse())
315 return;
316 // %13 = bitcast x86_amx %src to <256 x i32>
317 // store <256 x i32> %13, <256 x i32>* %addr, align 64
318 // %add = <256 x i32> %13, <256 x i32> %src2
319 // -->
320 // %13 = bitcast x86_amx %src to <256 x i32>
321 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
322 // %stride64, %13)
323 // %14 = load <256 x i32>, %addr
324 // %add = <256 x i32> %14, <256 x i32> %src2
325 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
326 Bitcast->replaceAllUsesWith(Vec);
327}
328
329// transform bitcast to <store, load> instructions.
330bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
331 IRBuilder<> Builder(Bitcast);
332 AllocaInst *AllocaAddr;
333 Value *I8Ptr, *Stride;
334 auto *Src = Bitcast->getOperand(0);
335
336 auto Prepare = [&](Type *MemTy) {
337 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
338 I8Ptr = AllocaAddr;
339 Stride = Builder.getInt64(64);
340 };
341
342 if (Bitcast->getType()->isX86_AMXTy()) {
343 // %2 = bitcast <256 x i32> %src to x86_amx
344 // -->
345 // %addr = alloca <256 x i32>, align 64
346 // store <256 x i32> %src, <256 x i32>* %addr, align 64
347 // %addr2 = bitcast <256 x i32>* to i8*
348 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
349 // i8* %addr2,
350 // i64 64)
351 Use &U = *(Bitcast->use_begin());
352 unsigned OpNo = U.getOperandNo();
353 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
354 if (!II)
355 return false; // May be bitcast from x86amx to <256 x i32>.
356 Prepare(Bitcast->getOperand(0)->getType());
357 Builder.CreateStore(Src, AllocaAddr);
358 // TODO we can pick an constant operand for the shape.
359 Value *Row = nullptr, *Col = nullptr;
360 std::tie(Row, Col) = getShape(II, OpNo);
361 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
362 Value *NewInst =
363 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
364 Bitcast->replaceAllUsesWith(NewInst);
365 } else {
366 // %2 = bitcast x86_amx %src to <256 x i32>
367 // -->
368 // %addr = alloca <256 x i32>, align 64
369 // %addr2 = bitcast <256 x i32>* to i8*
370 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
371 // i8* %addr2, i64 %stride)
372 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
373 auto *II = dyn_cast<IntrinsicInst>(Src);
374 if (!II)
375 return false; // May be bitcast from <256 x i32> to x86amx.
376 Prepare(Bitcast->getType());
377 Value *Row = II->getOperand(0);
378 Value *Col = II->getOperand(1);
379 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
380 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
381 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
382 Bitcast->replaceAllUsesWith(NewInst);
383 }
384
385 return true;
386}
387
388bool X86LowerAMXType::visit() {
389 SmallVector<Instruction *, 8> DeadInsts;
390 Col2Row.clear();
391
392 for (BasicBlock *BB : post_order(&Func)) {
393 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
394 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
395 if (!Bitcast)
396 continue;
397
398 Value *Src = Bitcast->getOperand(0);
399 if (Bitcast->getType()->isX86_AMXTy()) {
400 if (Bitcast->user_empty()) {
401 DeadInsts.push_back(Bitcast);
402 continue;
403 }
404 LoadInst *LD = dyn_cast<LoadInst>(Src);
405 if (!LD) {
406 if (transformBitcast(Bitcast))
407 DeadInsts.push_back(Bitcast);
408 continue;
409 }
410 // If load has multi-user, duplicate a vector load.
411 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
412 // %2 = bitcast <256 x i32> %src to x86_amx
413 // %add = add <256 x i32> %src, <256 x i32> %src2
414 // -->
415 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
416 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
417 // i8* %addr, i64 %stride64)
418 // %add = add <256 x i32> %src, <256 x i32> %src2
419
420 // If load has one user, the load will be eliminated in DAG ISel.
421 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
422 // %2 = bitcast <256 x i32> %src to x86_amx
423 // -->
424 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
425 // i8* %addr, i64 %stride64)
426 combineLoadBitcast(LD, Bitcast);
427 DeadInsts.push_back(Bitcast);
428 if (LD->hasOneUse())
429 DeadInsts.push_back(LD);
430 } else if (Src->getType()->isX86_AMXTy()) {
431 if (Bitcast->user_empty()) {
432 DeadInsts.push_back(Bitcast);
433 continue;
434 }
435 StoreInst *ST = nullptr;
436 for (Use &U : Bitcast->uses()) {
437 ST = dyn_cast<StoreInst>(U.getUser());
438 if (ST)
439 break;
440 }
441 if (!ST) {
442 if (transformBitcast(Bitcast))
443 DeadInsts.push_back(Bitcast);
444 continue;
445 }
446 // If bitcast (%13) has one use, combine bitcast and store to amx store.
447 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
448 // %stride);
449 // %13 = bitcast x86_amx %src to <256 x i32>
450 // store <256 x i32> %13, <256 x i32>* %addr, align 64
451 // -->
452 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
453 // %stride64, %13)
454 //
455 // If bitcast (%13) has multi-use, transform as below.
456 // %13 = bitcast x86_amx %src to <256 x i32>
457 // store <256 x i32> %13, <256 x i32>* %addr, align 64
458 // %add = <256 x i32> %13, <256 x i32> %src2
459 // -->
460 // %13 = bitcast x86_amx %src to <256 x i32>
461 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
462 // %stride64, %13)
463 // %14 = load <256 x i32>, %addr
464 // %add = <256 x i32> %14, <256 x i32> %src2
465 //
466 combineBitcastStore(Bitcast, ST);
467 // Delete user first.
468 DeadInsts.push_back(ST);
469 DeadInsts.push_back(Bitcast);
470 }
471 }
472 }
473
474 bool C = !DeadInsts.empty();
475
476 for (auto *Inst : DeadInsts)
477 Inst->eraseFromParent();
478
479 return C;
480}
481} // anonymous namespace
482
484 Function *F = BB->getParent();
485 IRBuilder<> Builder(&F->getEntryBlock().front());
486 const DataLayout &DL = F->getDataLayout();
487 unsigned AllocaAS = DL.getAllocaAddrSpace();
488 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
489 AllocaInst *AllocaRes =
490 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
491 BasicBlock::iterator Iter = AllocaRes->getIterator();
492 ++Iter;
493 Builder.SetInsertPoint(&*Iter);
494 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
495 return I8Ptr;
496}
497
499 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
500 auto *II = cast<IntrinsicInst>(TileDef);
501
502 assert(II && "Not tile intrinsic!");
503 Value *Row = II->getOperand(0);
504 Value *Col = II->getOperand(1);
505
506 BasicBlock *BB = TileDef->getParent();
507 BasicBlock::iterator Iter = TileDef->getIterator();
508 IRBuilder<> Builder(BB, ++Iter);
509 Value *Stride = Builder.getInt64(64);
510 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
511
512 Instruction *TileStore =
513 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
514 return TileStore;
515}
516
517static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
518 Value *V = U.get();
519 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
520
521 // Get tile shape.
522 IntrinsicInst *II = nullptr;
523 if (IsPHI) {
524 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
525 II = cast<IntrinsicInst>(PhiOp);
526 } else {
528 }
529 Value *Row = II->getOperand(0);
530 Value *Col = II->getOperand(1);
531
532 Instruction *UserI = cast<Instruction>(U.getUser());
533 IRBuilder<> Builder(UserI);
534 Value *Stride = Builder.getInt64(64);
535 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
536
537 Value *TileLoad =
538 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
539 UserI->replaceUsesOfWith(V, TileLoad);
540}
541
543 for (Use &U : I->uses()) {
544 User *V = U.getUser();
545 if (isa<PHINode>(V))
546 return true;
547 }
548 return false;
549}
550
551// Let all AMX tile data become volatile data, shorten the life range
552// of each tile register before fast register allocation.
553namespace {
554class X86VolatileTileData {
555 Function &F;
556
557public:
558 X86VolatileTileData(Function &Func) : F(Func) {}
559 Value *updatePhiIncomings(BasicBlock *BB,
560 SmallVector<Instruction *, 2> &Incomings);
561 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
562 bool volatileTileData();
563 void volatileTilePHI(PHINode *PHI);
564 void volatileTileNonPHI(Instruction *I);
565};
566
567Value *X86VolatileTileData::updatePhiIncomings(
568 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
569 Value *I8Ptr = getAllocaPos(BB);
570
571 for (auto *I : Incomings) {
572 User *Store = createTileStore(I, I8Ptr);
573
574 // All its uses (except phi) should load from stored mem.
575 for (Use &U : I->uses()) {
576 User *V = U.getUser();
577 if (isa<PHINode>(V) || V == Store)
578 continue;
579 replaceWithTileLoad(U, I8Ptr);
580 }
581 }
582 return I8Ptr;
583}
584
585void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
586 Value *StorePtr) {
587 for (Use &U : PHI->uses())
588 replaceWithTileLoad(U, StorePtr, true);
589 PHI->eraseFromParent();
590}
591
592// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
593// and their related AMX intrinsics.
594// 1) PHI Def should change to tileload.
595// 2) PHI Incoming Values should tilestored in just after their def.
596// 3) The mem of these tileload and tilestores should be same.
597// e.g.
598// ------------------------------------------------------
599// bb_dom:
600// ...
601// br i1 %bool.cond, label %if.else, label %if.then
602//
603// if.then:
604// def %t0 = ...
605// ...
606// use %t0
607// ...
608// br label %if.end
609//
610// if.else:
611// def %t1 = ...
612// br label %if.end
613//
614// if.end:
615// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
616// ...
617// use %td
618// ------------------------------------------------------
619// -->
620// ------------------------------------------------------
621// bb_entry:
622// %mem = alloca <256 x i32>, align 1024 *
623// ...
624// bb_dom:
625// ...
626// br i1 %bool.cond, label %if.else, label %if.then
627//
628// if.then:
629// def %t0 = ...
630// call void @llvm.x86.tilestored64.internal(mem, %t0) *
631// ...
632// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
633// use %t0` *
634// ...
635// br label %if.end
636//
637// if.else:
638// def %t1 = ...
639// call void @llvm.x86.tilestored64.internal(mem, %t1) *
640// br label %if.end
641//
642// if.end:
643// ...
644// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
645// use %td
646// ------------------------------------------------------
647void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
648 BasicBlock *BB = PHI->getParent();
649 SmallVector<Instruction *, 2> Incomings;
650
651 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
652 Value *Op = PHI->getIncomingValue(I);
654 assert(Inst && "We shouldn't fold AMX instrution!");
655 Incomings.push_back(Inst);
656 }
657
658 Value *StorePtr = updatePhiIncomings(BB, Incomings);
659 replacePhiDefWithLoad(PHI, StorePtr);
660}
661
662// Store the defined tile and load it before use.
663// All its users are not PHI.
664// e.g.
665// ------------------------------------------------------
666// def %td = ...
667// ...
668// "use %td"
669// ------------------------------------------------------
670// -->
671// ------------------------------------------------------
672// def %td = ...
673// call void @llvm.x86.tilestored64.internal(mem, %td)
674// ...
675// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
676// "use %td2"
677// ------------------------------------------------------
678void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
679 BasicBlock *BB = I->getParent();
680 Value *I8Ptr = getAllocaPos(BB);
681 User *Store = createTileStore(I, I8Ptr);
682
683 // All its uses should load from stored mem.
684 for (Use &U : I->uses()) {
685 User *V = U.getUser();
686 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
687 if (V != Store)
688 replaceWithTileLoad(U, I8Ptr);
689 }
690}
691
692// Volatile Tile Model:
693// 1) All the uses of tile data comes from tileload in time.
694// 2) All the defs of tile data tilestore into mem immediately.
695// For example:
696// --------------------------------------------------------------------------
697// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
698// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
699// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
700// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
701// call void @llvm.x86.tilestored64.internal(... td) area
702// --------------------------------------------------------------------------
703// 3) No terminator, call or other amx instructions in the key amx area.
704bool X86VolatileTileData::volatileTileData() {
705 bool Changed = false;
706 for (BasicBlock &BB : F) {
707 SmallVector<Instruction *, 2> PHIInsts;
708 SmallVector<Instruction *, 8> AMXDefInsts;
709
710 for (Instruction &I : BB) {
711 if (!I.getType()->isX86_AMXTy())
712 continue;
713 if (isa<PHINode>(&I))
714 PHIInsts.push_back(&I);
715 else
716 AMXDefInsts.push_back(&I);
717 }
718
719 // First we "volatile" the non-phi related amx intrinsics.
720 for (Instruction *I : AMXDefInsts) {
721 if (isIncomingOfPHI(I))
722 continue;
723 volatileTileNonPHI(I);
724 Changed = true;
725 }
726
727 for (Instruction *I : PHIInsts) {
728 volatileTilePHI(dyn_cast<PHINode>(I));
729 Changed = true;
730 }
731 }
732 return Changed;
733}
734
735} // anonymous namespace
736
737namespace {
738
739class X86LowerAMXCast {
740 Function &Func;
741 std::unique_ptr<DominatorTree> DT;
742
743public:
744 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
745 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
746 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
747 bool combineTilezero(IntrinsicInst *Cast);
748 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
749 bool combineAMXcast(TargetLibraryInfo *TLI);
750 bool transformAMXCast(IntrinsicInst *AMXCast);
751 bool transformAllAMXCast();
752 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
753 SmallSetVector<Instruction *, 16> &DeadInst);
754};
755
756static bool DCEInstruction(Instruction *I,
757 SmallSetVector<Instruction *, 16> &WorkList,
758 const TargetLibraryInfo *TLI) {
759 if (isInstructionTriviallyDead(I, TLI)) {
762
763 // Null out all of the instruction's operands to see if any operand becomes
764 // dead as we go.
765 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
766 Value *OpV = I->getOperand(i);
767 I->setOperand(i, nullptr);
768
769 if (!OpV->use_empty() || I == OpV)
770 continue;
771
772 // If the operand is an instruction that became dead as we nulled out the
773 // operand, and if it is 'trivially' dead, delete it in a future loop
774 // iteration.
775 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
776 if (isInstructionTriviallyDead(OpI, TLI)) {
777 WorkList.insert(OpI);
778 }
779 }
780 }
781 I->eraseFromParent();
782 return true;
783 }
784 return false;
785}
786
787/// This function handles following case
788///
789/// A -> B amxcast
790/// PHI
791/// B -> A amxcast
792///
793/// All the related PHI nodes can be replaced by new PHI nodes with type A.
794/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
795bool X86LowerAMXCast::optimizeAMXCastFromPhi(
796 IntrinsicInst *CI, PHINode *PN,
797 SmallSetVector<Instruction *, 16> &DeadInst) {
798 IRBuilder<> Builder(CI);
799 Value *Src = CI->getOperand(0);
800 Type *SrcTy = Src->getType(); // Type B
801 Type *DestTy = CI->getType(); // Type A
802
803 SmallVector<PHINode *, 4> PhiWorklist;
804 SmallSetVector<PHINode *, 4> OldPhiNodes;
805
806 // Find all of the A->B casts and PHI nodes.
807 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
808 // OldPhiNodes is used to track all known PHI nodes, before adding a new
809 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
810 PhiWorklist.push_back(PN);
811 OldPhiNodes.insert(PN);
812 while (!PhiWorklist.empty()) {
813 auto *OldPN = PhiWorklist.pop_back_val();
814 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
815 Value *IncValue = OldPN->getIncomingValue(I);
816 // TODO: currently, We ignore cases where it is a const. In the future, we
817 // might support const.
818 if (isa<Constant>(IncValue)) {
819 auto *IncConst = dyn_cast<Constant>(IncValue);
820 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
821 return false;
822 Value *Row = nullptr, *Col = nullptr;
823 std::tie(Row, Col) = getShape(OldPN);
824 // TODO: If it is not constant the Row and Col must domoniate tilezero
825 // that we are going to create.
826 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
827 return false;
828 // Create tilezero at the end of incoming block.
829 auto *Block = OldPN->getIncomingBlock(I);
830 BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
831 Instruction *NewInst = Builder.CreateIntrinsic(
832 Intrinsic::x86_tilezero_internal, {}, {Row, Col});
833 NewInst->moveBefore(Iter);
834 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
835 {IncValue->getType()}, {NewInst});
836 NewInst->moveBefore(Iter);
837 // Replace InValue with new Value.
838 OldPN->setIncomingValue(I, NewInst);
839 IncValue = NewInst;
840 }
841
842 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
843 if (OldPhiNodes.insert(PNode))
844 PhiWorklist.push_back(PNode);
845 continue;
846 }
847 Instruction *ACI = dyn_cast<Instruction>(IncValue);
848 if (ACI && isAMXCast(ACI)) {
849 // Verify it's a A->B cast.
850 Type *TyA = ACI->getOperand(0)->getType();
851 Type *TyB = ACI->getType();
852 if (TyA != DestTy || TyB != SrcTy)
853 return false;
854 continue;
855 }
856 return false;
857 }
858 }
859
860 // Check that each user of each old PHI node is something that we can
861 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
862 for (auto *OldPN : OldPhiNodes) {
863 for (User *V : OldPN->users()) {
865 if (ACI && isAMXCast(ACI)) {
866 // Verify it's a B->A cast.
867 Type *TyB = ACI->getOperand(0)->getType();
868 Type *TyA = ACI->getType();
869 if (TyA != DestTy || TyB != SrcTy)
870 return false;
871 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
872 // As long as the user is another old PHI node, then even if we don't
873 // rewrite it, the PHI web we're considering won't have any users
874 // outside itself, so it'll be dead.
875 // example:
876 // bb.0:
877 // %0 = amxcast ...
878 // bb.1:
879 // %1 = amxcast ...
880 // bb.2:
881 // %goodphi = phi %0, %1
882 // %3 = amxcast %goodphi
883 // bb.3:
884 // %goodphi2 = phi %0, %goodphi
885 // %4 = amxcast %goodphi2
886 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
887 // outside the phi-web, so the combination stop When
888 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
889 // will be done.
890 if (OldPhiNodes.count(PHI) == 0)
891 return false;
892 } else
893 return false;
894 }
895 }
896
897 // For each old PHI node, create a corresponding new PHI node with a type A.
898 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
899 for (auto *OldPN : OldPhiNodes) {
900 Builder.SetInsertPoint(OldPN);
901 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
902 NewPNodes[OldPN] = NewPN;
903 }
904
905 // Fill in the operands of new PHI nodes.
906 for (auto *OldPN : OldPhiNodes) {
907 PHINode *NewPN = NewPNodes[OldPN];
908 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
909 Value *V = OldPN->getOperand(j);
910 Value *NewV = nullptr;
912 // There should not be a AMXcast from a const.
913 if (ACI && isAMXCast(ACI))
914 NewV = ACI->getOperand(0);
915 else if (auto *PrevPN = dyn_cast<PHINode>(V))
916 NewV = NewPNodes[PrevPN];
917 assert(NewV);
918 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
919 }
920 }
921
922 // Traverse all accumulated PHI nodes and process its users,
923 // which are Stores and BitcCasts. Without this processing
924 // NewPHI nodes could be replicated and could lead to extra
925 // moves generated after DeSSA.
926 // If there is a store with type B, change it to type A.
927
928 // Replace users of BitCast B->A with NewPHI. These will help
929 // later to get rid of a closure formed by OldPHI nodes.
930 for (auto *OldPN : OldPhiNodes) {
931 PHINode *NewPN = NewPNodes[OldPN];
932 for (User *V : make_early_inc_range(OldPN->users())) {
934 if (ACI && isAMXCast(ACI)) {
935 Type *TyB = ACI->getOperand(0)->getType();
936 Type *TyA = ACI->getType();
937 assert(TyA == DestTy && TyB == SrcTy);
938 (void)TyA;
939 (void)TyB;
940 ACI->replaceAllUsesWith(NewPN);
941 DeadInst.insert(ACI);
942 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
943 // We don't need to push PHINode into DeadInst since they are operands
944 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
945 assert(OldPhiNodes.contains(PHI));
946 (void)PHI;
947 } else
948 llvm_unreachable("all uses should be handled");
949 }
950 }
951 return true;
952}
953
954// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
955// store <256 x i32> %43, <256 x i32>* %p, align 64
956// -->
957// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
958// i64 64, x86_amx %42)
959bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
960 Value *Tile = Cast->getOperand(0);
961
962 assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!");
963
964 // TODO: Specially handle the multi-use case.
965 if (!Tile->hasOneUse())
966 return false;
967
968 auto *II = cast<IntrinsicInst>(Tile);
969 // Tile is output from AMX intrinsic. The first operand of the
970 // intrinsic is row, the second operand of the intrinsic is column.
971 Value *Row = II->getOperand(0);
972 Value *Col = II->getOperand(1);
973
974 IRBuilder<> Builder(ST);
975
976 // Stride should be equal to col(measured by bytes)
977 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
978 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
979 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
980 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
981 return true;
982}
983
984// %65 = load <256 x i32>, <256 x i32>* %p, align 64
985// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
986// -->
987// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
988// i8* %p, i64 64)
989bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
990 bool EraseLoad = true;
991 Value *Row = nullptr, *Col = nullptr;
992 Use &U = *(Cast->use_begin());
993 unsigned OpNo = U.getOperandNo();
994 auto *II = cast<IntrinsicInst>(U.getUser());
995 // TODO: If it is cast intrinsic or phi node, we can propagate the
996 // shape information through def-use chain.
997 if (!isAMXIntrinsic(II))
998 return false;
999 std::tie(Row, Col) = getShape(II, OpNo);
1000 IRBuilder<> Builder(LD);
1001 // Stride should be equal to col(measured by bytes)
1002 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
1003 Value *I8Ptr;
1004
1005 // To save compiling time, we create doninator tree when it is really
1006 // needed.
1007 if (!DT)
1008 DT.reset(new DominatorTree(Func));
1009 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
1010 // store the value to stack and reload it from stack before cast.
1011 auto *AllocaAddr =
1012 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
1013 Builder.SetInsertPoint(&*std::next(LD->getIterator()));
1014 Builder.CreateStore(LD, AllocaAddr);
1015
1016 Builder.SetInsertPoint(Cast);
1017 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1018 EraseLoad = false;
1019 } else {
1020 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
1021 }
1022 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
1023
1024 Value *NewInst =
1025 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
1026 Cast->replaceAllUsesWith(NewInst);
1027
1028 return EraseLoad;
1029}
1030
1031// %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1032// -->
1033// %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1034bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {
1035 Value *Row = nullptr, *Col = nullptr;
1036 Use &U = *(Cast->use_begin());
1037 unsigned OpNo = U.getOperandNo();
1038 auto *II = cast<IntrinsicInst>(U.getUser());
1039 if (!isAMXIntrinsic(II))
1040 return false;
1041
1042 std::tie(Row, Col) = getShape(II, OpNo);
1043
1044 IRBuilder<> Builder(Cast);
1045 Value *NewInst =
1046 Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal, {}, {Row, Col});
1047 Cast->replaceAllUsesWith(NewInst);
1048 return true;
1049}
1050
1051bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
1052 bool Change = false;
1053 for (auto *Cast : Casts) {
1054 auto *II = cast<IntrinsicInst>(Cast);
1055 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
1056 // store <256 x i32> %43, <256 x i32>* %p, align 64
1057 // -->
1058 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1059 // i64 64, x86_amx %42)
1060 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1061 SmallVector<Instruction *, 2> DeadStores;
1062 for (User *U : Cast->users()) {
1063 StoreInst *Store = dyn_cast<StoreInst>(U);
1064 if (!Store)
1065 continue;
1066 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1067 DeadStores.push_back(Store);
1068 Change = true;
1069 }
1070 }
1071 for (auto *Store : DeadStores)
1072 Store->eraseFromParent();
1073 } else { // x86_cast_vector_to_tile
1074 // %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1075 // -->
1076 // %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1077 if (isa<ConstantAggregateZero>(Cast->getOperand(0))) {
1078 Change |= combineTilezero(cast<IntrinsicInst>(Cast));
1079 continue;
1080 }
1081
1082 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1083 if (!Load || !Load->hasOneUse())
1084 continue;
1085 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1086 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1087 // -->
1088 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1089 // i8* %p, i64 64)
1090 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1091 // Set the operand is null so that load instruction can be erased.
1092 Cast->setOperand(0, nullptr);
1093 Load->eraseFromParent();
1094 Change = true;
1095 }
1096 }
1097 }
1098 return Change;
1099}
1100
1101bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1102 bool Change = false;
1103 // Collect tile cast instruction.
1104 SmallVector<Instruction *, 8> Vec2TileInsts;
1105 SmallVector<Instruction *, 8> Tile2VecInsts;
1106 SmallVector<Instruction *, 8> PhiCastWorkList;
1107 SmallSetVector<Instruction *, 16> DeadInst;
1108 for (BasicBlock &BB : Func) {
1109 for (Instruction &I : BB) {
1110 Value *Vec;
1111 if (match(&I,
1113 Vec2TileInsts.push_back(&I);
1115 m_Value(Vec))))
1116 Tile2VecInsts.push_back(&I);
1117 }
1118 }
1119
1120 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1121 for (auto *Inst : Insts) {
1122 for (User *U : Inst->users()) {
1123 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1124 if (!II || II->getIntrinsicID() != IID)
1125 continue;
1126 // T1 = vec2tile V0
1127 // V2 = tile2vec T1
1128 // V3 = OP V2
1129 // -->
1130 // T1 = vec2tile V0
1131 // V2 = tile2vec T1
1132 // V3 = OP V0
1133 II->replaceAllUsesWith(Inst->getOperand(0));
1134 Change = true;
1135 }
1136 }
1137 };
1138
1139 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1140 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1141
1142 SmallVector<Instruction *, 8> LiveCasts;
1143 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1144 for (auto *Inst : Insts) {
1145 if (Inst->use_empty()) {
1146 Inst->eraseFromParent();
1147 Change = true;
1148 } else {
1149 LiveCasts.push_back(Inst);
1150 }
1151 }
1152 };
1153
1154 EraseInst(Vec2TileInsts);
1155 EraseInst(Tile2VecInsts);
1156 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1157 "Vec2Tile and Tile2Vec:\n";
1158 Func.dump());
1159 Change |= combineLdSt(LiveCasts);
1160 EraseInst(LiveCasts);
1161 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1162 "AMXCast and load/store:\n";
1163 Func.dump());
1164
1165 // Handle the A->B->A cast, and there is an intervening PHI node.
1166 for (BasicBlock &BB : Func) {
1167 for (Instruction &I : BB) {
1168 if (isAMXCast(&I)) {
1169 if (isa<PHINode>(I.getOperand(0)))
1170 PhiCastWorkList.push_back(&I);
1171 }
1172 }
1173 }
1174 for (auto *I : PhiCastWorkList) {
1175 // We skip the dead Amxcast.
1176 if (DeadInst.contains(I))
1177 continue;
1178 PHINode *PN = cast<PHINode>(I->getOperand(0));
1179 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1180 DeadInst.insert(PN);
1181 Change = true;
1182 }
1183 }
1184
1185 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1186 // have no uses. We do some DeadCodeElimination for them.
1187 while (!DeadInst.empty()) {
1188 Instruction *I = DeadInst.pop_back_val();
1189 Change |= DCEInstruction(I, DeadInst, TLI);
1190 }
1191 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1192 "optimizeAMXCastFromPhi:\n";
1193 Func.dump());
1194 return Change;
1195}
1196
1197// There might be remaining AMXcast after combineAMXcast and they should be
1198// handled elegantly.
1199bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1200 IRBuilder<> Builder(AMXCast);
1201 AllocaInst *AllocaAddr;
1202 Value *I8Ptr, *Stride;
1203 auto *Src = AMXCast->getOperand(0);
1204
1205 auto Prepare = [&](Type *MemTy) {
1206 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1207 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1208 Stride = Builder.getInt64(64);
1209 };
1210
1211 if (AMXCast->getType()->isX86_AMXTy()) {
1212 // %2 = amxcast <225 x i32> %src to x86_amx
1213 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1214 // i8* %addr3, i64 60, x86_amx %2)
1215 // -->
1216 // %addr = alloca <225 x i32>, align 64
1217 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1218 // %addr2 = bitcast <225 x i32>* %addr to i8*
1219 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1220 // i8* %addr2,
1221 // i64 60)
1222 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1223 // i8* %addr3, i64 60, x86_amx %2)
1224 if (AMXCast->use_empty()) {
1225 AMXCast->eraseFromParent();
1226 return true;
1227 }
1228 Use &U = *(AMXCast->use_begin());
1229 unsigned OpNo = U.getOperandNo();
1230 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1231 if (!II)
1232 return false; // May be bitcast from x86amx to <256 x i32>.
1233 Prepare(AMXCast->getOperand(0)->getType());
1234 Builder.CreateStore(Src, AllocaAddr);
1235 // TODO we can pick an constant operand for the shape.
1236 Value *Row = nullptr, *Col = nullptr;
1237 std::tie(Row, Col) = getShape(II, OpNo);
1238 std::array<Value *, 4> Args = {
1239 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1240 Value *NewInst =
1241 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
1242 AMXCast->replaceAllUsesWith(NewInst);
1243 AMXCast->eraseFromParent();
1244 } else {
1245 // %2 = amxcast x86_amx %src to <225 x i32>
1246 // -->
1247 // %addr = alloca <225 x i32>, align 64
1248 // %addr2 = bitcast <225 x i32>* to i8*
1249 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1250 // i8* %addr2, i64 %stride)
1251 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1252 auto *II = dyn_cast<IntrinsicInst>(Src);
1253 if (!II)
1254 return false; // May be bitcast from <256 x i32> to x86amx.
1255 Prepare(AMXCast->getType());
1256 Value *Row = II->getOperand(0);
1257 Value *Col = II->getOperand(1);
1258 std::array<Value *, 5> Args = {
1259 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1260 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
1261 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1262 AMXCast->replaceAllUsesWith(NewInst);
1263 AMXCast->eraseFromParent();
1264 }
1265
1266 return true;
1267}
1268
1269bool X86LowerAMXCast::transformAllAMXCast() {
1270 bool Change = false;
1271 // Collect tile cast instruction.
1272 SmallVector<Instruction *, 8> WorkLists;
1273 for (BasicBlock &BB : Func) {
1274 for (Instruction &I : BB) {
1275 if (isAMXCast(&I))
1276 WorkLists.push_back(&I);
1277 }
1278 }
1279
1280 for (auto *Inst : WorkLists) {
1281 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1282 }
1283
1284 return Change;
1285}
1286
1287bool lowerAmxType(Function &F, const TargetMachine *TM,
1288 TargetLibraryInfo *TLI) {
1289 // Performance optimization: most code doesn't use AMX, so return early if
1290 // there are no instructions that produce AMX values. This is sufficient, as
1291 // AMX arguments and constants are not allowed -- so any producer of an AMX
1292 // value must be an instruction.
1293 // TODO: find a cheaper way for this, without looking at all instructions.
1294 if (!containsAMXCode(F))
1295 return false;
1296
1297 bool C = false;
1298 X86LowerAMXCast LAC(F);
1299 C |= LAC.combineAMXcast(TLI);
1300 // There might be remaining AMXcast after combineAMXcast and they should be
1301 // handled elegantly.
1302 C |= LAC.transformAllAMXCast();
1303
1304 X86LowerAMXType LAT(F);
1305 C |= LAT.visit();
1306
1307 // Prepare for fast register allocation at O0.
1308 // Todo: May better check the volatile model of AMX code, not just
1309 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1310 if (TM->getOptLevel() == CodeGenOptLevel::None) {
1311 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1312 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1313 // sure the amx data is volatile, that is necessary for AMX fast
1314 // register allocation.
1315 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1316 X86VolatileTileData VTD(F);
1317 C = VTD.volatileTileData() || C;
1318 }
1319 }
1320
1321 return C;
1322}
1323
1324} // anonymous namespace
1325
1328 TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
1329 bool Changed = lowerAmxType(F, TM, &TLI);
1330 if (!Changed)
1331 return PreservedAnalyses::all();
1332
1335 return PA;
1336}
1337
1338namespace {
1339
1340class X86LowerAMXTypeLegacyPass : public FunctionPass {
1341public:
1342 static char ID;
1343
1344 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {}
1345
1346 bool runOnFunction(Function &F) override {
1347 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1348 TargetLibraryInfo *TLI =
1349 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1350 return lowerAmxType(F, TM, TLI);
1351 }
1352
1353 void getAnalysisUsage(AnalysisUsage &AU) const override {
1354 AU.setPreservesCFG();
1355 AU.addRequired<TargetPassConfig>();
1356 AU.addRequired<TargetLibraryInfoWrapperPass>();
1357 }
1358};
1359
1360} // anonymous namespace
1361
1362static const char PassName[] = "Lower AMX type for load/store";
1363char X86LowerAMXTypeLegacyPass::ID = 0;
1364INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1365 false)
1368INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1369 false)
1370
1372 return new X86LowerAMXTypeLegacyPass();
1373}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Rewrite undef for PHI
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition DCE.cpp:55
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
uint64_t IntrinsicInst * II
FunctionAnalysisManager FAM
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file implements a set that has insertion order iteration characteristics.
#define LLVM_DEBUG(...)
Definition Debug.h:114
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg)
static const char PassName[]
static bool isAMXCast(Instruction *II)
static Value * getRowFromCol(Instruction *II, Value *V, unsigned Granularity)
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
static Value * getAllocaPos(BasicBlock *BB)
static bool containsAMXCode(Function &F)
std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
static bool isIncomingOfPHI(Instruction *I)
static bool isAMXIntrinsic(Value *I)
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
an instruction to allocate memory on the stack
void setAlignment(Align Align)
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
LLVM Basic Block Representation.
Definition BasicBlock.h:62
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
This class represents a no-op cast from one type to another.
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition IRBuilder.h:1454
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition IRBuilder.h:517
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2788
LLVM_ABI void moveBefore(InstListType::iterator InsertPos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
bool empty() const
Determine if the SetVector is empty or not.
Definition SetVector.h:99
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:150
value_type pop_back_val()
Definition SetVector.h:278
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition SetVector.h:251
void push_back(const T &Elt)
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI Type * getX86_AMXTy(LLVMContext &C)
Definition Type.cpp:292
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition Type.h:200
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
User * getUser() const
Returns the User that contains this Use.
Definition Use.h:61
LLVM_ABI unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition Use.cpp:35
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition User.cpp:21
void setOperand(unsigned i, Value *Val)
Definition User.h:237
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
iterator_range< user_iterator > users()
Definition Value.h:426
use_iterator use_begin()
Definition Value.h:364
bool use_empty() const
Definition Value.h:346
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM)
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
Changed
Pass manager infrastructure for declaring and invalidating analyses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
@ Bitcast
Perform the operation on a different, but equivalently sized type.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
@ User
could "use" a pointer
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
NodeAddr< FuncNode * > Func
Definition RDFGraph.h:393
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
Definition Utils.cpp:1724
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:632
iterator_range< po_iterator< T > > post_order(const T &G)
LLVM_ABI bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition Local.cpp:402
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
LLVM_ABI bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
DWARFExpression::Operation Op
FunctionPass * createX86LowerAMXTypeLegacyPass()
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.