25#include "llvm/IR/IntrinsicsDirectX.h"
32#define DEBUG_TYPE "dxil-op-lower"
38 switch (
F.getIntrinsicID()) {
39 case Intrinsic::dx_dot2:
40 case Intrinsic::dx_dot3:
41 case Intrinsic::dx_dot4:
49 auto *VecArg = dyn_cast<FixedVectorType>(Arg->
getType());
50 for (
unsigned I = 0;
I < VecArg->getNumElements(); ++
I) {
53 ExtractedElements.
push_back(ExtractedElement);
55 return ExtractedElements;
64 [[maybe_unused]]
auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->
getType());
67 for (
unsigned I = 1;
I < NumOperands; ++
I) {
69 [[maybe_unused]]
auto *VecArg = dyn_cast<FixedVectorType>(Arg->
getType());
71 assert(VecArg0->getElementType() == VecArg->getElementType());
72 assert(VecArg0->getNumElements() == VecArg->getNumElements());
74 NewOperands.
append(NextOperandList.begin(), NextOperandList.end());
89 :
M(
M), OpBuilder(
M), DBM(DBM), DRTM(DRTM) {}
97 CallInst *CI = dyn_cast<CallInst>(U);
101 if (
Error E = ReplaceCall(CI)) {
102 std::string Message(
toString(std::move(E)));
105 M.getContext().diagnose(Diag);
114 struct IntrinArgSelect {
116#define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name,
117#include "DXILOperation.inc"
127 assert(!(IsVectorArgExpansion && ArgSelects.
size()) &&
128 "Cann't do vector arg expansion when using arg selects.");
130 OpBuilder.
getIRB().SetInsertPoint(CI);
132 if (ArgSelects.
size()) {
133 for (const IntrinArgSelect &A : ArgSelects) {
135 case IntrinArgSelect::Type::Index:
136 Args.push_back(CI->getArgOperand(A.Value));
138 case IntrinArgSelect::Type::I8:
139 Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value));
141 case IntrinArgSelect::Type::I32:
142 Args.push_back(OpBuilder.getIRB().getInt32(A.Value));
146 }
else if (IsVectorArgExpansion) {
163 [[nodiscard]]
bool replaceFunctionWithNamedStructOp(
169 OpBuilder.
getIRB().SetInsertPoint(CI);
170 if (IsVectorArgExpansion) {
180 if (
Error E = ReplaceUses(CI, *OpCall))
194 Intrinsic::dx_resource_casthandle, {Ty,
V->getType()}, {
V});
199 void cleanupHandleCasts() {
203 for (
CallInst *Cast : CleanupCasts) {
219 assert(
Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle &&
220 "Unbalanced pair of temporary handle casts");
233 F->eraseFromParent();
235 CleanupCasts.clear();
243 void removeResourceGlobals(
CallInst *CI) {
247 Store->eraseFromParent();
249 if (GV->use_empty()) {
250 GV->removeDeadConstantUsers();
251 GV->eraseFromParent();
257 [[nodiscard]]
bool lowerToCreateHandle(
Function &
F) {
265 auto *It = DBM.
find(CI);
266 assert(It != DBM.
end() &&
"Resource not in map?");
273 if (Binding.LowerBound != 0)
275 ConstantInt::get(Int32Ty, Binding.LowerBound));
277 std::array<Value *, 4>
Args{
279 ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp,
286 Value *Cast = createTmpHandleCast(*OpCall, CI->
getType());
288 removeResourceGlobals(CI);
296 [[nodiscard]]
bool lowerToBindAndAnnotateHandle(
Function &
F) {
303 auto *It = DBM.
find(CI);
304 assert(It != DBM.
end() &&
"Resource not in map?");
312 if (Binding.LowerBound != 0)
314 ConstantInt::get(Int32Ty, Binding.LowerBound));
316 std::pair<uint32_t, uint32_t> Props =
321 uint32_t Unbounded = std::numeric_limits<uint32_t>::max();
322 uint32_t UpperBound = Binding.Size == Unbounded
324 : Binding.LowerBound + Binding.Size - 1;
327 std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->
getArgOperand(4)};
329 OpCode::CreateHandleFromBinding, BindArgs, CI->
getName());
333 std::array<Value *, 2> AnnotateArgs{
334 *OpBind, OpBuilder.
getResProps(Props.first, Props.second)};
336 OpCode::AnnotateHandle, AnnotateArgs,
341 Value *Cast = createTmpHandleCast(*OpAnnotate, CI->
getType());
343 removeResourceGlobals(CI);
355 bool lowerHandleFromBinding(
Function &
F) {
358 return lowerToCreateHandle(
F);
359 return lowerToBindAndAnnotateHandle(
F);
364 if (
auto *EVI = dyn_cast<ExtractValueInst>(
U.getUser())) {
366 if (EVI->getNumIndices() != 1)
368 "Splitdouble has only 2 elements");
369 EVI->setOperand(0,
Op);
371 return make_error<StringError>(
372 "Splitdouble use is not ExtractValueInst",
391 auto *
ST = cast<StructType>(OldTy);
393 Value *CheckOp =
nullptr;
396 if (
auto *EVI = dyn_cast<ExtractValueInst>(
U.getUser())) {
405 OpCode::CheckAccessFullyMapped, {NewEVI},
413 EVI->replaceAllUsesWith(CheckOp);
414 EVI->eraseFromParent();
418 OldResult = cast<Instruction>(
420 OldTy =
ST->getElementType(0);
424 if (!isa<FixedVectorType>(OldTy)) {
428 if (OldResult != Intrin) {
435 std::array<Value *, 4> Extracts = {};
441 if (
auto *EEI = dyn_cast<ExtractElementInst>(
U.getUser())) {
442 if (
auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
443 size_t IndexVal = IndexOp->getZExtValue();
444 assert(IndexVal < 4 &&
"Index into buffer load out of range");
445 if (!Extracts[IndexVal])
448 EEI->eraseFromParent();
455 const auto *VecTy = cast<FixedVectorType>(OldTy);
456 const unsigned N = VecTy->getNumElements();
460 if (!DynamicAccesses.
empty()) {
464 Type *ElTy = VecTy->getElementType();
465 Type *ArrayTy = ArrayType::get(ElTy,
N);
468 for (
int I = 0, E =
N;
I != E; ++
I) {
472 ArrayTy, Alloca, {
Zero, ConstantInt::get(Int32Ty,
I)});
478 {
Zero, EEI->getIndexOperand()});
481 EEI->eraseFromParent();
489 for (
int I = 0, E =
N;
I != E; ++
I)
494 for (
int I = 0, E =
N;
I != E; ++
I)
500 if (OldResult != Intrin) {
508 [[nodiscard]]
bool lowerTypedBufferLoad(
Function &
F,
bool HasCheckBit) {
522 OldTy = cast<StructType>(OldTy)->getElementType(0);
525 std::array<Value *, 3>
Args{Handle, Index0, Index1};
527 OpCode::BufferLoad, Args, CI->
getName(), NewRetTy);
530 if (
Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
537 [[nodiscard]]
bool lowerUpdateCounter(
Function &
F) {
547 std::array<Value *, 2>
Args{Handle, Op1};
550 OpCode::UpdateCounter, Args, CI->
getName(), Int32Ty);
561 [[nodiscard]]
bool lowerGetPointer(
Function &
F) {
564 assert(
F.user_empty() &&
"getpointer operations should have been removed");
569 [[nodiscard]]
bool lowerTypedBufferStore(
Function &
F) {
585 auto *DataTy = dyn_cast<FixedVectorType>(Data->getType());
586 if (!DataTy || DataTy->getNumElements() != 4)
587 return make_error<StringError>(
588 "typedBufferStore data must be a vector of 4 elements",
594 std::array<Value *, 4> DataElements{
nullptr,
nullptr,
nullptr,
nullptr};
595 auto *IEI = dyn_cast<InsertElementInst>(Data);
597 auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
600 size_t IndexVal = IndexOp->getZExtValue();
601 assert(IndexVal < 4 &&
"Too many elements for buffer store");
602 DataElements[IndexVal] = IEI->getOperand(1);
603 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
609 for (
int I = 0, E = 4;
I != E; ++
I)
610 if (DataElements[
I] ==
nullptr)
614 std::array<Value *, 8>
Args{
615 Handle, Index0, Index1, DataElements[0],
616 DataElements[1], DataElements[2], DataElements[3],
Mask};
624 IEI = dyn_cast<InsertElementInst>(Data);
625 while (IEI && IEI->use_empty()) {
627 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
635 [[nodiscard]]
bool lowerCtpopToCountBits(
Function &
F) {
645 Type *FRT =
F.getReturnType();
646 if (
const auto *VT = dyn_cast<VectorType>(FRT))
655 if (FRT->isIntOrIntVectorTy(32)) {
656 CI->replaceAllUsesWith(*OpCall);
657 CI->eraseFromParent();
658 return Error::success();
663 if (FRT->isIntOrIntVectorTy(16)) {
664 CastOp = Instruction::ZExt;
665 CastOp2 = Instruction::SExt;
667 assert(FRT->isIntOrIntVectorTy(64) &&
668 "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
670 CastOp = Instruction::Trunc;
671 CastOp2 = Instruction::Trunc;
676 bool NeedsCast =
false;
678 Instruction *I = dyn_cast<Instruction>(User);
679 if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
680 I->getType() == RetTy) {
681 I->replaceAllUsesWith(*OpCall);
682 I->eraseFromParent();
702 bool lowerIntrinsics() {
703 bool Updated =
false;
704 bool HasErrors =
false;
707 if (!
F.isDeclaration())
713#define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \
715 HasErrors |= replaceFunctionWithOp( \
716 F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__}); \
718#include "DXILOperation.inc"
719 case Intrinsic::dx_resource_handlefrombinding:
720 HasErrors |= lowerHandleFromBinding(
F);
722 case Intrinsic::dx_resource_getpointer:
723 HasErrors |= lowerGetPointer(
F);
725 case Intrinsic::dx_resource_load_typedbuffer:
726 HasErrors |= lowerTypedBufferLoad(
F,
false);
728 case Intrinsic::dx_resource_loadchecked_typedbuffer:
729 HasErrors |= lowerTypedBufferLoad(
F,
true);
731 case Intrinsic::dx_resource_store_typedbuffer:
732 HasErrors |= lowerTypedBufferStore(
F);
734 case Intrinsic::dx_resource_updatecounter:
735 HasErrors |= lowerUpdateCounter(
F);
739 case Intrinsic::dx_splitdouble:
740 HasErrors |= replaceFunctionWithNamedStructOp(
741 F, OpCode::SplitDouble,
744 return replaceSplitDoubleCallUsages(CI, Op);
747 case Intrinsic::ctpop:
748 HasErrors |= lowerCtpopToCountBits(
F);
753 if (Updated && !HasErrors)
754 cleanupHandleCasts();
765 bool MadeChanges = OpLowerer(M, DBM, DRTM).lowerIntrinsics();
776class DXILOpLoweringLegacy :
public ModulePass {
778 bool runOnModule(
Module &M)
override {
780 getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
782 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
784 return OpLowerer(M, DBM, DRTM).lowerIntrinsics();
786 StringRef getPassName()
const override {
return "DXIL Op Lowering"; }
799char DXILOpLoweringLegacy::ID = 0;
810 return new DXILOpLoweringLegacy();
for(const MachineOperand &MO :llvm::drop_begin(OldMI.operands(), Desc.getNumOperands()))
ReachingDefAnalysis InstSet & ToRemove
static bool isVectorArgExpansion(Function &F)
static SmallVector< Value * > argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder)
static SmallVector< Value * > populateOperands(Value *Arg, IRBuilder<> &Builder)
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
ModuleAnalysisManager MAM
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
This class represents an Operation in the Expression.
iterator find(const CallInst *Key)
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
The legacy pass manager's analysis pass to compute DXIL resource information.
Diagnostic information for unsupported feature in backend.
Lightweight error class with error context and mandatory checking.
static ErrorSuccess success()
Create a success value.
Tagged union holding either a T or a Error.
Error takeError()
Take ownership of the stored error.
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Value * CreateZExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a ZExt or Trunc from the integer value V to DestTy.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Value * CreateInBoundsGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
This instruction inserts a single (scalar) element into a VectorType value.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
const Function * getFunction() const
Return the function this instruction belongs to.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Triple - Helper class for working with autoconf configuration names.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
static IntegerType * getInt32Ty(LLVMContext &C)
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
unsigned getNumOperands() const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
LLVMContext & getContext() const
All values hold a context through their type.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
Represents a version number in the form major[.minor[.subminor[.build]]].
StructType * getResRetType(Type *ElementTy)
Get a dx.types.ResRet type with the given element type.
StructType * getSplitDoubleType(LLVMContext &Context)
Get the dx.types.splitdouble type.
Expected< CallInst * > tryCreateOp(dxil::OpCode Op, ArrayRef< Value * > Args, const Twine &Name="", Type *RetTy=nullptr)
Try to create a call instruction for the given DXIL op.
Constant * getResBind(uint32_t LowerBound, uint32_t UpperBound, uint32_t SpaceID, dxil::ResourceClass RC)
Get a constant dx.types.ResBind value.
Constant * getResProps(uint32_t Word0, uint32_t Word1)
Get a constant dx.types.ResourceProperties value.
StructType * getHandleType()
Get the dx.types.Handle type.
std::pair< uint32_t, uint32_t > getAnnotateProps(Module &M, dxil::ResourceTypeInfo &RTI) const
TargetExtType * getHandleTy() const
const ResourceBinding & getBinding() const
dxil::ResourceClass getResourceClass() const
Wrapper pass for the legacy pass manager.
An efficient, type-erasing, non-owning reference to a callable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
NodeAddr< DefNode * > Def
This is an optimization pass for GlobalISel generic memory operations.
std::error_code inconvertibleErrorCode()
The value returned by this function can be returned from convertToErrorCode for Error values where no...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
auto unique(Range &&R, Predicate P)
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
void sort(IteratorTy Start, IteratorTy End)
constexpr std::underlying_type_t< Enum > to_underlying(Enum E)
Returns underlying integer value of an enum.
ModulePass * createDXILOpLoweringLegacyPass()
Pass to lowering LLVM intrinsic call to DXIL op function call.
const char * toString(DWARFSectionKind Kind)