LLVM 22.0.0git
SPIRVLegalizePointerCast.cpp
Go to the documentation of this file.
1//===-- SPIRVLegalizePointerCast.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The LLVM IR has multiple legal patterns we cannot lower to Logical SPIR-V.
10// This pass modifies such loads to have an IR we can directly lower to valid
11// logical SPIR-V.
12// OpenCL can avoid this because they rely on ptrcast, which is not supported
13// by logical SPIR-V.
14//
15// This pass relies on the assign_ptr_type intrinsic to deduce the type of the
16// pointed values, must replace all occurences of `ptrcast`. This is why
17// unhandled cases are reported as unreachable: we MUST cover all cases.
18//
19// 1. Loading the first element of an array
20//
21// %array = [10 x i32]
22// %value = load i32, ptr %array
23//
24// LLVM can skip the GEP instruction, and only request loading the first 4
25// bytes. In logical SPIR-V, we need an OpAccessChain to access the first
26// element. This pass will add a getelementptr instruction before the load.
27//
28//
29// 2. Implicit downcast from load
30//
31// %1 = getelementptr <4 x i32>, ptr %vec4, i64 0
32// %2 = load <3 x i32>, ptr %1
33//
34// The pointer in the GEP instruction is only used for offset computations,
35// but it doesn't NEED to match the pointed type. OpAccessChain however
36// requires this. Also, LLVM loads define the bitwidth of the load, not the
37// pointer. In this example, we can guess %vec4 is a vec4 thanks to the GEP
38// instruction basetype, but we only want to load the first 3 elements, hence
39// do a partial load. In logical SPIR-V, this is not legal. What we must do
40// is load the full vector (basetype), extract 3 elements, and recombine them
41// to form a 3-element vector.
42//
43//===----------------------------------------------------------------------===//
44
45#include "SPIRV.h"
46#include "SPIRVSubtarget.h"
47#include "SPIRVTargetMachine.h"
48#include "SPIRVUtils.h"
50#include "llvm/IR/IRBuilder.h"
52#include "llvm/IR/Intrinsics.h"
53#include "llvm/IR/IntrinsicsSPIRV.h"
56
57using namespace llvm;
58
59namespace {
60class SPIRVLegalizePointerCast : public FunctionPass {
61
62 // Builds the `spv_assign_type` assigning |Ty| to |Value| at the current
63 // builder position.
64 void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg) {
65 Value *OfType = PoisonValue::get(Ty);
66 CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
67 {Arg->getType()}, OfType, Arg, {}, B);
68 GR->addAssignPtrTypeInstr(Arg, AssignCI);
69 }
70
71 // Loads parts of the vector of type |SourceType| from the pointer |Source|
72 // and create a new vector of type |TargetType|. |TargetType| must be a vector
73 // type, and element types of |TargetType| and |SourceType| must match.
74 // Returns the loaded value.
75 Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
76 FixedVectorType *TargetType, Value *Source) {
77 assert(TargetType->getNumElements() <= SourceType->getNumElements());
78 LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
79 buildAssignType(B, SourceType, NewLoad);
80 Value *AssignValue = NewLoad;
81 if (TargetType->getElementType() != SourceType->getElementType()) {
82 AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
83 {TargetType, SourceType}, {NewLoad});
84 buildAssignType(B, TargetType, AssignValue);
85 }
86
87 SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
88 for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
89 Mask[I] = I;
90 Value *Output = B.CreateShuffleVector(AssignValue, AssignValue, Mask);
91 buildAssignType(B, TargetType, Output);
92 return Output;
93 }
94
95 // Loads the first value in an aggregate pointed by |Source| of containing
96 // elements of type |ElementType|. Load flags will be copied from |BadLoad|,
97 // which should be the load being legalized. Returns the loaded value.
98 Value *loadFirstValueFromAggregate(IRBuilder<> &B, Type *ElementType,
99 Value *Source, LoadInst *BadLoad) {
101 BadLoad->getPointerOperandType()};
102 SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(false), Source,
103 B.getInt32(0), B.getInt32(0)};
104 auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
105 GR->buildAssignPtr(B, ElementType, GEP);
106
107 LoadInst *LI = B.CreateLoad(ElementType, GEP);
108 LI->setAlignment(BadLoad->getAlign());
109 buildAssignType(B, ElementType, LI);
110 return LI;
111 }
112
113 // Replaces the load instruction to get rid of the ptrcast used as source
114 // operand.
115 void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
116 Value *OriginalOperand) {
117 Type *FromTy = GR->findDeducedElementType(OriginalOperand);
118 Type *ToTy = GR->findDeducedElementType(CastedOperand);
119 Value *Output = nullptr;
120
121 auto *SAT = dyn_cast<ArrayType>(FromTy);
122 auto *SVT = dyn_cast<FixedVectorType>(FromTy);
123 auto *SST = dyn_cast<StructType>(FromTy);
124 auto *DVT = dyn_cast<FixedVectorType>(ToTy);
125
126 B.SetInsertPoint(LI);
127
128 // Destination is the element type of Source, and source is an array ->
129 // Loading 1st element.
130 // - float a = array[0];
131 if (SAT && SAT->getElementType() == ToTy)
132 Output = loadFirstValueFromAggregate(B, SAT->getElementType(),
133 OriginalOperand, LI);
134 // Destination is the element type of Source, and source is a vector ->
135 // Vector to scalar.
136 // - float a = vector.x;
137 else if (!DVT && SVT && SVT->getElementType() == ToTy) {
138 Output = loadFirstValueFromAggregate(B, SVT->getElementType(),
139 OriginalOperand, LI);
140 }
141 // Destination is a smaller vector than source or different vector type.
142 // - float3 v3 = vector4;
143 // - float4 v2 = int4;
144 else if (SVT && DVT)
145 Output = loadVectorFromVector(B, SVT, DVT, OriginalOperand);
146 // Destination is the scalar type stored at the start of an aggregate.
147 // - struct S { float m };
148 // - float v = s.m;
149 else if (SST && SST->getTypeAtIndex(0u) == ToTy)
150 Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI);
151 else
152 llvm_unreachable("Unimplemented implicit down-cast from load.");
153
154 GR->replaceAllUsesWith(LI, Output, /* DeleteOld= */ true);
155 DeadInstructions.push_back(LI);
156 }
157
158 // Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
159 Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
160 unsigned Index) {
161 Type *Int32Ty = Type::getInt32Ty(B.getContext());
162 SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
163 Element->getType(), Int32Ty};
164 SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
165 Instruction *NewI =
166 B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
167 buildAssignType(B, Vector->getType(), NewI);
168 return NewI;
169 }
170
171 // Creates an spv_extractelt instruction (equivalent to llvm's
172 // extractelement).
173 Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
174 unsigned Index) {
175 Type *Int32Ty = Type::getInt32Ty(B.getContext());
176 SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
177 SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
178 Instruction *NewI =
179 B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
180 buildAssignType(B, ElementType, NewI);
181 return NewI;
182 }
183
184 // Stores the given Src vector operand into the Dst vector, adjusting the size
185 // if required.
186 Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
187 Align Alignment) {
188 FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
189 FixedVectorType *DstType =
190 cast<FixedVectorType>(GR->findDeducedElementType(Dst));
191 assert(DstType->getNumElements() >= SrcType->getNumElements());
192
193 LoadInst *LI = B.CreateLoad(DstType, Dst);
194 LI->setAlignment(Alignment);
195 Value *OldValues = LI;
196 buildAssignType(B, OldValues->getType(), OldValues);
197 Value *NewValues = Src;
198
199 for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
200 Value *Element =
201 makeExtractElement(B, SrcType->getElementType(), NewValues, I);
202 OldValues = makeInsertElement(B, OldValues, Element, I);
203 }
204
205 StoreInst *SI = B.CreateStore(OldValues, Dst);
206 SI->setAlignment(Alignment);
207 return SI;
208 }
209
210 void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
211 SmallVectorImpl<Value *> &Indices) {
212 Indices.push_back(B.getInt32(0));
213
214 if (Search == Aggregate)
215 return;
216
217 if (auto *ST = dyn_cast<StructType>(Aggregate))
218 buildGEPIndexChain(B, Search, ST->getTypeAtIndex(0u), Indices);
219 else if (auto *AT = dyn_cast<ArrayType>(Aggregate))
220 buildGEPIndexChain(B, Search, AT->getElementType(), Indices);
221 else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
222 buildGEPIndexChain(B, Search, VT->getElementType(), Indices);
223 else
224 llvm_unreachable("Bad access chain?");
225 }
226
227 // Stores the given Src value into the first entry of the Dst aggregate.
228 Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
229 Type *DstPointeeType, Align Alignment) {
230 SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
231 SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst};
232 buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
233 auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
234 GR->buildAssignPtr(B, Src->getType(), GEP);
235 StoreInst *SI = B.CreateStore(Src, GEP);
236 SI->setAlignment(Alignment);
237 return SI;
238 }
239
240 bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
241 if (Search == Aggregate)
242 return true;
243 if (auto *ST = dyn_cast<StructType>(Aggregate))
244 return isTypeFirstElementAggregate(Search, ST->getTypeAtIndex(0u));
245 if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
246 return isTypeFirstElementAggregate(Search, VT->getElementType());
247 if (auto *AT = dyn_cast<ArrayType>(Aggregate))
248 return isTypeFirstElementAggregate(Search, AT->getElementType());
249 return false;
250 }
251
252 // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
253 // operand into a valid logical SPIR-V store with no ptrcast.
254 void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
255 Value *Dst, Align Alignment) {
256 Type *ToTy = GR->findDeducedElementType(Dst);
257 Type *FromTy = Src->getType();
258
259 auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
260 auto *D_ST = dyn_cast<StructType>(ToTy);
261 auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
262
263 B.SetInsertPoint(BadStore);
264 if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
265 storeToFirstValueAggregate(B, Src, Dst, D_ST, Alignment);
266 else if (D_VT && S_VT)
267 storeVectorFromVector(B, Src, Dst, Alignment);
268 else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
269 storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
270 else
271 llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
272
273 DeadInstructions.push_back(BadStore);
274 }
275
276 void legalizePointerCast(IntrinsicInst *II) {
277 Value *CastedOperand = II;
278 Value *OriginalOperand = II->getOperand(0);
279
280 IRBuilder<> B(II->getContext());
281 std::vector<Value *> Users;
282 for (Use &U : II->uses())
283 Users.push_back(U.getUser());
284
285 for (Value *User : Users) {
286 if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
287 transformLoad(B, LI, CastedOperand, OriginalOperand);
288 continue;
289 }
290
291 if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
292 transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
293 SI->getAlign());
294 continue;
295 }
296
297 if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
298 if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
299 DeadInstructions.push_back(Intrin);
300 continue;
301 }
302
303 if (Intrin->getIntrinsicID() == Intrinsic::spv_gep) {
304 GR->replaceAllUsesWith(CastedOperand, OriginalOperand,
305 /* DeleteOld= */ false);
306 continue;
307 }
308
309 if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
310 Align Alignment;
311 if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
312 Alignment = Align(C->getZExtValue());
313 transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
314 Alignment);
315 continue;
316 }
317 }
318
319 llvm_unreachable("Unsupported ptrcast user. Please fix.");
320 }
321
322 DeadInstructions.push_back(II);
323 }
324
325public:
326 SPIRVLegalizePointerCast(SPIRVTargetMachine *TM) : FunctionPass(ID), TM(TM) {}
327
328 virtual bool runOnFunction(Function &F) override {
329 const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
330 GR = ST.getSPIRVGlobalRegistry();
331 DeadInstructions.clear();
332
333 std::vector<IntrinsicInst *> WorkList;
334 for (auto &BB : F) {
335 for (auto &I : BB) {
336 auto *II = dyn_cast<IntrinsicInst>(&I);
337 if (II && II->getIntrinsicID() == Intrinsic::spv_ptrcast)
338 WorkList.push_back(II);
339 }
340 }
341
342 for (IntrinsicInst *II : WorkList)
343 legalizePointerCast(II);
344
345 for (Instruction *I : DeadInstructions)
346 I->eraseFromParent();
347
348 return DeadInstructions.size() != 0;
349 }
350
351private:
352 SPIRVTargetMachine *TM = nullptr;
353 SPIRVGlobalRegistry *GR = nullptr;
354 std::vector<Instruction *> DeadInstructions;
355
356public:
357 static char ID;
358};
359} // namespace
360
361char SPIRVLegalizePointerCast::ID = 0;
362INITIALIZE_PASS(SPIRVLegalizePointerCast, "spirv-legalize-bitcast",
363 "SPIRV legalize bitcast pass", false, false)
364
366 return new SPIRVLegalizePointerCast(TM);
367}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Hexagon Common GEP
iv Induction Variable Users
Definition: IVUsers.cpp:48
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:56
This class represents a function call, abstracting a target machine's calling convention.
This is the shared class of boolean and integer constants.
Definition: Constants.h:87
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:592
unsigned getNumElements() const
Definition: DerivedTypes.h:635
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:314
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2780
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:49
An instruction for reading from memory.
Definition: Instructions.h:180
void setAlignment(Align Align)
Definition: Instructions.h:219
Type * getPointerOperandType() const
Definition: Instructions.h:262
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:215
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1885
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:574
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
An instruction for storing to memory.
Definition: Instructions.h:296
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
A Use represents the edge between a Value definition and its users.
Definition: Use.h:35
LLVM Value Representation.
Definition: Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
Type * getElementType() const
Definition: DerivedTypes.h:463
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:126
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
ElementType
The element type of an SRV or UAV resource.
Definition: DXILABI.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
CallInst * buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef< Type * > Types, Value *Arg, Value *Arg2, ArrayRef< Constant * > Imms, IRBuilder<> &B)
Definition: SPIRVUtils.cpp:821
@ SAT
Definition: SPIRVUtils.h:476
FunctionPass * createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM)
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39