29#include "llvm/IR/IntrinsicsARM.h"
40using namespace PatternMatch;
42#define DEBUG_TYPE "arm-parallel-dsp"
44STATISTIC(NumSMLAD ,
"Number of smlad instructions generated");
48 cl::desc(
"Disable the ARM Parallel DSP pass"));
52 cl::desc(
"Limit the number of loads analysed"));
68 bool Exchange =
false;
75 bool HasTwoLoadInputs()
const {
76 return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
105 if (
auto *SExt = dyn_cast<SExtInst>(V)) {
106 if (
auto *
I = dyn_cast<Instruction>(SExt->getOperand(0)))
107 if (
I->getOpcode() == Instruction::Mul)
109 }
else if (
auto *
I = dyn_cast<Instruction>(V)) {
110 if (
I->getOpcode() == Instruction::Mul)
117 Value *
LHS = cast<Instruction>(
I->getOperand(0))->getOperand(0);
118 Value *
RHS = cast<Instruction>(
I->getOperand(1))->getOperand(0);
119 Muls.push_back(std::make_unique<MulCandidate>(
I, LHS, RHS));
122 for (
auto *
Add : Adds) {
125 if (
auto *
Mul = GetMulOperand(
Add->getOperand(0)))
127 if (
auto *
Mul = GetMulOperand(
Add->getOperand(1)))
135 bool InsertAcc(
Value *V) {
144 void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1,
145 bool Exchange =
false) {
147 << *Mul0->Root <<
"\n"
148 << *Mul1->Root <<
"\n");
152 Mul1->Exchange =
true;
153 MulPairs.push_back(std::make_pair(Mul0, Mul1));
164 Value *getAccumulator() {
return Acc; }
171 MulCandList &getMuls() {
return Muls; }
175 MulPairList &getMulPairs() {
return MulPairs; }
184 for (
auto *
Add : Adds)
186 for (
auto &
Mul : Muls)
188 <<
" " << *
Mul->LHS <<
"\n"
189 <<
" " << *
Mul->RHS <<
"\n");
216 std::map<LoadInst*, LoadInst*> LoadPairs;
218 std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
221 bool IsNarrowSequence(
Value *V);
224 void InsertParallelMACs(Reduction &Reduction);
227 bool CreateParallelPairs(Reduction &R);
260 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
261 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
262 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
263 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
264 auto &TPC = getAnalysis<TargetPassConfig>();
267 DL = &
M->getDataLayout();
272 if (!
ST->allowsUnalignedMem()) {
273 LLVM_DEBUG(
dbgs() <<
"Unaligned memory access not supported: not "
274 "running pass ARMParallelDSP\n");
279 LLVM_DEBUG(
dbgs() <<
"DSP extension not enabled: not running pass "
284 if (!
ST->isLittle()) {
285 LLVM_DEBUG(
dbgs() <<
"Only supporting little endian: not running pass "
286 <<
"ARMParallelDSP\n");
293 bool Changes = MatchSMLAD(
F);
300 MemInstList &VecMem) {
304 if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
313 VecMem.push_back(Ld0);
314 VecMem.push_back(Ld1);
323template<
unsigned MaxBitW
idth>
324bool ARMParallelDSP::IsNarrowSequence(
Value *V) {
325 if (
auto *SExt = dyn_cast<SExtInst>(V)) {
326 if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
329 if (
auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
331 return LoadPairs.count(Ld) || OffsetLoads.
count(Ld);
339bool ARMParallelDSP::RecordMemoryOps(
BasicBlock *BB) {
348 for (
auto &
I : *BB) {
349 if (
I.mayWriteToMemory())
351 auto *Ld = dyn_cast<LoadInst>(&
I);
352 if (!Ld || !Ld->isSimple() ||
353 !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
361 using InstSet = std::set<Instruction*>;
362 using DepMap = std::map<Instruction*, InstSet>;
367 for (
auto *Write :
Writes) {
368 for (
auto *Read : Loads) {
374 if (
Write->comesBefore(Read))
375 RAWDeps[
Read].insert(Write);
386 if (RAWDeps.count(Dominated)) {
387 InstSet &WritesBefore = RAWDeps[Dominated];
389 for (
auto *
Before : WritesBefore) {
400 for (
auto *
Base : Loads) {
401 for (
auto *
Offset : Loads) {
415 dbgs() <<
"Consecutive load pairs:\n";
416 for (auto &MapIt : LoadPairs) {
417 LLVM_DEBUG(dbgs() << *MapIt.first <<
", "
418 << *MapIt.second <<
"\n");
421 return LoadPairs.size() > 1;
432 auto *
I = dyn_cast<Instruction>(V);
434 return R.InsertAcc(V);
436 if (
I->getParent() != BB)
439 switch (
I->getOpcode()) {
442 case Instruction::PHI:
444 return R.InsertAcc(V);
445 case Instruction::Add: {
452 bool ValidLHS = Search(LHS, BB, R);
453 bool ValidRHS = Search(RHS, BB, R);
455 if (ValidLHS && ValidRHS)
459 if (
R.getRoot() ==
I)
462 return R.InsertAcc(
I);
464 case Instruction::Mul: {
465 Value *MulOp0 =
I->getOperand(0);
466 Value *MulOp1 =
I->getOperand(1);
467 return IsNarrowSequence<16>(MulOp0) && IsNarrowSequence<16>(MulOp1);
469 case Instruction::SExt:
470 return Search(
I->getOperand(0), BB, R);
506bool ARMParallelDSP::MatchSMLAD(
Function &
F) {
507 bool Changed =
false;
511 if (!RecordMemoryOps(&BB))
515 if (
I.getOpcode() != Instruction::Add)
521 const auto *Ty =
I.getType();
522 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
526 if (!Search(&
I, &BB, R))
532 if (!CreateParallelPairs(R))
535 InsertParallelMACs(R);
537 AllAdds.
insert(
R.getAdds().begin(),
R.getAdds().end());
538 LLVM_DEBUG(
dbgs() <<
"BB after inserting parallel MACs:\n" << BB);
545bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
548 if (
R.getMuls().size() < 2)
552 for (
auto &MulCand :
R.getMuls()) {
553 if (!MulCand->HasTwoLoadInputs())
557 auto CanPair = [&](
Reduction &
R, MulCandidate *PMul0, MulCandidate *PMul1) {
562 auto Ld0 =
static_cast<LoadInst*
>(PMul0->LHS);
563 auto Ld1 =
static_cast<LoadInst*
>(PMul1->LHS);
564 auto Ld2 =
static_cast<LoadInst*
>(PMul0->RHS);
565 auto Ld3 =
static_cast<LoadInst*
>(PMul1->RHS);
568 if (Ld0 == Ld2 || Ld1 == Ld3)
571 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
572 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
574 R.AddMulPair(PMul0, PMul1);
576 }
else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
579 R.AddMulPair(PMul0, PMul1,
true);
582 }
else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
583 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
588 R.AddMulPair(PMul1, PMul0,
true);
594 MulCandList &Muls =
R.getMuls();
595 const unsigned Elems = Muls.size();
596 for (
unsigned i = 0; i < Elems; ++i) {
597 MulCandidate *PMul0 =
static_cast<MulCandidate*
>(Muls[i].get());
601 for (
unsigned j = 0;
j < Elems; ++
j) {
605 MulCandidate *PMul1 =
static_cast<MulCandidate*
>(Muls[
j].get());
614 assert(PMul0 != PMul1 &&
"expected different chains");
616 if (CanPair(R, PMul0, PMul1))
620 return !
R.getMulPairs().empty();
623void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
626 Value *Acc,
bool Exchange,
630 Value*
Args[] = { WideLd0, WideLd1, Acc };
650 assert((isa<Instruction>(
A) || isa<Instruction>(
B)) &&
651 "expected at least one instruction");
654 if (!isa<Instruction>(
A))
656 else if (!isa<Instruction>(
B))
659 V = DT->
dominates(cast<Instruction>(
A), cast<Instruction>(
B)) ?
B :
A;
664 Value *Acc =
R.getAccumulator();
669 MulCandList &MulCands =
R.getMuls();
670 for (
auto &MulCand : MulCands) {
678 assert(
R.is64Bit() &&
"expected 64-bit result");
680 Mul = cast<Instruction>(Builder.CreateSExt(
Mul,
R.getRoot()->getType()));
691 Builder.SetInsertPoint(GetInsertPoint(
Mul, Acc));
692 Acc = Builder.CreateAdd(
Mul, Acc);
699 }
else if (Acc->
getType() !=
R.getType()) {
700 Builder.SetInsertPoint(
R.getRoot());
701 Acc = Builder.CreateSExt(Acc,
R.getType());
705 llvm::sort(
R.getMulPairs(), [](
auto &PairA,
auto &PairB) {
706 const Instruction *A = PairA.first->Root;
707 const Instruction *B = PairB.first->Root;
708 return A->comesBefore(B);
712 for (
auto &Pair :
R.getMulPairs()) {
713 MulCandidate *LHSMul = Pair.first;
714 MulCandidate *RHSMul = Pair.second;
715 LoadInst *BaseLHS = LHSMul->getBaseLoad();
716 LoadInst *BaseRHS = RHSMul->getBaseLoad();
717 LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
718 WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
719 LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
720 WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
722 Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS);
723 InsertAfter = GetInsertPoint(InsertAfter, Acc);
724 Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
726 R.UpdateRoot(cast<Instruction>(Acc));
729LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads,
731 assert(Loads.size() == 2 &&
"currently only support widening two loads");
739 assert((BaseSExt && OffsetSExt)
740 &&
"Loads should have a single, extending, user");
742 std::function<void(
Value*,
Value*)> MoveBefore =
744 if (!isa<Instruction>(
A) || !isa<Instruction>(
B))
747 auto *
Source = cast<Instruction>(
A);
748 auto *
Sink = cast<Instruction>(
B);
752 isa<PHINode>(Source) || isa<PHINode>(Sink))
757 MoveBefore(
Op, Source);
768 Value *VecPtr =
Base->getPointerOperand();
769 LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
Base->getAlign());
772 MoveBefore(
Base->getPointerOperand(), VecPtr);
773 MoveBefore(VecPtr, WideLoad);
778 Value *Bottom = IRB.CreateTrunc(WideLoad,
Base->getType());
779 Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->
getType());
784 Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
785 Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
786 Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->
getType());
791 <<
"Created Wide Load:\n"
794 << *NewBaseSExt <<
"\n"
797 << *NewOffsetSExt <<
"\n");
798 WideLoads.emplace(std::make_pair(
Base,
799 std::make_unique<WidenedLoad>(Loads, WideLoad)));
804 return new ARMParallelDSP();
807char ARMParallelDSP::ID = 0;
810 "Transform functions to use DSP intrinsics",
false,
false)
Lower uses of LDS variables from non kernel functions
arm parallel Transform functions to use DSP intrinsics
static cl::opt< bool > DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false), cl::desc("Disable the ARM Parallel DSP pass"))
static cl::opt< unsigned > NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16), cl::desc("Limit the number of loads analysed"))
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static DeltaTreeNode * getRoot(void *Root)
SmallVector< uint32_t, 0 > Writes
This is the interface for a simple mod/ref and alias analysis over globals.
Move duplicate certain instructions close to their use
loop Loop Strength Reduction
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallPtrSet class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static SymbolRef::Type getType(const Symbol *Sym)
Target-Independent Code Generator Pass Configuration Options pass.
static bool is64Bit(const char *name)
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
ModRefInfo getModRefInfo(const Instruction *I, const std::optional< MemoryLocation > &OptLoc)
Check whether or not an instruction may read or write the optionally specified memory location.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
An immutable pass that tracks lazily created AssumptionCache objects.
LLVM Basic Block Representation.
InstListType::iterator iterator
Instruction iterators...
This is the shared class of boolean and integer constants.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
bool skipFunction(const Function &F) const
Optional passes call this function to check whether the pass should be skipped.
PointerType * getType() const
Global values are always pointers.
Legacy wrapper pass to provide the GlobalsAAResult object.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
Class to represent integer types.
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
An instruction for reading from memory.
static constexpr LocationSize beforeOrAfterPointer()
Any location before or after the base pointer (but still within the underlying object).
Representation for a specific memory location.
A Module instance is used to store all the information related to an LLVM module.
Pass interface - Implemented by all 'passes'.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
The main scalar evolution driver.
A vector that has set insertion semantics.
bool insert(const value_type &X)
Insert a new element into the SetVector.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isIntegerTy() const
True if this is an instance of IntegerType.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
void dump() const
Support for debugging, callable in GDB: V->dump()
const ParentTy * getParent() const
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
auto reverse(ContainerTy &&C)
decltype(auto) get(const PointerIntPair< PointerTy, IntBits, IntType, PtrTraits, Info > &Pair)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool isModOrRefSet(const ModRefInfo MRI)
bool isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL, ScalarEvolution &SE, bool CheckType=true)
Returns true if the memory operations A and B are consecutive.
Pass * createARMParallelDSPPass()