LLVM 20.0.0git
LoopTermFold.cpp
Go to the documentation of this file.
1//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//===----------------------------------------------------------------------===//
9
11#include "llvm/ADT/Statistic.h"
22#include "llvm/Config/llvm-config.h"
23#include "llvm/IR/BasicBlock.h"
24#include "llvm/IR/Dominators.h"
25#include "llvm/IR/IRBuilder.h"
26#include "llvm/IR/InstrTypes.h"
27#include "llvm/IR/Instruction.h"
29#include "llvm/IR/Module.h"
30#include "llvm/IR/Type.h"
31#include "llvm/IR/Value.h"
33#include "llvm/Pass.h"
34#include "llvm/Support/Debug.h"
42#include <cassert>
43#include <optional>
44#include <utility>
45
46using namespace llvm;
47
48#define DEBUG_TYPE "loop-term-fold"
49
50STATISTIC(NumTermFold,
51 "Number of terminating condition fold recognized and performed");
52
53static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
55 const LoopInfo &LI, const TargetTransformInfo &TTI) {
56 if (!L->isInnermost()) {
57 LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
58 return std::nullopt;
59 }
60 // Only inspect on simple loop structure
61 if (!L->isLoopSimplifyForm()) {
62 LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
63 return std::nullopt;
64 }
65
67 LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
68 return std::nullopt;
69 }
70
71 BasicBlock *LoopLatch = L->getLoopLatch();
72 BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
73 if (!BI || BI->isUnconditional())
74 return std::nullopt;
75 auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
76 if (!TermCond) {
78 dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
79 return std::nullopt;
80 }
81 if (!TermCond->hasOneUse()) {
83 dbgs()
84 << "Cannot replace terminating condition with more than one use\n");
85 return std::nullopt;
86 }
87
88 BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
89 Value *RHS = TermCond->getOperand(1);
90 if (!LHS || !L->isLoopInvariant(RHS))
91 // We could pattern match the inverse form of the icmp, but that is
92 // non-canonical, and this pass is running *very* late in the pipeline.
93 return std::nullopt;
94
95 // Find the IV used by the current exit condition.
96 PHINode *ToFold;
97 Value *ToFoldStart, *ToFoldStep;
98 if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
99 return std::nullopt;
100
101 // Ensure the simple recurrence is a part of the current loop.
102 if (ToFold->getParent() != L->getHeader())
103 return std::nullopt;
104
105 // If that IV isn't dead after we rewrite the exit condition in terms of
106 // another IV, there's no point in doing the transform.
107 if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
108 return std::nullopt;
109
110 // Inserting instructions in the preheader has a runtime cost, scale
111 // the allowed cost with the loops trip count as best we can.
112 const unsigned ExpansionBudget = [&]() {
113 unsigned Budget = 2 * SCEVCheapExpansionBudget;
114 if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
115 return std::min(Budget, SmallTC);
116 if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
117 return std::min(Budget, *SmallTC);
118 // Unknown trip count, assume long running by default.
119 return Budget;
120 }();
121
122 const SCEV *BECount = SE.getBackedgeTakenCount(L);
123 const DataLayout &DL = L->getHeader()->getDataLayout();
124 SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
125
126 PHINode *ToHelpFold = nullptr;
127 const SCEV *TermValueS = nullptr;
128 bool MustDropPoison = false;
129 auto InsertPt = L->getLoopPreheader()->getTerminator();
130 for (PHINode &PN : L->getHeader()->phis()) {
131 if (ToFold == &PN)
132 continue;
133
134 if (!SE.isSCEVable(PN.getType())) {
135 LLVM_DEBUG(dbgs() << "IV of phi '" << PN
136 << "' is not SCEV-able, not qualified for the "
137 "terminating condition folding.\n");
138 continue;
139 }
140 const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
141 // Only speculate on affine AddRec
142 if (!AddRec || !AddRec->isAffine()) {
143 LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
144 << "' is not an affine add recursion, not qualified "
145 "for the terminating condition folding.\n");
146 continue;
147 }
148
149 // Check that we can compute the value of AddRec on the exiting iteration
150 // without soundness problems. evaluateAtIteration internally needs
151 // to multiply the stride of the iteration number - which may wrap around.
152 // The issue here is subtle because computing the result accounting for
153 // wrap is insufficient. In order to use the result in an exit test, we
154 // must also know that AddRec doesn't take the same value on any previous
155 // iteration. The simplest case to consider is a candidate IV which is
156 // narrower than the trip count (and thus original IV), but this can
157 // also happen due to non-unit strides on the candidate IVs.
158 if (!AddRec->hasNoSelfWrap() ||
159 !SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
160 continue;
161
162 const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
163 const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
164 if (!Expander.isSafeToExpand(TermValueSLocal)) {
166 dbgs() << "Is not safe to expand terminating value for phi node" << PN
167 << "\n");
168 continue;
169 }
170
171 if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,
172 InsertPt)) {
174 dbgs() << "Is too expensive to expand terminating value for phi node"
175 << PN << "\n");
176 continue;
177 }
178
179 // The candidate IV may have been otherwise dead and poison from the
180 // very first iteration. If we can't disprove that, we can't use the IV.
181 if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
182 LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
183 continue;
184 }
185
186 // The candidate IV may become poison on the last iteration. If this
187 // value is not branched on, this is a well defined program. We're
188 // about to add a new use to this IV, and we have to ensure we don't
189 // insert UB which didn't previously exist.
190 bool MustDropPoisonLocal = false;
191 Instruction *PostIncV =
192 cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
193 if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
194 &DT)) {
195 LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
196 << "\n");
197
198 // If this is a complex recurrance with multiple instructions computing
199 // the backedge value, we might need to strip poison flags from all of
200 // them.
201 if (PostIncV->getOperand(0) != &PN)
202 continue;
203
204 // In order to perform the transform, we need to drop the poison
205 // generating flags on this instruction (if any).
206 MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
207 }
208
209 // We pick the last legal alternate IV. We could expore choosing an optimal
210 // alternate IV if we had a decent heuristic to do so.
211 ToHelpFold = &PN;
212 TermValueS = TermValueSLocal;
213 MustDropPoison = MustDropPoisonLocal;
214 }
215
216 LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
217 << "Cannot find other AddRec IV to help folding\n";);
218
219 LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
220 << "\nFound loop that can fold terminating condition\n"
221 << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
222 << " TermCond: " << *TermCond << "\n"
223 << " BrandInst: " << *BI << "\n"
224 << " ToFold: " << *ToFold << "\n"
225 << " ToHelpFold: " << *ToHelpFold << "\n");
226
227 if (!ToFold || !ToHelpFold)
228 return std::nullopt;
229 return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
230}
231
233 LoopInfo &LI, const TargetTransformInfo &TTI,
234 TargetLibraryInfo &TLI, MemorySSA *MSSA) {
235 std::unique_ptr<MemorySSAUpdater> MSSAU;
236 if (MSSA)
237 MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
238
239 auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
240 if (!Opt)
241 return false;
242
243 auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
244
245 NumTermFold++;
246
247 BasicBlock *LoopPreheader = L->getLoopPreheader();
248 BasicBlock *LoopLatch = L->getLoopLatch();
249
250 (void)ToFold;
251 LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
252 << *ToFold << "\n"
253 << "New term-cond phi-node:\n"
254 << *ToHelpFold << "\n");
255
256 Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
257 (void)StartValue;
258 Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
259
260 // See comment in canFoldTermCondOfLoop on why this is sufficient.
261 if (MustDrop)
262 cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
263
264 // SCEVExpander for both use in preheader and latch
265 const DataLayout &DL = L->getHeader()->getDataLayout();
266 SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
267
268 assert(Expander.isSafeToExpand(TermValueS) &&
269 "Terminating value was checked safe in canFoldTerminatingCondition");
270
271 // Create new terminating value at loop preheader
272 Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
273 LoopPreheader->getTerminator());
274
275 LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
276 << *StartValue << "\n"
277 << "Terminating value of new term-cond phi-node:\n"
278 << *TermValue << "\n");
279
280 // Create new terminating condition at loop latch
281 BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
282 ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
283 IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
284 Value *NewTermCond =
285 LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
286 "lsr_fold_term_cond.replaced_term_cond");
287 // Swap successors to exit loop body if IV equals to new TermValue
288 if (BI->getSuccessor(0) == L->getHeader())
289 BI->swapSuccessors();
290
291 LLVM_DEBUG(dbgs() << "Old term-cond:\n"
292 << *OldTermCond << "\n"
293 << "New term-cond:\n"
294 << *NewTermCond << "\n");
295
296 BI->setCondition(NewTermCond);
297
298 Expander.clear();
299 OldTermCond->eraseFromParent();
300 DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
301 return true;
302}
303
304namespace {
305
306class LoopTermFold : public LoopPass {
307public:
308 static char ID; // Pass ID, replacement for typeid
309
310 LoopTermFold();
311
312private:
313 bool runOnLoop(Loop *L, LPPassManager &LPM) override;
314 void getAnalysisUsage(AnalysisUsage &AU) const override;
315};
316
317} // end anonymous namespace
318
319LoopTermFold::LoopTermFold() : LoopPass(ID) {
321}
322
323void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
335}
336
337bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
338 if (skipLoop(L))
339 return false;
340
341 auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
342 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
343 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
344 const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
345 *L->getHeader()->getParent());
346 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
347 *L->getHeader()->getParent());
348 auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
349 MemorySSA *MSSA = nullptr;
350 if (MSSAAnalysis)
351 MSSA = &MSSAAnalysis->getMSSA();
352 return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
353}
354
357 LPMUpdater &) {
358 if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))
359 return PreservedAnalyses::all();
360
362 if (AR.MSSA)
363 PA.preserve<MemorySSAAnalysis>();
364 return PA;
365}
366
367char LoopTermFold::ID = 0;
368
369INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
370 false, false)
375INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
376INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
378
379Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
#define LLVM_DEBUG(X)
Definition: Debug.h:101
This header provides classes for managing per-loop analyses.
static std::optional< std::tuple< PHINode *, PHINode *, const SCEV *, bool > > canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, const LoopInfo &LI, const TargetTransformInfo &TTI)
loop term fold
loop term Loop Terminator Folding
static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI, TargetLibraryInfo &TLI, MemorySSA *MSSA)
This file exposes an interface to building/using memory SSA to walk memory instructions using a use/d...
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This pass exposes codegen information to IR-level passes.
Value * RHS
Value * LHS
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
Represent the analysis usage information of a pass.
AnalysisUsage & addRequiredID(const void *ID)
Definition: Pass.cpp:270
AnalysisUsage & addPreservedID(const void *ID)
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
Conditional or Unconditional Branch instruction.
void setCondition(Value *V)
void swapSuccessors()
Swap the successors of this branch instruction.
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
@ ICMP_EQ
equal
Definition: InstrTypes.h:778
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
This instruction compares its operands according to the predicate given to the constructor.
Value * CreateICmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:2371
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2686
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:92
bool hasPoisonGeneratingFlags() const LLVM_READONLY
Return true if this operator has flags which may cause this instruction to evaluate to poison despite...
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
virtual bool runOnLoop(Loop *L, LPPassManager &LPM)=0
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:39
An analysis that produces MemorySSA for a function.
Definition: MemorySSA.h:924
Legacy analysis pass which computes MemorySSA.
Definition: MemorySSA.h:981
Encapsulates MemorySSA, including all data associated with memory accesses.
Definition: MemorySSA.h:697
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:94
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This class uses information about analyze scalars to rewrite expressions in canonical form.
bool isSafeToExpand(const SCEV *S) const
Return true if the given expression is safe to expand in the sense that all materialized values are s...
bool isHighCostExpansion(ArrayRef< const SCEV * > Exprs, Loop *L, unsigned Budget, const TargetTransformInfo *TTI, const Instruction *At)
Return true for expressions that can't be evaluated at runtime within given Budget.
void clear()
Erase the contents of the InsertedExpressions map so that users trying to expand the same expression ...
Value * expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I)
Insert code to directly compute the specified SCEV expression into the program.
This class represents an analyzed expression in the program.
The main scalar evolution driver.
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
unsigned getSmallConstantMaxTripCount(const Loop *L)
Returns the upper bound of the loop trip count as a normal unsigned value.
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
Provides information about what library functions are available for the current target.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
const ParentTy * getParent() const
Definition: ilist_node.h:32
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root, Instruction *OnPathTo, DominatorTree *DT)
Return true if undefined behavior would provable be executed on the path to OnPathTo if Root produced...
std::optional< unsigned > getLoopEstimatedTripCount(Loop *L, unsigned *EstimatedLoopInvocationWeight=nullptr)
Returns a loop's estimated trip count based on branch weight metadata.
Definition: LoopUtils.cpp:849
char & LoopSimplifyID
Pass * createLoopTermFoldPass()
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
bool DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
Examine each PHI in the given block and delete it if it is dead.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
cl::opt< unsigned > SCEVCheapExpansionBudget
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
bool isAlmostDeadIV(PHINode *IV, BasicBlock *LatchBlock, Value *Cond)
Return true if the induction variable IV in a Loop whose latch is LatchBlock would become dead if the...
Definition: LoopUtils.cpp:469
void initializeLoopTermFoldPass(PassRegistry &)
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...