87#define DEBUG_TYPE "tailcallelim"
89STATISTIC(NumEliminated,
"Number of tail calls removed");
90STATISTIC(NumRetDuped,
"Number of return duplicated");
91STATISTIC(NumAccumAdded,
"Number of accumulators introduced");
95 cl::desc(
"Force disabling recomputing of function entry count, on "
96 "successful tail recursion elimination."));
107 return !AI || AI->isStaticAlloca();
112struct AllocaDerivedValueTracker {
116 void walk(
Value *Root) {
118 SmallPtrSet<Use *, 32> Visited;
120 auto AddUsesToWorklist = [&](
Value *
V) {
121 for (
auto &U :
V->uses()) {
122 if (!Visited.
insert(&U).second)
128 AddUsesToWorklist(Root);
130 while (!Worklist.
empty()) {
134 switch (
I->getOpcode()) {
135 case Instruction::Call:
136 case Instruction::Invoke: {
142 if (CB.isArgOperand(U) && CB.isByValArgument(CB.getArgOperandNo(U)))
145 CB.isDataOperand(U) && CB.doesNotCapture(CB.getDataOperandNo(U));
146 callUsesLocalStack(CB, IsNocapture);
154 case Instruction::Load: {
159 case Instruction::Store: {
160 if (
U->getOperandNo() == 0)
161 EscapePoints.insert(
I);
164 case Instruction::BitCast:
165 case Instruction::GetElementPtr:
166 case Instruction::PHI:
167 case Instruction::Select:
168 case Instruction::AddrSpaceCast:
171 EscapePoints.insert(
I);
175 AddUsesToWorklist(
I);
179 void callUsesLocalStack(CallBase &CB,
bool IsNocapture) {
181 AllocaUsers.insert(&CB);
189 EscapePoints.insert(&CB);
192 SmallPtrSet<Instruction *, 32> AllocaUsers;
193 SmallPtrSet<Instruction *, 32> EscapePoints;
198 if (
F.callsFunctionThatReturnsTwice())
202 AllocaDerivedValueTracker Tracker;
204 if (Arg.hasByValAttr())
240 VisitType Escaped = UNESCAPED;
242 for (
auto &
I : *BB) {
243 if (Tracker.EscapePoints.count(&
I))
256 if (
II->getIntrinsicID() == Intrinsic::stackrestore)
274 bool SafeToTail =
true;
275 for (
auto &Arg : CI->
args()) {
279 if (!
A->hasByValAttr())
288 <<
"marked as tail call candidate (readnone)";
296 if (!IsNoTail && Escaped == UNESCAPED && !Tracker.AllocaUsers.count(CI))
301 auto &State = Visited[SuccBB];
302 if (State < Escaped) {
304 if (State == ESCAPED)
311 if (!WorklistEscaped.
empty()) {
316 while (!WorklistUnescaped.
empty()) {
318 if (Visited[NextBB] == UNESCAPED) {
327 for (
CallInst *CI : DeferredTails) {
328 if (Visited[CI->getParent()] != ESCAPED) {
331 LLVM_DEBUG(
dbgs() <<
"Marked as tail call candidate: " << *CI <<
"\n");
346 if (
II->getIntrinsicID() == Intrinsic::lifetime_end)
351 if (
I->mayHaveSideEffects())
364 L->getAlign(),
DL, L))
378 if (!
I->isAssociative() || !
I->isCommutative())
381 assert(
I->getNumOperands() >= 2 &&
382 "Associative/commutative operations should have at least 2 args!");
391 if ((
I->getOperand(0) == CI &&
I->getOperand(1) == CI) ||
392 (
I->getOperand(0) != CI &&
I->getOperand(1) != CI))
403class TailRecursionEliminator {
405 const TargetTransformInfo *TTI;
407 OptimizationRemarkEmitter *ORE;
409 BlockFrequencyInfo *
const BFI;
410 const uint64_t OrigEntryBBFreq;
411 const uint64_t OrigEntryCount;
420 PHINode *RetPN =
nullptr;
423 PHINode *RetKnownPN =
nullptr;
434 PHINode *AccPN =
nullptr;
439 TailRecursionEliminator(Function &F,
const TargetTransformInfo *TTI,
441 DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
442 : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU), BFI(BFI),
444 BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0
U),
445 OrigEntryCount(F.getEntryCount() ? F.getEntryCount()->getCount() : 0) {
448 assert((OrigEntryCount != 0 && OrigEntryBBFreq != 0) &&
449 "If a BFI was provided, the function should have both an entry "
450 "count that is non-zero and an entry basic block with a non-zero "
455 CallInst *findTRECandidate(BasicBlock *BB);
457 void createTailRecurseLoopHeader(CallInst *CI);
459 void insertAccumulator(Instruction *AccRecInstr);
461 bool eliminateCall(CallInst *CI);
463 void cleanupAndFinalize();
465 bool processBlock(BasicBlock &BB);
467 void copyByValueOperandIntoLocalTemp(CallInst *CI,
int OpndIdx);
469 void copyLocalTempOfByValueOperandIntoArguments(CallInst *CI,
int OpndIdx);
472 static bool eliminate(Function &F,
const TargetTransformInfo *TTI,
474 DomTreeUpdater &DTU, BlockFrequencyInfo *BFI);
481 if (&BB->
front() == TI)
486 CallInst *CI =
nullptr;
493 if (BBI == BB->
begin())
499 "Incompatible call site attributes(Tail,NoTail)");
507 if (BB == &
F.getEntryBlock() && &BB->
front() == CI &&
514 for (;
I !=
E && FI != FE; ++
I, ++FI)
515 if (*
I != &*FI)
break;
516 if (
I ==
E && FI == FE)
523void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) {
524 HeaderBB = &
F.getEntryBlock();
527 HeaderBB->
setName(
"tailrecurse");
536 NEBI = NewEntry->
begin();
540 AI->moveBefore(NEBI);
550 I->replaceAllUsesWith(PN);
559 Type *RetType =
F.getReturnType();
561 Type *BoolType = Type::getInt1Ty(
F.getContext());
577void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) {
578 assert(!AccPN &&
"Trying to insert multiple accumulators");
580 AccumulatorRecursionInstr = AccRecInstr;
596 if (
P == &
F.getEntryBlock()) {
610void TailRecursionEliminator::copyByValueOperandIntoLocalTemp(CallInst *CI,
614 const DataLayout &
DL =
F.getDataLayout();
621 Value *NewAlloca =
new AllocaInst(
622 AggTy,
DL.getAllocaAddrSpace(),
nullptr, Alignment,
626 Value *
Size = Builder.getInt64(
DL.getTypeAllocSize(AggTy));
629 Builder.CreateMemCpy(NewAlloca, Alignment,
637void TailRecursionEliminator::copyLocalTempOfByValueOperandIntoArguments(
638 CallInst *CI,
int OpndIdx) {
641 const DataLayout &
DL =
F.getDataLayout();
647 Value *
Size = Builder.getInt64(
DL.getTypeAllocSize(AggTy));
651 Builder.CreateMemCpy(
F.getArg(OpndIdx), Alignment,
656bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
665 for (++BBI; &*BBI !=
Ret; ++BBI) {
685 return OptimizationRemark(
DEBUG_TYPE,
"tailcall-recursion", CI)
686 <<
"transforming tail recursion into loop";
692 createTailRecurseLoopHeader(CI);
697 copyByValueOperandIntoLocalTemp(CI,
I);
705 copyLocalTempOfByValueOperandIntoArguments(CI,
I);
711 F.removeParamAttr(
I, Attribute::ReadOnly);
712 ArgumentPHIs[
I]->addIncoming(
F.getArg(
I), BB);
718 insertAccumulator(AccRecInstr);
728 if (
Ret->getReturnValue() == CI || AccRecInstr) {
738 "current.ret.tr",
Ret->getIterator());
739 SI->setDebugLoc(
Ret->getDebugLoc());
747 AccPN->
addIncoming(AccRecInstr ? AccRecInstr : AccPN, BB);
755 Ret->eraseFromParent();
757 DTU.
applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
759 if (OrigEntryBBFreq) {
760 assert(
F.getEntryCount().has_value());
764 assert(&
F.getEntryBlock() != BB);
765 auto RelativeBBFreq =
766 static_cast<double>(
BFI->getBlockFreq(BB).getFrequency()) /
767 static_cast<double>(OrigEntryBBFreq);
769 static_cast<uint64_t
>(std::round(RelativeBBFreq * OrigEntryCount));
770 auto OldEntryCount =
F.getEntryCount()->getCount();
771 if (OldEntryCount <= ToSubtract) {
773 errs() <<
"[TRE] The entrycount attributable to the recursive call, "
775 <<
", should be strictly lower than the function entry count, "
776 << OldEntryCount <<
"\n");
778 F.setEntryCount(OldEntryCount - ToSubtract,
F.getEntryCount()->getType());
784void TailRecursionEliminator::cleanupAndFinalize() {
790 for (PHINode *PN : ArgumentPHIs) {
799 if (RetSelects.
empty()) {
811 Instruction *AccRecInstr = AccumulatorRecursionInstr;
812 for (BasicBlock &BB :
F) {
818 AccRecInstrNew->
setName(
"accumulator.ret.tr");
829 for (BasicBlock &BB :
F) {
845 Instruction *AccRecInstr = AccumulatorRecursionInstr;
846 for (SelectInst *SI : RetSelects) {
848 AccRecInstrNew->
setName(
"accumulator.ret.tr");
850 SI->getFalseValue());
853 SI->setFalseValue(AccRecInstrNew);
860bool TailRecursionEliminator::processBlock(BasicBlock &BB) {
864 if (BI->isConditional())
873 CallInst *CI = findTRECandidate(&BB);
879 <<
"INTO UNCOND BRANCH PRED: " << BB);
894 CallInst *CI = findTRECandidate(&BB);
897 return eliminateCall(CI);
903bool TailRecursionEliminator::eliminate(Function &
F,
904 const TargetTransformInfo *
TTI,
906 OptimizationRemarkEmitter *ORE,
908 BlockFrequencyInfo *BFI) {
909 if (
F.getFnAttribute(
"disable-tail-calls").getValueAsBool())
912 bool MadeChange =
false;
917 if (
F.getFunctionType()->isVarArg())
924 TailRecursionEliminator TRE(
F,
TTI, AA, ORE, DTU, BFI);
926 for (BasicBlock &BB :
F)
927 MadeChange |= TRE.processBlock(BB);
929 TRE.cleanupAndFinalize();
935struct TailCallElim :
public FunctionPass {
937 TailCallElim() : FunctionPass(
ID) {
941 void getAnalysisUsage(AnalysisUsage &AU)
const override {
944 AU.
addRequired<OptimizationRemarkEmitterWrapperPass>();
954 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
955 auto *DT = DTWP ? &DTWP->getDomTree() :
nullptr;
956 auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
957 auto *PDT = PDTWP ? &PDTWP->getPostDomTree() :
nullptr;
961 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
963 return TailRecursionEliminator::eliminate(
964 F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F),
965 &getAnalysis<AAResultsWrapperPass>().getAAResults(),
966 &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
972char TailCallElim::ID = 0;
982 return new TailCallElim();
994 F.getEntryCount().has_value() &&
F.getEntryCount()->getCount())
1003 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
1005 TailRecursionEliminator::eliminate(
F, &
TTI, &
AA, &ORE, DTU, BFI);
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static bool runOnFunction(Function &F, bool PostInlining)
This is the interface for a simple mod/ref and alias analysis over globals.
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Module.h This file contains the declarations for the Module class.
uint64_t IntrinsicInst * II
PassBuilder PB(Machine, PassOpts->PTO, std::nullopt, &PIC)
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file defines the SmallPtrSet class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static cl::opt< bool > ForceDisableBFI("tre-disable-entrycount-recompute", cl::init(false), cl::Hidden, cl::desc("Force disabling recomputing of function entry count, on " "successful tail recursion elimination."))
static bool canTRE(Function &F)
Scan the specified function for alloca instructions.
static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA)
Return true if it is safe to move the specified instruction from after the call to before the call,...
static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI)
static bool markTails(Function &F, OptimizationRemarkEmitter *ORE)
A manager for alias analyses.
an instruction to allocate memory on the stack
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
This class represents an incoming formal argument to a Function.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
const Function * getParent() const
Return the enclosing method, or null if none.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
LLVM_ABI InstListType::const_iterator getFirstNonPHIOrDbg(bool SkipPseudoOp=true) const
Returns a pointer to the first instruction in this block that is not a PHINode or a debug intrinsic,...
const Instruction & front() const
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Analysis pass which computes BlockFrequencyInfo.
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool doesNotAccessMemory(unsigned OpNo) const
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
bool isByValArgument(unsigned ArgNo) const
Determine whether this argument is passed by value.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
bool onlyReadsMemory(unsigned OpNo) const
Type * getParamByValType(unsigned ArgNo) const
Extract the byval type for a call or parameter.
bool hasOperandBundlesOtherThan(ArrayRef< uint32_t > IDs) const
Return true if this operand bundle user contains operand bundles with tags other than those specified...
Value * getArgOperand(unsigned i) const
void setArgOperand(unsigned i, Value *v)
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
unsigned arg_size() const
This class represents a function call, abstracting a target machine's calling convention.
bool isNoTailCall() const
void setTailCall(bool IsTc=true)
static LLVM_ABI Constant * getIdentity(Instruction *I, Type *Ty, bool AllowRHSConstant=false, bool NSZ=false)
Return the identity constant for a binary or intrinsic Instruction.
static LLVM_ABI Constant * getIntrinsicIdentity(Intrinsic::ID, Type *Ty)
static LLVM_ABI ConstantInt * getTrue(LLVMContext &Context)
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
A parsed version of the target data layout string in and methods for querying it.
static DebugLoc getCompilerGenerated()
LLVM_ABI void deleteBB(BasicBlock *DelBB)
Delete DelBB.
Analysis pass which computes a DominatorTree.
FunctionPass class - This class is used to implement most global optimizations.
void applyUpdates(ArrayRef< UpdateT > Updates)
Submit updates to all available trees.
void recalculate(FuncT &F)
Notify DTU that the entry block was replaced.
LLVM_ABI Instruction * clone() const
Create a copy of 'this' instruction that is identical in all ways except the following:
LLVM_ABI void dropLocation()
Drop the instruction's debug location.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI void insertBefore(InstListType::iterator InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified position.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
A wrapper class for inspecting calls to intrinsic functions.
@ OB_clang_arc_attachedcall
An instruction for reading from memory.
static LLVM_ABI MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Analysis pass which computes a PostDominatorTree.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
static SelectInst * Create(Value *C, Value *S1, Value *S2, const Twine &NameStr="", InsertPosition InsertBefore=nullptr, Instruction *MDFrom=nullptr)
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Analysis pass providing the TargetTransformInfo.
bool isVoidTy() const
Return true if this is 'void'.
void dropAllReferences()
Drop all references to operands.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
const ParentTy * getParent() const
self_iterator getIterator()
Abstract Attribute helper functions.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ BasicBlock
Various leaf nodes.
initializer< Ty > init(const Ty &Val)
Add a small namespace to avoid name clashes with the classes used in the streaming interface.
NodeAddr< UseNode * > Use
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
LLVM_ABI FunctionPass * createTailCallEliminationPass()
auto pred_end(const MachineBasicBlock *BB)
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
auto successors(const MachineBasicBlock *BB)
LLVM_ABI ReturnInst * FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, BasicBlock *Pred, DomTreeUpdater *DTU=nullptr)
This method duplicates the specified return instruction into a predecessor which ends in an unconditi...
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isModSet(const ModRefInfo MRI)
LLVM_ABI bool isSafeToLoadUnconditionally(Value *V, Align Alignment, const APInt &Size, const DataLayout &DL, Instruction *ScanFrom, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if we know that executing a load from this value cannot trap.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
PredIterator< BasicBlock, Value::user_iterator > pred_iterator
auto pred_begin(const MachineBasicBlock *BB)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool pred_empty(const BasicBlock *BB)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI void initializeTailCallElimPass(PassRegistry &)
AAResults AliasAnalysis
Temporary typedef for legacy code that uses a generic AliasAnalysis pointer or reference.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.