LLVM 18.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
22#include "llvm/ADT/SmallSet.h"
31#include "llvm/IR/CFG.h"
32#include "llvm/IR/DataLayout.h"
34#include "llvm/IR/Function.h"
35#include "llvm/IR/IRBuilder.h"
42#include "llvm/Support/Debug.h"
46
47#include <cmath>
48
49using namespace llvm;
50using namespace PatternMatch;
51
52#define DEBUG_TYPE "lower-matrix-intrinsics"
53
54static cl::opt<bool>
55 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
56 cl::desc("Enable/disable fusing matrix instructions."));
57// TODO: Allow and use non-square tiles.
59 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
61 "Tile size for matrix instruction fusion using square-shaped tiles."));
62static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
64 cl::desc("Generate loop nest for tiling."));
66 "force-fuse-matrix", cl::init(false), cl::Hidden,
67 cl::desc("Force matrix instruction fusion even if not profitable."));
69 "matrix-allow-contract", cl::init(false), cl::Hidden,
70 cl::desc("Allow the use of FMAs if available and profitable. This may "
71 "result in different results, due to less rounding error."));
72
73static cl::opt<bool>
74 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
75 cl::desc("Enable/disable matrix shape verification."),
76 cl::init(false));
77
79
81 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
82 cl::desc("Sets the default matrix layout"),
84 "Use column-major layout"),
86 "Use row-major layout")));
87
88static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
89 cl::init(false));
90
91/// Helper function to either return Scope, if it is a subprogram or the
92/// attached subprogram for a local scope.
94 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
95 return Subprogram;
96 return cast<DILocalScope>(Scope)->getSubprogram();
97}
98
99/// Erase \p V from \p BB and move \II forward to avoid invalidating
100/// iterators.
102 BasicBlock &BB) {
103 auto *Inst = cast<Instruction>(V);
104 // Still used, don't erase.
105 if (!Inst->use_empty())
106 return;
107 if (II != BB.rend() && Inst == &*II)
108 ++II;
109 Inst->eraseFromParent();
110}
111
112/// Return true if V is a splat of a value (which is used when multiplying a
113/// matrix with a scalar).
114static bool isSplat(Value *V) {
115 if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
116 return SV->isZeroEltSplat();
117 return false;
118}
119
120/// Match any mul operation (fp or integer).
121template <typename LTy, typename RTy>
122auto m_AnyMul(const LTy &L, const RTy &R) {
123 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
124}
125
126/// Match any add operation (fp or integer).
127template <typename LTy, typename RTy>
128auto m_AnyAdd(const LTy &L, const RTy &R) {
129 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
130}
131
132namespace {
133
134// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
135// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
136// assuming \p Stride elements between start two consecutive vectors.
137// \p Stride must be >= \p NumElements.
138// For column-major matrixes, the function computes the address of a column
139// vectors and \p NumElements must be set to the number of elements in a column
140// (= number of rows of the matrix). For row-major matrixes, the function
141// computes the address of a row vector and \p NumElements must be set to the
142// number of elements in a column (= number of columns of the matrix).
143//
144// Consider a 4x4 matrix in column-mjaor layout like below
145//
146// 0 1 2 3
147// 0 v_0_0 v_0_1 v_0_2 v_0_3
148// 1 v_1_0 v_1_1 v_1_2 v_1_3
149// 2 v_2_0 v_2_1 v_2_2 v_2_3
150// 3 v_3_0 v_3_1 v_3_2 v_3_3
151
152// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
153// we need a pointer to the first element of the submatrix as base pointer.
154// Then we can use computeVectorAddr to compute the addresses for the columns
155// of the sub-matrix.
156//
157// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
158// -> just returns Base
159// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
160// -> returns Base + (1 * 4)
161// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
162// -> returns Base + (2 * 4)
163//
164// The graphic below illustrates the number of elements in a column (marked
165// with |) and the number of skipped elements (marked with }).
166//
167// v_0_0 v_0_1 {v_0_2 {v_0_3
168// Base Col 1 Col 2
169// | | |
170// v_1_0 |v_1_1 |v_1_2 |v_1_3
171// v_2_0 |v_2_1 |v_2_2 |v_2_3
172// v_3_0 {v_3_1 {v_3_2 v_3_3
173//
174Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
175 unsigned NumElements, Type *EltType,
176 IRBuilder<> &Builder) {
177
178 assert((!isa<ConstantInt>(Stride) ||
179 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
180 "Stride must be >= the number of elements in the result vector.");
181
182 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
183 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
184
185 // Get pointer to the start of the selected vector. Skip GEP creation,
186 // if we select vector 0.
187 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
188 VecStart = BasePtr;
189 else
190 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
191
192 return VecStart;
193}
194
195/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
196///
197/// Currently, the lowering for each matrix intrinsic is done as follows:
198/// 1. Propagate the shape information from intrinsics to connected
199/// instructions.
200/// 2. Lower instructions with shape information (assuming column-major layout).
201/// The lowering works similarly using row-major layout.
202/// 2.1. Get column vectors for each argument. If we already lowered the
203/// definition of an argument, use the produced column vectors directly.
204/// If not, split the operand vector containing an embedded matrix into
205/// a set of column vectors,
206/// 2.2. Lower the instruction in terms of column major operations, which
207/// yields a set of column vectors containing result matrix. Note that we
208/// lower all instructions that have shape information. Besides the
209/// intrinsics, this includes stores for example.
210/// 2.3. Update uses of the lowered instruction. If we have shape information
211/// for a user, there is nothing to do, as we will look up the result
212/// column matrix when lowering the user. For other uses, we embed the
213/// result matrix in a flat vector and update the use.
214/// 2.4. Cache the result column matrix for the instruction we lowered
215/// 3. After we lowered all instructions in a function, remove the now
216/// obsolete instructions.
217///
218class LowerMatrixIntrinsics {
219 Function &Func;
220 const DataLayout &DL;
222 AliasAnalysis *AA;
223 DominatorTree *DT;
224 LoopInfo *LI;
226
227 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
228 struct OpInfoTy {
229 /// Number of stores emitted to generate this matrix.
230 unsigned NumStores = 0;
231 /// Number of loads emitted to generate this matrix.
232 unsigned NumLoads = 0;
233 /// Number of compute operations emitted to generate this matrix.
234 unsigned NumComputeOps = 0;
235 /// Most of the time transposes can be fused with matrix multiplies or can
236 /// be folded away via algebraic simplifications. This is the number of
237 /// transposes that we failed to make "free" via such optimizations.
238 unsigned NumExposedTransposes = 0;
239
240 OpInfoTy &operator+=(const OpInfoTy &RHS) {
241 NumStores += RHS.NumStores;
242 NumLoads += RHS.NumLoads;
243 NumComputeOps += RHS.NumComputeOps;
244 NumExposedTransposes += RHS.NumExposedTransposes;
245 return *this;
246 }
247 };
248
249 /// Wrapper class representing a matrix as a set of vectors, either in row or
250 /// column major layout. All vectors must have the same vector type.
251 class MatrixTy {
253
254 OpInfoTy OpInfo;
255
256 bool IsColumnMajor = true;
257
258 public:
259 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
260 MatrixTy(ArrayRef<Value *> Vectors)
261 : Vectors(Vectors.begin(), Vectors.end()),
262 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
263 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
264 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
265
266 unsigned D = isColumnMajor() ? NumColumns : NumRows;
267 for (unsigned J = 0; J < D; ++J)
269 EltTy, isColumnMajor() ? NumRows : NumColumns)));
270 }
271
272 Value *getVector(unsigned i) const { return Vectors[i]; }
273 Value *getColumn(unsigned i) const {
274 assert(isColumnMajor() && "only supported for column-major matrixes");
275 return Vectors[i];
276 }
277 Value *getRow(unsigned i) const {
278 assert(!isColumnMajor() && "only supported for row-major matrixes");
279 return Vectors[i];
280 }
281
282 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
283
284 Type *getElementType() const { return getVectorTy()->getElementType(); }
285
286 unsigned getNumVectors() const {
287 if (isColumnMajor())
288 return getNumColumns();
289 return getNumRows();
290 }
291
292 unsigned getNumColumns() const {
293 if (isColumnMajor())
294 return Vectors.size();
295 else {
296 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
297 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
298 }
299 }
300 unsigned getNumRows() const {
301 if (isColumnMajor()) {
302 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
303 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
304 } else
305 return Vectors.size();
306 }
307
308 void addVector(Value *V) { Vectors.push_back(V); }
309 VectorType *getColumnTy() {
310 assert(isColumnMajor() && "only supported for column-major matrixes");
311 return getVectorTy();
312 }
313
314 VectorType *getVectorTy() const {
315 return cast<VectorType>(Vectors[0]->getType());
316 }
317
319 assert(isColumnMajor() &&
320 "columns() only supported for column-major matrixes");
321 return make_range(Vectors.begin(), Vectors.end());
322 }
323
325 return make_range(Vectors.begin(), Vectors.end());
326 }
327
328 /// Embed the vectors of the matrix into a flat vector by concatenating
329 /// them.
330 Value *embedInVector(IRBuilder<> &Builder) const {
331 return Vectors.size() == 1 ? Vectors[0]
332 : concatenateVectors(Builder, Vectors);
333 }
334
335 MatrixTy &addNumLoads(unsigned N) {
336 OpInfo.NumLoads += N;
337 return *this;
338 }
339
340 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
341
342 MatrixTy &addNumStores(unsigned N) {
343 OpInfo.NumStores += N;
344 return *this;
345 }
346
347 MatrixTy &addNumExposedTransposes(unsigned N) {
348 OpInfo.NumExposedTransposes += N;
349 return *this;
350 }
351
352 MatrixTy &addNumComputeOps(unsigned N) {
353 OpInfo.NumComputeOps += N;
354 return *this;
355 }
356
357 unsigned getNumStores() const { return OpInfo.NumStores; }
358 unsigned getNumLoads() const { return OpInfo.NumLoads; }
359 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
360
361 const OpInfoTy &getOpInfo() const { return OpInfo; }
362
363 bool isColumnMajor() const { return IsColumnMajor; }
364
365 unsigned getStride() const {
366 if (isColumnMajor())
367 return getNumRows();
368 return getNumColumns();
369 }
370
371 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
372 /// matrix is column-major, the result vector is extracted from a column
373 /// vector, otherwise from a row vector.
374 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
375 IRBuilder<> &Builder) const {
376 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
377 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
378 NumElts &&
379 "Extracted vector will contain poison values");
380 return Builder.CreateShuffleVector(
381 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
382 "block");
383 }
384 };
385
386 struct ShapeInfo {
387 unsigned NumRows;
388 unsigned NumColumns;
389
390 bool IsColumnMajor;
391
392 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
393 : NumRows(NumRows), NumColumns(NumColumns),
394 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
395
396 ShapeInfo(Value *NumRows, Value *NumColumns)
397 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
398 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
399
400 bool operator==(const ShapeInfo &other) {
401 return NumRows == other.NumRows && NumColumns == other.NumColumns;
402 }
403 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
404
405 /// Returns true if shape-information is defined, meaning both dimensions
406 /// are != 0.
407 operator bool() const {
408 assert(NumRows == 0 || NumColumns != 0);
409 return NumRows != 0;
410 }
411
412 unsigned getStride() const {
413 if (IsColumnMajor)
414 return NumRows;
415 return NumColumns;
416 }
417
418 unsigned getNumVectors() const {
419 if (IsColumnMajor)
420 return NumColumns;
421 return NumRows;
422 }
423
424 /// Returns the transposed shape.
425 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
426 };
427
428 /// Maps instructions to their shape information. The shape information
429 /// describes the shape to be used while lowering. This matches the shape of
430 /// the result value of the instruction, with the only exceptions being store
431 /// instructions and the matrix_column_major_store intrinsics. For those, the
432 /// shape information indicates that those instructions should be lowered
433 /// using shape information as well. A ValueMap is used so that when
434 /// sub-passes like optimizeTransposes performs RAUW the map stays
435 /// up-to-date.
437
438 /// List of instructions to remove. While lowering, we are not replacing all
439 /// users of a lowered instruction, if shape information is available and
440 /// those need to be removed after we finished lowering.
442
443 /// Map from instructions to their produced column matrix.
444 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
445
446private:
447 static FastMathFlags getFastMathFlags(Instruction *Inst) {
448 FastMathFlags FMF;
449
450 if (isa<FPMathOperator>(*Inst))
451 FMF = Inst->getFastMathFlags();
452
454
455 return FMF;
456 }
457
458public:
459 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
462 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
463 LI(LI), ORE(ORE) {}
464
465 unsigned getNumOps(Type *VT) {
466 assert(isa<VectorType>(VT) && "Expected vector type");
467 return getNumOps(VT->getScalarType(),
468 cast<FixedVectorType>(VT)->getNumElements());
469 }
470
471 /// Is this the minimal version executed in the backend pipelines.
472 bool isMinimal() const {
473 return !DT;
474 }
475
476 /// Return the estimated number of vector ops required for an operation on
477 /// \p VT * N.
478 unsigned getNumOps(Type *ST, unsigned N) {
479 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
482 .getFixedValue()));
483 }
484
485 /// Return the set of vectors that a matrix value is lowered to.
486 ///
487 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
488 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
489 /// into vectors.
490 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
491 IRBuilder<> &Builder) {
492 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
493 assert(VType && "MatrixVal must be a vector type");
494 assert(cast<FixedVectorType>(VType)->getNumElements() ==
495 SI.NumRows * SI.NumColumns &&
496 "The vector size must match the number of matrix elements");
497
498 // Check if we lowered MatrixVal using shape information. In that case,
499 // return the existing matrix, if it matches the requested shape
500 // information. If there is a mis-match, embed the result in a flat
501 // vector and split it later.
502 auto Found = Inst2ColumnMatrix.find(MatrixVal);
503 if (Found != Inst2ColumnMatrix.end()) {
504 MatrixTy &M = Found->second;
505 // Return the found matrix, if its shape matches the requested shape
506 // information
507 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
508 return M;
509
510 MatrixVal = M.embedInVector(Builder);
511 }
512
513 // Otherwise split MatrixVal.
514 SmallVector<Value *, 16> SplitVecs;
515 for (unsigned MaskStart = 0;
516 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
517 MaskStart += SI.getStride()) {
518 Value *V = Builder.CreateShuffleVector(
519 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
520 "split");
521 SplitVecs.push_back(V);
522 }
523
524 return {SplitVecs};
525 }
526
527 /// If \p V already has a known shape return false. Otherwise set the shape
528 /// for instructions that support it.
529 bool setShapeInfo(Value *V, ShapeInfo Shape) {
530 assert(Shape && "Shape not set");
531 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
532 return false;
533
534 auto SIter = ShapeMap.find(V);
535 if (SIter != ShapeMap.end()) {
536 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
537 SIter->second.NumColumns != Shape.NumColumns)) {
538 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
539 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
540 << Shape.NumColumns << ") for " << *V << "\n";
542 "Matrix shape verification failed, compilation aborted!");
543 }
544
545 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
546 << SIter->second.NumRows << " "
547 << SIter->second.NumColumns << " for " << *V << "\n");
548 return false;
549 }
550
551 ShapeMap.insert({V, Shape});
552 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
553 << " for " << *V << "\n");
554 return true;
555 }
556
557 bool isUniformShape(Value *V) {
558 Instruction *I = dyn_cast<Instruction>(V);
559 if (!I)
560 return true;
561
562 switch (I->getOpcode()) {
563 case Instruction::FAdd:
564 case Instruction::FSub:
565 case Instruction::FMul: // Scalar multiply.
566 case Instruction::FNeg:
567 case Instruction::Add:
568 case Instruction::Mul:
569 case Instruction::Sub:
570 return true;
571 default:
572 return false;
573 }
574 }
575
576 /// Returns true if shape information can be used for \p V. The supported
577 /// instructions must match the instructions that can be lowered by this pass.
578 bool supportsShapeInfo(Value *V) {
579 Instruction *Inst = dyn_cast<Instruction>(V);
580 if (!Inst)
581 return false;
582
583 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
584 if (II)
585 switch (II->getIntrinsicID()) {
586 case Intrinsic::matrix_multiply:
587 case Intrinsic::matrix_transpose:
588 case Intrinsic::matrix_column_major_load:
589 case Intrinsic::matrix_column_major_store:
590 return true;
591 default:
592 return false;
593 }
594 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
595 }
596
597 /// Propagate the shape information of instructions to their users.
598 /// The work list contains instructions for which we can compute the shape,
599 /// either based on the information provided by matrix intrinsics or known
600 /// shapes of operands.
602 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
604 // Pop an element for which we guaranteed to have at least one of the
605 // operand shapes. Add the shape for this and then add users to the work
606 // list.
607 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
608 while (!WorkList.empty()) {
609 Instruction *Inst = WorkList.pop_back_val();
610
611 // New entry, set the value and insert operands
612 bool Propagate = false;
613
614 Value *MatrixA;
615 Value *MatrixB;
616 Value *M;
617 Value *N;
618 Value *K;
619 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
620 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
621 m_Value(N), m_Value(K)))) {
622 Propagate = setShapeInfo(Inst, {M, K});
623 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
624 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
625 // Flip dimensions.
626 Propagate = setShapeInfo(Inst, {N, M});
627 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
628 m_Value(MatrixA), m_Value(), m_Value(),
629 m_Value(), m_Value(M), m_Value(N)))) {
630 Propagate = setShapeInfo(Inst, {N, M});
631 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
632 m_Value(), m_Value(), m_Value(), m_Value(M),
633 m_Value(N)))) {
634 Propagate = setShapeInfo(Inst, {M, N});
635 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
636 auto OpShape = ShapeMap.find(MatrixA);
637 if (OpShape != ShapeMap.end())
638 setShapeInfo(Inst, OpShape->second);
639 continue;
640 } else if (isUniformShape(Inst)) {
641 // Find the first operand that has a known shape and use that.
642 for (auto &Op : Inst->operands()) {
643 auto OpShape = ShapeMap.find(Op.get());
644 if (OpShape != ShapeMap.end()) {
645 Propagate |= setShapeInfo(Inst, OpShape->second);
646 break;
647 }
648 }
649 }
650
651 if (Propagate) {
652 NewWorkList.push_back(Inst);
653 for (auto *User : Inst->users())
654 if (ShapeMap.count(User) == 0)
655 WorkList.push_back(cast<Instruction>(User));
656 }
657 }
658
659 return NewWorkList;
660 }
661
662 /// Propagate the shape to operands of instructions with shape information.
663 /// \p Worklist contains the instruction for which we already know the shape.
665 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
667
668 auto pushInstruction = [](Value *V,
670 Instruction *I = dyn_cast<Instruction>(V);
671 if (I)
672 WorkList.push_back(I);
673 };
674 // Pop an element with known shape. Traverse the operands, if their shape
675 // derives from the result shape and is unknown, add it and add them to the
676 // worklist.
677 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
678 while (!WorkList.empty()) {
679 Value *V = WorkList.pop_back_val();
680
681 size_t BeforeProcessingV = WorkList.size();
682 if (!isa<Instruction>(V))
683 continue;
684
685 Value *MatrixA;
686 Value *MatrixB;
687 Value *M;
688 Value *N;
689 Value *K;
690 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
691 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
692 m_Value(N), m_Value(K)))) {
693 if (setShapeInfo(MatrixA, {M, N}))
694 pushInstruction(MatrixA, WorkList);
695
696 if (setShapeInfo(MatrixB, {N, K}))
697 pushInstruction(MatrixB, WorkList);
698
699 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
700 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
701 // Flip dimensions.
702 if (setShapeInfo(MatrixA, {M, N}))
703 pushInstruction(MatrixA, WorkList);
704 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
705 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
706 m_Value(M), m_Value(N)))) {
707 if (setShapeInfo(MatrixA, {M, N})) {
708 pushInstruction(MatrixA, WorkList);
709 }
710 } else if (isa<LoadInst>(V) ||
711 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
712 // Nothing to do, no matrix input.
713 } else if (isa<StoreInst>(V)) {
714 // Nothing to do. We forward-propagated to this so we would just
715 // backward propagate to an instruction with an already known shape.
716 } else if (isUniformShape(V)) {
717 // Propagate to all operands.
718 ShapeInfo Shape = ShapeMap[V];
719 for (Use &U : cast<Instruction>(V)->operands()) {
720 if (setShapeInfo(U.get(), Shape))
721 pushInstruction(U.get(), WorkList);
722 }
723 }
724 // After we discovered new shape info for new instructions in the
725 // worklist, we use their users as seeds for the next round of forward
726 // propagation.
727 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
728 for (User *U : WorkList[I]->users())
729 if (isa<Instruction>(U) && V != U)
730 NewWorkList.push_back(cast<Instruction>(U));
731 }
732 return NewWorkList;
733 }
734
735 /// (Op0 op Op1)^T -> Op0^T op Op1^T
736 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
737 /// them on both sides of \p Operation.
738 Instruction *distributeTransposes(
739 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
740 MatrixBuilder &Builder,
741 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
742 Operation) {
743 Value *T0 = Builder.CreateMatrixTranspose(
744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
745 // We are being run after shape prop, add shape for newly created
746 // instructions so that we lower them later.
747 setShapeInfo(T0, Shape0.t());
748 Value *T1 = Builder.CreateMatrixTranspose(
749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
750 setShapeInfo(T1, Shape1.t());
751 return Operation(T0, Shape0.t(), T1, Shape1.t());
752 }
753
754 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
755 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
756 // with New. We should only add New it it supportsShapeInfo so we insert
757 // it conditionally instead.
758 auto S = ShapeMap.find(&Old);
759 if (S != ShapeMap.end()) {
760 ShapeMap.erase(S);
761 if (supportsShapeInfo(New))
762 ShapeMap.insert({New, S->second});
763 }
764 Old.replaceAllUsesWith(New);
765 }
766
767 /// Sink a top-level transpose inside matmuls and adds.
768 /// This creates and erases instructions as needed, and returns the newly
769 /// created instruction while updating the iterator to avoid invalidation. If
770 /// this returns nullptr, no new instruction was created.
772 BasicBlock &BB = *I.getParent();
773 IRBuilder<> IB(&I);
774 MatrixBuilder Builder(IB);
775
776 Value *TA, *TAMA, *TAMB;
777 ConstantInt *R, *K, *C;
778 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
780 return nullptr;
781
782 // Transpose of a transpose is a nop
783 Value *TATA;
784 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
785 updateShapeAndReplaceAllUsesWith(I, TATA);
786 eraseFromParentAndMove(&I, II, BB);
787 eraseFromParentAndMove(TA, II, BB);
788 return nullptr;
789 }
790
791 // k^T -> k
792 if (isSplat(TA)) {
793 updateShapeAndReplaceAllUsesWith(I, TA);
794 eraseFromParentAndMove(&I, II, BB);
795 return nullptr;
796 }
797
798 // (A * B)^t -> B^t * A^t
799 // RxK KxC CxK KxR
800 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
801 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
803 auto NewInst = distributeTransposes(
804 TAMB, {K, C}, TAMA, {R, K}, Builder,
805 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
806 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
807 Shape0.NumColumns,
808 Shape1.NumColumns, "mmul");
809 });
810 updateShapeAndReplaceAllUsesWith(I, NewInst);
811 eraseFromParentAndMove(&I, II, BB);
812 eraseFromParentAndMove(TA, II, BB);
813 return NewInst;
814 }
815
816 // Same as above, but with a mul, which occurs when multiplied
817 // with a scalar.
818 // (A * k)^t -> A^t * k
819 // R x C RxC
820 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
821 (isSplat(TAMA) || isSplat(TAMB))) {
822 IRBuilder<> LocalBuilder(&I);
823 // We know that the transposed operand is of shape RxC.
824 // An when multiplied with a scalar, the shape is preserved.
825 auto NewInst = distributeTransposes(
826 TAMA, {R, C}, TAMB, {R, C}, Builder,
827 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
828 bool IsFP = I.getType()->isFPOrFPVectorTy();
829 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
830 : LocalBuilder.CreateMul(T0, T1, "mmul");
831 auto *Result = cast<Instruction>(Mul);
832 setShapeInfo(Result, Shape0);
833 return Result;
834 });
835 updateShapeAndReplaceAllUsesWith(I, NewInst);
836 eraseFromParentAndMove(&I, II, BB);
837 eraseFromParentAndMove(TA, II, BB);
838 return NewInst;
839 }
840
841 // (A + B)^t -> A^t + B^t
842 // RxC RxC CxR CxR
843 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
844 IRBuilder<> LocalBuilder(&I);
845 auto NewInst = distributeTransposes(
846 TAMA, {R, C}, TAMB, {R, C}, Builder,
847 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
848 bool IsFP = I.getType()->isFPOrFPVectorTy();
849 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
850 : LocalBuilder.CreateAdd(T0, T1, "madd");
851
852 auto *Result = cast<Instruction>(Add);
853 setShapeInfo(Result, Shape0);
854 return Result;
855 });
856 updateShapeAndReplaceAllUsesWith(I, NewInst);
857 eraseFromParentAndMove(&I, II, BB);
858 eraseFromParentAndMove(TA, II, BB);
859 return NewInst;
860 }
861
862 return nullptr;
863 }
864
865 void liftTranspose(Instruction &I) {
866 // Erase dead Instructions after lifting transposes from binops.
867 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
868 if (T.use_empty())
869 T.eraseFromParent();
870 if (A->use_empty())
871 cast<Instruction>(A)->eraseFromParent();
872 if (A != B && B->use_empty())
873 cast<Instruction>(B)->eraseFromParent();
874 };
875
876 Value *A, *B, *AT, *BT;
877 ConstantInt *R, *K, *C;
878 // A^t * B ^t -> (B * A)^t
879 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
882 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
883 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
884 IRBuilder<> IB(&I);
885 MatrixBuilder Builder(IB);
886 Value *M = Builder.CreateMatrixMultiply(
887 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
888 setShapeInfo(M, {C, R});
889 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
890 R->getZExtValue());
891 updateShapeAndReplaceAllUsesWith(I, NewInst);
892 CleanupBinOp(I, A, B);
893 }
894 // A^t + B ^t -> (A + B)^t
895 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
896 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
897 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
898 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
900 IRBuilder<> Builder(&I);
901 Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
902 setShapeInfo(Add, {C, R});
903 MatrixBuilder MBuilder(Builder);
904 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
905 Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
906 updateShapeAndReplaceAllUsesWith(I, NewInst);
907 CleanupBinOp(I, A, B);
908 }
909 }
910
911 /// Try moving transposes in order to fold them away or into multiplies.
912 void optimizeTransposes() {
913 // First sink all transposes inside matmuls and adds, hoping that we end up
914 // with NN, NT or TN variants.
915 for (BasicBlock &BB : reverse(Func)) {
916 for (auto II = BB.rbegin(); II != BB.rend();) {
917 Instruction &I = *II;
918 // We may remove II. By default continue on the next/prev instruction.
919 ++II;
920 if (Instruction *NewInst = sinkTranspose(I, II))
921 II = std::next(BasicBlock::reverse_iterator(NewInst));
922 }
923 }
924
925 // If we have a TT matmul or a TT add, lift the transpose. We may be able
926 // to fold into consuming multiply or add.
927 for (BasicBlock &BB : Func) {
929 liftTranspose(I);
930 }
931 }
932 }
933
934 bool Visit() {
936
937 // Initially only the shape of matrix intrinsics is known.
938 // Initialize the work list with ops carrying shape information.
939 for (BasicBlock &BB : Func)
940 for (Instruction &Inst : BB) {
941 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
942 if (!II)
943 continue;
944
945 switch (II->getIntrinsicID()) {
946 case Intrinsic::matrix_multiply:
947 case Intrinsic::matrix_transpose:
948 case Intrinsic::matrix_column_major_load:
949 case Intrinsic::matrix_column_major_store:
950 WorkList.push_back(&Inst);
951 break;
952 default:
953 break;
954 }
955 }
956
957 // Avoid unnecessary work if there are no matrix intrinsics in the function.
958 if (WorkList.empty())
959 return false;
960
961 // Propagate shapes until nothing changes any longer.
962 while (!WorkList.empty()) {
963 WorkList = propagateShapeForward(WorkList);
964 WorkList = propagateShapeBackward(WorkList);
965 }
966
967 if (!isMinimal()) {
968 optimizeTransposes();
970 dbgs() << "Dump after matrix transpose optimization:\n";
971 Func.print(dbgs());
972 }
973 }
974
975 bool Changed = false;
976 SmallVector<CallInst *, 16> MaybeFusableInsts;
978
979 // First, collect all instructions with shape information and candidates for
980 // fusion (currently only matrix multiplies).
982 for (auto *BB : RPOT)
983 for (Instruction &I : *BB) {
984 if (ShapeMap.find(&I) == ShapeMap.end())
985 continue;
986 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
987 MaybeFusableInsts.push_back(cast<CallInst>(&I));
988 MatrixInsts.push_back(&I);
989 }
990
991 // Second, try to lower any dot products
993 for (CallInst *CI : MaybeFusableInsts)
994 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
995
996 // Third, try to fuse candidates.
997 for (CallInst *CI : MaybeFusableInsts)
998 LowerMatrixMultiplyFused(CI, FusedInsts);
999
1000 Changed = !FusedInsts.empty();
1001
1002 // Fourth, lower remaining instructions with shape information.
1003 for (Instruction *Inst : MatrixInsts) {
1004 if (FusedInsts.count(Inst))
1005 continue;
1006
1007 IRBuilder<> Builder(Inst);
1008
1009 if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1010 Changed |= VisitCallInst(CInst);
1011
1012 Value *Op1;
1013 Value *Op2;
1014 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1015 Changed |= VisitBinaryOperator(BinOp);
1016 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1017 Changed |= VisitUnaryOperator(UnOp);
1018 if (match(Inst, m_Load(m_Value(Op1))))
1019 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1020 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1021 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1022 }
1023
1024 if (ORE) {
1025 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1026 RemarkGen.emitRemarks();
1027 }
1028
1029 // Delete the instructions backwards, as it has a reduced likelihood of
1030 // having to update as many def-use and use-def chains.
1031 //
1032 // Because we add to ToRemove during fusion we can't guarantee that defs
1033 // are before uses. Change uses to poison temporarily as these should get
1034 // removed as well.
1035 //
1036 // For verification, we keep track of where we changed uses to poison in
1037 // PoisonedInsts and then check that we in fact remove them.
1038 SmallSet<Instruction *, 16> PoisonedInsts;
1039 for (auto *Inst : reverse(ToRemove)) {
1040 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1041 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1042 PoisonedInsts.insert(Poisoned);
1043 U.set(PoisonValue::get(Inst->getType()));
1044 }
1045 Inst->eraseFromParent();
1046 PoisonedInsts.erase(Inst);
1047 }
1048 if (!PoisonedInsts.empty()) {
1049 // If we didn't remove all poisoned instructions, it's a hard error.
1050 dbgs() << "Poisoned but present instructions:\n";
1051 for (auto *I : PoisonedInsts)
1052 dbgs() << *I << "\n";
1053 llvm_unreachable("Poisoned but instruction not removed");
1054 }
1055
1056 return Changed;
1057 }
1058
1059 /// Replace intrinsic calls
1060 bool VisitCallInst(CallInst *Inst) {
1061 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1062 return false;
1063
1064 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1065 case Intrinsic::matrix_multiply:
1066 LowerMultiply(Inst);
1067 break;
1068 case Intrinsic::matrix_transpose:
1069 LowerTranspose(Inst);
1070 break;
1071 case Intrinsic::matrix_column_major_load:
1072 LowerColumnMajorLoad(Inst);
1073 break;
1074 case Intrinsic::matrix_column_major_store:
1075 LowerColumnMajorStore(Inst);
1076 break;
1077 default:
1078 return false;
1079 }
1080 return true;
1081 }
1082
1083 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1084 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1085 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1086 /// non-ConstantInt strides, return the common alignment of the initial
1087 /// alignment and the element size in bytes.
1088 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1089 MaybeAlign A) const {
1090 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1091 if (Idx == 0)
1092 return InitialAlign;
1093
1094 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1095 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1096 uint64_t StrideInBytes =
1097 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1098 return commonAlignment(InitialAlign, Idx * StrideInBytes);
1099 }
1100 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1101 }
1102
1103 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1104 /// vectors.
1105 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1106 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1107 auto *VType = cast<VectorType>(Ty);
1108 Type *EltTy = VType->getElementType();
1109 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1110 Value *EltPtr = Ptr;
1111 MatrixTy Result;
1112 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1113 Value *GEP = computeVectorAddr(
1114 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1115 Stride, Shape.getStride(), EltTy, Builder);
1116 Value *Vector = Builder.CreateAlignedLoad(
1117 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1118 IsVolatile, "col.load");
1119
1120 Result.addVector(Vector);
1121 }
1122 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1123 Result.getNumVectors());
1124 }
1125
1126 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1127 /// starting at \p MatrixPtr[I][J].
1128 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1129 ShapeInfo MatrixShape, Value *I, Value *J,
1130 ShapeInfo ResultShape, Type *EltTy,
1131 IRBuilder<> &Builder) {
1132
1133 Value *Offset = Builder.CreateAdd(
1134 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1135
1136 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1137 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1138 ResultShape.NumColumns);
1139
1140 return loadMatrix(TileTy, TileStart, Align,
1141 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1142 ResultShape, Builder);
1143 }
1144
1145 /// Lower a load instruction with shape information.
1146 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1147 bool IsVolatile, ShapeInfo Shape) {
1148 IRBuilder<> Builder(Inst);
1149 finalizeLowering(Inst,
1150 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1151 Shape, Builder),
1152 Builder);
1153 }
1154
1155 /// Lowers llvm.matrix.column.major.load.
1156 ///
1157 /// The intrinsic loads a matrix from memory using a stride between columns.
1158 void LowerColumnMajorLoad(CallInst *Inst) {
1159 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1160 "Intrinsic only supports column-major layout!");
1161 Value *Ptr = Inst->getArgOperand(0);
1162 Value *Stride = Inst->getArgOperand(1);
1163 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1164 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1165 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1166 }
1167
1168 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1169 /// MatrixPtr[I][J].
1170 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1171 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1172 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1173 Value *Offset = Builder.CreateAdd(
1174 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1175
1176 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1177 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1178 StoreVal.getNumColumns());
1179
1180 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1181 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1182 }
1183
1184 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1185 /// vectors.
1186 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1187 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1188 IRBuilder<> &Builder) {
1189 auto VType = cast<VectorType>(Ty);
1190 Value *EltPtr = Ptr;
1191 for (auto Vec : enumerate(StoreVal.vectors())) {
1192 Value *GEP = computeVectorAddr(
1193 EltPtr,
1194 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1195 Vec.index()),
1196 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1197 Builder.CreateAlignedStore(Vec.value(), GEP,
1198 getAlignForIndex(Vec.index(), Stride,
1199 VType->getElementType(),
1200 MAlign),
1201 IsVolatile);
1202 }
1203 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1204 StoreVal.getNumVectors());
1205 }
1206
1207 /// Lower a store instruction with shape information.
1209 Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1210 IRBuilder<> Builder(Inst);
1211 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1212 finalizeLowering(Inst,
1213 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1214 IsVolatile, Builder),
1215 Builder);
1216 }
1217
1218 /// Lowers llvm.matrix.column.major.store.
1219 ///
1220 /// The intrinsic store a matrix back memory using a stride between columns.
1221 void LowerColumnMajorStore(CallInst *Inst) {
1222 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1223 "Intrinsic only supports column-major layout!");
1224 Value *Matrix = Inst->getArgOperand(0);
1225 Value *Ptr = Inst->getArgOperand(1);
1226 Value *Stride = Inst->getArgOperand(2);
1227 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1228 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1229 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1230 }
1231
1232 // Set elements I..I+NumElts-1 to Block
1233 Value *insertVector(Value *Col, unsigned I, Value *Block,
1234 IRBuilder<> &Builder) {
1235
1236 // First, bring Block to the same size as Col
1237 unsigned BlockNumElts =
1238 cast<FixedVectorType>(Block->getType())->getNumElements();
1239 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1240 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1241
1242 Block = Builder.CreateShuffleVector(
1243 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1244
1245 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1246 // 8, 4, 5, 6
1248 unsigned i;
1249 for (i = 0; i < I; i++)
1250 Mask.push_back(i);
1251
1252 unsigned VecNumElts =
1253 cast<FixedVectorType>(Col->getType())->getNumElements();
1254 for (; i < I + BlockNumElts; i++)
1255 Mask.push_back(i - I + VecNumElts);
1256
1257 for (; i < VecNumElts; i++)
1258 Mask.push_back(i);
1259
1260 return Builder.CreateShuffleVector(Col, Block, Mask);
1261 }
1262
1263 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1264 IRBuilder<> &Builder, bool AllowContraction,
1265 unsigned &NumComputeOps) {
1266 NumComputeOps += getNumOps(A->getType());
1267 if (!Sum)
1268 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1269
1270 if (UseFPOp) {
1271 if (AllowContraction) {
1272 // Use fmuladd for floating point operations and let the backend decide
1273 // if that's profitable.
1275 Func.getParent(), Intrinsic::fmuladd, A->getType());
1276 return Builder.CreateCall(FMulAdd, {A, B, Sum});
1277 }
1278 NumComputeOps += getNumOps(A->getType());
1279 Value *Mul = Builder.CreateFMul(A, B);
1280 return Builder.CreateFAdd(Sum, Mul);
1281 }
1282
1283 NumComputeOps += getNumOps(A->getType());
1284 Value *Mul = Builder.CreateMul(A, B);
1285 return Builder.CreateAdd(Sum, Mul);
1286 }
1287
1288 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1289 /// users with shape information, there's nothing to do: they will use the
1290 /// cached value when they are lowered. For other users, \p Matrix is
1291 /// flattened and the uses are updated to use it. Also marks \p Inst for
1292 /// deletion.
1293 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1294 IRBuilder<> &Builder) {
1295 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1296 (void)inserted;
1297 assert(inserted.second && "multiple matrix lowering mapping");
1298
1299 ToRemove.push_back(Inst);
1300 Value *Flattened = nullptr;
1301 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1302 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1303 if (!Flattened)
1304 Flattened = Matrix.embedInVector(Builder);
1305 U.set(Flattened);
1306 }
1307 }
1308 }
1309
1310 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1311 /// vectors Lowers to vector reduction add instead of sequential add if
1312 /// reassocation is enabled.
1313 void lowerDotProduct(CallInst *MatMul,
1315 FastMathFlags FMF) {
1316 if (FusedInsts.contains(MatMul) ||
1317 MatrixLayout != MatrixLayoutTy::ColumnMajor)
1318 return;
1319 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1320 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1321
1322 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1323 return;
1324
1325 Value *LHS = MatMul->getArgOperand(0);
1326 Value *RHS = MatMul->getArgOperand(1);
1327
1328 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
1329 bool IsIntVec = ElementType->isIntegerTy();
1330
1331 // Floating point reductions require reassocation.
1332 if (!IsIntVec && !FMF.allowReassoc())
1333 return;
1334
1335 auto CanBeFlattened = [this](Value *Op) {
1336 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end())
1337 return true;
1338 return match(
1340 m_Load(m_Value()),
1341 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1342 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1343 m_Value(), m_SpecificInt(1))))));
1344 };
1345 // Returns the cost benefit of using \p Op with the dot product lowering. If
1346 // the returned cost is < 0, the argument is cheaper to use in the
1347 // dot-product lowering.
1348 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1349 if (!isa<Instruction>(Op))
1350 return InstructionCost(0);
1351
1352 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1353 Type *EltTy = VecTy->getElementType();
1354
1355 if (!CanBeFlattened(Op)) {
1356 InstructionCost EmbedCost(0);
1357 // Roughly estimate the cost for embedding the columns into a vector.
1358 for (unsigned I = 1; I < N; ++I)
1359 EmbedCost -=
1361 std::nullopt, TTI::TCK_RecipThroughput);
1362 return EmbedCost;
1363 }
1364
1365 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1366 InstructionCost OriginalCost =
1367 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1368 EltTy) *
1369 N;
1371 cast<Instruction>(Op)->getOpcode(), VecTy);
1372 return NewCost - OriginalCost;
1373 }
1374
1375 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1376 // The transpose can be skipped for the dot product lowering, roughly
1377 // estimate the savings as the cost of embedding the columns in a
1378 // vector.
1379 InstructionCost EmbedCost(0);
1380 for (unsigned I = 1; I < N; ++I)
1381 EmbedCost +=
1383 std::nullopt, TTI::TCK_RecipThroughput);
1384 return EmbedCost;
1385 }
1386
1387 // Costs for loads.
1388 if (N == 1)
1389 return InstructionCost(0);
1390
1391 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1392 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
1393 };
1394 auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
1395
1396 // We compare the costs of a vector.reduce.add to sequential add.
1397 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1398 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1399 InstructionCost ReductionCost =
1401 AddOpCode, cast<VectorType>(LHS->getType()),
1402 IsIntVec ? std::nullopt : std::optional(FMF)) +
1403 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
1404 InstructionCost SequentialAddCost =
1405 TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
1406 (LShape.NumColumns - 1) +
1407 TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
1408 (LShape.NumColumns);
1409 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1410 return;
1411
1412 FusedInsts.insert(MatMul);
1413 IRBuilder<> Builder(MatMul);
1414 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1415 this](Value *Op) -> Value * {
1416 // Matmul must be the only user of loads because we don't use LowerLoad
1417 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1418 // instead of single vector load).
1419 if (!CanBeFlattened(Op))
1420 return Op;
1421
1422 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1423 ShapeMap[Op] = ShapeMap[Op].t();
1424 return Op;
1425 }
1426
1427 FusedInsts.insert(cast<Instruction>(Op));
1428 // If vector uses the builtin load, lower to a LoadInst
1429 Value *Arg;
1430 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1431 m_Value(Arg)))) {
1432 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1433 Op->replaceAllUsesWith(NewLoad);
1434 cast<Instruction>(Op)->eraseFromParent();
1435 return NewLoad;
1436 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1437 m_Value(Arg)))) {
1438 ToRemove.push_back(cast<Instruction>(Op));
1439 return Arg;
1440 }
1441
1442 return Op;
1443 };
1444 LHS = FlattenArg(LHS);
1445
1446 // Insert mul/fmul and llvm.vector.reduce.fadd
1447 Value *Mul =
1448 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
1449
1450 Value *Result;
1451 if (IsIntVec)
1452 Result = Builder.CreateAddReduce(Mul);
1453 else {
1454 Result = Builder.CreateFAddReduce(
1455 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
1456 0.0),
1457 Mul);
1458 cast<Instruction>(Result)->setFastMathFlags(FMF);
1459 }
1460
1461 // pack scalar back into a matrix and then replace matmul inst
1463 Result, uint64_t(0));
1464 MatMul->replaceAllUsesWith(Result);
1465 FusedInsts.insert(MatMul);
1466 ToRemove.push_back(MatMul);
1467 }
1468
1469 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1470 /// addition.
1471 ///
1472 /// We can fold a transpose into the operand that is used to extract scalars.
1473 /// This is the first operands with row-major and the second with
1474 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1475 /// operand is transposed.
1476 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1477 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1478 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1479 const unsigned VF = std::max<unsigned>(
1481 .getFixedValue() /
1482 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1483 1U);
1484 unsigned R = Result.getNumRows();
1485 unsigned C = Result.getNumColumns();
1486 unsigned M = A.getNumColumns();
1487
1488 bool IsFP = Result.getElementType()->isFloatingPointTy();
1489 assert(A.isColumnMajor() == B.isColumnMajor() &&
1490 Result.isColumnMajor() == A.isColumnMajor() &&
1491 "operands must agree on matrix layout");
1492 unsigned NumComputeOps = 0;
1493
1494 Builder.setFastMathFlags(FMF);
1495
1496 if (A.isColumnMajor()) {
1497 // Multiply columns from the first operand with scalars from the second
1498 // operand. Then move along the K axes and accumulate the columns. With
1499 // this the adds can be vectorized without reassociation.
1500 for (unsigned J = 0; J < C; ++J) {
1501 unsigned BlockSize = VF;
1502 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1503 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1504
1505 for (unsigned I = 0; I < R; I += BlockSize) {
1506 // Gradually lower the vectorization factor to cover the remainder.
1507 while (I + BlockSize > R)
1508 BlockSize /= 2;
1509
1510 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1511 : nullptr;
1512 for (unsigned K = 0; K < M; ++K) {
1513 Value *L = A.extractVector(I, K, BlockSize, Builder);
1514 Value *RH = Builder.CreateExtractElement(
1515 B.getColumn(IsScalarMatrixTransposed ? K : J),
1516 IsScalarMatrixTransposed ? J : K);
1517 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1518 Sum =
1519 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1520 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1521 }
1522 Result.setVector(J,
1523 insertVector(Result.getVector(J), I, Sum, Builder));
1524 }
1525 }
1526 } else {
1527 // Multiply rows from the second operand with scalars from the first
1528 // operand. Then move along the K axes and accumulate the rows. With this
1529 // the adds can be vectorized without reassociation.
1530 for (unsigned I = 0; I < R; ++I) {
1531 unsigned BlockSize = VF;
1532 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1533 for (unsigned J = 0; J < C; J += BlockSize) {
1534 // Gradually lower the vectorization factor to cover the remainder.
1535 while (J + BlockSize > C)
1536 BlockSize /= 2;
1537
1538 Value *Sum = nullptr;
1539 for (unsigned K = 0; K < M; ++K) {
1540 Value *R = B.extractVector(K, J, BlockSize, Builder);
1541 Value *LH = Builder.CreateExtractElement(
1542 A.getVector(IsScalarMatrixTransposed ? K : I),
1543 IsScalarMatrixTransposed ? I : K);
1544 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1545 Sum =
1546 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1547 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1548 }
1549 Result.setVector(I,
1550 insertVector(Result.getVector(I), J, Sum, Builder));
1551 }
1552 }
1553 }
1554 Result.addNumComputeOps(NumComputeOps);
1555 }
1556
1557 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1558 /// copying it to a new location. This new or otherwise the original location
1559 /// is returned.
1560 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1561 CallInst *MatMul) {
1562 MemoryLocation StoreLoc = MemoryLocation::get(Store);
1563 MemoryLocation LoadLoc = MemoryLocation::get(Load);
1564
1565 // If we can statically determine noalias we're good.
1566 if (AA->isNoAlias(LoadLoc, StoreLoc))
1567 return Load->getPointerOperand();
1568
1569 // Create code to check if the memory locations of the Load and Store
1570 // overlap and if they do, copy Load's operand to a new buffer.
1571
1572 // First, create new blocks for 2n part of the check and the copy.
1573 BasicBlock *Check0 = MatMul->getParent();
1574 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1575 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1576 // as we adjust Check0 and Check1's branches.
1578 for (BasicBlock *Succ : successors(Check0))
1579 DTUpdates.push_back({DT->Delete, Check0, Succ});
1580
1581 BasicBlock *Check1 =
1582 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1583 nullptr, "alias_cont");
1584 BasicBlock *Copy =
1585 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1586 nullptr, "copy");
1587 BasicBlock *Fusion =
1588 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1589 nullptr, "no_alias");
1590
1591 // Check if the loaded memory location begins before the end of the store
1592 // location. If the condition holds, they might overlap, otherwise they are
1593 // guaranteed to not overlap.
1594 IRBuilder<> Builder(MatMul);
1595 Check0->getTerminator()->eraseFromParent();
1596 Builder.SetInsertPoint(Check0);
1597 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1598 Value *StoreBegin = Builder.CreatePtrToInt(
1599 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1600 Value *StoreEnd = Builder.CreateAdd(
1601 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1602 "store.end", true, true);
1603 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1604 IntPtrTy, "load.begin");
1605 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1606 Fusion);
1607
1608 // Check if the store begins before the end of the load location. If the
1609 // condition holds, they alias, otherwise they are guaranteed to not
1610 // overlap.
1611 Check1->getTerminator()->eraseFromParent();
1612 Builder.SetInsertPoint(Check1, Check1->begin());
1613 Value *LoadEnd = Builder.CreateAdd(
1614 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1615 "load.end", true, true);
1616 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1617 Fusion);
1618
1619 // Copy load operand to new alloca.
1620 Builder.SetInsertPoint(Copy, Copy->begin());
1621 auto *VT = cast<FixedVectorType>(Load->getType());
1622 // Use an array type for the alloca, to avoid potentially huge alignment
1623 // requirements for large vector types.
1624 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1625 AllocaInst *Alloca =
1626 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1627
1628 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
1629 Load->getAlign(), LoadLoc.Size.getValue());
1630 Builder.SetInsertPoint(Fusion, Fusion->begin());
1631 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1632 PHI->addIncoming(Load->getPointerOperand(), Check0);
1633 PHI->addIncoming(Load->getPointerOperand(), Check1);
1634 PHI->addIncoming(Alloca, Copy);
1635
1636 // Adjust DT.
1637 DTUpdates.push_back({DT->Insert, Check0, Check1});
1638 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1639 DTUpdates.push_back({DT->Insert, Check1, Copy});
1640 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1641 DT->applyUpdates(DTUpdates);
1642 return PHI;
1643 }
1644
1645 bool isFusionProfitable(CallInst *MatMul) {
1646 if (ForceFusion)
1647 return true;
1648
1649 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1650 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1651
1652 const unsigned R = LShape.NumRows;
1653 const unsigned C = RShape.NumColumns;
1654 const unsigned M = LShape.NumColumns;
1655 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1656
1657 const unsigned VF = std::max<unsigned>(
1659 .getFixedValue() /
1661 1U);
1662
1663 // Cost model for tiling
1664 //
1665 // For tiling to be beneficial, we need reuse either along the R or
1666 // the C axis. We vectorize along the R axis so that means at least
1667 // 3 elements.
1668 // TODO: Also consider cost of copying if operands alias.
1669 if (R <= VF && C == 1)
1670 return false;
1671 // Then we need enough elements to exceed the number of vector
1672 // registers we have. Note that this is an oversimplification since
1673 // fusing also takes some extra loads which may exceed the number of
1674 // reloads necessary.
1675 unsigned Op0Regs = (R + VF - 1) / VF * M;
1676 unsigned Op1Regs = (M + VF - 1) / VF * C;
1677 return Op0Regs + Op1Regs >
1679 }
1680
1681 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1682 MatrixTy Res;
1683 auto *ColumType = FixedVectorType::get(EltType, R);
1684 for (unsigned I = 0; I < C; ++I)
1685 Res.addVector(ConstantAggregateZero::get(ColumType));
1686 return Res;
1687 }
1688
1689 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1690 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1691 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1692
1693 // Create the main tiling loop nest.
1694 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1695 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1696 Instruction *InsertI = cast<Instruction>(MatMul);
1697 BasicBlock *Start = InsertI->getParent();
1698 BasicBlock *End =
1699 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1700 IRBuilder<> Builder(MatMul);
1701 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1702
1703 Type *TileVecTy =
1705 MatrixTy TileResult;
1706 // Insert in the inner loop header.
1707 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1708 // Create PHI nodes for the result columns to accumulate across iterations.
1709 SmallVector<PHINode *, 4> ColumnPhis;
1710 for (unsigned I = 0; I < TileSize; I++) {
1711 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1712 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1713 TI.RowLoop.Header->getSingleSuccessor());
1714 TileResult.addVector(Phi);
1715 ColumnPhis.push_back(Phi);
1716 }
1717
1718 // Insert in the inner loop body, which computes
1719 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1720 Builder.SetInsertPoint(InnerBody->getTerminator());
1721 // Load tiles of the operands.
1722 MatrixTy A =
1723 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1724 {TileSize, TileSize}, EltType, Builder);
1725 MatrixTy B =
1726 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1727 {TileSize, TileSize}, EltType, Builder);
1728 emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1729 getFastMathFlags(MatMul));
1730 // Store result after the inner loop is done.
1731 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1732 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1733 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1734 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1735
1736 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1737 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1738
1739 // Force unrolling of a few iterations of the inner loop, to make sure there
1740 // is enough work per iteration.
1741 // FIXME: The unroller should make this decision directly instead, but
1742 // currently the cost-model is not up to the task.
1743 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1744 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1745 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1746 }
1747
1748 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1749 StoreInst *Store,
1750 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1751 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1752 "Tiling only supported for column-major matrixes at the moment!");
1753 if (!isFusionProfitable(MatMul))
1754 return;
1755
1756 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1757 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1758
1759 const unsigned R = LShape.NumRows;
1760 const unsigned C = RShape.NumColumns;
1761 const unsigned M = LShape.NumColumns;
1762 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1763
1764 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1765 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1766 Value *CPtr = Store->getPointerOperand();
1767
1768 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1769 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1770 else {
1771 IRBuilder<> Builder(Store);
1772 for (unsigned J = 0; J < C; J += TileSize)
1773 for (unsigned I = 0; I < R; I += TileSize) {
1774 const unsigned TileR = std::min(R - I, unsigned(TileSize));
1775 const unsigned TileC = std::min(C - J, unsigned(TileSize));
1776 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1777
1778 for (unsigned K = 0; K < M; K += TileSize) {
1779 const unsigned TileM = std::min(M - K, unsigned(TileSize));
1780 MatrixTy A =
1781 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1782 LShape, Builder.getInt64(I), Builder.getInt64(K),
1783 {TileR, TileM}, EltType, Builder);
1784 MatrixTy B =
1785 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1786 RShape, Builder.getInt64(K), Builder.getInt64(J),
1787 {TileM, TileC}, EltType, Builder);
1788 emitMatrixMultiply(Res, A, B, Builder, true, false,
1789 getFastMathFlags(MatMul));
1790 }
1791 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1792 Builder.getInt64(I), Builder.getInt64(J), EltType,
1793 Builder);
1794 }
1795 }
1796
1797 // Mark eliminated instructions as fused and remove them.
1798 FusedInsts.insert(Store);
1799 FusedInsts.insert(MatMul);
1800 Store->eraseFromParent();
1801 MatMul->eraseFromParent();
1802 if (LoadOp0->hasNUses(0)) {
1803 FusedInsts.insert(LoadOp0);
1804 LoadOp0->eraseFromParent();
1805 }
1806 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1807 FusedInsts.insert(LoadOp1);
1808 LoadOp1->eraseFromParent();
1809 }
1810 }
1811
1812 /// Try to lower matrix multiply chains by fusing operations.
1813 ///
1814 /// Call finalizeLowering on lowered instructions. Instructions that are
1815 /// completely eliminated by fusion are added to \p FusedInsts.
1816 void LowerMatrixMultiplyFused(CallInst *MatMul,
1817 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1818 if (!FuseMatrix || !DT)
1819 return;
1820
1821 assert(AA && LI && "Analyses should be available");
1822
1823 Value *A = MatMul->getArgOperand(0);
1824 Value *B = MatMul->getArgOperand(1);
1825
1826 // We can fold the transpose into the operand that is used to fetch scalars.
1827 Value *T;
1828 if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1829 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1830 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1831 IRBuilder<> Builder(MatMul);
1832 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1833 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1834 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1835 const unsigned R = LShape.NumRows;
1836 const unsigned M = LShape.NumColumns;
1837 const unsigned C = RShape.NumColumns;
1838
1839 MatrixTy MA;
1840 MatrixTy MB;
1841
1842 Value *Transpose;
1843 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1844 MA = getMatrix(A, ShapeInfo(R, M), Builder);
1845 MB = getMatrix(T, ShapeInfo(C, M), Builder);
1846 Transpose = B;
1847 } else {
1848 MA = getMatrix(T, ShapeInfo(R, M), Builder);
1849 MB = getMatrix(B, ShapeInfo(C, M), Builder);
1850 Transpose = A;
1851 }
1852
1853 // Initialize the output
1854 MatrixTy Result(R, C, EltType);
1855
1856 emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1857 getFastMathFlags(MatMul));
1858
1859 FusedInsts.insert(MatMul);
1860 if (Transpose->hasOneUse()) {
1861 FusedInsts.insert(cast<Instruction>(Transpose));
1862 ToRemove.push_back(cast<Instruction>(Transpose));
1863 // TODO: add a fake entry for the folded instruction so that this is
1864 // included in the expression in the remark.
1865 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1866 }
1867 finalizeLowering(MatMul, Result, Builder);
1868 return;
1869 }
1870
1871 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1872 return;
1873
1874 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1875 // since the single store user will be lowered as part of this.
1876 auto *LoadOp0 = dyn_cast<LoadInst>(A);
1877 auto *LoadOp1 = dyn_cast<LoadInst>(B);
1878 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1879 if (LoadOp0 && LoadOp1 && Store) {
1880 // The store address must dominate the MatMul instruction, otherwise
1881 // we create invalid IR.
1882 SetVector<Value *> WorkList;
1883 WorkList.insert(Store->getOperand(1));
1885 for (unsigned I = 0; I != WorkList.size(); ++I) {
1886 Value *Current = WorkList[I];
1887 auto *CurrI = dyn_cast<Instruction>(Current);
1888 if (!CurrI)
1889 continue;
1890 if (isa<PHINode>(CurrI))
1891 return;
1892 if (DT->dominates(CurrI, MatMul))
1893 continue;
1894 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1895 return;
1896 ToHoist.push_back(CurrI);
1897 WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1898 }
1899
1900 sort(ToHoist, [this](Instruction *A, Instruction *B) {
1901 return DT->dominates(A, B);
1902 });
1903 for (Instruction *I : ToHoist)
1904 I->moveBefore(MatMul);
1905
1906 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1907 return;
1908 }
1909 }
1910
1911 /// Lowers llvm.matrix.multiply.
1912 void LowerMultiply(CallInst *MatMul) {
1913 IRBuilder<> Builder(MatMul);
1914 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1915 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1916 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1917
1918 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1919 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1920 assert(Lhs.getElementType() == Rhs.getElementType() &&
1921 "Matrix multiply argument element types do not match.");
1922
1923 const unsigned R = LShape.NumRows;
1924 const unsigned C = RShape.NumColumns;
1925 assert(LShape.NumColumns == RShape.NumRows);
1926
1927 // Initialize the output
1928 MatrixTy Result(R, C, EltType);
1929 assert(Lhs.getElementType() == Result.getElementType() &&
1930 "Matrix multiply result element type does not match arguments.");
1931
1932 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
1933 getFastMathFlags(MatMul));
1934 finalizeLowering(MatMul, Result, Builder);
1935 }
1936
1937 /// Lowers llvm.matrix.transpose.
1938 void LowerTranspose(CallInst *Inst) {
1939 MatrixTy Result;
1940 IRBuilder<> Builder(Inst);
1941 Value *InputVal = Inst->getArgOperand(0);
1942 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1943 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1944 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1945
1946 const unsigned NewNumVecs =
1947 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1948 const unsigned NewNumElts =
1949 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1950
1951 for (unsigned I = 0; I < NewNumVecs; ++I) {
1952 // Build a single result vector. First initialize it.
1953 Value *ResultVector = PoisonValue::get(
1954 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1955 // Go through the old elements and insert it into the resulting vector.
1956 for (auto J : enumerate(InputMatrix.vectors())) {
1957 Value *Elt = Builder.CreateExtractElement(J.value(), I);
1958 // Row and column indices are transposed.
1959 ResultVector =
1960 Builder.CreateInsertElement(ResultVector, Elt, J.index());
1961 }
1962 Result.addVector(ResultVector);
1963 }
1964
1965 // TODO: Improve estimate of operations needed for transposes. Currently we
1966 // just count the insertelement/extractelement instructions, but do not
1967 // account for later simplifications/combines.
1968 finalizeLowering(
1969 Inst,
1970 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
1971 .addNumExposedTransposes(1),
1972 Builder);
1973 }
1974
1975 /// Lower load instructions, if shape information is available.
1976 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
1977 auto I = ShapeMap.find(Inst);
1978 if (I == ShapeMap.end())
1979 return false;
1980
1981 LowerLoad(Inst, Ptr, Inst->getAlign(),
1982 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1983 I->second);
1984 return true;
1985 }
1986
1987 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
1988 IRBuilder<> &Builder) {
1989 auto I = ShapeMap.find(StoredVal);
1990 if (I == ShapeMap.end())
1991 return false;
1992
1993 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
1994 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1995 I->second);
1996 return true;
1997 }
1998
1999 /// Lower binary operators, if shape information is available.
2000 bool VisitBinaryOperator(BinaryOperator *Inst) {
2001 auto I = ShapeMap.find(Inst);
2002 if (I == ShapeMap.end())
2003 return false;
2004
2005 Value *Lhs = Inst->getOperand(0);
2006 Value *Rhs = Inst->getOperand(1);
2007
2008 IRBuilder<> Builder(Inst);
2009 ShapeInfo &Shape = I->second;
2010
2011 MatrixTy Result;
2012 MatrixTy A = getMatrix(Lhs, Shape, Builder);
2013 MatrixTy B = getMatrix(Rhs, Shape, Builder);
2014 assert(A.isColumnMajor() == B.isColumnMajor() &&
2015 Result.isColumnMajor() == A.isColumnMajor() &&
2016 "operands must agree on matrix layout");
2017
2018 Builder.setFastMathFlags(getFastMathFlags(Inst));
2019
2020 // Helper to perform binary op on vectors.
2021 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2022 switch (Inst->getOpcode()) {
2023 case Instruction::Add:
2024 return Builder.CreateAdd(LHS, RHS);
2025 case Instruction::Mul:
2026 return Builder.CreateMul(LHS, RHS);
2027 case Instruction::Sub:
2028 return Builder.CreateSub(LHS, RHS);
2029 case Instruction::FAdd:
2030 return Builder.CreateFAdd(LHS, RHS);
2031 case Instruction::FMul:
2032 return Builder.CreateFMul(LHS, RHS);
2033 case Instruction::FSub:
2034 return Builder.CreateFSub(LHS, RHS);
2035 default:
2036 llvm_unreachable("Unsupported binary operator for matrix");
2037 }
2038 };
2039
2040 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2041 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
2042
2043 finalizeLowering(Inst,
2044 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2045 Result.getNumVectors()),
2046 Builder);
2047 return true;
2048 }
2049
2050 /// Lower unary operators, if shape information is available.
2051 bool VisitUnaryOperator(UnaryOperator *Inst) {
2052 auto I = ShapeMap.find(Inst);
2053 if (I == ShapeMap.end())
2054 return false;
2055
2056 Value *Op = Inst->getOperand(0);
2057
2058 IRBuilder<> Builder(Inst);
2059 ShapeInfo &Shape = I->second;
2060
2061 MatrixTy Result;
2062 MatrixTy M = getMatrix(Op, Shape, Builder);
2063
2064 Builder.setFastMathFlags(getFastMathFlags(Inst));
2065
2066 // Helper to perform unary op on vectors.
2067 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2068 switch (Inst->getOpcode()) {
2069 case Instruction::FNeg:
2070 return Builder.CreateFNeg(Op);
2071 default:
2072 llvm_unreachable("Unsupported unary operator for matrix");
2073 }
2074 };
2075
2076 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2077 Result.addVector(BuildVectorOp(M.getVector(I)));
2078
2079 finalizeLowering(Inst,
2080 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2081 Result.getNumVectors()),
2082 Builder);
2083 return true;
2084 }
2085
2086 /// Helper to linearize a matrix expression tree into a string. Currently
2087 /// matrix expressions are linarized by starting at an expression leaf and
2088 /// linearizing bottom up.
2089 struct ExprLinearizer {
2090 unsigned LengthToBreak = 100;
2091 std::string Str;
2092 raw_string_ostream Stream;
2093 unsigned LineLength = 0;
2094 const DataLayout &DL;
2095
2096 /// Mapping from instructions to matrixes. It is used to identify
2097 /// matrix instructions.
2098 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2099
2100 /// Mapping from values to the leaves of all expressions that the value is
2101 /// part of.
2103
2104 /// Set of matrix expressions in the scope of a given DISubprogram.
2105 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2106
2107 /// Leaf node of the expression to linearize.
2108 Value *Leaf;
2109
2110 /// Used to keep track of sub-expressions that get reused while linearizing
2111 /// the expression. Re-used sub-expressions are marked as (reused).
2112 SmallPtrSet<Value *, 8> ReusedExprs;
2113
2114 ExprLinearizer(const DataLayout &DL,
2115 const MapVector<Value *, MatrixTy> &Inst2Matrix,
2116 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2117 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2118 Value *Leaf)
2119 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2120 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2121
2122 void indent(unsigned N) {
2123 LineLength += N;
2124 for (unsigned i = 0; i < N; i++)
2125 Stream << " ";
2126 }
2127
2128 void lineBreak() {
2129 Stream << "\n";
2130 LineLength = 0;
2131 }
2132
2133 void maybeIndent(unsigned Indent) {
2134 if (LineLength >= LengthToBreak)
2135 lineBreak();
2136
2137 if (LineLength == 0)
2138 indent(Indent);
2139 }
2140
2141 void write(StringRef S) {
2142 LineLength += S.size();
2143 Stream << S;
2144 }
2145
2146 Value *getUnderlyingObjectThroughLoads(Value *V) {
2147 if (Value *Ptr = getPointerOperand(V))
2148 return getUnderlyingObjectThroughLoads(Ptr);
2149 else if (V->getType()->isPointerTy())
2150 return getUnderlyingObject(V);
2151 return V;
2152 }
2153
2154 /// Returns true if \p V is a matrix value in the given subprogram.
2155 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2156
2157 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2158 /// \p SS.
2159 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2160 auto M = Inst2Matrix.find(V);
2161 if (M == Inst2Matrix.end())
2162 SS << "unknown";
2163 else {
2164 SS << M->second.getNumRows();
2165 SS << "x";
2166 SS << M->second.getNumColumns();
2167 }
2168 }
2169
2170 /// Write the called function name. Handles calls to llvm.matrix.*
2171 /// specially: we write the name, followed by the dimensions of the input
2172 /// matrixes, followed by the scalar type name.
2173 void writeFnName(CallInst *CI) {
2174 if (!CI->getCalledFunction())
2175 write("<no called fn>");
2176 else {
2178 if (!Name.starts_with("llvm.matrix")) {
2179 write(Name);
2180 return;
2181 }
2182 auto *II = cast<IntrinsicInst>(CI);
2184 .drop_front(StringRef("llvm.matrix.").size()));
2185 write(".");
2186 std::string Tmp;
2188
2189 switch (II->getIntrinsicID()) {
2190 case Intrinsic::matrix_multiply:
2191 prettyPrintMatrixType(II->getOperand(0), SS);
2192 SS << ".";
2193 prettyPrintMatrixType(II->getOperand(1), SS);
2194 SS << "." << *II->getType()->getScalarType();
2195 break;
2196 case Intrinsic::matrix_transpose:
2197 prettyPrintMatrixType(II->getOperand(0), SS);
2198 SS << "." << *II->getType()->getScalarType();
2199 break;
2200 case Intrinsic::matrix_column_major_load:
2201 prettyPrintMatrixType(II, SS);
2202 SS << "." << *II->getType()->getScalarType();
2203 break;
2204 case Intrinsic::matrix_column_major_store:
2205 prettyPrintMatrixType(II->getOperand(0), SS);
2206 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2207 break;
2208 default:
2209 llvm_unreachable("Unhandled case");
2210 }
2211 SS.flush();
2212 write(Tmp);
2213 }
2214 }
2215
2216 unsigned getNumShapeArgs(CallInst *CI) const {
2217 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2218 switch (II->getIntrinsicID()) {
2219 case Intrinsic::matrix_multiply:
2220 return 3;
2221 case Intrinsic::matrix_transpose:
2222 return 2;
2223 case Intrinsic::matrix_column_major_load:
2224 case Intrinsic::matrix_column_major_store:
2225 return 3;
2226 default:
2227 return 0;
2228 }
2229 }
2230 return 0;
2231 }
2232
2233 /// Special printing for values: for pointers, we print if they refer to an
2234 /// (function) external address or a stack address, for other values we
2235 /// either print the constant or "scalar"/"matrix" for other values.
2236 void write(Value *V) {
2237 V = getUnderlyingObjectThroughLoads(V);
2238 if (V->getType()->isPointerTy()) {
2239 if (isa<AllocaInst>(V)) {
2240 Stream << "stack addr";
2241 LineLength += StringRef("stack addr").size();
2242 } else {
2243 Stream << "addr";
2244 LineLength += StringRef("addr").size();
2245 }
2246 if (!V->getName().empty()) {
2247 Stream << " %" << V->getName() << "";
2248 LineLength += V->getName().size() + 2;
2249 }
2250 return;
2251 }
2252
2253 std::string Tmp;
2254 raw_string_ostream TmpStream(Tmp);
2255
2256 if (auto *CI = dyn_cast<ConstantInt>(V))
2257 TmpStream << CI->getValue();
2258 else if (isa<Constant>(V))
2259 TmpStream << "constant";
2260 else {
2261 if (isMatrix(V))
2262 TmpStream << "matrix";
2263 else
2264 TmpStream << "scalar";
2265 }
2266 TmpStream.flush();
2267 Tmp = std::string(StringRef(Tmp).trim());
2268 LineLength += Tmp.size();
2269 Stream << Tmp;
2270 }
2271
2272 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2273 /// Expressions that are re-used multiple times are prefixed with (reused)
2274 /// at the re-used root instruction.
2275 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2276 bool ParentShared) {
2277 auto *I = cast<Instruction>(Expr);
2278 maybeIndent(Indent);
2280
2281 // Is Expr shared with other expression leaves?
2282 bool ExprShared = false;
2283
2284 // Deal with shared subtrees. Mark them as shared, if required.
2285 if (!ParentShared) {
2286 auto SI = Shared.find(Expr);
2287 assert(SI != Shared.end() && SI->second.count(Leaf));
2288
2289 for (Value *S : SI->second) {
2290 if (S == Leaf)
2291 continue;
2292 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2293 write("shared with remark at line " + std::to_string(DL.getLine()) +
2294 " column " + std::to_string(DL.getCol()) + " (");
2295 }
2296 ExprShared = SI->second.size() > 1;
2297 }
2298
2299 bool Reused = !ReusedExprs.insert(Expr).second;
2300 if (Reused && !ParentReused)
2301 write("(reused) ");
2302
2303 if (auto *CI = dyn_cast<CallInst>(I)) {
2304 writeFnName(CI);
2305
2306 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2307 } else if (isa<BitCastInst>(Expr)) {
2308 // Special case bitcasts, which are used to materialize matrixes from
2309 // non-matrix ops.
2310 write("matrix");
2311 return;
2312 } else {
2313 Ops.append(I->value_op_begin(), I->value_op_end());
2314 write(std::string(I->getOpcodeName()));
2315 }
2316
2317 write(std::string("("));
2318
2319 unsigned NumOpsToBreak = 1;
2320 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2321 NumOpsToBreak = 2;
2322
2323 for (Value *Op : Ops) {
2324 if (Ops.size() > NumOpsToBreak)
2325 lineBreak();
2326
2327 maybeIndent(Indent + 1);
2328 if (isMatrix(Op))
2329 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2330 else
2331 write(Op);
2332 if (Op != Ops.back())
2333 write(", ");
2334 }
2335
2336 write(")");
2337 }
2338
2339 const std::string &getResult() {
2340 Stream.flush();
2341 return Str;
2342 }
2343 };
2344
2345 /// Generate remarks for matrix operations in a function. To generate remarks
2346 /// for matrix expressions, the following approach is used:
2347 /// 1. Use the inlined-at debug information to group matrix operations to the
2348 /// DISubprograms they are contained in.
2349 /// 2. Collect leaves of matrix expressions (done in
2350 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2351 // mapping. Leaves are lowered matrix instructions without other matrix
2352 // users (like stores) in the current subprogram.
2353 /// 3. For each leaf, create a remark containing a linearizied version of the
2354 /// matrix expression. The expression is linearized by a recursive
2355 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2356 /// that multiple leaves can share sub-expressions. Shared subexpressions
2357 /// are explicitly marked as shared().
2358 struct RemarkGenerator {
2359 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2361 Function &Func;
2362 const DataLayout &DL;
2363
2364 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2366 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2367 DL(Func.getParent()->getDataLayout()) {}
2368
2369 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2370 /// instructions in Inst2Matrix returning void or without any users in
2371 /// \p ExprsInSubprogram. Currently that should only include stores.
2373 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2375 for (auto *Expr : ExprsInSubprogram)
2376 if (Expr->getType()->isVoidTy() ||
2377 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2378 return ExprsInSubprogram.count(U);
2379 }))
2380 Leaves.push_back(Expr);
2381 return Leaves;
2382 }
2383
2384 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2385 /// to all visited expressions in \p Shared. Limit the matrix operations to
2386 /// the ones in \p ExprsInSubprogram.
2387 void collectSharedInfo(Value *Leaf, Value *V,
2388 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2390
2391 if (!ExprsInSubprogram.count(V))
2392 return;
2393
2394 auto I = Shared.insert({V, {}});
2395 I.first->second.insert(Leaf);
2396
2397 for (Value *Op : cast<Instruction>(V)->operand_values())
2398 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2399 }
2400
2401 /// Calculate the number of exclusive and shared op counts for expression
2402 /// starting at \p V. Expressions used multiple times are counted once.
2403 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2404 std::pair<OpInfoTy, OpInfoTy>
2405 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2406 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2407 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2408 if (!ExprsInSubprogram.count(Root))
2409 return {};
2410
2411 // Already counted this expression. Stop.
2412 if (!ReusedExprs.insert(Root).second)
2413 return {};
2414
2415 OpInfoTy SharedCount;
2416 OpInfoTy Count;
2417
2418 auto I = Shared.find(Root);
2419 auto CM = Inst2Matrix.find(Root);
2420 if (I->second.size() == 1)
2421 Count = CM->second.getOpInfo();
2422 else
2423 SharedCount = CM->second.getOpInfo();
2424
2425 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2426 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2427 Count += C.first;
2428 SharedCount += C.second;
2429 }
2430 return {Count, SharedCount};
2431 }
2432
2433 void emitRemarks() {
2435 return;
2436
2437 // Map matrix operations to their containting subprograms, by traversing
2438 // the inlinedAt chain. If the function does not have a DISubprogram, we
2439 // only map them to the containing function.
2441 for (const auto &KV : Inst2Matrix) {
2442 if (Func.getSubprogram()) {
2443 auto *I = cast<Instruction>(KV.first);
2444 DILocation *Context = I->getDebugLoc();
2445 while (Context) {
2446 auto I =
2447 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2448 I.first->second.push_back(KV.first);
2449 Context = DebugLoc(Context).getInlinedAt();
2450 }
2451 } else {
2452 auto I = Subprog2Exprs.insert({nullptr, {}});
2453 I.first->second.push_back(KV.first);
2454 }
2455 }
2456 for (auto &KV : Subprog2Exprs) {
2457 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2458 KV.second.end());
2459 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2460
2462 for (Value *Leaf : Leaves)
2463 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2464
2465 // Generate remarks for each leaf.
2466 for (auto *L : Leaves) {
2467
2468 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2469 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2470 while (Context) {
2471 if (getSubprogram(Context->getScope()) == KV.first) {
2472 Loc = Context;
2473 break;
2474 }
2475 Context = DebugLoc(Context).getInlinedAt();
2476 }
2477
2478 SmallPtrSet<Value *, 8> ReusedExprs;
2479 OpInfoTy Counts, SharedCounts;
2480 std::tie(Counts, SharedCounts) =
2481 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2482
2483 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2484 cast<Instruction>(L)->getParent());
2485
2486 Rem << "Lowered with ";
2487 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2488 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2489 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2490 << " compute ops, "
2491 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2492 << " exposed transposes";
2493
2494 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2495 SharedCounts.NumComputeOps > 0) {
2496 Rem << ",\nadditionally "
2497 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2498 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2499 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2500 << " compute ops"
2501 << " are shared with other expressions";
2502 }
2503
2504 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2505 ORE.emit(Rem);
2506 }
2507 }
2508 }
2509
2510 std::string
2511 linearize(Value *L,
2512 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2513 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2514 const DataLayout &DL) {
2515 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2516 Lin.linearizeExpr(L, 0, false, false);
2517 return Lin.getResult();
2518 }
2519 };
2520};
2521} // namespace
2522
2525 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2526 OptimizationRemarkEmitter *ORE = nullptr;
2527 AAResults *AA = nullptr;
2528 DominatorTree *DT = nullptr;
2529 LoopInfo *LI = nullptr;
2530
2531 if (!Minimal) {
2533 AA = &AM.getResult<AAManager>(F);
2535 LI = &AM.getResult<LoopAnalysis>(F);
2536 }
2537
2538 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2539 if (LMT.Visit()) {
2541 if (!Minimal) {
2542 PA.preserve<LoopAnalysis>();
2544 }
2545 return PA;
2546 }
2547 return PreservedAnalyses::all();
2548}
2549
2551 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2553 OS, MapClassName2PassName);
2554 OS << '<';
2555 if (Minimal)
2556 OS << "minimal";
2557 OS << '>';
2558}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Rewrite undef for PHI
ReachingDefAnalysis InstSet & ToRemove
static const Function * getParent(const Value *V)
BitTracker BT
Definition: BitTracker.cpp:73
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:680
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
std::string Name
bool End
Definition: ELF_riscv.cpp:478
Hexagon Common GEP
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:526
Live Register Matrix
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
#define DEBUG_TYPE
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define T1
LLVMContext & Context
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2499
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2521
raw_pwrite_stream & OS
This file defines the SmallSet class.
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
static const int BlockSize
Definition: TarWriter.cpp:33
This pass exposes codegen information to IR-level passes.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Value * RHS
Value * LHS
BinaryOperator * Mul
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Definition: Instructions.h:58
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:125
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:649
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:803
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:429
reverse_iterator rbegin()
Definition: BasicBlock.h:445
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:167
reverse_iterator rend()
Definition: BasicBlock.h:447
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:220
BinaryOps getOpcode() const
Definition: InstrTypes.h:391
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1449
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1369
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1799
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1394
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1375
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1579
static Constant * get(Type *Ty, double V)
This returns a ConstantFP, or a vector containing a splat of a ConstantFP, for the specified value in...
Definition: Constants.cpp:927
This is the shared class of boolean and integer constants.
Definition: Constants.h:78
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:888
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Debug location.
Base class for scope-like contexts.
Subprogram description.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
A debug info location.
Definition: DebugLoc.h:33
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:39
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:278
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:165
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:123
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:20
void setAllowContract(bool B=true)
Definition: FMF.h:91
bool allowReassoc() const
Flag queries.
Definition: FMF.h:65
bool allowContract() const
Definition: FMF.h:70
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:539
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:699
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:230
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:235
CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Definition: IRBuilder.cpp:417
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:2227
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1551
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2442
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1769
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2430
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition: IRBuilder.h:1803
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1524
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
Definition: IRBuilder.cpp:1212
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
Definition: IRBuilder.cpp:433
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
Definition: IRBuilder.h:566
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Definition: IRBuilder.h:297
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition: IRBuilder.h:477
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition: IRBuilder.h:2367
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1335
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
Definition: IRBuilder.h:483
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition: IRBuilder.h:1111
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1786
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition: IRBuilder.h:2464
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1318
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2087
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:180
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Definition: IRBuilder.h:1822
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args=std::nullopt, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition: IRBuilder.h:2382
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", bool IsInBounds=false)
Definition: IRBuilder.h:1862
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1578
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition: IRBuilder.h:1726
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memcpy between the specified pointers.
Definition: IRBuilder.h:650
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1352
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2636
void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
const BasicBlock * getParent() const
Definition: Instruction.h:134
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:93
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:54
An instruction for reading from memory.
Definition: Instructions.h:177
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:214
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:220
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:36
iterator end()
Definition: MapVector.h:71
iterator find(const KeyT &Key)
Definition: MapVector.h:167
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:141
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
The optimization diagnostic interface.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1743
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:172
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:178
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:193
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:345
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:384
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:366
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:390
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:451
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
bool empty() const
Definition: SmallSet.h:159
bool erase(const T &V)
Definition: SmallSet.h:207
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:577
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:687
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
An instruction for storing to memory.
Definition: Instructions.h:301
Align getAlign() const
Definition: Instructions.h:345
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:337
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:613
constexpr size_t size() const
size - Get the string size.
Definition: StringRef.h:137
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
TypeSize getRegisterBitWidth(RegisterKind K) const
InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
@ TCK_RecipThroughput
Reciprocal throughput.
InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args=ArrayRef< const Value * >(), const Instruction *CxtI=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *Tp, ArrayRef< int > Mask=std::nullopt, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args=std::nullopt) const
unsigned getNumberOfRegisters(unsigned ClassID) const
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:228
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:140
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:348
UnaryOps getOpcode() const
Definition: InstrTypes.h:170
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:242
Value * getOperand(unsigned i) const
Definition: User.h:169
See the file comment.
Definition: ValueMap.h:84
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:151
iterator find(const KeyT &Val)
Definition: ValueMap.h:155
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:172
iterator end()
Definition: ValueMap.h:135
bool erase(const KeyT &Val)
Definition: ValueMap.h:190
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
user_iterator user_begin()
Definition: Value.h:397
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:149
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Type * getElementType() const
Definition: DerivedTypes.h:436
constexpr ScalarTy getFixedValue() const
Definition: TypeSize.h:189
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:642
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:121
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:1005
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1444
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
specific_intval< false > m_SpecificInt(APInt V)
Match a specific integer value or vector with all elements equal to the value.
Definition: PatternMatch.h:862
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
Definition: PatternMatch.h:982
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:84
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:147
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
Definition: PatternMatch.h:988
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
Definition: PatternMatch.h:67
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:218
@ SS
Definition: X86.h:207
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:705
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
NodeAddr< FuncNode * > Func
Definition: RDFGraph.h:393
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:228
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:237
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:440
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1684
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
Definition: STLExtras.h:2375
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2036
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
std::string & operator+=(std::string &buffer, StringRef string)
Definition: StringRef.h:907
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, bool ContinueOnCuIndexOverflow)
Definition: DWP.cpp:585
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:665
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:214
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1733
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:428
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1651
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:156
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ FMulAdd
Sum of float products with llvm.fmuladd(a * b + sum).
@ Add
Sum of integers.
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition: Casting.h:565
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition: PassManager.h:391
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....
Definition: MatrixUtils.h:31