84#define DEBUG_TYPE "tailcallelim"
86STATISTIC(NumEliminated,
"Number of tail calls removed");
87STATISTIC(NumRetDuped,
"Number of return duplicated");
88STATISTIC(NumAccumAdded,
"Number of accumulators introduced");
98 auto *AI = dyn_cast<AllocaInst>(&
I);
99 return !AI || AI->isStaticAlloca();
104struct AllocaDerivedValueTracker {
108 void walk(
Value *Root) {
112 auto AddUsesToWorklist = [&](
Value *
V) {
113 for (
auto &U :
V->uses()) {
114 if (!Visited.
insert(&U).second)
120 AddUsesToWorklist(Root);
122 while (!Worklist.
empty()) {
126 switch (
I->getOpcode()) {
127 case Instruction::Call:
128 case Instruction::Invoke: {
129 auto &CB = cast<CallBase>(*
I);
134 if (CB.isArgOperand(U) && CB.isByValArgument(CB.getArgOperandNo(U)))
137 CB.isDataOperand(U) && CB.doesNotCapture(CB.getDataOperandNo(U));
138 callUsesLocalStack(CB, IsNocapture);
146 case Instruction::Load: {
151 case Instruction::Store: {
152 if (
U->getOperandNo() == 0)
153 EscapePoints.insert(
I);
156 case Instruction::BitCast:
157 case Instruction::GetElementPtr:
158 case Instruction::PHI:
159 case Instruction::Select:
160 case Instruction::AddrSpaceCast:
163 EscapePoints.insert(
I);
167 AddUsesToWorklist(
I);
171 void callUsesLocalStack(
CallBase &CB,
bool IsNocapture) {
173 AllocaUsers.insert(&CB);
181 EscapePoints.insert(&CB);
190 if (
F.callsFunctionThatReturnsTwice())
194 AllocaDerivedValueTracker Tracker;
196 if (Arg.hasByValAttr())
232 VisitType Escaped = UNESCAPED;
234 for (
auto &
I : *BB) {
235 if (Tracker.EscapePoints.count(&
I))
242 if (!CI || CI->
isTailCall() || isa<DbgInfoIntrinsic>(&
I) ||
243 isa<PseudoProbeInst>(&
I))
261 bool SafeToTail =
true;
262 for (
auto &Arg : CI->
args()) {
263 if (isa<Constant>(Arg.getUser()))
265 if (
Argument *
A = dyn_cast<Argument>(Arg.getUser()))
266 if (!
A->hasByValAttr())
275 <<
"marked as tail call candidate (readnone)";
283 if (!IsNoTail && Escaped == UNESCAPED && !Tracker.AllocaUsers.count(CI))
288 auto &State = Visited[SuccBB];
289 if (State < Escaped) {
291 if (State == ESCAPED)
298 if (!WorklistEscaped.
empty()) {
303 while (!WorklistUnescaped.
empty()) {
305 if (Visited[NextBB] == UNESCAPED) {
314 for (
CallInst *CI : DeferredTails) {
315 if (Visited[CI->getParent()] != ESCAPED) {
318 LLVM_DEBUG(
dbgs() <<
"Marked as tail call candidate: " << *CI <<
"\n");
332 if (isa<DbgInfoIntrinsic>(
I))
336 if (II->getIntrinsicID() == Intrinsic::lifetime_end &&
342 if (
I->mayHaveSideEffects())
345 if (
LoadInst *L = dyn_cast<LoadInst>(
I)) {
355 L->getAlign(),
DL, L))
369 if (!
I->isAssociative() || !
I->isCommutative())
372 assert(
I->getNumOperands() == 2 &&
373 "Associative/commutative operations should have 2 args!");
376 if ((
I->getOperand(0) == CI &&
I->getOperand(1) == CI) ||
377 (
I->getOperand(0) != CI &&
I->getOperand(1) != CI))
381 if (!
I->hasOneUse() || !isa<ReturnInst>(
I->user_back()))
388 while (isa<DbgInfoIntrinsic>(
I))
394class TailRecursionEliminator {
430 :
F(
F),
TTI(
TTI), AA(AA), ORE(ORE), DTU(DTU) {}
434 void createTailRecurseLoopHeader(
CallInst *CI);
440 void cleanupAndFinalize();
444 void copyByValueOperandIntoLocalTemp(
CallInst *CI,
int OpndIdx);
446 void copyLocalTempOfByValueOperandIntoArguments(
CallInst *CI,
int OpndIdx);
458 if (&BB->
front() == TI)
466 CI = dyn_cast<CallInst>(BBI);
470 if (BBI == BB->
begin())
476 "Incompatible call site attributes(Tail,NoTail)");
484 if (BB == &
F.getEntryBlock() &&
492 for (;
I !=
E && FI != FE; ++
I, ++FI)
493 if (*
I != &*FI)
break;
494 if (
I ==
E && FI == FE)
501void TailRecursionEliminator::createTailRecurseLoopHeader(
CallInst *CI) {
502 HeaderBB = &
F.getEntryBlock();
505 HeaderBB->setName(
"tailrecurse");
511 NEBI = NewEntry->
begin();
513 if (
AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++))
514 if (isa<ConstantInt>(AI->getArraySize()))
515 AI->moveBefore(&*NEBI);
525 I->replaceAllUsesWith(PN);
527 ArgumentPHIs.push_back(PN);
534 Type *RetType =
F.getReturnType();
538 RetPN->insertBefore(InsertPos);
540 RetKnownPN->insertBefore(InsertPos);
552void TailRecursionEliminator::insertAccumulator(
Instruction *AccRecInstr) {
553 assert(!AccPN &&
"Trying to insert multiple accumulators");
555 AccumulatorRecursionInstr = AccRecInstr;
561 AccPN->insertBefore(HeaderBB->begin());
571 if (
P == &
F.getEntryBlock()) {
574 AccPN->addIncoming(Identity,
P);
576 AccPN->addIncoming(AccPN,
P);
585void TailRecursionEliminator::copyByValueOperandIntoLocalTemp(
CallInst *CI,
597 AggTy,
DL.getAllocaAddrSpace(),
nullptr, Alignment,
601 Value *
Size = Builder.getInt64(
DL.getTypeAllocSize(AggTy));
604 Builder.CreateMemCpy(NewAlloca, Alignment,
612void TailRecursionEliminator::copyLocalTempOfByValueOperandIntoArguments(
622 Value *
Size = Builder.getInt64(
DL.getTypeAllocSize(AggTy));
626 Builder.CreateMemCpy(
F.getArg(OpndIdx), Alignment,
631bool TailRecursionEliminator::eliminateCall(
CallInst *CI) {
640 for (++BBI; &*BBI !=
Ret; ++BBI) {
661 <<
"transforming tail recursion into loop";
667 createTailRecurseLoopHeader(CI);
672 copyByValueOperandIntoLocalTemp(CI,
I);
680 copyLocalTempOfByValueOperandIntoArguments(CI,
I);
686 F.removeParamAttr(
I, Attribute::ReadOnly);
687 ArgumentPHIs[
I]->addIncoming(
F.getArg(
I), BB);
693 insertAccumulator(AccRecInstr);
703 if (
Ret->getReturnValue() == CI || AccRecInstr) {
705 RetPN->addIncoming(RetPN, BB);
706 RetKnownPN->addIncoming(RetKnownPN, BB);
712 RetKnownPN, RetPN,
Ret->getReturnValue(),
"current.ret.tr", Ret);
713 RetSelects.push_back(SI);
715 RetPN->addIncoming(SI, BB);
720 AccPN->addIncoming(AccRecInstr ? AccRecInstr : AccPN, BB);
728 Ret->eraseFromParent();
730 DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
735void TailRecursionEliminator::cleanupAndFinalize() {
741 for (
PHINode *PN : ArgumentPHIs) {
750 if (RetSelects.empty()) {
753 RetPN->dropAllReferences();
754 RetPN->eraseFromParent();
756 RetKnownPN->dropAllReferences();
757 RetKnownPN->eraseFromParent();
762 Instruction *AccRecInstr = AccumulatorRecursionInstr;
769 AccRecInstrNew->
setName(
"accumulator.ret.tr");
785 RetKnownPN, RetPN, RI->
getOperand(0),
"current.ret.tr", RI);
786 RetSelects.push_back(SI);
793 Instruction *AccRecInstr = AccumulatorRecursionInstr;
796 AccRecInstrNew->
setName(
"accumulator.ret.tr");
798 SI->getFalseValue());
800 SI->setFalseValue(AccRecInstrNew);
807bool TailRecursionEliminator::processBlock(
BasicBlock &BB) {
810 if (
BranchInst *BI = dyn_cast<BranchInst>(TI)) {
820 CallInst *CI = findTRECandidate(&BB);
826 <<
"INTO UNCOND BRANCH PRED: " << BB);
840 }
else if (isa<ReturnInst>(TI)) {
841 CallInst *CI = findTRECandidate(&BB);
844 return eliminateCall(CI);
850bool TailRecursionEliminator::eliminate(
Function &
F,
855 if (
F.getFnAttribute(
"disable-tail-calls").getValueAsBool())
858 bool MadeChange =
false;
863 if (
F.getFunctionType()->isVarArg())
870 TailRecursionEliminator TRE(
F,
TTI, AA, ORE, DTU);
873 MadeChange |= TRE.processBlock(BB);
875 TRE.cleanupAndFinalize();
900 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
901 auto *DT = DTWP ? &DTWP->getDomTree() :
nullptr;
902 auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
903 auto *PDT = PDTWP ? &PDTWP->getPostDomTree() :
nullptr;
907 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
909 return TailRecursionEliminator::eliminate(
910 F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F),
911 &getAnalysis<AAResultsWrapperPass>().getAAResults(),
912 &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
917char TailCallElim::ID = 0;
927 return new TailCallElim();
942 bool Changed = TailRecursionEliminator::eliminate(
F, &
TTI, &AA, &ORE, DTU);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This is the interface for a simple mod/ref and alias analysis over globals.
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Select target instructions out of generic instructions
Module.h This file contains the declarations for the Module class.
PassBuilder PB(Machine, PassOpts->PTO, std::nullopt, &PIC)
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallPtrSet class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static bool canTRE(Function &F)
Scan the specified function for alloca instructions.
static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA)
Return true if it is safe to move the specified instruction from after the call to before the call,...
static Instruction * firstNonDbg(BasicBlock::iterator I)
static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI)
static bool markTails(Function &F, OptimizationRemarkEmitter *ORE)
A manager for alias analyses.
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
ModRefInfo getModRefInfo(const Instruction *I, const std::optional< MemoryLocation > &OptLoc)
Check whether or not an instruction may read or write the optionally specified memory location.
an instruction to allocate memory on the stack
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
This class represents an incoming formal argument to a Function.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
const Instruction & front() const
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const Function * getParent() const
Return the enclosing method, or null if none.
const Instruction * getFirstNonPHIOrDbg(bool SkipPseudoOp=true) const
Returns a pointer to the first instruction in this block that is not a PHINode or a debug intrinsic,...
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Conditional or Unconditional Branch instruction.
bool isConditional() const
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool doesNotAccessMemory(unsigned OpNo) const
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
bool isByValArgument(unsigned ArgNo) const
Determine whether this argument is passed by value.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
bool onlyReadsMemory(unsigned OpNo) const
Type * getParamByValType(unsigned ArgNo) const
Extract the byval type for a call or parameter.
bool hasOperandBundlesOtherThan(ArrayRef< uint32_t > IDs) const
Return true if this operand bundle user contains operand bundles with tags other than those specified...
Value * getArgOperand(unsigned i) const
void setArgOperand(unsigned i, Value *v)
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
unsigned arg_size() const
This class represents a function call, abstracting a target machine's calling convention.
bool isNoTailCall() const
void setTailCall(bool IsTc=true)
static Constant * getBinOpIdentity(unsigned Opcode, Type *Ty, bool AllowRHSConstant=false, bool NSZ=false)
Return the identity constant for a binary opcode.
static ConstantInt * getTrue(LLVMContext &Context)
static ConstantInt * getFalse(LLVMContext &Context)
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
bool skipFunction(const Function &F) const
Optional passes call this function to check whether the pass should be skipped.
Legacy wrapper pass to provide the GlobalsAAResult object.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Instruction * clone() const
Create a copy of 'this' instruction that is identical in all ways except the following:
void insertBefore(Instruction *InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified instruction.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
const BasicBlock * getParent() const
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
A wrapper class for inspecting calls to intrinsic functions.
@ OB_clang_arc_attachedcall
An instruction for reading from memory.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Analysis pass which computes a PostDominatorTree.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
Return a value (possibly void), from a function.
This class represents the LLVM 'select' instruction.
static SelectInst * Create(Value *C, Value *S1, Value *S2, const Twine &NameStr="", Instruction *InsertBefore=nullptr, Instruction *MDFrom=nullptr)
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetTransformInfo.
The instances of the Type class are immutable: once they are created, they are never changed.
static IntegerType * getInt1Ty(LLVMContext &C)
bool isVoidTy() const
Return true if this is 'void'.
A Use represents the edge between a Value definition and its users.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void setName(const Twine &Name)
Change the name of the value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
StringRef getName() const
Return a constant reference to the value's name.
void takeName(Value *V)
Transfer the name from V to this value.
self_iterator getIterator()
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ Tail
Attemps to make calls as fast as possible while guaranteeing that tail call optimization can always b...
This is an optimization pass for GlobalISel generic memory operations.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
FunctionPass * createTailCallEliminationPass()
AllocaInst * findAllocaForValue(Value *V, bool OffsetZero=false)
Returns unique alloca where the value comes from, or nullptr.
auto successors(const MachineBasicBlock *BB)
void initializeTailCallElimPass(PassRegistry &)
ReturnInst * FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, BasicBlock *Pred, DomTreeUpdater *DTU=nullptr)
This method duplicates the specified return instruction into a predecessor which ends in an unconditi...
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
Interval::pred_iterator pred_end(Interval *I)
bool isModSet(const ModRefInfo MRI)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Interval::pred_iterator pred_begin(Interval *I)
pred_begin/pred_end - define methods so that Intervals may be used just like BasicBlocks can with the...
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool isSafeToLoadUnconditionally(Value *V, Align Alignment, APInt &Size, const DataLayout &DL, Instruction *ScanFrom=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if we know that executing a load from this value cannot trap.
bool pred_empty(const BasicBlock *BB)
This struct is a compact representation of a valid (non-zero power of two) alignment.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)