30#define DEBUG_TYPE "dxil-flatten-arrays" 
   35class DXILFlattenArraysLegacy : 
public ModulePass {
 
   38  bool runOnModule(
Module &M) 
override;
 
   46  Value *RootPointerOperand;
 
   51class DXILFlattenArraysVisitor
 
   52    : 
public InstVisitor<DXILFlattenArraysVisitor, bool> {
 
   54  DXILFlattenArraysVisitor(
 
   56      : GlobalMap(GlobalMap) {}
 
   64  bool visitICmpInst(
ICmpInst &ICI) { 
return false; }
 
   65  bool visitFCmpInst(
FCmpInst &FCI) { 
return false; }
 
   68  bool visitCastInst(
CastInst &CI) { 
return false; }
 
   69  bool visitBitCastInst(
BitCastInst &BCI) { 
return false; }
 
   73  bool visitPHINode(
PHINode &
PHI) { 
return false; }
 
   76  bool visitCallInst(
CallInst &ICI) { 
return false; }
 
   77  bool visitFreezeInst(
FreezeInst &FI) { 
return false; }
 
   78  static bool isMultiDimensionalArray(
Type *
T);
 
   79  static std::pair<unsigned, Type *> getElementCountAndType(
Type *ArrayTy);
 
   95bool DXILFlattenArraysVisitor::finish() {
 
   96  GEPChainInfoMap.clear();
 
  101bool DXILFlattenArraysVisitor::isMultiDimensionalArray(
Type *
T) {
 
  107std::pair<unsigned, Type *>
 
  108DXILFlattenArraysVisitor::getElementCountAndType(
Type *ArrayTy) {
 
  109  unsigned TotalElements = 1;
 
  110  Type *CurrArrayTy = ArrayTy;
 
  112    TotalElements *= InnerArrayTy->getNumElements();
 
  113    CurrArrayTy = InnerArrayTy->getElementType();
 
  115  return std::make_pair(TotalElements, CurrArrayTy);
 
  118ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
 
  121         "Indicies and dimmensions should be the same");
 
  122  unsigned FlatIndex = 0;
 
  123  unsigned Multiplier = 1;
 
  125  for (
int I = Indices.
size() - 1; 
I >= 0; --
I) {
 
  126    unsigned DimSize = Dims[
I];
 
  128    assert(CIndex && 
"This function expects all indicies to be ConstantInt");
 
  130    Multiplier *= DimSize;
 
  135Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
 
  137  if (Indices.
size() == 1)
 
  141  unsigned Multiplier = 1;
 
  143  for (
int I = Indices.
size() - 1; 
I >= 0; --
I) {
 
  144    unsigned DimSize = Dims[
I];
 
  147    FlatIndex = Builder.
CreateAdd(FlatIndex, ScaledIndex);
 
  148    Multiplier *= DimSize;
 
  153bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
 
  155  for (
unsigned I = 0; 
I < NumOperands; ++
I) {
 
  158    if (CE && 
CE->getOpcode() == Instruction::GetElementPtr) {
 
  159      GetElementPtrInst *OldGEP =
 
  169      visitGetElementPtrInst(*OldGEP);
 
  176bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
 
  177  unsigned NumOperands = 
SI.getNumOperands();
 
  178  for (
unsigned I = 0; 
I < NumOperands; ++
I) {
 
  179    Value *CurrOpperand = 
SI.getOperand(
I);
 
  181    if (CE && 
CE->getOpcode() == Instruction::GetElementPtr) {
 
  182      GetElementPtrInst *OldGEP =
 
  187      StoreInst *NewStore = Builder.
CreateStore(
SI.getValueOperand(), OldGEP);
 
  189      SI.replaceAllUsesWith(NewStore);
 
  190      SI.eraseFromParent();
 
  191      visitGetElementPtrInst(*OldGEP);
 
  198bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
 
  204  auto [TotalElements, 
BaseType] = getElementCountAndType(ArrType);
 
  207  AllocaInst *FlatAlloca =
 
  215bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &
GEP) {
 
  220  Value *PtrOperand = 
GEP.getPointerOperand();
 
  225         "Pointer operand of GEP should not be a PHI Node");
 
  230      PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
 
  231    GetElementPtrInst *OldGEPI =
 
  238        Builder.
CreateGEP(
GEP.getSourceElementType(), OldGEPI, Indices,
 
  239                          GEP.getName(), 
GEP.getNoWrapFlags());
 
  241           "Expected newly-created GEP to be an instruction");
 
  244    GEP.replaceAllUsesWith(NewGEPI);
 
  245    GEP.eraseFromParent();
 
  246    visitGetElementPtrInst(*OldGEPI);
 
  247    visitGetElementPtrInst(*NewGEPI);
 
  255  const DataLayout &
DL = 
GEP.getDataLayout();
 
  256  unsigned BitWidth = 
DL.getIndexTypeSizeInBits(
GEP.getType());
 
  258  [[maybe_unused]] 
bool Success = 
GEP.collectOffset(
 
  270    if (!GEPChainInfoMap.contains(PtrOpGEP))
 
  273    GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
 
  274    Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
 
  275    Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
 
  276    for (
auto &VariableOffset : PGEPInfo.VariableOffsets)
 
  277      Info.VariableOffsets.insert(VariableOffset);
 
  278    Info.ConstantOffset += PGEPInfo.ConstantOffset;
 
  280    Info.RootPointerOperand = PtrOperand;
 
  285    Type *RootTy = 
GEP.getSourceElementType();
 
  287      if (GlobalMap.contains(GlobalVar))
 
  292      RootTy = Alloca->getAllocatedType();
 
  293    assert(!isMultiDimensionalArray(RootTy) &&
 
  294           "Expected root array type to be flattened");
 
  306  bool ReplaceThisGEP = 
GEP.users().empty();
 
  309      ReplaceThisGEP = 
true;
 
  311  if (ReplaceThisGEP) {
 
  312    unsigned BytesPerElem =
 
  313        DL.getTypeAllocSize(
Info.RootFlattenedArrayType->getArrayElementType());
 
  315           "Bytes per element should be a power of 2");
 
  321    uint64_t ConstantOffset =
 
  323    assert(ConstantOffset < UINT32_MAX &&
 
  324           "Constant byte offset for flat GEP index must fit within 32 bits");
 
  326    for (
auto [VarIndex, Multiplier] : 
Info.VariableOffsets) {
 
  327      assert(Multiplier.getActiveBits() <= 32 &&
 
  328             "The multiplier for a flat GEP index must fit within 32 bits");
 
  329      assert(VarIndex->getType()->isIntegerTy(32) &&
 
  330             "Expected i32-typed GEP indices");
 
  332      if (Multiplier.getZExtValue() % BytesPerElem != 0) {
 
  337                               Builder.
getInt32(Multiplier.getZExtValue()));
 
  342            Builder.
getInt32(Multiplier.getZExtValue() / BytesPerElem));
 
  343      FlattenedIndex = Builder.
CreateAdd(FlattenedIndex, VI);
 
  348        Info.RootFlattenedArrayType, 
Info.RootPointerOperand,
 
  349        {ZeroIndex, FlattenedIndex}, 
GEP.getName(), 
GEP.getNoWrapFlags());
 
  357          Info.RootFlattenedArrayType, 
Info.RootPointerOperand,
 
  358          {ZeroIndex, FlattenedIndex}, 
GEP.getNoWrapFlags(), 
GEP.getName(),
 
  364    GEP.replaceAllUsesWith(NewGEP);
 
  365    GEP.eraseFromParent();
 
  373  PotentiallyDeadInstrs.emplace_back(&
GEP);
 
  377bool DXILFlattenArraysVisitor::visit(Function &
F) {
 
  378  bool MadeChange = 
false;
 
  379  ReversePostOrderTraversal<Function *> RPOT(&
F);
 
  393    Elements.push_back(
Init);
 
  396  unsigned ArrSize = ArrayTy->getNumElements();
 
  398    for (
unsigned I = 0; 
I < ArrSize; ++
I)
 
  405    for (
unsigned I = 0; 
I < ArrayConstant->getNumOperands(); ++
I) {
 
  409    for (
unsigned I = 0; 
I < DataArrayConstant->getNumElements(); ++
I) {
 
  414        "Expected a ConstantArray or ConstantDataArray for array initializer!");
 
 
  434  assert(FlattenedType->getNumElements() == FlattenedElements.
size() &&
 
  435         "The number of collected elements should match the FlattenedType");
 
 
  443    Type *OrigType = 
G.getValueType();
 
  444    if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
 
  449        DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
 
  456                           nullptr, 
G.getName() + 
".1dim", &
G,
 
  457                           G.getThreadLocalMode(), 
G.getAddressSpace(),
 
  458                           G.isExternallyInitialized());
 
  462    if (
G.getAlignment() > 0) {
 
  466    if (
G.hasInitializer()) {
 
  472    GlobalMap[&
G] = NewGlobal;
 
 
  477  bool MadeChange = 
false;
 
  480  DXILFlattenArraysVisitor Impl(GlobalMap);
 
  482    if (
F.isDeclaration())
 
  484    MadeChange |= Impl.visit(
F);
 
  486  for (
auto &[Old, New] : GlobalMap) {
 
  487    Old->replaceAllUsesWith(New);
 
  488    Old->eraseFromParent();
 
 
  502bool DXILFlattenArraysLegacy::runOnModule(
Module &M) {
 
  506char DXILFlattenArraysLegacy::ID = 0;
 
  509                      "DXIL Array Flattener", 
false, 
false)
 
  514  return new DXILFlattenArraysLegacy();
 
 
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Analysis containing CSE Info
static Constant * transformInitializer(Constant *Init, Type *OrigType, Type *NewType, LLVMContext &Ctx)
static void collectElements(Constant *Init, SmallVectorImpl< Constant * > &Elements)
static bool flattenArrays(Module &M)
static void flattenGlobalArrays(Module &M, SmallDenseMap< GlobalVariable *, GlobalVariable * > &GlobalMap)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
BaseType
A given derived pointer can have multiple base pointers through phi/selects.
Class for arbitrary precision integers.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Type * getAllocatedType() const
Return the type that is being allocated by the instruction.
void setAlignment(Align Align)
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
static LLVM_ABI ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
This class represents a no-op cast from one type to another.
This class represents a function call, abstracting a target machine's calling convention.
This is the base class for all instructions that perform data casts.
static LLVM_ABI ConstantAggregateZero * get(Type *Ty)
static LLVM_ABI Constant * get(ArrayType *T, ArrayRef< Constant * > V)
This is the shared class of boolean and integer constants.
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This is an important base class in LLVM.
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &)
This instruction compares its operands according to the predicate given to the constructor.
This class represents a freeze function that returns random concrete value if an operand is either a ...
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
void setUnnamedAddr(UnnamedAddr Val)
LLVM_ABI void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
void setAlignment(Align Align)
Sets the alignment attribute of the GlobalVariable.
This instruction compares its operands according to the predicate given to the constructor.
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
BasicBlock::iterator GetInsertPoint() const
Value * CreateLShr(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
This instruction inserts a single (scalar) element into a VectorType value.
Base class for instruction visitors.
void visit(Iterator Start, Iterator End)
LLVM_ABI void insertBefore(InstListType::iterator InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified position.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
void setAlignment(Align Align)
Align getAlign() const
Return the alignment of the access that is being performed.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class represents the LLVM 'select' instruction.
This instruction constructs a fixed permutation of two input vectors.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
void setAlignment(Align Align)
The instances of the Type class are immutable: once they are created, they are never changed.
static LLVM_ABI UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Value * getOperand(unsigned i) const
unsigned getNumOperands() const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ CE
Windows NT (Windows on ARM)
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
ModulePass * createDXILFlattenArraysLegacyPass()
Pass to flatten arrays into a one dimensional DXIL legal form.
unsigned Log2_32(uint32_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructionsPermissive(SmallVectorImpl< WeakTrackingVH > &DeadInsts, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
Same functionality as RecursivelyDeleteTriviallyDeadInstructions, but allow instructions that are not...
constexpr unsigned BitWidth
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
A MapVector that performs no allocations if smaller than a certain size.