LLVM 20.0.0git
DXILShaderFlags.cpp
Go to the documentation of this file.
1//===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains helper objects and APIs for working with DXIL
10/// Shader Flags.
11///
12//===----------------------------------------------------------------------===//
13
14#include "DXILShaderFlags.h"
15#include "DirectX.h"
20#include "llvm/IR/Instruction.h"
23#include "llvm/IR/Intrinsics.h"
24#include "llvm/IR/IntrinsicsDirectX.h"
25#include "llvm/IR/Module.h"
29
30using namespace llvm;
31using namespace llvm::dxil;
32
33/// Update the shader flags mask based on the given instruction.
34/// \param CSF Shader flags mask to update.
35/// \param I Instruction to check.
36void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
37 const Instruction &I,
38 DXILResourceTypeMap &DRTM) {
39 if (!CSF.Doubles)
40 CSF.Doubles = I.getType()->isDoubleTy();
41
42 if (!CSF.Doubles) {
43 for (const Value *Op : I.operands()) {
44 if (Op->getType()->isDoubleTy()) {
45 CSF.Doubles = true;
46 break;
47 }
48 }
49 }
50
51 if (CSF.Doubles) {
52 switch (I.getOpcode()) {
53 case Instruction::FDiv:
54 case Instruction::UIToFP:
55 case Instruction::SIToFP:
56 case Instruction::FPToUI:
57 case Instruction::FPToSI:
58 CSF.DX11_1_DoubleExtensions = true;
59 break;
60 }
61 }
62
63 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
64 switch (II->getIntrinsicID()) {
65 default:
66 break;
67 case Intrinsic::dx_resource_load_typedbuffer: {
69 DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())];
70 if (RTI.isTyped())
71 CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1;
72 }
73 }
74 }
75 // Handle call instructions
76 if (auto *CI = dyn_cast<CallInst>(&I)) {
77 const Function *CF = CI->getCalledFunction();
78 // Merge-in shader flags mask of the called function in the current module
79 if (FunctionFlags.contains(CF))
80 CSF.merge(FunctionFlags[CF]);
81
82 // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
83 // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
84 }
85}
86
87/// Construct ModuleShaderFlags for module Module M
89 CallGraph CG(M);
90
91 // Compute Shader Flags Mask for all functions using post-order visit of SCC
92 // of the call graph.
93 for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
94 ++SCCI) {
95 const std::vector<CallGraphNode *> &CurSCC = *SCCI;
96
97 // Union of shader masks of all functions in CurSCC
99 // List of functions in CurSCC that are neither external nor declarations
100 // and hence whose flags are collected
101 SmallVector<Function *> CurSCCFuncs;
102 for (CallGraphNode *CGN : CurSCC) {
103 Function *F = CGN->getFunction();
104 if (!F)
105 continue;
106
107 if (F->isDeclaration()) {
108 assert(!F->getName().starts_with("dx.op.") &&
109 "DXIL Shader Flag analysis should not be run post-lowering.");
110 continue;
111 }
112
114 for (const auto &BB : *F)
115 for (const auto &I : BB)
116 updateFunctionFlags(CSF, I, DRTM);
117 // Update combined shader flags mask for all functions in this SCC
118 SCCSF.merge(CSF);
119
120 CurSCCFuncs.push_back(F);
121 }
122
123 // Update combined shader flags mask for all functions of the module
124 CombinedSFMask.merge(SCCSF);
125
126 // Shader flags mask of each of the functions in an SCC of the call graph is
127 // the union of all functions in the SCC. Update shader flags masks of
128 // functions in CurSCC accordingly. This is trivially true if SCC contains
129 // one function.
130 for (Function *F : CurSCCFuncs)
131 // Merge SCCSF with that of F
132 FunctionFlags[F].merge(SCCSF);
133 }
134}
135
137 uint64_t FlagVal = (uint64_t) * this;
138 OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
139 if (FlagVal == 0)
140 return;
141 OS << "; Note: shader requires additional functionality:\n";
142#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str) \
143 if (FlagName) \
144 (OS << ";").indent(7) << Str << "\n";
145#include "llvm/BinaryFormat/DXContainerConstants.def"
146 OS << "; Note: extra DXIL module flags:\n";
147#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
148 if (FlagName) \
149 (OS << ";").indent(7) << Str << "\n";
150#include "llvm/BinaryFormat/DXContainerConstants.def"
151 OS << ";\n";
152}
153
154/// Return the shader flags mask of the specified function Func.
157 auto Iter = FunctionFlags.find(Func);
158 assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
159 "Get Shader Flags : No Shader Flags Mask exists for function");
160 return Iter->second;
161}
162
163//===----------------------------------------------------------------------===//
164// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
165
166// Provide an explicit template instantiation for the static ID.
167AnalysisKey ShaderFlagsAnalysis::Key;
168
172
174 MSFI.initialize(M, DRTM);
175
176 return MSFI;
177}
178
181 const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
182 // Print description of combined shader flags for all module functions
183 OS << "; Combined Shader Flags for Module\n";
184 FlagsInfo.getCombinedFlags().print(OS);
185 // Print shader flags mask for each of the module functions
186 OS << "; Shader Flags for Module Functions\n";
187 for (const auto &F : M.getFunctionList()) {
188 if (F.isDeclaration())
189 continue;
190 const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
191 OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
192 (uint64_t)(SFMask));
193 }
194
195 return PreservedAnalyses::all();
196}
197
198//===----------------------------------------------------------------------===//
199// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
200
202 DXILResourceTypeMap &DRTM =
203 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
204
205 MSFI.initialize(M, DRTM);
206 return false;
207}
208
210 AU.setPreservesAll();
212}
213
215
217 "DXIL Shader Flag Analysis", true, true)
220 "DXIL Shader Flag Analysis", true, true)
basic Basic Alias true
block Block Frequency Analysis
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
Module.h This file contains the declarations for the Module class.
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
module summary analysis
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
This builds on the llvm/ADT/GraphTraits.h file to find the strongly connected components (SCCs) of a ...
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
raw_pwrite_stream & OS
This file defines the SmallVector class.
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
A node in the call graph for a module.
Definition: CallGraph.h:165
The basic data container for the call graph of a Module of IR.
Definition: CallGraph.h:71
This class represents an Operation in the Expression.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
LLVM Value Representation.
Definition: Value.h:74
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Wrapper pass for the legacy pass manager.
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
bool runOnModule(Module &M) override
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
ModuleShaderFlags run(Module &M, ModuleAnalysisManager &AM)
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
Enumerate the SCCs of a directed graph in reverse topological order of the SCC DAG.
Definition: SCCIterator.h:49
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
scc_iterator< T > scc_begin(const T &G)
Construct the begin iterator for a deduced graph type T.
Definition: SCCIterator.h:233
auto formatv(bool Validate, const char *Fmt, Ts &&...Vals)
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:28
void merge(const ComputedShaderFlags CSF)
void print(raw_ostream &OS=dbgs()) const
const ComputedShaderFlags & getFunctionFlags(const Function *) const
Return the shader flags mask of the specified function Func.
void initialize(Module &, DXILResourceTypeMap &DRTM)
Construct ModuleShaderFlags for module Module M.
const ComputedShaderFlags & getCombinedFlags() const