LLVM 20.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
22#include "llvm/ADT/ScopeExit.h"
23#include "llvm/ADT/SmallSet.h"
32#include "llvm/IR/CFG.h"
33#include "llvm/IR/DataLayout.h"
35#include "llvm/IR/Function.h"
36#include "llvm/IR/IRBuilder.h"
43#include "llvm/Support/Debug.h"
47
48#include <cmath>
49
50using namespace llvm;
51using namespace PatternMatch;
52
53#define DEBUG_TYPE "lower-matrix-intrinsics"
54
55static cl::opt<bool>
56 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
57 cl::desc("Enable/disable fusing matrix instructions."));
58// TODO: Allow and use non-square tiles.
60 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
62 "Tile size for matrix instruction fusion using square-shaped tiles."));
63static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
65 cl::desc("Generate loop nest for tiling."));
67 "force-fuse-matrix", cl::init(false), cl::Hidden,
68 cl::desc("Force matrix instruction fusion even if not profitable."));
70 "matrix-allow-contract", cl::init(false), cl::Hidden,
71 cl::desc("Allow the use of FMAs if available and profitable. This may "
72 "result in different results, due to less rounding error."));
73
74static cl::opt<bool>
75 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
76 cl::desc("Enable/disable matrix shape verification."),
77 cl::init(false));
78
80
82 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
83 cl::desc("Sets the default matrix layout"),
85 "Use column-major layout"),
87 "Use row-major layout")));
88
89static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
90 cl::init(false));
91
92/// Helper function to either return Scope, if it is a subprogram or the
93/// attached subprogram for a local scope.
95 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
96 return Subprogram;
97 return cast<DILocalScope>(Scope)->getSubprogram();
98}
99
100/// Return true if V is a splat of a value (which is used when multiplying a
101/// matrix with a scalar).
102static bool isSplat(Value *V) {
103 if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
104 return SV->isZeroEltSplat();
105 return false;
106}
107
108/// Match any mul operation (fp or integer).
109template <typename LTy, typename RTy>
110auto m_AnyMul(const LTy &L, const RTy &R) {
111 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
112}
113
114/// Match any add operation (fp or integer).
115template <typename LTy, typename RTy>
116auto m_AnyAdd(const LTy &L, const RTy &R) {
117 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
118}
119
120namespace {
121
122// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
123// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
124// assuming \p Stride elements between start two consecutive vectors.
125// \p Stride must be >= \p NumElements.
126// For column-major matrixes, the function computes the address of a column
127// vectors and \p NumElements must be set to the number of elements in a column
128// (= number of rows of the matrix). For row-major matrixes, the function
129// computes the address of a row vector and \p NumElements must be set to the
130// number of elements in a column (= number of columns of the matrix).
131//
132// Consider a 4x4 matrix in column-mjaor layout like below
133//
134// 0 1 2 3
135// 0 v_0_0 v_0_1 v_0_2 v_0_3
136// 1 v_1_0 v_1_1 v_1_2 v_1_3
137// 2 v_2_0 v_2_1 v_2_2 v_2_3
138// 3 v_3_0 v_3_1 v_3_2 v_3_3
139
140// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
141// we need a pointer to the first element of the submatrix as base pointer.
142// Then we can use computeVectorAddr to compute the addresses for the columns
143// of the sub-matrix.
144//
145// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
146// -> just returns Base
147// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
148// -> returns Base + (1 * 4)
149// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
150// -> returns Base + (2 * 4)
151//
152// The graphic below illustrates the number of elements in a column (marked
153// with |) and the number of skipped elements (marked with }).
154//
155// v_0_0 v_0_1 {v_0_2 {v_0_3
156// Base Col 1 Col 2
157// | | |
158// v_1_0 |v_1_1 |v_1_2 |v_1_3
159// v_2_0 |v_2_1 |v_2_2 |v_2_3
160// v_3_0 {v_3_1 {v_3_2 v_3_3
161//
162Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
163 unsigned NumElements, Type *EltType,
164 IRBuilder<> &Builder) {
165
166 assert((!isa<ConstantInt>(Stride) ||
167 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
168 "Stride must be >= the number of elements in the result vector.");
169
170 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
171 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
172
173 // Get pointer to the start of the selected vector. Skip GEP creation,
174 // if we select vector 0.
175 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
176 VecStart = BasePtr;
177 else
178 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
179
180 return VecStart;
181}
182
183namespace {
184struct ShapeInfo {
185 unsigned NumRows;
186 unsigned NumColumns;
187
188 bool IsColumnMajor;
189
190 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
191 : NumRows(NumRows), NumColumns(NumColumns),
192 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
193
194 ShapeInfo(Value *NumRows, Value *NumColumns)
195 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
196 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
197
198 bool operator==(const ShapeInfo &other) {
199 return NumRows == other.NumRows && NumColumns == other.NumColumns;
200 }
201 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
202
203 /// Returns true if shape-information is defined, meaning both dimensions
204 /// are != 0.
205 operator bool() const {
206 assert(NumRows == 0 || NumColumns != 0);
207 return NumRows != 0;
208 }
209
210 unsigned getStride() const {
211 if (IsColumnMajor)
212 return NumRows;
213 return NumColumns;
214 }
215
216 unsigned getNumVectors() const {
217 if (IsColumnMajor)
218 return NumColumns;
219 return NumRows;
220 }
221
222 /// Returns the transposed shape.
223 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
224};
225} // namespace
226
227static bool isUniformShape(Value *V) {
228 Instruction *I = dyn_cast<Instruction>(V);
229 if (!I)
230 return true;
231
232 switch (I->getOpcode()) {
233 case Instruction::FAdd:
234 case Instruction::FSub:
235 case Instruction::FMul: // Scalar multiply.
236 case Instruction::FNeg:
237 case Instruction::Add:
238 case Instruction::Mul:
239 case Instruction::Sub:
240 return true;
241 default:
242 return false;
243 }
244}
245
246/// Return the ShapeInfo for the result of \p I, it it can be determined.
247static std::optional<ShapeInfo>
248computeShapeInfoForInst(Instruction *I,
249 const DenseMap<Value *, ShapeInfo> &ShapeMap) {
250 Value *M;
251 Value *N;
252 Value *K;
253 if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
254 m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
255 return ShapeInfo(M, K);
256 if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
257 m_Value(N)))) {
258 // Flip dimensions.
259 return ShapeInfo(N, M);
260 }
261 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
262 m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
263 m_Value(N))))
264 return ShapeInfo(N, M);
265 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
266 m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
267 return ShapeInfo(M, N);
268 Value *MatrixA;
269 if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
270 auto OpShape = ShapeMap.find(MatrixA);
271 if (OpShape != ShapeMap.end())
272 return OpShape->second;
273 }
274
275 if (isUniformShape(I)) {
276 // Find the first operand that has a known shape and use that.
277 for (auto &Op : I->operands()) {
278 auto OpShape = ShapeMap.find(Op.get());
279 if (OpShape != ShapeMap.end())
280 return OpShape->second;
281 }
282 }
283 return std::nullopt;
284}
285
286/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
287///
288/// Currently, the lowering for each matrix intrinsic is done as follows:
289/// 1. Propagate the shape information from intrinsics to connected
290/// instructions.
291/// 2. Lower instructions with shape information (assuming column-major layout).
292/// The lowering works similarly using row-major layout.
293/// 2.1. Get column vectors for each argument. If we already lowered the
294/// definition of an argument, use the produced column vectors directly.
295/// If not, split the operand vector containing an embedded matrix into
296/// a set of column vectors,
297/// 2.2. Lower the instruction in terms of column major operations, which
298/// yields a set of column vectors containing result matrix. Note that we
299/// lower all instructions that have shape information. Besides the
300/// intrinsics, this includes stores for example.
301/// 2.3. Update uses of the lowered instruction. If we have shape information
302/// for a user, there is nothing to do, as we will look up the result
303/// column matrix when lowering the user. For other uses, we embed the
304/// result matrix in a flat vector and update the use.
305/// 2.4. Cache the result column matrix for the instruction we lowered
306/// 3. After we lowered all instructions in a function, remove the now
307/// obsolete instructions.
308///
309class LowerMatrixIntrinsics {
310 Function &Func;
311 const DataLayout &DL;
314 AliasAnalysis *AA = nullptr;
315 DominatorTree *DT = nullptr;
316 LoopInfo *LI = nullptr;
317 OptimizationRemarkEmitter *ORE = nullptr;
318
319 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
320 struct OpInfoTy {
321 /// Number of stores emitted to generate this matrix.
322 unsigned NumStores = 0;
323 /// Number of loads emitted to generate this matrix.
324 unsigned NumLoads = 0;
325 /// Number of compute operations emitted to generate this matrix.
326 unsigned NumComputeOps = 0;
327 /// Most of the time transposes can be fused with matrix multiplies or can
328 /// be folded away via algebraic simplifications. This is the number of
329 /// transposes that we failed to make "free" via such optimizations.
330 unsigned NumExposedTransposes = 0;
331
332 OpInfoTy &operator+=(const OpInfoTy &RHS) {
333 NumStores += RHS.NumStores;
334 NumLoads += RHS.NumLoads;
335 NumComputeOps += RHS.NumComputeOps;
336 NumExposedTransposes += RHS.NumExposedTransposes;
337 return *this;
338 }
339 };
340
341 /// Wrapper class representing a matrix as a set of vectors, either in row or
342 /// column major layout. All vectors must have the same vector type.
343 class MatrixTy {
345
346 OpInfoTy OpInfo;
347
348 bool IsColumnMajor = true;
349
350 public:
351 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
352 MatrixTy(ArrayRef<Value *> Vectors)
353 : Vectors(Vectors),
354 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
355 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
356 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
357
358 unsigned D = isColumnMajor() ? NumColumns : NumRows;
359 for (unsigned J = 0; J < D; ++J)
361 EltTy, isColumnMajor() ? NumRows : NumColumns)));
362 }
363
364 Value *getVector(unsigned i) const { return Vectors[i]; }
365 Value *getColumn(unsigned i) const {
366 assert(isColumnMajor() && "only supported for column-major matrixes");
367 return Vectors[i];
368 }
369 Value *getRow(unsigned i) const {
370 assert(!isColumnMajor() && "only supported for row-major matrixes");
371 return Vectors[i];
372 }
373
374 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
375
376 Type *getElementType() const { return getVectorTy()->getElementType(); }
377
378 unsigned getNumVectors() const {
379 if (isColumnMajor())
380 return getNumColumns();
381 return getNumRows();
382 }
383
384 unsigned getNumColumns() const {
385 if (isColumnMajor())
386 return Vectors.size();
387 else {
388 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
389 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
390 }
391 }
392 unsigned getNumRows() const {
393 if (isColumnMajor()) {
394 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
395 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
396 } else
397 return Vectors.size();
398 }
399
400 void addVector(Value *V) { Vectors.push_back(V); }
401 VectorType *getColumnTy() {
402 assert(isColumnMajor() && "only supported for column-major matrixes");
403 return getVectorTy();
404 }
405
406 VectorType *getVectorTy() const {
407 return cast<VectorType>(Vectors[0]->getType());
408 }
409
411 assert(isColumnMajor() &&
412 "columns() only supported for column-major matrixes");
413 return make_range(Vectors.begin(), Vectors.end());
414 }
415
417 return make_range(Vectors.begin(), Vectors.end());
418 }
419
420 /// Embed the vectors of the matrix into a flat vector by concatenating
421 /// them.
422 Value *embedInVector(IRBuilder<> &Builder) const {
423 return Vectors.size() == 1 ? Vectors[0]
424 : concatenateVectors(Builder, Vectors);
425 }
426
427 MatrixTy &addNumLoads(unsigned N) {
428 OpInfo.NumLoads += N;
429 return *this;
430 }
431
432 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
433
434 MatrixTy &addNumStores(unsigned N) {
435 OpInfo.NumStores += N;
436 return *this;
437 }
438
439 MatrixTy &addNumExposedTransposes(unsigned N) {
440 OpInfo.NumExposedTransposes += N;
441 return *this;
442 }
443
444 MatrixTy &addNumComputeOps(unsigned N) {
445 OpInfo.NumComputeOps += N;
446 return *this;
447 }
448
449 unsigned getNumStores() const { return OpInfo.NumStores; }
450 unsigned getNumLoads() const { return OpInfo.NumLoads; }
451 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
452
453 const OpInfoTy &getOpInfo() const { return OpInfo; }
454
455 bool isColumnMajor() const { return IsColumnMajor; }
456
457 unsigned getStride() const {
458 if (isColumnMajor())
459 return getNumRows();
460 return getNumColumns();
461 }
462
463 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
464 /// matrix is column-major, the result vector is extracted from a column
465 /// vector, otherwise from a row vector.
466 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
467 IRBuilder<> &Builder) const {
468 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
469 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
470 NumElts &&
471 "Extracted vector will contain poison values");
472 return Builder.CreateShuffleVector(
473 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
474 "block");
475 }
476 };
477
478 /// Maps instructions to their shape information. The shape information
479 /// describes the shape to be used while lowering. This matches the shape of
480 /// the result value of the instruction, with the only exceptions being store
481 /// instructions and the matrix_column_major_store intrinsics. For those, the
482 /// shape information indicates that those instructions should be lowered
483 /// using shape information as well. Note that extra care is needed when
484 /// erasing or RAUW'ing a value that is present in ShapeMap. If the
485 /// replacement is also a matrix operation, use
486 /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
487 /// ShapeMap. We don't use ValueMap, as there are also cases where we do not
488 /// want to add shape information for a replacement instruction. When directly
489 /// erasing a value with an entry in ShapeMap, use
490 /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
491 /// accordingly.
493
494 /// List of instructions to remove. While lowering, we are not replacing all
495 /// users of a lowered instruction, if shape information is available and
496 /// those need to be removed after we finished lowering.
498
499 /// Map from instructions to their produced column matrix.
500 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
501
502private:
504 FastMathFlags FMF;
505
506 if (isa<FPMathOperator>(*Inst))
507 FMF = Inst->getFastMathFlags();
508
510
511 return FMF;
512 }
513
514public:
515 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
517 : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
518
519 unsigned getNumOps(Type *VT) {
520 assert(isa<VectorType>(VT) && "Expected vector type");
521 return getNumOps(VT->getScalarType(),
522 cast<FixedVectorType>(VT)->getNumElements());
523 }
524
525 /// Is this the minimal version executed in the backend pipelines.
526 bool isMinimal() const {
527 return !DT;
528 }
529
530 /// Return the estimated number of vector ops required for an operation on
531 /// \p VT * N.
532 unsigned getNumOps(Type *ST, unsigned N) {
533 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
536 .getFixedValue()));
537 }
538
539 /// Return the set of vectors that a matrix value is lowered to.
540 ///
541 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
542 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
543 /// into vectors.
544 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
545 IRBuilder<> &Builder) {
546 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
547 assert(VType && "MatrixVal must be a vector type");
548 assert(cast<FixedVectorType>(VType)->getNumElements() ==
549 SI.NumRows * SI.NumColumns &&
550 "The vector size must match the number of matrix elements");
551
552 // Check if we lowered MatrixVal using shape information. In that case,
553 // return the existing matrix, if it matches the requested shape
554 // information. If there is a mis-match, embed the result in a flat
555 // vector and split it later.
556 auto Found = Inst2ColumnMatrix.find(MatrixVal);
557 if (Found != Inst2ColumnMatrix.end()) {
558 MatrixTy &M = Found->second;
559 // Return the found matrix, if its shape matches the requested shape
560 // information
561 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
562 return M;
563
564 MatrixVal = M.embedInVector(Builder);
565 }
566
567 // Otherwise split MatrixVal.
568 SmallVector<Value *, 16> SplitVecs;
569 for (unsigned MaskStart = 0;
570 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
571 MaskStart += SI.getStride()) {
572 Value *V = Builder.CreateShuffleVector(
573 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
574 "split");
575 SplitVecs.push_back(V);
576 }
577
578 return {SplitVecs};
579 }
580
581 /// If \p V already has a known shape return false. Otherwise set the shape
582 /// for instructions that support it.
583 bool setShapeInfo(Value *V, ShapeInfo Shape) {
584 assert(Shape && "Shape not set");
585 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
586 return false;
587
588 auto SIter = ShapeMap.find(V);
589 if (SIter != ShapeMap.end()) {
590 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
591 SIter->second.NumColumns != Shape.NumColumns)) {
592 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
593 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
594 << Shape.NumColumns << ") for " << *V << "\n";
596 "Matrix shape verification failed, compilation aborted!");
597 }
598
599 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
600 << SIter->second.NumRows << " "
601 << SIter->second.NumColumns << " for " << *V << "\n");
602 return false;
603 }
604
605 ShapeMap.insert({V, Shape});
606 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
607 << " for " << *V << "\n");
608 return true;
609 }
610
611 /// Returns true if shape information can be used for \p V. The supported
612 /// instructions must match the instructions that can be lowered by this pass.
613 bool supportsShapeInfo(Value *V) {
614 Instruction *Inst = dyn_cast<Instruction>(V);
615 if (!Inst)
616 return false;
617
618 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
619 if (II)
620 switch (II->getIntrinsicID()) {
621 case Intrinsic::matrix_multiply:
622 case Intrinsic::matrix_transpose:
623 case Intrinsic::matrix_column_major_load:
624 case Intrinsic::matrix_column_major_store:
625 return true;
626 default:
627 return false;
628 }
629 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
630 }
631
632 /// Propagate the shape information of instructions to their users.
633 /// The work list contains instructions for which we can compute the shape,
634 /// either based on the information provided by matrix intrinsics or known
635 /// shapes of operands.
637 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
639 // Pop an element for which we guaranteed to have at least one of the
640 // operand shapes. Add the shape for this and then add users to the work
641 // list.
642 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
643 while (!WorkList.empty()) {
644 Instruction *Inst = WorkList.pop_back_val();
645
646 // New entry, set the value and insert operands
647 bool Propagate = false;
648 if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
649 Propagate = setShapeInfo(Inst, *SI);
650
651 if (Propagate) {
652 NewWorkList.push_back(Inst);
653 for (auto *User : Inst->users())
654 if (ShapeMap.count(User) == 0)
655 WorkList.push_back(cast<Instruction>(User));
656 }
657 }
658
659 return NewWorkList;
660 }
661
662 /// Propagate the shape to operands of instructions with shape information.
663 /// \p Worklist contains the instruction for which we already know the shape.
665 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
667
668 auto pushInstruction = [](Value *V,
670 Instruction *I = dyn_cast<Instruction>(V);
671 if (I)
672 WorkList.push_back(I);
673 };
674 // Pop an element with known shape. Traverse the operands, if their shape
675 // derives from the result shape and is unknown, add it and add them to the
676 // worklist.
677 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
678 while (!WorkList.empty()) {
679 Value *V = WorkList.pop_back_val();
680
681 size_t BeforeProcessingV = WorkList.size();
682 if (!isa<Instruction>(V))
683 continue;
684
685 Value *MatrixA;
686 Value *MatrixB;
687 Value *M;
688 Value *N;
689 Value *K;
690 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
691 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
692 m_Value(N), m_Value(K)))) {
693 if (setShapeInfo(MatrixA, {M, N}))
694 pushInstruction(MatrixA, WorkList);
695
696 if (setShapeInfo(MatrixB, {N, K}))
697 pushInstruction(MatrixB, WorkList);
698
699 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
700 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
701 // Flip dimensions.
702 if (setShapeInfo(MatrixA, {M, N}))
703 pushInstruction(MatrixA, WorkList);
704 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
705 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
706 m_Value(M), m_Value(N)))) {
707 if (setShapeInfo(MatrixA, {M, N})) {
708 pushInstruction(MatrixA, WorkList);
709 }
710 } else if (isa<LoadInst>(V) ||
711 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
712 // Nothing to do, no matrix input.
713 } else if (isa<StoreInst>(V)) {
714 // Nothing to do. We forward-propagated to this so we would just
715 // backward propagate to an instruction with an already known shape.
716 } else if (isUniformShape(V)) {
717 // Propagate to all operands.
718 ShapeInfo Shape = ShapeMap[V];
719 for (Use &U : cast<Instruction>(V)->operands()) {
720 if (setShapeInfo(U.get(), Shape))
721 pushInstruction(U.get(), WorkList);
722 }
723 }
724 // After we discovered new shape info for new instructions in the
725 // worklist, we use their users as seeds for the next round of forward
726 // propagation.
727 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
728 for (User *U : WorkList[I]->users())
729 if (isa<Instruction>(U) && V != U)
730 NewWorkList.push_back(cast<Instruction>(U));
731 }
732 return NewWorkList;
733 }
734
735 /// (Op0 op Op1)^T -> Op0^T op Op1^T
736 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
737 /// them on both sides of \p Operation.
738 Instruction *distributeTransposes(
739 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
740 MatrixBuilder &Builder,
741 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
742 Operation) {
743 Value *T0 = Builder.CreateMatrixTranspose(
744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
745 // We are being run after shape prop, add shape for newly created
746 // instructions so that we lower them later.
747 setShapeInfo(T0, Shape0.t());
748 Value *T1 = Builder.CreateMatrixTranspose(
749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
750 setShapeInfo(T1, Shape1.t());
751 return Operation(T0, Shape0.t(), T1, Shape1.t());
752 }
753
754 /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
755 /// itself.
756 void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
757 auto Iter = ShapeMap.find(Inst);
758 if (Iter != ShapeMap.end())
759 ShapeMap.erase(Iter);
760 Inst->eraseFromParent();
761 }
762
763 /// Erase \p V from \p BB and move \II forward to avoid invalidating
764 /// iterators.
765 void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
766 BasicBlock &BB) {
767 auto *Inst = cast<Instruction>(V);
768 // Still used, don't erase.
769 if (!Inst->use_empty())
770 return;
771 if (II != BB.rend() && Inst == &*II)
772 ++II;
773 eraseFromParentAndRemoveFromShapeMap(Inst);
774 }
775
776 /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
777 /// entry for \p Old and replace all uses of \p Old with \p New.
778 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
779 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
780 // with New. We should only add New it it supportsShapeInfo so we insert
781 // it conditionally instead.
782 auto S = ShapeMap.find(&Old);
783 if (S != ShapeMap.end()) {
784 ShapeMap.erase(S);
785 if (supportsShapeInfo(New))
786 ShapeMap.insert({New, S->second});
787 }
788 Old.replaceAllUsesWith(New);
789 }
790
791 /// Sink a top-level transpose inside matmuls and adds.
792 /// This creates and erases instructions as needed, and returns the newly
793 /// created instruction while updating the iterator to avoid invalidation. If
794 /// this returns nullptr, no new instruction was created.
796 BasicBlock &BB = *I.getParent();
797 IRBuilder<> IB(&I);
798 MatrixBuilder Builder(IB);
799
800 Value *TA, *TAMA, *TAMB;
801 ConstantInt *R, *K, *C;
802 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
804 return nullptr;
805
806 // Transpose of a transpose is a nop
807 Value *TATA;
808 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
809 updateShapeAndReplaceAllUsesWith(I, TATA);
810 eraseFromParentAndMove(&I, II, BB);
811 eraseFromParentAndMove(TA, II, BB);
812 return nullptr;
813 }
814
815 // k^T -> k
816 if (isSplat(TA)) {
817 updateShapeAndReplaceAllUsesWith(I, TA);
818 eraseFromParentAndMove(&I, II, BB);
819 return nullptr;
820 }
821
822 // (A * B)^t -> B^t * A^t
823 // RxK KxC CxK KxR
824 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
825 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
827 auto NewInst = distributeTransposes(
828 TAMB, {K, C}, TAMA, {R, K}, Builder,
829 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
830 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
831 Shape0.NumColumns,
832 Shape1.NumColumns, "mmul");
833 });
834 updateShapeAndReplaceAllUsesWith(I, NewInst);
835 eraseFromParentAndMove(&I, II, BB);
836 eraseFromParentAndMove(TA, II, BB);
837 return NewInst;
838 }
839
840 // Same as above, but with a mul, which occurs when multiplied
841 // with a scalar.
842 // (A * k)^t -> A^t * k
843 // R x C RxC
844 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
845 (isSplat(TAMA) || isSplat(TAMB))) {
846 IRBuilder<> LocalBuilder(&I);
847 // We know that the transposed operand is of shape RxC.
848 // An when multiplied with a scalar, the shape is preserved.
849 auto NewInst = distributeTransposes(
850 TAMA, {R, C}, TAMB, {R, C}, Builder,
851 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
852 bool IsFP = I.getType()->isFPOrFPVectorTy();
853 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
854 : LocalBuilder.CreateMul(T0, T1, "mmul");
855 auto *Result = cast<Instruction>(Mul);
856 setShapeInfo(Result, Shape0);
857 return Result;
858 });
859 updateShapeAndReplaceAllUsesWith(I, NewInst);
860 eraseFromParentAndMove(&I, II, BB);
861 eraseFromParentAndMove(TA, II, BB);
862 return NewInst;
863 }
864
865 // (A + B)^t -> A^t + B^t
866 // RxC RxC CxR CxR
867 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
868 IRBuilder<> LocalBuilder(&I);
869 auto NewInst = distributeTransposes(
870 TAMA, {R, C}, TAMB, {R, C}, Builder,
871 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
872 bool IsFP = I.getType()->isFPOrFPVectorTy();
873 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
874 : LocalBuilder.CreateAdd(T0, T1, "madd");
875
876 auto *Result = cast<Instruction>(Add);
877 setShapeInfo(Result, Shape0);
878 return Result;
879 });
880 updateShapeAndReplaceAllUsesWith(I, NewInst);
881 eraseFromParentAndMove(&I, II, BB);
882 eraseFromParentAndMove(TA, II, BB);
883 return NewInst;
884 }
885
886 return nullptr;
887 }
888
889 void liftTranspose(Instruction &I) {
890 // Erase dead Instructions after lifting transposes from binops.
891 auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
892 if (T.use_empty())
893 eraseFromParentAndRemoveFromShapeMap(&T);
894 if (A->use_empty())
895 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
896 if (A != B && B->use_empty())
897 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
898 };
899
900 Value *A, *B, *AT, *BT;
901 ConstantInt *R, *K, *C;
902 // A^t * B ^t -> (B * A)^t
903 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
906 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
907 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
908 IRBuilder<> IB(&I);
909 MatrixBuilder Builder(IB);
910 Value *M = Builder.CreateMatrixMultiply(
911 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
912 setShapeInfo(M, {C, R});
913 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
914 R->getZExtValue());
915 updateShapeAndReplaceAllUsesWith(I, NewInst);
916 CleanupBinOp(I, A, B);
917 }
918 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
919 // the shape of the second transpose is different, there's a shape conflict
920 // which gets resolved by picking the shape of the first operand.
921 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
922 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
923 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
924 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
926 IRBuilder<> Builder(&I);
927 auto *Add = Builder.CreateFAdd(AT, BT, "mfadd");
928 MatrixBuilder MBuilder(Builder);
929 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
930 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
931 updateShapeAndReplaceAllUsesWith(I, NewInst);
932 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
933 computeShapeInfoForInst(&I, ShapeMap) &&
934 "Shape of new instruction doesn't match original shape.");
935 CleanupBinOp(I, A, B);
936 if (auto *AddI = dyn_cast<Instruction>(Add)) {
937 setShapeInfo(AddI, {R, C});
938 assert(
939 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
940 ShapeMap[AddI] &&
941 "Shape of updated addition doesn't match cached shape.");
942 }
943 }
944 }
945
946 /// Try moving transposes in order to fold them away or into multiplies.
947 void optimizeTransposes() {
948 // First sink all transposes inside matmuls and adds, hoping that we end up
949 // with NN, NT or TN variants.
950 for (BasicBlock &BB : reverse(Func)) {
951 for (auto II = BB.rbegin(); II != BB.rend();) {
952 Instruction &I = *II;
953 // We may remove II. By default continue on the next/prev instruction.
954 ++II;
955 if (Instruction *NewInst = sinkTranspose(I, II))
956 II = std::next(BasicBlock::reverse_iterator(NewInst));
957 }
958 }
959
960 // If we have a TT matmul or a TT add, lift the transpose. We may be able
961 // to fold into consuming multiply or add.
962 for (BasicBlock &BB : Func) {
964 liftTranspose(I);
965 }
966 }
967 }
968
969 bool Visit() {
971
972 // Initially only the shape of matrix intrinsics is known.
973 // Initialize the work list with ops carrying shape information.
974 for (BasicBlock &BB : Func)
975 for (Instruction &Inst : BB) {
976 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
977 if (!II)
978 continue;
979
980 switch (II->getIntrinsicID()) {
981 case Intrinsic::matrix_multiply:
982 case Intrinsic::matrix_transpose:
983 case Intrinsic::matrix_column_major_load:
984 case Intrinsic::matrix_column_major_store:
985 WorkList.push_back(&Inst);
986 break;
987 default:
988 break;
989 }
990 }
991
992 // Avoid unnecessary work if there are no matrix intrinsics in the function.
993 if (WorkList.empty())
994 return false;
995
996 if (AM) {
998 AA = &AM->getResult<AAManager>(Func);
1000 LI = &AM->getResult<LoopAnalysis>(Func);
1001 }
1002
1003 // Propagate shapes until nothing changes any longer.
1004 while (!WorkList.empty()) {
1005 WorkList = propagateShapeForward(WorkList);
1006 WorkList = propagateShapeBackward(WorkList);
1007 }
1008
1009 if (!isMinimal()) {
1010 optimizeTransposes();
1012 dbgs() << "Dump after matrix transpose optimization:\n";
1013 Func.print(dbgs());
1014 }
1015 }
1016
1017 bool Changed = false;
1018 SmallVector<CallInst *, 16> MaybeFusableInsts;
1021
1022 // First, collect all instructions with shape information and candidates for
1023 // fusion (currently only matrix multiplies).
1025 for (auto *BB : RPOT)
1026 for (Instruction &I : *BB) {
1027 if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
1028 LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
1029 if (ShapeMap.find(&I) == ShapeMap.end())
1030 continue;
1031 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
1032 MaybeFusableInsts.push_back(cast<CallInst>(&I));
1033 MatrixInsts.push_back(&I);
1034 }
1035
1036 // Second, try to lower any dot products
1038 for (CallInst *CI : MaybeFusableInsts)
1039 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1040
1041 // Third, try to fuse candidates.
1042 for (CallInst *CI : MaybeFusableInsts)
1043 if (!FusedInsts.contains(CI))
1044 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1045
1046 Changed = !FusedInsts.empty();
1047
1048 // Fourth, lower remaining instructions with shape information.
1049 for (Instruction *Inst : MatrixInsts) {
1050 if (FusedInsts.count(Inst))
1051 continue;
1052
1053 IRBuilder<> Builder(Inst);
1054
1055 if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1056 Changed |= VisitCallInst(CInst);
1057
1058 Value *Op1;
1059 Value *Op2;
1060 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1061 Changed |= VisitBinaryOperator(BinOp);
1062 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1063 Changed |= VisitUnaryOperator(UnOp);
1064 if (match(Inst, m_Load(m_Value(Op1))))
1065 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1066 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1067 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1068 }
1069
1070 if (ORE) {
1071 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1072 RemarkGen.emitRemarks();
1073 }
1074
1075 // Delete the instructions backwards, as it has a reduced likelihood of
1076 // having to update as many def-use and use-def chains.
1077 //
1078 // Because we add to ToRemove during fusion we can't guarantee that defs
1079 // are before uses. Change uses to poison temporarily as these should get
1080 // removed as well.
1081 //
1082 // For verification, we keep track of where we changed uses to poison in
1083 // PoisonedInsts and then check that we in fact remove them.
1084 SmallSet<Instruction *, 16> PoisonedInsts;
1085 for (auto *Inst : reverse(ToRemove)) {
1086 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1087 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1088 PoisonedInsts.insert(Poisoned);
1089 U.set(PoisonValue::get(Inst->getType()));
1090 }
1091 Inst->eraseFromParent();
1092 PoisonedInsts.erase(Inst);
1093 }
1094 if (!PoisonedInsts.empty()) {
1095 // If we didn't remove all poisoned instructions, it's a hard error.
1096 dbgs() << "Poisoned but present instructions:\n";
1097 for (auto *I : PoisonedInsts)
1098 dbgs() << *I << "\n";
1099 llvm_unreachable("Poisoned but instruction not removed");
1100 }
1101
1102 return Changed;
1103 }
1104
1105 /// Replace intrinsic calls
1106 bool VisitCallInst(CallInst *Inst) {
1107 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1108 return false;
1109
1110 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1111 case Intrinsic::matrix_multiply:
1112 LowerMultiply(Inst);
1113 break;
1114 case Intrinsic::matrix_transpose:
1115 LowerTranspose(Inst);
1116 break;
1117 case Intrinsic::matrix_column_major_load:
1118 LowerColumnMajorLoad(Inst);
1119 break;
1120 case Intrinsic::matrix_column_major_store:
1121 LowerColumnMajorStore(Inst);
1122 break;
1123 default:
1124 return false;
1125 }
1126 return true;
1127 }
1128
1129 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1130 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1131 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1132 /// non-ConstantInt strides, return the common alignment of the initial
1133 /// alignment and the element size in bytes.
1134 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1135 MaybeAlign A) const {
1136 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1137 if (Idx == 0)
1138 return InitialAlign;
1139
1140 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1141 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1142 uint64_t StrideInBytes =
1143 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1144 return commonAlignment(InitialAlign, Idx * StrideInBytes);
1145 }
1146 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1147 }
1148
1149 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1150 /// vectors.
1151 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1152 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1153 auto *VType = cast<VectorType>(Ty);
1154 Type *EltTy = VType->getElementType();
1155 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1156 Value *EltPtr = Ptr;
1157 MatrixTy Result;
1158 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1159 Value *GEP = computeVectorAddr(
1160 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1161 Stride, Shape.getStride(), EltTy, Builder);
1162 Value *Vector = Builder.CreateAlignedLoad(
1163 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1164 IsVolatile, "col.load");
1165
1166 Result.addVector(Vector);
1167 }
1168 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1169 Result.getNumVectors());
1170 }
1171
1172 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1173 /// starting at \p MatrixPtr[I][J].
1174 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1175 ShapeInfo MatrixShape, Value *I, Value *J,
1176 ShapeInfo ResultShape, Type *EltTy,
1177 IRBuilder<> &Builder) {
1178
1179 Value *Offset = Builder.CreateAdd(
1180 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1181
1182 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1183 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1184 ResultShape.NumColumns);
1185
1186 return loadMatrix(TileTy, TileStart, Align,
1187 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1188 ResultShape, Builder);
1189 }
1190
1191 /// Lower a load instruction with shape information.
1192 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1193 bool IsVolatile, ShapeInfo Shape) {
1194 IRBuilder<> Builder(Inst);
1195 finalizeLowering(Inst,
1196 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1197 Shape, Builder),
1198 Builder);
1199 }
1200
1201 /// Lowers llvm.matrix.column.major.load.
1202 ///
1203 /// The intrinsic loads a matrix from memory using a stride between columns.
1204 void LowerColumnMajorLoad(CallInst *Inst) {
1205 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1206 "Intrinsic only supports column-major layout!");
1207 Value *Ptr = Inst->getArgOperand(0);
1208 Value *Stride = Inst->getArgOperand(1);
1209 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1210 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1211 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1212 }
1213
1214 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1215 /// MatrixPtr[I][J].
1216 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1217 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1218 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1219 Value *Offset = Builder.CreateAdd(
1220 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1221
1222 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1223 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1224 StoreVal.getNumColumns());
1225
1226 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1227 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1228 }
1229
1230 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1231 /// vectors.
1232 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1233 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1234 IRBuilder<> &Builder) {
1235 auto VType = cast<VectorType>(Ty);
1236 Value *EltPtr = Ptr;
1237 for (auto Vec : enumerate(StoreVal.vectors())) {
1238 Value *GEP = computeVectorAddr(
1239 EltPtr,
1240 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1241 Vec.index()),
1242 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1243 Builder.CreateAlignedStore(Vec.value(), GEP,
1244 getAlignForIndex(Vec.index(), Stride,
1245 VType->getElementType(),
1246 MAlign),
1247 IsVolatile);
1248 }
1249 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1250 StoreVal.getNumVectors());
1251 }
1252
1253 /// Lower a store instruction with shape information.
1255 Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1256 IRBuilder<> Builder(Inst);
1257 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1258 finalizeLowering(Inst,
1259 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1260 IsVolatile, Builder),
1261 Builder);
1262 }
1263
1264 /// Lowers llvm.matrix.column.major.store.
1265 ///
1266 /// The intrinsic store a matrix back memory using a stride between columns.
1267 void LowerColumnMajorStore(CallInst *Inst) {
1268 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1269 "Intrinsic only supports column-major layout!");
1270 Value *Matrix = Inst->getArgOperand(0);
1271 Value *Ptr = Inst->getArgOperand(1);
1272 Value *Stride = Inst->getArgOperand(2);
1273 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1274 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1275 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1276 }
1277
1278 // Set elements I..I+NumElts-1 to Block
1279 Value *insertVector(Value *Col, unsigned I, Value *Block,
1280 IRBuilder<> &Builder) {
1281
1282 // First, bring Block to the same size as Col
1283 unsigned BlockNumElts =
1284 cast<FixedVectorType>(Block->getType())->getNumElements();
1285 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1286 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1287
1288 Block = Builder.CreateShuffleVector(
1289 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1290
1291 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1292 // 8, 4, 5, 6
1294 unsigned i;
1295 for (i = 0; i < I; i++)
1296 Mask.push_back(i);
1297
1298 unsigned VecNumElts =
1299 cast<FixedVectorType>(Col->getType())->getNumElements();
1300 for (; i < I + BlockNumElts; i++)
1301 Mask.push_back(i - I + VecNumElts);
1302
1303 for (; i < VecNumElts; i++)
1304 Mask.push_back(i);
1305
1306 return Builder.CreateShuffleVector(Col, Block, Mask);
1307 }
1308
1309 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1310 IRBuilder<> &Builder, bool AllowContraction,
1311 unsigned &NumComputeOps) {
1312 NumComputeOps += getNumOps(A->getType());
1313 if (!Sum)
1314 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1315
1316 if (UseFPOp) {
1317 if (AllowContraction) {
1318 // Use fmuladd for floating point operations and let the backend decide
1319 // if that's profitable.
1320 return Builder.CreateIntrinsic(Intrinsic::fmuladd, A->getType(),
1321 {A, B, Sum});
1322 }
1323 NumComputeOps += getNumOps(A->getType());
1324 Value *Mul = Builder.CreateFMul(A, B);
1325 return Builder.CreateFAdd(Sum, Mul);
1326 }
1327
1328 NumComputeOps += getNumOps(A->getType());
1329 Value *Mul = Builder.CreateMul(A, B);
1330 return Builder.CreateAdd(Sum, Mul);
1331 }
1332
1333 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1334 /// users with shape information, there's nothing to do: they will use the
1335 /// cached value when they are lowered. For other users, \p Matrix is
1336 /// flattened and the uses are updated to use it. Also marks \p Inst for
1337 /// deletion.
1338 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1339 IRBuilder<> &Builder) {
1340 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1341 (void)inserted;
1342 assert(inserted.second && "multiple matrix lowering mapping");
1343
1344 ToRemove.push_back(Inst);
1345 Value *Flattened = nullptr;
1346 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1347 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1348 if (!Flattened)
1349 Flattened = Matrix.embedInVector(Builder);
1350 U.set(Flattened);
1351 }
1352 }
1353 }
1354
1355 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1356 /// vectors Lowers to vector reduction add instead of sequential add if
1357 /// reassocation is enabled.
1358 void lowerDotProduct(CallInst *MatMul,
1360 FastMathFlags FMF) {
1361 if (FusedInsts.contains(MatMul) ||
1362 MatrixLayout != MatrixLayoutTy::ColumnMajor)
1363 return;
1364 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1365 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1366
1367 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1368 return;
1369
1370 Value *LHS = MatMul->getArgOperand(0);
1371 Value *RHS = MatMul->getArgOperand(1);
1372
1373 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
1374 bool IsIntVec = ElementType->isIntegerTy();
1375
1376 // Floating point reductions require reassocation.
1377 if (!IsIntVec && !FMF.allowReassoc())
1378 return;
1379
1380 auto CanBeFlattened = [](Value *Op) {
1381 if (match(Op, m_BinOp()))
1382 return true;
1383 return match(
1385 m_Load(m_Value()),
1386 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1387 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1388 m_Value(), m_SpecificInt(1))))));
1389 };
1390 // Returns the cost benefit of using \p Op with the dot product lowering. If
1391 // the returned cost is < 0, the argument is cheaper to use in the
1392 // dot-product lowering.
1393 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1394 if (ShapeMap.find(Op) == ShapeMap.end())
1396
1397 if (!isa<Instruction>(Op))
1398 return InstructionCost(0);
1399
1400 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1401 Type *EltTy = VecTy->getElementType();
1402
1403 if (!CanBeFlattened(Op)) {
1404 InstructionCost EmbedCost(0);
1405 // Roughly estimate the cost for embedding the columns into a vector.
1406 for (unsigned I = 1; I < N; ++I)
1407 EmbedCost +=
1410 return EmbedCost;
1411 }
1412
1413 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1414 InstructionCost OriginalCost =
1415 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1416 EltTy) *
1417 N;
1419 cast<Instruction>(Op)->getOpcode(), VecTy);
1420 return NewCost - OriginalCost;
1421 }
1422
1423 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1424 // The transpose can be skipped for the dot product lowering, roughly
1425 // estimate the savings as the cost of embedding the columns in a
1426 // vector.
1427 InstructionCost EmbedCost(0);
1428 for (unsigned I = 1; I < N; ++I)
1429 EmbedCost -=
1432 return EmbedCost;
1433 }
1434
1435 // Costs for loads.
1436 if (N == 1)
1437 return InstructionCost(0);
1438
1439 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1440 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
1441 };
1442
1443 // Iterate over LHS and operations feeding LHS and check if it is profitable
1444 // to flatten the visited ops. For each op, we compute the difference
1445 // between the flattened and matrix versions.
1447 SmallVector<Value *> WorkList;
1448 SmallVector<Value *> ToFlatten;
1449 WorkList.push_back(LHS);
1450 InstructionCost LHSCost(0);
1451 while (!WorkList.empty()) {
1452 Value *Op = WorkList.pop_back_val();
1453 if (!Seen.insert(Op).second)
1454 continue;
1455
1456 InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1457 if (OpCost + LHSCost >= LHSCost)
1458 continue;
1459
1460 LHSCost += OpCost;
1461 ToFlatten.push_back(Op);
1462 if (auto *I = dyn_cast<Instruction>(Op))
1463 WorkList.append(I->op_begin(), I->op_end());
1464 }
1465
1466 // We compare the costs of a vector.reduce.add to sequential add.
1467 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1468 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1469 InstructionCost ReductionCost =
1471 AddOpCode, cast<VectorType>(LHS->getType()),
1472 IsIntVec ? std::nullopt : std::optional(FMF)) +
1473 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
1474 InstructionCost SequentialAddCost =
1475 TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
1476 (LShape.NumColumns - 1) +
1477 TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
1478 (LShape.NumColumns);
1479 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1480 return;
1481
1482 FusedInsts.insert(MatMul);
1483 IRBuilder<> Builder(MatMul);
1484 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1485 this](Value *Op) {
1486 // Matmul must be the only user of loads because we don't use LowerLoad
1487 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1488 // instead of single vector load).
1489 if (!CanBeFlattened(Op))
1490 return;
1491
1492 if (match(Op, m_BinOp())) {
1493 auto It = ShapeMap.find(Op);
1494 if (It != ShapeMap.end()) {
1495 It->second = It->second.t();
1496 return;
1497 }
1498 }
1499
1500 FusedInsts.insert(cast<Instruction>(Op));
1501 // If vector uses the builtin load, lower to a LoadInst
1502 Value *Arg;
1503 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1504 m_Value(Arg)))) {
1505 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1506 Op->replaceAllUsesWith(NewLoad);
1507 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
1508 return;
1509 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1510 m_Value(Arg)))) {
1511 ToRemove.push_back(cast<Instruction>(Op));
1512 Op->replaceAllUsesWith(Arg);
1513 return;
1514 }
1515 };
1516
1517 for (auto *V : ToFlatten)
1518 FlattenArg(V);
1519
1520 LHS = MatMul->getArgOperand(0);
1521
1522 // Insert mul/fmul and llvm.vector.reduce.fadd
1523 Value *Mul =
1524 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
1525
1526 Value *Result;
1527 if (IsIntVec)
1528 Result = Builder.CreateAddReduce(Mul);
1529 else {
1530 Result = Builder.CreateFAddReduce(
1531 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
1532 0.0),
1533 Mul);
1534 cast<Instruction>(Result)->setFastMathFlags(FMF);
1535 }
1536
1537 // pack scalar back into a matrix and then replace matmul inst
1539 Result, uint64_t(0));
1540 MatMul->replaceAllUsesWith(Result);
1541 FusedInsts.insert(MatMul);
1542 ToRemove.push_back(MatMul);
1543 }
1544
1545 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1546 /// addition.
1547 ///
1548 /// We can fold a transpose into the operand that is used to extract scalars.
1549 /// This is the first operands with row-major and the second with
1550 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1551 /// operand is transposed.
1552 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1553 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1554 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1555 const unsigned VF = std::max<unsigned>(
1557 .getFixedValue() /
1558 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1559 1U);
1560 unsigned R = Result.getNumRows();
1561 unsigned C = Result.getNumColumns();
1562 unsigned M = A.getNumColumns();
1563
1564 bool IsFP = Result.getElementType()->isFloatingPointTy();
1565 assert(A.isColumnMajor() == B.isColumnMajor() &&
1566 Result.isColumnMajor() == A.isColumnMajor() &&
1567 "operands must agree on matrix layout");
1568 unsigned NumComputeOps = 0;
1569
1570 Builder.setFastMathFlags(FMF);
1571
1572 if (A.isColumnMajor()) {
1573 // Multiply columns from the first operand with scalars from the second
1574 // operand. Then move along the K axes and accumulate the columns. With
1575 // this the adds can be vectorized without reassociation.
1576 for (unsigned J = 0; J < C; ++J) {
1577 unsigned BlockSize = VF;
1578 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1579 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1580
1581 for (unsigned I = 0; I < R; I += BlockSize) {
1582 // Gradually lower the vectorization factor to cover the remainder.
1583 while (I + BlockSize > R)
1584 BlockSize /= 2;
1585
1586 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1587 : nullptr;
1588 for (unsigned K = 0; K < M; ++K) {
1589 Value *L = A.extractVector(I, K, BlockSize, Builder);
1590 Value *RH = Builder.CreateExtractElement(
1591 B.getColumn(IsScalarMatrixTransposed ? K : J),
1592 IsScalarMatrixTransposed ? J : K);
1593 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1594 Sum =
1595 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1596 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1597 }
1598 Result.setVector(J,
1599 insertVector(Result.getVector(J), I, Sum, Builder));
1600 }
1601 }
1602 } else {
1603 // Multiply rows from the second operand with scalars from the first
1604 // operand. Then move along the K axes and accumulate the rows. With this
1605 // the adds can be vectorized without reassociation.
1606 for (unsigned I = 0; I < R; ++I) {
1607 unsigned BlockSize = VF;
1608 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1609 for (unsigned J = 0; J < C; J += BlockSize) {
1610 // Gradually lower the vectorization factor to cover the remainder.
1611 while (J + BlockSize > C)
1612 BlockSize /= 2;
1613
1614 Value *Sum = nullptr;
1615 for (unsigned K = 0; K < M; ++K) {
1616 Value *R = B.extractVector(K, J, BlockSize, Builder);
1617 Value *LH = Builder.CreateExtractElement(
1618 A.getVector(IsScalarMatrixTransposed ? K : I),
1619 IsScalarMatrixTransposed ? I : K);
1620 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1621 Sum =
1622 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1623 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1624 }
1625 Result.setVector(I,
1626 insertVector(Result.getVector(I), J, Sum, Builder));
1627 }
1628 }
1629 }
1630 Result.addNumComputeOps(NumComputeOps);
1631 }
1632
1633 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1634 /// copying it to a new location. This new or otherwise the original location
1635 /// is returned.
1636 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1637 CallInst *MatMul) {
1638 MemoryLocation StoreLoc = MemoryLocation::get(Store);
1639 MemoryLocation LoadLoc = MemoryLocation::get(Load);
1640
1641 // If we can statically determine noalias we're good.
1642 if (AA->isNoAlias(LoadLoc, StoreLoc))
1643 return Load->getPointerOperand();
1644
1645 // Create code to check if the memory locations of the Load and Store
1646 // overlap and if they do, copy Load's operand to a new buffer.
1647
1648 // First, create new blocks for 2n part of the check and the copy.
1649 BasicBlock *Check0 = MatMul->getParent();
1650 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1651 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1652 // as we adjust Check0 and Check1's branches.
1654 for (BasicBlock *Succ : successors(Check0))
1655 DTUpdates.push_back({DT->Delete, Check0, Succ});
1656
1657 BasicBlock *Check1 =
1658 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1659 nullptr, "alias_cont");
1660 BasicBlock *Copy =
1661 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1662 nullptr, "copy");
1663 BasicBlock *Fusion =
1664 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1665 nullptr, "no_alias");
1666
1667 // Check if the loaded memory location begins before the end of the store
1668 // location. If the condition holds, they might overlap, otherwise they are
1669 // guaranteed to not overlap.
1670 IRBuilder<> Builder(MatMul);
1671 Check0->getTerminator()->eraseFromParent();
1672 Builder.SetInsertPoint(Check0);
1673 Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout());
1674 Value *StoreBegin = Builder.CreatePtrToInt(
1675 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1676 Value *StoreEnd = Builder.CreateAdd(
1677 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1678 "store.end", true, true);
1679 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1680 IntPtrTy, "load.begin");
1681 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1682 Fusion);
1683
1684 // Check if the store begins before the end of the load location. If the
1685 // condition holds, they alias, otherwise they are guaranteed to not
1686 // overlap.
1687 Check1->getTerminator()->eraseFromParent();
1688 Builder.SetInsertPoint(Check1, Check1->begin());
1689 Value *LoadEnd = Builder.CreateAdd(
1690 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1691 "load.end", true, true);
1692 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1693 Fusion);
1694
1695 // Copy load operand to new alloca.
1696 Builder.SetInsertPoint(Copy, Copy->begin());
1697 auto *VT = cast<FixedVectorType>(Load->getType());
1698 // Use an array type for the alloca, to avoid potentially huge alignment
1699 // requirements for large vector types.
1700 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1701 AllocaInst *Alloca =
1702 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1703
1704 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
1705 Load->getAlign(), LoadLoc.Size.getValue());
1706 Builder.SetInsertPoint(Fusion, Fusion->begin());
1707 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1708 PHI->addIncoming(Load->getPointerOperand(), Check0);
1709 PHI->addIncoming(Load->getPointerOperand(), Check1);
1710 PHI->addIncoming(Alloca, Copy);
1711
1712 // Adjust DT.
1713 DTUpdates.push_back({DT->Insert, Check0, Check1});
1714 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1715 DTUpdates.push_back({DT->Insert, Check1, Copy});
1716 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1717 DT->applyUpdates(DTUpdates);
1718 return PHI;
1719 }
1720
1721 bool isFusionProfitable(CallInst *MatMul) {
1722 if (ForceFusion)
1723 return true;
1724
1725 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1726 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1727
1728 const unsigned R = LShape.NumRows;
1729 const unsigned C = RShape.NumColumns;
1730 const unsigned M = LShape.NumColumns;
1731 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1732
1733 const unsigned VF = std::max<unsigned>(
1735 .getFixedValue() /
1737 1U);
1738
1739 // Cost model for tiling
1740 //
1741 // For tiling to be beneficial, we need reuse either along the R or
1742 // the C axis. We vectorize along the R axis so that means at least
1743 // 3 elements.
1744 // TODO: Also consider cost of copying if operands alias.
1745 if (R <= VF && C == 1)
1746 return false;
1747 // Then we need enough elements to exceed the number of vector
1748 // registers we have. Note that this is an oversimplification since
1749 // fusing also takes some extra loads which may exceed the number of
1750 // reloads necessary.
1751 unsigned Op0Regs = (R + VF - 1) / VF * M;
1752 unsigned Op1Regs = (M + VF - 1) / VF * C;
1753 return Op0Regs + Op1Regs >
1755 }
1756
1757 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1758 MatrixTy Res;
1759 auto *ColumType = FixedVectorType::get(EltType, R);
1760 for (unsigned I = 0; I < C; ++I)
1761 Res.addVector(ConstantAggregateZero::get(ColumType));
1762 return Res;
1763 }
1764
1765 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1766 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1767 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1768
1769 // Create the main tiling loop nest.
1770 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1771 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1772 Instruction *InsertI = cast<Instruction>(MatMul);
1773 BasicBlock *Start = InsertI->getParent();
1774 BasicBlock *End =
1775 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1776 IRBuilder<> Builder(MatMul);
1777 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1778
1779 Type *TileVecTy =
1781 MatrixTy TileResult;
1782 // Insert in the inner loop header.
1783 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1784 // Create PHI nodes for the result columns to accumulate across iterations.
1785 SmallVector<PHINode *, 4> ColumnPhis;
1786 for (unsigned I = 0; I < TileSize; I++) {
1787 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1788 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1789 TI.RowLoop.Header->getSingleSuccessor());
1790 TileResult.addVector(Phi);
1791 ColumnPhis.push_back(Phi);
1792 }
1793
1794 // Insert in the inner loop body, which computes
1795 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1796 Builder.SetInsertPoint(InnerBody->getTerminator());
1797 // Load tiles of the operands.
1798 MatrixTy A =
1799 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1800 {TileSize, TileSize}, EltType, Builder);
1801 MatrixTy B =
1802 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1803 {TileSize, TileSize}, EltType, Builder);
1804 emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1805 getFastMathFlags(MatMul));
1806 // Store result after the inner loop is done.
1807 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1808 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1809 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1810 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1811
1812 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1813 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1814
1815 // Force unrolling of a few iterations of the inner loop, to make sure there
1816 // is enough work per iteration.
1817 // FIXME: The unroller should make this decision directly instead, but
1818 // currently the cost-model is not up to the task.
1819 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1820 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1821 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1822 }
1823
1824 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1825 StoreInst *Store,
1826 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1827 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1828 "Tiling only supported for column-major matrixes at the moment!");
1829 if (!isFusionProfitable(MatMul))
1830 return;
1831
1832 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1833 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1834
1835 const unsigned R = LShape.NumRows;
1836 const unsigned C = RShape.NumColumns;
1837 const unsigned M = LShape.NumColumns;
1838 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1839
1840 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1841 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1842 Value *CPtr = Store->getPointerOperand();
1843
1844 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1845 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1846 else {
1847 IRBuilder<> Builder(Store);
1848 for (unsigned J = 0; J < C; J += TileSize)
1849 for (unsigned I = 0; I < R; I += TileSize) {
1850 const unsigned TileR = std::min(R - I, unsigned(TileSize));
1851 const unsigned TileC = std::min(C - J, unsigned(TileSize));
1852 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1853
1854 for (unsigned K = 0; K < M; K += TileSize) {
1855 const unsigned TileM = std::min(M - K, unsigned(TileSize));
1856 MatrixTy A =
1857 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1858 LShape, Builder.getInt64(I), Builder.getInt64(K),
1859 {TileR, TileM}, EltType, Builder);
1860 MatrixTy B =
1861 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1862 RShape, Builder.getInt64(K), Builder.getInt64(J),
1863 {TileM, TileC}, EltType, Builder);
1864 emitMatrixMultiply(Res, A, B, Builder, true, false,
1865 getFastMathFlags(MatMul));
1866 }
1867 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1868 Builder.getInt64(I), Builder.getInt64(J), EltType,
1869 Builder);
1870 }
1871 }
1872
1873 // Mark eliminated instructions as fused and remove them.
1874 FusedInsts.insert(Store);
1875 FusedInsts.insert(MatMul);
1876 eraseFromParentAndRemoveFromShapeMap(Store);
1877 eraseFromParentAndRemoveFromShapeMap(MatMul);
1878 if (LoadOp0->hasNUses(0)) {
1879 FusedInsts.insert(LoadOp0);
1880 eraseFromParentAndRemoveFromShapeMap(LoadOp0);
1881 }
1882 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1883 FusedInsts.insert(LoadOp1);
1884 eraseFromParentAndRemoveFromShapeMap(LoadOp1);
1885 }
1886 }
1887
1888 /// Try to lower matrix multiply chains by fusing operations.
1889 ///
1890 /// Call finalizeLowering on lowered instructions. Instructions that are
1891 /// completely eliminated by fusion are added to \p FusedInsts.
1892 void
1893 LowerMatrixMultiplyFused(CallInst *MatMul,
1895 SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
1896 if (!FuseMatrix || !DT)
1897 return;
1898
1899 assert(AA && LI && "Analyses should be available");
1900
1901 Value *A = MatMul->getArgOperand(0);
1902 Value *B = MatMul->getArgOperand(1);
1903
1904 // We can fold the transpose into the operand that is used to fetch scalars.
1905 Value *T;
1906 if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1907 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1908 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1909 IRBuilder<> Builder(MatMul);
1910 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1911 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1912 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1913 const unsigned R = LShape.NumRows;
1914 const unsigned M = LShape.NumColumns;
1915 const unsigned C = RShape.NumColumns;
1916
1917 MatrixTy MA;
1918 MatrixTy MB;
1919
1920 Value *Transpose;
1921 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1922 MA = getMatrix(A, ShapeInfo(R, M), Builder);
1923 MB = getMatrix(T, ShapeInfo(C, M), Builder);
1924 Transpose = B;
1925 } else {
1926 MA = getMatrix(T, ShapeInfo(R, M), Builder);
1927 MB = getMatrix(B, ShapeInfo(C, M), Builder);
1928 Transpose = A;
1929 }
1930
1931 // Initialize the output
1932 MatrixTy Result(R, C, EltType);
1933
1934 emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1935 getFastMathFlags(MatMul));
1936
1937 FusedInsts.insert(MatMul);
1938 if (Transpose->hasOneUse()) {
1939 FusedInsts.insert(cast<Instruction>(Transpose));
1940 ToRemove.push_back(cast<Instruction>(Transpose));
1941 // TODO: add a fake entry for the folded instruction so that this is
1942 // included in the expression in the remark.
1943 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1944 }
1945 finalizeLowering(MatMul, Result, Builder);
1946 return;
1947 }
1948
1949 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1950 return;
1951
1952 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1953 // since the single store user will be lowered as part of this.
1954 auto *LoadOp0 = dyn_cast<LoadInst>(A);
1955 auto *LoadOp1 = dyn_cast<LoadInst>(B);
1956 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1957 if (LoadOp0 && LoadOp1 && Store) {
1958 // The store address must dominate the MatMul instruction, otherwise
1959 // we create invalid IR.
1960 SetVector<Value *> WorkList;
1961 WorkList.insert(Store->getOperand(1));
1963 for (unsigned I = 0; I != WorkList.size(); ++I) {
1964 Value *Current = WorkList[I];
1965 auto *CurrI = dyn_cast<Instruction>(Current);
1966 if (!CurrI)
1967 continue;
1968 if (isa<PHINode>(CurrI))
1969 return;
1970 if (DT->dominates(CurrI, MatMul))
1971 continue;
1972 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1973 return;
1974 ToHoist.push_back(CurrI);
1975 WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1976 }
1977
1978 sort(ToHoist, [this](Instruction *A, Instruction *B) {
1979 return DT->dominates(A, B);
1980 });
1981 for (Instruction *I : ToHoist)
1982 I->moveBefore(MatMul);
1983
1984 // Deal with lifetime.end calls that might be between Load0/Load1 and the
1985 // store. To avoid introducing loads to dead objects (i.e. after the
1986 // lifetime has been termined by @llvm.lifetime.end), either sink them
1987 // after the store if in the same block, or remove the lifetime.end marker
1988 // otherwise. This might pessimize further optimizations, by extending the
1989 // lifetime of the object until the function returns, but should be
1990 // conservatively correct.
1991 MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
1992 MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
1993 BasicBlock *StoreParent = Store->getParent();
1994 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
1995 LoadOp1->getParent() == StoreParent;
1996 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
1997 IntrinsicInst *End = LifetimeEnds[Idx];
1998 auto Inc = make_scope_exit([&Idx]() { Idx++; });
1999 // If the lifetime.end is guaranteed to be before the loads or after the
2000 // store, it won't interfere with fusion.
2001 if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
2002 continue;
2003 if (DT->dominates(Store, End))
2004 continue;
2005 // If all fusable ops are in the same block and the lifetime.end is in a
2006 // different block, it won't interfere with fusion.
2007 if (FusableOpsInSameBlock && End->getParent() != StoreParent)
2008 continue;
2009
2010 // If the loads don't alias the lifetime.end, it won't interfere with
2011 // fusion.
2013 if (!EndLoc.Ptr)
2014 continue;
2015 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
2016 continue;
2017
2018 // If both lifetime.end and the store are in the same block, extend the
2019 // lifetime until after the store, so the new lifetime covers the loads
2020 // we introduce later.
2021 if (End->getParent() == StoreParent) {
2022 End->moveAfter(Store);
2023 continue;
2024 }
2025
2026 // Otherwise remove the conflicting lifetime.end marker.
2027 ToRemove.push_back(End);
2028 std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
2029 LifetimeEnds.pop_back();
2030 Inc.release();
2031 }
2032
2033 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2034 return;
2035 }
2036 }
2037
2038 /// Lowers llvm.matrix.multiply.
2039 void LowerMultiply(CallInst *MatMul) {
2040 IRBuilder<> Builder(MatMul);
2041 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
2042 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2043 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2044
2045 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
2046 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
2047 assert(Lhs.getElementType() == Rhs.getElementType() &&
2048 "Matrix multiply argument element types do not match.");
2049
2050 const unsigned R = LShape.NumRows;
2051 const unsigned C = RShape.NumColumns;
2052 assert(LShape.NumColumns == RShape.NumRows);
2053
2054 // Initialize the output
2055 MatrixTy Result(R, C, EltType);
2056 assert(Lhs.getElementType() == Result.getElementType() &&
2057 "Matrix multiply result element type does not match arguments.");
2058
2059 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
2060 getFastMathFlags(MatMul));
2061 finalizeLowering(MatMul, Result, Builder);
2062 }
2063
2064 /// Lowers llvm.matrix.transpose.
2065 void LowerTranspose(CallInst *Inst) {
2066 MatrixTy Result;
2067 IRBuilder<> Builder(Inst);
2068 Value *InputVal = Inst->getArgOperand(0);
2069 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
2070 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
2071 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2072
2073 const unsigned NewNumVecs =
2074 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2075 const unsigned NewNumElts =
2076 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2077
2078 for (unsigned I = 0; I < NewNumVecs; ++I) {
2079 // Build a single result vector. First initialize it.
2080 Value *ResultVector = PoisonValue::get(
2081 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
2082 // Go through the old elements and insert it into the resulting vector.
2083 for (auto J : enumerate(InputMatrix.vectors())) {
2084 Value *Elt = Builder.CreateExtractElement(J.value(), I);
2085 // Row and column indices are transposed.
2086 ResultVector =
2087 Builder.CreateInsertElement(ResultVector, Elt, J.index());
2088 }
2089 Result.addVector(ResultVector);
2090 }
2091
2092 // TODO: Improve estimate of operations needed for transposes. Currently we
2093 // just count the insertelement/extractelement instructions, but do not
2094 // account for later simplifications/combines.
2095 finalizeLowering(
2096 Inst,
2097 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2098 .addNumExposedTransposes(1),
2099 Builder);
2100 }
2101
2102 /// Lower load instructions, if shape information is available.
2103 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
2104 auto I = ShapeMap.find(Inst);
2105 if (I == ShapeMap.end())
2106 return false;
2107
2108 LowerLoad(Inst, Ptr, Inst->getAlign(),
2109 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2110 I->second);
2111 return true;
2112 }
2113
2114 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
2115 IRBuilder<> &Builder) {
2116 auto I = ShapeMap.find(StoredVal);
2117 if (I == ShapeMap.end())
2118 return false;
2119
2120 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2121 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2122 I->second);
2123 return true;
2124 }
2125
2126 /// Lower binary operators, if shape information is available.
2127 bool VisitBinaryOperator(BinaryOperator *Inst) {
2128 auto I = ShapeMap.find(Inst);
2129 if (I == ShapeMap.end())
2130 return false;
2131
2132 Value *Lhs = Inst->getOperand(0);
2133 Value *Rhs = Inst->getOperand(1);
2134
2135 IRBuilder<> Builder(Inst);
2136 ShapeInfo &Shape = I->second;
2137
2138 MatrixTy Result;
2139 MatrixTy A = getMatrix(Lhs, Shape, Builder);
2140 MatrixTy B = getMatrix(Rhs, Shape, Builder);
2141 assert(A.isColumnMajor() == B.isColumnMajor() &&
2142 Result.isColumnMajor() == A.isColumnMajor() &&
2143 "operands must agree on matrix layout");
2144
2145 Builder.setFastMathFlags(getFastMathFlags(Inst));
2146
2147 // Helper to perform binary op on vectors.
2148 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2149 switch (Inst->getOpcode()) {
2150 case Instruction::Add:
2151 return Builder.CreateAdd(LHS, RHS);
2152 case Instruction::Mul:
2153 return Builder.CreateMul(LHS, RHS);
2154 case Instruction::Sub:
2155 return Builder.CreateSub(LHS, RHS);
2156 case Instruction::FAdd:
2157 return Builder.CreateFAdd(LHS, RHS);
2158 case Instruction::FMul:
2159 return Builder.CreateFMul(LHS, RHS);
2160 case Instruction::FSub:
2161 return Builder.CreateFSub(LHS, RHS);
2162 default:
2163 llvm_unreachable("Unsupported binary operator for matrix");
2164 }
2165 };
2166
2167 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2168 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
2169
2170 finalizeLowering(Inst,
2171 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2172 Result.getNumVectors()),
2173 Builder);
2174 return true;
2175 }
2176
2177 /// Lower unary operators, if shape information is available.
2178 bool VisitUnaryOperator(UnaryOperator *Inst) {
2179 auto I = ShapeMap.find(Inst);
2180 if (I == ShapeMap.end())
2181 return false;
2182
2183 Value *Op = Inst->getOperand(0);
2184
2185 IRBuilder<> Builder(Inst);
2186 ShapeInfo &Shape = I->second;
2187
2188 MatrixTy Result;
2189 MatrixTy M = getMatrix(Op, Shape, Builder);
2190
2191 Builder.setFastMathFlags(getFastMathFlags(Inst));
2192
2193 // Helper to perform unary op on vectors.
2194 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2195 switch (Inst->getOpcode()) {
2196 case Instruction::FNeg:
2197 return Builder.CreateFNeg(Op);
2198 default:
2199 llvm_unreachable("Unsupported unary operator for matrix");
2200 }
2201 };
2202
2203 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2204 Result.addVector(BuildVectorOp(M.getVector(I)));
2205
2206 finalizeLowering(Inst,
2207 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2208 Result.getNumVectors()),
2209 Builder);
2210 return true;
2211 }
2212
2213 /// Helper to linearize a matrix expression tree into a string. Currently
2214 /// matrix expressions are linarized by starting at an expression leaf and
2215 /// linearizing bottom up.
2216 struct ExprLinearizer {
2217 unsigned LengthToBreak = 100;
2218 std::string Str;
2219 raw_string_ostream Stream;
2220 unsigned LineLength = 0;
2221 const DataLayout &DL;
2222
2223 /// Mapping from instructions to matrixes. It is used to identify
2224 /// matrix instructions.
2225 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2226
2227 /// Mapping from values to the leaves of all expressions that the value is
2228 /// part of.
2230
2231 /// Set of matrix expressions in the scope of a given DISubprogram.
2232 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2233
2234 /// Leaf node of the expression to linearize.
2235 Value *Leaf;
2236
2237 /// Used to keep track of sub-expressions that get reused while linearizing
2238 /// the expression. Re-used sub-expressions are marked as (reused).
2239 SmallPtrSet<Value *, 8> ReusedExprs;
2240
2241 ExprLinearizer(const DataLayout &DL,
2242 const MapVector<Value *, MatrixTy> &Inst2Matrix,
2243 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2244 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2245 Value *Leaf)
2246 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2247 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2248
2249 void indent(unsigned N) {
2250 LineLength += N;
2251 for (unsigned i = 0; i < N; i++)
2252 Stream << " ";
2253 }
2254
2255 void lineBreak() {
2256 Stream << "\n";
2257 LineLength = 0;
2258 }
2259
2260 void maybeIndent(unsigned Indent) {
2261 if (LineLength >= LengthToBreak)
2262 lineBreak();
2263
2264 if (LineLength == 0)
2265 indent(Indent);
2266 }
2267
2268 void write(StringRef S) {
2269 LineLength += S.size();
2270 Stream << S;
2271 }
2272
2273 Value *getUnderlyingObjectThroughLoads(Value *V) {
2274 if (Value *Ptr = getPointerOperand(V))
2275 return getUnderlyingObjectThroughLoads(Ptr);
2276 else if (V->getType()->isPointerTy())
2277 return getUnderlyingObject(V);
2278 return V;
2279 }
2280
2281 /// Returns true if \p V is a matrix value in the given subprogram.
2282 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2283
2284 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2285 /// \p SS.
2286 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2287 auto M = Inst2Matrix.find(V);
2288 if (M == Inst2Matrix.end())
2289 SS << "unknown";
2290 else {
2291 SS << M->second.getNumRows();
2292 SS << "x";
2293 SS << M->second.getNumColumns();
2294 }
2295 }
2296
2297 /// Write the called function name. Handles calls to llvm.matrix.*
2298 /// specially: we write the name, followed by the dimensions of the input
2299 /// matrixes, followed by the scalar type name.
2300 void writeFnName(CallInst *CI) {
2301 if (!CI->getCalledFunction())
2302 write("<no called fn>");
2303 else {
2305 if (!Name.starts_with("llvm.matrix")) {
2306 write(Name);
2307 return;
2308 }
2309 auto *II = cast<IntrinsicInst>(CI);
2310 write(Intrinsic::getBaseName(II->getIntrinsicID())
2311 .drop_front(StringRef("llvm.matrix.").size()));
2312 write(".");
2313 std::string Tmp;
2315
2316 switch (II->getIntrinsicID()) {
2317 case Intrinsic::matrix_multiply:
2318 prettyPrintMatrixType(II->getOperand(0), SS);
2319 SS << ".";
2320 prettyPrintMatrixType(II->getOperand(1), SS);
2321 SS << "." << *II->getType()->getScalarType();
2322 break;
2323 case Intrinsic::matrix_transpose:
2324 prettyPrintMatrixType(II->getOperand(0), SS);
2325 SS << "." << *II->getType()->getScalarType();
2326 break;
2327 case Intrinsic::matrix_column_major_load:
2328 prettyPrintMatrixType(II, SS);
2329 SS << "." << *II->getType()->getScalarType();
2330 break;
2331 case Intrinsic::matrix_column_major_store:
2332 prettyPrintMatrixType(II->getOperand(0), SS);
2333 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2334 break;
2335 default:
2336 llvm_unreachable("Unhandled case");
2337 }
2338 write(Tmp);
2339 }
2340 }
2341
2342 unsigned getNumShapeArgs(CallInst *CI) const {
2343 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2344 switch (II->getIntrinsicID()) {
2345 case Intrinsic::matrix_multiply:
2346 return 3;
2347 case Intrinsic::matrix_transpose:
2348 return 2;
2349 case Intrinsic::matrix_column_major_load:
2350 case Intrinsic::matrix_column_major_store:
2351 return 3;
2352 default:
2353 return 0;
2354 }
2355 }
2356 return 0;
2357 }
2358
2359 /// Special printing for values: for pointers, we print if they refer to an
2360 /// (function) external address or a stack address, for other values we
2361 /// either print the constant or "scalar"/"matrix" for other values.
2362 void write(Value *V) {
2363 V = getUnderlyingObjectThroughLoads(V);
2364 if (V->getType()->isPointerTy()) {
2365 if (isa<AllocaInst>(V)) {
2366 Stream << "stack addr";
2367 LineLength += StringRef("stack addr").size();
2368 } else {
2369 Stream << "addr";
2370 LineLength += StringRef("addr").size();
2371 }
2372 if (!V->getName().empty()) {
2373 Stream << " %" << V->getName() << "";
2374 LineLength += V->getName().size() + 2;
2375 }
2376 return;
2377 }
2378
2379 std::string Tmp;
2380 raw_string_ostream TmpStream(Tmp);
2381
2382 if (auto *CI = dyn_cast<ConstantInt>(V))
2383 TmpStream << CI->getValue();
2384 else if (isa<Constant>(V))
2385 TmpStream << "constant";
2386 else {
2387 if (isMatrix(V))
2388 TmpStream << "matrix";
2389 else
2390 TmpStream << "scalar";
2391 }
2392 Tmp = std::string(StringRef(Tmp).trim());
2393 LineLength += Tmp.size();
2394 Stream << Tmp;
2395 }
2396
2397 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2398 /// Expressions that are re-used multiple times are prefixed with (reused)
2399 /// at the re-used root instruction.
2400 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2401 bool ParentShared) {
2402 auto *I = cast<Instruction>(Expr);
2403 maybeIndent(Indent);
2405
2406 // Is Expr shared with other expression leaves?
2407 bool ExprShared = false;
2408
2409 // Deal with shared subtrees. Mark them as shared, if required.
2410 if (!ParentShared) {
2411 auto SI = Shared.find(Expr);
2412 assert(SI != Shared.end() && SI->second.count(Leaf));
2413
2414 for (Value *S : SI->second) {
2415 if (S == Leaf)
2416 continue;
2417 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2418 write("shared with remark at line " + std::to_string(DL.getLine()) +
2419 " column " + std::to_string(DL.getCol()) + " (");
2420 }
2421 ExprShared = SI->second.size() > 1;
2422 }
2423
2424 bool Reused = !ReusedExprs.insert(Expr).second;
2425 if (Reused && !ParentReused)
2426 write("(reused) ");
2427
2428 if (auto *CI = dyn_cast<CallInst>(I)) {
2429 writeFnName(CI);
2430
2431 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2432 } else if (isa<BitCastInst>(Expr)) {
2433 // Special case bitcasts, which are used to materialize matrixes from
2434 // non-matrix ops.
2435 write("matrix");
2436 return;
2437 } else {
2438 Ops.append(I->value_op_begin(), I->value_op_end());
2439 write(std::string(I->getOpcodeName()));
2440 }
2441
2442 write(std::string("("));
2443
2444 unsigned NumOpsToBreak = 1;
2445 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2446 NumOpsToBreak = 2;
2447
2448 for (Value *Op : Ops) {
2449 if (Ops.size() > NumOpsToBreak)
2450 lineBreak();
2451
2452 maybeIndent(Indent + 1);
2453 if (isMatrix(Op))
2454 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2455 else
2456 write(Op);
2457 if (Op != Ops.back())
2458 write(", ");
2459 }
2460
2461 write(")");
2462 }
2463
2464 const std::string &getResult() {
2465 return Str;
2466 }
2467 };
2468
2469 /// Generate remarks for matrix operations in a function. To generate remarks
2470 /// for matrix expressions, the following approach is used:
2471 /// 1. Use the inlined-at debug information to group matrix operations to the
2472 /// DISubprograms they are contained in.
2473 /// 2. Collect leaves of matrix expressions (done in
2474 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2475 // mapping. Leaves are lowered matrix instructions without other matrix
2476 // users (like stores) in the current subprogram.
2477 /// 3. For each leaf, create a remark containing a linearizied version of the
2478 /// matrix expression. The expression is linearized by a recursive
2479 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2480 /// that multiple leaves can share sub-expressions. Shared subexpressions
2481 /// are explicitly marked as shared().
2482 struct RemarkGenerator {
2483 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2485 Function &Func;
2486 const DataLayout &DL;
2487
2488 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2490 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2491 DL(Func.getDataLayout()) {}
2492
2493 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2494 /// instructions in Inst2Matrix returning void or without any users in
2495 /// \p ExprsInSubprogram. Currently that should only include stores.
2497 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2499 for (auto *Expr : ExprsInSubprogram)
2500 if (Expr->getType()->isVoidTy() ||
2501 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2502 return ExprsInSubprogram.count(U);
2503 }))
2504 Leaves.push_back(Expr);
2505 return Leaves;
2506 }
2507
2508 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2509 /// to all visited expressions in \p Shared. Limit the matrix operations to
2510 /// the ones in \p ExprsInSubprogram.
2511 void collectSharedInfo(Value *Leaf, Value *V,
2512 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2514
2515 if (!ExprsInSubprogram.count(V))
2516 return;
2517
2518 Shared[V].insert(Leaf);
2519
2520 for (Value *Op : cast<Instruction>(V)->operand_values())
2521 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2522 }
2523
2524 /// Calculate the number of exclusive and shared op counts for expression
2525 /// starting at \p V. Expressions used multiple times are counted once.
2526 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2527 std::pair<OpInfoTy, OpInfoTy>
2528 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2529 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2530 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2531 if (!ExprsInSubprogram.count(Root))
2532 return {};
2533
2534 // Already counted this expression. Stop.
2535 if (!ReusedExprs.insert(Root).second)
2536 return {};
2537
2538 OpInfoTy SharedCount;
2539 OpInfoTy Count;
2540
2541 auto I = Shared.find(Root);
2542 auto CM = Inst2Matrix.find(Root);
2543 if (I->second.size() == 1)
2544 Count = CM->second.getOpInfo();
2545 else
2546 SharedCount = CM->second.getOpInfo();
2547
2548 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2549 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2550 Count += C.first;
2551 SharedCount += C.second;
2552 }
2553 return {Count, SharedCount};
2554 }
2555
2556 void emitRemarks() {
2558 return;
2559
2560 // Map matrix operations to their containting subprograms, by traversing
2561 // the inlinedAt chain. If the function does not have a DISubprogram, we
2562 // only map them to the containing function.
2564 for (const auto &KV : Inst2Matrix) {
2565 if (Func.getSubprogram()) {
2566 auto *I = cast<Instruction>(KV.first);
2567 DILocation *Context = I->getDebugLoc();
2568 while (Context) {
2569 Subprog2Exprs[getSubprogram(Context->getScope())].push_back(
2570 KV.first);
2571 Context = DebugLoc(Context).getInlinedAt();
2572 }
2573 } else {
2574 Subprog2Exprs[nullptr].push_back(KV.first);
2575 }
2576 }
2577 for (auto &KV : Subprog2Exprs) {
2578 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2579 KV.second.end());
2580 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2581
2583 for (Value *Leaf : Leaves)
2584 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2585
2586 // Generate remarks for each leaf.
2587 for (auto *L : Leaves) {
2588
2589 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2590 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2591 while (Context) {
2592 if (getSubprogram(Context->getScope()) == KV.first) {
2593 Loc = Context;
2594 break;
2595 }
2596 Context = DebugLoc(Context).getInlinedAt();
2597 }
2598
2599 SmallPtrSet<Value *, 8> ReusedExprs;
2600 OpInfoTy Counts, SharedCounts;
2601 std::tie(Counts, SharedCounts) =
2602 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2603
2604 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2605 cast<Instruction>(L)->getParent());
2606
2607 Rem << "Lowered with ";
2608 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2609 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2610 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2611 << " compute ops, "
2612 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2613 << " exposed transposes";
2614
2615 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2616 SharedCounts.NumComputeOps > 0) {
2617 Rem << ",\nadditionally "
2618 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2619 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2620 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2621 << " compute ops"
2622 << " are shared with other expressions";
2623 }
2624
2625 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2626 ORE.emit(Rem);
2627 }
2628 }
2629 }
2630
2631 std::string
2632 linearize(Value *L,
2633 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2634 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2635 const DataLayout &DL) {
2636 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2637 Lin.linearizeExpr(L, 0, false, false);
2638 return Lin.getResult();
2639 }
2640 };
2641};
2642} // namespace
2643
2646 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2647
2648 LowerMatrixIntrinsics LMT(F, TTI, Minimal ? nullptr : &AM);
2649 if (LMT.Visit()) {
2651 if (!Minimal) {
2652 PA.preserve<LoopAnalysis>();
2654 }
2655 return PA;
2656 }
2657 return PreservedAnalyses::all();
2658}
2659
2661 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2663 OS, MapClassName2PassName);
2664 OS << '<';
2665 if (Minimal)
2666 OS << "minimal";
2667 OS << '>';
2668}
Rewrite undef for PHI
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static const Function * getParent(const Value *V)
BitTracker BT
Definition: BitTracker.cpp:73
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:686
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
std::string Name
bool End
Definition: ELF_riscv.cpp:480
Hexagon Common GEP
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:533
Live Register Matrix
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
#define DEBUG_TYPE
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define T1
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned getNumElements(Type *Ty)
static unsigned getFastMathFlags(const MachineInstr &I)
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2550
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2572
raw_pwrite_stream & OS
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallSet class.
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:39
static const int BlockSize
Definition: TarWriter.cpp:33
This pass exposes codegen information to IR-level passes.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Value * RHS
Value * LHS
BinaryOperator * Mul
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Definition: Instructions.h:63
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:124
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:448
reverse_iterator rbegin()
Definition: BasicBlock.h:464
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:179
reverse_iterator rend()
Definition: BasicBlock.h:466
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
BinaryOps getOpcode() const
Definition: InstrTypes.h:370
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1349
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1269
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1746
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1294
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1275
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1672
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Debug location.
Base class for scope-like contexts.
Subprogram description.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
A debug info location.
Definition: DebugLoc.h:33
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:39
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:156
bool erase(const KeyT &Val)
Definition: DenseMap.h:321
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:152
iterator end()
Definition: DenseMap.h:84
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:20
void setAllowContract(bool B=true)
Definition: FMF.h:91
bool allowReassoc() const
Flag queries.
Definition: FMF.h:65
bool allowContract() const
Definition: FMF.h:70
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:563
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:791
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:251
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:256
CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Definition: IRBuilder.cpp:402
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:2289
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1583
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2503
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1796
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2491
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition: IRBuilder.h:1830
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1556
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
Definition: IRBuilder.cpp:1152
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:890
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
Definition: IRBuilder.cpp:412
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
Definition: IRBuilder.h:572
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Definition: IRBuilder.h:308
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Definition: IRBuilder.h:1889
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition: IRBuilder.h:488
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition: IRBuilder.h:2429
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1367
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
Definition: IRBuilder.h:494
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition: IRBuilder.h:1144
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1813
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition: IRBuilder.h:2525
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1350
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2145
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Definition: IRBuilder.h:1849
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1610
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition: IRBuilder.h:1753
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memcpy between the specified pointers.
Definition: IRBuilder.h:655
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1384
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2697
static InstructionCost getInvalid(CostType Val=0)
void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:92
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:48
An instruction for reading from memory.
Definition: Instructions.h:176
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:205
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:211
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:36
iterator end()
Definition: MapVector.h:71
iterator find(const KeyT &Key)
Definition: MapVector.h:167
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:141
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)
Return a location representing a particular argument of a call.
The optimization diagnostic interface.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1878
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
void preserve()
Mark an analysis as preserved.
Definition: Analysis.h:131
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:363
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:452
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:458
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:132
bool empty() const
Definition: SmallSet.h:168
bool erase(const T &V)
Definition: SmallSet.h:193
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:181
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
Align getAlign() const
Definition: Instructions.h:333
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:325
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:609
constexpr size_t size() const
size - Get the string size.
Definition: StringRef.h:150
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
TypeSize getRegisterBitWidth(RegisterKind K) const
InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
@ TCK_RecipThroughput
Reciprocal throughput.
InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr, const TargetLibraryInfo *TLibInfo=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *Tp, ArrayRef< int > Mask={}, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr) const
unsigned getNumberOfRegisters(unsigned ClassID) const
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:139
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:355
UnaryOps getOpcode() const
Definition: InstrTypes.h:153
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
Value * getOperand(unsigned i) const
Definition: User.h:228
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
user_iterator user_begin()
Definition: Value.h:397
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:149
bool use_empty() const
Definition: Value.h:344
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Type * getElementType() const
Definition: DerivedTypes.h:460
constexpr ScalarTy getFixedValue() const
Definition: TypeSize.h:202
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition: ilist_node.h:32
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:661
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:125
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Intrinsics.cpp:41
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:100
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
Definition: PatternMatch.h:982
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
Definition: PatternMatch.h:67
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:239
@ SS
Definition: X86.h:212
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:711
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
ElementType
The element type of an SRV or UAV resource.
Definition: DXILABI.h:58
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
NodeAddr< FuncNode * > Func
Definition: RDFGraph.h:393
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:480
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1697
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition: STLExtras.h:2448
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2082
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)
Definition: DynamicAPInt.h:518
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:657
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:214
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1746
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:420
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)
Definition: DWP.cpp:625
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1664
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:167
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ Add
Sum of integers.
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition: Casting.h:565
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition: PassManager.h:69
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....
Definition: MatrixUtils.h:31