LLVM  12.0.0git
MatrixBuilder.h
Go to the documentation of this file.
1 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MatrixBuilder class, which is used as a convenient way
10 // to lower matrix operations to LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_IR_MATRIXBUILDER_H
15 #define LLVM_IR_MATRIXBUILDER_H
16 
17 #include "llvm/IR/Constant.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstrTypes.h"
21 #include "llvm/IR/Instruction.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Alignment.h"
26 
27 namespace llvm {
28 
29 class Function;
30 class Twine;
31 class Module;
32 
33 template <class IRBuilderTy> class MatrixBuilder {
34  IRBuilderTy &B;
35  Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36 
37  std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38  Value *RHS) {
39  assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40  "One of the operands must be a matrix (embedded in a vector)");
41  if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
42  assert(!isa<ScalableVectorType>(LHS->getType()) &&
43  "LHS Assumed to be fixed width");
44  RHS = B.CreateVectorSplat(
45  cast<VectorType>(LHS->getType())->getElementCount(), RHS,
46  "scalar.splat");
47  } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
48  assert(!isa<ScalableVectorType>(RHS->getType()) &&
49  "RHS Assumed to be fixed width");
50  LHS = B.CreateVectorSplat(
51  cast<VectorType>(RHS->getType())->getElementCount(), LHS,
52  "scalar.splat");
53  }
54  return {LHS, RHS};
55  }
56 
57 public:
58  MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
59 
60  /// Create a column major, strided matrix load.
61  /// \p DataPtr - Start address of the matrix read
62  /// \p Rows - Number of rows in matrix (must be a constant)
63  /// \p Columns - Number of columns in matrix (must be a constant)
64  /// \p Stride - Space between columns
65  CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment,
66  Value *Stride, bool IsVolatile, unsigned Rows,
67  unsigned Columns, const Twine &Name = "") {
68 
69  // Deal with the pointer
70  PointerType *PtrTy = cast<PointerType>(DataPtr->getType());
71  Type *EltTy = PtrTy->getElementType();
72 
73  auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
74 
75  Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
76  B.getInt32(Columns)};
77  Type *OverloadedTypes[] = {RetType};
78 
80  getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
81 
82  CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
83  Attribute AlignAttr =
84  Attribute::getWithAlignment(Call->getContext(), Alignment);
85  Call->addAttribute(1, AlignAttr);
86  return Call;
87  }
88 
89  /// Create a column major, strided matrix store.
90  /// \p Matrix - Matrix to store
91  /// \p Ptr - Pointer to write back to
92  /// \p Stride - Space between columns
94  Value *Stride, bool IsVolatile,
95  unsigned Rows, unsigned Columns,
96  const Twine &Name = "") {
97  Value *Ops[] = {Matrix, Ptr,
98  Stride, B.getInt1(IsVolatile),
99  B.getInt32(Rows), B.getInt32(Columns)};
100  Type *OverloadedTypes[] = {Matrix->getType()};
101 
103  getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
104 
105  CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
106  Attribute AlignAttr =
107  Attribute::getWithAlignment(Call->getContext(), Alignment);
108  Call->addAttribute(2, AlignAttr);
109  return Call;
110  }
111 
112  /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
113  /// rows and \p Columns columns.
115  unsigned Columns, const Twine &Name = "") {
116  auto *OpType = cast<VectorType>(Matrix->getType());
117  auto *ReturnType =
118  FixedVectorType::get(OpType->getElementType(), Rows * Columns);
119 
120  Type *OverloadedTypes[] = {ReturnType};
121  Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
123  getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
124 
125  return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
126  }
127 
128  /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
129  /// RHS.
130  CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
131  unsigned LHSColumns, unsigned RHSColumns,
132  const Twine &Name = "") {
133  auto *LHSType = cast<VectorType>(LHS->getType());
134  auto *RHSType = cast<VectorType>(RHS->getType());
135 
136  auto *ReturnType =
137  FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
138 
139  Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
140  B.getInt32(RHSColumns)};
141  Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
142 
144  getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
145  return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
146  }
147 
148  /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
149  /// ColumnIdx).
151  Value *ColumnIdx, unsigned NumRows) {
152  return B.CreateInsertElement(
153  Matrix, NewVal,
154  B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
155  ColumnIdx->getType(), NumRows)),
156  RowIdx));
157  }
158 
159  /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
160  /// matrixes.
161  Value *CreateAdd(Value *LHS, Value *RHS) {
162  assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
163  if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
164  assert(!isa<ScalableVectorType>(LHS->getType()) &&
165  "LHS Assumed to be fixed width");
166  RHS = B.CreateVectorSplat(
167  cast<VectorType>(LHS->getType())->getElementCount(), RHS,
168  "scalar.splat");
169  } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
170  assert(!isa<ScalableVectorType>(RHS->getType()) &&
171  "RHS Assumed to be fixed width");
172  LHS = B.CreateVectorSplat(
173  cast<VectorType>(RHS->getType())->getElementCount(), LHS,
174  "scalar.splat");
175  }
176 
177  return cast<VectorType>(LHS->getType())
178  ->getElementType()
179  ->isFloatingPointTy()
180  ? B.CreateFAdd(LHS, RHS)
181  : B.CreateAdd(LHS, RHS);
182  }
183 
184  /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
185  /// point matrixes.
186  Value *CreateSub(Value *LHS, Value *RHS) {
187  assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
188  if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
189  assert(!isa<ScalableVectorType>(LHS->getType()) &&
190  "LHS Assumed to be fixed width");
191  RHS = B.CreateVectorSplat(
192  cast<VectorType>(LHS->getType())->getElementCount(), RHS,
193  "scalar.splat");
194  } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
195  assert(!isa<ScalableVectorType>(RHS->getType()) &&
196  "RHS Assumed to be fixed width");
197  LHS = B.CreateVectorSplat(
198  cast<VectorType>(RHS->getType())->getElementCount(), LHS,
199  "scalar.splat");
200  }
201 
202  return cast<VectorType>(LHS->getType())
203  ->getElementType()
204  ->isFloatingPointTy()
205  ? B.CreateFSub(LHS, RHS)
206  : B.CreateSub(LHS, RHS);
207  }
208 
209  /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
210  /// RHS.
212  std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
213  if (LHS->getType()->getScalarType()->isFloatingPointTy())
214  return B.CreateFMul(LHS, RHS);
215  return B.CreateMul(LHS, RHS);
216  }
217 
218  /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.
220  unsigned NumRows, Twine const &Name = "") {
221 
222  unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
223  ColumnIdx->getType()->getScalarSizeInBits());
224  Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
225  RowIdx = B.CreateZExt(RowIdx, IntTy);
226  ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
227  Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
228  return B.CreateExtractElement(
229  Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx),
230  "matext");
231  }
232 };
233 
234 } // end namespace llvm
235 
236 #endif // LLVM_IR_MATRIXBUILDER_H
This class represents lattice values for constants.
Definition: AllocatorList.h:23
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
constexpr char IsVolatile[]
Key for Kernel::Arg::Metadata::mIsVolatile.
Value * CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, Value *ColumnIdx, unsigned NumRows)
Insert a single element NewVal into Matrix at indices (RowIdx, ColumnIdx).
This class represents a function call, abstracting a target machine's calling convention.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:626
Value * CreateSub(Value *LHS, Value *RHS)
Subtract matrixes LHS and RHS.
Value * CreateScalarMultiply(Value *LHS, Value *RHS)
Multiply matrix LHS with scalar RHS or scalar LHS with matrix RHS.
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:235
Live Register Matrix
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition: Type.h:128
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:80
bool isFloatingPointTy() const
Return true if this is one of the six floating-point types.
Definition: Type.h:163
static Attribute getWithAlignment(LLVMContext &Context, Align Alignment)
Return a uniquified Attribute object that has the specific alignment set.
Definition: Attributes.cpp:149
CallInst * CreateColumnMajorLoad(Value *DataPtr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix load.
Definition: MatrixBuilder.h:65
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:246
Value * CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx, unsigned NumRows, Twine const &Name="")
Extracts the element at (RowIdx, ColumnIdx) from Matrix.
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1250
Class to represent pointers.
Definition: DerivedTypes.h:655
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:313
MatrixBuilder(IRBuilderTy &Builder)
Definition: MatrixBuilder.h:58
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:46
This file contains the declarations for the subclasses of Constant, which represent the different fla...
assume Assume Builder
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:269
Machine Check Debug Module
Align max(MaybeAlign Lhs, Align Rhs)
Definition: Alignment.h:350
Value * CreateAdd(Value *LHS, Value *RHS)
Add matrixes LHS and RHS.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition: Type.cpp:147
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:867
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:165
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix store.
Definition: MatrixBuilder.h:93
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
Definition: Value.h:75
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Type * getElementType() const
Definition: DerivedTypes.h:674