LLVM 23.0.0git
MatrixBuilder.h
Go to the documentation of this file.
1//===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MatrixBuilder class, which is used as a convenient way
10// to lower matrix operations to LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_IR_MATRIXBUILDER_H
15#define LLVM_IR_MATRIXBUILDER_H
16
17#include "llvm/IR/Constant.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/IRBuilder.h"
20#include "llvm/IR/InstrTypes.h"
21#include "llvm/IR/Instruction.h"
23#include "llvm/IR/Type.h"
24#include "llvm/IR/Value.h"
26
27namespace llvm {
28
29class Function;
30class Twine;
31class Module;
32
35 Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36
37 std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38 Value *RHS) {
39 assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40 "One of the operands must be a matrix (embedded in a vector)");
41 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
42 assert(!isa<ScalableVectorType>(LHS->getType()) &&
43 "LHS Assumed to be fixed width");
44 RHS = B.CreateVectorSplat(
45 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
46 "scalar.splat");
47 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
48 assert(!isa<ScalableVectorType>(RHS->getType()) &&
49 "RHS Assumed to be fixed width");
50 LHS = B.CreateVectorSplat(
51 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
52 "scalar.splat");
53 }
54 return {LHS, RHS};
55 }
56
57public:
58 MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
59
60 /// Create a column major, strided matrix load.
61 /// \p EltTy - Matrix element type
62 /// \p DataPtr - Start address of the matrix read
63 /// \p Rows - Number of rows in matrix (must be a constant)
64 /// \p Columns - Number of columns in matrix (must be a constant)
65 /// \p Stride - Space between columns
66 CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
67 Value *Stride, bool IsVolatile, unsigned Rows,
68 unsigned Columns, const Twine &Name = "") {
69 auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
70
71 Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
72 B.getInt32(Columns)};
73 Type *OverloadedTypes[] = {RetType, Stride->getType()};
74
76 getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
77
78 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
79 Attribute AlignAttr =
80 Attribute::getWithAlignment(Call->getContext(), Alignment);
81 Call->addParamAttr(0, AlignAttr);
82 return Call;
83 }
84
85 /// Create a column major, strided matrix store.
86 /// \p Matrix - Matrix to store
87 /// \p Ptr - Pointer to write back to
88 /// \p Stride - Space between columns
90 Value *Stride, bool IsVolatile,
91 unsigned Rows, unsigned Columns,
92 const Twine &Name = "") {
93 Value *Ops[] = {Matrix, Ptr,
94 Stride, B.getInt1(IsVolatile),
95 B.getInt32(Rows), B.getInt32(Columns)};
96 Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
97
99 getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
100
101 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
102 Attribute AlignAttr =
103 Attribute::getWithAlignment(Call->getContext(), Alignment);
104 Call->addParamAttr(1, AlignAttr);
105 return Call;
106 }
107
108 /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
109 /// rows and \p Columns columns.
111 unsigned Columns, const Twine &Name = "") {
112 auto *OpType = cast<VectorType>(Matrix->getType());
113 auto *ReturnType =
114 FixedVectorType::get(OpType->getElementType(), Rows * Columns);
115
116 Type *OverloadedTypes[] = {ReturnType};
117 Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
119 getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
120
121 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
122 }
123
124 /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
125 /// RHS.
127 unsigned LHSColumns, unsigned RHSColumns,
128 const Twine &Name = "") {
129 auto *LHSType = cast<VectorType>(LHS->getType());
130 auto *RHSType = cast<VectorType>(RHS->getType());
131
132 auto *ReturnType =
133 FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
134
135 Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
136 B.getInt32(RHSColumns)};
137 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
138
140 getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
141 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
142 }
143
144 /// Create a column-major matrix from a row-major matrix with the given
145 /// logical dimensions by transposing it.
146 /// Assumes the matrix transpose assumes column-major matrix memory layout,
147 /// which is true in the case of the DirectX and SPIRV backends, but not
148 /// necessarily true in the case of the LowerMatrixIntrinsics pass.
150 unsigned Columns,
151 const Twine &Name = "") {
152 return CreateMatrixTranspose(Matrix, Columns, Rows, Name);
153 }
154
155 /// Create a row-major matrix from a column-major matrix with the given
156 /// logical dimensions by transposing it.
157 /// Assumes the matrix transpose assumes column-major matrix memory layout,
158 /// which is true in the case of the DirectX and SPIRV backends, but not
159 /// necessarily true in the case of the LowerMatrixIntrinsics pass.
161 unsigned Columns,
162 const Twine &Name = "") {
163 return CreateMatrixTranspose(Matrix, Rows, Columns, Name);
164 }
165
166 /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
167 /// ColumnIdx).
169 Value *ColumnIdx, unsigned NumRows) {
170 return B.CreateInsertElement(
171 Matrix, NewVal,
172 B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
173 ColumnIdx->getType(), NumRows)),
174 RowIdx));
175 }
176
177 /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
178 /// matrixes.
180 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
181 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
182 assert(!isa<ScalableVectorType>(LHS->getType()) &&
183 "LHS Assumed to be fixed width");
184 RHS = B.CreateVectorSplat(
185 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
186 "scalar.splat");
187 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
188 assert(!isa<ScalableVectorType>(RHS->getType()) &&
189 "RHS Assumed to be fixed width");
190 LHS = B.CreateVectorSplat(
191 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
192 "scalar.splat");
193 }
194
195 return cast<VectorType>(LHS->getType())
196 ->getElementType()
197 ->isFloatingPointTy()
198 ? B.CreateFAdd(LHS, RHS)
199 : B.CreateAdd(LHS, RHS);
200 }
201
202 /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
203 /// point matrixes.
205 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
206 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
207 assert(!isa<ScalableVectorType>(LHS->getType()) &&
208 "LHS Assumed to be fixed width");
209 RHS = B.CreateVectorSplat(
210 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
211 "scalar.splat");
212 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
213 assert(!isa<ScalableVectorType>(RHS->getType()) &&
214 "RHS Assumed to be fixed width");
215 LHS = B.CreateVectorSplat(
216 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
217 "scalar.splat");
218 }
219
220 return cast<VectorType>(LHS->getType())
221 ->getElementType()
222 ->isFloatingPointTy()
223 ? B.CreateFSub(LHS, RHS)
224 : B.CreateSub(LHS, RHS);
225 }
226
227 /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
228 /// RHS.
230 std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
231 if (LHS->getType()->getScalarType()->isFloatingPointTy())
232 return B.CreateFMul(LHS, RHS);
233 return B.CreateMul(LHS, RHS);
234 }
235
236 /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
237 /// IsUnsigned indicates whether UDiv or SDiv should be used.
238 Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
239 assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
240 assert(!isa<ScalableVectorType>(LHS->getType()) &&
241 "LHS Assumed to be fixed width");
242 RHS =
243 B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
244 RHS, "scalar.splat");
245 return cast<VectorType>(LHS->getType())
246 ->getElementType()
247 ->isFloatingPointTy()
248 ? B.CreateFDiv(LHS, RHS)
249 : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
250 }
251
252 /// Create an assumption that \p Idx is less than \p NumElements.
253 void CreateIndexAssumption(Value *Idx, unsigned NumElements,
254 Twine const &Name = "") {
255 Value *NumElts =
256 B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
257 auto *Cmp = B.CreateICmpULT(Idx, NumElts);
258 if (isa<ConstantInt>(Cmp))
259 assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
260 else
261 B.CreateAssumption(Cmp);
262 }
263 /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
264 /// a matrix with \p NumRows or \p NumCols embedded in a vector depending
265 /// on matrix major ordering.
266 Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
267 unsigned NumCols, bool IsMatrixRowMajor = false,
268 Twine const &Name = "") {
269 unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
270 ColumnIdx->getType()->getScalarSizeInBits());
271 Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
272 RowIdx = B.CreateZExt(RowIdx, IntTy);
273 ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
274 if (IsMatrixRowMajor) {
275 Value *NumColsV = B.getIntN(MaxWidth, NumCols);
276 return CreateRowMajorIndex(RowIdx, ColumnIdx, NumColsV, Name);
277 }
278 Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
279 return CreateColumnMajorIndex(RowIdx, ColumnIdx, NumRowsV, Name);
280 }
281
282private:
283 /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
284 /// a matrix with \p NumRows embedded in a vector.
285 Value *CreateColumnMajorIndex(Value *RowIdx, Value *ColumnIdx,
286 Value *NumRowsV, Twine const &Name) {
287 return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
288 }
289
290 /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
291 /// a matrix with \p NumCols embedded in a vector.
292 Value *CreateRowMajorIndex(Value *RowIdx, Value *ColumnIdx, Value *NumColsV,
293 Twine const &Name) {
294 return B.CreateAdd(B.CreateMul(RowIdx, NumColsV), ColumnIdx);
295 }
296};
297
298} // end namespace llvm
299
300#endif // LLVM_IR_MATRIXBUILDER_H
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
Live Register Matrix
Value * RHS
Value * LHS
Functions, function parameters, and return types can have attributes to indicate how they should be t...
Definition Attributes.h:105
static LLVM_ABI Attribute getWithAlignment(LLVMContext &Context, Align Alignment)
Return a uniquified Attribute object that has the specific alignment set.
This class represents a function call, abstracting a target machine's calling convention.
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition Type.cpp:873
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition Function.h:211
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:354
Value * CreateScalarMultiply(Value *LHS, Value *RHS)
Multiply matrix LHS with scalar RHS or scalar LHS with matrix RHS.
CallInst * CreateColumnMajorToRowMajorTransform(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a row-major matrix from a column-major matrix with the given logical dimensions by transposing...
Value * CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned)
Divide matrix LHS by scalar RHS.
Value * CreateSub(Value *LHS, Value *RHS)
Subtract matrixes LHS and RHS.
MatrixBuilder(IRBuilderBase &Builder)
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix store.
Value * CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, unsigned NumCols, bool IsMatrixRowMajor=false, Twine const &Name="")
Compute the index to access the element at (RowIdx, ColumnIdx) from a matrix with NumRows or NumCols ...
CallInst * CreateRowMajorToColumnMajorTransform(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column-major matrix from a row-major matrix with the given logical dimensions by transposing...
Value * CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, Value *ColumnIdx, unsigned NumRows)
Insert a single element NewVal into Matrix at indices (RowIdx, ColumnIdx).
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
void CreateIndexAssumption(Value *Idx, unsigned NumElements, Twine const &Name="")
Create an assumption that Idx is less than NumElements.
CallInst * CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix load.
Value * CreateAdd(Value *LHS, Value *RHS)
Add matrixes LHS and RHS.
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition Type.h:130
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition Type.cpp:236
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
CallInst * Call
LLVM_ABI Function * getOrInsertDeclaration(Module *M, ID id, ArrayRef< Type * > Tys={})
Look up the Function declaration of the intrinsic id in the Module M.
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39