24#include "llvm/IR/IntrinsicsDirectX.h"
32#define DEBUG_TYPE "dxil-op-lower"
49 : M(M), OpBuilder(M), DRM(DRM), DRTM(DRTM), MMDI(MMDI) {}
57 CallInst *CI = dyn_cast<CallInst>(U);
61 if (
Error E = ReplaceCall(CI)) {
62 std::string Message(
toString(std::move(E)));
74 struct IntrinArgSelect {
76#define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name,
77#include "DXILOperation.inc"
88 auto *IntrinTy = cast<StructType>(Intrin->
getType());
89 auto *DXILOpTy = cast<StructType>(DXILOp->
getType());
90 if (!IntrinTy->isLayoutIdentical(DXILOpTy))
91 return make_error<StringError>(
92 "Type mismatch between intrinsic and DXIL op",
96 if (
auto *EVI = dyn_cast<ExtractValueInst>(
U.getUser()))
97 EVI->setOperand(0, DXILOp);
98 else if (
auto *IVI = dyn_cast<InsertValueInst>(
U.getUser()))
99 IVI->setOperand(0, DXILOp);
101 return make_error<StringError>(
"DXIL ops that return structs may only "
102 "be used by insert- and extractvalue",
111 OpBuilder.
getIRB().SetInsertPoint(CI);
113 if (ArgSelects.
size()) {
114 for (const IntrinArgSelect &A : ArgSelects) {
116 case IntrinArgSelect::Type::Index:
117 Args.push_back(CI->getArgOperand(A.Value));
119 case IntrinArgSelect::Type::I8:
120 Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value));
122 case IntrinArgSelect::Type::I32:
123 Args.push_back(OpBuilder.getIRB().getInt32(A.Value));
137 if (Error E = replaceNamedStructUses(CI, *OpCall))
154 Intrinsic::dx_resource_casthandle, {Ty,
V->getType()}, {
V});
159 void cleanupHandleCasts() {
163 for (
CallInst *Cast : CleanupCasts) {
179 assert(
Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle &&
180 "Unbalanced pair of temporary handle casts");
193 F->eraseFromParent();
195 CleanupCasts.clear();
203 void removeResourceGlobals(
CallInst *CI) {
207 Store->eraseFromParent();
209 if (GV->use_empty()) {
210 GV->removeDeadConstantUsers();
211 GV->eraseFromParent();
217 void replaceHandleFromBindingCall(
CallInst *CI,
Value *Replacement) {
219 Intrinsic::dx_resource_handlefrombinding);
221 removeResourceGlobals(CI);
223 auto *NameGlobal = dyn_cast<llvm::GlobalVariable>(CI->
getArgOperand(4));
228 if (NameGlobal && NameGlobal->use_empty())
229 NameGlobal->removeFromParent();
232 [[nodiscard]]
bool lowerToCreateHandle(
Function &
F) {
241 auto *It = DRM.
find(CI);
242 assert(It != DRM.
end() &&
"Resource not in map?");
251 ConstantInt::get(Int32Ty,
Binding.LowerBound));
256 std::array<Value *, 4>
Args{
258 ConstantInt::get(Int32Ty,
Binding.RecordID), IndexOp,
259 ConstantInt::get(Int1Ty,
false)};
265 Value *Cast = createTmpHandleCast(*OpCall, CI->
getType());
266 replaceHandleFromBindingCall(CI, Cast);
271 [[nodiscard]]
bool lowerToBindAndAnnotateHandle(
Function &
F) {
279 auto *It = DRM.
find(CI);
280 assert(It != DRM.
end() &&
"Resource not in map?");
290 ConstantInt::get(Int32Ty,
Binding.LowerBound));
292 std::pair<uint32_t, uint32_t> Props =
297 uint32_t Unbounded = std::numeric_limits<uint32_t>::max();
306 Constant *NonUniform = ConstantInt::get(Int1Ty,
false);
307 std::array<Value *, 3> BindArgs{ResBind, IndexOp, NonUniform};
309 OpCode::CreateHandleFromBinding, BindArgs, CI->
getName());
313 std::array<Value *, 2> AnnotateArgs{
314 *OpBind, OpBuilder.
getResProps(Props.first, Props.second)};
316 OpCode::AnnotateHandle, AnnotateArgs,
321 Value *Cast = createTmpHandleCast(*OpAnnotate, CI->
getType());
322 replaceHandleFromBindingCall(CI, Cast);
330 bool lowerHandleFromBinding(
Function &
F) {
332 return lowerToCreateHandle(
F);
333 return lowerToBindAndAnnotateHandle(
F);
345 auto *
ST = cast<StructType>(OldTy);
347 Value *CheckOp =
nullptr;
350 if (
auto *EVI = dyn_cast<ExtractValueInst>(
U.getUser())) {
359 OpCode::CheckAccessFullyMapped, {NewEVI},
367 EVI->replaceAllUsesWith(CheckOp);
368 EVI->eraseFromParent();
379 isa<ExtractValueInst>(*OldResult->
user_begin()) &&
380 "Expected only use to be extract of first element");
381 OldResult = cast<Instruction>(*OldResult->
user_begin());
382 OldTy =
ST->getElementType(0);
386 if (!isa<FixedVectorType>(OldTy)) {
390 if (OldResult != Intrin) {
397 std::array<Value *, 4> Extracts = {};
403 if (
auto *EEI = dyn_cast<ExtractElementInst>(
U.getUser())) {
404 if (
auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
405 size_t IndexVal = IndexOp->getZExtValue();
406 assert(IndexVal < 4 &&
"Index into buffer load out of range");
407 if (!Extracts[IndexVal])
410 EEI->eraseFromParent();
417 const auto *VecTy = cast<FixedVectorType>(OldTy);
418 const unsigned N = VecTy->getNumElements();
422 if (!DynamicAccesses.
empty()) {
426 Type *ElTy = VecTy->getElementType();
427 Type *ArrayTy = ArrayType::get(ElTy,
N);
430 for (
int I = 0, E =
N;
I != E; ++
I) {
434 ArrayTy, Alloca, {
Zero, ConstantInt::get(Int32Ty,
I)});
440 {
Zero, EEI->getIndexOperand()});
443 EEI->eraseFromParent();
451 for (
int I = 0, E =
N;
I != E; ++
I)
456 for (
int I = 0, E =
N;
I != E; ++
I)
462 if (OldResult != Intrin) {
470 [[nodiscard]]
bool lowerTypedBufferLoad(
Function &
F,
bool HasCheckBit) {
484 OldTy = cast<StructType>(OldTy)->getElementType(0);
487 std::array<Value *, 3>
Args{Handle, Index0, Index1};
489 OpCode::BufferLoad, Args, CI->
getName(), NewRetTy);
492 if (
Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
499 [[nodiscard]]
bool lowerRawBufferLoad(
Function &
F) {
508 Type *OldTy = cast<StructType>(CI->
getType())->getElementType(0);
517 DL.getTypeSizeInBits(OldTy) /
DL.getTypeSizeInBits(ScalarTy);
518 Value *
Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
520 ConstantInt::get(Int32Ty,
DL.getPrefTypeAlign(ScalarTy).value());
528 {Handle, Index0, Index1}, CI->
getName(),
532 if (
Error E = replaceResRetUses(CI, *OpCall,
true))
539 [[nodiscard]]
bool lowerCBufferLoad(
Function &
F) {
545 Type *OldTy = cast<StructType>(CI->
getType())->getElementType(0);
554 OpCode::CBufferLoadLegacy, {Handle,
Index}, CI->
getName(), NewRetTy);
557 if (
Error E = replaceNamedStructUses(CI, *OpCall))
565 [[nodiscard]]
bool lowerUpdateCounter(
Function &
F) {
575 std::array<Value *, 2>
Args{Handle, Op1};
578 OpCode::UpdateCounter, Args, CI->
getName(), Int32Ty);
589 [[nodiscard]]
bool lowerGetPointer(
Function &
F) {
592 assert(
F.user_empty() &&
"getpointer operations should have been removed");
597 [[nodiscard]]
bool lowerBufferStore(
Function &
F,
bool IsRaw) {
612 Type *DataTy = Data->getType();
616 DL.getTypeSizeInBits(DataTy) /
DL.getTypeSizeInBits(ScalarTy);
618 ConstantInt::get(Int8Ty, IsRaw ? ~(~0U << NumElements) : 15U);
622 return make_error<StringError>(
623 "Buffer store data must have at most 4 elements",
626 std::array<Value *, 4> DataElements{
nullptr,
nullptr,
nullptr,
nullptr};
627 if (DataTy == ScalarTy)
628 DataElements[0] = Data;
633 auto *IEI = dyn_cast<InsertElementInst>(Data);
635 auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
638 size_t IndexVal = IndexOp->getZExtValue();
639 assert(IndexVal < 4 &&
"Too many elements for buffer store");
640 DataElements[IndexVal] = IEI->getOperand(1);
641 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
648 for (
int I = 0, E = NumElements;
I < E; ++
I)
649 if (DataElements[
I] ==
nullptr)
656 for (
int I = NumElements, E = 4;
I < E; ++
I)
657 if (DataElements[
I] ==
nullptr)
662 Handle, Index0, Index1, DataElements[0],
663 DataElements[1], DataElements[2], DataElements[3],
Mask};
665 Op = OpCode::RawBufferStore;
668 ConstantInt::get(Int32Ty,
DL.getPrefTypeAlign(ScalarTy).value()));
677 auto *IEI = dyn_cast<InsertElementInst>(Data);
678 while (IEI && IEI->use_empty()) {
680 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
688 [[nodiscard]]
bool lowerCtpopToCountBits(
Function &
F) {
698 Type *FRT =
F.getReturnType();
699 if (
const auto *VT = dyn_cast<VectorType>(FRT))
708 if (FRT->isIntOrIntVectorTy(32)) {
709 CI->replaceAllUsesWith(*OpCall);
710 CI->eraseFromParent();
711 return Error::success();
716 if (FRT->isIntOrIntVectorTy(16)) {
717 CastOp = Instruction::ZExt;
718 CastOp2 = Instruction::SExt;
720 assert(FRT->isIntOrIntVectorTy(64) &&
721 "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
723 CastOp = Instruction::Trunc;
724 CastOp2 = Instruction::Trunc;
729 bool NeedsCast =
false;
731 Instruction *I = dyn_cast<Instruction>(User);
732 if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
733 I->getType() == RetTy) {
734 I->replaceAllUsesWith(*OpCall);
735 I->eraseFromParent();
755 [[nodiscard]]
bool lowerLifetimeIntrinsic(
Function &
F) {
761 "Expected operand of lifetime intrinsic to be a pointer");
763 auto ZeroOrUndef = [&](
Type *Ty) {
769 Value *Val =
nullptr;
770 if (
auto *GV = dyn_cast<GlobalVariable>(
Ptr)) {
771 if (GV->hasInitializer() || GV->isExternallyInitialized())
773 Val = ZeroOrUndef(GV->getValueType());
774 }
else if (
auto *AI = dyn_cast<AllocaInst>(
Ptr))
775 Val = ZeroOrUndef(AI->getAllocatedType());
777 assert(Val &&
"Expected operand of lifetime intrinsic to be a global "
778 "variable or alloca instruction");
786 [[nodiscard]]
bool lowerIsFPClass(
Function &
F) {
798 auto *TCI = dyn_cast<ConstantInt>(
T);
799 switch (TCI->getZExtValue()) {
800 case FPClassTest::fcInf:
801 OpCode = dxil::OpCode::IsInf;
803 case FPClassTest::fcNan:
804 OpCode = dxil::OpCode::IsNaN;
806 case FPClassTest::fcNormal:
807 OpCode = dxil::OpCode::IsNormal;
809 case FPClassTest::fcFinite:
810 OpCode = dxil::OpCode::IsFinite;
814 formatv(
"Unsupported FPClassTest {0} for DXIL Op Lowering",
815 TCI->getZExtValue());
830 bool lowerIntrinsics() {
831 bool Updated =
false;
832 bool HasErrors =
false;
835 if (!
F.isDeclaration())
841 case Intrinsic::dx_resource_casthandle:
843 case Intrinsic::dbg_value:
853 "Unsupported intrinsic {0} for DXIL lowering",
F.getName());
854 M.getContext().emitError(Msg);
859#define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \
861 HasErrors |= replaceFunctionWithOp( \
862 F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__}); \
864#include "DXILOperation.inc"
865 case Intrinsic::dx_resource_handlefrombinding:
866 HasErrors |= lowerHandleFromBinding(
F);
868 case Intrinsic::dx_resource_getpointer:
869 HasErrors |= lowerGetPointer(
F);
871 case Intrinsic::dx_resource_load_typedbuffer:
872 HasErrors |= lowerTypedBufferLoad(
F,
true);
874 case Intrinsic::dx_resource_store_typedbuffer:
875 HasErrors |= lowerBufferStore(
F,
false);
877 case Intrinsic::dx_resource_load_rawbuffer:
878 HasErrors |= lowerRawBufferLoad(
F);
880 case Intrinsic::dx_resource_store_rawbuffer:
881 HasErrors |= lowerBufferStore(
F,
true);
883 case Intrinsic::dx_resource_load_cbufferrow_2:
884 case Intrinsic::dx_resource_load_cbufferrow_4:
885 case Intrinsic::dx_resource_load_cbufferrow_8:
886 HasErrors |= lowerCBufferLoad(
F);
888 case Intrinsic::dx_resource_updatecounter:
889 HasErrors |= lowerUpdateCounter(
F);
891 case Intrinsic::ctpop:
892 HasErrors |= lowerCtpopToCountBits(
F);
894 case Intrinsic::lifetime_start:
895 case Intrinsic::lifetime_end:
900 HasErrors |= lowerLifetimeIntrinsic(
F);
905 case Intrinsic::is_fpclass:
906 HasErrors |= lowerIsFPClass(
F);
911 if (Updated && !HasErrors)
912 cleanupHandleCasts();
924 const bool MadeChanges = OpLowerer(M, DRM, DRTM, MMDI).lowerIntrinsics();
936class DXILOpLoweringLegacy :
public ModulePass {
938 bool runOnModule(
Module &M)
override {
940 getAnalysis<DXILResourceWrapperPass>().getResourceMap();
942 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
944 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
946 return OpLowerer(M, DRM, DRTM, MMDI).lowerIntrinsics();
948 StringRef getPassName()
const override {
return "DXIL Op Lowering"; }
962char DXILOpLoweringLegacy::ID = 0;
973 return new DXILOpLoweringLegacy();
for(const MachineOperand &MO :llvm::drop_begin(OldMI.operands(), Desc.getNumOperands()))
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
DXIL Resource Implicit Binding
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
ModuleAnalysisManager MAM
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file defines the SmallVector class.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
This class represents an Operation in the Expression.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
iterator find(const CallInst *Key)
A parsed version of the target data layout string in and methods for querying it.
Diagnostic information for unsupported feature in backend.
Lightweight error class with error context and mandatory checking.
static ErrorSuccess success()
Create a success value.
Tagged union holding either a T or a Error.
Error takeError()
Take ownership of the stored error.
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
IntegerType * getInt1Ty()
Fetch the type representing a single bit.
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Value * CreateZExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a ZExt or Trunc from the integer value V to DestTy.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Value * CreateInBoundsGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
This instruction inserts a single (scalar) element into a VectorType value.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
A Module instance is used to store all the information related to an LLVM module.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
iterator erase(const_iterator CI)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
'undef' values are things that do not have specified contents.
static LLVM_ABI UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
user_iterator user_begin()
bool hasOneUse() const
Return true if there is exactly one use of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
iterator_range< use_iterator > uses()
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Represents a version number in the form major[.minor[.subminor[.build]]].
StructType * getResRetType(Type *ElementTy)
Get a dx.types.ResRet type with the given element type.
Expected< CallInst * > tryCreateOp(dxil::OpCode Op, ArrayRef< Value * > Args, const Twine &Name="", Type *RetTy=nullptr)
Try to create a call instruction for the given DXIL op.
Constant * getResBind(uint32_t LowerBound, uint32_t UpperBound, uint32_t SpaceID, dxil::ResourceClass RC)
Get a constant dx.types.ResBind value.
Constant * getResProps(uint32_t Word0, uint32_t Word1)
Get a constant dx.types.ResourceProperties value.
StructType * getHandleType()
Get the dx.types.Handle type.
StructType * getCBufRetType(Type *ElementTy)
Get a dx.types.CBufRet type with the given element type.
TargetExtType * getHandleTy() const
LLVM_ABI std::pair< uint32_t, uint32_t > getAnnotateProps(Module &M, dxil::ResourceTypeInfo &RTI) const
const ResourceBinding & getBinding() const
dxil::ResourceClass getResourceClass() const
Wrapper pass for the legacy pass manager.
Wrapper pass for the legacy pass manager.
An efficient, type-erasing, non-owning reference to a callable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
NodeAddr< DefNode * > Def
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI std::error_code inconvertibleErrorCode()
The value returned by this function can be returned from convertToErrorCode for Error values where no...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
auto unique(Range &&R, Predicate P)
auto formatv(bool Validate, const char *Fmt, Ts &&...Vals)
decltype(auto) get(const PointerIntPair< PointerTy, IntBits, IntType, PtrTraits, Info > &Pair)
void sort(IteratorTy Start, IteratorTy End)
constexpr std::underlying_type_t< Enum > to_underlying(Enum E)
Returns underlying integer value of an enum.
ModulePass * createDXILOpLoweringLegacyPass()
Pass to lowering LLVM intrinsic call to DXIL op function call.
const char * toString(DWARFSectionKind Kind)
This struct is a compact representation of a valid (non-zero power of two) alignment.