LLVM  14.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 // * Improve fusion:
13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14 // transposed.
15 // * Improve cost-modeling, e.g. choose different number of rows/columns
16 // columns for tiles, consider cost of copies on alias.
17 //
18 //===----------------------------------------------------------------------===//
19 
21 #include "llvm/ADT/GraphTraits.h"
23 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/Instructions.h"
36 #include "llvm/IR/IntrinsicInst.h"
37 #include "llvm/IR/MatrixBuilder.h"
38 #include "llvm/IR/PatternMatch.h"
39 #include "llvm/InitializePasses.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Alignment.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Transforms/Scalar.h"
48 
49 using namespace llvm;
50 using namespace PatternMatch;
51 
52 #define DEBUG_TYPE "lower-matrix-intrinsics"
53 
54 static cl::opt<bool>
55  FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
56  cl::desc("Enable/disable fusing matrix instructions."));
57 // TODO: Allow and use non-square tiles.
59  "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
60  cl::desc(
61  "Tile size for matrix instruction fusion using square-shaped tiles."));
62 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
63  cl::Hidden,
64  cl::desc("Generate loop nest for tiling."));
66  "force-fuse-matrix", cl::init(false), cl::Hidden,
67  cl::desc("Force matrix instruction fusion even if not profitable."));
69  "matrix-allow-contract", cl::init(false), cl::Hidden,
70  cl::desc("Allow the use of FMAs if available and profitable. This may "
71  "result in different results, due to less rounding error."));
72 
74 
76  "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
77  cl::desc("Sets the default matrix layout"),
79  "Use column-major layout"),
81  "Use row-major layout")));
82 
83 /// Helper function to either return Scope, if it is a subprogram or the
84 /// attached subprogram for a local scope.
86  if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
87  return Subprogram;
88  return cast<DILocalScope>(Scope)->getSubprogram();
89 }
90 
91 namespace {
92 
93 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
94 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
95 // assuming \p Stride elements between start two consecutive vectors.
96 // \p Stride must be >= \p NumElements.
97 // For column-major matrixes, the function computes the address of a column
98 // vectors and \p NumElements must be set to the number of elements in a column
99 // (= number of rows of the matrix). For row-major matrixes, the function
100 // computes the address of a row vector and \p NumElements must be set to the
101 // number of elements in a column (= number of columns of the matrix).
102 //
103 // Consider a 4x4 matrix in column-mjaor layout like below
104 //
105 // 0 1 2 3
106 // 0 v_0_0 v_0_1 v_0_2 v_0_3
107 // 1 v_1_0 v_1_1 v_1_2 v_1_3
108 // 2 v_2_0 v_2_1 v_2_2 v_2_3
109 // 3 v_3_0 v_3_1 v_3_2 v_3_3
110 
111 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
112 // we need a pointer to the first element of the submatrix as base pointer.
113 // Then we can use computeVectorAddr to compute the addresses for the columns
114 // of the sub-matrix.
115 //
116 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
117 // -> just returns Base
118 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
119 // -> returns Base + (1 * 4)
120 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
121 // -> returns Base + (2 * 4)
122 //
123 // The graphic below illustrates the number of elements in a column (marked
124 // with |) and the number of skipped elements (marked with }).
125 //
126 // v_0_0 v_0_1 {v_0_2 {v_0_3
127 // Base Col 1 Col 2
128 // | | |
129 // v_1_0 |v_1_1 |v_1_2 |v_1_3
130 // v_2_0 |v_2_1 |v_2_2 |v_2_3
131 // v_3_0 {v_3_1 {v_3_2 v_3_3
132 //
133 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
134  unsigned NumElements, Type *EltType,
135  IRBuilder<> &Builder) {
136 
137  assert((!isa<ConstantInt>(Stride) ||
138  cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
139  "Stride must be >= the number of elements in the result vector.");
140  unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
141 
142  // Compute the start of the vector with index VecIdx as VecIdx * Stride.
143  Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
144 
145  // Get pointer to the start of the selected vector. Skip GEP creation,
146  // if we select vector 0.
147  if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
148  VecStart = BasePtr;
149  else
150  VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
151 
152  // Cast elementwise vector start pointer to a pointer to a vector
153  // (EltType x NumElements)*.
154  auto *VecType = FixedVectorType::get(EltType, NumElements);
155  Type *VecPtrType = PointerType::get(VecType, AS);
156  return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
157 }
158 
159 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
160 ///
161 /// Currently, the lowering for each matrix intrinsic is done as follows:
162 /// 1. Propagate the shape information from intrinsics to connected
163 /// instructions.
164 /// 2. Lower instructions with shape information (assuming column-major layout).
165 /// The lowering works similarly using row-major layout.
166 /// 2.1. Get column vectors for each argument. If we already lowered the
167 /// definition of an argument, use the produced column vectors directly.
168 /// If not, split the operand vector containing an embedded matrix into
169 /// a set of column vectors,
170 /// 2.2. Lower the instruction in terms of column major operations, which
171 /// yields a set of column vectors containing result matrix. Note that we
172 /// lower all instructions that have shape information. Besides the
173 /// intrinsics, this includes stores for example.
174 /// 2.3. Update uses of the lowered instruction. If we have shape information
175 /// for a user, there is nothing to do, as we will look up the result
176 /// column matrix when lowering the user. For other uses, we embed the
177 /// result matrix in a flat vector and update the use.
178 /// 2.4. Cache the result column matrix for the instruction we lowered
179 /// 3. After we lowered all instructions in a function, remove the now
180 /// obsolete instructions.
181 ///
182 class LowerMatrixIntrinsics {
183  Function &Func;
184  const DataLayout &DL;
185  const TargetTransformInfo &TTI;
186  AliasAnalysis *AA;
187  DominatorTree *DT;
188  LoopInfo *LI;
190 
191  /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
192  struct OpInfoTy {
193  /// Number of stores emitted to generate this matrix.
194  unsigned NumStores = 0;
195  /// Number of loads emitted to generate this matrix.
196  unsigned NumLoads = 0;
197  /// Number of compute operations emitted to generate this matrix.
198  unsigned NumComputeOps = 0;
199  /// Most of the time transposes can be fused with matrix multiplies or can
200  /// be folded away via algebraic simplifications. This is the number of
201  /// transposes that we failed to make "free" via such optimizations.
202  unsigned NumExposedTransposes = 0;
203 
204  OpInfoTy &operator+=(const OpInfoTy &RHS) {
205  NumStores += RHS.NumStores;
206  NumLoads += RHS.NumLoads;
207  NumComputeOps += RHS.NumComputeOps;
208  NumExposedTransposes += RHS.NumExposedTransposes;
209  return *this;
210  }
211  };
212 
213  /// Wrapper class representing a matrix as a set of vectors, either in row or
214  /// column major layout. All vectors must have the same vector type.
215  class MatrixTy {
216  SmallVector<Value *, 16> Vectors;
217 
218  OpInfoTy OpInfo;
219 
220  bool IsColumnMajor = true;
221 
222  public:
223  MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
224  MatrixTy(ArrayRef<Value *> Vectors)
225  : Vectors(Vectors.begin(), Vectors.end()),
226  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
227  MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
228  : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
229 
230  unsigned D = isColumnMajor() ? NumColumns : NumRows;
231  for (unsigned J = 0; J < D; ++J)
233  EltTy, isColumnMajor() ? NumRows : NumColumns)));
234  }
235 
236  Value *getVector(unsigned i) const { return Vectors[i]; }
237  Value *getColumn(unsigned i) const {
238  assert(isColumnMajor() && "only supported for column-major matrixes");
239  return Vectors[i];
240  }
241  Value *getRow(unsigned i) const {
242  assert(!isColumnMajor() && "only supported for row-major matrixes");
243  return Vectors[i];
244  }
245 
246  void setVector(unsigned i, Value *V) { Vectors[i] = V; }
247 
248  Type *getElementType() const { return getVectorTy()->getElementType(); }
249 
250  unsigned getNumVectors() const {
251  if (isColumnMajor())
252  return getNumColumns();
253  return getNumRows();
254  }
255 
256  unsigned getNumColumns() const {
257  if (isColumnMajor())
258  return Vectors.size();
259  else {
260  assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
261  return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
262  }
263  }
264  unsigned getNumRows() const {
265  if (isColumnMajor()) {
266  assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
267  return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
268  } else
269  return Vectors.size();
270  }
271 
272  void addVector(Value *V) { Vectors.push_back(V); }
273  VectorType *getColumnTy() {
274  assert(isColumnMajor() && "only supported for column-major matrixes");
275  return getVectorTy();
276  }
277 
278  VectorType *getVectorTy() const {
279  return cast<VectorType>(Vectors[0]->getType());
280  }
281 
283  assert(isColumnMajor() &&
284  "columns() only supported for column-major matrixes");
285  return make_range(Vectors.begin(), Vectors.end());
286  }
287 
289  return make_range(Vectors.begin(), Vectors.end());
290  }
291 
292  /// Embed the vectors of the matrix into a flat vector by concatenating
293  /// them.
294  Value *embedInVector(IRBuilder<> &Builder) const {
295  return Vectors.size() == 1 ? Vectors[0]
296  : concatenateVectors(Builder, Vectors);
297  }
298 
299  MatrixTy &addNumLoads(unsigned N) {
300  OpInfo.NumLoads += N;
301  return *this;
302  }
303 
304  void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
305 
306  MatrixTy &addNumStores(unsigned N) {
307  OpInfo.NumStores += N;
308  return *this;
309  }
310 
311  MatrixTy &addNumExposedTransposes(unsigned N) {
312  OpInfo.NumExposedTransposes += N;
313  return *this;
314  }
315 
316  MatrixTy &addNumComputeOps(unsigned N) {
317  OpInfo.NumComputeOps += N;
318  return *this;
319  }
320 
321  unsigned getNumStores() const { return OpInfo.NumStores; }
322  unsigned getNumLoads() const { return OpInfo.NumLoads; }
323  unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
324 
325  const OpInfoTy &getOpInfo() const { return OpInfo; }
326 
327  bool isColumnMajor() const { return IsColumnMajor; }
328 
329  unsigned getStride() const {
330  if (isColumnMajor())
331  return getNumRows();
332  return getNumColumns();
333  }
334 
335  /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
336  /// matrix is column-major, the result vector is extracted from a column
337  /// vector, otherwise from a row vector.
338  Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
339  IRBuilder<> &Builder) const {
340  Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
341  return Builder.CreateShuffleVector(
342  Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
343  "block");
344  }
345  };
346 
347  struct ShapeInfo {
348  unsigned NumRows;
349  unsigned NumColumns;
350 
351  bool IsColumnMajor;
352 
353  ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
354  : NumRows(NumRows), NumColumns(NumColumns),
355  IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
356 
357  ShapeInfo(Value *NumRows, Value *NumColumns)
358  : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
359  cast<ConstantInt>(NumColumns)->getZExtValue()) {}
360 
361  bool operator==(const ShapeInfo &other) {
362  return NumRows == other.NumRows && NumColumns == other.NumColumns;
363  }
364  bool operator!=(const ShapeInfo &other) { return !(*this == other); }
365 
366  /// Returns true if shape-information is defined, meaning both dimensions
367  /// are != 0.
368  operator bool() const {
369  assert(NumRows == 0 || NumColumns != 0);
370  return NumRows != 0;
371  }
372 
373  unsigned getStride() const {
374  if (IsColumnMajor)
375  return NumRows;
376  return NumColumns;
377  }
378 
379  unsigned getNumVectors() const {
380  if (IsColumnMajor)
381  return NumColumns;
382  return NumRows;
383  }
384  };
385 
386  /// Maps instructions to their shape information. The shape information
387  /// describes the shape to be used while lowering. This matches the shape of
388  /// the result value of the instruction, with the only exceptions being store
389  /// instructions and the matrix_column_major_store intrinsics. For those, the
390  /// shape information indicates that those instructions should be lowered
391  /// using shape information as well. A ValueMap is used so that when
392  /// sub-passes like optimizeTransposes performs RAUW the map stays
393  /// up-to-date.
395 
396  /// List of instructions to remove. While lowering, we are not replacing all
397  /// users of a lowered instruction, if shape information is available and
398  /// those need to be removed after we finished lowering.
400 
401  /// Map from instructions to their produced column matrix.
402  MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
403 
404 private:
405  static FastMathFlags getFastMathFlags(Instruction *Inst) {
406  FastMathFlags FMF;
407 
408  if (isa<FPMathOperator>(*Inst))
409  FMF = Inst->getFastMathFlags();
410 
412 
413  return FMF;
414  }
415 
416 public:
417  LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
418  AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
420  : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
421  LI(LI), ORE(ORE) {}
422 
423  unsigned getNumOps(Type *VT) {
424  assert(isa<VectorType>(VT) && "Expected vector type");
425  return getNumOps(VT->getScalarType(),
426  cast<FixedVectorType>(VT)->getNumElements());
427  }
428 
429  /// Is this the minimal version executed in the backend pipelines.
430  bool isMinimal() const {
431  return !DT;
432  }
433 
434  /// Return the estimated number of vector ops required for an operation on
435  /// \p VT * N.
436  unsigned getNumOps(Type *ST, unsigned N) {
437  return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
438  double(TTI.getRegisterBitWidth(
440  .getFixedSize()));
441  }
442 
443  /// Return the set of vectors that a matrix value is lowered to.
444  ///
445  /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
446  /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
447  /// into vectors.
448  MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
449  IRBuilder<> &Builder) {
450  VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
451  assert(VType && "MatrixVal must be a vector type");
452  assert(cast<FixedVectorType>(VType)->getNumElements() ==
453  SI.NumRows * SI.NumColumns &&
454  "The vector size must match the number of matrix elements");
455 
456  // Check if we lowered MatrixVal using shape information. In that case,
457  // return the existing matrix, if it matches the requested shape
458  // information. If there is a mis-match, embed the result in a flat
459  // vector and split it later.
460  auto Found = Inst2ColumnMatrix.find(MatrixVal);
461  if (Found != Inst2ColumnMatrix.end()) {
462  MatrixTy &M = Found->second;
463  // Return the found matrix, if its shape matches the requested shape
464  // information
465  if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
466  return M;
467 
468  MatrixVal = M.embedInVector(Builder);
469  }
470 
471  // Otherwise split MatrixVal.
472  SmallVector<Value *, 16> SplitVecs;
473  for (unsigned MaskStart = 0;
474  MaskStart < cast<FixedVectorType>(VType)->getNumElements();
475  MaskStart += SI.getStride()) {
476  Value *V = Builder.CreateShuffleVector(
477  MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
478  "split");
479  SplitVecs.push_back(V);
480  }
481 
482  return {SplitVecs};
483  }
484 
485  /// If \p V already has a known shape return false. Otherwise set the shape
486  /// for instructions that support it.
487  bool setShapeInfo(Value *V, ShapeInfo Shape) {
488  assert(Shape && "Shape not set");
489  if (isa<UndefValue>(V) || !supportsShapeInfo(V))
490  return false;
491 
492  auto SIter = ShapeMap.find(V);
493  if (SIter != ShapeMap.end()) {
494  LLVM_DEBUG(dbgs() << " not overriding existing shape: "
495  << SIter->second.NumRows << " "
496  << SIter->second.NumColumns << " for " << *V << "\n");
497  return false;
498  }
499 
500  ShapeMap.insert({V, Shape});
501  LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
502  << " for " << *V << "\n");
503  return true;
504  }
505 
506  bool isUniformShape(Value *V) {
507  Instruction *I = dyn_cast<Instruction>(V);
508  if (!I)
509  return true;
510 
511  switch (I->getOpcode()) {
512  case Instruction::FAdd:
513  case Instruction::FSub:
514  case Instruction::FMul: // Scalar multiply.
515  case Instruction::FNeg:
516  case Instruction::Add:
517  case Instruction::Mul:
518  case Instruction::Sub:
519  return true;
520  default:
521  return false;
522  }
523  }
524 
525  /// Returns true if shape information can be used for \p V. The supported
526  /// instructions must match the instructions that can be lowered by this pass.
527  bool supportsShapeInfo(Value *V) {
528  Instruction *Inst = dyn_cast<Instruction>(V);
529  if (!Inst)
530  return false;
531 
532  IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
533  if (II)
534  switch (II->getIntrinsicID()) {
535  case Intrinsic::matrix_multiply:
536  case Intrinsic::matrix_transpose:
537  case Intrinsic::matrix_column_major_load:
538  case Intrinsic::matrix_column_major_store:
539  return true;
540  default:
541  return false;
542  }
543  return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
544  }
545 
546  /// Propagate the shape information of instructions to their users.
547  /// The work list contains instructions for which we can compute the shape,
548  /// either based on the information provided by matrix intrinsics or known
549  /// shapes of operands.
551  propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
552  SmallVector<Instruction *, 32> NewWorkList;
553  // Pop an element for which we guaranteed to have at least one of the
554  // operand shapes. Add the shape for this and then add users to the work
555  // list.
556  LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
557  while (!WorkList.empty()) {
558  Instruction *Inst = WorkList.pop_back_val();
559 
560  // New entry, set the value and insert operands
561  bool Propagate = false;
562 
563  Value *MatrixA;
564  Value *MatrixB;
565  Value *M;
566  Value *N;
567  Value *K;
568  if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
569  m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
570  m_Value(N), m_Value(K)))) {
571  Propagate = setShapeInfo(Inst, {M, K});
572  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
573  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
574  // Flip dimensions.
575  Propagate = setShapeInfo(Inst, {N, M});
576  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
577  m_Value(MatrixA), m_Value(), m_Value(),
578  m_Value(), m_Value(M), m_Value(N)))) {
579  Propagate = setShapeInfo(Inst, {N, M});
580  } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
581  m_Value(), m_Value(), m_Value(), m_Value(M),
582  m_Value(N)))) {
583  Propagate = setShapeInfo(Inst, {M, N});
584  } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
585  auto OpShape = ShapeMap.find(MatrixA);
586  if (OpShape != ShapeMap.end())
587  setShapeInfo(Inst, OpShape->second);
588  continue;
589  } else if (isUniformShape(Inst)) {
590  // Find the first operand that has a known shape and use that.
591  for (auto &Op : Inst->operands()) {
592  auto OpShape = ShapeMap.find(Op.get());
593  if (OpShape != ShapeMap.end()) {
594  Propagate |= setShapeInfo(Inst, OpShape->second);
595  break;
596  }
597  }
598  }
599 
600  if (Propagate) {
601  NewWorkList.push_back(Inst);
602  for (auto *User : Inst->users())
603  if (ShapeMap.count(User) == 0)
604  WorkList.push_back(cast<Instruction>(User));
605  }
606  }
607 
608  return NewWorkList;
609  }
610 
611  /// Propagate the shape to operands of instructions with shape information.
612  /// \p Worklist contains the instruction for which we already know the shape.
614  propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
615  SmallVector<Instruction *, 32> NewWorkList;
616 
617  auto pushInstruction = [](Value *V,
618  SmallVectorImpl<Instruction *> &WorkList) {
619  Instruction *I = dyn_cast<Instruction>(V);
620  if (I)
621  WorkList.push_back(I);
622  };
623  // Pop an element with known shape. Traverse the operands, if their shape
624  // derives from the result shape and is unknown, add it and add them to the
625  // worklist.
626  LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
627  while (!WorkList.empty()) {
628  Value *V = WorkList.pop_back_val();
629 
630  size_t BeforeProcessingV = WorkList.size();
631  if (!isa<Instruction>(V))
632  continue;
633 
634  Value *MatrixA;
635  Value *MatrixB;
636  Value *M;
637  Value *N;
638  Value *K;
639  if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
640  m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
641  m_Value(N), m_Value(K)))) {
642  if (setShapeInfo(MatrixA, {M, N}))
643  pushInstruction(MatrixA, WorkList);
644 
645  if (setShapeInfo(MatrixB, {N, K}))
646  pushInstruction(MatrixB, WorkList);
647 
648  } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
649  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
650  // Flip dimensions.
651  if (setShapeInfo(MatrixA, {M, N}))
652  pushInstruction(MatrixA, WorkList);
653  } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
654  m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
655  m_Value(M), m_Value(N)))) {
656  if (setShapeInfo(MatrixA, {M, N})) {
657  pushInstruction(MatrixA, WorkList);
658  }
659  } else if (isa<LoadInst>(V) ||
660  match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
661  // Nothing to do, no matrix input.
662  } else if (isa<StoreInst>(V)) {
663  // Nothing to do. We forward-propagated to this so we would just
664  // backward propagate to an instruction with an already known shape.
665  } else if (isUniformShape(V)) {
666  // Propagate to all operands.
667  ShapeInfo Shape = ShapeMap[V];
668  for (Use &U : cast<Instruction>(V)->operands()) {
669  if (setShapeInfo(U.get(), Shape))
670  pushInstruction(U.get(), WorkList);
671  }
672  }
673  // After we discovered new shape info for new instructions in the
674  // worklist, we use their users as seeds for the next round of forward
675  // propagation.
676  for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
677  for (User *U : WorkList[I]->users())
678  if (isa<Instruction>(U) && V != U)
679  NewWorkList.push_back(cast<Instruction>(U));
680  }
681  return NewWorkList;
682  }
683 
684  /// Try moving transposes in order to fold them away or into multiplies.
685  void optimizeTransposes() {
686  auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
687  // We need to remove Old from the ShapeMap otherwise RAUW will replace it
688  // with New. We should only add New it it supportsShapeInfo so we insert
689  // it conditionally instead.
690  auto S = ShapeMap.find(&Old);
691  if (S != ShapeMap.end()) {
692  ShapeMap.erase(S);
693  if (supportsShapeInfo(New))
694  ShapeMap.insert({New, S->second});
695  }
696  Old.replaceAllUsesWith(New);
697  };
698 
699  // First sink all transposes inside matmuls, hoping that we end up with NN,
700  // NT or TN variants.
701  for (BasicBlock &BB : reverse(Func)) {
702  for (auto II = BB.rbegin(); II != BB.rend();) {
703  Instruction &I = *II;
704  // We may remove II. By default continue on the next/prev instruction.
705  ++II;
706  // If we were to erase II, move again.
707  auto EraseFromParent = [&II](Value *V) {
708  auto *Inst = cast<Instruction>(V);
709  if (Inst->use_empty()) {
710  if (Inst == &*II) {
711  ++II;
712  }
713  Inst->eraseFromParent();
714  }
715  };
716 
717  // If we're creating a new instruction, continue from there.
718  Instruction *NewInst = nullptr;
719 
720  IRBuilder<> IB(&I);
722 
723  Value *TA, *TAMA, *TAMB;
724  ConstantInt *R, *K, *C;
725  if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) {
726 
727  // Transpose of a transpose is a nop
728  Value *TATA;
729  if (match(TA,
730  m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
731  ReplaceAllUsesWith(I, TATA);
732  EraseFromParent(&I);
733  EraseFromParent(TA);
734  }
735 
736  // (A * B)^t -> B^t * A^t
737  // RxK KxC CxK KxR
738  else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
739  m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
740  m_ConstantInt(K), m_ConstantInt(C)))) {
741  Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(),
742  C->getZExtValue(),
743  TAMB->getName() + "_t");
744  // We are being run after shape prop, add shape for newly created
745  // instructions so that we lower them later.
746  setShapeInfo(T0, {C, K});
747  Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(),
748  K->getZExtValue(),
749  TAMA->getName() + "_t");
750  setShapeInfo(T1, {K, R});
751  NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
752  K->getZExtValue(),
753  R->getZExtValue(), "mmul");
754  ReplaceAllUsesWith(I, NewInst);
755  EraseFromParent(&I);
756  EraseFromParent(TA);
757  }
758  }
759 
760  // If we replaced I with a new instruction, continue from there.
761  if (NewInst)
762  II = std::next(BasicBlock::reverse_iterator(NewInst));
763  }
764  }
765 
766  // If we have a TT matmul, lift the transpose. We may be able to fold into
767  // consuming multiply.
768  for (BasicBlock &BB : Func) {
769  for (BasicBlock::iterator II = BB.begin(); II != BB.end();) {
770  Instruction *I = &*II;
771  // We may remove I.
772  ++II;
773  Value *A, *B, *AT, *BT;
774  ConstantInt *R, *K, *C;
775  // A^t * B ^t -> (B * A)^t
776  if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>(
777  m_Value(A), m_Value(B), m_ConstantInt(R),
778  m_ConstantInt(K), m_ConstantInt(C))) &&
779  match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
780  match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
781  IRBuilder<> IB(&*I);
783  Value *M = Builder.CreateMatrixMultiply(
784  BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
785  setShapeInfo(M, {C, R});
786  Instruction *NewInst = Builder.CreateMatrixTranspose(
787  M, C->getZExtValue(), R->getZExtValue());
788  ReplaceAllUsesWith(*I, NewInst);
789  if (I->use_empty())
790  I->eraseFromParent();
791  if (A->use_empty())
792  cast<Instruction>(A)->eraseFromParent();
793  if (A != B && B->use_empty())
794  cast<Instruction>(B)->eraseFromParent();
795  }
796  }
797  }
798  }
799 
800  bool Visit() {
802 
803  // Initially only the shape of matrix intrinsics is known.
804  // Initialize the work list with ops carrying shape information.
805  for (BasicBlock &BB : Func)
806  for (Instruction &Inst : BB) {
807  IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
808  if (!II)
809  continue;
810 
811  switch (II->getIntrinsicID()) {
812  case Intrinsic::matrix_multiply:
813  case Intrinsic::matrix_transpose:
814  case Intrinsic::matrix_column_major_load:
815  case Intrinsic::matrix_column_major_store:
816  WorkList.push_back(&Inst);
817  break;
818  default:
819  break;
820  }
821  }
822 
823  // Avoid unnecessary work if there are no matrix intrinsics in the function.
824  if (WorkList.empty())
825  return false;
826 
827  // Propagate shapes until nothing changes any longer.
828  while (!WorkList.empty()) {
829  WorkList = propagateShapeForward(WorkList);
830  WorkList = propagateShapeBackward(WorkList);
831  }
832 
833  if (!isMinimal()) {
834  optimizeTransposes();
835  LLVM_DEBUG({
836  dbgs() << "Dump after matrix transpose optimization:\n";
837  Func.dump();
838  });
839  }
840 
841  bool Changed = false;
842  SmallVector<CallInst *, 16> MaybeFusableInsts;
843  SmallVector<Instruction *, 16> MatrixInsts;
844 
845  // First, collect all instructions with shape information and candidates for
846  // fusion (currently only matrix multiplies).
848  for (auto *BB : RPOT)
849  for (Instruction &I : *BB) {
850  if (ShapeMap.find(&I) == ShapeMap.end())
851  continue;
852  if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
853  MaybeFusableInsts.push_back(cast<CallInst>(&I));
854  MatrixInsts.push_back(&I);
855  }
856 
857  // Second, try to fuse candidates.
859  for (CallInst *CI : MaybeFusableInsts)
860  LowerMatrixMultiplyFused(CI, FusedInsts);
861  Changed = !FusedInsts.empty();
862 
863  // Third, lower remaining instructions with shape information.
864  for (Instruction *Inst : MatrixInsts) {
865  if (FusedInsts.count(Inst))
866  continue;
867 
868  IRBuilder<> Builder(Inst);
869 
870  if (CallInst *CInst = dyn_cast<CallInst>(Inst))
871  Changed |= VisitCallInst(CInst);
872 
873  Value *Op1;
874  Value *Op2;
875  if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
876  Changed |= VisitBinaryOperator(BinOp);
877  if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
878  Changed |= VisitUnaryOperator(UnOp);
879  if (match(Inst, m_Load(m_Value(Op1))))
880  Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
881  else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
882  Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
883  }
884 
885  if (ORE) {
886  RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
887  RemarkGen.emitRemarks();
888  }
889 
890  // Delete the instructions backwards, as it has a reduced likelihood of
891  // having to update as many def-use and use-def chains.
892  //
893  // Because we add to ToRemove during fusion we can't guarantee that defs
894  // are before uses. Change uses to undef temporarily as these should get
895  // removed as well.
896  //
897  // For verification, we keep track of where we changed uses to undefs in
898  // UndefedInsts and then check that we in fact remove them.
899  SmallSet<Instruction *, 16> UndefedInsts;
900  for (auto *Inst : reverse(ToRemove)) {
901  for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
902  if (auto *Undefed = dyn_cast<Instruction>(U.getUser()))
903  UndefedInsts.insert(Undefed);
904  U.set(UndefValue::get(Inst->getType()));
905  }
906  Inst->eraseFromParent();
907  UndefedInsts.erase(Inst);
908  }
909  if (!UndefedInsts.empty()) {
910  // If we didn't remove all undefed instructions, it's a hard error.
911  dbgs() << "Undefed but present instructions:\n";
912  for (auto *I : UndefedInsts)
913  dbgs() << *I << "\n";
914  llvm_unreachable("Undefed but instruction not removed");
915  }
916 
917  return Changed;
918  }
919 
920  /// Turns \p BasePtr into an elementwise pointer to \p EltType.
921  Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
922  unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
923  Type *EltPtrType = PointerType::get(EltType, AS);
924  return Builder.CreatePointerCast(BasePtr, EltPtrType);
925  }
926 
927  /// Replace intrinsic calls
928  bool VisitCallInst(CallInst *Inst) {
929  if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
930  return false;
931 
932  switch (Inst->getCalledFunction()->getIntrinsicID()) {
933  case Intrinsic::matrix_multiply:
934  LowerMultiply(Inst);
935  break;
936  case Intrinsic::matrix_transpose:
937  LowerTranspose(Inst);
938  break;
939  case Intrinsic::matrix_column_major_load:
940  LowerColumnMajorLoad(Inst);
941  break;
942  case Intrinsic::matrix_column_major_store:
943  LowerColumnMajorStore(Inst);
944  break;
945  default:
946  return false;
947  }
948  return true;
949  }
950 
951  /// Compute the alignment for a column/row \p Idx with \p Stride between them.
952  /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
953  /// ConstantInt, reduce the initial alignment based on the byte offset. For
954  /// non-ConstantInt strides, return the common alignment of the initial
955  /// alignment and the element size in bytes.
956  Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
957  MaybeAlign A) const {
958  Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
959  if (Idx == 0)
960  return InitialAlign;
961 
962  TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
963  if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
964  uint64_t StrideInBytes =
965  ConstStride->getZExtValue() * ElementSizeInBits / 8;
966  return commonAlignment(InitialAlign, Idx * StrideInBytes);
967  }
968  return commonAlignment(InitialAlign, ElementSizeInBits / 8);
969  }
970 
971  /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
972  /// vectors.
973  MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
974  bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
975  auto *VType = cast<VectorType>(Ty);
976  Type *EltTy = VType->getElementType();
977  Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
978  Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
979  MatrixTy Result;
980  for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
981  Value *GEP = computeVectorAddr(
982  EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
983  Stride, Shape.getStride(), EltTy, Builder);
984  Value *Vector = Builder.CreateAlignedLoad(
985  VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
986  IsVolatile, "col.load");
987 
988  Result.addVector(Vector);
989  }
990  return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
991  Result.getNumVectors());
992  }
993 
994  /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
995  /// starting at \p MatrixPtr[I][J].
996  MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
997  ShapeInfo MatrixShape, Value *I, Value *J,
998  ShapeInfo ResultShape, Type *EltTy,
999  IRBuilder<> &Builder) {
1000 
1001  Value *Offset = Builder.CreateAdd(
1002  Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1003 
1004  unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1005  Value *EltPtr =
1006  Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1007  Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1008  auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1009  ResultShape.NumColumns);
1010  Type *TilePtrTy = PointerType::get(TileTy, AS);
1011  Value *TilePtr =
1012  Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1013 
1014  return loadMatrix(TileTy, TilePtr, Align,
1015  Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1016  ResultShape, Builder);
1017  }
1018 
1019  /// Lower a load instruction with shape information.
1020  void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1021  bool IsVolatile, ShapeInfo Shape) {
1022  IRBuilder<> Builder(Inst);
1023  finalizeLowering(Inst,
1024  loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1025  Shape, Builder),
1026  Builder);
1027  }
1028 
1029  /// Lowers llvm.matrix.column.major.load.
1030  ///
1031  /// The intrinsic loads a matrix from memory using a stride between columns.
1032  void LowerColumnMajorLoad(CallInst *Inst) {
1034  "Intrinsic only supports column-major layout!");
1035  Value *Ptr = Inst->getArgOperand(0);
1036  Value *Stride = Inst->getArgOperand(1);
1037  LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1038  cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1039  {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1040  }
1041 
1042  /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1043  /// MatrixPtr[I][J].
1044  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1045  MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1046  Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1047  Value *Offset = Builder.CreateAdd(
1048  Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1049 
1050  unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1051  Value *EltPtr =
1052  Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1053  Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1054  auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1055  StoreVal.getNumColumns());
1056  Type *TilePtrTy = PointerType::get(TileTy, AS);
1057  Value *TilePtr =
1058  Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1059 
1060  storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1061  Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1062  }
1063 
1064  /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1065  /// vectors.
1066  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1067  MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1068  IRBuilder<> &Builder) {
1069  auto VType = cast<VectorType>(Ty);
1070  Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
1071  for (auto Vec : enumerate(StoreVal.vectors())) {
1072  Value *GEP = computeVectorAddr(
1073  EltPtr,
1074  Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1075  Vec.index()),
1076  Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1077  Builder.CreateAlignedStore(Vec.value(), GEP,
1078  getAlignForIndex(Vec.index(), Stride,
1079  VType->getElementType(),
1080  MAlign),
1081  IsVolatile);
1082  }
1083  return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1084  StoreVal.getNumVectors());
1085  }
1086 
1087  /// Lower a store instruction with shape information.
1088  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1089  Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1090  IRBuilder<> Builder(Inst);
1091  auto StoreVal = getMatrix(Matrix, Shape, Builder);
1092  finalizeLowering(Inst,
1093  storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1094  IsVolatile, Builder),
1095  Builder);
1096  }
1097 
1098  /// Lowers llvm.matrix.column.major.store.
1099  ///
1100  /// The intrinsic store a matrix back memory using a stride between columns.
1101  void LowerColumnMajorStore(CallInst *Inst) {
1103  "Intrinsic only supports column-major layout!");
1104  Value *Matrix = Inst->getArgOperand(0);
1105  Value *Ptr = Inst->getArgOperand(1);
1106  Value *Stride = Inst->getArgOperand(2);
1107  LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1108  cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1109  {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1110  }
1111 
1112  // Set elements I..I+NumElts-1 to Block
1113  Value *insertVector(Value *Col, unsigned I, Value *Block,
1114  IRBuilder<> &Builder) {
1115 
1116  // First, bring Block to the same size as Col
1117  unsigned BlockNumElts =
1118  cast<FixedVectorType>(Block->getType())->getNumElements();
1119  unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1120  assert(NumElts >= BlockNumElts && "Too few elements for current block");
1121 
1122  Block = Builder.CreateShuffleVector(
1123  Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1124 
1125  // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1126  // 8, 4, 5, 6
1128  unsigned i;
1129  for (i = 0; i < I; i++)
1130  Mask.push_back(i);
1131 
1132  unsigned VecNumElts =
1133  cast<FixedVectorType>(Col->getType())->getNumElements();
1134  for (; i < I + BlockNumElts; i++)
1135  Mask.push_back(i - I + VecNumElts);
1136 
1137  for (; i < VecNumElts; i++)
1138  Mask.push_back(i);
1139 
1140  return Builder.CreateShuffleVector(Col, Block, Mask);
1141  }
1142 
1143  Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1144  IRBuilder<> &Builder, bool AllowContraction,
1145  unsigned &NumComputeOps) {
1146  NumComputeOps += getNumOps(A->getType());
1147  if (!Sum)
1148  return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1149 
1150  if (UseFPOp) {
1151  if (AllowContraction) {
1152  // Use fmuladd for floating point operations and let the backend decide
1153  // if that's profitable.
1155  Func.getParent(), Intrinsic::fmuladd, A->getType());
1156  return Builder.CreateCall(FMulAdd, {A, B, Sum});
1157  }
1158  NumComputeOps += getNumOps(A->getType());
1159  Value *Mul = Builder.CreateFMul(A, B);
1160  return Builder.CreateFAdd(Sum, Mul);
1161  }
1162 
1163  NumComputeOps += getNumOps(A->getType());
1164  Value *Mul = Builder.CreateMul(A, B);
1165  return Builder.CreateAdd(Sum, Mul);
1166  }
1167 
1168  /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1169  /// users with shape information, there's nothing to do: they will use the
1170  /// cached value when they are lowered. For other users, \p Matrix is
1171  /// flattened and the uses are updated to use it. Also marks \p Inst for
1172  /// deletion.
1173  void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1174  IRBuilder<> &Builder) {
1175  auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1176  (void)inserted;
1177  assert(inserted.second && "multiple matrix lowering mapping");
1178 
1179  ToRemove.push_back(Inst);
1180  Value *Flattened = nullptr;
1181  for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1182  if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1183  if (!Flattened)
1184  Flattened = Matrix.embedInVector(Builder);
1185  U.set(Flattened);
1186  }
1187  }
1188  }
1189 
1190  /// Compute \p Result += \p A * \p B for input matrices with left-associating
1191  /// addition.
1192  ///
1193  /// We can fold a transpose into the operand that is used to extract scalars.
1194  /// This is the first operands with row-major and the second with
1195  /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1196  /// operand is transposed.
1197  void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1198  const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1199  bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1200  const unsigned VF = std::max<unsigned>(
1202  .getFixedSize() /
1203  Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
1204  1U);
1205  unsigned R = Result.getNumRows();
1206  unsigned C = Result.getNumColumns();
1207  unsigned M = A.getNumColumns();
1208 
1209  bool IsFP = Result.getElementType()->isFloatingPointTy();
1210  assert(A.isColumnMajor() == B.isColumnMajor() &&
1211  Result.isColumnMajor() == A.isColumnMajor() &&
1212  "operands must agree on matrix layout");
1213  unsigned NumComputeOps = 0;
1214 
1215  Builder.setFastMathFlags(FMF);
1216 
1217  if (A.isColumnMajor()) {
1218  // Multiply columns from the first operand with scalars from the second
1219  // operand. Then move along the K axes and accumulate the columns. With
1220  // this the adds can be vectorized without reassociation.
1221  for (unsigned J = 0; J < C; ++J) {
1222  unsigned BlockSize = VF;
1223  // If Result is zero, we don't need to accumulate in the K==0 iteration.
1224  bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1225 
1226  for (unsigned I = 0; I < R; I += BlockSize) {
1227  // Gradually lower the vectorization factor to cover the remainder.
1228  while (I + BlockSize > R)
1229  BlockSize /= 2;
1230 
1231  Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1232  : nullptr;
1233  for (unsigned K = 0; K < M; ++K) {
1234  Value *L = A.extractVector(I, K, BlockSize, Builder);
1235  Value *RH = Builder.CreateExtractElement(
1236  B.getColumn(IsScalarMatrixTransposed ? K : J),
1237  IsScalarMatrixTransposed ? J : K);
1238  Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1239  Sum =
1240  createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1241  IsFP, Builder, FMF.allowContract(), NumComputeOps);
1242  }
1243  Result.setVector(J,
1244  insertVector(Result.getVector(J), I, Sum, Builder));
1245  }
1246  }
1247  } else {
1248  // Multiply rows from the second operand with scalars from the first
1249  // operand. Then move along the K axes and accumulate the rows. With this
1250  // the adds can be vectorized without reassociation.
1251  for (unsigned I = 0; I < R; ++I) {
1252  unsigned BlockSize = VF;
1253  bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1254  for (unsigned J = 0; J < C; J += BlockSize) {
1255  // Gradually lower the vectorization factor to cover the remainder.
1256  while (J + BlockSize > C)
1257  BlockSize /= 2;
1258 
1259  Value *Sum = nullptr;
1260  for (unsigned K = 0; K < M; ++K) {
1261  Value *R = B.extractVector(K, J, BlockSize, Builder);
1262  Value *LH = Builder.CreateExtractElement(
1263  A.getVector(IsScalarMatrixTransposed ? K : I),
1264  IsScalarMatrixTransposed ? I : K);
1265  Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1266  Sum =
1267  createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1268  IsFP, Builder, FMF.allowContract(), NumComputeOps);
1269  }
1270  Result.setVector(I,
1271  insertVector(Result.getVector(I), J, Sum, Builder));
1272  }
1273  }
1274  }
1275  Result.addNumComputeOps(NumComputeOps);
1276  }
1277 
1278  /// Ensure that the memory in \p Load does not alias \p Store by potentially
1279  /// copying it to a new location. This new or otherwise the original location
1280  /// is returned.
1281  Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1282  CallInst *MatMul) {
1285 
1286  // If we can statically determine noalias we're good.
1287  if (AA->isNoAlias(LoadLoc, StoreLoc))
1288  return Load->getPointerOperand();
1289 
1290  // Create code to check if the memory locations of the Load and Store
1291  // overlap and if they do, copy Load's operand to a new buffer.
1292 
1293  // First, create new blocks for 2n part of the check and the copy.
1294  BasicBlock *Check0 = MatMul->getParent();
1295  // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1296  // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1297  // as we adjust Check0 and Check1's branches.
1299  for (BasicBlock *Succ : successors(Check0))
1300  DTUpdates.push_back({DT->Delete, Check0, Succ});
1301 
1302  BasicBlock *Check1 =
1303  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1304  nullptr, "alias_cont");
1305  BasicBlock *Copy =
1306  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1307  nullptr, "copy");
1308  BasicBlock *Fusion =
1309  SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1310  nullptr, "no_alias");
1311 
1312  // Check if the loaded memory location begins before the end of the store
1313  // location. If the condition holds, they might overlap, otherwise they are
1314  // guaranteed to not overlap.
1315  IRBuilder<> Builder(MatMul);
1316  Check0->getTerminator()->eraseFromParent();
1317  Builder.SetInsertPoint(Check0);
1318  Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1319  Value *StoreBegin = Builder.CreatePtrToInt(
1320  const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1321  Value *StoreEnd = Builder.CreateAdd(
1322  StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1323  "store.end", true, true);
1324  Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1325  IntPtrTy, "load.begin");
1326  Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1327  Fusion);
1328 
1329  // Check if the store begins before the end of the load location. If the
1330  // condition holds, they alias, otherwise they are guaranteed to not
1331  // overlap.
1332  Check1->getTerminator()->eraseFromParent();
1333  Builder.SetInsertPoint(Check1, Check1->begin());
1334  Value *LoadEnd = Builder.CreateAdd(
1335  LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1336  "load.end", true, true);
1337  Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1338  Fusion);
1339 
1340  // Copy load operand to new alloca.
1341  Builder.SetInsertPoint(Copy, Copy->begin());
1342  AllocaInst *NewLd =
1343  Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace());
1344  Builder.CreateMemCpy(NewLd, NewLd->getAlign(),
1345  Load->getPointerOperand(), Load->getAlign(),
1346  LoadLoc.Size.getValue());
1347  Builder.SetInsertPoint(Fusion, Fusion->begin());
1348  PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1349  PHI->addIncoming(Load->getPointerOperand(), Check0);
1350  PHI->addIncoming(Load->getPointerOperand(), Check1);
1351  PHI->addIncoming(NewLd, Copy);
1352 
1353  // Adjust DT.
1354  DTUpdates.push_back({DT->Insert, Check0, Check1});
1355  DTUpdates.push_back({DT->Insert, Check0, Fusion});
1356  DTUpdates.push_back({DT->Insert, Check1, Copy});
1357  DTUpdates.push_back({DT->Insert, Check1, Fusion});
1358  DT->applyUpdates(DTUpdates);
1359  return PHI;
1360  }
1361 
1362  bool isFusionProfitable(CallInst *MatMul) {
1363  if (ForceFusion)
1364  return true;
1365 
1366  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1367  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1368 
1369  const unsigned R = LShape.NumRows;
1370  const unsigned C = RShape.NumColumns;
1371  const unsigned M = LShape.NumColumns;
1372  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1373 
1374  const unsigned VF = std::max<unsigned>(
1376  .getFixedSize() /
1377  EltType->getPrimitiveSizeInBits().getFixedSize(),
1378  1U);
1379 
1380  // Cost model for tiling
1381  //
1382  // For tiling to be beneficial, we need reuse either along the R or
1383  // the C axis. We vectorize along the R axis so that means at least
1384  // 3 elements.
1385  // TODO: Also consider cost of copying if operands alias.
1386  if (R <= VF && C == 1)
1387  return false;
1388  // Then we need enough elements to exceed the number of vector
1389  // registers we have. Note that this is an oversimplification since
1390  // fusing also takes some extra loads which may exceed the number of
1391  // reloads necessary.
1392  unsigned Op0Regs = (R + VF - 1) / VF * M;
1393  unsigned Op1Regs = (M + VF - 1) / VF * C;
1394  return Op0Regs + Op1Regs >
1396  }
1397 
1398  MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1399  MatrixTy Res;
1400  auto *ColumType = FixedVectorType::get(EltType, R);
1401  for (unsigned I = 0; I < C; ++I)
1402  Res.addVector(ConstantAggregateZero::get(ColumType));
1403  return Res;
1404  }
1405 
1406  void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1407  Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1408  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1409 
1410  // Create the main tiling loop nest.
1411  TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1413  Instruction *InsertI = cast<Instruction>(MatMul);
1414  BasicBlock *Start = InsertI->getParent();
1415  BasicBlock *End =
1416  SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1417  IRBuilder<> Builder(MatMul);
1418  BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1419 
1420  Type *TileVecTy =
1422  MatrixTy TileResult;
1423  // Insert in the inner loop header.
1424  Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
1425  // Create PHI nodes for the result columns to accumulate across iterations.
1426  SmallVector<PHINode *, 4> ColumnPhis;
1427  for (unsigned I = 0; I < TileSize; I++) {
1428  auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1429  Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1430  TI.RowLoopHeader->getSingleSuccessor());
1431  TileResult.addVector(Phi);
1432  ColumnPhis.push_back(Phi);
1433  }
1434 
1435  // Insert in the inner loop body, which computes
1436  // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1437  Builder.SetInsertPoint(InnerBody->getTerminator());
1438  // Load tiles of the operands.
1439  MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK,
1440  {TileSize, TileSize}, EltType, Builder);
1441  MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol,
1442  {TileSize, TileSize}, EltType, Builder);
1443  emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1444  getFastMathFlags(MatMul));
1445  // Store result after the inner loop is done.
1446  Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
1447  storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1448  Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1449  TI.CurrentRow, TI.CurrentCol, EltType, Builder);
1450 
1451  for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1452  ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch);
1453 
1454  // Force unrolling of a few iterations of the inner loop, to make sure there
1455  // is enough work per iteration.
1456  // FIXME: The unroller should make this decision directly instead, but
1457  // currently the cost-model is not up to the task.
1458  unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1459  addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader),
1460  "llvm.loop.unroll.count", InnerLoopUnrollCount);
1461  }
1462 
1463  void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1464  StoreInst *Store,
1465  SmallPtrSetImpl<Instruction *> &FusedInsts) {
1467  "Tiling only supported for column-major matrixes at the moment!");
1468  if (!isFusionProfitable(MatMul))
1469  return;
1470 
1471  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1472  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1473 
1474  const unsigned R = LShape.NumRows;
1475  const unsigned C = RShape.NumColumns;
1476  const unsigned M = LShape.NumColumns;
1477  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1478 
1479  Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1480  Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1481  Value *CPtr = Store->getPointerOperand();
1482 
1483  if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1484  createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1485  else {
1487  for (unsigned J = 0; J < C; J += TileSize)
1488  for (unsigned I = 0; I < R; I += TileSize) {
1489  const unsigned TileR = std::min(R - I, unsigned(TileSize));
1490  const unsigned TileC = std::min(C - J, unsigned(TileSize));
1491  MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1492 
1493  for (unsigned K = 0; K < M; K += TileSize) {
1494  const unsigned TileM = std::min(M - K, unsigned(TileSize));
1495  MatrixTy A =
1496  loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1497  LShape, Builder.getInt64(I), Builder.getInt64(K),
1498  {TileR, TileM}, EltType, Builder);
1499  MatrixTy B =
1500  loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1501  RShape, Builder.getInt64(K), Builder.getInt64(J),
1502  {TileM, TileC}, EltType, Builder);
1503  emitMatrixMultiply(Res, A, B, Builder, true, false,
1504  getFastMathFlags(MatMul));
1505  }
1506  storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1507  Builder.getInt64(I), Builder.getInt64(J), EltType,
1508  Builder);
1509  }
1510  }
1511 
1512  // Mark eliminated instructions as fused and remove them.
1513  FusedInsts.insert(Store);
1514  FusedInsts.insert(MatMul);
1515  Store->eraseFromParent();
1516  MatMul->eraseFromParent();
1517  if (LoadOp0->hasNUses(0)) {
1518  FusedInsts.insert(LoadOp0);
1519  LoadOp0->eraseFromParent();
1520  }
1521  if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1522  FusedInsts.insert(LoadOp1);
1523  LoadOp1->eraseFromParent();
1524  }
1525  }
1526 
1527  /// Try to lower matrix multiply chains by fusing operations.
1528  ///
1529  /// Call finalizeLowering on lowered instructions. Instructions that are
1530  /// completely eliminated by fusion are added to \p FusedInsts.
1531  void LowerMatrixMultiplyFused(CallInst *MatMul,
1532  SmallPtrSetImpl<Instruction *> &FusedInsts) {
1533  if (!FuseMatrix || !DT)
1534  return;
1535 
1536  assert(AA && LI && "Analyses should be available");
1537 
1538  Value *A = MatMul->getArgOperand(0);
1539  Value *B = MatMul->getArgOperand(1);
1540 
1541  // We can fold the transpose into the operand that is used to fetch scalars.
1542  Value *T;
1544  ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1545  : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1546  IRBuilder<> Builder(MatMul);
1547  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1548  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1549  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1550  const unsigned R = LShape.NumRows;
1551  const unsigned M = LShape.NumColumns;
1552  const unsigned C = RShape.NumColumns;
1553 
1554  MatrixTy MA;
1555  MatrixTy MB;
1556 
1557  Value *Transpose;
1559  MA = getMatrix(A, ShapeInfo(R, M), Builder);
1560  MB = getMatrix(T, ShapeInfo(C, M), Builder);
1561  Transpose = B;
1562  } else {
1563  MA = getMatrix(T, ShapeInfo(R, M), Builder);
1564  MB = getMatrix(B, ShapeInfo(C, M), Builder);
1565  Transpose = A;
1566  }
1567 
1568  // Initialize the output
1569  MatrixTy Result(R, C, EltType);
1570 
1571  emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1572  getFastMathFlags(MatMul));
1573 
1574  FusedInsts.insert(MatMul);
1575  if (Transpose->hasOneUse()) {
1576  FusedInsts.insert(cast<Instruction>(Transpose));
1577  ToRemove.push_back(cast<Instruction>(Transpose));
1578  // TODO: add a fake entry for the folded instruction so that this is
1579  // included in the expression in the remark.
1580  Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1581  }
1582  finalizeLowering(MatMul, Result, Builder);
1583  return;
1584  }
1585 
1586  if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1587  return;
1588 
1589  // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1590  // since the single store user will be lowered as part of this.
1591  auto *LoadOp0 = dyn_cast<LoadInst>(A);
1592  auto *LoadOp1 = dyn_cast<LoadInst>(B);
1593  auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1594  if (LoadOp0 && LoadOp1 && Store) {
1595  // The store address must dominate the MatMul instruction, otherwise
1596  // we create invalid IR.
1597  SetVector<Value *> WorkList;
1598  WorkList.insert(Store->getOperand(1));
1600  for (unsigned I = 0; I != WorkList.size(); ++I) {
1601  Value *Current = WorkList[I];
1602  auto *CurrI = dyn_cast<Instruction>(Current);
1603  if (!CurrI)
1604  continue;
1605  if (isa<PHINode>(CurrI))
1606  return;
1607  if (DT->dominates(CurrI, MatMul))
1608  continue;
1609  if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1610  return;
1611  ToHoist.push_back(CurrI);
1612  WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1613  }
1614 
1615  sort(ToHoist, [this](Instruction *A, Instruction *B) {
1616  return DT->dominates(A, B);
1617  });
1618  for (Instruction *I : ToHoist)
1619  I->moveBefore(MatMul);
1620 
1621  emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1622  return;
1623  }
1624  }
1625 
1626  /// Lowers llvm.matrix.multiply.
1627  void LowerMultiply(CallInst *MatMul) {
1628  IRBuilder<> Builder(MatMul);
1629  auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1630  ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1631  ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1632 
1633  const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1634  const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1635  assert(Lhs.getElementType() == Rhs.getElementType() &&
1636  "Matrix multiply argument element types do not match.");
1637 
1638  const unsigned R = LShape.NumRows;
1639  const unsigned C = RShape.NumColumns;
1640  assert(LShape.NumColumns == RShape.NumRows);
1641 
1642  // Initialize the output
1643  MatrixTy Result(R, C, EltType);
1644  assert(Lhs.getElementType() == Result.getElementType() &&
1645  "Matrix multiply result element type does not match arguments.");
1646 
1647  emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
1648  getFastMathFlags(MatMul));
1649  finalizeLowering(MatMul, Result, Builder);
1650  }
1651 
1652  /// Lowers llvm.matrix.transpose.
1653  void LowerTranspose(CallInst *Inst) {
1654  MatrixTy Result;
1655  IRBuilder<> Builder(Inst);
1656  Value *InputVal = Inst->getArgOperand(0);
1657  VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1658  ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1659  MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1660 
1661  const unsigned NewNumVecs =
1662  InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1663  const unsigned NewNumElts =
1664  InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1665 
1666  for (unsigned I = 0; I < NewNumVecs; ++I) {
1667  // Build a single result vector. First initialize it.
1668  Value *ResultVector = UndefValue::get(
1669  FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1670  // Go through the old elements and insert it into the resulting vector.
1671  for (auto J : enumerate(InputMatrix.vectors())) {
1672  Value *Elt = Builder.CreateExtractElement(J.value(), I);
1673  // Row and column indices are transposed.
1674  ResultVector =
1675  Builder.CreateInsertElement(ResultVector, Elt, J.index());
1676  }
1677  Result.addVector(ResultVector);
1678  }
1679 
1680  // TODO: Improve estimate of operations needed for transposes. Currently we
1681  // just count the insertelement/extractelement instructions, but do not
1682  // account for later simplifications/combines.
1683  finalizeLowering(
1684  Inst,
1685  Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1686  .addNumExposedTransposes(1),
1687  Builder);
1688  }
1689 
1690  /// Lower load instructions, if shape information is available.
1691  bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
1692  auto I = ShapeMap.find(Inst);
1693  if (I == ShapeMap.end())
1694  return false;
1695 
1696  LowerLoad(Inst, Ptr, Inst->getAlign(),
1697  Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1698  I->second);
1699  return true;
1700  }
1701 
1702  bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
1703  IRBuilder<> &Builder) {
1704  auto I = ShapeMap.find(StoredVal);
1705  if (I == ShapeMap.end())
1706  return false;
1707 
1708  LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
1709  Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1710  I->second);
1711  return true;
1712  }
1713 
1714  /// Lower binary operators, if shape information is available.
1715  bool VisitBinaryOperator(BinaryOperator *Inst) {
1716  auto I = ShapeMap.find(Inst);
1717  if (I == ShapeMap.end())
1718  return false;
1719 
1720  Value *Lhs = Inst->getOperand(0);
1721  Value *Rhs = Inst->getOperand(1);
1722 
1723  IRBuilder<> Builder(Inst);
1724  ShapeInfo &Shape = I->second;
1725 
1726  MatrixTy Result;
1727  MatrixTy A = getMatrix(Lhs, Shape, Builder);
1728  MatrixTy B = getMatrix(Rhs, Shape, Builder);
1729  assert(A.isColumnMajor() == B.isColumnMajor() &&
1730  Result.isColumnMajor() == A.isColumnMajor() &&
1731  "operands must agree on matrix layout");
1732 
1733  Builder.setFastMathFlags(getFastMathFlags(Inst));
1734 
1735  // Helper to perform binary op on vectors.
1736  auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
1737  switch (Inst->getOpcode()) {
1738  case Instruction::Add:
1739  return Builder.CreateAdd(LHS, RHS);
1740  case Instruction::Mul:
1741  return Builder.CreateMul(LHS, RHS);
1742  case Instruction::Sub:
1743  return Builder.CreateSub(LHS, RHS);
1744  case Instruction::FAdd:
1745  return Builder.CreateFAdd(LHS, RHS);
1746  case Instruction::FMul:
1747  return Builder.CreateFMul(LHS, RHS);
1748  case Instruction::FSub:
1749  return Builder.CreateFSub(LHS, RHS);
1750  default:
1751  llvm_unreachable("Unsupported binary operator for matrix");
1752  }
1753  };
1754 
1755  for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1756  Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
1757 
1758  finalizeLowering(Inst,
1759  Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1760  Result.getNumVectors()),
1761  Builder);
1762  return true;
1763  }
1764 
1765  /// Lower unary operators, if shape information is available.
1766  bool VisitUnaryOperator(UnaryOperator *Inst) {
1767  auto I = ShapeMap.find(Inst);
1768  if (I == ShapeMap.end())
1769  return false;
1770 
1771  Value *Op = Inst->getOperand(0);
1772 
1773  IRBuilder<> Builder(Inst);
1774  ShapeInfo &Shape = I->second;
1775 
1776  MatrixTy Result;
1777  MatrixTy M = getMatrix(Op, Shape, Builder);
1778 
1779  Builder.setFastMathFlags(getFastMathFlags(Inst));
1780 
1781  // Helper to perform unary op on vectors.
1782  auto BuildVectorOp = [&Builder, Inst](Value *Op) {
1783  switch (Inst->getOpcode()) {
1784  case Instruction::FNeg:
1785  return Builder.CreateFNeg(Op);
1786  default:
1787  llvm_unreachable("Unsupported unary operator for matrix");
1788  }
1789  };
1790 
1791  for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1792  Result.addVector(BuildVectorOp(M.getVector(I)));
1793 
1794  finalizeLowering(Inst,
1795  Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1796  Result.getNumVectors()),
1797  Builder);
1798  return true;
1799  }
1800 
1801  /// Helper to linearize a matrix expression tree into a string. Currently
1802  /// matrix expressions are linarized by starting at an expression leaf and
1803  /// linearizing bottom up.
1804  struct ExprLinearizer {
1805  unsigned LengthToBreak = 100;
1806  std::string Str;
1807  raw_string_ostream Stream;
1808  unsigned LineLength = 0;
1809  const DataLayout &DL;
1810 
1811  /// Mapping from instructions to matrixes. It is used to identify
1812  /// matrix instructions.
1813  const MapVector<Value *, MatrixTy> &Inst2Matrix;
1814 
1815  /// Mapping from values to the leaves of all expressions that the value is
1816  /// part of.
1818 
1819  /// Set of matrix expressions in the scope of a given DISubprogram.
1820  const SmallSetVector<Value *, 32> &ExprsInSubprogram;
1821 
1822  /// Leaf node of the expression to linearize.
1823  Value *Leaf;
1824 
1825  /// Used to keep track of sub-expressions that get reused while linearizing
1826  /// the expression. Re-used sub-expressions are marked as (reused).
1827  SmallPtrSet<Value *, 8> ReusedExprs;
1828 
1829  ExprLinearizer(const DataLayout &DL,
1830  const MapVector<Value *, MatrixTy> &Inst2Matrix,
1831  const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1832  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1833  Value *Leaf)
1834  : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
1835  ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
1836 
1837  void indent(unsigned N) {
1838  LineLength += N;
1839  for (unsigned i = 0; i < N; i++)
1840  Stream << " ";
1841  }
1842 
1843  void lineBreak() {
1844  Stream << "\n";
1845  LineLength = 0;
1846  }
1847 
1848  void maybeIndent(unsigned Indent) {
1849  if (LineLength >= LengthToBreak)
1850  lineBreak();
1851 
1852  if (LineLength == 0)
1853  indent(Indent);
1854  }
1855 
1856  void write(StringRef S) {
1857  LineLength += S.size();
1858  Stream << S;
1859  }
1860 
1861  Value *getUnderlyingObjectThroughLoads(Value *V) {
1862  if (Value *Ptr = getPointerOperand(V))
1863  return getUnderlyingObjectThroughLoads(Ptr);
1864  else if (V->getType()->isPointerTy())
1865  return getUnderlyingObject(V);
1866  return V;
1867  }
1868 
1869  /// Returns true if \p V is a matrix value in the given subprogram.
1870  bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
1871 
1872  /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
1873  /// \p SS.
1874  void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
1875  auto M = Inst2Matrix.find(V);
1876  if (M == Inst2Matrix.end())
1877  SS << "unknown";
1878  else {
1879  SS << M->second.getNumRows();
1880  SS << "x";
1881  SS << M->second.getNumColumns();
1882  }
1883  }
1884 
1885  /// Write the called function name. Handles calls to llvm.matrix.*
1886  /// specially: we write the name, followed by the dimensions of the input
1887  /// matrixes, followed by the scalar type name.
1888  void writeFnName(CallInst *CI) {
1889  if (!CI->getCalledFunction())
1890  write("<no called fn>");
1891  else {
1893  if (!Name.startswith("llvm.matrix")) {
1894  write(Name);
1895  return;
1896  }
1897  auto *II = cast<IntrinsicInst>(CI);
1899  .drop_front(StringRef("llvm.matrix.").size()));
1900  write(".");
1901  std::string Tmp;
1902  raw_string_ostream SS(Tmp);
1903 
1904  switch (II->getIntrinsicID()) {
1905  case Intrinsic::matrix_multiply:
1906  prettyPrintMatrixType(II->getOperand(0), SS);
1907  SS << ".";
1908  prettyPrintMatrixType(II->getOperand(1), SS);
1909  SS << "." << *II->getType()->getScalarType();
1910  break;
1911  case Intrinsic::matrix_transpose:
1912  prettyPrintMatrixType(II->getOperand(0), SS);
1913  SS << "." << *II->getType()->getScalarType();
1914  break;
1915  case Intrinsic::matrix_column_major_load:
1916  prettyPrintMatrixType(II, SS);
1917  SS << "." << *II->getType()->getScalarType();
1918  break;
1919  case Intrinsic::matrix_column_major_store:
1920  prettyPrintMatrixType(II->getOperand(0), SS);
1921  SS << "." << *II->getOperand(0)->getType()->getScalarType();
1922  break;
1923  default:
1924  llvm_unreachable("Unhandled case");
1925  }
1926  SS.flush();
1927  write(Tmp);
1928  }
1929  }
1930 
1931  unsigned getNumShapeArgs(CallInst *CI) const {
1932  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
1933  switch (II->getIntrinsicID()) {
1934  case Intrinsic::matrix_multiply:
1935  return 3;
1936  case Intrinsic::matrix_transpose:
1937  return 2;
1938  case Intrinsic::matrix_column_major_load:
1939  case Intrinsic::matrix_column_major_store:
1940  return 3;
1941  default:
1942  return 0;
1943  }
1944  }
1945  return 0;
1946  }
1947 
1948  /// Special printing for values: for pointers, we print if they refer to an
1949  /// (function) external address or a stack address, for other values we
1950  /// either print the constant or "scalar"/"matrix" for other values.
1951  void write(Value *V) {
1952  V = getUnderlyingObjectThroughLoads(V);
1953  if (V->getType()->isPointerTy()) {
1954  if (isa<AllocaInst>(V)) {
1955  Stream << "stack addr";
1956  LineLength += StringRef("stack addr").size();
1957  } else {
1958  Stream << "addr";
1959  LineLength += StringRef("addr").size();
1960  }
1961  if (!V->getName().empty()) {
1962  Stream << " %" << V->getName() << "";
1963  LineLength += V->getName().size() + 2;
1964  }
1965  return;
1966  }
1967 
1968  std::string Tmp;
1969  raw_string_ostream TmpStream(Tmp);
1970 
1971  if (auto *CI = dyn_cast<ConstantInt>(V))
1972  TmpStream << CI->getValue();
1973  else if (isa<Constant>(V))
1974  TmpStream << "constant";
1975  else {
1976  if (isMatrix(V))
1977  TmpStream << "matrix";
1978  else
1979  TmpStream << "scalar";
1980  }
1981  TmpStream.flush();
1982  Tmp = std::string(StringRef(Tmp).trim());
1983  LineLength += Tmp.size();
1984  Stream << Tmp;
1985  }
1986 
1987  /// Linearize expression \p Expr starting at an indentation of \p Indent.
1988  /// Expressions that are re-used multiple times are prefixed with (reused)
1989  /// at the re-used root instruction.
1990  void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
1991  bool ParentShared) {
1992  auto *I = cast<Instruction>(Expr);
1993  maybeIndent(Indent);
1995 
1996  // Is Expr shared with other expression leaves?
1997  bool ExprShared = false;
1998 
1999  // Deal with shared subtrees. Mark them as shared, if required.
2000  if (!ParentShared) {
2001  auto SI = Shared.find(Expr);
2002  assert(SI != Shared.end() && SI->second.count(Leaf));
2003 
2004  for (Value *S : SI->second) {
2005  if (S == Leaf)
2006  continue;
2007  DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2008  write("shared with remark at line " + std::to_string(DL.getLine()) +
2009  " column " + std::to_string(DL.getCol()) + " (");
2010  }
2011  ExprShared = SI->second.size() > 1;
2012  }
2013 
2014  bool Reused = !ReusedExprs.insert(Expr).second;
2015  if (Reused && !ParentReused)
2016  write("(reused) ");
2017 
2018  if (auto *CI = dyn_cast<CallInst>(I)) {
2019  writeFnName(CI);
2020 
2021  Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2022  } else if (isa<BitCastInst>(Expr)) {
2023  // Special case bitcasts, which are used to materialize matrixes from
2024  // non-matrix ops.
2025  write("matrix");
2026  return;
2027  } else {
2028  Ops.append(I->value_op_begin(), I->value_op_end());
2029  write(std::string(I->getOpcodeName()));
2030  }
2031 
2032  write(std::string("("));
2033 
2034  unsigned NumOpsToBreak = 1;
2035  if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2036  NumOpsToBreak = 2;
2037 
2038  for (Value *Op : Ops) {
2039  if (Ops.size() > NumOpsToBreak)
2040  lineBreak();
2041 
2042  maybeIndent(Indent + 1);
2043  if (isMatrix(Op))
2044  linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2045  else
2046  write(Op);
2047  if (Op != Ops.back())
2048  write(", ");
2049  }
2050 
2051  write(")");
2052  }
2053 
2054  const std::string &getResult() {
2055  Stream.flush();
2056  return Str;
2057  }
2058  };
2059 
2060  /// Generate remarks for matrix operations in a function. To generate remarks
2061  /// for matrix expressions, the following approach is used:
2062  /// 1. Use the inlined-at debug information to group matrix operations to the
2063  /// DISubprograms they are contained in.
2064  /// 2. Collect leaves of matrix expressions (done in
2065  /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2066  // mapping. Leaves are lowered matrix instructions without other matrix
2067  // users (like stores) in the current subprogram.
2068  /// 3. For each leaf, create a remark containing a linearizied version of the
2069  /// matrix expression. The expression is linearized by a recursive
2070  /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2071  /// that multiple leaves can share sub-expressions. Shared subexpressions
2072  /// are explicitly marked as shared().
2073  struct RemarkGenerator {
2074  const MapVector<Value *, MatrixTy> &Inst2Matrix;
2076  Function &Func;
2077  const DataLayout &DL;
2078 
2079  RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2080  OptimizationRemarkEmitter &ORE, Function &Func)
2081  : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2082  DL(Func.getParent()->getDataLayout()) {}
2083 
2084  /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2085  /// instructions in Inst2Matrix returning void or without any users in
2086  /// \p ExprsInSubprogram. Currently that should only include stores.
2088  getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2089  SmallVector<Value *, 4> Leaves;
2090  for (auto *Expr : ExprsInSubprogram)
2091  if (Expr->getType()->isVoidTy() ||
2092  !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2093  return ExprsInSubprogram.count(U);
2094  }))
2095  Leaves.push_back(Expr);
2096  return Leaves;
2097  }
2098 
2099  /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2100  /// to all visited expressions in \p Shared. Limit the matrix operations to
2101  /// the ones in \p ExprsInSubprogram.
2102  void collectSharedInfo(Value *Leaf, Value *V,
2103  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2104  DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
2105 
2106  if (!ExprsInSubprogram.count(V))
2107  return;
2108 
2109  auto I = Shared.insert({V, {}});
2110  I.first->second.insert(Leaf);
2111 
2112  for (Value *Op : cast<Instruction>(V)->operand_values())
2113  collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2114  }
2115 
2116  /// Calculate the number of exclusive and shared op counts for expression
2117  /// starting at \p V. Expressions used multiple times are counted once.
2118  /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2119  std::pair<OpInfoTy, OpInfoTy>
2120  sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2121  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2122  DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2123  if (!ExprsInSubprogram.count(Root))
2124  return {};
2125 
2126  // Already counted this expression. Stop.
2127  if (!ReusedExprs.insert(Root).second)
2128  return {};
2129 
2130  OpInfoTy SharedCount;
2131  OpInfoTy Count;
2132 
2133  auto I = Shared.find(Root);
2134  auto CM = Inst2Matrix.find(Root);
2135  if (I->second.size() == 1)
2136  Count = CM->second.getOpInfo();
2137  else
2138  SharedCount = CM->second.getOpInfo();
2139 
2140  for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2141  auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2142  Count += C.first;
2143  SharedCount += C.second;
2144  }
2145  return {Count, SharedCount};
2146  }
2147 
2148  void emitRemarks() {
2149  if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2150  return;
2151 
2152  // Map matrix operations to their containting subprograms, by traversing
2153  // the inlinedAt chain. If the function does not have a DISubprogram, we
2154  // only map them to the containing function.
2156  for (auto &KV : Inst2Matrix) {
2157  if (Func.getSubprogram()) {
2158  auto *I = cast<Instruction>(KV.first);
2159  DILocation *Context = I->getDebugLoc();
2160  while (Context) {
2161  auto I =
2162  Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2163  I.first->second.push_back(KV.first);
2165  }
2166  } else {
2167  auto I = Subprog2Exprs.insert({nullptr, {}});
2168  I.first->second.push_back(KV.first);
2169  }
2170  }
2171  for (auto &KV : Subprog2Exprs) {
2172  SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2173  KV.second.end());
2174  auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2175 
2177  for (Value *Leaf : Leaves)
2178  collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2179 
2180  // Generate remarks for each leaf.
2181  for (auto *L : Leaves) {
2182 
2183  DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2184  DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2185  while (Context) {
2186  if (getSubprogram(Context->getScope()) == KV.first) {
2187  Loc = Context;
2188  break;
2189  }
2191  }
2192 
2193  SmallPtrSet<Value *, 8> ReusedExprs;
2194  OpInfoTy Counts, SharedCounts;
2195  std::tie(Counts, SharedCounts) =
2196  sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2197 
2198  OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2199  cast<Instruction>(L)->getParent());
2200 
2201  Rem << "Lowered with ";
2202  Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2203  << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2204  << ore::NV("NumComputeOps", Counts.NumComputeOps)
2205  << " compute ops, "
2206  << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2207  << " exposed transposes";
2208 
2209  if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2210  SharedCounts.NumComputeOps > 0) {
2211  Rem << ",\nadditionally "
2212  << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2213  << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2214  << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2215  << " compute ops"
2216  << " are shared with other expressions";
2217  }
2218 
2219  Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2220  ORE.emit(Rem);
2221  }
2222  }
2223  }
2224 
2225  std::string
2226  linearize(Value *L,
2227  const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2228  const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2229  const DataLayout &DL) {
2230  ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2231  Lin.linearizeExpr(L, 0, false, false);
2232  return Lin.getResult();
2233  }
2234  };
2235 };
2236 } // namespace
2237 
2240  auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2241  OptimizationRemarkEmitter *ORE = nullptr;
2242  AAResults *AA = nullptr;
2243  DominatorTree *DT = nullptr;
2244  LoopInfo *LI = nullptr;
2245 
2246  if (!Minimal) {
2248  AA = &AM.getResult<AAManager>(F);
2249  DT = &AM.getResult<DominatorTreeAnalysis>(F);
2250  LI = &AM.getResult<LoopAnalysis>(F);
2251  }
2252 
2253  LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2254  if (LMT.Visit()) {
2255  PreservedAnalyses PA;
2256  if (!Minimal) {
2257  PA.preserve<LoopAnalysis>();
2259  }
2260  return PA;
2261  }
2262  return PreservedAnalyses::all();
2263 }
2264 
2266  raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2268  OS, MapClassName2PassName);
2269  OS << "<";
2270  if (Minimal)
2271  OS << "minimal";
2272  OS << ">";
2273 }
2274 
2275 namespace {
2276 
2277 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
2278 public:
2279  static char ID;
2280 
2281  LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
2284  }
2285 
2286  bool runOnFunction(Function &F) override {
2287  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2288  auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2289  auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2290  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2291  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2292  LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
2293  bool C = LMT.Visit();
2294  return C;
2295  }
2296 
2297  void getAnalysisUsage(AnalysisUsage &AU) const override {
2305  }
2306 };
2307 } // namespace
2308 
2309 static const char pass_name[] = "Lower the matrix intrinsics";
2311 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2312  false, false)
2317 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2319 
2321  return new LowerMatrixIntrinsicsLegacyPass();
2322 }
2323 
2324 namespace {
2325 
2326 /// A lightweight version of the matrix lowering pass that only requires TTI.
2327 /// Advanced features that require DT, AA or ORE like tiling are disabled. This
2328 /// is used to lower matrix intrinsics if the main lowering pass is not run, for
2329 /// example with -O0.
2330 class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
2331 public:
2332  static char ID;
2333 
2334  LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
2337  }
2338 
2339  bool runOnFunction(Function &F) override {
2340  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2341  LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
2342  bool C = LMT.Visit();
2343  return C;
2344  }
2345 
2346  void getAnalysisUsage(AnalysisUsage &AU) const override {
2348  AU.setPreservesCFG();
2349  }
2350 };
2351 } // namespace
2352 
2353 static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
2355 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
2356  "lower-matrix-intrinsics-minimal", pass_name_minimal,
2357  false, false)
2358 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
2360  false)
2361 
2363  return new LowerMatrixIntrinsicsMinimalLegacyPass();
2364 }
llvm::Intrinsic::getBaseName
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:874
i
i
Definition: README.txt:29
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:155
pass_name
static const char pass_name[]
Definition: LowerMatrixIntrinsics.cpp:2309
llvm::AAManager
A manager for alias analyses.
Definition: AliasAnalysis.h:1287
llvm::TargetIRAnalysis
Analysis pass providing the TargetTransformInfo.
Definition: TargetTransformInfo.h:2418
llvm::cast
std::enable_if_t<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type > cast(const Y &Val)
Definition: Casting.h:254
llvm::Function::isIntrinsic
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:212
llvm::MemoryLocation::get
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
Definition: MemoryLocation.cpp:37
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AllocatorList.h:23
BlockSize
static const int BlockSize
Definition: TarWriter.cpp:33
llvm::DILocalScope::getSubprogram
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Definition: DebugInfoMetadata.cpp:828
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::make_range
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
Definition: iterator_range.h:53
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
llvm::Intrinsic::getDeclaration
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1399
llvm::User::operands
op_range operands()
Definition: User.h:242
llvm::BasicBlock::iterator
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:90
llvm::AllocaInst::getAlign
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:125
IntrinsicInst.h
llvm::Type::isPointerTy
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:217
ceil
We have fiadd patterns now but the followings have the same cost and complexity We need a way to specify the later is more profitable def def The FP stackifier should handle simple permutates to reduce number of shuffle e g ceil
Definition: README-FPStack.txt:54
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:783
DebugInfoMetadata.h
llvm::MemoryLocation::Ptr
const Value * Ptr
The address of the start of the location.
Definition: MemoryLocation.h:217
llvm::ValueMap::end
iterator end()
Definition: ValueMap.h:136
Scalar.h
llvm::PassInfoMixin< LowerMatrixIntrinsicsPass >
LowerLoad
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition: X86ISelLowering.cpp:25166
llvm::TypeSize::getFixedSize
ScalarTy getFixedSize() const
Definition: TypeSize.h:425
T
llvm::Function
Definition: Function.h:62
Pass.h
llvm::TargetTransformInfo::getRegisterBitWidth
TypeSize getRegisterBitWidth(RegisterKind K) const
Definition: TargetTransformInfo.cpp:609
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:52
LowerStore
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition: X86ISelLowering.cpp:25081
llvm::raw_string_ostream
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:631
llvm::write
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs)
Definition: DWP.cpp:535
llvm::PointerType::get
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Definition: Type.cpp:729
llvm::SetVector::size
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:77
llvm::Type::getScalarType
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:308
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1177
llvm::enumerate
detail::enumerator< R > enumerate(R &&TheRange)
Given an input range, returns a new range whose values are are pair (A,B) such that A is the 0-based ...
Definition: STLExtras.h:2080
llvm::PatternMatch::m_Load
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
Definition: PatternMatch.h:1573
llvm::TargetTransformInfo
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Definition: TargetTransformInfo.h:168
insertVector
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2188
ToRemove
ReachingDefAnalysis InstSet & ToRemove
Definition: ARMLowOverheadLoops.cpp:540
llvm::IRBuilder<>
llvm::AAResults::isNoAlias
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
Definition: AliasAnalysis.h:561
DomTreeUpdater.h
ValueTracking.h
OptimizationRemarkEmitter.h
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:151
llvm::OptimizationRemarkEmitter::allowExtraAnalysis
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
Definition: OptimizationRemarkEmitter.h:98
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:143
llvm::DILocation
Debug location.
Definition: DebugInfoMetadata.h:1580
ForceFusion
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
llvm::MemoryLocation::Size
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
Definition: MemoryLocation.h:226
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::DebugLoc::getInlinedAt
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:40
llvm::reverse
auto reverse(ContainerTy &&C, std::enable_if_t< has_rbegin< ContainerTy >::value > *=nullptr)
Definition: STLExtras.h:414
llvm::sys::path::end
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:236
llvm::sys::path::begin
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:227
llvm::LoopInfoWrapperPass
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:1271
llvm::DominatorTreeBase::Insert
static constexpr UpdateKind Insert
Definition: GenericDomTree.h:242
T1
#define T1
Definition: Mips16ISelLowering.cpp:340
llvm::SmallSet
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:134
llvm::operator!=
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:1988
vectors
hexagon Hexagon specific predictive commoning for HVX vectors
Definition: HexagonVectorLoopCarriedReuse.cpp:221
Offset
uint64_t Offset
Definition: ELFObjHandler.cpp:80
llvm::MapVector
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:37
llvm::SmallPtrSet
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:449
llvm::ore::NV
DiagnosticInfoOptimizationBase::Argument NV
Definition: OptimizationRemarkEmitter.h:136
llvm::VectorType::getElementType
Type * getElementType() const
Definition: DerivedTypes.h:422
llvm::Value::user_begin
user_iterator user_begin()
Definition: Value.h:397
llvm::successors
auto successors(MachineBasicBlock *BB)
Definition: MachineSSAContext.h:31
llvm::CallBase::arg_begin
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1318
llvm::SmallVectorImpl::pop_back_val
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:644
RHS
Value * RHS
Definition: X86PartialReduction.cpp:74
llvm::UnaryOperator
Definition: InstrTypes.h:103
llvm::FastMathFlags
Convenience struct for specifying and reasoning about fast-math flags.
Definition: Operator.h:165
llvm::initializeLowerMatrixIntrinsicsMinimalLegacyPassPass
void initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(PassRegistry &)
llvm::LoadInst::getAlign
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:228
llvm::BitmaskEnumDetail::Mask
std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:80
llvm::DominatorTreeBase::applyUpdates
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
Definition: GenericDomTree.h:544
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
llvm::NVPTX::PTXLdStInstCode::VecType
VecType
Definition: NVPTX.h:121
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::RISCVFenceField::R
@ R
Definition: RISCVBaseInfo.h:207
llvm::DomTreeUpdater::UpdateStrategy::Lazy
@ Lazy
llvm::TileInfo
A helper struct to create IR loop nests for tiling in IR of the following form: for CurrentColumn = 0...
Definition: MatrixUtils.h:31
Context
ManagedStatic< detail::RecordContext > Context
Definition: Record.cpp:96
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
getSubprogram
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
Definition: LowerMatrixIntrinsics.cpp:85
AliasAnalysis.h
MatrixBuilder.h
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
GraphTraits.h
llvm::DominatorTree::dominates
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:115
CommandLine.h
llvm::UnaryOperator::getOpcode
UnaryOps getOpcode() const
Definition: InstrTypes.h:173
extractVector
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2163
LHS
Value * LHS
Definition: X86PartialReduction.cpp:73
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
llvm::Intrinsic::getType
FunctionType * getType(LLVMContext &Context, ID id, ArrayRef< Type * > Tys=None)
Return the function type for an intrinsic.
Definition: Function.cpp:1355
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
MatrixUtils.h
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
isZero
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:519
llvm::AAResults
Definition: AliasAnalysis.h:507
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::SmallVectorImpl::append
void append(in_iter in_start, in_iter in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:657
llvm::createLowerMatrixIntrinsicsPass
Pass * createLowerMatrixIntrinsicsPass()
Definition: LowerMatrixIntrinsics.cpp:2320
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::ARM_PROC::A
@ A
Definition: ARMBaseInfo.h:34
TileSize
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation.
Definition: InstrTypes.h:1398
llvm::BasicBlock::begin
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:296
llvm::operator+=
std::string & operator+=(std::string &buffer, StringRef string)
Definition: StringRef.h:940
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::LocationSize::getValue
uint64_t getValue() const
Definition: MemoryLocation.h:158
llvm::ms_demangle::QualifierMangleMode::Result
@ Result
llvm::BitTracker
Definition: BitTracker.h:35
llvm::Value::uses
iterator_range< use_iterator > uses()
Definition: Value.h:376
false
Definition: StackSlotColoring.cpp:142
llvm::MaybeAlign
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:109
B
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
llvm::BinaryOperator::getOpcode
BinaryOps getOpcode() const
Definition: InstrTypes.h:394
llvm::PatternMatch::m_ConstantInt
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:145
llvm::Instruction
Definition: Instruction.h:45
llvm::Type::getScalarSizeInBits
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition: Type.cpp:191
llvm::DominatorTreeWrapperPass
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:287
llvm::raw_ostream
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:53
llvm::LowerMatrixIntrinsicsPass::run
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition: LowerMatrixIntrinsics.cpp:2238
llvm::codeview::EncodedFramePtrReg::BasePtr
@ BasePtr
llvm::raw_ostream::flush
void flush()
Definition: raw_ostream.h:186
llvm::UndefValue::get
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1804
llvm::DomTreeUpdater
Definition: DomTreeUpdater.h:28
llvm::ConstantInt::get
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:932
LoopUtils.h
llvm::getUnderlyingObject
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
Definition: ValueTracking.cpp:4280
PatternMatch.h
llvm::TargetTransformInfo::RGK_FixedWidthVector
@ RGK_FixedWidthVector
Definition: TargetTransformInfo.h:916
llvm::FixedVectorType::get
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:686
llvm::StoreInst::getAlign
Align getAlign() const
Definition: Instructions.h:358
llvm::Align
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
llvm::FastMathFlags::setAllowContract
void setAllowContract(bool B=true)
Definition: Operator.h:236
llvm::ValueMap::count
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:152
llvm::Value::use_empty
bool use_empty() const
Definition: Value.h:344
MatrixLayoutTy
MatrixLayoutTy
Definition: LowerMatrixIntrinsics.cpp:73
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
CFG.h
llvm::function_ref
An efficient, type-erasing, non-owning reference to a callable.
Definition: STLExtras.h:223
llvm::VectorType
Base class of all SIMD vector types.
Definition: DerivedTypes.h:389
VectorUtils.h
llvm::cl::opt< bool >
llvm::StoreInst::isVolatile
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:345
llvm::StoreInst
An instruction for storing to memory.
Definition: Instructions.h:309
llvm::cl::values
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:697
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:78
llvm::LowerMatrixIntrinsicsPass::printPipeline
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
Definition: LowerMatrixIntrinsics.cpp:2265
llvm::StringRef::empty
constexpr LLVM_NODISCARD bool empty() const
empty - Check if the string is empty.
Definition: StringRef.h:152
llvm::MapVector::find
iterator find(const KeyT &Key)
Definition: MapVector.h:148
llvm::getPointerOperand
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
Definition: Instructions.h:5333
uint64_t
llvm::TargetTransformInfoWrapperPass
Wrapper pass for TargetTransformInfo.
Definition: TargetTransformInfo.h:2474
D
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
intrinsics
expand Expand reduction intrinsics
Definition: ExpandReductions.cpp:200
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
llvm::PreservedAnalyses::preserve
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:176
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2807
llvm::addStringMetadataToLoop
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:223
llvm::TargetTransformInfo::getRegisterClassForType
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
Definition: TargetTransformInfo.cpp:600
llvm::omp::AddressSpace::Shared
@ Shared
llvm::DenseMap
Definition: DenseMap.h:714
llvm::PatternMatch::m_Store
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
Definition: PatternMatch.h:1580
llvm::codeview::FrameCookieKind::Copy
@ Copy
llvm::ValueMap::erase
bool erase(const KeyT &Val)
Definition: ValueMap.h:191
llvm::LoopInfoBase::getLoopFor
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition: LoopInfo.h:970
I
#define I(x, y, z)
Definition: MD5.cpp:58
llvm::cl::init
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:441
llvm::make_early_inc_range
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:642
llvm::concatenateVectors
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
Definition: VectorUtils.cpp:875
llvm::X86AS::SS
@ SS
Definition: X86.h:189
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI
StandardInstrumentations SI(Debug, VerifyEach)
FuseMatrix
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
llvm::OptimizationRemarkEmitter::emit
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Definition: OptimizationRemarkEmitter.cpp:77
llvm::ValueMap::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:173
llvm::operator==
bool operator==(uint64_t V1, const APInt &V2)
Definition: APInt.h:1986
llvm::TTI
TargetTransformInfo TTI
Definition: TargetTransformInfo.h:163
llvm::Type::isVoidTy
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:138
MatrixLayoutTy::ColumnMajor
@ ColumnMajor
LowerMatrixIntrinsics.h
llvm::CallBase::arg_end
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1324
llvm::SmallPtrSetImpl::count
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:382
llvm::SmallSet::erase
bool erase(const T &V)
Definition: SmallSet.h:207
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:650
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::SetVector::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:141
llvm::size
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1630
llvm::Function::getIntrinsicID
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:207
llvm::ArrayRef
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: APInt.h:32
llvm::LoopInfo
Definition: LoopInfo.h:1086
llvm::Instruction::getFastMathFlags
FastMathFlags getFastMathFlags() const
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
Definition: Instruction.cpp:290
llvm::BinaryOperator
Definition: InstrTypes.h:190
llvm::OptimizationRemarkEmitter
The optimization diagnostic interface.
Definition: OptimizationRemarkEmitter.h:33
Matrix
Live Register Matrix
Definition: LiveRegMatrix.cpp:44
llvm::min
Expected< ExpressionValue > min(const ExpressionValue &Lhs, const ExpressionValue &Rhs)
Definition: FileCheck.cpp:357
llvm::MapVector::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:118
Mul
BinaryOperator * Mul
Definition: X86PartialReduction.cpp:68
llvm::any_of
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1656
llvm::CallBase::getParamAlign
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1729
DataLayout.h
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:253
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:57
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:134
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::AnalysisUsage::addPreserved
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
Definition: PassAnalysisSupport.h:98
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:532
getParent
static const Function * getParent(const Value *V)
Definition: BasicAliasAnalysis.cpp:870
llvm::ms_demangle::IntrinsicFunctionKind::New
@ New
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
clEnumValN
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:672
llvm::MatrixBuilder
Definition: MatrixBuilder.h:33
S
add sub stmia L5 ldr r0 bl L_printf $stub Instead of a and a wouldn t it be better to do three moves *Return an aggregate type is even return S
Definition: README.txt:210
pass_name_minimal
static const char pass_name_minimal[]
Definition: LowerMatrixIntrinsics.cpp:2353
llvm::TargetTransformInfo::getNumberOfRegisters
unsigned getNumberOfRegisters(unsigned ClassID) const
Definition: TargetTransformInfo.cpp:596
llvm::AMDGPU::HSAMD::Kernel::Arg::Key::IsVolatile
constexpr char IsVolatile[]
Key for Kernel::Arg::Metadata::mIsVolatile.
Definition: AMDGPUMetadata.h:194
llvm::SmallSet::insert
std::pair< NoneType, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:180
llvm::ifs::IFSSymbolType::Func
@ Func
llvm::Value::getName
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
llvm::ValueMap
See the file comment.
Definition: ValueMap.h:85
llvm::LoadInst
An instruction for reading from memory.
Definition: Instructions.h:180
llvm::BasicBlock::getTerminator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:152
users
iv users
Definition: IVUsers.cpp:52
llvm::MapVector::end
iterator end()
Definition: MapVector.h:72
llvm::ConstantInt::getZExtValue
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:142
llvm::RecurKind::FMulAdd
@ FMulAdd
Fused multiply-add of floats (a * b + c).
llvm::StringRef::size
constexpr LLVM_NODISCARD size_t size() const
size - Get the string size.
Definition: StringRef.h:156
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
llvm::Twine
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:83
Alignment.h
llvm::commonAlignment
Align commonAlignment(Align A, Align B)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:211
llvm::OptimizationRemarkEmitterWrapperPass
OptimizationRemarkEmitter legacy analysis pass.
Definition: OptimizationRemarkEmitter.h:146
MatrixLayout
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
llvm::GraphProgram::Name
Name
Definition: GraphWriter.h:52
AllowContractEnabled
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
llvm::StringRef::drop_front
LLVM_NODISCARD StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:652
llvm::DIScope
Base class for scope-like contexts.
Definition: DebugInfoMetadata.h:476
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:325
llvm::PreservedAnalyses::all
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:161
MatrixLayoutTy::RowMajor
@ RowMajor
llvm::TypeSize
Definition: TypeSize.h:416
Function.h
llvm::Value::hasNUses
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:149
llvm::sort
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1590
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::count
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:215
llvm::ValueMap::find
iterator find(const KeyT &Val)
Definition: ValueMap.h:156
llvm::MCID::Add
@ Add
Definition: MCInstrDesc.h:183
llvm::ReversePostOrderTraversal
Definition: PostOrderIterator.h:290
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:45
Vector
So we should use XX3Form_Rcr to implement instrinsic Convert DP outs ins xscvdpsp No builtin are required Round &Convert QP DP(dword[1] is set to zero) No builtin are required Round to Quad Precision because you need to assign rounding mode in instruction Provide builtin(set f128:$vT,(int_ppc_vsx_xsrqpi f128:$vB))(set f128 yields< n x< ty > >< result > yields< ty >< result > No builtin are required Load Store Vector
Definition: README_P9.txt:497
llvm::DominatorTreeAnalysis
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:252
llvm::BasicBlock::reverse_iterator
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:92
llvm::OptimizationRemark
Diagnostic information for applied optimization remarks.
Definition: DiagnosticInfo.h:685
llvm::Pass
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:91
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
Instructions.h
PostOrderIterator.h
llvm::LoadInst::isVolatile
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:217
llvm::initializeLowerMatrixIntrinsicsLegacyPassPass
void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &)
SmallVector.h
BT
BitTracker BT
Definition: BitTracker.cpp:73
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1343
llvm::SmallPtrSetImplBase::empty
LLVM_NODISCARD bool empty() const
Definition: SmallPtrSet.h:91
N
#define N
llvm::AAResultsWrapperPass
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
Definition: AliasAnalysis.h:1335
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:94
llvm::to_string
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:87
TargetTransformInfo.h
llvm::iterator_range
A range adaptor for a pair of iterators.
Definition: iterator_range.h:30
llvm::PHINode
Definition: Instructions.h:2657
DEBUG_TYPE
#define DEBUG_TYPE
Definition: LowerMatrixIntrinsics.cpp:52
minimal
lower matrix intrinsics minimal
Definition: LowerMatrixIntrinsics.cpp:2359
llvm::DISubprogram
Subprogram description.
Definition: DebugInfoMetadata.h:1826
llvm::SmallSet::empty
LLVM_NODISCARD bool empty() const
Definition: SmallSet.h:155
llvm::SmallVectorImpl< Instruction * >
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, false, false) INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass
llvm::SmallPtrSetImpl< Instruction * >
llvm::SmallSetVector
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:307
llvm::AnalysisManager
A container for analyses that lazily runs them and caches their results.
Definition: InstructionSimplify.h:44
llvm::X86II::TA
@ TA
Definition: X86BaseInfo.h:808
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:298
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1478
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
GEP
Hexagon Common GEP
Definition: HexagonCommonGEP.cpp:172
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
TileUseLoops
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
llvm::DebugLoc
A debug info location.
Definition: DebugLoc.h:33
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition: Instructions.h:62
llvm::FastMathFlags::allowContract
bool allowContract() const
Definition: Operator.h:215
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::cl::desc
Definition: CommandLine.h:412
llvm::createLowerMatrixIntrinsicsMinimalPass
Pass * createLowerMatrixIntrinsicsMinimalPass()
Definition: LowerMatrixIntrinsics.cpp:2362
llvm::createSequentialMask
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
Definition: VectorUtils.cpp:820
llvm::SetVector< Value * >
llvm::ConstantAggregateZero::get
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1683
BasicBlockUtils.h
llvm::SplitBlock
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Definition: BasicBlockUtils.cpp:839
llvm::pdb::PDB_SymType::Block
@ Block
InitializePasses.h
llvm::OptimizationRemarkEmitterAnalysis
Definition: OptimizationRemarkEmitter.h:164
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
Debug.h
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::MemoryLocation
Representation for a specific memory location.
Definition: MemoryLocation.h:209
llvm::LoopAnalysis
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:1246
llvm::DominatorTreeBase::Delete
static constexpr UpdateKind Delete
Definition: GenericDomTree.h:243
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition: Use.h:44
llvm::Type::getPrimitiveSizeInBits
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition: Type.cpp:166
llvm::SmallPtrSetImpl::insert
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:364
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:38