LLVM 20.0.0git
DXILOpLowering.cpp
Go to the documentation of this file.
1//===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "DXILOpLowering.h"
10#include "DXILConstants.h"
12#include "DXILOpBuilder.h"
14#include "DXILShaderFlags.h"
15#include "DirectX.h"
19#include "llvm/CodeGen/Passes.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/Instruction.h"
24#include "llvm/IR/Intrinsics.h"
25#include "llvm/IR/IntrinsicsDirectX.h"
26#include "llvm/IR/Module.h"
27#include "llvm/IR/PassManager.h"
29#include "llvm/Pass.h"
31
32#define DEBUG_TYPE "dxil-op-lower"
33
34using namespace llvm;
35using namespace llvm::dxil;
36
38 switch (F.getIntrinsicID()) {
39 case Intrinsic::dx_dot2:
40 case Intrinsic::dx_dot3:
41 case Intrinsic::dx_dot4:
42 return true;
43 }
44 return false;
45}
46
48 SmallVector<Value *> ExtractedElements;
49 auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
50 for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
51 Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
52 Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
53 ExtractedElements.push_back(ExtractedElement);
54 }
55 return ExtractedElements;
56}
57
59 IRBuilder<> &Builder) {
60 // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
61 unsigned NumOperands = Orig->getNumOperands() - 1;
62 assert(NumOperands > 0);
63 Value *Arg0 = Orig->getOperand(0);
64 [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
65 assert(VecArg0);
66 SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
67 for (unsigned I = 1; I < NumOperands; ++I) {
68 Value *Arg = Orig->getOperand(I);
69 [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
70 assert(VecArg);
71 assert(VecArg0->getElementType() == VecArg->getElementType());
72 assert(VecArg0->getNumElements() == VecArg->getNumElements());
73 auto NextOperandList = populateOperands(Arg, Builder);
74 NewOperands.append(NextOperandList.begin(), NextOperandList.end());
75 }
76 return NewOperands;
77}
78
79namespace {
80class OpLowerer {
81 Module &M;
82 DXILOpBuilder OpBuilder;
83 DXILBindingMap &DBM;
85 SmallVector<CallInst *> CleanupCasts;
86
87public:
88 OpLowerer(Module &M, DXILBindingMap &DBM, DXILResourceTypeMap &DRTM)
89 : M(M), OpBuilder(M), DBM(DBM), DRTM(DRTM) {}
90
91 /// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If
92 /// there is an error replacing a call, we emit a diagnostic and return true.
93 [[nodiscard]] bool
94 replaceFunction(Function &F,
95 llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
96 for (User *U : make_early_inc_range(F.users())) {
97 CallInst *CI = dyn_cast<CallInst>(U);
98 if (!CI)
99 continue;
100
101 if (Error E = ReplaceCall(CI)) {
102 std::string Message(toString(std::move(E)));
103 DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
104 CI->getDebugLoc());
105 M.getContext().diagnose(Diag);
106 return true;
107 }
108 }
109 if (F.user_empty())
110 F.eraseFromParent();
111 return false;
112 }
113
114 struct IntrinArgSelect {
115 enum class Type {
116#define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name,
117#include "DXILOperation.inc"
118 };
119 Type Type;
120 int Value;
121 };
122
123 [[nodiscard]] bool
124 replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
125 ArrayRef<IntrinArgSelect> ArgSelects) {
126 bool IsVectorArgExpansion = isVectorArgExpansion(F);
127 assert(!(IsVectorArgExpansion && ArgSelects.size()) &&
128 "Cann't do vector arg expansion when using arg selects.");
129 return replaceFunction(F, [&](CallInst *CI) -> Error {
130 OpBuilder.getIRB().SetInsertPoint(CI);
132 if (ArgSelects.size()) {
133 for (const IntrinArgSelect &A : ArgSelects) {
134 switch (A.Type) {
135 case IntrinArgSelect::Type::Index:
136 Args.push_back(CI->getArgOperand(A.Value));
137 break;
138 case IntrinArgSelect::Type::I8:
139 Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value));
140 break;
141 case IntrinArgSelect::Type::I32:
142 Args.push_back(OpBuilder.getIRB().getInt32(A.Value));
143 break;
144 }
145 }
146 } else if (IsVectorArgExpansion) {
147 Args = argVectorFlatten(CI, OpBuilder.getIRB());
148 } else {
149 Args.append(CI->arg_begin(), CI->arg_end());
150 }
151
152 Expected<CallInst *> OpCall =
153 OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType());
154 if (Error E = OpCall.takeError())
155 return E;
156
157 CI->replaceAllUsesWith(*OpCall);
158 CI->eraseFromParent();
159 return Error::success();
160 });
161 }
162
163 [[nodiscard]] bool replaceFunctionWithNamedStructOp(
164 Function &F, dxil::OpCode DXILOp, Type *NewRetTy,
165 llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) {
166 bool IsVectorArgExpansion = isVectorArgExpansion(F);
167 return replaceFunction(F, [&](CallInst *CI) -> Error {
169 OpBuilder.getIRB().SetInsertPoint(CI);
170 if (IsVectorArgExpansion) {
171 SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
172 Args.append(NewArgs.begin(), NewArgs.end());
173 } else
174 Args.append(CI->arg_begin(), CI->arg_end());
175
176 Expected<CallInst *> OpCall =
177 OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy);
178 if (Error E = OpCall.takeError())
179 return E;
180 if (Error E = ReplaceUses(CI, *OpCall))
181 return E;
182
183 return Error::success();
184 });
185 }
186
187 /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which
188 /// is intended to be removed by the end of lowering. This is used to allow
189 /// lowering of ops which need to change their return or argument types in a
190 /// piecemeal way - we can add the casts in to avoid updating all of the uses
191 /// or defs, and by the end all of the casts will be redundant.
192 Value *createTmpHandleCast(Value *V, Type *Ty) {
193 CallInst *Cast = OpBuilder.getIRB().CreateIntrinsic(
194 Intrinsic::dx_resource_casthandle, {Ty, V->getType()}, {V});
195 CleanupCasts.push_back(Cast);
196 return Cast;
197 }
198
199 void cleanupHandleCasts() {
202
203 for (CallInst *Cast : CleanupCasts) {
204 // These casts were only put in to ease the move from `target("dx")` types
205 // to `dx.types.Handle in a piecemeal way. At this point, all of the
206 // non-cast uses should now be `dx.types.Handle`, and remaining casts
207 // should all form pairs to and from the now unused `target("dx")` type.
208 CastFns.push_back(Cast->getCalledFunction());
209
210 // If the cast is not to `dx.types.Handle`, it should be the first part of
211 // the pair. Keep track so we can remove it once it has no more uses.
212 if (Cast->getType() != OpBuilder.getHandleType()) {
213 ToRemove.push_back(Cast);
214 continue;
215 }
216 // Otherwise, we're the second handle in a pair. Forward the arguments and
217 // remove the (second) cast.
218 CallInst *Def = cast<CallInst>(Cast->getOperand(0));
219 assert(Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle &&
220 "Unbalanced pair of temporary handle casts");
221 Cast->replaceAllUsesWith(Def->getOperand(0));
222 Cast->eraseFromParent();
223 }
224 for (CallInst *Cast : ToRemove) {
225 assert(Cast->user_empty() && "Temporary handle cast still has users");
226 Cast->eraseFromParent();
227 }
228
229 // Deduplicate the cast functions so that we only erase each one once.
230 llvm::sort(CastFns);
231 CastFns.erase(llvm::unique(CastFns), CastFns.end());
232 for (Function *F : CastFns)
233 F->eraseFromParent();
234
235 CleanupCasts.clear();
236 }
237
238 // Remove the resource global associated with the handleFromBinding call
239 // instruction and their uses as they aren't needed anymore.
240 // TODO: We should verify that all the globals get removed.
241 // It's expected we'll need a custom pass in the future that will eliminate
242 // the need for this here.
243 void removeResourceGlobals(CallInst *CI) {
244 for (User *User : make_early_inc_range(CI->users())) {
245 if (StoreInst *Store = dyn_cast<StoreInst>(User)) {
246 Value *V = Store->getOperand(1);
247 Store->eraseFromParent();
248 if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
249 if (GV->use_empty()) {
250 GV->removeDeadConstantUsers();
251 GV->eraseFromParent();
252 }
253 }
254 }
255 }
256
257 [[nodiscard]] bool lowerToCreateHandle(Function &F) {
258 IRBuilder<> &IRB = OpBuilder.getIRB();
259 Type *Int8Ty = IRB.getInt8Ty();
260 Type *Int32Ty = IRB.getInt32Ty();
261
262 return replaceFunction(F, [&](CallInst *CI) -> Error {
263 IRB.SetInsertPoint(CI);
264
265 auto *It = DBM.find(CI);
266 assert(It != DBM.end() && "Resource not in map?");
268
269 const auto &Binding = RI.getBinding();
270 dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass();
271
272 Value *IndexOp = CI->getArgOperand(3);
273 if (Binding.LowerBound != 0)
274 IndexOp = IRB.CreateAdd(IndexOp,
275 ConstantInt::get(Int32Ty, Binding.LowerBound));
276
277 std::array<Value *, 4> Args{
278 ConstantInt::get(Int8Ty, llvm::to_underlying(RC)),
279 ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp,
280 CI->getArgOperand(4)};
281 Expected<CallInst *> OpCall =
282 OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName());
283 if (Error E = OpCall.takeError())
284 return E;
285
286 Value *Cast = createTmpHandleCast(*OpCall, CI->getType());
287
288 removeResourceGlobals(CI);
289
290 CI->replaceAllUsesWith(Cast);
291 CI->eraseFromParent();
292 return Error::success();
293 });
294 }
295
296 [[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) {
297 IRBuilder<> &IRB = OpBuilder.getIRB();
298 Type *Int32Ty = IRB.getInt32Ty();
299
300 return replaceFunction(F, [&](CallInst *CI) -> Error {
301 IRB.SetInsertPoint(CI);
302
303 auto *It = DBM.find(CI);
304 assert(It != DBM.end() && "Resource not in map?");
306
307 const auto &Binding = RI.getBinding();
308 dxil::ResourceTypeInfo &RTI = DRTM[RI.getHandleTy()];
310
311 Value *IndexOp = CI->getArgOperand(3);
312 if (Binding.LowerBound != 0)
313 IndexOp = IRB.CreateAdd(IndexOp,
314 ConstantInt::get(Int32Ty, Binding.LowerBound));
315
316 std::pair<uint32_t, uint32_t> Props =
317 RI.getAnnotateProps(*F.getParent(), RTI);
318
319 // For `CreateHandleFromBinding` we need the upper bound rather than the
320 // size, so we need to be careful about the difference for "unbounded".
321 uint32_t Unbounded = std::numeric_limits<uint32_t>::max();
322 uint32_t UpperBound = Binding.Size == Unbounded
323 ? Unbounded
324 : Binding.LowerBound + Binding.Size - 1;
325 Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound,
326 Binding.Space, RC);
327 std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)};
328 Expected<CallInst *> OpBind = OpBuilder.tryCreateOp(
329 OpCode::CreateHandleFromBinding, BindArgs, CI->getName());
330 if (Error E = OpBind.takeError())
331 return E;
332
333 std::array<Value *, 2> AnnotateArgs{
334 *OpBind, OpBuilder.getResProps(Props.first, Props.second)};
335 Expected<CallInst *> OpAnnotate = OpBuilder.tryCreateOp(
336 OpCode::AnnotateHandle, AnnotateArgs,
337 CI->hasName() ? CI->getName() + "_annot" : Twine());
338 if (Error E = OpAnnotate.takeError())
339 return E;
340
341 Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType());
342
343 removeResourceGlobals(CI);
344
345 CI->replaceAllUsesWith(Cast);
346 CI->eraseFromParent();
347
348 return Error::success();
349 });
350 }
351
352 /// Lower `dx.resource.handlefrombinding` intrinsics depending on the shader
353 /// model and taking into account binding information from
354 /// DXILResourceBindingAnalysis.
355 bool lowerHandleFromBinding(Function &F) {
356 Triple TT(Triple(M.getTargetTriple()));
357 if (TT.getDXILVersion() < VersionTuple(1, 6))
358 return lowerToCreateHandle(F);
359 return lowerToBindAndAnnotateHandle(F);
360 }
361
362 Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
363 for (Use &U : make_early_inc_range(Intrin->uses())) {
364 if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
365
366 if (EVI->getNumIndices() != 1)
367 return createStringError(std::errc::invalid_argument,
368 "Splitdouble has only 2 elements");
369 EVI->setOperand(0, Op);
370 } else {
371 return make_error<StringError>(
372 "Splitdouble use is not ExtractValueInst",
374 }
375 }
376
377 Intrin->eraseFromParent();
378
379 return Error::success();
380 }
381
382 /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
383 /// Since we expect to be post-scalarization, make an effort to avoid vectors.
384 Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
385 IRBuilder<> &IRB = OpBuilder.getIRB();
386
387 Instruction *OldResult = Intrin;
388 Type *OldTy = Intrin->getType();
389
390 if (HasCheckBit) {
391 auto *ST = cast<StructType>(OldTy);
392
393 Value *CheckOp = nullptr;
394 Type *Int32Ty = IRB.getInt32Ty();
395 for (Use &U : make_early_inc_range(OldResult->uses())) {
396 if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
397 ArrayRef<unsigned> Indices = EVI->getIndices();
398 assert(Indices.size() == 1);
399 // We're only interested in uses of the check bit for now.
400 if (Indices[0] != 1)
401 continue;
402 if (!CheckOp) {
403 Value *NewEVI = IRB.CreateExtractValue(Op, 4);
404 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
405 OpCode::CheckAccessFullyMapped, {NewEVI},
406 OldResult->hasName() ? OldResult->getName() + "_check"
407 : Twine(),
408 Int32Ty);
409 if (Error E = OpCall.takeError())
410 return E;
411 CheckOp = *OpCall;
412 }
413 EVI->replaceAllUsesWith(CheckOp);
414 EVI->eraseFromParent();
415 }
416 }
417
418 OldResult = cast<Instruction>(
419 IRB.CreateExtractValue(Op, 0, OldResult->getName()));
420 OldTy = ST->getElementType(0);
421 }
422
423 // For scalars, we just extract the first element.
424 if (!isa<FixedVectorType>(OldTy)) {
425 Value *EVI = IRB.CreateExtractValue(Op, 0);
426 OldResult->replaceAllUsesWith(EVI);
427 OldResult->eraseFromParent();
428 if (OldResult != Intrin) {
429 assert(Intrin->use_empty() && "Intrinsic still has uses?");
430 Intrin->eraseFromParent();
431 }
432 return Error::success();
433 }
434
435 std::array<Value *, 4> Extracts = {};
436 SmallVector<ExtractElementInst *> DynamicAccesses;
437
438 // The users of the operation should all be scalarized, so we attempt to
439 // replace the extractelements with extractvalues directly.
440 for (Use &U : make_early_inc_range(OldResult->uses())) {
441 if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
442 if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
443 size_t IndexVal = IndexOp->getZExtValue();
444 assert(IndexVal < 4 && "Index into buffer load out of range");
445 if (!Extracts[IndexVal])
446 Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal);
447 EEI->replaceAllUsesWith(Extracts[IndexVal]);
448 EEI->eraseFromParent();
449 } else {
450 DynamicAccesses.push_back(EEI);
451 }
452 }
453 }
454
455 const auto *VecTy = cast<FixedVectorType>(OldTy);
456 const unsigned N = VecTy->getNumElements();
457
458 // If there's a dynamic access we need to round trip through stack memory so
459 // that we don't leave vectors around.
460 if (!DynamicAccesses.empty()) {
461 Type *Int32Ty = IRB.getInt32Ty();
462 Constant *Zero = ConstantInt::get(Int32Ty, 0);
463
464 Type *ElTy = VecTy->getElementType();
465 Type *ArrayTy = ArrayType::get(ElTy, N);
466 Value *Alloca = IRB.CreateAlloca(ArrayTy);
467
468 for (int I = 0, E = N; I != E; ++I) {
469 if (!Extracts[I])
470 Extracts[I] = IRB.CreateExtractValue(Op, I);
472 ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)});
473 IRB.CreateStore(Extracts[I], GEP);
474 }
475
476 for (ExtractElementInst *EEI : DynamicAccesses) {
477 Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca,
478 {Zero, EEI->getIndexOperand()});
479 Value *Load = IRB.CreateLoad(ElTy, GEP);
480 EEI->replaceAllUsesWith(Load);
481 EEI->eraseFromParent();
482 }
483 }
484
485 // If we still have uses, then we're not fully scalarized and need to
486 // recreate the vector. This should only happen for things like exported
487 // functions from libraries.
488 if (!OldResult->use_empty()) {
489 for (int I = 0, E = N; I != E; ++I)
490 if (!Extracts[I])
491 Extracts[I] = IRB.CreateExtractValue(Op, I);
492
493 Value *Vec = UndefValue::get(OldTy);
494 for (int I = 0, E = N; I != E; ++I)
495 Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
496 OldResult->replaceAllUsesWith(Vec);
497 }
498
499 OldResult->eraseFromParent();
500 if (OldResult != Intrin) {
501 assert(Intrin->use_empty() && "Intrinsic still has uses?");
502 Intrin->eraseFromParent();
503 }
504
505 return Error::success();
506 }
507
508 [[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) {
509 IRBuilder<> &IRB = OpBuilder.getIRB();
510 Type *Int32Ty = IRB.getInt32Ty();
511
512 return replaceFunction(F, [&](CallInst *CI) -> Error {
513 IRB.SetInsertPoint(CI);
514
515 Value *Handle =
516 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
517 Value *Index0 = CI->getArgOperand(1);
518 Value *Index1 = UndefValue::get(Int32Ty);
519
520 Type *OldTy = CI->getType();
521 if (HasCheckBit)
522 OldTy = cast<StructType>(OldTy)->getElementType(0);
523 Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType());
524
525 std::array<Value *, 3> Args{Handle, Index0, Index1};
526 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
527 OpCode::BufferLoad, Args, CI->getName(), NewRetTy);
528 if (Error E = OpCall.takeError())
529 return E;
530 if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
531 return E;
532
533 return Error::success();
534 });
535 }
536
537 [[nodiscard]] bool lowerUpdateCounter(Function &F) {
538 IRBuilder<> &IRB = OpBuilder.getIRB();
539 Type *Int32Ty = IRB.getInt32Ty();
540
541 return replaceFunction(F, [&](CallInst *CI) -> Error {
542 IRB.SetInsertPoint(CI);
543 Value *Handle =
544 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
545 Value *Op1 = CI->getArgOperand(1);
546
547 std::array<Value *, 2> Args{Handle, Op1};
548
549 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
550 OpCode::UpdateCounter, Args, CI->getName(), Int32Ty);
551
552 if (Error E = OpCall.takeError())
553 return E;
554
555 CI->replaceAllUsesWith(*OpCall);
556 CI->eraseFromParent();
557 return Error::success();
558 });
559 }
560
561 [[nodiscard]] bool lowerGetPointer(Function &F) {
562 // These should have already been handled in DXILResourceAccess, so we can
563 // just clean up the dead prototype.
564 assert(F.user_empty() && "getpointer operations should have been removed");
565 F.eraseFromParent();
566 return false;
567 }
568
569 [[nodiscard]] bool lowerTypedBufferStore(Function &F) {
570 IRBuilder<> &IRB = OpBuilder.getIRB();
571 Type *Int8Ty = IRB.getInt8Ty();
572 Type *Int32Ty = IRB.getInt32Ty();
573
574 return replaceFunction(F, [&](CallInst *CI) -> Error {
575 IRB.SetInsertPoint(CI);
576
577 Value *Handle =
578 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
579 Value *Index0 = CI->getArgOperand(1);
580 Value *Index1 = UndefValue::get(Int32Ty);
581 // For typed stores, the mask must always cover all four elements.
582 Constant *Mask = ConstantInt::get(Int8Ty, 0xF);
583
584 Value *Data = CI->getArgOperand(2);
585 auto *DataTy = dyn_cast<FixedVectorType>(Data->getType());
586 if (!DataTy || DataTy->getNumElements() != 4)
587 return make_error<StringError>(
588 "typedBufferStore data must be a vector of 4 elements",
590
591 // Since we're post-scalarizer, we likely have a vector that's constructed
592 // solely for the argument of the store. If so, just use the scalar values
593 // from before they're inserted into the temporary.
594 std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
595 auto *IEI = dyn_cast<InsertElementInst>(Data);
596 while (IEI) {
597 auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
598 if (!IndexOp)
599 break;
600 size_t IndexVal = IndexOp->getZExtValue();
601 assert(IndexVal < 4 && "Too many elements for buffer store");
602 DataElements[IndexVal] = IEI->getOperand(1);
603 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
604 }
605
606 // If for some reason we weren't able to forward the arguments from the
607 // scalarizer artifact, then we need to actually extract elements from the
608 // vector.
609 for (int I = 0, E = 4; I != E; ++I)
610 if (DataElements[I] == nullptr)
611 DataElements[I] =
612 IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
613
614 std::array<Value *, 8> Args{
615 Handle, Index0, Index1, DataElements[0],
616 DataElements[1], DataElements[2], DataElements[3], Mask};
617 Expected<CallInst *> OpCall =
618 OpBuilder.tryCreateOp(OpCode::BufferStore, Args, CI->getName());
619 if (Error E = OpCall.takeError())
620 return E;
621
622 CI->eraseFromParent();
623 // Clean up any leftover `insertelement`s
624 IEI = dyn_cast<InsertElementInst>(Data);
625 while (IEI && IEI->use_empty()) {
626 InsertElementInst *Tmp = IEI;
627 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
628 Tmp->eraseFromParent();
629 }
630
631 return Error::success();
632 });
633 }
634
635 [[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
636 IRBuilder<> &IRB = OpBuilder.getIRB();
637 Type *Int32Ty = IRB.getInt32Ty();
638
639 return replaceFunction(F, [&](CallInst *CI) -> Error {
640 IRB.SetInsertPoint(CI);
642 Args.append(CI->arg_begin(), CI->arg_end());
643
644 Type *RetTy = Int32Ty;
645 Type *FRT = F.getReturnType();
646 if (const auto *VT = dyn_cast<VectorType>(FRT))
647 RetTy = VectorType::get(RetTy, VT);
648
649 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
650 dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
651 if (Error E = OpCall.takeError())
652 return E;
653
654 // If the result type is 32 bits we can do a direct replacement.
655 if (FRT->isIntOrIntVectorTy(32)) {
656 CI->replaceAllUsesWith(*OpCall);
657 CI->eraseFromParent();
658 return Error::success();
659 }
660
661 unsigned CastOp;
662 unsigned CastOp2;
663 if (FRT->isIntOrIntVectorTy(16)) {
664 CastOp = Instruction::ZExt;
665 CastOp2 = Instruction::SExt;
666 } else { // must be 64 bits
667 assert(FRT->isIntOrIntVectorTy(64) &&
668 "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
669 is supported.");
670 CastOp = Instruction::Trunc;
671 CastOp2 = Instruction::Trunc;
672 }
673
674 // It is correct to replace the ctpop with the dxil op and
675 // remove all casts to i32
676 bool NeedsCast = false;
678 Instruction *I = dyn_cast<Instruction>(User);
679 if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
680 I->getType() == RetTy) {
681 I->replaceAllUsesWith(*OpCall);
682 I->eraseFromParent();
683 } else
684 NeedsCast = true;
685 }
686
687 // It is correct to replace a ctpop with the dxil op and
688 // a cast from i32 to the return type of the ctpop
689 // the cast is emitted here if there is a non-cast to i32
690 // instr which uses the ctpop
691 if (NeedsCast) {
692 Value *Cast =
693 IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
694 CI->replaceAllUsesWith(Cast);
695 }
696
697 CI->eraseFromParent();
698 return Error::success();
699 });
700 }
701
702 bool lowerIntrinsics() {
703 bool Updated = false;
704 bool HasErrors = false;
705
706 for (Function &F : make_early_inc_range(M.functions())) {
707 if (!F.isDeclaration())
708 continue;
709 Intrinsic::ID ID = F.getIntrinsicID();
710 switch (ID) {
711 default:
712 continue;
713#define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \
714 case Intrin: \
715 HasErrors |= replaceFunctionWithOp( \
716 F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__}); \
717 break;
718#include "DXILOperation.inc"
719 case Intrinsic::dx_resource_handlefrombinding:
720 HasErrors |= lowerHandleFromBinding(F);
721 break;
722 case Intrinsic::dx_resource_getpointer:
723 HasErrors |= lowerGetPointer(F);
724 break;
725 case Intrinsic::dx_resource_load_typedbuffer:
726 HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false);
727 break;
728 case Intrinsic::dx_resource_loadchecked_typedbuffer:
729 HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true);
730 break;
731 case Intrinsic::dx_resource_store_typedbuffer:
732 HasErrors |= lowerTypedBufferStore(F);
733 break;
734 case Intrinsic::dx_resource_updatecounter:
735 HasErrors |= lowerUpdateCounter(F);
736 break;
737 // TODO: this can be removed when
738 // https://github.com/llvm/llvm-project/issues/113192 is fixed
739 case Intrinsic::dx_splitdouble:
740 HasErrors |= replaceFunctionWithNamedStructOp(
741 F, OpCode::SplitDouble,
742 OpBuilder.getSplitDoubleType(M.getContext()),
743 [&](CallInst *CI, CallInst *Op) {
744 return replaceSplitDoubleCallUsages(CI, Op);
745 });
746 break;
747 case Intrinsic::ctpop:
748 HasErrors |= lowerCtpopToCountBits(F);
749 break;
750 }
751 Updated = true;
752 }
753 if (Updated && !HasErrors)
754 cleanupHandleCasts();
755
756 return Updated;
757 }
758};
759} // namespace
760
764
765 bool MadeChanges = OpLowerer(M, DBM, DRTM).lowerIntrinsics();
766 if (!MadeChanges)
767 return PreservedAnalyses::all();
772 return PA;
773}
774
775namespace {
776class DXILOpLoweringLegacy : public ModulePass {
777public:
778 bool runOnModule(Module &M) override {
779 DXILBindingMap &DBM =
780 getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
781 DXILResourceTypeMap &DRTM =
782 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
783
784 return OpLowerer(M, DBM, DRTM).lowerIntrinsics();
785 }
786 StringRef getPassName() const override { return "DXIL Op Lowering"; }
787 DXILOpLoweringLegacy() : ModulePass(ID) {}
788
789 static char ID; // Pass identification.
790 void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
797 }
798};
799char DXILOpLoweringLegacy::ID = 0;
800} // end anonymous namespace
801
802INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
803 false, false)
806INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
807 false)
808
810 return new DXILOpLoweringLegacy();
811}
for(const MachineOperand &MO :llvm::drop_begin(OldMI.operands(), Desc.getNumOperands()))
ReachingDefAnalysis InstSet & ToRemove
static bool isVectorArgExpansion(Function &F)
static SmallVector< Value * > argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder)
static SmallVector< Value * > populateOperands(Value *Arg, IRBuilder<> &Builder)
return RetTy
#define DEBUG_TYPE
Hexagon Common GEP
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
ModuleAnalysisManager MAM
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:168
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1349
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1269
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1294
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1275
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
Definition: Constant.h:42
This class represents an Operation in the Expression.
iterator find(const CallInst *Key)
Definition: DXILResource.h:444
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
The legacy pass manager's analysis pass to compute DXIL resource information.
Diagnostic information for unsupported feature in backend.
Lightweight error class with error context and mandatory checking.
Definition: Error.h:160
static ErrorSuccess success()
Create a success value.
Definition: Error.h:337
Tagged union holding either a T or a Error.
Definition: Error.h:481
Error takeError()
Take ownership of the stored error.
Definition: Error.h:608
This instruction extracts a single (scalar) element from a VectorType value.
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2503
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1796
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2491
Value * CreateZExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a ZExt or Trunc from the integer value V to DestTy.
Definition: IRBuilder.h:2066
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition: IRBuilder.h:2547
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:890
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Definition: IRBuilder.h:523
Value * CreateInBoundsGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
Definition: IRBuilder.h:1897
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1813
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition: IRBuilder.h:1826
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1350
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
Definition: IRBuilder.h:513
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2697
This instruction inserts a single (scalar) element into a VectorType value.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:471
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:92
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:70
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
void preserve()
Mark an analysis as preserved.
Definition: Analysis.h:131
bool empty() const
Definition: SmallVector.h:81
iterator erase(const_iterator CI)
Definition: SmallVector.h:737
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static IntegerType * getInt32Ty(LLVMContext &C)
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:355
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1859
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
Value * getOperand(unsigned i) const
Definition: User.h:228
unsigned getNumOperands() const
Definition: User.h:250
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
bool use_empty() const
Definition: Value.h:344
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
iterator_range< use_iterator > uses()
Definition: Value.h:376
bool hasName() const
Definition: Value.h:261
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
bool user_empty() const
Definition: Value.h:385
Represents a version number in the form major[.minor[.subminor[.build]]].
Definition: VersionTuple.h:29
StructType * getResRetType(Type *ElementTy)
Get a dx.types.ResRet type with the given element type.
StructType * getSplitDoubleType(LLVMContext &Context)
Get the dx.types.splitdouble type.
Expected< CallInst * > tryCreateOp(dxil::OpCode Op, ArrayRef< Value * > Args, const Twine &Name="", Type *RetTy=nullptr)
Try to create a call instruction for the given DXIL op.
Constant * getResBind(uint32_t LowerBound, uint32_t UpperBound, uint32_t SpaceID, dxil::ResourceClass RC)
Get a constant dx.types.ResBind value.
Constant * getResProps(uint32_t Word0, uint32_t Word1)
Get a constant dx.types.ResourceProperties value.
StructType * getHandleType()
Get the dx.types.Handle type.
std::pair< uint32_t, uint32_t > getAnnotateProps(Module &M, dxil::ResourceTypeInfo &RTI) const
TargetExtType * getHandleTy() const
Definition: DXILResource.h:345
const ResourceBinding & getBinding() const
Definition: DXILResource.h:344
dxil::ResourceClass getResourceClass() const
Definition: DXILResource.h:294
Wrapper pass for the legacy pass manager.
An efficient, type-erasing, non-owning reference to a callable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:125
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
ResourceClass
Definition: DXILABI.h:25
NodeAddr< DefNode * > Def
Definition: RDFGraph.h:384
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
std::error_code inconvertibleErrorCode()
The value returned by this function can be returned from convertToErrorCode for Error values where no...
Definition: Error.cpp:98
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:657
auto unique(Range &&R, Predicate P)
Definition: STLExtras.h:2055
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
Definition: Error.h:1291
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1664
constexpr std::underlying_type_t< Enum > to_underlying(Enum E)
Returns underlying integer value of an enum.
ModulePass * createDXILOpLoweringLegacyPass()
Pass to lowering LLVM intrinsic call to DXIL op function call.
const char * toString(DWARFSectionKind Kind)
#define N