LLVM 20.0.0git
MatrixBuilder.h
Go to the documentation of this file.
1//===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MatrixBuilder class, which is used as a convenient way
10// to lower matrix operations to LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_IR_MATRIXBUILDER_H
15#define LLVM_IR_MATRIXBUILDER_H
16
17#include "llvm/IR/Constant.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/IRBuilder.h"
20#include "llvm/IR/InstrTypes.h"
21#include "llvm/IR/Instruction.h"
23#include "llvm/IR/Type.h"
24#include "llvm/IR/Value.h"
26
27namespace llvm {
28
29class Function;
30class Twine;
31class Module;
32
35 Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36
37 std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38 Value *RHS) {
39 assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40 "One of the operands must be a matrix (embedded in a vector)");
41 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
42 assert(!isa<ScalableVectorType>(LHS->getType()) &&
43 "LHS Assumed to be fixed width");
44 RHS = B.CreateVectorSplat(
45 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
46 "scalar.splat");
47 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
48 assert(!isa<ScalableVectorType>(RHS->getType()) &&
49 "RHS Assumed to be fixed width");
50 LHS = B.CreateVectorSplat(
51 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
52 "scalar.splat");
53 }
54 return {LHS, RHS};
55 }
56
57public:
58 MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
59
60 /// Create a column major, strided matrix load.
61 /// \p EltTy - Matrix element type
62 /// \p DataPtr - Start address of the matrix read
63 /// \p Rows - Number of rows in matrix (must be a constant)
64 /// \p Columns - Number of columns in matrix (must be a constant)
65 /// \p Stride - Space between columns
66 CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
67 Value *Stride, bool IsVolatile, unsigned Rows,
68 unsigned Columns, const Twine &Name = "") {
69 auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
70
71 Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
72 B.getInt32(Columns)};
73 Type *OverloadedTypes[] = {RetType, Stride->getType()};
74
76 getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
77
78 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
79 Attribute AlignAttr =
80 Attribute::getWithAlignment(Call->getContext(), Alignment);
81 Call->addParamAttr(0, AlignAttr);
82 return Call;
83 }
84
85 /// Create a column major, strided matrix store.
86 /// \p Matrix - Matrix to store
87 /// \p Ptr - Pointer to write back to
88 /// \p Stride - Space between columns
90 Value *Stride, bool IsVolatile,
91 unsigned Rows, unsigned Columns,
92 const Twine &Name = "") {
93 Value *Ops[] = {Matrix, Ptr,
94 Stride, B.getInt1(IsVolatile),
95 B.getInt32(Rows), B.getInt32(Columns)};
96 Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
97
99 getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
100
101 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
102 Attribute AlignAttr =
103 Attribute::getWithAlignment(Call->getContext(), Alignment);
104 Call->addParamAttr(1, AlignAttr);
105 return Call;
106 }
107
108 /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
109 /// rows and \p Columns columns.
111 unsigned Columns, const Twine &Name = "") {
112 auto *OpType = cast<VectorType>(Matrix->getType());
113 auto *ReturnType =
114 FixedVectorType::get(OpType->getElementType(), Rows * Columns);
115
116 Type *OverloadedTypes[] = {ReturnType};
117 Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
119 getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
120
121 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
122 }
123
124 /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
125 /// RHS.
127 unsigned LHSColumns, unsigned RHSColumns,
128 const Twine &Name = "") {
129 auto *LHSType = cast<VectorType>(LHS->getType());
130 auto *RHSType = cast<VectorType>(RHS->getType());
131
132 auto *ReturnType =
133 FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
134
135 Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
136 B.getInt32(RHSColumns)};
137 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
138
140 getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
141 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
142 }
143
144 /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
145 /// ColumnIdx).
147 Value *ColumnIdx, unsigned NumRows) {
148 return B.CreateInsertElement(
149 Matrix, NewVal,
150 B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
151 ColumnIdx->getType(), NumRows)),
152 RowIdx));
153 }
154
155 /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
156 /// matrixes.
158 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
159 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
160 assert(!isa<ScalableVectorType>(LHS->getType()) &&
161 "LHS Assumed to be fixed width");
162 RHS = B.CreateVectorSplat(
163 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
164 "scalar.splat");
165 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
166 assert(!isa<ScalableVectorType>(RHS->getType()) &&
167 "RHS Assumed to be fixed width");
168 LHS = B.CreateVectorSplat(
169 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
170 "scalar.splat");
171 }
172
173 return cast<VectorType>(LHS->getType())
174 ->getElementType()
175 ->isFloatingPointTy()
176 ? B.CreateFAdd(LHS, RHS)
177 : B.CreateAdd(LHS, RHS);
178 }
179
180 /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
181 /// point matrixes.
183 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
184 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
185 assert(!isa<ScalableVectorType>(LHS->getType()) &&
186 "LHS Assumed to be fixed width");
187 RHS = B.CreateVectorSplat(
188 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
189 "scalar.splat");
190 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
191 assert(!isa<ScalableVectorType>(RHS->getType()) &&
192 "RHS Assumed to be fixed width");
193 LHS = B.CreateVectorSplat(
194 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
195 "scalar.splat");
196 }
197
198 return cast<VectorType>(LHS->getType())
199 ->getElementType()
200 ->isFloatingPointTy()
201 ? B.CreateFSub(LHS, RHS)
202 : B.CreateSub(LHS, RHS);
203 }
204
205 /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
206 /// RHS.
208 std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
209 if (LHS->getType()->getScalarType()->isFloatingPointTy())
210 return B.CreateFMul(LHS, RHS);
211 return B.CreateMul(LHS, RHS);
212 }
213
214 /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
215 /// IsUnsigned indicates whether UDiv or SDiv should be used.
216 Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
217 assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
218 assert(!isa<ScalableVectorType>(LHS->getType()) &&
219 "LHS Assumed to be fixed width");
220 RHS =
221 B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
222 RHS, "scalar.splat");
223 return cast<VectorType>(LHS->getType())
224 ->getElementType()
225 ->isFloatingPointTy()
226 ? B.CreateFDiv(LHS, RHS)
227 : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
228 }
229
230 /// Create an assumption that \p Idx is less than \p NumElements.
231 void CreateIndexAssumption(Value *Idx, unsigned NumElements,
232 Twine const &Name = "") {
233 Value *NumElts =
234 B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
235 auto *Cmp = B.CreateICmpULT(Idx, NumElts);
236 if (isa<ConstantInt>(Cmp))
237 assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
238 else
239 B.CreateAssumption(Cmp);
240 }
241
242 /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
243 /// a matrix with \p NumRows embedded in a vector.
244 Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
245 Twine const &Name = "") {
246 unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
247 ColumnIdx->getType()->getScalarSizeInBits());
248 Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
249 RowIdx = B.CreateZExt(RowIdx, IntTy);
250 ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
251 Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
252 return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
253 }
254};
255
256} // end namespace llvm
257
258#endif // LLVM_IR_MATRIXBUILDER_H
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
std::string Name
Live Register Matrix
Machine Check Debug Module
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Value * RHS
Value * LHS
static Attribute getWithAlignment(LLVMContext &Context, Align Alignment)
Return a uniquified Attribute object that has the specific alignment set.
Definition: Attributes.cpp:233
This class represents a function call, abstracting a target machine's calling convention.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:680
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:214
Common base class shared among various IRBuilders.
Definition: IRBuilder.h:91
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:266
Value * CreateScalarMultiply(Value *LHS, Value *RHS)
Multiply matrix LHS with scalar RHS or scalar LHS with matrix RHS.
Value * CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned)
Divide matrix LHS by scalar RHS.
Value * CreateSub(Value *LHS, Value *RHS)
Subtract matrixes LHS and RHS.
MatrixBuilder(IRBuilderBase &Builder)
Definition: MatrixBuilder.h:58
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix store.
Definition: MatrixBuilder.h:89
Value * CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, Value *ColumnIdx, unsigned NumRows)
Insert a single element NewVal into Matrix at indices (RowIdx, ColumnIdx).
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
void CreateIndexAssumption(Value *Idx, unsigned NumElements, Twine const &Name="")
Create an assumption that Idx is less than NumElements.
Value * CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, Twine const &Name="")
Compute the index to access the element at (RowIdx, ColumnIdx) from a matrix with NumRows embedded in...
CallInst * CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix load.
Definition: MatrixBuilder.h:66
Value * CreateAdd(Value *LHS, Value *RHS)
Add matrixes LHS and RHS.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition: Type.h:128
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1539
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39