LLVM  15.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ... |
27 /// | ... |
28 /// | "use %td" |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ... |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34 /// | ... |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2" |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
48 #include "llvm/CodeGen/Passes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
63 
64 #include <map>
65 
66 using namespace llvm;
67 using namespace PatternMatch;
68 
69 #define DEBUG_TYPE "lower-amx-type"
70 
71 static bool isAMXCast(Instruction *II) {
72  return match(II,
73  m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74  match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75 }
76 
77 static bool isAMXIntrinsic(Value *I) {
78  auto *II = dyn_cast<IntrinsicInst>(I);
79  if (!II)
80  return false;
81  if (isAMXCast(II))
82  return false;
83  // Check if return type or parameter is x86_amx. If it is x86_amx
84  // the intrinsic must be x86 amx intrinsics.
85  if (II->getType()->isX86_AMXTy())
86  return true;
87  for (Value *V : II->args()) {
88  if (V->getType()->isX86_AMXTy())
89  return true;
90  }
91 
92  return false;
93 }
94 
96  Type *Ty) {
97  Function &F = *BB->getParent();
98  Module *M = BB->getModule();
99  const DataLayout &DL = M->getDataLayout();
100 
101  LLVMContext &Ctx = Builder.getContext();
102  auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
103  unsigned AllocaAS = DL.getAllocaAddrSpace();
104  AllocaInst *AllocaRes =
105  new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
106  AllocaRes->setAlignment(AllocaAlignment);
107  return AllocaRes;
108 }
109 
111  for (Instruction &I : F.getEntryBlock())
112  if (!isa<AllocaInst>(&I))
113  return &I;
114  llvm_unreachable("No terminator in the entry block!");
115 }
116 
117 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
118  IRBuilder<> Builder(II);
119  Value *Row = nullptr, *Col = nullptr;
120  switch (II->getIntrinsicID()) {
121  default:
122  llvm_unreachable("Expect amx intrinsics");
123  case Intrinsic::x86_tileloadd64_internal:
124  case Intrinsic::x86_tileloaddt164_internal:
125  case Intrinsic::x86_tilestored64_internal: {
126  Row = II->getArgOperand(0);
127  Col = II->getArgOperand(1);
128  break;
129  }
130  // a * b + c
131  // The shape depends on which operand.
132  case Intrinsic::x86_tdpbssd_internal:
133  case Intrinsic::x86_tdpbsud_internal:
134  case Intrinsic::x86_tdpbusd_internal:
135  case Intrinsic::x86_tdpbuud_internal:
136  case Intrinsic::x86_tdpbf16ps_internal: {
137  switch (OpNo) {
138  case 3:
139  Row = II->getArgOperand(0);
140  Col = II->getArgOperand(1);
141  break;
142  case 4:
143  Row = II->getArgOperand(0);
144  Col = II->getArgOperand(2);
145  break;
146  case 5:
147  if (isa<ConstantInt>(II->getArgOperand(2)))
148  Row = Builder.getInt16(
149  (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
150  else if (isa<Instruction>(II->getArgOperand(2))) {
151  // When it is not a const value and it is not a function argument, we
152  // create Row after the definition of II->getOperand(2) instead of
153  // before II. For example, II is %118, we try to getshape for %117:
154  // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
155  // i32> %115).
156  // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
157  // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
158  // %117).
159  // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
160  // definition is after its user(new tileload for %117).
161  // So, the best choice is to create %row right after the definition of
162  // %106.
163  Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
164  Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
165  cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
166  } else {
167  // When it is not a const value and it is a function argument, we create
168  // Row at the entry bb.
169  IRBuilder<> NewBuilder(
171  Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
172  }
173  Col = II->getArgOperand(1);
174  break;
175  }
176  break;
177  }
178  }
179 
180  return std::make_pair(Row, Col);
181 }
182 
183 static std::pair<Value *, Value *> getShape(PHINode *Phi) {
184  Use &U = *(Phi->use_begin());
185  unsigned OpNo = U.getOperandNo();
186  User *V = U.getUser();
187  // TODO We don't traverse all users. To make the algorithm simple, here we
188  // just traverse the first user. If we can find shape, then return the shape,
189  // otherwise just return nullptr and the optimization for undef/zero will be
190  // abandoned.
191  while (V) {
192  if (isAMXCast(dyn_cast<Instruction>(V))) {
193  if (V->use_empty())
194  break;
195  Use &U = *(V->use_begin());
196  OpNo = U.getOperandNo();
197  V = U.getUser();
198  } else if (isAMXIntrinsic(V)) {
199  return getShape(cast<IntrinsicInst>(V), OpNo);
200  } else if (isa<PHINode>(V)) {
201  if (V->use_empty())
202  break;
203  Use &U = *(V->use_begin());
204  V = U.getUser();
205  } else {
206  break;
207  }
208  }
209 
210  return std::make_pair(nullptr, nullptr);
211 }
212 
213 namespace {
214 class X86LowerAMXType {
215  Function &Func;
216 
217  // In AMX intrinsics we let Shape = {Row, Col}, but the
218  // RealCol = Col / ElementSize. We may use the RealCol
219  // as a new Row for other new created AMX intrinsics.
220  std::map<Value *, Value *> Col2Row;
221 
222 public:
223  X86LowerAMXType(Function &F) : Func(F) {}
224  bool visit();
225  void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
226  void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
227  bool transformBitcast(BitCastInst *Bitcast);
228 };
229 
230 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
231 // %2 = bitcast <256 x i32> %src to x86_amx
232 // -->
233 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
234 // i8* %addr, i64 %stride64)
235 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
236  Value *Row = nullptr, *Col = nullptr;
237  Use &U = *(Bitcast->use_begin());
238  unsigned OpNo = U.getOperandNo();
239  auto *II = cast<IntrinsicInst>(U.getUser());
240  std::tie(Row, Col) = getShape(II, OpNo);
242  // Use the maximun column as stride.
243  Value *Stride = Builder.getInt64(64);
244  Value *I8Ptr =
245  Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
246  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
247 
248  Value *NewInst =
249  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
250  Bitcast->replaceAllUsesWith(NewInst);
251 }
252 
253 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
254 // %stride);
255 // %13 = bitcast x86_amx %src to <256 x i32>
256 // store <256 x i32> %13, <256 x i32>* %addr, align 64
257 // -->
258 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
259 // %stride64, %13)
260 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
261 
262  Value *Tile = Bitcast->getOperand(0);
263  auto *II = cast<IntrinsicInst>(Tile);
264  // Tile is output from AMX intrinsic. The first operand of the
265  // intrinsic is row, the second operand of the intrinsic is column.
266  Value *Row = II->getOperand(0);
267  Value *Col = II->getOperand(1);
269  // Use the maximum column as stride. It must be the same with load
270  // stride.
271  Value *Stride = Builder.getInt64(64);
272  Value *I8Ptr =
273  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
274  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
275  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
276  if (Bitcast->hasOneUse())
277  return;
278  // %13 = bitcast x86_amx %src to <256 x i32>
279  // store <256 x i32> %13, <256 x i32>* %addr, align 64
280  // %add = <256 x i32> %13, <256 x i32> %src2
281  // -->
282  // %13 = bitcast x86_amx %src to <256 x i32>
283  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
284  // %stride64, %13)
285  // %14 = load <256 x i32>, %addr
286  // %add = <256 x i32> %14, <256 x i32> %src2
287  Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
288  Bitcast->replaceAllUsesWith(Vec);
289 }
290 
291 // transform bitcast to <store, load> instructions.
292 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
294  AllocaInst *AllocaAddr;
295  Value *I8Ptr, *Stride;
296  auto *Src = Bitcast->getOperand(0);
297 
298  auto Prepare = [&](Type *MemTy) {
299  AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
300  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
301  Stride = Builder.getInt64(64);
302  };
303 
304  if (Bitcast->getType()->isX86_AMXTy()) {
305  // %2 = bitcast <256 x i32> %src to x86_amx
306  // -->
307  // %addr = alloca <256 x i32>, align 64
308  // store <256 x i32> %src, <256 x i32>* %addr, align 64
309  // %addr2 = bitcast <256 x i32>* to i8*
310  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
311  // i8* %addr2,
312  // i64 64)
313  Use &U = *(Bitcast->use_begin());
314  unsigned OpNo = U.getOperandNo();
315  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
316  if (!II)
317  return false; // May be bitcast from x86amx to <256 x i32>.
318  Prepare(Bitcast->getOperand(0)->getType());
319  Builder.CreateStore(Src, AllocaAddr);
320  // TODO we can pick an constant operand for the shape.
321  Value *Row = nullptr, *Col = nullptr;
322  std::tie(Row, Col) = getShape(II, OpNo);
323  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
324  Value *NewInst = Builder.CreateIntrinsic(
325  Intrinsic::x86_tileloadd64_internal, None, Args);
326  Bitcast->replaceAllUsesWith(NewInst);
327  } else {
328  // %2 = bitcast x86_amx %src to <256 x i32>
329  // -->
330  // %addr = alloca <256 x i32>, align 64
331  // %addr2 = bitcast <256 x i32>* to i8*
332  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
333  // i8* %addr2, i64 %stride)
334  // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
335  auto *II = dyn_cast<IntrinsicInst>(Src);
336  if (!II)
337  return false; // May be bitcast from <256 x i32> to x86amx.
338  Prepare(Bitcast->getType());
339  Value *Row = II->getOperand(0);
340  Value *Col = II->getOperand(1);
341  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
342  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
343  Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
344  Bitcast->replaceAllUsesWith(NewInst);
345  }
346 
347  return true;
348 }
349 
350 bool X86LowerAMXType::visit() {
352  Col2Row.clear();
353 
354  for (BasicBlock *BB : post_order(&Func)) {
356  auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
357  if (!Bitcast)
358  continue;
359 
360  Value *Src = Bitcast->getOperand(0);
361  if (Bitcast->getType()->isX86_AMXTy()) {
362  if (Bitcast->user_empty()) {
363  DeadInsts.push_back(Bitcast);
364  continue;
365  }
366  LoadInst *LD = dyn_cast<LoadInst>(Src);
367  if (!LD) {
368  if (transformBitcast(Bitcast))
369  DeadInsts.push_back(Bitcast);
370  continue;
371  }
372  // If load has mutli-user, duplicate a vector load.
373  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
374  // %2 = bitcast <256 x i32> %src to x86_amx
375  // %add = add <256 x i32> %src, <256 x i32> %src2
376  // -->
377  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
378  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
379  // i8* %addr, i64 %stride64)
380  // %add = add <256 x i32> %src, <256 x i32> %src2
381 
382  // If load has one user, the load will be eliminated in DAG ISel.
383  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
384  // %2 = bitcast <256 x i32> %src to x86_amx
385  // -->
386  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
387  // i8* %addr, i64 %stride64)
388  combineLoadBitcast(LD, Bitcast);
389  DeadInsts.push_back(Bitcast);
390  if (LD->hasOneUse())
391  DeadInsts.push_back(LD);
392  } else if (Src->getType()->isX86_AMXTy()) {
393  if (Bitcast->user_empty()) {
394  DeadInsts.push_back(Bitcast);
395  continue;
396  }
397  StoreInst *ST = nullptr;
398  for (Use &U : Bitcast->uses()) {
399  ST = dyn_cast<StoreInst>(U.getUser());
400  if (ST)
401  break;
402  }
403  if (!ST) {
404  if (transformBitcast(Bitcast))
405  DeadInsts.push_back(Bitcast);
406  continue;
407  }
408  // If bitcast (%13) has one use, combine bitcast and store to amx store.
409  // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
410  // %stride);
411  // %13 = bitcast x86_amx %src to <256 x i32>
412  // store <256 x i32> %13, <256 x i32>* %addr, align 64
413  // -->
414  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
415  // %stride64, %13)
416  //
417  // If bitcast (%13) has multi-use, transform as below.
418  // %13 = bitcast x86_amx %src to <256 x i32>
419  // store <256 x i32> %13, <256 x i32>* %addr, align 64
420  // %add = <256 x i32> %13, <256 x i32> %src2
421  // -->
422  // %13 = bitcast x86_amx %src to <256 x i32>
423  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
424  // %stride64, %13)
425  // %14 = load <256 x i32>, %addr
426  // %add = <256 x i32> %14, <256 x i32> %src2
427  //
428  combineBitcastStore(Bitcast, ST);
429  // Delete user first.
430  DeadInsts.push_back(ST);
431  DeadInsts.push_back(Bitcast);
432  }
433  }
434  }
435 
436  bool C = !DeadInsts.empty();
437 
438  for (auto *Inst : DeadInsts)
439  Inst->eraseFromParent();
440 
441  return C;
442 }
443 } // anonymous namespace
444 
446  Module *M = BB->getModule();
447  Function *F = BB->getParent();
448  IRBuilder<> Builder(&F->getEntryBlock().front());
449  const DataLayout &DL = M->getDataLayout();
450  unsigned AllocaAS = DL.getAllocaAddrSpace();
451  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
452  AllocaInst *AllocaRes =
453  new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
454  BasicBlock::iterator Iter = AllocaRes->getIterator();
455  ++Iter;
456  Builder.SetInsertPoint(&*Iter);
457  Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
458  return I8Ptr;
459 }
460 
461 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
462  assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
463  auto *II = cast<IntrinsicInst>(TileDef);
464  assert(II && "Not tile intrinsic!");
465  Value *Row = II->getOperand(0);
466  Value *Col = II->getOperand(1);
467 
468  BasicBlock *BB = TileDef->getParent();
469  BasicBlock::iterator Iter = TileDef->getIterator();
470  IRBuilder<> Builder(BB, ++Iter);
471  Value *Stride = Builder.getInt64(64);
472  std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
473 
474  Instruction *TileStore =
475  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
476  return TileStore;
477 }
478 
479 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
480  Value *V = U.get();
481  assert(V->getType()->isX86_AMXTy() && "Not define tile!");
482 
483  // Get tile shape.
484  IntrinsicInst *II = nullptr;
485  if (IsPHI) {
486  Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
487  II = cast<IntrinsicInst>(PhiOp);
488  } else {
489  II = cast<IntrinsicInst>(V);
490  }
491  Value *Row = II->getOperand(0);
492  Value *Col = II->getOperand(1);
493 
494  Instruction *UserI = dyn_cast<Instruction>(U.getUser());
495  IRBuilder<> Builder(UserI);
496  Value *Stride = Builder.getInt64(64);
497  std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
498 
499  Value *TileLoad =
500  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
501  UserI->replaceUsesOfWith(V, TileLoad);
502 }
503 
505  for (Use &U : I->uses()) {
506  User *V = U.getUser();
507  if (isa<PHINode>(V))
508  return true;
509  }
510  return false;
511 }
512 
513 // Let all AMX tile data become volatile data, shorten the life range
514 // of each tile register before fast register allocation.
515 namespace {
516 class X86VolatileTileData {
517  Function &F;
518 
519 public:
520  X86VolatileTileData(Function &Func) : F(Func) {}
521  Value *updatePhiIncomings(BasicBlock *BB,
522  SmallVector<Instruction *, 2> &Incomings);
523  void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
524  bool volatileTileData();
525  void volatileTilePHI(PHINode *Inst);
526  void volatileTileNonPHI(Instruction *I);
527 };
528 
529 Value *X86VolatileTileData::updatePhiIncomings(
531  Value *I8Ptr = getAllocaPos(BB);
532 
533  for (auto *I : Incomings) {
534  User *Store = createTileStore(I, I8Ptr);
535 
536  // All its uses (except phi) should load from stored mem.
537  for (Use &U : I->uses()) {
538  User *V = U.getUser();
539  if (isa<PHINode>(V) || V == Store)
540  continue;
541  replaceWithTileLoad(U, I8Ptr);
542  }
543  }
544  return I8Ptr;
545 }
546 
547 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
548  Value *StorePtr) {
549  for (Use &U : PHI->uses())
550  replaceWithTileLoad(U, StorePtr, true);
551  PHI->eraseFromParent();
552 }
553 
554 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
555 // and their related AMX intrinsics.
556 // 1) PHI Def should change to tileload.
557 // 2) PHI Incoming Values should tilestored in just after their def.
558 // 3) The mem of these tileload and tilestores should be same.
559 // e.g.
560 // ------------------------------------------------------
561 // bb_dom:
562 // ...
563 // br i1 %bool.cond, label %if.else, label %if.then
564 //
565 // if.then:
566 // def %t0 = ...
567 // ...
568 // use %t0
569 // ...
570 // br label %if.end
571 //
572 // if.else:
573 // def %t1 = ...
574 // br label %if.end
575 //
576 // if.end:
577 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
578 // ...
579 // use %td
580 // ------------------------------------------------------
581 // -->
582 // ------------------------------------------------------
583 // bb_entry:
584 // %mem = alloca <256 x i32>, align 1024 *
585 // ...
586 // bb_dom:
587 // ...
588 // br i1 %bool.cond, label %if.else, label %if.then
589 //
590 // if.then:
591 // def %t0 = ...
592 // call void @llvm.x86.tilestored64.internal(mem, %t0) *
593 // ...
594 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
595 // use %t0` *
596 // ...
597 // br label %if.end
598 //
599 // if.else:
600 // def %t1 = ...
601 // call void @llvm.x86.tilestored64.internal(mem, %t1) *
602 // br label %if.end
603 //
604 // if.end:
605 // ...
606 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
607 // use %td
608 // ------------------------------------------------------
609 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
610  BasicBlock *BB = PHI->getParent();
612 
613  for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
614  Value *Op = PHI->getIncomingValue(I);
615  Instruction *Inst = dyn_cast<Instruction>(Op);
616  assert(Inst && "We shouldn't fold AMX instrution!");
617  Incomings.push_back(Inst);
618  }
619 
620  Value *StorePtr = updatePhiIncomings(BB, Incomings);
621  replacePhiDefWithLoad(PHI, StorePtr);
622 }
623 
624 // Store the defined tile and load it before use.
625 // All its users are not PHI.
626 // e.g.
627 // ------------------------------------------------------
628 // def %td = ...
629 // ...
630 // "use %td"
631 // ------------------------------------------------------
632 // -->
633 // ------------------------------------------------------
634 // def %td = ...
635 // call void @llvm.x86.tilestored64.internal(mem, %td)
636 // ...
637 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
638 // "use %td2"
639 // ------------------------------------------------------
640 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
641  BasicBlock *BB = I->getParent();
642  Value *I8Ptr = getAllocaPos(BB);
643  User *Store = createTileStore(I, I8Ptr);
644 
645  // All its uses should load from stored mem.
646  for (Use &U : I->uses()) {
647  User *V = U.getUser();
648  assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
649  if (V != Store)
650  replaceWithTileLoad(U, I8Ptr);
651  }
652 }
653 
654 // Volatile Tile Model:
655 // 1) All the uses of tile data comes from tileload in time.
656 // 2) All the defs of tile data tilestore into mem immediately.
657 // For example:
658 // --------------------------------------------------------------------------
659 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
660 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
661 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
662 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
663 // call void @llvm.x86.tilestored64.internal(... td) area
664 // --------------------------------------------------------------------------
665 // 3) No terminator, call or other amx instructions in the key amx area.
666 bool X86VolatileTileData::volatileTileData() {
667  bool Changed = false;
668  for (BasicBlock &BB : F) {
670  SmallVector<Instruction *, 8> AMXDefInsts;
671 
672  for (Instruction &I : BB) {
673  if (!I.getType()->isX86_AMXTy())
674  continue;
675  if (isa<PHINode>(&I))
676  PHIInsts.push_back(&I);
677  else
678  AMXDefInsts.push_back(&I);
679  }
680 
681  // First we "volatile" the non-phi related amx intrinsics.
682  for (Instruction *I : AMXDefInsts) {
683  if (isIncomingOfPHI(I))
684  continue;
685  volatileTileNonPHI(I);
686  Changed = true;
687  }
688 
689  for (Instruction *I : PHIInsts) {
690  volatileTilePHI(dyn_cast<PHINode>(I));
691  Changed = true;
692  }
693  }
694  return Changed;
695 }
696 
697 } // anonymous namespace
698 
699 namespace {
700 
701 class X86LowerAMXCast {
702  Function &Func;
703 
704 public:
705  X86LowerAMXCast(Function &F) : Func(F) {}
706  void combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
707  void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
708  bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
709  bool combineAMXcast(TargetLibraryInfo *TLI);
710  bool transformAMXCast(IntrinsicInst *AMXCast);
711  bool transformAllAMXCast();
712  bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
714 };
715 
716 static bool DCEInstruction(Instruction *I,
718  const TargetLibraryInfo *TLI) {
719  if (isInstructionTriviallyDead(I, TLI)) {
722 
723  // Null out all of the instruction's operands to see if any operand becomes
724  // dead as we go.
725  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
726  Value *OpV = I->getOperand(i);
727  I->setOperand(i, nullptr);
728 
729  if (!OpV->use_empty() || I == OpV)
730  continue;
731 
732  // If the operand is an instruction that became dead as we nulled out the
733  // operand, and if it is 'trivially' dead, delete it in a future loop
734  // iteration.
735  if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
736  if (isInstructionTriviallyDead(OpI, TLI)) {
737  WorkList.insert(OpI);
738  }
739  }
740  }
741  I->eraseFromParent();
742  return true;
743  }
744  return false;
745 }
746 
747 /// This function handles following case
748 ///
749 /// A -> B amxcast
750 /// PHI
751 /// B -> A amxcast
752 ///
753 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
754 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
755 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
756  IntrinsicInst *CI, PHINode *PN,
758  IRBuilder<> Builder(CI);
759  Value *Src = CI->getOperand(0);
760  Type *SrcTy = Src->getType(); // Type B
761  Type *DestTy = CI->getType(); // Type A
762 
763  SmallVector<PHINode *, 4> PhiWorklist;
764  SmallSetVector<PHINode *, 4> OldPhiNodes;
765 
766  // Find all of the A->B casts and PHI nodes.
767  // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
768  // OldPhiNodes is used to track all known PHI nodes, before adding a new
769  // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
770  PhiWorklist.push_back(PN);
771  OldPhiNodes.insert(PN);
772  while (!PhiWorklist.empty()) {
773  auto *OldPN = PhiWorklist.pop_back_val();
774  for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
775  Value *IncValue = OldPN->getIncomingValue(I);
776  // TODO: currently, We ignore cases where it is a const. In the future, we
777  // might support const.
778  if (isa<Constant>(IncValue)) {
779  auto *IncConst = dyn_cast<Constant>(IncValue);
780  if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
781  return false;
782  Value *Row = nullptr, *Col = nullptr;
783  std::tie(Row, Col) = getShape(OldPN);
784  // TODO: If it is not constant the Row and Col must domoniate tilezero
785  // that we are going to create.
786  if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
787  return false;
788  // Create tilezero at the end of incoming block.
789  auto *Block = OldPN->getIncomingBlock(I);
790  BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
791  Instruction *NewInst = Builder.CreateIntrinsic(
792  Intrinsic::x86_tilezero_internal, None, {Row, Col});
793  NewInst->moveBefore(&*Iter);
794  NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
795  {IncValue->getType()}, {NewInst});
796  NewInst->moveBefore(&*Iter);
797  // Replace InValue with new Value.
798  OldPN->setIncomingValue(I, NewInst);
799  IncValue = NewInst;
800  }
801 
802  if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
803  if (OldPhiNodes.insert(PNode))
804  PhiWorklist.push_back(PNode);
805  continue;
806  }
807  Instruction *ACI = dyn_cast<Instruction>(IncValue);
808  if (ACI && isAMXCast(ACI)) {
809  // Verify it's a A->B cast.
810  Type *TyA = ACI->getOperand(0)->getType();
811  Type *TyB = ACI->getType();
812  if (TyA != DestTy || TyB != SrcTy)
813  return false;
814  continue;
815  }
816  return false;
817  }
818  }
819 
820  // Check that each user of each old PHI node is something that we can
821  // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
822  for (auto *OldPN : OldPhiNodes) {
823  for (User *V : OldPN->users()) {
824  Instruction *ACI = dyn_cast<Instruction>(V);
825  if (ACI && isAMXCast(ACI)) {
826  // Verify it's a B->A cast.
827  Type *TyB = ACI->getOperand(0)->getType();
828  Type *TyA = ACI->getType();
829  if (TyA != DestTy || TyB != SrcTy)
830  return false;
831  } else if (auto *PHI = dyn_cast<PHINode>(V)) {
832  // As long as the user is another old PHI node, then even if we don't
833  // rewrite it, the PHI web we're considering won't have any users
834  // outside itself, so it'll be dead.
835  // example:
836  // bb.0:
837  // %0 = amxcast ...
838  // bb.1:
839  // %1 = amxcast ...
840  // bb.2:
841  // %goodphi = phi %0, %1
842  // %3 = amxcast %goodphi
843  // bb.3:
844  // %goodphi2 = phi %0, %goodphi
845  // %4 = amxcast %goodphi2
846  // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
847  // outside the phi-web, so the combination stop When
848  // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
849  // will be done.
850  if (OldPhiNodes.count(PHI) == 0)
851  return false;
852  } else
853  return false;
854  }
855  }
856 
857  // For each old PHI node, create a corresponding new PHI node with a type A.
859  for (auto *OldPN : OldPhiNodes) {
860  Builder.SetInsertPoint(OldPN);
861  PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
862  NewPNodes[OldPN] = NewPN;
863  }
864 
865  // Fill in the operands of new PHI nodes.
866  for (auto *OldPN : OldPhiNodes) {
867  PHINode *NewPN = NewPNodes[OldPN];
868  for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
869  Value *V = OldPN->getOperand(j);
870  Value *NewV = nullptr;
871  Instruction *ACI = dyn_cast<Instruction>(V);
872  // There should not be a AMXcast from a const.
873  if (ACI && isAMXCast(ACI))
874  NewV = ACI->getOperand(0);
875  else if (auto *PrevPN = dyn_cast<PHINode>(V))
876  NewV = NewPNodes[PrevPN];
877  assert(NewV);
878  NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
879  }
880  }
881 
882  // Traverse all accumulated PHI nodes and process its users,
883  // which are Stores and BitcCasts. Without this processing
884  // NewPHI nodes could be replicated and could lead to extra
885  // moves generated after DeSSA.
886  // If there is a store with type B, change it to type A.
887 
888  // Replace users of BitCast B->A with NewPHI. These will help
889  // later to get rid of a closure formed by OldPHI nodes.
890  for (auto *OldPN : OldPhiNodes) {
891  PHINode *NewPN = NewPNodes[OldPN];
892  for (User *V : make_early_inc_range(OldPN->users())) {
893  Instruction *ACI = dyn_cast<Instruction>(V);
894  if (ACI && isAMXCast(ACI)) {
895  Type *TyB = ACI->getOperand(0)->getType();
896  Type *TyA = ACI->getType();
897  assert(TyA == DestTy && TyB == SrcTy);
898  (void)TyA;
899  (void)TyB;
900  ACI->replaceAllUsesWith(NewPN);
901  DeadInst.insert(ACI);
902  } else if (auto *PHI = dyn_cast<PHINode>(V)) {
903  // We don't need to push PHINode into DeadInst since they are operands
904  // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
905  assert(OldPhiNodes.contains(PHI));
906  (void)PHI;
907  } else
908  llvm_unreachable("all uses should be handled");
909  }
910  }
911  return true;
912 }
913 
914 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
915 // store <256 x i32> %43, <256 x i32>* %p, align 64
916 // -->
917 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
918 // i64 64, x86_amx %42)
919 void X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
920  Value *Tile = Cast->getOperand(0);
921  // TODO: If it is cast intrinsic or phi node, we can propagate the
922  // shape information through def-use chain.
923  if (!isAMXIntrinsic(Tile))
924  return;
925  auto *II = cast<IntrinsicInst>(Tile);
926  // Tile is output from AMX intrinsic. The first operand of the
927  // intrinsic is row, the second operand of the intrinsic is column.
928  Value *Row = II->getOperand(0);
929  Value *Col = II->getOperand(1);
931  // Use the maximum column as stride. It must be the same with load
932  // stride.
933  Value *Stride = Builder.getInt64(64);
934  Value *I8Ptr =
935  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
936  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
937  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
938 }
939 
940 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
941 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
942 // -->
943 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
944 // i8* %p, i64 64)
945 void X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
946  Value *Row = nullptr, *Col = nullptr;
947  Use &U = *(Cast->use_begin());
948  unsigned OpNo = U.getOperandNo();
949  auto *II = cast<IntrinsicInst>(U.getUser());
950  // TODO: If it is cast intrinsic or phi node, we can propagate the
951  // shape information through def-use chain.
952  if (!isAMXIntrinsic(II))
953  return;
954  std::tie(Row, Col) = getShape(II, OpNo);
956  // Use the maximun column as stride.
957  Value *Stride = Builder.getInt64(64);
958  Value *I8Ptr =
959  Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
960  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
961 
962  Value *NewInst =
963  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
964  Cast->replaceAllUsesWith(NewInst);
965 }
966 
967 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
968  bool Change = false;
969  for (auto *Cast : Casts) {
970  auto *II = cast<IntrinsicInst>(Cast);
971  // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
972  // store <256 x i32> %43, <256 x i32>* %p, align 64
973  // -->
974  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
975  // i64 64, x86_amx %42)
976  if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
978  for (User *U : Cast->users()) {
979  StoreInst *Store = dyn_cast<StoreInst>(U);
980  if (!Store)
981  continue;
982  combineCastStore(cast<IntrinsicInst>(Cast), Store);
983  DeadStores.push_back(Store);
984  Change = true;
985  }
986  for (auto *Store : DeadStores)
987  Store->eraseFromParent();
988  } else { // x86_cast_vector_to_tile
990  auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
991  if (!Load || !Load->hasOneUse())
992  continue;
993  // %65 = load <256 x i32>, <256 x i32>* %p, align 64
994  // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
995  // -->
996  // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
997  // i8* %p, i64 64)
998  combineLoadCast(cast<IntrinsicInst>(Cast), Load);
999  // Set the operand is null so that load instruction can be erased.
1000  Cast->setOperand(0, nullptr);
1001  Load->eraseFromParent();
1002  }
1003  }
1004  return Change;
1005 }
1006 
1007 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1008  bool Change = false;
1009  // Collect tile cast instruction.
1010  SmallVector<Instruction *, 8> Vec2TileInsts;
1011  SmallVector<Instruction *, 8> Tile2VecInsts;
1012  SmallVector<Instruction *, 8> PhiCastWorkList;
1014  for (BasicBlock &BB : Func) {
1015  for (Instruction &I : BB) {
1016  Value *Vec;
1017  if (match(&I,
1018  m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1019  Vec2TileInsts.push_back(&I);
1020  else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1021  m_Value(Vec))))
1022  Tile2VecInsts.push_back(&I);
1023  }
1024  }
1025 
1026  auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1027  for (auto *Inst : Insts) {
1028  for (User *U : Inst->users()) {
1029  IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1030  if (!II || II->getIntrinsicID() != IID)
1031  continue;
1032  // T1 = vec2tile V0
1033  // V2 = tile2vec T1
1034  // V3 = OP V2
1035  // -->
1036  // T1 = vec2tile V0
1037  // V2 = tile2vec T1
1038  // V3 = OP V0
1039  II->replaceAllUsesWith(Inst->getOperand(0));
1040  Change = true;
1041  }
1042  }
1043  };
1044 
1045  Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1046  Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1047 
1049  auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1050  for (auto *Inst : Insts) {
1051  if (Inst->use_empty()) {
1052  Inst->eraseFromParent();
1053  Change = true;
1054  } else {
1055  LiveCasts.push_back(Inst);
1056  }
1057  }
1058  };
1059 
1060  EraseInst(Vec2TileInsts);
1061  EraseInst(Tile2VecInsts);
1062  Change |= combineLdSt(LiveCasts);
1063  EraseInst(LiveCasts);
1064 
1065  // Handle the A->B->A cast, and there is an intervening PHI node.
1066  for (BasicBlock &BB : Func) {
1067  for (Instruction &I : BB) {
1068  if (isAMXCast(&I)) {
1069  if (isa<PHINode>(I.getOperand(0)))
1070  PhiCastWorkList.push_back(&I);
1071  }
1072  }
1073  }
1074  for (auto *I : PhiCastWorkList) {
1075  // We skip the dead Amxcast.
1076  if (DeadInst.contains(I))
1077  continue;
1078  PHINode *PN = cast<PHINode>(I->getOperand(0));
1079  if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1080  DeadInst.insert(PN);
1081  Change = true;
1082  }
1083  }
1084 
1085  // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1086  // have no uses. We do some DeadCodeElimination for them.
1087  while (!DeadInst.empty()) {
1088  Instruction *I = DeadInst.pop_back_val();
1089  Change |= DCEInstruction(I, DeadInst, TLI);
1090  }
1091  return Change;
1092 }
1093 
1094 // There might be remaining AMXcast after combineAMXcast and they should be
1095 // handled elegantly.
1096 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1097  IRBuilder<> Builder(AMXCast);
1098  AllocaInst *AllocaAddr;
1099  Value *I8Ptr, *Stride;
1100  auto *Src = AMXCast->getOperand(0);
1101 
1102  auto Prepare = [&](Type *MemTy) {
1103  AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1104  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
1105  Stride = Builder.getInt64(64);
1106  };
1107 
1108  if (AMXCast->getType()->isX86_AMXTy()) {
1109  // %2 = amxcast <225 x i32> %src to x86_amx
1110  // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1111  // i8* %addr3, i64 60, x86_amx %2)
1112  // -->
1113  // %addr = alloca <225 x i32>, align 64
1114  // store <225 x i32> %src, <225 x i32>* %addr, align 64
1115  // %addr2 = bitcast <225 x i32>* %addr to i8*
1116  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1117  // i8* %addr2,
1118  // i64 60)
1119  // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1120  // i8* %addr3, i64 60, x86_amx %2)
1121  if (AMXCast->use_empty()) {
1122  AMXCast->eraseFromParent();
1123  return true;
1124  }
1125  Use &U = *(AMXCast->use_begin());
1126  unsigned OpNo = U.getOperandNo();
1127  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1128  if (!II)
1129  return false; // May be bitcast from x86amx to <256 x i32>.
1130  Prepare(AMXCast->getOperand(0)->getType());
1131  Builder.CreateStore(Src, AllocaAddr);
1132  // TODO we can pick an constant operand for the shape.
1133  Value *Row = nullptr, *Col = nullptr;
1134  std::tie(Row, Col) = getShape(II, OpNo);
1135  std::array<Value *, 4> Args = {
1136  Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1137  Value *NewInst = Builder.CreateIntrinsic(
1138  Intrinsic::x86_tileloadd64_internal, None, Args);
1139  AMXCast->replaceAllUsesWith(NewInst);
1140  AMXCast->eraseFromParent();
1141  } else {
1142  // %2 = amxcast x86_amx %src to <225 x i32>
1143  // -->
1144  // %addr = alloca <225 x i32>, align 64
1145  // %addr2 = bitcast <225 x i32>* to i8*
1146  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1147  // i8* %addr2, i64 %stride)
1148  // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1149  auto *II = dyn_cast<IntrinsicInst>(Src);
1150  if (!II)
1151  return false; // May be bitcast from <256 x i32> to x86amx.
1152  Prepare(AMXCast->getType());
1153  Value *Row = II->getOperand(0);
1154  Value *Col = II->getOperand(1);
1155  std::array<Value *, 5> Args = {
1156  Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1157  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
1158  Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1159  AMXCast->replaceAllUsesWith(NewInst);
1160  AMXCast->eraseFromParent();
1161  }
1162 
1163  return true;
1164 }
1165 
1166 bool X86LowerAMXCast::transformAllAMXCast() {
1167  bool Change = false;
1168  // Collect tile cast instruction.
1170  for (BasicBlock &BB : Func) {
1171  for (Instruction &I : BB) {
1172  if (isAMXCast(&I))
1173  WorkLists.push_back(&I);
1174  }
1175  }
1176 
1177  for (auto *Inst : WorkLists) {
1178  Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1179  }
1180 
1181  return Change;
1182 }
1183 
1184 } // anonymous namespace
1185 
1186 namespace {
1187 
1188 class X86LowerAMXTypeLegacyPass : public FunctionPass {
1189 public:
1190  static char ID;
1191 
1192  X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1194  }
1195 
1196  bool runOnFunction(Function &F) override {
1197  bool C = false;
1198  TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1199  TargetLibraryInfo *TLI =
1200  &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1201  X86LowerAMXCast LAC(F);
1202  C |= LAC.combineAMXcast(TLI);
1203  // There might be remaining AMXcast after combineAMXcast and they should be
1204  // handled elegantly.
1205  C |= LAC.transformAllAMXCast();
1206 
1207  X86LowerAMXType LAT(F);
1208  C |= LAT.visit();
1209 
1210  // Prepare for fast register allocation at O0.
1211  // Todo: May better check the volatile model of AMX code, not just
1212  // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1213  if (TM->getOptLevel() == CodeGenOpt::None) {
1214  // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1215  // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1216  // sure the amx data is volatile, that is nessary for AMX fast
1217  // register allocation.
1218  if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1219  X86VolatileTileData VTD(F);
1220  C = VTD.volatileTileData() || C;
1221  }
1222  }
1223 
1224  return C;
1225  }
1226 
1227  void getAnalysisUsage(AnalysisUsage &AU) const override {
1228  AU.setPreservesCFG();
1231  }
1232 };
1233 
1234 } // anonymous namespace
1235 
1236 static const char PassName[] = "Lower AMX type for load/store";
1238 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1239  false)
1242 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1243  false)
1244 
1246  return new X86LowerAMXTypeLegacyPass();
1247 }
i
i
Definition: README.txt:29
llvm::createX86LowerAMXTypePass
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...
Definition: X86LowerAMXType.cpp:1245
ValueTypes.h
DCEInstruction
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition: DCE.cpp:88
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:17
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::BasicBlock::iterator
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:87
IntrinsicInst.h
llvm::Function
Definition: Function.h:60
Pass.h
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:53
llvm::ARM_MB::LD
@ LD
Definition: ARMBaseInfo.h:72
llvm::BitCastInst
This class represents a no-op cast from one type to another.
Definition: Instructions.h:5212
llvm::SmallVector< Instruction *, 8 >
llvm::LegacyLegalizeActions::Bitcast
@ Bitcast
Perform the operation on a different, but equivalently sized type.
Definition: LegacyLegalizerInfo.h:54
llvm::IRBuilder<>
llvm::Use::get
Value * get() const
Definition: Use.h:66
llvm::SmallDenseMap
Definition: DenseMap.h:882
Local.h
OptimizationRemarkEmitter.h
createAllocaInstAtEntry
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
Definition: X86LowerAMXType.cpp:95
createTileStore
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
Definition: X86LowerAMXType.cpp:461
isAMXCast
static bool isAMXCast(Instruction *II)
Definition: X86LowerAMXType.cpp:71
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::reverse
auto reverse(ContainerTy &&C, std::enable_if_t< has_rbegin< ContainerTy >::value > *=nullptr)
Definition: STLExtras.h:380
isAMXIntrinsic
static bool isAMXIntrinsic(Value *I)
Definition: X86LowerAMXType.cpp:77
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:31
llvm::SmallVectorImpl::pop_back_val
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:654
replaceWithTileLoad
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
Definition: X86LowerAMXType.cpp:479
F
#define F(x, y, z)
Definition: MD5.cpp:55
DEBUG_TYPE
#define DEBUG_TYPE
Definition: X86LowerAMXType.cpp:69
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
X86.h
llvm::Type::getX86_AMXTy
static Type * getX86_AMXTy(LLVMContext &C)
Definition: Type.cpp:234
TargetMachine.h
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
llvm::PHINode::getIncomingValue
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
Definition: Instructions.h:2747
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
TargetLibraryInfo.h
AssumeBundleBuilder.h
llvm::Value::uses
iterator_range< use_iterator > uses()
Definition: Value.h:376
false
Definition: StackSlotColoring.cpp:141
llvm::Instruction
Definition: Instruction.h:42
llvm::Use::getUser
User * getUser() const
Returns the User that contains this Use.
Definition: Use.h:72
PatternMatch.h
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::empty
bool empty() const
Determine if the SetVector is empty or not.
Definition: SetVector.h:72
llvm::PHINode::getNumIncomingValues
unsigned getNumIncomingValues() const
Return the number of incoming edges.
Definition: Instructions.h:2743
llvm::None
const NoneType None
Definition: None.h:24
llvm::Value::use_empty
bool use_empty() const
Definition: Value.h:344
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
Passes.h
llvm::TargetPassConfig
Target-Independent Code Generator Pass Configuration Options.
Definition: TargetPassConfig.h:84
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:297
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::contains
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition: SetVector.h:209
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:77
llvm::TargetLibraryInfoWrapperPass
Definition: TargetLibraryInfo.h:468
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2801
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
llvm::numbers::e
constexpr double e
Definition: MathExtras.h:57
I
#define I(x, y, z)
Definition: MD5.cpp:58
llvm::make_early_inc_range
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:618
getAllocaPos
static Value * getAllocaPos(BasicBlock *BB)
Definition: X86LowerAMXType.cpp:445
TargetPassConfig.h
llvm::initializeX86LowerAMXTypeLegacyPassPass
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::TargetMachine
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
llvm::Value::use_begin
use_iterator use_begin()
Definition: Value.h:360
llvm::salvageKnowledge
void salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
Definition: AssumeBundleBuilder.cpp:293
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
PassName
static const char PassName[]
Definition: X86LowerAMXType.cpp:1236
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::User::setOperand
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
llvm::CodeGenOpt::None
@ None
Definition: CodeGen.h:53
getFirstNonAllocaInTheEntryBlock
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
Definition: X86LowerAMXType.cpp:110
llvm::isInstructionTriviallyDead
bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition: Local.cpp:395
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:263
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:143
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::Instruction::getFunction
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:69
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:529
llvm::ilist_node_impl::getIterator
self_iterator getIterator()
Definition: ilist_node.h:82
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
isIncomingOfPHI
static bool isIncomingOfPHI(Instruction *I)
Definition: X86LowerAMXType.cpp:504
llvm::ifs::IFSSymbolType::Func
@ Func
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:173
llvm::User::replaceUsesOfWith
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
j
return j(j<< 16)
getShape
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
Definition: X86LowerAMXType.cpp:117
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:348
llvm::post_order
iterator_range< po_iterator< T > > post_order(const T &G)
Definition: PostOrderIterator.h:189
Function.h
llvm::salvageDebugInfo
void salvageDebugInfo(Instruction &I)
Assuming the instruction I is going to be deleted, attempt to salvage debug users of I by writing the...
Definition: Local.cpp:1739
llvm::TargetLibraryInfo
Provides information about what library functions are available for the current target.
Definition: TargetLibraryInfo.h:222
llvm::SmallVectorImpl::clear
void clear()
Definition: SmallVector.h:591
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:46
llvm::AllocaInst::setAlignment
void setAlignment(Align Align)
Definition: Instructions.h:125
Instructions.h
PostOrderIterator.h
llvm::IRBuilderBase::getInt16
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition: IRBuilder.h:456
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1341
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:91
TargetTransformInfo.h
llvm::PHINode
Definition: Instructions.h:2651
llvm::SmallVectorImpl< Instruction * >
llvm::SmallSetVector
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:307
TM
const char LLVMTargetMachineRef TM
Definition: PassBuilderBindings.cpp:47
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
llvm::IRBuilderBase::CreateUDiv
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition: IRBuilder.h:1251
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
llvm::AMDGPU::HSAMD::Kernel::Key::Args
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Definition: AMDGPUMetadata.h:394
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:58
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::pdb::PDB_SymType::Block
@ Block
InitializePasses.h
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::pop_back_val
LLVM_NODISCARD T pop_back_val()
Definition: SetVector.h:232
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
llvm::VectorType::get
static VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
Definition: Type.cpp:668
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::Type::isX86_AMXTy
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition: Type.h:176
SetVector.h
llvm::Instruction::moveBefore
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
Definition: Instruction.cpp:96
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
SmallSet.h
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:38