LLVM  13.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ... |
27 /// | ... |
28 /// | "use %td" |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ... |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34 /// | ... |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2" |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
43 #include "llvm/ADT/SmallSet.h"
46 #include "llvm/CodeGen/Passes.h"
49 #include "llvm/IR/DataLayout.h"
50 #include "llvm/IR/Function.h"
51 #include "llvm/IR/IRBuilder.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/IntrinsicInst.h"
54 #include "llvm/IR/IntrinsicsX86.h"
55 #include "llvm/IR/PatternMatch.h"
56 #include "llvm/InitializePasses.h"
57 #include "llvm/Pass.h"
59 
60 using namespace llvm;
61 using namespace PatternMatch;
62 
63 #define DEBUG_TYPE "lower-amx-type"
64 
66  BasicBlock *BB) {
67  Function &F = *BB->getParent();
68  Module *M = BB->getModule();
69  const DataLayout &DL = M->getDataLayout();
70 
71  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
72  LLVMContext &Ctx = Builder.getContext();
73  auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
74  unsigned AllocaAS = DL.getAllocaAddrSpace();
75  AllocaInst *AllocaRes =
76  new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
77  AllocaRes->setAlignment(AllocaAlignment);
78  return AllocaRes;
79 }
80 
81 namespace {
82 class X86LowerAMXType {
83  Function &Func;
84  TargetMachine *TM = nullptr;
85 
86  // In AMX intrinsics we let Shape = {Row, Col}, but the
87  // RealCol = Col / ElementSize. We may use the RealCol
88  // as a new Row for other new created AMX intrinsics.
89  std::map<Value *, Value *> Col2Row;
90 
91 public:
92  X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {}
93  bool visit();
94  void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
95  void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
96  bool transformBitcast(BitCastInst *Bitcast);
97  std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
98  Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
99 };
100 
101 Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
102  unsigned Granularity) {
103  if (Col2Row.count(V))
104  return Col2Row[V];
106  if (auto *I = dyn_cast<Instruction>(V)) {
107  BasicBlock::iterator Iter = I->getIterator();
108  ++Iter;
109  Builder.SetInsertPoint(&*Iter);
110  }
111  ConstantInt *Gran = Builder.getInt16(Granularity);
112  Value *RealRow = Builder.CreateUDiv(V, Gran);
113  Col2Row[V] = RealRow;
114  return RealRow;
115 }
116 
117 std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
118  unsigned OpNo) {
119  Value *Row = nullptr, *Col = nullptr;
120  switch (II->getIntrinsicID()) {
121  default:
122  llvm_unreachable("Expect amx intrinsics");
123  case Intrinsic::x86_tileloadd64_internal:
124  case Intrinsic::x86_tilestored64_internal: {
125  Row = II->getArgOperand(0);
126  Col = II->getArgOperand(1);
127  break;
128  }
129  // a * b + c
130  // The shape depends on which operand.
131  case Intrinsic::x86_tdpbssd_internal:
132  case Intrinsic::x86_tdpbsud_internal:
133  case Intrinsic::x86_tdpbusd_internal:
134  case Intrinsic::x86_tdpbuud_internal:
135  case Intrinsic::x86_tdpbf16ps_internal: {
136  switch (OpNo) {
137  case 3:
138  Row = II->getArgOperand(0);
139  Col = II->getArgOperand(1);
140  break;
141  case 4:
142  Row = II->getArgOperand(0);
143  Col = II->getArgOperand(2);
144  break;
145  case 5:
146  Row = II->getArgOperand(2);
147  // FIXME: There is a design bug for AMX shape, which the Col should be
148  // Col/4 if it will be used as Row, but current Greedy RA can't handle
149  // this case well, it may failed if we generate a new Shape definition.
150  // So Let's just do it in O0 first.
151  // Row = Row / 4
152  if (TM->getOptLevel() == CodeGenOpt::None)
153  Row = getRowFromCol(II, Row, 4);
154  Col = II->getArgOperand(1);
155  break;
156  }
157  break;
158  }
159  }
160 
161  return std::make_pair(Row, Col);
162 }
163 
164 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
165 // %2 = bitcast <256 x i32> %src to x86_amx
166 // -->
167 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
168 // i8* %addr, i64 %stride64)
169 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
170  Value *Row = nullptr, *Col = nullptr;
171  Use &U = *(Bitcast->use_begin());
172  unsigned OpNo = U.getOperandNo();
173  auto *II = cast<IntrinsicInst>(U.getUser());
174  std::tie(Row, Col) = getShape(II, OpNo);
176  // Use the maximun column as stride.
177  Value *Stride = Builder.getInt64(64);
178  Value *I8Ptr =
179  Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
180  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
181 
182  Value *NewInst =
183  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
184  Bitcast->replaceAllUsesWith(NewInst);
185 }
186 
187 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
188 // %stride);
189 // %13 = bitcast x86_amx %src to <256 x i32>
190 // store <256 x i32> %13, <256 x i32>* %addr, align 64
191 // -->
192 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
193 // %stride64, %13)
194 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
195 
196  Value *Tile = Bitcast->getOperand(0);
197  auto *II = cast<IntrinsicInst>(Tile);
198  // Tile is output from AMX intrinsic. The first operand of the
199  // intrinsic is row, the second operand of the intrinsic is column.
200  Value *Row = II->getOperand(0);
201  Value *Col = II->getOperand(1);
203  // Use the maximum column as stride. It must be the same with load
204  // stride.
205  Value *Stride = Builder.getInt64(64);
206  Value *I8Ptr =
207  Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
208  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
209  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
210  if (Bitcast->hasOneUse())
211  return;
212  // %13 = bitcast x86_amx %src to <256 x i32>
213  // store <256 x i32> %13, <256 x i32>* %addr, align 64
214  // %add = <256 x i32> %13, <256 x i32> %src2
215  // -->
216  // %13 = bitcast x86_amx %src to <256 x i32>
217  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
218  // %stride64, %13)
219  // %14 = load <256 x i32>, %addr
220  // %add = <256 x i32> %14, <256 x i32> %src2
221  Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
222  Bitcast->replaceAllUsesWith(Vec);
223 }
224 
225 // transform bitcast to <store, load> instructions.
226 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
228  AllocaInst *AllocaAddr;
229  Value *I8Ptr, *Stride;
230  auto *Src = Bitcast->getOperand(0);
231 
232  auto Prepare = [&]() {
233  AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent());
234  I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
235  Stride = Builder.getInt64(64);
236  };
237 
238  if (Bitcast->getType()->isX86_AMXTy()) {
239  // %2 = bitcast <256 x i32> %src to x86_amx
240  // -->
241  // %addr = alloca <256 x i32>, align 64
242  // store <256 x i32> %src, <256 x i32>* %addr, align 64
243  // %addr2 = bitcast <256 x i32>* to i8*
244  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
245  // i8* %addr2,
246  // i64 64)
247  Use &U = *(Bitcast->use_begin());
248  unsigned OpNo = U.getOperandNo();
249  auto *II = dyn_cast<IntrinsicInst>(U.getUser());
250  if (!II)
251  return false; // May be bitcast from x86amx to <256 x i32>.
252  Prepare();
253  Builder.CreateStore(Src, AllocaAddr);
254  // TODO we can pick an constant operand for the shape.
255  Value *Row = nullptr, *Col = nullptr;
256  std::tie(Row, Col) = getShape(II, OpNo);
257  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
258  Value *NewInst = Builder.CreateIntrinsic(
259  Intrinsic::x86_tileloadd64_internal, None, Args);
260  Bitcast->replaceAllUsesWith(NewInst);
261  } else {
262  // %2 = bitcast x86_amx %src to <256 x i32>
263  // -->
264  // %addr = alloca <256 x i32>, align 64
265  // %addr2 = bitcast <256 x i32>* to i8*
266  // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
267  // i8* %addr2, i64 %stride)
268  // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
269  auto *II = dyn_cast<IntrinsicInst>(Src);
270  if (!II)
271  return false; // May be bitcast from <256 x i32> to x86amx.
272  Prepare();
273  Value *Row = II->getOperand(0);
274  Value *Col = II->getOperand(1);
275  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
276  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
277  Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
278  Bitcast->replaceAllUsesWith(NewInst);
279  }
280 
281  return true;
282 }
283 
284 bool X86LowerAMXType::visit() {
286  Col2Row.clear();
287 
288  for (BasicBlock *BB : post_order(&Func)) {
289  for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
290  II != IE;) {
291  Instruction &Inst = *II++;
292  auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
293  if (!Bitcast)
294  continue;
295 
296  Value *Src = Bitcast->getOperand(0);
297  if (Bitcast->getType()->isX86_AMXTy()) {
298  if (Bitcast->user_empty()) {
299  DeadInsts.push_back(Bitcast);
300  continue;
301  }
302  LoadInst *LD = dyn_cast<LoadInst>(Src);
303  if (!LD) {
304  if (transformBitcast(Bitcast))
305  DeadInsts.push_back(Bitcast);
306  continue;
307  }
308  // If load has mutli-user, duplicate a vector load.
309  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
310  // %2 = bitcast <256 x i32> %src to x86_amx
311  // %add = add <256 x i32> %src, <256 x i32> %src2
312  // -->
313  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
314  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
315  // i8* %addr, i64 %stride64)
316  // %add = add <256 x i32> %src, <256 x i32> %src2
317 
318  // If load has one user, the load will be eliminated in DAG ISel.
319  // %src = load <256 x i32>, <256 x i32>* %addr, align 64
320  // %2 = bitcast <256 x i32> %src to x86_amx
321  // -->
322  // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
323  // i8* %addr, i64 %stride64)
324  combineLoadBitcast(LD, Bitcast);
325  DeadInsts.push_back(Bitcast);
326  if (LD->hasOneUse())
327  DeadInsts.push_back(LD);
328  } else if (Src->getType()->isX86_AMXTy()) {
329  if (Bitcast->user_empty()) {
330  DeadInsts.push_back(Bitcast);
331  continue;
332  }
333  StoreInst *ST = nullptr;
334  for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
335  UI != UE;) {
336  Value *I = (UI++)->getUser();
337  ST = dyn_cast<StoreInst>(I);
338  if (ST)
339  break;
340  }
341  if (!ST) {
342  if (transformBitcast(Bitcast))
343  DeadInsts.push_back(Bitcast);
344  continue;
345  }
346  // If bitcast (%13) has one use, combine bitcast and store to amx store.
347  // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
348  // %stride);
349  // %13 = bitcast x86_amx %src to <256 x i32>
350  // store <256 x i32> %13, <256 x i32>* %addr, align 64
351  // -->
352  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
353  // %stride64, %13)
354  //
355  // If bitcast (%13) has multi-use, transform as below.
356  // %13 = bitcast x86_amx %src to <256 x i32>
357  // store <256 x i32> %13, <256 x i32>* %addr, align 64
358  // %add = <256 x i32> %13, <256 x i32> %src2
359  // -->
360  // %13 = bitcast x86_amx %src to <256 x i32>
361  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
362  // %stride64, %13)
363  // %14 = load <256 x i32>, %addr
364  // %add = <256 x i32> %14, <256 x i32> %src2
365  //
366  combineBitcastStore(Bitcast, ST);
367  // Delete user first.
368  DeadInsts.push_back(ST);
369  DeadInsts.push_back(Bitcast);
370  }
371  }
372  }
373 
374  bool C = !DeadInsts.empty();
375 
376  for (auto *Inst : DeadInsts)
377  Inst->eraseFromParent();
378 
379  return C;
380 }
381 } // anonymous namespace
382 
384  Module *M = BB->getModule();
385  Function *F = BB->getParent();
386  IRBuilder<> Builder(&F->getEntryBlock().front());
387  const DataLayout &DL = M->getDataLayout();
388  unsigned AllocaAS = DL.getAllocaAddrSpace();
389  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
390  AllocaInst *AllocaRes =
391  new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
392  BasicBlock::iterator Iter = AllocaRes->getIterator();
393  ++Iter;
394  Builder.SetInsertPoint(&*Iter);
395  Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
396  return I8Ptr;
397 }
398 
399 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
400  assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
401  auto *II = cast<IntrinsicInst>(TileDef);
402  assert(II && "Not tile intrinsic!");
403  Value *Row = II->getOperand(0);
404  Value *Col = II->getOperand(1);
405 
406  BasicBlock *BB = TileDef->getParent();
407  BasicBlock::iterator Iter = TileDef->getIterator();
408  IRBuilder<> Builder(BB, ++Iter);
409  Value *Stride = Builder.getInt64(64);
410  std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
411 
412  Instruction *TileStore =
413  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
414  return TileStore;
415 }
416 
417 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
418  Value *V = U.get();
419  assert(V->getType()->isX86_AMXTy() && "Not define tile!");
420 
421  // Get tile shape.
422  IntrinsicInst *II = nullptr;
423  if (IsPHI) {
424  Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
425  II = cast<IntrinsicInst>(PhiOp);
426  } else {
427  II = cast<IntrinsicInst>(V);
428  }
429  Value *Row = II->getOperand(0);
430  Value *Col = II->getOperand(1);
431 
432  Instruction *UserI = dyn_cast<Instruction>(U.getUser());
433  IRBuilder<> Builder(UserI);
434  Value *Stride = Builder.getInt64(64);
435  std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
436 
437  Value *TileLoad =
438  Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
439  UserI->replaceUsesOfWith(V, TileLoad);
440 }
441 
443  for (Use &U : I->uses()) {
444  User *V = U.getUser();
445  if (isa<PHINode>(V))
446  return true;
447  }
448  return false;
449 }
450 
451 // Let all AMX tile data become volatile data, shorten the life range
452 // of each tile register before fast register allocation.
453 namespace {
454 class X86VolatileTileData {
455  Function &F;
456 
457 public:
458  X86VolatileTileData(Function &Func) : F(Func) {}
459  Value *updatePhiIncomings(BasicBlock *BB,
460  SmallVector<Instruction *, 2> &Imcomings);
461  void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
462  bool volatileTileData();
463  void volatileTilePHI(PHINode *Inst);
464  void volatileTileNonPHI(Instruction *I);
465 };
466 
467 Value *X86VolatileTileData::updatePhiIncomings(
469  Value *I8Ptr = getAllocaPos(BB);
470 
471  for (auto *I : Imcomings) {
472  User *Store = createTileStore(I, I8Ptr);
473 
474  // All its uses (except phi) should load from stored mem.
475  for (Use &U : I->uses()) {
476  User *V = U.getUser();
477  if (isa<PHINode>(V) || V == Store)
478  continue;
479  replaceWithTileLoad(U, I8Ptr);
480  }
481  }
482  return I8Ptr;
483 }
484 
485 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
486  Value *StorePtr) {
487  for (Use &U : PHI->uses())
488  replaceWithTileLoad(U, StorePtr, true);
489  PHI->eraseFromParent();
490 }
491 
492 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
493 // and their related AMX intrinsics.
494 // 1) PHI Def should change to tileload.
495 // 2) PHI Incoming Values should tilestored in just after their def.
496 // 3) The mem of these tileload and tilestores should be same.
497 // e.g.
498 // ------------------------------------------------------
499 // bb_dom:
500 // ...
501 // br i1 %bool.cond, label %if.else, label %if.then
502 //
503 // if.then:
504 // def %t0 = ...
505 // ...
506 // use %t0
507 // ...
508 // br label %if.end
509 //
510 // if.else:
511 // def %t1 = ...
512 // br label %if.end
513 //
514 // if.end:
515 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
516 // ...
517 // use %td
518 // ------------------------------------------------------
519 // -->
520 // ------------------------------------------------------
521 // bb_entry:
522 // %mem = alloca <256 x i32>, align 1024 *
523 // ...
524 // bb_dom:
525 // ...
526 // br i1 %bool.cond, label %if.else, label %if.then
527 //
528 // if.then:
529 // def %t0 = ...
530 // call void @llvm.x86.tilestored64.internal(mem, %t0) *
531 // ...
532 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
533 // use %t0` *
534 // ...
535 // br label %if.end
536 //
537 // if.else:
538 // def %t1 = ...
539 // call void @llvm.x86.tilestored64.internal(mem, %t1) *
540 // br label %if.end
541 //
542 // if.end:
543 // ...
544 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
545 // use %td
546 // ------------------------------------------------------
547 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
548  BasicBlock *BB = PHI->getParent();
550 
551  for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
552  Value *Op = PHI->getIncomingValue(I);
553  Instruction *Inst = dyn_cast<Instruction>(Op);
554  assert(Inst && "We shouldn't fold AMX instrution!");
555  Imcomings.push_back(Inst);
556  }
557 
558  Value *StorePtr = updatePhiIncomings(BB, Imcomings);
559  replacePhiDefWithLoad(PHI, StorePtr);
560 }
561 
562 // Store the defined tile and load it before use.
563 // All its users are not PHI.
564 // e.g.
565 // ------------------------------------------------------
566 // def %td = ...
567 // ...
568 // "use %td"
569 // ------------------------------------------------------
570 // -->
571 // ------------------------------------------------------
572 // def %td = ...
573 // call void @llvm.x86.tilestored64.internal(mem, %td)
574 // ...
575 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
576 // "use %td2"
577 // ------------------------------------------------------
578 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
579  BasicBlock *BB = I->getParent();
580  Value *I8Ptr = getAllocaPos(BB);
581  User *Store = createTileStore(I, I8Ptr);
582 
583  // All its uses should load from stored mem.
584  for (Use &U : I->uses()) {
585  User *V = U.getUser();
586  assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
587  if (V != Store)
588  replaceWithTileLoad(U, I8Ptr);
589  }
590 }
591 
592 // Volatile Tile Model:
593 // 1) All the uses of tile data comes from tileload in time.
594 // 2) All the defs of tile data tilestore into mem immediately.
595 // For example:
596 // --------------------------------------------------------------------------
597 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
598 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
599 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
600 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
601 // call void @llvm.x86.tilestored64.internal(... td) area
602 // --------------------------------------------------------------------------
603 // 3) No terminator, call or other amx instructions in the key amx area.
604 bool X86VolatileTileData::volatileTileData() {
605  bool Changed = false;
606  for (BasicBlock &BB : F) {
608  SmallVector<Instruction *, 8> AMXDefInsts;
609 
610  for (Instruction &I : BB) {
611  if (!I.getType()->isX86_AMXTy())
612  continue;
613  if (isa<PHINode>(&I))
614  PHIInsts.push_back(&I);
615  else
616  AMXDefInsts.push_back(&I);
617  }
618 
619  // First we "volatile" the non-phi related amx intrinsics.
620  for (Instruction *I : AMXDefInsts) {
621  if (isIncomingOfPHI(I))
622  continue;
623  volatileTileNonPHI(I);
624  Changed = true;
625  }
626 
627  for (Instruction *I : PHIInsts) {
628  volatileTilePHI(dyn_cast<PHINode>(I));
629  Changed = true;
630  }
631  }
632  return Changed;
633 }
634 
635 } // anonymous namespace
636 
637 namespace {
638 
639 class X86LowerAMXTypeLegacyPass : public FunctionPass {
640 public:
641  static char ID;
642 
643  X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
645  }
646 
647  bool runOnFunction(Function &F) override {
648  TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
649 
650  X86LowerAMXType LAT(F, TM);
651  bool C = LAT.visit();
652 
653  // Prepare for fast register allocation at O0.
654  // Todo: May better check the volatile model of AMX code, not just
655  // by checking Attribute::OptimizeNone and CodeGenOpt::None.
656  if (TM->getOptLevel() == CodeGenOpt::None) {
657  // If Front End not use O0 but the Mid/Back end use O0, (e.g.
658  // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
659  // sure the amx data is volatile, that is nessary for AMX fast
660  // register allocation.
661  if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
662  X86VolatileTileData VTD(F);
663  C = VTD.volatileTileData() || C;
664  }
665  }
666 
667  return C;
668  }
669 
670  void getAnalysisUsage(AnalysisUsage &AU) const override {
671  AU.setPreservesCFG();
673  }
674 };
675 
676 } // anonymous namespace
677 
678 static const char PassName[] = "Lower AMX type for load/store";
680 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
681  false)
683 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
684  false)
685 
687  return new X86LowerAMXTypeLegacyPass();
688 }
llvm::createX86LowerAMXTypePass
FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...
Definition: X86LowerAMXType.cpp:686
ValueTypes.h
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass
llvm
Definition: AllocatorList.h:23
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:112
llvm::BasicBlock::iterator
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:90
IntrinsicInst.h
llvm::Function
Definition: Function.h:61
Pass.h
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:52
llvm::ARM_MB::LD
@ LD
Definition: ARMBaseInfo.h:72
llvm::BitCastInst
This class represents a no-op cast from one type to another.
Definition: Instructions.h:5166
llvm::SmallVector< Instruction *, 8 >
llvm::LegacyLegalizeActions::Bitcast
@ Bitcast
Perform the operation on a different, but equivalently sized type.
Definition: LegacyLegalizerInfo.h:54
llvm::IRBuilder<>
llvm::Use::get
Value * get() const
Definition: Use.h:67
OptimizationRemarkEmitter.h
createTileStore
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
Definition: X86LowerAMXType.cpp:399
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:46
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:33
replaceWithTileLoad
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
Definition: X86LowerAMXType.cpp:417
createAllocaInstAtEntry
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB)
Definition: X86LowerAMXType.cpp:65
F
#define F(x, y, z)
Definition: MD5.cpp:56
DEBUG_TYPE
#define DEBUG_TYPE
Definition: X86LowerAMXType.cpp:63
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:77
X86.h
llvm::Type::getX86_AMXTy
static Type * getX86_AMXTy(LLVMContext &C)
Definition: Type.cpp:192
TargetMachine.h
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
llvm::PHINode::getIncomingValue
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
Definition: Instructions.h:2696
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::Value::uses
iterator_range< use_iterator > uses()
Definition: Value.h:377
llvm::BasicBlock::getFirstInsertionPt
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:249
false
Definition: StackSlotColoring.cpp:142
llvm::Instruction
Definition: Instruction.h:45
llvm::Use::getUser
User * getUser() const
Returns the User that contains this Use.
Definition: Use.h:73
PatternMatch.h
llvm::PHINode::getNumIncomingValues
unsigned getNumIncomingValues() const
Return the number of incoming edges.
Definition: Instructions.h:2692
llvm::None
const NoneType None
Definition: None.h:23
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
llvm::ARM_PROC::IE
@ IE
Definition: ARMBaseInfo.h:27
Passes.h
llvm::TargetPassConfig
Target-Independent Code Generator Pass Configuration Options.
Definition: TargetPassConfig.h:84
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:303
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:78
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
I
#define I(x, y, z)
Definition: MD5.cpp:59
getAllocaPos
static Value * getAllocaPos(BasicBlock *BB)
Definition: X86LowerAMXType.cpp:383
TargetPassConfig.h
llvm::initializeX86LowerAMXTypeLegacyPassPass
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &)
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::TargetMachine
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
llvm::elfabi::ELFSymbolType::Func
@ Func
llvm::User::replaceUsesOfWith
void replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
PassName
static const char PassName[]
Definition: X86LowerAMXType.cpp:678
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::CodeGenOpt::None
@ None
Definition: CodeGen.h:53
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:253
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:136
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
llvm::ilist_node_impl::getIterator
self_iterator getIterator()
Definition: ilist_node.h:81
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
isIncomingOfPHI
static bool isIncomingOfPHI(Instruction *I)
Definition: X86LowerAMXType.cpp:442
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:174
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:314
llvm::post_order
iterator_range< po_iterator< T > > post_order(const T &G)
Definition: PostOrderIterator.h:188
Function.h
llvm::SmallVectorImpl::clear
void clear()
Definition: SmallVector.h:584
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:45
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
llvm::AllocaInst::setAlignment
void setAlignment(Align Align)
Definition: Instructions.h:123
llvm::BasicBlock::reverse_iterator
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:92
Instructions.h
PostOrderIterator.h
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1341
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:94
TargetTransformInfo.h
llvm::PHINode
Definition: Instructions.h:2600
TM
const char LLVMTargetMachineRef TM
Definition: PassBuilderBindings.cpp:47
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:298
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
llvm::AMDGPU::HSAMD::Kernel::Key::Args
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Definition: AMDGPUMetadata.h:389
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:61
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
InitializePasses.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:75
llvm::VectorType::get
static VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
Definition: Type.cpp:628
llvm::Type::isX86_AMXTy
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition: Type.h:187
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:44
SmallSet.h
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:38