LLVM 20.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
22#include "llvm/ADT/ScopeExit.h"
23#include "llvm/ADT/SmallSet.h"
32#include "llvm/IR/CFG.h"
33#include "llvm/IR/DataLayout.h"
35#include "llvm/IR/Function.h"
36#include "llvm/IR/IRBuilder.h"
43#include "llvm/Support/Debug.h"
47
48#include <cmath>
49
50using namespace llvm;
51using namespace PatternMatch;
52
53#define DEBUG_TYPE "lower-matrix-intrinsics"
54
55static cl::opt<bool>
56 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
57 cl::desc("Enable/disable fusing matrix instructions."));
58// TODO: Allow and use non-square tiles.
60 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
62 "Tile size for matrix instruction fusion using square-shaped tiles."));
63static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
65 cl::desc("Generate loop nest for tiling."));
67 "force-fuse-matrix", cl::init(false), cl::Hidden,
68 cl::desc("Force matrix instruction fusion even if not profitable."));
70 "matrix-allow-contract", cl::init(false), cl::Hidden,
71 cl::desc("Allow the use of FMAs if available and profitable. This may "
72 "result in different results, due to less rounding error."));
73
74static cl::opt<bool>
75 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
76 cl::desc("Enable/disable matrix shape verification."),
77 cl::init(false));
78
80
82 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
83 cl::desc("Sets the default matrix layout"),
85 "Use column-major layout"),
87 "Use row-major layout")));
88
89static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
90 cl::init(false));
91
92/// Helper function to either return Scope, if it is a subprogram or the
93/// attached subprogram for a local scope.
95 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
96 return Subprogram;
97 return cast<DILocalScope>(Scope)->getSubprogram();
98}
99
100/// Erase \p V from \p BB and move \II forward to avoid invalidating
101/// iterators.
103 BasicBlock &BB) {
104 auto *Inst = cast<Instruction>(V);
105 // Still used, don't erase.
106 if (!Inst->use_empty())
107 return;
108 if (II != BB.rend() && Inst == &*II)
109 ++II;
110 Inst->eraseFromParent();
111}
112
113/// Return true if V is a splat of a value (which is used when multiplying a
114/// matrix with a scalar).
115static bool isSplat(Value *V) {
116 if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
117 return SV->isZeroEltSplat();
118 return false;
119}
120
121/// Match any mul operation (fp or integer).
122template <typename LTy, typename RTy>
123auto m_AnyMul(const LTy &L, const RTy &R) {
124 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
125}
126
127/// Match any add operation (fp or integer).
128template <typename LTy, typename RTy>
129auto m_AnyAdd(const LTy &L, const RTy &R) {
130 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
131}
132
133namespace {
134
135// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
136// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
137// assuming \p Stride elements between start two consecutive vectors.
138// \p Stride must be >= \p NumElements.
139// For column-major matrixes, the function computes the address of a column
140// vectors and \p NumElements must be set to the number of elements in a column
141// (= number of rows of the matrix). For row-major matrixes, the function
142// computes the address of a row vector and \p NumElements must be set to the
143// number of elements in a column (= number of columns of the matrix).
144//
145// Consider a 4x4 matrix in column-mjaor layout like below
146//
147// 0 1 2 3
148// 0 v_0_0 v_0_1 v_0_2 v_0_3
149// 1 v_1_0 v_1_1 v_1_2 v_1_3
150// 2 v_2_0 v_2_1 v_2_2 v_2_3
151// 3 v_3_0 v_3_1 v_3_2 v_3_3
152
153// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
154// we need a pointer to the first element of the submatrix as base pointer.
155// Then we can use computeVectorAddr to compute the addresses for the columns
156// of the sub-matrix.
157//
158// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
159// -> just returns Base
160// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
161// -> returns Base + (1 * 4)
162// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
163// -> returns Base + (2 * 4)
164//
165// The graphic below illustrates the number of elements in a column (marked
166// with |) and the number of skipped elements (marked with }).
167//
168// v_0_0 v_0_1 {v_0_2 {v_0_3
169// Base Col 1 Col 2
170// | | |
171// v_1_0 |v_1_1 |v_1_2 |v_1_3
172// v_2_0 |v_2_1 |v_2_2 |v_2_3
173// v_3_0 {v_3_1 {v_3_2 v_3_3
174//
175Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
176 unsigned NumElements, Type *EltType,
177 IRBuilder<> &Builder) {
178
179 assert((!isa<ConstantInt>(Stride) ||
180 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
181 "Stride must be >= the number of elements in the result vector.");
182
183 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
184 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
185
186 // Get pointer to the start of the selected vector. Skip GEP creation,
187 // if we select vector 0.
188 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
189 VecStart = BasePtr;
190 else
191 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
192
193 return VecStart;
194}
195
196namespace {
197struct ShapeInfo {
198 unsigned NumRows;
199 unsigned NumColumns;
200
201 bool IsColumnMajor;
202
203 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
204 : NumRows(NumRows), NumColumns(NumColumns),
205 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
206
207 ShapeInfo(Value *NumRows, Value *NumColumns)
208 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
209 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
210
211 bool operator==(const ShapeInfo &other) {
212 return NumRows == other.NumRows && NumColumns == other.NumColumns;
213 }
214 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
215
216 /// Returns true if shape-information is defined, meaning both dimensions
217 /// are != 0.
218 operator bool() const {
219 assert(NumRows == 0 || NumColumns != 0);
220 return NumRows != 0;
221 }
222
223 unsigned getStride() const {
224 if (IsColumnMajor)
225 return NumRows;
226 return NumColumns;
227 }
228
229 unsigned getNumVectors() const {
230 if (IsColumnMajor)
231 return NumColumns;
232 return NumRows;
233 }
234
235 /// Returns the transposed shape.
236 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
237};
238} // namespace
239
240static bool isUniformShape(Value *V) {
241 Instruction *I = dyn_cast<Instruction>(V);
242 if (!I)
243 return true;
244
245 switch (I->getOpcode()) {
246 case Instruction::FAdd:
247 case Instruction::FSub:
248 case Instruction::FMul: // Scalar multiply.
249 case Instruction::FNeg:
250 case Instruction::Add:
251 case Instruction::Mul:
252 case Instruction::Sub:
253 return true;
254 default:
255 return false;
256 }
257}
258
259/// Return the ShapeInfo for the result of \p I, it it can be determined.
260static std::optional<ShapeInfo>
261computeShapeInfoForInst(Instruction *I,
262 const ValueMap<Value *, ShapeInfo> &ShapeMap) {
263 Value *M;
264 Value *N;
265 Value *K;
266 if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
267 m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
268 return ShapeInfo(M, K);
269 if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
270 m_Value(N)))) {
271 // Flip dimensions.
272 return ShapeInfo(N, M);
273 }
274 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
275 m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
276 m_Value(N))))
277 return ShapeInfo(N, M);
278 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
279 m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
280 return ShapeInfo(M, N);
281 Value *MatrixA;
282 if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
283 auto OpShape = ShapeMap.find(MatrixA);
284 if (OpShape != ShapeMap.end())
285 return OpShape->second;
286 }
287
288 if (isUniformShape(I)) {
289 // Find the first operand that has a known shape and use that.
290 for (auto &Op : I->operands()) {
291 auto OpShape = ShapeMap.find(Op.get());
292 if (OpShape != ShapeMap.end())
293 return OpShape->second;
294 }
295 }
296 return std::nullopt;
297}
298
299/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
300///
301/// Currently, the lowering for each matrix intrinsic is done as follows:
302/// 1. Propagate the shape information from intrinsics to connected
303/// instructions.
304/// 2. Lower instructions with shape information (assuming column-major layout).
305/// The lowering works similarly using row-major layout.
306/// 2.1. Get column vectors for each argument. If we already lowered the
307/// definition of an argument, use the produced column vectors directly.
308/// If not, split the operand vector containing an embedded matrix into
309/// a set of column vectors,
310/// 2.2. Lower the instruction in terms of column major operations, which
311/// yields a set of column vectors containing result matrix. Note that we
312/// lower all instructions that have shape information. Besides the
313/// intrinsics, this includes stores for example.
314/// 2.3. Update uses of the lowered instruction. If we have shape information
315/// for a user, there is nothing to do, as we will look up the result
316/// column matrix when lowering the user. For other uses, we embed the
317/// result matrix in a flat vector and update the use.
318/// 2.4. Cache the result column matrix for the instruction we lowered
319/// 3. After we lowered all instructions in a function, remove the now
320/// obsolete instructions.
321///
322class LowerMatrixIntrinsics {
323 Function &Func;
324 const DataLayout &DL;
326 AliasAnalysis *AA;
327 DominatorTree *DT;
328 LoopInfo *LI;
330
331 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
332 struct OpInfoTy {
333 /// Number of stores emitted to generate this matrix.
334 unsigned NumStores = 0;
335 /// Number of loads emitted to generate this matrix.
336 unsigned NumLoads = 0;
337 /// Number of compute operations emitted to generate this matrix.
338 unsigned NumComputeOps = 0;
339 /// Most of the time transposes can be fused with matrix multiplies or can
340 /// be folded away via algebraic simplifications. This is the number of
341 /// transposes that we failed to make "free" via such optimizations.
342 unsigned NumExposedTransposes = 0;
343
344 OpInfoTy &operator+=(const OpInfoTy &RHS) {
345 NumStores += RHS.NumStores;
346 NumLoads += RHS.NumLoads;
347 NumComputeOps += RHS.NumComputeOps;
348 NumExposedTransposes += RHS.NumExposedTransposes;
349 return *this;
350 }
351 };
352
353 /// Wrapper class representing a matrix as a set of vectors, either in row or
354 /// column major layout. All vectors must have the same vector type.
355 class MatrixTy {
357
358 OpInfoTy OpInfo;
359
360 bool IsColumnMajor = true;
361
362 public:
363 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
364 MatrixTy(ArrayRef<Value *> Vectors)
365 : Vectors(Vectors),
366 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
367 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
368 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
369
370 unsigned D = isColumnMajor() ? NumColumns : NumRows;
371 for (unsigned J = 0; J < D; ++J)
373 EltTy, isColumnMajor() ? NumRows : NumColumns)));
374 }
375
376 Value *getVector(unsigned i) const { return Vectors[i]; }
377 Value *getColumn(unsigned i) const {
378 assert(isColumnMajor() && "only supported for column-major matrixes");
379 return Vectors[i];
380 }
381 Value *getRow(unsigned i) const {
382 assert(!isColumnMajor() && "only supported for row-major matrixes");
383 return Vectors[i];
384 }
385
386 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
387
388 Type *getElementType() const { return getVectorTy()->getElementType(); }
389
390 unsigned getNumVectors() const {
391 if (isColumnMajor())
392 return getNumColumns();
393 return getNumRows();
394 }
395
396 unsigned getNumColumns() const {
397 if (isColumnMajor())
398 return Vectors.size();
399 else {
400 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
401 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
402 }
403 }
404 unsigned getNumRows() const {
405 if (isColumnMajor()) {
406 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
407 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
408 } else
409 return Vectors.size();
410 }
411
412 void addVector(Value *V) { Vectors.push_back(V); }
413 VectorType *getColumnTy() {
414 assert(isColumnMajor() && "only supported for column-major matrixes");
415 return getVectorTy();
416 }
417
418 VectorType *getVectorTy() const {
419 return cast<VectorType>(Vectors[0]->getType());
420 }
421
423 assert(isColumnMajor() &&
424 "columns() only supported for column-major matrixes");
425 return make_range(Vectors.begin(), Vectors.end());
426 }
427
429 return make_range(Vectors.begin(), Vectors.end());
430 }
431
432 /// Embed the vectors of the matrix into a flat vector by concatenating
433 /// them.
434 Value *embedInVector(IRBuilder<> &Builder) const {
435 return Vectors.size() == 1 ? Vectors[0]
436 : concatenateVectors(Builder, Vectors);
437 }
438
439 MatrixTy &addNumLoads(unsigned N) {
440 OpInfo.NumLoads += N;
441 return *this;
442 }
443
444 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
445
446 MatrixTy &addNumStores(unsigned N) {
447 OpInfo.NumStores += N;
448 return *this;
449 }
450
451 MatrixTy &addNumExposedTransposes(unsigned N) {
452 OpInfo.NumExposedTransposes += N;
453 return *this;
454 }
455
456 MatrixTy &addNumComputeOps(unsigned N) {
457 OpInfo.NumComputeOps += N;
458 return *this;
459 }
460
461 unsigned getNumStores() const { return OpInfo.NumStores; }
462 unsigned getNumLoads() const { return OpInfo.NumLoads; }
463 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
464
465 const OpInfoTy &getOpInfo() const { return OpInfo; }
466
467 bool isColumnMajor() const { return IsColumnMajor; }
468
469 unsigned getStride() const {
470 if (isColumnMajor())
471 return getNumRows();
472 return getNumColumns();
473 }
474
475 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
476 /// matrix is column-major, the result vector is extracted from a column
477 /// vector, otherwise from a row vector.
478 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
479 IRBuilder<> &Builder) const {
480 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
481 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
482 NumElts &&
483 "Extracted vector will contain poison values");
484 return Builder.CreateShuffleVector(
485 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
486 "block");
487 }
488 };
489
490 /// Maps instructions to their shape information. The shape information
491 /// describes the shape to be used while lowering. This matches the shape of
492 /// the result value of the instruction, with the only exceptions being store
493 /// instructions and the matrix_column_major_store intrinsics. For those, the
494 /// shape information indicates that those instructions should be lowered
495 /// using shape information as well. A ValueMap is used so that when
496 /// sub-passes like optimizeTransposes performs RAUW the map stays
497 /// up-to-date.
499
500 /// List of instructions to remove. While lowering, we are not replacing all
501 /// users of a lowered instruction, if shape information is available and
502 /// those need to be removed after we finished lowering.
504
505 /// Map from instructions to their produced column matrix.
506 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
507
508private:
510 FastMathFlags FMF;
511
512 if (isa<FPMathOperator>(*Inst))
513 FMF = Inst->getFastMathFlags();
514
516
517 return FMF;
518 }
519
520public:
521 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
524 : Func(F), DL(F.getDataLayout()), TTI(TTI), AA(AA), DT(DT),
525 LI(LI), ORE(ORE) {}
526
527 unsigned getNumOps(Type *VT) {
528 assert(isa<VectorType>(VT) && "Expected vector type");
529 return getNumOps(VT->getScalarType(),
530 cast<FixedVectorType>(VT)->getNumElements());
531 }
532
533 /// Is this the minimal version executed in the backend pipelines.
534 bool isMinimal() const {
535 return !DT;
536 }
537
538 /// Return the estimated number of vector ops required for an operation on
539 /// \p VT * N.
540 unsigned getNumOps(Type *ST, unsigned N) {
541 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
544 .getFixedValue()));
545 }
546
547 /// Return the set of vectors that a matrix value is lowered to.
548 ///
549 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
550 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
551 /// into vectors.
552 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
553 IRBuilder<> &Builder) {
554 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
555 assert(VType && "MatrixVal must be a vector type");
556 assert(cast<FixedVectorType>(VType)->getNumElements() ==
557 SI.NumRows * SI.NumColumns &&
558 "The vector size must match the number of matrix elements");
559
560 // Check if we lowered MatrixVal using shape information. In that case,
561 // return the existing matrix, if it matches the requested shape
562 // information. If there is a mis-match, embed the result in a flat
563 // vector and split it later.
564 auto Found = Inst2ColumnMatrix.find(MatrixVal);
565 if (Found != Inst2ColumnMatrix.end()) {
566 MatrixTy &M = Found->second;
567 // Return the found matrix, if its shape matches the requested shape
568 // information
569 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
570 return M;
571
572 MatrixVal = M.embedInVector(Builder);
573 }
574
575 // Otherwise split MatrixVal.
576 SmallVector<Value *, 16> SplitVecs;
577 for (unsigned MaskStart = 0;
578 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
579 MaskStart += SI.getStride()) {
580 Value *V = Builder.CreateShuffleVector(
581 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
582 "split");
583 SplitVecs.push_back(V);
584 }
585
586 return {SplitVecs};
587 }
588
589 /// If \p V already has a known shape return false. Otherwise set the shape
590 /// for instructions that support it.
591 bool setShapeInfo(Value *V, ShapeInfo Shape) {
592 assert(Shape && "Shape not set");
593 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
594 return false;
595
596 auto SIter = ShapeMap.find(V);
597 if (SIter != ShapeMap.end()) {
598 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
599 SIter->second.NumColumns != Shape.NumColumns)) {
600 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
601 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
602 << Shape.NumColumns << ") for " << *V << "\n";
604 "Matrix shape verification failed, compilation aborted!");
605 }
606
607 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
608 << SIter->second.NumRows << " "
609 << SIter->second.NumColumns << " for " << *V << "\n");
610 return false;
611 }
612
613 ShapeMap.insert({V, Shape});
614 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
615 << " for " << *V << "\n");
616 return true;
617 }
618
619 /// Returns true if shape information can be used for \p V. The supported
620 /// instructions must match the instructions that can be lowered by this pass.
621 bool supportsShapeInfo(Value *V) {
622 Instruction *Inst = dyn_cast<Instruction>(V);
623 if (!Inst)
624 return false;
625
626 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
627 if (II)
628 switch (II->getIntrinsicID()) {
629 case Intrinsic::matrix_multiply:
630 case Intrinsic::matrix_transpose:
631 case Intrinsic::matrix_column_major_load:
632 case Intrinsic::matrix_column_major_store:
633 return true;
634 default:
635 return false;
636 }
637 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
638 }
639
640 /// Propagate the shape information of instructions to their users.
641 /// The work list contains instructions for which we can compute the shape,
642 /// either based on the information provided by matrix intrinsics or known
643 /// shapes of operands.
645 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
647 // Pop an element for which we guaranteed to have at least one of the
648 // operand shapes. Add the shape for this and then add users to the work
649 // list.
650 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
651 while (!WorkList.empty()) {
652 Instruction *Inst = WorkList.pop_back_val();
653
654 // New entry, set the value and insert operands
655 bool Propagate = false;
656 if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
657 Propagate = setShapeInfo(Inst, *SI);
658
659 if (Propagate) {
660 NewWorkList.push_back(Inst);
661 for (auto *User : Inst->users())
662 if (ShapeMap.count(User) == 0)
663 WorkList.push_back(cast<Instruction>(User));
664 }
665 }
666
667 return NewWorkList;
668 }
669
670 /// Propagate the shape to operands of instructions with shape information.
671 /// \p Worklist contains the instruction for which we already know the shape.
673 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
675
676 auto pushInstruction = [](Value *V,
678 Instruction *I = dyn_cast<Instruction>(V);
679 if (I)
680 WorkList.push_back(I);
681 };
682 // Pop an element with known shape. Traverse the operands, if their shape
683 // derives from the result shape and is unknown, add it and add them to the
684 // worklist.
685 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
686 while (!WorkList.empty()) {
687 Value *V = WorkList.pop_back_val();
688
689 size_t BeforeProcessingV = WorkList.size();
690 if (!isa<Instruction>(V))
691 continue;
692
693 Value *MatrixA;
694 Value *MatrixB;
695 Value *M;
696 Value *N;
697 Value *K;
698 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
699 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
700 m_Value(N), m_Value(K)))) {
701 if (setShapeInfo(MatrixA, {M, N}))
702 pushInstruction(MatrixA, WorkList);
703
704 if (setShapeInfo(MatrixB, {N, K}))
705 pushInstruction(MatrixB, WorkList);
706
707 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
708 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
709 // Flip dimensions.
710 if (setShapeInfo(MatrixA, {M, N}))
711 pushInstruction(MatrixA, WorkList);
712 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
713 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
714 m_Value(M), m_Value(N)))) {
715 if (setShapeInfo(MatrixA, {M, N})) {
716 pushInstruction(MatrixA, WorkList);
717 }
718 } else if (isa<LoadInst>(V) ||
719 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
720 // Nothing to do, no matrix input.
721 } else if (isa<StoreInst>(V)) {
722 // Nothing to do. We forward-propagated to this so we would just
723 // backward propagate to an instruction with an already known shape.
724 } else if (isUniformShape(V)) {
725 // Propagate to all operands.
726 ShapeInfo Shape = ShapeMap[V];
727 for (Use &U : cast<Instruction>(V)->operands()) {
728 if (setShapeInfo(U.get(), Shape))
729 pushInstruction(U.get(), WorkList);
730 }
731 }
732 // After we discovered new shape info for new instructions in the
733 // worklist, we use their users as seeds for the next round of forward
734 // propagation.
735 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
736 for (User *U : WorkList[I]->users())
737 if (isa<Instruction>(U) && V != U)
738 NewWorkList.push_back(cast<Instruction>(U));
739 }
740 return NewWorkList;
741 }
742
743 /// (Op0 op Op1)^T -> Op0^T op Op1^T
744 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
745 /// them on both sides of \p Operation.
746 Instruction *distributeTransposes(
747 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
748 MatrixBuilder &Builder,
749 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
750 Operation) {
751 Value *T0 = Builder.CreateMatrixTranspose(
752 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
753 // We are being run after shape prop, add shape for newly created
754 // instructions so that we lower them later.
755 setShapeInfo(T0, Shape0.t());
756 Value *T1 = Builder.CreateMatrixTranspose(
757 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
758 setShapeInfo(T1, Shape1.t());
759 return Operation(T0, Shape0.t(), T1, Shape1.t());
760 }
761
762 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
763 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
764 // with New. We should only add New it it supportsShapeInfo so we insert
765 // it conditionally instead.
766 auto S = ShapeMap.find(&Old);
767 if (S != ShapeMap.end()) {
768 ShapeMap.erase(S);
769 if (supportsShapeInfo(New))
770 ShapeMap.insert({New, S->second});
771 }
772 Old.replaceAllUsesWith(New);
773 }
774
775 /// Sink a top-level transpose inside matmuls and adds.
776 /// This creates and erases instructions as needed, and returns the newly
777 /// created instruction while updating the iterator to avoid invalidation. If
778 /// this returns nullptr, no new instruction was created.
780 BasicBlock &BB = *I.getParent();
781 IRBuilder<> IB(&I);
782 MatrixBuilder Builder(IB);
783
784 Value *TA, *TAMA, *TAMB;
785 ConstantInt *R, *K, *C;
786 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
788 return nullptr;
789
790 // Transpose of a transpose is a nop
791 Value *TATA;
792 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
793 updateShapeAndReplaceAllUsesWith(I, TATA);
795 eraseFromParentAndMove(TA, II, BB);
796 return nullptr;
797 }
798
799 // k^T -> k
800 if (isSplat(TA)) {
801 updateShapeAndReplaceAllUsesWith(I, TA);
803 return nullptr;
804 }
805
806 // (A * B)^t -> B^t * A^t
807 // RxK KxC CxK KxR
808 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
809 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
811 auto NewInst = distributeTransposes(
812 TAMB, {K, C}, TAMA, {R, K}, Builder,
813 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
814 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
815 Shape0.NumColumns,
816 Shape1.NumColumns, "mmul");
817 });
818 updateShapeAndReplaceAllUsesWith(I, NewInst);
820 eraseFromParentAndMove(TA, II, BB);
821 return NewInst;
822 }
823
824 // Same as above, but with a mul, which occurs when multiplied
825 // with a scalar.
826 // (A * k)^t -> A^t * k
827 // R x C RxC
828 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
829 (isSplat(TAMA) || isSplat(TAMB))) {
830 IRBuilder<> LocalBuilder(&I);
831 // We know that the transposed operand is of shape RxC.
832 // An when multiplied with a scalar, the shape is preserved.
833 auto NewInst = distributeTransposes(
834 TAMA, {R, C}, TAMB, {R, C}, Builder,
835 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
836 bool IsFP = I.getType()->isFPOrFPVectorTy();
837 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
838 : LocalBuilder.CreateMul(T0, T1, "mmul");
839 auto *Result = cast<Instruction>(Mul);
840 setShapeInfo(Result, Shape0);
841 return Result;
842 });
843 updateShapeAndReplaceAllUsesWith(I, NewInst);
845 eraseFromParentAndMove(TA, II, BB);
846 return NewInst;
847 }
848
849 // (A + B)^t -> A^t + B^t
850 // RxC RxC CxR CxR
851 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
852 IRBuilder<> LocalBuilder(&I);
853 auto NewInst = distributeTransposes(
854 TAMA, {R, C}, TAMB, {R, C}, Builder,
855 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
856 bool IsFP = I.getType()->isFPOrFPVectorTy();
857 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
858 : LocalBuilder.CreateAdd(T0, T1, "madd");
859
860 auto *Result = cast<Instruction>(Add);
861 setShapeInfo(Result, Shape0);
862 return Result;
863 });
864 updateShapeAndReplaceAllUsesWith(I, NewInst);
866 eraseFromParentAndMove(TA, II, BB);
867 return NewInst;
868 }
869
870 return nullptr;
871 }
872
873 void liftTranspose(Instruction &I) {
874 // Erase dead Instructions after lifting transposes from binops.
875 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
876 if (T.use_empty())
877 T.eraseFromParent();
878 if (A->use_empty())
879 cast<Instruction>(A)->eraseFromParent();
880 if (A != B && B->use_empty())
881 cast<Instruction>(B)->eraseFromParent();
882 };
883
884 Value *A, *B, *AT, *BT;
885 ConstantInt *R, *K, *C;
886 // A^t * B ^t -> (B * A)^t
887 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
890 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
891 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
892 IRBuilder<> IB(&I);
893 MatrixBuilder Builder(IB);
894 Value *M = Builder.CreateMatrixMultiply(
895 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
896 setShapeInfo(M, {C, R});
897 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
898 R->getZExtValue());
899 updateShapeAndReplaceAllUsesWith(I, NewInst);
900 CleanupBinOp(I, A, B);
901 }
902 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
903 // the shape of the second transpose is different, there's a shape conflict
904 // which gets resolved by picking the shape of the first operand.
905 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
906 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
907 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
908 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
910 IRBuilder<> Builder(&I);
911 auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
912 setShapeInfo(Add, {R, C});
913 MatrixBuilder MBuilder(Builder);
914 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
915 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
916 updateShapeAndReplaceAllUsesWith(I, NewInst);
917 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
918 computeShapeInfoForInst(&I, ShapeMap) &&
919 "Shape of new instruction doesn't match original shape.");
920 CleanupBinOp(I, A, B);
921 assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
922 ShapeMap[Add] &&
923 "Shape of updated addition doesn't match cached shape.");
924 }
925 }
926
927 /// Try moving transposes in order to fold them away or into multiplies.
928 void optimizeTransposes() {
929 // First sink all transposes inside matmuls and adds, hoping that we end up
930 // with NN, NT or TN variants.
931 for (BasicBlock &BB : reverse(Func)) {
932 for (auto II = BB.rbegin(); II != BB.rend();) {
933 Instruction &I = *II;
934 // We may remove II. By default continue on the next/prev instruction.
935 ++II;
936 if (Instruction *NewInst = sinkTranspose(I, II))
937 II = std::next(BasicBlock::reverse_iterator(NewInst));
938 }
939 }
940
941 // If we have a TT matmul or a TT add, lift the transpose. We may be able
942 // to fold into consuming multiply or add.
943 for (BasicBlock &BB : Func) {
945 liftTranspose(I);
946 }
947 }
948 }
949
950 bool Visit() {
952
953 // Initially only the shape of matrix intrinsics is known.
954 // Initialize the work list with ops carrying shape information.
955 for (BasicBlock &BB : Func)
956 for (Instruction &Inst : BB) {
957 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
958 if (!II)
959 continue;
960
961 switch (II->getIntrinsicID()) {
962 case Intrinsic::matrix_multiply:
963 case Intrinsic::matrix_transpose:
964 case Intrinsic::matrix_column_major_load:
965 case Intrinsic::matrix_column_major_store:
966 WorkList.push_back(&Inst);
967 break;
968 default:
969 break;
970 }
971 }
972
973 // Avoid unnecessary work if there are no matrix intrinsics in the function.
974 if (WorkList.empty())
975 return false;
976
977 // Propagate shapes until nothing changes any longer.
978 while (!WorkList.empty()) {
979 WorkList = propagateShapeForward(WorkList);
980 WorkList = propagateShapeBackward(WorkList);
981 }
982
983 if (!isMinimal()) {
984 optimizeTransposes();
986 dbgs() << "Dump after matrix transpose optimization:\n";
987 Func.print(dbgs());
988 }
989 }
990
991 bool Changed = false;
992 SmallVector<CallInst *, 16> MaybeFusableInsts;
995
996 // First, collect all instructions with shape information and candidates for
997 // fusion (currently only matrix multiplies).
999 for (auto *BB : RPOT)
1000 for (Instruction &I : *BB) {
1001 if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
1002 LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
1003 if (ShapeMap.find(&I) == ShapeMap.end())
1004 continue;
1005 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
1006 MaybeFusableInsts.push_back(cast<CallInst>(&I));
1007 MatrixInsts.push_back(&I);
1008 }
1009
1010 // Second, try to lower any dot products
1012 for (CallInst *CI : MaybeFusableInsts)
1013 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1014
1015 // Third, try to fuse candidates.
1016 for (CallInst *CI : MaybeFusableInsts)
1017 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1018
1019 Changed = !FusedInsts.empty();
1020
1021 // Fourth, lower remaining instructions with shape information.
1022 for (Instruction *Inst : MatrixInsts) {
1023 if (FusedInsts.count(Inst))
1024 continue;
1025
1026 IRBuilder<> Builder(Inst);
1027
1028 if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1029 Changed |= VisitCallInst(CInst);
1030
1031 Value *Op1;
1032 Value *Op2;
1033 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1034 Changed |= VisitBinaryOperator(BinOp);
1035 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1036 Changed |= VisitUnaryOperator(UnOp);
1037 if (match(Inst, m_Load(m_Value(Op1))))
1038 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1039 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1040 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1041 }
1042
1043 if (ORE) {
1044 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1045 RemarkGen.emitRemarks();
1046 }
1047
1048 // Delete the instructions backwards, as it has a reduced likelihood of
1049 // having to update as many def-use and use-def chains.
1050 //
1051 // Because we add to ToRemove during fusion we can't guarantee that defs
1052 // are before uses. Change uses to poison temporarily as these should get
1053 // removed as well.
1054 //
1055 // For verification, we keep track of where we changed uses to poison in
1056 // PoisonedInsts and then check that we in fact remove them.
1057 SmallSet<Instruction *, 16> PoisonedInsts;
1058 for (auto *Inst : reverse(ToRemove)) {
1059 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1060 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1061 PoisonedInsts.insert(Poisoned);
1062 U.set(PoisonValue::get(Inst->getType()));
1063 }
1064 Inst->eraseFromParent();
1065 PoisonedInsts.erase(Inst);
1066 }
1067 if (!PoisonedInsts.empty()) {
1068 // If we didn't remove all poisoned instructions, it's a hard error.
1069 dbgs() << "Poisoned but present instructions:\n";
1070 for (auto *I : PoisonedInsts)
1071 dbgs() << *I << "\n";
1072 llvm_unreachable("Poisoned but instruction not removed");
1073 }
1074
1075 return Changed;
1076 }
1077
1078 /// Replace intrinsic calls
1079 bool VisitCallInst(CallInst *Inst) {
1080 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1081 return false;
1082
1083 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1084 case Intrinsic::matrix_multiply:
1085 LowerMultiply(Inst);
1086 break;
1087 case Intrinsic::matrix_transpose:
1088 LowerTranspose(Inst);
1089 break;
1090 case Intrinsic::matrix_column_major_load:
1091 LowerColumnMajorLoad(Inst);
1092 break;
1093 case Intrinsic::matrix_column_major_store:
1094 LowerColumnMajorStore(Inst);
1095 break;
1096 default:
1097 return false;
1098 }
1099 return true;
1100 }
1101
1102 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1103 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1104 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1105 /// non-ConstantInt strides, return the common alignment of the initial
1106 /// alignment and the element size in bytes.
1107 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1108 MaybeAlign A) const {
1109 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1110 if (Idx == 0)
1111 return InitialAlign;
1112
1113 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1114 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1115 uint64_t StrideInBytes =
1116 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1117 return commonAlignment(InitialAlign, Idx * StrideInBytes);
1118 }
1119 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1120 }
1121
1122 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1123 /// vectors.
1124 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1125 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1126 auto *VType = cast<VectorType>(Ty);
1127 Type *EltTy = VType->getElementType();
1128 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1129 Value *EltPtr = Ptr;
1130 MatrixTy Result;
1131 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1132 Value *GEP = computeVectorAddr(
1133 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1134 Stride, Shape.getStride(), EltTy, Builder);
1135 Value *Vector = Builder.CreateAlignedLoad(
1136 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1137 IsVolatile, "col.load");
1138
1139 Result.addVector(Vector);
1140 }
1141 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1142 Result.getNumVectors());
1143 }
1144
1145 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1146 /// starting at \p MatrixPtr[I][J].
1147 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1148 ShapeInfo MatrixShape, Value *I, Value *J,
1149 ShapeInfo ResultShape, Type *EltTy,
1150 IRBuilder<> &Builder) {
1151
1152 Value *Offset = Builder.CreateAdd(
1153 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1154
1155 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1156 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1157 ResultShape.NumColumns);
1158
1159 return loadMatrix(TileTy, TileStart, Align,
1160 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1161 ResultShape, Builder);
1162 }
1163
1164 /// Lower a load instruction with shape information.
1165 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1166 bool IsVolatile, ShapeInfo Shape) {
1167 IRBuilder<> Builder(Inst);
1168 finalizeLowering(Inst,
1169 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1170 Shape, Builder),
1171 Builder);
1172 }
1173
1174 /// Lowers llvm.matrix.column.major.load.
1175 ///
1176 /// The intrinsic loads a matrix from memory using a stride between columns.
1177 void LowerColumnMajorLoad(CallInst *Inst) {
1178 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1179 "Intrinsic only supports column-major layout!");
1180 Value *Ptr = Inst->getArgOperand(0);
1181 Value *Stride = Inst->getArgOperand(1);
1182 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1183 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1184 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1185 }
1186
1187 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1188 /// MatrixPtr[I][J].
1189 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1190 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1191 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1192 Value *Offset = Builder.CreateAdd(
1193 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1194
1195 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1196 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1197 StoreVal.getNumColumns());
1198
1199 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1200 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1201 }
1202
1203 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1204 /// vectors.
1205 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1206 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1207 IRBuilder<> &Builder) {
1208 auto VType = cast<VectorType>(Ty);
1209 Value *EltPtr = Ptr;
1210 for (auto Vec : enumerate(StoreVal.vectors())) {
1211 Value *GEP = computeVectorAddr(
1212 EltPtr,
1213 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1214 Vec.index()),
1215 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1216 Builder.CreateAlignedStore(Vec.value(), GEP,
1217 getAlignForIndex(Vec.index(), Stride,
1218 VType->getElementType(),
1219 MAlign),
1220 IsVolatile);
1221 }
1222 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1223 StoreVal.getNumVectors());
1224 }
1225
1226 /// Lower a store instruction with shape information.
1228 Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1229 IRBuilder<> Builder(Inst);
1230 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1231 finalizeLowering(Inst,
1232 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1233 IsVolatile, Builder),
1234 Builder);
1235 }
1236
1237 /// Lowers llvm.matrix.column.major.store.
1238 ///
1239 /// The intrinsic store a matrix back memory using a stride between columns.
1240 void LowerColumnMajorStore(CallInst *Inst) {
1241 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1242 "Intrinsic only supports column-major layout!");
1243 Value *Matrix = Inst->getArgOperand(0);
1244 Value *Ptr = Inst->getArgOperand(1);
1245 Value *Stride = Inst->getArgOperand(2);
1246 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1247 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1248 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1249 }
1250
1251 // Set elements I..I+NumElts-1 to Block
1252 Value *insertVector(Value *Col, unsigned I, Value *Block,
1253 IRBuilder<> &Builder) {
1254
1255 // First, bring Block to the same size as Col
1256 unsigned BlockNumElts =
1257 cast<FixedVectorType>(Block->getType())->getNumElements();
1258 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1259 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1260
1261 Block = Builder.CreateShuffleVector(
1262 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1263
1264 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1265 // 8, 4, 5, 6
1267 unsigned i;
1268 for (i = 0; i < I; i++)
1269 Mask.push_back(i);
1270
1271 unsigned VecNumElts =
1272 cast<FixedVectorType>(Col->getType())->getNumElements();
1273 for (; i < I + BlockNumElts; i++)
1274 Mask.push_back(i - I + VecNumElts);
1275
1276 for (; i < VecNumElts; i++)
1277 Mask.push_back(i);
1278
1279 return Builder.CreateShuffleVector(Col, Block, Mask);
1280 }
1281
1282 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1283 IRBuilder<> &Builder, bool AllowContraction,
1284 unsigned &NumComputeOps) {
1285 NumComputeOps += getNumOps(A->getType());
1286 if (!Sum)
1287 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1288
1289 if (UseFPOp) {
1290 if (AllowContraction) {
1291 // Use fmuladd for floating point operations and let the backend decide
1292 // if that's profitable.
1294 Func.getParent(), Intrinsic::fmuladd, A->getType());
1295 return Builder.CreateCall(FMulAdd, {A, B, Sum});
1296 }
1297 NumComputeOps += getNumOps(A->getType());
1298 Value *Mul = Builder.CreateFMul(A, B);
1299 return Builder.CreateFAdd(Sum, Mul);
1300 }
1301
1302 NumComputeOps += getNumOps(A->getType());
1303 Value *Mul = Builder.CreateMul(A, B);
1304 return Builder.CreateAdd(Sum, Mul);
1305 }
1306
1307 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1308 /// users with shape information, there's nothing to do: they will use the
1309 /// cached value when they are lowered. For other users, \p Matrix is
1310 /// flattened and the uses are updated to use it. Also marks \p Inst for
1311 /// deletion.
1312 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1313 IRBuilder<> &Builder) {
1314 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1315 (void)inserted;
1316 assert(inserted.second && "multiple matrix lowering mapping");
1317
1318 ToRemove.push_back(Inst);
1319 Value *Flattened = nullptr;
1320 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1321 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1322 if (!Flattened)
1323 Flattened = Matrix.embedInVector(Builder);
1324 U.set(Flattened);
1325 }
1326 }
1327 }
1328
1329 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1330 /// vectors Lowers to vector reduction add instead of sequential add if
1331 /// reassocation is enabled.
1332 void lowerDotProduct(CallInst *MatMul,
1334 FastMathFlags FMF) {
1335 if (FusedInsts.contains(MatMul) ||
1336 MatrixLayout != MatrixLayoutTy::ColumnMajor)
1337 return;
1338 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1339 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1340
1341 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1342 return;
1343
1344 Value *LHS = MatMul->getArgOperand(0);
1345 Value *RHS = MatMul->getArgOperand(1);
1346
1347 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
1348 bool IsIntVec = ElementType->isIntegerTy();
1349
1350 // Floating point reductions require reassocation.
1351 if (!IsIntVec && !FMF.allowReassoc())
1352 return;
1353
1354 auto CanBeFlattened = [](Value *Op) {
1355 if (match(Op, m_BinOp()))
1356 return true;
1357 return match(
1359 m_Load(m_Value()),
1360 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1361 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1362 m_Value(), m_SpecificInt(1))))));
1363 };
1364 // Returns the cost benefit of using \p Op with the dot product lowering. If
1365 // the returned cost is < 0, the argument is cheaper to use in the
1366 // dot-product lowering.
1367 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1368 if (ShapeMap.find(Op) == ShapeMap.end())
1370
1371 if (!isa<Instruction>(Op))
1372 return InstructionCost(0);
1373
1374 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1375 Type *EltTy = VecTy->getElementType();
1376
1377 if (!CanBeFlattened(Op)) {
1378 InstructionCost EmbedCost(0);
1379 // Roughly estimate the cost for embedding the columns into a vector.
1380 for (unsigned I = 1; I < N; ++I)
1381 EmbedCost +=
1383 std::nullopt, TTI::TCK_RecipThroughput);
1384 return EmbedCost;
1385 }
1386
1387 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1388 InstructionCost OriginalCost =
1389 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1390 EltTy) *
1391 N;
1393 cast<Instruction>(Op)->getOpcode(), VecTy);
1394 return NewCost - OriginalCost;
1395 }
1396
1397 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1398 // The transpose can be skipped for the dot product lowering, roughly
1399 // estimate the savings as the cost of embedding the columns in a
1400 // vector.
1401 InstructionCost EmbedCost(0);
1402 for (unsigned I = 1; I < N; ++I)
1403 EmbedCost -=
1405 std::nullopt, TTI::TCK_RecipThroughput);
1406 return EmbedCost;
1407 }
1408
1409 // Costs for loads.
1410 if (N == 1)
1411 return InstructionCost(0);
1412
1413 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1414 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
1415 };
1416
1417 // Iterate over LHS and operations feeding LHS and check if it is profitable
1418 // to flatten the visited ops. For each op, we compute the difference
1419 // between the flattened and matrix versions.
1421 SmallVector<Value *> WorkList;
1422 SmallVector<Value *> ToFlatten;
1423 WorkList.push_back(LHS);
1424 InstructionCost LHSCost(0);
1425 while (!WorkList.empty()) {
1426 Value *Op = WorkList.pop_back_val();
1427 if (!Seen.insert(Op).second)
1428 continue;
1429
1430 InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1431 if (OpCost + LHSCost >= LHSCost)
1432 continue;
1433
1434 LHSCost += OpCost;
1435 ToFlatten.push_back(Op);
1436 if (auto *I = dyn_cast<Instruction>(Op))
1437 WorkList.append(I->op_begin(), I->op_end());
1438 }
1439
1440 // We compare the costs of a vector.reduce.add to sequential add.
1441 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1442 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1443 InstructionCost ReductionCost =
1445 AddOpCode, cast<VectorType>(LHS->getType()),
1446 IsIntVec ? std::nullopt : std::optional(FMF)) +
1447 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
1448 InstructionCost SequentialAddCost =
1449 TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
1450 (LShape.NumColumns - 1) +
1451 TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
1452 (LShape.NumColumns);
1453 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1454 return;
1455
1456 FusedInsts.insert(MatMul);
1457 IRBuilder<> Builder(MatMul);
1458 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1459 this](Value *Op) {
1460 // Matmul must be the only user of loads because we don't use LowerLoad
1461 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1462 // instead of single vector load).
1463 if (!CanBeFlattened(Op))
1464 return;
1465
1466 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1467 ShapeMap[Op] = ShapeMap[Op].t();
1468 return;
1469 }
1470
1471 FusedInsts.insert(cast<Instruction>(Op));
1472 // If vector uses the builtin load, lower to a LoadInst
1473 Value *Arg;
1474 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1475 m_Value(Arg)))) {
1476 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1477 Op->replaceAllUsesWith(NewLoad);
1478 cast<Instruction>(Op)->eraseFromParent();
1479 return;
1480 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1481 m_Value(Arg)))) {
1482 ToRemove.push_back(cast<Instruction>(Op));
1483 Op->replaceAllUsesWith(Arg);
1484 return;
1485 }
1486 };
1487
1488 for (auto *V : ToFlatten)
1489 FlattenArg(V);
1490
1491 LHS = MatMul->getArgOperand(0);
1492
1493 // Insert mul/fmul and llvm.vector.reduce.fadd
1494 Value *Mul =
1495 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
1496
1497 Value *Result;
1498 if (IsIntVec)
1499 Result = Builder.CreateAddReduce(Mul);
1500 else {
1501 Result = Builder.CreateFAddReduce(
1502 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
1503 0.0),
1504 Mul);
1505 cast<Instruction>(Result)->setFastMathFlags(FMF);
1506 }
1507
1508 // pack scalar back into a matrix and then replace matmul inst
1510 Result, uint64_t(0));
1511 MatMul->replaceAllUsesWith(Result);
1512 FusedInsts.insert(MatMul);
1513 ToRemove.push_back(MatMul);
1514 }
1515
1516 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1517 /// addition.
1518 ///
1519 /// We can fold a transpose into the operand that is used to extract scalars.
1520 /// This is the first operands with row-major and the second with
1521 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1522 /// operand is transposed.
1523 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1524 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1525 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1526 const unsigned VF = std::max<unsigned>(
1528 .getFixedValue() /
1529 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1530 1U);
1531 unsigned R = Result.getNumRows();
1532 unsigned C = Result.getNumColumns();
1533 unsigned M = A.getNumColumns();
1534
1535 bool IsFP = Result.getElementType()->isFloatingPointTy();
1536 assert(A.isColumnMajor() == B.isColumnMajor() &&
1537 Result.isColumnMajor() == A.isColumnMajor() &&
1538 "operands must agree on matrix layout");
1539 unsigned NumComputeOps = 0;
1540
1541 Builder.setFastMathFlags(FMF);
1542
1543 if (A.isColumnMajor()) {
1544 // Multiply columns from the first operand with scalars from the second
1545 // operand. Then move along the K axes and accumulate the columns. With
1546 // this the adds can be vectorized without reassociation.
1547 for (unsigned J = 0; J < C; ++J) {
1548 unsigned BlockSize = VF;
1549 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1550 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1551
1552 for (unsigned I = 0; I < R; I += BlockSize) {
1553 // Gradually lower the vectorization factor to cover the remainder.
1554 while (I + BlockSize > R)
1555 BlockSize /= 2;
1556
1557 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1558 : nullptr;
1559 for (unsigned K = 0; K < M; ++K) {
1560 Value *L = A.extractVector(I, K, BlockSize, Builder);
1561 Value *RH = Builder.CreateExtractElement(
1562 B.getColumn(IsScalarMatrixTransposed ? K : J),
1563 IsScalarMatrixTransposed ? J : K);
1564 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1565 Sum =
1566 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1567 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1568 }
1569 Result.setVector(J,
1570 insertVector(Result.getVector(J), I, Sum, Builder));
1571 }
1572 }
1573 } else {
1574 // Multiply rows from the second operand with scalars from the first
1575 // operand. Then move along the K axes and accumulate the rows. With this
1576 // the adds can be vectorized without reassociation.
1577 for (unsigned I = 0; I < R; ++I) {
1578 unsigned BlockSize = VF;
1579 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1580 for (unsigned J = 0; J < C; J += BlockSize) {
1581 // Gradually lower the vectorization factor to cover the remainder.
1582 while (J + BlockSize > C)
1583 BlockSize /= 2;
1584
1585 Value *Sum = nullptr;
1586 for (unsigned K = 0; K < M; ++K) {
1587 Value *R = B.extractVector(K, J, BlockSize, Builder);
1588 Value *LH = Builder.CreateExtractElement(
1589 A.getVector(IsScalarMatrixTransposed ? K : I),
1590 IsScalarMatrixTransposed ? I : K);
1591 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1592 Sum =
1593 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1594 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1595 }
1596 Result.setVector(I,
1597 insertVector(Result.getVector(I), J, Sum, Builder));
1598 }
1599 }
1600 }
1601 Result.addNumComputeOps(NumComputeOps);
1602 }
1603
1604 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1605 /// copying it to a new location. This new or otherwise the original location
1606 /// is returned.
1607 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1608 CallInst *MatMul) {
1609 MemoryLocation StoreLoc = MemoryLocation::get(Store);
1610 MemoryLocation LoadLoc = MemoryLocation::get(Load);
1611
1612 // If we can statically determine noalias we're good.
1613 if (AA->isNoAlias(LoadLoc, StoreLoc))
1614 return Load->getPointerOperand();
1615
1616 // Create code to check if the memory locations of the Load and Store
1617 // overlap and if they do, copy Load's operand to a new buffer.
1618
1619 // First, create new blocks for 2n part of the check and the copy.
1620 BasicBlock *Check0 = MatMul->getParent();
1621 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1622 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1623 // as we adjust Check0 and Check1's branches.
1625 for (BasicBlock *Succ : successors(Check0))
1626 DTUpdates.push_back({DT->Delete, Check0, Succ});
1627
1628 BasicBlock *Check1 =
1629 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1630 nullptr, "alias_cont");
1631 BasicBlock *Copy =
1632 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1633 nullptr, "copy");
1634 BasicBlock *Fusion =
1635 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1636 nullptr, "no_alias");
1637
1638 // Check if the loaded memory location begins before the end of the store
1639 // location. If the condition holds, they might overlap, otherwise they are
1640 // guaranteed to not overlap.
1641 IRBuilder<> Builder(MatMul);
1642 Check0->getTerminator()->eraseFromParent();
1643 Builder.SetInsertPoint(Check0);
1644 Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout());
1645 Value *StoreBegin = Builder.CreatePtrToInt(
1646 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1647 Value *StoreEnd = Builder.CreateAdd(
1648 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1649 "store.end", true, true);
1650 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1651 IntPtrTy, "load.begin");
1652 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1653 Fusion);
1654
1655 // Check if the store begins before the end of the load location. If the
1656 // condition holds, they alias, otherwise they are guaranteed to not
1657 // overlap.
1658 Check1->getTerminator()->eraseFromParent();
1659 Builder.SetInsertPoint(Check1, Check1->begin());
1660 Value *LoadEnd = Builder.CreateAdd(
1661 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1662 "load.end", true, true);
1663 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1664 Fusion);
1665
1666 // Copy load operand to new alloca.
1667 Builder.SetInsertPoint(Copy, Copy->begin());
1668 auto *VT = cast<FixedVectorType>(Load->getType());
1669 // Use an array type for the alloca, to avoid potentially huge alignment
1670 // requirements for large vector types.
1671 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1672 AllocaInst *Alloca =
1673 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1674
1675 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
1676 Load->getAlign(), LoadLoc.Size.getValue());
1677 Builder.SetInsertPoint(Fusion, Fusion->begin());
1678 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1679 PHI->addIncoming(Load->getPointerOperand(), Check0);
1680 PHI->addIncoming(Load->getPointerOperand(), Check1);
1681 PHI->addIncoming(Alloca, Copy);
1682
1683 // Adjust DT.
1684 DTUpdates.push_back({DT->Insert, Check0, Check1});
1685 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1686 DTUpdates.push_back({DT->Insert, Check1, Copy});
1687 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1688 DT->applyUpdates(DTUpdates);
1689 return PHI;
1690 }
1691
1692 bool isFusionProfitable(CallInst *MatMul) {
1693 if (ForceFusion)
1694 return true;
1695
1696 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1697 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1698
1699 const unsigned R = LShape.NumRows;
1700 const unsigned C = RShape.NumColumns;
1701 const unsigned M = LShape.NumColumns;
1702 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1703
1704 const unsigned VF = std::max<unsigned>(
1706 .getFixedValue() /
1708 1U);
1709
1710 // Cost model for tiling
1711 //
1712 // For tiling to be beneficial, we need reuse either along the R or
1713 // the C axis. We vectorize along the R axis so that means at least
1714 // 3 elements.
1715 // TODO: Also consider cost of copying if operands alias.
1716 if (R <= VF && C == 1)
1717 return false;
1718 // Then we need enough elements to exceed the number of vector
1719 // registers we have. Note that this is an oversimplification since
1720 // fusing also takes some extra loads which may exceed the number of
1721 // reloads necessary.
1722 unsigned Op0Regs = (R + VF - 1) / VF * M;
1723 unsigned Op1Regs = (M + VF - 1) / VF * C;
1724 return Op0Regs + Op1Regs >
1726 }
1727
1728 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1729 MatrixTy Res;
1730 auto *ColumType = FixedVectorType::get(EltType, R);
1731 for (unsigned I = 0; I < C; ++I)
1732 Res.addVector(ConstantAggregateZero::get(ColumType));
1733 return Res;
1734 }
1735
1736 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1737 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1738 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1739
1740 // Create the main tiling loop nest.
1741 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1742 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1743 Instruction *InsertI = cast<Instruction>(MatMul);
1744 BasicBlock *Start = InsertI->getParent();
1745 BasicBlock *End =
1746 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1747 IRBuilder<> Builder(MatMul);
1748 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1749
1750 Type *TileVecTy =
1752 MatrixTy TileResult;
1753 // Insert in the inner loop header.
1754 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1755 // Create PHI nodes for the result columns to accumulate across iterations.
1756 SmallVector<PHINode *, 4> ColumnPhis;
1757 for (unsigned I = 0; I < TileSize; I++) {
1758 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1759 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1760 TI.RowLoop.Header->getSingleSuccessor());
1761 TileResult.addVector(Phi);
1762 ColumnPhis.push_back(Phi);
1763 }
1764
1765 // Insert in the inner loop body, which computes
1766 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1767 Builder.SetInsertPoint(InnerBody->getTerminator());
1768 // Load tiles of the operands.
1769 MatrixTy A =
1770 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1771 {TileSize, TileSize}, EltType, Builder);
1772 MatrixTy B =
1773 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1774 {TileSize, TileSize}, EltType, Builder);
1775 emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1776 getFastMathFlags(MatMul));
1777 // Store result after the inner loop is done.
1778 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1779 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1780 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1781 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1782
1783 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1784 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1785
1786 // Force unrolling of a few iterations of the inner loop, to make sure there
1787 // is enough work per iteration.
1788 // FIXME: The unroller should make this decision directly instead, but
1789 // currently the cost-model is not up to the task.
1790 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1791 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1792 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1793 }
1794
1795 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1796 StoreInst *Store,
1797 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1798 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1799 "Tiling only supported for column-major matrixes at the moment!");
1800 if (!isFusionProfitable(MatMul))
1801 return;
1802
1803 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1804 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1805
1806 const unsigned R = LShape.NumRows;
1807 const unsigned C = RShape.NumColumns;
1808 const unsigned M = LShape.NumColumns;
1809 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1810
1811 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1812 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1813 Value *CPtr = Store->getPointerOperand();
1814
1815 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1816 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1817 else {
1818 IRBuilder<> Builder(Store);
1819 for (unsigned J = 0; J < C; J += TileSize)
1820 for (unsigned I = 0; I < R; I += TileSize) {
1821 const unsigned TileR = std::min(R - I, unsigned(TileSize));
1822 const unsigned TileC = std::min(C - J, unsigned(TileSize));
1823 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1824
1825 for (unsigned K = 0; K < M; K += TileSize) {
1826 const unsigned TileM = std::min(M - K, unsigned(TileSize));
1827 MatrixTy A =
1828 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1829 LShape, Builder.getInt64(I), Builder.getInt64(K),
1830 {TileR, TileM}, EltType, Builder);
1831 MatrixTy B =
1832 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1833 RShape, Builder.getInt64(K), Builder.getInt64(J),
1834 {TileM, TileC}, EltType, Builder);
1835 emitMatrixMultiply(Res, A, B, Builder, true, false,
1836 getFastMathFlags(MatMul));
1837 }
1838 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1839 Builder.getInt64(I), Builder.getInt64(J), EltType,
1840 Builder);
1841 }
1842 }
1843
1844 // Mark eliminated instructions as fused and remove them.
1845 FusedInsts.insert(Store);
1846 FusedInsts.insert(MatMul);
1847 Store->eraseFromParent();
1848 MatMul->eraseFromParent();
1849 if (LoadOp0->hasNUses(0)) {
1850 FusedInsts.insert(LoadOp0);
1851 LoadOp0->eraseFromParent();
1852 }
1853 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1854 FusedInsts.insert(LoadOp1);
1855 LoadOp1->eraseFromParent();
1856 }
1857 }
1858
1859 /// Try to lower matrix multiply chains by fusing operations.
1860 ///
1861 /// Call finalizeLowering on lowered instructions. Instructions that are
1862 /// completely eliminated by fusion are added to \p FusedInsts.
1863 void
1864 LowerMatrixMultiplyFused(CallInst *MatMul,
1866 SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
1867 if (!FuseMatrix || !DT)
1868 return;
1869
1870 assert(AA && LI && "Analyses should be available");
1871
1872 Value *A = MatMul->getArgOperand(0);
1873 Value *B = MatMul->getArgOperand(1);
1874
1875 // We can fold the transpose into the operand that is used to fetch scalars.
1876 Value *T;
1877 if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1878 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1879 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1880 IRBuilder<> Builder(MatMul);
1881 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1882 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1883 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1884 const unsigned R = LShape.NumRows;
1885 const unsigned M = LShape.NumColumns;
1886 const unsigned C = RShape.NumColumns;
1887
1888 MatrixTy MA;
1889 MatrixTy MB;
1890
1891 Value *Transpose;
1892 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1893 MA = getMatrix(A, ShapeInfo(R, M), Builder);
1894 MB = getMatrix(T, ShapeInfo(C, M), Builder);
1895 Transpose = B;
1896 } else {
1897 MA = getMatrix(T, ShapeInfo(R, M), Builder);
1898 MB = getMatrix(B, ShapeInfo(C, M), Builder);
1899 Transpose = A;
1900 }
1901
1902 // Initialize the output
1903 MatrixTy Result(R, C, EltType);
1904
1905 emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1906 getFastMathFlags(MatMul));
1907
1908 FusedInsts.insert(MatMul);
1909 if (Transpose->hasOneUse()) {
1910 FusedInsts.insert(cast<Instruction>(Transpose));
1911 ToRemove.push_back(cast<Instruction>(Transpose));
1912 // TODO: add a fake entry for the folded instruction so that this is
1913 // included in the expression in the remark.
1914 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1915 }
1916 finalizeLowering(MatMul, Result, Builder);
1917 return;
1918 }
1919
1920 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1921 return;
1922
1923 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1924 // since the single store user will be lowered as part of this.
1925 auto *LoadOp0 = dyn_cast<LoadInst>(A);
1926 auto *LoadOp1 = dyn_cast<LoadInst>(B);
1927 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1928 if (LoadOp0 && LoadOp1 && Store) {
1929 // The store address must dominate the MatMul instruction, otherwise
1930 // we create invalid IR.
1931 SetVector<Value *> WorkList;
1932 WorkList.insert(Store->getOperand(1));
1934 for (unsigned I = 0; I != WorkList.size(); ++I) {
1935 Value *Current = WorkList[I];
1936 auto *CurrI = dyn_cast<Instruction>(Current);
1937 if (!CurrI)
1938 continue;
1939 if (isa<PHINode>(CurrI))
1940 return;
1941 if (DT->dominates(CurrI, MatMul))
1942 continue;
1943 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1944 return;
1945 ToHoist.push_back(CurrI);
1946 WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1947 }
1948
1949 sort(ToHoist, [this](Instruction *A, Instruction *B) {
1950 return DT->dominates(A, B);
1951 });
1952 for (Instruction *I : ToHoist)
1953 I->moveBefore(MatMul);
1954
1955 // Deal with lifetime.end calls that might be between Load0/Load1 and the
1956 // store. To avoid introducing loads to dead objects (i.e. after the
1957 // lifetime has been termined by @llvm.lifetime.end), either sink them
1958 // after the store if in the same block, or remove the lifetime.end marker
1959 // otherwise. This might pessimize further optimizations, by extending the
1960 // lifetime of the object until the function returns, but should be
1961 // conservatively correct.
1962 MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
1963 MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
1964 BasicBlock *StoreParent = Store->getParent();
1965 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
1966 LoadOp1->getParent() == StoreParent;
1967 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
1968 IntrinsicInst *End = LifetimeEnds[Idx];
1969 auto Inc = make_scope_exit([&Idx]() { Idx++; });
1970 // If the lifetime.end is guaranteed to be before the loads or after the
1971 // store, it won't interfere with fusion.
1972 if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
1973 continue;
1974 if (DT->dominates(Store, End))
1975 continue;
1976 // If all fusable ops are in the same block and the lifetime.end is in a
1977 // different block, it won't interfere with fusion.
1978 if (FusableOpsInSameBlock && End->getParent() != StoreParent)
1979 continue;
1980
1981 // If the loads don't alias the lifetime.end, it won't interfere with
1982 // fusion.
1984 if (!EndLoc.Ptr)
1985 continue;
1986 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
1987 continue;
1988
1989 // If both lifetime.end and the store are in the same block, extend the
1990 // lifetime until after the store, so the new lifetime covers the loads
1991 // we introduce later.
1992 if (End->getParent() == StoreParent) {
1993 End->moveAfter(Store);
1994 continue;
1995 }
1996
1997 // Otherwise remove the conflicting lifetime.end marker.
1998 ToRemove.push_back(End);
1999 std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
2000 LifetimeEnds.pop_back();
2001 Inc.release();
2002 }
2003
2004 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2005 return;
2006 }
2007 }
2008
2009 /// Lowers llvm.matrix.multiply.
2010 void LowerMultiply(CallInst *MatMul) {
2011 IRBuilder<> Builder(MatMul);
2012 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
2013 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2014 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2015
2016 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
2017 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
2018 assert(Lhs.getElementType() == Rhs.getElementType() &&
2019 "Matrix multiply argument element types do not match.");
2020
2021 const unsigned R = LShape.NumRows;
2022 const unsigned C = RShape.NumColumns;
2023 assert(LShape.NumColumns == RShape.NumRows);
2024
2025 // Initialize the output
2026 MatrixTy Result(R, C, EltType);
2027 assert(Lhs.getElementType() == Result.getElementType() &&
2028 "Matrix multiply result element type does not match arguments.");
2029
2030 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
2031 getFastMathFlags(MatMul));
2032 finalizeLowering(MatMul, Result, Builder);
2033 }
2034
2035 /// Lowers llvm.matrix.transpose.
2036 void LowerTranspose(CallInst *Inst) {
2037 MatrixTy Result;
2038 IRBuilder<> Builder(Inst);
2039 Value *InputVal = Inst->getArgOperand(0);
2040 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
2041 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
2042 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2043
2044 const unsigned NewNumVecs =
2045 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2046 const unsigned NewNumElts =
2047 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2048
2049 for (unsigned I = 0; I < NewNumVecs; ++I) {
2050 // Build a single result vector. First initialize it.
2051 Value *ResultVector = PoisonValue::get(
2052 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
2053 // Go through the old elements and insert it into the resulting vector.
2054 for (auto J : enumerate(InputMatrix.vectors())) {
2055 Value *Elt = Builder.CreateExtractElement(J.value(), I);
2056 // Row and column indices are transposed.
2057 ResultVector =
2058 Builder.CreateInsertElement(ResultVector, Elt, J.index());
2059 }
2060 Result.addVector(ResultVector);
2061 }
2062
2063 // TODO: Improve estimate of operations needed for transposes. Currently we
2064 // just count the insertelement/extractelement instructions, but do not
2065 // account for later simplifications/combines.
2066 finalizeLowering(
2067 Inst,
2068 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2069 .addNumExposedTransposes(1),
2070 Builder);
2071 }
2072
2073 /// Lower load instructions, if shape information is available.
2074 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
2075 auto I = ShapeMap.find(Inst);
2076 if (I == ShapeMap.end())
2077 return false;
2078
2079 LowerLoad(Inst, Ptr, Inst->getAlign(),
2080 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2081 I->second);
2082 return true;
2083 }
2084
2085 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
2086 IRBuilder<> &Builder) {
2087 auto I = ShapeMap.find(StoredVal);
2088 if (I == ShapeMap.end())
2089 return false;
2090
2091 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2092 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2093 I->second);
2094 return true;
2095 }
2096
2097 /// Lower binary operators, if shape information is available.
2098 bool VisitBinaryOperator(BinaryOperator *Inst) {
2099 auto I = ShapeMap.find(Inst);
2100 if (I == ShapeMap.end())
2101 return false;
2102
2103 Value *Lhs = Inst->getOperand(0);
2104 Value *Rhs = Inst->getOperand(1);
2105
2106 IRBuilder<> Builder(Inst);
2107 ShapeInfo &Shape = I->second;
2108
2109 MatrixTy Result;
2110 MatrixTy A = getMatrix(Lhs, Shape, Builder);
2111 MatrixTy B = getMatrix(Rhs, Shape, Builder);
2112 assert(A.isColumnMajor() == B.isColumnMajor() &&
2113 Result.isColumnMajor() == A.isColumnMajor() &&
2114 "operands must agree on matrix layout");
2115
2116 Builder.setFastMathFlags(getFastMathFlags(Inst));
2117
2118 // Helper to perform binary op on vectors.
2119 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2120 switch (Inst->getOpcode()) {
2121 case Instruction::Add:
2122 return Builder.CreateAdd(LHS, RHS);
2123 case Instruction::Mul:
2124 return Builder.CreateMul(LHS, RHS);
2125 case Instruction::Sub:
2126 return Builder.CreateSub(LHS, RHS);
2127 case Instruction::FAdd:
2128 return Builder.CreateFAdd(LHS, RHS);
2129 case Instruction::FMul:
2130 return Builder.CreateFMul(LHS, RHS);
2131 case Instruction::FSub:
2132 return Builder.CreateFSub(LHS, RHS);
2133 default:
2134 llvm_unreachable("Unsupported binary operator for matrix");
2135 }
2136 };
2137
2138 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2139 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
2140
2141 finalizeLowering(Inst,
2142 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2143 Result.getNumVectors()),
2144 Builder);
2145 return true;
2146 }
2147
2148 /// Lower unary operators, if shape information is available.
2149 bool VisitUnaryOperator(UnaryOperator *Inst) {
2150 auto I = ShapeMap.find(Inst);
2151 if (I == ShapeMap.end())
2152 return false;
2153
2154 Value *Op = Inst->getOperand(0);
2155
2156 IRBuilder<> Builder(Inst);
2157 ShapeInfo &Shape = I->second;
2158
2159 MatrixTy Result;
2160 MatrixTy M = getMatrix(Op, Shape, Builder);
2161
2162 Builder.setFastMathFlags(getFastMathFlags(Inst));
2163
2164 // Helper to perform unary op on vectors.
2165 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2166 switch (Inst->getOpcode()) {
2167 case Instruction::FNeg:
2168 return Builder.CreateFNeg(Op);
2169 default:
2170 llvm_unreachable("Unsupported unary operator for matrix");
2171 }
2172 };
2173
2174 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2175 Result.addVector(BuildVectorOp(M.getVector(I)));
2176
2177 finalizeLowering(Inst,
2178 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2179 Result.getNumVectors()),
2180 Builder);
2181 return true;
2182 }
2183
2184 /// Helper to linearize a matrix expression tree into a string. Currently
2185 /// matrix expressions are linarized by starting at an expression leaf and
2186 /// linearizing bottom up.
2187 struct ExprLinearizer {
2188 unsigned LengthToBreak = 100;
2189 std::string Str;
2190 raw_string_ostream Stream;
2191 unsigned LineLength = 0;
2192 const DataLayout &DL;
2193
2194 /// Mapping from instructions to matrixes. It is used to identify
2195 /// matrix instructions.
2196 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2197
2198 /// Mapping from values to the leaves of all expressions that the value is
2199 /// part of.
2201
2202 /// Set of matrix expressions in the scope of a given DISubprogram.
2203 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2204
2205 /// Leaf node of the expression to linearize.
2206 Value *Leaf;
2207
2208 /// Used to keep track of sub-expressions that get reused while linearizing
2209 /// the expression. Re-used sub-expressions are marked as (reused).
2210 SmallPtrSet<Value *, 8> ReusedExprs;
2211
2212 ExprLinearizer(const DataLayout &DL,
2213 const MapVector<Value *, MatrixTy> &Inst2Matrix,
2214 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2215 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2216 Value *Leaf)
2217 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2218 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2219
2220 void indent(unsigned N) {
2221 LineLength += N;
2222 for (unsigned i = 0; i < N; i++)
2223 Stream << " ";
2224 }
2225
2226 void lineBreak() {
2227 Stream << "\n";
2228 LineLength = 0;
2229 }
2230
2231 void maybeIndent(unsigned Indent) {
2232 if (LineLength >= LengthToBreak)
2233 lineBreak();
2234
2235 if (LineLength == 0)
2236 indent(Indent);
2237 }
2238
2239 void write(StringRef S) {
2240 LineLength += S.size();
2241 Stream << S;
2242 }
2243
2244 Value *getUnderlyingObjectThroughLoads(Value *V) {
2245 if (Value *Ptr = getPointerOperand(V))
2246 return getUnderlyingObjectThroughLoads(Ptr);
2247 else if (V->getType()->isPointerTy())
2248 return getUnderlyingObject(V);
2249 return V;
2250 }
2251
2252 /// Returns true if \p V is a matrix value in the given subprogram.
2253 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2254
2255 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2256 /// \p SS.
2257 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2258 auto M = Inst2Matrix.find(V);
2259 if (M == Inst2Matrix.end())
2260 SS << "unknown";
2261 else {
2262 SS << M->second.getNumRows();
2263 SS << "x";
2264 SS << M->second.getNumColumns();
2265 }
2266 }
2267
2268 /// Write the called function name. Handles calls to llvm.matrix.*
2269 /// specially: we write the name, followed by the dimensions of the input
2270 /// matrixes, followed by the scalar type name.
2271 void writeFnName(CallInst *CI) {
2272 if (!CI->getCalledFunction())
2273 write("<no called fn>");
2274 else {
2276 if (!Name.starts_with("llvm.matrix")) {
2277 write(Name);
2278 return;
2279 }
2280 auto *II = cast<IntrinsicInst>(CI);
2281 write(Intrinsic::getBaseName(II->getIntrinsicID())
2282 .drop_front(StringRef("llvm.matrix.").size()));
2283 write(".");
2284 std::string Tmp;
2286
2287 switch (II->getIntrinsicID()) {
2288 case Intrinsic::matrix_multiply:
2289 prettyPrintMatrixType(II->getOperand(0), SS);
2290 SS << ".";
2291 prettyPrintMatrixType(II->getOperand(1), SS);
2292 SS << "." << *II->getType()->getScalarType();
2293 break;
2294 case Intrinsic::matrix_transpose:
2295 prettyPrintMatrixType(II->getOperand(0), SS);
2296 SS << "." << *II->getType()->getScalarType();
2297 break;
2298 case Intrinsic::matrix_column_major_load:
2299 prettyPrintMatrixType(II, SS);
2300 SS << "." << *II->getType()->getScalarType();
2301 break;
2302 case Intrinsic::matrix_column_major_store:
2303 prettyPrintMatrixType(II->getOperand(0), SS);
2304 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2305 break;
2306 default:
2307 llvm_unreachable("Unhandled case");
2308 }
2309 SS.flush();
2310 write(Tmp);
2311 }
2312 }
2313
2314 unsigned getNumShapeArgs(CallInst *CI) const {
2315 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2316 switch (II->getIntrinsicID()) {
2317 case Intrinsic::matrix_multiply:
2318 return 3;
2319 case Intrinsic::matrix_transpose:
2320 return 2;
2321 case Intrinsic::matrix_column_major_load:
2322 case Intrinsic::matrix_column_major_store:
2323 return 3;
2324 default:
2325 return 0;
2326 }
2327 }
2328 return 0;
2329 }
2330
2331 /// Special printing for values: for pointers, we print if they refer to an
2332 /// (function) external address or a stack address, for other values we
2333 /// either print the constant or "scalar"/"matrix" for other values.
2334 void write(Value *V) {
2335 V = getUnderlyingObjectThroughLoads(V);
2336 if (V->getType()->isPointerTy()) {
2337 if (isa<AllocaInst>(V)) {
2338 Stream << "stack addr";
2339 LineLength += StringRef("stack addr").size();
2340 } else {
2341 Stream << "addr";
2342 LineLength += StringRef("addr").size();
2343 }
2344 if (!V->getName().empty()) {
2345 Stream << " %" << V->getName() << "";
2346 LineLength += V->getName().size() + 2;
2347 }
2348 return;
2349 }
2350
2351 std::string Tmp;
2352 raw_string_ostream TmpStream(Tmp);
2353
2354 if (auto *CI = dyn_cast<ConstantInt>(V))
2355 TmpStream << CI->getValue();
2356 else if (isa<Constant>(V))
2357 TmpStream << "constant";
2358 else {
2359 if (isMatrix(V))
2360 TmpStream << "matrix";
2361 else
2362 TmpStream << "scalar";
2363 }
2364 TmpStream.flush();
2365 Tmp = std::string(StringRef(Tmp).trim());
2366 LineLength += Tmp.size();
2367 Stream << Tmp;
2368 }
2369
2370 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2371 /// Expressions that are re-used multiple times are prefixed with (reused)
2372 /// at the re-used root instruction.
2373 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2374 bool ParentShared) {
2375 auto *I = cast<Instruction>(Expr);
2376 maybeIndent(Indent);
2378
2379 // Is Expr shared with other expression leaves?
2380 bool ExprShared = false;
2381
2382 // Deal with shared subtrees. Mark them as shared, if required.
2383 if (!ParentShared) {
2384 auto SI = Shared.find(Expr);
2385 assert(SI != Shared.end() && SI->second.count(Leaf));
2386
2387 for (Value *S : SI->second) {
2388 if (S == Leaf)
2389 continue;
2390 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2391 write("shared with remark at line " + std::to_string(DL.getLine()) +
2392 " column " + std::to_string(DL.getCol()) + " (");
2393 }
2394 ExprShared = SI->second.size() > 1;
2395 }
2396
2397 bool Reused = !ReusedExprs.insert(Expr).second;
2398 if (Reused && !ParentReused)
2399 write("(reused) ");
2400
2401 if (auto *CI = dyn_cast<CallInst>(I)) {
2402 writeFnName(CI);
2403
2404 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2405 } else if (isa<BitCastInst>(Expr)) {
2406 // Special case bitcasts, which are used to materialize matrixes from
2407 // non-matrix ops.
2408 write("matrix");
2409 return;
2410 } else {
2411 Ops.append(I->value_op_begin(), I->value_op_end());
2412 write(std::string(I->getOpcodeName()));
2413 }
2414
2415 write(std::string("("));
2416
2417 unsigned NumOpsToBreak = 1;
2418 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2419 NumOpsToBreak = 2;
2420
2421 for (Value *Op : Ops) {
2422 if (Ops.size() > NumOpsToBreak)
2423 lineBreak();
2424
2425 maybeIndent(Indent + 1);
2426 if (isMatrix(Op))
2427 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2428 else
2429 write(Op);
2430 if (Op != Ops.back())
2431 write(", ");
2432 }
2433
2434 write(")");
2435 }
2436
2437 const std::string &getResult() {
2438 Stream.flush();
2439 return Str;
2440 }
2441 };
2442
2443 /// Generate remarks for matrix operations in a function. To generate remarks
2444 /// for matrix expressions, the following approach is used:
2445 /// 1. Use the inlined-at debug information to group matrix operations to the
2446 /// DISubprograms they are contained in.
2447 /// 2. Collect leaves of matrix expressions (done in
2448 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2449 // mapping. Leaves are lowered matrix instructions without other matrix
2450 // users (like stores) in the current subprogram.
2451 /// 3. For each leaf, create a remark containing a linearizied version of the
2452 /// matrix expression. The expression is linearized by a recursive
2453 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2454 /// that multiple leaves can share sub-expressions. Shared subexpressions
2455 /// are explicitly marked as shared().
2456 struct RemarkGenerator {
2457 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2459 Function &Func;
2460 const DataLayout &DL;
2461
2462 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2464 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2465 DL(Func.getDataLayout()) {}
2466
2467 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2468 /// instructions in Inst2Matrix returning void or without any users in
2469 /// \p ExprsInSubprogram. Currently that should only include stores.
2471 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2473 for (auto *Expr : ExprsInSubprogram)
2474 if (Expr->getType()->isVoidTy() ||
2475 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2476 return ExprsInSubprogram.count(U);
2477 }))
2478 Leaves.push_back(Expr);
2479 return Leaves;
2480 }
2481
2482 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2483 /// to all visited expressions in \p Shared. Limit the matrix operations to
2484 /// the ones in \p ExprsInSubprogram.
2485 void collectSharedInfo(Value *Leaf, Value *V,
2486 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2488
2489 if (!ExprsInSubprogram.count(V))
2490 return;
2491
2492 auto I = Shared.insert({V, {}});
2493 I.first->second.insert(Leaf);
2494
2495 for (Value *Op : cast<Instruction>(V)->operand_values())
2496 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2497 }
2498
2499 /// Calculate the number of exclusive and shared op counts for expression
2500 /// starting at \p V. Expressions used multiple times are counted once.
2501 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2502 std::pair<OpInfoTy, OpInfoTy>
2503 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2504 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2505 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2506 if (!ExprsInSubprogram.count(Root))
2507 return {};
2508
2509 // Already counted this expression. Stop.
2510 if (!ReusedExprs.insert(Root).second)
2511 return {};
2512
2513 OpInfoTy SharedCount;
2514 OpInfoTy Count;
2515
2516 auto I = Shared.find(Root);
2517 auto CM = Inst2Matrix.find(Root);
2518 if (I->second.size() == 1)
2519 Count = CM->second.getOpInfo();
2520 else
2521 SharedCount = CM->second.getOpInfo();
2522
2523 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2524 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2525 Count += C.first;
2526 SharedCount += C.second;
2527 }
2528 return {Count, SharedCount};
2529 }
2530
2531 void emitRemarks() {
2533 return;
2534
2535 // Map matrix operations to their containting subprograms, by traversing
2536 // the inlinedAt chain. If the function does not have a DISubprogram, we
2537 // only map them to the containing function.
2539 for (const auto &KV : Inst2Matrix) {
2540 if (Func.getSubprogram()) {
2541 auto *I = cast<Instruction>(KV.first);
2542 DILocation *Context = I->getDebugLoc();
2543 while (Context) {
2544 auto I =
2545 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
2546 I.first->second.push_back(KV.first);
2547 Context = DebugLoc(Context).getInlinedAt();
2548 }
2549 } else {
2550 auto I = Subprog2Exprs.insert({nullptr, {}});
2551 I.first->second.push_back(KV.first);
2552 }
2553 }
2554 for (auto &KV : Subprog2Exprs) {
2555 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2556 KV.second.end());
2557 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2558
2560 for (Value *Leaf : Leaves)
2561 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2562
2563 // Generate remarks for each leaf.
2564 for (auto *L : Leaves) {
2565
2566 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2567 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2568 while (Context) {
2569 if (getSubprogram(Context->getScope()) == KV.first) {
2570 Loc = Context;
2571 break;
2572 }
2573 Context = DebugLoc(Context).getInlinedAt();
2574 }
2575
2576 SmallPtrSet<Value *, 8> ReusedExprs;
2577 OpInfoTy Counts, SharedCounts;
2578 std::tie(Counts, SharedCounts) =
2579 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2580
2581 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2582 cast<Instruction>(L)->getParent());
2583
2584 Rem << "Lowered with ";
2585 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2586 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2587 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2588 << " compute ops, "
2589 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2590 << " exposed transposes";
2591
2592 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2593 SharedCounts.NumComputeOps > 0) {
2594 Rem << ",\nadditionally "
2595 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2596 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2597 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2598 << " compute ops"
2599 << " are shared with other expressions";
2600 }
2601
2602 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2603 ORE.emit(Rem);
2604 }
2605 }
2606 }
2607
2608 std::string
2609 linearize(Value *L,
2610 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2611 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2612 const DataLayout &DL) {
2613 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2614 Lin.linearizeExpr(L, 0, false, false);
2615 return Lin.getResult();
2616 }
2617 };
2618};
2619} // namespace
2620
2623 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2624 OptimizationRemarkEmitter *ORE = nullptr;
2625 AAResults *AA = nullptr;
2626 DominatorTree *DT = nullptr;
2627 LoopInfo *LI = nullptr;
2628
2629 if (!Minimal) {
2631 AA = &AM.getResult<AAManager>(F);
2633 LI = &AM.getResult<LoopAnalysis>(F);
2634 }
2635
2636 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2637 if (LMT.Visit()) {
2639 if (!Minimal) {
2640 PA.preserve<LoopAnalysis>();
2642 }
2643 return PA;
2644 }
2645 return PreservedAnalyses::all();
2646}
2647
2649 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2651 OS, MapClassName2PassName);
2652 OS << '<';
2653 if (Minimal)
2654 OS << "minimal";
2655 OS << '>';
2656}
Rewrite undef for PHI
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static const Function * getParent(const Value *V)
BitTracker BT
Definition: BitTracker.cpp:73
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition: CommandLine.h:686
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
std::string Name
bool End
Definition: ELF_riscv.cpp:480
Hexagon Common GEP
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:512
Live Register Matrix
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB)
Erase V from BB and move \II forward to avoid invalidating iterators.
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
#define DEBUG_TYPE
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define T1
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned getNumElements(Type *Ty)
static unsigned getFastMathFlags(const MachineInstr &I)
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition: SROA.cpp:2528
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition: SROA.cpp:2550
raw_pwrite_stream & OS
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallSet class.
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
static const int BlockSize
Definition: TarWriter.cpp:33
This pass exposes codegen information to IR-level passes.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Value * RHS
Value * LHS
BinaryOperator * Mul
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Definition: Instructions.h:61
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition: Instructions.h:122
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:405
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:448
reverse_iterator rbegin()
Definition: BasicBlock.h:464
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:179
reverse_iterator rend()
Definition: BasicBlock.h:466
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
BinaryOps getOpcode() const
Definition: InstrTypes.h:442
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1465
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1385
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1838
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1410
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1391
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1650
This is the shared class of boolean and integer constants.
Definition: Constants.h:81
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Debug location.
Base class for scope-like contexts.
Subprogram description.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
A debug info location.
Definition: DebugLoc.h:33
DILocation * getInlinedAt() const
Definition: DebugLoc.cpp:39
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:20
void setAllowContract(bool B=true)
Definition: FMF.h:91
bool allowReassoc() const
Flag queries.
Definition: FMF.h:65
bool allowContract() const
Definition: FMF.h:70
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:539
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:680
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition: Function.h:249
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition: Function.h:254
CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Definition: IRBuilder.cpp:418
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:2277
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1577
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2492
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1790
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2480
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition: IRBuilder.h:1824
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1550
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
Definition: IRBuilder.cpp:1193
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
Definition: IRBuilder.cpp:434
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
Definition: IRBuilder.h:572
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Definition: IRBuilder.h:308
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Definition: IRBuilder.h:1883
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition: IRBuilder.h:488
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition: IRBuilder.h:2417
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1361
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
Definition: IRBuilder.h:494
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition: IRBuilder.h:1137
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1807
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition: IRBuilder.h:2514
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1344
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2137
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Definition: IRBuilder.h:1843
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args=std::nullopt, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition: IRBuilder.h:2432
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1604
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition: IRBuilder.h:1747
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memcpy between the specified pointers.
Definition: IRBuilder.h:656
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1378
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2686
static InstructionCost getInvalid(CostType Val=0)
void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:92
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:48
An instruction for reading from memory.
Definition: Instructions.h:174
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition: Instructions.h:203
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:209
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:36
iterator end()
Definition: MapVector.h:71
iterator find(const KeyT &Key)
Definition: MapVector.h:167
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: MapVector.h:141
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)
Return a location representing a particular argument of a call.
The optimization diagnostic interface.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1852
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
void preserve()
Mark an analysis as preserved.
Definition: Analysis.h:131
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:346
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:435
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:367
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:441
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:502
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
bool empty() const
Definition: SmallSet.h:159
bool erase(const T &V)
Definition: SmallSet.h:207
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:290
Align getAlign() const
Definition: Instructions.h:329
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:321
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition: StringRef.h:594
constexpr size_t size() const
size - Get the string size.
Definition: StringRef.h:137
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
TypeSize getRegisterBitWidth(RegisterKind K) const
InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args=std::nullopt, const Instruction *CxtI=nullptr, const TargetLibraryInfo *TLibInfo=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
@ TCK_RecipThroughput
Reciprocal throughput.
unsigned getNumberOfRegisters(unsigned ClassID) const
InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *Tp, ArrayRef< int > Mask=std::nullopt, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args=std::nullopt, const Instruction *CxtI=nullptr) const
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:139
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:343
UnaryOps getOpcode() const
Definition: InstrTypes.h:171
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
Value * getOperand(unsigned i) const
Definition: User.h:169
See the file comment.
Definition: ValueMap.h:84
size_type count(const KeyT &Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: ValueMap.h:151
iterator find(const KeyT &Val)
Definition: ValueMap.h:155
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: ValueMap.h:172
iterator end()
Definition: ValueMap.h:135
bool erase(const KeyT &Val)
Definition: ValueMap.h:190
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
user_iterator user_begin()
Definition: Value.h:397
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:149
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Type * getElementType() const
Definition: DerivedTypes.h:436
constexpr ScalarTy getFixedValue() const
Definition: TypeSize.h:202
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition: ilist_node.h:32
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:661
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:121
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition: Function.cpp:1091
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1539
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:100
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
Definition: PatternMatch.h:972
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
Definition: PatternMatch.h:67
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:239
@ SS
Definition: X86.h:211
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition: CommandLine.h:711
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
ElementType
The element type of an SRV or UAV resource.
Definition: DXILABI.h:58
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
NodeAddr< FuncNode * > Func
Definition: RDFGraph.h:393
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:480
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1680
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition: STLExtras.h:2431
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2060
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)
Definition: DynamicAPInt.h:518
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:656
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition: LoopUtils.cpp:214
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1729
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:419
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)
Definition: DWP.cpp:625
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1647
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:167
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ FMulAdd
Sum of float products with llvm.fmuladd(a * b + sum).
@ Add
Sum of integers.
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition: Casting.h:565
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition: PassManager.h:69
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....
Definition: MatrixUtils.h:31