LLVM  15.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 // * Improve fusion:
13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14 // transposed.
15 // * Improve cost-modeling, e.g. choose different number of rows/columns
16 // columns for tiles, consider cost of copies on alias.
17 //
18 //===----------------------------------------------------------------------===//
19 
22 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Analysis/LoopInfo.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/Instructions.h"
36 #include "llvm/IR/IntrinsicInst.h"
37 #include "llvm/IR/MatrixBuilder.h"
38 #include "llvm/IR/PatternMatch.h"
39 #include "llvm/InitializePasses.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Alignment.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Transforms/Scalar.h"
48 
49 using namespace llvm;
50 using namespace PatternMatch;
51 
52 #define DEBUG_TYPE "lower-matrix-intrinsics"
53 
54 static cl::opt<bool>
55  FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
56  cl::desc("Enable/disable fusing matrix instructions."));
57 // TODO: Allow and use non-square tiles.
59  "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
60  cl::desc(
61  "Tile size for matrix instruction fusion using square-shaped tiles."));
62 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
63  cl::Hidden,
64  cl::desc("Generate loop nest for tiling."));
66  "force-fuse-matrix", cl::init(false), cl::Hidden,
67  cl::desc("Force matrix instruction fusion even if not profitable."));
69  "matrix-allow-contract", cl::init(false), cl::Hidden,
70  cl::desc("Allow the use of FMAs if available and profitable. This may "
71  "result in different results, due to less rounding error."));
72 
73 enum class MatrixLayoutTy { ColumnMajor, RowMajor };
74 
76  "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
77  cl::desc("Sets the default matrix layout"),
79  "Use column-major layout"),
81  "Use row-major layout")));
82 
83 /// Helper function to either return Scope, if it is a subprogram or the
84 /// attached subprogram for a local scope.
86  if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
87  return Subprogram;
88  return cast<DILocalScope>(Scope)->getSubprogram();
89 }
90 
91 namespace {
92 
93 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
94 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
95 // assuming \p Stride elements between start two consecutive vectors.
96 // \p Stride must be >= \p NumElements.
97 // For column-major matrixes, the function computes the address of a column
98 // vectors and \p NumElements must be set to the number of elements in a column
99 // (= number of rows of the matrix). For row-major matrixes, the function
100 // computes the address of a row vector and \p NumElements must be set to the
101 // number of elements in a column (= number of columns of the matrix).
102 //
103 // Consider a 4x4 matrix in column-mjaor layout like below
104 //
105 // 0 1 2 3
106 // 0 v_0_0 v_0_1 v_0_2 v_0_3
107 // 1 v_1_0 v_1_1 v_1_2 v_1_3
108 // 2 v_2_0 v_2_1 v_2_2 v_2_3
109 // 3 v_3_0 v_3_1 v_3_2 v_3_3
110 
111 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
112 // we need a pointer to the first element of the submatrix as base pointer.
113 // Then we can use computeVectorAddr to compute the addresses for the columns
114 // of the sub-matrix.
115 //
116 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
117 // -> just returns Base
118 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
119 // -> returns Base + (1 * 4)
120 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
121 // -> returns Base + (2 * 4)
122 //
123 // The graphic below illustrates the number of elements in a column (marked
124 // with |) and the number of skipped elements (marked with }).
125 //
126 // v_0_0 v_0_1 {v_0_2 {v_0_3
127 // Base Col 1 Col 2
128 // | | |
129 // v_1_0 |v_1_1 |v_1_2 |v_1_3
130 // v_2_0 |v_2_1 |v_2_2 |v_2_3
131 // v_3_0 {v_3_1 {v_3_2 v_3_3
132 //
133 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
134  unsigned NumElements, Type *EltType,
135  IRBuilder<> &Builder) {
136 
137  assert((!isa<ConstantInt>(Stride) ||
138  cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
139  "Stride must be >= the number of elements in the result vector.");
140  unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
141 
142  // Compute the start of the vector with index VecIdx as VecIdx * Stride.
143  Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
144 
145  // Get pointer to the start of the selected vector. Skip GEP creation,
146  // if we select vector 0.
147  if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
148  VecStart = BasePtr;
149  else
150  VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
151 
152  // Cast elementwise vector start pointer to a pointer to a vector
153  // (EltType x NumElements)*.
154  auto *VecType = FixedVectorType::get(EltType, NumElements);
155  Type *VecPtrType = PointerType::get(VecType, AS);
156  return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
157 }
158 
159 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
160 ///
161 /// Currently, the lowering for each matrix intrinsic is done as follows:
162 /// 1. Propagate the shape information from intrinsics to connected
163 /// instructions.
164 /// 2. Lower instructions with shape information (assuming column-major layout).
165 /// The lowering works similarly using row-major layout.
166 /// 2.1. Get column vectors for each argument. If we already lowered the
167 /// definition of an argument, use the produced column vectors directly.
168 /// If not, split the operand vector containing an embedded matrix into
169 /// a set of column vectors,
170 /// 2.2. Lower the instruction in terms of column major operations, which
171 /// yields a set of column vectors containing result matrix. Note that we
172 /// lower all instructions that have shape information. Besides the
173 /// intrinsics, this includes stores for example.
174 /// 2.3. Update uses of the lowered instruction. If we have shape information
175 /// for a user, there is nothing to do, as we will look up the result
176 /// column matrix when lowering the user. For other uses, we embed the
177 /// result matrix in a flat vector and update the use.
178 /// 2.4. Cache the result column matrix for the instruction we lowered
179 /// 3. After we lowered all instructions in a function, remove the now
180 /// obsolete instructions.
181 ///
182 class LowerMatrixIntrinsics {
183  Function &Func;
184  const DataLayout &DL;
185  const TargetTransformInfo &TTI;
186  AliasAnalysis *AA;
187  DominatorTree *DT;
188  LoopInfo *LI;
190 
191  /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
192  struct OpInfoTy {
193  /// Number of stores emitted to generate this matrix.
194  unsigned NumStores = 0;
195  /// Number of loads emitted to generate this matrix.
196  unsigned NumLoads = 0;
197  /// Number of compute operations emitted to generate this matrix.
198  unsigned NumComputeOps = 0;
199  /// Most of the time transposes can be fused with matrix multiplies or can
200  /// be folded away via algebraic simplifications. This is the number of
201  /// transposes that we failed to make "free" via such optimizations.
202  unsigned NumExposedTransposes = 0;
203 
204  OpInfoTy &operator+=(const OpInfoTy &RHS) {
205  NumStores += RHS.NumStores;
206  NumLoads += RHS.NumLoads;
207  NumComputeOps += RHS.NumComputeOps;
208  NumExposedTransposes += RHS.NumExposedTransposes;
209  return *this;
210  }
211  };
212 
213  /// Wrapper class representing a matrix as a set of vectors, either in row or
214  /// column major layout. All vectors must have the same vector type.
215  class MatrixTy {
216  SmallVector<Value *, 16> Vectors;
217 
218  OpInfoTy OpInfo;
219 
220  bool IsColumnMajor = true;
221 
222  public:
223  MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
224  MatrixTy(ArrayRef<Value *> Vectors)
225  : Vectors(Vectors.begin(), Vectors.end()),
226  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
227  MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
228  : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
229 
230  unsigned D = isColumnMajor() ? NumColumns : NumRows;
231  for (unsigned J = 0; J < D; ++J)
233  EltTy, isColumnMajor() ? NumRows : NumColumns)));
234  }
235 
236  Value *getVector(unsigned i) const { return Vectors[i]; }
237  Value *getColumn(unsigned i) const {
238  assert(isColumnMajor() && "only supported for column-major matrixes");
239  return Vectors[i];
240  }
241  Value *getRow(unsigned i) const {
242  assert(!isColumnMajor() && "only supported for row-major matrixes");
243  return Vectors[i];
244  }
245 
246  void setVector(unsigned i, Value *V) { Vectors[i] = V; }
247 
248  Type *getElementType() const { return getVectorTy()->getElementType(); }
249 
250  unsigned getNumVectors() const {
251  if (isColumnMajor())
252  return getNumColumns();
253  return getNumRows();
254  }
255 
256  unsigned getNumColumns() const {
257  if (isColumnMajor())
258  return Vectors.size();
259  else {
260  assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
261  return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
262  }
263  }
264  unsigned getNumRows() const {
265  if (isColumnMajor()) {
266  assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
267  return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
268  } else
269  return Vectors.size();
270  }
271 
272  void addVector(Value *V) { Vectors.push_back(V); }
273  VectorType *getColumnTy() {
274  assert(isColumnMajor() && "only supported for column-major matrixes");
275  return getVectorTy();
276  }
277 
278  VectorType *getVectorTy() const {
279  return cast<VectorType>(Vectors[0]->getType());
280  }
281 
283  assert(isColumnMajor() &&
284  "columns() only supported for column-major matrixes");
285  return make_range(Vectors.begin(), Vectors.end());
286  }
287 
289  return make_range(Vectors.begin(), Vectors.end());
290  }
291 
292  /// Embed the vectors of the matrix into a flat vector by concatenating
293  /// them.
294  Value *embedInVector(IRBuilder<> &Builder) const {
295  return Vectors.size() == 1 ? Vectors[0]
296  : concatenateVectors(Builder, Vectors);
297  }
298 
299  MatrixTy &addNumLoads(unsigned N) {
300  OpInfo.NumLoads += N;
301  return *this;
302  }
303 
304  void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
305 
306  MatrixTy &addNumStores(unsigned N) {
307  OpInfo.NumStores += N;
308  return *this;
309  }
310 
311  MatrixTy &addNumExposedTransposes(unsigned N) {
312  OpInfo.NumExposedTransposes += N;
313  return *this;
314  }
315 
316  MatrixTy &addNumComputeOps(unsigned N) {
317  OpInfo.NumComputeOps += N;
318  return *this;
319  }
320 
321  unsigned getNumStores() const { return OpInfo.NumStores; }
322  unsigned getNumLoads() const { return OpInfo.NumLoads; }
323  unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
324 
325  const OpInfoTy &getOpInfo() const { return OpInfo; }
326 
327  bool isColumnMajor() const { return IsColumnMajor; }
328 
329  unsigned getStride() const {
330  if (isColumnMajor())
331  return getNumRows();
332  return getNumColumns();
333  }
334 
335  /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
336  /// matrix is column-major, the result vector is extracted from a column
337  /// vector, otherwise from a row vector.
338  Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
339  IRBuilder<> &Builder) const {
340  Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
341  return Builder.CreateShuffleVector(
342  Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
343  "block");
344  }
345  };
346 
347  struct ShapeInfo {
348  unsigned NumRows;
349  unsigned NumColumns;
350 
351  bool IsColumnMajor;
352 
353  ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
354  : NumRows(NumRows), NumColumns(NumColumns),
355  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
356 
357  ShapeInfo(Value *NumRows, Value *NumColumns)
358  : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
359  cast<ConstantInt>(NumColumns)->getZExtValue()) {}
360 
361  bool operator==(const ShapeInfo &other) {
362  return NumRows == other.NumRows && NumColumns == other.NumColumns;
363  }
364  bool operator!=(const ShapeInfo &other) { return !(*this == other); }
365 
366  /// Returns true if shape-information is defined, meaning both dimensions
367  /// are != 0.
368  operator bool() const {
369  assert(NumRows == 0 || NumColumns != 0);
370  return NumRows != 0;
371  }
372 
373  unsigned getStride() const {
374  if (IsColumnMajor)
375  return NumRows;
376  return NumColumns;
377  }
378 
379  unsigned getNumVectors() const {
380  if (IsColumnMajor)
381  return NumColumns;
382  return NumRows;
383  }
384  };
385 
386  /// Maps instructions to their shape information. The shape information
387  /// describes the shape to be used while lowering. This matches the shape of
388  /// the result value of the instruction, with the only exceptions being store
389  /// instructions and the matrix_column_major_store intrinsics. For those, the
390  /// shape information indicates that those instructions should be lowered
391  /// using shape information as well. A ValueMap is used so that when
392  /// sub-passes like optimizeTransposes performs RAUW the map stays
393  /// up-to-date.
395 
396  /// List of instructions to remove. While lowering, we are not replacing all
397  /// users of a lowered instruction, if shape information is available and
398  /// those need to be removed after we finished lowering.
400 
401  /// Map from instructions to their produced column matrix.
402  MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
403 
404 private:
405  static FastMathFlags getFastMathFlags(Instruction *Inst) {
406  FastMathFlags FMF;
407 
408  if (isa<FPMathOperator>(*Inst))
409  FMF = Inst->getFastMathFlags();
410 
412 
413  return FMF;
414  }
415 
416 public:
417  LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
420  : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
421  LI(LI), ORE(ORE) {}
422 
423  unsigned getNumOps(Type *VT) {
424  assert(isa<VectorType>(VT) && "Expected vector type");
425  return getNumOps(VT->getScalarType(),
426  cast<FixedVectorType>(VT)->getNumElements());
427  }
428 
429  /// Is this the minimal version executed in the backend pipelines.
430  bool isMinimal() const {
431  return !DT;
432  }
433 
434  /// Return the estimated number of vector ops required for an operation on
435  /// \p VT * N.
436  unsigned getNumOps(Type *ST, unsigned N) {
437  return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
438  double(TTI.getRegisterBitWidth(
440  .getFixedSize()));
441  }
442 
443  /// Return the set of vectors that a matrix value is lowered to.
444  ///
445  /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
446  /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
447  /// into vectors.
448  MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
449  IRBuilder<> &Builder) {
450  VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
451  assert(VType && "MatrixVal must be a vector type");
452  assert(cast<FixedVectorType>(VType)->getNumElements() ==
453  SI.NumRows * SI.NumColumns &&
454  "The vector size must match the number of matrix elements");
455 
456  // Check if we lowered MatrixVal using shape information. In that case,
457  // return the existing matrix, if it matches the requested shape
458  // information. If there is a mis-match, embed the result in a flat
459  // vector and split it later.
460  auto Found = Inst2ColumnMatrix.find(MatrixVal);
461  if (Found != Inst2ColumnMatrix.end()) {
462  MatrixTy &M = Found->second;
463  // Return the found matrix, if its shape matches the requested shape
464  // information
465  if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
466  return M;
467 
468  MatrixVal = M.embedInVector(Builder);
469  }
470 
471  // Otherwise split MatrixVal.
472  SmallVector<Value *, 16> SplitVecs;
473  for (unsigned MaskStart = 0;
474  MaskStart < cast<FixedVectorType>(VType)->getNumElements();
475  MaskStart += SI.getStride()) {
476  Value *V = Builder.CreateShuffleVector(
477  MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
478  "split");
479  SplitVecs.push_back(V);
480  }
481 
482  return {SplitVecs};
483  }
484 
485  /// If \p V already has a known shape return false. Otherwise set the shape
486  /// for instructions that support it.
487  bool setShapeInfo(Value *V, ShapeInfo Shape) {
488  assert(Shape && "Shape not set");
489  if (isa<UndefValue>(V) || !supportsShapeInfo(V))
490  return false;
491 
492  auto SIter = ShapeMap.find(V);
493  if (SIter != ShapeMap.end()) {
494  LLVM_DEBUG(dbgs() << " not overriding existing shape: "
495  << SIter->second.NumRows << " "
496  << SIter->second.NumColumns << " for " << *V << "\n");
497  return false;
498  }
499 
500  ShapeMap.insert({V, Shape});
501  LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
502  << " for " << *V << "\n");
503  return true;
504  }
505 
506  bool isUniformShape(Value *V) {
507  Instruction *I = dyn_cast<Instruction>(V);
508  if (!I)
509  return true;
510 
511  switch (I->getOpcode()) {
512  case Instruction::FAdd:
513  case Instruction::FSub:
514  case Instruction::FMul: // Scalar multiply.
515  case Instruction::FNeg:
516  case Instruction::Add:
517  case Instruction::Mul:
518  case Instruction::Sub:
519  return true;
520  default:
521  return false;
522  }
523  }
524 
525  /// Returns true if shape information can be used for \p V. The supported
526  /// instructions must match the instructions that can be lowered by this pass.
527  bool supportsShapeInfo(Value *V) {
528  Instruction *Inst = dyn_cast<Instruction>(V);
529  if (!Inst)
530  return false;
531 
532  IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
533  if (II)
534  switch (II->getIntrinsicID()) {
535  case Intrinsic::matrix_multiply:
536  case Intrinsic::matrix_transpose:
537  case Intrinsic::matrix_column_major_load:
538  case Intrinsic::matrix_column_major_store:
539  return true;
540  default:
541  return false;
542  }
543  return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
544  }
545 
546  /// Propagate the shape information of instructions to their users.
547  /// The work list contains instructions for which we can compute the shape,
548  /// either based on the information provided by matrix intrinsics or known
549  /// shapes of operands.
551  propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
552  SmallVector<Instruction *, 32> NewWorkList;
553  // Pop an element for which we guaranteed to have at least one of the
554  // operand shapes. Add the shape for this and then add users to the work
555  // list.
556  LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
557  while (!WorkList.empty()) {
558  Instruction *Inst = WorkList.pop_back_val();
559 
560  // New entry, set the value and insert operands
561  bool Propagate = false;
562 
563  Value *MatrixA;
564  Value *MatrixB;
565  Value *M;
566  Value *N;
567  Value *K;
568  if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
569  m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
570  m_Value(N), m_Value(K)))) {
571  Propagate = setShapeInfo(Inst, {M, K});
572  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
573  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
574  // Flip dimensions.
575  Propagate = setShapeInfo(Inst, {N, M});
576  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
577  m_Value(MatrixA), m_Value(), m_Value(),
578  m_Value(), m_Value(M), m_Value(N)))) {
579  Propagate = setShapeInfo(Inst, {N, M});
580  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
581  m_Value(), m_Value(), m_Value(), m_Value(M),
582  m_Value(N)))) {
583  Propagate = setShapeInfo(Inst, {M, N});
584  } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
585  auto OpShape = ShapeMap.find(MatrixA);
586  if (OpShape != ShapeMap.end())
587  setShapeInfo(Inst, OpShape->second);
588  continue;
589  } else if (isUniformShape(Inst)) {
590  // Find the first operand that has a known shape and use that.
591  for (auto &Op : Inst->operands()) {
592  auto OpShape = ShapeMap.find(Op.get());
593  if (OpShape != ShapeMap.end()) {
594  Propagate |= setShapeInfo(Inst, OpShape->second);
595  break;
596  }
597  }
598  }
599 
600  if (Propagate) {
601  NewWorkList.push_back(Inst);
602  for (auto *User : Inst->users())
603  if (ShapeMap.count(User) == 0)
604  WorkList.push_back(cast<Instruction>(User));
605  }
606  }
607 
608  return NewWorkList;
609  }
610 
611  /// Propagate the shape to operands of instructions with shape information.
612  /// \p Worklist contains the instruction for which we already know the shape.
614  propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
615  SmallVector<Instruction *, 32> NewWorkList;
616 
617  auto pushInstruction = [](Value *V,
618  SmallVectorImpl<Instruction *> &WorkList) {
619  Instruction *I = dyn_cast<Instruction>(V);
620  if (I)
621  WorkList.push_back(I);
622  };
623  // Pop an element with known shape. Traverse the operands, if their shape
624  // derives from the result shape and is unknown, add it and add them to the
625  // worklist.
626  LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
627  while (!WorkList.empty()) {
628  Value *V = WorkList.pop_back_val();
629 
630  size_t BeforeProcessingV = WorkList.size();
631  if (!isa<Instruction>(V))
632  continue;
633 
634  Value *MatrixA;
635  Value *MatrixB;
636  Value *M;
637  Value *N;
638  Value *K;
639  if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
640  m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
641  m_Value(N), m_Value(K)))) {
642  if (setShapeInfo(MatrixA, {M, N}))
643  pushInstruction(MatrixA, WorkList);
644 
645  if (setShapeInfo(MatrixB, {N, K}))
646  pushInstruction(MatrixB, WorkList);
647 
648  } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
649  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
650  // Flip dimensions.
651  if (setShapeInfo(MatrixA, {M, N}))
652  pushInstruction(MatrixA, WorkList);
653  } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
654  m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
655  m_Value(M), m_Value(N)))) {
656  if (setShapeInfo(MatrixA, {M, N})) {
657  pushInstruction(MatrixA, WorkList);
658  }
659  } else if (isa<LoadInst>(V) ||
660  match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
661  // Nothing to do, no matrix input.
662  } else if (isa<StoreInst>(V)) {
663  // Nothing to do. We forward-propagated to this so we would just
664  // backward propagate to an instruction with an already known shape.
665  } else if (isUniformShape(V)) {
666  // Propagate to all operands.
667  ShapeInfo Shape = ShapeMap[V];
668  for (Use &U : cast<Instruction>(V)->operands()) {
669  if (setShapeInfo(U.get(), Shape))
670  pushInstruction(U.get(), WorkList);
671  }
672  }
673  // After we discovered new shape info for new instructions in the
674  // worklist, we use their users as seeds for the next round of forward
675  // propagation.
676  for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
677  for (User *U : WorkList[I]->users())
678  if (isa<Instruction>(U) && V != U)
679  NewWorkList.push_back(cast<Instruction>(U));
680  }
681  return NewWorkList;
682  }
683 
684  /// Try moving transposes in order to fold them away or into multiplies.
685  void optimizeTransposes() {
686  auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
687  // We need to remove Old from the ShapeMap otherwise RAUW will replace it
688  // with New. We should only add New it it supportsShapeInfo so we insert
689  // it conditionally instead.
690  auto S = ShapeMap.find(&Old);
691  if (S != ShapeMap.end()) {
692  ShapeMap.erase(S);
693  if (supportsShapeInfo(New))
694  ShapeMap.insert({New, S->second});
695  }
696  Old.replaceAllUsesWith(New);
697  };
698 
699  // First sink all transposes inside matmuls, hoping that we end up with NN,
700  // NT or TN variants.
701  for (BasicBlock &BB : reverse(Func)) {
702  for (auto II = BB.rbegin(); II != BB.rend();) {
703  Instruction &I = *II;
704  // We may remove II. By default continue on the next/prev instruction.
705  ++II;
706  // If we were to erase II, move again.
707  auto EraseFromParent = [&II, &BB](Value *V) {
708  auto *Inst = cast<Instruction>(V);
709  if (Inst->use_empty()) {
710  if (II != BB.rend() && Inst == &*II) {
711  ++II;
712  }
713  Inst->eraseFromParent();
714  }
715  };
716 
717  // If we're creating a new instruction, continue from there.
718  Instruction *NewInst = nullptr;
719 
720  IRBuilder<> IB(&I);
722 
723  Value *TA, *TAMA, *TAMB;
724  ConstantInt *R, *K, *C;
725  if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) {
726 
727  // Transpose of a transpose is a nop
728  Value *TATA;
729  if (match(TA,
730  m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
731  ReplaceAllUsesWith(I, TATA);
732  EraseFromParent(&I);
733  EraseFromParent(TA);
734  }
735 
736  // (A * B)^t -> B^t * A^t
737  // RxK KxC CxK KxR
738  else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
739  m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
740  m_ConstantInt(K), m_ConstantInt(C)))) {
741  Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(),
742  C->getZExtValue(),
743  TAMB->getName() + "_t");
744  // We are being run after shape prop, add shape for newly created
745  // instructions so that we lower them later.
746  setShapeInfo(T0, {C, K});
747  Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(),
748  K->getZExtValue(),
749  TAMA->getName() + "_t");
750  setShapeInfo(T1, {K, R});
751  NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
752  K->getZExtValue(),
753  R->getZExtValue(), "mmul");
754  ReplaceAllUsesWith(I, NewInst);
755  EraseFromParent(&I);
756  EraseFromParent(TA);
757  }
758  }
759 
760  // If we replaced I with a new instruction, continue from there.
761  if (NewInst)
762  II = std::next(BasicBlock::reverse_iterator(NewInst));
763  }
764  }
765 
766  // If we have a TT matmul, lift the transpose. We may be able to fold into
767  // consuming multiply.
768  for (BasicBlock &BB : Func) {
770  Value *A, *B, *AT, *BT;
771  ConstantInt *R, *K, *C;
772  // A^t * B ^t -> (B * A)^t
773  if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
774  m_Value(A), m_Value(B), m_ConstantInt(R),
775  m_ConstantInt(K), m_ConstantInt(C))) &&
776  match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
777  match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
778  IRBuilder<> IB(&I);
780  Value *M = Builder.CreateMatrixMultiply(
781  BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
782  setShapeInfo(M, {C, R});
783  Instruction *NewInst = Builder.CreateMatrixTranspose(
784  M, C->getZExtValue(), R->getZExtValue());
785  ReplaceAllUsesWith(I, NewInst);
786  if (I.use_empty())
787  I.eraseFromParent();
788  if (A->use_empty())
789  cast<Instruction>(A)->eraseFromParent();
790  if (A != B && B->use_empty())
791  cast<Instruction>(B)->eraseFromParent();
792  }
793  }
794  }
795  }
796 
797  bool Visit() {
799 
800  // Initially only the shape of matrix intrinsics is known.
801  // Initialize the work list with ops carrying shape information.
802  for (BasicBlock &BB : Func)
803  for (Instruction &Inst : BB) {
804  IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
805  if (!II)
806  continue;
807 
808  switch (II->getIntrinsicID()) {
809  case Intrinsic::matrix_multiply:
810  case Intrinsic::matrix_transpose:
811  case Intrinsic::matrix_column_major_load:
812  case Intrinsic::matrix_column_major_store:
813  WorkList.push_back(&Inst);
814  break;
815  default:
816  break;
817  }
818  }
819 
820  // Avoid unnecessary work if there are no matrix intrinsics in the function.
821  if (WorkList.empty())
822  return false;
823 
824  // Propagate shapes until nothing changes any longer.
825  while (!WorkList.empty()) {
826  WorkList = propagateShapeForward(WorkList);
827  WorkList = propagateShapeBackward(WorkList);
828  }
829 
830  if (!isMinimal()) {
831  optimizeTransposes();
832  LLVM_DEBUG({
833  dbgs() << "Dump after matrix transpose optimization:\n";
834  Func.dump();
835  });
836  }
837 
838  bool Changed = false;
839  SmallVector<CallInst *, 16> MaybeFusableInsts;
840  SmallVector<Instruction *, 16> MatrixInsts;
841 
842  // First, collect all instructions with shape information and candidates for
843  // fusion (currently only matrix multiplies).
845  for (auto *BB : RPOT)
846  for (Instruction &I : *BB) {
847  if (ShapeMap.find(&I) == ShapeMap.end())
848  continue;
849  if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
850  MaybeFusableInsts.push_back(cast<CallInst>(&I));
851  MatrixInsts.push_back(&I);
852  }
853 
854  // Second, try to fuse candidates.
856  for (CallInst *CI : MaybeFusableInsts)
857  LowerMatrixMultiplyFused(CI, FusedInsts);
858  Changed = !FusedInsts.empty();
859 
860  // Third, lower remaining instructions with shape information.
861  for (Instruction *Inst : MatrixInsts) {
862  if (FusedInsts.count(Inst))
863  continue;
864 
865  IRBuilder<> Builder(Inst);
866 
867  if (CallInst *CInst = dyn_cast<CallInst>(Inst))
868  Changed |= VisitCallInst(CInst);
869 
870  Value *Op1;
871  Value *Op2;
872  if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
873  Changed |= VisitBinaryOperator(BinOp);
874  if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
875  Changed |= VisitUnaryOperator(UnOp);
876  if (match(Inst, m_Load(m_Value(Op1))))
877  Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
878  else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
879  Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
880  }
881 
882  if (ORE) {
883  RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
884  RemarkGen.emitRemarks();
885  }
886 
887  // Delete the instructions backwards, as it has a reduced likelihood of
888  // having to update as many def-use and use-def chains.
889  //
890  // Because we add to ToRemove during fusion we can't guarantee that defs
891  // are before uses. Change uses to undef temporarily as these should get
892  // removed as well.
893  //
894  // For verification, we keep track of where we changed uses to undefs in
895  // UndefedInsts and then check that we in fact remove them.
896  SmallSet<Instruction *, 16> UndefedInsts;
897  for (auto *Inst : reverse(ToRemove)) {
898  for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
899  if (auto *Undefed = dyn_cast<Instruction>(U.getUser()))
900  UndefedInsts.insert(Undefed);
901  U.set(UndefValue::get(Inst->getType()));
902  }
903  Inst->eraseFromParent();
904  UndefedInsts.erase(Inst);
905  }
906  if (!UndefedInsts.empty()) {
907  // If we didn't remove all undefed instructions, it's a hard error.
908  dbgs() << "Undefed but present instructions:\n";
909  for (auto *I : UndefedInsts)
910  dbgs() << *I << "\n";
911  llvm_unreachable("Undefed but instruction not removed");
912  }
913 
914  return Changed;
915  }
916 
917  /// Turns \p BasePtr into an elementwise pointer to \p EltType.
918  Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
919  unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
920  Type *EltPtrType = PointerType::get(EltType, AS);
921  return Builder.CreatePointerCast(BasePtr, EltPtrType);
922  }
923 
924  /// Replace intrinsic calls
925  bool VisitCallInst(CallInst *Inst) {
926  if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
927  return false;
928 
929  switch (Inst->getCalledFunction()->getIntrinsicID()) {
930  case Intrinsic::matrix_multiply:
931  LowerMultiply(Inst);
932  break;
933  case Intrinsic::matrix_transpose:
934  LowerTranspose(Inst);
935  break;
936  case Intrinsic::matrix_column_major_load:
937  LowerColumnMajorLoad(Inst);
938  break;
939  case Intrinsic::matrix_column_major_store:
940  LowerColumnMajorStore(Inst);
941  break;
942  default:
943  return false;
944  }
945  return true;
946  }
947 
948  /// Compute the alignment for a column/row \p Idx with \p Stride between them.
949  /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
950  /// ConstantInt, reduce the initial alignment based on the byte offset. For
951  /// non-ConstantInt strides, return the common alignment of the initial
952  /// alignment and the element size in bytes.
953  Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
954  MaybeAlign A) const {
955  Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
956  if (Idx == 0)
957  return InitialAlign;
958 
959  TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
960  if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
961  uint64_t StrideInBytes =
962  ConstStride->getZExtValue() * ElementSizeInBits / 8;
963  return commonAlignment(InitialAlign, Idx * StrideInBytes);
964  }
965  return commonAlignment(InitialAlign, ElementSizeInBits / 8);
966  }
967 
968  /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
969  /// vectors.
970  MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
971  bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
972  auto *VType = cast<VectorType>(Ty);
973  Type *EltTy = VType->getElementType();
974  Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
975  Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
976  MatrixTy Result;
977  for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
978  Value *GEP = computeVectorAddr(
979  EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
980  Stride, Shape.getStride(), EltTy, Builder);
981  Value *Vector = Builder.CreateAlignedLoad(
982  VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
983  IsVolatile, "col.load");
984 
985  Result.addVector(Vector);
986  }
987  return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
988  Result.getNumVectors());
989  }
990 
991  /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
992  /// starting at \p MatrixPtr[I][J].
993  MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
994  ShapeInfo MatrixShape, Value *I, Value *J,
995  ShapeInfo ResultShape, Type *EltTy,
996  IRBuilder<> &Builder) {
997 
998  Value *Offset = Builder.CreateAdd(
999  Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1000 
1001  unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1002  Value *EltPtr =
1003  Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1004  Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1005  auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1006  ResultShape.NumColumns);
1007  Type *TilePtrTy = PointerType::get(TileTy, AS);
1008  Value *TilePtr =
1009  Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1010 
1011  return loadMatrix(TileTy, TilePtr, Align,
1012  Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1013  ResultShape, Builder);
1014  }
1015 
1016  /// Lower a load instruction with shape information.
1017  void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1018  bool IsVolatile, ShapeInfo Shape) {
1019  IRBuilder<> Builder(Inst);
1020  finalizeLowering(Inst,
1021  loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1022  Shape, Builder),
1023  Builder);
1024  }
1025 
1026  /// Lowers llvm.matrix.column.major.load.
1027  ///
1028  /// The intrinsic loads a matrix from memory using a stride between columns.
1029  void LowerColumnMajorLoad(CallInst *Inst) {
1031  "Intrinsic only supports column-major layout!");
1032  Value *Ptr = Inst->getArgOperand(0);
1033  Value *Stride = Inst->getArgOperand(1);
1034  LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1035  cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1036  {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1037  }
1038 
1039  /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1040  /// MatrixPtr[I][J].
1041  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1042  MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1043  Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1044  Value *Offset = Builder.CreateAdd(
1045  Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1046 
1047  unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1048  Value *EltPtr =
1049  Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1050  Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1051  auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1052  StoreVal.getNumColumns());
1053  Type *TilePtrTy = PointerType::get(TileTy, AS);
1054  Value *TilePtr =
1055  Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1056 
1057  storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1058  Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1059  }
1060 
1061  /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1062  /// vectors.
1063  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1064  MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1065  IRBuilder<> &Builder) {
1066  auto VType = cast<VectorType>(Ty);
1067  Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
1068  for (auto Vec : enumerate(StoreVal.vectors())) {
1069  Value *GEP = computeVectorAddr(
1070  EltPtr,
1071  Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1072  Vec.index()),
1073  Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1074  Builder.CreateAlignedStore(Vec.value(), GEP,
1075  getAlignForIndex(Vec.index(), Stride,
1076  VType->getElementType(),
1077  MAlign),
1078  IsVolatile);
1079  }
1080  return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1081  StoreVal.getNumVectors());
1082  }
1083 
1084  /// Lower a store instruction with shape information.
1085  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1086  Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1087  IRBuilder<> Builder(Inst);
1088  auto StoreVal = getMatrix(Matrix, Shape, Builder);
1089  finalizeLowering(Inst,
1090  storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1091  IsVolatile, Builder),
1092  Builder);
1093  }
1094 
1095  /// Lowers llvm.matrix.column.major.store.
1096  ///
1097  /// The intrinsic store a matrix back memory using a stride between columns.
1098  void LowerColumnMajorStore(CallInst *Inst) {
1100  "Intrinsic only supports column-major layout!");
1101  Value *Matrix = Inst->getArgOperand(0);
1102  Value *Ptr = Inst->getArgOperand(1);
1103  Value *Stride = Inst->getArgOperand(2);
1104  LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1105  cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1106  {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1107  }
1108 
1109  // Set elements I..I+NumElts-1 to Block
1110  Value *insertVector(Value *Col, unsigned I, Value *Block,
1111  IRBuilder<> &Builder) {
1112 
1113  // First, bring Block to the same size as Col
1114  unsigned BlockNumElts =
1115  cast<FixedVectorType>(Block->getType())->getNumElements();
1116  unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1117  assert(NumElts >= BlockNumElts && "Too few elements for current block");
1118 
1119  Block = Builder.CreateShuffleVector(
1120  Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1121 
1122  // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1123  // 8, 4, 5, 6
1125  unsigned i;
1126  for (i = 0; i < I; i++)
1127  Mask.push_back(i);
1128 
1129  unsigned VecNumElts =
1130  cast<FixedVectorType>(Col->getType())->getNumElements();
1131  for (; i < I + BlockNumElts; i++)
1132  Mask.push_back(i - I + VecNumElts);
1133 
1134  for (; i < VecNumElts; i++)
1135  Mask.push_back(i);
1136 
1137  return Builder.CreateShuffleVector(Col, Block, Mask);
1138  }
1139 
1140  Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1141  IRBuilder<> &Builder, bool AllowContraction,
1142  unsigned &NumComputeOps) {
1143  NumComputeOps += getNumOps(A->getType());
1144  if (!Sum)
1145  return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1146 
1147  if (UseFPOp) {
1148  if (AllowContraction) {
1149  // Use fmuladd for floating point operations and let the backend decide
1150  // if that's profitable.
1152  Func.getParent(), Intrinsic::fmuladd, A->getType());
1153  return Builder.CreateCall(FMulAdd, {A, B, Sum});
1154  }
1155  NumComputeOps += getNumOps(A->getType());
1156  Value *Mul = Builder.CreateFMul(A, B);
1157  return Builder.CreateFAdd(Sum, Mul);
1158  }
1159 
1160  NumComputeOps += getNumOps(A->getType());
1161  Value *Mul = Builder.CreateMul(A, B);
1162  return Builder.CreateAdd(Sum, Mul);
1163  }
1164 
1165  /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1166  /// users with shape information, there's nothing to do: they will use the
1167  /// cached value when they are lowered. For other users, \p Matrix is
1168  /// flattened and the uses are updated to use it. Also marks \p Inst for
1169  /// deletion.
1170  void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1171  IRBuilder<> &Builder) {
1172  auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1173  (void)inserted;
1174  assert(inserted.second && "multiple matrix lowering mapping");
1175 
1176  ToRemove.push_back(Inst);
1177  Value *Flattened = nullptr;
1178  for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1179  if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1180  if (!Flattened)
1181  Flattened = Matrix.embedInVector(Builder);
1182  U.set(Flattened);
1183  }
1184  }
1185  }
1186 
1187  /// Compute \p Result += \p A * \p B for input matrices with left-associating
1188  /// addition.
1189  ///
1190  /// We can fold a transpose into the operand that is used to extract scalars.
1191  /// This is the first operands with row-major and the second with
1192  /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1193  /// operand is transposed.
1194  void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1195  const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1196  bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1197  const unsigned VF = std::max<unsigned>(
1199  .getFixedSize() /
1200  Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
1201  1U);
1202  unsigned R = Result.getNumRows();
1203  unsigned C = Result.getNumColumns();
1204  unsigned M = A.getNumColumns();
1205 
1206  bool IsFP = Result.getElementType()->isFloatingPointTy();
1207  assert(A.isColumnMajor() == B.isColumnMajor() &&
1208  Result.isColumnMajor() == A.isColumnMajor() &&
1209  "operands must agree on matrix layout");
1210  unsigned NumComputeOps = 0;
1211 
1212  Builder.setFastMathFlags(FMF);
1213 
1214  if (A.isColumnMajor()) {
1215  // Multiply columns from the first operand with scalars from the second
1216  // operand. Then move along the K axes and accumulate the columns. With
1217  // this the adds can be vectorized without reassociation.
1218  for (unsigned J = 0; J < C; ++J) {
1219  unsigned BlockSize = VF;
1220  // If Result is zero, we don't need to accumulate in the K==0 iteration.
1221  bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1222 
1223  for (unsigned I = 0; I < R; I += BlockSize) {
1224  // Gradually lower the vectorization factor to cover the remainder.
1225  while (I + BlockSize > R)
1226  BlockSize /= 2;
1227 
1228  Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1229  : nullptr;
1230  for (unsigned K = 0; K < M; ++K) {
1231  Value *L = A.extractVector(I, K, BlockSize, Builder);
1232  Value *RH = Builder.CreateExtractElement(
1233  B.getColumn(IsScalarMatrixTransposed ? K : J),
1234  IsScalarMatrixTransposed ? J : K);
1235  Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1236  Sum =
1237  createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1238  IsFP, Builder, FMF.allowContract(), NumComputeOps);
1239  }
1240  Result.setVector(J,
1241  insertVector(Result.getVector(J), I, Sum, Builder));
1242  }
1243  }
1244  } else {
1245  // Multiply rows from the second operand with scalars from the first
1246  // operand. Then move along the K axes and accumulate the rows. With this
1247  // the adds can be vectorized without reassociation.
1248  for (unsigned I = 0; I < R; ++I) {
1249  unsigned BlockSize = VF;
1250  bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1251  for (unsigned J = 0; J < C; J += BlockSize) {
1252  // Gradually lower the vectorization factor to cover the remainder.
1253  while (J + BlockSize > C)
1254  BlockSize /= 2;
1255 
1256  Value *Sum = nullptr;
1257  for (unsigned K = 0; K < M; ++K) {
1258  Value *R = B.extractVector(K, J, BlockSize, Builder);
1259  Value *LH = Builder.CreateExtractElement(
1260  A.getVector(IsScalarMatrixTransposed ? K : I),
1261  IsScalarMatrixTransposed ? I : K);
1262  Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1263  Sum =
1264  createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1265  IsFP, Builder, FMF.allowContract(), NumComputeOps);
1266  }
1267  Result.setVector(I,
1268  insertVector(Result.getVector(I), J, Sum, Builder));
1269  }
1270  }
1271  }
1272  Result.addNumComputeOps(NumComputeOps);
1273  }
1274 
1275  /// Ensure that the memory in \p Load does not alias \p Store by potentially
1276  /// copying it to a new location. This new or otherwise the original location
1277  /// is returned.
1278  Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1279  CallInst *MatMul) {
1282 
1283  // If we can statically determine noalias we're good.
1284  if (AA->isNoAlias(LoadLoc, StoreLoc))
1285  return Load->getPointerOperand();
1286 
1287  // Create code to check if the memory locations of the Load and Store
1288  // overlap and if they do, copy Load's operand to a new buffer.
1289 
1290  // First, create new blocks for 2n part of the check and the copy.
1291  BasicBlock *Check0 = MatMul->getParent();
1292  // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1293  // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1294  // as we adjust Check0 and Check1's branches.
1296  for (BasicBlock *Succ : successors(Check0))
1297  DTUpdates.push_back({DT->Delete, Check0, Succ});
1298 
1299  BasicBlock *Check1 =
1300  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1301  nullptr, "alias_cont");
1302  BasicBlock *Copy =
1303  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1304  nullptr, "copy");
1305  BasicBlock *Fusion =
1306  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1307  nullptr, "no_alias");
1308 
1309  // Check if the loaded memory location begins before the end of the store
1310  // location. If the condition holds, they might overlap, otherwise they are
1311  // guaranteed to not overlap.
1312  IRBuilder<> Builder(MatMul);
1313  Check0->getTerminator()->eraseFromParent();
1314  Builder.SetInsertPoint(Check0);
1315  Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1316  Value *StoreBegin = Builder.CreatePtrToInt(
1317  const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1318  Value *StoreEnd = Builder.CreateAdd(
1319  StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1320  "store.end", true, true);
1321  Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1322  IntPtrTy, "load.begin");
1323  Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1324  Fusion);
1325 
1326  // Check if the store begins before the end of the load location. If the
1327  // condition holds, they alias, otherwise they are guaranteed to not
1328  // overlap.
1329  Check1->getTerminator()->eraseFromParent();
1330  Builder.SetInsertPoint(Check1, Check1->begin());
1331  Value *LoadEnd = Builder.CreateAdd(
1332  LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1333  "load.end", true, true);
1334  Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1335  Fusion);
1336 
1337  // Copy load operand to new alloca.
1338  Builder.SetInsertPoint(Copy, Copy->begin());
1339  auto *VT = cast<FixedVectorType>(Load->getType());
1340  // Use an array type for the alloca, to avoid potentially huge alignment
1341  // requirements for large vector types.
1342  auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1343  AllocaInst *Alloca =
1344  Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1345  Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo());
1346 
1347  Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(),
1348  Load->getAlign(), LoadLoc.Size.getValue());
1349  Builder.SetInsertPoint(Fusion, Fusion->begin());
1350  PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1351  PHI->addIncoming(Load->getPointerOperand(), Check0);
1352  PHI->addIncoming(Load->getPointerOperand(), Check1);
1353  PHI->addIncoming(BC, Copy);
1354 
1355  // Adjust DT.
1356  DTUpdates.push_back({DT->Insert, Check0, Check1});
1357  DTUpdates.push_back({DT->Insert, Check0, Fusion});
1358  DTUpdates.push_back({DT->Insert, Check1, Copy});
1359  DTUpdates.push_back({DT->Insert, Check1, Fusion});
1360  DT->applyUpdates(DTUpdates);
1361  return PHI;
1362  }
1363 
1364  bool isFusionProfitable(CallInst *MatMul) {
1365  if (ForceFusion)
1366  return true;
1367 
1368  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1369  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1370 
1371  const unsigned R = LShape.NumRows;
1372  const unsigned C = RShape.NumColumns;
1373  const unsigned M = LShape.NumColumns;
1374  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1375 
1376  const unsigned VF = std::max<unsigned>(
1378  .getFixedSize() /
1379  EltType->getPrimitiveSizeInBits().getFixedSize(),
1380  1U);
1381 
1382  // Cost model for tiling
1383  //
1384  // For tiling to be beneficial, we need reuse either along the R or
1385  // the C axis. We vectorize along the R axis so that means at least
1386  // 3 elements.
1387  // TODO: Also consider cost of copying if operands alias.
1388  if (R <= VF && C == 1)
1389  return false;
1390  // Then we need enough elements to exceed the number of vector
1391  // registers we have. Note that this is an oversimplification since
1392  // fusing also takes some extra loads which may exceed the number of
1393  // reloads necessary.
1394  unsigned Op0Regs = (R + VF - 1) / VF * M;
1395  unsigned Op1Regs = (M + VF - 1) / VF * C;
1396  return Op0Regs + Op1Regs >
1398  }
1399 
1400  MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1401  MatrixTy Res;
1402  auto *ColumType = FixedVectorType::get(EltType, R);
1403  for (unsigned I = 0; I < C; ++I)
1404  Res.addVector(ConstantAggregateZero::get(ColumType));
1405  return Res;
1406  }
1407 
1408  void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1409  Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1410  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1411 
1412  // Create the main tiling loop nest.
1413  TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1415  Instruction *InsertI = cast<Instruction>(MatMul);
1416  BasicBlock *Start = InsertI->getParent();
1417  BasicBlock *End =
1418  SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1419  IRBuilder<> Builder(MatMul);
1420  BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1421 
1422  Type *TileVecTy =
1424  MatrixTy TileResult;
1425  // Insert in the inner loop header.
1426  Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
1427  // Create PHI nodes for the result columns to accumulate across iterations.
1428  SmallVector<PHINode *, 4> ColumnPhis;
1429  for (unsigned I = 0; I < TileSize; I++) {
1430  auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1431  Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1432  TI.RowLoopHeader->getSingleSuccessor());
1433  TileResult.addVector(Phi);
1434  ColumnPhis.push_back(Phi);
1435  }
1436 
1437  // Insert in the inner loop body, which computes
1438  // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1439  Builder.SetInsertPoint(InnerBody->getTerminator());
1440  // Load tiles of the operands.
1441  MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK,
1442  {TileSize, TileSize}, EltType, Builder);
1443  MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol,
1444  {TileSize, TileSize}, EltType, Builder);
1445  emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1446  getFastMathFlags(MatMul));
1447  // Store result after the inner loop is done.
1448  Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
1449  storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1450  Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1451  TI.CurrentRow, TI.CurrentCol, EltType, Builder);
1452 
1453  for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1454  ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch);
1455 
1456  // Force unrolling of a few iterations of the inner loop, to make sure there
1457  // is enough work per iteration.
1458  // FIXME: The unroller should make this decision directly instead, but
1459  // currently the cost-model is not up to the task.
1460  unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1461  addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader),
1462  "llvm.loop.unroll.count", InnerLoopUnrollCount);
1463  }
1464 
1465  void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1466  StoreInst *Store,
1467  SmallPtrSetImpl<Instruction *> &FusedInsts) {
1469  "Tiling only supported for column-major matrixes at the moment!");
1470  if (!isFusionProfitable(MatMul))
1471  return;
1472 
1473  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1474  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1475 
1476  const unsigned R = LShape.NumRows;
1477  const unsigned C = RShape.NumColumns;
1478  const unsigned M = LShape.NumColumns;
1479  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1480 
1481  Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1482  Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1483  Value *CPtr = Store->getPointerOperand();
1484 
1485  if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1486  createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1487  else {
1489  for (unsigned J = 0; J < C; J += TileSize)
1490  for (unsigned I = 0; I < R; I += TileSize) {
1491  const unsigned TileR = std::min(R - I, unsigned(TileSize));
1492  const unsigned TileC = std::min(C - J, unsigned(TileSize));
1493  MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1494 
1495  for (unsigned K = 0; K < M; K += TileSize) {
1496  const unsigned TileM = std::min(M - K, unsigned(TileSize));
1497  MatrixTy A =
1498  loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1499  LShape, Builder.getInt64(I), Builder.getInt64(K),
1500  {TileR, TileM}, EltType, Builder);
1501  MatrixTy B =
1502  loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1503  RShape, Builder.getInt64(K), Builder.getInt64(J),
1504  {TileM, TileC}, EltType, Builder);
1505  emitMatrixMultiply(Res, A, B, Builder, true, false,
1506  getFastMathFlags(MatMul));
1507  }
1508  storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1509  Builder.getInt64(I), Builder.getInt64(J), EltType,
1510  Builder);
1511  }
1512  }
1513 
1514  // Mark eliminated instructions as fused and remove them.
1515  FusedInsts.insert(Store);
1516  FusedInsts.insert(MatMul);
1517  Store->eraseFromParent();
1518  MatMul->eraseFromParent();
1519  if (LoadOp0->hasNUses(0)) {
1520  FusedInsts.insert(LoadOp0);
1521  LoadOp0->eraseFromParent();
1522  }
1523  if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1524  FusedInsts.insert(LoadOp1);
1525  LoadOp1->eraseFromParent();
1526  }
1527  }
1528 
1529  /// Try to lower matrix multiply chains by fusing operations.
1530  ///
1531  /// Call finalizeLowering on lowered instructions. Instructions that are
1532  /// completely eliminated by fusion are added to \p FusedInsts.
1533  void LowerMatrixMultiplyFused(CallInst *MatMul,
1534  SmallPtrSetImpl<Instruction *> &FusedInsts) {
1535  if (!FuseMatrix || !DT)
1536  return;
1537 
1538  assert(AA && LI && "Analyses should be available");
1539 
1540  Value *A = MatMul->getArgOperand(0);
1541  Value *B = MatMul->getArgOperand(1);
1542 
1543  // We can fold the transpose into the operand that is used to fetch scalars.
1544  Value *T;
1546  ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1547  : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1548  IRBuilder<> Builder(MatMul);
1549  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1550  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1551  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1552  const unsigned R = LShape.NumRows;
1553  const unsigned M = LShape.NumColumns;
1554  const unsigned C = RShape.NumColumns;
1555 
1556  MatrixTy MA;
1557  MatrixTy MB;
1558 
1559  Value *Transpose;
1561  MA = getMatrix(A, ShapeInfo(R, M), Builder);
1562  MB = getMatrix(T, ShapeInfo(C, M), Builder);
1563  Transpose = B;
1564  } else {
1565  MA = getMatrix(T, ShapeInfo(R, M), Builder);
1566  MB = getMatrix(B, ShapeInfo(C, M), Builder);
1567  Transpose = A;
1568  }
1569 
1570  // Initialize the output
1571  MatrixTy Result(R, C, EltType);
1572 
1573  emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1574  getFastMathFlags(MatMul));
1575 
1576  FusedInsts.insert(MatMul);
1577  if (Transpose->hasOneUse()) {
1578  FusedInsts.insert(cast<Instruction>(Transpose));
1579  ToRemove.push_back(cast<Instruction>(Transpose));
1580  // TODO: add a fake entry for the folded instruction so that this is
1581  // included in the expression in the remark.
1582  Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1583  }
1584  finalizeLowering(MatMul, Result, Builder);
1585  return;
1586  }
1587 
1588  if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1589  return;
1590 
1591  // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1592  // since the single store user will be lowered as part of this.
1593  auto *LoadOp0 = dyn_cast<LoadInst>(A);
1594  auto *LoadOp1 = dyn_cast<LoadInst>(B);
1595  auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1596  if (LoadOp0 && LoadOp1 && Store) {
1597  // The store address must dominate the MatMul instruction, otherwise
1598  // we create invalid IR.
1599  SetVector<Value *> WorkList;
1600  WorkList.insert(Store->getOperand(1));
1602  for (unsigned I = 0; I != WorkList.size(); ++I) {
1603  Value *Current = WorkList[I];
1604  auto *CurrI = dyn_cast<Instruction>(Current);
1605  if (!CurrI)
1606  continue;
1607  if (isa<PHINode>(CurrI))
1608  return;
1609  if (DT->dominates(CurrI, MatMul))
1610  continue;
1611  if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1612  return;
1613  ToHoist.push_back(CurrI);
1614  WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1615  }
1616 
1617  sort(ToHoist, [this](Instruction *A, Instruction *B) {
1618  return DT->dominates(A, B);
1619  });
1620  for (Instruction *I : ToHoist)
1621  I->moveBefore(MatMul);
1622 
1623  emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1624  return;
1625  }
1626  }
1627 
1628  /// Lowers llvm.matrix.multiply.
1629  void LowerMultiply(CallInst *MatMul) {
1630  IRBuilder<> Builder(MatMul);
1631  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1632  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1633  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1634 
1635  const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1636  const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1637  assert(Lhs.getElementType() == Rhs.getElementType() &&
1638  "Matrix multiply argument element types do not match.");
1639 
1640  const unsigned R = LShape.NumRows;
1641  const unsigned C = RShape.NumColumns;
1642  assert(LShape.NumColumns == RShape.NumRows);
1643 
1644  // Initialize the output
1645  MatrixTy Result(R, C, EltType);
1646  assert(Lhs.getElementType() == Result.getElementType() &&
1647  "Matrix multiply result element type does not match arguments.");
1648 
1649  emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
1650  getFastMathFlags(MatMul));
1651  finalizeLowering(MatMul, Result, Builder);
1652  }
1653 
1654  /// Lowers llvm.matrix.transpose.
1655  void LowerTranspose(CallInst *Inst) {
1656  MatrixTy Result;
1657  IRBuilder<> Builder(Inst);
1658  Value *InputVal = Inst->getArgOperand(0);
1659  VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1660  ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1661  MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1662 
1663  const unsigned NewNumVecs =
1664  InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1665  const unsigned NewNumElts =
1666  InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1667 
1668  for (unsigned I = 0; I < NewNumVecs; ++I) {
1669  // Build a single result vector. First initialize it.
1670  Value *ResultVector = UndefValue::get(
1671  FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1672  // Go through the old elements and insert it into the resulting vector.
1673  for (auto J : enumerate(InputMatrix.vectors())) {
1674  Value *Elt = Builder.CreateExtractElement(J.value(), I);
1675  // Row and column indices are transposed.
1676  ResultVector =
1677  Builder.CreateInsertElement(ResultVector, Elt, J.index());
1678  }
1679  Result.addVector(ResultVector);
1680  }
1681 
1682  // TODO: Improve estimate of operations needed for transposes. Currently we
1683  // just count the insertelement/extractelement instructions, but do not
1684  // account for later simplifications/combines.
1685  finalizeLowering(
1686  Inst,
1687  Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1688  .addNumExposedTransposes(1),
1689  Builder);
1690  }
1691 
1692  /// Lower load instructions, if shape information is available.
1693  bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
1694  auto I = ShapeMap.find(Inst);
1695  if (I == ShapeMap.end())
1696  return false;
1697 
1698  LowerLoad(Inst, Ptr, Inst->getAlign(),
1699  Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1700  I->second);
1701  return true;
1702  }
1703 
1704  bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
1705  IRBuilder<> &Builder) {
1706  auto I = ShapeMap.find(StoredVal);
1707  if (I == ShapeMap.end())
1708  return false;
1709 
1710  LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
1711  Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1712  I->second);
1713  return true;
1714  }
1715 
1716  /// Lower binary operators, if shape information is available.
1717  bool VisitBinaryOperator(BinaryOperator *Inst) {
1718  auto I = ShapeMap.find(Inst);
1719  if (I == ShapeMap.end())
1720  return false;
1721 
1722  Value *Lhs = Inst->getOperand(0);
1723  Value *Rhs = Inst->getOperand(1);
1724 
1725  IRBuilder<> Builder(Inst);
1726  ShapeInfo &Shape = I->second;
1727 
1728  MatrixTy Result;
1729  MatrixTy A = getMatrix(Lhs, Shape, Builder);
1730  MatrixTy B = getMatrix(Rhs, Shape, Builder);
1731  assert(A.isColumnMajor() == B.isColumnMajor() &&
1732  Result.isColumnMajor() == A.isColumnMajor() &&
1733  "operands must agree on matrix layout");
1734 
1735  Builder.setFastMathFlags(getFastMathFlags(Inst));
1736 
1737  // Helper to perform binary op on vectors.
1738  auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
1739  switch (Inst->getOpcode()) {
1740  case Instruction::Add:
1741  return Builder.CreateAdd(LHS, RHS);
1742  case Instruction::Mul:
1743  return Builder.CreateMul(LHS, RHS);
1744  case Instruction::Sub:
1745  return Builder.CreateSub(LHS, RHS);
1746  case Instruction::FAdd:
1747  return Builder.CreateFAdd(LHS, RHS);
1748  case Instruction::FMul:
1749  return Builder.CreateFMul(LHS, RHS);
1750  case Instruction::FSub:
1751  return Builder.CreateFSub(LHS, RHS);
1752  default:
1753  llvm_unreachable("Unsupported binary operator for matrix");
1754  }
1755  };
1756 
1757  for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1758  Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
1759 
1760  finalizeLowering(Inst,
1761  Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1762  Result.getNumVectors()),
1763  Builder);
1764  return true;
1765  }
1766 
1767  /// Lower unary operators, if shape information is available.
1768  bool VisitUnaryOperator(UnaryOperator *Inst) {
1769  auto I = ShapeMap.find(Inst);
1770  if (I == ShapeMap.end())
1771  return false;
1772 
1773  Value *Op = Inst->getOperand(0);
1774 
1775  IRBuilder<> Builder(Inst);
1776  ShapeInfo &Shape = I->second;
1777 
1778  MatrixTy Result;
1779  MatrixTy M = getMatrix(Op, Shape, Builder);
1780 
1781  Builder.setFastMathFlags(getFastMathFlags(Inst));
1782 
1783  // Helper to perform unary op on vectors.
1784  auto BuildVectorOp = [&Builder, Inst](Value *Op) {
1785  switch (Inst->getOpcode()) {
1786  case Instruction::FNeg:
1787  return Builder.CreateFNeg(Op);
1788  default:
1789  llvm_unreachable("Unsupported unary operator for matrix");
1790  }
1791  };
1792 
1793  for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1794  Result.addVector(BuildVectorOp(M.getVector(I)));
1795 
1796  finalizeLowering(Inst,
1797  Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1798  Result.getNumVectors()),
1799  Builder);
1800  return true;
1801  }
1802 
1803  /// Helper to linearize a matrix expression tree into a string. Currently
1804  /// matrix expressions are linarized by starting at an expression leaf and
1805  /// linearizing bottom up.
1806  struct ExprLinearizer {
1807  unsigned LengthToBreak = 100;
1808  std::string Str;
1810  unsigned LineLength = 0;
1811  const DataLayout &DL;
1812 
1813  /// Mapping from instructions to matrixes. It is used to identify
1814  /// matrix instructions.
1815  const MapVector<Value *, MatrixTy> &Inst2Matrix;
1816 
1817  /// Mapping from values to the leaves of all expressions that the value is
1818  /// part of.
1820 
1821  /// Set of matrix expressions in the scope of a given DISubprogram.
1822  const SmallSetVector<Value *, 32> &ExprsInSubprogram;
1823 
1824  /// Leaf node of the expression to linearize.
1825  Value *Leaf;
1826 
1827  /// Used to keep track of sub-expressions that get reused while linearizing
1828  /// the expression. Re-used sub-expressions are marked as (reused).
1829  SmallPtrSet<Value *, 8> ReusedExprs;
1830 
1831  ExprLinearizer(const DataLayout &DL,
1832  const MapVector<Value *, MatrixTy> &Inst2Matrix,
1833  const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1834  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1835  Value *Leaf)
1836  : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
1837  ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
1838 
1839  void indent(unsigned N) {
1840  LineLength += N;
1841  for (unsigned i = 0; i < N; i++)
1842  Stream << " ";
1843  }
1844 
1845  void lineBreak() {
1846  Stream << "\n";
1847  LineLength = 0;
1848  }
1849 
1850  void maybeIndent(unsigned Indent) {
1851  if (LineLength >= LengthToBreak)
1852  lineBreak();
1853 
1854  if (LineLength == 0)
1855  indent(Indent);
1856  }
1857 
1858  void write(StringRef S) {
1859  LineLength += S.size();
1860  Stream << S;
1861  }
1862 
1863  Value *getUnderlyingObjectThroughLoads(Value *V) {
1864  if (Value *Ptr = getPointerOperand(V))
1865  return getUnderlyingObjectThroughLoads(Ptr);
1866  else if (V->getType()->isPointerTy())
1867  return getUnderlyingObject(V);
1868  return V;
1869  }
1870 
1871  /// Returns true if \p V is a matrix value in the given subprogram.
1872  bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
1873 
1874  /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
1875  /// \p SS.
1876  void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
1877  auto M = Inst2Matrix.find(V);
1878  if (M == Inst2Matrix.end())
1879  SS << "unknown";
1880  else {
1881  SS << M->second.getNumRows();
1882  SS << "x";
1883  SS << M->second.getNumColumns();
1884  }
1885  }
1886 
1887  /// Write the called function name. Handles calls to llvm.matrix.*
1888  /// specially: we write the name, followed by the dimensions of the input
1889  /// matrixes, followed by the scalar type name.
1890  void writeFnName(CallInst *CI) {
1891  if (!CI->getCalledFunction())
1892  write("<no called fn>");
1893  else {
1895  if (!Name.startswith("llvm.matrix")) {
1896  write(Name);
1897  return;
1898  }
1899  auto *II = cast<IntrinsicInst>(CI);
1901  .drop_front(StringRef("llvm.matrix.").size()));
1902  write(".");
1903  std::string Tmp;
1904  raw_string_ostream SS(Tmp);
1905 
1906  switch (II->getIntrinsicID()) {
1907  case Intrinsic::matrix_multiply:
1908  prettyPrintMatrixType(II->getOperand(0), SS);
1909  SS << ".";
1910  prettyPrintMatrixType(II->getOperand(1), SS);
1911  SS << "." << *II->getType()->getScalarType();
1912  break;
1913  case Intrinsic::matrix_transpose:
1914  prettyPrintMatrixType(II->getOperand(0), SS);
1915  SS << "." << *II->getType()->getScalarType();
1916  break;
1917  case Intrinsic::matrix_column_major_load:
1918  prettyPrintMatrixType(II, SS);
1919  SS << "." << *II->getType()->getScalarType();
1920  break;
1921  case Intrinsic::matrix_column_major_store:
1922  prettyPrintMatrixType(II->getOperand(0), SS);
1923  SS << "." << *II->getOperand(0)->getType()->getScalarType();
1924  break;
1925  default:
1926  llvm_unreachable("Unhandled case");
1927  }
1928  SS.flush();
1929  write(Tmp);
1930  }
1931  }
1932 
1933  unsigned getNumShapeArgs(CallInst *CI) const {
1934  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
1935  switch (II->getIntrinsicID()) {
1936  case Intrinsic::matrix_multiply:
1937  return 3;
1938  case Intrinsic::matrix_transpose:
1939  return 2;
1940  case Intrinsic::matrix_column_major_load:
1941  case Intrinsic::matrix_column_major_store:
1942  return 3;
1943  default:
1944  return 0;
1945  }
1946  }
1947  return 0;
1948  }
1949 
1950  /// Special printing for values: for pointers, we print if they refer to an
1951  /// (function) external address or a stack address, for other values we
1952  /// either print the constant or "scalar"/"matrix" for other values.
1953  void write(Value *V) {
1954  V = getUnderlyingObjectThroughLoads(V);
1955  if (V->getType()->isPointerTy()) {
1956  if (isa<AllocaInst>(V)) {
1957  Stream << "stack addr";
1958  LineLength += StringRef("stack addr").size();
1959  } else {
1960  Stream << "addr";
1961  LineLength += StringRef("addr").size();
1962  }
1963  if (!V->getName().empty()) {
1964  Stream << " %" << V->getName() << "";
1965  LineLength += V->getName().size() + 2;
1966  }
1967  return;
1968  }
1969 
1970  std::string Tmp;
1971  raw_string_ostream TmpStream(Tmp);
1972 
1973  if (auto *CI = dyn_cast<ConstantInt>(V))
1974  TmpStream << CI->getValue();
1975  else if (isa<Constant>(V))
1976  TmpStream << "constant";
1977  else {
1978  if (isMatrix(V))
1979  TmpStream << "matrix";
1980  else
1981  TmpStream << "scalar";
1982  }
1983  TmpStream.flush();
1984  Tmp = std::string(StringRef(Tmp).trim());
1985  LineLength += Tmp.size();
1986  Stream << Tmp;
1987  }
1988 
1989  /// Linearize expression \p Expr starting at an indentation of \p Indent.
1990  /// Expressions that are re-used multiple times are prefixed with (reused)
1991  /// at the re-used root instruction.
1992  void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
1993  bool ParentShared) {
1994  auto *I = cast<Instruction>(Expr);
1995  maybeIndent(Indent);
1997 
1998  // Is Expr shared with other expression leaves?
1999  bool ExprShared = false;
2000 
2001  // Deal with shared subtrees. Mark them as shared, if required.
2002  if (!ParentShared) {
2003  auto SI = Shared.find(Expr);
2004  assert(SI != Shared.end() && SI->second.count(Leaf));
2005 
2006  for (Value *S : SI->second) {
2007  if (S == Leaf)
2008  continue;
2009  DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2010  write("shared with remark at line " + std::to_string(DL.getLine()) +
2011  " column " + std::to_string(DL.getCol()) + " (");
2012  }
2013  ExprShared = SI->second.size() > 1;
2014  }
2015 
2016  bool Reused = !ReusedExprs.insert(Expr).second;
2017  if (Reused && !ParentReused)
2018  write("(reused) ");
2019 
2020  if (auto *CI = dyn_cast<CallInst>(I)) {
2021  writeFnName(CI);
2022 
2023  Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2024  } else if (isa<BitCastInst>(Expr)) {
2025  // Special case bitcasts, which are used to materialize matrixes from
2026  // non-matrix ops.
2027  write("matrix");
2028  return;
2029  } else {
2030  Ops.append(I->value_op_begin(), I->value_op_end());
2031  write(std::string(I->getOpcodeName()));
2032  }
2033 
2034  write(std::string("("));
2035 
2036  unsigned NumOpsToBreak = 1;
2037  if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2038  NumOpsToBreak = 2;
2039 
2040  for (Value *Op : Ops) {
2041  if (Ops.size() > NumOpsToBreak)
2042  lineBreak();
2043 
2044  maybeIndent(Indent + 1);
2045  if (isMatrix(Op))
2046  linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2047  else
2048  write(Op);
2049  if (Op != Ops.back())
2050  write(", ");
2051  }
2052 
2053  write(")");
2054  }
2055 
2056  const std::string &getResult() {
2057  Stream.flush();
2058  return Str;
2059  }
2060  };
2061 
2062  /// Generate remarks for matrix operations in a function. To generate remarks
2063  /// for matrix expressions, the following approach is used:
2064  /// 1. Use the inlined-at debug information to group matrix operations to the
2065  /// DISubprograms they are contained in.
2066  /// 2. Collect leaves of matrix expressions (done in
2067  /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2068  // mapping. Leaves are lowered matrix instructions without other matrix
2069  // users (like stores) in the current subprogram.
2070  /// 3. For each leaf, create a remark containing a linearizied version of the
2071  /// matrix expression. The expression is linearized by a recursive
2072  /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2073  /// that multiple leaves can share sub-expressions. Shared subexpressions
2074  /// are explicitly marked as shared().
2075  struct RemarkGenerator {
2076  const MapVector<Value *, MatrixTy> &Inst2Matrix;
2078  Function &Func;
2079  const DataLayout &DL;
2080 
2081  RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2082  OptimizationRemarkEmitter &ORE, Function &Func)
2083  : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2084  DL(Func.getParent()->getDataLayout()) {}
2085 
2086  /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2087  /// instructions in Inst2Matrix returning void or without any users in
2088  /// \p ExprsInSubprogram. Currently that should only include stores.
2090  getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2091  SmallVector<Value *, 4> Leaves;
2092  for (auto *Expr : ExprsInSubprogram)
2093  if (Expr->getType()->isVoidTy() ||
2094  !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2095  return ExprsInSubprogram.count(U);
2096  }))
2097  Leaves.push_back(Expr);
2098  return Leaves;
2099  }
2100 
2101  /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2102  /// to all visited expressions in \p Shared. Limit the matrix operations to
2103  /// the ones in \p ExprsInSubprogram.
2104  void collectSharedInfo(Value *Leaf, Value *V,
2105  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2106  DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
2107 
2108  if (!ExprsInSubprogram.count(V))
2109  return;
2110 
2111  auto I = Shared.insert({V, {}});
2112  I.first->second.insert(Leaf);
2113 
2114  for (Value *Op : cast<Instruction>(V)->operand_values())
2115  collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2116  }
2117 
2118  /// Calculate the number of exclusive and shared op counts for expression
2119  /// starting at \p V. Expressions used multiple times are counted once.
2120  /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2121  std::pair<OpInfoTy, OpInfoTy>
2122  sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2123  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2124  DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2125  if (!ExprsInSubprogram.count(Root))
2126  return {};
2127 
2128  // Already counted this expression. Stop.
2129  if (!ReusedExprs.insert(Root).second)
2130  return {};
2131 
2132  OpInfoTy SharedCount;
2133  OpInfoTy Count;
2134 
2135  auto I = Shared.find(Root);
2136  auto CM = Inst2Matrix.find(Root);
2137  if (I->second.size() == 1)
2138  Count = CM->second.getOpInfo();
2139  else
2140  SharedCount = CM->second.getOpInfo();
2141 
2142  for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2143  auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2144  Count += C.first;
2145  SharedCount += C.second;
2146  }
2147  return {Count, SharedCount};
2148  }
2149 
2150  void emitRemarks() {
2151  if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2152  return;
2153 
2154  // Map matrix operations to their containting subprograms, by traversing
2155  // the inlinedAt chain. If the function does not have a DISubprogram, we
2156  // only map them to the containing function.
2158  for (auto &KV : Inst2Matrix) {
2159  if (Func.getSubprogram()) {
2160  auto *I = cast<Instruction>(KV.first);
2161  DILocation *Context = I->getDebugLoc();
2162  while (Context) {
2163  auto I =
2164  Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2165  I.first->second.push_back(KV.first);
2167  }
2168  } else {
2169  auto I = Subprog2Exprs.insert({nullptr, {}});
2170  I.first->second.push_back(KV.first);
2171  }
2172  }
2173  for (auto &KV : Subprog2Exprs) {
2174  SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2175  KV.second.end());
2176  auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2177 
2179  for (Value *Leaf : Leaves)
2180  collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2181 
2182  // Generate remarks for each leaf.
2183  for (auto *L : Leaves) {
2184 
2185  DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2186  DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2187  while (Context) {
2188  if (getSubprogram(Context->getScope()) == KV.first) {
2189  Loc = Context;
2190  break;
2191  }
2193  }
2194 
2195  SmallPtrSet<Value *, 8> ReusedExprs;
2196  OpInfoTy Counts, SharedCounts;
2197  std::tie(Counts, SharedCounts) =
2198  sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2199 
2200  OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2201  cast<Instruction>(L)->getParent());
2202 
2203  Rem << "Lowered with ";
2204  Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2205  << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2206  << ore::NV("NumComputeOps", Counts.NumComputeOps)
2207  << " compute ops, "
2208  << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2209  << " exposed transposes";
2210 
2211  if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2212  SharedCounts.NumComputeOps > 0) {
2213  Rem << ",\nadditionally "
2214  << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2215  << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2216  << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2217  << " compute ops"
2218  << " are shared with other expressions";
2219  }
2220 
2221  Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2222  ORE.emit(Rem);
2223  }
2224  }
2225  }
2226 
2227  std::string
2228  linearize(Value *L,
2229  const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2230  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2231  const DataLayout &DL) {
2232  ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2233  Lin.linearizeExpr(L, 0, false, false);
2234  return Lin.getResult();
2235  }
2236  };
2237 };
2238 } // namespace
2239 
2242  auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2243  OptimizationRemarkEmitter *ORE = nullptr;
2244  AAResults *AA = nullptr;
2245  DominatorTree *DT = nullptr;
2246  LoopInfo *LI = nullptr;
2247 
2248  if (!Minimal) {
2250  AA = &AM.getResult<AAManager>(F);
2251  DT = &AM.getResult<DominatorTreeAnalysis>(F);
2252  LI = &AM.getResult<LoopAnalysis>(F);
2253  }
2254 
2255  LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2256  if (LMT.Visit()) {
2257  PreservedAnalyses PA;
2258  if (!Minimal) {
2259  PA.preserve<LoopAnalysis>();
2261  }
2262  return PA;
2263  }
2264  return PreservedAnalyses::all();
2265 }
2266 
2268  raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2270  OS, MapClassName2PassName);
2271  OS << "<";
2272  if (Minimal)
2273  OS << "minimal";
2274  OS << ">";
2275 }
2276 
2277 namespace {
2278 
2279 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
2280 public:
2281  static char ID;
2282 
2283  LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
2286  }
2287 
2288  bool runOnFunction(Function &F) override {
2289  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2290  auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2291  auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2292  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2293  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2294  LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
2295  bool C = LMT.Visit();
2296  return C;
2297  }
2298 
2299  void getAnalysisUsage(AnalysisUsage &AU) const override {
2307  }
2308 };
2309 } // namespace
2310 
2311 static const char pass_name[] = "Lower the matrix intrinsics";
2313 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2314  false, false)
2319 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2321 
2323  return new LowerMatrixIntrinsicsLegacyPass();
2324 }
2325 
2326 namespace {
2327 
2328 /// A lightweight version of the matrix lowering pass that only requires TTI.
2329 /// Advanced features that require DT, AA or ORE like tiling are disabled. This
2330 /// is used to lower matrix intrinsics if the main lowering pass is not run, for
2331 /// example with -O0.
2332 class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
2333 public:
2334  static char ID;
2335 
2336  LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
2339  }
2340 
2341  bool runOnFunction(Function &F) override {
2342  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2343  LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
2344  bool C = LMT.Visit();
2345  return C;
2346  }
2347 
2348  void getAnalysisUsage(AnalysisUsage &AU) const override {
2350  AU.setPreservesCFG();
2351  }
2352 };
2353 } // namespace
2354 
2355 static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
2357 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
2358  "lower-matrix-intrinsics-minimal", pass_name_minimal,
2359  false, false)
2360 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
2362  false)
2363 
2365  return new LowerMatrixIntrinsicsMinimalLegacyPass();
2366 }
llvm::Intrinsic::getBaseName
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:875
i
i
Definition: README.txt:29
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
pass_name
static const char pass_name[]
Definition: LowerMatrixIntrinsics.cpp:2311
llvm::AAManager
A manager for alias analyses.
Definition: AliasAnalysis.h:1303
llvm::TargetIRAnalysis
Analysis pass providing the TargetTransformInfo.
Definition: TargetTransformInfo.h:2485
llvm::Function::isIntrinsic
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:210
llvm::MemoryLocation::get
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
Definition: MemoryLocation.cpp:35
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:17
BlockSize
static const int BlockSize
Definition: TarWriter.cpp:33
llvm::DILocalScope::getSubprogram
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Definition: DebugInfoMetadata.cpp:930
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::make_range
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
Definition: iterator_range.h:53
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
llvm::Intrinsic::getDeclaration
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1419
llvm::User::operands
op_range operands()
Definition: User.h:242
llvm::X86II::TA
@ TA
Definition: X86BaseInfo.h:808
llvm::AllocaInst::getAlign
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:121
IntrinsicInst.h
llvm::Type::isPointerTy
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:218
ceil
We have fiadd patterns now but the followings have the same cost and complexity We need a way to specify the later is more profitable def def The FP stackifier should handle simple permutates to reduce number of shuffle e g ceil
Definition: README-FPStack.txt:54
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:780
DebugInfoMetadata.h
llvm::MemoryLocation::Ptr
const Value * Ptr
The address of the start of the location.
Definition: MemoryLocation.h:218
llvm::ValueMap::end
iterator end()
Definition: ValueMap.h:136
Scalar.h
llvm::PassInfoMixin< LowerMatrixIntrinsicsPass >
LowerLoad
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition: X86ISelLowering.cpp:25567
llvm::TypeSize::getFixedSize
ScalarTy getFixedSize() const
Definition: TypeSize.h:444
T
llvm::Function
Definition: Function.h:60
Pass.h
llvm::TargetTransformInfo::getRegisterBitWidth
TypeSize getRegisterBitWidth(RegisterKind K) const
Definition: TargetTransformInfo.cpp:628
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:53
LowerStore
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition: X86ISelLowering.cpp:25478
llvm::raw_string_ostream
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:632
llvm::write
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs)
Definition: DWP.cpp:536
llvm::PointerType::get
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Definition: Type.cpp:727
llvm::SetVector::size
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:77
llvm::Type::getScalarType
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:309
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1185
llvm::cast
decltype(auto) LLVM_NODISCARD cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition: Casting.h:565
llvm::enumerate
detail::enumerator< R > enumerate(R &&TheRange)
Given an input range, returns a new range whose values are are pair (A,B) such that A is the 0-based ...
Definition: STLExtras.h:2057
llvm::PatternMatch::m_Load
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
Definition: PatternMatch.h:1557
llvm::TargetTransformInfo
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Definition: TargetTransformInfo.h:168
insertVector
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2190
ToRemove
ReachingDefAnalysis InstSet & ToRemove
Definition: ARMLowOverheadLoops.cpp:542
llvm::IRBuilder<>
DomTreeUpdater.h
ValueTracking.h
OptimizationRemarkEmitter.h
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
llvm::OptimizationRemarkEmitter::allowExtraAnalysis
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
Definition: OptimizationRemarkEmitter.h:98
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:139
llvm::DILocation
Debug location.
Definition: DebugInfoMetadata.h:1557
ForceFusion
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
llvm::MemoryLocation::Size
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
Definition: MemoryLocation.h:227
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::DebugLoc::getInlinedAt
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:39
llvm::reverse
auto reverse(ContainerTy &&C, std::enable_if_t< has_rbegin< ContainerTy >::value > *=nullptr)
Definition: STLExtras.h:380
llvm::sys::path::end
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:235
llvm::sys::path::begin
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
llvm::LoopInfoWrapperPass
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:1287
llvm::DominatorTreeBase::Insert
static constexpr UpdateKind Insert
Definition: GenericDomTree.h:242
T1
#define T1
Definition: Mips16ISelLowering.cpp:340
llvm::SmallSet
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:136
llvm::operator!=
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:1992
Vector
So we should use XX3Form_Rcr to implement intrinsic Convert DP outs ins xscvdpsp No builtin are required Round &Convert QP DP(dword[1] is set to zero) No builtin are required Round to Quad Precision because you need to assign rounding mode in instruction Provide builtin(set f128:$vT,(int_ppc_vsx_xsrqpi f128:$vB))(set f128 yields< n x< ty > >< result > yields< ty >< result > No builtin are required Load Store Vector
Definition: README_P9.txt:497
vectors
hexagon Hexagon specific predictive commoning for HVX vectors
Definition: HexagonVectorLoopCarriedReuse.cpp:221
llvm::MapVector
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:37
llvm::SmallPtrSet
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
llvm::ore::NV
DiagnosticInfoOptimizationBase::Argument NV
Definition: OptimizationRemarkEmitter.h:136
llvm::VectorType::getElementType
Type * getElementType() const
Definition: DerivedTypes.h:422
llvm::Value::user_begin
user_iterator user_begin()
Definition: Value.h:397
llvm::successors
auto successors(MachineBasicBlock *BB)
Definition: MachineSSAContext.h:29
llvm::CallBase::arg_begin
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1316
llvm::SmallVectorImpl::pop_back_val
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:654
RHS
Value * RHS
Definition: X86PartialReduction.cpp:76
llvm::UnaryOperator
Definition: InstrTypes.h:101
llvm::FastMathFlags
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:21
llvm::initializeLowerMatrixIntrinsicsMinimalLegacyPassPass
void initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(PassRegistry &)
llvm::LoadInst::getAlign
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:216
llvm::DominatorTreeBase::applyUpdates
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
Definition: GenericDomTree.h:544
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
llvm::NVPTX::PTXLdStInstCode::VecType
VecType
Definition: NVPTX.h:121
llvm::commonAlignment
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:213
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::RISCVFenceField::R
@ R
Definition: RISCVBaseInfo.h:241
llvm::DomTreeUpdater::UpdateStrategy::Lazy
@ Lazy
llvm::TileInfo
A helper struct to create IR loop nests for tiling in IR of the following form: for CurrentColumn = 0...
Definition: MatrixUtils.h:31
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
getSubprogram
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
Definition: LowerMatrixIntrinsics.cpp:85
AliasAnalysis.h
MatrixBuilder.h
Context
LLVMContext & Context
Definition: NVVMIntrRange.cpp:66
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
llvm::DominatorTree::dominates
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
llvm::BitmaskEnumDetail::Mask
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:80
CommandLine.h
llvm::UnaryOperator::getOpcode
UnaryOps getOpcode() const
Definition: InstrTypes.h:171
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
extractVector
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2168
LHS
Value * LHS
Definition: X86PartialReduction.cpp:75
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
llvm::Intrinsic::getType
FunctionType * getType(LLVMContext &Context, ID id, ArrayRef< Type * > Tys=None)
Return the function type for an intrinsic.
Definition: Function.cpp:1375
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
MatrixUtils.h
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
isZero
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:524
llvm::AAResults
Definition: AliasAnalysis.h:511
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::SmallVectorImpl::append
void append(in_iter in_start, in_iter in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:667
llvm::createLowerMatrixIntrinsicsPass
Pass * createLowerMatrixIntrinsicsPass()
Definition: LowerMatrixIntrinsics.cpp:2322
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::ARM_PROC::A
@ A
Definition: ARMBaseInfo.h:34
TileSize
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1396
llvm::BasicBlock::begin
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:297
llvm::operator+=
std::string & operator+=(std::string &buffer, StringRef string)
Definition: StringRef.h:964
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::LocationSize::getValue
uint64_t getValue() const
Definition: MemoryLocation.h:159
llvm::ms_demangle::QualifierMangleMode::Result
@ Result
llvm::BitTracker
Definition: BitTracker.h:35
llvm::Value::uses
iterator_range< use_iterator > uses()
Definition: Value.h:376
false
Definition: StackSlotColoring.cpp:141
llvm::MaybeAlign
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
B
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
llvm::BinaryOperator::getOpcode
BinaryOps getOpcode() const
Definition: InstrTypes.h:392
llvm::PatternMatch::m_ConstantInt
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:147
llvm::Instruction
Definition: Instruction.h:42
llvm::Type::getScalarSizeInBits
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition: Type.cpp:189
llvm::DominatorTreeWrapperPass
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:302
llvm::raw_ostream
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:54
llvm::LowerMatrixIntrinsicsPass::run
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition: LowerMatrixIntrinsics.cpp:2240
llvm::codeview::EncodedFramePtrReg::BasePtr
@ BasePtr
llvm::UndefValue::get
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1768
llvm::DomTreeUpdater
Definition: DomTreeUpdater.h:28
llvm::ConstantInt::get
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:926
LoopUtils.h
llvm::getUnderlyingObject
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
Definition: ValueTracking.cpp:4450
PatternMatch.h
llvm::TargetTransformInfo::RGK_FixedWidthVector
@ RGK_FixedWidthVector
Definition: TargetTransformInfo.h:933
llvm::FixedVectorType::get
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:684
llvm::StoreInst::getAlign
Align getAlign() const
Definition: Instructions.h:341
llvm::Align
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
llvm::FastMathFlags::setAllowContract
void setAllowContract(bool B=true)
Definition: FMF.h:92
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
llvm::ValueMap::count
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:152
llvm::Value::use_empty
bool use_empty() const
Definition: Value.h:344
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
MatrixLayoutTy
MatrixLayoutTy
Definition: LowerMatrixIntrinsics.cpp:73
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
CFG.h
LoopInfo.h
llvm::function_ref
An efficient, type-erasing, non-owning reference to a callable.
Definition: STLFunctionalExtras.h:36
llvm::VectorType
Base class of all SIMD vector types.
Definition: DerivedTypes.h:389
llvm::X86AS::SS
@ SS
Definition: X86.h:193
VectorUtils.h
llvm::cl::opt< bool >
llvm::StoreInst::isVolatile
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:333
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:297
llvm::cl::values
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:685
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:77
llvm::AMDGPU::Hwreg::Offset
Offset
Definition: SIDefines.h:416
llvm::LowerMatrixIntrinsicsPass::printPipeline
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
Definition: LowerMatrixIntrinsics.cpp:2267
llvm::StringRef::empty
constexpr LLVM_NODISCARD bool empty() const
empty - Check if the string is empty.
Definition: StringRef.h:153
llvm::MapVector::find
iterator find(const KeyT &Key)
Definition: MapVector.h:148
llvm::getPointerOperand
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
Definition: Instructions.h:5331
uint64_t
llvm::TargetTransformInfoWrapperPass
Wrapper pass for TargetTransformInfo.
Definition: TargetTransformInfo.h:2541
D
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
intrinsics
expand Expand reduction intrinsics
Definition: ExpandReductions.cpp:198
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::PreservedAnalyses::preserve
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:173
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2801
llvm::addStringMetadataToLoop
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:217
llvm::TargetTransformInfo::getRegisterClassForType
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
Definition: TargetTransformInfo.cpp:619
llvm::omp::AddressSpace::Shared
@ Shared
llvm::DenseMap
Definition: DenseMap.h:716
llvm::PatternMatch::m_Store
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
Definition: PatternMatch.h:1564
llvm::codeview::FrameCookieKind::Copy
@ Copy
llvm::ValueMap::erase
bool erase(const KeyT &Val)
Definition: ValueMap.h:191
llvm::LoopInfoBase::getLoopFor
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition: LoopInfo.h:986
I
#define I(x, y, z)
Definition: MD5.cpp:58
llvm::cl::init
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:432
llvm::make_early_inc_range
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:618
llvm::concatenateVectors
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
Definition: VectorUtils.cpp:989
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::SPIRV::Decoration::Stream
@ Stream
SI
StandardInstrumentations SI(Debug, VerifyEach)
FuseMatrix
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
llvm::OptimizationRemarkEmitter::emit
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Definition: OptimizationRemarkEmitter.cpp:77
llvm::ValueMap::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:173
llvm::operator==
bool operator==(uint64_t V1, const APInt &V2)
Definition: APInt.h:1990
llvm::TTI
TargetTransformInfo TTI
Definition: TargetTransformInfo.h:163
llvm::Type::isVoidTy
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:139
llvm::ArrayType::get
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:638
MatrixLayoutTy::ColumnMajor
@ ColumnMajor
LowerMatrixIntrinsics.h
llvm::CallBase::arg_end
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1322
llvm::SmallPtrSetImpl::count
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:383
llvm::SmallSet::erase
bool erase(const T &V)
Definition: SmallSet.h:209
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::SetVector::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
llvm::size
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1598
llvm::Function::getIntrinsicID
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:205
llvm::ArrayRef
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: APInt.h:32
llvm::LoopInfo
Definition: LoopInfo.h:1102
llvm::Instruction::getFastMathFlags
FastMathFlags getFastMathFlags() const
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
Definition: Instruction.cpp:289
llvm::BinaryOperator
Definition: InstrTypes.h:188
llvm::OptimizationRemarkEmitter
The optimization diagnostic interface.
Definition: OptimizationRemarkEmitter.h:33
Matrix
Live Register Matrix
Definition: LiveRegMatrix.cpp:44
llvm::min
Expected< ExpressionValue > min(const ExpressionValue &Lhs, const ExpressionValue &Rhs)
Definition: FileCheck.cpp:357
llvm::MapVector::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:118
Mul
BinaryOperator * Mul
Definition: X86PartialReduction.cpp:70
llvm::any_of
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1624
llvm::CallBase::getParamAlign
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1742
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:263
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:58
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:143
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::AnalysisUsage::addPreserved
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
Definition: PassAnalysisSupport.h:98
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:529
getParent
static const Function * getParent(const Value *V)
Definition: BasicAliasAnalysis.cpp:845
llvm::ms_demangle::IntrinsicFunctionKind::New
@ New
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
clEnumValN
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:660
llvm::MatrixBuilder
Definition: MatrixBuilder.h:33
S
add sub stmia L5 ldr r0 bl L_printf $stub Instead of a and a wouldn t it be better to do three moves *Return an aggregate type is even return S
Definition: README.txt:210
pass_name_minimal
static const char pass_name_minimal[]
Definition: LowerMatrixIntrinsics.cpp:2355
llvm::TargetTransformInfo::getNumberOfRegisters
unsigned getNumberOfRegisters(unsigned ClassID) const
Definition: TargetTransformInfo.cpp:615
llvm::AMDGPU::HSAMD::Kernel::Arg::Key::IsVolatile
constexpr char IsVolatile[]
Key for Kernel::Arg::Metadata::mIsVolatile.
Definition: AMDGPUMetadata.h:199
llvm::SmallSet::insert
std::pair< NoneType, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:182
llvm::ifs::IFSSymbolType::Func
@ Func
llvm::Value::getName
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:305
llvm::ValueMap
See the file comment.
Definition: ValueMap.h:85
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:173
users
iv users
Definition: IVUsers.cpp:48
llvm::MapVector::end
iterator end()
Definition: MapVector.h:72
llvm::ConstantInt::getZExtValue
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:142
llvm::RecurKind::FMulAdd
@ FMulAdd
Fused multiply-add of floats (a * b + c).
llvm::StringRef::size
constexpr LLVM_NODISCARD size_t size() const
size - Get the string size.
Definition: StringRef.h:157
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
llvm::Twine
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:83
Alignment.h
llvm::OptimizationRemarkEmitterWrapperPass
OptimizationRemarkEmitter legacy analysis pass.
Definition: OptimizationRemarkEmitter.h:146
MatrixLayout
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
llvm::GraphProgram::Name
Name
Definition: GraphWriter.h:50
AllowContractEnabled
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
llvm::StringRef::drop_front
LLVM_NODISCARD StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:657
llvm::DIScope
Base class for scope-like contexts.
Definition: DebugInfoMetadata.h:471
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:348
llvm::PreservedAnalyses::all
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
MatrixLayoutTy::RowMajor
@ RowMajor
llvm::TypeSize
Definition: TypeSize.h:435
Function.h
llvm::Value::hasNUses
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:145
llvm::sort
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1562
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::count
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:215
llvm::Type::getPointerTo
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
Definition: Type.cpp:774
llvm::ValueMap::find
iterator find(const KeyT &Val)
Definition: ValueMap.h:156
llvm::MCID::Add
@ Add
Definition: MCInstrDesc.h:185
llvm::ReversePostOrderTraversal
Definition: PostOrderIterator.h:291
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:46
AA
llvm::DominatorTreeAnalysis
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:267
llvm::BasicBlock::reverse_iterator
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:89
llvm::OptimizationRemark
Diagnostic information for applied optimization remarks.
Definition: DiagnosticInfo.h:690
llvm::Pass
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:91
Instructions.h
PostOrderIterator.h
llvm::LoadInst::isVolatile
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:210
llvm::initializeLowerMatrixIntrinsicsLegacyPassPass
void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &)
SmallVector.h
BT
BitTracker BT
Definition: BitTracker.cpp:73
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1341
llvm::SmallPtrSetImplBase::empty
LLVM_NODISCARD bool empty() const
Definition: SmallPtrSet.h:92
N
#define N
llvm::AAResultsWrapperPass
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
Definition: AliasAnalysis.h:1351
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:91
llvm::to_string
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:85
TargetTransformInfo.h
llvm::iterator_range
A range adaptor for a pair of iterators.
Definition: iterator_range.h:30
llvm::PHINode
Definition: Instructions.h:2651
llvm::BasicBlock::getTerminator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:119
DEBUG_TYPE
#define DEBUG_TYPE
Definition: LowerMatrixIntrinsics.cpp:52
minimal
lower matrix intrinsics minimal
Definition: LowerMatrixIntrinsics.cpp:2361
llvm::DISubprogram
Subprogram description.
Definition: DebugInfoMetadata.h:1803
llvm::SmallSet::empty
LLVM_NODISCARD bool empty() const
Definition: SmallSet.h:157
llvm::SmallVectorImpl< Instruction * >
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, false, false) INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass
llvm::SmallPtrSetImpl< Instruction * >
llvm::SmallSetVector
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:307
llvm::AnalysisManager
A container for analyses that lazily runs them and caches their results.
Definition: InstructionSimplify.h:42
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1461
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
GEP
Hexagon Common GEP
Definition: HexagonCommonGEP.cpp:171
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
TileUseLoops
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
llvm::DebugLoc
A debug info location.
Definition: DebugLoc.h:33
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:58
llvm::FastMathFlags::allowContract
bool allowContract() const
Definition: FMF.h:71
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::cl::desc
Definition: CommandLine.h:405
llvm::createLowerMatrixIntrinsicsMinimalPass
Pass * createLowerMatrixIntrinsicsMinimalPass()
Definition: LowerMatrixIntrinsics.cpp:2364
llvm::createSequentialMask
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
Definition: VectorUtils.cpp:934
llvm::SetVector< Value * >
llvm::ConstantAggregateZero::get
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1647
BasicBlockUtils.h
llvm::SplitBlock
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Definition: BasicBlockUtils.cpp:837
llvm::pdb::PDB_SymType::Block
@ Block
InitializePasses.h
llvm::OptimizationRemarkEmitterAnalysis
Definition: OptimizationRemarkEmitter.h:164
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
Debug.h
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::MemoryLocation
Representation for a specific memory location.
Definition: MemoryLocation.h:210
llvm::LoopAnalysis
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:1262
llvm::DominatorTreeBase::Delete
static constexpr UpdateKind Delete
Definition: GenericDomTree.h:243
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
llvm::Type::getPrimitiveSizeInBits
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition: Type.cpp:164
llvm::SmallPtrSetImpl::insert
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:38