LLVM 20.0.0git
DXILDataScalarization.cpp
Go to the documentation of this file.
1//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8
10#include "DirectX.h"
12#include "llvm/ADT/STLExtras.h"
14#include "llvm/IR/IRBuilder.h"
15#include "llvm/IR/InstVisitor.h"
16#include "llvm/IR/Module.h"
17#include "llvm/IR/Operator.h"
18#include "llvm/IR/PassManager.h"
20#include "llvm/IR/Type.h"
23
24#define DEBUG_TYPE "dxil-data-scalarization"
25static const int MaxVecSize = 4;
26
27using namespace llvm;
28
30
31public:
32 bool runOnModule(Module &M) override;
34
35 static char ID; // Pass identification.
36};
37
38static bool findAndReplaceVectors(Module &M);
39
40class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
41public:
42 DataScalarizerVisitor() : GlobalMap() {}
43 bool visit(Instruction &I);
44 // InstVisitor methods. They return true if the instruction was scalarized,
45 // false if nothing changed.
46 bool visitInstruction(Instruction &I) { return false; }
47 bool visitSelectInst(SelectInst &SI) { return false; }
48 bool visitICmpInst(ICmpInst &ICI) { return false; }
49 bool visitFCmpInst(FCmpInst &FCI) { return false; }
50 bool visitUnaryOperator(UnaryOperator &UO) { return false; }
51 bool visitBinaryOperator(BinaryOperator &BO) { return false; }
53 bool visitCastInst(CastInst &CI) { return false; }
54 bool visitBitCastInst(BitCastInst &BCI) { return false; }
55 bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
56 bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
57 bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
58 bool visitPHINode(PHINode &PHI) { return false; }
59 bool visitLoadInst(LoadInst &LI);
60 bool visitStoreInst(StoreInst &SI);
61 bool visitCallInst(CallInst &ICI) { return false; }
62 bool visitFreezeInst(FreezeInst &FI) { return false; }
63 friend bool findAndReplaceVectors(llvm::Module &M);
64
65private:
66 GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
68};
69
71 assert(!GlobalMap.empty());
72 return InstVisitor::visit(I);
73}
74
76DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
77 if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
78 auto It = GlobalMap.find(OldGlobal);
79 if (It != GlobalMap.end()) {
80 return It->second; // Found, return the new global
81 }
82 }
83 return nullptr; // Not found
84}
85
87 unsigned NumOperands = LI.getNumOperands();
88 for (unsigned I = 0; I < NumOperands; ++I) {
89 Value *CurrOpperand = LI.getOperand(I);
90 ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
91 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
92 GetElementPtrInst *OldGEP =
93 cast<GetElementPtrInst>(CE->getAsInstruction());
94 OldGEP->insertBefore(&LI);
95 IRBuilder<> Builder(&LI);
96 LoadInst *NewLoad =
97 Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
98 NewLoad->setAlignment(LI.getAlign());
99 LI.replaceAllUsesWith(NewLoad);
100 LI.eraseFromParent();
101 visitGetElementPtrInst(*OldGEP);
102 return true;
103 }
104 if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
105 LI.setOperand(I, NewGlobal);
106 }
107 return false;
108}
109
111 unsigned NumOperands = SI.getNumOperands();
112 for (unsigned I = 0; I < NumOperands; ++I) {
113 Value *CurrOpperand = SI.getOperand(I);
114 ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
115 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
116 GetElementPtrInst *OldGEP =
117 cast<GetElementPtrInst>(CE->getAsInstruction());
118 OldGEP->insertBefore(&SI);
119 IRBuilder<> Builder(&SI);
120 StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
121 NewStore->setAlignment(SI.getAlign());
122 SI.replaceAllUsesWith(NewStore);
123 SI.eraseFromParent();
124 visitGetElementPtrInst(*OldGEP);
125 return true;
126 }
127 if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
128 SI.setOperand(I, NewGlobal);
129 }
130 return false;
131}
132
134
135 unsigned NumOperands = GEPI.getNumOperands();
136 GlobalVariable *NewGlobal = nullptr;
137 for (unsigned I = 0; I < NumOperands; ++I) {
138 Value *CurrOpperand = GEPI.getOperand(I);
139 NewGlobal = lookupReplacementGlobal(CurrOpperand);
140 if (NewGlobal)
141 break;
142 }
143 if (!NewGlobal)
144 return false;
145
146 IRBuilder<> Builder(&GEPI);
148 for (auto &Index : GEPI.indices())
149 Indices.push_back(Index);
150
151 Value *NewGEP =
152 Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
153 GEPI.getName(), GEPI.getNoWrapFlags());
154 GEPI.replaceAllUsesWith(NewGEP);
155 GEPI.eraseFromParent();
156 return true;
157}
158
159// Recursively Creates and Array like version of the given vector like type.
161 if (auto *VecTy = dyn_cast<VectorType>(T))
162 return ArrayType::get(VecTy->getElementType(),
163 dyn_cast<FixedVectorType>(VecTy)->getNumElements());
164 if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
165 Type *NewElementType =
166 replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
167 return ArrayType::get(NewElementType, ArrayTy->getNumElements());
168 }
169 // If it's not a vector or array, return the original type.
170 return T;
171}
172
174 LLVMContext &Ctx) {
175 // Handle ConstantAggregateZero (zero-initialized constants)
176 if (isa<ConstantAggregateZero>(Init)) {
177 return ConstantAggregateZero::get(NewType);
178 }
179
180 // Handle UndefValue (undefined constants)
181 if (isa<UndefValue>(Init)) {
182 return UndefValue::get(NewType);
183 }
184
185 // Handle vector to array transformation
186 if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
187 // Convert vector initializer to array initializer
189 if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
190 for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
191 ArrayElements.push_back(ConstVecInit->getOperand(I));
192 } else if (ConstantDataVector *ConstDataVecInit =
193 llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
194 for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
195 ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
196 } else {
197 assert(false && "Expected a ConstantVector or ConstantDataVector for "
198 "vector initializer!");
199 }
200
201 return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
202 }
203
204 // Handle array of vectors transformation
205 if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
206 auto *ArrayInit = dyn_cast<ConstantArray>(Init);
207 assert(ArrayInit && "Expected a ConstantArray for array initializer!");
208
210 for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
211 // Recursively transform array elements
212 Constant *NewElemInit = transformInitializer(
213 ArrayInit->getOperand(I), ArrayTy->getElementType(),
214 cast<ArrayType>(NewType)->getElementType(), Ctx);
215 NewArrayElements.push_back(NewElemInit);
216 }
217
218 return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
219 }
220
221 // If not a vector or array, return the original initializer
222 return Init;
223}
224
226 bool MadeChange = false;
227 LLVMContext &Ctx = M.getContext();
228 IRBuilder<> Builder(Ctx);
230 for (GlobalVariable &G : M.globals()) {
231 Type *OrigType = G.getValueType();
232
233 Type *NewType = replaceVectorWithArray(OrigType, Ctx);
234 if (OrigType != NewType) {
235 // Create a new global variable with the updated type
236 // Note: Initializer is set via transformInitializer
237 GlobalVariable *NewGlobal = new GlobalVariable(
238 M, NewType, G.isConstant(), G.getLinkage(),
239 /*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
240 G.getThreadLocalMode(), G.getAddressSpace(),
241 G.isExternallyInitialized());
242
243 // Copy relevant attributes
244 NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
245 if (G.getAlignment() > 0) {
246 NewGlobal->setAlignment(G.getAlign());
247 }
248
249 if (G.hasInitializer()) {
250 Constant *Init = G.getInitializer();
251 Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
252 NewGlobal->setInitializer(NewInit);
253 }
254
255 // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
256 // type equality. Instead we will use the visitor pattern.
257 Impl.GlobalMap[&G] = NewGlobal;
258 for (User *U : make_early_inc_range(G.users())) {
259 if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
260 ConstantExpr *CE = cast<ConstantExpr>(U);
261 for (User *UCE : make_early_inc_range(CE->users())) {
262 if (Instruction *Inst = dyn_cast<Instruction>(UCE))
263 Impl.visit(*Inst);
264 }
265 }
266 if (Instruction *Inst = dyn_cast<Instruction>(U))
267 Impl.visit(*Inst);
268 }
269 }
270 }
271
272 // Remove the old globals after the iteration
273 for (auto &[Old, New] : Impl.GlobalMap) {
274 Old->eraseFromParent();
275 MadeChange = true;
276 }
277 return MadeChange;
278}
279
282 bool MadeChanges = findAndReplaceVectors(M);
283 if (!MadeChanges)
284 return PreservedAnalyses::all();
286 return PA;
287}
288
290 return findAndReplaceVectors(M);
291}
292
294
296 "DXIL Data Scalarization", false, false)
299
301 return new DXILDataScalarizationLegacy();
302}
Rewrite undef for PHI
DXIL Data Scalarization
static bool findAndReplaceVectors(Module &M)
static const int MaxVecSize
Constant * transformInitializer(Constant *Init, Type *OrigType, Type *NewType, LLVMContext &Ctx)
static Type * replaceVectorWithArray(Type *T, LLVMContext &Ctx)
#define DEBUG_TYPE
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define I(x, y, z)
Definition: MD5.cpp:58
#define G(x, y, z)
Definition: MD5.cpp:56
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file contains some templates that are useful if you are working with the STL at all.
bool runOnModule(Module &M) override
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
bool visitCallInst(CallInst &ICI)
bool visitInstruction(Instruction &I)
bool visitBinaryOperator(BinaryOperator &BO)
bool visitInsertElementInst(InsertElementInst &IEI)
bool visitGetElementPtrInst(GetElementPtrInst &GEPI)
bool visitStoreInst(StoreInst &SI)
bool visitFreezeInst(FreezeInst &FI)
bool visitBitCastInst(BitCastInst &BCI)
bool visitCastInst(CastInst &CI)
bool visitFCmpInst(FCmpInst &FCI)
bool visitLoadInst(LoadInst &LI)
bool visitSelectInst(SelectInst &SI)
bool visitUnaryOperator(UnaryOperator &UO)
bool visitICmpInst(ICmpInst &ICI)
bool visitShuffleVectorInst(ShuffleVectorInst &SVI)
bool visit(Instruction &I)
bool visitPHINode(PHINode &PHI)
bool visitExtractElementInst(ExtractElementInst &EEI)
friend bool findAndReplaceVectors(llvm::Module &M)
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
This class represents a no-op cast from one type to another.
This class represents a function call, abstracting a target machine's calling convention.
This is the base class for all instructions that perform data casts.
Definition: InstrTypes.h:444
static ConstantAggregateZero * get(Type *Ty)
Definition: Constants.cpp:1672
static Constant * get(ArrayType *T, ArrayRef< Constant * > V)
Definition: Constants.cpp:1312
A vector constant whose element type is a simple 1/2/4/8-byte integer or float/double,...
Definition: Constants.h:770
A constant value that is initialized with an expression using other constant values.
Definition: Constants.h:1108
Constant Vector Declarations.
Definition: Constants.h:511
This is an important base class in LLVM.
Definition: Constant.h:42
PreservedAnalyses run(Module &M, ModuleAnalysisManager &)
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:156
bool empty() const
Definition: DenseMap.h:98
iterator end()
Definition: DenseMap.h:84
This instruction extracts a single (scalar) element from a VectorType value.
This instruction compares its operands according to the predicate given to the constructor.
This class represents a freeze function that returns random concrete value if an operand is either a ...
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Definition: Instructions.h:933
iterator_range< op_iterator > indices()
GEPNoWrapFlags getNoWrapFlags() const
Get the nowrap flags for the GEP instruction.
void setAlignment(Align Align)
Sets the alignment attribute of the GlobalObject.
Definition: Globals.cpp:143
void setUnnamedAddr(UnnamedAddr Val)
Definition: GlobalValue.h:231
Type * getValueType() const
Definition: GlobalValue.h:296
void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition: Globals.cpp:492
This instruction compares its operands according to the predicate given to the constructor.
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Definition: IRBuilder.h:1889
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1813
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition: IRBuilder.h:1826
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2697
This instruction inserts a single (scalar) element into a VectorType value.
Base class for instruction visitors.
Definition: InstVisitor.h:78
void visit(Iterator Start, Iterator End)
Definition: InstVisitor.h:87
void insertBefore(Instruction *InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified instruction.
Definition: Instruction.cpp:99
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:94
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
An instruction for reading from memory.
Definition: Instructions.h:176
void setAlignment(Align Align)
Definition: Instructions.h:215
Align getAlign() const
Return the alignment of the access that is being performed.
Definition: Instructions.h:211
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
This class represents the LLVM 'select' instruction.
This instruction constructs a fixed permutation of two input vectors.
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
void setAlignment(Align Align)
Definition: Instructions.h:337
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1859
void setOperand(unsigned i, Value *Val)
Definition: User.h:233
Value * getOperand(unsigned i) const
Definition: User.h:228
unsigned getNumOperands() const
Definition: User.h:250
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
ModulePass * createDXILDataScalarizationLegacyPass()
Pass to scalarize llvm global data into a DXIL legal form.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:657