LLVM 20.0.0git
CoroAnnotationElide.cpp
Go to the documentation of this file.
1//===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This pass transforms all Call or Invoke instructions that are annotated
11// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12// The frame of the callee coroutine is allocated inside the caller. A pointer
13// to the allocated frame will be passed into the `.noalloc` ramp function.
14//
15//===----------------------------------------------------------------------===//
16
18
22#include "llvm/IR/Analysis.h"
23#include "llvm/IR/IRBuilder.h"
24#include "llvm/IR/Instruction.h"
25#include "llvm/IR/Module.h"
26#include "llvm/IR/PassManager.h"
29
30#include <cassert>
31
32using namespace llvm;
33
34#define DEBUG_TYPE "coro-annotation-elide"
35
37 for (Instruction &I : F->getEntryBlock())
38 if (!isa<AllocaInst>(&I))
39 return &I;
40 llvm_unreachable("no terminator in the entry block");
41}
42
43// Create an alloca in the caller, using FrameSize and FrameAlign as the callee
44// coroutine's activation frame.
45static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
46 Align FrameAlign) {
47 LLVMContext &C = Caller->getContext();
48 BasicBlock::iterator InsertPt =
50 const DataLayout &DL = Caller->getDataLayout();
51 auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
52 auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
53 Frame->setAlignment(FrameAlign);
54 return Frame;
55}
56
57// Given a call or invoke instruction to the elide safe coroutine, this function
58// does the following:
59// - Allocate a frame for the callee coroutine in the caller using alloca.
60// - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
61// pointer to the frame as an additional argument to NewCallee.
62static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
63 uint64_t FrameSize, Align FrameAlign) {
64 // TODO: generate the lifetime intrinsics for the new frame. This will require
65 // introduction of two pesudo lifetime intrinsics in the frontend around the
66 // `co_await` expression and convert them to real lifetime intrinsics here.
67 auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
68 auto NewCBInsertPt = CB->getIterator();
69 llvm::CallBase *NewCB = nullptr;
71 NewArgs.append(CB->arg_begin(), CB->arg_end());
72 NewArgs.push_back(FramePtr);
73
74 if (auto *CI = dyn_cast<CallInst>(CB)) {
75 auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
76 NewArgs, "", NewCBInsertPt);
77 NewCI->setTailCallKind(CI->getTailCallKind());
78 NewCB = NewCI;
79 } else if (auto *II = dyn_cast<InvokeInst>(CB)) {
80 NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
81 II->getNormalDest(), II->getUnwindDest(),
82 NewArgs, {}, "", NewCBInsertPt);
83 } else {
84 llvm_unreachable("CallBase should either be Call or Invoke!");
85 }
86
87 NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee);
88 NewCB->setCallingConv(CB->getCallingConv());
89 NewCB->setAttributes(CB->getAttributes());
90 NewCB->setDebugLoc(CB->getDebugLoc());
91 std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
92 NewCB->bundle_op_info_begin());
93
94 NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
95 CB->replaceAllUsesWith(NewCB);
96
98 InlineResult IR = InlineFunction(*NewCB, IFI);
99 if (IR.isSuccess()) {
100 CB->eraseFromParent();
101 } else {
102 NewCB->replaceAllUsesWith(CB);
103 NewCB->eraseFromParent();
104 }
105}
106
109 LazyCallGraph &CG,
110 CGSCCUpdateResult &UR) {
111 bool Changed = false;
112 CallGraphUpdater CGUpdater;
113 CGUpdater.initialize(CG, C, AM, UR);
114
115 auto &FAM =
116 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
117
118 for (LazyCallGraph::Node &N : C) {
119 Function *Callee = &N.getFunction();
120 Function *NewCallee = Callee->getParent()->getFunction(
121 (Callee->getName() + ".noalloc").str());
122 if (!NewCallee)
123 continue;
124
126 for (auto *U : Callee->users()) {
127 if (auto *CB = dyn_cast<CallBase>(U)) {
128 if (CB->getCalledFunction() == Callee)
129 Users.push_back(CB);
130 }
131 }
132 auto FramePtrArgPosition = NewCallee->arg_size() - 1;
133 auto FrameSize =
134 NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
135 auto FrameAlign =
136 NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
137
138 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee);
139
140 for (auto *CB : Users) {
141 auto *Caller = CB->getFunction();
142 if (!Caller)
143 continue;
144
145 bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
146 bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe);
147 if (IsCallerPresplitCoroutine && HasAttr) {
148 auto *CallerN = CG.lookup(*Caller);
149 auto *CallerC = CallerN ? CG.lookupSCC(*CallerN) : nullptr;
150 // If CallerC is nullptr, it means LazyCallGraph hasn't visited Caller
151 // yet. Skip the call graph update.
152 auto ShouldUpdateCallGraph = !!CallerC;
153 processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
154
155 ORE.emit([&]() {
156 return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
157 << "'" << ore::NV("callee", Callee->getName())
158 << "' elided in '" << ore::NV("caller", Caller->getName())
159 << "'";
160 });
161
163 Changed = true;
164 if (ShouldUpdateCallGraph)
165 updateCGAndAnalysisManagerForCGSCCPass(CG, *CallerC, *CallerN, AM, UR,
166 FAM);
167
168 } else {
169 ORE.emit([&]() {
170 return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
171 Caller)
172 << "'" << ore::NV("callee", Callee->getName())
173 << "' not elided in '" << ore::NV("caller", Caller->getName())
174 << "' (caller_presplit="
175 << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
176 << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
177 << ")";
178 });
179 }
180 }
181 }
182
183 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
184}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This header provides classes for managing passes over SCCs of the call graph.
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
static void processCall(CallBase *CB, Function *Caller, Function *NewCallee, uint64_t FrameSize, Align FrameAlign)
static Instruction * getFirstNonAllocaInTheEntryBlock(Function *F)
static Value * allocateFrameInCaller(Function *Caller, uint64_t FrameSize, Align FrameAlign)
#define DEBUG_TYPE
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
iv Induction Variable Users
Definition: IVUsers.cpp:48
Implements a lazy call graph analysis and related passes for the new pass manager.
Legalize the Machine IR a function s Machine IR
Definition: Legalizer.cpp:80
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
uint64_t IntrinsicInst * II
FunctionAnalysisManager FAM
static const unsigned FramePtr
an instruction to allocate memory on the stack
Definition: Instructions.h:63
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
void invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Invalidate cached analyses for an IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:177
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1120
void setCallingConv(CallingConv::ID CC)
Definition: InstrTypes.h:1411
bundle_op_iterator bundle_op_info_begin()
Return the start of the list of BundleOpInfo instances associated with this OperandBundleUser.
Definition: InstrTypes.h:2212
CallingConv::ID getCallingConv() const
Definition: InstrTypes.h:1407
bundle_op_iterator bundle_op_info_end()
Return the end of the list of BundleOpInfo instances associated with this OperandBundleUser.
Definition: InstrTypes.h:2229
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition: InstrTypes.h:1269
void setAttributes(AttributeList A)
Set the attributes for this call.
Definition: InstrTypes.h:1428
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition: InstrTypes.h:1275
void removeFnAttr(Attribute::AttrKind Kind)
Removes the attribute from the function.
Definition: InstrTypes.h:1529
AttributeList getAttributes() const
Return the attributes for this call.
Definition: InstrTypes.h:1425
void setCalledFunction(Function *Fn)
Sets the function called, including updating the function type.
Definition: InstrTypes.h:1388
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
A proxy from a FunctionAnalysisManager to an SCC.
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:216
uint64_t getParamDereferenceableBytes(unsigned ArgNo) const
Extract the number of dereferenceable bytes for a parameter.
Definition: Function.h:523
MaybeAlign getParamAlign(unsigned ArgNo) const
Definition: Function.h:488
size_t arg_size() const
Definition: Function.h:901
This class captures the data input to the InlineFunction call, and records the auxiliary results prod...
Definition: Cloning.h:255
InlineResult is basically true or false.
Definition: InlineCost.h:179
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:475
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:94
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
Definition: Instruction.h:472
static InvokeInst * Create(FunctionType *Ty, Value *Func, BasicBlock *IfNormal, BasicBlock *IfException, ArrayRef< Value * > Args, const Twine &NameStr, InsertPosition InsertBefore=nullptr)
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
SCC * lookupSCC(Node &N) const
Lookup a function's SCC in the graph.
Node * lookup(const Function &F) const
Lookup a function in the graph which has already been scanned and added.
Diagnostic information for missed-optimization remarks.
Diagnostic information for applied optimization remarks.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:114
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
static IntegerType * getInt8Ty(LLVMContext &C)
LLVM Value Representation.
Definition: Value.h:74
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
self_iterator getIterator()
Definition: ilist_node.h:132
Pass manager infrastructure for declaring and invalidating analyses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
DiagnosticInfoOptimizationBase::Argument NV
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
LazyCallGraph::SCC & updateCGAndAnalysisManagerForCGSCCPass(LazyCallGraph &G, LazyCallGraph::SCC &C, LazyCallGraph::Node &N, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR, FunctionAnalysisManager &FAM)
Helper to update the call graph after running a CGSCC pass.
InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, bool MergeAttributes=false, AAResults *CalleeAAR=nullptr, bool InsertLifetime=true, Function *ForwardVarArgsTo=nullptr)
This function inlines the called function into the basic block of the caller.
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
Definition: Alignment.h:141