LLVM 20.0.0git
DXILOpLowering.cpp
Go to the documentation of this file.
1//===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "DXILOpLowering.h"
10#include "DXILConstants.h"
12#include "DXILOpBuilder.h"
13#include "DirectX.h"
15#include "llvm/CodeGen/Passes.h"
17#include "llvm/IR/IRBuilder.h"
18#include "llvm/IR/Instruction.h"
19#include "llvm/IR/Intrinsics.h"
20#include "llvm/IR/IntrinsicsDirectX.h"
21#include "llvm/IR/Module.h"
22#include "llvm/IR/PassManager.h"
23#include "llvm/Pass.h"
25
26#define DEBUG_TYPE "dxil-op-lower"
27
28using namespace llvm;
29using namespace llvm::dxil;
30
32 switch (F.getIntrinsicID()) {
33 case Intrinsic::dx_dot2:
34 case Intrinsic::dx_dot3:
35 case Intrinsic::dx_dot4:
36 return true;
37 }
38 return false;
39}
40
42 SmallVector<Value *> ExtractedElements;
43 auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
44 for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
45 Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
46 Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
47 ExtractedElements.push_back(ExtractedElement);
48 }
49 return ExtractedElements;
50}
51
53 IRBuilder<> &Builder) {
54 // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
55 unsigned NumOperands = Orig->getNumOperands() - 1;
56 assert(NumOperands > 0);
57 Value *Arg0 = Orig->getOperand(0);
58 [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
59 assert(VecArg0);
60 SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
61 for (unsigned I = 1; I < NumOperands; ++I) {
62 Value *Arg = Orig->getOperand(I);
63 [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
64 assert(VecArg);
65 assert(VecArg0->getElementType() == VecArg->getElementType());
66 assert(VecArg0->getNumElements() == VecArg->getNumElements());
67 auto NextOperandList = populateOperands(Arg, Builder);
68 NewOperands.append(NextOperandList.begin(), NextOperandList.end());
69 }
70 return NewOperands;
71}
72
73namespace {
74class OpLowerer {
75 Module &M;
76 DXILOpBuilder OpBuilder;
77
78public:
79 OpLowerer(Module &M) : M(M), OpBuilder(M) {}
80
81 void replaceFunction(Function &F,
82 llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
83 for (User *U : make_early_inc_range(F.users())) {
84 CallInst *CI = dyn_cast<CallInst>(U);
85 if (!CI)
86 continue;
87
88 if (Error E = ReplaceCall(CI)) {
89 std::string Message(toString(std::move(E)));
90 DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
91 CI->getDebugLoc());
92 M.getContext().diagnose(Diag);
93 continue;
94 }
95 }
96 if (F.user_empty())
97 F.eraseFromParent();
98 }
99
100 void replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) {
101 bool IsVectorArgExpansion = isVectorArgExpansion(F);
102 replaceFunction(F, [&](CallInst *CI) -> Error {
104 OpBuilder.getIRB().SetInsertPoint(CI);
105 if (IsVectorArgExpansion) {
106 SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
107 Args.append(NewArgs.begin(), NewArgs.end());
108 } else
109 Args.append(CI->arg_begin(), CI->arg_end());
110
111 Expected<CallInst *> OpCall =
112 OpBuilder.tryCreateOp(DXILOp, Args, F.getReturnType());
113 if (Error E = OpCall.takeError())
114 return E;
115
116 CI->replaceAllUsesWith(*OpCall);
117 CI->eraseFromParent();
118 return Error::success();
119 });
120 }
121
122 bool lowerIntrinsics() {
123 bool Updated = false;
124
125 for (Function &F : make_early_inc_range(M.functions())) {
126 if (!F.isDeclaration())
127 continue;
128 Intrinsic::ID ID = F.getIntrinsicID();
129 switch (ID) {
130 default:
131 continue;
132#define DXIL_OP_INTRINSIC(OpCode, Intrin) \
133 case Intrin: \
134 replaceFunctionWithOp(F, OpCode); \
135 break;
136#include "DXILOperation.inc"
137 }
138 Updated = true;
139 }
140 return Updated;
141 }
142};
143} // namespace
144
146 if (OpLowerer(M).lowerIntrinsics())
148 return PreservedAnalyses::all();
149}
150
151namespace {
152class DXILOpLoweringLegacy : public ModulePass {
153public:
154 bool runOnModule(Module &M) override {
155 return OpLowerer(M).lowerIntrinsics();
156 }
157 StringRef getPassName() const override { return "DXIL Op Lowering"; }
158 DXILOpLoweringLegacy() : ModulePass(ID) {}
159
160 static char ID; // Pass identification.
161 void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
162 // Specify the passes that your pass depends on
164 }
165};
166char DXILOpLoweringLegacy::ID = 0;
167} // end anonymous namespace
168
169INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
170 false, false)
171INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
172 false)
173
175 return new DXILOpLoweringLegacy();
176}
static bool isVectorArgExpansion(Function &F)
static SmallVector< Value * > argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder)
static SmallVector< Value * > populateOperands(Value *Arg, IRBuilder<> &Builder)
#define DEBUG_TYPE
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1385
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1391
This class represents a function call, abstracting a target machine's calling convention.
This class represents an Operation in the Expression.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Diagnostic information for unsupported feature in backend.
Lightweight error class with error context and mandatory checking.
Definition: Error.h:160
static ErrorSuccess success()
Create a success value.
Definition: Error.h:337
Tagged union holding either a T or a Error.
Definition: Error.h:481
Error takeError()
Take ownership of the stored error.
Definition: Error.h:608
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2480
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2686
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:466
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:92
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:70
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:114
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
static IntegerType * getInt32Ty(LLVMContext &C)
Value * getOperand(unsigned i) const
Definition: User.h:169
unsigned getNumOperands() const
Definition: User.h:191
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
Expected< CallInst * > tryCreateOp(dxil::OpCode Op, ArrayRef< Value * > Args, Type *RetTy=nullptr)
Try to create a call instruction for the given DXIL op.
An efficient, type-erasing, non-owning reference to a callable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
std::optional< const char * > toString(const std::optional< DWARFFormValue > &V)
Take an optional DWARFFormValue and try to extract a string value from it.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:656
ModulePass * createDXILOpLoweringLegacyPass()
Pass to lowering LLVM intrinsic call to DXIL op function call.