LLVM  16.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ... |
27 /// | ... |
28 /// | "use %td" |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ... |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34 /// | ... |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2" |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
48 #include "llvm/CodeGen/Passes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
63 
64 #include <map>
65 
66 using namespace llvm;
67 using namespace PatternMatch;
68 
69 #define DEBUG_TYPE "lower-amx-type"
70 
71 static bool isAMXCast(Instruction *II) {
72  return match(II,
73  m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74  match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75 }
76 
77 static bool isAMXIntrinsic(Value *I) {
78  auto *II = dyn_cast<IntrinsicInst>(I);
79  if (!II)
80  return false;
81  if (isAMXCast(II))
82  return false;
83  // Check if return type or parameter is x86_amx. If it is x86_amx
84  // the intrinsic must be x86 amx intrinsics.
85  if (II->getType()->isX86_AMXTy())
86  return true;
87  for (Value *V : II->args()) {
88  if (V->getType()->isX86_AMXTy())
89  return true;
90  }
91 
92  return false;
93 }
94 
96  Type *Ty) {
97  Function &F = *BB->getParent();
98  Module *M = BB->getModule();
99  const DataLayout &DL = M->getDataLayout();
100 
101  LLVMContext &Ctx = Builder.getContext();
102  auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
103  unsigned AllocaAS = DL.getAllocaAddrSpace();
104  AllocaInst *AllocaRes =
105  new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
106  AllocaRes->setAlignment(AllocaAlignment);
107  return AllocaRes;
108 }
109 
111  for (Instruction &I : F.getEntryBlock())
112  if (!isa<AllocaInst>(&I))
113  return &I;
114  llvm_unreachable("No terminator in the entry block!");
115 }
116 
117 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
118  IRBuilder<> Builder(II);
119  Value *Row = nullptr, *Col = nullptr;
120  switch (II->getIntrinsicID()) {
121  default:
122  llvm_unreachable("Expect amx intrinsics");
123  case Intrinsic::x86_tileloadd64_internal:
124  case Intrinsic::x86_tileloaddt164_internal:
125  case Intrinsic::x86_tilestored64_internal: {
126  Row = II->getArgOperand(0);
127  Col = II->getArgOperand(1);
128  break;
129  }
130  // a * b + c
131  // The shape depends on which operand.
132  case Intrinsic::x86_tdpbssd_internal:
133  case Intrinsic::x86_tdpbsud_internal:
134  case Intrinsic::x86_tdpbusd_internal:
135  case Intrinsic::x86_tdpbuud_internal:
136  case Intrinsic::x86_tdpbf16ps_internal: {
137  switch (OpNo) {
138  case 3:
139  Row = II->getArgOperand(0);
140  Col = II->getArgOperand(1);
141  break;
142  case 4:
143  Row = II->getArgOperand(0);
144  Col = II->getArgOperand(2);
145  break;
146  case 5:
147  if (isa<ConstantInt>(II->getArgOperand(2)))
148  Row = Builder.getInt16(
149  (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
150  else if (isa<Instruction>(II->getArgOperand(2))) {
151  // When it is not a const value and it is not a function argument, we
152  // create Row after the definition of II->getOperand(2) instead of
153  // before II. For example, II is %118, we try to getshape for %117:
154  // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
155  // i32> %115).
156  // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
157  // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
158  // %117).
159  // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
160  // definition is after its user(new tileload for %117).
161  // So, the best choice is to create %row right after the definition of
162  // %106.
163  Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
164  Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
165  cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
166  } else {
167  // When it is not a const value and it is a function argument, we create
168  // Row at the entry bb.
169  IRBuilder<> NewBuilder(
171  Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
172  }
173  Col = II->getArgOperand(1);
174  break;
175  }
176  break;
177  }
178  }
179 
180  return std::make_pair(Row, Col);
181 }
182 
183 static std::pair<Value *, Value *> getShape(PHINode *Phi) {
184  Use &U = *(Phi->use_begin());
185  unsigned OpNo = U.getOperandNo();
186  User *V = U.getUser();
187  // TODO We don't traverse all users. To make the algorithm simple, here we
188  // just traverse the first user. If we can find shape, then return the shape,
189  // otherwise just return nullptr and the optimization for undef/zero will be
190  // abandoned.
191  while (V) {
192  if (isAMXCast(dyn_cast<Instruction>(V))) {
193  if (V->use_empty())
194  break;
195  Use &U = *(V->use_begin());
196  OpNo = U.getOperandNo();
197  V = U.getUser();
198  } else if (isAMXIntrinsic(V)) {
199  return getShape(cast<IntrinsicInst>(V), OpNo);
200  } else if (isa<PHINode>(V)) {
201  if (V->use_empty())
202  break;
203  Use &U = *(V->use_begin());
204  V = U.getUser();
205  } else {
206  break;
207  }
208  }
209 
210  return std::make_pair(nullptr, nullptr);
211 }
212 
213 namespace {
214 class X86LowerAMXType {
215  Function &Func;
216 
217  // In AMX intrinsics we let Shape = {Row, Col}, but the
218  // RealCol = Col / ElementSize. We may use the RealCol
219  // as a new Row for other new created AMX intrinsics.
220  std::map<Value *, Value *> Col2Row;
221 
222 public:
223  X86LowerAMXType(Function &F) : Func(F) {}
224  bool visit();
225  void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
226  void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
227  bool transformBitcast(BitCastInst *Bitcast);
228 };
229 
230 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
231 // %2 = bitcast <256 x i32> %src to x86_amx
232 // -->
233 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
234 // i8* %addr, i64 %stride64)
235 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
236  Value *Row = nullptr, *Col = nullptr;
237  Use &U = *(Bitcast->use_begin());
238  unsigned OpNo = U.getOperandNo();
239  auto *II = cast<IntrinsicInst>(U.getUser());
240  std::tie(Row, Col) = getShape(II, OpNo);
242  // Use the maximun column as stride.
243  Value *Stride = Builder.getInt64(64);
244  Value *I8Ptr =
245  Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
246  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
247 
248  Value *NewInst =
249  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
250  Bitcast->replaceAllUsesWith(NewInst);
251 }
252 
253 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
254 // %stride);
255 // %13 = bitcast x86_amx %src to <256 x i32>
256 // store <256 x i32> %13, <256 x i32>* %addr, align 64
257 // -->
258 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
259 // %stride64, %13)
260 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
261 
262  Value *Tile = Bitcast->getOperand(0);
263  auto *II = cast<IntrinsicInst>(Tile);
264  // Tile is output from AMX intrinsic. The first operand of the
265  // intrinsic is row, the second operand of the intrinsic is column.
266  Value *Row = II->getOperand(0);
267  Value *Col = II->getOperand(1);
269  // Use the maximum column as stride. It must be the same with load
270  // stride.
271  Value *Stride = Builder.getInt64(64);
272  Value *I8Ptr =
273  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
274  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
275  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
276  if (Bitcast->hasOneUse())
277  return;
278  // %13 = bitcast x86_amx %src to <256 x i32>
279  // store <256 x i32> %13, <256 x i32>* %addr, align 64
280  // %add = <256 x i32> %13, <256 x i32> %src2
281  // -->
282  // %13 = bitcast x86_amx %src to <256 x i32>
283  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
284  // %stride64, %13)
285  // %14 = load <256 x i32>, %addr
286  // %add = <256 x i32> %14, <256 x i32> %src2
287  Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
288  Bitcast->replaceAllUsesWith(Vec);
289 }
290 
291 // transform bitcast to <store, load> instructions.
292 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
294  AllocaInst *AllocaAddr;
295  Value *I8Ptr, *Stride;
296  auto *Src = Bitcast->getOperand(0);
297 
298  auto Prepare = [&](Type *MemTy) {
299  AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
300  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
301  Stride = Builder.getInt64(64);
302  };
303 
304  if (Bitcast->getType()->isX86_AMXTy()) {
305  // %2 = bitcast <256 x i32> %src to x86_amx
306  // -->
307  // %addr = alloca <256 x i32>, align 64
308  // store <256 x i32> %src, <256 x i32>* %addr, align 64
309  // %addr2 = bitcast <256 x i32>* to i8*
310  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
311  // i8* %addr2,
312  // i64 64)
313  Use &U = *(Bitcast->use_begin());
314  unsigned OpNo = U.getOperandNo();
315  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
316  if (!II)
317  return false; // May be bitcast from x86amx to <256 x i32>.
318  Prepare(Bitcast->getOperand(0)->getType());
319  Builder.CreateStore(Src, AllocaAddr);
320  // TODO we can pick an constant operand for the shape.
321  Value *Row = nullptr, *Col = nullptr;
322  std::tie(Row, Col) = getShape(II, OpNo);
323  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
324  Value *NewInst = Builder.CreateIntrinsic(
325  Intrinsic::x86_tileloadd64_internal, None, Args);
326  Bitcast->replaceAllUsesWith(NewInst);
327  } else {
328  // %2 = bitcast x86_amx %src to <256 x i32>
329  // -->
330  // %addr = alloca <256 x i32>, align 64
331  // %addr2 = bitcast <256 x i32>* to i8*
332  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
333  // i8* %addr2, i64 %stride)
334  // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
335  auto *II = dyn_cast<IntrinsicInst>(Src);
336  if (!II)
337  return false; // May be bitcast from <256 x i32> to x86amx.
338  Prepare(Bitcast->getType());
339  Value *Row = II->getOperand(0);
340  Value *Col = II->getOperand(1);
341  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
342  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
343  Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
344  Bitcast->replaceAllUsesWith(NewInst);
345  }
346 
347  return true;
348 }
349 
350 bool X86LowerAMXType::visit() {
352  Col2Row.clear();
353 
354  for (BasicBlock *BB : post_order(&Func)) {
356  auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
357  if (!Bitcast)
358  continue;
359 
360  Value *Src = Bitcast->getOperand(0);
361  if (Bitcast->getType()->isX86_AMXTy()) {
362  if (Bitcast->user_empty()) {
363  DeadInsts.push_back(Bitcast);
364  continue;
365  }
366  LoadInst *LD = dyn_cast<LoadInst>(Src);
367  if (!LD) {
368  if (transformBitcast(Bitcast))
369  DeadInsts.push_back(Bitcast);
370  continue;
371  }
372  // If load has mutli-user, duplicate a vector load.
373  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
374  // %2 = bitcast <256 x i32> %src to x86_amx
375  // %add = add <256 x i32> %src, <256 x i32> %src2
376  // -->
377  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
378  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
379  // i8* %addr, i64 %stride64)
380  // %add = add <256 x i32> %src, <256 x i32> %src2
381 
382  // If load has one user, the load will be eliminated in DAG ISel.
383  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
384  // %2 = bitcast <256 x i32> %src to x86_amx
385  // -->
386  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
387  // i8* %addr, i64 %stride64)
388  combineLoadBitcast(LD, Bitcast);
389  DeadInsts.push_back(Bitcast);
390  if (LD->hasOneUse())
391  DeadInsts.push_back(LD);
392  } else if (Src->getType()->isX86_AMXTy()) {
393  if (Bitcast->user_empty()) {
394  DeadInsts.push_back(Bitcast);
395  continue;
396  }
397  StoreInst *ST = nullptr;
398  for (Use &U : Bitcast->uses()) {
399  ST = dyn_cast<StoreInst>(U.getUser());
400  if (ST)
401  break;
402  }
403  if (!ST) {
404  if (transformBitcast(Bitcast))
405  DeadInsts.push_back(Bitcast);
406  continue;
407  }
408  // If bitcast (%13) has one use, combine bitcast and store to amx store.
409  // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
410  // %stride);
411  // %13 = bitcast x86_amx %src to <256 x i32>
412  // store <256 x i32> %13, <256 x i32>* %addr, align 64
413  // -->
414  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
415  // %stride64, %13)
416  //
417  // If bitcast (%13) has multi-use, transform as below.
418  // %13 = bitcast x86_amx %src to <256 x i32>
419  // store <256 x i32> %13, <256 x i32>* %addr, align 64
420  // %add = <256 x i32> %13, <256 x i32> %src2
421  // -->
422  // %13 = bitcast x86_amx %src to <256 x i32>
423  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
424  // %stride64, %13)
425  // %14 = load <256 x i32>, %addr
426  // %add = <256 x i32> %14, <256 x i32> %src2
427  //
428  combineBitcastStore(Bitcast, ST);
429  // Delete user first.
430  DeadInsts.push_back(ST);
431  DeadInsts.push_back(Bitcast);
432  }
433  }
434  }
435 
436  bool C = !DeadInsts.empty();
437 
438  for (auto *Inst : DeadInsts)
439  Inst->eraseFromParent();
440 
441  return C;
442 }
443 } // anonymous namespace
444 
446  Module *M = BB->getModule();
447  Function *F = BB->getParent();
448  IRBuilder<> Builder(&F->getEntryBlock().front());
449  const DataLayout &DL = M->getDataLayout();
450  unsigned AllocaAS = DL.getAllocaAddrSpace();
451  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
452  AllocaInst *AllocaRes =
453  new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
454  BasicBlock::iterator Iter = AllocaRes->getIterator();
455  ++Iter;
456  Builder.SetInsertPoint(&*Iter);
457  Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
458  return I8Ptr;
459 }
460 
462  assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
463  auto *II = cast<IntrinsicInst>(TileDef);
464  assert(II && "Not tile intrinsic!");
465  Value *Row = II->getOperand(0);
466  Value *Col = II->getOperand(1);
467 
468  BasicBlock *BB = TileDef->getParent();
469  BasicBlock::iterator Iter = TileDef->getIterator();
470  IRBuilder<> Builder(BB, ++Iter);
471  Value *Stride = Builder.getInt64(64);
472  std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
473 
474  Instruction *TileStore =
475  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
476  return TileStore;
477 }
478 
479 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
480  Value *V = U.get();
481  assert(V->getType()->isX86_AMXTy() && "Not define tile!");
482 
483  // Get tile shape.
484  IntrinsicInst *II = nullptr;
485  if (IsPHI) {
486  Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
487  II = cast<IntrinsicInst>(PhiOp);
488  } else {
489  II = cast<IntrinsicInst>(V);
490  }
491  Value *Row = II->getOperand(0);
492  Value *Col = II->getOperand(1);
493 
494  Instruction *UserI = dyn_cast<Instruction>(U.getUser());
495  IRBuilder<> Builder(UserI);
496  Value *Stride = Builder.getInt64(64);
497  std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
498 
499  Value *TileLoad =
500  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
501  UserI->replaceUsesOfWith(V, TileLoad);
502 }
503 
505  for (Use &U : I->uses()) {
506  User *V = U.getUser();
507  if (isa<PHINode>(V))
508  return true;
509  }
510  return false;
511 }
512 
513 // Let all AMX tile data become volatile data, shorten the life range
514 // of each tile register before fast register allocation.
515 namespace {
516 class X86VolatileTileData {
517  Function &F;
518 
519 public:
520  X86VolatileTileData(Function &Func) : F(Func) {}
521  Value *updatePhiIncomings(BasicBlock *BB,
522  SmallVector<Instruction *, 2> &Incomings);
523  void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
524  bool volatileTileData();
525  void volatileTilePHI(PHINode *Inst);
526  void volatileTileNonPHI(Instruction *I);
527 };
528 
529 Value *X86VolatileTileData::updatePhiIncomings(
531  Value *I8Ptr = getAllocaPos(BB);
532 
533  for (auto *I : Incomings) {
534  User *Store = createTileStore(I, I8Ptr);
535 
536  // All its uses (except phi) should load from stored mem.
537  for (Use &U : I->uses()) {
538  User *V = U.getUser();
539  if (isa<PHINode>(V) || V == Store)
540  continue;
541  replaceWithTileLoad(U, I8Ptr);
542  }
543  }
544  return I8Ptr;
545 }
546 
547 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
548  Value *StorePtr) {
549  for (Use &U : PHI->uses())
550  replaceWithTileLoad(U, StorePtr, true);
551  PHI->eraseFromParent();
552 }
553 
554 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
555 // and their related AMX intrinsics.
556 // 1) PHI Def should change to tileload.
557 // 2) PHI Incoming Values should tilestored in just after their def.
558 // 3) The mem of these tileload and tilestores should be same.
559 // e.g.
560 // ------------------------------------------------------
561 // bb_dom:
562 // ...
563 // br i1 %bool.cond, label %if.else, label %if.then
564 //
565 // if.then:
566 // def %t0 = ...
567 // ...
568 // use %t0
569 // ...
570 // br label %if.end
571 //
572 // if.else:
573 // def %t1 = ...
574 // br label %if.end
575 //
576 // if.end:
577 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
578 // ...
579 // use %td
580 // ------------------------------------------------------
581 // -->
582 // ------------------------------------------------------
583 // bb_entry:
584 // %mem = alloca <256 x i32>, align 1024 *
585 // ...
586 // bb_dom:
587 // ...
588 // br i1 %bool.cond, label %if.else, label %if.then
589 //
590 // if.then:
591 // def %t0 = ...
592 // call void @llvm.x86.tilestored64.internal(mem, %t0) *
593 // ...
594 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
595 // use %t0` *
596 // ...
597 // br label %if.end
598 //
599 // if.else:
600 // def %t1 = ...
601 // call void @llvm.x86.tilestored64.internal(mem, %t1) *
602 // br label %if.end
603 //
604 // if.end:
605 // ...
606 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
607 // use %td
608 // ------------------------------------------------------
609 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
610  BasicBlock *BB = PHI->getParent();
612 
613  for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
614  Value *Op = PHI->getIncomingValue(I);
615  Instruction *Inst = dyn_cast<Instruction>(Op);
616  assert(Inst && "We shouldn't fold AMX instrution!");
617  Incomings.push_back(Inst);
618  }
619 
620  Value *StorePtr = updatePhiIncomings(BB, Incomings);
621  replacePhiDefWithLoad(PHI, StorePtr);
622 }
623 
624 // Store the defined tile and load it before use.
625 // All its users are not PHI.
626 // e.g.
627 // ------------------------------------------------------
628 // def %td = ...
629 // ...
630 // "use %td"
631 // ------------------------------------------------------
632 // -->
633 // ------------------------------------------------------
634 // def %td = ...
635 // call void @llvm.x86.tilestored64.internal(mem, %td)
636 // ...
637 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
638 // "use %td2"
639 // ------------------------------------------------------
640 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
641  BasicBlock *BB = I->getParent();
642  Value *I8Ptr = getAllocaPos(BB);
643  User *Store = createTileStore(I, I8Ptr);
644 
645  // All its uses should load from stored mem.
646  for (Use &U : I->uses()) {
647  User *V = U.getUser();
648  assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
649  if (V != Store)
650  replaceWithTileLoad(U, I8Ptr);
651  }
652 }
653 
654 // Volatile Tile Model:
655 // 1) All the uses of tile data comes from tileload in time.
656 // 2) All the defs of tile data tilestore into mem immediately.
657 // For example:
658 // --------------------------------------------------------------------------
659 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
660 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
661 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
662 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
663 // call void @llvm.x86.tilestored64.internal(... td) area
664 // --------------------------------------------------------------------------
665 // 3) No terminator, call or other amx instructions in the key amx area.
666 bool X86VolatileTileData::volatileTileData() {
667  bool Changed = false;
668  for (BasicBlock &BB : F) {
670  SmallVector<Instruction *, 8> AMXDefInsts;
671 
672  for (Instruction &I : BB) {
673  if (!I.getType()->isX86_AMXTy())
674  continue;
675  if (isa<PHINode>(&I))
676  PHIInsts.push_back(&I);
677  else
678  AMXDefInsts.push_back(&I);
679  }
680 
681  // First we "volatile" the non-phi related amx intrinsics.
682  for (Instruction *I : AMXDefInsts) {
683  if (isIncomingOfPHI(I))
684  continue;
685  volatileTileNonPHI(I);
686  Changed = true;
687  }
688 
689  for (Instruction *I : PHIInsts) {
690  volatileTilePHI(dyn_cast<PHINode>(I));
691  Changed = true;
692  }
693  }
694  return Changed;
695 }
696 
697 } // anonymous namespace
698 
699 namespace {
700 
701 class X86LowerAMXCast {
702  Function &Func;
703  std::unique_ptr<DominatorTree> DT;
704 
705 public:
706  X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
707  void combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
708  bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
709  bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
710  bool combineAMXcast(TargetLibraryInfo *TLI);
711  bool transformAMXCast(IntrinsicInst *AMXCast);
712  bool transformAllAMXCast();
713  bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
715 };
716 
717 static bool DCEInstruction(Instruction *I,
719  const TargetLibraryInfo *TLI) {
720  if (isInstructionTriviallyDead(I, TLI)) {
723 
724  // Null out all of the instruction's operands to see if any operand becomes
725  // dead as we go.
726  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
727  Value *OpV = I->getOperand(i);
728  I->setOperand(i, nullptr);
729 
730  if (!OpV->use_empty() || I == OpV)
731  continue;
732 
733  // If the operand is an instruction that became dead as we nulled out the
734  // operand, and if it is 'trivially' dead, delete it in a future loop
735  // iteration.
736  if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
737  if (isInstructionTriviallyDead(OpI, TLI)) {
738  WorkList.insert(OpI);
739  }
740  }
741  }
742  I->eraseFromParent();
743  return true;
744  }
745  return false;
746 }
747 
748 /// This function handles following case
749 ///
750 /// A -> B amxcast
751 /// PHI
752 /// B -> A amxcast
753 ///
754 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
755 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
756 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
757  IntrinsicInst *CI, PHINode *PN,
759  IRBuilder<> Builder(CI);
760  Value *Src = CI->getOperand(0);
761  Type *SrcTy = Src->getType(); // Type B
762  Type *DestTy = CI->getType(); // Type A
763 
764  SmallVector<PHINode *, 4> PhiWorklist;
765  SmallSetVector<PHINode *, 4> OldPhiNodes;
766 
767  // Find all of the A->B casts and PHI nodes.
768  // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
769  // OldPhiNodes is used to track all known PHI nodes, before adding a new
770  // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
771  PhiWorklist.push_back(PN);
772  OldPhiNodes.insert(PN);
773  while (!PhiWorklist.empty()) {
774  auto *OldPN = PhiWorklist.pop_back_val();
775  for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
776  Value *IncValue = OldPN->getIncomingValue(I);
777  // TODO: currently, We ignore cases where it is a const. In the future, we
778  // might support const.
779  if (isa<Constant>(IncValue)) {
780  auto *IncConst = dyn_cast<Constant>(IncValue);
781  if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
782  return false;
783  Value *Row = nullptr, *Col = nullptr;
784  std::tie(Row, Col) = getShape(OldPN);
785  // TODO: If it is not constant the Row and Col must domoniate tilezero
786  // that we are going to create.
787  if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
788  return false;
789  // Create tilezero at the end of incoming block.
790  auto *Block = OldPN->getIncomingBlock(I);
791  BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
792  Instruction *NewInst = Builder.CreateIntrinsic(
793  Intrinsic::x86_tilezero_internal, None, {Row, Col});
794  NewInst->moveBefore(&*Iter);
795  NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
796  {IncValue->getType()}, {NewInst});
797  NewInst->moveBefore(&*Iter);
798  // Replace InValue with new Value.
799  OldPN->setIncomingValue(I, NewInst);
800  IncValue = NewInst;
801  }
802 
803  if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
804  if (OldPhiNodes.insert(PNode))
805  PhiWorklist.push_back(PNode);
806  continue;
807  }
808  Instruction *ACI = dyn_cast<Instruction>(IncValue);
809  if (ACI && isAMXCast(ACI)) {
810  // Verify it's a A->B cast.
811  Type *TyA = ACI->getOperand(0)->getType();
812  Type *TyB = ACI->getType();
813  if (TyA != DestTy || TyB != SrcTy)
814  return false;
815  continue;
816  }
817  return false;
818  }
819  }
820 
821  // Check that each user of each old PHI node is something that we can
822  // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
823  for (auto *OldPN : OldPhiNodes) {
824  for (User *V : OldPN->users()) {
825  Instruction *ACI = dyn_cast<Instruction>(V);
826  if (ACI && isAMXCast(ACI)) {
827  // Verify it's a B->A cast.
828  Type *TyB = ACI->getOperand(0)->getType();
829  Type *TyA = ACI->getType();
830  if (TyA != DestTy || TyB != SrcTy)
831  return false;
832  } else if (auto *PHI = dyn_cast<PHINode>(V)) {
833  // As long as the user is another old PHI node, then even if we don't
834  // rewrite it, the PHI web we're considering won't have any users
835  // outside itself, so it'll be dead.
836  // example:
837  // bb.0:
838  // %0 = amxcast ...
839  // bb.1:
840  // %1 = amxcast ...
841  // bb.2:
842  // %goodphi = phi %0, %1
843  // %3 = amxcast %goodphi
844  // bb.3:
845  // %goodphi2 = phi %0, %goodphi
846  // %4 = amxcast %goodphi2
847  // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
848  // outside the phi-web, so the combination stop When
849  // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
850  // will be done.
851  if (OldPhiNodes.count(PHI) == 0)
852  return false;
853  } else
854  return false;
855  }
856  }
857 
858  // For each old PHI node, create a corresponding new PHI node with a type A.
860  for (auto *OldPN : OldPhiNodes) {
861  Builder.SetInsertPoint(OldPN);
862  PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
863  NewPNodes[OldPN] = NewPN;
864  }
865 
866  // Fill in the operands of new PHI nodes.
867  for (auto *OldPN : OldPhiNodes) {
868  PHINode *NewPN = NewPNodes[OldPN];
869  for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
870  Value *V = OldPN->getOperand(j);
871  Value *NewV = nullptr;
872  Instruction *ACI = dyn_cast<Instruction>(V);
873  // There should not be a AMXcast from a const.
874  if (ACI && isAMXCast(ACI))
875  NewV = ACI->getOperand(0);
876  else if (auto *PrevPN = dyn_cast<PHINode>(V))
877  NewV = NewPNodes[PrevPN];
878  assert(NewV);
879  NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
880  }
881  }
882 
883  // Traverse all accumulated PHI nodes and process its users,
884  // which are Stores and BitcCasts. Without this processing
885  // NewPHI nodes could be replicated and could lead to extra
886  // moves generated after DeSSA.
887  // If there is a store with type B, change it to type A.
888 
889  // Replace users of BitCast B->A with NewPHI. These will help
890  // later to get rid of a closure formed by OldPHI nodes.
891  for (auto *OldPN : OldPhiNodes) {
892  PHINode *NewPN = NewPNodes[OldPN];
893  for (User *V : make_early_inc_range(OldPN->users())) {
894  Instruction *ACI = dyn_cast<Instruction>(V);
895  if (ACI && isAMXCast(ACI)) {
896  Type *TyB = ACI->getOperand(0)->getType();
897  Type *TyA = ACI->getType();
898  assert(TyA == DestTy && TyB == SrcTy);
899  (void)TyA;
900  (void)TyB;
901  ACI->replaceAllUsesWith(NewPN);
902  DeadInst.insert(ACI);
903  } else if (auto *PHI = dyn_cast<PHINode>(V)) {
904  // We don't need to push PHINode into DeadInst since they are operands
905  // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
906  assert(OldPhiNodes.contains(PHI));
907  (void)PHI;
908  } else
909  llvm_unreachable("all uses should be handled");
910  }
911  }
912  return true;
913 }
914 
915 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
916 // store <256 x i32> %43, <256 x i32>* %p, align 64
917 // -->
918 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
919 // i64 64, x86_amx %42)
920 void X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
921  Value *Tile = Cast->getOperand(0);
922  // TODO: If it is cast intrinsic or phi node, we can propagate the
923  // shape information through def-use chain.
924  if (!isAMXIntrinsic(Tile))
925  return;
926  auto *II = cast<IntrinsicInst>(Tile);
927  // Tile is output from AMX intrinsic. The first operand of the
928  // intrinsic is row, the second operand of the intrinsic is column.
929  Value *Row = II->getOperand(0);
930  Value *Col = II->getOperand(1);
932  // Use the maximum column as stride. It must be the same with load
933  // stride.
934  Value *Stride = Builder.getInt64(64);
935  Value *I8Ptr =
936  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
937  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
938  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
939 }
940 
941 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
942 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
943 // -->
944 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
945 // i8* %p, i64 64)
946 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
947  bool EraseLoad = true;
948  Value *Row = nullptr, *Col = nullptr;
949  Use &U = *(Cast->use_begin());
950  unsigned OpNo = U.getOperandNo();
951  auto *II = cast<IntrinsicInst>(U.getUser());
952  // TODO: If it is cast intrinsic or phi node, we can propagate the
953  // shape information through def-use chain.
954  if (!isAMXIntrinsic(II))
955  return false;
956  std::tie(Row, Col) = getShape(II, OpNo);
958  // Use the maximun column as stride.
959  Value *Stride = Builder.getInt64(64);
960  Value *I8Ptr;
961 
962  // To save compiling time, we create doninator tree when it is really
963  // needed.
964  if (!DT)
965  DT.reset(new DominatorTree(Func));
966  if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
967  // store the value to stack and reload it from stack before cast.
968  auto *AllocaAddr =
969  createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
970  Builder.SetInsertPoint(&*std::next(LD->getIterator()));
971  Builder.CreateStore(LD, AllocaAddr);
972 
973  Builder.SetInsertPoint(Cast);
974  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
975  EraseLoad = false;
976  } else {
977  I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
978  }
979  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
980 
981  Value *NewInst =
982  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
983  Cast->replaceAllUsesWith(NewInst);
984 
985  return EraseLoad;
986 }
987 
988 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
989  bool Change = false;
990  for (auto *Cast : Casts) {
991  auto *II = cast<IntrinsicInst>(Cast);
992  // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
993  // store <256 x i32> %43, <256 x i32>* %p, align 64
994  // -->
995  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
996  // i64 64, x86_amx %42)
997  if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
999  for (User *U : Cast->users()) {
1000  StoreInst *Store = dyn_cast<StoreInst>(U);
1001  if (!Store)
1002  continue;
1003  combineCastStore(cast<IntrinsicInst>(Cast), Store);
1004  DeadStores.push_back(Store);
1005  Change = true;
1006  }
1007  for (auto *Store : DeadStores)
1008  Store->eraseFromParent();
1009  } else { // x86_cast_vector_to_tile
1011  auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1012  if (!Load || !Load->hasOneUse())
1013  continue;
1014  // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1015  // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1016  // -->
1017  // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1018  // i8* %p, i64 64)
1019  if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1020  // Set the operand is null so that load instruction can be erased.
1021  Cast->setOperand(0, nullptr);
1022  Load->eraseFromParent();
1023  }
1024  }
1025  }
1026  return Change;
1027 }
1028 
1029 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1030  bool Change = false;
1031  // Collect tile cast instruction.
1032  SmallVector<Instruction *, 8> Vec2TileInsts;
1033  SmallVector<Instruction *, 8> Tile2VecInsts;
1034  SmallVector<Instruction *, 8> PhiCastWorkList;
1036  for (BasicBlock &BB : Func) {
1037  for (Instruction &I : BB) {
1038  Value *Vec;
1039  if (match(&I,
1040  m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1041  Vec2TileInsts.push_back(&I);
1042  else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1043  m_Value(Vec))))
1044  Tile2VecInsts.push_back(&I);
1045  }
1046  }
1047 
1048  auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1049  for (auto *Inst : Insts) {
1050  for (User *U : Inst->users()) {
1051  IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1052  if (!II || II->getIntrinsicID() != IID)
1053  continue;
1054  // T1 = vec2tile V0
1055  // V2 = tile2vec T1
1056  // V3 = OP V2
1057  // -->
1058  // T1 = vec2tile V0
1059  // V2 = tile2vec T1
1060  // V3 = OP V0
1061  II->replaceAllUsesWith(Inst->getOperand(0));
1062  Change = true;
1063  }
1064  }
1065  };
1066 
1067  Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1068  Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1069 
1071  auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1072  for (auto *Inst : Insts) {
1073  if (Inst->use_empty()) {
1074  Inst->eraseFromParent();
1075  Change = true;
1076  } else {
1077  LiveCasts.push_back(Inst);
1078  }
1079  }
1080  };
1081 
1082  EraseInst(Vec2TileInsts);
1083  EraseInst(Tile2VecInsts);
1084  Change |= combineLdSt(LiveCasts);
1085  EraseInst(LiveCasts);
1086 
1087  // Handle the A->B->A cast, and there is an intervening PHI node.
1088  for (BasicBlock &BB : Func) {
1089  for (Instruction &I : BB) {
1090  if (isAMXCast(&I)) {
1091  if (isa<PHINode>(I.getOperand(0)))
1092  PhiCastWorkList.push_back(&I);
1093  }
1094  }
1095  }
1096  for (auto *I : PhiCastWorkList) {
1097  // We skip the dead Amxcast.
1098  if (DeadInst.contains(I))
1099  continue;
1100  PHINode *PN = cast<PHINode>(I->getOperand(0));
1101  if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1102  DeadInst.insert(PN);
1103  Change = true;
1104  }
1105  }
1106 
1107  // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1108  // have no uses. We do some DeadCodeElimination for them.
1109  while (!DeadInst.empty()) {
1110  Instruction *I = DeadInst.pop_back_val();
1111  Change |= DCEInstruction(I, DeadInst, TLI);
1112  }
1113  return Change;
1114 }
1115 
1116 // There might be remaining AMXcast after combineAMXcast and they should be
1117 // handled elegantly.
1118 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1119  IRBuilder<> Builder(AMXCast);
1120  AllocaInst *AllocaAddr;
1121  Value *I8Ptr, *Stride;
1122  auto *Src = AMXCast->getOperand(0);
1123 
1124  auto Prepare = [&](Type *MemTy) {
1125  AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1126  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
1127  Stride = Builder.getInt64(64);
1128  };
1129 
1130  if (AMXCast->getType()->isX86_AMXTy()) {
1131  // %2 = amxcast <225 x i32> %src to x86_amx
1132  // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1133  // i8* %addr3, i64 60, x86_amx %2)
1134  // -->
1135  // %addr = alloca <225 x i32>, align 64
1136  // store <225 x i32> %src, <225 x i32>* %addr, align 64
1137  // %addr2 = bitcast <225 x i32>* %addr to i8*
1138  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1139  // i8* %addr2,
1140  // i64 60)
1141  // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1142  // i8* %addr3, i64 60, x86_amx %2)
1143  if (AMXCast->use_empty()) {
1144  AMXCast->eraseFromParent();
1145  return true;
1146  }
1147  Use &U = *(AMXCast->use_begin());
1148  unsigned OpNo = U.getOperandNo();
1149  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1150  if (!II)
1151  return false; // May be bitcast from x86amx to <256 x i32>.
1152  Prepare(AMXCast->getOperand(0)->getType());
1153  Builder.CreateStore(Src, AllocaAddr);
1154  // TODO we can pick an constant operand for the shape.
1155  Value *Row = nullptr, *Col = nullptr;
1156  std::tie(Row, Col) = getShape(II, OpNo);
1157  std::array<Value *, 4> Args = {
1158  Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1159  Value *NewInst = Builder.CreateIntrinsic(
1160  Intrinsic::x86_tileloadd64_internal, None, Args);
1161  AMXCast->replaceAllUsesWith(NewInst);
1162  AMXCast->eraseFromParent();
1163  } else {
1164  // %2 = amxcast x86_amx %src to <225 x i32>
1165  // -->
1166  // %addr = alloca <225 x i32>, align 64
1167  // %addr2 = bitcast <225 x i32>* to i8*
1168  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1169  // i8* %addr2, i64 %stride)
1170  // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1171  auto *II = dyn_cast<IntrinsicInst>(Src);
1172  if (!II)
1173  return false; // May be bitcast from <256 x i32> to x86amx.
1174  Prepare(AMXCast->getType());
1175  Value *Row = II->getOperand(0);
1176  Value *Col = II->getOperand(1);
1177  std::array<Value *, 5> Args = {
1178  Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1179  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
1180  Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1181  AMXCast->replaceAllUsesWith(NewInst);
1182  AMXCast->eraseFromParent();
1183  }
1184 
1185  return true;
1186 }
1187 
1188 bool X86LowerAMXCast::transformAllAMXCast() {
1189  bool Change = false;
1190  // Collect tile cast instruction.
1192  for (BasicBlock &BB : Func) {
1193  for (Instruction &I : BB) {
1194  if (isAMXCast(&I))
1195  WorkLists.push_back(&I);
1196  }
1197  }
1198 
1199  for (auto *Inst : WorkLists) {
1200  Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1201  }
1202 
1203  return Change;
1204 }
1205 
1206 } // anonymous namespace
1207 
1208 namespace {
1209 
1210 class X86LowerAMXTypeLegacyPass : public FunctionPass {
1211 public:
1212  static char ID;
1213 
1214  X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1216  }
1217 
1218  bool runOnFunction(Function &F) override {
1219  bool C = false;
1220  TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1221  TargetLibraryInfo *TLI =
1222  &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1223 
1224  X86LowerAMXCast LAC(F);
1225  C |= LAC.combineAMXcast(TLI);
1226  // There might be remaining AMXcast after combineAMXcast and they should be
1227  // handled elegantly.
1228  C |= LAC.transformAllAMXCast();
1229 
1230  X86LowerAMXType LAT(F);
1231  C |= LAT.visit();
1232 
1233  // Prepare for fast register allocation at O0.
1234  // Todo: May better check the volatile model of AMX code, not just
1235  // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1236  if (TM->getOptLevel() == CodeGenOpt::None) {
1237  // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1238  // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1239  // sure the amx data is volatile, that is nessary for AMX fast
1240  // register allocation.
1241  if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1242  X86VolatileTileData VTD(F);
1243  C = VTD.volatileTileData() || C;
1244  }
1245  }
1246 
1247  return C;
1248  }
1249 
1250  void getAnalysisUsage(AnalysisUsage &AU) const override {
1251  AU.setPreservesCFG();
1254  }
1255 };
1256 
1257 } // anonymous namespace
1258 
1259 static const char PassName[] = "Lower AMX type for load/store";
1261 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1262  false)
1265 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1266  false)
1267 
1269  return new X86LowerAMXTypeLegacyPass();
1270 }
i
i
Definition: README.txt:29
llvm::createX86LowerAMXTypePass
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...
Definition: X86LowerAMXType.cpp:1268
ValueTypes.h
DCEInstruction
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition: DCE.cpp:88
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::AArch64PACKey::ID
ID
Definition: AArch64BaseInfo.h:818
llvm::BasicBlock::iterator
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:87
PHI
Rewrite undef for PHI
Definition: AMDGPURewriteUndefForPHI.cpp:101
IntrinsicInst.h
llvm::Function
Definition: Function.h:60
Pass.h
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:53
llvm::ARM_MB::LD
@ LD
Definition: ARMBaseInfo.h:72
llvm::BitCastInst
This class represents a no-op cast from one type to another.
Definition: Instructions.h:5256
llvm::SmallVector< Instruction *, 8 >
llvm::LegacyLegalizeActions::Bitcast
@ Bitcast
Perform the operation on a different, but equivalently sized type.
Definition: LegacyLegalizerInfo.h:54
llvm::IRBuilder<>
llvm::Use::get
Value * get() const
Definition: Use.h:66
llvm::SmallDenseMap
Definition: DenseMap.h:880
Local.h
OptimizationRemarkEmitter.h
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
createAllocaInstAtEntry
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
Definition: X86LowerAMXType.cpp:95
createTileStore
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
Definition: X86LowerAMXType.cpp:461
isAMXCast
static bool isAMXCast(Instruction *II)
Definition: X86LowerAMXType.cpp:71
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
isAMXIntrinsic
static bool isAMXIntrinsic(Value *I)
Definition: X86LowerAMXType.cpp:77
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:31
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
replaceWithTileLoad
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
Definition: X86LowerAMXType.cpp:479
F
#define F(x, y, z)
Definition: MD5.cpp:55
DEBUG_TYPE
#define DEBUG_TYPE
Definition: X86LowerAMXType.cpp:69
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
X86.h
llvm::Type::getX86_AMXTy
static Type * getX86_AMXTy(LLVMContext &C)
Definition: Type.cpp:234
TargetMachine.h
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:24
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
TargetLibraryInfo.h
AssumeBundleBuilder.h
false
Definition: StackSlotColoring.cpp:141
llvm::Instruction
Definition: Instruction.h:42
llvm::salvageDebugInfo
void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
Definition: Utils.cpp:1361
llvm::Use::getUser
User * getUser() const
Returns the User that contains this Use.
Definition: Use.h:72
PatternMatch.h
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::empty
bool empty() const
Determine if the SetVector is empty or not.
Definition: SetVector.h:72
llvm::Value::use_empty
bool use_empty() const
Definition: Value.h:344
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
Passes.h
llvm::TargetPassConfig
Target-Independent Code Generator Pass Configuration Options.
Definition: TargetPassConfig.h:84
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:298
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::contains
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition: SetVector.h:209
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:81
llvm::TargetLibraryInfoWrapperPass
Definition: TargetLibraryInfo.h:475
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2849
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
llvm::numbers::e
constexpr double e
Definition: MathExtras.h:53
I
#define I(x, y, z)
Definition: MD5.cpp:58
llvm::make_early_inc_range
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:716
getAllocaPos
static Value * getAllocaPos(BasicBlock *BB)
Definition: X86LowerAMXType.cpp:445
TargetPassConfig.h
llvm::initializeX86LowerAMXTypeLegacyPassPass
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::TargetMachine
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:76
Ptr
@ Ptr
Definition: TargetLibraryInfo.cpp:60
llvm::Value::use_begin
use_iterator use_begin()
Definition: Value.h:360
llvm::salvageKnowledge
void salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
Definition: AssumeBundleBuilder.cpp:293
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
PassName
static const char PassName[]
Definition: X86LowerAMXType.cpp:1259
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::User::setOperand
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
llvm::CodeGenOpt::None
@ None
Definition: CodeGen.h:53
getFirstNonAllocaInTheEntryBlock
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
Definition: X86LowerAMXType.cpp:110
llvm::isInstructionTriviallyDead
bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition: Local.cpp:396
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::pop_back_val
T pop_back_val()
Definition: SetVector.h:232
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:265
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:143
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::Instruction::getFunction
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:73
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:532
llvm::ilist_node_impl::getIterator
self_iterator getIterator()
Definition: ilist_node.h:82
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
isIncomingOfPHI
static bool isIncomingOfPHI(Instruction *I)
Definition: X86LowerAMXType.cpp:504
llvm::ifs::IFSSymbolType::Func
@ Func
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:174
llvm::User::replaceUsesOfWith
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:85
j
return j(j<< 16)
getShape
static std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
Definition: X86LowerAMXType.cpp:117
llvm::None
constexpr std::nullopt_t None
Definition: None.h:27
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:348
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
llvm::post_order
iterator_range< po_iterator< T > > post_order(const T &G)
Definition: PostOrderIterator.h:189
Function.h
llvm::TargetLibraryInfo
Provides information about what library functions are available for the current target.
Definition: TargetLibraryInfo.h:226
llvm::SmallVectorImpl::clear
void clear()
Definition: SmallVector.h:614
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:46
llvm::AllocaInst::setAlignment
void setAlignment(Align Align)
Definition: Instructions.h:126
Instructions.h
PostOrderIterator.h
llvm::IRBuilderBase::getInt16
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition: IRBuilder.h:464
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1342
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:91
llvm::SmallVectorImpl::pop_back_val
T pop_back_val()
Definition: SmallVector.h:677
TargetTransformInfo.h
llvm::PHINode
Definition: Instructions.h:2699
llvm::SmallVectorImpl< Instruction * >
llvm::reverse
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:485
llvm::SmallSetVector
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:307
TM
const char LLVMTargetMachineRef TM
Definition: PassBuilderBindings.cpp:47
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
llvm::IRBuilderBase::CreateUDiv
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition: IRBuilder.h:1284
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
llvm::AMDGPU::HSAMD::Kernel::Key::Args
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Definition: AMDGPUMetadata.h:394
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:59
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::pdb::PDB_SymType::Block
@ Block
InitializePasses.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
llvm::VectorType::get
static VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
Definition: Type.cpp:668
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::Type::isX86_AMXTy
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition: Type.h:195
SetVector.h
llvm::Instruction::moveBefore
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
Definition: Instruction.cpp:107
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
SmallSet.h
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:39