Bug Summary

File:llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Warning:line 1030, column 44
Called C++ object pointer is null

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -disable-llvm-verifier -discard-value-names -main-file-name LowerMatrixIntrinsics.cpp -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mthread-model posix -mframe-pointer=none -fmath-errno -fno-rounding-math -masm-verbose -mconstructor-aliases -munwind-tables -target-cpu x86-64 -dwarf-column-info -fno-split-dwarf-inlining -debugger-tuning=gdb -ffunction-sections -fdata-sections -resource-dir /usr/lib/llvm-11/lib/clang/11.0.0 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I /build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/build-llvm/lib/Transforms/Scalar -I /build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar -I /build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/build-llvm/include -I /build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/include -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/6.3.0/../../../../include/c++/6.3.0 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/6.3.0/../../../../include/x86_64-linux-gnu/c++/6.3.0 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/6.3.0/../../../../include/x86_64-linux-gnu/c++/6.3.0 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/6.3.0/../../../../include/c++/6.3.0/backward -internal-isystem /usr/local/include -internal-isystem /usr/lib/llvm-11/lib/clang/11.0.0/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -O2 -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-comment -std=c++14 -fdeprecated-macro -fdebug-compilation-dir /build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/build-llvm/lib/Transforms/Scalar -fdebug-prefix-map=/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347=. -ferror-limit 19 -fmessage-length 0 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fobjc-runtime=gcc -fdiagnostics-show-option -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -o /tmp/scan-build-2020-03-09-184146-41876-1 -x c++ /build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Implement multiply & add fusion
13// * Add remark, summarizing the available matrix optimization opportunities
14// (WIP).
15//
16//===----------------------------------------------------------------------===//
17
18#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
19#include "llvm/ADT/GraphTraits.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Analysis/OptimizationRemarkEmitter.h"
23#include "llvm/Analysis/TargetTransformInfo.h"
24#include "llvm/Analysis/ValueTracking.h"
25#include "llvm/Analysis/VectorUtils.h"
26#include "llvm/IR/CFG.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/Function.h"
29#include "llvm/IR/IRBuilder.h"
30#include "llvm/IR/Instructions.h"
31#include "llvm/IR/IntrinsicInst.h"
32#include "llvm/IR/PatternMatch.h"
33#include "llvm/InitializePasses.h"
34#include "llvm/Pass.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Transforms/Scalar.h"
37
38using namespace llvm;
39using namespace PatternMatch;
40
41#define DEBUG_TYPE"lower-matrix-intrinsics" "lower-matrix-intrinsics"
42
43static cl::opt<bool> EnableShapePropagation(
44 "matrix-propagate-shape", cl::init(true), cl::Hidden,
45 cl::desc("Enable/disable shape propagation from matrix intrinsics to other "
46 "instructions."));
47
48static cl::opt<bool> AllowContractEnabled(
49 "matrix-allow-contract", cl::init(false), cl::Hidden,
50 cl::desc("Allow the use of FMAs if available and profitable. This may "
51 "result in different results, due to less rounding error."));
52
53namespace {
54
55// Given an element poitner \p BasePtr to the start of a (sub) matrix, compute
56// the start address of column \p Col with type (\p EltType x \p NumRows)
57// assuming \p Stride elements between start two consecutive columns.
58// \p Stride must be >= \p NumRows.
59//
60// Consider a 4x4 matrix like below
61//
62// 0 1 2 3
63// 0 v_0_0 v_0_1 v_0_2 v_0_3
64// 1 v_1_0 v_1_1 v_1_2 v_1_3
65// 2 v_2_0 v_2_1 v_2_2 v_2_3
66// 3 v_3_0 v_3_1 v_3_2 v_3_3
67
68// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
69// we need a pointer to the first element of the submatrix as base pointer.
70// Then we can use computeColumnAddr to compute the addresses for the columns
71// of the sub-matrix.
72//
73// Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
74// -> just returns Base
75// Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
76// -> returns Base + (1 * 4)
77// Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
78// -> returns Base + (2 * 4)
79//
80// The graphic below illustrates the number of elements in a column (marked
81// with |) and the number of skipped elements (marked with }).
82//
83// v_0_0 v_0_1 {v_0_2 {v_0_3
84// Base Col 1 Col 2
85// | | |
86// v_1_0 |v_1_1 |v_1_2 |v_1_3
87// v_2_0 |v_2_1 |v_2_2 |v_2_3
88// v_3_0 {v_3_1 {v_3_2 v_3_3
89//
90Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
91 unsigned NumRows, Type *EltType,
92 IRBuilder<> &Builder) {
93
94 assert((!isa<ConstantInt>(Stride) ||(((!isa<ConstantInt>(Stride) || cast<ConstantInt>
(Stride)->getZExtValue() >= NumRows) && "Stride must be >= the number of rows."
) ? static_cast<void> (0) : __assert_fail ("(!isa<ConstantInt>(Stride) || cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && \"Stride must be >= the number of rows.\""
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 96, __PRETTY_FUNCTION__))
95 cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) &&(((!isa<ConstantInt>(Stride) || cast<ConstantInt>
(Stride)->getZExtValue() >= NumRows) && "Stride must be >= the number of rows."
) ? static_cast<void> (0) : __assert_fail ("(!isa<ConstantInt>(Stride) || cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && \"Stride must be >= the number of rows.\""
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 96, __PRETTY_FUNCTION__))
96 "Stride must be >= the number of rows.")(((!isa<ConstantInt>(Stride) || cast<ConstantInt>
(Stride)->getZExtValue() >= NumRows) && "Stride must be >= the number of rows."
) ? static_cast<void> (0) : __assert_fail ("(!isa<ConstantInt>(Stride) || cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && \"Stride must be >= the number of rows.\""
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 96, __PRETTY_FUNCTION__))
;
97 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
98
99 // Compute the start of the column with index Col as Col * Stride.
100 Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start");
101
102 // Get pointer to the start of the selected column. Skip GEP creation,
103 // if we select column 0.
104 if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero())
105 ColumnStart = BasePtr;
106 else
107 ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep");
108
109 // Cast elementwise column start pointer to a pointer to a column
110 // (EltType x NumRows)*.
111 Type *ColumnType = VectorType::get(EltType, NumRows);
112 Type *ColumnPtrType = PointerType::get(ColumnType, AS);
113 return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast");
114}
115
116/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
117///
118/// Currently, the lowering for each matrix intrinsic is done as follows:
119/// 1. Propagate the shape information from intrinsics to connected
120/// instructions.
121/// 2. Lower instructions with shape information.
122/// 2.1. Get column vectors for each argument. If we already lowered the
123/// definition of an argument, use the produced column vectors directly.
124/// If not, split the operand vector containing an embedded matrix into
125/// a set of column vectors,
126/// 2.2. Lower the instruction in terms of columnwise operations, which yields
127/// a set of column vectors containing result matrix. Note that we lower
128/// all instructions that have shape information. Besides the intrinsics,
129/// this includes stores for example.
130/// 2.3. Update uses of the lowered instruction. If we have shape information
131/// for a user, there is nothing to do, as we will look up the result
132/// column matrix when lowering the user. For other uses, we embed the
133/// result matrix in a flat vector and update the use.
134/// 2.4. Cache the result column matrix for the instruction we lowered
135/// 3. After we lowered all instructions in a function, remove the now
136/// obsolete instructions.
137///
138class LowerMatrixIntrinsics {
139 Function &Func;
140 const DataLayout &DL;
141 const TargetTransformInfo &TTI;
142 OptimizationRemarkEmitter &ORE;
143
144 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
145 struct OpInfoTy {
146 /// Number of stores emitted to generate this matrix.
147 unsigned NumStores = 0;
148 /// Number of loads emitted to generate this matrix.
149 unsigned NumLoads = 0;
150 /// Number of compute operations emitted to generate this matrix.
151 unsigned NumComputeOps = 0;
152
153 OpInfoTy &operator+=(const OpInfoTy &RHS) {
154 NumStores += RHS.NumStores;
155 NumLoads += RHS.NumLoads;
156 NumComputeOps += RHS.NumComputeOps;
157 return *this;
158 }
159 };
160
161 /// Wrapper class representing a matrix as a set of column vectors.
162 /// All column vectors must have the same vector type.
163 class ColumnMatrixTy {
164 SmallVector<Value *, 16> Columns;
165
166 OpInfoTy OpInfo;
167
168 public:
169 ColumnMatrixTy() : Columns() {}
170 ColumnMatrixTy(ArrayRef<Value *> Cols)
171 : Columns(Cols.begin(), Cols.end()) {}
172
173 Value *getColumn(unsigned i) const { return Columns[i]; }
174
175 void setColumn(unsigned i, Value *V) { Columns[i] = V; }
176
177 size_t getNumColumns() const { return Columns.size(); }
178 size_t getNumRows() const {
179 assert(Columns.size() > 0 && "Cannot call getNumRows without columns")((Columns.size() > 0 && "Cannot call getNumRows without columns"
) ? static_cast<void> (0) : __assert_fail ("Columns.size() > 0 && \"Cannot call getNumRows without columns\""
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 179, __PRETTY_FUNCTION__))
;
180 return cast<VectorType>(Columns[0]->getType())->getNumElements();
181 }
182
183 const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; }
184
185 SmallVectorImpl<Value *> &getColumnVectors() { return Columns; }
186
187 void addColumn(Value *V) { Columns.push_back(V); }
188
189 VectorType *getColumnTy() {
190 return cast<VectorType>(Columns[0]->getType());
191 }
192
193 iterator_range<SmallVector<Value *, 8>::iterator> columns() {
194 return make_range(Columns.begin(), Columns.end());
195 }
196
197 /// Embed the columns of the matrix into a flat vector by concatenating
198 /// them.
199 Value *embedInVector(IRBuilder<> &Builder) const {
200 return Columns.size() == 1 ? Columns[0]
201 : concatenateVectors(Builder, Columns);
202 }
203
204 ColumnMatrixTy &addNumLoads(unsigned N) {
205 OpInfo.NumLoads += N;
206 return *this;
207 }
208
209 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
210
211 ColumnMatrixTy &addNumStores(unsigned N) {
212 OpInfo.NumStores += N;
213 return *this;
214 }
215
216 ColumnMatrixTy &addNumComputeOps(unsigned N) {
217 OpInfo.NumComputeOps += N;
218 return *this;
219 }
220
221 unsigned getNumStores() const { return OpInfo.NumStores; }
222 unsigned getNumLoads() const { return OpInfo.NumLoads; }
223 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
224
225 const OpInfoTy &getOpInfo() const { return OpInfo; }
226 };
227
228 struct ShapeInfo {
229 unsigned NumRows;
230 unsigned NumColumns;
231
232 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
233 : NumRows(NumRows), NumColumns(NumColumns) {}
234
235 ShapeInfo(Value *NumRows, Value *NumColumns)
236 : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()),
237 NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {}
238
239 bool operator==(const ShapeInfo &other) {
240 return NumRows == other.NumRows && NumColumns == other.NumColumns;
241 }
242 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
243
244 /// Returns true if shape-information is defined, meaning both dimensions
245 /// are != 0.
246 operator bool() const {
247 assert(NumRows == 0 || NumColumns != 0)((NumRows == 0 || NumColumns != 0) ? static_cast<void> (
0) : __assert_fail ("NumRows == 0 || NumColumns != 0", "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 247, __PRETTY_FUNCTION__))
;
248 return NumRows != 0;
249 }
250 };
251
252 /// Maps instructions to their shape information. The shape information
253 /// describes the shape to be used while lowering. This matches the shape of
254 /// the result value of the instruction, with the only exceptions being store
255 /// instructions and the matrix_columnwise_store intrinsics. For those, the
256 /// shape information indicates that those instructions should be lowered
257 /// using shape information as well.
258 DenseMap<Value *, ShapeInfo> ShapeMap;
259
260 /// List of instructions to remove. While lowering, we are not replacing all
261 /// users of a lowered instruction, if shape information is available and
262 /// those need to be removed after we finished lowering.
263 SmallVector<Instruction *, 16> ToRemove;
264
265 /// Map from instructions to their produced column matrix.
266 MapVector<Value *, ColumnMatrixTy> Inst2ColumnMatrix;
267
268public:
269 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
270 OptimizationRemarkEmitter &ORE)
271 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), ORE(ORE) {}
272
273 unsigned getNumOps(Type *VT) {
274 assert(isa<VectorType>(VT) && "Expected vector type")((isa<VectorType>(VT) && "Expected vector type"
) ? static_cast<void> (0) : __assert_fail ("isa<VectorType>(VT) && \"Expected vector type\""
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 274, __PRETTY_FUNCTION__))
;
275 return getNumOps(VT->getScalarType(),
276 cast<VectorType>(VT)->getNumElements());
277 }
278
279 //
280 /// Return the estimated number of vector ops required for an operation on
281 /// \p VT * N.
282 unsigned getNumOps(Type *ST, unsigned N) {
283 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
284 double(TTI.getRegisterBitWidth(true)));
285 }
286
287 /// Return the set of column vectors that a matrix value is lowered to.
288 ///
289 /// If we lowered \p MatrixVal, just return the cache result column matrix.
290 /// Otherwie split the flat vector \p MatrixVal containing a matrix with
291 /// shape \p SI into column vectors.
292 ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
293 IRBuilder<> &Builder) {
294 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
295 assert(VType && "MatrixVal must be a vector type")((VType && "MatrixVal must be a vector type") ? static_cast
<void> (0) : __assert_fail ("VType && \"MatrixVal must be a vector type\""
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 295, __PRETTY_FUNCTION__))
;
296 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&((VType->getNumElements() == SI.NumRows * SI.NumColumns &&
"The vector size must match the number of matrix elements") ?
static_cast<void> (0) : __assert_fail ("VType->getNumElements() == SI.NumRows * SI.NumColumns && \"The vector size must match the number of matrix elements\""
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 297, __PRETTY_FUNCTION__))
297 "The vector size must match the number of matrix elements")((VType->getNumElements() == SI.NumRows * SI.NumColumns &&
"The vector size must match the number of matrix elements") ?
static_cast<void> (0) : __assert_fail ("VType->getNumElements() == SI.NumRows * SI.NumColumns && \"The vector size must match the number of matrix elements\""
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 297, __PRETTY_FUNCTION__))
;
298
299 // Check if we lowered MatrixVal using shape information. In that case,
300 // return the existing column matrix, if it matches the requested shape
301 // information. If there is a mis-match, embed the result in a flat
302 // vector and split it later.
303 auto Found = Inst2ColumnMatrix.find(MatrixVal);
304 if (Found != Inst2ColumnMatrix.end()) {
305 ColumnMatrixTy &M = Found->second;
306 // Return the found matrix, if its shape matches the requested shape
307 // information
308 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
309 return M;
310
311 MatrixVal = M.embedInVector(Builder);
312 }
313
314 // Otherwise split MatrixVal.
315 SmallVector<Value *, 16> SplitVecs;
316 Value *Undef = UndefValue::get(VType);
317 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
318 MaskStart += SI.NumRows) {
319 Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0);
320 Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split");
321 SplitVecs.push_back(V);
322 }
323
324 return {SplitVecs};
325 }
326
327 /// If \p V already has a known shape return false. Otherwise set the shape
328 /// for instructions that support it.
329 bool setShapeInfo(Value *V, ShapeInfo Shape) {
330 assert(Shape && "Shape not set")((Shape && "Shape not set") ? static_cast<void>
(0) : __assert_fail ("Shape && \"Shape not set\"", "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 330, __PRETTY_FUNCTION__))
;
331 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
332 return false;
333
334 auto SIter = ShapeMap.find(V);
335 if (SIter != ShapeMap.end()) {
336 LLVM_DEBUG(dbgs() << " not overriding existing shape: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("lower-matrix-intrinsics")) { dbgs() << " not overriding existing shape: "
<< SIter->second.NumRows << " " << SIter
->second.NumColumns << " for " << *V << "\n"
; } } while (false)
337 << SIter->second.NumRows << " "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("lower-matrix-intrinsics")) { dbgs() << " not overriding existing shape: "
<< SIter->second.NumRows << " " << SIter
->second.NumColumns << " for " << *V << "\n"
; } } while (false)
338 << SIter->second.NumColumns << " for " << *V << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("lower-matrix-intrinsics")) { dbgs() << " not overriding existing shape: "
<< SIter->second.NumRows << " " << SIter
->second.NumColumns << " for " << *V << "\n"
; } } while (false)
;
339 return false;
340 }
341
342 ShapeMap.insert({V, Shape});
343 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumnsdo { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("lower-matrix-intrinsics")) { dbgs() << " " << Shape
.NumRows << " x " << Shape.NumColumns << " for "
<< *V << "\n"; } } while (false)
344 << " for " << *V << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("lower-matrix-intrinsics")) { dbgs() << " " << Shape
.NumRows << " x " << Shape.NumColumns << " for "
<< *V << "\n"; } } while (false)
;
345 return true;
346 }
347
348 bool isUniformShape(Value *V) {
349 Instruction *I = dyn_cast<Instruction>(V);
350 if (!I)
351 return true;
352
353 switch (I->getOpcode()) {
354 case Instruction::FAdd:
355 case Instruction::FSub:
356 case Instruction::FMul: // Scalar multiply.
357 case Instruction::Add:
358 case Instruction::Mul:
359 case Instruction::Sub:
360 return true;
361 default:
362 return false;
363 }
364 }
365
366 /// Returns true if shape information can be used for \p V. The supported
367 /// instructions must match the instructions that can be lowered by this pass.
368 bool supportsShapeInfo(Value *V) {
369 Instruction *Inst = dyn_cast<Instruction>(V);
370 if (!Inst)
371 return false;
372
373 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
374 if (II)
375 switch (II->getIntrinsicID()) {
376 case Intrinsic::matrix_multiply:
377 case Intrinsic::matrix_transpose:
378 case Intrinsic::matrix_columnwise_load:
379 case Intrinsic::matrix_columnwise_store:
380 return true;
381 default:
382 return false;
383 }
384 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
385 }
386
387 /// Propagate the shape information of instructions to their users.
388 /// The work list contains instructions for which we can compute the shape,
389 /// either based on the information provided by matrix intrinsics or known
390 /// shapes of operands.
391 SmallVector<Instruction *, 32>
392 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
393 SmallVector<Instruction *, 32> NewWorkList;
394 // Pop an element for which we guaranteed to have at least one of the
395 // operand shapes. Add the shape for this and then add users to the work
396 // list.
397 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("lower-matrix-intrinsics")) { dbgs() << "Forward-propagate shapes:\n"
; } } while (false)
;
398 while (!WorkList.empty()) {
399 Instruction *Inst = WorkList.back();
400 WorkList.pop_back();
401
402 // New entry, set the value and insert operands
403 bool Propagate = false;
404
405 Value *MatrixA;
406 Value *MatrixB;
407 Value *M;
408 Value *N;
409 Value *K;
410 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
411 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
412 m_Value(N), m_Value(K)))) {
413 Propagate = setShapeInfo(Inst, {M, K});
414 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
415 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
416 // Flip dimensions.
417 Propagate = setShapeInfo(Inst, {N, M});
418 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
419 m_Value(MatrixA), m_Value(), m_Value(),
420 m_Value(M), m_Value(N)))) {
421 Propagate = setShapeInfo(Inst, {N, M});
422 } else if (match(Inst,
423 m_Intrinsic<Intrinsic::matrix_columnwise_load>(
424 m_Value(), m_Value(), m_Value(M), m_Value(N)))) {
425 Propagate = setShapeInfo(Inst, {M, N});
426 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
427 auto OpShape = ShapeMap.find(MatrixA);
428 if (OpShape != ShapeMap.end())
429 setShapeInfo(Inst, OpShape->second);
430 continue;
431 } else if (isUniformShape(Inst)) {
432 // Find the first operand that has a known shape and use that.
433 for (auto &Op : Inst->operands()) {
434 auto OpShape = ShapeMap.find(Op.get());
435 if (OpShape != ShapeMap.end()) {
436 Propagate |= setShapeInfo(Inst, OpShape->second);
437 break;
438 }
439 }
440 }
441
442 if (Propagate) {
443 NewWorkList.push_back(Inst);
444 for (auto *User : Inst->users())
445 if (ShapeMap.count(User) == 0)
446 WorkList.push_back(cast<Instruction>(User));
447 }
448 }
449
450 return NewWorkList;
451 }
452
453 /// Propagate the shape to operands of instructions with shape information.
454 /// \p Worklist contains the instruction for which we already know the shape.
455 SmallVector<Instruction *, 32>
456 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
457 SmallVector<Instruction *, 32> NewWorkList;
458
459 auto pushInstruction = [](Value *V,
460 SmallVectorImpl<Instruction *> &WorkList) {
461 Instruction *I = dyn_cast<Instruction>(V);
462 if (I)
463 WorkList.push_back(I);
464 };
465 // Pop an element with known shape. Traverse the operands, if their shape
466 // derives from the result shape and is unknown, add it and add them to the
467 // worklist.
468 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("lower-matrix-intrinsics")) { dbgs() << "Backward-propagate shapes:\n"
; } } while (false)
;
469 while (!WorkList.empty()) {
470 Value *V = WorkList.back();
471 WorkList.pop_back();
472
473 size_t BeforeProcessingV = WorkList.size();
474 if (!isa<Instruction>(V))
475 continue;
476
477 Value *MatrixA;
478 Value *MatrixB;
479 Value *M;
480 Value *N;
481 Value *K;
482 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
483 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
484 m_Value(N), m_Value(K)))) {
485 if (setShapeInfo(MatrixA, {M, N}))
486 pushInstruction(MatrixA, WorkList);
487
488 if (setShapeInfo(MatrixB, {N, K}))
489 pushInstruction(MatrixB, WorkList);
490
491 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
492 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
493 // Flip dimensions.
494 if (setShapeInfo(MatrixA, {M, N}))
495 pushInstruction(MatrixA, WorkList);
496 } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
497 m_Value(MatrixA), m_Value(), m_Value(),
498 m_Value(M), m_Value(N)))) {
499 if (setShapeInfo(MatrixA, {M, N})) {
500 pushInstruction(MatrixA, WorkList);
501 }
502 } else if (isa<LoadInst>(V) ||
503 match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) {
504 // Nothing to do, no matrix input.
505 } else if (isa<StoreInst>(V)) {
506 // Nothing to do. We forward-propagated to this so we would just
507 // backward propagate to an instruction with an already known shape.
508 } else if (isUniformShape(V)) {
509 // Propagate to all operands.
510 ShapeInfo Shape = ShapeMap[V];
511 for (Use &U : cast<Instruction>(V)->operands()) {
512 if (setShapeInfo(U.get(), Shape))
513 pushInstruction(U.get(), WorkList);
514 }
515 }
516 // After we discovered new shape info for new instructions in the
517 // worklist, we use their users as seeds for the next round of forward
518 // propagation.
519 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
520 for (User *U : WorkList[I]->users())
521 if (isa<Instruction>(U) && V != U)
522 NewWorkList.push_back(cast<Instruction>(U));
523 }
524 return NewWorkList;
525 }
526
527 bool Visit() {
528 if (EnableShapePropagation) {
529 SmallVector<Instruction *, 32> WorkList;
530
531 // Initially only the shape of matrix intrinsics is known.
532 // Initialize the work list with ops carrying shape information.
533 for (BasicBlock &BB : Func)
534 for (Instruction &Inst : BB) {
535 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
536 if (!II)
537 continue;
538
539 switch (II->getIntrinsicID()) {
540 case Intrinsic::matrix_multiply:
541 case Intrinsic::matrix_transpose:
542 case Intrinsic::matrix_columnwise_load:
543 case Intrinsic::matrix_columnwise_store:
544 WorkList.push_back(&Inst);
545 break;
546 default:
547 break;
548 }
549 }
550 // Propagate shapes until nothing changes any longer.
551 while (!WorkList.empty()) {
552 WorkList = propagateShapeForward(WorkList);
553 WorkList = propagateShapeBackward(WorkList);
554 }
555 }
556
557 ReversePostOrderTraversal<Function *> RPOT(&Func);
558 bool Changed = false;
559 for (auto *BB : RPOT) {
560 for (Instruction &Inst : make_early_inc_range(*BB)) {
561 IRBuilder<> Builder(&Inst);
562
563 if (CallInst *CInst = dyn_cast<CallInst>(&Inst))
564 Changed |= VisitCallInst(CInst);
565
566 Value *Op1;
567 Value *Op2;
568 if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
569 Changed |= VisitBinaryOperator(BinOp);
570 if (match(&Inst, m_Load(m_Value(Op1))))
571 Changed |= VisitLoad(&Inst, Op1, Builder);
572 else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
573 Changed |= VisitStore(&Inst, Op1, Op2, Builder);
574 }
575 }
576
577 RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, DL);
578 RemarkGen.emitRemarks();
579
580 for (Instruction *Inst : reverse(ToRemove))
581 Inst->eraseFromParent();
582
583 return Changed;
584 }
585
586 LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType,
587 IRBuilder<> &Builder) {
588 return Builder.CreateAlignedLoad(
589 ColumnPtr, Align(DL.getABITypeAlignment(EltType)), "col.load");
590 }
591
592 StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr,
593 Type *EltType, IRBuilder<> &Builder) {
594 return Builder.CreateAlignedStore(ColumnValue, ColumnPtr,
595 DL.getABITypeAlign(EltType));
596 }
597
598
599 /// Turns \p BasePtr into an elementwise pointer to \p EltType.
600 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
601 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
602 Type *EltPtrType = PointerType::get(EltType, AS);
603 return Builder.CreatePointerCast(BasePtr, EltPtrType);
604 }
605
606 /// Replace intrinsic calls
607 bool VisitCallInst(CallInst *Inst) {
608 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
609 return false;
610
611 switch (Inst->getCalledFunction()->getIntrinsicID()) {
612 case Intrinsic::matrix_multiply:
613 LowerMultiply(Inst);
614 break;
615 case Intrinsic::matrix_transpose:
616 LowerTranspose(Inst);
617 break;
618 case Intrinsic::matrix_columnwise_load:
619 LowerColumnwiseLoad(Inst);
620 break;
621 case Intrinsic::matrix_columnwise_store:
622 LowerColumnwiseStore(Inst);
623 break;
624 default:
625 return false;
626 }
627 return true;
628 }
629
630 void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
631 ShapeInfo Shape) {
632 IRBuilder<> Builder(Inst);
633 auto VType = cast<VectorType>(Inst->getType());
634 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
635 ColumnMatrixTy Result;
636 // Distance between start of one column and the start of the next
637 for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
638 Value *GEP =
639 computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows,
640 VType->getElementType(), Builder);
641 Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
642 Result.addColumn(Column);
643 }
644
645 finalizeLowering(Inst,
646 Result.addNumLoads(getNumOps(Result.getColumnTy()) *
647 Result.getNumColumns()),
648 Builder);
649 }
650
651 /// Lowers llvm.matrix.columnwise.load.
652 ///
653 /// The intrinsic loads a matrix from memory using a stride between columns.
654 void LowerColumnwiseLoad(CallInst *Inst) {
655 Value *Ptr = Inst->getArgOperand(0);
656 Value *Stride = Inst->getArgOperand(1);
657 LowerLoad(Inst, Ptr, Stride,
658 {Inst->getArgOperand(2), Inst->getArgOperand(3)});
659 }
660
661 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
662 ShapeInfo Shape) {
663 IRBuilder<> Builder(Inst);
664 auto VType = cast<VectorType>(Matrix->getType());
665 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
666 auto LM = getMatrix(Matrix, Shape, Builder);
667 for (auto C : enumerate(LM.columns())) {
668 Value *GEP =
669 computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
670 Shape.NumRows, VType->getElementType(), Builder);
671 createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
672 }
673 Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores(
674 getNumOps(LM.getColumnTy()) * LM.getNumColumns());
675
676 ToRemove.push_back(Inst);
677 }
678
679 /// Lowers llvm.matrix.columnwise.store.
680 ///
681 /// The intrinsic store a matrix back memory using a stride between columns.
682 void LowerColumnwiseStore(CallInst *Inst) {
683 Value *Matrix = Inst->getArgOperand(0);
684 Value *Ptr = Inst->getArgOperand(1);
685 Value *Stride = Inst->getArgOperand(2);
686 LowerStore(Inst, Matrix, Ptr, Stride,
687 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
688 }
689
690 /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from
691 /// the matrix \p LM represented as a vector of column vectors.
692 Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J,
693 unsigned NumElts, IRBuilder<> &Builder) {
694 Value *Col = LM.getColumn(J);
695 Value *Undef = UndefValue::get(Col->getType());
696 Constant *Mask = createSequentialMask(Builder, I, NumElts, 0);
697 return Builder.CreateShuffleVector(Col, Undef, Mask, "block");
698 }
699
700 // Set elements I..I+NumElts-1 to Block
701 Value *insertVector(Value *Col, unsigned I, Value *Block,
702 IRBuilder<> &Builder) {
703
704 // First, bring Block to the same size as Col
705 unsigned BlockNumElts =
706 cast<VectorType>(Block->getType())->getNumElements();
707 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements();
708 assert(NumElts >= BlockNumElts && "Too few elements for current block")((NumElts >= BlockNumElts && "Too few elements for current block"
) ? static_cast<void> (0) : __assert_fail ("NumElts >= BlockNumElts && \"Too few elements for current block\""
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 708, __PRETTY_FUNCTION__))
;
709
710 Value *ExtendMask =
711 createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);
712 Value *Undef = UndefValue::get(Block->getType());
713 Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask);
714
715 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
716 // 8, 4, 5, 6
717 SmallVector<Constant *, 16> Mask;
718 unsigned i;
719 for (i = 0; i < I; i++)
720 Mask.push_back(Builder.getInt32(i));
721
722 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements();
723 for (; i < I + BlockNumElts; i++)
724 Mask.push_back(Builder.getInt32(i - I + VecNumElts));
725
726 for (; i < VecNumElts; i++)
727 Mask.push_back(Builder.getInt32(i));
728
729 Value *MaskVal = ConstantVector::get(Mask);
730
731 return Builder.CreateShuffleVector(Col, Block, MaskVal);
732 }
733
734 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
735 IRBuilder<> &Builder, bool AllowContraction,
736 unsigned &NumComputeOps) {
737 NumComputeOps += getNumOps(A->getType());
738 if (!Sum)
739 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
740
741 if (UseFPOp) {
742 if (AllowContraction) {
743 // Use fmuladd for floating point operations and let the backend decide
744 // if that's profitable.
745 Function *FMulAdd = Intrinsic::getDeclaration(
746 Func.getParent(), Intrinsic::fmuladd, A->getType());
747 return Builder.CreateCall(FMulAdd, {A, B, Sum});
748 }
749 NumComputeOps += getNumOps(A->getType());
750 Value *Mul = Builder.CreateFMul(A, B);
751 return Builder.CreateFAdd(Sum, Mul);
752 }
753
754 NumComputeOps += getNumOps(A->getType());
755 Value *Mul = Builder.CreateMul(A, B);
756 return Builder.CreateAdd(Sum, Mul);
757 }
758
759 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
760 /// users with shape information, there's nothing to do: the will use the
761 /// cached value when they are lowered. For other users, \p Matrix is
762 /// flattened and the uses are updated to use it. Also marks \p Inst for
763 /// deletion.
764 void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix,
765 IRBuilder<> &Builder) {
766 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
767
768 ToRemove.push_back(Inst);
769 Value *Flattened = nullptr;
770 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
771 Use &U = *I++;
772 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
773 if (!Flattened)
774 Flattened = Matrix.embedInVector(Builder);
775 U.set(Flattened);
776 }
777 }
778 }
779
780 /// Lowers llvm.matrix.multiply.
781 void LowerMultiply(CallInst *MatMul) {
782 IRBuilder<> Builder(MatMul);
783 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
784 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
785 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
786
787 const ColumnMatrixTy &Lhs =
788 getMatrix(MatMul->getArgOperand(0), LShape, Builder);
789 const ColumnMatrixTy &Rhs =
790 getMatrix(MatMul->getArgOperand(1), RShape, Builder);
791
792 const unsigned R = LShape.NumRows;
793 const unsigned M = LShape.NumColumns;
794 const unsigned C = RShape.NumColumns;
795 assert(M == RShape.NumRows)((M == RShape.NumRows) ? static_cast<void> (0) : __assert_fail
("M == RShape.NumRows", "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 795, __PRETTY_FUNCTION__))
;
796
797 // Initialize the output
798 ColumnMatrixTy Result;
799 for (unsigned J = 0; J < C; ++J)
800 Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
801
802 const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
803 EltType->getPrimitiveSizeInBits(),
804 uint64_t(1));
805
806 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
807 MatMul->hasAllowContract());
808 unsigned NumComputeOps = 0;
809 // Multiply columns from the first operand with scalars from the second
810 // operand. Then move along the K axes and accumulate the columns. With
811 // this the adds can be vectorized without reassociation.
812 for (unsigned J = 0; J < C; ++J) {
813 unsigned BlockSize = VF;
814 for (unsigned I = 0; I < R; I += BlockSize) {
815 // Gradually lower the vectorization factor to cover the remainder.
816 while (I + BlockSize > R)
817 BlockSize /= 2;
818
819 Value *Sum = nullptr;
820 for (unsigned K = 0; K < M; ++K) {
821 Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
822 Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
823 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
824 Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
825 Builder, AllowContract, NumComputeOps);
826 }
827 Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
828 }
829 }
830 Result.addNumComputeOps(NumComputeOps);
831 finalizeLowering(MatMul, Result, Builder);
832 }
833
834 /// Lowers llvm.matrix.transpose.
835 void LowerTranspose(CallInst *Inst) {
836 ColumnMatrixTy Result;
837 IRBuilder<> Builder(Inst);
838 Value *InputVal = Inst->getArgOperand(0);
839 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
840 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
841 ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
842
843 for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) {
844 // Build a single column vector for this row. First initialize it.
845 Value *ResultColumn = UndefValue::get(
846 VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns));
847
848 // Go through the elements of this row and insert it into the resulting
849 // column vector.
850 for (auto C : enumerate(InputMatrix.columns())) {
851 Value *Elt = Builder.CreateExtractElement(C.value(), Row);
852 // We insert at index Column since that is the row index after the
853 // transpose.
854 ResultColumn =
855 Builder.CreateInsertElement(ResultColumn, Elt, C.index());
856 }
857 Result.addColumn(ResultColumn);
858 }
859
860 // TODO: Improve estimate of operations needed for transposes. Currently we
861 // just count the insertelement/extractelement instructions, but do not
862 // account for later simplifications/combines.
863 finalizeLowering(
864 Inst,
865 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns),
866 Builder);
867 }
868
869 /// Lower load instructions, if shape information is available.
870 bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) {
871 auto I = ShapeMap.find(Inst);
872 if (I == ShapeMap.end())
873 return false;
874
875 LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second);
876 return true;
877 }
878
879 bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr,
880 IRBuilder<> &Builder) {
881 auto I = ShapeMap.find(StoredVal);
882 if (I == ShapeMap.end())
883 return false;
884
885 LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
886 return true;
887 }
888
889 /// Lower binary operators, if shape information is available.
890 bool VisitBinaryOperator(BinaryOperator *Inst) {
891 auto I = ShapeMap.find(Inst);
892 if (I == ShapeMap.end())
893 return false;
894
895 Value *Lhs = Inst->getOperand(0);
896 Value *Rhs = Inst->getOperand(1);
897
898 IRBuilder<> Builder(Inst);
899 ShapeInfo &Shape = I->second;
900
901 ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
902 ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
903
904 // Add each column and store the result back into the opmapping
905 ColumnMatrixTy Result;
906 auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
907 switch (Inst->getOpcode()) {
908 case Instruction::Add:
909 return Builder.CreateAdd(LHS, RHS);
910 case Instruction::Mul:
911 return Builder.CreateMul(LHS, RHS);
912 case Instruction::Sub:
913 return Builder.CreateSub(LHS, RHS);
914 case Instruction::FAdd:
915 return Builder.CreateFAdd(LHS, RHS);
916 case Instruction::FMul:
917 return Builder.CreateFMul(LHS, RHS);
918 case Instruction::FSub:
919 return Builder.CreateFSub(LHS, RHS);
920 default:
921 llvm_unreachable("Unsupported binary operator for matrix")::llvm::llvm_unreachable_internal("Unsupported binary operator for matrix"
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 921)
;
922 }
923 };
924 for (unsigned C = 0; C < Shape.NumColumns; ++C)
925 Result.addColumn(
926 BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
927
928 finalizeLowering(Inst,
929 Result.addNumComputeOps(getNumOps(Result.getColumnTy()) *
930 Result.getNumColumns()),
931 Builder);
932 return true;
933 }
934
935 /// Helper to linearize a matrix expression tree into a string. Currently
936 /// matrix expressions are linarized by starting at an expression leaf and
937 /// linearizing bottom up.
938 struct ExprLinearizer {
939 unsigned LengthToBreak = 100;
940 std::string Str;
941 raw_string_ostream Stream;
942 unsigned LineLength = 0;
943 const DataLayout &DL;
944
945 /// Mapping from instructions to column matrixes. It is used to identify
946 /// matrix instructions.
947 const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix;
948
949 /// Mapping from values to the leaves of all expressions that the value is
950 /// part of.
951 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
952
953 /// Leaf node of the expression to linearize.
954 Value *Leaf;
955
956 /// Used to keep track of sub-expressions that get reused while linearizing
957 /// the expression. Re-used sub-expressions are marked as (reused).
958 SmallPtrSet<Value *, 8> ReusedExprs;
959
960 ExprLinearizer(const DataLayout &DL,
961 const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix,
962 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
963 Value *Leaf)
964 : Str(), Stream(Str), DL(DL), Inst2ColumnMatrix(Inst2ColumnMatrix),
965 Shared(Shared), Leaf(Leaf) {}
966
967 void indent(unsigned N) {
968 LineLength += N;
969 for (unsigned i = 0; i < N; i++)
970 Stream << " ";
971 }
972
973 void lineBreak() {
974 Stream << "\n";
975 LineLength = 0;
976 }
977
978 void maybeIndent(unsigned Indent) {
979 if (LineLength >= LengthToBreak)
980 lineBreak();
981
982 if (LineLength == 0)
983 indent(Indent);
984 }
985
986 void write(StringRef S) {
987 LineLength += S.size();
988 Stream << S;
989 }
990
991 Value *getUnderlyingObjectThroughLoads(Value *V) {
992 if (Value *Ptr = getPointerOperand(V))
993 return getUnderlyingObjectThroughLoads(Ptr);
994 else if (V->getType()->isPointerTy())
995 return GetUnderlyingObject(V, DL);
996 return V;
997 }
998
999 /// Returns true if \p V is a matrix value.
1000 bool isMatrix(Value *V) const {
1001 return Inst2ColumnMatrix.find(V) != Inst2ColumnMatrix.end();
1002 }
1003
1004 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
1005 /// \p SS.
1006 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
1007 auto M = Inst2ColumnMatrix.find(V);
1008 if (M == Inst2ColumnMatrix.end())
1009 SS << "unknown";
1010 else {
1011 SS << M->second.getNumRows();
1012 SS << "x";
1013 SS << M->second.getNumColumns();
1014 }
1015 }
1016
1017 /// Write the called function name. Handles calls to llvm.matrix.*
1018 /// specially: we write the name, followed by the dimensions of the input
1019 /// matrixes, followed by the scalar type name.
1020 void writeFnName(CallInst *CI) {
1021 if (!CI->getCalledFunction())
11
Taking false branch
1022 write("<no called fn>");
1023 else {
1024 StringRef Name = CI->getCalledFunction()->getName();
1025 if (!Name.startswith("llvm.matrix")) {
12
Assuming the condition is false
13
Taking false branch
1026 write(Name);
1027 return;
1028 }
1029 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
14
Assuming 'CI' is not a 'IntrinsicInst'
15
'II' initialized to a null pointer value
1030 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {}))
16
Called C++ object pointer is null
1031 .drop_front(StringRef("llvm.matrix.").size()));
1032 write(".");
1033 std::string Tmp = "";
1034 raw_string_ostream SS(Tmp);
1035
1036 switch (II->getIntrinsicID()) {
1037 case Intrinsic::matrix_multiply:
1038 prettyPrintMatrixType(II->getOperand(0), SS);
1039 SS << ".";
1040 prettyPrintMatrixType(II->getOperand(1), SS);
1041 SS << "." << *II->getType()->getScalarType();
1042 break;
1043 case Intrinsic::matrix_transpose:
1044 prettyPrintMatrixType(II->getOperand(0), SS);
1045 SS << "." << *II->getType()->getScalarType();
1046 break;
1047 case Intrinsic::matrix_columnwise_load:
1048 prettyPrintMatrixType(II, SS);
1049 SS << "." << *II->getType()->getScalarType();
1050 break;
1051 case Intrinsic::matrix_columnwise_store:
1052 prettyPrintMatrixType(II->getOperand(0), SS);
1053 SS << "." << *II->getOperand(0)->getType()->getScalarType();
1054 break;
1055 default:
1056 llvm_unreachable("Unhandled case")::llvm::llvm_unreachable_internal("Unhandled case", "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 1056)
;
1057 }
1058 SS.flush();
1059 write(Tmp);
1060 }
1061 }
1062
1063 unsigned getNumShapeArgs(CallInst *CI) const {
1064 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
1065 switch (II->getIntrinsicID()) {
1066 case Intrinsic::matrix_multiply:
1067 return 3;
1068 case Intrinsic::matrix_transpose:
1069 case Intrinsic::matrix_columnwise_load:
1070 case Intrinsic::matrix_columnwise_store:
1071 return 2;
1072 default:
1073 return 0;
1074 }
1075 }
1076 return 0;
1077 }
1078
1079 /// Special printing for values: for pointers, we print if they refer to an
1080 /// (function) external address or a stack address, for other values we
1081 /// either print the constant or "scalar"/"matrix" for other values.
1082 void write(Value *V) {
1083 V = getUnderlyingObjectThroughLoads(V);
1084 if (V->getType()->isPointerTy()) {
1085 if (isa<AllocaInst>(V)) {
1086 Stream << "stack addr";
1087 LineLength += StringRef("stack addr").size();
1088 } else {
1089 Stream << "addr";
1090 LineLength += StringRef("addr").size();
1091 }
1092 if (!V->getName().empty()) {
1093 Stream << " %" << V->getName() << "";
1094 LineLength += V->getName().size() + 2;
1095 }
1096 return;
1097 }
1098
1099 std::string Tmp;
1100 raw_string_ostream TmpStream(Tmp);
1101
1102 if (auto *CI = dyn_cast<ConstantInt>(V))
1103 TmpStream << CI->getValue();
1104 else if (isa<Constant>(V))
1105 TmpStream << "constant";
1106 else {
1107 if (isMatrix(V))
1108 TmpStream << "matrix";
1109 else
1110 TmpStream << "scalar";
1111 }
1112 TmpStream.flush();
1113 Tmp = std::string(StringRef(Tmp).trim());
1114 LineLength += Tmp.size();
1115 Stream << Tmp;
1116 }
1117
1118 /// Linearize expression \p Expr starting at an indentation of \p Indent.
1119 /// Expressions that are re-used multiple times are prefixed with (reused)
1120 /// at the re-used root instruction.
1121 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
1122 bool ParentShared) {
1123 auto *I = cast<Instruction>(Expr);
2
'Expr' is a 'Instruction'
1124 maybeIndent(Indent);
1125 SmallVector<Value *, 8> Ops;
1126
1127 // Is Expr shared with other expression leaves?
1128 bool ExprShared = false;
1129
1130 // Deal with shared subtrees. Mark them as shared, if required.
1131 if (!ParentShared
2.1
'ParentShared' is false
) {
3
Taking true branch
1132 auto SI = Shared.find(Expr);
1133 assert(SI != Shared.end() && SI->second.find(Leaf) != SI->second.end())((SI != Shared.end() && SI->second.find(Leaf) != SI
->second.end()) ? static_cast<void> (0) : __assert_fail
("SI != Shared.end() && SI->second.find(Leaf) != SI->second.end()"
, "/build/llvm-toolchain-snapshot-11~++20200309111110+2c36c23f347/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp"
, 1133, __PRETTY_FUNCTION__))
;
4
Assuming the condition is true
5
'?' condition is true
1134
1135 for (Value *S : SI->second) {
1136 if (S == Leaf)
1137 continue;
1138 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
1139 write("shared with remark at line " + std::to_string(DL.getLine()) +
1140 " column " + std::to_string(DL.getCol()) + " (");
1141 }
1142 ExprShared = SI->second.size() > 1;
6
Assuming the condition is false
1143 }
1144
1145 bool Reused = !ReusedExprs.insert(Expr).second;
7
Assuming field 'second' is true
1146 if (Reused
7.1
'Reused' is false
&& !ParentReused)
1147 write("(reused) ");
1148
1149 if (auto *CI
8.1
'CI' is non-null
= dyn_cast<CallInst>(I)) {
8
Assuming 'I' is a 'CallInst'
9
Taking true branch
1150 writeFnName(CI);
10
Calling 'ExprLinearizer::writeFnName'
1151
1152 Ops.append(CallSite(CI).arg_begin(),
1153 CallSite(CI).arg_end() - getNumShapeArgs(CI));
1154 } else if (isa<BitCastInst>(Expr)) {
1155 // Special case bitcasts, which are used to materialize matrixes from
1156 // non-matrix ops.
1157 write("matrix");
1158 return;
1159 } else {
1160 Ops.append(I->value_op_begin(), I->value_op_end());
1161 write(std::string(I->getOpcodeName()));
1162 }
1163
1164 write(std::string("("));
1165
1166 unsigned NumOpsToBreak = 1;
1167 if (match(Expr, m_Intrinsic<Intrinsic::matrix_columnwise_load>()))
1168 NumOpsToBreak = 2;
1169
1170 for (Value *Op : Ops) {
1171 if (Ops.size() > NumOpsToBreak)
1172 lineBreak();
1173
1174 maybeIndent(Indent + 1);
1175 if (isMatrix(Op))
1176 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
1177 else
1178 write(Op);
1179 if (Op != Ops.back())
1180 write(", ");
1181 }
1182
1183 write(")");
1184 }
1185
1186 const std::string &getResult() {
1187 Stream.flush();
1188 return Str;
1189 }
1190 };
1191
1192 /// Generate remarks for matrix operations in a function. To generate remarks
1193 /// for matrix expressions, the following approach is used:
1194 /// 1. Collect leafs of matrix expressions (done in
1195 /// RemarkGenerator::getExpressionLeaves). Leaves are lowered matrix
1196 /// instructions without other matrix users (like stores).
1197 ///
1198 /// 2. For each leaf, create a remark containing a linearizied version of the
1199 /// matrix expression.
1200 ///
1201 /// TODO:
1202 /// * Summarize number of vector instructions generated for each expression.
1203 /// * Propagate matrix remarks up the inlining chain.
1204 struct RemarkGenerator {
1205 const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix;
1206 OptimizationRemarkEmitter &ORE;
1207 const DataLayout &DL;
1208
1209 RemarkGenerator(const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix,
1210 OptimizationRemarkEmitter &ORE, const DataLayout &DL)
1211 : Inst2ColumnMatrix(Inst2ColumnMatrix), ORE(ORE), DL(DL) {}
1212
1213 /// Return all leafs of matrix expressions. Those are instructions in
1214 /// Inst2ColumnMatrix returing void. Currently that should only include
1215 /// stores.
1216 SmallVector<Value *, 4> getExpressionLeaves() {
1217 SmallVector<Value *, 4> Leaves;
1218 for (auto &KV : Inst2ColumnMatrix)
1219 if (KV.first->getType()->isVoidTy())
1220 Leaves.push_back(KV.first);
1221
1222 return Leaves;
1223 }
1224
1225 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
1226 /// to all visited expressions in \p Shared.
1227 void collectSharedInfo(Value *Leaf, Value *V,
1228 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
1229
1230 if (Inst2ColumnMatrix.find(V) == Inst2ColumnMatrix.end())
1231 return;
1232
1233 auto I = Shared.insert({V, {}});
1234 I.first->second.insert(Leaf);
1235
1236 for (Value *Op : cast<Instruction>(V)->operand_values())
1237 collectSharedInfo(Leaf, Op, Shared);
1238 return;
1239 }
1240
1241 /// Calculate the number of exclusive and shared op counts for expression
1242 /// starting at \p V. Expressions used multiple times are counted once.
1243 std::pair<OpInfoTy, OpInfoTy>
1244 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
1245 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
1246 auto CM = Inst2ColumnMatrix.find(Root);
1247 if (CM == Inst2ColumnMatrix.end())
1248 return {};
1249
1250 // Already counted this expression. Stop.
1251 if (!ReusedExprs.insert(Root).second)
1252 return {};
1253
1254 OpInfoTy SharedCount;
1255 OpInfoTy Count;
1256
1257 auto I = Shared.find(Root);
1258 if (I->second.size() == 1)
1259 Count = CM->second.getOpInfo();
1260 else
1261 SharedCount = CM->second.getOpInfo();
1262
1263 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
1264 auto C = sumOpInfos(Op, ReusedExprs, Shared);
1265 Count += C.first;
1266 SharedCount += C.second;
1267 }
1268 return {Count, SharedCount};
1269 }
1270
1271 void emitRemarks() {
1272 if (!ORE.allowExtraAnalysis(DEBUG_TYPE"lower-matrix-intrinsics"))
1273 return;
1274
1275 // Find leafs of matrix expressions.
1276 auto Leaves = getExpressionLeaves();
1277
1278 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
1279
1280 for (Value *Leaf : Leaves)
1281 collectSharedInfo(Leaf, Leaf, Shared);
1282
1283 // Generate remarks for each leaf.
1284 for (auto *L : Leaves) {
1285 SmallPtrSet<Value *, 8> ReusedExprs;
1286 OpInfoTy Counts, SharedCounts;
1287 std::tie(Counts, SharedCounts) = sumOpInfos(L, ReusedExprs, Shared);
1288
1289 OptimizationRemark Rem(DEBUG_TYPE"lower-matrix-intrinsics", "matrix-lowered",
1290 cast<Instruction>(L)->getDebugLoc(),
1291 cast<Instruction>(L)->getParent());
1292
1293 Rem << "Lowered with ";
1294 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
1295 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
1296 << ore::NV("NumComputeOps", Counts.NumComputeOps) << " compute ops";
1297
1298 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
1299 SharedCounts.NumComputeOps > 0) {
1300 Rem << ",\nadditionally "
1301 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
1302 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
1303 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
1304 << " compute ops"
1305 << " are shared with other expressions";
1306 }
1307
1308 Rem << ("\n" + linearize(L, Shared, DL));
1309 ORE.emit(Rem);
1310 }
1311 }
1312
1313 std::string
1314 linearize(Value *L,
1315 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1316 const DataLayout &DL) {
1317 ExprLinearizer Lin(DL, Inst2ColumnMatrix, Shared, L);
1318 Lin.linearizeExpr(L, 0, false, false);
1
Calling 'ExprLinearizer::linearizeExpr'
1319 return Lin.getResult();
1320 }
1321 };
1322};
1323} // namespace
1324
1325PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
1326 FunctionAnalysisManager &AM) {
1327 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1328 auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
1329 LowerMatrixIntrinsics LMT(F, TTI, ORE);
1330 if (LMT.Visit()) {
1331 PreservedAnalyses PA;
1332 PA.preserveSet<CFGAnalyses>();
1333 return PA;
1334 }
1335 return PreservedAnalyses::all();
1336}
1337
1338namespace {
1339
1340class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
1341public:
1342 static char ID;
1343
1344 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
1345 initializeLowerMatrixIntrinsicsLegacyPassPass(
1346 *PassRegistry::getPassRegistry());
1347 }
1348
1349 bool runOnFunction(Function &F) override {
1350 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1351 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
1352 LowerMatrixIntrinsics LMT(F, TTI, ORE);
1353 bool C = LMT.Visit();
1354 return C;
1355 }
1356
1357 void getAnalysisUsage(AnalysisUsage &AU) const override {
1358 AU.addRequired<TargetTransformInfoWrapperPass>();
1359 AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
1360 AU.setPreservesCFG();
1361 }
1362};
1363} // namespace
1364
1365static const char pass_name[] = "Lower the matrix intrinsics";
1366char LowerMatrixIntrinsicsLegacyPass::ID = 0;
1367INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,static void *initializeLowerMatrixIntrinsicsLegacyPassPassOnce
(PassRegistry &Registry) {
1368 false, false)static void *initializeLowerMatrixIntrinsicsLegacyPassPassOnce
(PassRegistry &Registry) {
1369INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)initializeOptimizationRemarkEmitterWrapperPassPass(Registry);
1370INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,PassInfo *PI = new PassInfo( pass_name, "lower-matrix-intrinsics"
, &LowerMatrixIntrinsicsLegacyPass::ID, PassInfo::NormalCtor_t
(callDefaultCtor<LowerMatrixIntrinsicsLegacyPass>), false
, false); Registry.registerPass(*PI, true); return PI; } static
llvm::once_flag InitializeLowerMatrixIntrinsicsLegacyPassPassFlag
; void llvm::initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeLowerMatrixIntrinsicsLegacyPassPassFlag
, initializeLowerMatrixIntrinsicsLegacyPassPassOnce, std::ref
(Registry)); }
1371 false, false)PassInfo *PI = new PassInfo( pass_name, "lower-matrix-intrinsics"
, &LowerMatrixIntrinsicsLegacyPass::ID, PassInfo::NormalCtor_t
(callDefaultCtor<LowerMatrixIntrinsicsLegacyPass>), false
, false); Registry.registerPass(*PI, true); return PI; } static
llvm::once_flag InitializeLowerMatrixIntrinsicsLegacyPassPassFlag
; void llvm::initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeLowerMatrixIntrinsicsLegacyPassPassFlag
, initializeLowerMatrixIntrinsicsLegacyPassPassOnce, std::ref
(Registry)); }
1372
1373Pass *llvm::createLowerMatrixIntrinsicsPass() {
1374 return new LowerMatrixIntrinsicsLegacyPass();
1375}