LLVM 20.0.0git
DXILOpLowering.cpp
Go to the documentation of this file.
1//===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "DXILOpLowering.h"
10#include "DXILConstants.h"
12#include "DXILOpBuilder.h"
14#include "DXILShaderFlags.h"
15#include "DirectX.h"
19#include "llvm/CodeGen/Passes.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/Instruction.h"
24#include "llvm/IR/Intrinsics.h"
25#include "llvm/IR/IntrinsicsDirectX.h"
26#include "llvm/IR/Module.h"
27#include "llvm/IR/PassManager.h"
29#include "llvm/Pass.h"
31
32#define DEBUG_TYPE "dxil-op-lower"
33
34using namespace llvm;
35using namespace llvm::dxil;
36
38 switch (F.getIntrinsicID()) {
39 case Intrinsic::dx_dot2:
40 case Intrinsic::dx_dot3:
41 case Intrinsic::dx_dot4:
42 return true;
43 }
44 return false;
45}
46
48 SmallVector<Value *> ExtractedElements;
49 auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
50 for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
51 Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
52 Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
53 ExtractedElements.push_back(ExtractedElement);
54 }
55 return ExtractedElements;
56}
57
59 IRBuilder<> &Builder) {
60 // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
61 unsigned NumOperands = Orig->getNumOperands() - 1;
62 assert(NumOperands > 0);
63 Value *Arg0 = Orig->getOperand(0);
64 [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
65 assert(VecArg0);
66 SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
67 for (unsigned I = 1; I < NumOperands; ++I) {
68 Value *Arg = Orig->getOperand(I);
69 [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
70 assert(VecArg);
71 assert(VecArg0->getElementType() == VecArg->getElementType());
72 assert(VecArg0->getNumElements() == VecArg->getNumElements());
73 auto NextOperandList = populateOperands(Arg, Builder);
74 NewOperands.append(NextOperandList.begin(), NextOperandList.end());
75 }
76 return NewOperands;
77}
78
79namespace {
80class OpLowerer {
81 Module &M;
82 DXILOpBuilder OpBuilder;
83 DXILBindingMap &DBM;
85 SmallVector<CallInst *> CleanupCasts;
86
87public:
88 OpLowerer(Module &M, DXILBindingMap &DBM, DXILResourceTypeMap &DRTM)
89 : M(M), OpBuilder(M), DBM(DBM), DRTM(DRTM) {}
90
91 /// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If
92 /// there is an error replacing a call, we emit a diagnostic and return true.
93 [[nodiscard]] bool
94 replaceFunction(Function &F,
95 llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
96 for (User *U : make_early_inc_range(F.users())) {
97 CallInst *CI = dyn_cast<CallInst>(U);
98 if (!CI)
99 continue;
100
101 if (Error E = ReplaceCall(CI)) {
102 std::string Message(toString(std::move(E)));
103 DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
104 CI->getDebugLoc());
105 M.getContext().diagnose(Diag);
106 return true;
107 }
108 }
109 if (F.user_empty())
110 F.eraseFromParent();
111 return false;
112 }
113
114 struct IntrinArgSelect {
115 enum class Type {
116#define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name,
117#include "DXILOperation.inc"
118 };
119 Type Type;
120 int Value;
121 };
122
123 [[nodiscard]] bool
124 replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
125 ArrayRef<IntrinArgSelect> ArgSelects) {
126 bool IsVectorArgExpansion = isVectorArgExpansion(F);
127 assert(!(IsVectorArgExpansion && ArgSelects.size()) &&
128 "Cann't do vector arg expansion when using arg selects.");
129 return replaceFunction(F, [&](CallInst *CI) -> Error {
130 OpBuilder.getIRB().SetInsertPoint(CI);
132 if (ArgSelects.size()) {
133 for (const IntrinArgSelect &A : ArgSelects) {
134 switch (A.Type) {
135 case IntrinArgSelect::Type::Index:
136 Args.push_back(CI->getArgOperand(A.Value));
137 break;
138 case IntrinArgSelect::Type::I8:
139 Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value));
140 break;
141 case IntrinArgSelect::Type::I32:
142 Args.push_back(OpBuilder.getIRB().getInt32(A.Value));
143 break;
144 }
145 }
146 } else if (IsVectorArgExpansion) {
147 Args = argVectorFlatten(CI, OpBuilder.getIRB());
148 } else {
149 Args.append(CI->arg_begin(), CI->arg_end());
150 }
151
152 Expected<CallInst *> OpCall =
153 OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType());
154 if (Error E = OpCall.takeError())
155 return E;
156
157 CI->replaceAllUsesWith(*OpCall);
158 CI->eraseFromParent();
159 return Error::success();
160 });
161 }
162
163 [[nodiscard]] bool replaceFunctionWithNamedStructOp(
164 Function &F, dxil::OpCode DXILOp, Type *NewRetTy,
165 llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) {
166 bool IsVectorArgExpansion = isVectorArgExpansion(F);
167 return replaceFunction(F, [&](CallInst *CI) -> Error {
169 OpBuilder.getIRB().SetInsertPoint(CI);
170 if (IsVectorArgExpansion) {
171 SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
172 Args.append(NewArgs.begin(), NewArgs.end());
173 } else
174 Args.append(CI->arg_begin(), CI->arg_end());
175
176 Expected<CallInst *> OpCall =
177 OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy);
178 if (Error E = OpCall.takeError())
179 return E;
180 if (Error E = ReplaceUses(CI, *OpCall))
181 return E;
182
183 return Error::success();
184 });
185 }
186
187 /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which
188 /// is intended to be removed by the end of lowering. This is used to allow
189 /// lowering of ops which need to change their return or argument types in a
190 /// piecemeal way - we can add the casts in to avoid updating all of the uses
191 /// or defs, and by the end all of the casts will be redundant.
192 Value *createTmpHandleCast(Value *V, Type *Ty) {
193 CallInst *Cast = OpBuilder.getIRB().CreateIntrinsic(
194 Intrinsic::dx_resource_casthandle, {Ty, V->getType()}, {V});
195 CleanupCasts.push_back(Cast);
196 return Cast;
197 }
198
199 void cleanupHandleCasts() {
202
203 for (CallInst *Cast : CleanupCasts) {
204 // These casts were only put in to ease the move from `target("dx")` types
205 // to `dx.types.Handle in a piecemeal way. At this point, all of the
206 // non-cast uses should now be `dx.types.Handle`, and remaining casts
207 // should all form pairs to and from the now unused `target("dx")` type.
208 CastFns.push_back(Cast->getCalledFunction());
209
210 // If the cast is not to `dx.types.Handle`, it should be the first part of
211 // the pair. Keep track so we can remove it once it has no more uses.
212 if (Cast->getType() != OpBuilder.getHandleType()) {
213 ToRemove.push_back(Cast);
214 continue;
215 }
216 // Otherwise, we're the second handle in a pair. Forward the arguments and
217 // remove the (second) cast.
218 CallInst *Def = cast<CallInst>(Cast->getOperand(0));
219 assert(Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle &&
220 "Unbalanced pair of temporary handle casts");
221 Cast->replaceAllUsesWith(Def->getOperand(0));
222 Cast->eraseFromParent();
223 }
224 for (CallInst *Cast : ToRemove) {
225 assert(Cast->user_empty() && "Temporary handle cast still has users");
226 Cast->eraseFromParent();
227 }
228
229 // Deduplicate the cast functions so that we only erase each one once.
230 llvm::sort(CastFns);
231 CastFns.erase(llvm::unique(CastFns), CastFns.end());
232 for (Function *F : CastFns)
233 F->eraseFromParent();
234
235 CleanupCasts.clear();
236 }
237
238 // Remove the resource global associated with the handleFromBinding call
239 // instruction and their uses as they aren't needed anymore.
240 // TODO: We should verify that all the globals get removed.
241 // It's expected we'll need a custom pass in the future that will eliminate
242 // the need for this here.
243 void removeResourceGlobals(CallInst *CI) {
244 for (User *User : make_early_inc_range(CI->users())) {
245 if (StoreInst *Store = dyn_cast<StoreInst>(User)) {
246 Value *V = Store->getOperand(1);
247 Store->eraseFromParent();
248 if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
249 if (GV->use_empty()) {
250 GV->removeDeadConstantUsers();
251 GV->eraseFromParent();
252 }
253 }
254 }
255 }
256
257 [[nodiscard]] bool lowerToCreateHandle(Function &F) {
258 IRBuilder<> &IRB = OpBuilder.getIRB();
259 Type *Int8Ty = IRB.getInt8Ty();
260 Type *Int32Ty = IRB.getInt32Ty();
261
262 return replaceFunction(F, [&](CallInst *CI) -> Error {
263 IRB.SetInsertPoint(CI);
264
265 auto *It = DBM.find(CI);
266 assert(It != DBM.end() && "Resource not in map?");
268
269 const auto &Binding = RI.getBinding();
270 dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass();
271
272 Value *IndexOp = CI->getArgOperand(3);
273 if (Binding.LowerBound != 0)
274 IndexOp = IRB.CreateAdd(IndexOp,
275 ConstantInt::get(Int32Ty, Binding.LowerBound));
276
277 std::array<Value *, 4> Args{
278 ConstantInt::get(Int8Ty, llvm::to_underlying(RC)),
279 ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp,
280 CI->getArgOperand(4)};
281 Expected<CallInst *> OpCall =
282 OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName());
283 if (Error E = OpCall.takeError())
284 return E;
285
286 Value *Cast = createTmpHandleCast(*OpCall, CI->getType());
287
288 removeResourceGlobals(CI);
289
290 CI->replaceAllUsesWith(Cast);
291 CI->eraseFromParent();
292 return Error::success();
293 });
294 }
295
296 [[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) {
297 IRBuilder<> &IRB = OpBuilder.getIRB();
298 Type *Int32Ty = IRB.getInt32Ty();
299
300 return replaceFunction(F, [&](CallInst *CI) -> Error {
301 IRB.SetInsertPoint(CI);
302
303 auto *It = DBM.find(CI);
304 assert(It != DBM.end() && "Resource not in map?");
306
307 const auto &Binding = RI.getBinding();
308 dxil::ResourceTypeInfo &RTI = DRTM[RI.getHandleTy()];
310
311 Value *IndexOp = CI->getArgOperand(3);
312 if (Binding.LowerBound != 0)
313 IndexOp = IRB.CreateAdd(IndexOp,
314 ConstantInt::get(Int32Ty, Binding.LowerBound));
315
316 std::pair<uint32_t, uint32_t> Props =
317 RI.getAnnotateProps(*F.getParent(), RTI);
318
319 // For `CreateHandleFromBinding` we need the upper bound rather than the
320 // size, so we need to be careful about the difference for "unbounded".
321 uint32_t Unbounded = std::numeric_limits<uint32_t>::max();
322 uint32_t UpperBound = Binding.Size == Unbounded
323 ? Unbounded
324 : Binding.LowerBound + Binding.Size - 1;
325 Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound,
326 Binding.Space, RC);
327 std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)};
328 Expected<CallInst *> OpBind = OpBuilder.tryCreateOp(
329 OpCode::CreateHandleFromBinding, BindArgs, CI->getName());
330 if (Error E = OpBind.takeError())
331 return E;
332
333 std::array<Value *, 2> AnnotateArgs{
334 *OpBind, OpBuilder.getResProps(Props.first, Props.second)};
335 Expected<CallInst *> OpAnnotate = OpBuilder.tryCreateOp(
336 OpCode::AnnotateHandle, AnnotateArgs,
337 CI->hasName() ? CI->getName() + "_annot" : Twine());
338 if (Error E = OpAnnotate.takeError())
339 return E;
340
341 Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType());
342
343 removeResourceGlobals(CI);
344
345 CI->replaceAllUsesWith(Cast);
346 CI->eraseFromParent();
347
348 return Error::success();
349 });
350 }
351
352 /// Lower `dx.resource.handlefrombinding` intrinsics depending on the shader
353 /// model and taking into account binding information from
354 /// DXILResourceBindingAnalysis.
355 bool lowerHandleFromBinding(Function &F) {
356 Triple TT(Triple(M.getTargetTriple()));
357 if (TT.getDXILVersion() < VersionTuple(1, 6))
358 return lowerToCreateHandle(F);
359 return lowerToBindAndAnnotateHandle(F);
360 }
361
362 Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
363 for (Use &U : make_early_inc_range(Intrin->uses())) {
364 if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
365
366 if (EVI->getNumIndices() != 1)
367 return createStringError(std::errc::invalid_argument,
368 "Splitdouble has only 2 elements");
369 EVI->setOperand(0, Op);
370 } else {
371 return make_error<StringError>(
372 "Splitdouble use is not ExtractValueInst",
374 }
375 }
376
377 Intrin->eraseFromParent();
378
379 return Error::success();
380 }
381
382 /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
383 /// Since we expect to be post-scalarization, make an effort to avoid vectors.
384 Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
385 IRBuilder<> &IRB = OpBuilder.getIRB();
386
387 Instruction *OldResult = Intrin;
388 Type *OldTy = Intrin->getType();
389
390 if (HasCheckBit) {
391 auto *ST = cast<StructType>(OldTy);
392
393 Value *CheckOp = nullptr;
394 Type *Int32Ty = IRB.getInt32Ty();
395 for (Use &U : make_early_inc_range(OldResult->uses())) {
396 if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
397 ArrayRef<unsigned> Indices = EVI->getIndices();
398 assert(Indices.size() == 1);
399 // We're only interested in uses of the check bit for now.
400 if (Indices[0] != 1)
401 continue;
402 if (!CheckOp) {
403 Value *NewEVI = IRB.CreateExtractValue(Op, 4);
404 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
405 OpCode::CheckAccessFullyMapped, {NewEVI},
406 OldResult->hasName() ? OldResult->getName() + "_check"
407 : Twine(),
408 Int32Ty);
409 if (Error E = OpCall.takeError())
410 return E;
411 CheckOp = *OpCall;
412 }
413 EVI->replaceAllUsesWith(CheckOp);
414 EVI->eraseFromParent();
415 }
416 }
417
418 if (OldResult->use_empty()) {
419 // Only the check bit was used, so we're done here.
420 OldResult->eraseFromParent();
421 return Error::success();
422 }
423
424 assert(OldResult->hasOneUse() &&
425 isa<ExtractValueInst>(*OldResult->user_begin()) &&
426 "Expected only use to be extract of first element");
427 OldResult = cast<Instruction>(*OldResult->user_begin());
428 OldTy = ST->getElementType(0);
429 }
430
431 // For scalars, we just extract the first element.
432 if (!isa<FixedVectorType>(OldTy)) {
433 Value *EVI = IRB.CreateExtractValue(Op, 0);
434 OldResult->replaceAllUsesWith(EVI);
435 OldResult->eraseFromParent();
436 if (OldResult != Intrin) {
437 assert(Intrin->use_empty() && "Intrinsic still has uses?");
438 Intrin->eraseFromParent();
439 }
440 return Error::success();
441 }
442
443 std::array<Value *, 4> Extracts = {};
444 SmallVector<ExtractElementInst *> DynamicAccesses;
445
446 // The users of the operation should all be scalarized, so we attempt to
447 // replace the extractelements with extractvalues directly.
448 for (Use &U : make_early_inc_range(OldResult->uses())) {
449 if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
450 if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
451 size_t IndexVal = IndexOp->getZExtValue();
452 assert(IndexVal < 4 && "Index into buffer load out of range");
453 if (!Extracts[IndexVal])
454 Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal);
455 EEI->replaceAllUsesWith(Extracts[IndexVal]);
456 EEI->eraseFromParent();
457 } else {
458 DynamicAccesses.push_back(EEI);
459 }
460 }
461 }
462
463 const auto *VecTy = cast<FixedVectorType>(OldTy);
464 const unsigned N = VecTy->getNumElements();
465
466 // If there's a dynamic access we need to round trip through stack memory so
467 // that we don't leave vectors around.
468 if (!DynamicAccesses.empty()) {
469 Type *Int32Ty = IRB.getInt32Ty();
470 Constant *Zero = ConstantInt::get(Int32Ty, 0);
471
472 Type *ElTy = VecTy->getElementType();
473 Type *ArrayTy = ArrayType::get(ElTy, N);
474 Value *Alloca = IRB.CreateAlloca(ArrayTy);
475
476 for (int I = 0, E = N; I != E; ++I) {
477 if (!Extracts[I])
478 Extracts[I] = IRB.CreateExtractValue(Op, I);
480 ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)});
481 IRB.CreateStore(Extracts[I], GEP);
482 }
483
484 for (ExtractElementInst *EEI : DynamicAccesses) {
485 Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca,
486 {Zero, EEI->getIndexOperand()});
487 Value *Load = IRB.CreateLoad(ElTy, GEP);
488 EEI->replaceAllUsesWith(Load);
489 EEI->eraseFromParent();
490 }
491 }
492
493 // If we still have uses, then we're not fully scalarized and need to
494 // recreate the vector. This should only happen for things like exported
495 // functions from libraries.
496 if (!OldResult->use_empty()) {
497 for (int I = 0, E = N; I != E; ++I)
498 if (!Extracts[I])
499 Extracts[I] = IRB.CreateExtractValue(Op, I);
500
501 Value *Vec = UndefValue::get(OldTy);
502 for (int I = 0, E = N; I != E; ++I)
503 Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
504 OldResult->replaceAllUsesWith(Vec);
505 }
506
507 OldResult->eraseFromParent();
508 if (OldResult != Intrin) {
509 assert(Intrin->use_empty() && "Intrinsic still has uses?");
510 Intrin->eraseFromParent();
511 }
512
513 return Error::success();
514 }
515
516 [[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) {
517 IRBuilder<> &IRB = OpBuilder.getIRB();
518 Type *Int32Ty = IRB.getInt32Ty();
519
520 return replaceFunction(F, [&](CallInst *CI) -> Error {
521 IRB.SetInsertPoint(CI);
522
523 Value *Handle =
524 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
525 Value *Index0 = CI->getArgOperand(1);
526 Value *Index1 = UndefValue::get(Int32Ty);
527
528 Type *OldTy = CI->getType();
529 if (HasCheckBit)
530 OldTy = cast<StructType>(OldTy)->getElementType(0);
531 Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType());
532
533 std::array<Value *, 3> Args{Handle, Index0, Index1};
534 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
535 OpCode::BufferLoad, Args, CI->getName(), NewRetTy);
536 if (Error E = OpCall.takeError())
537 return E;
538 if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
539 return E;
540
541 return Error::success();
542 });
543 }
544
545 [[nodiscard]] bool lowerRawBufferLoad(Function &F) {
546 Triple TT(Triple(M.getTargetTriple()));
547 VersionTuple DXILVersion = TT.getDXILVersion();
548 const DataLayout &DL = F.getDataLayout();
549 IRBuilder<> &IRB = OpBuilder.getIRB();
550 Type *Int8Ty = IRB.getInt8Ty();
551 Type *Int32Ty = IRB.getInt32Ty();
552
553 return replaceFunction(F, [&](CallInst *CI) -> Error {
554 IRB.SetInsertPoint(CI);
555
556 Type *OldTy = cast<StructType>(CI->getType())->getElementType(0);
557 Type *ScalarTy = OldTy->getScalarType();
558 Type *NewRetTy = OpBuilder.getResRetType(ScalarTy);
559
560 Value *Handle =
561 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
562 Value *Index0 = CI->getArgOperand(1);
563 Value *Index1 = CI->getArgOperand(2);
564 uint64_t NumElements =
565 DL.getTypeSizeInBits(OldTy) / DL.getTypeSizeInBits(ScalarTy);
566 Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
567 Value *Align =
568 ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value());
569
570 Expected<CallInst *> OpCall =
571 DXILVersion >= VersionTuple(1, 2)
572 ? OpBuilder.tryCreateOp(OpCode::RawBufferLoad,
573 {Handle, Index0, Index1, Mask, Align},
574 CI->getName(), NewRetTy)
575 : OpBuilder.tryCreateOp(OpCode::BufferLoad,
576 {Handle, Index0, Index1}, CI->getName(),
577 NewRetTy);
578 if (Error E = OpCall.takeError())
579 return E;
580 if (Error E = replaceResRetUses(CI, *OpCall, /*HasCheckBit=*/true))
581 return E;
582
583 return Error::success();
584 });
585 }
586
587 [[nodiscard]] bool lowerUpdateCounter(Function &F) {
588 IRBuilder<> &IRB = OpBuilder.getIRB();
589 Type *Int32Ty = IRB.getInt32Ty();
590
591 return replaceFunction(F, [&](CallInst *CI) -> Error {
592 IRB.SetInsertPoint(CI);
593 Value *Handle =
594 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
595 Value *Op1 = CI->getArgOperand(1);
596
597 std::array<Value *, 2> Args{Handle, Op1};
598
599 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
600 OpCode::UpdateCounter, Args, CI->getName(), Int32Ty);
601
602 if (Error E = OpCall.takeError())
603 return E;
604
605 CI->replaceAllUsesWith(*OpCall);
606 CI->eraseFromParent();
607 return Error::success();
608 });
609 }
610
611 [[nodiscard]] bool lowerGetPointer(Function &F) {
612 // These should have already been handled in DXILResourceAccess, so we can
613 // just clean up the dead prototype.
614 assert(F.user_empty() && "getpointer operations should have been removed");
615 F.eraseFromParent();
616 return false;
617 }
618
619 [[nodiscard]] bool lowerTypedBufferStore(Function &F) {
620 IRBuilder<> &IRB = OpBuilder.getIRB();
621 Type *Int8Ty = IRB.getInt8Ty();
622 Type *Int32Ty = IRB.getInt32Ty();
623
624 return replaceFunction(F, [&](CallInst *CI) -> Error {
625 IRB.SetInsertPoint(CI);
626
627 Value *Handle =
628 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
629 Value *Index0 = CI->getArgOperand(1);
630 Value *Index1 = UndefValue::get(Int32Ty);
631 // For typed stores, the mask must always cover all four elements.
632 Constant *Mask = ConstantInt::get(Int8Ty, 0xF);
633
634 Value *Data = CI->getArgOperand(2);
635 auto *DataTy = dyn_cast<FixedVectorType>(Data->getType());
636 if (!DataTy || DataTy->getNumElements() != 4)
637 return make_error<StringError>(
638 "typedBufferStore data must be a vector of 4 elements",
640
641 // Since we're post-scalarizer, we likely have a vector that's constructed
642 // solely for the argument of the store. If so, just use the scalar values
643 // from before they're inserted into the temporary.
644 std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
645 auto *IEI = dyn_cast<InsertElementInst>(Data);
646 while (IEI) {
647 auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
648 if (!IndexOp)
649 break;
650 size_t IndexVal = IndexOp->getZExtValue();
651 assert(IndexVal < 4 && "Too many elements for buffer store");
652 DataElements[IndexVal] = IEI->getOperand(1);
653 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
654 }
655
656 // If for some reason we weren't able to forward the arguments from the
657 // scalarizer artifact, then we need to actually extract elements from the
658 // vector.
659 for (int I = 0, E = 4; I != E; ++I)
660 if (DataElements[I] == nullptr)
661 DataElements[I] =
662 IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
663
664 std::array<Value *, 8> Args{
665 Handle, Index0, Index1, DataElements[0],
666 DataElements[1], DataElements[2], DataElements[3], Mask};
667 Expected<CallInst *> OpCall =
668 OpBuilder.tryCreateOp(OpCode::BufferStore, Args, CI->getName());
669 if (Error E = OpCall.takeError())
670 return E;
671
672 CI->eraseFromParent();
673 // Clean up any leftover `insertelement`s
674 IEI = dyn_cast<InsertElementInst>(Data);
675 while (IEI && IEI->use_empty()) {
676 InsertElementInst *Tmp = IEI;
677 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
678 Tmp->eraseFromParent();
679 }
680
681 return Error::success();
682 });
683 }
684
685 [[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
686 IRBuilder<> &IRB = OpBuilder.getIRB();
687 Type *Int32Ty = IRB.getInt32Ty();
688
689 return replaceFunction(F, [&](CallInst *CI) -> Error {
690 IRB.SetInsertPoint(CI);
692 Args.append(CI->arg_begin(), CI->arg_end());
693
694 Type *RetTy = Int32Ty;
695 Type *FRT = F.getReturnType();
696 if (const auto *VT = dyn_cast<VectorType>(FRT))
697 RetTy = VectorType::get(RetTy, VT);
698
699 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
700 dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
701 if (Error E = OpCall.takeError())
702 return E;
703
704 // If the result type is 32 bits we can do a direct replacement.
705 if (FRT->isIntOrIntVectorTy(32)) {
706 CI->replaceAllUsesWith(*OpCall);
707 CI->eraseFromParent();
708 return Error::success();
709 }
710
711 unsigned CastOp;
712 unsigned CastOp2;
713 if (FRT->isIntOrIntVectorTy(16)) {
714 CastOp = Instruction::ZExt;
715 CastOp2 = Instruction::SExt;
716 } else { // must be 64 bits
717 assert(FRT->isIntOrIntVectorTy(64) &&
718 "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
719 is supported.");
720 CastOp = Instruction::Trunc;
721 CastOp2 = Instruction::Trunc;
722 }
723
724 // It is correct to replace the ctpop with the dxil op and
725 // remove all casts to i32
726 bool NeedsCast = false;
728 Instruction *I = dyn_cast<Instruction>(User);
729 if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
730 I->getType() == RetTy) {
731 I->replaceAllUsesWith(*OpCall);
732 I->eraseFromParent();
733 } else
734 NeedsCast = true;
735 }
736
737 // It is correct to replace a ctpop with the dxil op and
738 // a cast from i32 to the return type of the ctpop
739 // the cast is emitted here if there is a non-cast to i32
740 // instr which uses the ctpop
741 if (NeedsCast) {
742 Value *Cast =
743 IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
744 CI->replaceAllUsesWith(Cast);
745 }
746
747 CI->eraseFromParent();
748 return Error::success();
749 });
750 }
751
752 bool lowerIntrinsics() {
753 bool Updated = false;
754 bool HasErrors = false;
755
756 for (Function &F : make_early_inc_range(M.functions())) {
757 if (!F.isDeclaration())
758 continue;
759 Intrinsic::ID ID = F.getIntrinsicID();
760 switch (ID) {
761 default:
762 continue;
763#define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \
764 case Intrin: \
765 HasErrors |= replaceFunctionWithOp( \
766 F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__}); \
767 break;
768#include "DXILOperation.inc"
769 case Intrinsic::dx_resource_handlefrombinding:
770 HasErrors |= lowerHandleFromBinding(F);
771 break;
772 case Intrinsic::dx_resource_getpointer:
773 HasErrors |= lowerGetPointer(F);
774 break;
775 case Intrinsic::dx_resource_load_typedbuffer:
776 HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true);
777 break;
778 case Intrinsic::dx_resource_store_typedbuffer:
779 HasErrors |= lowerTypedBufferStore(F);
780 break;
781 case Intrinsic::dx_resource_load_rawbuffer:
782 HasErrors |= lowerRawBufferLoad(F);
783 break;
784 case Intrinsic::dx_resource_updatecounter:
785 HasErrors |= lowerUpdateCounter(F);
786 break;
787 // TODO: this can be removed when
788 // https://github.com/llvm/llvm-project/issues/113192 is fixed
789 case Intrinsic::dx_splitdouble:
790 HasErrors |= replaceFunctionWithNamedStructOp(
791 F, OpCode::SplitDouble,
792 OpBuilder.getSplitDoubleType(M.getContext()),
793 [&](CallInst *CI, CallInst *Op) {
794 return replaceSplitDoubleCallUsages(CI, Op);
795 });
796 break;
797 case Intrinsic::ctpop:
798 HasErrors |= lowerCtpopToCountBits(F);
799 break;
800 }
801 Updated = true;
802 }
803 if (Updated && !HasErrors)
804 cleanupHandleCasts();
805
806 return Updated;
807 }
808};
809} // namespace
810
814
815 bool MadeChanges = OpLowerer(M, DBM, DRTM).lowerIntrinsics();
816 if (!MadeChanges)
817 return PreservedAnalyses::all();
822 return PA;
823}
824
825namespace {
826class DXILOpLoweringLegacy : public ModulePass {
827public:
828 bool runOnModule(Module &M) override {
829 DXILBindingMap &DBM =
830 getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
831 DXILResourceTypeMap &DRTM =
832 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
833
834 return OpLowerer(M, DBM, DRTM).lowerIntrinsics();
835 }
836 StringRef getPassName() const override { return "DXIL Op Lowering"; }
837 DXILOpLoweringLegacy() : ModulePass(ID) {}
838
839 static char ID; // Pass identification.
840 void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
847 }
848};
849char DXILOpLoweringLegacy::ID = 0;
850} // end anonymous namespace
851
852INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
853 false, false)
856INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
857 false)
858
860 return new DXILOpLoweringLegacy();
861}
for(const MachineOperand &MO :llvm::drop_begin(OldMI.operands(), Desc.getNumOperands()))
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static bool isVectorArgExpansion(Function &F)
static SmallVector< Value * > argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder)
static SmallVector< Value * > populateOperands(Value *Arg, IRBuilder<> &Builder)
return RetTy
#define DEBUG_TYPE
Hexagon Common GEP
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
ModuleAnalysisManager MAM
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:168
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1349
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1269
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1294
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1275
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
Definition: Constant.h:42
This class represents an Operation in the Expression.
iterator find(const CallInst *Key)
Definition: DXILResource.h:444
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
The legacy pass manager's analysis pass to compute DXIL resource information.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
Diagnostic information for unsupported feature in backend.
Lightweight error class with error context and mandatory checking.
Definition: Error.h:160
static ErrorSuccess success()
Create a success value.
Definition: Error.h:337
Tagged union holding either a T or a Error.
Definition: Error.h:481
Error takeError()
Take ownership of the stored error.
Definition: Error.h:608
This instruction extracts a single (scalar) element from a VectorType value.
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2510
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1780
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2498
Value * CreateZExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a ZExt or Trunc from the integer value V to DestTy.
Definition: IRBuilder.h:2050
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition: IRBuilder.h:2554
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Definition: IRBuilder.h:545
Value * CreateInBoundsGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
Definition: IRBuilder.h:1881
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:890
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1797
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition: IRBuilder.h:1810
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1369
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:199
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
Definition: IRBuilder.h:535
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2704
This instruction inserts a single (scalar) element into a VectorType value.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:475
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:94
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:72
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
void preserve()
Mark an analysis as preserved.
Definition: Analysis.h:131
bool empty() const
Definition: SmallVector.h:81
iterator erase(const_iterator CI)
Definition: SmallVector.h:737
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static IntegerType * getInt32Ty(LLVMContext &C)
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:355
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1859
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
Value * getOperand(unsigned i) const
Definition: User.h:228
unsigned getNumOperands() const
Definition: User.h:250
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
user_iterator user_begin()
Definition: Value.h:397
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
bool use_empty() const
Definition: Value.h:344
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
iterator_range< use_iterator > uses()
Definition: Value.h:376
bool hasName() const
Definition: Value.h:261
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
bool user_empty() const
Definition: Value.h:385
Represents a version number in the form major[.minor[.subminor[.build]]].
Definition: VersionTuple.h:29
StructType * getResRetType(Type *ElementTy)
Get a dx.types.ResRet type with the given element type.
StructType * getSplitDoubleType(LLVMContext &Context)
Get the dx.types.splitdouble type.
Expected< CallInst * > tryCreateOp(dxil::OpCode Op, ArrayRef< Value * > Args, const Twine &Name="", Type *RetTy=nullptr)
Try to create a call instruction for the given DXIL op.
Constant * getResBind(uint32_t LowerBound, uint32_t UpperBound, uint32_t SpaceID, dxil::ResourceClass RC)
Get a constant dx.types.ResBind value.
Constant * getResProps(uint32_t Word0, uint32_t Word1)
Get a constant dx.types.ResourceProperties value.
StructType * getHandleType()
Get the dx.types.Handle type.
std::pair< uint32_t, uint32_t > getAnnotateProps(Module &M, dxil::ResourceTypeInfo &RTI) const
TargetExtType * getHandleTy() const
Definition: DXILResource.h:345
const ResourceBinding & getBinding() const
Definition: DXILResource.h:344
dxil::ResourceClass getResourceClass() const
Definition: DXILResource.h:294
Wrapper pass for the legacy pass manager.
An efficient, type-erasing, non-owning reference to a callable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:125
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
ResourceClass
Definition: DXILABI.h:25
NodeAddr< DefNode * > Def
Definition: RDFGraph.h:384
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
std::error_code inconvertibleErrorCode()
The value returned by this function can be returned from convertToErrorCode for Error values where no...
Definition: Error.cpp:98
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:657
auto unique(Range &&R, Predicate P)
Definition: STLExtras.h:2055
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
Definition: Error.h:1291
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1664
constexpr std::underlying_type_t< Enum > to_underlying(Enum E)
Returns underlying integer value of an enum.
ModulePass * createDXILOpLoweringLegacyPass()
Pass to lowering LLVM intrinsic call to DXIL op function call.
const char * toString(DWARFSectionKind Kind)
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39