LLVM 20.0.0git
SPIRVRegularizer.cpp
Go to the documentation of this file.
1//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10// the pass was taken from SPIRV-LLVM translator.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRV.h"
15#include "SPIRVTargetMachine.h"
18#include "llvm/IR/InstVisitor.h"
19#include "llvm/IR/PassManager.h"
21
22#include <list>
23
24#define DEBUG_TYPE "spirv-regularizer"
25
26using namespace llvm;
27
28namespace llvm {
30}
31
32namespace {
33struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
35
36public:
37 static char ID;
38 SPIRVRegularizer() : FunctionPass(ID) {
40 }
41 bool runOnFunction(Function &F) override;
42 StringRef getPassName() const override { return "SPIR-V Regularizer"; }
43
44 void getAnalysisUsage(AnalysisUsage &AU) const override {
46 }
47 void visitCallInst(CallInst &CI);
48
49private:
50 void visitCallScalToVec(CallInst *CI, StringRef MangledName,
51 StringRef DemangledName);
52 void runLowerConstExpr(Function &F);
53};
54} // namespace
55
56char SPIRVRegularizer::ID = 0;
57
58INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
59 false)
60
61// Since SPIR-V cannot represent constant expression, constant expressions
62// in LLVM IR need to be lowered to instructions. For each function,
63// the constant expressions used by instructions of the function are replaced
64// by instructions placed in the entry block since it dominates all other BBs.
65// Each constant expression only needs to be lowered once in each function
66// and all uses of it by instructions in that function are replaced by
67// one instruction.
68// TODO: remove redundant instructions for common subexpression.
69void SPIRVRegularizer::runLowerConstExpr(Function &F) {
70 LLVMContext &Ctx = F.getContext();
71 std::list<Instruction *> WorkList;
72 for (auto &II : instructions(F))
73 WorkList.push_back(&II);
74
75 auto FBegin = F.begin();
76 while (!WorkList.empty()) {
77 Instruction *II = WorkList.front();
78
79 auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
80 if (isa<Function>(V))
81 return V;
82 auto *CE = cast<ConstantExpr>(V);
83 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
84 auto ReplInst = CE->getAsInstruction();
85 auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
86 ReplInst->insertBefore(InsPoint);
87 LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
88 std::vector<Instruction *> Users;
89 // Do not replace use during iteration of use. Do it in another loop.
90 for (auto U : CE->users()) {
91 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
92 auto InstUser = dyn_cast<Instruction>(U);
93 // Only replace users in scope of current function.
94 if (InstUser && InstUser->getParent()->getParent() == &F)
95 Users.push_back(InstUser);
96 }
97 for (auto &User : Users) {
98 if (ReplInst->getParent() == User->getParent() &&
99 User->comesBefore(ReplInst))
100 ReplInst->moveBefore(User);
101 User->replaceUsesOfWith(CE, ReplInst);
102 }
103 return ReplInst;
104 };
105
106 WorkList.pop_front();
107 auto LowerConstantVec = [&II, &LowerOp, &WorkList,
108 &Ctx](ConstantVector *Vec,
109 unsigned NumOfOp) -> Value * {
110 if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
111 return isa<ConstantExpr>(V) || isa<Function>(V);
112 })) {
113 // Expand a vector of constexprs and construct it back with
114 // series of insertelement instructions.
115 std::list<Value *> OpList;
116 std::transform(Vec->op_begin(), Vec->op_end(),
117 std::back_inserter(OpList),
118 [LowerOp](Value *V) { return LowerOp(V); });
119 Value *Repl = nullptr;
120 unsigned Idx = 0;
121 auto *PhiII = dyn_cast<PHINode>(II);
122 Instruction *InsPoint =
123 PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
124 std::list<Instruction *> ReplList;
125 for (auto V : OpList) {
126 if (auto *Inst = dyn_cast<Instruction>(V))
127 ReplList.push_back(Inst);
129 (Repl ? Repl : PoisonValue::get(Vec->getType())), V,
130 ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "",
131 InsPoint->getIterator());
132 }
133 WorkList.splice(WorkList.begin(), ReplList);
134 return Repl;
135 }
136 return nullptr;
137 };
138 for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
139 auto *Op = II->getOperand(OI);
140 if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
141 Value *ReplInst = LowerConstantVec(Vec, OI);
142 if (ReplInst)
143 II->replaceUsesOfWith(Op, ReplInst);
144 } else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
145 WorkList.push_front(cast<Instruction>(LowerOp(CE)));
146 } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
147 auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
148 if (!ConstMD)
149 continue;
150 Constant *C = ConstMD->getValue();
151 Value *ReplInst = nullptr;
152 if (auto *Vec = dyn_cast<ConstantVector>(C))
153 ReplInst = LowerConstantVec(Vec, OI);
154 if (auto *CE = dyn_cast<ConstantExpr>(C))
155 ReplInst = LowerOp(CE);
156 if (!ReplInst)
157 continue;
158 Metadata *RepMD = ValueAsMetadata::get(ReplInst);
159 Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
160 II->setOperand(OI, RepMDVal);
161 WorkList.push_front(cast<Instruction>(ReplInst));
162 }
163 }
164 }
165}
166
167// It fixes calls to OCL builtins that accept vector arguments and one of them
168// is actually a scalar splat.
169void SPIRVRegularizer::visitCallInst(CallInst &CI) {
170 auto F = CI.getCalledFunction();
171 if (!F)
172 return;
173
174 auto MangledName = F->getName();
175 char *NameStr = itaniumDemangle(F->getName().data());
176 if (!NameStr)
177 return;
178 StringRef DemangledName(NameStr);
179
180 // TODO: add support for other builtins.
181 if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
182 DemangledName.starts_with("min") || DemangledName.starts_with("max"))
183 visitCallScalToVec(&CI, MangledName, DemangledName);
184 free(NameStr);
185}
186
187void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
188 StringRef DemangledName) {
189 // Check if all arguments have the same type - it's simple case.
190 auto Uniform = true;
191 Type *Arg0Ty = CI->getOperand(0)->getType();
192 auto IsArg0Vector = isa<VectorType>(Arg0Ty);
193 for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
194 Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
195 if (Uniform)
196 return;
197
198 auto *OldF = CI->getCalledFunction();
199 Function *NewF = nullptr;
200 if (!Old2NewFuncs.count(OldF)) {
202 SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
203 auto *NewFTy =
204 FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
205 NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
206 *OldF->getParent());
208 auto NewFArgIt = NewF->arg_begin();
209 for (auto &Arg : OldF->args()) {
210 auto ArgName = Arg.getName();
211 NewFArgIt->setName(ArgName);
212 VMap[&Arg] = &(*NewFArgIt++);
213 }
215 CloneFunctionInto(NewF, OldF, VMap,
216 CloneFunctionChangeType::LocalChangesOnly, Returns);
217 NewF->setAttributes(Attrs);
218 Old2NewFuncs[OldF] = NewF;
219 } else {
220 NewF = Old2NewFuncs[OldF];
221 }
222 assert(NewF);
223
224 // This produces an instruction sequence that implements a splat of
225 // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
226 // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
227 // For instance (transcoding/OpMin.ll), this call
228 // call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
229 // is translated to
230 // %8 = OpUndef %v2uint
231 // %14 = OpConstantComposite %v2uint %uint_1 %uint_10
232 // ...
233 // %10 = OpCompositeInsert %v2uint %uint_5 %8 0
234 // %11 = OpVectorShuffle %v2uint %10 %8 0 0
235 // %call = OpExtInst %v2uint %1 s_min %14 %11
236 auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
237 PoisonValue *PVal = PoisonValue::get(Arg0Ty);
239 PVal, CI->getOperand(1), ConstInt, "", CI->getIterator());
240 ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
241 Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
242 Value *NewVec =
243 new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI->getIterator());
244 CI->setOperand(1, NewVec);
245 CI->replaceUsesOfWith(OldF, NewF);
247}
248
249bool SPIRVRegularizer::runOnFunction(Function &F) {
250 runLowerConstExpr(F);
251 visit(F);
252 for (auto &OldNew : Old2NewFuncs) {
253 Function *OldF = OldNew.first;
254 Function *NewF = OldNew.second;
255 NewF->takeName(OldF);
256 OldF->eraseFromParent();
257 }
258 return true;
259}
260
262 return new SPIRVRegularizer();
263}
Expand Atomic instructions
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
This header defines various interfaces for pass management in LLVM.
iv Induction Variable Users
Definition: IVUsers.cpp:48
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
#define DEBUG_TYPE
Represent the analysis usage information of a pass.
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1349
void mutateFunctionType(FunctionType *FTy)
Definition: InstrTypes.h:1209
unsigned arg_size() const
Definition: InstrTypes.h:1292
This class represents a function call, abstracting a target machine's calling convention.
Constant Vector Declarations.
Definition: Constants.h:511
static Constant * getSplat(ElementCount EC, Constant *Elt)
Return a ConstantVector with the specified constant in each element.
Definition: Constants.cpp:1472
This is an important base class in LLVM.
Definition: Constant.h:42
This class represents an Operation in the Expression.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:173
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:216
AttributeList getAttributes() const
Return the attribute list for this Function.
Definition: Function.h:353
void eraseFromParent()
eraseFromParent - This method unlinks 'this' from the containing module and deletes it.
Definition: Function.cpp:458
arg_iterator arg_begin()
Definition: Function.h:868
void setAttributes(AttributeList Attrs)
Set the attribute list for this Function.
Definition: Function.h:356
static InsertElementInst * Create(Value *Vec, Value *NewElt, Value *Idx, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Base class for instruction visitors.
Definition: InstVisitor.h:78
RetTy visitCallInst(CallInst &I)
Definition: InstVisitor.h:223
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:311
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
static MetadataAsValue * get(LLVMContext &Context, Metadata *MD)
Definition: Metadata.cpp:103
Root of the metadata hierarchy.
Definition: Metadata.h:62
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
Definition: PassRegistry.h:37
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
In order to facilitate speculative execution, many instructions do not invoke immediate undefined beh...
Definition: Constants.h:1460
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1878
This instruction constructs a fixed permutation of two input vectors.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static IntegerType * getInt32Ty(LLVMContext &C)
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
void setOperand(unsigned i, Value *Val)
Definition: User.h:233
Value * getOperand(unsigned i) const
Definition: User.h:228
static ValueAsMetadata * get(Value *V)
Definition: Metadata.cpp:501
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
void takeName(Value *V)
Transfer the name from V to this value.
Definition: Value.cpp:383
self_iterator getIterator()
Definition: ilist_node.h:132
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
char * itaniumDemangle(std::string_view mangled_name, bool ParseParams=true)
Returns a non-NULL pointer to a NUL-terminated C style string that should be explicitly freed,...
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void initializeSPIRVRegularizerPass(PassRegistry &)
FunctionPass * createSPIRVRegularizerPass()
void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, CloneFunctionChangeType Changes, SmallVectorImpl< ReturnInst * > &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)
Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values.