LLVM 20.0.0git
BottomUpVec.cpp
Go to the documentation of this file.
1//===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17
18namespace llvm::sandboxir {
19
21 : FunctionPass("bottom-up-vec"),
22 RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
23
24// TODO: This is a temporary function that returns some seeds.
25// Replace this with SeedCollector's function when it lands.
28 for (auto &I : BB)
29 if (auto *SI = llvm::dyn_cast<StoreInst>(&I))
30 Seeds.push_back(SI);
31 return Seeds;
32}
33
35 unsigned OpIdx) {
37 for (Value *BndlV : Bndl) {
38 auto *BndlI = cast<Instruction>(BndlV);
39 Operands.push_back(BndlI->getOperand(OpIdx));
40 }
41 return Operands;
42}
43
46 // TODO: Use the VecUtils function for getting the bottom instr once it lands.
47 auto *BotI = cast<Instruction>(
48 *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
49 return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
50 }));
51 // If Bndl contains Arguments or Constants, use the beginning of the BB.
52 return std::next(BotI->getIterator());
53}
54
55Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
57 Change = true;
58 assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
59 "Expect Instructions!");
60 auto &Ctx = Bndl[0]->getContext();
61
62 Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
63 auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
64
66
67 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
68 switch (Opcode) {
69 case Instruction::Opcode::ZExt:
70 case Instruction::Opcode::SExt:
71 case Instruction::Opcode::FPToUI:
72 case Instruction::Opcode::FPToSI:
73 case Instruction::Opcode::FPExt:
74 case Instruction::Opcode::PtrToInt:
75 case Instruction::Opcode::IntToPtr:
76 case Instruction::Opcode::SIToFP:
77 case Instruction::Opcode::UIToFP:
78 case Instruction::Opcode::Trunc:
79 case Instruction::Opcode::FPTrunc:
80 case Instruction::Opcode::BitCast: {
81 assert(Operands.size() == 1u && "Casts are unary!");
82 return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast");
83 }
84 case Instruction::Opcode::FCmp:
85 case Instruction::Opcode::ICmp: {
86 auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
88 [Pred](auto *SBV) {
89 return cast<CmpInst>(SBV)->getPredicate() == Pred;
90 }) &&
91 "Expected same predicate across bundle.");
92 return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
93 "VCmp");
94 }
95 case Instruction::Opcode::Select: {
96 return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
97 Ctx, "Vec");
98 }
99 case Instruction::Opcode::FNeg: {
100 auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
101 auto OpC = UOp0->getOpcode();
102 return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt,
103 Ctx, "Vec");
104 }
105 case Instruction::Opcode::Add:
106 case Instruction::Opcode::FAdd:
107 case Instruction::Opcode::Sub:
108 case Instruction::Opcode::FSub:
109 case Instruction::Opcode::Mul:
110 case Instruction::Opcode::FMul:
111 case Instruction::Opcode::UDiv:
112 case Instruction::Opcode::SDiv:
113 case Instruction::Opcode::FDiv:
114 case Instruction::Opcode::URem:
115 case Instruction::Opcode::SRem:
116 case Instruction::Opcode::FRem:
117 case Instruction::Opcode::Shl:
118 case Instruction::Opcode::LShr:
119 case Instruction::Opcode::AShr:
120 case Instruction::Opcode::And:
121 case Instruction::Opcode::Or:
122 case Instruction::Opcode::Xor: {
123 auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
124 auto *LHS = Operands[0];
125 auto *RHS = Operands[1];
126 return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS,
127 BinOp0, WhereIt, Ctx, "Vec");
128 }
129 case Instruction::Opcode::Load: {
130 auto *Ld0 = cast<LoadInst>(Bndl[0]);
131 Value *Ptr = Ld0->getPointerOperand();
132 return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL");
133 }
134 case Instruction::Opcode::Store: {
135 auto Align = cast<StoreInst>(Bndl[0])->getAlign();
136 Value *Val = Operands[0];
137 Value *Ptr = Operands[1];
138 return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
139 }
140 case Instruction::Opcode::Br:
141 case Instruction::Opcode::Ret:
142 case Instruction::Opcode::PHI:
143 case Instruction::Opcode::AddrSpaceCast:
144 case Instruction::Opcode::Call:
145 case Instruction::Opcode::GetElementPtr:
146 llvm_unreachable("Unimplemented");
147 break;
148 default:
149 llvm_unreachable("Unimplemented");
150 break;
151 }
152 llvm_unreachable("Missing switch case!");
153 // TODO: Propagate debug info.
154}
155
156void BottomUpVec::tryEraseDeadInstrs() {
157 // Visiting the dead instructions bottom-to-top.
158 sort(DeadInstrCandidates,
159 [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); });
160 for (Instruction *I : reverse(DeadInstrCandidates)) {
161 if (I->hasNUses(0))
162 I->eraseFromParent();
163 }
164 DeadInstrCandidates.clear();
165}
166
167Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
169
170 Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
171 unsigned Lanes = VecUtils::getNumLanes(ToPack);
172 Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
173
174 // Create a series of pack instructions.
175 Value *LastInsert = PoisonValue::get(VecTy);
176
177 Context &Ctx = ToPack[0]->getContext();
178
179 unsigned InsertIdx = 0;
180 for (Value *Elm : ToPack) {
181 // An element can be either scalar or vector. We need to generate different
182 // IR for each case.
183 if (Elm->getType()->isVectorTy()) {
184 unsigned NumElms =
185 cast<FixedVectorType>(Elm->getType())->getNumElements();
186 for (auto ExtrLane : seq<int>(0, NumElms)) {
187 // We generate extract-insert pairs, for each lane in `Elm`.
188 Constant *ExtrLaneC =
190 // This may return a Constant if Elm is a Constant.
191 auto *ExtrI =
192 ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack");
193 if (!isa<Constant>(ExtrI))
194 WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator());
195 Constant *InsertLaneC =
196 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
197 // This may also return a Constant if ExtrI is a Constant.
198 auto *InsertI = InsertElementInst::create(
199 LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack");
200 if (!isa<Constant>(InsertI)) {
201 LastInsert = InsertI;
202 WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator());
203 }
204 }
205 } else {
206 Constant *InsertLaneC =
207 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
208 // This may be folded into a Constant if LastInsert is a Constant. In
209 // that case we only collect the last constant.
210 LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
211 WhereIt, Ctx, "Pack");
212 if (auto *NewI = dyn_cast<Instruction>(LastInsert))
213 WhereIt = std::next(NewI->getIterator());
214 }
215 }
216 return LastInsert;
217}
218
219Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
220 Value *NewVec = nullptr;
221 const auto &LegalityRes = Legality->canVectorize(Bndl);
222 switch (LegalityRes.getSubclassID()) {
224 auto *I = cast<Instruction>(Bndl[0]);
225 SmallVector<Value *, 2> VecOperands;
226 switch (I->getOpcode()) {
227 case Instruction::Opcode::Load:
228 // Don't recurse towards the pointer operand.
229 VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
230 break;
231 case Instruction::Opcode::Store: {
232 // Don't recurse towards the pointer operand.
233 auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1);
234 VecOperands.push_back(VecOp);
235 VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
236 break;
237 }
238 default:
239 // Visit all operands.
240 for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
241 auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1);
242 VecOperands.push_back(VecOp);
243 }
244 break;
245 }
246 NewVec = createVectorInstr(Bndl, VecOperands);
247
248 // Collect the original scalar instructions as they may be dead.
249 if (NewVec != nullptr) {
250 for (Value *V : Bndl)
251 DeadInstrCandidates.push_back(cast<Instruction>(V));
252 }
253 break;
254 }
256 // If we can't vectorize the seeds then just return.
257 if (Depth == 0)
258 return nullptr;
259 NewVec = createPack(Bndl);
260 break;
261 }
262 }
263 return NewVec;
264}
265
266bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
267 DeadInstrCandidates.clear();
268 vectorizeRec(Bndl, /*Depth=*/0);
269 tryEraseDeadInstrs();
270 return Change;
271}
272
274 Legality = std::make_unique<LegalityAnalysis>(
275 A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
276 F.getContext());
277 Change = false;
278 // TODO: Start from innermost BBs first
279 for (auto &BB : F) {
280 // TODO: Replace with proper SeedCollector function.
281 auto Seeds = collectSeeds(BB);
282 // TODO: Slice Seeds into smaller chunks.
283 // TODO: If vectorization succeeds, run the RegionPassManager on the
284 // resulting region.
285 if (Seeds.size() >= 2)
286 Change |= tryVectorize(Seeds);
287 }
288 return Change;
289}
290
291} // namespace llvm::sandboxir
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
Value * RHS
Value * LHS
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
iterator end() const
Definition: ArrayRef.h:157
iterator begin() const
Definition: ArrayRef.h:156
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:177
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
LLVM Value Representation.
Definition: Value.h:74
Contains a list of sandboxir::Instruction's.
Definition: BasicBlock.h:67
static Value * createWithCopiedFlags(Instruction::Opcode Op, Value *LHS, Value *RHS, Value *CopyFrom, InsertPosition Pos, Context &Ctx, const Twine &Name="")
BottomUpVec(StringRef Pipeline)
Definition: BottomUpVec.cpp:20
bool runOnFunction(Function &F, const Analyses &A) final
\Returns true if it modifies F.
static Value * create(Type *DestTy, Opcode Op, Value *Operand, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static CmpInst * create(Predicate Pred, Value *S1, Value *S2, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static ConstantInt * getSigned(IntegerType *Ty, int64_t V)
Return a ConstantInt with the specified value for the specified type.
Definition: Constant.cpp:57
static Value * create(Value *Vec, Value *Idx, InsertPosition Pos, Context &Ctx, const Twine &Name="")
A pass that runs on a sandbox::Function.
Definition: Pass.h:71
static Value * create(Value *Vec, Value *NewElt, Value *Idx, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static LoadInst * create(Type *Ty, Value *Ptr, MaybeAlign Align, InsertPosition Pos, bool IsVolatile, Context &Ctx, const Twine &Name="")
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constant.cpp:238
static Value * create(Value *Cond, Value *True, Value *False, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static StoreInst * create(Value *V, Value *Ptr, MaybeAlign Align, InsertPosition Pos, bool IsVolatile, Context &Ctx)
static Type * getInt32Ty(Context &Ctx)
static Value * createWithCopiedFlags(Instruction::Opcode Op, Value *OpV, Value *CopyFrom, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static Type * getExpectedType(const Value *V)
\Returns the expected type of Value V.
Definition: Utils.h:30
A SandboxIR Value has users. This is the base class.
Definition: Value.h:63
static Type * getCommonScalarType(ArrayRef< Value * > Bndl)
Similar to tryGetCommonScalarType() but will assert that there is a common type.
Definition: VecUtils.h:129
static unsigned getNumLanes(Type *Ty)
\Returns the number of vector lanes of Ty or 1 if not a vector.
Definition: VecUtils.h:72
static Type * getWideType(Type *ElemTy, unsigned NumElts)
\Returns <NumElts x ElemTy>.
Definition: VecUtils.h:95
static Type * getElementType(Type *Ty)
Returns Ty if scalar or its element type if vector.
Definition: VecUtils.h:32
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
Type
MessagePack types as defined in the standard, with the exception of Integer being divided into a sign...
Definition: MsgPackReader.h:53
@ Widen
ā€¨Collect scalar values.
static BasicBlock::iterator getInsertPointAfterInstrs(ArrayRef< Value * > Instrs)
Definition: BottomUpVec.cpp:45
static SmallVector< Value *, 4 > getOperand(ArrayRef< Value * > Bndl, unsigned OpIdx)
Definition: BottomUpVec.cpp:34
static llvm::SmallVector< Value *, 4 > collectSeeds(BasicBlock &BB)
Definition: BottomUpVec.cpp:26
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:420
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1664