84#define DEBUG_TYPE "tailcallelim"
86STATISTIC(NumEliminated,
"Number of tail calls removed");
87STATISTIC(NumRetDuped,
"Number of return duplicated");
88STATISTIC(NumAccumAdded,
"Number of accumulators introduced");
98 auto *AI = dyn_cast<AllocaInst>(&
I);
99 return !AI || AI->isStaticAlloca();
104struct AllocaDerivedValueTracker {
108 void walk(
Value *Root) {
112 auto AddUsesToWorklist = [&](
Value *
V) {
113 for (
auto &U :
V->uses()) {
114 if (!Visited.
insert(&U).second)
120 AddUsesToWorklist(Root);
122 while (!Worklist.
empty()) {
126 switch (
I->getOpcode()) {
127 case Instruction::Call:
128 case Instruction::Invoke: {
129 auto &CB = cast<CallBase>(*
I);
134 if (CB.isArgOperand(U) && CB.isByValArgument(CB.getArgOperandNo(U)))
137 CB.isDataOperand(U) && CB.doesNotCapture(CB.getDataOperandNo(U));
138 callUsesLocalStack(CB, IsNocapture);
146 case Instruction::Load: {
151 case Instruction::Store: {
152 if (
U->getOperandNo() == 0)
153 EscapePoints.insert(
I);
156 case Instruction::BitCast:
157 case Instruction::GetElementPtr:
158 case Instruction::PHI:
159 case Instruction::Select:
160 case Instruction::AddrSpaceCast:
163 EscapePoints.insert(
I);
167 AddUsesToWorklist(
I);
171 void callUsesLocalStack(
CallBase &CB,
bool IsNocapture) {
173 AllocaUsers.insert(&CB);
181 EscapePoints.insert(&CB);
190 if (
F.callsFunctionThatReturnsTwice())
194 AllocaDerivedValueTracker Tracker;
196 if (Arg.hasByValAttr())
232 VisitType Escaped = UNESCAPED;
234 for (
auto &
I : *BB) {
235 if (Tracker.EscapePoints.count(&
I))
242 if (!CI || CI->
isTailCall() || isa<DbgInfoIntrinsic>(&
I) ||
243 isa<PseudoProbeInst>(&
I))
248 if (
auto *
II = dyn_cast<IntrinsicInst>(CI))
249 if (
II->getIntrinsicID() == Intrinsic::stackrestore)
267 bool SafeToTail =
true;
268 for (
auto &Arg : CI->
args()) {
269 if (isa<Constant>(Arg.getUser()))
271 if (
Argument *
A = dyn_cast<Argument>(Arg.getUser()))
272 if (!
A->hasByValAttr())
281 <<
"marked as tail call candidate (readnone)";
289 if (!IsNoTail && Escaped == UNESCAPED && !Tracker.AllocaUsers.count(CI))
294 auto &State = Visited[SuccBB];
295 if (State < Escaped) {
297 if (State == ESCAPED)
304 if (!WorklistEscaped.
empty()) {
309 while (!WorklistUnescaped.
empty()) {
311 if (Visited[NextBB] == UNESCAPED) {
320 for (
CallInst *CI : DeferredTails) {
321 if (Visited[CI->getParent()] != ESCAPED) {
324 LLVM_DEBUG(
dbgs() <<
"Marked as tail call candidate: " << *CI <<
"\n");
338 if (isa<DbgInfoIntrinsic>(
I))
342 if (
II->getIntrinsicID() == Intrinsic::lifetime_end &&
348 if (
I->mayHaveSideEffects())
351 if (
LoadInst *L = dyn_cast<LoadInst>(
I)) {
361 L->getAlign(),
DL, L))
375 if (!
I->isAssociative() || !
I->isCommutative())
378 assert(
I->getNumOperands() >= 2 &&
379 "Associative/commutative operations should have at least 2 args!");
388 if ((
I->getOperand(0) == CI &&
I->getOperand(1) == CI) ||
389 (
I->getOperand(0) != CI &&
I->getOperand(1) != CI))
393 if (!
I->hasOneUse() || !isa<ReturnInst>(
I->user_back()))
400 while (isa<DbgInfoIntrinsic>(
I))
406class TailRecursionEliminator {
442 :
F(
F),
TTI(
TTI), AA(AA), ORE(ORE), DTU(DTU) {}
446 void createTailRecurseLoopHeader(
CallInst *CI);
452 void cleanupAndFinalize();
456 void copyByValueOperandIntoLocalTemp(
CallInst *CI,
int OpndIdx);
458 void copyLocalTempOfByValueOperandIntoArguments(
CallInst *CI,
int OpndIdx);
470 if (&BB->
front() == TI)
478 CI = dyn_cast<CallInst>(BBI);
482 if (BBI == BB->
begin())
488 "Incompatible call site attributes(Tail,NoTail)");
496 if (BB == &
F.getEntryBlock() &&
504 for (;
I != E && FI != FE; ++
I, ++FI)
505 if (*
I != &*FI)
break;
506 if (
I == E && FI == FE)
513void TailRecursionEliminator::createTailRecurseLoopHeader(
CallInst *CI) {
514 HeaderBB = &
F.getEntryBlock();
517 HeaderBB->setName(
"tailrecurse");
525 NEBI = NewEntry->
begin();
527 if (
AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++))
528 if (isa<ConstantInt>(AI->getArraySize()))
529 AI->moveBefore(&*NEBI);
539 I->replaceAllUsesWith(PN);
541 ArgumentPHIs.push_back(PN);
548 Type *RetType =
F.getReturnType();
552 RetPN->insertBefore(InsertPos);
554 RetKnownPN->insertBefore(InsertPos);
566void TailRecursionEliminator::insertAccumulator(
Instruction *AccRecInstr) {
567 assert(!AccPN &&
"Trying to insert multiple accumulators");
569 AccumulatorRecursionInstr = AccRecInstr;
575 AccPN->insertBefore(HeaderBB->begin());
585 if (
P == &
F.getEntryBlock()) {
588 AccPN->addIncoming(Identity,
P);
590 AccPN->addIncoming(AccPN,
P);
599void TailRecursionEliminator::copyByValueOperandIntoLocalTemp(
CallInst *CI,
611 AggTy,
DL.getAllocaAddrSpace(),
nullptr, Alignment,
615 Value *
Size = Builder.getInt64(
DL.getTypeAllocSize(AggTy));
618 Builder.CreateMemCpy(NewAlloca, Alignment,
626void TailRecursionEliminator::copyLocalTempOfByValueOperandIntoArguments(
636 Value *
Size = Builder.getInt64(
DL.getTypeAllocSize(AggTy));
640 Builder.CreateMemCpy(
F.getArg(OpndIdx), Alignment,
645bool TailRecursionEliminator::eliminateCall(
CallInst *CI) {
654 for (++BBI; &*BBI !=
Ret; ++BBI) {
675 <<
"transforming tail recursion into loop";
681 createTailRecurseLoopHeader(CI);
684 for (
unsigned I = 0, E = CI->
arg_size();
I != E; ++
I) {
686 copyByValueOperandIntoLocalTemp(CI,
I);
692 for (
unsigned I = 0, E = CI->
arg_size();
I != E; ++
I) {
694 copyLocalTempOfByValueOperandIntoArguments(CI,
I);
700 F.removeParamAttr(
I, Attribute::ReadOnly);
701 ArgumentPHIs[
I]->addIncoming(
F.getArg(
I), BB);
707 insertAccumulator(AccRecInstr);
717 if (
Ret->getReturnValue() == CI || AccRecInstr) {
719 RetPN->addIncoming(RetPN, BB);
720 RetKnownPN->addIncoming(RetKnownPN, BB);
727 "current.ret.tr",
Ret->getIterator());
728 RetSelects.push_back(SI);
730 RetPN->addIncoming(SI, BB);
735 AccPN->addIncoming(AccRecInstr ? AccRecInstr : AccPN, BB);
743 Ret->eraseFromParent();
745 DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
750void TailRecursionEliminator::cleanupAndFinalize() {
756 for (
PHINode *PN : ArgumentPHIs) {
765 if (RetSelects.empty()) {
768 RetPN->dropAllReferences();
769 RetPN->eraseFromParent();
771 RetKnownPN->dropAllReferences();
772 RetKnownPN->eraseFromParent();
777 Instruction *AccRecInstr = AccumulatorRecursionInstr;
784 AccRecInstrNew->
setName(
"accumulator.ret.tr");
803 RetSelects.push_back(SI);
810 Instruction *AccRecInstr = AccumulatorRecursionInstr;
813 AccRecInstrNew->
setName(
"accumulator.ret.tr");
815 SI->getFalseValue());
818 SI->setFalseValue(AccRecInstrNew);
825bool TailRecursionEliminator::processBlock(
BasicBlock &BB) {
828 if (
BranchInst *BI = dyn_cast<BranchInst>(TI)) {
829 if (BI->isConditional())
838 CallInst *CI = findTRECandidate(&BB);
844 <<
"INTO UNCOND BRANCH PRED: " << BB);
858 }
else if (isa<ReturnInst>(TI)) {
859 CallInst *CI = findTRECandidate(&BB);
862 return eliminateCall(CI);
868bool TailRecursionEliminator::eliminate(
Function &
F,
873 if (
F.getFnAttribute(
"disable-tail-calls").getValueAsBool())
876 bool MadeChange =
false;
881 if (
F.getFunctionType()->isVarArg())
888 TailRecursionEliminator TRE(
F,
TTI, AA, ORE, DTU);
891 MadeChange |= TRE.processBlock(BB);
893 TRE.cleanupAndFinalize();
918 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
919 auto *DT = DTWP ? &DTWP->getDomTree() :
nullptr;
920 auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
921 auto *PDT = PDTWP ? &PDTWP->getPostDomTree() :
nullptr;
925 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
927 return TailRecursionEliminator::eliminate(
928 F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F),
929 &getAnalysis<AAResultsWrapperPass>().getAAResults(),
930 &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
935char TailCallElim::ID = 0;
945 return new TailCallElim();
959 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
960 bool Changed = TailRecursionEliminator::eliminate(
F, &
TTI, &AA, &ORE, DTU);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This is the interface for a simple mod/ref and alias analysis over globals.
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Module.h This file contains the declarations for the Module class.
uint64_t IntrinsicInst * II
PassBuilder PB(Machine, PassOpts->PTO, std::nullopt, &PIC)
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallPtrSet class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static bool canTRE(Function &F)
Scan the specified function for alloca instructions.
static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA)
Return true if it is safe to move the specified instruction from after the call to before the call,...
static Instruction * firstNonDbg(BasicBlock::iterator I)
static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI)
static bool markTails(Function &F, OptimizationRemarkEmitter *ORE)
A manager for alias analyses.
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
ModRefInfo getModRefInfo(const Instruction *I, const std::optional< MemoryLocation > &OptLoc)
Check whether or not an instruction may read or write the optionally specified memory location.
an instruction to allocate memory on the stack
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
This class represents an incoming formal argument to a Function.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
const Instruction & front() const
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const Function * getParent() const
Return the enclosing method, or null if none.
const Instruction * getFirstNonPHIOrDbg(bool SkipPseudoOp=true) const
Returns a pointer to the first instruction in this block that is not a PHINode or a debug intrinsic,...
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Conditional or Unconditional Branch instruction.
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool doesNotAccessMemory(unsigned OpNo) const
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
bool isByValArgument(unsigned ArgNo) const
Determine whether this argument is passed by value.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
bool onlyReadsMemory(unsigned OpNo) const
Type * getParamByValType(unsigned ArgNo) const
Extract the byval type for a call or parameter.
bool hasOperandBundlesOtherThan(ArrayRef< uint32_t > IDs) const
Return true if this operand bundle user contains operand bundles with tags other than those specified...
Value * getArgOperand(unsigned i) const
void setArgOperand(unsigned i, Value *v)
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
unsigned arg_size() const
This class represents a function call, abstracting a target machine's calling convention.
bool isNoTailCall() const
void setTailCall(bool IsTc=true)
static Constant * getIdentity(Instruction *I, Type *Ty, bool AllowRHSConstant=false, bool NSZ=false)
Return the identity constant for a binary or intrinsic Instruction.
static Constant * getIntrinsicIdentity(Intrinsic::ID, Type *Ty)
static ConstantInt * getTrue(LLVMContext &Context)
static ConstantInt * getFalse(LLVMContext &Context)
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
bool skipFunction(const Function &F) const
Optional passes call this function to check whether the pass should be skipped.
Legacy wrapper pass to provide the GlobalsAAResult object.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Instruction * clone() const
Create a copy of 'this' instruction that is identical in all ways except the following:
void dropLocation()
Drop the instruction's debug location.
void insertBefore(Instruction *InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified instruction.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
A wrapper class for inspecting calls to intrinsic functions.
@ OB_clang_arc_attachedcall
An instruction for reading from memory.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Analysis pass which computes a PostDominatorTree.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
Return a value (possibly void), from a function.
This class represents the LLVM 'select' instruction.
static SelectInst * Create(Value *C, Value *S1, Value *S2, const Twine &NameStr="", InsertPosition InsertBefore=nullptr, Instruction *MDFrom=nullptr)
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetTransformInfo.
The instances of the Type class are immutable: once they are created, they are never changed.
static IntegerType * getInt1Ty(LLVMContext &C)
bool isVoidTy() const
Return true if this is 'void'.
A Use represents the edge between a Value definition and its users.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void setName(const Twine &Name)
Change the name of the value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
StringRef getName() const
Return a constant reference to the value's name.
void takeName(Value *V)
Transfer the name from V to this value.
const ParentTy * getParent() const
self_iterator getIterator()
@ Tail
Attemps to make calls as fast as possible while guaranteeing that tail call optimization can always b...
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
This is an optimization pass for GlobalISel generic memory operations.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
FunctionPass * createTailCallEliminationPass()
auto pred_end(const MachineBasicBlock *BB)
AllocaInst * findAllocaForValue(Value *V, bool OffsetZero=false)
Returns unique alloca where the value comes from, or nullptr.
auto successors(const MachineBasicBlock *BB)
void initializeTailCallElimPass(PassRegistry &)
ReturnInst * FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, BasicBlock *Pred, DomTreeUpdater *DTU=nullptr)
This method duplicates the specified return instruction into a predecessor which ends in an unconditi...
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isModSet(const ModRefInfo MRI)
bool isSafeToLoadUnconditionally(Value *V, Align Alignment, const APInt &Size, const DataLayout &DL, Instruction *ScanFrom, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if we know that executing a load from this value cannot trap.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
auto pred_begin(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool pred_empty(const BasicBlock *BB)
This struct is a compact representation of a valid (non-zero power of two) alignment.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)