29#define DEBUG_TYPE "dxil-flatten-arrays"
34class DXILFlattenArraysLegacy :
public ModulePass {
48 bool AllIndicesAreConstInt;
51class DXILFlattenArraysVisitor
52 :
public InstVisitor<DXILFlattenArraysVisitor, bool> {
54 DXILFlattenArraysVisitor() {}
76 static bool isMultiDimensionalArray(
Type *
T);
77 static std::pair<unsigned, Type *> getElementCountAndType(
Type *ArrayTy);
92 unsigned &GEPChainUseCount,
95 bool AllIndicesAreConstInt =
true);
97 bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
102bool DXILFlattenArraysVisitor::finish() {
107bool DXILFlattenArraysVisitor::isMultiDimensionalArray(
Type *
T) {
108 if (
ArrayType *ArrType = dyn_cast<ArrayType>(
T))
109 return isa<ArrayType>(ArrType->getElementType());
113std::pair<unsigned, Type *>
114DXILFlattenArraysVisitor::getElementCountAndType(
Type *ArrayTy) {
115 unsigned TotalElements = 1;
116 Type *CurrArrayTy = ArrayTy;
117 while (
auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
118 TotalElements *= InnerArrayTy->getNumElements();
119 CurrArrayTy = InnerArrayTy->getElementType();
121 return std::make_pair(TotalElements, CurrArrayTy);
124ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
127 "Indicies and dimmensions should be the same");
128 unsigned FlatIndex = 0;
129 unsigned Multiplier = 1;
131 for (
int I = Indices.
size() - 1;
I >= 0; --
I) {
132 unsigned DimSize = Dims[
I];
133 ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[
I]);
134 assert(CIndex &&
"This function expects all indicies to be ConstantInt");
136 Multiplier *= DimSize;
141Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
143 if (Indices.
size() == 1)
147 unsigned Multiplier = 1;
149 for (
int I = Indices.
size() - 1;
I >= 0; --
I) {
150 unsigned DimSize = Dims[
I];
153 FlatIndex = Builder.
CreateAdd(FlatIndex, ScaledIndex);
154 Multiplier *= DimSize;
159bool DXILFlattenArraysVisitor::visitLoadInst(
LoadInst &LI) {
161 for (
unsigned I = 0;
I < NumOperands; ++
I) {
164 if (CE &&
CE->getOpcode() == Instruction::GetElementPtr) {
166 cast<GetElementPtrInst>(
CE->getAsInstruction());
175 visitGetElementPtrInst(*OldGEP);
182bool DXILFlattenArraysVisitor::visitStoreInst(
StoreInst &SI) {
183 unsigned NumOperands =
SI.getNumOperands();
184 for (
unsigned I = 0;
I < NumOperands; ++
I) {
185 Value *CurrOpperand =
SI.getOperand(
I);
187 if (CE &&
CE->getOpcode() == Instruction::GetElementPtr) {
189 cast<GetElementPtrInst>(
CE->getAsInstruction());
195 SI.replaceAllUsesWith(NewStore);
196 SI.eraseFromParent();
197 visitGetElementPtrInst(*OldGEP);
204bool DXILFlattenArraysVisitor::visitAllocaInst(
AllocaInst &AI) {
210 auto [TotalElements,
BaseType] = getElementCountAndType(ArrType);
221void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
226 AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
232 if (!IsMultiDimArr) {
236 {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
237 std::move(Dims), AllIndicesAreConstInt}});
240 bool GepUses =
false;
243 recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
244 ++GEPChainUseCount, Indices, Dims,
245 AllIndicesAreConstInt);
250 if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
253 {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
254 std::move(Dims), AllIndicesAreConstInt}});
258bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
260 GEPData GEPInfo = GEPChainMap.at(&
GEP);
261 return visitGetElementPtrInstInGEPChainBase(GEPInfo,
GEP);
263bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
267 if (GEPInfo.AllIndicesAreConstInt)
268 FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
271 genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
273 ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
275 Builder.
CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
276 GEP.getName() +
".flat",
GEP.isInBounds());
278 GEP.replaceAllUsesWith(FlatGEP);
279 GEP.eraseFromParent();
284 auto It = GEPChainMap.find(&
GEP);
285 if (It != GEPChainMap.end())
286 return visitGetElementPtrInstInGEPChain(
GEP);
287 if (!isMultiDimensionalArray(
GEP.getSourceElementType()))
290 ArrayType *ArrType = cast<ArrayType>(
GEP.getSourceElementType());
292 auto [TotalElements,
BaseType] = getElementCountAndType(ArrType);
295 Value *PtrOperand =
GEP.getPointerOperand();
297 unsigned GEPChainUseCount = 0;
298 recursivelyCollectGEPs(
GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
304 if (GEPChainUseCount == 0) {
307 bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
308 GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
309 std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
310 return visitGetElementPtrInstInGEPChainBase(GEPInfo,
GEP);
313 PotentiallyDeadInstrs.emplace_back(&
GEP);
317bool DXILFlattenArraysVisitor::visit(
Function &
F) {
318 bool MadeChange =
false;
331 auto *ArrayTy = dyn_cast<ArrayType>(
Init->getType());
333 Elements.push_back(
Init);
336 unsigned ArrSize = ArrayTy->getNumElements();
337 if (isa<ConstantAggregateZero>(
Init)) {
338 for (
unsigned I = 0;
I < ArrSize; ++
I)
344 if (
auto *ArrayConstant = dyn_cast<ConstantArray>(
Init)) {
345 for (
unsigned I = 0;
I < ArrayConstant->getNumOperands(); ++
I) {
348 }
else if (
auto *DataArrayConstant = dyn_cast<ConstantDataArray>(
Init)) {
349 for (
unsigned I = 0;
I < DataArrayConstant->getNumElements(); ++
I) {
354 "Expected a ConstantArray or ConstantDataArray for array initializer!");
362 if (isa<ConstantAggregateZero>(
Init))
366 if (isa<UndefValue>(
Init))
369 if (!isa<ArrayType>(OrigType))
374 assert(FlattenedType->getNumElements() == FlattenedElements.
size() &&
375 "The number of collected elements should match the FlattenedType");
384 Type *OrigType =
G.getValueType();
385 if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
388 ArrayType *ArrType = cast<ArrayType>(OrigType);
390 DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
397 nullptr,
G.getName() +
".1dim", &
G,
398 G.getThreadLocalMode(),
G.getAddressSpace(),
399 G.isExternallyInitialized());
403 if (
G.getAlignment() > 0) {
407 if (
G.hasInitializer()) {
413 GlobalMap[&
G] = NewGlobal;
418 bool MadeChange =
false;
419 DXILFlattenArraysVisitor Impl;
423 if (
F.isDeclaration())
425 MadeChange |= Impl.visit(
F);
427 for (
auto &[Old, New] : GlobalMap) {
428 Old->replaceAllUsesWith(New);
429 Old->eraseFromParent();
443bool DXILFlattenArraysLegacy::runOnModule(
Module &M) {
447char DXILFlattenArraysLegacy::ID = 0;
450 "DXIL Array Flattener",
false,
false)
455 return new DXILFlattenArraysLegacy();
static void collectElements(Constant *Init, SmallVectorImpl< Constant * > &Elements)
static bool flattenArrays(Module &M)
static void flattenGlobalArrays(Module &M, DenseMap< GlobalVariable *, GlobalVariable * > &GlobalMap)
static Constant * transformInitializer(Constant *Init, Type *OrigType, ArrayType *FlattenedType, LLVMContext &Ctx)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned getNumElements(Type *Ty)
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Type * getAllocatedType() const
Return the type that is being allocated by the instruction.
void setAlignment(Align Align)
A container for analyses that lazily runs them and caches their results.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
LLVM Basic Block Representation.
This class represents a no-op cast from one type to another.
This class represents a function call, abstracting a target machine's calling convention.
This is the base class for all instructions that perform data casts.
static ConstantAggregateZero * get(Type *Ty)
static Constant * get(ArrayType *T, ArrayRef< Constant * > V)
A constant value that is initialized with an expression using other constant values.
This is the shared class of boolean and integer constants.
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This is an important base class in LLVM.
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &)
This instruction compares its operands according to the predicate given to the constructor.
This class represents a freeze function that returns random concrete value if an operand is either a ...
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Type * getSourceElementType() const
void setAlignment(Align Align)
Sets the alignment attribute of the GlobalObject.
void setUnnamedAddr(UnnamedAddr Val)
void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
This instruction compares its operands according to the predicate given to the constructor.
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
This instruction inserts a single (scalar) element into a VectorType value.
Base class for instruction visitors.
RetTy visitFreezeInst(FreezeInst &I)
RetTy visitFCmpInst(FCmpInst &I)
RetTy visitExtractElementInst(ExtractElementInst &I)
RetTy visitShuffleVectorInst(ShuffleVectorInst &I)
RetTy visitBitCastInst(BitCastInst &I)
void visit(Iterator Start, Iterator End)
RetTy visitPHINode(PHINode &I)
RetTy visitUnaryOperator(UnaryOperator &I)
RetTy visitStoreInst(StoreInst &I)
RetTy visitInsertElementInst(InsertElementInst &I)
RetTy visitAllocaInst(AllocaInst &I)
RetTy visitBinaryOperator(BinaryOperator &I)
RetTy visitICmpInst(ICmpInst &I)
RetTy visitCallInst(CallInst &I)
RetTy visitCastInst(CastInst &I)
RetTy visitSelectInst(SelectInst &I)
RetTy visitGetElementPtrInst(GetElementPtrInst &I)
void visitInstruction(Instruction &I)
RetTy visitLoadInst(LoadInst &I)
void insertBefore(Instruction *InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified instruction.
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
void setAlignment(Align Align)
Align getAlign() const
Return the alignment of the access that is being performed.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class represents the LLVM 'select' instruction.
This instruction constructs a fixed permutation of two input vectors.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
void setAlignment(Align Align)
The instances of the Type class are immutable: once they are created, they are never changed.
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Value * getOperand(unsigned i) const
unsigned getNumOperands() const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
StringRef getName() const
Return a constant reference to the value's name.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ CE
Windows NT (Windows on ARM)
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
ModulePass * createDXILFlattenArraysLegacyPass()
Pass to flatten arrays into a one dimensional DXIL legal form.
bool RecursivelyDeleteTriviallyDeadInstructionsPermissive(SmallVectorImpl< WeakTrackingVH > &DeadInsts, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
Same functionality as RecursivelyDeleteTriviallyDeadInstructions, but allow instructions that are not...