28#include "llvm/IR/IntrinsicsARM.h"
38using namespace PatternMatch;
40#define DEBUG_TYPE "arm-parallel-dsp"
42STATISTIC(NumSMLAD ,
"Number of smlad instructions generated");
46 cl::desc(
"Disable the ARM Parallel DSP pass"));
50 cl::desc(
"Limit the number of loads analysed"));
66 bool Exchange =
false;
73 bool HasTwoLoadInputs()
const {
74 return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
103 if (
auto *SExt = dyn_cast<SExtInst>(V)) {
104 if (
auto *
I = dyn_cast<Instruction>(SExt->getOperand(0)))
105 if (
I->getOpcode() == Instruction::Mul)
107 }
else if (
auto *
I = dyn_cast<Instruction>(V)) {
108 if (
I->getOpcode() == Instruction::Mul)
115 Value *
LHS = cast<Instruction>(
I->getOperand(0))->getOperand(0);
116 Value *
RHS = cast<Instruction>(
I->getOperand(1))->getOperand(0);
117 Muls.push_back(std::make_unique<MulCandidate>(
I, LHS, RHS));
120 for (
auto *
Add : Adds) {
123 if (
auto *
Mul = GetMulOperand(
Add->getOperand(0)))
125 if (
auto *
Mul = GetMulOperand(
Add->getOperand(1)))
133 bool InsertAcc(
Value *V) {
142 void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1,
143 bool Exchange =
false) {
145 << *Mul0->Root <<
"\n"
146 << *Mul1->Root <<
"\n");
150 Mul1->Exchange =
true;
151 MulPairs.push_back(std::make_pair(Mul0, Mul1));
162 Value *getAccumulator() {
return Acc; }
169 MulCandList &getMuls() {
return Muls; }
173 MulPairList &getMulPairs() {
return MulPairs; }
182 for (
auto *
Add : Adds)
184 for (
auto &
Mul : Muls)
186 <<
" " << *
Mul->LHS <<
"\n"
187 <<
" " << *
Mul->RHS <<
"\n");
214 std::map<LoadInst*, LoadInst*> LoadPairs;
216 std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
219 bool IsNarrowSequence(
Value *V);
222 void InsertParallelMACs(Reduction &Reduction);
225 bool CreateParallelPairs(Reduction &R);
258 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
259 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
260 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
261 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
262 auto &TPC = getAnalysis<TargetPassConfig>();
265 DL = &
M->getDataLayout();
270 if (!
ST->allowsUnalignedMem()) {
271 LLVM_DEBUG(
dbgs() <<
"Unaligned memory access not supported: not "
272 "running pass ARMParallelDSP\n");
277 LLVM_DEBUG(
dbgs() <<
"DSP extension not enabled: not running pass "
282 if (!
ST->isLittle()) {
283 LLVM_DEBUG(
dbgs() <<
"Only supporting little endian: not running pass "
284 <<
"ARMParallelDSP\n");
291 bool Changes = MatchSMLAD(
F);
298 MemInstList &VecMem) {
302 if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
311 VecMem.push_back(Ld0);
312 VecMem.push_back(Ld1);
321template<
unsigned MaxBitW
idth>
322bool ARMParallelDSP::IsNarrowSequence(
Value *V) {
323 if (
auto *SExt = dyn_cast<SExtInst>(V)) {
324 if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
327 if (
auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
329 return LoadPairs.count(Ld) || OffsetLoads.
count(Ld);
337bool ARMParallelDSP::RecordMemoryOps(
BasicBlock *BB) {
346 for (
auto &
I : *BB) {
347 if (
I.mayWriteToMemory())
349 auto *Ld = dyn_cast<LoadInst>(&
I);
350 if (!Ld || !Ld->isSimple() ||
351 !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
359 using InstSet = std::set<Instruction*>;
360 using DepMap = std::map<Instruction*, InstSet>;
365 for (
auto *Write :
Writes) {
366 for (
auto *Read : Loads) {
372 if (
Write->comesBefore(Read))
373 RAWDeps[
Read].insert(Write);
384 if (RAWDeps.count(Dominated)) {
385 InstSet &WritesBefore = RAWDeps[Dominated];
387 for (
auto *
Before : WritesBefore) {
398 for (
auto *
Base : Loads) {
399 for (
auto *
Offset : Loads) {
413 dbgs() <<
"Consecutive load pairs:\n";
414 for (auto &MapIt : LoadPairs) {
415 LLVM_DEBUG(dbgs() << *MapIt.first <<
", "
416 << *MapIt.second <<
"\n");
419 return LoadPairs.size() > 1;
430 auto *
I = dyn_cast<Instruction>(V);
432 return R.InsertAcc(V);
434 if (
I->getParent() != BB)
437 switch (
I->getOpcode()) {
440 case Instruction::PHI:
442 return R.InsertAcc(V);
443 case Instruction::Add: {
450 bool ValidLHS = Search(LHS, BB, R);
451 bool ValidRHS = Search(RHS, BB, R);
453 if (ValidLHS && ValidRHS)
457 if (
R.getRoot() ==
I)
460 return R.InsertAcc(
I);
462 case Instruction::Mul: {
463 Value *MulOp0 =
I->getOperand(0);
464 Value *MulOp1 =
I->getOperand(1);
465 return IsNarrowSequence<16>(MulOp0) && IsNarrowSequence<16>(MulOp1);
467 case Instruction::SExt:
468 return Search(
I->getOperand(0), BB, R);
504bool ARMParallelDSP::MatchSMLAD(
Function &
F) {
505 bool Changed =
false;
509 if (!RecordMemoryOps(&BB))
513 if (
I.getOpcode() != Instruction::Add)
519 const auto *Ty =
I.getType();
520 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
524 if (!Search(&
I, &BB, R))
530 if (!CreateParallelPairs(R))
533 InsertParallelMACs(R);
535 AllAdds.
insert(
R.getAdds().begin(),
R.getAdds().end());
536 LLVM_DEBUG(
dbgs() <<
"BB after inserting parallel MACs:\n" << BB);
543bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
546 if (
R.getMuls().size() < 2)
550 for (
auto &MulCand :
R.getMuls()) {
551 if (!MulCand->HasTwoLoadInputs())
555 auto CanPair = [&](
Reduction &
R, MulCandidate *PMul0, MulCandidate *PMul1) {
560 auto Ld0 =
static_cast<LoadInst*
>(PMul0->LHS);
561 auto Ld1 =
static_cast<LoadInst*
>(PMul1->LHS);
562 auto Ld2 =
static_cast<LoadInst*
>(PMul0->RHS);
563 auto Ld3 =
static_cast<LoadInst*
>(PMul1->RHS);
566 if (Ld0 == Ld2 || Ld1 == Ld3)
569 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
570 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
572 R.AddMulPair(PMul0, PMul1);
574 }
else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
577 R.AddMulPair(PMul0, PMul1,
true);
580 }
else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
581 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
586 R.AddMulPair(PMul1, PMul0,
true);
592 MulCandList &Muls =
R.getMuls();
593 const unsigned Elems = Muls.size();
594 for (
unsigned i = 0; i < Elems; ++i) {
595 MulCandidate *PMul0 =
static_cast<MulCandidate*
>(Muls[i].get());
599 for (
unsigned j = 0;
j < Elems; ++
j) {
603 MulCandidate *PMul1 =
static_cast<MulCandidate*
>(Muls[
j].get());
612 assert(PMul0 != PMul1 &&
"expected different chains");
614 if (CanPair(R, PMul0, PMul1))
618 return !
R.getMulPairs().empty();
621void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
624 Value *Acc,
bool Exchange,
628 Value*
Args[] = { WideLd0, WideLd1, Acc };
648 assert((isa<Instruction>(
A) || isa<Instruction>(
B)) &&
649 "expected at least one instruction");
652 if (!isa<Instruction>(
A))
654 else if (!isa<Instruction>(
B))
657 V = DT->
dominates(cast<Instruction>(
A), cast<Instruction>(
B)) ?
B :
A;
662 Value *Acc =
R.getAccumulator();
667 MulCandList &MulCands =
R.getMuls();
668 for (
auto &MulCand : MulCands) {
676 assert(
R.is64Bit() &&
"expected 64-bit result");
678 Mul = cast<Instruction>(Builder.CreateSExt(
Mul,
R.getRoot()->getType()));
689 Builder.SetInsertPoint(GetInsertPoint(
Mul, Acc));
690 Acc = Builder.CreateAdd(
Mul, Acc);
697 }
else if (Acc->
getType() !=
R.getType()) {
698 Builder.SetInsertPoint(
R.getRoot());
699 Acc = Builder.CreateSExt(Acc,
R.getType());
703 llvm::sort(
R.getMulPairs(), [](
auto &PairA,
auto &PairB) {
704 const Instruction *A = PairA.first->Root;
705 const Instruction *B = PairB.first->Root;
706 return A->comesBefore(B);
710 for (
auto &Pair :
R.getMulPairs()) {
711 MulCandidate *LHSMul = Pair.first;
712 MulCandidate *RHSMul = Pair.second;
713 LoadInst *BaseLHS = LHSMul->getBaseLoad();
714 LoadInst *BaseRHS = RHSMul->getBaseLoad();
715 LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
716 WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
717 LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
718 WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
720 Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS);
721 InsertAfter = GetInsertPoint(InsertAfter, Acc);
722 Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
724 R.UpdateRoot(cast<Instruction>(Acc));
727LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads,
729 assert(Loads.size() == 2 &&
"currently only support widening two loads");
737 assert((BaseSExt && OffsetSExt)
738 &&
"Loads should have a single, extending, user");
740 std::function<void(
Value*,
Value*)> MoveBefore =
742 if (!isa<Instruction>(
A) || !isa<Instruction>(
B))
745 auto *
Source = cast<Instruction>(
A);
746 auto *
Sink = cast<Instruction>(
B);
750 isa<PHINode>(Source) || isa<PHINode>(Sink))
755 MoveBefore(
Op, Source);
766 Value *VecPtr =
Base->getPointerOperand();
767 LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
Base->getAlign());
770 MoveBefore(
Base->getPointerOperand(), VecPtr);
771 MoveBefore(VecPtr, WideLoad);
776 Value *Bottom = IRB.CreateTrunc(WideLoad,
Base->getType());
777 Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->
getType());
782 Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
783 Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
784 Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->
getType());
789 <<
"Created Wide Load:\n"
792 << *NewBaseSExt <<
"\n"
795 << *NewOffsetSExt <<
"\n");
796 WideLoads.emplace(std::make_pair(
Base,
797 std::make_unique<WidenedLoad>(Loads, WideLoad)));
802 return new ARMParallelDSP();
805char ARMParallelDSP::ID = 0;
808 "Transform functions to use DSP intrinsics",
false,
false)
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Lower uses of LDS variables from non kernel functions
arm parallel Transform functions to use DSP intrinsics
static cl::opt< bool > DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false), cl::desc("Disable the ARM Parallel DSP pass"))
static cl::opt< unsigned > NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16), cl::desc("Limit the number of loads analysed"))
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
SmallVector< uint32_t, 0 > Writes
This is the interface for a simple mod/ref and alias analysis over globals.
Move duplicate certain instructions close to their use
loop Loop Strength Reduction
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallPtrSet class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static SymbolRef::Type getType(const Symbol *Sym)
Target-Independent Code Generator Pass Configuration Options pass.
static bool is64Bit(const char *name)
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
ModRefInfo getModRefInfo(const Instruction *I, const std::optional< MemoryLocation > &OptLoc)
Check whether or not an instruction may read or write the optionally specified memory location.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
An immutable pass that tracks lazily created AssumptionCache objects.
LLVM Basic Block Representation.
InstListType::iterator iterator
Instruction iterators...
This is the shared class of boolean and integer constants.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
bool skipFunction(const Function &F) const
Optional passes call this function to check whether the pass should be skipped.
PointerType * getType() const
Global values are always pointers.
Legacy wrapper pass to provide the GlobalsAAResult object.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const BasicBlock * getParent() const
bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
Class to represent integer types.
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
An instruction for reading from memory.
static constexpr LocationSize beforeOrAfterPointer()
Any location before or after the base pointer (but still within the underlying object).
Representation for a specific memory location.
A Module instance is used to store all the information related to an LLVM module.
Pass interface - Implemented by all 'passes'.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
The main scalar evolution driver.
A vector that has set insertion semantics.
bool insert(const value_type &X)
Insert a new element into the SetVector.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isIntegerTy() const
True if this is an instance of IntegerType.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
void dump() const
Support for debugging, callable in GDB: V->dump()
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
auto reverse(ContainerTy &&C)
decltype(auto) get(const PointerIntPair< PointerTy, IntBits, IntType, PtrTraits, Info > &Pair)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool isModOrRefSet(const ModRefInfo MRI)
bool isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL, ScalarEvolution &SE, bool CheckType=true)
Returns true if the memory operations A and B are consecutive.
Pass * createARMParallelDSPPass()