LLVM 22.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/ScopeExit.h"
25#include "llvm/ADT/Statistic.h"
33#include "llvm/IR/CFG.h"
34#include "llvm/IR/DataLayout.h"
37#include "llvm/IR/Function.h"
38#include "llvm/IR/IRBuilder.h"
39#include "llvm/IR/InstrTypes.h"
47#include "llvm/Support/Debug.h"
51
52#include <cmath>
53
54using namespace llvm;
55using namespace PatternMatch;
56
57#define DEBUG_TYPE "lower-matrix-intrinsics"
58
59STATISTIC(FlattenedMatrices, "Number of matrix flattenings");
60STATISTIC(ReshapedMatrices, "Number of matrix reshapes");
61STATISTIC(SplitMatrices, "Number of matrix splits");
62
63static cl::opt<bool>
64 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
65 cl::desc("Enable/disable fusing matrix instructions."));
66// TODO: Allow and use non-square tiles.
68 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
70 "Tile size for matrix instruction fusion using square-shaped tiles."));
71static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
73 cl::desc("Generate loop nest for tiling."));
75 "force-fuse-matrix", cl::init(false), cl::Hidden,
76 cl::desc("Force matrix instruction fusion even if not profitable."));
78 "matrix-allow-contract", cl::init(false), cl::Hidden,
79 cl::desc("Allow the use of FMAs if available and profitable. This may "
80 "result in different results, due to less rounding error."));
81
82static cl::opt<bool>
83 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
84 cl::desc("Enable/disable matrix shape verification."),
85 cl::init(false));
86
88
90 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
91 cl::desc("Sets the default matrix layout"),
93 "Use column-major layout"),
95 "Use row-major layout")));
96
97static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
98 cl::init(false));
99
100/// Helper function to either return Scope, if it is a subprogram or the
101/// attached subprogram for a local scope.
103 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
104 return Subprogram;
105 return cast<DILocalScope>(Scope)->getSubprogram();
106}
107
108/// Return true if V is a splat of a value (which is used when multiplying a
109/// matrix with a scalar).
110static bool isSplat(Value *V) {
111 if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
112 return SV->isZeroEltSplat();
113 return false;
114}
115
116/// Match any mul operation (fp or integer).
117template <typename LTy, typename RTy>
118auto m_AnyMul(const LTy &L, const RTy &R) {
119 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
120}
121
122/// Match any add operation (fp or integer).
123template <typename LTy, typename RTy>
124auto m_AnyAdd(const LTy &L, const RTy &R) {
125 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
126}
127
128namespace {
129
130// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
131// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
132// assuming \p Stride elements between start two consecutive vectors.
133// \p Stride must be >= \p NumElements.
134// For column-major matrixes, the function computes the address of a column
135// vectors and \p NumElements must be set to the number of elements in a column
136// (= number of rows of the matrix). For row-major matrixes, the function
137// computes the address of a row vector and \p NumElements must be set to the
138// number of elements in a column (= number of columns of the matrix).
139//
140// Consider a 4x4 matrix in column-mjaor layout like below
141//
142// 0 1 2 3
143// 0 v_0_0 v_0_1 v_0_2 v_0_3
144// 1 v_1_0 v_1_1 v_1_2 v_1_3
145// 2 v_2_0 v_2_1 v_2_2 v_2_3
146// 3 v_3_0 v_3_1 v_3_2 v_3_3
147
148// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
149// we need a pointer to the first element of the submatrix as base pointer.
150// Then we can use computeVectorAddr to compute the addresses for the columns
151// of the sub-matrix.
152//
153// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
154// -> just returns Base
155// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
156// -> returns Base + (1 * 4)
157// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
158// -> returns Base + (2 * 4)
159//
160// The graphic below illustrates the number of elements in a column (marked
161// with |) and the number of skipped elements (marked with }).
162//
163// v_0_0 v_0_1 {v_0_2 {v_0_3
164// Base Col 1 Col 2
165// | | |
166// v_1_0 |v_1_1 |v_1_2 |v_1_3
167// v_2_0 |v_2_1 |v_2_2 |v_2_3
168// v_3_0 {v_3_1 {v_3_2 v_3_3
169//
170Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
171 unsigned NumElements, Type *EltType,
172 IRBuilder<> &Builder) {
173
174 assert((!isa<ConstantInt>(Stride) ||
175 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
176 "Stride must be >= the number of elements in the result vector.");
177
178 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
179 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
180
181 // Get pointer to the start of the selected vector. Skip GEP creation,
182 // if we select vector 0.
183 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
184 VecStart = BasePtr;
185 else
186 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
187
188 return VecStart;
189}
190
191namespace {
192struct ShapeInfo {
193 unsigned NumRows;
194 unsigned NumColumns;
195
196 bool IsColumnMajor;
197
198 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
199 : NumRows(NumRows), NumColumns(NumColumns),
200 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
201
202 ShapeInfo(Value *NumRows, Value *NumColumns)
203 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
204 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
205
206 bool operator==(const ShapeInfo &other) {
207 return NumRows == other.NumRows && NumColumns == other.NumColumns;
208 }
209 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
210
211 /// Returns true if shape-information is defined, meaning both dimensions
212 /// are != 0.
213 operator bool() const {
214 assert(NumRows == 0 || NumColumns != 0);
215 return NumRows != 0;
216 }
217
218 unsigned getStride() const {
219 if (IsColumnMajor)
220 return NumRows;
221 return NumColumns;
222 }
223
224 unsigned getNumVectors() const {
225 if (IsColumnMajor)
226 return NumColumns;
227 return NumRows;
228 }
229
230 /// Returns the transposed shape.
231 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
232
233 friend raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI);
234
235 LLVM_DUMP_METHOD void dump() const { dbgs() << *this << '\n'; }
236};
237
238raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
239 return OS << SI.NumRows << 'x' << SI.NumColumns;
240}
241
242} // namespace
243
244static bool isUniformShape(Value *V) {
246 if (!I)
247 return true;
248
249 if (I->isBinaryOp())
250 return true;
251
252 if (auto *Cast = dyn_cast<CastInst>(V)) {
253 switch (Cast->getOpcode()) {
254 case llvm::Instruction::Trunc:
255 case llvm::Instruction::ZExt:
256 case llvm::Instruction::SExt:
257 case llvm::Instruction::FPToUI:
258 case llvm::Instruction::FPToSI:
259 case llvm::Instruction::UIToFP:
260 case llvm::Instruction::SIToFP:
261 case llvm::Instruction::FPTrunc:
262 case llvm::Instruction::FPExt:
263 return true;
264 case llvm::Instruction::AddrSpaceCast:
265 case CastInst::PtrToAddr:
266 case CastInst::PtrToInt:
267 case CastInst::IntToPtr:
268 return false;
269 case CastInst::BitCast: {
270 if (auto *SrcVTy = dyn_cast<FixedVectorType>(Cast->getSrcTy()))
271 if (auto *DestVTy = dyn_cast<FixedVectorType>(Cast->getDestTy()))
272 return SrcVTy->getNumElements() == DestVTy->getNumElements();
273 return false;
274 }
275 case llvm::Instruction::CastOpsEnd:
276 llvm_unreachable("not an actual cast op");
277 }
278 llvm_unreachable("unhandled cast opcode");
279 }
280
281 if (auto *II = dyn_cast<IntrinsicInst>(V))
282 switch (II->getIntrinsicID()) {
283 case Intrinsic::abs:
284 case Intrinsic::fabs:
285 return true;
286 default:
287 return false;
288 }
289
290 switch (I->getOpcode()) {
291 case Instruction::PHI:
292 case Instruction::FNeg:
293 return true;
294 default:
295 return false;
296 }
297}
298
299/// Return the ShapeInfo for the result of \p I, it it can be determined.
300static std::optional<ShapeInfo>
301computeShapeInfoForInst(Instruction *I,
302 const DenseMap<Value *, ShapeInfo> &ShapeMap) {
303 Value *M;
304 Value *N;
305 Value *K;
307 m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
308 return ShapeInfo(M, K);
310 m_Value(N)))) {
311 // Flip dimensions.
312 return ShapeInfo(N, M);
313 }
315 m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
316 m_Value(N))))
317 return ShapeInfo(N, M);
319 m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
320 return ShapeInfo(M, N);
321 Value *MatrixA;
322 if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
323 auto OpShape = ShapeMap.find(MatrixA);
324 if (OpShape != ShapeMap.end())
325 return OpShape->second;
326 }
327
328 if (isUniformShape(I) || isa<SelectInst>(I)) {
329 auto Ops = I->operands();
330 auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
331 // Find the first operand that has a known shape and use that.
332 for (auto &Op : ShapedOps) {
333 auto OpShape = ShapeMap.find(Op.get());
334 if (OpShape != ShapeMap.end())
335 return OpShape->second;
336 }
337 }
338 return std::nullopt;
339}
340
341/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
342///
343/// Currently, the lowering for each matrix intrinsic is done as follows:
344/// 1. Propagate the shape information from intrinsics to connected
345/// instructions.
346/// 2. Lower instructions with shape information (assuming column-major layout).
347/// The lowering works similarly using row-major layout.
348/// 2.1. Get column vectors for each argument. If we already lowered the
349/// definition of an argument, use the produced column vectors directly.
350/// If not, split the operand vector containing an embedded matrix into
351/// a set of column vectors,
352/// 2.2. Lower the instruction in terms of column major operations, which
353/// yields a set of column vectors containing result matrix. Note that we
354/// lower all instructions that have shape information. Besides the
355/// intrinsics, this includes stores for example.
356/// 2.3. Update uses of the lowered instruction. If we have shape information
357/// for a user, there is nothing to do, as we will look up the result
358/// column matrix when lowering the user. For other uses, we embed the
359/// result matrix in a flat vector and update the use.
360/// 2.4. Cache the result column matrix for the instruction we lowered
361/// 3. After we lowered all instructions in a function, remove the now
362/// obsolete instructions.
363///
364class LowerMatrixIntrinsics {
365 Function &Func;
366 const DataLayout &DL;
367 const TargetTransformInfo &TTI;
369 AliasAnalysis *AA = nullptr;
370 DominatorTree *DT = nullptr;
371 LoopInfo *LI = nullptr;
372 OptimizationRemarkEmitter *ORE = nullptr;
373
374 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
375 struct OpInfoTy {
376 /// Number of stores emitted to generate this matrix.
377 unsigned NumStores = 0;
378 /// Number of loads emitted to generate this matrix.
379 unsigned NumLoads = 0;
380 /// Number of compute operations emitted to generate this matrix.
381 unsigned NumComputeOps = 0;
382 /// Most of the time transposes can be fused with matrix multiplies or can
383 /// be folded away via algebraic simplifications. This is the number of
384 /// transposes that we failed to make "free" via such optimizations.
385 unsigned NumExposedTransposes = 0;
386
387 OpInfoTy &operator+=(const OpInfoTy &RHS) {
388 NumStores += RHS.NumStores;
389 NumLoads += RHS.NumLoads;
390 NumComputeOps += RHS.NumComputeOps;
391 NumExposedTransposes += RHS.NumExposedTransposes;
392 return *this;
393 }
394 };
395
396 /// Wrapper class representing a matrix as a set of vectors, either in row or
397 /// column major layout. All vectors must have the same vector type.
398 class MatrixTy {
399 SmallVector<Value *, 16> Vectors;
400
401 OpInfoTy OpInfo;
402
403 bool IsColumnMajor = true;
404
405 public:
406 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
407 MatrixTy(ArrayRef<Value *> Vectors)
408 : Vectors(Vectors),
409 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
410 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
411 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
412
413 unsigned D = isColumnMajor() ? NumColumns : NumRows;
414 for (unsigned J = 0; J < D; ++J)
416 EltTy, isColumnMajor() ? NumRows : NumColumns)));
417 }
418
419 Value *getVector(unsigned i) const { return Vectors[i]; }
420 Value *getColumn(unsigned i) const {
421 assert(isColumnMajor() && "only supported for column-major matrixes");
422 return Vectors[i];
423 }
424 Value *getRow(unsigned i) const {
425 assert(!isColumnMajor() && "only supported for row-major matrixes");
426 return Vectors[i];
427 }
428
429 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
430
431 Type *getElementType() const { return getVectorTy()->getElementType(); }
432
433 unsigned getNumVectors() const {
434 if (isColumnMajor())
435 return getNumColumns();
436 return getNumRows();
437 }
438
439 unsigned getNumColumns() const {
440 if (isColumnMajor())
441 return Vectors.size();
442 else {
443 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
444 return getVectorTy()->getNumElements();
445 }
446 }
447 unsigned getNumRows() const {
448 if (isColumnMajor()) {
449 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
450 return getVectorTy()->getNumElements();
451 } else
452 return Vectors.size();
453 }
454
455 void addVector(Value *V) { Vectors.push_back(V); }
456 FixedVectorType *getColumnTy() {
457 assert(isColumnMajor() && "only supported for column-major matrixes");
458 return getVectorTy();
459 }
460
461 FixedVectorType *getVectorTy() const {
462 return cast<FixedVectorType>(Vectors[0]->getType());
463 }
464
465 iterator_range<SmallVector<Value *, 8>::iterator> columns() {
466 assert(isColumnMajor() &&
467 "columns() only supported for column-major matrixes");
468 return make_range(Vectors.begin(), Vectors.end());
469 }
470
471 iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
472 return make_range(Vectors.begin(), Vectors.end());
473 }
474
475 /// Embed the vectors of the matrix into a flat vector by concatenating
476 /// them.
477 Value *embedInVector(IRBuilder<> &Builder) const {
478 return Vectors.size() == 1 ? Vectors[0]
479 : concatenateVectors(Builder, Vectors);
480 }
481
482 MatrixTy &addNumLoads(unsigned N) {
483 OpInfo.NumLoads += N;
484 return *this;
485 }
486
487 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
488
489 MatrixTy &addNumStores(unsigned N) {
490 OpInfo.NumStores += N;
491 return *this;
492 }
493
494 MatrixTy &addNumExposedTransposes(unsigned N) {
495 OpInfo.NumExposedTransposes += N;
496 return *this;
497 }
498
499 MatrixTy &addNumComputeOps(unsigned N) {
500 OpInfo.NumComputeOps += N;
501 return *this;
502 }
503
504 unsigned getNumStores() const { return OpInfo.NumStores; }
505 unsigned getNumLoads() const { return OpInfo.NumLoads; }
506 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
507
508 const OpInfoTy &getOpInfo() const { return OpInfo; }
509
510 bool isColumnMajor() const { return IsColumnMajor; }
511
512 unsigned getStride() const {
513 if (isColumnMajor())
514 return getNumRows();
515 return getNumColumns();
516 }
517
518 ShapeInfo shape() const { return {getNumRows(), getNumColumns()}; }
519
520 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
521 /// matrix is column-major, the result vector is extracted from a column
522 /// vector, otherwise from a row vector.
523 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
524 IRBuilder<> &Builder) const {
525 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
526 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
527 NumElts &&
528 "Extracted vector will contain poison values");
529 return Builder.CreateShuffleVector(
530 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
531 "block");
532 }
533 };
534
535 /// Maps instructions to their shape information. The shape information
536 /// describes the shape to be used while lowering. This matches the shape of
537 /// the result value of the instruction, with the only exceptions being store
538 /// instructions and the matrix_column_major_store intrinsics. For those, the
539 /// shape information indicates that those instructions should be lowered
540 /// using shape information as well. Note that extra care is needed when
541 /// erasing or RAUW'ing a value that is present in ShapeMap. If the
542 /// replacement is also a matrix operation, use
543 /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
544 /// ShapeMap. We don't use ValueMap, as there are also cases where we do not
545 /// want to add shape information for a replacement instruction. When directly
546 /// erasing a value with an entry in ShapeMap, use
547 /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
548 /// accordingly.
549 DenseMap<Value *, ShapeInfo> ShapeMap;
550
551 /// List of instructions to remove. While lowering, we are not replacing all
552 /// users of a lowered instruction, if shape information is available and
553 /// those need to be removed after we finished lowering.
554 SmallVector<Instruction *, 16> ToRemove;
555
556 /// Map from instructions to their produced column matrix.
557 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
558
559private:
560 static FastMathFlags getFastMathFlags(Instruction *Inst) {
561 FastMathFlags FMF;
562
563 if (isa<FPMathOperator>(*Inst))
564 FMF = Inst->getFastMathFlags();
565
567
568 return FMF;
569 }
570
571public:
572 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
574 : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
575
576 unsigned getNumOps(Type *VT) {
577 assert(isa<FixedVectorType>(VT) && "Expected vector type");
578 return getNumOps(VT->getScalarType(),
579 cast<FixedVectorType>(VT)->getNumElements());
580 }
581
582 /// Is this the minimal version executed in the backend pipelines.
583 bool isMinimal() const {
584 return !DT;
585 }
586
587 /// Return the estimated number of vector ops required for an operation on
588 /// \p VT * N.
589 unsigned getNumOps(Type *ST, unsigned N) {
590 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
591 double(TTI.getRegisterBitWidth(
593 .getFixedValue()));
594 }
595
596 /// Return the set of vectors that a matrix value is lowered to.
597 ///
598 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
599 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
600 /// into vectors.
601 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
602 IRBuilder<> &Builder) {
603 FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType());
604 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
605 "The vector size must match the number of matrix elements");
606
607 // Check if we lowered MatrixVal using shape information. In that case,
608 // return the existing matrix, if it matches the requested shape
609 // information. If there is a mis-match, embed the result in a flat
610 // vector and split it later.
611 auto Found = Inst2ColumnMatrix.find(MatrixVal);
612 if (Found != Inst2ColumnMatrix.end()) {
613 MatrixTy &M = Found->second;
614 // Return the found matrix, if its shape matches the requested shape
615 // information
616 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
617 return M;
618
619 MatrixVal = M.embedInVector(Builder);
620 }
621
622 // Otherwise split MatrixVal.
623 SmallVector<Value *, 16> SplitVecs;
624 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
625 MaskStart += SI.getStride()) {
626 Value *V = Builder.CreateShuffleVector(
627 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
628 "split");
629 SplitVecs.push_back(V);
630 }
631
632 if (Instruction *Inst = dyn_cast<Instruction>(MatrixVal)) {
633 if (Found != Inst2ColumnMatrix.end()) {
634 // FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles
635 // that embedInVector created.
636 LLVM_DEBUG(dbgs() << "matrix reshape from " << Found->second.shape()
637 << " to " << SI << " using at least "
638 << SplitVecs.size() << " shuffles on behalf of:\n"
639 << *Inst << '\n');
640 ReshapedMatrices++;
641 } else if (!ShapeMap.contains(MatrixVal)) {
643 dbgs()
644 << "splitting a " << SI << " matrix with " << SplitVecs.size()
645 << " shuffles beacuse we do not have a shape-aware lowering for "
646 "its def:\n"
647 << *Inst << '\n');
648 (void)Inst;
649 SplitMatrices++;
650 } else {
651 // The ShapeMap has it, so it's a case where we're being lowered
652 // before the def, and we expect that InstCombine will clean things up
653 // afterward.
654 }
655 }
656
657 return {SplitVecs};
658 }
659
660 /// If \p V already has a known shape return false. Otherwise set the shape
661 /// for instructions that support it.
662 bool setShapeInfo(Value *V, ShapeInfo Shape) {
663 assert(Shape && "Shape not set");
664 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
665 return false;
666
667 auto SIter = ShapeMap.find(V);
668 if (SIter != ShapeMap.end()) {
669 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
670 SIter->second.NumColumns != Shape.NumColumns)) {
671 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
672 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
673 << Shape.NumColumns << ") for " << *V << "\n";
675 "Matrix shape verification failed, compilation aborted!");
676 }
677
678 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
679 << SIter->second.NumRows << " "
680 << SIter->second.NumColumns << " for " << *V << "\n");
681 return false;
682 }
683
684 ShapeMap.insert({V, Shape});
685 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
686 << " for " << *V << "\n");
687 return true;
688 }
689
690 /// Returns true if shape information can be used for \p V. The supported
691 /// instructions must match the instructions that can be lowered by this pass.
692 bool supportsShapeInfo(Value *V) {
694 if (!Inst)
695 return false;
696
697 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
698 if (II)
699 switch (II->getIntrinsicID()) {
700 case Intrinsic::matrix_multiply:
701 case Intrinsic::matrix_transpose:
702 case Intrinsic::matrix_column_major_load:
703 case Intrinsic::matrix_column_major_store:
704 return true;
705 default:
706 return isUniformShape(II);
707 }
708 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
710 }
711
712 /// Propagate the shape information of instructions to their users.
713 /// The work list contains instructions for which we can compute the shape,
714 /// either based on the information provided by matrix intrinsics or known
715 /// shapes of operands.
717 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
719 // Pop an element for which we guaranteed to have at least one of the
720 // operand shapes. Add the shape for this and then add users to the work
721 // list.
722 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
723 while (!WorkList.empty()) {
724 Instruction *Inst = WorkList.pop_back_val();
725
726 // New entry, set the value and insert operands
727 bool Propagate = false;
728 if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
729 Propagate = setShapeInfo(Inst, *SI);
730
731 if (Propagate) {
732 NewWorkList.push_back(Inst);
733 for (auto *User : Inst->users())
734 if (ShapeMap.count(User) == 0)
735 WorkList.push_back(cast<Instruction>(User));
736 }
737 }
738
739 return NewWorkList;
740 }
741
742 /// Propagate the shape to operands of instructions with shape information.
743 /// \p Worklist contains the instruction for which we already know the shape.
745 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
747
748 auto pushInstruction = [](Value *V,
749 SmallVectorImpl<Instruction *> &WorkList) {
751 if (I)
752 WorkList.push_back(I);
753 };
754 // Pop an element with known shape. Traverse the operands, if their shape
755 // derives from the result shape and is unknown, add it and add them to the
756 // worklist.
757 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
758 while (!WorkList.empty()) {
759 Value *V = WorkList.pop_back_val();
760
761 size_t BeforeProcessingV = WorkList.size();
762 if (!isa<Instruction>(V))
763 continue;
764
765 Value *MatrixA;
766 Value *MatrixB;
767 Value *M;
768 Value *N;
769 Value *K;
771 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
772 m_Value(N), m_Value(K)))) {
773 if (setShapeInfo(MatrixA, {M, N}))
774 pushInstruction(MatrixA, WorkList);
775
776 if (setShapeInfo(MatrixB, {N, K}))
777 pushInstruction(MatrixB, WorkList);
778
780 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
781 // Flip dimensions.
782 if (setShapeInfo(MatrixA, {M, N}))
783 pushInstruction(MatrixA, WorkList);
785 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
786 m_Value(M), m_Value(N)))) {
787 if (setShapeInfo(MatrixA, {M, N})) {
788 pushInstruction(MatrixA, WorkList);
789 }
790 } else if (isa<LoadInst>(V) ||
792 // Nothing to do, no matrix input.
793 } else if (isa<StoreInst>(V)) {
794 // Nothing to do. We forward-propagated to this so we would just
795 // backward propagate to an instruction with an already known shape.
796 } else if (isUniformShape(V) || isa<SelectInst>(V)) {
797 auto Ops = cast<Instruction>(V)->operands();
798 auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
799 // Propagate to all operands.
800 ShapeInfo Shape = ShapeMap[V];
801 for (Use &U : ShapedOps) {
802 if (setShapeInfo(U.get(), Shape))
803 pushInstruction(U.get(), WorkList);
804 }
805 }
806 // After we discovered new shape info for new instructions in the
807 // worklist, we use their users as seeds for the next round of forward
808 // propagation.
809 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
810 for (User *U : WorkList[I]->users())
811 if (isa<Instruction>(U) && V != U)
812 NewWorkList.push_back(cast<Instruction>(U));
813 }
814 return NewWorkList;
815 }
816
817 /// (Op0 op Op1)^T -> Op0^T op Op1^T
818 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
819 /// them on both sides of \p Operation.
820 Instruction *distributeTransposes(
821 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
822 MatrixBuilder &Builder,
823 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
824 Operation) {
825 Value *T0 = Builder.CreateMatrixTranspose(
826 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
827 // We are being run after shape prop, add shape for newly created
828 // instructions so that we lower them later.
829 setShapeInfo(T0, Shape0.t());
830 Value *T1 = Builder.CreateMatrixTranspose(
831 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
832 setShapeInfo(T1, Shape1.t());
833 return Operation(T0, Shape0.t(), T1, Shape1.t());
834 }
835
836 /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
837 /// itself.
838 void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
839 ShapeMap.erase(Inst);
840 Inst->eraseFromParent();
841 }
842
843 /// Erase \p V from \p BB and move \II forward to avoid invalidating
844 /// iterators.
845 void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
846 BasicBlock &BB) {
847 auto *Inst = cast<Instruction>(V);
848 // Still used, don't erase.
849 if (!Inst->use_empty())
850 return;
851 if (II != BB.rend() && Inst == &*II)
852 ++II;
853 eraseFromParentAndRemoveFromShapeMap(Inst);
854 }
855
856 /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
857 /// entry for \p Old and replace all uses of \p Old with \p New.
858 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
859 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
860 // with New. We should only add New it it supportsShapeInfo so we insert
861 // it conditionally instead.
862 auto S = ShapeMap.find(&Old);
863 if (S != ShapeMap.end()) {
864 ShapeMap.erase(S);
865 if (supportsShapeInfo(New))
866 ShapeMap.insert({New, S->second});
867 }
868 Old.replaceAllUsesWith(New);
869 }
870
871 /// Sink a top-level transpose inside matmuls and adds.
872 /// This creates and erases instructions as needed, and returns the newly
873 /// created instruction while updating the iterator to avoid invalidation. If
874 /// this returns nullptr, no new instruction was created.
875 Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II,
876 bool &Changed) {
877 BasicBlock &BB = *I.getParent();
878 IRBuilder<> IB(&I);
879 MatrixBuilder Builder(IB);
880
881 Value *TA, *TAMA, *TAMB;
882 ConstantInt *R, *K, *C;
885 return nullptr;
886
887 // Transpose of a transpose is a nop when the shapes match.
888 Value *TATA;
890 m_Value(TATA), m_Specific(C), m_Specific(R)))) {
891 updateShapeAndReplaceAllUsesWith(I, TATA);
892 eraseFromParentAndMove(&I, II, BB);
893 eraseFromParentAndMove(TA, II, BB);
894 Changed = true;
895 return nullptr;
896 }
897
898 // k^T -> k
899 if (isSplat(TA)) {
900 updateShapeAndReplaceAllUsesWith(I, TA);
901 eraseFromParentAndMove(&I, II, BB);
902 Changed = true;
903 return nullptr;
904 }
905
906 // (A * B)^t -> B^t * A^t
907 // RxK KxC CxK KxR
909 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
911 auto NewInst = distributeTransposes(
912 TAMB, {K, C}, TAMA, {R, K}, Builder,
913 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
914 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
915 Shape0.NumColumns,
916 Shape1.NumColumns, "mmul");
917 });
918 updateShapeAndReplaceAllUsesWith(I, NewInst);
919 eraseFromParentAndMove(&I, II, BB);
920 eraseFromParentAndMove(TA, II, BB);
921 Changed = true;
922 return NewInst;
923 }
924
925 // Same as above, but with a mul, which occurs when multiplied
926 // with a scalar.
927 // (A * k)^t -> A^t * k
928 // R x C RxC
929 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
930 (isSplat(TAMA) || isSplat(TAMB))) {
931 IRBuilder<> LocalBuilder(&I);
932 // We know that the transposed operand is of shape RxC.
933 // An when multiplied with a scalar, the shape is preserved.
934 auto NewInst = distributeTransposes(
935 TAMA, {R, C}, TAMB, {R, C}, Builder,
936 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
937 bool IsFP = I.getType()->isFPOrFPVectorTy();
938 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
939 : LocalBuilder.CreateMul(T0, T1, "mmul");
941 setShapeInfo(Result, Shape0);
942 return Result;
943 });
944 updateShapeAndReplaceAllUsesWith(I, NewInst);
945 eraseFromParentAndMove(&I, II, BB);
946 eraseFromParentAndMove(TA, II, BB);
947 Changed = true;
948 return NewInst;
949 }
950
951 // (A + B)^t -> A^t + B^t
952 // RxC RxC CxR CxR
953 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
954 IRBuilder<> LocalBuilder(&I);
955 auto NewInst = distributeTransposes(
956 TAMA, {R, C}, TAMB, {R, C}, Builder,
957 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
958 bool IsFP = I.getType()->isFPOrFPVectorTy();
959 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
960 : LocalBuilder.CreateAdd(T0, T1, "madd");
961
963 setShapeInfo(Result, Shape0);
964 return Result;
965 });
966 updateShapeAndReplaceAllUsesWith(I, NewInst);
967 eraseFromParentAndMove(&I, II, BB);
968 eraseFromParentAndMove(TA, II, BB);
969 Changed = true;
970 return NewInst;
971 }
972
973 return nullptr;
974 }
975
976 bool liftTranspose(Instruction &I) {
977 // Erase dead Instructions after lifting transposes from binops.
978 auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
979 if (T.use_empty())
980 eraseFromParentAndRemoveFromShapeMap(&T);
981 if (A->use_empty())
982 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
983 if (A != B && B->use_empty())
984 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
985 };
986
987 Value *A, *B, *AT, *BT;
988 ConstantInt *R, *K, *C;
989 // A^t * B ^t -> (B * A)^t
995 IRBuilder<> IB(&I);
996 MatrixBuilder Builder(IB);
997 Value *M = Builder.CreateMatrixMultiply(
998 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
999 setShapeInfo(M, {C, R});
1000 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
1001 R->getZExtValue());
1002 updateShapeAndReplaceAllUsesWith(I, NewInst);
1003 CleanupBinOp(I, A, B);
1004 return true;
1005 }
1006 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
1007 // the shape of the second transpose is different, there's a shape conflict
1008 // which gets resolved by picking the shape of the first operand.
1009 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
1011 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
1014 IRBuilder<> Builder(&I);
1015 auto *Add = Builder.CreateFAdd(AT, BT, "mfadd");
1016 MatrixBuilder MBuilder(Builder);
1017 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
1018 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
1019 updateShapeAndReplaceAllUsesWith(I, NewInst);
1020 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
1021 computeShapeInfoForInst(&I, ShapeMap) &&
1022 "Shape of new instruction doesn't match original shape.");
1023 CleanupBinOp(I, A, B);
1024 if (auto *AddI = dyn_cast<Instruction>(Add)) {
1025 setShapeInfo(AddI, {R, C});
1026 assert(
1027 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
1028 ShapeMap[AddI] &&
1029 "Shape of updated addition doesn't match cached shape.");
1030 }
1031 return true;
1032 }
1033 return false;
1034 }
1035
1036 /// Try moving transposes in order to fold them away or into multiplies.
1037 bool optimizeTransposes() {
1038 bool Changed = false;
1039 // First sink all transposes inside matmuls and adds, hoping that we end up
1040 // with NN, NT or TN variants.
1041 for (BasicBlock &BB : reverse(Func)) {
1042 for (auto II = BB.rbegin(); II != BB.rend();) {
1043 Instruction &I = *II;
1044 // We may remove II. By default continue on the next/prev instruction.
1045 ++II;
1046 if (Instruction *NewInst = sinkTranspose(I, II, Changed))
1047 II = std::next(BasicBlock::reverse_iterator(NewInst));
1048 }
1049 }
1050
1051 // If we have a TT matmul or a TT add, lift the transpose. We may be able
1052 // to fold into consuming multiply or add.
1053 for (BasicBlock &BB : Func) {
1054 for (Instruction &I : llvm::make_early_inc_range(BB)) {
1055 Changed |= liftTranspose(I);
1056 }
1057 }
1058 return Changed;
1059 }
1060
1061 bool Visit() {
1063
1064 // Initially only the shape of matrix intrinsics is known.
1065 // Initialize the work list with ops carrying shape information.
1066 for (BasicBlock &BB : Func)
1067 for (Instruction &Inst : BB) {
1068 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
1069 if (!II)
1070 continue;
1071
1072 switch (II->getIntrinsicID()) {
1073 case Intrinsic::matrix_multiply:
1074 case Intrinsic::matrix_transpose:
1075 case Intrinsic::matrix_column_major_load:
1076 case Intrinsic::matrix_column_major_store:
1077 WorkList.push_back(&Inst);
1078 break;
1079 default:
1080 break;
1081 }
1082 }
1083
1084 // Avoid unnecessary work if there are no matrix intrinsics in the function.
1085 if (WorkList.empty())
1086 return false;
1087
1088 if (AM) {
1089 ORE = &AM->getResult<OptimizationRemarkEmitterAnalysis>(Func);
1090 AA = &AM->getResult<AAManager>(Func);
1091 DT = &AM->getResult<DominatorTreeAnalysis>(Func);
1092 LI = &AM->getResult<LoopAnalysis>(Func);
1093 }
1094
1095 // Propagate shapes until nothing changes any longer.
1096 while (!WorkList.empty()) {
1097 WorkList = propagateShapeForward(WorkList);
1098 WorkList = propagateShapeBackward(WorkList);
1099 }
1100
1101 bool Changed = false;
1102 if (!isMinimal()) {
1103 Changed |= optimizeTransposes();
1105 dbgs() << "Dump after matrix transpose optimization:\n";
1106 Func.print(dbgs());
1107 }
1108 }
1109
1110 SmallVector<CallInst *, 16> MaybeFusableInsts;
1111 SmallVector<Instruction *, 16> MatrixInsts;
1113
1114 // First, collect all instructions with shape information and candidates for
1115 // fusion (currently only matrix multiplies).
1116 ReversePostOrderTraversal<Function *> RPOT(&Func);
1117 for (auto *BB : RPOT)
1118 for (Instruction &I : *BB) {
1120 LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
1121 if (!ShapeMap.contains(&I))
1122 continue;
1124 MaybeFusableInsts.push_back(cast<CallInst>(&I));
1125 MatrixInsts.push_back(&I);
1126 }
1127
1128 // Second, try to lower any dot products
1129 SmallPtrSet<Instruction *, 16> FusedInsts;
1130 for (CallInst *CI : MaybeFusableInsts)
1131 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1132
1133 // Third, try to fuse candidates.
1134 for (CallInst *CI : MaybeFusableInsts)
1135 if (!FusedInsts.contains(CI))
1136 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1137
1138 Changed |= !FusedInsts.empty();
1139
1140 // Fourth, pre-process all the PHINode's. The incoming values will be
1141 // assigned later in VisitPHI.
1142 for (Instruction *Inst : MatrixInsts) {
1143 if (FusedInsts.count(Inst))
1144 continue;
1145
1146 auto *PHI = dyn_cast<PHINode>(Inst);
1147 if (!PHI)
1148 continue;
1149
1150 const ShapeInfo &SI = ShapeMap.at(Inst);
1151 auto *EltTy = cast<FixedVectorType>(PHI->getType())->getElementType();
1152 MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy);
1153
1154 IRBuilder<> Builder(Inst);
1155 for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
1156 PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),
1157 PHI->getNumIncomingValues(),
1158 PHI->getName()));
1159 assert(!Inst2ColumnMatrix.contains(PHI) && "map already contains phi?");
1160 Inst2ColumnMatrix[PHI] = PhiM;
1161 }
1162
1163 // Fifth, lower remaining instructions with shape information.
1164 for (Instruction *Inst : MatrixInsts) {
1165 if (FusedInsts.count(Inst))
1166 continue;
1167
1168 const ShapeInfo &SI = ShapeMap.at(Inst);
1169
1170 Value *Op1;
1171 Value *Op2;
1172 MatrixTy Result;
1173 IRBuilder<> Builder(Inst);
1174 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1175 Result = VisitBinaryOperator(BinOp, SI, Builder);
1176 else if (auto *Cast = dyn_cast<CastInst>(Inst))
1177 Result = VisitCastInstruction(Cast, SI, Builder);
1178 else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1179 Result = VisitUnaryOperator(UnOp, SI, Builder);
1180 else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
1181 Result = VisitIntrinsicInst(Intr, SI, Builder);
1182 else if (auto *Select = dyn_cast<SelectInst>(Inst))
1183 Result = VisitSelectInst(Select, SI, Builder);
1184 else if (match(Inst, m_Load(m_Value(Op1))))
1185 Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
1186 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1187 Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
1188 else if (auto *PHI = dyn_cast<PHINode>(Inst))
1189 Result = VisitPHI(PHI, SI, Builder);
1190 else
1191 continue;
1192
1193 finalizeLowering(Inst, Result, Builder);
1194 Changed = true;
1195 }
1196
1197 if (ORE) {
1198 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1199 RemarkGen.emitRemarks();
1200 }
1201
1202 // Delete the instructions backwards, as it has a reduced likelihood of
1203 // having to update as many def-use and use-def chains.
1204 //
1205 // Because we add to ToRemove during fusion we can't guarantee that defs
1206 // are before uses. Change uses to poison temporarily as these should get
1207 // removed as well.
1208 //
1209 // For verification, we keep track of where we changed uses to poison in
1210 // PoisonedInsts and then check that we in fact remove them.
1211 SmallPtrSet<Instruction *, 16> PoisonedInsts;
1212 for (auto *Inst : reverse(ToRemove)) {
1213 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1214 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1215 PoisonedInsts.insert(Poisoned);
1216 U.set(PoisonValue::get(Inst->getType()));
1217 }
1218 Inst->eraseFromParent();
1219 PoisonedInsts.erase(Inst);
1220 }
1221 if (!PoisonedInsts.empty()) {
1222 // If we didn't remove all poisoned instructions, it's a hard error.
1223 dbgs() << "Poisoned but present instructions:\n";
1224 for (auto *I : PoisonedInsts)
1225 dbgs() << *I << "\n";
1226 llvm_unreachable("Poisoned but instruction not removed");
1227 }
1228
1229 return Changed;
1230 }
1231
1232 /// Replace intrinsic calls.
1233 MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI,
1234 IRBuilder<> &Builder) {
1235 assert(Inst->getCalledFunction() &&
1236 Inst->getCalledFunction()->isIntrinsic());
1237
1238 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1239 case Intrinsic::matrix_multiply:
1240 return LowerMultiply(Inst, Builder);
1241 case Intrinsic::matrix_transpose:
1242 return LowerTranspose(Inst, Builder);
1243 case Intrinsic::matrix_column_major_load:
1244 return LowerColumnMajorLoad(Inst, Builder);
1245 case Intrinsic::matrix_column_major_store:
1246 return LowerColumnMajorStore(Inst, Builder);
1247 case Intrinsic::abs:
1248 case Intrinsic::fabs: {
1249 MatrixTy Result;
1250 MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder);
1251 Builder.setFastMathFlags(getFastMathFlags(Inst));
1252
1253 for (auto *Vector : M.vectors()) {
1254 switch (Inst->getIntrinsicID()) {
1255 case Intrinsic::abs:
1256 Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
1257 Inst->getOperand(1)));
1258 continue;
1259 case Intrinsic::fabs:
1260 Result.addVector(
1261 Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
1262 continue;
1263 default:
1264 llvm_unreachable("unexpected intrinsic");
1265 }
1266 }
1267
1268 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1269 Result.getNumVectors());
1270 }
1271 default:
1272 break;
1273 }
1275 "only intrinsics supporting shape info should be seen here");
1276 }
1277
1278 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1279 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1280 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1281 /// non-ConstantInt strides, return the common alignment of the initial
1282 /// alignment and the element size in bytes.
1283 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1284 MaybeAlign A) const {
1285 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1286 if (Idx == 0)
1287 return InitialAlign;
1288
1289 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1290 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1291 uint64_t StrideInBytes =
1292 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1293 return commonAlignment(InitialAlign, Idx * StrideInBytes);
1294 }
1295 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1296 }
1297
1298 IntegerType *getIndexType(Value *Ptr) const {
1299 return cast<IntegerType>(DL.getIndexType(Ptr->getType()));
1300 }
1301
1302 Value *getIndex(Value *Ptr, uint64_t V) const {
1303 return ConstantInt::get(getIndexType(Ptr), V);
1304 }
1305
1306 Value *castToIndexType(Value *Ptr, Value *V, IRBuilder<> &Builder) const {
1307 assert(isa<IntegerType>(V->getType()) &&
1308 "Attempted to cast non-integral type to integer index");
1309 // In case the data layout's index type differs in width from the type of
1310 // the value we're given, truncate or zero extend to the appropriate width.
1311 // We zero extend here as indices are unsigned.
1312 return Builder.CreateZExtOrTrunc(V, getIndexType(Ptr),
1313 V->getName() + ".cast");
1314 }
1315
1316 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1317 /// vectors.
1318 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1319 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1320 auto *VType = cast<FixedVectorType>(Ty);
1321 Type *EltTy = VType->getElementType();
1322 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1323 Value *EltPtr = Ptr;
1324 MatrixTy Result;
1325 Stride = castToIndexType(Ptr, Stride, Builder);
1326 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1327 Value *GEP = computeVectorAddr(
1328 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1329 Stride, Shape.getStride(), EltTy, Builder);
1330 Value *Vector = Builder.CreateAlignedLoad(
1331 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1332 IsVolatile, "col.load");
1333
1334 Result.addVector(Vector);
1335 }
1336 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1337 Result.getNumVectors());
1338 }
1339
1340 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1341 /// starting at \p MatrixPtr[I][J].
1342 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1343 ShapeInfo MatrixShape, Value *I, Value *J,
1344 ShapeInfo ResultShape, Type *EltTy,
1345 IRBuilder<> &Builder) {
1346 Value *Offset = Builder.CreateAdd(
1347 Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);
1348
1349 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1350 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1351 ResultShape.NumColumns);
1352
1353 return loadMatrix(TileTy, TileStart, Align,
1354 getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
1355 ResultShape, Builder);
1356 }
1357
1358 /// Lower a load instruction with shape information.
1359 MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align,
1360 Value *Stride, bool IsVolatile, ShapeInfo Shape,
1361 IRBuilder<> &Builder) {
1362 return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape,
1363 Builder);
1364 }
1365
1366 /// Lowers llvm.matrix.column.major.load.
1367 ///
1368 /// The intrinsic loads a matrix from memory using a stride between columns.
1369 MatrixTy LowerColumnMajorLoad(CallInst *Inst, IRBuilder<> &Builder) {
1371 "Intrinsic only supports column-major layout!");
1372 Value *Ptr = Inst->getArgOperand(0);
1373 Value *Stride = Inst->getArgOperand(1);
1374 return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1375 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1376 {Inst->getArgOperand(3), Inst->getArgOperand(4)}, Builder);
1377 }
1378
1379 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1380 /// MatrixPtr[I][J].
1381 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1382 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1383 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1384 Value *Offset = Builder.CreateAdd(
1385 Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);
1386
1387 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1388 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1389 StoreVal.getNumColumns());
1390
1391 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1392 getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
1393 Builder);
1394 }
1395
1396 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1397 /// vectors.
1398 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1399 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1400 IRBuilder<> &Builder) {
1401 auto *VType = cast<FixedVectorType>(Ty);
1402 Value *EltPtr = Ptr;
1403 Stride = castToIndexType(Ptr, Stride, Builder);
1404 for (auto Vec : enumerate(StoreVal.vectors())) {
1405 Value *GEP = computeVectorAddr(
1406 EltPtr,
1407 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1408 Vec.index()),
1409 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1410 Builder.CreateAlignedStore(Vec.value(), GEP,
1411 getAlignForIndex(Vec.index(), Stride,
1412 VType->getElementType(),
1413 MAlign),
1414 IsVolatile);
1415 }
1416 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1417 StoreVal.getNumVectors());
1418 }
1419
1420 /// Lower a store instruction with shape information.
1421 MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr,
1422 MaybeAlign A, Value *Stride, bool IsVolatile,
1423 ShapeInfo Shape, IRBuilder<> &Builder) {
1424 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1425 return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile,
1426 Builder);
1427 }
1428
1429 /// Lowers llvm.matrix.column.major.store.
1430 ///
1431 /// The intrinsic store a matrix back memory using a stride between columns.
1432 MatrixTy LowerColumnMajorStore(CallInst *Inst, IRBuilder<> &Builder) {
1434 "Intrinsic only supports column-major layout!");
1435 Value *Matrix = Inst->getArgOperand(0);
1436 Value *Ptr = Inst->getArgOperand(1);
1437 Value *Stride = Inst->getArgOperand(2);
1438 return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1439 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1440 {Inst->getArgOperand(4), Inst->getArgOperand(5)},
1441 Builder);
1442 }
1443
1444 // Set elements I..I+NumElts-1 to Block
1445 Value *insertVector(Value *Col, unsigned I, Value *Block,
1446 IRBuilder<> &Builder) {
1447
1448 // First, bring Block to the same size as Col
1449 unsigned BlockNumElts =
1450 cast<FixedVectorType>(Block->getType())->getNumElements();
1451 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1452 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1453
1454 Block = Builder.CreateShuffleVector(
1455 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1456
1457 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1458 // 8, 4, 5, 6
1459 SmallVector<int, 16> Mask;
1460 unsigned i;
1461 for (i = 0; i < I; i++)
1462 Mask.push_back(i);
1463
1464 unsigned VecNumElts =
1465 cast<FixedVectorType>(Col->getType())->getNumElements();
1466 for (; i < I + BlockNumElts; i++)
1467 Mask.push_back(i - I + VecNumElts);
1468
1469 for (; i < VecNumElts; i++)
1470 Mask.push_back(i);
1471
1472 return Builder.CreateShuffleVector(Col, Block, Mask);
1473 }
1474
1475 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1476 IRBuilder<> &Builder, bool AllowContraction,
1477 unsigned &NumComputeOps) {
1478 NumComputeOps += getNumOps(A->getType());
1479 if (!Sum)
1480 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1481
1482 if (UseFPOp) {
1483 if (AllowContraction) {
1484 // Use fmuladd for floating point operations and let the backend decide
1485 // if that's profitable.
1486 return Builder.CreateIntrinsic(Intrinsic::fmuladd, A->getType(),
1487 {A, B, Sum});
1488 }
1489 NumComputeOps += getNumOps(A->getType());
1490 Value *Mul = Builder.CreateFMul(A, B);
1491 return Builder.CreateFAdd(Sum, Mul);
1492 }
1493
1494 NumComputeOps += getNumOps(A->getType());
1495 Value *Mul = Builder.CreateMul(A, B);
1496 return Builder.CreateAdd(Sum, Mul);
1497 }
1498
1499 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1500 /// users with shape information, there's nothing to do: they will use the
1501 /// cached value when they are lowered. For other users, \p Matrix is
1502 /// flattened and the uses are updated to use it. Also marks \p Inst for
1503 /// deletion.
1504 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1505 IRBuilder<> &Builder) {
1506 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1507 (void)inserted;
1508 assert((inserted.second || isa<PHINode>(Inst)) &&
1509 "multiple matrix lowering mapping");
1510
1511 ToRemove.push_back(Inst);
1512 Value *Flattened = nullptr;
1513 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1514 if (ShapeMap.contains(U.getUser()))
1515 continue;
1516
1517 if (!Flattened) {
1518 Flattened = Matrix.embedInVector(Builder);
1519 LLVM_DEBUG(
1520 if (Instruction *User = dyn_cast<Instruction>(U.getUser())) dbgs()
1521 << "flattening a " << Matrix.shape() << " matrix:\n"
1522 << *Inst
1523 << "\nbecause we do not have a shape-aware lowering for its "
1524 "user:\n"
1525 << *User << '\n';);
1526 FlattenedMatrices++;
1527 }
1528 U.set(Flattened);
1529 }
1530 }
1531
1532 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1533 /// vectors Lowers to vector reduction add instead of sequential add if
1534 /// reassocation is enabled.
1535 void lowerDotProduct(CallInst *MatMul,
1536 SmallPtrSet<Instruction *, 16> &FusedInsts,
1537 FastMathFlags FMF) {
1538 if (FusedInsts.contains(MatMul) ||
1540 return;
1541 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1542 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1543
1544 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1545 return;
1546
1547 Value *LHS = MatMul->getArgOperand(0);
1548 Value *RHS = MatMul->getArgOperand(1);
1549
1550 Type *ElementType = cast<FixedVectorType>(LHS->getType())->getElementType();
1551 bool IsIntVec = ElementType->isIntegerTy();
1552
1553 // Floating point reductions require reassocation.
1554 if (!IsIntVec && !FMF.allowReassoc())
1555 return;
1556
1557 auto CanBeFlattened = [](Value *Op) {
1558 if (match(Op, m_BinOp()))
1559 return true;
1560 return match(
1562 m_Load(m_Value()),
1565 m_Value(), m_SpecificInt(1))))));
1566 };
1567 // Returns the cost benefit of using \p Op with the dot product lowering. If
1568 // the returned cost is < 0, the argument is cheaper to use in the
1569 // dot-product lowering.
1570 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1571 if (!ShapeMap.contains(Op))
1572 return InstructionCost::getInvalid();
1573
1574 if (!isa<Instruction>(Op))
1575 return InstructionCost(0);
1576
1577 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1578 Type *EltTy = VecTy->getElementType();
1579
1580 if (!CanBeFlattened(Op)) {
1581 InstructionCost EmbedCost(0);
1582 // Roughly estimate the cost for embedding the columns into a vector.
1583 for (unsigned I = 1; I < N; ++I)
1584 EmbedCost += TTI.getShuffleCost(
1587 return EmbedCost;
1588 }
1589
1590 if (match(Op, m_BinOp()) && ShapeMap.contains(Op)) {
1591 InstructionCost OriginalCost =
1592 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1593 EltTy) *
1594 N;
1595 InstructionCost NewCost = TTI.getArithmeticInstrCost(
1596 cast<Instruction>(Op)->getOpcode(), VecTy);
1597 return NewCost - OriginalCost;
1598 }
1599
1601 // The transpose can be skipped for the dot product lowering, roughly
1602 // estimate the savings as the cost of embedding the columns in a
1603 // vector.
1604 InstructionCost EmbedCost(0);
1605 for (unsigned I = 1; I < N; ++I)
1606 EmbedCost -= TTI.getShuffleCost(
1609 return EmbedCost;
1610 }
1611
1612 // Costs for loads.
1613 if (N == 1)
1614 return InstructionCost(0);
1615
1616 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1617 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
1618 };
1619
1620 // Iterate over LHS and operations feeding LHS and check if it is profitable
1621 // to flatten the visited ops. For each op, we compute the difference
1622 // between the flattened and matrix versions.
1623 SmallPtrSet<Value *, 4> Seen;
1624 SmallVector<Value *> WorkList;
1625 SmallVector<Value *> ToFlatten;
1626 WorkList.push_back(LHS);
1627 InstructionCost LHSCost(0);
1628 while (!WorkList.empty()) {
1629 Value *Op = WorkList.pop_back_val();
1630 if (!Seen.insert(Op).second)
1631 continue;
1632
1633 InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1634 if (OpCost + LHSCost >= LHSCost)
1635 continue;
1636
1637 LHSCost += OpCost;
1638 ToFlatten.push_back(Op);
1639 if (auto *I = dyn_cast<Instruction>(Op))
1640 WorkList.append(I->op_begin(), I->op_end());
1641 }
1642
1643 // We compare the costs of a vector.reduce.add to sequential add.
1644 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1645 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1646 InstructionCost ReductionCost =
1647 TTI.getArithmeticReductionCost(
1648 AddOpCode, cast<FixedVectorType>(LHS->getType()),
1649 IsIntVec ? std::nullopt : std::optional(FMF)) +
1650 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
1651 InstructionCost SequentialAddCost =
1652 TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
1653 (LShape.NumColumns - 1) +
1654 TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
1655 (LShape.NumColumns);
1656 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1657 return;
1658
1659 FusedInsts.insert(MatMul);
1660 IRBuilder<> Builder(MatMul);
1661 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1662 this](Value *Op) {
1663 // Matmul must be the only user of loads because we don't use LowerLoad
1664 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1665 // instead of single vector load).
1666 if (!CanBeFlattened(Op))
1667 return;
1668
1669 if (match(Op, m_BinOp())) {
1670 auto It = ShapeMap.find(Op);
1671 if (It != ShapeMap.end()) {
1672 It->second = It->second.t();
1673 return;
1674 }
1675 }
1676
1677 FusedInsts.insert(cast<Instruction>(Op));
1678 // If vector uses the builtin load, lower to a LoadInst
1679 Value *Arg;
1681 m_Value(Arg)))) {
1682 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1683 Op->replaceAllUsesWith(NewLoad);
1684 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
1685 return;
1687 m_Value(Arg)))) {
1688 ToRemove.push_back(cast<Instruction>(Op));
1689 Op->replaceAllUsesWith(Arg);
1690 return;
1691 }
1692 };
1693
1694 for (auto *V : ToFlatten)
1695 FlattenArg(V);
1696
1697 LHS = MatMul->getArgOperand(0);
1698
1699 // Insert mul/fmul and llvm.vector.reduce.fadd
1700 Value *Mul =
1701 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
1702
1703 Value *Result;
1704 if (IsIntVec)
1705 Result = Builder.CreateAddReduce(Mul);
1706 else {
1707 Result = Builder.CreateFAddReduce(
1708 ConstantFP::get(
1709 cast<FixedVectorType>(LHS->getType())->getElementType(), 0.0),
1710 Mul);
1711 cast<Instruction>(Result)->setFastMathFlags(FMF);
1712 }
1713
1714 // pack scalar back into a matrix and then replace matmul inst
1716 Result, uint64_t(0));
1717 MatMul->replaceAllUsesWith(Result);
1718 FusedInsts.insert(MatMul);
1719 ToRemove.push_back(MatMul);
1720 }
1721
1722 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1723 /// addition.
1724 ///
1725 /// We can fold a transpose into the operand that is used to extract scalars.
1726 /// This is the first operands with row-major and the second with
1727 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1728 /// operand is transposed.
1729 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1730 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1731 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1732 const unsigned VF = std::max<unsigned>(
1733 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
1734 .getFixedValue() /
1735 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1736 1U);
1737 unsigned R = Result.getNumRows();
1738 unsigned C = Result.getNumColumns();
1739 unsigned M = A.getNumColumns();
1740
1741 bool IsFP = Result.getElementType()->isFloatingPointTy();
1742 assert(A.isColumnMajor() == B.isColumnMajor() &&
1743 Result.isColumnMajor() == A.isColumnMajor() &&
1744 "operands must agree on matrix layout");
1745 unsigned NumComputeOps = 0;
1746
1747 Builder.setFastMathFlags(FMF);
1748
1749 if (A.isColumnMajor()) {
1750 // Multiply columns from the first operand with scalars from the second
1751 // operand. Then move along the K axes and accumulate the columns. With
1752 // this the adds can be vectorized without reassociation.
1753 for (unsigned J = 0; J < C; ++J) {
1754 unsigned BlockSize = VF;
1755 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1756 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1757
1758 for (unsigned I = 0; I < R; I += BlockSize) {
1759 // Gradually lower the vectorization factor to cover the remainder.
1760 while (I + BlockSize > R)
1761 BlockSize /= 2;
1762
1763 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1764 : nullptr;
1765 for (unsigned K = 0; K < M; ++K) {
1766 Value *L = A.extractVector(I, K, BlockSize, Builder);
1767 Value *RH = Builder.CreateExtractElement(
1768 B.getColumn(IsScalarMatrixTransposed ? K : J),
1769 IsScalarMatrixTransposed ? J : K);
1770 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1771 Sum =
1772 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1773 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1774 }
1775 Result.setVector(J,
1776 insertVector(Result.getVector(J), I, Sum, Builder));
1777 }
1778 }
1779 } else {
1780 // Multiply rows from the second operand with scalars from the first
1781 // operand. Then move along the K axes and accumulate the rows. With this
1782 // the adds can be vectorized without reassociation.
1783 for (unsigned I = 0; I < R; ++I) {
1784 unsigned BlockSize = VF;
1785 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1786 for (unsigned J = 0; J < C; J += BlockSize) {
1787 // Gradually lower the vectorization factor to cover the remainder.
1788 while (J + BlockSize > C)
1789 BlockSize /= 2;
1790
1791 Value *Sum = nullptr;
1792 for (unsigned K = 0; K < M; ++K) {
1793 Value *R = B.extractVector(K, J, BlockSize, Builder);
1794 Value *LH = Builder.CreateExtractElement(
1795 A.getVector(IsScalarMatrixTransposed ? K : I),
1796 IsScalarMatrixTransposed ? I : K);
1797 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1798 Sum =
1799 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1800 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1801 }
1802 Result.setVector(I,
1803 insertVector(Result.getVector(I), J, Sum, Builder));
1804 }
1805 }
1806 }
1807 Result.addNumComputeOps(NumComputeOps);
1808 }
1809
1810 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1811 /// copying it to a new location. This new or otherwise the original location
1812 /// is returned.
1813 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1814 CallInst *MatMul) {
1815 MemoryLocation StoreLoc = MemoryLocation::get(Store);
1816 MemoryLocation LoadLoc = MemoryLocation::get(Load);
1817
1818 // If we can statically determine noalias we're good.
1819 if (AA->isNoAlias(LoadLoc, StoreLoc))
1820 return Load->getPointerOperand();
1821
1822 // Create code to check if the memory locations of the Load and Store
1823 // overlap and if they do, copy Load's operand to a new buffer.
1824
1825 // First, create new blocks for 2n part of the check and the copy.
1826 BasicBlock *Check0 = MatMul->getParent();
1827 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1828 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1829 // as we adjust Check0 and Check1's branches.
1831 for (BasicBlock *Succ : successors(Check0))
1832 DTUpdates.push_back({DT->Delete, Check0, Succ});
1833
1834 BasicBlock *Check1 =
1835 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1836 nullptr, "alias_cont");
1837 BasicBlock *Copy =
1838 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1839 nullptr, "copy");
1840 BasicBlock *Fusion =
1841 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1842 nullptr, "no_alias");
1843
1844 // Check if the loaded memory location begins before the end of the store
1845 // location. If the condition holds, they might overlap, otherwise they are
1846 // guaranteed to not overlap.
1847 IRBuilder<> Builder(MatMul);
1848 Check0->getTerminator()->eraseFromParent();
1849 Builder.SetInsertPoint(Check0);
1850 Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout());
1851 Value *StoreBegin = Builder.CreatePtrToInt(
1852 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1853 Value *StoreEnd = Builder.CreateAdd(
1854 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1855 "store.end", true, true);
1856 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1857 IntPtrTy, "load.begin");
1858 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1859 Fusion);
1860
1861 // Check if the store begins before the end of the load location. If the
1862 // condition holds, they alias, otherwise they are guaranteed to not
1863 // overlap.
1864 Check1->getTerminator()->eraseFromParent();
1865 Builder.SetInsertPoint(Check1, Check1->begin());
1866 Value *LoadEnd = Builder.CreateAdd(
1867 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1868 "load.end", true, true);
1869 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1870 Fusion);
1871
1872 // Copy load operand to new alloca.
1873 Builder.SetInsertPoint(Copy, Copy->begin());
1874 auto *VT = cast<FixedVectorType>(Load->getType());
1875 // Use an array type for the alloca, to avoid potentially huge alignment
1876 // requirements for large vector types.
1877 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1878 AllocaInst *Alloca =
1879 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1880
1881 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
1882 Load->getAlign(), LoadLoc.Size.getValue());
1883 Builder.SetInsertPoint(Fusion, Fusion->begin());
1884 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1885 PHI->addIncoming(Load->getPointerOperand(), Check0);
1886 PHI->addIncoming(Load->getPointerOperand(), Check1);
1887 PHI->addIncoming(Alloca, Copy);
1888
1889 // Adjust DT.
1890 DTUpdates.push_back({DT->Insert, Check0, Check1});
1891 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1892 DTUpdates.push_back({DT->Insert, Check1, Copy});
1893 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1894 DT->applyUpdates(DTUpdates);
1895 return PHI;
1896 }
1897
1898 bool isFusionProfitable(CallInst *MatMul) {
1899 if (ForceFusion)
1900 return true;
1901
1902 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1903 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1904
1905 const unsigned R = LShape.NumRows;
1906 const unsigned C = RShape.NumColumns;
1907 const unsigned M = LShape.NumColumns;
1908 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
1909
1910 const unsigned VF = std::max<unsigned>(
1911 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
1912 .getFixedValue() /
1914 1U);
1915
1916 // Cost model for tiling
1917 //
1918 // For tiling to be beneficial, we need reuse either along the R or
1919 // the C axis. We vectorize along the R axis so that means at least
1920 // 3 elements.
1921 // TODO: Also consider cost of copying if operands alias.
1922 if (R <= VF && C == 1)
1923 return false;
1924 // Then we need enough elements to exceed the number of vector
1925 // registers we have. Note that this is an oversimplification since
1926 // fusing also takes some extra loads which may exceed the number of
1927 // reloads necessary.
1928 unsigned Op0Regs = (R + VF - 1) / VF * M;
1929 unsigned Op1Regs = (M + VF - 1) / VF * C;
1930 return Op0Regs + Op1Regs >
1931 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true));
1932 }
1933
1934 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1935 MatrixTy Res;
1936 auto *ColumType = FixedVectorType::get(EltType, R);
1937 for (unsigned I = 0; I < C; ++I)
1938 Res.addVector(ConstantAggregateZero::get(ColumType));
1939 return Res;
1940 }
1941
1942 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1943 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1944 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
1945
1946 // Create the main tiling loop nest.
1947 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1948 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1949 Instruction *InsertI = cast<Instruction>(MatMul);
1950 BasicBlock *Start = InsertI->getParent();
1951 BasicBlock *End =
1952 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1953 IRBuilder<> Builder(MatMul);
1954 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1955
1956 Type *TileVecTy =
1958 MatrixTy TileResult;
1959 // Insert in the inner loop header.
1960 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1961 // Create PHI nodes for the result columns to accumulate across iterations.
1962 SmallVector<PHINode *, 4> ColumnPhis;
1963 for (unsigned I = 0; I < TileSize; I++) {
1964 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1965 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1966 TI.RowLoop.Header->getSingleSuccessor());
1967 TileResult.addVector(Phi);
1968 ColumnPhis.push_back(Phi);
1969 }
1970
1971 // Insert in the inner loop body, which computes
1972 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1973 Builder.SetInsertPoint(InnerBody->getTerminator());
1974 // Load tiles of the operands.
1975 MatrixTy A =
1976 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1977 {TileSize, TileSize}, EltType, Builder);
1978 MatrixTy B =
1979 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1980 {TileSize, TileSize}, EltType, Builder);
1981 emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1982 getFastMathFlags(MatMul));
1983 // Store result after the inner loop is done.
1984 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1985 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1986 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1987 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1988
1989 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1990 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1991
1992 // Force unrolling of a few iterations of the inner loop, to make sure there
1993 // is enough work per iteration.
1994 // FIXME: The unroller should make this decision directly instead, but
1995 // currently the cost-model is not up to the task.
1996 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1997 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1998 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1999 }
2000
2001 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
2002 StoreInst *Store,
2003 SmallPtrSetImpl<Instruction *> &FusedInsts) {
2005 "Tiling only supported for column-major matrixes at the moment!");
2006 if (!isFusionProfitable(MatMul))
2007 return;
2008
2009 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2010 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2011
2012 const unsigned R = LShape.NumRows;
2013 const unsigned C = RShape.NumColumns;
2014 const unsigned M = LShape.NumColumns;
2015 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
2016
2017 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
2018 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
2019 Value *CPtr = Store->getPointerOperand();
2020
2021 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
2022 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
2023 else {
2024 IRBuilder<> Builder(Store);
2025 for (unsigned J = 0; J < C; J += TileSize)
2026 for (unsigned I = 0; I < R; I += TileSize) {
2027 const unsigned TileR = std::min(R - I, unsigned(TileSize));
2028 const unsigned TileC = std::min(C - J, unsigned(TileSize));
2029 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
2030
2031 for (unsigned K = 0; K < M; K += TileSize) {
2032 const unsigned TileM = std::min(M - K, unsigned(TileSize));
2033 MatrixTy A =
2034 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
2035 LShape, getIndex(APtr, I), getIndex(APtr, K),
2036 {TileR, TileM}, EltType, Builder);
2037 MatrixTy B =
2038 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
2039 RShape, getIndex(BPtr, K), getIndex(BPtr, J),
2040 {TileM, TileC}, EltType, Builder);
2041 emitMatrixMultiply(Res, A, B, Builder, true, false,
2042 getFastMathFlags(MatMul));
2043 }
2044 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
2045 getIndex(CPtr, I), getIndex(CPtr, J), EltType, Builder);
2046 }
2047 }
2048
2049 // Mark eliminated instructions as fused and remove them.
2050 FusedInsts.insert(Store);
2051 FusedInsts.insert(MatMul);
2052 eraseFromParentAndRemoveFromShapeMap(Store);
2053 eraseFromParentAndRemoveFromShapeMap(MatMul);
2054 if (LoadOp0->use_empty()) {
2055 FusedInsts.insert(LoadOp0);
2056 eraseFromParentAndRemoveFromShapeMap(LoadOp0);
2057 }
2058 if (LoadOp1 != LoadOp0 && LoadOp1->use_empty()) {
2059 FusedInsts.insert(LoadOp1);
2060 eraseFromParentAndRemoveFromShapeMap(LoadOp1);
2061 }
2062 }
2063
2064 /// Try to lower matrix multiply chains by fusing operations.
2065 ///
2066 /// Call finalizeLowering on lowered instructions. Instructions that are
2067 /// completely eliminated by fusion are added to \p FusedInsts.
2068 void
2069 LowerMatrixMultiplyFused(CallInst *MatMul,
2070 SmallPtrSetImpl<Instruction *> &FusedInsts,
2071 SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
2072 if (!FuseMatrix || !DT)
2073 return;
2074
2075 assert(AA && LI && "Analyses should be available");
2076
2077 Value *A = MatMul->getArgOperand(0);
2078 Value *B = MatMul->getArgOperand(1);
2079
2080 // We can fold the transpose into the operand that is used to fetch scalars.
2081 Value *T;
2085 IRBuilder<> Builder(MatMul);
2086 auto *EltType =
2087 cast<FixedVectorType>(MatMul->getType())->getElementType();
2088 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2089 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2090 const unsigned R = LShape.NumRows;
2091 const unsigned M = LShape.NumColumns;
2092 const unsigned C = RShape.NumColumns;
2093
2094 MatrixTy MA;
2095 MatrixTy MB;
2096
2097 Value *Transpose;
2099 MA = getMatrix(A, ShapeInfo(R, M), Builder);
2100 MB = getMatrix(T, ShapeInfo(C, M), Builder);
2101 Transpose = B;
2102 } else {
2103 MA = getMatrix(T, ShapeInfo(R, M), Builder);
2104 MB = getMatrix(B, ShapeInfo(C, M), Builder);
2105 Transpose = A;
2106 }
2107
2108 // Initialize the output
2109 MatrixTy Result(R, C, EltType);
2110
2111 emitMatrixMultiply(Result, MA, MB, Builder, false, true,
2112 getFastMathFlags(MatMul));
2113
2114 FusedInsts.insert(MatMul);
2115 if (Transpose->hasOneUse()) {
2116 FusedInsts.insert(cast<Instruction>(Transpose));
2117 ToRemove.push_back(cast<Instruction>(Transpose));
2118 // TODO: add a fake entry for the folded instruction so that this is
2119 // included in the expression in the remark.
2120 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
2121 }
2122 finalizeLowering(MatMul, Result, Builder);
2123 return;
2124 }
2125
2127 return;
2128
2129 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
2130 // since the single store user will be lowered as part of this.
2131 auto *LoadOp0 = dyn_cast<LoadInst>(A);
2132 auto *LoadOp1 = dyn_cast<LoadInst>(B);
2133 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
2134 if (LoadOp0 && LoadOp1 && Store) {
2135 // The store address must dominate the MatMul instruction, otherwise
2136 // we create invalid IR.
2137 SetVector<Value *> WorkList;
2138 WorkList.insert(Store->getOperand(1));
2140 for (unsigned I = 0; I != WorkList.size(); ++I) {
2141 Value *Current = WorkList[I];
2142 auto *CurrI = dyn_cast<Instruction>(Current);
2143 if (!CurrI)
2144 continue;
2145 if (isa<PHINode>(CurrI))
2146 return;
2147 if (DT->dominates(CurrI, MatMul))
2148 continue;
2149 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
2150 return;
2151 ToHoist.push_back(CurrI);
2152 WorkList.insert_range(CurrI->operands());
2153 }
2154
2155 sort(ToHoist, [this](Instruction *A, Instruction *B) {
2156 return DT->dominates(A, B);
2157 });
2158 for (Instruction *I : ToHoist)
2159 I->moveBefore(MatMul->getIterator());
2160
2161 // Deal with lifetime.end calls that might be between Load0/Load1 and the
2162 // store. To avoid introducing loads to dead objects (i.e. after the
2163 // lifetime has been termined by @llvm.lifetime.end), either sink them
2164 // after the store if in the same block, or remove the lifetime.end marker
2165 // otherwise. This might pessimize further optimizations, by extending the
2166 // lifetime of the object until the function returns, but should be
2167 // conservatively correct.
2168 MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
2169 MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
2170 BasicBlock *StoreParent = Store->getParent();
2171 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
2172 LoadOp1->getParent() == StoreParent;
2173 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
2174 IntrinsicInst *End = LifetimeEnds[Idx];
2175 auto Inc = make_scope_exit([&Idx]() { Idx++; });
2176 // If the lifetime.end is guaranteed to be before the loads or after the
2177 // store, it won't interfere with fusion.
2178 if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
2179 continue;
2180 if (DT->dominates(Store, End))
2181 continue;
2182 // If all fusable ops are in the same block and the lifetime.end is in a
2183 // different block, it won't interfere with fusion.
2184 if (FusableOpsInSameBlock && End->getParent() != StoreParent)
2185 continue;
2186
2187 // If the loads don't alias the lifetime.end, it won't interfere with
2188 // fusion.
2189 MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 0, nullptr);
2190 if (!EndLoc.Ptr)
2191 continue;
2192 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
2193 continue;
2194
2195 // If both lifetime.end and the store are in the same block, extend the
2196 // lifetime until after the store, so the new lifetime covers the loads
2197 // we introduce later.
2198 if (End->getParent() == StoreParent) {
2199 End->moveAfter(Store);
2200 continue;
2201 }
2202
2203 // Otherwise remove the conflicting lifetime.end marker.
2204 ToRemove.push_back(End);
2205 std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
2206 LifetimeEnds.pop_back();
2207 Inc.release();
2208 }
2209
2210 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2211 return;
2212 }
2213 }
2214
2215 /// Lowers llvm.matrix.multiply.
2216 MatrixTy LowerMultiply(CallInst *MatMul, IRBuilder<> &Builder) {
2217 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
2218 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2219 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2220
2221 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
2222 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
2223 assert(Lhs.getElementType() == Rhs.getElementType() &&
2224 "Matrix multiply argument element types do not match.");
2225
2226 const unsigned R = LShape.NumRows;
2227 const unsigned C = RShape.NumColumns;
2228 assert(LShape.NumColumns == RShape.NumRows);
2229
2230 // Initialize the output
2231 MatrixTy Result(R, C, EltType);
2232 assert(Lhs.getElementType() == Result.getElementType() &&
2233 "Matrix multiply result element type does not match arguments.");
2234
2235 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
2236 getFastMathFlags(MatMul));
2237 return Result;
2238 }
2239
2240 /// Lowers llvm.matrix.transpose.
2241 MatrixTy LowerTranspose(CallInst *Inst, IRBuilder<> &Builder) {
2242 MatrixTy Result;
2243 Value *InputVal = Inst->getArgOperand(0);
2244 FixedVectorType *VectorTy = cast<FixedVectorType>(InputVal->getType());
2245 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
2246 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2247
2248 const unsigned NewNumVecs =
2249 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2250 const unsigned NewNumElts =
2251 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2252
2253 for (unsigned I = 0; I < NewNumVecs; ++I) {
2254 // Build a single result vector. First initialize it.
2255 Value *ResultVector = PoisonValue::get(
2256 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
2257 // Go through the old elements and insert it into the resulting vector.
2258 for (auto J : enumerate(InputMatrix.vectors())) {
2259 Value *Elt = Builder.CreateExtractElement(J.value(), I);
2260 // Row and column indices are transposed.
2261 ResultVector =
2262 Builder.CreateInsertElement(ResultVector, Elt, J.index());
2263 }
2264 Result.addVector(ResultVector);
2265 }
2266
2267 // TODO: Improve estimate of operations needed for transposes. Currently we
2268 // just count the insertelement/extractelement instructions, but do not
2269 // account for later simplifications/combines.
2270 return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2271 .addNumExposedTransposes(1);
2272 }
2273
2274 /// Lower load instructions.
2275 MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2276 IRBuilder<> &Builder) {
2277 return LowerLoad(Inst, Ptr, Inst->getAlign(), getIndex(Ptr, SI.getStride()),
2278 Inst->isVolatile(), SI, Builder);
2279 }
2280
2281 MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2282 Value *Ptr, IRBuilder<> &Builder) {
2283 return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2284 getIndex(Ptr, SI.getStride()), Inst->isVolatile(), SI,
2285 Builder);
2286 }
2287
2288 MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) {
2289 auto BlockIP = Inst->getParent()->getFirstInsertionPt();
2290 Builder.SetInsertPoint(BlockIP);
2291 MatrixTy PhiM = getMatrix(Inst, SI, Builder);
2292
2293 for (auto [IncomingV, IncomingB] :
2294 llvm::zip_equal(Inst->incoming_values(), Inst->blocks())) {
2295 // getMatrix() may insert some instructions to help with reshaping. The
2296 // safest place for those is at the top of the block after the rest of the
2297 // PHI's. Even better, if we can put it in the incoming block.
2298 Builder.SetInsertPoint(BlockIP);
2299 if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV))
2300 if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef())
2301 Builder.SetInsertPoint(*MaybeIP);
2302
2303 MatrixTy OpM = getMatrix(IncomingV, SI, Builder);
2304
2305 for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {
2306 PHINode *NewPHI = cast<PHINode>(PhiM.getVector(VI));
2307 NewPHI->addIncoming(OpM.getVector(VI), IncomingB);
2308 }
2309 }
2310
2311 // finalizeLowering() may also insert instructions in some cases. The safe
2312 // place for those is at the end of the initial block of PHIs.
2313 Builder.SetInsertPoint(BlockIP);
2314 return PhiM;
2315 }
2316
2317 /// Lower binary operators.
2318 MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI,
2319 IRBuilder<> &Builder) {
2320 Value *Lhs = Inst->getOperand(0);
2321 Value *Rhs = Inst->getOperand(1);
2322
2323 MatrixTy Result;
2324 MatrixTy A = getMatrix(Lhs, SI, Builder);
2325 MatrixTy B = getMatrix(Rhs, SI, Builder);
2326 assert(A.isColumnMajor() == B.isColumnMajor() &&
2327 Result.isColumnMajor() == A.isColumnMajor() &&
2328 "operands must agree on matrix layout");
2329
2330 Builder.setFastMathFlags(getFastMathFlags(Inst));
2331
2332 for (auto [AV, BV] : llvm::zip_equal(A.vectors(), B.vectors()))
2333 Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), AV, BV));
2334
2335 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2336 Result.getNumVectors());
2337 }
2338
2339 /// Lower unary operators.
2340 MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI,
2341 IRBuilder<> &Builder) {
2342 Value *Op = Inst->getOperand(0);
2343
2344 MatrixTy Result;
2345 MatrixTy M = getMatrix(Op, SI, Builder);
2346
2347 Builder.setFastMathFlags(getFastMathFlags(Inst));
2348
2349 // Helper to perform unary op on vectors.
2350 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2351 switch (Inst->getOpcode()) {
2352 case Instruction::FNeg:
2353 return Builder.CreateFNeg(Op);
2354 default:
2355 llvm_unreachable("Unsupported unary operator for matrix");
2356 }
2357 };
2358
2359 for (auto *Vector : M.vectors())
2360 Result.addVector(BuildVectorOp(Vector));
2361
2362 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2363 Result.getNumVectors());
2364 }
2365
2366 /// Lower cast instructions.
2367 MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape,
2368 IRBuilder<> &Builder) {
2369 Value *Op = Inst->getOperand(0);
2370
2371 MatrixTy Result;
2372 MatrixTy M = getMatrix(Op, Shape, Builder);
2373
2374 Builder.setFastMathFlags(getFastMathFlags(Inst));
2375
2376 auto *OrigVTy = cast<VectorType>(Inst->getType());
2377 auto *NewVTy = VectorType::get(OrigVTy->getElementType(),
2378 ElementCount::getFixed(M.getStride()));
2379
2380 for (auto *Vector : M.vectors())
2381 Result.addVector(Builder.CreateCast(Inst->getOpcode(), Vector, NewVTy));
2382
2383 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2384 Result.getNumVectors());
2385 }
2386
2387 /// Lower selects.
2388 MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape,
2389 IRBuilder<> &Builder) {
2390 Value *Cond = Inst->getOperand(0);
2391 Value *OpA = Inst->getOperand(1);
2392 Value *OpB = Inst->getOperand(2);
2393
2394 MatrixTy Result;
2395 MatrixTy A = getMatrix(OpA, Shape, Builder);
2396 MatrixTy B = getMatrix(OpB, Shape, Builder);
2397
2398 SmallVector<Value*> CondV;
2399 if (isa<FixedVectorType>(Cond->getType())) {
2400 MatrixTy C = getMatrix(Cond, Shape, Builder);
2401 llvm::copy(C.vectors(), std::back_inserter(CondV));
2402 } else {
2403 CondV.resize(A.getNumVectors());
2404 llvm::fill(CondV, Cond);
2405 }
2406
2407 for (auto [CV, AV, BV] : llvm::zip_equal(CondV, A.vectors(), B.vectors()))
2408 Result.addVector(Builder.CreateSelect(CV, AV, BV));
2409
2410 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2411 Result.getNumVectors());
2412 }
2413
2414 /// Helper to linearize a matrix expression tree into a string. Currently
2415 /// matrix expressions are linarized by starting at an expression leaf and
2416 /// linearizing bottom up.
2417 struct ExprLinearizer {
2418 unsigned LengthToBreak = 100;
2419 std::string Str;
2420 raw_string_ostream Stream;
2421 unsigned LineLength = 0;
2422 const DataLayout &DL;
2423
2424 /// Mapping from instructions to matrixes. It is used to identify
2425 /// matrix instructions.
2426 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2427
2428 /// Mapping from values to the leaves of all expressions that the value is
2429 /// part of.
2430 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
2431
2432 /// Set of matrix expressions in the scope of a given DISubprogram.
2433 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2434
2435 /// Leaf node of the expression to linearize.
2436 Value *Leaf;
2437
2438 /// Used to keep track of sub-expressions that get reused while linearizing
2439 /// the expression. Re-used sub-expressions are marked as (reused).
2440 SmallPtrSet<Value *, 8> ReusedExprs;
2441
2442 ExprLinearizer(const DataLayout &DL,
2443 const MapVector<Value *, MatrixTy> &Inst2Matrix,
2444 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2445 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2446 Value *Leaf)
2447 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2448 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2449
2450 void indent(unsigned N) {
2451 LineLength += N;
2452 for (unsigned i = 0; i < N; i++)
2453 Stream << " ";
2454 }
2455
2456 void lineBreak() {
2457 Stream << "\n";
2458 LineLength = 0;
2459 }
2460
2461 void maybeIndent(unsigned Indent) {
2462 if (LineLength >= LengthToBreak)
2463 lineBreak();
2464
2465 if (LineLength == 0)
2466 indent(Indent);
2467 }
2468
2469 void write(StringRef S) {
2470 LineLength += S.size();
2471 Stream << S;
2472 }
2473
2474 Value *getUnderlyingObjectThroughLoads(Value *V) {
2475 if (Value *Ptr = getPointerOperand(V))
2476 return getUnderlyingObjectThroughLoads(Ptr);
2477 else if (V->getType()->isPointerTy())
2478 return getUnderlyingObject(V);
2479 return V;
2480 }
2481
2482 /// Returns true if \p V is a matrix value in the given subprogram.
2483 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2484
2485 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2486 /// \p SS.
2487 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2488 auto M = Inst2Matrix.find(V);
2489 if (M == Inst2Matrix.end())
2490 SS << "unknown";
2491 else {
2492 SS << M->second.getNumRows();
2493 SS << "x";
2494 SS << M->second.getNumColumns();
2495 }
2496 }
2497
2498 /// Write the called function name. Handles calls to llvm.matrix.*
2499 /// specially: we write the name, followed by the dimensions of the input
2500 /// matrixes, followed by the scalar type name.
2501 void writeFnName(CallInst *CI) {
2502 if (!CI->getCalledFunction())
2503 write("<no called fn>");
2504 else {
2505 StringRef Name = CI->getCalledFunction()->getName();
2506 if (!Name.starts_with("llvm.matrix")) {
2507 write(Name);
2508 return;
2509 }
2510 auto *II = cast<IntrinsicInst>(CI);
2511 write(Intrinsic::getBaseName(II->getIntrinsicID())
2512 .drop_front(StringRef("llvm.matrix.").size()));
2513 write(".");
2514 std::string Tmp;
2515 raw_string_ostream SS(Tmp);
2516
2517 switch (II->getIntrinsicID()) {
2518 case Intrinsic::matrix_multiply:
2519 prettyPrintMatrixType(II->getOperand(0), SS);
2520 SS << ".";
2521 prettyPrintMatrixType(II->getOperand(1), SS);
2522 SS << "." << *II->getType()->getScalarType();
2523 break;
2524 case Intrinsic::matrix_transpose:
2525 prettyPrintMatrixType(II->getOperand(0), SS);
2526 SS << "." << *II->getType()->getScalarType();
2527 break;
2528 case Intrinsic::matrix_column_major_load:
2529 prettyPrintMatrixType(II, SS);
2530 SS << "." << *II->getType()->getScalarType();
2531 break;
2532 case Intrinsic::matrix_column_major_store:
2533 prettyPrintMatrixType(II->getOperand(0), SS);
2534 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2535 break;
2536 default:
2537 llvm_unreachable("Unhandled case");
2538 }
2539 write(Tmp);
2540 }
2541 }
2542
2543 unsigned getNumShapeArgs(CallInst *CI) const {
2544 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2545 switch (II->getIntrinsicID()) {
2546 case Intrinsic::matrix_multiply:
2547 return 3;
2548 case Intrinsic::matrix_transpose:
2549 return 2;
2550 case Intrinsic::matrix_column_major_load:
2551 case Intrinsic::matrix_column_major_store:
2552 return 3;
2553 default:
2554 return 0;
2555 }
2556 }
2557 return 0;
2558 }
2559
2560 /// Special printing for values: for pointers, we print if they refer to an
2561 /// (function) external address or a stack address, for other values we
2562 /// either print the constant or "scalar"/"matrix" for other values.
2563 void write(Value *V) {
2564 V = getUnderlyingObjectThroughLoads(V);
2565 if (V->getType()->isPointerTy()) {
2566 if (isa<AllocaInst>(V)) {
2567 Stream << "stack addr";
2568 LineLength += StringRef("stack addr").size();
2569 } else {
2570 Stream << "addr";
2571 LineLength += StringRef("addr").size();
2572 }
2573 if (!V->getName().empty()) {
2574 Stream << " %" << V->getName() << "";
2575 LineLength += V->getName().size() + 2;
2576 }
2577 return;
2578 }
2579
2580 std::string Tmp;
2581 raw_string_ostream TmpStream(Tmp);
2582
2583 if (auto *CI = dyn_cast<ConstantInt>(V))
2584 TmpStream << CI->getValue();
2585 else if (isa<Constant>(V))
2586 TmpStream << "constant";
2587 else {
2588 if (isMatrix(V))
2589 TmpStream << "matrix";
2590 else
2591 TmpStream << "scalar";
2592 }
2593 Tmp = std::string(StringRef(Tmp).trim());
2594 LineLength += Tmp.size();
2595 Stream << Tmp;
2596 }
2597
2598 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2599 /// Expressions that are re-used multiple times are prefixed with (reused)
2600 /// at the re-used root instruction.
2601 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2602 bool ParentShared) {
2603 auto *I = cast<Instruction>(Expr);
2604 maybeIndent(Indent);
2605 SmallVector<Value *, 8> Ops;
2606
2607 // Is Expr shared with other expression leaves?
2608 bool ExprShared = false;
2609
2610 // Deal with shared subtrees. Mark them as shared, if required.
2611 if (!ParentShared) {
2612 auto SI = Shared.find(Expr);
2613 assert(SI != Shared.end() && SI->second.count(Leaf));
2614
2615 for (Value *S : SI->second) {
2616 if (S == Leaf)
2617 continue;
2618 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2619 write("shared with remark at line " + std::to_string(DL.getLine()) +
2620 " column " + std::to_string(DL.getCol()) + " (");
2621 }
2622 ExprShared = SI->second.size() > 1;
2623 }
2624
2625 bool Reused = !ReusedExprs.insert(Expr).second;
2626 if (Reused && !ParentReused)
2627 write("(reused) ");
2628
2629 if (auto *CI = dyn_cast<CallInst>(I)) {
2630 writeFnName(CI);
2631
2632 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2633 } else if (isa<BitCastInst>(Expr)) {
2634 // Special case bitcasts, which are used to materialize matrixes from
2635 // non-matrix ops.
2636 write("matrix");
2637 return;
2638 } else {
2639 Ops.append(I->value_op_begin(), I->value_op_end());
2640 write(I->getOpcodeName());
2641 }
2642
2643 write("(");
2644
2645 unsigned NumOpsToBreak = 1;
2647 NumOpsToBreak = 2;
2648
2649 for (Value *Op : Ops) {
2650 if (Ops.size() > NumOpsToBreak)
2651 lineBreak();
2652
2653 maybeIndent(Indent + 1);
2654 if (isMatrix(Op))
2655 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2656 else
2657 write(Op);
2658 if (Op != Ops.back())
2659 write(", ");
2660 }
2661
2662 write(")");
2663 }
2664
2665 const std::string &getResult() {
2666 return Str;
2667 }
2668 };
2669
2670 /// Generate remarks for matrix operations in a function. To generate remarks
2671 /// for matrix expressions, the following approach is used:
2672 /// 1. Use the inlined-at debug information to group matrix operations to the
2673 /// DISubprograms they are contained in.
2674 /// 2. Collect leaves of matrix expressions (done in
2675 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2676 // mapping. Leaves are lowered matrix instructions without other matrix
2677 // users (like stores) in the current subprogram.
2678 /// 3. For each leaf, create a remark containing a linearizied version of the
2679 /// matrix expression. The expression is linearized by a recursive
2680 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2681 /// that multiple leaves can share sub-expressions. Shared subexpressions
2682 /// are explicitly marked as shared().
2683 struct RemarkGenerator {
2684 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2685 OptimizationRemarkEmitter &ORE;
2686 Function &Func;
2687 const DataLayout &DL;
2688
2689 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2690 OptimizationRemarkEmitter &ORE, Function &Func)
2691 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2692 DL(Func.getDataLayout()) {}
2693
2694 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2695 /// instructions in Inst2Matrix returning void or without any users in
2696 /// \p ExprsInSubprogram. Currently that should only include stores.
2698 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2700 for (auto *Expr : ExprsInSubprogram)
2701 if (Expr->getType()->isVoidTy() ||
2702 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2703 return ExprsInSubprogram.count(U);
2704 }))
2705 Leaves.push_back(Expr);
2706 return Leaves;
2707 }
2708
2709 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2710 /// to all visited expressions in \p Shared. Limit the matrix operations to
2711 /// the ones in \p ExprsInSubprogram.
2712 void collectSharedInfo(Value *Leaf, Value *V,
2713 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2714 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
2715
2716 if (!ExprsInSubprogram.count(V))
2717 return;
2718
2719 Shared[V].insert(Leaf);
2720
2721 for (Value *Op : cast<Instruction>(V)->operand_values())
2722 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2723 }
2724
2725 /// Calculate the number of exclusive and shared op counts for expression
2726 /// starting at \p V. Expressions used multiple times are counted once.
2727 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2728 std::pair<OpInfoTy, OpInfoTy>
2729 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2730 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2731 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2732 if (!ExprsInSubprogram.count(Root))
2733 return {};
2734
2735 // Already counted this expression. Stop.
2736 if (!ReusedExprs.insert(Root).second)
2737 return {};
2738
2739 OpInfoTy SharedCount;
2740 OpInfoTy Count;
2741
2742 auto I = Shared.find(Root);
2743 auto CM = Inst2Matrix.find(Root);
2744 if (I->second.size() == 1)
2745 Count = CM->second.getOpInfo();
2746 else
2747 SharedCount = CM->second.getOpInfo();
2748
2749 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2750 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2751 Count += C.first;
2752 SharedCount += C.second;
2753 }
2754 return {Count, SharedCount};
2755 }
2756
2757 void emitRemarks() {
2758 if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2759 return;
2760
2761 // Map matrix operations to their containting subprograms, by traversing
2762 // the inlinedAt chain. If the function does not have a DISubprogram, we
2763 // only map them to the containing function.
2764 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
2765 for (const auto &KV : Inst2Matrix) {
2766 if (Func.getSubprogram()) {
2767 auto *I = cast<Instruction>(KV.first);
2768 DILocation *Context = I->getDebugLoc();
2769 while (Context) {
2770 Subprog2Exprs[getSubprogram(Context->getScope())].push_back(
2771 KV.first);
2772 Context = DebugLoc(Context).getInlinedAt();
2773 }
2774 } else {
2775 Subprog2Exprs[nullptr].push_back(KV.first);
2776 }
2777 }
2778 for (auto &KV : Subprog2Exprs) {
2779 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2780 KV.second.end());
2781 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2782
2783 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
2784 for (Value *Leaf : Leaves)
2785 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2786
2787 // Generate remarks for each leaf.
2788 for (auto *L : Leaves) {
2789
2790 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2791 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2792 while (Context) {
2793 if (getSubprogram(Context->getScope()) == KV.first) {
2794 Loc = Context;
2795 break;
2796 }
2797 Context = DebugLoc(Context).getInlinedAt();
2798 }
2799
2800 SmallPtrSet<Value *, 8> ReusedExprs;
2801 OpInfoTy Counts, SharedCounts;
2802 std::tie(Counts, SharedCounts) =
2803 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2804
2805 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2807
2808 Rem << "Lowered with ";
2809 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2810 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2811 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2812 << " compute ops, "
2813 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2814 << " exposed transposes";
2815
2816 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2817 SharedCounts.NumComputeOps > 0) {
2818 Rem << ",\nadditionally "
2819 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2820 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2821 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2822 << " compute ops"
2823 << " are shared with other expressions";
2824 }
2825
2826 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2827 ORE.emit(Rem);
2828 }
2829 }
2830 }
2831
2832 std::string
2833 linearize(Value *L,
2834 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2835 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2836 const DataLayout &DL) {
2837 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2838 Lin.linearizeExpr(L, 0, false, false);
2839 return Lin.getResult();
2840 }
2841 };
2842};
2843} // namespace
2844
2847 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2848
2849 LowerMatrixIntrinsics LMT(F, TTI, Minimal ? nullptr : &AM);
2850 if (LMT.Visit()) {
2852 if (!Minimal) {
2853 PA.preserve<LoopAnalysis>();
2855 }
2856 return PA;
2857 }
2858 return PreservedAnalyses::all();
2859}
2860
2862 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2864 OS, MapClassName2PassName);
2865 OS << '<';
2866 if (Minimal)
2867 OS << "minimal";
2868 OS << '>';
2869}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
AMDGPU Register Bank Select
Rewrite undef for PHI
static const Function * getParent(const Value *V)
BitTracker BT
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:638
Hexagon Common GEP
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
Live Register Matrix
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
#define DEBUG_TYPE
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define T
#define T1
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
const SmallVectorImpl< MachineOperand > & Cond
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition SROA.cpp:2605
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition SROA.cpp:2627
This file contains some templates that are useful if you are working with the STL at all.
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
static const int BlockSize
Definition TarWriter.cpp:33
This pass exposes codegen information to IR-level passes.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Value * RHS
Value * LHS
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
reverse_iterator rbegin()
Definition BasicBlock.h:475
InstListType::reverse_iterator reverse_iterator
Definition BasicBlock.h:172
reverse_iterator rend()
Definition BasicBlock.h:477
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Instruction::CastOps getOpcode() const
Return the opcode of this CastInst.
Definition InstrTypes.h:610
static LLVM_ABI ConstantAggregateZero * get(Type *Ty)
LLVM_ABI DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Base class for scope-like contexts.
Subprogram description. Uses SubclassData1.
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:167
iterator end()
Definition DenseMap.h:81
Analysis pass which computes a DominatorTree.
Definition Dominators.h:284
static constexpr ElementCount getFixed(ScalarTy MinVal)
Definition TypeSize.h:310
void setAllowContract(bool B=true)
Definition FMF.h:90
bool allowReassoc() const
Flag queries.
Definition FMF.h:64
bool allowContract() const
Definition FMF.h:69
unsigned getNumElements() const
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition Type.cpp:803
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition Function.h:244
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition Function.h:249
LLVM_ABI CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:2348
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition IRBuilder.h:2579
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition IRBuilder.h:1833
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition IRBuilder.h:2567
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition IRBuilder.h:1867
Value * CreateZExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a ZExt or Trunc from the integer value V to DestTy.
Definition IRBuilder.h:2103
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, const AAMDNodes &AAInfo=AAMDNodes())
Create and insert a memcpy between the specified pointers.
Definition IRBuilder.h:687
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition IRBuilder.h:1613
LLVM_ABI Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
Definition IRBuilder.h:611
Value * CreateCast(Instruction::CastOps Op, Value *V, Type *DestTy, const Twine &Name="", MDNode *FPMathTag=nullptr, FMFSource FMFSource={})
Definition IRBuilder.h:2241
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Definition IRBuilder.h:345
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Definition IRBuilder.h:1926
LLVM_ABI Value * CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with 2 operands which is mangled on the first type.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition IRBuilder.h:2497
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
Definition IRBuilder.h:533
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition IRBuilder.h:1197
LLVM_ABI CallInst * CreateUnaryIntrinsic(Intrinsic::ID ID, Value *V, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with 1 operand which is mangled on its type.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition IRBuilder.h:1850
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition IRBuilder.h:2601
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1403
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
Definition IRBuilder.h:2197
Value * CreateBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:1708
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:207
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Definition IRBuilder.h:1886
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition IRBuilder.h:1651
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:1793
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1437
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2788
LLVM_ABI void moveAfter(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
LLVM_ABI void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Align getAlign() const
Return the alignment of the access that is being performed.
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
static LLVM_ABI MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static LLVM_ABI MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)
Return a location representing a particular argument of a call.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
size_type size() const
Determine the number of elements in the SetVector.
Definition SetVector.h:102
void insert_range(Range &&R)
Definition SetVector.h:175
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition SetVector.h:261
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:150
bool erase(PtrType Ptr)
Remove pointer from the set.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void resize(size_type N)
void push_back(const T &Elt)
Align getAlign() const
bool isVolatile() const
Return true if this is a store to a volatile memory location.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition StringRef.h:611
constexpr size_t size() const
size - Get the string size.
Definition StringRef.h:146
Analysis pass providing the TargetTransformInfo.
@ TCK_RecipThroughput
Reciprocal throughput.
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition Type.h:352
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:198
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition Type.cpp:231
bool isVoidTy() const
Return true if this is 'void'.
Definition Type.h:139
UnaryOps getOpcode() const
Definition InstrTypes.h:154
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
user_iterator user_begin()
Definition Value.h:402
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
iterator_range< user_iterator > users()
Definition Value.h:426
bool use_empty() const
Definition Value.h:346
iterator_range< use_iterator > uses()
Definition Value.h:380
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
Type * getElementType() const
constexpr ScalarTy getFixedValue() const
Definition TypeSize.h:201
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
OneUse_match< SubPat > m_OneUse(const SubPat &SP)
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
ElementType
The element type of an SRV or UAV resource.
Definition DXILABI.h:60
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
void fill(R &&Range, T &&Value)
Provide wrappers to std::fill which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1655
detail::zippy< detail::zip_first, T, U, Args... > zip_equal(T &&t, U &&u, Args &&...args)
zip iterator that assumes that all iteratees have the same length.
Definition STLExtras.h:839
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition ScopeExit.h:59
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2472
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:644
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
Definition APInt.h:2113
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:632
LLVM_ABI Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
LLVM_ABI void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1732
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)
Definition DWP.cpp:622
void sort(IteratorTy Start, IteratorTy End)
Definition STLExtras.h:1622
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:167
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
TargetTransformInfo TTI
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
@ Mul
Product of integers.
@ Add
Sum of integers.
DWARFExpression::Operation Op
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
OutputIt copy(R &&Range, OutputIt Out)
Definition STLExtras.h:1835
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:560
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:201
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=MaxLookupSearchDepth)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
AAResults AliasAnalysis
Temporary typedef for legacy code that uses a generic AliasAnalysis pointer or reference.
LLVM_ABI llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:869
#define N
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition PassManager.h:70