LLVM 22.0.0git
SPIRVRegularizer.cpp
Go to the documentation of this file.
1//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10// the pass was taken from SPIRV-LLVM translator.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRV.h"
17#include "llvm/IR/InstVisitor.h"
18#include "llvm/IR/PassManager.h"
20
21#include <list>
22
23#define DEBUG_TYPE "spirv-regularizer"
24
25using namespace llvm;
26
27namespace {
28struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
30
31public:
32 static char ID;
33 SPIRVRegularizer() : FunctionPass(ID) {}
34 bool runOnFunction(Function &F) override;
35 StringRef getPassName() const override { return "SPIR-V Regularizer"; }
36
37 void getAnalysisUsage(AnalysisUsage &AU) const override {
39 }
40 void visitCallInst(CallInst &CI);
41
42private:
43 void visitCallScalToVec(CallInst *CI, StringRef MangledName,
44 StringRef DemangledName);
45 void runLowerConstExpr(Function &F);
46};
47} // namespace
48
49char SPIRVRegularizer::ID = 0;
50
51INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
52 false)
53
54// Since SPIR-V cannot represent constant expression, constant expressions
55// in LLVM IR need to be lowered to instructions. For each function,
56// the constant expressions used by instructions of the function are replaced
57// by instructions placed in the entry block since it dominates all other BBs.
58// Each constant expression only needs to be lowered once in each function
59// and all uses of it by instructions in that function are replaced by
60// one instruction.
61// TODO: remove redundant instructions for common subexpression.
62void SPIRVRegularizer::runLowerConstExpr(Function &F) {
63 LLVMContext &Ctx = F.getContext();
64 std::list<Instruction *> WorkList;
65 for (auto &II : instructions(F))
66 WorkList.push_back(&II);
67
68 auto FBegin = F.begin();
69 while (!WorkList.empty()) {
70 Instruction *II = WorkList.front();
71
72 auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
73 if (isa<Function>(V))
74 return V;
75 auto *CE = cast<ConstantExpr>(V);
76 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
77 auto ReplInst = CE->getAsInstruction();
78 auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
79 ReplInst->insertBefore(InsPoint->getIterator());
80 LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
81 std::vector<Instruction *> Users;
82 // Do not replace use during iteration of use. Do it in another loop.
83 for (auto U : CE->users()) {
84 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
85 auto InstUser = dyn_cast<Instruction>(U);
86 // Only replace users in scope of current function.
87 if (InstUser && InstUser->getParent()->getParent() == &F)
88 Users.push_back(InstUser);
89 }
90 for (auto &User : Users) {
91 if (ReplInst->getParent() == User->getParent() &&
92 User->comesBefore(ReplInst))
93 ReplInst->moveBefore(User->getIterator());
94 User->replaceUsesOfWith(CE, ReplInst);
95 }
96 return ReplInst;
97 };
98
99 WorkList.pop_front();
100 auto LowerConstantVec = [&II, &LowerOp, &WorkList,
101 &Ctx](ConstantVector *Vec,
102 unsigned NumOfOp) -> Value * {
103 if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
104 return isa<ConstantExpr>(V) || isa<Function>(V);
105 })) {
106 // Expand a vector of constexprs and construct it back with
107 // series of insertelement instructions.
108 std::list<Value *> OpList;
109 std::transform(Vec->op_begin(), Vec->op_end(),
110 std::back_inserter(OpList),
111 [LowerOp](Value *V) { return LowerOp(V); });
112 Value *Repl = nullptr;
113 unsigned Idx = 0;
114 auto *PhiII = dyn_cast<PHINode>(II);
115 Instruction *InsPoint =
116 PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
117 std::list<Instruction *> ReplList;
118 for (auto V : OpList) {
119 if (auto *Inst = dyn_cast<Instruction>(V))
120 ReplList.push_back(Inst);
122 (Repl ? Repl : PoisonValue::get(Vec->getType())), V,
123 ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "",
124 InsPoint->getIterator());
125 }
126 WorkList.splice(WorkList.begin(), ReplList);
127 return Repl;
128 }
129 return nullptr;
130 };
131 for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
132 auto *Op = II->getOperand(OI);
133 if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
134 Value *ReplInst = LowerConstantVec(Vec, OI);
135 if (ReplInst)
136 II->replaceUsesOfWith(Op, ReplInst);
137 } else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
138 WorkList.push_front(cast<Instruction>(LowerOp(CE)));
139 } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
140 auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
141 if (!ConstMD)
142 continue;
143 Constant *C = ConstMD->getValue();
144 Value *ReplInst = nullptr;
145 if (auto *Vec = dyn_cast<ConstantVector>(C))
146 ReplInst = LowerConstantVec(Vec, OI);
147 if (auto *CE = dyn_cast<ConstantExpr>(C))
148 ReplInst = LowerOp(CE);
149 if (!ReplInst)
150 continue;
151 Metadata *RepMD = ValueAsMetadata::get(ReplInst);
152 Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
153 II->setOperand(OI, RepMDVal);
154 WorkList.push_front(cast<Instruction>(ReplInst));
155 }
156 }
157 }
158}
159
160// It fixes calls to OCL builtins that accept vector arguments and one of them
161// is actually a scalar splat.
162void SPIRVRegularizer::visitCallInst(CallInst &CI) {
163 auto F = CI.getCalledFunction();
164 if (!F)
165 return;
166
167 auto MangledName = F->getName();
168 char *NameStr = itaniumDemangle(F->getName().data());
169 if (!NameStr)
170 return;
171 StringRef DemangledName(NameStr);
172
173 // TODO: add support for other builtins.
174 if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
175 DemangledName.starts_with("min") || DemangledName.starts_with("max"))
176 visitCallScalToVec(&CI, MangledName, DemangledName);
177 free(NameStr);
178}
179
180void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
181 StringRef DemangledName) {
182 // Check if all arguments have the same type - it's simple case.
183 auto Uniform = true;
184 Type *Arg0Ty = CI->getOperand(0)->getType();
185 auto IsArg0Vector = isa<VectorType>(Arg0Ty);
186 for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
187 Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
188 if (Uniform)
189 return;
190
191 auto *OldF = CI->getCalledFunction();
192 Function *NewF = nullptr;
193 auto [It, Inserted] = Old2NewFuncs.try_emplace(OldF);
194 if (Inserted) {
195 AttributeList Attrs = CI->getCalledFunction()->getAttributes();
196 SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
197 auto *NewFTy =
198 FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
199 NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
200 *OldF->getParent());
202 auto NewFArgIt = NewF->arg_begin();
203 for (auto &Arg : OldF->args()) {
204 auto ArgName = Arg.getName();
205 NewFArgIt->setName(ArgName);
206 VMap[&Arg] = &(*NewFArgIt++);
207 }
209 CloneFunctionInto(NewF, OldF, VMap,
210 CloneFunctionChangeType::LocalChangesOnly, Returns);
211 NewF->setAttributes(Attrs);
212 It->second = NewF;
213 } else {
214 NewF = It->second;
215 }
216 assert(NewF);
217
218 // This produces an instruction sequence that implements a splat of
219 // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
220 // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
221 // For instance (transcoding/OpMin.ll), this call
222 // call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
223 // is translated to
224 // %8 = OpUndef %v2uint
225 // %14 = OpConstantComposite %v2uint %uint_1 %uint_10
226 // ...
227 // %10 = OpCompositeInsert %v2uint %uint_5 %8 0
228 // %11 = OpVectorShuffle %v2uint %10 %8 0 0
229 // %call = OpExtInst %v2uint %1 s_min %14 %11
230 auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
231 PoisonValue *PVal = PoisonValue::get(Arg0Ty);
233 PVal, CI->getOperand(1), ConstInt, "", CI->getIterator());
234 ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
235 Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
236 Value *NewVec =
237 new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI->getIterator());
238 CI->setOperand(1, NewVec);
239 CI->replaceUsesOfWith(OldF, NewF);
241}
242
243bool SPIRVRegularizer::runOnFunction(Function &F) {
244 runLowerConstExpr(F);
245 visit(F);
246 for (auto &OldNew : Old2NewFuncs) {
247 Function *OldF = OldNew.first;
248 Function *NewF = OldNew.second;
249 NewF->takeName(OldF);
250 OldF->eraseFromParent();
251 }
252 return true;
253}
254
256 return new SPIRVRegularizer();
257}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Expand Atomic instructions
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
Represent the analysis usage information of a pass.
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
void mutateFunctionType(FunctionType *FTy)
unsigned arg_size() const
This class represents a function call, abstracting a target machine's calling convention.
Constant Vector Declarations.
Definition Constants.h:517
FixedVectorType * getType() const
Specialize the getType() method to always return a FixedVectorType, which reduces the amount of casti...
Definition Constants.h:540
static LLVM_ABI Constant * getSplat(ElementCount EC, Constant *Elt)
Return a ConstantVector with the specified constant in each element.
This is an important base class in LLVM.
Definition Constant.h:43
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:229
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition Function.h:166
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition Function.h:209
AttributeList getAttributes() const
Return the attribute list for this Function.
Definition Function.h:352
void eraseFromParent()
eraseFromParent - This method unlinks 'this' from the containing module and deletes it.
Definition Function.cpp:448
arg_iterator arg_begin()
Definition Function.h:866
void setAttributes(AttributeList Attrs)
Set the attribute list for this Function.
Definition Function.h:355
static InsertElementInst * Create(Value *Vec, Value *NewElt, Value *Idx, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Base class for instruction visitors.
Definition InstVisitor.h:78
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:319
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
static LLVM_ABI MetadataAsValue * get(LLVMContext &Context, Metadata *MD)
Definition Metadata.cpp:103
Root of the metadata hierarchy.
Definition Metadata.h:63
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition Pass.cpp:112
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:297
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition User.cpp:21
op_iterator op_begin()
Definition User.h:284
void setOperand(unsigned i, Value *Val)
Definition User.h:237
Value * getOperand(unsigned i) const
Definition User.h:232
op_iterator op_end()
Definition User.h:286
static LLVM_ABI ValueAsMetadata * get(Value *V)
Definition Metadata.cpp:502
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1101
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
Definition Value.cpp:396
self_iterator getIterator()
Definition ilist_node.h:130
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
DEMANGLE_ABI char * itaniumDemangle(std::string_view mangled_name, bool ParseParams=true)
Returns a non-NULL pointer to a NUL-terminated C style string that should be explicitly freed,...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
FunctionPass * createSPIRVRegularizerPass()
DWARFExpression::Operation Op
ValueMap< const Value *, WeakTrackingVH > ValueToValueMapTy
LLVM_ABI void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, CloneFunctionChangeType Changes, SmallVectorImpl< ReturnInst * > &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)
Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values.
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:565