LLVM 17.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
30#include "llvm/IR/CFG.h"
31#include "llvm/IR/DataLayout.h"
33#include "llvm/IR/Function.h"
34#include "llvm/IR/IRBuilder.h"
40#include "llvm/Pass.h"
43#include "llvm/Support/Debug.h"
48
49#include <cmath>
50
51using namespace llvm;
52using namespace PatternMatch;
53
54#define DEBUG_TYPE "lower-matrix-intrinsics"
55
56static cl::opt<bool>
57 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
58 cl::desc("Enable/disable fusing matrix instructions."));
59// TODO: Allow and use non-square tiles.
61 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
63 "Tile size for matrix instruction fusion using square-shaped tiles."));
64static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
66 cl::desc("Generate loop nest for tiling."));
68 "force-fuse-matrix", cl::init(false), cl::Hidden,
69 cl::desc("Force matrix instruction fusion even if not profitable."));
71 "matrix-allow-contract", cl::init(false), cl::Hidden,
72 cl::desc("Allow the use of FMAs if available and profitable. This may "
73 "result in different results, due to less rounding error."));
74
75static cl::opt<bool>
76 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
77 cl::desc("Enable/disable matrix shape verification."),
78 cl::init(false));
79
81
83 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
84 cl::desc("Sets the default matrix layout"),
86 "Use column-major layout"),
88 "Use row-major layout")));
89
90static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
91 cl::init(false));
92
93/// Helper function to either return Scope, if it is a subprogram or the
94/// attached subprogram for a local scope.
96 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
97 return Subprogram;
98 return cast<DILocalScope>(Scope)->getSubprogram();
99}
100
101/// Erase \p V from \p BB and move \II forward to avoid invalidating
102/// iterators.
104 BasicBlock &BB) {
105 auto *Inst = cast<Instruction>(V);
106 // Still used, don't erase.
107 if (!Inst->use_empty())
108 return;
109 if (II != BB.rend() && Inst == &*II)
110 ++II;
111 Inst->eraseFromParent();
112}
113
114/// Return true if V is a splat of a value (which is used when multiplying a
115/// matrix with a scalar).
116static bool isSplat(Value *V) {
117 if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
118 return SV->isZeroEltSplat();
119 return false;
120}
121
122/// Match any mul operation (fp or integer).
123template <typename LTy, typename RTy>
124auto m_AnyMul(const LTy &L, const RTy &R) {
125 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
126}
127
128/// Match any add operation (fp or integer).
129template <typename LTy, typename RTy>
130auto m_AnyAdd(const LTy &L, const RTy &R) {
131 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
132}
133
134namespace {
135
136// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
137// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
138// assuming \p Stride elements between start two consecutive vectors.
139// \p Stride must be >= \p NumElements.
140// For column-major matrixes, the function computes the address of a column
141// vectors and \p NumElements must be set to the number of elements in a column
142// (= number of rows of the matrix). For row-major matrixes, the function
143// computes the address of a row vector and \p NumElements must be set to the
144// number of elements in a column (= number of columns of the matrix).
145//
146// Consider a 4x4 matrix in column-mjaor layout like below
147//
148// 0 1 2 3
149// 0 v_0_0 v_0_1 v_0_2 v_0_3
150// 1 v_1_0 v_1_1 v_1_2 v_1_3
151// 2 v_2_0 v_2_1 v_2_2 v_2_3
152// 3 v_3_0 v_3_1 v_3_2 v_3_3
153
154// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
155// we need a pointer to the first element of the submatrix as base pointer.
156// Then we can use computeVectorAddr to compute the addresses for the columns
157// of the sub-matrix.
158//
159// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
160// -> just returns Base
161// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
162// -> returns Base + (1 * 4)
163// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
164// -> returns Base + (2 * 4)
165//
166// The graphic below illustrates the number of elements in a column (marked
167// with |) and the number of skipped elements (marked with }).
168//
169// v_0_0 v_0_1 {v_0_2 {v_0_3
170// Base Col 1 Col 2
171// | | |
172// v_1_0 |v_1_1 |v_1_2 |v_1_3
173// v_2_0 |v_2_1 |v_2_2 |v_2_3
174// v_3_0 {v_3_1 {v_3_2 v_3_3
175//
176Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
177 unsigned NumElements, Type *EltType,
178 IRBuilder<> &Builder) {
179
180 assert((!isa<ConstantInt>(Stride) ||
181 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
182 "Stride must be >= the number of elements in the result vector.");
183 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
184
185 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
186 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
187
188 // Get pointer to the start of the selected vector. Skip GEP creation,
189 // if we select vector 0.
190 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
191 VecStart = BasePtr;
192 else
193 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
194
195 // Cast elementwise vector start pointer to a pointer to a vector
196 // (EltType x NumElements)*.
197 auto *VecType = FixedVectorType::get(EltType, NumElements);
198 Type *VecPtrType = PointerType::get(VecType, AS);
199 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
200}
201
202/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
203///
204/// Currently, the lowering for each matrix intrinsic is done as follows:
205/// 1. Propagate the shape information from intrinsics to connected
206/// instructions.
207/// 2. Lower instructions with shape information (assuming column-major layout).
208/// The lowering works similarly using row-major layout.
209/// 2.1. Get column vectors for each argument. If we already lowered the
210/// definition of an argument, use the produced column vectors directly.
211/// If not, split the operand vector containing an embedded matrix into
212/// a set of column vectors,
213/// 2.2. Lower the instruction in terms of column major operations, which
214/// yields a set of column vectors containing result matrix. Note that we
215/// lower all instructions that have shape information. Besides the
216/// intrinsics, this includes stores for example.
217/// 2.3. Update uses of the lowered instruction. If we have shape information
218/// for a user, there is nothing to do, as we will look up the result
219/// column matrix when lowering the user. For other uses, we embed the
220/// result matrix in a flat vector and update the use.
221/// 2.4. Cache the result column matrix for the instruction we lowered
222/// 3. After we lowered all instructions in a function, remove the now
223/// obsolete instructions.
224///
225class LowerMatrixIntrinsics {
226 Function &Func;
227 const DataLayout &DL;
229 AliasAnalysis *AA;
230 DominatorTree *DT;
231 LoopInfo *LI;
233
234 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
235 struct OpInfoTy {
236 /// Number of stores emitted to generate this matrix.
237 unsigned NumStores = 0;
238 /// Number of loads emitted to generate this matrix.
239 unsigned NumLoads = 0;
240 /// Number of compute operations emitted to generate this matrix.
241 unsigned NumComputeOps = 0;
242 /// Most of the time transposes can be fused with matrix multiplies or can
243 /// be folded away via algebraic simplifications. This is the number of
244 /// transposes that we failed to make "free" via such optimizations.
245 unsigned NumExposedTransposes = 0;
246
247 OpInfoTy &operator+=(const OpInfoTy &RHS) {
248 NumStores += RHS.NumStores;
249 NumLoads += RHS.NumLoads;
250 NumComputeOps += RHS.NumComputeOps;
251 NumExposedTransposes += RHS.NumExposedTransposes;
252 return *this;
253 }
254 };
255
256 /// Wrapper class representing a matrix as a set of vectors, either in row or
257 /// column major layout. All vectors must have the same vector type.
258 class MatrixTy {
260
261 OpInfoTy OpInfo;
262
263 bool IsColumnMajor = true;
264
265 public:
266 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
267 MatrixTy(ArrayRef<Value *> Vectors)
268 : Vectors(Vectors.begin(), Vectors.end()),
269 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
270 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
271 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
272
273 unsigned D = isColumnMajor() ? NumColumns : NumRows;
274 for (unsigned J = 0; J < D; ++J)
276 EltTy, isColumnMajor() ? NumRows : NumColumns)));
277 }
278
279 Value *getVector(unsigned i) const { return Vectors[i]; }
280 Value *getColumn(unsigned i) const {
281 assert(isColumnMajor() && "only supported for column-major matrixes");
282 return Vectors[i];
283 }
284 Value *getRow(unsigned i) const {
285 assert(!isColumnMajor() && "only supported for row-major matrixes");
286 return Vectors[i];
287 }
288
289 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
290
291 Type *getElementType() const { return getVectorTy()->getElementType(); }
292
293 unsigned getNumVectors() const {
294 if (isColumnMajor())
295 return getNumColumns();
296 return getNumRows();
297 }
298
299 unsigned getNumColumns() const {
300 if (isColumnMajor())
301 return Vectors.size();
302 else {
303 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
304 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
305 }
306 }
307 unsigned getNumRows() const {
308 if (isColumnMajor()) {
309 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
310 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
311 } else
312 return Vectors.size();
313 }
314
315 void addVector(Value *V) { Vectors.push_back(V); }
316 VectorType *getColumnTy() {
317 assert(isColumnMajor() && "only supported for column-major matrixes");
318 return getVectorTy();
319 }
320
321 VectorType *getVectorTy() const {
322 return cast<VectorType>(Vectors[0]->getType());
323 }
324
326 assert(isColumnMajor() &&
327 "columns() only supported for column-major matrixes");
328 return make_range(Vectors.begin(), Vectors.end());
329 }
330
332 return make_range(Vectors.begin(), Vectors.end());
333 }
334
335 /// Embed the vectors of the matrix into a flat vector by concatenating
336 /// them.
337 Value *embedInVector(IRBuilder<> &Builder) const {
338 return Vectors.size() == 1 ? Vectors[0]
339 : concatenateVectors(Builder, Vectors);
340 }
341
342 MatrixTy &addNumLoads(unsigned N) {
343 OpInfo.NumLoads += N;
344 return *this;
345 }
346
347 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
348
349 MatrixTy &addNumStores(unsigned N) {
350 OpInfo.NumStores += N;
351 return *this;
352 }
353
354 MatrixTy &addNumExposedTransposes(unsigned N) {
355 OpInfo.NumExposedTransposes += N;
356 return *this;
357 }
358
359 MatrixTy &addNumComputeOps(unsigned N) {
360 OpInfo.NumComputeOps += N;
361 return *this;
362 }
363
364 unsigned getNumStores() const { return OpInfo.NumStores; }
365 unsigned getNumLoads() const { return OpInfo.NumLoads; }
366 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
367
368 const OpInfoTy &getOpInfo() const { return OpInfo; }
369
370 bool isColumnMajor() const { return IsColumnMajor; }
371
372 unsigned getStride() const {
373 if (isColumnMajor())
374 return getNumRows();
375 return getNumColumns();
376 }
377
378 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
379 /// matrix is column-major, the result vector is extracted from a column
380 /// vector, otherwise from a row vector.
381 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
382 IRBuilder<> &Builder) const {
383 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
384 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
385 NumElts &&
386 "Extracted vector will contain poison values");
387 return Builder.CreateShuffleVector(
388 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
389 "block");
390 }
391 };
392
393 struct ShapeInfo {
394 unsigned NumRows;
395 unsigned NumColumns;
396
397 bool IsColumnMajor;
398
399 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
400 : NumRows(NumRows), NumColumns(NumColumns),
401 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
402
403 ShapeInfo(Value *NumRows, Value *NumColumns)
404 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
405 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
406
407 bool operator==(const ShapeInfo &other) {
408 return NumRows == other.NumRows && NumColumns == other.NumColumns;
409 }
410 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
411
412 /// Returns true if shape-information is defined, meaning both dimensions
413 /// are != 0.
414 operator bool() const {
415 assert(NumRows == 0 || NumColumns != 0);
416 return NumRows != 0;
417 }
418
419 unsigned getStride() const {
420 if (IsColumnMajor)
421 return NumRows;
422 return NumColumns;
423 }
424
425 unsigned getNumVectors() const {
426 if (IsColumnMajor)
427 return NumColumns;
428 return NumRows;
429 }
430
431 /// Returns the transposed shape.
432 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
433 };
434
435 /// Maps instructions to their shape information. The shape information
436 /// describes the shape to be used while lowering. This matches the shape of
437 /// the result value of the instruction, with the only exceptions being store
438 /// instructions and the matrix_column_major_store intrinsics. For those, the
439 /// shape information indicates that those instructions should be lowered
440 /// using shape information as well. A ValueMap is used so that when
441 /// sub-passes like optimizeTransposes performs RAUW the map stays
442 /// up-to-date.
444
445 /// List of instructions to remove. While lowering, we are not replacing all
446 /// users of a lowered instruction, if shape information is available and
447 /// those need to be removed after we finished lowering.
449
450 /// Map from instructions to their produced column matrix.
451 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
452
453private:
454 static FastMathFlags getFastMathFlags(Instruction *Inst) {
455 FastMathFlags FMF;
456
457 if (isa<FPMathOperator>(*Inst))
458 FMF = Inst->getFastMathFlags();
459
461
462 return FMF;
463 }
464
465public:
466 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
469 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
470 LI(LI), ORE(ORE) {}
471
472 unsigned getNumOps(Type *VT) {
473 assert(isa<VectorType>(VT) && "Expected vector type");
474 return getNumOps(VT->getScalarType(),
475 cast<FixedVectorType>(VT)->getNumElements());
476 }
477
478 /// Is this the minimal version executed in the backend pipelines.
479 bool isMinimal() const {
480 return !DT;
481 }
482
483 /// Return the estimated number of vector ops required for an operation on
484 /// \p VT * N.
485 unsigned getNumOps(Type *ST, unsigned N) {
486 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
489 .getFixedValue()));
490 }
491
492 /// Return the set of vectors that a matrix value is lowered to.
493 ///
494 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
495 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
496 /// into vectors.
497 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
498 IRBuilder<> &Builder) {
499 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
500 assert(VType && "MatrixVal must be a vector type");
501 assert(cast<FixedVectorType>(VType)->getNumElements() ==
502 SI.NumRows * SI.NumColumns &&
503 "The vector size must match the number of matrix elements");
504
505 // Check if we lowered MatrixVal using shape information. In that case,
506 // return the existing matrix, if it matches the requested shape
507 // information. If there is a mis-match, embed the result in a flat
508 // vector and split it later.
509 auto Found = Inst2ColumnMatrix.find(MatrixVal);
510 if (Found != Inst2ColumnMatrix.end()) {
511 MatrixTy &M = Found->second;
512 // Return the found matrix, if its shape matches the requested shape
513 // information
514 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
515 return M;
516
517 MatrixVal = M.embedInVector(Builder);
518 }
519
520 // Otherwise split MatrixVal.
521 SmallVector<Value *, 16> SplitVecs;
522 for (unsigned MaskStart = 0;
523 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
524 MaskStart += SI.getStride()) {
525 Value *V = Builder.CreateShuffleVector(
526 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
527 "split");
528 SplitVecs.push_back(V);
529 }
530
531 return {SplitVecs};
532 }
533
534 /// If \p V already has a known shape return false. Otherwise set the shape
535 /// for instructions that support it.
536 bool setShapeInfo(Value *V, ShapeInfo Shape) {
537 assert(Shape && "Shape not set");
538 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
539 return false;
540
541 auto SIter = ShapeMap.find(V);
542 if (SIter != ShapeMap.end()) {
543 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
544 SIter->second.NumColumns != Shape.NumColumns)) {
545 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
546 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
547 << Shape.NumColumns << ") for " << *V << "\n";
549 "Matrix shape verification failed, compilation aborted!");
550 }
551
552 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
553 << SIter->second.NumRows << " "
554 << SIter->second.NumColumns << " for " << *V << "\n");
555 return false;
556 }
557
558 ShapeMap.insert({V, Shape});
559 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
560 << " for " << *V << "\n");
561 return true;
562 }
563
564 bool isUniformShape(Value *V) {
565 Instruction *I = dyn_cast<Instruction>(V);
566 if (!I)
567 return true;
568
569 switch (I->getOpcode()) {
570 case Instruction::FAdd:
571 case Instruction::FSub:
572 case Instruction::FMul: // Scalar multiply.
573 case Instruction::FNeg:
574 case Instruction::Add:
575 case Instruction::Mul:
576 case Instruction::Sub:
577 return true;
578 default:
579 return false;
580 }
581 }
582
583 /// Returns true if shape information can be used for \p V. The supported
584 /// instructions must match the instructions that can be lowered by this pass.
585 bool supportsShapeInfo(Value *V) {
586 Instruction *Inst = dyn_cast<Instruction>(V);
587 if (!Inst)
588 return false;
589
590 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
591 if (II)
592 switch (II->getIntrinsicID()) {
593 case Intrinsic::matrix_multiply:
594 case Intrinsic::matrix_transpose:
595 case Intrinsic::matrix_column_major_load:
596 case Intrinsic::matrix_column_major_store:
597 return true;
598 default:
599 return false;
600 }
601 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
602 }
603
604 /// Propagate the shape information of instructions to their users.
605 /// The work list contains instructions for which we can compute the shape,
606 /// either based on the information provided by matrix intrinsics or known
607 /// shapes of operands.
609 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
611 // Pop an element for which we guaranteed to have at least one of the
612 // operand shapes. Add the shape for this and then add users to the work
613 // list.
614 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
615 while (!WorkList.empty()) {
616 Instruction *Inst = WorkList.pop_back_val();
617
618 // New entry, set the value and insert operands
619 bool Propagate = false;
620
621 Value *MatrixA;
622 Value *MatrixB;
623 Value *M;
624 Value *N;
625 Value *K;
626 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
627 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
628 m_Value(N), m_Value(K)))) {
629 Propagate = setShapeInfo(Inst, {M, K});
630 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
631 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
632 // Flip dimensions.
633 Propagate = setShapeInfo(Inst, {N, M});
634 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
635 m_Value(MatrixA), m_Value(), m_Value(),
636 m_Value(), m_Value(M), m_Value(N)))) {
637 Propagate = setShapeInfo(Inst, {N, M});
638 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
639 m_Value(), m_Value(), m_Value(), m_Value(M),
640 m_Value(N)))) {
641 Propagate = setShapeInfo(Inst, {M, N});
642 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
643 auto OpShape = ShapeMap.find(MatrixA);
644 if (OpShape != ShapeMap.end())
645 setShapeInfo(Inst, OpShape->second);
646 continue;
647 } else if (isUniformShape(Inst)) {
648 // Find the first operand that has a known shape and use that.
649 for (auto &Op : Inst->operands()) {
650 auto OpShape = ShapeMap.find(Op.get());
651 if (OpShape != ShapeMap.end()) {
652 Propagate |= setShapeInfo(Inst, OpShape->second);
653 break;
654 }
655 }
656 }
657
658 if (Propagate) {
659 NewWorkList.push_back(Inst);
660 for (auto *User : Inst->users())
661 if (ShapeMap.count(User) == 0)
662 WorkList.push_back(cast<Instruction>(User));
663 }
664 }
665
666 return NewWorkList;
667 }
668
669 /// Propagate the shape to operands of instructions with shape information.
670 /// \p Worklist contains the instruction for which we already know the shape.
672 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
674
675 auto pushInstruction = [](Value *V,
677 Instruction *I = dyn_cast<Instruction>(V);
678 if (I)
679 WorkList.push_back(I);
680 };
681 // Pop an element with known shape. Traverse the operands, if their shape
682 // derives from the result shape and is unknown, add it and add them to the
683 // worklist.
684 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
685 while (!WorkList.empty()) {
686 Value *V = WorkList.pop_back_val();
687
688 size_t BeforeProcessingV = WorkList.size();
689 if (!isa<Instruction>(V))
690 continue;
691
692 Value *MatrixA;
693 Value *MatrixB;
694 Value *M;
695 Value *N;
696 Value *K;
697 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
698 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
699 m_Value(N), m_Value(K)))) {
700 if (setShapeInfo(MatrixA, {M, N}))
701 pushInstruction(MatrixA, WorkList);
702
703 if (setShapeInfo(MatrixB, {N, K}))
704 pushInstruction(MatrixB, WorkList);
705
706 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
707 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
708 // Flip dimensions.
709 if (setShapeInfo(MatrixA, {M, N}))
710 pushInstruction(MatrixA, WorkList);
711 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
712 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
713 m_Value(M), m_Value(N)))) {
714 if (setShapeInfo(MatrixA, {M, N})) {
715 pushInstruction(MatrixA, WorkList);
716 }
717 } else if (isa<LoadInst>(V) ||
718 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
719 // Nothing to do, no matrix input.
720 } else if (isa<StoreInst>(V)) {
721 // Nothing to do. We forward-propagated to this so we would just
722 // backward propagate to an instruction with an already known shape.
723 } else if (isUniformShape(V)) {
724 // Propagate to all operands.
725 ShapeInfo Shape = ShapeMap[V];
726 for (Use &U : cast<Instruction>(V)->operands()) {
727 if (setShapeInfo(U.get(), Shape))
728 pushInstruction(U.get(), WorkList);
729 }
730 }
731 // After we discovered new shape info for new instructions in the
732 // worklist, we use their users as seeds for the next round of forward
733 // propagation.
734 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
735 for (User *U : WorkList[I]->users())
736 if (isa<Instruction>(U) && V != U)
737 NewWorkList.push_back(cast<Instruction>(U));
738 }
739 return NewWorkList;
740 }
741
742 /// (Op0 op Op1)^T -> Op0^T op Op1^T
743 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
744 /// them on both sides of \p Operation.
745 Instruction *distributeTransposes(
746 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
747 MatrixBuilder &Builder,
748 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
749 Operation) {
750 Value *T0 = Builder.CreateMatrixTranspose(
751 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
752 // We are being run after shape prop, add shape for newly created
753 // instructions so that we lower them later.
754 setShapeInfo(T0, Shape0.t());
755 Value *T1 = Builder.CreateMatrixTranspose(
756 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
757 setShapeInfo(T1, Shape1.t());
758 return Operation(T0, Shape0.t(), T1, Shape1.t());
759 }
760
761 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
762 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
763 // with New. We should only add New it it supportsShapeInfo so we insert
764 // it conditionally instead.
765 auto S = ShapeMap.find(&Old);
766 if (S != ShapeMap.end()) {
767 ShapeMap.erase(S);
768 if (supportsShapeInfo(New))
769 ShapeMap.insert({New, S->second});
770 }
771 Old.replaceAllUsesWith(New);
772 }
773
774 /// Sink a top-level transpose inside matmuls and adds.
775 /// This creates and erases instructions as needed, and returns the newly
776 /// created instruction while updating the iterator to avoid invalidation. If
777 /// this returns nullptr, no new instruction was created.
779 BasicBlock &BB = *I.getParent();
780 IRBuilder<> IB(&I);
782
783 Value *TA, *TAMA, *TAMB;
784 ConstantInt *R, *K, *C;
785 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
787 return nullptr;
788
789 // Transpose of a transpose is a nop
790 Value *TATA;
791 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
792 updateShapeAndReplaceAllUsesWith(I, TATA);
793 eraseFromParentAndMove(&I, II, BB);
794 eraseFromParentAndMove(TA, II, BB);
795 return nullptr;
796 }
797
798 // k^T -> k
799 if (isSplat(TA)) {
800 updateShapeAndReplaceAllUsesWith(I, TA);
801 eraseFromParentAndMove(&I, II, BB);
802 return nullptr;
803 }
804
805 // (A * B)^t -> B^t * A^t
806 // RxK KxC CxK KxR
807 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
808 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
810 auto NewInst = distributeTransposes(
811 TAMB, {K, C}, TAMA, {R, K}, Builder,
812 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
813 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
814 Shape0.NumColumns,
815 Shape1.NumColumns, "mmul");
816 });
817 updateShapeAndReplaceAllUsesWith(I, NewInst);
818 eraseFromParentAndMove(&I, II, BB);
819 eraseFromParentAndMove(TA, II, BB);
820 return NewInst;
821 }
822
823 // Same as above, but with a mul, which occurs when multiplied
824 // with a scalar.
825 // (A * k)^t -> A^t * k
826 // R x C RxC
827 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
828 (isSplat(TAMA) || isSplat(TAMB))) {
829 IRBuilder<> LocalBuilder(&I);
830 // We know that the transposed operand is of shape RxC.
831 // An when multiplied with a scalar, the shape is preserved.
832 auto NewInst = distributeTransposes(
833 TAMA, {R, C}, TAMB, {R, C}, Builder,
834 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
835 bool IsFP = I.getType()->isFPOrFPVectorTy();
836 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
837 : LocalBuilder.CreateMul(T0, T1, "mmul");
838 auto *Result = cast<Instruction>(Mul);
839 setShapeInfo(Result, Shape0);
840 return Result;
841 });
842 updateShapeAndReplaceAllUsesWith(I, NewInst);
843 eraseFromParentAndMove(&I, II, BB);
844 eraseFromParentAndMove(TA, II, BB);
845 return NewInst;
846 }
847
848 // (A + B)^t -> A^t + B^t
849 // RxC RxC CxR CxR
850 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
851 IRBuilder<> LocalBuilder(&I);
852 auto NewInst = distributeTransposes(
853 TAMA, {R, C}, TAMB, {R, C}, Builder,
854 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
855 bool IsFP = I.getType()->isFPOrFPVectorTy();
856 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
857 : LocalBuilder.CreateAdd(T0, T1, "madd");
858
859 auto *Result = cast<Instruction>(Add);
860 setShapeInfo(Result, Shape0);
861 return Result;
862 });
863 updateShapeAndReplaceAllUsesWith(I, NewInst);
864 eraseFromParentAndMove(&I, II, BB);
865 eraseFromParentAndMove(TA, II, BB);
866 return NewInst;
867 }
868
869 return nullptr;
870 }
871
872 void liftTranspose(Instruction &I) {
873 // Erase dead Instructions after lifting transposes from binops.
874 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
875 if (T.use_empty())
876 T.eraseFromParent();
877 if (A->use_empty())
878 cast<Instruction>(A)->eraseFromParent();
879 if (A != B && B->use_empty())
880 cast<Instruction>(B)->eraseFromParent();
881 };
882
883 Value *A, *B, *AT, *BT;
884 ConstantInt *R, *K, *C;
885 // A^t * B ^t -> (B * A)^t
886 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
889 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
890 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
891 IRBuilder<> IB(&I);
893 Value *M = Builder.CreateMatrixMultiply(
894 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
895 setShapeInfo(M, {C, R});
896 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
897 R->getZExtValue());
898 updateShapeAndReplaceAllUsesWith(I, NewInst);
899 CleanupBinOp(I, A, B);
900 }
901 // A^t + B ^t -> (A + B)^t
902 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
903 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
904 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
905 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
908 Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
909 setShapeInfo(Add, {C, R});
910 MatrixBuilder MBuilder(Builder);
911 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
912 Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
913 updateShapeAndReplaceAllUsesWith(I, NewInst);
914 CleanupBinOp(I, A, B);
915 }
916 }
917
918 /// Try moving transposes in order to fold them away or into multiplies.
919 void optimizeTransposes() {
920 // First sink all transposes inside matmuls and adds, hoping that we end up
921 // with NN, NT or TN variants.
922 for (BasicBlock &BB : reverse(Func)) {
923 for (auto II = BB.rbegin(); II != BB.rend();) {
924 Instruction &I = *II;
925 // We may remove II. By default continue on the next/prev instruction.
926 ++II;
927 if (Instruction *NewInst = sinkTranspose(I, II))
928 II = std::next(BasicBlock::reverse_iterator(NewInst));
929 }
930 }
931
932 // If we have a TT matmul or a TT add, lift the transpose. We may be able
933 // to fold into consuming multiply or add.
934 for (BasicBlock &BB : Func) {
936 liftTranspose(I);
937 }
938 }
939 }
940
941 bool Visit() {
943
944 // Initially only the shape of matrix intrinsics is known.
945 // Initialize the work list with ops carrying shape information.
946 for (BasicBlock &BB : Func)
947 for (Instruction &Inst : BB) {
948 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
949 if (!II)
950 continue;
951
952 switch (II->getIntrinsicID()) {
953 case Intrinsic::matrix_multiply:
954 case Intrinsic::matrix_transpose:
955 case Intrinsic::matrix_column_major_load:
956 case Intrinsic::matrix_column_major_store:
957 WorkList.push_back(&Inst);
958 break;
959 default:
960 break;
961 }
962 }
963
964 // Avoid unnecessary work if there are no matrix intrinsics in the function.
965 if (WorkList.empty())
966 return false;
967
968 // Propagate shapes until nothing changes any longer.
969 while (!WorkList.empty()) {
970 WorkList = propagateShapeForward(WorkList);
971 WorkList = propagateShapeBackward(WorkList);
972 }
973
974 if (!isMinimal()) {
975 optimizeTransposes();
977 dbgs() << "Dump after matrix transpose optimization:\n";
978 Func.print(dbgs());
979 }
980 }
981
982 bool Changed = false;
983 SmallVector<CallInst *, 16> MaybeFusableInsts;
985
986 // First, collect all instructions with shape information and candidates for
987 // fusion (currently only matrix multiplies).
989 for (auto *BB : RPOT)
990 for (Instruction &I : *BB) {
991 if (ShapeMap.find(&I) == ShapeMap.end())
992 continue;
993 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
994 MaybeFusableInsts.push_back(cast<CallInst>(&I));
995 MatrixInsts.push_back(&I);
996 }
997
998 // Second, try to lower any dot products
1000 for (CallInst *CI : MaybeFusableInsts)
1001 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1002
1003 // Third, try to fuse candidates.
1004 for (CallInst *CI : MaybeFusableInsts)
1005 LowerMatrixMultiplyFused(CI, FusedInsts);
1006
1007 Changed = !FusedInsts.empty();
1008
1009 // Fourth, lower remaining instructions with shape information.
1010 for (Instruction *Inst : MatrixInsts) {
1011 if (FusedInsts.count(Inst))
1012 continue;
1013
1014 IRBuilder<> Builder(Inst);
1015
1016 if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1017 Changed |= VisitCallInst(CInst);
1018
1019 Value *Op1;
1020 Value *Op2;
1021 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1022 Changed |= VisitBinaryOperator(BinOp);
1023 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1024 Changed |= VisitUnaryOperator(UnOp);
1025 if (match(Inst, m_Load(m_Value(Op1))))
1026 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1027 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1028 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1029 }
1030
1031 if (ORE) {
1032 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1033 RemarkGen.emitRemarks();
1034 }
1035
1036 // Delete the instructions backwards, as it has a reduced likelihood of
1037 // having to update as many def-use and use-def chains.
1038 //
1039 // Because we add to ToRemove during fusion we can't guarantee that defs
1040 // are before uses. Change uses to poison temporarily as these should get
1041 // removed as well.
1042 //
1043 // For verification, we keep track of where we changed uses to poison in
1044 // PoisonedInsts and then check that we in fact remove them.
1045 SmallSet<Instruction *, 16> PoisonedInsts;
1046 for (auto *Inst : reverse(ToRemove)) {
1047 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1048 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1049 PoisonedInsts.insert(Poisoned);
1050 U.set(PoisonValue::get(Inst->getType()));
1051 }
1052 Inst->eraseFromParent();
1053 PoisonedInsts.erase(Inst);
1054 }
1055 if (!PoisonedInsts.empty()) {
1056 // If we didn't remove all poisoned instructions, it's a hard error.
1057 dbgs() << "Poisoned but present instructions:\n";
1058 for (auto *I : PoisonedInsts)
1059 dbgs() << *I << "\n";
1060 llvm_unreachable("Poisoned but instruction not removed");
1061 }
1062
1063 return Changed;
1064 }
1065
1066 /// Turns \p BasePtr into an elementwise pointer to \p EltType.
1067 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
1068 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
1069 Type *EltPtrType = PointerType::get(EltType, AS);
1070 return Builder.CreatePointerCast(BasePtr, EltPtrType);
1071 }
1072
1073 /// Replace intrinsic calls
1074 bool VisitCallInst(CallInst *Inst) {
1075 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1076 return false;
1077
1078 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1079 case Intrinsic::matrix_multiply:
1080 LowerMultiply(Inst);
1081 break;
1082 case Intrinsic::matrix_transpose:
1083 LowerTranspose(Inst);
1084 break;
1085 case Intrinsic::matrix_column_major_load:
1086 LowerColumnMajorLoad(Inst);
1087 break;
1088 case Intrinsic::matrix_column_major_store:
1089 LowerColumnMajorStore(Inst);
1090 break;
1091 default:
1092 return false;
1093 }
1094 return true;
1095 }
1096
1097 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1098 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1099 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1100 /// non-ConstantInt strides, return the common alignment of the initial
1101 /// alignment and the element size in bytes.
1102 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1103 MaybeAlign A) const {
1104 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1105 if (Idx == 0)
1106 return InitialAlign;
1107
1108 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1109 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1110 uint64_t StrideInBytes =
1111 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1112 return commonAlignment(InitialAlign, Idx * StrideInBytes);
1113 }
1114 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1115 }
1116
1117 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1118 /// vectors.
1119 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1120 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1121 auto *VType = cast<VectorType>(Ty);
1122 Type *EltTy = VType->getElementType();
1123 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1124 Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
1125 MatrixTy Result;
1126 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1127 Value *GEP = computeVectorAddr(
1128 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1129 Stride, Shape.getStride(), EltTy, Builder);
1130 Value *Vector = Builder.CreateAlignedLoad(
1131 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1132 IsVolatile, "col.load");
1133
1134 Result.addVector(Vector);
1135 }
1136 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1137 Result.getNumVectors());
1138 }
1139
1140 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1141 /// starting at \p MatrixPtr[I][J].
1142 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1143 ShapeInfo MatrixShape, Value *I, Value *J,
1144 ShapeInfo ResultShape, Type *EltTy,
1145 IRBuilder<> &Builder) {
1146
1147 Value *Offset = Builder.CreateAdd(
1148 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1149
1150 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1151 Value *EltPtr =
1152 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1153 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1154 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1155 ResultShape.NumColumns);
1156 Type *TilePtrTy = PointerType::get(TileTy, AS);
1157 Value *TilePtr =
1158 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1159
1160 return loadMatrix(TileTy, TilePtr, Align,
1161 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1162 ResultShape, Builder);
1163 }
1164
1165 /// Lower a load instruction with shape information.
1166 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1167 bool IsVolatile, ShapeInfo Shape) {
1168 IRBuilder<> Builder(Inst);
1169 finalizeLowering(Inst,
1170 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1171 Shape, Builder),
1172 Builder);
1173 }
1174
1175 /// Lowers llvm.matrix.column.major.load.
1176 ///
1177 /// The intrinsic loads a matrix from memory using a stride between columns.
1178 void LowerColumnMajorLoad(CallInst *Inst) {
1179 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1180 "Intrinsic only supports column-major layout!");
1181 Value *Ptr = Inst->getArgOperand(0);
1182 Value *Stride = Inst->getArgOperand(1);
1183 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1184 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1185 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1186 }
1187
1188 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1189 /// MatrixPtr[I][J].
1190 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1191 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1192 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1193 Value *Offset = Builder.CreateAdd(
1194 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1195
1196 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1197 Value *EltPtr =
1198 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1199 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1200 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1201 StoreVal.getNumColumns());
1202 Type *TilePtrTy = PointerType::get(TileTy, AS);
1203 Value *TilePtr =
1204 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
1205
1206 storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1207 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1208 }
1209
1210 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1211 /// vectors.
1212 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1213 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1214 IRBuilder<> &Builder) {
1215 auto VType = cast<VectorType>(Ty);
1216 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
1217 for (auto Vec : enumerate(StoreVal.vectors())) {
1218 Value *GEP = computeVectorAddr(
1219 EltPtr,
1220 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1221 Vec.index()),
1222 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1223 Builder.CreateAlignedStore(Vec.value(), GEP,
1224 getAlignForIndex(Vec.index(), Stride,
1225 VType->getElementType(),
1226 MAlign),
1227 IsVolatile);
1228 }
1229 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1230 StoreVal.getNumVectors());
1231 }
1232
1233 /// Lower a store instruction with shape information.
1235 Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1236 IRBuilder<> Builder(Inst);
1237 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1238 finalizeLowering(Inst,
1239 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1240 IsVolatile, Builder),
1241 Builder);
1242 }
1243
1244 /// Lowers llvm.matrix.column.major.store.
1245 ///
1246 /// The intrinsic store a matrix back memory using a stride between columns.
1247 void LowerColumnMajorStore(CallInst *Inst) {
1248 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1249 "Intrinsic only supports column-major layout!");
1250 Value *Matrix = Inst->getArgOperand(0);
1251 Value *Ptr = Inst->getArgOperand(1);
1252 Value *Stride = Inst->getArgOperand(2);
1253 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1254 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1255 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1256 }
1257
1258 // Set elements I..I+NumElts-1 to Block
1259 Value *insertVector(Value *Col, unsigned I, Value *Block,
1260 IRBuilder<> &Builder) {
1261
1262 // First, bring Block to the same size as Col
1263 unsigned BlockNumElts =
1264 cast<FixedVectorType>(Block->getType())->getNumElements();
1265 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1266 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1267
1268 Block = Builder.CreateShuffleVector(
1269 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1270
1271 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1272 // 8, 4, 5, 6
1274 unsigned i;
1275 for (i = 0; i < I; i++)
1276 Mask.push_back(i);
1277
1278 unsigned VecNumElts =
1279 cast<FixedVectorType>(Col->getType())->getNumElements();
1280 for (; i < I + BlockNumElts; i++)
1281 Mask.push_back(i - I + VecNumElts);
1282
1283 for (; i < VecNumElts; i++)
1284 Mask.push_back(i);
1285
1286 return Builder.CreateShuffleVector(Col, Block, Mask);
1287 }
1288
1289 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1290 IRBuilder<> &Builder, bool AllowContraction,
1291 unsigned &NumComputeOps) {
1292 NumComputeOps += getNumOps(A->getType());
1293 if (!Sum)
1294 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1295
1296 if (UseFPOp) {
1297 if (AllowContraction) {
1298 // Use fmuladd for floating point operations and let the backend decide
1299 // if that's profitable.
1301 Func.getParent(), Intrinsic::fmuladd, A->getType());
1302 return Builder.CreateCall(FMulAdd, {A, B, Sum});
1303 }
1304 NumComputeOps += getNumOps(A->getType());
1305 Value *Mul = Builder.CreateFMul(A, B);
1306 return Builder.CreateFAdd(Sum, Mul);
1307 }
1308
1309 NumComputeOps += getNumOps(A->getType());
1310 Value *Mul = Builder.CreateMul(A, B);
1311 return Builder.CreateAdd(Sum, Mul);
1312 }
1313
1314 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1315 /// users with shape information, there's nothing to do: they will use the
1316 /// cached value when they are lowered. For other users, \p Matrix is
1317 /// flattened and the uses are updated to use it. Also marks \p Inst for
1318 /// deletion.
1319 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1320 IRBuilder<> &Builder) {
1321 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1322 (void)inserted;
1323 assert(inserted.second && "multiple matrix lowering mapping");
1324
1325 ToRemove.push_back(Inst);
1326 Value *Flattened = nullptr;
1327 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1328 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1329 if (!Flattened)
1330 Flattened = Matrix.embedInVector(Builder);
1331 U.set(Flattened);
1332 }
1333 }
1334 }
1335
1336 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1337 /// vectors Lowers to vector reduction add instead of sequential add if
1338 /// reassocation is enabled.
1339 void lowerDotProduct(CallInst *MatMul,
1341 FastMathFlags FMF) {
1342 if (FusedInsts.contains(MatMul) ||
1343 MatrixLayout != MatrixLayoutTy::ColumnMajor)
1344 return;
1345 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1346 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1347
1348 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1349 return;
1350
1351 Value *LHS = MatMul->getArgOperand(0);
1352 Value *RHS = MatMul->getArgOperand(1);
1353
1354 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
1355 bool IsIntVec = ElementType->isIntegerTy();
1356
1357 // Floating point reductions require reassocation.
1358 if (!IsIntVec && !FMF.allowReassoc())
1359 return;
1360
1361 auto CanBeFlattened = [this](Value *Op) {
1362 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end())
1363 return true;
1364 return match(
1366 m_Load(m_Value()),
1367 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1368 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1369 m_Value(), m_SpecificInt(1))))));
1370 };
1371 // Returns the cost benefit of using \p Op with the dot product lowering. If
1372 // the returned cost is < 0, the argument is cheaper to use in the
1373 // dot-product lowering.
1374 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1375 if (!isa<Instruction>(Op))
1376 return InstructionCost(0);
1377
1378 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1379 Type *EltTy = VecTy->getElementType();
1380
1381 if (!CanBeFlattened(Op)) {
1382 InstructionCost EmbedCost(0);
1383 // Roughly estimate the cost for embedding the columns into a vector.
1384 for (unsigned I = 1; I < N; ++I)
1385 EmbedCost -=
1387 std::nullopt, TTI::TCK_RecipThroughput);
1388 return EmbedCost;
1389 }
1390
1391 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1392 InstructionCost OriginalCost =
1393 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1394 EltTy) *
1395 N;
1397 cast<Instruction>(Op)->getOpcode(), VecTy);
1398 return NewCost - OriginalCost;
1399 }
1400
1401 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1402 // The transpose can be skipped for the dot product lowering, roughly
1403 // estimate the savings as the cost of embedding the columns in a
1404 // vector.
1405 InstructionCost EmbedCost(0);
1406 for (unsigned I = 1; I < N; ++I)
1407 EmbedCost +=
1409 std::nullopt, TTI::TCK_RecipThroughput);
1410 return EmbedCost;
1411 }
1412
1413 // Costs for loads.
1414 if (N == 1)
1415 return InstructionCost(0);
1416
1417 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1418 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
1419 };
1420 auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
1421
1422 // We compare the costs of a vector.reduce.add to sequential add.
1423 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1424 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1425 InstructionCost ReductionCost =
1427 AddOpCode, cast<VectorType>(LHS->getType()),
1428 IsIntVec ? std::nullopt : std::optional(FMF)) +
1429 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
1430 InstructionCost SequentialAddCost =
1431 TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
1432 (LShape.NumColumns - 1) +
1433 TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
1434 (LShape.NumColumns);
1435 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1436 return;
1437
1438 FusedInsts.insert(MatMul);
1439 IRBuilder<> Builder(MatMul);
1440 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1441 this](Value *Op) -> Value * {
1442 // Matmul must be the only user of loads because we don't use LowerLoad
1443 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1444 // instead of single vector load).
1445 if (!CanBeFlattened(Op))
1446 return Op;
1447
1448 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1449 ShapeMap[Op] = ShapeMap[Op].t();
1450 return Op;
1451 }
1452
1453 FusedInsts.insert(cast<Instruction>(Op));
1454 // If vector uses the builtin load, lower to a LoadInst
1455 Value *Arg;
1456 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1457 m_Value(Arg)))) {
1458 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1459 Op->replaceAllUsesWith(NewLoad);
1460 cast<Instruction>(Op)->eraseFromParent();
1461 return NewLoad;
1462 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1463 m_Value(Arg)))) {
1464 ToRemove.push_back(cast<Instruction>(Op));
1465 return Arg;
1466 }
1467
1468 return Op;
1469 };
1470 LHS = FlattenArg(LHS);
1471
1472 // Insert mul/fmul and llvm.vector.reduce.fadd
1473 Value *Mul =
1474 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
1475
1476 Value *Result;
1477 if (IsIntVec)
1478 Result = Builder.CreateAddReduce(Mul);
1479 else {
1480 Result = Builder.CreateFAddReduce(
1481 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
1482 0.0),
1483 Mul);
1484 cast<Instruction>(Result)->setFastMathFlags(FMF);
1485 }
1486
1487 // pack scalar back into a matrix and then replace matmul inst
1488 Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()),
1489 Result, uint64_t(0));
1490 MatMul->replaceAllUsesWith(Result);
1491 FusedInsts.insert(MatMul);
1492 ToRemove.push_back(MatMul);
1493 }
1494
1495 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1496 /// addition.
1497 ///
1498 /// We can fold a transpose into the operand that is used to extract scalars.
1499 /// This is the first operands with row-major and the second with
1500 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1501 /// operand is transposed.
1502 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1503 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1504 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1505 const unsigned VF = std::max<unsigned>(
1507 .getFixedValue() /
1508 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1509 1U);
1510 unsigned R = Result.getNumRows();
1511 unsigned C = Result.getNumColumns();
1512 unsigned M = A.getNumColumns();
1513
1514 bool IsFP = Result.getElementType()->isFloatingPointTy();
1515 assert(A.isColumnMajor() == B.isColumnMajor() &&
1516 Result.isColumnMajor() == A.isColumnMajor() &&
1517 "operands must agree on matrix layout");
1518 unsigned NumComputeOps = 0;
1519
1520 Builder.setFastMathFlags(FMF);
1521
1522 if (A.isColumnMajor()) {
1523 // Multiply columns from the first operand with scalars from the second
1524 // operand. Then move along the K axes and accumulate the columns. With
1525 // this the adds can be vectorized without reassociation.
1526 for (unsigned J = 0; J < C; ++J) {
1527 unsigned BlockSize = VF;
1528 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1529 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1530
1531 for (unsigned I = 0; I < R; I += BlockSize) {
1532 // Gradually lower the vectorization factor to cover the remainder.
1533 while (I + BlockSize > R)
1534 BlockSize /= 2;
1535
1536 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1537 : nullptr;
1538 for (unsigned K = 0; K < M; ++K) {
1539 Value *L = A.extractVector(I, K, BlockSize, Builder);
1540 Value *RH = Builder.CreateExtractElement(
1541 B.getColumn(IsScalarMatrixTransposed ? K : J),
1542 IsScalarMatrixTransposed ? J : K);
1543 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1544 Sum =
1545 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1546 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1547 }
1548 Result.setVector(J,
1549 insertVector(Result.getVector(J), I, Sum, Builder));
1550 }
1551 }
1552 } else {
1553 // Multiply rows from the second operand with scalars from the first
1554 // operand. Then move along the K axes and accumulate the rows. With this
1555 // the adds can be vectorized without reassociation.
1556 for (unsigned I = 0; I < R; ++I) {
1557 unsigned BlockSize = VF;
1558 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1559 for (unsigned J = 0; J < C; J += BlockSize) {
1560 // Gradually lower the vectorization factor to cover the remainder.
1561 while (J + BlockSize > C)
1562 BlockSize /= 2;
1563
1564 Value *Sum = nullptr;
1565 for (unsigned K = 0; K < M; ++K) {
1566 Value *R = B.extractVector(K, J, BlockSize, Builder);
1567 Value *LH = Builder.CreateExtractElement(
1568 A.getVector(IsScalarMatrixTransposed ? K : I),
1569 IsScalarMatrixTransposed ? I : K);
1570 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1571 Sum =
1572 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1573 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1574 }
1575 Result.setVector(I,
1576 insertVector(Result.getVector(I), J, Sum, Builder));
1577 }
1578 }
1579 }
1580 Result.addNumComputeOps(NumComputeOps);
1581 }
1582
1583 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1584 /// copying it to a new location. This new or otherwise the original location
1585 /// is returned.
1586 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1587 CallInst *MatMul) {
1588 MemoryLocation StoreLoc = MemoryLocation::get(Store);
1589 MemoryLocation LoadLoc = MemoryLocation::get(Load);
1590
1591 // If we can statically determine noalias we're good.
1592 if (AA->isNoAlias(LoadLoc, StoreLoc))
1593 return Load->getPointerOperand();
1594
1595 // Create code to check if the memory locations of the Load and Store
1596 // overlap and if they do, copy Load's operand to a new buffer.
1597
1598 // First, create new blocks for 2n part of the check and the copy.
1599 BasicBlock *Check0 = MatMul->getParent();
1600 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1601 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1602 // as we adjust Check0 and Check1's branches.
1604 for (BasicBlock *Succ : successors(Check0))
1605 DTUpdates.push_back({DT->Delete, Check0, Succ});
1606
1607 BasicBlock *Check1 =
1608 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1609 nullptr, "alias_cont");
1610 BasicBlock *Copy =
1611 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1612 nullptr, "copy");
1613 BasicBlock *Fusion =
1614 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1615 nullptr, "no_alias");
1616
1617 // Check if the loaded memory location begins before the end of the store
1618 // location. If the condition holds, they might overlap, otherwise they are
1619 // guaranteed to not overlap.
1620 IRBuilder<> Builder(MatMul);
1621 Check0->getTerminator()->eraseFromParent();
1622 Builder.SetInsertPoint(Check0);
1623 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1624 Value *StoreBegin = Builder.CreatePtrToInt(
1625 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1626 Value *StoreEnd = Builder.CreateAdd(
1627 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1628 "store.end", true, true);
1629 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1630 IntPtrTy, "load.begin");
1631 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1632 Fusion);
1633
1634 // Check if the store begins before the end of the load location. If the
1635 // condition holds, they alias, otherwise they are guaranteed to not
1636 // overlap.
1637 Check1->getTerminator()->eraseFromParent();
1638 Builder.SetInsertPoint(Check1, Check1->begin());
1639 Value *LoadEnd = Builder.CreateAdd(
1640 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1641 "load.end", true, true);
1642 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1643 Fusion);
1644
1645 // Copy load operand to new alloca.
1646 Builder.SetInsertPoint(Copy, Copy->begin());
1647 auto *VT = cast<FixedVectorType>(Load->getType());
1648 // Use an array type for the alloca, to avoid potentially huge alignment
1649 // requirements for large vector types.
1650 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1651 AllocaInst *Alloca =
1652 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1653 Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo());
1654
1655 Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(),
1656 Load->getAlign(), LoadLoc.Size.getValue());
1657 Builder.SetInsertPoint(Fusion, Fusion->begin());
1658 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1659 PHI->addIncoming(Load->getPointerOperand(), Check0);
1660 PHI->addIncoming(Load->getPointerOperand(), Check1);
1661 PHI->addIncoming(BC, Copy);
1662
1663 // Adjust DT.
1664 DTUpdates.push_back({DT->Insert, Check0, Check1});
1665 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1666 DTUpdates.push_back({DT->Insert, Check1, Copy});
1667 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1668 DT->applyUpdates(DTUpdates);
1669 return PHI;
1670 }
1671
1672 bool isFusionProfitable(CallInst *MatMul) {
1673 if (ForceFusion)
1674 return true;
1675
1676 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1677 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1678
1679 const unsigned R = LShape.NumRows;
1680 const unsigned C = RShape.NumColumns;
1681 const unsigned M = LShape.NumColumns;
1682 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1683
1684 const unsigned VF = std::max<unsigned>(
1686 .getFixedValue() /
1688 1U);
1689
1690 // Cost model for tiling
1691 //
1692 // For tiling to be beneficial, we need reuse either along the R or
1693 // the C axis. We vectorize along the R axis so that means at least
1694 // 3 elements.
1695 // TODO: Also consider cost of copying if operands alias.
1696 if (R <= VF && C == 1)
1697 return false;
1698 // Then we need enough elements to exceed the number of vector
1699 // registers we have. Note that this is an oversimplification since
1700 // fusing also takes some extra loads which may exceed the number of
1701 // reloads necessary.
1702 unsigned Op0Regs = (R + VF - 1) / VF * M;
1703 unsigned Op1Regs = (M + VF - 1) / VF * C;
1704 return Op0Regs + Op1Regs >
1706 }
1707
1708 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1709 MatrixTy Res;
1710 auto *ColumType = FixedVectorType::get(EltType, R);
1711 for (unsigned I = 0; I < C; ++I)
1712 Res.addVector(ConstantAggregateZero::get(ColumType));
1713 return Res;
1714 }
1715
1716 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1717 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1718 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1719
1720 // Create the main tiling loop nest.
1721 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1722 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1723 Instruction *InsertI = cast<Instruction>(MatMul);
1724 BasicBlock *Start = InsertI->getParent();
1725 BasicBlock *End =
1726 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1727 IRBuilder<> Builder(MatMul);
1728 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1729
1730 Type *TileVecTy =
1732 MatrixTy TileResult;
1733 // Insert in the inner loop header.
1734 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1735 // Create PHI nodes for the result columns to accumulate across iterations.
1736 SmallVector<PHINode *, 4> ColumnPhis;
1737 for (unsigned I = 0; I < TileSize; I++) {
1738 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1739 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1740 TI.RowLoop.Header->getSingleSuccessor());
1741 TileResult.addVector(Phi);
1742 ColumnPhis.push_back(Phi);
1743 }
1744
1745 // Insert in the inner loop body, which computes
1746 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1747 Builder.SetInsertPoint(InnerBody->getTerminator());
1748 // Load tiles of the operands.
1749 MatrixTy A =
1750 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1751 {TileSize, TileSize}, EltType, Builder);
1752 MatrixTy B =
1753 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1754 {TileSize, TileSize}, EltType, Builder);
1755 emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1756 getFastMathFlags(MatMul));
1757 // Store result after the inner loop is done.
1758 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1759 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1760 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1761 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1762
1763 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1764 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1765
1766 // Force unrolling of a few iterations of the inner loop, to make sure there
1767 // is enough work per iteration.
1768 // FIXME: The unroller should make this decision directly instead, but
1769 // currently the cost-model is not up to the task.
1770 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1771 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1772 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1773 }
1774
1775 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1776 StoreInst *Store,
1777 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1778 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1779 "Tiling only supported for column-major matrixes at the moment!");
1780 if (!isFusionProfitable(MatMul))
1781 return;
1782
1783 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1784 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1785
1786 const unsigned R = LShape.NumRows;
1787 const unsigned C = RShape.NumColumns;
1788 const unsigned M = LShape.NumColumns;
1789 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1790
1791 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1792 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1793 Value *CPtr = Store->getPointerOperand();
1794
1795 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1796 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1797 else {
1798 IRBuilder<> Builder(Store);
1799 for (unsigned J = 0; J < C; J += TileSize)
1800 for (unsigned I = 0; I < R; I += TileSize) {
1801 const unsigned TileR = std::min(R - I, unsigned(TileSize));
1802 const unsigned TileC = std::min(C - J, unsigned(TileSize));
1803 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1804
1805 for (unsigned K = 0; K < M; K += TileSize) {
1806 const unsigned TileM = std::min(M - K, unsigned(TileSize));
1807 MatrixTy A =
1808 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1809 LShape, Builder.getInt64(I), Builder.getInt64(K),
1810 {TileR, TileM}, EltType, Builder);
1811 MatrixTy B =
1812 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1813 RShape, Builder.getInt64(K), Builder.getInt64(J),
1814 {TileM, TileC}, EltType, Builder);
1815 emitMatrixMultiply(Res, A, B, Builder, true, false,
1816 getFastMathFlags(MatMul));
1817 }
1818 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1819 Builder.getInt64(I), Builder.getInt64(J), EltType,
1820 Builder);
1821 }
1822 }
1823
1824 // Mark eliminated instructions as fused and remove them.
1825 FusedInsts.insert(Store);
1826 FusedInsts.insert(MatMul);
1827 Store->eraseFromParent();
1828 MatMul->eraseFromParent();
1829 if (LoadOp0->hasNUses(0)) {
1830 FusedInsts.insert(LoadOp0);
1831 LoadOp0->eraseFromParent();
1832 }
1833 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1834 FusedInsts.insert(LoadOp1);
1835 LoadOp1->eraseFromParent();
1836 }
1837 }
1838
1839 /// Try to lower matrix multiply chains by fusing operations.
1840 ///
1841 /// Call finalizeLowering on lowered instructions. Instructions that are
1842 /// completely eliminated by fusion are added to \p FusedInsts.
1843 void LowerMatrixMultiplyFused(CallInst *MatMul,
1844 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1845 if (!FuseMatrix || !DT)
1846 return;
1847
1848 assert(AA && LI && "Analyses should be available");
1849
1850 Value *A = MatMul->getArgOperand(0);
1851 Value *B = MatMul->getArgOperand(1);
1852
1853 // We can fold the transpose into the operand that is used to fetch scalars.
1854 Value *T;
1855 if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1856 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1857 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1858 IRBuilder<> Builder(MatMul);
1859 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1860 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1861 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1862 const unsigned R = LShape.NumRows;
1863 const unsigned M = LShape.NumColumns;
1864 const unsigned C = RShape.NumColumns;
1865
1866 MatrixTy MA;
1867 MatrixTy MB;
1868
1869 Value *Transpose;
1870 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1871 MA = getMatrix(A, ShapeInfo(R, M), Builder);
1872 MB = getMatrix(T, ShapeInfo(C, M), Builder);
1873 Transpose = B;
1874 } else {
1875 MA = getMatrix(T, ShapeInfo(R, M), Builder);
1876 MB = getMatrix(B, ShapeInfo(C, M), Builder);
1877 Transpose = A;
1878 }
1879
1880 // Initialize the output
1881 MatrixTy Result(R, C, EltType);
1882
1883 emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1884 getFastMathFlags(MatMul));
1885
1886 FusedInsts.insert(MatMul);
1887 if (Transpose->hasOneUse()) {
1888 FusedInsts.insert(cast<Instruction>(Transpose));
1889 ToRemove.push_back(cast<Instruction>(Transpose));
1890 // TODO: add a fake entry for the folded instruction so that this is
1891 // included in the expression in the remark.
1892 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1893 }
1894 finalizeLowering(MatMul, Result, Builder);
1895 return;
1896 }
1897
1898 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1899 return;
1900
1901 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1902 // since the single store user will be lowered as part of this.
1903 auto *LoadOp0 = dyn_cast<LoadInst>(A);
1904 auto *LoadOp1 = dyn_cast<LoadInst>(B);
1905 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1906 if (LoadOp0 && LoadOp1 && Store) {
1907 // The store address must dominate the MatMul instruction, otherwise
1908 // we create invalid IR.
1909 SetVector<Value *> WorkList;
1910 WorkList.insert(Store->getOperand(1));
1912 for (unsigned I = 0; I != WorkList.size(); ++I) {
1913 Value *Current = WorkList[I];
1914 auto *CurrI = dyn_cast<Instruction>(Current);
1915 if (!CurrI)
1916 continue;
1917 if (isa<PHINode>(CurrI))
1918 return;
1919 if (DT->dominates(CurrI, MatMul))
1920 continue;
1921 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1922 return;
1923 ToHoist.push_back(CurrI);
1924 WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1925 }
1926
1927 sort(ToHoist, [this](Instruction *A, Instruction *B) {
1928 return DT->dominates(A, B);
1929 });
1930 for (Instruction *I : ToHoist)
1931 I->moveBefore(MatMul);
1932
1933 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1934 return;
1935 }
1936 }
1937
1938 /// Lowers llvm.matrix.multiply.
1939 void LowerMultiply(CallInst *MatMul) {
1940 IRBuilder<> Builder(MatMul);
1941 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1942 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1943 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1944
1945 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1946 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1947 assert(Lhs.getElementType() == Rhs.getElementType() &&
1948 "Matrix multiply argument element types do not match.");
1949
1950 const unsigned R = LShape.NumRows;
1951 const unsigned C = RShape.NumColumns;
1952 assert(LShape.NumColumns == RShape.NumRows);
1953
1954 // Initialize the output
1955 MatrixTy Result(R, C, EltType);
1956 assert(Lhs.getElementType() == Result.getElementType() &&
1957 "Matrix multiply result element type does not match arguments.");
1958
1959 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
1960 getFastMathFlags(MatMul));
1961 finalizeLowering(MatMul, Result, Builder);
1962 }
1963
1964 /// Lowers llvm.matrix.transpose.
1965 void LowerTranspose(CallInst *Inst) {
1966 MatrixTy Result;
1967 IRBuilder<> Builder(Inst);
1968 Value *InputVal = Inst->getArgOperand(0);
1969 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1970 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1971 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1972
1973 const unsigned NewNumVecs =
1974 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1975 const unsigned NewNumElts =
1976 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1977
1978 for (unsigned I = 0; I < NewNumVecs; ++I) {
1979 // Build a single result vector. First initialize it.
1980 Value *ResultVector = PoisonValue::get(
1981 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1982 // Go through the old elements and insert it into the resulting vector.
1983 for (auto J : enumerate(InputMatrix.vectors())) {
1984 Value *Elt = Builder.CreateExtractElement(J.value(), I);
1985 // Row and column indices are transposed.
1986 ResultVector =
1987 Builder.CreateInsertElement(ResultVector, Elt, J.index());
1988 }
1989 Result.addVector(ResultVector);
1990 }
1991
1992 // TODO: Improve estimate of operations needed for transposes. Currently we
1993 // just count the insertelement/extractelement instructions, but do not
1994 // account for later simplifications/combines.
1995 finalizeLowering(
1996 Inst,
1997 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1998 .addNumExposedTransposes(1),
1999 Builder);
2000 }
2001
2002 /// Lower load instructions, if shape information is available.
2003 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
2004 auto I = ShapeMap.find(Inst);
2005 if (I == ShapeMap.end())
2006 return false;
2007
2008 LowerLoad(Inst, Ptr, Inst->getAlign(),
2009 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2010 I->second);
2011 return true;
2012 }
2013
2014 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
2015 IRBuilder<> &Builder) {
2016 auto I = ShapeMap.find(StoredVal);
2017 if (I == ShapeMap.end())
2018 return false;
2019
2020 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2021 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2022 I->second);
2023 return true;
2024 }
2025
2026 /// Lower binary operators, if shape information is available.
2027 bool VisitBinaryOperator(BinaryOperator *Inst) {
2028 auto I = ShapeMap.find(Inst);
2029 if (I == ShapeMap.end())
2030 return false;
2031
2032 Value *Lhs = Inst->getOperand(0);
2033 Value *Rhs = Inst->getOperand(1);
2034
2035 IRBuilder<> Builder(Inst);
2036 ShapeInfo &Shape = I->second;
2037
2038 MatrixTy Result;
2039 MatrixTy A = getMatrix(Lhs, Shape, Builder);
2040 MatrixTy B = getMatrix(Rhs, Shape, Builder);
2041 assert(A.isColumnMajor() == B.isColumnMajor() &&
2042 Result.isColumnMajor() == A.isColumnMajor() &&
2043 "operands must agree on matrix layout");
2044
2045 Builder.setFastMathFlags(getFastMathFlags(Inst));
2046
2047 // Helper to perform binary op on vectors.
2048 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2049 switch (Inst->getOpcode()) {
2050 case Instruction::Add:
2051 return Builder.CreateAdd(LHS, RHS);
2052 case Instruction::Mul:
2053 return Builder.CreateMul(LHS, RHS);
2054 case Instruction::Sub:
2055 return Builder.CreateSub(LHS, RHS);
2056 case Instruction::FAdd:
2057 return Builder.CreateFAdd(LHS, RHS);
2058 case Instruction::FMul:
2059 return Builder.CreateFMul(LHS, RHS);
2060 case Instruction::FSub:
2061 return Builder.CreateFSub(LHS, RHS);
2062 default:
2063 llvm_unreachable("Unsupported binary operator for matrix");
2064 }
2065 };
2066
2067 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2068 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
2069
2070 finalizeLowering(Inst,
2071 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2072 Result.getNumVectors()),
2073 Builder);
2074 return true;
2075 }
2076
2077 /// Lower unary operators, if shape information is available.
2078 bool VisitUnaryOperator(UnaryOperator *Inst) {
2079 auto I = ShapeMap.find(Inst);
2080 if (I == ShapeMap.end())
2081 return false;
2082
2083 Value *Op = Inst->getOperand(0);
2084
2085 IRBuilder<> Builder(Inst);
2086 ShapeInfo &Shape = I->second;
2087
2088 MatrixTy Result;
2089 MatrixTy M = getMatrix(Op, Shape, Builder);
2090
2091 Builder.setFastMathFlags(getFastMathFlags(Inst));
2092
2093 // Helper to perform unary op on vectors.
2094 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2095 switch (Inst->getOpcode()) {
2096 case Instruction::FNeg:
2097 return Builder.CreateFNeg(Op);
2098 default:
2099 llvm_unreachable("Unsupported unary operator for matrix");
2100 }
2101 };
2102
2103 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2104 Result.addVector(BuildVectorOp(M.getVector(I)));
2105
2106 finalizeLowering(Inst,
2107 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2108 Result.getNumVectors()),
2109 Builder);
2110 return true;
2111 }
2112
2113 /// Helper to linearize a matrix expression tree into a string. Currently
2114 /// matrix expressions are linarized by starting at an expression leaf and
2115 /// linearizing bottom up.
2116 struct ExprLinearizer {
2117 unsigned LengthToBreak = 100;
2118 std::string Str;
2119 raw_string_ostream Stream;
2120 unsigned LineLength = 0;
2121 const DataLayout &DL;
2122
2123 /// Mapping from instructions to matrixes. It is used to identify
2124 /// matrix instructions.
2125 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2126
2127 /// Mapping from values to the leaves of all expressions that the value is
2128 /// part of.
2130
2131 /// Set of matrix expressions in the scope of a given DISubprogram.
2132 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2133
2134 /// Leaf node of the expression to linearize.
2135 Value *Leaf;
2136
2137 /// Used to keep track of sub-expressions that get reused while linearizing
2138 /// the expression. Re-used sub-expressions are marked as (reused).
2139 SmallPtrSet<Value *, 8> ReusedExprs;
2140
2141 ExprLinearizer(const DataLayout &DL,
2142 const MapVector<Value *, MatrixTy> &Inst2Matrix,
2143 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2144 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2145 Value *Leaf)
2146 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2147 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2148
2149 void indent(unsigned N) {
2150 LineLength += N;
2151 for (unsigned i = 0; i < N; i++)
2152 Stream << " ";
2153 }
2154
2155 void lineBreak() {
2156 Stream << "\n";
2157 LineLength = 0;
2158 }
2159
2160 void maybeIndent(unsigned Indent) {
2161 if (LineLength >= LengthToBreak)
2162 lineBreak();
2163
2164 if (LineLength == 0)
2165 indent(Indent);
2166 }
2167
2168 void write(StringRef S) {
2169 LineLength += S.size();
2170 Stream << S;
2171 }
2172
2173 Value *getUnderlyingObjectThroughLoads(Value *V) {
2174 if (Value *Ptr = getPointerOperand(V))
2175 return getUnderlyingObjectThroughLoads(Ptr);
2176 else if (V->getType()->isPointerTy())
2177 return getUnderlyingObject(V);
2178 return V;
2179 }
2180
2181 /// Returns true if \p V is a matrix value in the given subprogram.
2182 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2183
2184 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
2185 /// \p SS.
2186 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2187 auto M = Inst2Matrix.find(V);
2188 if (M == Inst2Matrix.end())
2189 SS << "unknown";
2190 else {
2191 SS << M->second.getNumRows();
2192 SS << "x";
2193 SS << M->second.getNumColumns();
2194 }
2195 }
2196
2197 /// Write the called function name. Handles calls to llvm.matrix.*
2198 /// specially: we write the name, followed by the dimensions of the input
2199 /// matrixes, followed by the scalar type name.
2200 void writeFnName(CallInst *CI) {
2201 if (!CI->getCalledFunction())
2202 write("<no called fn>");
2203 else {
2205 if (!Name.startswith("llvm.matrix")) {
2206 write(Name);
2207 return;
2208 }
2209 auto *II = cast<IntrinsicInst>(CI);
2211 .drop_front(StringRef("llvm.matrix.").size()));
2212 write(".");
2213 std::string Tmp;
2215
2216 switch (II->getIntrinsicID()) {
2217 case Intrinsic::matrix_multiply:
2218 prettyPrintMatrixType(II->getOperand(0), SS);
2219 SS << ".";
2220 prettyPrintMatrixType(II->getOperand(1), SS);
2221 SS << "." << *II->getType()->getScalarType();
2222 break;
2223 case Intrinsic::matrix_transpose:
2224 prettyPrintMatrixType(II->getOperand(0), SS);
2225 SS << "." << *II->getType()->getScalarType();
2226 break;
2227 case Intrinsic::matrix_column_major_load:
2228 prettyPrintMatrixType(II, SS);
2229 SS << "." << *II->getType()->getScalarType();
2230 break;
2231 case Intrinsic::matrix_column_major_store:
2232 prettyPrintMatrixType(II->getOperand(0), SS);
2233 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2234 break;
2235 default:
2236 llvm_unreachable("Unhandled case");
2237 }
2238 SS.flush();
2239 write(Tmp);
2240 }
2241 }
2242
2243 unsigned getNumShapeArgs(CallInst *CI) const {
2244 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2245 switch (II->getIntrinsicID()) {
2246 case Intrinsic::matrix_multiply:
2247 return 3;
2248 case Intrinsic::matrix_transpose:
2249 return 2;
2250 case Intrinsic::matrix_column_major_load:
2251 case Intrinsic::matrix_column_major_store:
2252 return 3;
2253 default:
2254 return 0;
2255 }
2256 }
2257 return 0;
2258 }
2259
2260 /// Special printing for values: for pointers, we print if they refer to an
2261 /// (function) external address or a stack address, for other values we
2262 /// either print the constant or "scalar"/"matrix" for other values.
2263 void write(Value *V) {
2264 V = getUnderlyingObjectThroughLoads(V);
2265 if (V->getType()->isPointerTy()) {
2266 if (isa<AllocaInst>(V)) {
2267 Stream << "stack addr";
2268 LineLength += StringRef("stack addr").size();
2269 } else {
2270 Stream << "addr";
2271 LineLength += StringRef("addr").size();
2272 }
2273 if (!V->getName().empty()) {
2274 Stream << " %" << V->getName() << "";
2275 LineLength += V->getName().size() + 2;
2276 }
2277 return;
2278 }
2279
2280 std::string Tmp;
2281 raw_string_ostream TmpStream(Tmp);
2282
2283 if (auto *CI = dyn_cast<ConstantInt>(V))
2284 TmpStream << CI->getValue();
2285 else if (isa<Constant>(V))
2286 TmpStream << "constant";
2287 else {
2288 if (isMatrix(V))
2289 TmpStream << "matrix";
2290 else
2291 TmpStream << "scalar";
2292 }
2293 TmpStream.flush();
2294 Tmp = std::string(StringRef(Tmp).trim());
2295 LineLength += Tmp.size();
2296 Stream << Tmp;
2297 }
2298
2299 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2300 /// Expressions that are re-used multiple times are prefixed with (reused)
2301 /// at the re-used root instruction.
2302 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2303 bool ParentShared) {
2304 auto *I = cast<Instruction>(Expr);
2305 maybeIndent(Indent);
2307
2308 // Is Expr shared with other expression leaves?
2309 bool ExprShared = false;
2310
2311 // Deal with shared subtrees. Mark them as shared, if required.
2312 if (!ParentShared) {
2313 auto SI = Shared.find(Expr);
2314 assert(SI != Shared.end() && SI->second.count(Leaf));
2315
2316 for (Value *S : SI->second) {
2317 if (S == Leaf)
2318 continue;
2319 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2320 write("shared with remark at line " + std::to_string(DL.getLine()) +
2321 " column " + std::to_string(DL.getCol()) + " (");
2322 }
2323 ExprShared = SI->second.size() > 1;
2324 }
2325
2326 bool Reused = !ReusedExprs.insert(Expr).second;
2327 if (Reused && !ParentReused)
2328 write("(reused) ");
2329
2330 if (auto *CI = dyn_cast<CallInst>(I)) {
2331 writeFnName(CI);
2332
2333 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2334 } else if (isa<BitCastInst>(Expr)) {
2335 // Special case bitcasts, which are used to materialize matrixes from
2336 // non-matrix ops.
2337 write("matrix");
2338 return;
2339 } else {
2340 Ops.append(I->value_op_begin(), I->value_op_end());
2341 write(std::string(I->getOpcodeName()));
2342 }
2343
2344 write(std::string("("));
2345
2346 unsigned NumOpsToBreak = 1;
2347 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2348 NumOpsToBreak = 2;
2349
2350 for (Value *Op : Ops) {
2351 if (Ops.size() > NumOpsToBreak)
2352 lineBreak();
2353
2354 maybeIndent(Indent + 1);
2355 if (isMatrix(Op))
2356 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2357 else
2358 write(Op);
2359 if (Op != Ops.back())
2360 write(", ");
2361 }
2362
2363 write(")");
2364 }
2365
2366 const std::string &getResult() {
2367 Stream.flush();
2368 return Str;
2369 }
2370 };
2371
2372 /// Generate remarks for matrix operations in a function. To generate remarks
2373 /// for matrix expressions, the following approach is used:
2374 /// 1. Use the inlined-at debug information to group matrix operations to the
2375 /// DISubprograms they are contained in.
2376 /// 2. Collect leaves of matrix expressions (done in
2377 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2378 // mapping. Leaves are lowered matrix instructions without other matrix
2379 // users (like stores) in the current subprogram.
2380 /// 3. For each leaf, create a remark containing a linearizied version of the
2381 /// matrix expression. The expression is linearized by a recursive
2382 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2383 /// that multiple leaves can share sub-expressions. Shared subexpressions
2384 /// are explicitly marked as shared().
2385 struct RemarkGenerator {
2386 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2388 Function &Func;
2389 const DataLayout &DL;
2390
2391 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2393 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2394 DL(Func.getParent()->getDataLayout()) {}
2395
2396 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2397 /// instructions in Inst2Matrix returning void or without any users in
2398 /// \p ExprsInSubprogram. Currently that should only include stores.
2400 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2402 for (auto *Expr : ExprsInSubprogram)
2403 if (Expr->getType()->isVoidTy() ||
2404 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2405 return ExprsInSubprogram.count(U);
2406 }))
2407 Leaves.push_back(Expr);
2408 return Leaves;
2409 }
2410
2411 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2412 /// to all visited expressions in \p Shared. Limit the matrix operations to
2413 /// the ones in \p ExprsInSubprogram.
2414 void collectSharedInfo(Value *Leaf, Value *V,
2415 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2417
2418 if (!ExprsInSubprogram.count(V))
2419 return;
2420
2421 auto I = Shared.insert({V, {}});
2422 I.first->second.insert(Leaf);
2423
2424 for (Value *Op : cast<Instruction>(V)->operand_values())
2425 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2426 }
2427
2428 /// Calculate the number of exclusive and shared op counts for expression
2429 /// starting at \p V. Expressions used multiple times are counted once.
2430 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2431 std::pair<OpInfoTy, OpInfoTy>
2432 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2433 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2434 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2435 if (!ExprsInSubprogram.count(Root))
2436 return {};
2437
2438 // Already counted this expression. Stop.
2439 if (!ReusedExprs.insert(Root).second)
2440 return {};
2441
2442 OpInfoTy SharedCount;
2443 OpInfoTy Count;
2444
2445 auto I = Shared.find(Root);
2446 auto CM = Inst2Matrix.find(Root);
2447 if (I->second.size() == 1)
2448 Count = CM->second.getOpInfo();
2449 else
2450 SharedCount = CM->second.getOpInfo();
2451
2452 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2453 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2454 Count += C.first;
2455 SharedCount += C.second;
2456 }
2457 return {Count, SharedCount};
2458 }
2459
2460 void emitRemarks() {
2462 return;
2463
2464 // Map matrix operations to their containting subprograms, by traversing
2465 // the inlinedAt chain. If the function does not have a DISubprogram, we
2466 // only map them to the containing function.
2468 for (const auto &KV : Inst2Matrix) {
2469 if (Func.getSubprogram()) {
2470 auto *I = cast<Instruction>(KV.first);
2471 DILocation *Context = I->getDebugLoc();
2472 while (Context) {
2473 auto I =
2474 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2475 I.first->second.push_back(KV.first);
2476 Context = DebugLoc(Context).getInlinedAt();
2477 }
2478 } else {
2479 auto I = Subprog2Exprs.insert({nullptr, {}});
2480 I.first->second.push_back(KV.first);
2481 }
2482 }
2483 for (auto &KV : Subprog2Exprs) {
2484 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2485 KV.second.end());
2486 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2487
2489 for (Value *Leaf : Leaves)
2490 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2491
2492 // Generate remarks for each leaf.
2493 for (auto *L : Leaves) {
2494
2495 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2496 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2497 while (Context) {
2498 if (getSubprogram(Context->getScope()) == KV.first) {
2499 Loc = Context;
2500 break;
2501 }
2502 Context = DebugLoc(Context).getInlinedAt();
2503 }
2504
2505 SmallPtrSet<Value *, 8> ReusedExprs;
2506 OpInfoTy Counts, SharedCounts;
2507 std::tie(Counts, SharedCounts) =
2508 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2509
2510 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2511 cast<Instruction>(L)->getParent());
2512
2513 Rem << "Lowered with ";
2514 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2515 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2516 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2517 << " compute ops, "
2518 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2519 << " exposed transposes";
2520
2521 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2522 SharedCounts.NumComputeOps > 0) {
2523 Rem << ",\nadditionally "
2524 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2525 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2526 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2527 << " compute ops"
2528 << " are shared with other expressions";
2529 }
2530
2531 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2532 ORE.emit(Rem);
2533 }
2534 }
2535 }
2536
2537 std::string
2538 linearize(Value *L,
2539 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2540 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2541 const DataLayout &DL) {
2542 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2543 Lin.linearizeExpr(L, 0, false, false);
2544 return Lin.getResult();
2545 }
2546 };
2547};
2548} // namespace
2549
2552 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2553 OptimizationRemarkEmitter *ORE = nullptr;
2554 AAResults *AA = nullptr;
2555 DominatorTree *DT = nullptr;
2556 LoopInfo *LI = nullptr;
2557
2558 if (!Minimal) {
2560 AA = &AM.getResult<AAManager>(F);
2562 LI = &AM.getResult<LoopAnalysis>(F);
2563 }
2564
2565 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2566 if (LMT.Visit()) {
2568 if (!Minimal) {
2569 PA.preserve<LoopAnalysis>();
2571 }
2572 return PA;
2573 }
2574 return PreservedAnalyses::all();
2575}
2576
2578 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2580 OS, MapClassName2PassName);
2581 OS << '<';
2582 if (Minimal)
2583 OS << "minimal";
2584 OS << '>';
2585}
2586
2587namespace {
2588
2589class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
2590public:
2591 static char ID;
2592
2593 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
2596 }
2597
2598 bool runOnFunction(Function &F) override {
2599 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2600 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
2601 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
2602 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2603 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2604 LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
2605 bool C = LMT.Visit();
2606 return C;
2607 }
2608
2609 void getAnalysisUsage(AnalysisUsage &AU) const override {
2617 }
2618};
2619} // namespace
2620
2621static const char pass_name[] = "Lower the matrix intrinsics";
2622char LowerMatrixIntrinsicsLegacyPass::ID = 0;
2623INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2624 false, false)
2629INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
2631
2633 return new LowerMatrixIntrinsicsLegacyPass();
2634}
2635
2636namespace {
2637
2638/// A lightweight version of the matrix lowering pass that only requires TTI.
2639/// Advanced features that require DT, AA or ORE like tiling are disabled. This
2640/// is used to lower matrix intrinsics if the main lowering pass is not run, for
2641/// example with -O0.
2642class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
2643public:
2644 static char ID;
2645
2646 LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
2649 }
2650
2651 bool runOnFunction(Function &F) override {
2652 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
2653 LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
2654 bool C = LMT.Visit();
2655 return C;
2656 }
2657
2658 void getAnalysisUsage(AnalysisUsage &AU) const override {
2660 AU.setPreservesCFG();
2661 }
2662};
2663} // namespace
2664
2665static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
2666char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
2667INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
2668 "lower-matrix-intrinsics-minimal", pass_name_minimal,
2669 false, false)
2670INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
2672 false)
2673
2675 return new LowerMatrixIntrinsicsMinimalLegacyPass();
2676}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
amdgpu Simplify well known AMD library false FunctionCallee Value * Arg
Rewrite undef for PHI
ReachingDefAnalysis InstSet & ToRemove
assume Assume Builder
static const Function * getParent(const Value *V)
BitTracker BT
Definition: BitTracker.cpp:73
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:678
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
std::string Name
bool End
Definition: ELF_riscv.cpp:464
static bool runOnFunction(Function &F, bool PostInlining)
expand Expand reduction intrinsics
#define DEBUG_TYPE
Hexagon Common GEP
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:522
Live Register Matrix
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
lower matrix intrinsics minimal
static const char pass_name[]
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
static const char pass_name_minimal[]
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
#define DEBUG_TYPE
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define T1
LLVMContext & Context
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
@ SI
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2366
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2388
raw_pwrite_stream & OS
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
static const int BlockSize
Definition: TarWriter.cpp:33
This pass exposes codegen information to IR-level passes.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Value * RHS
Value * LHS
BinaryOperator * Mul
A manager for alias analyses.
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Definition: Instructions.h:58
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:125
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:620
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:774
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:265
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:326
reverse_iterator rbegin()
Definition: BasicBlock.h:331
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:89
reverse_iterator rend()
Definition: BasicBlock.h:333
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:127
BinaryOps getOpcode() const
Definition: InstrTypes.h:391
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1412
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1332
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1763
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1357
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1338
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1579
static Constant * get(Type *Ty, double V)
This returns a ConstantFP, or a vector containing a splat of a ConstantFP, for the specified value in...
Definition: Constants.cpp:927
This is the shared class of boolean and integer constants.
Definition: Constants.h:78
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:888
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Debug location.
Base class for scope-like contexts.
Subprogram description.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
A debug info location.
Definition: DebugLoc.h:33
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:39
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:314
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:20
void setAllowContract(bool B=true)
Definition: FMF.h:91
bool allowReassoc() const
Flag queries.
Definition: FMF.h:65
bool allowContract() const
Definition: FMF.h:70
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:536
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:754
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:204
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:209
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2570
const BasicBlock * getParent() const
Definition: Instruction.h:90
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:82
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:54
An instruction for reading from memory.
Definition: Instructions.h:177
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:214
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:220
uint64_t getValue() const
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:569
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:594
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:37
iterator end()
Definition: MapVector.h:72
iterator find(const KeyT &Key)
Definition: MapVector.h:147
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:118
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
OptimizationRemarkEmitter legacy analysis pass.
The optimization diagnostic interface.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:94
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1743
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:173
A vector that has set insertion semantics.
Definition: SetVector.h:51
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:88
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:219
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:152
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:344
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:383
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:389
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:312
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
bool empty() const
Definition: SmallSet.h:159
bool erase(const T &V)
Definition: SmallSet.h:207
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:577
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:687
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
An instruction for storing to memory.
Definition: Instructions.h:301
Align getAlign() const
Definition: Instructions.h:345
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:337
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:613
constexpr size_t size() const
size - Get the string size.
Definition: StringRef.h:137
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
TypeSize getRegisterBitWidth(RegisterKind K) const
InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
@ TCK_RecipThroughput
Reciprocal throughput.
InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args=ArrayRef< const Value * >(), const Instruction *CxtI=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *Tp, ArrayRef< int > Mask=std::nullopt, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args=std::nullopt) const
unsigned getNumberOfRegisters(unsigned ClassID) const
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:229
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:140
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:348
UnaryOps getOpcode() const
Definition: InstrTypes.h:170
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1724
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:242
Value * getOperand(unsigned i) const
Definition: User.h:169
See the file comment.
Definition: ValueMap.h:84
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:151
iterator find(const KeyT &Val)
Definition: ValueMap.h:155
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:172
iterator end()
Definition: ValueMap.h:135
bool erase(const KeyT &Val)
Definition: ValueMap.h:190
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
user_iterator user_begin()
Definition: Value.h:397
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:535
iterator_range< user_iterator > users()
Definition: Value.h:421
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:149
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Type * getElementType() const
Definition: DerivedTypes.h:433
constexpr ScalarTy getFixedValue() const
Definition: TypeSize.h:182
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:642
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:119
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:987
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1465
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
specific_intval< false > m_SpecificInt(APInt V)
Match a specific integer value or vector with all elements equal to the value.
Definition: PatternMatch.h:854
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
Definition: PatternMatch.h:979
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:84
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:147
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
Definition: PatternMatch.h:985
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
Definition: PatternMatch.h:67
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:218
@ SS
Definition: X86.h:209
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:703
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
DiagnosticInfoOptimizationBase::Argument NV
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:227
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:236
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:440
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1777
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
Definition: STLExtras.h:2430
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2052
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
std::string & operator+=(std::string &buffer, StringRef string)
Definition: StringRef.h:907
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, bool ContinueOnCuIndexOverflow)
Definition: DWP.cpp:585
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:748
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &)
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:214
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1826
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:511
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1744
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:145
Pass * createLowerMatrixIntrinsicsPass()
void initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(PassRegistry &)
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ FMulAdd
Fused multiply-add of floats (a * b + c).
@ Add
Sum of integers.
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition: Casting.h:565
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
BasicBlock * SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Pass * createLowerMatrixIntrinsicsMinimalPass()
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition: PassManager.h:371
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....
Definition: MatrixUtils.h:31