LLVM 20.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
43#include "llvm/ADT/SetVector.h"
46#include "llvm/CodeGen/Passes.h"
49#include "llvm/IR/DataLayout.h"
50#include "llvm/IR/Function.h"
51#include "llvm/IR/IRBuilder.h"
54#include "llvm/IR/IntrinsicsX86.h"
57#include "llvm/Pass.h"
61
62#include <map>
63
64using namespace llvm;
65using namespace PatternMatch;
66
67#define DEBUG_TYPE "lower-amx-type"
68
69static bool isAMXCast(Instruction *II) {
70 return match(II,
71 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
72 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
73}
74
75// Some instructions may return more than one tiles.
76// e.g: call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal
77static unsigned getNumDefTiles(IntrinsicInst *II) {
78 Type *Ty = II->getType();
79 if (Ty->isX86_AMXTy())
80 return 1;
81
82 unsigned Num = 0;
83 for (unsigned i = 0; i < Ty->getNumContainedTypes(); i++) {
84 Type *STy = Ty->getContainedType(i);
85 if (STy->isX86_AMXTy())
86 Num++;
87 }
88 return Num;
89}
90
91static bool isAMXIntrinsic(Value *I) {
92 auto *II = dyn_cast<IntrinsicInst>(I);
93 if (!II)
94 return false;
95 if (isAMXCast(II))
96 return false;
97 // Check if return type or parameter is x86_amx. If it is x86_amx
98 // the intrinsic must be x86 amx intrinsics.
99 if (getNumDefTiles(II) > 0)
100 return true;
101 for (Value *V : II->args()) {
102 if (V->getType()->isX86_AMXTy())
103 return true;
104 }
105
106 return false;
107}
108
110 for (BasicBlock &BB : F)
111 for (Instruction &I : BB)
112 if (I.getType()->isX86_AMXTy())
113 return true;
114 return false;
115}
116
118 Type *Ty) {
119 Function &F = *BB->getParent();
120 const DataLayout &DL = F.getDataLayout();
121
122 LLVMContext &Ctx = Builder.getContext();
123 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
124 unsigned AllocaAS = DL.getAllocaAddrSpace();
125 AllocaInst *AllocaRes =
126 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
127 AllocaRes->setAlignment(AllocaAlignment);
128 return AllocaRes;
129}
130
132 for (Instruction &I : F.getEntryBlock())
133 if (!isa<AllocaInst>(&I))
134 return &I;
135 llvm_unreachable("No terminator in the entry block!");
136}
137
139private:
140 TargetMachine *TM = nullptr;
141
142 // In AMX intrinsics we let Shape = {Row, Col}, but the
143 // RealCol = Col / ElementSize. We may use the RealCol
144 // as a new Row for other new created AMX intrinsics.
145 std::map<Value *, Value *> Col2Row, Row2Col;
146
147public:
148 ShapeCalculator(TargetMachine *TargetM) : TM(TargetM) {}
149 std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
150 std::pair<Value *, Value *> getShape(PHINode *Phi);
151 Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
152 Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity);
153};
154
156 unsigned Granularity) {
157 if (Col2Row.count(V))
158 return Col2Row[V];
159 IRBuilder<> Builder(II);
160 Value *RealRow = nullptr;
161 if (isa<ConstantInt>(V))
162 RealRow =
163 Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) / Granularity);
164 else if (isa<Instruction>(V)) {
165 // When it is not a const value and it is not a function argument, we
166 // create Row after the definition of V instead of
167 // before II. For example, II is %118, we try to getshape for %117:
168 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
169 // i32> %115).
170 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
171 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
172 // %117).
173 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
174 // definition is after its user(new tileload for %117).
175 // So, the best choice is to create %row right after the definition of
176 // %106.
177 Builder.SetInsertPoint(cast<Instruction>(V));
178 RealRow = Builder.CreateUDiv(V, Builder.getInt16(4));
179 cast<Instruction>(RealRow)->moveAfter(cast<Instruction>(V));
180 } else {
181 // When it is not a const value and it is a function argument, we create
182 // Row at the entry bb.
183 IRBuilder<> NewBuilder(
184 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
185 RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));
186 }
187 Col2Row[V] = RealRow;
188 return RealRow;
189}
190
192 unsigned Granularity) {
193 if (Row2Col.count(V))
194 return Row2Col[V];
195 IRBuilder<> Builder(II);
196 Value *RealCol = nullptr;
197 if (isa<ConstantInt>(V))
198 RealCol =
199 Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity);
200 else if (isa<Instruction>(V)) {
201 Builder.SetInsertPoint(cast<Instruction>(V));
202 RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity));
203 cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V));
204 } else {
205 // When it is not a const value and it is a function argument, we create
206 // Row at the entry bb.
207 IRBuilder<> NewBuilder(
208 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
209 RealCol = NewBuilder.CreateNUWMul(V, NewBuilder.getInt16(Granularity));
210 }
211 Row2Col[V] = RealCol;
212 return RealCol;
213}
214
215// TODO: Refine the row and col-in-bytes of tile to row and col of matrix.
216std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,
217 unsigned OpNo) {
218 (void)TM;
219 IRBuilder<> Builder(II);
220 Value *Row = nullptr, *Col = nullptr;
221 switch (II->getIntrinsicID()) {
222 default:
223 llvm_unreachable("Expect amx intrinsics");
224 case Intrinsic::x86_t2rpntlvwz0_internal:
225 case Intrinsic::x86_t2rpntlvwz0t1_internal:
226 case Intrinsic::x86_t2rpntlvwz1_internal:
227 case Intrinsic::x86_t2rpntlvwz1t1_internal:
228 case Intrinsic::x86_tileloadd64_internal:
229 case Intrinsic::x86_tileloaddt164_internal:
230 case Intrinsic::x86_tilestored64_internal:
231 case Intrinsic::x86_t2rpntlvwz0rs_internal:
232 case Intrinsic::x86_t2rpntlvwz0rst1_internal:
233 case Intrinsic::x86_t2rpntlvwz1rs_internal:
234 case Intrinsic::x86_t2rpntlvwz1rst1_internal:
235 case Intrinsic::x86_tileloaddrs64_internal:
236 case Intrinsic::x86_tileloaddrst164_internal: {
237 Row = II->getArgOperand(0);
238 Col = II->getArgOperand(1);
239 break;
240 }
241 // a * b + c
242 // The shape depends on which operand.
243 case Intrinsic::x86_tcmmimfp16ps_internal:
244 case Intrinsic::x86_tcmmrlfp16ps_internal:
245 case Intrinsic::x86_tdpbssd_internal:
246 case Intrinsic::x86_tdpbsud_internal:
247 case Intrinsic::x86_tdpbusd_internal:
248 case Intrinsic::x86_tdpbuud_internal:
249 case Intrinsic::x86_tdpbf16ps_internal:
250 case Intrinsic::x86_tdpfp16ps_internal:
251 case Intrinsic::x86_tmmultf32ps_internal: {
252 switch (OpNo) {
253 case 3:
254 Row = II->getArgOperand(0);
255 Col = II->getArgOperand(1);
256 break;
257 case 4:
258 Row = II->getArgOperand(0);
259 Col = II->getArgOperand(2);
260 break;
261 case 5:
262 Row = getRowFromCol(II, II->getArgOperand(2), 4);
263 Col = II->getArgOperand(1);
264 break;
265 }
266 break;
267 }
268 case Intrinsic::x86_ttransposed_internal:
269 case Intrinsic::x86_tconjtfp16_internal: {
270 assert((OpNo == 2) && "Illegal Operand Number.");
271 Row = getRowFromCol(II, II->getArgOperand(1), 4);
272 Col = getColFromRow(II, II->getArgOperand(0), 4);
273 break;
274 }
275 case Intrinsic::x86_tcvtrowd2ps_internal:
276 case Intrinsic::x86_tcvtrowps2pbf16h_internal:
277 case Intrinsic::x86_tcvtrowps2pbf16l_internal:
278 case Intrinsic::x86_tcvtrowps2phh_internal:
279 case Intrinsic::x86_tcvtrowps2phl_internal:
280 case Intrinsic::x86_tilemovrow_internal: {
281 assert(OpNo == 2 && "Illegal Operand Number.");
282 Row = II->getArgOperand(0);
283 Col = II->getArgOperand(1);
284 break;
285 }
286 case Intrinsic::x86_ttdpbf16ps_internal:
287 case Intrinsic::x86_ttdpfp16ps_internal:
288 case Intrinsic::x86_ttcmmimfp16ps_internal:
289 case Intrinsic::x86_ttcmmrlfp16ps_internal:
290 case Intrinsic::x86_tconjtcmmimfp16ps_internal:
291 case Intrinsic::x86_ttmmultf32ps_internal: {
292 switch (OpNo) {
293 case 3:
294 Row = II->getArgOperand(0);
295 Col = II->getArgOperand(1);
296 break;
297 case 4:
298 Row = getRowFromCol(II, II->getArgOperand(2), 4);
299 Col = getColFromRow(II, II->getArgOperand(0), 4);
300 break;
301 case 5:
302 Row = getRowFromCol(II, II->getArgOperand(2), 4);
303 Col = II->getArgOperand(1);
304 break;
305 }
306 break;
307 }
308 }
309
310 return std::make_pair(Row, Col);
311}
312
313std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) {
314 Use &U = *(Phi->use_begin());
315 unsigned OpNo = U.getOperandNo();
316 User *V = U.getUser();
317 // TODO We don't traverse all users. To make the algorithm simple, here we
318 // just traverse the first user. If we can find shape, then return the shape,
319 // otherwise just return nullptr and the optimization for undef/zero will be
320 // abandoned.
321 while (V) {
322 if (isAMXCast(dyn_cast<Instruction>(V))) {
323 if (V->use_empty())
324 break;
325 Use &U = *(V->use_begin());
326 OpNo = U.getOperandNo();
327 V = U.getUser();
328 } else if (isAMXIntrinsic(V)) {
329 return getShape(cast<IntrinsicInst>(V), OpNo);
330 } else if (isa<PHINode>(V)) {
331 if (V->use_empty())
332 break;
333 Use &U = *(V->use_begin());
334 V = U.getUser();
335 } else {
336 break;
337 }
338 }
339
340 return std::make_pair(nullptr, nullptr);
341}
342
343namespace {
344class X86LowerAMXType {
345 Function &Func;
346 ShapeCalculator *SC;
347
348 // In AMX intrinsics we let Shape = {Row, Col}, but the
349 // RealCol = Col / ElementSize. We may use the RealCol
350 // as a new Row for other new created AMX intrinsics.
351 std::map<Value *, Value *> Col2Row, Row2Col;
352
353public:
354 X86LowerAMXType(Function &F, ShapeCalculator *ShapeC) : Func(F), SC(ShapeC) {}
355 bool visit();
356 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
357 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
358 bool transformBitcast(BitCastInst *Bitcast);
359};
360
361// %src = load <256 x i32>, <256 x i32>* %addr, align 64
362// %2 = bitcast <256 x i32> %src to x86_amx
363// -->
364// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
365// i8* %addr, i64 %stride64)
366void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
367 Value *Row = nullptr, *Col = nullptr;
368 Use &U = *(Bitcast->use_begin());
369 unsigned OpNo = U.getOperandNo();
370 auto *II = cast<IntrinsicInst>(U.getUser());
371 std::tie(Row, Col) = SC->getShape(II, OpNo);
372 IRBuilder<> Builder(Bitcast);
373 // Use the maximun column as stride.
374 Value *Stride = Builder.getInt64(64);
375 Value *I8Ptr = LD->getOperand(0);
376 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
377
378 Value *NewInst =
379 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args);
380 Bitcast->replaceAllUsesWith(NewInst);
381}
382
383// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
384// %stride);
385// %13 = bitcast x86_amx %src to <256 x i32>
386// store <256 x i32> %13, <256 x i32>* %addr, align 64
387// -->
388// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
389// %stride64, %13)
390void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
391
392 Value *Tile = Bitcast->getOperand(0);
393 auto *II = cast<IntrinsicInst>(Tile);
394 // Tile is output from AMX intrinsic. The first operand of the
395 // intrinsic is row, the second operand of the intrinsic is column.
396 Value *Row = II->getOperand(0);
397 Value *Col = II->getOperand(1);
398 IRBuilder<> Builder(ST);
399 // Use the maximum column as stride. It must be the same with load
400 // stride.
401 Value *Stride = Builder.getInt64(64);
402 Value *I8Ptr = ST->getOperand(1);
403 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
404 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args);
405 if (Bitcast->hasOneUse())
406 return;
407 // %13 = bitcast x86_amx %src to <256 x i32>
408 // store <256 x i32> %13, <256 x i32>* %addr, align 64
409 // %add = <256 x i32> %13, <256 x i32> %src2
410 // -->
411 // %13 = bitcast x86_amx %src to <256 x i32>
412 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
413 // %stride64, %13)
414 // %14 = load <256 x i32>, %addr
415 // %add = <256 x i32> %14, <256 x i32> %src2
416 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
417 Bitcast->replaceAllUsesWith(Vec);
418}
419
420// transform bitcast to <store, load> instructions.
421bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
422 IRBuilder<> Builder(Bitcast);
423 AllocaInst *AllocaAddr;
424 Value *I8Ptr, *Stride;
425 auto *Src = Bitcast->getOperand(0);
426
427 auto Prepare = [&](Type *MemTy) {
428 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
429 I8Ptr = AllocaAddr;
430 Stride = Builder.getInt64(64);
431 };
432
433 if (Bitcast->getType()->isX86_AMXTy()) {
434 // %2 = bitcast <256 x i32> %src to x86_amx
435 // -->
436 // %addr = alloca <256 x i32>, align 64
437 // store <256 x i32> %src, <256 x i32>* %addr, align 64
438 // %addr2 = bitcast <256 x i32>* to i8*
439 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
440 // i8* %addr2,
441 // i64 64)
442 Use &U = *(Bitcast->use_begin());
443 unsigned OpNo = U.getOperandNo();
444 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
445 if (!II)
446 return false; // May be bitcast from x86amx to <256 x i32>.
447 Prepare(Bitcast->getOperand(0)->getType());
448 Builder.CreateStore(Src, AllocaAddr);
449 // TODO we can pick an constant operand for the shape.
450 Value *Row = nullptr, *Col = nullptr;
451 std::tie(Row, Col) = SC->getShape(II, OpNo);
452 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
453 Value *NewInst =
454 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args);
455 Bitcast->replaceAllUsesWith(NewInst);
456 } else {
457 // %2 = bitcast x86_amx %src to <256 x i32>
458 // -->
459 // %addr = alloca <256 x i32>, align 64
460 // %addr2 = bitcast <256 x i32>* to i8*
461 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
462 // i8* %addr2, i64 %stride)
463 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
464 auto *II = dyn_cast<IntrinsicInst>(Src);
465 if (!II)
466 return false; // May be bitcast from <256 x i32> to x86amx.
467 Prepare(Bitcast->getType());
468 Value *Row = II->getOperand(0);
469 Value *Col = II->getOperand(1);
470 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
471 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args);
472 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
473 Bitcast->replaceAllUsesWith(NewInst);
474 }
475
476 return true;
477}
478
479bool X86LowerAMXType::visit() {
481 Col2Row.clear();
482
483 for (BasicBlock *BB : post_order(&Func)) {
485 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
486 if (!Bitcast)
487 continue;
488
489 Value *Src = Bitcast->getOperand(0);
490 if (Bitcast->getType()->isX86_AMXTy()) {
491 if (Bitcast->user_empty()) {
492 DeadInsts.push_back(Bitcast);
493 continue;
494 }
495 LoadInst *LD = dyn_cast<LoadInst>(Src);
496 if (!LD) {
497 if (transformBitcast(Bitcast))
498 DeadInsts.push_back(Bitcast);
499 continue;
500 }
501 // If load has mutli-user, duplicate a vector load.
502 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
503 // %2 = bitcast <256 x i32> %src to x86_amx
504 // %add = add <256 x i32> %src, <256 x i32> %src2
505 // -->
506 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
507 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
508 // i8* %addr, i64 %stride64)
509 // %add = add <256 x i32> %src, <256 x i32> %src2
510
511 // If load has one user, the load will be eliminated in DAG ISel.
512 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
513 // %2 = bitcast <256 x i32> %src to x86_amx
514 // -->
515 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
516 // i8* %addr, i64 %stride64)
517 combineLoadBitcast(LD, Bitcast);
518 DeadInsts.push_back(Bitcast);
519 if (LD->hasOneUse())
520 DeadInsts.push_back(LD);
521 } else if (Src->getType()->isX86_AMXTy()) {
522 if (Bitcast->user_empty()) {
523 DeadInsts.push_back(Bitcast);
524 continue;
525 }
526 StoreInst *ST = nullptr;
527 for (Use &U : Bitcast->uses()) {
528 ST = dyn_cast<StoreInst>(U.getUser());
529 if (ST)
530 break;
531 }
532 if (!ST) {
533 if (transformBitcast(Bitcast))
534 DeadInsts.push_back(Bitcast);
535 continue;
536 }
537 // If bitcast (%13) has one use, combine bitcast and store to amx store.
538 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
539 // %stride);
540 // %13 = bitcast x86_amx %src to <256 x i32>
541 // store <256 x i32> %13, <256 x i32>* %addr, align 64
542 // -->
543 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
544 // %stride64, %13)
545 //
546 // If bitcast (%13) has multi-use, transform as below.
547 // %13 = bitcast x86_amx %src to <256 x i32>
548 // store <256 x i32> %13, <256 x i32>* %addr, align 64
549 // %add = <256 x i32> %13, <256 x i32> %src2
550 // -->
551 // %13 = bitcast x86_amx %src to <256 x i32>
552 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
553 // %stride64, %13)
554 // %14 = load <256 x i32>, %addr
555 // %add = <256 x i32> %14, <256 x i32> %src2
556 //
557 combineBitcastStore(Bitcast, ST);
558 // Delete user first.
559 DeadInsts.push_back(ST);
560 DeadInsts.push_back(Bitcast);
561 }
562 }
563 }
564
565 bool C = !DeadInsts.empty();
566
567 for (auto *Inst : DeadInsts)
568 Inst->eraseFromParent();
569
570 return C;
571}
572} // anonymous namespace
573
575 Function *F = BB->getParent();
576 IRBuilder<> Builder(&F->getEntryBlock().front());
577 const DataLayout &DL = F->getDataLayout();
578 unsigned AllocaAS = DL.getAllocaAddrSpace();
579 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
580 AllocaInst *AllocaRes =
581 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
582 BasicBlock::iterator Iter = AllocaRes->getIterator();
583 ++Iter;
584 Builder.SetInsertPoint(&*Iter);
585 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
586 return I8Ptr;
587}
588
590 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
591 auto *II = dyn_cast<IntrinsicInst>(TileDef);
592 unsigned Idx = 0;
593 // Extract tile from multiple tiles' def.
594 if (auto *Extr = dyn_cast<ExtractValueInst>(TileDef)) {
595 assert(Extr->hasIndices() && "Tile extract miss index!");
596 Idx = Extr->getIndices()[0];
597 II = cast<IntrinsicInst>(Extr->getOperand(0));
598 }
599
600 assert(II && "Not tile intrinsic!");
601 Value *Row = II->getOperand(Idx);
602 Value *Col = II->getOperand(Idx + 1);
603
604 BasicBlock *BB = TileDef->getParent();
605 BasicBlock::iterator Iter = TileDef->getIterator();
606 IRBuilder<> Builder(BB, ++Iter);
607 Value *Stride = Builder.getInt64(64);
608 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
609
610 Instruction *TileStore =
611 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args);
612 return TileStore;
613}
614
615static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
616 Value *V = U.get();
617 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
618
619 // Get tile shape.
620 IntrinsicInst *II = nullptr;
621 unsigned Idx = 0;
622 if (IsPHI) {
623 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
624 II = cast<IntrinsicInst>(PhiOp);
625 } else if (auto *Extr = dyn_cast<ExtractValueInst>(V)) {
626 // Extract tile from multiple tiles' def.
627 assert(Extr->hasIndices() && "Tile extract miss index!");
628 Idx = Extr->getIndices()[0];
629 II = cast<IntrinsicInst>(Extr->getOperand(0));
630 } else {
631 II = cast<IntrinsicInst>(V);
632 }
633 Value *Row = II->getOperand(Idx);
634 Value *Col = II->getOperand(Idx + 1);
635
636 Instruction *UserI = cast<Instruction>(U.getUser());
637 IRBuilder<> Builder(UserI);
638 Value *Stride = Builder.getInt64(64);
639 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
640
641 Value *TileLoad =
642 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args);
643 UserI->replaceUsesOfWith(V, TileLoad);
644}
645
647 for (Use &U : I->uses()) {
648 User *V = U.getUser();
649 if (isa<PHINode>(V))
650 return true;
651 }
652 return false;
653}
654
655// Let all AMX tile data become volatile data, shorten the life range
656// of each tile register before fast register allocation.
657namespace {
658class X86VolatileTileData {
659 Function &F;
660
661public:
662 X86VolatileTileData(Function &Func) : F(Func) {}
663 Value *updatePhiIncomings(BasicBlock *BB,
665 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
666 bool volatileTileData();
667 void volatileTilePHI(PHINode *PHI);
668 void volatileTileNonPHI(Instruction *I);
669};
670
671Value *X86VolatileTileData::updatePhiIncomings(
673 Value *I8Ptr = getAllocaPos(BB);
674
675 for (auto *I : Incomings) {
676 User *Store = createTileStore(I, I8Ptr);
677
678 // All its uses (except phi) should load from stored mem.
679 for (Use &U : I->uses()) {
680 User *V = U.getUser();
681 if (isa<PHINode>(V) || V == Store)
682 continue;
683 replaceWithTileLoad(U, I8Ptr);
684 }
685 }
686 return I8Ptr;
687}
688
689void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
690 Value *StorePtr) {
691 for (Use &U : PHI->uses())
692 replaceWithTileLoad(U, StorePtr, true);
693 PHI->eraseFromParent();
694}
695
696// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
697// and their related AMX intrinsics.
698// 1) PHI Def should change to tileload.
699// 2) PHI Incoming Values should tilestored in just after their def.
700// 3) The mem of these tileload and tilestores should be same.
701// e.g.
702// ------------------------------------------------------
703// bb_dom:
704// ...
705// br i1 %bool.cond, label %if.else, label %if.then
706//
707// if.then:
708// def %t0 = ...
709// ...
710// use %t0
711// ...
712// br label %if.end
713//
714// if.else:
715// def %t1 = ...
716// br label %if.end
717//
718// if.end:
719// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
720// ...
721// use %td
722// ------------------------------------------------------
723// -->
724// ------------------------------------------------------
725// bb_entry:
726// %mem = alloca <256 x i32>, align 1024 *
727// ...
728// bb_dom:
729// ...
730// br i1 %bool.cond, label %if.else, label %if.then
731//
732// if.then:
733// def %t0 = ...
734// call void @llvm.x86.tilestored64.internal(mem, %t0) *
735// ...
736// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
737// use %t0` *
738// ...
739// br label %if.end
740//
741// if.else:
742// def %t1 = ...
743// call void @llvm.x86.tilestored64.internal(mem, %t1) *
744// br label %if.end
745//
746// if.end:
747// ...
748// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
749// use %td
750// ------------------------------------------------------
751void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
752 BasicBlock *BB = PHI->getParent();
754
755 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
756 Value *Op = PHI->getIncomingValue(I);
757 Instruction *Inst = dyn_cast<Instruction>(Op);
758 assert(Inst && "We shouldn't fold AMX instrution!");
759 Incomings.push_back(Inst);
760 }
761
762 Value *StorePtr = updatePhiIncomings(BB, Incomings);
763 replacePhiDefWithLoad(PHI, StorePtr);
764}
765
766// Store the defined tile and load it before use.
767// All its users are not PHI.
768// e.g.
769// ------------------------------------------------------
770// def %td = ...
771// ...
772// "use %td"
773// ------------------------------------------------------
774// -->
775// ------------------------------------------------------
776// def %td = ...
777// call void @llvm.x86.tilestored64.internal(mem, %td)
778// ...
779// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
780// "use %td2"
781// ------------------------------------------------------
782void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
783 BasicBlock *BB = I->getParent();
784 Value *I8Ptr = getAllocaPos(BB);
785 User *Store = createTileStore(I, I8Ptr);
786
787 // All its uses should load from stored mem.
788 for (Use &U : I->uses()) {
789 User *V = U.getUser();
790 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
791 if (V != Store)
792 replaceWithTileLoad(U, I8Ptr);
793 }
794}
795
796// Volatile Tile Model:
797// 1) All the uses of tile data comes from tileload in time.
798// 2) All the defs of tile data tilestore into mem immediately.
799// For example:
800// --------------------------------------------------------------------------
801// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
802// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
803// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
804// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
805// call void @llvm.x86.tilestored64.internal(... td) area
806// --------------------------------------------------------------------------
807// 3) No terminator, call or other amx instructions in the key amx area.
808bool X86VolatileTileData::volatileTileData() {
809 bool Changed = false;
810 for (BasicBlock &BB : F) {
813
814 for (Instruction &I : BB) {
815 if (!I.getType()->isX86_AMXTy())
816 continue;
817 if (isa<PHINode>(&I))
818 PHIInsts.push_back(&I);
819 else
820 AMXDefInsts.push_back(&I);
821 }
822
823 // First we "volatile" the non-phi related amx intrinsics.
824 for (Instruction *I : AMXDefInsts) {
825 if (isIncomingOfPHI(I))
826 continue;
827 volatileTileNonPHI(I);
828 Changed = true;
829 }
830
831 for (Instruction *I : PHIInsts) {
832 volatileTilePHI(dyn_cast<PHINode>(I));
833 Changed = true;
834 }
835 }
836 return Changed;
837}
838
839} // anonymous namespace
840
841namespace {
842
843class X86LowerAMXCast {
844 Function &Func;
846 std::unique_ptr<DominatorTree> DT;
847
848public:
849 X86LowerAMXCast(Function &F, ShapeCalculator *ShapeC)
850 : Func(F), SC(ShapeC), DT(nullptr) {}
851 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
852 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
853 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
854 bool combineAMXcast(TargetLibraryInfo *TLI);
855 bool transformAMXCast(IntrinsicInst *AMXCast);
856 bool transformAllAMXCast();
857 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
859};
860
861static bool DCEInstruction(Instruction *I,
863 const TargetLibraryInfo *TLI) {
864 if (isInstructionTriviallyDead(I, TLI)) {
867
868 // Null out all of the instruction's operands to see if any operand becomes
869 // dead as we go.
870 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
871 Value *OpV = I->getOperand(i);
872 I->setOperand(i, nullptr);
873
874 if (!OpV->use_empty() || I == OpV)
875 continue;
876
877 // If the operand is an instruction that became dead as we nulled out the
878 // operand, and if it is 'trivially' dead, delete it in a future loop
879 // iteration.
880 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
881 if (isInstructionTriviallyDead(OpI, TLI)) {
882 WorkList.insert(OpI);
883 }
884 }
885 }
886 I->eraseFromParent();
887 return true;
888 }
889 return false;
890}
891
892/// This function handles following case
893///
894/// A -> B amxcast
895/// PHI
896/// B -> A amxcast
897///
898/// All the related PHI nodes can be replaced by new PHI nodes with type A.
899/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
900bool X86LowerAMXCast::optimizeAMXCastFromPhi(
901 IntrinsicInst *CI, PHINode *PN,
903 IRBuilder<> Builder(CI);
904 Value *Src = CI->getOperand(0);
905 Type *SrcTy = Src->getType(); // Type B
906 Type *DestTy = CI->getType(); // Type A
907
908 SmallVector<PHINode *, 4> PhiWorklist;
910
911 // Find all of the A->B casts and PHI nodes.
912 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
913 // OldPhiNodes is used to track all known PHI nodes, before adding a new
914 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
915 PhiWorklist.push_back(PN);
916 OldPhiNodes.insert(PN);
917 while (!PhiWorklist.empty()) {
918 auto *OldPN = PhiWorklist.pop_back_val();
919 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
920 Value *IncValue = OldPN->getIncomingValue(I);
921 // TODO: currently, We ignore cases where it is a const. In the future, we
922 // might support const.
923 if (isa<Constant>(IncValue)) {
924 auto *IncConst = dyn_cast<Constant>(IncValue);
925 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
926 return false;
927 Value *Row = nullptr, *Col = nullptr;
928 std::tie(Row, Col) = SC->getShape(OldPN);
929 // TODO: If it is not constant the Row and Col must domoniate tilezero
930 // that we are going to create.
931 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
932 return false;
933 // Create tilezero at the end of incoming block.
934 auto *Block = OldPN->getIncomingBlock(I);
935 BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
936 Instruction *NewInst = Builder.CreateIntrinsic(
937 Intrinsic::x86_tilezero_internal, {}, {Row, Col});
938 NewInst->moveBefore(&*Iter);
939 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
940 {IncValue->getType()}, {NewInst});
941 NewInst->moveBefore(&*Iter);
942 // Replace InValue with new Value.
943 OldPN->setIncomingValue(I, NewInst);
944 IncValue = NewInst;
945 }
946
947 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
948 if (OldPhiNodes.insert(PNode))
949 PhiWorklist.push_back(PNode);
950 continue;
951 }
952 Instruction *ACI = dyn_cast<Instruction>(IncValue);
953 if (ACI && isAMXCast(ACI)) {
954 // Verify it's a A->B cast.
955 Type *TyA = ACI->getOperand(0)->getType();
956 Type *TyB = ACI->getType();
957 if (TyA != DestTy || TyB != SrcTy)
958 return false;
959 continue;
960 }
961 return false;
962 }
963 }
964
965 // Check that each user of each old PHI node is something that we can
966 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
967 for (auto *OldPN : OldPhiNodes) {
968 for (User *V : OldPN->users()) {
969 Instruction *ACI = dyn_cast<Instruction>(V);
970 if (ACI && isAMXCast(ACI)) {
971 // Verify it's a B->A cast.
972 Type *TyB = ACI->getOperand(0)->getType();
973 Type *TyA = ACI->getType();
974 if (TyA != DestTy || TyB != SrcTy)
975 return false;
976 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
977 // As long as the user is another old PHI node, then even if we don't
978 // rewrite it, the PHI web we're considering won't have any users
979 // outside itself, so it'll be dead.
980 // example:
981 // bb.0:
982 // %0 = amxcast ...
983 // bb.1:
984 // %1 = amxcast ...
985 // bb.2:
986 // %goodphi = phi %0, %1
987 // %3 = amxcast %goodphi
988 // bb.3:
989 // %goodphi2 = phi %0, %goodphi
990 // %4 = amxcast %goodphi2
991 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
992 // outside the phi-web, so the combination stop When
993 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
994 // will be done.
995 if (OldPhiNodes.count(PHI) == 0)
996 return false;
997 } else
998 return false;
999 }
1000 }
1001
1002 // For each old PHI node, create a corresponding new PHI node with a type A.
1004 for (auto *OldPN : OldPhiNodes) {
1005 Builder.SetInsertPoint(OldPN);
1006 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
1007 NewPNodes[OldPN] = NewPN;
1008 }
1009
1010 // Fill in the operands of new PHI nodes.
1011 for (auto *OldPN : OldPhiNodes) {
1012 PHINode *NewPN = NewPNodes[OldPN];
1013 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
1014 Value *V = OldPN->getOperand(j);
1015 Value *NewV = nullptr;
1016 Instruction *ACI = dyn_cast<Instruction>(V);
1017 // There should not be a AMXcast from a const.
1018 if (ACI && isAMXCast(ACI))
1019 NewV = ACI->getOperand(0);
1020 else if (auto *PrevPN = dyn_cast<PHINode>(V))
1021 NewV = NewPNodes[PrevPN];
1022 assert(NewV);
1023 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
1024 }
1025 }
1026
1027 // Traverse all accumulated PHI nodes and process its users,
1028 // which are Stores and BitcCasts. Without this processing
1029 // NewPHI nodes could be replicated and could lead to extra
1030 // moves generated after DeSSA.
1031 // If there is a store with type B, change it to type A.
1032
1033 // Replace users of BitCast B->A with NewPHI. These will help
1034 // later to get rid of a closure formed by OldPHI nodes.
1035 for (auto *OldPN : OldPhiNodes) {
1036 PHINode *NewPN = NewPNodes[OldPN];
1037 for (User *V : make_early_inc_range(OldPN->users())) {
1038 Instruction *ACI = dyn_cast<Instruction>(V);
1039 if (ACI && isAMXCast(ACI)) {
1040 Type *TyB = ACI->getOperand(0)->getType();
1041 Type *TyA = ACI->getType();
1042 assert(TyA == DestTy && TyB == SrcTy);
1043 (void)TyA;
1044 (void)TyB;
1045 ACI->replaceAllUsesWith(NewPN);
1046 DeadInst.insert(ACI);
1047 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
1048 // We don't need to push PHINode into DeadInst since they are operands
1049 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
1050 assert(OldPhiNodes.contains(PHI));
1051 (void)PHI;
1052 } else
1053 llvm_unreachable("all uses should be handled");
1054 }
1055 }
1056 return true;
1057}
1058
1059static Value *getShapeFromAMXIntrinsic(Value *Inst, unsigned ShapeIdx,
1060 bool IsRow) {
1061 if (!isAMXIntrinsic(Inst))
1062 return nullptr;
1063
1064 auto *II = cast<IntrinsicInst>(Inst);
1065 if (IsRow)
1066 return II->getOperand(0);
1067
1068 assert(ShapeIdx < 2 && "Currently 2 shapes in 1 instruction at most!");
1069 return II->getOperand(ShapeIdx + 1);
1070}
1071
1072// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
1073// store <256 x i32> %43, <256 x i32>* %p, align 64
1074// -->
1075// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1076// i64 64, x86_amx %42)
1077bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
1078 Value *Tile = Cast->getOperand(0);
1079
1080 assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!");
1081
1082 // TODO: Specially handle the multi-use case.
1083 if (Tile->getNumUses() != 1)
1084 return false;
1085
1086 // We don't fetch shape from tilestore, we only get shape from tiledef,
1087 // so we can set the max tile shape to tilestore for special cases.
1088 IRBuilder<> Builder(ST);
1089 Value *Row = nullptr;
1090 Value *Col = nullptr;
1091
1092 if (isAMXIntrinsic(Tile)) {
1093 auto *II = cast<IntrinsicInst>(Tile);
1094 // Tile is output from AMX intrinsic. The first operand of the
1095 // intrinsic is row, the second operand of the intrinsic is column.
1096 Row = II->getOperand(0);
1097 Col = II->getOperand(1);
1098 } else {
1099 // Now we supported multi-tiles value in structure, so we may get tile
1100 // from extracting multi-tiles structure.
1101 // For example:
1102 // %6 = call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal(i16 %1,
1103 // i16 %2, i16 %3, i8* %4, i64 %5)
1104 // %7 = extractvalue { x86_amx, x86_amx } %6, 0
1105 // %8 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %7)
1106 // store <256 x i32> %8, <256 x i32>* %0, align 1024
1107 //
1108 // TODO: Currently we only handle extractvalue case, enhance me for other
1109 // cases if possible.
1110 auto *II = cast<ExtractValueInst>(Tile);
1111 assert(II && "We meet unhandle source in fetching tile value!");
1112 unsigned ShapeIdx = II->getIndices()[0];
1113 Value *Tiles = II->getOperand(0);
1114 Row = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, true);
1115 Col = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, false);
1116 }
1117 assert(Row && Col && "Shape got failed!");
1118
1119 // Stride should be equal to col(measured by bytes)
1120 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
1121 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
1122 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
1123 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args);
1124 return true;
1125}
1126
1127// %65 = load <256 x i32>, <256 x i32>* %p, align 64
1128// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1129// -->
1130// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1131// i8* %p, i64 64)
1132bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
1133 bool EraseLoad = true;
1134 Value *Row = nullptr, *Col = nullptr;
1135 Use &U = *(Cast->use_begin());
1136 unsigned OpNo = U.getOperandNo();
1137 auto *II = cast<IntrinsicInst>(U.getUser());
1138 // TODO: If it is cast intrinsic or phi node, we can propagate the
1139 // shape information through def-use chain.
1140 if (!isAMXIntrinsic(II))
1141 return false;
1142 std::tie(Row, Col) = SC->getShape(II, OpNo);
1143 IRBuilder<> Builder(LD);
1144 // Stride should be equal to col(measured by bytes)
1145 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
1146 Value *I8Ptr;
1147
1148 // To save compiling time, we create doninator tree when it is really
1149 // needed.
1150 if (!DT)
1151 DT.reset(new DominatorTree(Func));
1152 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
1153 // store the value to stack and reload it from stack before cast.
1154 auto *AllocaAddr =
1155 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
1156 Builder.SetInsertPoint(&*std::next(LD->getIterator()));
1157 Builder.CreateStore(LD, AllocaAddr);
1158
1159 Builder.SetInsertPoint(Cast);
1160 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1161 EraseLoad = false;
1162 } else {
1163 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
1164 }
1165 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
1166
1167 Value *NewInst =
1168 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args);
1169 Cast->replaceAllUsesWith(NewInst);
1170
1171 return EraseLoad;
1172}
1173
1174bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
1175 bool Change = false;
1176 for (auto *Cast : Casts) {
1177 auto *II = cast<IntrinsicInst>(Cast);
1178 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
1179 // store <256 x i32> %43, <256 x i32>* %p, align 64
1180 // -->
1181 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1182 // i64 64, x86_amx %42)
1183 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1185 for (User *U : Cast->users()) {
1186 StoreInst *Store = dyn_cast<StoreInst>(U);
1187 if (!Store)
1188 continue;
1189 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1190 DeadStores.push_back(Store);
1191 Change = true;
1192 }
1193 }
1194 for (auto *Store : DeadStores)
1195 Store->eraseFromParent();
1196 } else { // x86_cast_vector_to_tile
1198 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1199 if (!Load || !Load->hasOneUse())
1200 continue;
1201 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1202 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1203 // -->
1204 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1205 // i8* %p, i64 64)
1206 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1207 // Set the operand is null so that load instruction can be erased.
1208 Cast->setOperand(0, nullptr);
1209 Load->eraseFromParent();
1210 }
1211 }
1212 }
1213 return Change;
1214}
1215
1216bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1217 bool Change = false;
1218 // Collect tile cast instruction.
1219 SmallVector<Instruction *, 8> Vec2TileInsts;
1220 SmallVector<Instruction *, 8> Tile2VecInsts;
1221 SmallVector<Instruction *, 8> PhiCastWorkList;
1223 for (BasicBlock &BB : Func) {
1224 for (Instruction &I : BB) {
1225 Value *Vec;
1226 if (match(&I,
1227 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1228 Vec2TileInsts.push_back(&I);
1229 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1230 m_Value(Vec))))
1231 Tile2VecInsts.push_back(&I);
1232 }
1233 }
1234
1235 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1236 for (auto *Inst : Insts) {
1237 for (User *U : Inst->users()) {
1238 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1239 if (!II || II->getIntrinsicID() != IID)
1240 continue;
1241 // T1 = vec2tile V0
1242 // V2 = tile2vec T1
1243 // V3 = OP V2
1244 // -->
1245 // T1 = vec2tile V0
1246 // V2 = tile2vec T1
1247 // V3 = OP V0
1248 II->replaceAllUsesWith(Inst->getOperand(0));
1249 Change = true;
1250 }
1251 }
1252 };
1253
1254 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1255 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1256
1258 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1259 for (auto *Inst : Insts) {
1260 if (Inst->use_empty()) {
1261 Inst->eraseFromParent();
1262 Change = true;
1263 } else {
1264 LiveCasts.push_back(Inst);
1265 }
1266 }
1267 };
1268
1269 EraseInst(Vec2TileInsts);
1270 EraseInst(Tile2VecInsts);
1271 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1272 "Vec2Tile and Tile2Vec:\n";
1273 Func.dump());
1274 Change |= combineLdSt(LiveCasts);
1275 EraseInst(LiveCasts);
1276 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1277 "AMXCast and load/store:\n";
1278 Func.dump());
1279
1280 // Handle the A->B->A cast, and there is an intervening PHI node.
1281 for (BasicBlock &BB : Func) {
1282 for (Instruction &I : BB) {
1283 if (isAMXCast(&I)) {
1284 if (isa<PHINode>(I.getOperand(0)))
1285 PhiCastWorkList.push_back(&I);
1286 }
1287 }
1288 }
1289 for (auto *I : PhiCastWorkList) {
1290 // We skip the dead Amxcast.
1291 if (DeadInst.contains(I))
1292 continue;
1293 PHINode *PN = cast<PHINode>(I->getOperand(0));
1294 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1295 DeadInst.insert(PN);
1296 Change = true;
1297 }
1298 }
1299
1300 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1301 // have no uses. We do some DeadCodeElimination for them.
1302 while (!DeadInst.empty()) {
1303 Instruction *I = DeadInst.pop_back_val();
1304 Change |= DCEInstruction(I, DeadInst, TLI);
1305 }
1306 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1307 "optimizeAMXCastFromPhi:\n";
1308 Func.dump());
1309 return Change;
1310}
1311
1312// There might be remaining AMXcast after combineAMXcast and they should be
1313// handled elegantly.
1314bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1315 IRBuilder<> Builder(AMXCast);
1316 AllocaInst *AllocaAddr;
1317 Value *I8Ptr, *Stride;
1318 auto *Src = AMXCast->getOperand(0);
1319
1320 auto Prepare = [&](Type *MemTy) {
1321 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1322 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1323 Stride = Builder.getInt64(64);
1324 };
1325
1326 if (AMXCast->getType()->isX86_AMXTy()) {
1327 // %2 = amxcast <225 x i32> %src to x86_amx
1328 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1329 // i8* %addr3, i64 60, x86_amx %2)
1330 // -->
1331 // %addr = alloca <225 x i32>, align 64
1332 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1333 // %addr2 = bitcast <225 x i32>* %addr to i8*
1334 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1335 // i8* %addr2,
1336 // i64 60)
1337 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1338 // i8* %addr3, i64 60, x86_amx %2)
1339 if (AMXCast->use_empty()) {
1340 AMXCast->eraseFromParent();
1341 return true;
1342 }
1343 Use &U = *(AMXCast->use_begin());
1344 unsigned OpNo = U.getOperandNo();
1345 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1346 if (!II)
1347 return false; // May be bitcast from x86amx to <256 x i32>.
1348 Prepare(AMXCast->getOperand(0)->getType());
1349 Builder.CreateStore(Src, AllocaAddr);
1350 // TODO we can pick an constant operand for the shape.
1351 Value *Row = nullptr, *Col = nullptr;
1352 std::tie(Row, Col) = SC->getShape(II, OpNo);
1353 std::array<Value *, 4> Args = {
1354 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1355 Value *NewInst =
1356 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args);
1357 AMXCast->replaceAllUsesWith(NewInst);
1358 AMXCast->eraseFromParent();
1359 } else {
1360 // %2 = amxcast x86_amx %src to <225 x i32>
1361 // -->
1362 // %addr = alloca <225 x i32>, align 64
1363 // %addr2 = bitcast <225 x i32>* to i8*
1364 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1365 // i8* %addr2, i64 %stride)
1366 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1367 auto *II = dyn_cast<IntrinsicInst>(Src);
1368 if (!II)
1369 return false; // May be bitcast from <256 x i32> to x86amx.
1370 Prepare(AMXCast->getType());
1371 Value *Row = II->getOperand(0);
1372 Value *Col = II->getOperand(1);
1373 std::array<Value *, 5> Args = {
1374 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1375 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args);
1376 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1377 AMXCast->replaceAllUsesWith(NewInst);
1378 AMXCast->eraseFromParent();
1379 }
1380
1381 return true;
1382}
1383
1384bool X86LowerAMXCast::transformAllAMXCast() {
1385 bool Change = false;
1386 // Collect tile cast instruction.
1388 for (BasicBlock &BB : Func) {
1389 for (Instruction &I : BB) {
1390 if (isAMXCast(&I))
1391 WorkLists.push_back(&I);
1392 }
1393 }
1394
1395 for (auto *Inst : WorkLists) {
1396 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1397 }
1398
1399 return Change;
1400}
1401
1402} // anonymous namespace
1403
1404namespace {
1405
1406class X86LowerAMXTypeLegacyPass : public FunctionPass {
1407public:
1408 static char ID;
1409
1410 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {}
1411
1412 bool runOnFunction(Function &F) override {
1413 // Performance optimization: most code doesn't use AMX, so return early if
1414 // there are no instructions that produce AMX values. This is sufficient, as
1415 // AMX arguments and constants are not allowed -- so any producer of an AMX
1416 // value must be an instruction.
1417 // TODO: find a cheaper way for this, without looking at all instructions.
1418 if (!containsAMXCode(F))
1419 return false;
1420
1421 bool C = false;
1422 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1423 TargetLibraryInfo *TLI =
1424 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1425
1426 ShapeCalculator SC(TM);
1427 X86LowerAMXCast LAC(F, &SC);
1428 C |= LAC.combineAMXcast(TLI);
1429 // There might be remaining AMXcast after combineAMXcast and they should be
1430 // handled elegantly.
1431 C |= LAC.transformAllAMXCast();
1432
1433 X86LowerAMXType LAT(F, &SC);
1434 C |= LAT.visit();
1435
1436 // Prepare for fast register allocation at O0.
1437 // Todo: May better check the volatile model of AMX code, not just
1438 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1439 if (TM->getOptLevel() == CodeGenOptLevel::None) {
1440 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1441 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1442 // sure the amx data is volatile, that is nessary for AMX fast
1443 // register allocation.
1444 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1445 X86VolatileTileData VTD(F);
1446 C = VTD.volatileTileData() || C;
1447 }
1448 }
1449
1450 return C;
1451 }
1452
1453 void getAnalysisUsage(AnalysisUsage &AU) const override {
1454 AU.setPreservesCFG();
1457 }
1458};
1459
1460} // anonymous namespace
1461
1462static const char PassName[] = "Lower AMX type for load/store";
1463char X86LowerAMXTypeLegacyPass::ID = 0;
1464INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1465 false)
1468INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1469 false)
1470
1472 return new X86LowerAMXTypeLegacyPass();
1473}
Rewrite undef for PHI
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition: DCE.cpp:55
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file implements a set that has insertion order iteration characteristics.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
static bool isAMXCast(Instruction *II)
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
static unsigned getNumDefTiles(IntrinsicInst *II)
static Value * getAllocaPos(BasicBlock *BB)
static bool containsAMXCode(Function &F)
static bool isIncomingOfPHI(Instruction *I)
static bool isAMXIntrinsic(Value *I)
#define DEBUG_TYPE
static const char PassName[]
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
Value * getRowFromCol(Instruction *II, Value *V, unsigned Granularity)
ShapeCalculator(TargetMachine *TargetM)
Value * getColFromRow(Instruction *II, Value *V, unsigned Granularity)
std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
an instruction to allocate memory on the stack
Definition: Instructions.h:63
void setAlignment(Align Align)
Definition: Instructions.h:128
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:256
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:219
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:177
This class represents a no-op cast from one type to another.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Value * CreateNUWMul(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:1397
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:890
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Definition: IRBuilder.h:523
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition: IRBuilder.h:1401
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition: IRBuilder.h:488
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2155
LLVMContext & getContext() const
Definition: IRBuilder.h:173
PointerType * getPtrTy(unsigned AddrSpace=0)
Fetch the type representing a pointer.
Definition: IRBuilder.h:566
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition: IRBuilder.h:478
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2697
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:94
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:48
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
An instruction for reading from memory.
Definition: Instructions.h:176
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
bool empty() const
Determine if the SetVector is empty or not.
Definition: SetVector.h:93
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
value_type pop_back_val()
Definition: SetVector.h:285
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition: SetVector.h:254
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
bool empty() const
Definition: SmallVector.h:81
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static Type * getX86_AMXTy(LLVMContext &C)
unsigned getNumContainedTypes() const
Return the number of types in the derived type.
Definition: Type.h:390
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition: Type.h:200
Type * getContainedType(unsigned i) const
This method is used to implement the type iterator (defined at the end of the file).
Definition: Type.h:384
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
void setOperand(unsigned i, Value *Val)
Definition: User.h:233
Value * getOperand(unsigned i) const
Definition: User.h:228
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
use_iterator use_begin()
Definition: Value.h:360
bool use_empty() const
Definition: Value.h:344
const ParentTy * getParent() const
Definition: ilist_node.h:32
self_iterator getIterator()
Definition: ilist_node.h:132
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ Bitcast
Perform the operation on a different, but equivalently sized type.
@ SC
CHAIN = SC CHAIN, Imm128 - System call.
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
NodeAddr< FuncNode * > Func
Definition: RDFGraph.h:393
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
Definition: Utils.cpp:1668
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:657
iterator_range< po_iterator< T > > post_order(const T &G)
bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition: Local.cpp:406
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:420
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...