LLVM  13.0.0git
X86LowerAMXIntrinsics.cpp
Go to the documentation of this file.
1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform amx intrinsics to scalar operations.
10 /// This pass is always enabled and it skips when it is not -O0 and has no
11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12 /// intrinsics is near the amx intrinsics code. We are not able to find a
13 /// point which post-dominate all the shape and dominate all amx intrinsics.
14 /// To decouple the dependency of the shape, we transform amx intrinsics
15 /// to scalar operation, so that compiling doesn't fail. In long term, we
16 /// should improve fast register allocation to allocate amx register.
17 //===----------------------------------------------------------------------===//
18 //
19 #include "X86.h"
20 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/CodeGen/Passes.h"
28 #include "llvm/IR/DataLayout.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/IntrinsicsX86.h"
34 #include "llvm/IR/PatternMatch.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Pass.h"
40 
41 using namespace llvm;
42 using namespace PatternMatch;
43 
44 #define DEBUG_TYPE "lower-amx-intrinsics"
45 
46 #ifndef NDEBUG
47 static bool isV256I32Ty(Type *Ty) {
48  if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
49  return FVT->getNumElements() == 256 &&
50  FVT->getElementType()->isIntegerTy(32);
51  return false;
52 }
53 #endif
54 
55 namespace {
56 class X86LowerAMXIntrinsics {
57  Function &Func;
58 
59 public:
60  X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
61  : Func(F), DTU(DomTU), LI(LoopI) {}
62  bool visit();
63 
64 private:
65  DomTreeUpdater &DTU;
66  LoopInfo *LI;
67  BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
69  Loop *L);
70  template <bool IsTileLoad>
71  Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
72  IRBuilderBase &B, Value *Row, Value *Col,
73  Value *Ptr, Value *Stride, Value *Tile);
74  template <Intrinsic::ID IntrID>
75  typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
76  IntrID == Intrinsic::x86_tdpbsud_internal ||
77  IntrID == Intrinsic::x86_tdpbusd_internal ||
78  IntrID == Intrinsic::x86_tdpbuud_internal ||
79  IntrID == Intrinsic::x86_tdpbf16ps_internal,
80  Value *>::type
81  createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
82  Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
83  Value *RHS);
84  template <bool IsTileLoad>
85  bool lowerTileLoadStore(Instruction *TileLoadStore);
86  template <Intrinsic::ID IntrID>
87  typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
88  IntrID == Intrinsic::x86_tdpbsud_internal ||
89  IntrID == Intrinsic::x86_tdpbusd_internal ||
90  IntrID == Intrinsic::x86_tdpbuud_internal ||
91  IntrID == Intrinsic::x86_tdpbf16ps_internal,
92  bool>::type
93  lowerTileDP(Instruction *TileDP);
94  bool lowerTileZero(Instruction *TileZero);
95 };
96 
97 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
98  BasicBlock *Exit, Value *Bound,
99  Value *Step, StringRef Name,
100  IRBuilderBase &B, Loop *L) {
101  LLVMContext &Ctx = Preheader->getContext();
102  BasicBlock *Header =
103  BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
104  BasicBlock *Body =
105  BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
106  BasicBlock *Latch =
107  BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
108 
109  Type *I16Ty = Type::getInt16Ty(Ctx);
110  BranchInst::Create(Body, Header);
111  BranchInst::Create(Latch, Body);
112  PHINode *IV =
113  PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
114  IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
115 
116  B.SetInsertPoint(Latch);
117  Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
118  Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
119  BranchInst::Create(Header, Exit, Cond, Latch);
120  IV->addIncoming(Inc, Latch);
121 
122  BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
123  BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
124  PreheaderBr->setSuccessor(0, Header);
125  DTU.applyUpdatesPermissive({
126  {DominatorTree::Delete, Preheader, Tmp},
127  {DominatorTree::Insert, Header, Body},
128  {DominatorTree::Insert, Body, Latch},
129  {DominatorTree::Insert, Latch, Header},
130  {DominatorTree::Insert, Latch, Exit},
131  {DominatorTree::Insert, Preheader, Header},
132  });
133  if (LI) {
134  L->addBasicBlockToLoop(Header, *LI);
135  L->addBasicBlockToLoop(Body, *LI);
136  L->addBasicBlockToLoop(Latch, *LI);
137  }
138  return Body;
139 }
140 
141 template <bool IsTileLoad>
142 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
143  BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
144  Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
145  std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
146  Loop *RowLoop = nullptr;
147  Loop *ColLoop = nullptr;
148  if (LI) {
149  RowLoop = LI->AllocateLoop();
150  ColLoop = LI->AllocateLoop();
151  RowLoop->addChildLoop(ColLoop);
152  if (Loop *ParentL = LI->getLoopFor(Start))
153  ParentL->addChildLoop(RowLoop);
154  else
155  LI->addTopLevelLoop(RowLoop);
156  }
157 
158  BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
159  IntrinName + ".scalarize.rows", B, RowLoop);
160  BasicBlock *RowLatch = RowBody->getSingleSuccessor();
161 
162  BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
163  IntrinName + ".scalarize.cols", B, ColLoop);
164 
165  BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
166  BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
167  BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
168  Value *CurrentRow = &*RowLoopHeader->begin();
169  Value *CurrentCol = &*ColLoopHeader->begin();
170  Type *EltTy = B.getInt32Ty();
171  FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
172 
173  // Common part for tileload and tilestore
174  // *.scalarize.cols.body:
175  // Calculate %idxmem and %idxvec
176  B.SetInsertPoint(ColBody->getTerminator());
177  Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
178  Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
179  Value *Offset =
180  B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
181  unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
182  Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
183  Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
184  Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
185  if (IsTileLoad) {
186  // tileload.scalarize.rows.header:
187  // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
188  // %tileload.scalarize.rows.latch ]
189  B.SetInsertPoint(RowLoopHeader->getTerminator());
190  Value *VecZero = Constant::getNullValue(V256I32Ty);
191  PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
192  VecCPhiRowLoop->addIncoming(VecZero, Start);
193 
194  // tileload.scalarize.cols.header:
195  // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
196  // ], [ %ResVec, %tileload.scalarize.cols.latch ]
197  B.SetInsertPoint(ColLoopHeader->getTerminator());
198  PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
199  VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
200 
201  // tileload.scalarize.cols.body:
202  // Calculate %idxmem and %idxvec
203  // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
204  // %elt = load i32, i32* %ptr
205  // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
206  B.SetInsertPoint(ColBody->getTerminator());
207  Value *Elt = B.CreateLoad(EltTy, EltPtr);
208  Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
209  VecPhi->addIncoming(ResVec, ColLoopLatch);
210  VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
211 
212  return ResVec;
213  } else {
214  auto *BitCast = cast<BitCastInst>(Tile);
215  Value *Vec = BitCast->getOperand(0);
216  assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
217  // tilestore.scalarize.cols.body:
218  // %mul = mul i16 %row.iv, i16 16
219  // %idx = add i16 %mul, i16 %col.iv
220  // %vec = extractelement <16 x i32> %vec, i16 %idx
221  // store i32 %vec, i32* %ptr
222  B.SetInsertPoint(ColBody->getTerminator());
223  Value *Elt = B.CreateExtractElement(Vec, Idx);
224 
225  B.CreateStore(Elt, EltPtr);
226  return nullptr;
227  }
228 }
229 
230 template <Intrinsic::ID IntrID>
231 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
232  IntrID == Intrinsic::x86_tdpbsud_internal ||
233  IntrID == Intrinsic::x86_tdpbusd_internal ||
234  IntrID == Intrinsic::x86_tdpbuud_internal ||
235  IntrID == Intrinsic::x86_tdpbf16ps_internal,
236  Value *>::type
237 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
238  IRBuilderBase &B, Value *Row,
239  Value *Col, Value *K, Value *Acc,
240  Value *LHS, Value *RHS) {
241  std::string IntrinName;
242  switch (IntrID) {
243  case Intrinsic::x86_tdpbssd_internal:
244  IntrinName = "tiledpbssd";
245  break;
246  case Intrinsic::x86_tdpbsud_internal:
247  IntrinName = "tiledpbsud";
248  break;
249  case Intrinsic::x86_tdpbusd_internal:
250  IntrinName = "tiledpbusd";
251  break;
252  case Intrinsic::x86_tdpbuud_internal:
253  IntrinName = "tiledpbuud";
254  break;
255  case Intrinsic::x86_tdpbf16ps_internal:
256  IntrinName = "tiledpbf16ps";
257  break;
258  }
259  Loop *RowLoop = nullptr;
260  Loop *ColLoop = nullptr;
261  Loop *InnerLoop = nullptr;
262  if (LI) {
263  RowLoop = LI->AllocateLoop();
264  ColLoop = LI->AllocateLoop();
265  InnerLoop = LI->AllocateLoop();
266  ColLoop->addChildLoop(InnerLoop);
267  RowLoop->addChildLoop(ColLoop);
268  if (Loop *ParentL = LI->getLoopFor(Start))
269  ParentL->addChildLoop(RowLoop);
270  else
271  LI->addTopLevelLoop(RowLoop);
272  }
273 
274  BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
275  IntrinName + ".scalarize.rows", B, RowLoop);
276  BasicBlock *RowLatch = RowBody->getSingleSuccessor();
277 
278  BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
279  IntrinName + ".scalarize.cols", B, ColLoop);
280 
281  BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
282 
283  B.SetInsertPoint(ColBody->getTerminator());
284  BasicBlock *InnerBody =
285  createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
286  IntrinName + ".scalarize.inner", B, InnerLoop);
287 
288  BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
289  BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
290  BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
291  BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
292  Value *CurrentRow = &*RowLoopHeader->begin();
293  Value *CurrentCol = &*ColLoopHeader->begin();
294  Value *CurrentInner = &*InnerLoopHeader->begin();
295 
296  FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
297  auto *BitCastAcc = cast<BitCastInst>(Acc);
298  Value *VecC = BitCastAcc->getOperand(0);
299  assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
300  // TODO else create BitCast from x86amx to v256i32.
301  // Store x86amx to memory, and reload from memory
302  // to vector. However with -O0, it doesn't happen.
303  auto *BitCastLHS = cast<BitCastInst>(LHS);
304  Value *VecA = BitCastLHS->getOperand(0);
305  assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
306  auto *BitCastRHS = cast<BitCastInst>(RHS);
307  Value *VecB = BitCastRHS->getOperand(0);
308  assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
309 
310  // tiledpbssd.scalarize.rows.header:
311  // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
312  // %tiledpbssd.scalarize.rows.latch ]
313 
314  // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
315  // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
316  B.SetInsertPoint(RowLoopHeader->getTerminator());
317  PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
318  VecCPhiRowLoop->addIncoming(VecC, Start);
319  Value *VecZero = Constant::getNullValue(V256I32Ty);
320  PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
321  VecDPhiRowLoop->addIncoming(VecZero, Start);
322 
323  // tiledpbssd.scalarize.cols.header:
324  // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
325  // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
326  // %tiledpbssd.scalarize.cols.latch ]
327 
328  // %vec.d.phi.col = phi <256 x i32> [
329  // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
330  // %tiledpbssd.scalarize.cols.latch ]
331 
332  // calculate idxc.
333  B.SetInsertPoint(ColLoopHeader->getTerminator());
334  PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
335  VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
336  PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
337  VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
338  Value *IdxC =
339  B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
340 
341  // tiledpbssd.scalarize.inner.header:
342  // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
343  // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
344  // %tiledpbssd.scalarize.inner.latch ]
345 
346  B.SetInsertPoint(InnerLoopHeader->getTerminator());
347  PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
348  VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
349 
350  B.SetInsertPoint(InnerBody->getTerminator());
351  Value *IdxA =
352  B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
353  Value *IdxB =
354  B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
355  Value *NewVecC = nullptr;
356 
357  if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
358  // tiledpbssd.scalarize.inner.body:
359  // calculate idxa, idxb
360  // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
361  // %elta = extractelement <256 x i32> %veca, i16 %idxa
362  // %eltav4i8 = bitcast i32 %elta to <4 x i8>
363  // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
364  // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
365  // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
366  // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
367  // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
368  // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
369  // %neweltc = add i32 %elt, %acc
370  // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
371  // i16 %idxc
372  FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
373  FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
374  Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
375  Value *EltA = B.CreateExtractElement(VecA, IdxA);
376  Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
377  Value *EltB = B.CreateExtractElement(VecB, IdxB);
378  Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
379  Value *SEXTSubVecB = nullptr;
380  Value *SEXTSubVecA = nullptr;
381  switch (IntrID) {
382  case Intrinsic::x86_tdpbssd_internal:
383  SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
384  SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
385  break;
386  case Intrinsic::x86_tdpbsud_internal:
387  SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
388  SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
389  break;
390  case Intrinsic::x86_tdpbusd_internal:
391  SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
392  SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
393  break;
394  case Intrinsic::x86_tdpbuud_internal:
395  SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
396  SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
397  break;
398  default:
399  llvm_unreachable("Invalid intrinsic ID!");
400  }
401  Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
402  Value *ResElt = B.CreateAdd(EltC, SubVecR);
403  NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
404  } else {
405  // tiledpbf16ps.scalarize.inner.body:
406  // calculate idxa, idxb, idxc
407  // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
408  // %eltcf32 = bitcast i32 %eltc to float
409  // %elta = extractelement <256 x i32> %veca, i16 %idxa
410  // %eltav2i16 = bitcast i32 %elta to <2 x i16>
411  // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
412  // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
413  // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
414  // x i32> <i32 2, i32 0, i32 3, i32 1>
415  // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
416  // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
417  // i32> <i32 2, i32 0, i32 3, i32 1>
418  // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
419  // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
420  // %acc = call float
421  // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
422  // %neweltc = bitcast float %acc to i32
423  // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
424  // i16 %idxc
425  // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
426  // i16 %idxc
427  FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
428  FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
429  Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
430  Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
431  Value *EltA = B.CreateExtractElement(VecA, IdxA);
432  Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
433  Value *EltB = B.CreateExtractElement(VecB, IdxB);
434  Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
435  Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
436  int ShuffleMask[4] = {2, 0, 3, 1};
437  auto ShuffleArray = makeArrayRef(ShuffleMask);
438  Value *AV2F32 = B.CreateBitCast(
439  B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
440  Value *BV2F32 = B.CreateBitCast(
441  B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
442  Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
443  Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
444  NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
445  }
446 
447  // tiledpbssd.scalarize.cols.latch:
448  // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
449  // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
450  // i16 %idxc
451  B.SetInsertPoint(ColLoopLatch->getTerminator());
452  Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
453  Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
454 
455  VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
456  VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
457  VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
458  VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
459  VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
460 
461  return NewVecD;
462 }
463 
464 template <Intrinsic::ID IntrID>
465 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
466  IntrID == Intrinsic::x86_tdpbsud_internal ||
467  IntrID == Intrinsic::x86_tdpbusd_internal ||
468  IntrID == Intrinsic::x86_tdpbuud_internal ||
469  IntrID == Intrinsic::x86_tdpbf16ps_internal,
470  bool>::type
471 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
472  Value *M, *N, *K, *C, *A, *B;
473  match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
474  m_Value(C), m_Value(A), m_Value(B)));
475  Instruction *InsertI = TileDP;
476  IRBuilder<> PreBuilder(TileDP);
477  PreBuilder.SetInsertPoint(TileDP);
478  // We visit the loop with (m, n/4, k/4):
479  // %n_dword = lshr i16 %n, 2
480  // %k_dword = lshr i16 %k, 2
481  Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
482  Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
483  BasicBlock *Start = InsertI->getParent();
484  BasicBlock *End =
485  SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
486  IRBuilder<> Builder(TileDP);
487  Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
488  KDWord, C, A, B);
489  // we cannot assume there always be bitcast after tiledpbssd. So we need to
490  // insert one bitcast as required
491  Builder.SetInsertPoint(End->getFirstNonPHI());
492  Value *ResAMX =
493  Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
494  // Delete TileDP intrinsic and do some clean-up.
495  for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) {
496  Instruction *I = cast<Instruction>((UI++)->getUser());
497  Value *Vec;
498  if (match(I, m_BitCast(m_Value(Vec)))) {
499  I->replaceAllUsesWith(ResVec);
500  I->eraseFromParent();
501  }
502  }
503  TileDP->replaceAllUsesWith(ResAMX);
504  TileDP->eraseFromParent();
505  return true;
506 }
507 
508 template <bool IsTileLoad>
509 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
510  Value *M, *N, *Ptr, *Stride, *Tile;
511  if (IsTileLoad)
512  match(TileLoadStore,
513  m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
514  m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
515  else
516  match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
517  m_Value(M), m_Value(N), m_Value(Ptr),
518  m_Value(Stride), m_Value(Tile)));
519 
520  Instruction *InsertI = TileLoadStore;
521  IRBuilder<> PreBuilder(TileLoadStore);
522  PreBuilder.SetInsertPoint(TileLoadStore);
523  Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
524  Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
525  BasicBlock *Start = InsertI->getParent();
526  BasicBlock *End =
527  SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
528  IRBuilder<> Builder(TileLoadStore);
529  Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
530  Start, End, Builder, M, NDWord, Ptr, StrideDWord,
531  IsTileLoad ? nullptr : Tile);
532  if (IsTileLoad) {
533  // we cannot assume there always be bitcast after tileload. So we need to
534  // insert one bitcast as required
535  Builder.SetInsertPoint(End->getFirstNonPHI());
536  Value *ResAMX =
537  Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
538  // Delete tileloadd6 intrinsic and do some clean-up
539  for (auto UI = TileLoadStore->use_begin(), UE = TileLoadStore->use_end();
540  UI != UE;) {
541  Instruction *I = cast<Instruction>((UI++)->getUser());
542  Value *Vec;
543  if (match(I, m_BitCast(m_Value(Vec)))) {
544  I->replaceAllUsesWith(ResVec);
545  I->eraseFromParent();
546  }
547  }
548  TileLoadStore->replaceAllUsesWith(ResAMX);
549  }
550  TileLoadStore->eraseFromParent();
551  return true;
552 }
553 
554 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
555  IRBuilder<> Builder(TileZero);
556  FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
557  Value *VecZero = Constant::getNullValue(V256I32Ty);
558  for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) {
559  Instruction *I = cast<Instruction>((UI++)->getUser());
560  Value *Vec;
561  if (match(I, m_BitCast(m_Value(Vec)))) {
562  I->replaceAllUsesWith(VecZero);
563  I->eraseFromParent();
564  }
565  }
566  TileZero->eraseFromParent();
567  return true;
568 }
569 
570 bool X86LowerAMXIntrinsics::visit() {
571  bool C = false;
573  for (BasicBlock *BB : depth_first(&Func)) {
574  for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
575  if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
576  switch (Inst->getIntrinsicID()) {
577  case Intrinsic::x86_tdpbssd_internal:
578  case Intrinsic::x86_tdpbsud_internal:
579  case Intrinsic::x86_tdpbusd_internal:
580  case Intrinsic::x86_tdpbuud_internal:
581  case Intrinsic::x86_tileloadd64_internal:
582  case Intrinsic::x86_tilestored64_internal:
583  case Intrinsic::x86_tilezero_internal:
584  case Intrinsic::x86_tdpbf16ps_internal:
585  WorkList.push_back(Inst);
586  break;
587  default:
588  break;
589  }
590  }
591  }
592  }
593 
594  for (auto *Inst : WorkList) {
595  switch (Inst->getIntrinsicID()) {
596  case Intrinsic::x86_tdpbssd_internal:
597  C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
598  break;
599  case Intrinsic::x86_tdpbsud_internal:
600  C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
601  break;
602  case Intrinsic::x86_tdpbusd_internal:
603  C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
604  break;
605  case Intrinsic::x86_tdpbuud_internal:
606  C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
607  break;
608  case Intrinsic::x86_tdpbf16ps_internal:
609  C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
610  break;
611  case Intrinsic::x86_tileloadd64_internal:
612  C = lowerTileLoadStore<true>(Inst) || C;
613  break;
614  case Intrinsic::x86_tilestored64_internal:
615  C = lowerTileLoadStore<false>(Inst) || C;
616  break;
617  case Intrinsic::x86_tilezero_internal:
618  C = lowerTileZero(Inst) || C;
619  break;
620  default:
621  llvm_unreachable("invalid amx intrinsics!");
622  }
623  }
624 
625  return C;
626 }
627 } // anonymous namespace
628 
629 namespace {
630 
631 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
632 public:
633  static char ID;
634 
635  X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
638  }
639 
640  bool runOnFunction(Function &F) override {
641  TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
642  if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
643  TM->getOptLevel() != CodeGenOpt::None)
644  return false;
645 
646  auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
647  auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
648  auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
649  auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
651 
652  X86LowerAMXIntrinsics LAT(F, DTU, LI);
653  return LAT.visit();
654  }
655  StringRef getPassName() const override { return "Lower AMX intrinsics"; }
656 
657  void getAnalysisUsage(AnalysisUsage &AU) const override {
661  }
662 };
663 
664 } // anonymous namespace
665 
666 static const char PassName[] = "Lower AMX intrinsics";
668 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
669  false, false)
671 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
673 
675  return new X86LowerAMXIntrinsicsLegacyPass();
676 }
ValueTypes.h
llvm
Definition: AllocatorList.h:23
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::SystemZISD::TM
@ TM
Definition: SystemZISelLowering.h:65
type
llvm::BasicBlock::iterator
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:90
llvm::BasicBlock::getParent
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:107
IntrinsicInst.h
llvm::Function
Definition: Function.h:61
llvm::Loop
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:530
Pass.h
llvm::PointerType::get
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Definition: Type.cpp:687
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1168
llvm::IRBuilder<>
DomTreeUpdater.h
OptimizationRemarkEmitter.h
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:46
llvm::LoopInfoWrapperPass
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:1258
llvm::BasicBlock::getSingleSuccessor
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
Definition: BasicBlock.cpp:294
llvm::DominatorTreeBase< BasicBlock, false >::Insert
static constexpr UpdateKind Insert
Definition: GenericDomTree.h:242
Offset
uint64_t Offset
Definition: ELFObjHandler.cpp:81
llvm::BasicBlock::getSinglePredecessor
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:264
llvm::PatternMatch::m_BitCast
CastClass_match< OpTy, Instruction::BitCast > m_BitCast(const OpTy &Op)
Matches BitCast.
Definition: PatternMatch.h:1603
llvm::FixedVectorType
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:527
F
#define F(x, y, z)
Definition: MD5.cpp:56
llvm::BranchInst::setSuccessor
void setSuccessor(unsigned idx, BasicBlock *NewSucc)
Definition: Instructions.h:3105
llvm::DomTreeUpdater::UpdateStrategy::Lazy
@ Lazy
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
X86.h
llvm::Type::getX86_AMXTy
static Type * getX86_AMXTy(LLVMContext &C)
Definition: Type.cpp:192
TargetMachine.h
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::ARM_PROC::A
@ A
Definition: ARMBaseInfo.h:34
llvm::LoopBase::addChildLoop
void addChildLoop(LoopT *NewChild)
Add the specified loop to be a child of this loop.
Definition: LoopInfo.h:395
llvm::BasicBlock::begin
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:296
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
DenseSet.h
false
Definition: StackSlotColoring.cpp:142
B
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
llvm::Instruction
Definition: Instruction.h:45
llvm::createX86LowerAMXIntrinsicsPass
FunctionPass * createX86LowerAMXIntrinsicsPass()
The pass transforms amx intrinsics to scalar operation if the function has optnone attribute or it is...
Definition: X86LowerAMXIntrinsics.cpp:674
llvm::DominatorTreeWrapperPass
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:281
llvm::DomTreeUpdater
Definition: DomTreeUpdater.h:28
llvm::ConstantInt::get
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:885
LoopUtils.h
PatternMatch.h
llvm::FixedVectorType::get
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:644
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
llvm::ARM_PROC::IE
@ IE
Definition: ARMBaseInfo.h:27
Passes.h
llvm::TargetPassConfig
Target-Independent Code Generator Pass Configuration Options.
Definition: TargetPassConfig.h:84
DEBUG_TYPE
#define DEBUG_TYPE
Definition: X86LowerAMXIntrinsics.cpp:44
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:78
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2722
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
llvm::BranchInst::Create
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
Definition: Instructions.h:3063
I
#define I(x, y, z)
Definition: MD5.cpp:59
TargetPassConfig.h
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::TargetMachine
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
llvm::elfabi::ELFSymbolType::Func
@ Func
llvm::Value::use_begin
use_iterator use_begin()
Definition: Value.h:373
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::CodeGenOpt::None
@ None
Definition: CodeGen.h:53
llvm::LoopInfo
Definition: LoopInfo.h:1080
DataLayout.h
Cond
SmallVector< MachineOperand, 4 > Cond
Definition: BasicBlockSections.cpp:167
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:57
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:136
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
llvm::AnalysisUsage::addPreserved
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
Definition: PassAnalysisSupport.h:98
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:527
llvm::BasicBlock::Create
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:100
llvm::IRBuilderBase
Common base class shared among various IRBuilders.
Definition: IRBuilder.h:95
llvm::Value::use_end
use_iterator use_end()
Definition: Value.h:381
llvm::BasicBlock::getTerminator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:148
llvm::BasicBlock::getContext
LLVMContext & getContext() const
Get the context in which this basic block lives.
Definition: BasicBlock.cpp:32
llvm::depth_first
iterator_range< df_iterator< T > > depth_first(const T &G)
Definition: DepthFirstIterator.h:229
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
llvm::GraphProgram::Name
Name
Definition: GraphWriter.h:52
llvm::Constant::getNullValue
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:347
llvm::initializeX86LowerAMXIntrinsicsLegacyPassPass
void initializeX86LowerAMXIntrinsicsLegacyPassPass(PassRegistry &)
llvm::PHINode::Create
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
Definition: Instructions.h:2614
Function.h
llvm::makeArrayRef
ArrayRef< T > makeArrayRef(const T &OneElt)
Construct an ArrayRef from a single element.
Definition: ArrayRef.h:474
Instructions.h
PostOrderIterator.h
llvm::LoopBase::addBasicBlockToLoop
void addBasicBlockToLoop(BlockT *NewBB, LoopInfoBase< BlockT, LoopT > &LI)
This method is used by other analyses to update loop information.
Definition: LoopInfoImpl.h:242
N
#define N
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:94
TargetTransformInfo.h
llvm::PHINode
Definition: Instructions.h:2572
llvm::Type::getInt16Ty
static IntegerType * getInt16Ty(LLVMContext &C)
Definition: Type.cpp:196
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:298
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
llvm::BranchInst
Conditional or Unconditional Branch instruction.
Definition: Instructions.h:3007
BasicBlockUtils.h
llvm::SplitBlock
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Definition: BasicBlockUtils.cpp:820
InitializePasses.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:75
PassName
static const char PassName[]
Definition: X86LowerAMXIntrinsics.cpp:666
isV256I32Ty
static bool isV256I32Ty(Type *Ty)
Definition: X86LowerAMXIntrinsics.cpp:47
llvm::BranchInst::getSuccessor
BasicBlock * getSuccessor(unsigned i) const
Definition: Instructions.h:3100
llvm::DominatorTreeBase< BasicBlock, false >::Delete
static constexpr UpdateKind Delete
Definition: GenericDomTree.h:243
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:38