LLVM 22.0.0git
LoopTermFold.cpp
Go to the documentation of this file.
1//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//===----------------------------------------------------------------------===//
9
11#include "llvm/ADT/Statistic.h"
22#include "llvm/Config/llvm-config.h"
23#include "llvm/IR/BasicBlock.h"
24#include "llvm/IR/Dominators.h"
25#include "llvm/IR/IRBuilder.h"
26#include "llvm/IR/InstrTypes.h"
27#include "llvm/IR/Instruction.h"
29#include "llvm/IR/Type.h"
30#include "llvm/IR/Value.h"
32#include "llvm/Pass.h"
33#include "llvm/Support/Debug.h"
41#include <cassert>
42#include <optional>
43
44using namespace llvm;
45
46#define DEBUG_TYPE "loop-term-fold"
47
48STATISTIC(NumTermFold,
49 "Number of terminating condition fold recognized and performed");
50
51static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
53 const LoopInfo &LI, const TargetTransformInfo &TTI) {
54 if (!L->isInnermost()) {
55 LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
56 return std::nullopt;
57 }
58 // Only inspect on simple loop structure
59 if (!L->isLoopSimplifyForm()) {
60 LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
61 return std::nullopt;
62 }
63
65 LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
66 return std::nullopt;
67 }
68
69 BasicBlock *LoopLatch = L->getLoopLatch();
71 if (!BI || BI->isUnconditional())
72 return std::nullopt;
73 auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
74 if (!TermCond) {
76 dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
77 return std::nullopt;
78 }
79 if (!TermCond->hasOneUse()) {
81 dbgs()
82 << "Cannot replace terminating condition with more than one use\n");
83 return std::nullopt;
84 }
85
86 BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
87 Value *RHS = TermCond->getOperand(1);
88 if (!LHS || !L->isLoopInvariant(RHS))
89 // We could pattern match the inverse form of the icmp, but that is
90 // non-canonical, and this pass is running *very* late in the pipeline.
91 return std::nullopt;
92
93 // Find the IV used by the current exit condition.
94 PHINode *ToFold;
95 Value *ToFoldStart, *ToFoldStep;
96 if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
97 return std::nullopt;
98
99 // Ensure the simple recurrence is a part of the current loop.
100 if (ToFold->getParent() != L->getHeader())
101 return std::nullopt;
102
103 // If that IV isn't dead after we rewrite the exit condition in terms of
104 // another IV, there's no point in doing the transform.
105 if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
106 return std::nullopt;
107
108 // Inserting instructions in the preheader has a runtime cost, scale
109 // the allowed cost with the loops trip count as best we can.
110 const unsigned ExpansionBudget = [&]() {
111 unsigned Budget = 2 * SCEVCheapExpansionBudget;
112 if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
113 return std::min(Budget, SmallTC);
114 if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
115 return std::min(Budget, *SmallTC);
116 // Unknown trip count, assume long running by default.
117 return Budget;
118 }();
119
120 const SCEV *BECount = SE.getBackedgeTakenCount(L);
121 SCEVExpander Expander(SE, "lsr_fold_term_cond");
122
123 PHINode *ToHelpFold = nullptr;
124 const SCEV *TermValueS = nullptr;
125 bool MustDropPoison = false;
126 auto InsertPt = L->getLoopPreheader()->getTerminator();
127 for (PHINode &PN : L->getHeader()->phis()) {
128 if (ToFold == &PN)
129 continue;
130
131 if (!SE.isSCEVable(PN.getType())) {
132 LLVM_DEBUG(dbgs() << "IV of phi '" << PN
133 << "' is not SCEV-able, not qualified for the "
134 "terminating condition folding.\n");
135 continue;
136 }
137 const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
138 // Only speculate on affine AddRec
139 if (!AddRec || !AddRec->isAffine()) {
140 LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
141 << "' is not an affine add recursion, not qualified "
142 "for the terminating condition folding.\n");
143 continue;
144 }
145
146 // Check that we can compute the value of AddRec on the exiting iteration
147 // without soundness problems. evaluateAtIteration internally needs
148 // to multiply the stride of the iteration number - which may wrap around.
149 // The issue here is subtle because computing the result accounting for
150 // wrap is insufficient. In order to use the result in an exit test, we
151 // must also know that AddRec doesn't take the same value on any previous
152 // iteration. The simplest case to consider is a candidate IV which is
153 // narrower than the trip count (and thus original IV), but this can
154 // also happen due to non-unit strides on the candidate IVs.
155 if (!AddRec->hasNoSelfWrap() ||
156 !SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
157 continue;
158
159 const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
160 const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
161 if (!Expander.isSafeToExpand(TermValueSLocal)) {
163 dbgs() << "Is not safe to expand terminating value for phi node" << PN
164 << "\n");
165 continue;
166 }
167
168 if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,
169 InsertPt)) {
171 dbgs() << "Is too expensive to expand terminating value for phi node"
172 << PN << "\n");
173 continue;
174 }
175
176 // The candidate IV may have been otherwise dead and poison from the
177 // very first iteration. If we can't disprove that, we can't use the IV.
178 if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
179 LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
180 continue;
181 }
182
183 // The candidate IV may become poison on the last iteration. If this
184 // value is not branched on, this is a well defined program. We're
185 // about to add a new use to this IV, and we have to ensure we don't
186 // insert UB which didn't previously exist.
187 bool MustDropPoisonLocal = false;
188 Instruction *PostIncV =
189 cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
190 if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
191 &DT)) {
192 LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
193 << "\n");
194
195 // If this is a complex recurrance with multiple instructions computing
196 // the backedge value, we might need to strip poison flags from all of
197 // them.
198 if (PostIncV->getOperand(0) != &PN)
199 continue;
200
201 // In order to perform the transform, we need to drop the poison
202 // generating flags on this instruction (if any).
203 MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
204 }
205
206 // We pick the last legal alternate IV. We could expore choosing an optimal
207 // alternate IV if we had a decent heuristic to do so.
208 ToHelpFold = &PN;
209 TermValueS = TermValueSLocal;
210 MustDropPoison = MustDropPoisonLocal;
211 }
212
213 LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
214 << "Cannot find other AddRec IV to help folding\n";);
215
216 LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
217 << "\nFound loop that can fold terminating condition\n"
218 << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
219 << " TermCond: " << *TermCond << "\n"
220 << " BrandInst: " << *BI << "\n"
221 << " ToFold: " << *ToFold << "\n"
222 << " ToHelpFold: " << *ToHelpFold << "\n");
223
224 if (!ToFold || !ToHelpFold)
225 return std::nullopt;
226 return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
227}
228
230 LoopInfo &LI, const TargetTransformInfo &TTI,
231 TargetLibraryInfo &TLI, MemorySSA *MSSA) {
232 std::unique_ptr<MemorySSAUpdater> MSSAU;
233 if (MSSA)
234 MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
235
236 auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
237 if (!Opt)
238 return false;
239
240 auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
241
242 NumTermFold++;
243
244 BasicBlock *LoopPreheader = L->getLoopPreheader();
245 BasicBlock *LoopLatch = L->getLoopLatch();
246
247 (void)ToFold;
248 LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
249 << *ToFold << "\n"
250 << "New term-cond phi-node:\n"
251 << *ToHelpFold << "\n");
252
253 Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
254 (void)StartValue;
255 Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
256
257 // See comment in canFoldTermCondOfLoop on why this is sufficient.
258 if (MustDrop)
259 cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
260
261 // SCEVExpander for both use in preheader and latch
262 SCEVExpander Expander(SE, "lsr_fold_term_cond");
263
264 assert(Expander.isSafeToExpand(TermValueS) &&
265 "Terminating value was checked safe in canFoldTerminatingCondition");
266
267 // Create new terminating value at loop preheader
268 Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
269 LoopPreheader->getTerminator());
270
271 LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
272 << *StartValue << "\n"
273 << "Terminating value of new term-cond phi-node:\n"
274 << *TermValue << "\n");
275
276 // Create new terminating condition at loop latch
277 BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
278 ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
279 IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
280 Value *NewTermCond =
281 LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
282 "lsr_fold_term_cond.replaced_term_cond");
283 // Swap successors to exit loop body if IV equals to new TermValue
284 if (BI->getSuccessor(0) == L->getHeader())
285 BI->swapSuccessors();
286
287 LLVM_DEBUG(dbgs() << "Old term-cond:\n"
288 << *OldTermCond << "\n"
289 << "New term-cond:\n"
290 << *NewTermCond << "\n");
291
292 BI->setCondition(NewTermCond);
293
294 Expander.clear();
295 OldTermCond->eraseFromParent();
296 DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
297 return true;
298}
299
300namespace {
301
302class LoopTermFold : public LoopPass {
303public:
304 static char ID; // Pass ID, replacement for typeid
305
306 LoopTermFold();
307
308private:
309 bool runOnLoop(Loop *L, LPPassManager &LPM) override;
310 void getAnalysisUsage(AnalysisUsage &AU) const override;
311};
312
313} // end anonymous namespace
314
315LoopTermFold::LoopTermFold() : LoopPass(ID) {
317}
318
319void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
320 AU.addRequired<LoopInfoWrapperPass>();
321 AU.addPreserved<LoopInfoWrapperPass>();
324 AU.addRequired<DominatorTreeWrapperPass>();
325 AU.addPreserved<DominatorTreeWrapperPass>();
326 AU.addRequired<ScalarEvolutionWrapperPass>();
327 AU.addPreserved<ScalarEvolutionWrapperPass>();
328 AU.addRequired<TargetLibraryInfoWrapperPass>();
329 AU.addRequired<TargetTransformInfoWrapperPass>();
330 AU.addPreserved<MemorySSAWrapperPass>();
331}
332
333bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
334 if (skipLoop(L))
335 return false;
336
337 auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
338 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
339 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
340 const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
341 *L->getHeader()->getParent());
342 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
343 *L->getHeader()->getParent());
344 auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
345 MemorySSA *MSSA = nullptr;
346 if (MSSAAnalysis)
347 MSSA = &MSSAAnalysis->getMSSA();
348 return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
349}
350
353 LPMUpdater &) {
354 if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))
355 return PreservedAnalyses::all();
356
358 if (AR.MSSA)
359 PA.preserve<MemorySSAAnalysis>();
360 return PA;
361}
362
363char LoopTermFold::ID = 0;
364
365INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
366 false, false)
371INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
372INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
374
375Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
@ PostInc
early cse Early CSE w MemorySSA
This header provides classes for managing per-loop analyses.
static std::optional< std::tuple< PHINode *, PHINode *, const SCEV *, bool > > canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, const LoopInfo &LI, const TargetTransformInfo &TTI)
static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI, TargetLibraryInfo &TLI, MemorySSA *MSSA)
This file exposes an interface to building/using memory SSA to walk memory instructions using a use/d...
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
This pass exposes codegen information to IR-level passes.
Value * RHS
Value * LHS
Represent the analysis usage information of a pass.
LLVM_ABI AnalysisUsage & addRequiredID(const void *ID)
Definition Pass.cpp:284
AnalysisUsage & addPreservedID(const void *ID)
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
Conditional or Unconditional Branch instruction.
void setCondition(Value *V)
LLVM_ABI void swapSuccessors()
Swap the successors of this branch instruction.
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:321
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:164
This instruction compares its operands according to the predicate given to the constructor.
Value * CreateICmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:2442
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2788
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI bool hasPoisonGeneratingFlags() const LLVM_READONLY
Return true if this operator has flags which may cause this instruction to evaluate to poison despite...
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
An analysis that produces MemorySSA for a function.
Definition MemorySSA.h:936
Encapsulates MemorySSA, including all data associated with memory accesses.
Definition MemorySSA.h:702
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
Definition Pass.h:99
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This class uses information about analyze scalars to rewrite expressions in canonical form.
LLVM_ABI bool isSafeToExpand(const SCEV *S) const
Return true if the given expression is safe to expand in the sense that all materialized values are s...
bool isHighCostExpansion(ArrayRef< const SCEV * > Exprs, Loop *L, unsigned Budget, const TargetTransformInfo *TTI, const Instruction *At)
Return true for expressions that can't be evaluated at runtime within given Budget.
void clear()
Erase the contents of the InsertedExpressions map so that users trying to expand the same expression ...
LLVM_ABI Value * expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I)
Insert code to directly compute the specified SCEV expression into the program.
This class represents an analyzed expression in the program.
The main scalar evolution driver.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
Provides information about what library functions are available for the current target.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
const ParentTy * getParent() const
Definition ilist_node.h:34
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root, Instruction *OnPathTo, DominatorTree *DT)
Return true if undefined behavior would provable be executed on the path to OnPathTo if Root produced...
LLVM_ABI std::optional< unsigned > getLoopEstimatedTripCount(Loop *L, unsigned *EstimatedLoopInvocationWeight=nullptr)
Return either:
LLVM_ABI void initializeLoopTermFoldPass(PassRegistry &)
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI char & LoopSimplifyID
LLVM_ABI Pass * createLoopTermFoldPass()
AnalysisManager< Loop, LoopStandardAnalysisResults & > LoopAnalysisManager
The loop analysis manager.
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
LLVM_ABI bool DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
Examine each PHI in the given block and delete it if it is dead.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ABI cl::opt< unsigned > SCEVCheapExpansionBudget
TargetTransformInfo TTI
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
LLVM_ABI PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
LLVM_ABI bool isAlmostDeadIV(PHINode *IV, BasicBlock *LatchBlock, Value *Cond)
Return true if the induction variable IV in a Loop whose latch is LatchBlock would become dead if the...
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...