LLVM 20.0.0git
DXILIntrinsicExpansion.cpp
Go to the documentation of this file.
1//===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains DXIL intrinsic expansions for those that don't have
10// opcodes in DirectX Intermediate Language (DXIL).
11//===----------------------------------------------------------------------===//
12
14#include "DirectX.h"
15#include "llvm/ADT/STLExtras.h"
17#include "llvm/CodeGen/Passes.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/IR/Instruction.h"
21#include "llvm/IR/Intrinsics.h"
22#include "llvm/IR/IntrinsicsDirectX.h"
23#include "llvm/IR/Module.h"
24#include "llvm/IR/PassManager.h"
25#include "llvm/IR/Type.h"
26#include "llvm/Pass.h"
29
30#define DEBUG_TYPE "dxil-intrinsic-expansion"
31
32using namespace llvm;
33
35 switch (F.getIntrinsicID()) {
36 case Intrinsic::abs:
37 case Intrinsic::exp:
38 case Intrinsic::log:
39 case Intrinsic::log10:
40 case Intrinsic::pow:
41 case Intrinsic::dx_any:
42 case Intrinsic::dx_clamp:
43 case Intrinsic::dx_uclamp:
44 case Intrinsic::dx_lerp:
45 case Intrinsic::dx_length:
46 case Intrinsic::dx_normalize:
47 case Intrinsic::dx_sdot:
48 case Intrinsic::dx_udot:
49 return true;
50 }
51 return false;
52}
53
54static Value *expandAbs(CallInst *Orig) {
55 Value *X = Orig->getOperand(0);
56 IRBuilder<> Builder(Orig->getParent());
57 Builder.SetInsertPoint(Orig);
58 Type *Ty = X->getType();
59 Type *EltTy = Ty->getScalarType();
60 Constant *Zero = Ty->isVectorTy()
63 cast<FixedVectorType>(Ty)->getNumElements()),
64 ConstantInt::get(EltTy, 0))
65 : ConstantInt::get(EltTy, 0);
66 auto *V = Builder.CreateSub(Zero, X);
67 return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr,
68 "dx.max");
69}
70
71static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
72 assert(DotIntrinsic == Intrinsic::dx_sdot ||
73 DotIntrinsic == Intrinsic::dx_udot);
74 Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
75 ? Intrinsic::dx_imad
76 : Intrinsic::dx_umad;
77 Value *A = Orig->getOperand(0);
78 Value *B = Orig->getOperand(1);
79 [[maybe_unused]] Type *ATy = A->getType();
80 [[maybe_unused]] Type *BTy = B->getType();
81 assert(ATy->isVectorTy() && BTy->isVectorTy());
82
83 IRBuilder<> Builder(Orig->getParent());
84 Builder.SetInsertPoint(Orig);
85
86 auto *AVec = dyn_cast<FixedVectorType>(A->getType());
87 Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
88 Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
89 Value *Result = Builder.CreateMul(Elt0, Elt1);
90 for (unsigned I = 1; I < AVec->getNumElements(); I++) {
91 Elt0 = Builder.CreateExtractElement(A, I);
92 Elt1 = Builder.CreateExtractElement(B, I);
93 Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
94 ArrayRef<Value *>{Elt0, Elt1, Result},
95 nullptr, "dx.mad");
96 }
97 return Result;
98}
99
101 Value *X = Orig->getOperand(0);
102 IRBuilder<> Builder(Orig->getParent());
103 Builder.SetInsertPoint(Orig);
104 Type *Ty = X->getType();
105 Type *EltTy = Ty->getScalarType();
106 Constant *Log2eConst =
109 cast<FixedVectorType>(Ty)->getNumElements()),
110 ConstantFP::get(EltTy, numbers::log2ef))
111 : ConstantFP::get(EltTy, numbers::log2ef);
112 Value *NewX = Builder.CreateFMul(Log2eConst, X);
113 auto *Exp2Call =
114 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
115 Exp2Call->setTailCall(Orig->isTailCall());
116 Exp2Call->setAttributes(Orig->getAttributes());
117 return Exp2Call;
118}
119
121 Value *X = Orig->getOperand(0);
122 IRBuilder<> Builder(Orig->getParent());
123 Builder.SetInsertPoint(Orig);
124 Type *Ty = X->getType();
125 Type *EltTy = Ty->getScalarType();
126
127 Value *Result = nullptr;
128 if (!Ty->isVectorTy()) {
129 Result = EltTy->isFloatingPointTy()
130 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
131 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
132 } else {
133 auto *XVec = dyn_cast<FixedVectorType>(Ty);
134 Value *Cond =
135 EltTy->isFloatingPointTy()
136 ? Builder.CreateFCmpUNE(
138 ElementCount::getFixed(XVec->getNumElements()),
139 ConstantFP::get(EltTy, 0)))
140 : Builder.CreateICmpNE(
142 ElementCount::getFixed(XVec->getNumElements()),
143 ConstantInt::get(EltTy, 0)));
144 Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
145 for (unsigned I = 1; I < XVec->getNumElements(); I++) {
146 Value *Elt = Builder.CreateExtractElement(Cond, I);
147 Result = Builder.CreateOr(Result, Elt);
148 }
149 }
150 return Result;
151}
152
154 Value *X = Orig->getOperand(0);
155 IRBuilder<> Builder(Orig->getParent());
156 Builder.SetInsertPoint(Orig);
157 Type *Ty = X->getType();
158 Type *EltTy = Ty->getScalarType();
159
160 // Though dx.length does work on scalar type, we can optimize it to just emit
161 // fabs, in CGBuiltin.cpp. We shouldn't see a scalar type here because
162 // CGBuiltin.cpp should have emitted a fabs call.
163 Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
164 auto *XVec = dyn_cast<FixedVectorType>(Ty);
165 unsigned XVecSize = XVec->getNumElements();
166 if (!(Ty->isVectorTy() && XVecSize > 1))
167 report_fatal_error(Twine("Invalid input type for length intrinsic"),
168 /* gen_crash_diag=*/false);
169
170 Value *Sum = Builder.CreateFMul(Elt, Elt);
171 for (unsigned I = 1; I < XVecSize; I++) {
172 Elt = Builder.CreateExtractElement(X, I);
173 Value *Mul = Builder.CreateFMul(Elt, Elt);
174 Sum = Builder.CreateFAdd(Sum, Mul);
175 }
176 return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum},
177 nullptr, "elt.sqrt");
178}
179
181 Value *X = Orig->getOperand(0);
182 Value *Y = Orig->getOperand(1);
183 Value *S = Orig->getOperand(2);
184 IRBuilder<> Builder(Orig->getParent());
185 Builder.SetInsertPoint(Orig);
186 auto *V = Builder.CreateFSub(Y, X);
187 V = Builder.CreateFMul(S, V);
188 return Builder.CreateFAdd(X, V, "dx.lerp");
189}
190
192 float LogConstVal = numbers::ln2f) {
193 Value *X = Orig->getOperand(0);
194 IRBuilder<> Builder(Orig->getParent());
195 Builder.SetInsertPoint(Orig);
196 Type *Ty = X->getType();
197 Type *EltTy = Ty->getScalarType();
198 Constant *Ln2Const =
201 cast<FixedVectorType>(Ty)->getNumElements()),
202 ConstantFP::get(EltTy, LogConstVal))
203 : ConstantFP::get(EltTy, LogConstVal);
204 auto *Log2Call =
205 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
206 Log2Call->setTailCall(Orig->isTailCall());
207 Log2Call->setAttributes(Orig->getAttributes());
208 return Builder.CreateFMul(Ln2Const, Log2Call);
209}
212}
213
215 Value *X = Orig->getOperand(0);
216 Type *Ty = Orig->getType();
217 Type *EltTy = Ty->getScalarType();
218 IRBuilder<> Builder(Orig->getParent());
219 Builder.SetInsertPoint(Orig);
220
221 auto *XVec = dyn_cast<FixedVectorType>(Ty);
222 if (!XVec) {
223 if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
224 const APFloat &fpVal = constantFP->getValueAPF();
225 if (fpVal.isZero())
226 report_fatal_error(Twine("Invalid input scalar: length is zero"),
227 /* gen_crash_diag=*/false);
228 }
229 return Builder.CreateFDiv(X, X);
230 }
231
232 unsigned XVecSize = XVec->getNumElements();
233 Value *DotProduct = nullptr;
234 // use the dot intrinsic corresponding to the vector size
235 switch (XVecSize) {
236 case 1:
237 report_fatal_error(Twine("Invalid input vector: length is zero"),
238 /* gen_crash_diag=*/false);
239 break;
240 case 2:
241 DotProduct = Builder.CreateIntrinsic(
242 EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
243 break;
244 case 3:
245 DotProduct = Builder.CreateIntrinsic(
246 EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
247 break;
248 case 4:
249 DotProduct = Builder.CreateIntrinsic(
250 EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
251 break;
252 default:
253 report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
254 /* gen_crash_diag=*/false);
255 }
256
257 // verify that the length is non-zero
258 // (if the dot product is non-zero, then the length is non-zero)
259 if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
260 const APFloat &fpVal = constantFP->getValueAPF();
261 if (fpVal.isZero())
262 report_fatal_error(Twine("Invalid input vector: length is zero"),
263 /* gen_crash_diag=*/false);
264 }
265
266 Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
267 ArrayRef<Value *>{DotProduct},
268 nullptr, "dx.rsqrt");
269
270 Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
271 return Builder.CreateFMul(X, MultiplicandVec);
272}
273
275
276 Value *X = Orig->getOperand(0);
277 Value *Y = Orig->getOperand(1);
278 Type *Ty = X->getType();
279 IRBuilder<> Builder(Orig->getParent());
280 Builder.SetInsertPoint(Orig);
281
282 auto *Log2Call =
283 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
284 auto *Mul = Builder.CreateFMul(Log2Call, Y);
285 auto *Exp2Call =
286 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
287 Exp2Call->setTailCall(Orig->isTailCall());
288 Exp2Call->setAttributes(Orig->getAttributes());
289 return Exp2Call;
290}
291
293 Intrinsic::ID ClampIntrinsic) {
294 if (ClampIntrinsic == Intrinsic::dx_uclamp)
295 return Intrinsic::umax;
296 assert(ClampIntrinsic == Intrinsic::dx_clamp);
297 if (ElemTy->isVectorTy())
298 ElemTy = ElemTy->getScalarType();
299 if (ElemTy->isIntegerTy())
300 return Intrinsic::smax;
301 assert(ElemTy->isFloatingPointTy());
302 return Intrinsic::maxnum;
303}
304
306 Intrinsic::ID ClampIntrinsic) {
307 if (ClampIntrinsic == Intrinsic::dx_uclamp)
308 return Intrinsic::umin;
309 assert(ClampIntrinsic == Intrinsic::dx_clamp);
310 if (ElemTy->isVectorTy())
311 ElemTy = ElemTy->getScalarType();
312 if (ElemTy->isIntegerTy())
313 return Intrinsic::smin;
314 assert(ElemTy->isFloatingPointTy());
315 return Intrinsic::minnum;
316}
317
319 Intrinsic::ID ClampIntrinsic) {
320 Value *X = Orig->getOperand(0);
321 Value *Min = Orig->getOperand(1);
322 Value *Max = Orig->getOperand(2);
323 Type *Ty = X->getType();
324 IRBuilder<> Builder(Orig->getParent());
325 Builder.SetInsertPoint(Orig);
326 auto *MaxCall = Builder.CreateIntrinsic(
327 Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
328 return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
329 {MaxCall, Max}, nullptr, "dx.min");
330}
331
332static bool expandIntrinsic(Function &F, CallInst *Orig) {
333 Value *Result = nullptr;
334 switch (F.getIntrinsicID()) {
335 case Intrinsic::abs:
336 Result = expandAbs(Orig);
337 break;
338 case Intrinsic::exp:
339 Result = expandExpIntrinsic(Orig);
340 break;
341 case Intrinsic::log:
342 Result = expandLogIntrinsic(Orig);
343 break;
344 case Intrinsic::log10:
345 Result = expandLog10Intrinsic(Orig);
346 break;
347 case Intrinsic::pow:
348 Result = expandPowIntrinsic(Orig);
349 break;
350 case Intrinsic::dx_any:
351 Result = expandAnyIntrinsic(Orig);
352 break;
353 case Intrinsic::dx_uclamp:
354 case Intrinsic::dx_clamp:
355 Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
356 break;
357 case Intrinsic::dx_lerp:
358 Result = expandLerpIntrinsic(Orig);
359 break;
360 case Intrinsic::dx_length:
361 Result = expandLengthIntrinsic(Orig);
362 break;
363 case Intrinsic::dx_normalize:
364 Result = expandNormalizeIntrinsic(Orig);
365 break;
366 case Intrinsic::dx_sdot:
367 case Intrinsic::dx_udot:
368 Result = expandIntegerDot(Orig, F.getIntrinsicID());
369 break;
370 }
371
372 if (Result) {
373 Orig->replaceAllUsesWith(Result);
374 Orig->eraseFromParent();
375 return true;
376 }
377 return false;
378}
379
381 for (auto &F : make_early_inc_range(M.functions())) {
383 continue;
384 bool IntrinsicExpanded = false;
385 for (User *U : make_early_inc_range(F.users())) {
386 auto *IntrinsicCall = dyn_cast<CallInst>(U);
387 if (!IntrinsicCall)
388 continue;
389 IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
390 }
391 if (F.user_empty() && IntrinsicExpanded)
392 F.eraseFromParent();
393 }
394 return true;
395}
396
399 if (expansionIntrinsics(M))
401 return PreservedAnalyses::all();
402}
403
405 return expansionIntrinsics(M);
406}
407
409
411 "DXIL Intrinsic Expansion", false, false)
413 "DXIL Intrinsic Expansion", false, false)
414
416 return new DXILIntrinsicExpansionLegacy();
417}
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static Value * expandNormalizeIntrinsic(CallInst *Orig)
DXIL Intrinsic Expansion
static bool expandIntrinsic(Function &F, CallInst *Orig)
static Value * expandLengthIntrinsic(CallInst *Orig)
static Value * expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic)
static bool expansionIntrinsics(Module &M)
static Value * expandLerpIntrinsic(CallInst *Orig)
static Value * expandAnyIntrinsic(CallInst *Orig)
static Value * expandLog10Intrinsic(CallInst *Orig)
static Value * expandPowIntrinsic(CallInst *Orig)
static Value * expandLogIntrinsic(CallInst *Orig, float LogConstVal=numbers::ln2f)
static Value * expandExpIntrinsic(CallInst *Orig)
static Value * expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic)
static Intrinsic::ID getMaxForClamp(Type *ElemTy, Intrinsic::ID ClampIntrinsic)
static Intrinsic::ID getMinForClamp(Type *ElemTy, Intrinsic::ID ClampIntrinsic)
static Value * expandAbs(CallInst *Orig)
static bool isIntrinsicExpansion(Function &F)
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
#define DEBUG_TYPE
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
Module.h This file contains the declarations for the Module class.
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
This header defines various interfaces for pass management in LLVM.
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
const SmallVectorImpl< MachineOperand > & Cond
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned getNumElements(Type *Ty)
This file contains some templates that are useful if you are working with the STL at all.
This file defines the SmallVector class.
BinaryOperator * Mul
bool isZero() const
Definition: APFloat.h:1356
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
AttributeList getAttributes() const
Return the parameter attributes for this call.
Definition: InstrTypes.h:1542
This class represents a function call, abstracting a target machine's calling convention.
bool isTailCall() const
void setTailCall(bool IsTc=true)
static Constant * getSplat(ElementCount EC, Constant *Elt)
Return a ConstantVector with the specified constant in each element.
Definition: Constants.cpp:1450
This is an important base class in LLVM.
Definition: Constant.h:42
bool runOnModule(Module &M) override
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &)
static constexpr ElementCount getFixed(ScalarTy MinVal)
Definition: TypeSize.h:311
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1577
Value * CreateFDiv(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1631
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2480
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1550
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
Definition: IRBuilder.cpp:1193
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:933
Value * CreateFCmpUNE(Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition: IRBuilder.h:2366
Value * CreateICmpNE(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:2265
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1361
Value * CreateOr(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:1514
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition: IRBuilder.h:1604
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1378
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2686
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:92
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:114
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:261
bool isFloatingPointTy() const
Return true if this is one of the floating-point types.
Definition: Type.h:184
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:224
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:343
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
const ParentTy * getParent() const
Definition: ilist_node.h:32
constexpr float ln10f
Definition: MathExtras.h:65
constexpr float log2ef
Definition: MathExtras.h:66
constexpr float ln2f
Definition: MathExtras.h:64
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:656
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:167
ModulePass * createDXILIntrinsicExpansionLegacyPass()
Pass to expand intrinsic operations that lack DXIL opCodes.