14#ifndef LLVM_IR_MATRIXBUILDER_H
15#define LLVM_IR_MATRIXBUILDER_H
35 Module *getModule() {
return B.GetInsertBlock()->getParent()->getParent(); }
37 std::pair<Value *, Value *> splatScalarOperandIfNeeded(
Value *
LHS,
39 assert((
LHS->getType()->isVectorTy() ||
RHS->getType()->isVectorTy()) &&
40 "One of the operands must be a matrix (embedded in a vector)");
41 if (
LHS->getType()->isVectorTy() && !
RHS->getType()->isVectorTy()) {
42 assert(!isa<ScalableVectorType>(
LHS->getType()) &&
43 "LHS Assumed to be fixed width");
44 RHS = B.CreateVectorSplat(
45 cast<VectorType>(
LHS->getType())->getElementCount(),
RHS,
47 }
else if (!
LHS->getType()->isVectorTy() &&
RHS->getType()->isVectorTy()) {
48 assert(!isa<ScalableVectorType>(
RHS->getType()) &&
49 "RHS Assumed to be fixed width");
50 LHS = B.CreateVectorSplat(
51 cast<VectorType>(
RHS->getType())->getElementCount(),
LHS,
67 Value *Stride,
bool IsVolatile,
unsigned Rows,
68 unsigned Columns,
const Twine &
Name =
"") {
71 Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
73 Type *OverloadedTypes[] = {RetType, Stride->
getType()};
76 getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
81 Call->addParamAttr(0, AlignAttr);
90 Value *Stride,
bool IsVolatile,
91 unsigned Rows,
unsigned Columns,
94 Stride, B.getInt1(IsVolatile),
95 B.getInt32(Rows), B.getInt32(Columns)};
99 getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
104 Call->addParamAttr(1, AlignAttr);
111 unsigned Columns,
const Twine &
Name =
"") {
112 auto *OpType = cast<VectorType>(
Matrix->getType());
116 Type *OverloadedTypes[] = {ReturnType};
117 Value *Ops[] = {
Matrix, B.getInt32(Rows), B.getInt32(Columns)};
119 getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
127 unsigned LHSColumns,
unsigned RHSColumns,
129 auto *LHSType = cast<VectorType>(
LHS->getType());
130 auto *RHSType = cast<VectorType>(
RHS->getType());
135 Value *Ops[] = {
LHS,
RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
136 B.getInt32(RHSColumns)};
137 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
140 getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
147 Value *ColumnIdx,
unsigned NumRows) {
148 return B.CreateInsertElement(
150 B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
151 ColumnIdx->
getType(), NumRows)),
158 assert(
LHS->getType()->isVectorTy() ||
RHS->getType()->isVectorTy());
159 if (
LHS->getType()->isVectorTy() && !
RHS->getType()->isVectorTy()) {
160 assert(!isa<ScalableVectorType>(
LHS->getType()) &&
161 "LHS Assumed to be fixed width");
162 RHS = B.CreateVectorSplat(
163 cast<VectorType>(
LHS->getType())->getElementCount(),
RHS,
165 }
else if (!
LHS->getType()->isVectorTy() &&
RHS->getType()->isVectorTy()) {
166 assert(!isa<ScalableVectorType>(
RHS->getType()) &&
167 "RHS Assumed to be fixed width");
168 LHS = B.CreateVectorSplat(
169 cast<VectorType>(
RHS->getType())->getElementCount(),
LHS,
173 return cast<VectorType>(
LHS->getType())
175 ->isFloatingPointTy()
183 assert(
LHS->getType()->isVectorTy() ||
RHS->getType()->isVectorTy());
184 if (
LHS->getType()->isVectorTy() && !
RHS->getType()->isVectorTy()) {
185 assert(!isa<ScalableVectorType>(
LHS->getType()) &&
186 "LHS Assumed to be fixed width");
187 RHS = B.CreateVectorSplat(
188 cast<VectorType>(
LHS->getType())->getElementCount(),
RHS,
190 }
else if (!
LHS->getType()->isVectorTy() &&
RHS->getType()->isVectorTy()) {
191 assert(!isa<ScalableVectorType>(
RHS->getType()) &&
192 "RHS Assumed to be fixed width");
193 LHS = B.CreateVectorSplat(
194 cast<VectorType>(
RHS->getType())->getElementCount(),
LHS,
198 return cast<VectorType>(
LHS->getType())
200 ->isFloatingPointTy()
208 std::tie(
LHS,
RHS) = splatScalarOperandIfNeeded(
LHS,
RHS);
209 if (
LHS->getType()->getScalarType()->isFloatingPointTy())
210 return B.CreateFMul(
LHS,
RHS);
211 return B.CreateMul(
LHS,
RHS);
217 assert(
LHS->getType()->isVectorTy() && !
RHS->getType()->isVectorTy());
218 assert(!isa<ScalableVectorType>(
LHS->getType()) &&
219 "LHS Assumed to be fixed width");
221 B.CreateVectorSplat(cast<VectorType>(
LHS->getType())->getElementCount(),
222 RHS,
"scalar.splat");
223 return cast<VectorType>(
LHS->getType())
225 ->isFloatingPointTy()
227 : (IsUnsigned ? B.CreateUDiv(
LHS,
RHS) : B.CreateSDiv(
LHS,
RHS));
234 B.getIntN(
Idx->getType()->getScalarSizeInBits(), NumElements);
235 auto *Cmp = B.CreateICmpULT(
Idx, NumElts);
236 if (isa<ConstantInt>(Cmp))
237 assert(cast<ConstantInt>(Cmp)->isOne() &&
"Index must be valid!");
239 B.CreateAssumption(Cmp);
249 RowIdx = B.CreateZExt(RowIdx, IntTy);
250 ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
251 Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
252 return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
Machine Check Debug Module
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static Attribute getWithAlignment(LLVMContext &Context, Align Alignment)
Return a uniquified Attribute object that has the specific alignment set.
This class represents a function call, abstracting a target machine's calling convention.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Common base class shared among various IRBuilders.
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Value * CreateScalarMultiply(Value *LHS, Value *RHS)
Multiply matrix LHS with scalar RHS or scalar LHS with matrix RHS.
Value * CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned)
Divide matrix LHS by scalar RHS.
Value * CreateSub(Value *LHS, Value *RHS)
Subtract matrixes LHS and RHS.
MatrixBuilder(IRBuilderBase &Builder)
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix store.
Value * CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, Value *ColumnIdx, unsigned NumRows)
Insert a single element NewVal into Matrix at indices (RowIdx, ColumnIdx).
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
void CreateIndexAssumption(Value *Idx, unsigned NumElements, Twine const &Name="")
Create an assumption that Idx is less than NumElements.
Value * CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, Twine const &Name="")
Compute the index to access the element at (RowIdx, ColumnIdx) from a matrix with NumRows embedded in...
CallInst * CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix load.
Value * CreateAdd(Value *LHS, Value *RHS)
Add matrixes LHS and RHS.
A Module instance is used to store all the information related to an LLVM module.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
This is an optimization pass for GlobalISel generic memory operations.
This struct is a compact representation of a valid (non-zero power of two) alignment.