25#include "llvm/IR/IntrinsicsDirectX.h"
32#define DEBUG_TYPE "dxil-op-lower"
38 switch (
F.getIntrinsicID()) {
39 case Intrinsic::dx_dot2:
40 case Intrinsic::dx_dot3:
41 case Intrinsic::dx_dot4:
49 auto *VecArg = dyn_cast<FixedVectorType>(Arg->
getType());
50 for (
unsigned I = 0;
I < VecArg->getNumElements(); ++
I) {
53 ExtractedElements.
push_back(ExtractedElement);
55 return ExtractedElements;
64 [[maybe_unused]]
auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->
getType());
67 for (
unsigned I = 1;
I < NumOperands; ++
I) {
69 [[maybe_unused]]
auto *VecArg = dyn_cast<FixedVectorType>(Arg->
getType());
71 assert(VecArg0->getElementType() == VecArg->getElementType());
72 assert(VecArg0->getNumElements() == VecArg->getNumElements());
74 NewOperands.
append(NextOperandList.begin(), NextOperandList.end());
89 :
M(
M), OpBuilder(
M), DBM(DBM), DRTM(DRTM) {}
97 CallInst *CI = dyn_cast<CallInst>(U);
101 if (
Error E = ReplaceCall(CI)) {
102 std::string Message(
toString(std::move(E)));
105 M.getContext().diagnose(Diag);
114 struct IntrinArgSelect {
116#define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name,
117#include "DXILOperation.inc"
127 assert(!(IsVectorArgExpansion && ArgSelects.
size()) &&
128 "Cann't do vector arg expansion when using arg selects.");
130 OpBuilder.
getIRB().SetInsertPoint(CI);
132 if (ArgSelects.
size()) {
133 for (const IntrinArgSelect &A : ArgSelects) {
135 case IntrinArgSelect::Type::Index:
136 Args.push_back(CI->getArgOperand(A.Value));
138 case IntrinArgSelect::Type::I8:
139 Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value));
141 case IntrinArgSelect::Type::I32:
142 Args.push_back(OpBuilder.getIRB().getInt32(A.Value));
146 }
else if (IsVectorArgExpansion) {
163 [[nodiscard]]
bool replaceFunctionWithNamedStructOp(
169 OpBuilder.
getIRB().SetInsertPoint(CI);
170 if (IsVectorArgExpansion) {
180 if (
Error E = ReplaceUses(CI, *OpCall))
194 Intrinsic::dx_resource_casthandle, {Ty,
V->getType()}, {
V});
199 void cleanupHandleCasts() {
203 for (
CallInst *Cast : CleanupCasts) {
219 assert(
Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle &&
220 "Unbalanced pair of temporary handle casts");
233 F->eraseFromParent();
235 CleanupCasts.clear();
243 void removeResourceGlobals(
CallInst *CI) {
247 Store->eraseFromParent();
249 if (GV->use_empty()) {
250 GV->removeDeadConstantUsers();
251 GV->eraseFromParent();
257 [[nodiscard]]
bool lowerToCreateHandle(
Function &
F) {
265 auto *It = DBM.
find(CI);
266 assert(It != DBM.
end() &&
"Resource not in map?");
273 if (Binding.LowerBound != 0)
275 ConstantInt::get(Int32Ty, Binding.LowerBound));
277 std::array<Value *, 4>
Args{
279 ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp,
286 Value *Cast = createTmpHandleCast(*OpCall, CI->
getType());
288 removeResourceGlobals(CI);
296 [[nodiscard]]
bool lowerToBindAndAnnotateHandle(
Function &
F) {
303 auto *It = DBM.
find(CI);
304 assert(It != DBM.
end() &&
"Resource not in map?");
312 if (Binding.LowerBound != 0)
314 ConstantInt::get(Int32Ty, Binding.LowerBound));
316 std::pair<uint32_t, uint32_t> Props =
321 uint32_t Unbounded = std::numeric_limits<uint32_t>::max();
322 uint32_t UpperBound = Binding.Size == Unbounded
324 : Binding.LowerBound + Binding.Size - 1;
327 std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->
getArgOperand(4)};
329 OpCode::CreateHandleFromBinding, BindArgs, CI->
getName());
333 std::array<Value *, 2> AnnotateArgs{
334 *OpBind, OpBuilder.
getResProps(Props.first, Props.second)};
336 OpCode::AnnotateHandle, AnnotateArgs,
341 Value *Cast = createTmpHandleCast(*OpAnnotate, CI->
getType());
343 removeResourceGlobals(CI);
355 bool lowerHandleFromBinding(
Function &
F) {
358 return lowerToCreateHandle(
F);
359 return lowerToBindAndAnnotateHandle(
F);
364 if (
auto *EVI = dyn_cast<ExtractValueInst>(
U.getUser())) {
366 if (EVI->getNumIndices() != 1)
368 "Splitdouble has only 2 elements");
369 EVI->setOperand(0,
Op);
371 return make_error<StringError>(
372 "Splitdouble use is not ExtractValueInst",
391 auto *
ST = cast<StructType>(OldTy);
393 Value *CheckOp =
nullptr;
396 if (
auto *EVI = dyn_cast<ExtractValueInst>(
U.getUser())) {
405 OpCode::CheckAccessFullyMapped, {NewEVI},
413 EVI->replaceAllUsesWith(CheckOp);
414 EVI->eraseFromParent();
425 isa<ExtractValueInst>(*OldResult->
user_begin()) &&
426 "Expected only use to be extract of first element");
427 OldResult = cast<Instruction>(*OldResult->
user_begin());
428 OldTy =
ST->getElementType(0);
432 if (!isa<FixedVectorType>(OldTy)) {
436 if (OldResult != Intrin) {
443 std::array<Value *, 4> Extracts = {};
449 if (
auto *EEI = dyn_cast<ExtractElementInst>(
U.getUser())) {
450 if (
auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
451 size_t IndexVal = IndexOp->getZExtValue();
452 assert(IndexVal < 4 &&
"Index into buffer load out of range");
453 if (!Extracts[IndexVal])
456 EEI->eraseFromParent();
463 const auto *VecTy = cast<FixedVectorType>(OldTy);
464 const unsigned N = VecTy->getNumElements();
468 if (!DynamicAccesses.
empty()) {
472 Type *ElTy = VecTy->getElementType();
473 Type *ArrayTy = ArrayType::get(ElTy,
N);
476 for (
int I = 0, E =
N;
I != E; ++
I) {
480 ArrayTy, Alloca, {
Zero, ConstantInt::get(Int32Ty,
I)});
486 {
Zero, EEI->getIndexOperand()});
489 EEI->eraseFromParent();
497 for (
int I = 0, E =
N;
I != E; ++
I)
502 for (
int I = 0, E =
N;
I != E; ++
I)
508 if (OldResult != Intrin) {
516 [[nodiscard]]
bool lowerTypedBufferLoad(
Function &
F,
bool HasCheckBit) {
530 OldTy = cast<StructType>(OldTy)->getElementType(0);
533 std::array<Value *, 3>
Args{Handle, Index0, Index1};
535 OpCode::BufferLoad, Args, CI->
getName(), NewRetTy);
538 if (
Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
545 [[nodiscard]]
bool lowerRawBufferLoad(
Function &
F) {
556 Type *OldTy = cast<StructType>(CI->
getType())->getElementType(0);
565 DL.getTypeSizeInBits(OldTy) /
DL.getTypeSizeInBits(ScalarTy);
566 Value *
Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
568 ConstantInt::get(Int32Ty,
DL.getPrefTypeAlign(ScalarTy).value());
576 {Handle, Index0, Index1}, CI->
getName(),
580 if (
Error E = replaceResRetUses(CI, *OpCall,
true))
587 [[nodiscard]]
bool lowerUpdateCounter(
Function &
F) {
597 std::array<Value *, 2>
Args{Handle, Op1};
600 OpCode::UpdateCounter, Args, CI->
getName(), Int32Ty);
611 [[nodiscard]]
bool lowerGetPointer(
Function &
F) {
614 assert(
F.user_empty() &&
"getpointer operations should have been removed");
619 [[nodiscard]]
bool lowerTypedBufferStore(
Function &
F) {
635 auto *DataTy = dyn_cast<FixedVectorType>(Data->getType());
636 if (!DataTy || DataTy->getNumElements() != 4)
637 return make_error<StringError>(
638 "typedBufferStore data must be a vector of 4 elements",
644 std::array<Value *, 4> DataElements{
nullptr,
nullptr,
nullptr,
nullptr};
645 auto *IEI = dyn_cast<InsertElementInst>(Data);
647 auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
650 size_t IndexVal = IndexOp->getZExtValue();
651 assert(IndexVal < 4 &&
"Too many elements for buffer store");
652 DataElements[IndexVal] = IEI->getOperand(1);
653 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
659 for (
int I = 0, E = 4;
I != E; ++
I)
660 if (DataElements[
I] ==
nullptr)
664 std::array<Value *, 8>
Args{
665 Handle, Index0, Index1, DataElements[0],
666 DataElements[1], DataElements[2], DataElements[3],
Mask};
674 IEI = dyn_cast<InsertElementInst>(Data);
675 while (IEI && IEI->use_empty()) {
677 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
685 [[nodiscard]]
bool lowerCtpopToCountBits(
Function &
F) {
695 Type *FRT =
F.getReturnType();
696 if (
const auto *VT = dyn_cast<VectorType>(FRT))
705 if (FRT->isIntOrIntVectorTy(32)) {
706 CI->replaceAllUsesWith(*OpCall);
707 CI->eraseFromParent();
708 return Error::success();
713 if (FRT->isIntOrIntVectorTy(16)) {
714 CastOp = Instruction::ZExt;
715 CastOp2 = Instruction::SExt;
717 assert(FRT->isIntOrIntVectorTy(64) &&
718 "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
720 CastOp = Instruction::Trunc;
721 CastOp2 = Instruction::Trunc;
726 bool NeedsCast =
false;
728 Instruction *I = dyn_cast<Instruction>(User);
729 if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
730 I->getType() == RetTy) {
731 I->replaceAllUsesWith(*OpCall);
732 I->eraseFromParent();
752 bool lowerIntrinsics() {
753 bool Updated =
false;
754 bool HasErrors =
false;
757 if (!
F.isDeclaration())
763#define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \
765 HasErrors |= replaceFunctionWithOp( \
766 F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__}); \
768#include "DXILOperation.inc"
769 case Intrinsic::dx_resource_handlefrombinding:
770 HasErrors |= lowerHandleFromBinding(
F);
772 case Intrinsic::dx_resource_getpointer:
773 HasErrors |= lowerGetPointer(
F);
775 case Intrinsic::dx_resource_load_typedbuffer:
776 HasErrors |= lowerTypedBufferLoad(
F,
true);
778 case Intrinsic::dx_resource_store_typedbuffer:
779 HasErrors |= lowerTypedBufferStore(
F);
781 case Intrinsic::dx_resource_load_rawbuffer:
782 HasErrors |= lowerRawBufferLoad(
F);
784 case Intrinsic::dx_resource_updatecounter:
785 HasErrors |= lowerUpdateCounter(
F);
789 case Intrinsic::dx_splitdouble:
790 HasErrors |= replaceFunctionWithNamedStructOp(
791 F, OpCode::SplitDouble,
794 return replaceSplitDoubleCallUsages(CI, Op);
797 case Intrinsic::ctpop:
798 HasErrors |= lowerCtpopToCountBits(
F);
803 if (Updated && !HasErrors)
804 cleanupHandleCasts();
815 bool MadeChanges = OpLowerer(M, DBM, DRTM).lowerIntrinsics();
826class DXILOpLoweringLegacy :
public ModulePass {
828 bool runOnModule(
Module &M)
override {
830 getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
832 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
834 return OpLowerer(M, DBM, DRTM).lowerIntrinsics();
836 StringRef getPassName()
const override {
return "DXIL Op Lowering"; }
849char DXILOpLoweringLegacy::ID = 0;
860 return new DXILOpLoweringLegacy();
for(const MachineOperand &MO :llvm::drop_begin(OldMI.operands(), Desc.getNumOperands()))
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static bool isVectorArgExpansion(Function &F)
static SmallVector< Value * > argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder)
static SmallVector< Value * > populateOperands(Value *Arg, IRBuilder<> &Builder)
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
ModuleAnalysisManager MAM
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
This class represents an Operation in the Expression.
iterator find(const CallInst *Key)
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
The legacy pass manager's analysis pass to compute DXIL resource information.
A parsed version of the target data layout string in and methods for querying it.
Diagnostic information for unsupported feature in backend.
Lightweight error class with error context and mandatory checking.
static ErrorSuccess success()
Create a success value.
Tagged union holding either a T or a Error.
Error takeError()
Take ownership of the stored error.
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Value * CreateZExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a ZExt or Trunc from the integer value V to DestTy.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Value * CreateInBoundsGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
This instruction inserts a single (scalar) element into a VectorType value.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
const Function * getFunction() const
Return the function this instruction belongs to.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Triple - Helper class for working with autoconf configuration names.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
static IntegerType * getInt32Ty(LLVMContext &C)
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
unsigned getNumOperands() const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
user_iterator user_begin()
bool hasOneUse() const
Return true if there is exactly one use of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
LLVMContext & getContext() const
All values hold a context through their type.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
Represents a version number in the form major[.minor[.subminor[.build]]].
StructType * getResRetType(Type *ElementTy)
Get a dx.types.ResRet type with the given element type.
StructType * getSplitDoubleType(LLVMContext &Context)
Get the dx.types.splitdouble type.
Expected< CallInst * > tryCreateOp(dxil::OpCode Op, ArrayRef< Value * > Args, const Twine &Name="", Type *RetTy=nullptr)
Try to create a call instruction for the given DXIL op.
Constant * getResBind(uint32_t LowerBound, uint32_t UpperBound, uint32_t SpaceID, dxil::ResourceClass RC)
Get a constant dx.types.ResBind value.
Constant * getResProps(uint32_t Word0, uint32_t Word1)
Get a constant dx.types.ResourceProperties value.
StructType * getHandleType()
Get the dx.types.Handle type.
std::pair< uint32_t, uint32_t > getAnnotateProps(Module &M, dxil::ResourceTypeInfo &RTI) const
TargetExtType * getHandleTy() const
const ResourceBinding & getBinding() const
dxil::ResourceClass getResourceClass() const
Wrapper pass for the legacy pass manager.
An efficient, type-erasing, non-owning reference to a callable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
NodeAddr< DefNode * > Def
This is an optimization pass for GlobalISel generic memory operations.
std::error_code inconvertibleErrorCode()
The value returned by this function can be returned from convertToErrorCode for Error values where no...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
auto unique(Range &&R, Predicate P)
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
void sort(IteratorTy Start, IteratorTy End)
constexpr std::underlying_type_t< Enum > to_underlying(Enum E)
Returns underlying integer value of an enum.
ModulePass * createDXILOpLoweringLegacyPass()
Pass to lowering LLVM intrinsic call to DXIL op function call.
const char * toString(DWARFSectionKind Kind)
This struct is a compact representation of a valid (non-zero power of two) alignment.