29#include "llvm/IR/IntrinsicsARM.h"
39using namespace PatternMatch;
41#define DEBUG_TYPE "arm-parallel-dsp"
43STATISTIC(NumSMLAD ,
"Number of smlad instructions generated");
47 cl::desc(
"Disable the ARM Parallel DSP pass"));
51 cl::desc(
"Limit the number of loads analysed"));
67 bool Exchange =
false;
74 bool HasTwoLoadInputs()
const {
75 return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
104 if (
auto *SExt = dyn_cast<SExtInst>(V)) {
105 if (
auto *
I = dyn_cast<Instruction>(SExt->getOperand(0)))
106 if (
I->getOpcode() == Instruction::Mul)
108 }
else if (
auto *
I = dyn_cast<Instruction>(V)) {
109 if (
I->getOpcode() == Instruction::Mul)
116 Value *
LHS = cast<Instruction>(
I->getOperand(0))->getOperand(0);
117 Value *
RHS = cast<Instruction>(
I->getOperand(1))->getOperand(0);
118 Muls.push_back(std::make_unique<MulCandidate>(
I, LHS, RHS));
121 for (
auto *
Add : Adds) {
124 if (
auto *
Mul = GetMulOperand(
Add->getOperand(0)))
126 if (
auto *
Mul = GetMulOperand(
Add->getOperand(1)))
134 bool InsertAcc(
Value *V) {
143 void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1,
144 bool Exchange =
false) {
146 << *Mul0->Root <<
"\n"
147 << *Mul1->Root <<
"\n");
151 Mul1->Exchange =
true;
152 MulPairs.push_back(std::make_pair(Mul0, Mul1));
163 Value *getAccumulator() {
return Acc; }
170 MulCandList &getMuls() {
return Muls; }
174 MulPairList &getMulPairs() {
return MulPairs; }
183 for (
auto *
Add : Adds)
185 for (
auto &
Mul : Muls)
187 <<
" " << *
Mul->LHS <<
"\n"
188 <<
" " << *
Mul->RHS <<
"\n");
215 std::map<LoadInst*, LoadInst*> LoadPairs;
217 std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
220 bool IsNarrowSequence(
Value *V);
223 void InsertParallelMACs(Reduction &Reduction);
226 bool CreateParallelPairs(Reduction &R);
259 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
260 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
261 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
262 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
263 auto &TPC = getAnalysis<TargetPassConfig>();
266 DL = &
M->getDataLayout();
271 if (!
ST->allowsUnalignedMem()) {
272 LLVM_DEBUG(
dbgs() <<
"Unaligned memory access not supported: not "
273 "running pass ARMParallelDSP\n");
278 LLVM_DEBUG(
dbgs() <<
"DSP extension not enabled: not running pass "
283 if (!
ST->isLittle()) {
284 LLVM_DEBUG(
dbgs() <<
"Only supporting little endian: not running pass "
285 <<
"ARMParallelDSP\n");
292 bool Changes = MatchSMLAD(
F);
299 MemInstList &VecMem) {
303 if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
312 VecMem.push_back(Ld0);
313 VecMem.push_back(Ld1);
322template<
unsigned MaxBitW
idth>
323bool ARMParallelDSP::IsNarrowSequence(
Value *V) {
324 if (
auto *SExt = dyn_cast<SExtInst>(V)) {
325 if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
328 if (
auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
330 return LoadPairs.count(Ld) || OffsetLoads.
count(Ld);
338bool ARMParallelDSP::RecordMemoryOps(
BasicBlock *BB) {
347 for (
auto &
I : *BB) {
348 if (
I.mayWriteToMemory())
350 auto *Ld = dyn_cast<LoadInst>(&
I);
351 if (!Ld || !Ld->isSimple() ||
352 !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
360 using InstSet = std::set<Instruction*>;
361 using DepMap = std::map<Instruction*, InstSet>;
367 for (
auto *
Read : Loads) {
385 if (RAWDeps.count(Dominated)) {
386 InstSet &WritesBefore = RAWDeps[Dominated];
388 for (
auto *
Before : WritesBefore) {
399 for (
auto *
Base : Loads) {
400 for (
auto *
Offset : Loads) {
414 dbgs() <<
"Consecutive load pairs:\n";
415 for (auto &MapIt : LoadPairs) {
416 LLVM_DEBUG(dbgs() << *MapIt.first <<
", "
417 << *MapIt.second <<
"\n");
420 return LoadPairs.size() > 1;
431 auto *
I = dyn_cast<Instruction>(V);
433 return R.InsertAcc(V);
435 if (
I->getParent() != BB)
438 switch (
I->getOpcode()) {
441 case Instruction::PHI:
443 return R.InsertAcc(V);
444 case Instruction::Add: {
451 bool ValidLHS = Search(LHS, BB, R);
452 bool ValidRHS = Search(RHS, BB, R);
454 if (ValidLHS && ValidRHS)
458 if (
R.getRoot() ==
I)
461 return R.InsertAcc(
I);
463 case Instruction::Mul: {
464 Value *MulOp0 =
I->getOperand(0);
465 Value *MulOp1 =
I->getOperand(1);
466 return IsNarrowSequence<16>(MulOp0) && IsNarrowSequence<16>(MulOp1);
468 case Instruction::SExt:
469 return Search(
I->getOperand(0), BB, R);
505bool ARMParallelDSP::MatchSMLAD(
Function &
F) {
506 bool Changed =
false;
510 if (!RecordMemoryOps(&BB))
514 if (
I.getOpcode() != Instruction::Add)
520 const auto *Ty =
I.getType();
521 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
525 if (!Search(&
I, &BB, R))
531 if (!CreateParallelPairs(R))
534 InsertParallelMACs(R);
536 AllAdds.
insert(
R.getAdds().begin(),
R.getAdds().end());
537 LLVM_DEBUG(
dbgs() <<
"BB after inserting parallel MACs:\n" << BB);
544bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
547 if (
R.getMuls().size() < 2)
551 for (
auto &MulCand :
R.getMuls()) {
552 if (!MulCand->HasTwoLoadInputs())
556 auto CanPair = [&](
Reduction &
R, MulCandidate *PMul0, MulCandidate *PMul1) {
561 auto Ld0 =
static_cast<LoadInst*
>(PMul0->LHS);
562 auto Ld1 =
static_cast<LoadInst*
>(PMul1->LHS);
563 auto Ld2 =
static_cast<LoadInst*
>(PMul0->RHS);
564 auto Ld3 =
static_cast<LoadInst*
>(PMul1->RHS);
567 if (Ld0 == Ld2 || Ld1 == Ld3)
570 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
571 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
573 R.AddMulPair(PMul0, PMul1);
575 }
else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
578 R.AddMulPair(PMul0, PMul1,
true);
581 }
else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
582 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
587 R.AddMulPair(PMul1, PMul0,
true);
593 MulCandList &Muls =
R.getMuls();
594 const unsigned Elems = Muls.size();
595 for (
unsigned i = 0; i < Elems; ++i) {
596 MulCandidate *PMul0 =
static_cast<MulCandidate*
>(Muls[i].get());
600 for (
unsigned j = 0;
j < Elems; ++
j) {
604 MulCandidate *PMul1 =
static_cast<MulCandidate*
>(Muls[
j].get());
613 assert(PMul0 != PMul1 &&
"expected different chains");
615 if (CanPair(R, PMul0, PMul1))
619 return !
R.getMulPairs().empty();
622void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
625 Value *Acc,
bool Exchange,
629 Value*
Args[] = { WideLd0, WideLd1, Acc };
650 assert((isa<Instruction>(
A) || isa<Instruction>(
B)) &&
651 "expected at least one instruction");
654 if (!isa<Instruction>(
A))
656 else if (!isa<Instruction>(
B))
659 V = DT->
dominates(cast<Instruction>(
A), cast<Instruction>(
B)) ?
B :
A;
664 Value *Acc =
R.getAccumulator();
669 MulCandList &MulCands =
R.getMuls();
670 for (
auto &MulCand : MulCands) {
678 assert(
R.is64Bit() &&
"expected 64-bit result");
680 Mul = cast<Instruction>(Builder.CreateSExt(
Mul,
R.getRoot()->getType()));
691 Builder.SetInsertPoint(GetInsertPoint(
Mul, Acc));
692 Acc = Builder.CreateAdd(
Mul, Acc);
699 }
else if (Acc->
getType() !=
R.getType()) {
700 Builder.SetInsertPoint(
R.getRoot());
701 Acc = Builder.CreateSExt(Acc,
R.getType());
705 llvm::sort(
R.getMulPairs(), [](
auto &PairA,
auto &PairB) {
706 const Instruction *A = PairA.first->Root;
707 const Instruction *B = PairB.first->Root;
708 return A->comesBefore(B);
712 for (
auto &Pair :
R.getMulPairs()) {
713 MulCandidate *LHSMul = Pair.first;
714 MulCandidate *RHSMul = Pair.second;
715 LoadInst *BaseLHS = LHSMul->getBaseLoad();
716 LoadInst *BaseRHS = RHSMul->getBaseLoad();
717 LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
718 WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
719 LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
720 WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
722 Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS);
723 InsertAfter = GetInsertPoint(InsertAfter, Acc);
724 Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
726 R.UpdateRoot(cast<Instruction>(Acc));
729LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads,
731 assert(Loads.size() == 2 &&
"currently only support widening two loads");
739 assert((BaseSExt && OffsetSExt)
740 &&
"Loads should have a single, extending, user");
742 std::function<void(
Value*,
Value*)> MoveBefore =
744 if (!isa<Instruction>(
A) || !isa<Instruction>(
B))
747 auto *
Source = cast<Instruction>(
A);
748 auto *
Sink = cast<Instruction>(
B);
752 isa<PHINode>(Source) || isa<PHINode>(Sink))
757 MoveBefore(
Op, Source);
768 Value *VecPtr =
Base->getPointerOperand();
769 LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
Base->getAlign());
772 MoveBefore(
Base->getPointerOperand(), VecPtr);
773 MoveBefore(VecPtr, WideLoad);
778 Value *Bottom = IRB.CreateTrunc(WideLoad,
Base->getType());
779 Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->
getType());
784 Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
785 Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
786 Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->
getType());
791 <<
"Created Wide Load:\n"
794 << *NewBaseSExt <<
"\n"
797 << *NewOffsetSExt <<
"\n");
798 WideLoads.emplace(std::make_pair(
Base,
799 std::make_unique<WidenedLoad>(Loads, WideLoad)));
804 return new ARMParallelDSP();
807char ARMParallelDSP::ID = 0;
810 "Transform functions to use DSP intrinsics",
false,
false)
Lower uses of LDS variables from non kernel functions
arm parallel Transform functions to use DSP intrinsics
static cl::opt< bool > DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false), cl::desc("Disable the ARM Parallel DSP pass"))
static cl::opt< unsigned > NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16), cl::desc("Limit the number of loads analysed"))
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static DeltaTreeNode * getRoot(void *Root)
SmallVector< uint32_t, 0 > Writes
This is the interface for a simple mod/ref and alias analysis over globals.
Module.h This file contains the declarations for the Module class.
Move duplicate certain instructions close to their use
loop Loop Strength Reduction
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallPtrSet class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static SymbolRef::Type getType(const Symbol *Sym)
Target-Independent Code Generator Pass Configuration Options pass.
static bool is64Bit(const char *name)
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
ModRefInfo getModRefInfo(const Instruction *I, const std::optional< MemoryLocation > &OptLoc)
Check whether or not an instruction may read or write the optionally specified memory location.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
An immutable pass that tracks lazily created AssumptionCache objects.
LLVM Basic Block Representation.
InstListType::iterator iterator
Instruction iterators...
This is the shared class of boolean and integer constants.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
bool skipFunction(const Function &F) const
Optional passes call this function to check whether the pass should be skipped.
PointerType * getType() const
Global values are always pointers.
Legacy wrapper pass to provide the GlobalsAAResult object.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
Class to represent integer types.
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
An instruction for reading from memory.
static constexpr LocationSize beforeOrAfterPointer()
Any location before or after the base pointer (but still within the underlying object).
Representation for a specific memory location.
A Module instance is used to store all the information related to an LLVM module.
Pass interface - Implemented by all 'passes'.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
The main scalar evolution driver.
A vector that has set insertion semantics.
bool insert(const value_type &X)
Insert a new element into the SetVector.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isIntegerTy() const
True if this is an instance of IntegerType.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
void dump() const
Support for debugging, callable in GDB: V->dump()
const ParentTy * getParent() const
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Function * getOrInsertDeclaration(Module *M, ID id, ArrayRef< Type * > Tys={})
Look up the Function declaration of the intrinsic id in the Module M.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
auto reverse(ContainerTy &&C)
decltype(auto) get(const PointerIntPair< PointerTy, IntBits, IntType, PtrTraits, Info > &Pair)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool isModOrRefSet(const ModRefInfo MRI)
bool isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL, ScalarEvolution &SE, bool CheckType=true)
Returns true if the memory operations A and B are consecutive.
Pass * createARMParallelDSPPass()