LLVM 23.0.0git
ExpandReductions.cpp
Go to the documentation of this file.
1//===- ExpandReductions.cpp - Expand reduction intrinsics -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements IR expansion for reduction intrinsics, allowing targets
10// to enable the intrinsics until just before codegen.
11//
12//===----------------------------------------------------------------------===//
13
17#include "llvm/CodeGen/Passes.h"
18#include "llvm/IR/Dominators.h"
19#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/Intrinsics.h"
24#include "llvm/Pass.h"
26
27using namespace llvm;
28
29namespace {
30
31bool expandReductions(Function &F, const TargetTransformInfo *TTI,
32 DominatorTree *DT, LoopInfo *LI) {
33 bool Changed = false;
35 for (auto &I : instructions(F)) {
36 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
37 switch (II->getIntrinsicID()) {
38 default: break;
39 case Intrinsic::vector_reduce_fadd:
40 case Intrinsic::vector_reduce_fmul:
41 case Intrinsic::vector_reduce_add:
42 case Intrinsic::vector_reduce_mul:
43 case Intrinsic::vector_reduce_and:
44 case Intrinsic::vector_reduce_or:
45 case Intrinsic::vector_reduce_xor:
46 case Intrinsic::vector_reduce_smax:
47 case Intrinsic::vector_reduce_smin:
48 case Intrinsic::vector_reduce_umax:
49 case Intrinsic::vector_reduce_umin:
50 case Intrinsic::vector_reduce_fmax:
51 case Intrinsic::vector_reduce_fmin:
52 if (TTI->shouldExpandReduction(II))
53 Worklist.push_back(II);
54
55 break;
56 }
57 }
58 }
59
60 for (auto *II : Worklist) {
61 FastMathFlags FMF =
62 isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
63 Intrinsic::ID ID = II->getIntrinsicID();
66 TTI->getPreferredExpandedReductionShuffle(II);
67
68 Value *Rdx = nullptr;
69 IRBuilder<> Builder(II);
70 IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
71 Builder.setFastMathFlags(FMF);
72 switch (ID) {
73 default: llvm_unreachable("Unexpected intrinsic!");
74 case Intrinsic::vector_reduce_fadd:
75 case Intrinsic::vector_reduce_fmul: {
76 // FMFs must be attached to the call, otherwise it's an ordered reduction
77 // and it can't be handled by generating a shuffle sequence.
78 Value *Acc = II->getArgOperand(0);
79 Value *Vec = II->getArgOperand(1);
80 unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
81 if (isa<ScalableVectorType>(Vec->getType())) {
82 Rdx = expandReductionViaLoop(Builder, Vec, RdxOpcode, Acc, DT, LI);
83 break;
84 }
85 if (!FMF.allowReassoc())
86 Rdx = getOrderedReduction(Builder, Acc, Vec, RdxOpcode, RK);
87 else {
88 if (!isPowerOf2_32(
89 cast<FixedVectorType>(Vec->getType())->getNumElements()))
90 continue;
91 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
92 Rdx = Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, Acc, Rdx,
93 "bin.rdx");
94 }
95 break;
96 }
97 case Intrinsic::vector_reduce_and:
98 case Intrinsic::vector_reduce_or: {
99 // Canonicalize logical or/and reductions:
100 // Or reduction for i1 is represented as:
101 // %val = bitcast <ReduxWidth x i1> to iReduxWidth
102 // %res = cmp ne iReduxWidth %val, 0
103 // And reduction for i1 is represented as:
104 // %val = bitcast <ReduxWidth x i1> to iReduxWidth
105 // %res = cmp eq iReduxWidth %val, 11111
106 Value *Vec = II->getArgOperand(0);
107 auto *FTy = cast<FixedVectorType>(Vec->getType());
108 unsigned NumElts = FTy->getNumElements();
109 if (!isPowerOf2_32(NumElts))
110 continue;
111
112 if (FTy->getElementType() == Builder.getInt1Ty()) {
113 Rdx = Builder.CreateBitCast(Vec, Builder.getIntNTy(NumElts));
114 if (ID == Intrinsic::vector_reduce_and) {
115 Rdx = Builder.CreateICmpEQ(
117 } else {
118 assert(ID == Intrinsic::vector_reduce_or && "Expected or reduction.");
119 Rdx = Builder.CreateIsNotNull(Rdx);
120 }
121 break;
122 }
123 unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
124 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
125 break;
126 }
127 case Intrinsic::vector_reduce_add:
128 case Intrinsic::vector_reduce_mul:
129 case Intrinsic::vector_reduce_xor:
130 case Intrinsic::vector_reduce_smax:
131 case Intrinsic::vector_reduce_smin:
132 case Intrinsic::vector_reduce_umax:
133 case Intrinsic::vector_reduce_umin: {
134 Value *Vec = II->getArgOperand(0);
135 unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
136 if (isa<ScalableVectorType>(Vec->getType())) {
137 Type *EltTy = Vec->getType()->getScalarType();
138 Value *Ident = getReductionIdentity(ID, EltTy, FMF);
139 Rdx = expandReductionViaLoop(Builder, Vec, RdxOpcode, Ident, DT, LI);
140 break;
141 }
142 if (!isPowerOf2_32(
143 cast<FixedVectorType>(Vec->getType())->getNumElements()))
144 continue;
145 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
146 break;
147 }
148 case Intrinsic::vector_reduce_fmax:
149 case Intrinsic::vector_reduce_fmin: {
150 // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
151 // semantics of the reduction.
152 Value *Vec = II->getArgOperand(0);
153 if (!isPowerOf2_32(
154 cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
155 !FMF.noNaNs())
156 continue;
157 unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
158 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
159 break;
160 }
161 }
162 II->replaceAllUsesWith(Rdx);
163 II->eraseFromParent();
164 Changed = true;
165 }
166 return Changed;
167}
168
169class ExpandReductions : public FunctionPass {
170public:
171 static char ID;
172 ExpandReductions() : FunctionPass(ID) {}
173
174 bool runOnFunction(Function &F) override {
175 const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
176 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
177 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
178 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
179 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
180 return expandReductions(F, TTI, DT, LI);
181 }
182
183 void getAnalysisUsage(AnalysisUsage &AU) const override {
184 AU.addRequired<TargetTransformInfoWrapperPass>();
185 AU.addPreserved<DominatorTreeWrapperPass>();
186 AU.addPreserved<LoopInfoWrapperPass>();
187 }
188};
189}
190
191char ExpandReductions::ID;
192INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
193 "Expand reduction intrinsics", false, false)
195INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
196 "Expand reduction intrinsics", false, false)
197
199 return new ExpandReductions();
200}
201
204 const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
206 auto *LI = AM.getCachedResult<LoopAnalysis>(F);
207 if (!expandReductions(F, &TTI, DT, LI))
208 return PreservedAnalyses::all();
212 return PA;
213}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Expand Atomic instructions
static bool runOnFunction(Function &F, bool PostInlining)
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This pass exposes codegen information to IR-level passes.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
static LLVM_ABI Constant * getAllOnesValue(Type *Ty)
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Convenience struct for specifying and reasoning about fast-math flags.
Definition FMF.h:23
bool allowReassoc() const
Flag queries.
Definition FMF.h:67
bool noNaNs() const
Definition FMF.h:68
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2835
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition Type.h:370
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI Value * getReductionIdentity(Intrinsic::ID RdxID, Type *Ty, FastMathFlags FMF)
Given information about an @llvm.vector.reduce.
LLVM_ABI unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID)
Returns the arithmetic instruction opcode used when expanding a reduction.
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
Definition MathExtras.h:279
LLVM_ABI Value * getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op, TargetTransformInfo::ReductionShuffle RS, RecurKind MinMaxKind=RecurKind::None)
Generates a vector reduction using shufflevectors to reduce the value.
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
TargetTransformInfo TTI
RecurKind
These are the kinds of recurrences that we support.
LLVM_ABI FunctionPass * createExpandReductionsPass()
This pass expands the reduction intrinsics into sequences of shuffles.
LLVM_ABI Value * expandReductionViaLoop(IRBuilderBase &Builder, Value *Vec, unsigned RdxOpcode, Value *Acc, DominatorTree *DT=nullptr, LoopInfo *LI=nullptr)
Expand a scalable vector reduction into a runtime loop that applies RdxOpcode element by element,...
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
LLVM_ABI RecurKind getMinMaxReductionRecurKind(Intrinsic::ID RdxID)
Returns the recurence kind used when expanding a min/max reduction.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI Value * getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src, unsigned Op, RecurKind MinMaxKind=RecurKind::None)
Generates an ordered vector reduction using extracts to reduce the value.