LLVM 22.0.0git
EVLIndVarSimplify.cpp
Go to the documentation of this file.
1//===---- EVLIndVarSimplify.cpp - Optimize vectorized loops w/ EVL IV------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass optimizes a vectorized loop with canonical IV to using EVL-based
10// IV if it was tail-folded by predicated EVL.
11//
12//===----------------------------------------------------------------------===//
13
15#include "llvm/ADT/Statistic.h"
23#include "llvm/IR/IRBuilder.h"
26#include "llvm/Support/Debug.h"
31
32#define DEBUG_TYPE "evl-iv-simplify"
33
34using namespace llvm;
35
36STATISTIC(NumEliminatedCanonicalIV, "Number of canonical IVs we eliminated");
37
39 "enable-evl-indvar-simplify",
40 cl::desc("Enable EVL-based induction variable simplify Pass"), cl::Hidden,
41 cl::init(true));
42
43namespace {
44struct EVLIndVarSimplifyImpl {
46 OptimizationRemarkEmitter *ORE = nullptr;
47
48 EVLIndVarSimplifyImpl(LoopStandardAnalysisResults &LAR,
50 : SE(LAR.SE), ORE(ORE) {}
51
52 /// Returns true if modify the loop.
53 bool run(Loop &L);
54};
55} // anonymous namespace
56
57/// Returns the constant part of vectorization factor from the induction
58/// variable's step value SCEV expression.
59static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F) {
60 if (!Step)
61 return 0U;
62
63 // Looking for loops with IV step value in the form of `(<constant VF> x
64 // vscale)`.
65 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Step)) {
66 if (Mul->getNumOperands() == 2) {
67 const SCEV *LHS = Mul->getOperand(0);
68 const SCEV *RHS = Mul->getOperand(1);
69 if (const auto *Const = dyn_cast<SCEVConstant>(LHS);
70 Const && isa<SCEVVScale>(RHS)) {
71 uint64_t V = Const->getAPInt().getLimitedValue();
72 if (llvm::isUInt<32>(V))
73 return V;
74 }
75 }
76 }
77
78 // If not, see if the vscale_range of the parent function is a fixed value,
79 // which makes the step value to be replaced by a constant.
80 if (F.hasFnAttribute(Attribute::VScaleRange))
81 if (const auto *ConstStep = dyn_cast<SCEVConstant>(Step)) {
82 APInt V = ConstStep->getAPInt().abs();
84 if (const APInt *Fixed = CR.getSingleElement()) {
85 V = V.zextOrTrunc(Fixed->getBitWidth());
86 uint64_t VF = V.udiv(*Fixed).getLimitedValue();
87 if (VF && llvm::isUInt<32>(VF) &&
88 // Make sure step is divisible by vscale.
89 V.urem(*Fixed).isZero())
90 return VF;
91 }
92 }
93
94 return 0U;
95}
96
97bool EVLIndVarSimplifyImpl::run(Loop &L) {
99 return false;
100
101 if (!getBooleanLoopAttribute(&L, "llvm.loop.isvectorized"))
102 return false;
103 const MDOperand *EVLMD =
104 findStringMetadataForLoop(&L, "llvm.loop.isvectorized.tailfoldingstyle")
105 .value_or(nullptr);
106 if (!EVLMD || !EVLMD->equalsStr("evl"))
107 return false;
108
109 BasicBlock *LatchBlock = L.getLoopLatch();
110 ICmpInst *OrigLatchCmp = L.getLatchCmpInst();
111 if (!LatchBlock || !OrigLatchCmp)
112 return false;
113
115 PHINode *IndVar = L.getInductionVariable(SE);
116 if (!IndVar || !L.getInductionDescriptor(SE, IVD)) {
117 const char *Reason = (IndVar ? "induction descriptor is not available"
118 : "cannot recognize induction variable");
119 LLVM_DEBUG(dbgs() << "Cannot retrieve IV from loop " << L.getName()
120 << " because" << Reason << "\n");
121 if (ORE) {
122 ORE->emit([&]() {
123 return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar",
124 L.getStartLoc(), L.getHeader())
125 << "Cannot retrieve IV because " << ore::NV("Reason", Reason);
126 });
127 }
128 return false;
129 }
130
131 BasicBlock *InitBlock, *BackEdgeBlock;
132 if (!L.getIncomingAndBackEdge(InitBlock, BackEdgeBlock)) {
133 LLVM_DEBUG(dbgs() << "Expect unique incoming and backedge in "
134 << L.getName() << "\n");
135 if (ORE) {
136 ORE->emit([&]() {
137 return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure",
138 L.getStartLoc(), L.getHeader())
139 << "Does not have a unique incoming and backedge";
140 });
141 }
142 return false;
143 }
144
145 // Retrieve the loop bounds.
146 std::optional<Loop::LoopBounds> Bounds = L.getBounds(SE);
147 if (!Bounds) {
148 LLVM_DEBUG(dbgs() << "Could not obtain the bounds for loop " << L.getName()
149 << "\n");
150 if (ORE) {
151 ORE->emit([&]() {
152 return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure",
153 L.getStartLoc(), L.getHeader())
154 << "Could not obtain the loop bounds";
155 });
156 }
157 return false;
158 }
159 Value *CanonicalIVInit = &Bounds->getInitialIVValue();
160 Value *CanonicalIVFinal = &Bounds->getFinalIVValue();
161
162 const SCEV *StepV = IVD.getStep();
163 uint32_t VF = getVFFromIndVar(StepV, *L.getHeader()->getParent());
164 if (!VF) {
165 LLVM_DEBUG(dbgs() << "Could not infer VF from IndVar step '" << *StepV
166 << "'\n");
167 if (ORE) {
168 ORE->emit([&]() {
169 return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar",
170 L.getStartLoc(), L.getHeader())
171 << "Could not infer VF from IndVar step "
172 << ore::NV("Step", StepV);
173 });
174 }
175 return false;
176 }
177 LLVM_DEBUG(dbgs() << "Using VF=" << VF << " for loop " << L.getName()
178 << "\n");
179
180 // Try to find the EVL-based induction variable.
181 using namespace PatternMatch;
182 BasicBlock *BB = IndVar->getParent();
183
184 Value *EVLIndVar = nullptr;
185 Value *RemTC = nullptr;
186 Value *TC = nullptr;
187 auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(
188 m_Value(RemTC), m_SpecificInt(VF),
189 /*Scalable=*/m_SpecificInt(1));
190 for (PHINode &PN : BB->phis()) {
191 if (&PN == IndVar)
192 continue;
193
194 // Check 1: it has to contain both incoming (init) & backedge blocks
195 // from IndVar.
196 if (PN.getBasicBlockIndex(InitBlock) < 0 ||
197 PN.getBasicBlockIndex(BackEdgeBlock) < 0)
198 continue;
199 // Check 2: EVL index is always increasing, thus its inital value has to be
200 // equal to either the initial IV value (when the canonical IV is also
201 // increasing) or the last IV value (when canonical IV is decreasing).
202 Value *Init = PN.getIncomingValueForBlock(InitBlock);
204 switch (Bounds->getDirection()) {
205 case Direction::Increasing:
206 if (Init != CanonicalIVInit)
207 continue;
208 break;
209 case Direction::Decreasing:
210 if (Init != CanonicalIVFinal)
211 continue;
212 break;
213 case Direction::Unknown:
214 // To be more permissive and see if either the initial or final IV value
215 // matches PN's init value.
216 if (Init != CanonicalIVInit && Init != CanonicalIVFinal)
217 continue;
218 break;
219 }
220 Value *RecValue = PN.getIncomingValueForBlock(BackEdgeBlock);
221 assert(RecValue && "expect recurrent IndVar value");
222
223 LLVM_DEBUG(dbgs() << "Found candidate PN of EVL-based IndVar: " << PN
224 << "\n");
225
226 // Check 3: Pattern match to find the EVL-based index and total trip count
227 // (TC).
228 if (match(RecValue,
229 m_c_Add(m_ZExtOrSelf(IntrinsicMatch), m_Specific(&PN))) &&
230 match(RemTC, m_Sub(m_Value(TC), m_Specific(&PN)))) {
231 EVLIndVar = RecValue;
232 break;
233 }
234 }
235
236 if (!EVLIndVar || !TC)
237 return false;
238
239 LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n");
240 if (ORE) {
241 ORE->emit([&]() {
242 DebugLoc DL;
243 BasicBlock *Region = nullptr;
244 if (auto *I = dyn_cast<Instruction>(EVLIndVar)) {
245 DL = I->getDebugLoc();
246 Region = I->getParent();
247 } else {
248 DL = L.getStartLoc();
249 Region = L.getHeader();
250 }
251 return OptimizationRemark(DEBUG_TYPE, "UseEVLIndVar", DL, Region)
252 << "Using " << ore::NV("EVLIndVar", EVLIndVar)
253 << " for EVL-based IndVar";
254 });
255 }
256
257 // Create an EVL-based comparison and replace the branch to use it as
258 // predicate.
259
260 // Loop::getLatchCmpInst check at the beginning of this function has ensured
261 // that latch block ends in a conditional branch.
262 auto *LatchBranch = cast<BranchInst>(LatchBlock->getTerminator());
263 assert(LatchBranch->isConditional() &&
264 "expect the loop latch to be ended with a conditional branch");
266 if (LatchBranch->getSuccessor(0) == L.getHeader())
267 Pred = ICmpInst::ICMP_NE;
268 else
269 Pred = ICmpInst::ICMP_EQ;
270
271 IRBuilder<> Builder(OrigLatchCmp);
272 auto *NewLatchCmp = Builder.CreateICmp(Pred, EVLIndVar, TC);
273 OrigLatchCmp->replaceAllUsesWith(NewLatchCmp);
274
275 // llvm::RecursivelyDeleteDeadPHINode only deletes cycles whose values are
276 // not used outside the cycles. However, in this case the now-RAUW-ed
277 // OrigLatchCmp will be considered a use outside the cycle while in reality
278 // it's practically dead. Thus we need to remove it before calling
279 // RecursivelyDeleteDeadPHINode.
282 LLVM_DEBUG(dbgs() << "Removed original IndVar\n");
283
284 ++NumEliminatedCanonicalIV;
285
286 return true;
287}
288
291 LPMUpdater &U) {
292 Function &F = *L.getHeader()->getParent();
293 auto &FAMProxy = LAM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR);
295 FAMProxy.getCachedResult<OptimizationRemarkEmitterAnalysis>(F);
296
297 if (EVLIndVarSimplifyImpl(AR, ORE).run(L))
298 return PreservedAnalyses::allInSet<CFGAnalyses>();
299 return PreservedAnalyses::all();
300}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static cl::opt< bool > EnableEVLIndVarSimplify("enable-evl-indvar-simplify", cl::desc("Enable EVL-based induction variable simplify Pass"), cl::Hidden, cl::init(true))
static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F)
Returns the constant part of vectorization factor from the induction variable's step value SCEV expre...
#define DEBUG_TYPE
This header provides classes for managing a pipeline of passes over loops in LLVM IR.
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
LoopAnalysisManager LAM
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
#define LLVM_DEBUG(...)
Definition: Debug.h:119
Value * RHS
Value * LHS
BinaryOperator * Mul
Class for arbitrary precision integers.
Definition: APInt.h:78
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:255
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:412
LLVM Basic Block Representation.
Definition: BasicBlock.h:62
iterator_range< const_phi_iterator > phis() const
Returns a range that iterates over the phis in the basic block.
Definition: BasicBlock.h:528
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:233
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:678
This class represents a range of values.
Definition: ConstantRange.h:47
const APInt * getSingleElement() const
If this set contains a single element, return it, otherwise return null.
A debug info location.
Definition: DebugLoc.h:124
This instruction compares its operands according to the predicate given to the constructor.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2780
A struct for saving information about induction variables.
const SCEV * getStep() const
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:40
Tracking metadata reference owned by Metadata.
Definition: Metadata.h:899
The optimization diagnostic interface.
Diagnostic information for missed-optimization remarks.
Diagnostic information for applied optimization remarks.
An analysis over an "inner" IR unit that provides access to an analysis manager over a "outer" IR uni...
Definition: PassManager.h:716
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:118
This class represents an analyzed expression in the program.
The main scalar evolution driver.
Value * getOperand(unsigned i) const
Definition: User.h:232
unsigned getNumOperands() const
Definition: User.h:254
LLVM Value Representation.
Definition: Value.h:75
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:546
const ParentTy * getParent() const
Definition: ilist_node.h:34
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
match_combine_or< CastInst_match< OpTy, ZExtInst >, OpTy > m_ZExtOrSelf(const OpTy &Op)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:962
BinaryOp_match< LHS, RHS, Instruction::Add, true > m_c_Add(const LHS &L, const RHS &R)
Matches a Add with LHS and RHS in either order.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::Sub > m_Sub(const LHS &L, const RHS &R)
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:444
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
DiagnosticInfoOptimizationBase::Argument NV
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
LLVM_ABI bool getBooleanLoopAttribute(const Loop *TheLoop, StringRef Name)
Returns true if Name is applied to TheLoop and enabled.
Definition: LoopInfo.cpp:1121
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:533
LLVM_ABI std::optional< const MDOperand * > findStringMetadataForLoop(const Loop *TheLoop, StringRef Name)
Find string metadata for loop.
Definition: LoopInfo.cpp:1089
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:207
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
LLVM_ABI bool RecursivelyDeleteDeadPHINode(PHINode *PN, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is an effectively dead PHI node, due to being a def-use chain of single-use no...
Definition: Local.cpp:641
PreservedAnalyses run(Loop &L, LoopAnalysisManager &LAM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...
Direction
An enum for the direction of the loop.
Definition: LoopInfo.h:217