File: | build/source/llvm/include/llvm/IR/PatternMatch.h |
Warning: | line 961, column 9 Called C++ object pointer is null |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- InstCombineSelect.cpp ----------------------------------------------===// | ||||
2 | // | ||||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
4 | // See https://llvm.org/LICENSE.txt for license information. | ||||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
6 | // | ||||
7 | //===----------------------------------------------------------------------===// | ||||
8 | // | ||||
9 | // This file implements the visitSelect function. | ||||
10 | // | ||||
11 | //===----------------------------------------------------------------------===// | ||||
12 | |||||
13 | #include "InstCombineInternal.h" | ||||
14 | #include "llvm/ADT/APInt.h" | ||||
15 | #include "llvm/ADT/STLExtras.h" | ||||
16 | #include "llvm/ADT/SmallVector.h" | ||||
17 | #include "llvm/Analysis/AssumptionCache.h" | ||||
18 | #include "llvm/Analysis/CmpInstAnalysis.h" | ||||
19 | #include "llvm/Analysis/InstructionSimplify.h" | ||||
20 | #include "llvm/Analysis/OverflowInstAnalysis.h" | ||||
21 | #include "llvm/Analysis/ValueTracking.h" | ||||
22 | #include "llvm/Analysis/VectorUtils.h" | ||||
23 | #include "llvm/IR/BasicBlock.h" | ||||
24 | #include "llvm/IR/Constant.h" | ||||
25 | #include "llvm/IR/ConstantRange.h" | ||||
26 | #include "llvm/IR/Constants.h" | ||||
27 | #include "llvm/IR/DerivedTypes.h" | ||||
28 | #include "llvm/IR/IRBuilder.h" | ||||
29 | #include "llvm/IR/InstrTypes.h" | ||||
30 | #include "llvm/IR/Instruction.h" | ||||
31 | #include "llvm/IR/Instructions.h" | ||||
32 | #include "llvm/IR/IntrinsicInst.h" | ||||
33 | #include "llvm/IR/Intrinsics.h" | ||||
34 | #include "llvm/IR/Operator.h" | ||||
35 | #include "llvm/IR/PatternMatch.h" | ||||
36 | #include "llvm/IR/Type.h" | ||||
37 | #include "llvm/IR/User.h" | ||||
38 | #include "llvm/IR/Value.h" | ||||
39 | #include "llvm/Support/Casting.h" | ||||
40 | #include "llvm/Support/ErrorHandling.h" | ||||
41 | #include "llvm/Support/KnownBits.h" | ||||
42 | #include "llvm/Transforms/InstCombine/InstCombiner.h" | ||||
43 | #include <cassert> | ||||
44 | #include <utility> | ||||
45 | |||||
46 | #define DEBUG_TYPE"instcombine" "instcombine" | ||||
47 | #include "llvm/Transforms/Utils/InstructionWorklist.h" | ||||
48 | |||||
49 | using namespace llvm; | ||||
50 | using namespace PatternMatch; | ||||
51 | |||||
52 | |||||
53 | /// Replace a select operand based on an equality comparison with the identity | ||||
54 | /// constant of a binop. | ||||
55 | static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, | ||||
56 | const TargetLibraryInfo &TLI, | ||||
57 | InstCombinerImpl &IC) { | ||||
58 | // The select condition must be an equality compare with a constant operand. | ||||
59 | Value *X; | ||||
60 | Constant *C; | ||||
61 | CmpInst::Predicate Pred; | ||||
62 | if (!match(Sel.getCondition(), m_Cmp(Pred, m_Value(X), m_Constant(C)))) | ||||
63 | return nullptr; | ||||
64 | |||||
65 | bool IsEq; | ||||
66 | if (ICmpInst::isEquality(Pred)) | ||||
67 | IsEq = Pred == ICmpInst::ICMP_EQ; | ||||
68 | else if (Pred == FCmpInst::FCMP_OEQ) | ||||
69 | IsEq = true; | ||||
70 | else if (Pred == FCmpInst::FCMP_UNE) | ||||
71 | IsEq = false; | ||||
72 | else | ||||
73 | return nullptr; | ||||
74 | |||||
75 | // A select operand must be a binop. | ||||
76 | BinaryOperator *BO; | ||||
77 | if (!match(Sel.getOperand(IsEq ? 1 : 2), m_BinOp(BO))) | ||||
78 | return nullptr; | ||||
79 | |||||
80 | // The compare constant must be the identity constant for that binop. | ||||
81 | // If this a floating-point compare with 0.0, any zero constant will do. | ||||
82 | Type *Ty = BO->getType(); | ||||
83 | Constant *IdC = ConstantExpr::getBinOpIdentity(BO->getOpcode(), Ty, true); | ||||
84 | if (IdC != C) { | ||||
85 | if (!IdC || !CmpInst::isFPPredicate(Pred)) | ||||
86 | return nullptr; | ||||
87 | if (!match(IdC, m_AnyZeroFP()) || !match(C, m_AnyZeroFP())) | ||||
88 | return nullptr; | ||||
89 | } | ||||
90 | |||||
91 | // Last, match the compare variable operand with a binop operand. | ||||
92 | Value *Y; | ||||
93 | if (!BO->isCommutative() && !match(BO, m_BinOp(m_Value(Y), m_Specific(X)))) | ||||
94 | return nullptr; | ||||
95 | if (!match(BO, m_c_BinOp(m_Value(Y), m_Specific(X)))) | ||||
96 | return nullptr; | ||||
97 | |||||
98 | // +0.0 compares equal to -0.0, and so it does not behave as required for this | ||||
99 | // transform. Bail out if we can not exclude that possibility. | ||||
100 | if (isa<FPMathOperator>(BO)) | ||||
101 | if (!BO->hasNoSignedZeros() && !CannotBeNegativeZero(Y, &TLI)) | ||||
102 | return nullptr; | ||||
103 | |||||
104 | // BO = binop Y, X | ||||
105 | // S = { select (cmp eq X, C), BO, ? } or { select (cmp ne X, C), ?, BO } | ||||
106 | // => | ||||
107 | // S = { select (cmp eq X, C), Y, ? } or { select (cmp ne X, C), ?, Y } | ||||
108 | return IC.replaceOperand(Sel, IsEq ? 1 : 2, Y); | ||||
109 | } | ||||
110 | |||||
111 | /// This folds: | ||||
112 | /// select (icmp eq (and X, C1)), TC, FC | ||||
113 | /// iff C1 is a power 2 and the difference between TC and FC is a power-of-2. | ||||
114 | /// To something like: | ||||
115 | /// (shr (and (X, C1)), (log2(C1) - log2(TC-FC))) + FC | ||||
116 | /// Or: | ||||
117 | /// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC | ||||
118 | /// With some variations depending if FC is larger than TC, or the shift | ||||
119 | /// isn't needed, or the bit widths don't match. | ||||
120 | static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, | ||||
121 | InstCombiner::BuilderTy &Builder) { | ||||
122 | const APInt *SelTC, *SelFC; | ||||
123 | if (!match(Sel.getTrueValue(), m_APInt(SelTC)) || | ||||
124 | !match(Sel.getFalseValue(), m_APInt(SelFC))) | ||||
125 | return nullptr; | ||||
126 | |||||
127 | // If this is a vector select, we need a vector compare. | ||||
128 | Type *SelType = Sel.getType(); | ||||
129 | if (SelType->isVectorTy() != Cmp->getType()->isVectorTy()) | ||||
130 | return nullptr; | ||||
131 | |||||
132 | Value *V; | ||||
133 | APInt AndMask; | ||||
134 | bool CreateAnd = false; | ||||
135 | ICmpInst::Predicate Pred = Cmp->getPredicate(); | ||||
136 | if (ICmpInst::isEquality(Pred)) { | ||||
137 | if (!match(Cmp->getOperand(1), m_Zero())) | ||||
138 | return nullptr; | ||||
139 | |||||
140 | V = Cmp->getOperand(0); | ||||
141 | const APInt *AndRHS; | ||||
142 | if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) | ||||
143 | return nullptr; | ||||
144 | |||||
145 | AndMask = *AndRHS; | ||||
146 | } else if (decomposeBitTestICmp(Cmp->getOperand(0), Cmp->getOperand(1), | ||||
147 | Pred, V, AndMask)) { | ||||
148 | assert(ICmpInst::isEquality(Pred) && "Not equality test?")(static_cast <bool> (ICmpInst::isEquality(Pred) && "Not equality test?") ? void (0) : __assert_fail ("ICmpInst::isEquality(Pred) && \"Not equality test?\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 148 , __extension__ __PRETTY_FUNCTION__)); | ||||
149 | if (!AndMask.isPowerOf2()) | ||||
150 | return nullptr; | ||||
151 | |||||
152 | CreateAnd = true; | ||||
153 | } else { | ||||
154 | return nullptr; | ||||
155 | } | ||||
156 | |||||
157 | // In general, when both constants are non-zero, we would need an offset to | ||||
158 | // replace the select. This would require more instructions than we started | ||||
159 | // with. But there's one special-case that we handle here because it can | ||||
160 | // simplify/reduce the instructions. | ||||
161 | APInt TC = *SelTC; | ||||
162 | APInt FC = *SelFC; | ||||
163 | if (!TC.isZero() && !FC.isZero()) { | ||||
164 | // If the select constants differ by exactly one bit and that's the same | ||||
165 | // bit that is masked and checked by the select condition, the select can | ||||
166 | // be replaced by bitwise logic to set/clear one bit of the constant result. | ||||
167 | if (TC.getBitWidth() != AndMask.getBitWidth() || (TC ^ FC) != AndMask) | ||||
168 | return nullptr; | ||||
169 | if (CreateAnd) { | ||||
170 | // If we have to create an 'and', then we must kill the cmp to not | ||||
171 | // increase the instruction count. | ||||
172 | if (!Cmp->hasOneUse()) | ||||
173 | return nullptr; | ||||
174 | V = Builder.CreateAnd(V, ConstantInt::get(SelType, AndMask)); | ||||
175 | } | ||||
176 | bool ExtraBitInTC = TC.ugt(FC); | ||||
177 | if (Pred == ICmpInst::ICMP_EQ) { | ||||
178 | // If the masked bit in V is clear, clear or set the bit in the result: | ||||
179 | // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) ^ TC | ||||
180 | // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) | TC | ||||
181 | Constant *C = ConstantInt::get(SelType, TC); | ||||
182 | return ExtraBitInTC ? Builder.CreateXor(V, C) : Builder.CreateOr(V, C); | ||||
183 | } | ||||
184 | if (Pred == ICmpInst::ICMP_NE) { | ||||
185 | // If the masked bit in V is set, set or clear the bit in the result: | ||||
186 | // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) | FC | ||||
187 | // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) ^ FC | ||||
188 | Constant *C = ConstantInt::get(SelType, FC); | ||||
189 | return ExtraBitInTC ? Builder.CreateOr(V, C) : Builder.CreateXor(V, C); | ||||
190 | } | ||||
191 | llvm_unreachable("Only expecting equality predicates")::llvm::llvm_unreachable_internal("Only expecting equality predicates" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 191 ); | ||||
192 | } | ||||
193 | |||||
194 | // Make sure one of the select arms is a power-of-2. | ||||
195 | if (!TC.isPowerOf2() && !FC.isPowerOf2()) | ||||
196 | return nullptr; | ||||
197 | |||||
198 | // Determine which shift is needed to transform result of the 'and' into the | ||||
199 | // desired result. | ||||
200 | const APInt &ValC = !TC.isZero() ? TC : FC; | ||||
201 | unsigned ValZeros = ValC.logBase2(); | ||||
202 | unsigned AndZeros = AndMask.logBase2(); | ||||
203 | |||||
204 | // Insert the 'and' instruction on the input to the truncate. | ||||
205 | if (CreateAnd) | ||||
206 | V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); | ||||
207 | |||||
208 | // If types don't match, we can still convert the select by introducing a zext | ||||
209 | // or a trunc of the 'and'. | ||||
210 | if (ValZeros > AndZeros) { | ||||
211 | V = Builder.CreateZExtOrTrunc(V, SelType); | ||||
212 | V = Builder.CreateShl(V, ValZeros - AndZeros); | ||||
213 | } else if (ValZeros < AndZeros) { | ||||
214 | V = Builder.CreateLShr(V, AndZeros - ValZeros); | ||||
215 | V = Builder.CreateZExtOrTrunc(V, SelType); | ||||
216 | } else { | ||||
217 | V = Builder.CreateZExtOrTrunc(V, SelType); | ||||
218 | } | ||||
219 | |||||
220 | // Okay, now we know that everything is set up, we just don't know whether we | ||||
221 | // have a icmp_ne or icmp_eq and whether the true or false val is the zero. | ||||
222 | bool ShouldNotVal = !TC.isZero(); | ||||
223 | ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; | ||||
224 | if (ShouldNotVal) | ||||
225 | V = Builder.CreateXor(V, ValC); | ||||
226 | |||||
227 | return V; | ||||
228 | } | ||||
229 | |||||
230 | /// We want to turn code that looks like this: | ||||
231 | /// %C = or %A, %B | ||||
232 | /// %D = select %cond, %C, %A | ||||
233 | /// into: | ||||
234 | /// %C = select %cond, %B, 0 | ||||
235 | /// %D = or %A, %C | ||||
236 | /// | ||||
237 | /// Assuming that the specified instruction is an operand to the select, return | ||||
238 | /// a bitmask indicating which operands of this instruction are foldable if they | ||||
239 | /// equal the other incoming value of the select. | ||||
240 | static unsigned getSelectFoldableOperands(BinaryOperator *I) { | ||||
241 | switch (I->getOpcode()) { | ||||
242 | case Instruction::Add: | ||||
243 | case Instruction::FAdd: | ||||
244 | case Instruction::Mul: | ||||
245 | case Instruction::FMul: | ||||
246 | case Instruction::And: | ||||
247 | case Instruction::Or: | ||||
248 | case Instruction::Xor: | ||||
249 | return 3; // Can fold through either operand. | ||||
250 | case Instruction::Sub: // Can only fold on the amount subtracted. | ||||
251 | case Instruction::FSub: | ||||
252 | case Instruction::FDiv: // Can only fold on the divisor amount. | ||||
253 | case Instruction::Shl: // Can only fold on the shift amount. | ||||
254 | case Instruction::LShr: | ||||
255 | case Instruction::AShr: | ||||
256 | return 1; | ||||
257 | default: | ||||
258 | return 0; // Cannot fold | ||||
259 | } | ||||
260 | } | ||||
261 | |||||
262 | /// We have (select c, TI, FI), and we know that TI and FI have the same opcode. | ||||
263 | Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, | ||||
264 | Instruction *FI) { | ||||
265 | // Don't break up min/max patterns. The hasOneUse checks below prevent that | ||||
266 | // for most cases, but vector min/max with bitcasts can be transformed. If the | ||||
267 | // one-use restrictions are eased for other patterns, we still don't want to | ||||
268 | // obfuscate min/max. | ||||
269 | if ((match(&SI, m_SMin(m_Value(), m_Value())) || | ||||
270 | match(&SI, m_SMax(m_Value(), m_Value())) || | ||||
271 | match(&SI, m_UMin(m_Value(), m_Value())) || | ||||
272 | match(&SI, m_UMax(m_Value(), m_Value())))) | ||||
273 | return nullptr; | ||||
274 | |||||
275 | // If this is a cast from the same type, merge. | ||||
276 | Value *Cond = SI.getCondition(); | ||||
277 | Type *CondTy = Cond->getType(); | ||||
278 | if (TI->getNumOperands() == 1 && TI->isCast()) { | ||||
279 | Type *FIOpndTy = FI->getOperand(0)->getType(); | ||||
280 | if (TI->getOperand(0)->getType() != FIOpndTy) | ||||
281 | return nullptr; | ||||
282 | |||||
283 | // The select condition may be a vector. We may only change the operand | ||||
284 | // type if the vector width remains the same (and matches the condition). | ||||
285 | if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) { | ||||
286 | if (!FIOpndTy->isVectorTy() || | ||||
287 | CondVTy->getElementCount() != | ||||
288 | cast<VectorType>(FIOpndTy)->getElementCount()) | ||||
289 | return nullptr; | ||||
290 | |||||
291 | // TODO: If the backend knew how to deal with casts better, we could | ||||
292 | // remove this limitation. For now, there's too much potential to create | ||||
293 | // worse codegen by promoting the select ahead of size-altering casts | ||||
294 | // (PR28160). | ||||
295 | // | ||||
296 | // Note that ValueTracking's matchSelectPattern() looks through casts | ||||
297 | // without checking 'hasOneUse' when it matches min/max patterns, so this | ||||
298 | // transform may end up happening anyway. | ||||
299 | if (TI->getOpcode() != Instruction::BitCast && | ||||
300 | (!TI->hasOneUse() || !FI->hasOneUse())) | ||||
301 | return nullptr; | ||||
302 | } else if (!TI->hasOneUse() || !FI->hasOneUse()) { | ||||
303 | // TODO: The one-use restrictions for a scalar select could be eased if | ||||
304 | // the fold of a select in visitLoadInst() was enhanced to match a pattern | ||||
305 | // that includes a cast. | ||||
306 | return nullptr; | ||||
307 | } | ||||
308 | |||||
309 | // Fold this by inserting a select from the input values. | ||||
310 | Value *NewSI = | ||||
311 | Builder.CreateSelect(Cond, TI->getOperand(0), FI->getOperand(0), | ||||
312 | SI.getName() + ".v", &SI); | ||||
313 | return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI, | ||||
314 | TI->getType()); | ||||
315 | } | ||||
316 | |||||
317 | Value *OtherOpT, *OtherOpF; | ||||
318 | bool MatchIsOpZero; | ||||
319 | auto getCommonOp = [&](Instruction *TI, Instruction *FI, bool Commute, | ||||
320 | bool Swapped = false) -> Value * { | ||||
321 | assert(!(Commute && Swapped) &&(static_cast <bool> (!(Commute && Swapped) && "Commute and Swapped can't set at the same time") ? void (0) : __assert_fail ("!(Commute && Swapped) && \"Commute and Swapped can't set at the same time\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 322 , __extension__ __PRETTY_FUNCTION__)) | ||||
322 | "Commute and Swapped can't set at the same time")(static_cast <bool> (!(Commute && Swapped) && "Commute and Swapped can't set at the same time") ? void (0) : __assert_fail ("!(Commute && Swapped) && \"Commute and Swapped can't set at the same time\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 322 , __extension__ __PRETTY_FUNCTION__)); | ||||
323 | if (!Swapped) { | ||||
324 | if (TI->getOperand(0) == FI->getOperand(0)) { | ||||
325 | OtherOpT = TI->getOperand(1); | ||||
326 | OtherOpF = FI->getOperand(1); | ||||
327 | MatchIsOpZero = true; | ||||
328 | return TI->getOperand(0); | ||||
329 | } else if (TI->getOperand(1) == FI->getOperand(1)) { | ||||
330 | OtherOpT = TI->getOperand(0); | ||||
331 | OtherOpF = FI->getOperand(0); | ||||
332 | MatchIsOpZero = false; | ||||
333 | return TI->getOperand(1); | ||||
334 | } | ||||
335 | } | ||||
336 | |||||
337 | if (!Commute && !Swapped) | ||||
338 | return nullptr; | ||||
339 | |||||
340 | // If we are allowing commute or swap of operands, then | ||||
341 | // allow a cross-operand match. In that case, MatchIsOpZero | ||||
342 | // means that TI's operand 0 (FI's operand 1) is the common op. | ||||
343 | if (TI->getOperand(0) == FI->getOperand(1)) { | ||||
344 | OtherOpT = TI->getOperand(1); | ||||
345 | OtherOpF = FI->getOperand(0); | ||||
346 | MatchIsOpZero = true; | ||||
347 | return TI->getOperand(0); | ||||
348 | } else if (TI->getOperand(1) == FI->getOperand(0)) { | ||||
349 | OtherOpT = TI->getOperand(0); | ||||
350 | OtherOpF = FI->getOperand(1); | ||||
351 | MatchIsOpZero = false; | ||||
352 | return TI->getOperand(1); | ||||
353 | } | ||||
354 | return nullptr; | ||||
355 | }; | ||||
356 | |||||
357 | if (TI->hasOneUse() || FI->hasOneUse()) { | ||||
358 | // Cond ? -X : -Y --> -(Cond ? X : Y) | ||||
359 | Value *X, *Y; | ||||
360 | if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y)))) { | ||||
361 | // Intersect FMF from the fneg instructions and union those with the | ||||
362 | // select. | ||||
363 | FastMathFlags FMF = TI->getFastMathFlags(); | ||||
364 | FMF &= FI->getFastMathFlags(); | ||||
365 | FMF |= SI.getFastMathFlags(); | ||||
366 | Value *NewSel = | ||||
367 | Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); | ||||
368 | if (auto *NewSelI = dyn_cast<Instruction>(NewSel)) | ||||
369 | NewSelI->setFastMathFlags(FMF); | ||||
370 | Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel); | ||||
371 | NewFNeg->setFastMathFlags(FMF); | ||||
372 | return NewFNeg; | ||||
373 | } | ||||
374 | |||||
375 | // Min/max intrinsic with a common operand can have the common operand | ||||
376 | // pulled after the select. This is the same transform as below for binops, | ||||
377 | // but specialized for intrinsic matching and without the restrictive uses | ||||
378 | // clause. | ||||
379 | auto *TII = dyn_cast<IntrinsicInst>(TI); | ||||
380 | auto *FII = dyn_cast<IntrinsicInst>(FI); | ||||
381 | if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID()) { | ||||
382 | if (match(TII, m_MaxOrMin(m_Value(), m_Value()))) { | ||||
383 | if (Value *MatchOp = getCommonOp(TI, FI, true)) { | ||||
384 | Value *NewSel = | ||||
385 | Builder.CreateSelect(Cond, OtherOpT, OtherOpF, "minmaxop", &SI); | ||||
386 | return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp}); | ||||
387 | } | ||||
388 | } | ||||
389 | } | ||||
390 | |||||
391 | // icmp with a common operand also can have the common operand | ||||
392 | // pulled after the select. | ||||
393 | ICmpInst::Predicate TPred, FPred; | ||||
394 | if (match(TI, m_ICmp(TPred, m_Value(), m_Value())) && | ||||
395 | match(FI, m_ICmp(FPred, m_Value(), m_Value()))) { | ||||
396 | if (TPred == FPred || TPred == CmpInst::getSwappedPredicate(FPred)) { | ||||
397 | bool Swapped = TPred != FPred; | ||||
398 | if (Value *MatchOp = | ||||
399 | getCommonOp(TI, FI, ICmpInst::isEquality(TPred), Swapped)) { | ||||
400 | Value *NewSel = Builder.CreateSelect(Cond, OtherOpT, OtherOpF, | ||||
401 | SI.getName() + ".v", &SI); | ||||
402 | return new ICmpInst( | ||||
403 | MatchIsOpZero ? TPred : CmpInst::getSwappedPredicate(TPred), | ||||
404 | MatchOp, NewSel); | ||||
405 | } | ||||
406 | } | ||||
407 | } | ||||
408 | } | ||||
409 | |||||
410 | // Only handle binary operators (including two-operand getelementptr) with | ||||
411 | // one-use here. As with the cast case above, it may be possible to relax the | ||||
412 | // one-use constraint, but that needs be examined carefully since it may not | ||||
413 | // reduce the total number of instructions. | ||||
414 | if (TI->getNumOperands() != 2 || FI->getNumOperands() != 2 || | ||||
415 | !TI->isSameOperationAs(FI) || | ||||
416 | (!isa<BinaryOperator>(TI) && !isa<GetElementPtrInst>(TI)) || | ||||
417 | !TI->hasOneUse() || !FI->hasOneUse()) | ||||
418 | return nullptr; | ||||
419 | |||||
420 | // Figure out if the operations have any operands in common. | ||||
421 | Value *MatchOp = getCommonOp(TI, FI, TI->isCommutative()); | ||||
422 | if (!MatchOp) | ||||
423 | return nullptr; | ||||
424 | |||||
425 | // If the select condition is a vector, the operands of the original select's | ||||
426 | // operands also must be vectors. This may not be the case for getelementptr | ||||
427 | // for example. | ||||
428 | if (CondTy->isVectorTy() && (!OtherOpT->getType()->isVectorTy() || | ||||
429 | !OtherOpF->getType()->isVectorTy())) | ||||
430 | return nullptr; | ||||
431 | |||||
432 | // If we are sinking div/rem after a select, we may need to freeze the | ||||
433 | // condition because div/rem may induce immediate UB with a poison operand. | ||||
434 | // For example, the following transform is not safe if Cond can ever be poison | ||||
435 | // because we can replace poison with zero and then we have div-by-zero that | ||||
436 | // didn't exist in the original code: | ||||
437 | // Cond ? x/y : x/z --> x / (Cond ? y : z) | ||||
438 | auto *BO = dyn_cast<BinaryOperator>(TI); | ||||
439 | if (BO && BO->isIntDivRem() && !isGuaranteedNotToBePoison(Cond)) { | ||||
440 | // A udiv/urem with a common divisor is safe because UB can only occur with | ||||
441 | // div-by-zero, and that would be present in the original code. | ||||
442 | if (BO->getOpcode() == Instruction::SDiv || | ||||
443 | BO->getOpcode() == Instruction::SRem || MatchIsOpZero) | ||||
444 | Cond = Builder.CreateFreeze(Cond); | ||||
445 | } | ||||
446 | |||||
447 | // If we reach here, they do have operations in common. | ||||
448 | Value *NewSI = Builder.CreateSelect(Cond, OtherOpT, OtherOpF, | ||||
449 | SI.getName() + ".v", &SI); | ||||
450 | Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; | ||||
451 | Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; | ||||
452 | if (auto *BO = dyn_cast<BinaryOperator>(TI)) { | ||||
453 | BinaryOperator *NewBO = BinaryOperator::Create(BO->getOpcode(), Op0, Op1); | ||||
454 | NewBO->copyIRFlags(TI); | ||||
455 | NewBO->andIRFlags(FI); | ||||
456 | return NewBO; | ||||
457 | } | ||||
458 | if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) { | ||||
459 | auto *FGEP = cast<GetElementPtrInst>(FI); | ||||
460 | Type *ElementType = TGEP->getResultElementType(); | ||||
461 | return TGEP->isInBounds() && FGEP->isInBounds() | ||||
462 | ? GetElementPtrInst::CreateInBounds(ElementType, Op0, {Op1}) | ||||
463 | : GetElementPtrInst::Create(ElementType, Op0, {Op1}); | ||||
464 | } | ||||
465 | llvm_unreachable("Expected BinaryOperator or GEP")::llvm::llvm_unreachable_internal("Expected BinaryOperator or GEP" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 465 ); | ||||
466 | return nullptr; | ||||
467 | } | ||||
468 | |||||
469 | static bool isSelect01(const APInt &C1I, const APInt &C2I) { | ||||
470 | if (!C1I.isZero() && !C2I.isZero()) // One side must be zero. | ||||
471 | return false; | ||||
472 | return C1I.isOne() || C1I.isAllOnes() || C2I.isOne() || C2I.isAllOnes(); | ||||
473 | } | ||||
474 | |||||
475 | /// Try to fold the select into one of the operands to allow further | ||||
476 | /// optimization. | ||||
477 | Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, | ||||
478 | Value *FalseVal) { | ||||
479 | // See the comment above GetSelectFoldableOperands for a description of the | ||||
480 | // transformation we are doing here. | ||||
481 | auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal, | ||||
482 | Value *FalseVal, | ||||
483 | bool Swapped) -> Instruction * { | ||||
484 | auto *TVI = dyn_cast<BinaryOperator>(TrueVal); | ||||
485 | if (!TVI || !TVI->hasOneUse() || isa<Constant>(FalseVal)) | ||||
486 | return nullptr; | ||||
487 | |||||
488 | unsigned SFO = getSelectFoldableOperands(TVI); | ||||
489 | unsigned OpToFold = 0; | ||||
490 | if ((SFO & 1) && FalseVal == TVI->getOperand(0)) | ||||
491 | OpToFold = 1; | ||||
492 | else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) | ||||
493 | OpToFold = 2; | ||||
494 | |||||
495 | if (!OpToFold) | ||||
496 | return nullptr; | ||||
497 | |||||
498 | // TODO: We probably ought to revisit cases where the select and FP | ||||
499 | // instructions have different flags and add tests to ensure the | ||||
500 | // behaviour is correct. | ||||
501 | FastMathFlags FMF; | ||||
502 | if (isa<FPMathOperator>(&SI)) | ||||
503 | FMF = SI.getFastMathFlags(); | ||||
504 | Constant *C = ConstantExpr::getBinOpIdentity( | ||||
505 | TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros()); | ||||
506 | Value *OOp = TVI->getOperand(2 - OpToFold); | ||||
507 | // Avoid creating select between 2 constants unless it's selecting | ||||
508 | // between 0, 1 and -1. | ||||
509 | const APInt *OOpC; | ||||
510 | bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); | ||||
511 | if (!isa<Constant>(OOp) || | ||||
512 | (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { | ||||
513 | Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp, | ||||
514 | Swapped ? OOp : C); | ||||
515 | if (isa<FPMathOperator>(&SI)) | ||||
516 | cast<Instruction>(NewSel)->setFastMathFlags(FMF); | ||||
517 | NewSel->takeName(TVI); | ||||
518 | BinaryOperator *BO = | ||||
519 | BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); | ||||
520 | BO->copyIRFlags(TVI); | ||||
521 | return BO; | ||||
522 | } | ||||
523 | return nullptr; | ||||
524 | }; | ||||
525 | |||||
526 | if (Instruction *R = TryFoldSelectIntoOp(SI, TrueVal, FalseVal, false)) | ||||
527 | return R; | ||||
528 | |||||
529 | if (Instruction *R = TryFoldSelectIntoOp(SI, FalseVal, TrueVal, true)) | ||||
530 | return R; | ||||
531 | |||||
532 | return nullptr; | ||||
533 | } | ||||
534 | |||||
535 | /// We want to turn: | ||||
536 | /// (select (icmp eq (and X, Y), 0), (and (lshr X, Z), 1), 1) | ||||
537 | /// into: | ||||
538 | /// zext (icmp ne i32 (and X, (or Y, (shl 1, Z))), 0) | ||||
539 | /// Note: | ||||
540 | /// Z may be 0 if lshr is missing. | ||||
541 | /// Worst-case scenario is that we will replace 5 instructions with 5 different | ||||
542 | /// instructions, but we got rid of select. | ||||
543 | static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp, | ||||
544 | Value *TVal, Value *FVal, | ||||
545 | InstCombiner::BuilderTy &Builder) { | ||||
546 | if (!(Cmp->hasOneUse() && Cmp->getOperand(0)->hasOneUse() && | ||||
547 | Cmp->getPredicate() == ICmpInst::ICMP_EQ && | ||||
548 | match(Cmp->getOperand(1), m_Zero()) && match(FVal, m_One()))) | ||||
549 | return nullptr; | ||||
550 | |||||
551 | // The TrueVal has general form of: and %B, 1 | ||||
552 | Value *B; | ||||
553 | if (!match(TVal, m_OneUse(m_And(m_Value(B), m_One())))) | ||||
554 | return nullptr; | ||||
555 | |||||
556 | // Where %B may be optionally shifted: lshr %X, %Z. | ||||
557 | Value *X, *Z; | ||||
558 | const bool HasShift = match(B, m_OneUse(m_LShr(m_Value(X), m_Value(Z)))); | ||||
559 | |||||
560 | // The shift must be valid. | ||||
561 | // TODO: This restricts the fold to constant shift amounts. Is there a way to | ||||
562 | // handle variable shifts safely? PR47012 | ||||
563 | if (HasShift && | ||||
564 | !match(Z, m_SpecificInt_ICMP(CmpInst::ICMP_ULT, | ||||
565 | APInt(SelType->getScalarSizeInBits(), | ||||
566 | SelType->getScalarSizeInBits())))) | ||||
567 | return nullptr; | ||||
568 | |||||
569 | if (!HasShift) | ||||
570 | X = B; | ||||
571 | |||||
572 | Value *Y; | ||||
573 | if (!match(Cmp->getOperand(0), m_c_And(m_Specific(X), m_Value(Y)))) | ||||
574 | return nullptr; | ||||
575 | |||||
576 | // ((X & Y) == 0) ? ((X >> Z) & 1) : 1 --> (X & (Y | (1 << Z))) != 0 | ||||
577 | // ((X & Y) == 0) ? (X & 1) : 1 --> (X & (Y | 1)) != 0 | ||||
578 | Constant *One = ConstantInt::get(SelType, 1); | ||||
579 | Value *MaskB = HasShift ? Builder.CreateShl(One, Z) : One; | ||||
580 | Value *FullMask = Builder.CreateOr(Y, MaskB); | ||||
581 | Value *MaskedX = Builder.CreateAnd(X, FullMask); | ||||
582 | Value *ICmpNeZero = Builder.CreateIsNotNull(MaskedX); | ||||
583 | return new ZExtInst(ICmpNeZero, SelType); | ||||
584 | } | ||||
585 | |||||
586 | /// We want to turn: | ||||
587 | /// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1 | ||||
588 | /// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0 | ||||
589 | /// into: | ||||
590 | /// ashr (X, Y) | ||||
591 | static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal, | ||||
592 | Value *FalseVal, | ||||
593 | InstCombiner::BuilderTy &Builder) { | ||||
594 | ICmpInst::Predicate Pred = IC->getPredicate(); | ||||
595 | Value *CmpLHS = IC->getOperand(0); | ||||
596 | Value *CmpRHS = IC->getOperand(1); | ||||
597 | if (!CmpRHS->getType()->isIntOrIntVectorTy()) | ||||
598 | return nullptr; | ||||
599 | |||||
600 | Value *X, *Y; | ||||
601 | unsigned Bitwidth = CmpRHS->getType()->getScalarSizeInBits(); | ||||
602 | if ((Pred != ICmpInst::ICMP_SGT || | ||||
603 | !match(CmpRHS, | ||||
604 | m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, -1)))) && | ||||
605 | (Pred != ICmpInst::ICMP_SLT || | ||||
606 | !match(CmpRHS, | ||||
607 | m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, 0))))) | ||||
608 | return nullptr; | ||||
609 | |||||
610 | // Canonicalize so that ashr is in FalseVal. | ||||
611 | if (Pred == ICmpInst::ICMP_SLT) | ||||
612 | std::swap(TrueVal, FalseVal); | ||||
613 | |||||
614 | if (match(TrueVal, m_LShr(m_Value(X), m_Value(Y))) && | ||||
615 | match(FalseVal, m_AShr(m_Specific(X), m_Specific(Y))) && | ||||
616 | match(CmpLHS, m_Specific(X))) { | ||||
617 | const auto *Ashr = cast<Instruction>(FalseVal); | ||||
618 | // if lshr is not exact and ashr is, this new ashr must not be exact. | ||||
619 | bool IsExact = Ashr->isExact() && cast<Instruction>(TrueVal)->isExact(); | ||||
620 | return Builder.CreateAShr(X, Y, IC->getName(), IsExact); | ||||
621 | } | ||||
622 | |||||
623 | return nullptr; | ||||
624 | } | ||||
625 | |||||
626 | /// We want to turn: | ||||
627 | /// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) | ||||
628 | /// into: | ||||
629 | /// (or (shl (and X, C1), C3), Y) | ||||
630 | /// iff: | ||||
631 | /// C1 and C2 are both powers of 2 | ||||
632 | /// where: | ||||
633 | /// C3 = Log(C2) - Log(C1) | ||||
634 | /// | ||||
635 | /// This transform handles cases where: | ||||
636 | /// 1. The icmp predicate is inverted | ||||
637 | /// 2. The select operands are reversed | ||||
638 | /// 3. The magnitude of C2 and C1 are flipped | ||||
639 | static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, | ||||
640 | Value *FalseVal, | ||||
641 | InstCombiner::BuilderTy &Builder) { | ||||
642 | // Only handle integer compares. Also, if this is a vector select, we need a | ||||
643 | // vector compare. | ||||
644 | if (!TrueVal->getType()->isIntOrIntVectorTy() || | ||||
645 | TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy()) | ||||
646 | return nullptr; | ||||
647 | |||||
648 | Value *CmpLHS = IC->getOperand(0); | ||||
649 | Value *CmpRHS = IC->getOperand(1); | ||||
650 | |||||
651 | Value *V; | ||||
652 | unsigned C1Log; | ||||
653 | bool IsEqualZero; | ||||
654 | bool NeedAnd = false; | ||||
655 | if (IC->isEquality()) { | ||||
656 | if (!match(CmpRHS, m_Zero())) | ||||
657 | return nullptr; | ||||
658 | |||||
659 | const APInt *C1; | ||||
660 | if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1)))) | ||||
661 | return nullptr; | ||||
662 | |||||
663 | V = CmpLHS; | ||||
664 | C1Log = C1->logBase2(); | ||||
665 | IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ; | ||||
666 | } else if (IC->getPredicate() == ICmpInst::ICMP_SLT || | ||||
667 | IC->getPredicate() == ICmpInst::ICMP_SGT) { | ||||
668 | // We also need to recognize (icmp slt (trunc (X)), 0) and | ||||
669 | // (icmp sgt (trunc (X)), -1). | ||||
670 | IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT; | ||||
671 | if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) || | ||||
672 | (!IsEqualZero && !match(CmpRHS, m_Zero()))) | ||||
673 | return nullptr; | ||||
674 | |||||
675 | if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V))))) | ||||
676 | return nullptr; | ||||
677 | |||||
678 | C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1; | ||||
679 | NeedAnd = true; | ||||
680 | } else { | ||||
681 | return nullptr; | ||||
682 | } | ||||
683 | |||||
684 | const APInt *C2; | ||||
685 | bool OrOnTrueVal = false; | ||||
686 | bool OrOnFalseVal = match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2))); | ||||
687 | if (!OrOnFalseVal) | ||||
688 | OrOnTrueVal = match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2))); | ||||
689 | |||||
690 | if (!OrOnFalseVal && !OrOnTrueVal) | ||||
691 | return nullptr; | ||||
692 | |||||
693 | Value *Y = OrOnFalseVal ? TrueVal : FalseVal; | ||||
694 | |||||
695 | unsigned C2Log = C2->logBase2(); | ||||
696 | |||||
697 | bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); | ||||
698 | bool NeedShift = C1Log != C2Log; | ||||
699 | bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() != | ||||
700 | V->getType()->getScalarSizeInBits(); | ||||
701 | |||||
702 | // Make sure we don't create more instructions than we save. | ||||
703 | Value *Or = OrOnFalseVal ? FalseVal : TrueVal; | ||||
704 | if ((NeedShift + NeedXor + NeedZExtTrunc) > | ||||
705 | (IC->hasOneUse() + Or->hasOneUse())) | ||||
706 | return nullptr; | ||||
707 | |||||
708 | if (NeedAnd) { | ||||
709 | // Insert the AND instruction on the input to the truncate. | ||||
710 | APInt C1 = APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), C1Log); | ||||
711 | V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), C1)); | ||||
712 | } | ||||
713 | |||||
714 | if (C2Log > C1Log) { | ||||
715 | V = Builder.CreateZExtOrTrunc(V, Y->getType()); | ||||
716 | V = Builder.CreateShl(V, C2Log - C1Log); | ||||
717 | } else if (C1Log > C2Log) { | ||||
718 | V = Builder.CreateLShr(V, C1Log - C2Log); | ||||
719 | V = Builder.CreateZExtOrTrunc(V, Y->getType()); | ||||
720 | } else | ||||
721 | V = Builder.CreateZExtOrTrunc(V, Y->getType()); | ||||
722 | |||||
723 | if (NeedXor) | ||||
724 | V = Builder.CreateXor(V, *C2); | ||||
725 | |||||
726 | return Builder.CreateOr(V, Y); | ||||
727 | } | ||||
728 | |||||
729 | /// Canonicalize a set or clear of a masked set of constant bits to | ||||
730 | /// select-of-constants form. | ||||
731 | static Instruction *foldSetClearBits(SelectInst &Sel, | ||||
732 | InstCombiner::BuilderTy &Builder) { | ||||
733 | Value *Cond = Sel.getCondition(); | ||||
734 | Value *T = Sel.getTrueValue(); | ||||
735 | Value *F = Sel.getFalseValue(); | ||||
736 | Type *Ty = Sel.getType(); | ||||
737 | Value *X; | ||||
738 | const APInt *NotC, *C; | ||||
739 | |||||
740 | // Cond ? (X & ~C) : (X | C) --> (X & ~C) | (Cond ? 0 : C) | ||||
741 | if (match(T, m_And(m_Value(X), m_APInt(NotC))) && | ||||
742 | match(F, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && *NotC == ~(*C)) { | ||||
743 | Constant *Zero = ConstantInt::getNullValue(Ty); | ||||
744 | Constant *OrC = ConstantInt::get(Ty, *C); | ||||
745 | Value *NewSel = Builder.CreateSelect(Cond, Zero, OrC, "masksel", &Sel); | ||||
746 | return BinaryOperator::CreateOr(T, NewSel); | ||||
747 | } | ||||
748 | |||||
749 | // Cond ? (X | C) : (X & ~C) --> (X & ~C) | (Cond ? C : 0) | ||||
750 | if (match(F, m_And(m_Value(X), m_APInt(NotC))) && | ||||
751 | match(T, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && *NotC == ~(*C)) { | ||||
752 | Constant *Zero = ConstantInt::getNullValue(Ty); | ||||
753 | Constant *OrC = ConstantInt::get(Ty, *C); | ||||
754 | Value *NewSel = Builder.CreateSelect(Cond, OrC, Zero, "masksel", &Sel); | ||||
755 | return BinaryOperator::CreateOr(F, NewSel); | ||||
756 | } | ||||
757 | |||||
758 | return nullptr; | ||||
759 | } | ||||
760 | |||||
761 | // select (x == 0), 0, x * y --> freeze(y) * x | ||||
762 | // select (y == 0), 0, x * y --> freeze(x) * y | ||||
763 | // select (x == 0), undef, x * y --> freeze(y) * x | ||||
764 | // select (x == undef), 0, x * y --> freeze(y) * x | ||||
765 | // Usage of mul instead of 0 will make the result more poisonous, | ||||
766 | // so the operand that was not checked in the condition should be frozen. | ||||
767 | // The latter folding is applied only when a constant compared with x is | ||||
768 | // is a vector consisting of 0 and undefs. If a constant compared with x | ||||
769 | // is a scalar undefined value or undefined vector then an expression | ||||
770 | // should be already folded into a constant. | ||||
771 | static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) { | ||||
772 | auto *CondVal = SI.getCondition(); | ||||
773 | auto *TrueVal = SI.getTrueValue(); | ||||
774 | auto *FalseVal = SI.getFalseValue(); | ||||
775 | Value *X, *Y; | ||||
776 | ICmpInst::Predicate Predicate; | ||||
777 | |||||
778 | // Assuming that constant compared with zero is not undef (but it may be | ||||
779 | // a vector with some undef elements). Otherwise (when a constant is undef) | ||||
780 | // the select expression should be already simplified. | ||||
781 | if (!match(CondVal, m_ICmp(Predicate, m_Value(X), m_Zero())) || | ||||
782 | !ICmpInst::isEquality(Predicate)) | ||||
783 | return nullptr; | ||||
784 | |||||
785 | if (Predicate == ICmpInst::ICMP_NE) | ||||
786 | std::swap(TrueVal, FalseVal); | ||||
787 | |||||
788 | // Check that TrueVal is a constant instead of matching it with m_Zero() | ||||
789 | // to handle the case when it is a scalar undef value or a vector containing | ||||
790 | // non-zero elements that are masked by undef elements in the compare | ||||
791 | // constant. | ||||
792 | auto *TrueValC = dyn_cast<Constant>(TrueVal); | ||||
793 | if (TrueValC == nullptr || | ||||
794 | !match(FalseVal, m_c_Mul(m_Specific(X), m_Value(Y))) || | ||||
795 | !isa<Instruction>(FalseVal)) | ||||
796 | return nullptr; | ||||
797 | |||||
798 | auto *ZeroC = cast<Constant>(cast<Instruction>(CondVal)->getOperand(1)); | ||||
799 | auto *MergedC = Constant::mergeUndefsWith(TrueValC, ZeroC); | ||||
800 | // If X is compared with 0 then TrueVal could be either zero or undef. | ||||
801 | // m_Zero match vectors containing some undef elements, but for scalars | ||||
802 | // m_Undef should be used explicitly. | ||||
803 | if (!match(MergedC, m_Zero()) && !match(MergedC, m_Undef())) | ||||
804 | return nullptr; | ||||
805 | |||||
806 | auto *FalseValI = cast<Instruction>(FalseVal); | ||||
807 | auto *FrY = IC.InsertNewInstBefore(new FreezeInst(Y, Y->getName() + ".fr"), | ||||
808 | *FalseValI); | ||||
809 | IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY); | ||||
810 | return IC.replaceInstUsesWith(SI, FalseValI); | ||||
811 | } | ||||
812 | |||||
813 | /// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b). | ||||
814 | /// There are 8 commuted/swapped variants of this pattern. | ||||
815 | /// TODO: Also support a - UMIN(a,b) patterns. | ||||
816 | static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, | ||||
817 | const Value *TrueVal, | ||||
818 | const Value *FalseVal, | ||||
819 | InstCombiner::BuilderTy &Builder) { | ||||
820 | ICmpInst::Predicate Pred = ICI->getPredicate(); | ||||
821 | Value *A = ICI->getOperand(0); | ||||
822 | Value *B = ICI->getOperand(1); | ||||
823 | |||||
824 | // (b > a) ? 0 : a - b -> (b <= a) ? a - b : 0 | ||||
825 | // (a == 0) ? 0 : a - 1 -> (a != 0) ? a - 1 : 0 | ||||
826 | if (match(TrueVal, m_Zero())) { | ||||
827 | Pred = ICmpInst::getInversePredicate(Pred); | ||||
828 | std::swap(TrueVal, FalseVal); | ||||
829 | } | ||||
830 | |||||
831 | if (!match(FalseVal, m_Zero())) | ||||
832 | return nullptr; | ||||
833 | |||||
834 | // ugt 0 is canonicalized to ne 0 and requires special handling | ||||
835 | // (a != 0) ? a + -1 : 0 -> usub.sat(a, 1) | ||||
836 | if (Pred == ICmpInst::ICMP_NE) { | ||||
837 | if (match(B, m_Zero()) && match(TrueVal, m_Add(m_Specific(A), m_AllOnes()))) | ||||
838 | return Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, | ||||
839 | ConstantInt::get(A->getType(), 1)); | ||||
840 | return nullptr; | ||||
841 | } | ||||
842 | |||||
843 | if (!ICmpInst::isUnsigned(Pred)) | ||||
844 | return nullptr; | ||||
845 | |||||
846 | if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_ULT) { | ||||
847 | // (b < a) ? a - b : 0 -> (a > b) ? a - b : 0 | ||||
848 | std::swap(A, B); | ||||
849 | Pred = ICmpInst::getSwappedPredicate(Pred); | ||||
850 | } | ||||
851 | |||||
852 | assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) &&(static_cast <bool> ((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) && "Unexpected isUnsigned predicate!" ) ? void (0) : __assert_fail ("(Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) && \"Unexpected isUnsigned predicate!\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 853 , __extension__ __PRETTY_FUNCTION__)) | ||||
853 | "Unexpected isUnsigned predicate!")(static_cast <bool> ((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) && "Unexpected isUnsigned predicate!" ) ? void (0) : __assert_fail ("(Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) && \"Unexpected isUnsigned predicate!\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 853 , __extension__ __PRETTY_FUNCTION__)); | ||||
854 | |||||
855 | // Ensure the sub is of the form: | ||||
856 | // (a > b) ? a - b : 0 -> usub.sat(a, b) | ||||
857 | // (a > b) ? b - a : 0 -> -usub.sat(a, b) | ||||
858 | // Checking for both a-b and a+(-b) as a constant. | ||||
859 | bool IsNegative = false; | ||||
860 | const APInt *C; | ||||
861 | if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A))) || | ||||
862 | (match(A, m_APInt(C)) && | ||||
863 | match(TrueVal, m_Add(m_Specific(B), m_SpecificInt(-*C))))) | ||||
864 | IsNegative = true; | ||||
865 | else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B))) && | ||||
866 | !(match(B, m_APInt(C)) && | ||||
867 | match(TrueVal, m_Add(m_Specific(A), m_SpecificInt(-*C))))) | ||||
868 | return nullptr; | ||||
869 | |||||
870 | // If we are adding a negate and the sub and icmp are used anywhere else, we | ||||
871 | // would end up with more instructions. | ||||
872 | if (IsNegative && !TrueVal->hasOneUse() && !ICI->hasOneUse()) | ||||
873 | return nullptr; | ||||
874 | |||||
875 | // (a > b) ? a - b : 0 -> usub.sat(a, b) | ||||
876 | // (a > b) ? b - a : 0 -> -usub.sat(a, b) | ||||
877 | Value *Result = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, B); | ||||
878 | if (IsNegative) | ||||
879 | Result = Builder.CreateNeg(Result); | ||||
880 | return Result; | ||||
881 | } | ||||
882 | |||||
883 | static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, | ||||
884 | InstCombiner::BuilderTy &Builder) { | ||||
885 | if (!Cmp->hasOneUse()) | ||||
886 | return nullptr; | ||||
887 | |||||
888 | // Match unsigned saturated add with constant. | ||||
889 | Value *Cmp0 = Cmp->getOperand(0); | ||||
890 | Value *Cmp1 = Cmp->getOperand(1); | ||||
891 | ICmpInst::Predicate Pred = Cmp->getPredicate(); | ||||
892 | Value *X; | ||||
893 | const APInt *C, *CmpC; | ||||
894 | if (Pred == ICmpInst::ICMP_ULT && | ||||
895 | match(TVal, m_Add(m_Value(X), m_APInt(C))) && X == Cmp0 && | ||||
896 | match(FVal, m_AllOnes()) && match(Cmp1, m_APInt(CmpC)) && *CmpC == ~*C) { | ||||
897 | // (X u< ~C) ? (X + C) : -1 --> uadd.sat(X, C) | ||||
898 | return Builder.CreateBinaryIntrinsic( | ||||
899 | Intrinsic::uadd_sat, X, ConstantInt::get(X->getType(), *C)); | ||||
900 | } | ||||
901 | |||||
902 | // Match unsigned saturated add of 2 variables with an unnecessary 'not'. | ||||
903 | // There are 8 commuted variants. | ||||
904 | // Canonicalize -1 (saturated result) to true value of the select. | ||||
905 | if (match(FVal, m_AllOnes())) { | ||||
906 | std::swap(TVal, FVal); | ||||
907 | Pred = CmpInst::getInversePredicate(Pred); | ||||
908 | } | ||||
909 | if (!match(TVal, m_AllOnes())) | ||||
910 | return nullptr; | ||||
911 | |||||
912 | // Canonicalize predicate to less-than or less-or-equal-than. | ||||
913 | if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) { | ||||
914 | std::swap(Cmp0, Cmp1); | ||||
915 | Pred = CmpInst::getSwappedPredicate(Pred); | ||||
916 | } | ||||
917 | if (Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_ULE) | ||||
918 | return nullptr; | ||||
919 | |||||
920 | // Match unsigned saturated add of 2 variables with an unnecessary 'not'. | ||||
921 | // Strictness of the comparison is irrelevant. | ||||
922 | Value *Y; | ||||
923 | if (match(Cmp0, m_Not(m_Value(X))) && | ||||
924 | match(FVal, m_c_Add(m_Specific(X), m_Value(Y))) && Y == Cmp1) { | ||||
925 | // (~X u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y) | ||||
926 | // (~X u< Y) ? -1 : (Y + X) --> uadd.sat(X, Y) | ||||
927 | return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, X, Y); | ||||
928 | } | ||||
929 | // The 'not' op may be included in the sum but not the compare. | ||||
930 | // Strictness of the comparison is irrelevant. | ||||
931 | X = Cmp0; | ||||
932 | Y = Cmp1; | ||||
933 | if (match(FVal, m_c_Add(m_Not(m_Specific(X)), m_Specific(Y)))) { | ||||
934 | // (X u< Y) ? -1 : (~X + Y) --> uadd.sat(~X, Y) | ||||
935 | // (X u< Y) ? -1 : (Y + ~X) --> uadd.sat(Y, ~X) | ||||
936 | BinaryOperator *BO = cast<BinaryOperator>(FVal); | ||||
937 | return Builder.CreateBinaryIntrinsic( | ||||
938 | Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1)); | ||||
939 | } | ||||
940 | // The overflow may be detected via the add wrapping round. | ||||
941 | // This is only valid for strict comparison! | ||||
942 | if (Pred == ICmpInst::ICMP_ULT && | ||||
943 | match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) && | ||||
944 | match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) { | ||||
945 | // ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y) | ||||
946 | // ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y) | ||||
947 | return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp1, Y); | ||||
948 | } | ||||
949 | |||||
950 | return nullptr; | ||||
951 | } | ||||
952 | |||||
953 | /// Try to match patterns with select and subtract as absolute difference. | ||||
954 | static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal, | ||||
955 | InstCombiner::BuilderTy &Builder) { | ||||
956 | auto *TI = dyn_cast<Instruction>(TVal); | ||||
957 | auto *FI = dyn_cast<Instruction>(FVal); | ||||
958 | if (!TI || !FI) | ||||
959 | return nullptr; | ||||
960 | |||||
961 | // Normalize predicate to gt/lt rather than ge/le. | ||||
962 | ICmpInst::Predicate Pred = Cmp->getStrictPredicate(); | ||||
963 | Value *A = Cmp->getOperand(0); | ||||
964 | Value *B = Cmp->getOperand(1); | ||||
965 | |||||
966 | // Normalize "A - B" as the true value of the select. | ||||
967 | if (match(FI, m_Sub(m_Specific(A), m_Specific(B)))) { | ||||
968 | std::swap(FI, TI); | ||||
969 | Pred = ICmpInst::getSwappedPredicate(Pred); | ||||
970 | } | ||||
971 | |||||
972 | // With any pair of no-wrap subtracts: | ||||
973 | // (A > B) ? (A - B) : (B - A) --> abs(A - B) | ||||
974 | if (Pred == CmpInst::ICMP_SGT && | ||||
975 | match(TI, m_Sub(m_Specific(A), m_Specific(B))) && | ||||
976 | match(FI, m_Sub(m_Specific(B), m_Specific(A))) && | ||||
977 | (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap()) && | ||||
978 | (FI->hasNoSignedWrap() || FI->hasNoUnsignedWrap())) { | ||||
979 | // The remaining subtract is not "nuw" any more. | ||||
980 | // If there's one use of the subtract (no other use than the use we are | ||||
981 | // about to replace), then we know that the sub is "nsw" in this context | ||||
982 | // even if it was only "nuw" before. If there's another use, then we can't | ||||
983 | // add "nsw" to the existing instruction because it may not be safe in the | ||||
984 | // other user's context. | ||||
985 | TI->setHasNoUnsignedWrap(false); | ||||
986 | if (!TI->hasNoSignedWrap()) | ||||
987 | TI->setHasNoSignedWrap(TI->hasOneUse()); | ||||
988 | return Builder.CreateBinaryIntrinsic(Intrinsic::abs, TI, Builder.getTrue()); | ||||
989 | } | ||||
990 | |||||
991 | return nullptr; | ||||
992 | } | ||||
993 | |||||
994 | /// Fold the following code sequence: | ||||
995 | /// \code | ||||
996 | /// int a = ctlz(x & -x); | ||||
997 | // x ? 31 - a : a; | ||||
998 | /// \code | ||||
999 | /// | ||||
1000 | /// into: | ||||
1001 | /// cttz(x) | ||||
1002 | static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal, | ||||
1003 | Value *FalseVal, | ||||
1004 | InstCombiner::BuilderTy &Builder) { | ||||
1005 | unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits(); | ||||
1006 | if (!ICI->isEquality() || !match(ICI->getOperand(1), m_Zero())) | ||||
1007 | return nullptr; | ||||
1008 | |||||
1009 | if (ICI->getPredicate() == ICmpInst::ICMP_NE) | ||||
1010 | std::swap(TrueVal, FalseVal); | ||||
1011 | |||||
1012 | if (!match(FalseVal, | ||||
1013 | m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1)))) | ||||
1014 | return nullptr; | ||||
1015 | |||||
1016 | if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>())) | ||||
1017 | return nullptr; | ||||
1018 | |||||
1019 | Value *X = ICI->getOperand(0); | ||||
1020 | auto *II = cast<IntrinsicInst>(TrueVal); | ||||
1021 | if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X))))) | ||||
1022 | return nullptr; | ||||
1023 | |||||
1024 | Function *F = Intrinsic::getDeclaration(II->getModule(), Intrinsic::cttz, | ||||
1025 | II->getType()); | ||||
1026 | return CallInst::Create(F, {X, II->getArgOperand(1)}); | ||||
1027 | } | ||||
1028 | |||||
1029 | /// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single | ||||
1030 | /// call to cttz/ctlz with flag 'is_zero_poison' cleared. | ||||
1031 | /// | ||||
1032 | /// For example, we can fold the following code sequence: | ||||
1033 | /// \code | ||||
1034 | /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 true) | ||||
1035 | /// %1 = icmp ne i32 %x, 0 | ||||
1036 | /// %2 = select i1 %1, i32 %0, i32 32 | ||||
1037 | /// \code | ||||
1038 | /// | ||||
1039 | /// into: | ||||
1040 | /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) | ||||
1041 | static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, | ||||
1042 | InstCombiner::BuilderTy &Builder) { | ||||
1043 | ICmpInst::Predicate Pred = ICI->getPredicate(); | ||||
1044 | Value *CmpLHS = ICI->getOperand(0); | ||||
1045 | Value *CmpRHS = ICI->getOperand(1); | ||||
1046 | |||||
1047 | // Check if the select condition compares a value for equality. | ||||
1048 | if (!ICI->isEquality()) | ||||
1049 | return nullptr; | ||||
1050 | |||||
1051 | Value *SelectArg = FalseVal; | ||||
1052 | Value *ValueOnZero = TrueVal; | ||||
1053 | if (Pred == ICmpInst::ICMP_NE) | ||||
1054 | std::swap(SelectArg, ValueOnZero); | ||||
1055 | |||||
1056 | // Skip zero extend/truncate. | ||||
1057 | Value *Count = nullptr; | ||||
1058 | if (!match(SelectArg, m_ZExt(m_Value(Count))) && | ||||
1059 | !match(SelectArg, m_Trunc(m_Value(Count)))) | ||||
1060 | Count = SelectArg; | ||||
1061 | |||||
1062 | // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the | ||||
1063 | // input to the cttz/ctlz is used as LHS for the compare instruction. | ||||
1064 | Value *X; | ||||
1065 | if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Value(X))) && | ||||
1066 | !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Value(X)))) | ||||
1067 | return nullptr; | ||||
1068 | |||||
1069 | // (X == 0) ? BitWidth : ctz(X) | ||||
1070 | // (X == -1) ? BitWidth : ctz(~X) | ||||
1071 | if ((X != CmpLHS || !match(CmpRHS, m_Zero())) && | ||||
1072 | (!match(X, m_Not(m_Specific(CmpLHS))) || !match(CmpRHS, m_AllOnes()))) | ||||
1073 | return nullptr; | ||||
1074 | |||||
1075 | IntrinsicInst *II = cast<IntrinsicInst>(Count); | ||||
1076 | |||||
1077 | // Check if the value propagated on zero is a constant number equal to the | ||||
1078 | // sizeof in bits of 'Count'. | ||||
1079 | unsigned SizeOfInBits = Count->getType()->getScalarSizeInBits(); | ||||
1080 | if (match(ValueOnZero, m_SpecificInt(SizeOfInBits))) { | ||||
1081 | // Explicitly clear the 'is_zero_poison' flag. It's always valid to go from | ||||
1082 | // true to false on this flag, so we can replace it for all users. | ||||
1083 | II->setArgOperand(1, ConstantInt::getFalse(II->getContext())); | ||||
1084 | return SelectArg; | ||||
1085 | } | ||||
1086 | |||||
1087 | // The ValueOnZero is not the bitwidth. But if the cttz/ctlz (and optional | ||||
1088 | // zext/trunc) have one use (ending at the select), the cttz/ctlz result will | ||||
1089 | // not be used if the input is zero. Relax to 'zero is poison' for that case. | ||||
1090 | if (II->hasOneUse() && SelectArg->hasOneUse() && | ||||
1091 | !match(II->getArgOperand(1), m_One())) | ||||
1092 | II->setArgOperand(1, ConstantInt::getTrue(II->getContext())); | ||||
1093 | |||||
1094 | return nullptr; | ||||
1095 | } | ||||
1096 | |||||
1097 | /// Return true if we find and adjust an icmp+select pattern where the compare | ||||
1098 | /// is with a constant that can be incremented or decremented to match the | ||||
1099 | /// minimum or maximum idiom. | ||||
1100 | static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { | ||||
1101 | ICmpInst::Predicate Pred = Cmp.getPredicate(); | ||||
1102 | Value *CmpLHS = Cmp.getOperand(0); | ||||
1103 | Value *CmpRHS = Cmp.getOperand(1); | ||||
1104 | Value *TrueVal = Sel.getTrueValue(); | ||||
1105 | Value *FalseVal = Sel.getFalseValue(); | ||||
1106 | |||||
1107 | // We may move or edit the compare, so make sure the select is the only user. | ||||
1108 | const APInt *CmpC; | ||||
1109 | if (!Cmp.hasOneUse() || !match(CmpRHS, m_APInt(CmpC))) | ||||
1110 | return false; | ||||
1111 | |||||
1112 | // These transforms only work for selects of integers or vector selects of | ||||
1113 | // integer vectors. | ||||
1114 | Type *SelTy = Sel.getType(); | ||||
1115 | auto *SelEltTy = dyn_cast<IntegerType>(SelTy->getScalarType()); | ||||
1116 | if (!SelEltTy || SelTy->isVectorTy() != Cmp.getType()->isVectorTy()) | ||||
1117 | return false; | ||||
1118 | |||||
1119 | Constant *AdjustedRHS; | ||||
1120 | if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT) | ||||
1121 | AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC + 1); | ||||
1122 | else if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) | ||||
1123 | AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC - 1); | ||||
1124 | else | ||||
1125 | return false; | ||||
1126 | |||||
1127 | // X > C ? X : C+1 --> X < C+1 ? C+1 : X | ||||
1128 | // X < C ? X : C-1 --> X > C-1 ? C-1 : X | ||||
1129 | if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || | ||||
1130 | (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { | ||||
1131 | ; // Nothing to do here. Values match without any sign/zero extension. | ||||
1132 | } | ||||
1133 | // Types do not match. Instead of calculating this with mixed types, promote | ||||
1134 | // all to the larger type. This enables scalar evolution to analyze this | ||||
1135 | // expression. | ||||
1136 | else if (CmpRHS->getType()->getScalarSizeInBits() < SelEltTy->getBitWidth()) { | ||||
1137 | Constant *SextRHS = ConstantExpr::getSExt(AdjustedRHS, SelTy); | ||||
1138 | |||||
1139 | // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X | ||||
1140 | // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X | ||||
1141 | // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X | ||||
1142 | // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X | ||||
1143 | if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && SextRHS == FalseVal) { | ||||
1144 | CmpLHS = TrueVal; | ||||
1145 | AdjustedRHS = SextRHS; | ||||
1146 | } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) && | ||||
1147 | SextRHS == TrueVal) { | ||||
1148 | CmpLHS = FalseVal; | ||||
1149 | AdjustedRHS = SextRHS; | ||||
1150 | } else if (Cmp.isUnsigned()) { | ||||
1151 | Constant *ZextRHS = ConstantExpr::getZExt(AdjustedRHS, SelTy); | ||||
1152 | // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X | ||||
1153 | // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X | ||||
1154 | // zext + signed compare cannot be changed: | ||||
1155 | // 0xff <s 0x00, but 0x00ff >s 0x0000 | ||||
1156 | if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && ZextRHS == FalseVal) { | ||||
1157 | CmpLHS = TrueVal; | ||||
1158 | AdjustedRHS = ZextRHS; | ||||
1159 | } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) && | ||||
1160 | ZextRHS == TrueVal) { | ||||
1161 | CmpLHS = FalseVal; | ||||
1162 | AdjustedRHS = ZextRHS; | ||||
1163 | } else { | ||||
1164 | return false; | ||||
1165 | } | ||||
1166 | } else { | ||||
1167 | return false; | ||||
1168 | } | ||||
1169 | } else { | ||||
1170 | return false; | ||||
1171 | } | ||||
1172 | |||||
1173 | Pred = ICmpInst::getSwappedPredicate(Pred); | ||||
1174 | CmpRHS = AdjustedRHS; | ||||
1175 | std::swap(FalseVal, TrueVal); | ||||
1176 | Cmp.setPredicate(Pred); | ||||
1177 | Cmp.setOperand(0, CmpLHS); | ||||
1178 | Cmp.setOperand(1, CmpRHS); | ||||
1179 | Sel.setOperand(1, TrueVal); | ||||
1180 | Sel.setOperand(2, FalseVal); | ||||
1181 | Sel.swapProfMetadata(); | ||||
1182 | |||||
1183 | // Move the compare instruction right before the select instruction. Otherwise | ||||
1184 | // the sext/zext value may be defined after the compare instruction uses it. | ||||
1185 | Cmp.moveBefore(&Sel); | ||||
1186 | |||||
1187 | return true; | ||||
1188 | } | ||||
1189 | |||||
1190 | static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, | ||||
1191 | InstCombinerImpl &IC) { | ||||
1192 | Value *LHS, *RHS; | ||||
1193 | // TODO: What to do with pointer min/max patterns? | ||||
1194 | if (!Sel.getType()->isIntOrIntVectorTy()) | ||||
1195 | return nullptr; | ||||
1196 | |||||
1197 | SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor; | ||||
1198 | if (SPF == SelectPatternFlavor::SPF_ABS || | ||||
1199 | SPF == SelectPatternFlavor::SPF_NABS) { | ||||
1200 | if (!Cmp.hasOneUse() && !RHS->hasOneUse()) | ||||
1201 | return nullptr; // TODO: Relax this restriction. | ||||
1202 | |||||
1203 | // Note that NSW flag can only be propagated for normal, non-negated abs! | ||||
1204 | bool IntMinIsPoison = SPF == SelectPatternFlavor::SPF_ABS && | ||||
1205 | match(RHS, m_NSWNeg(m_Specific(LHS))); | ||||
1206 | Constant *IntMinIsPoisonC = | ||||
1207 | ConstantInt::get(Type::getInt1Ty(Sel.getContext()), IntMinIsPoison); | ||||
1208 | Instruction *Abs = | ||||
1209 | IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC); | ||||
1210 | |||||
1211 | if (SPF == SelectPatternFlavor::SPF_NABS) | ||||
1212 | return BinaryOperator::CreateNeg(Abs); // Always without NSW flag! | ||||
1213 | return IC.replaceInstUsesWith(Sel, Abs); | ||||
1214 | } | ||||
1215 | |||||
1216 | if (SelectPatternResult::isMinOrMax(SPF)) { | ||||
1217 | Intrinsic::ID IntrinsicID; | ||||
1218 | switch (SPF) { | ||||
1219 | case SelectPatternFlavor::SPF_UMIN: | ||||
1220 | IntrinsicID = Intrinsic::umin; | ||||
1221 | break; | ||||
1222 | case SelectPatternFlavor::SPF_UMAX: | ||||
1223 | IntrinsicID = Intrinsic::umax; | ||||
1224 | break; | ||||
1225 | case SelectPatternFlavor::SPF_SMIN: | ||||
1226 | IntrinsicID = Intrinsic::smin; | ||||
1227 | break; | ||||
1228 | case SelectPatternFlavor::SPF_SMAX: | ||||
1229 | IntrinsicID = Intrinsic::smax; | ||||
1230 | break; | ||||
1231 | default: | ||||
1232 | llvm_unreachable("Unexpected SPF")::llvm::llvm_unreachable_internal("Unexpected SPF", "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp" , 1232); | ||||
1233 | } | ||||
1234 | return IC.replaceInstUsesWith( | ||||
1235 | Sel, IC.Builder.CreateBinaryIntrinsic(IntrinsicID, LHS, RHS)); | ||||
1236 | } | ||||
1237 | |||||
1238 | return nullptr; | ||||
1239 | } | ||||
1240 | |||||
1241 | static bool replaceInInstruction(Value *V, Value *Old, Value *New, | ||||
1242 | InstCombiner &IC, unsigned Depth = 0) { | ||||
1243 | // Conservatively limit replacement to two instructions upwards. | ||||
1244 | if (Depth == 2) | ||||
1245 | return false; | ||||
1246 | |||||
1247 | auto *I = dyn_cast<Instruction>(V); | ||||
1248 | if (!I || !I->hasOneUse() || !isSafeToSpeculativelyExecute(I)) | ||||
1249 | return false; | ||||
1250 | |||||
1251 | bool Changed = false; | ||||
1252 | for (Use &U : I->operands()) { | ||||
1253 | if (U == Old) { | ||||
1254 | IC.replaceUse(U, New); | ||||
1255 | Changed = true; | ||||
1256 | } else { | ||||
1257 | Changed |= replaceInInstruction(U, Old, New, IC, Depth + 1); | ||||
1258 | } | ||||
1259 | } | ||||
1260 | return Changed; | ||||
1261 | } | ||||
1262 | |||||
1263 | /// If we have a select with an equality comparison, then we know the value in | ||||
1264 | /// one of the arms of the select. See if substituting this value into an arm | ||||
1265 | /// and simplifying the result yields the same value as the other arm. | ||||
1266 | /// | ||||
1267 | /// To make this transform safe, we must drop poison-generating flags | ||||
1268 | /// (nsw, etc) if we simplified to a binop because the select may be guarding | ||||
1269 | /// that poison from propagating. If the existing binop already had no | ||||
1270 | /// poison-generating flags, then this transform can be done by instsimplify. | ||||
1271 | /// | ||||
1272 | /// Consider: | ||||
1273 | /// %cmp = icmp eq i32 %x, 2147483647 | ||||
1274 | /// %add = add nsw i32 %x, 1 | ||||
1275 | /// %sel = select i1 %cmp, i32 -2147483648, i32 %add | ||||
1276 | /// | ||||
1277 | /// We can't replace %sel with %add unless we strip away the flags. | ||||
1278 | /// TODO: Wrapping flags could be preserved in some cases with better analysis. | ||||
1279 | Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, | ||||
1280 | ICmpInst &Cmp) { | ||||
1281 | if (!Cmp.isEquality()) | ||||
1282 | return nullptr; | ||||
1283 | |||||
1284 | // Canonicalize the pattern to ICMP_EQ by swapping the select operands. | ||||
1285 | Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); | ||||
1286 | bool Swapped = false; | ||||
1287 | if (Cmp.getPredicate() == ICmpInst::ICMP_NE) { | ||||
1288 | std::swap(TrueVal, FalseVal); | ||||
1289 | Swapped = true; | ||||
1290 | } | ||||
1291 | |||||
1292 | // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. | ||||
1293 | // Make sure Y cannot be undef though, as we might pick different values for | ||||
1294 | // undef in the icmp and in f(Y). Additionally, take care to avoid replacing | ||||
1295 | // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite | ||||
1296 | // replacement cycle. | ||||
1297 | Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); | ||||
1298 | if (TrueVal != CmpLHS && | ||||
1299 | isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) { | ||||
1300 | if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, | ||||
1301 | /* AllowRefinement */ true)) | ||||
1302 | return replaceOperand(Sel, Swapped ? 2 : 1, V); | ||||
1303 | |||||
1304 | // Even if TrueVal does not simplify, we can directly replace a use of | ||||
1305 | // CmpLHS with CmpRHS, as long as the instruction is not used anywhere | ||||
1306 | // else and is safe to speculatively execute (we may end up executing it | ||||
1307 | // with different operands, which should not cause side-effects or trigger | ||||
1308 | // undefined behavior). Only do this if CmpRHS is a constant, as | ||||
1309 | // profitability is not clear for other cases. | ||||
1310 | // FIXME: Support vectors. | ||||
1311 | if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) && | ||||
1312 | !Cmp.getType()->isVectorTy()) | ||||
1313 | if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS, *this)) | ||||
1314 | return &Sel; | ||||
1315 | } | ||||
1316 | if (TrueVal != CmpRHS && | ||||
1317 | isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) | ||||
1318 | if (Value *V = simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, | ||||
1319 | /* AllowRefinement */ true)) | ||||
1320 | return replaceOperand(Sel, Swapped ? 2 : 1, V); | ||||
1321 | |||||
1322 | auto *FalseInst = dyn_cast<Instruction>(FalseVal); | ||||
1323 | if (!FalseInst) | ||||
1324 | return nullptr; | ||||
1325 | |||||
1326 | // InstSimplify already performed this fold if it was possible subject to | ||||
1327 | // current poison-generating flags. Try the transform again with | ||||
1328 | // poison-generating flags temporarily dropped. | ||||
1329 | bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false; | ||||
1330 | if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) { | ||||
1331 | WasNUW = OBO->hasNoUnsignedWrap(); | ||||
1332 | WasNSW = OBO->hasNoSignedWrap(); | ||||
1333 | FalseInst->setHasNoUnsignedWrap(false); | ||||
1334 | FalseInst->setHasNoSignedWrap(false); | ||||
1335 | } | ||||
1336 | if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) { | ||||
1337 | WasExact = PEO->isExact(); | ||||
1338 | FalseInst->setIsExact(false); | ||||
1339 | } | ||||
1340 | if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) { | ||||
1341 | WasInBounds = GEP->isInBounds(); | ||||
1342 | GEP->setIsInBounds(false); | ||||
1343 | } | ||||
1344 | |||||
1345 | // Try each equivalence substitution possibility. | ||||
1346 | // We have an 'EQ' comparison, so the select's false value will propagate. | ||||
1347 | // Example: | ||||
1348 | // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 | ||||
1349 | if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, | ||||
1350 | /* AllowRefinement */ false) == TrueVal || | ||||
1351 | simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, | ||||
1352 | /* AllowRefinement */ false) == TrueVal) { | ||||
1353 | return replaceInstUsesWith(Sel, FalseVal); | ||||
1354 | } | ||||
1355 | |||||
1356 | // Restore poison-generating flags if the transform did not apply. | ||||
1357 | if (WasNUW) | ||||
1358 | FalseInst->setHasNoUnsignedWrap(); | ||||
1359 | if (WasNSW) | ||||
1360 | FalseInst->setHasNoSignedWrap(); | ||||
1361 | if (WasExact) | ||||
1362 | FalseInst->setIsExact(); | ||||
1363 | if (WasInBounds) | ||||
1364 | cast<GetElementPtrInst>(FalseInst)->setIsInBounds(); | ||||
1365 | |||||
1366 | return nullptr; | ||||
1367 | } | ||||
1368 | |||||
1369 | // See if this is a pattern like: | ||||
1370 | // %old_cmp1 = icmp slt i32 %x, C2 | ||||
1371 | // %old_replacement = select i1 %old_cmp1, i32 %target_low, i32 %target_high | ||||
1372 | // %old_x_offseted = add i32 %x, C1 | ||||
1373 | // %old_cmp0 = icmp ult i32 %old_x_offseted, C0 | ||||
1374 | // %r = select i1 %old_cmp0, i32 %x, i32 %old_replacement | ||||
1375 | // This can be rewritten as more canonical pattern: | ||||
1376 | // %new_cmp1 = icmp slt i32 %x, -C1 | ||||
1377 | // %new_cmp2 = icmp sge i32 %x, C0-C1 | ||||
1378 | // %new_clamped_low = select i1 %new_cmp1, i32 %target_low, i32 %x | ||||
1379 | // %r = select i1 %new_cmp2, i32 %target_high, i32 %new_clamped_low | ||||
1380 | // Iff -C1 s<= C2 s<= C0-C1 | ||||
1381 | // Also ULT predicate can also be UGT iff C0 != -1 (+invert result) | ||||
1382 | // SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.) | ||||
1383 | static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, | ||||
1384 | InstCombiner::BuilderTy &Builder) { | ||||
1385 | Value *X = Sel0.getTrueValue(); | ||||
1386 | Value *Sel1 = Sel0.getFalseValue(); | ||||
1387 | |||||
1388 | // First match the condition of the outermost select. | ||||
1389 | // Said condition must be one-use. | ||||
1390 | if (!Cmp0.hasOneUse()) | ||||
1391 | return nullptr; | ||||
1392 | ICmpInst::Predicate Pred0 = Cmp0.getPredicate(); | ||||
1393 | Value *Cmp00 = Cmp0.getOperand(0); | ||||
1394 | Constant *C0; | ||||
1395 | if (!match(Cmp0.getOperand(1), | ||||
1396 | m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0)))) | ||||
1397 | return nullptr; | ||||
1398 | |||||
1399 | if (!isa<SelectInst>(Sel1)) { | ||||
1400 | Pred0 = ICmpInst::getInversePredicate(Pred0); | ||||
1401 | std::swap(X, Sel1); | ||||
1402 | } | ||||
1403 | |||||
1404 | // Canonicalize Cmp0 into ult or uge. | ||||
1405 | // FIXME: we shouldn't care about lanes that are 'undef' in the end? | ||||
1406 | switch (Pred0) { | ||||
1407 | case ICmpInst::Predicate::ICMP_ULT: | ||||
1408 | case ICmpInst::Predicate::ICMP_UGE: | ||||
1409 | // Although icmp ult %x, 0 is an unusual thing to try and should generally | ||||
1410 | // have been simplified, it does not verify with undef inputs so ensure we | ||||
1411 | // are not in a strange state. | ||||
1412 | if (!match(C0, m_SpecificInt_ICMP( | ||||
1413 | ICmpInst::Predicate::ICMP_NE, | ||||
1414 | APInt::getZero(C0->getType()->getScalarSizeInBits())))) | ||||
1415 | return nullptr; | ||||
1416 | break; // Great! | ||||
1417 | case ICmpInst::Predicate::ICMP_ULE: | ||||
1418 | case ICmpInst::Predicate::ICMP_UGT: | ||||
1419 | // We want to canonicalize it to 'ult' or 'uge', so we'll need to increment | ||||
1420 | // C0, which again means it must not have any all-ones elements. | ||||
1421 | if (!match(C0, | ||||
1422 | m_SpecificInt_ICMP( | ||||
1423 | ICmpInst::Predicate::ICMP_NE, | ||||
1424 | APInt::getAllOnes(C0->getType()->getScalarSizeInBits())))) | ||||
1425 | return nullptr; // Can't do, have all-ones element[s]. | ||||
1426 | Pred0 = ICmpInst::getFlippedStrictnessPredicate(Pred0); | ||||
1427 | C0 = InstCombiner::AddOne(C0); | ||||
1428 | break; | ||||
1429 | default: | ||||
1430 | return nullptr; // Unknown predicate. | ||||
1431 | } | ||||
1432 | |||||
1433 | // Now that we've canonicalized the ICmp, we know the X we expect; | ||||
1434 | // the select in other hand should be one-use. | ||||
1435 | if (!Sel1->hasOneUse()) | ||||
1436 | return nullptr; | ||||
1437 | |||||
1438 | // If the types do not match, look through any truncs to the underlying | ||||
1439 | // instruction. | ||||
1440 | if (Cmp00->getType() != X->getType() && X->hasOneUse()) | ||||
1441 | match(X, m_TruncOrSelf(m_Value(X))); | ||||
1442 | |||||
1443 | // We now can finish matching the condition of the outermost select: | ||||
1444 | // it should either be the X itself, or an addition of some constant to X. | ||||
1445 | Constant *C1; | ||||
1446 | if (Cmp00 == X) | ||||
1447 | C1 = ConstantInt::getNullValue(X->getType()); | ||||
1448 | else if (!match(Cmp00, | ||||
1449 | m_Add(m_Specific(X), | ||||
1450 | m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C1))))) | ||||
1451 | return nullptr; | ||||
1452 | |||||
1453 | Value *Cmp1; | ||||
1454 | ICmpInst::Predicate Pred1; | ||||
1455 | Constant *C2; | ||||
1456 | Value *ReplacementLow, *ReplacementHigh; | ||||
1457 | if (!match(Sel1, m_Select(m_Value(Cmp1), m_Value(ReplacementLow), | ||||
1458 | m_Value(ReplacementHigh))) || | ||||
1459 | !match(Cmp1, | ||||
1460 | m_ICmp(Pred1, m_Specific(X), | ||||
1461 | m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C2))))) | ||||
1462 | return nullptr; | ||||
1463 | |||||
1464 | if (!Cmp1->hasOneUse() && (Cmp00 == X || !Cmp00->hasOneUse())) | ||||
1465 | return nullptr; // Not enough one-use instructions for the fold. | ||||
1466 | // FIXME: this restriction could be relaxed if Cmp1 can be reused as one of | ||||
1467 | // two comparisons we'll need to build. | ||||
1468 | |||||
1469 | // Canonicalize Cmp1 into the form we expect. | ||||
1470 | // FIXME: we shouldn't care about lanes that are 'undef' in the end? | ||||
1471 | switch (Pred1) { | ||||
1472 | case ICmpInst::Predicate::ICMP_SLT: | ||||
1473 | break; | ||||
1474 | case ICmpInst::Predicate::ICMP_SLE: | ||||
1475 | // We'd have to increment C2 by one, and for that it must not have signed | ||||
1476 | // max element, but then it would have been canonicalized to 'slt' before | ||||
1477 | // we get here. So we can't do anything useful with 'sle'. | ||||
1478 | return nullptr; | ||||
1479 | case ICmpInst::Predicate::ICMP_SGT: | ||||
1480 | // We want to canonicalize it to 'slt', so we'll need to increment C2, | ||||
1481 | // which again means it must not have any signed max elements. | ||||
1482 | if (!match(C2, | ||||
1483 | m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE, | ||||
1484 | APInt::getSignedMaxValue( | ||||
1485 | C2->getType()->getScalarSizeInBits())))) | ||||
1486 | return nullptr; // Can't do, have signed max element[s]. | ||||
1487 | C2 = InstCombiner::AddOne(C2); | ||||
1488 | [[fallthrough]]; | ||||
1489 | case ICmpInst::Predicate::ICMP_SGE: | ||||
1490 | // Also non-canonical, but here we don't need to change C2, | ||||
1491 | // so we don't have any restrictions on C2, so we can just handle it. | ||||
1492 | Pred1 = ICmpInst::Predicate::ICMP_SLT; | ||||
1493 | std::swap(ReplacementLow, ReplacementHigh); | ||||
1494 | break; | ||||
1495 | default: | ||||
1496 | return nullptr; // Unknown predicate. | ||||
1497 | } | ||||
1498 | assert(Pred1 == ICmpInst::Predicate::ICMP_SLT &&(static_cast <bool> (Pred1 == ICmpInst::Predicate::ICMP_SLT && "Unexpected predicate type.") ? void (0) : __assert_fail ("Pred1 == ICmpInst::Predicate::ICMP_SLT && \"Unexpected predicate type.\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 1499 , __extension__ __PRETTY_FUNCTION__)) | ||||
1499 | "Unexpected predicate type.")(static_cast <bool> (Pred1 == ICmpInst::Predicate::ICMP_SLT && "Unexpected predicate type.") ? void (0) : __assert_fail ("Pred1 == ICmpInst::Predicate::ICMP_SLT && \"Unexpected predicate type.\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 1499 , __extension__ __PRETTY_FUNCTION__)); | ||||
1500 | |||||
1501 | // The thresholds of this clamp-like pattern. | ||||
1502 | auto *ThresholdLowIncl = ConstantExpr::getNeg(C1); | ||||
1503 | auto *ThresholdHighExcl = ConstantExpr::getSub(C0, C1); | ||||
1504 | |||||
1505 | assert((Pred0 == ICmpInst::Predicate::ICMP_ULT ||(static_cast <bool> ((Pred0 == ICmpInst::Predicate::ICMP_ULT || Pred0 == ICmpInst::Predicate::ICMP_UGE) && "Unexpected predicate type." ) ? void (0) : __assert_fail ("(Pred0 == ICmpInst::Predicate::ICMP_ULT || Pred0 == ICmpInst::Predicate::ICMP_UGE) && \"Unexpected predicate type.\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 1507 , __extension__ __PRETTY_FUNCTION__)) | ||||
1506 | Pred0 == ICmpInst::Predicate::ICMP_UGE) &&(static_cast <bool> ((Pred0 == ICmpInst::Predicate::ICMP_ULT || Pred0 == ICmpInst::Predicate::ICMP_UGE) && "Unexpected predicate type." ) ? void (0) : __assert_fail ("(Pred0 == ICmpInst::Predicate::ICMP_ULT || Pred0 == ICmpInst::Predicate::ICMP_UGE) && \"Unexpected predicate type.\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 1507 , __extension__ __PRETTY_FUNCTION__)) | ||||
1507 | "Unexpected predicate type.")(static_cast <bool> ((Pred0 == ICmpInst::Predicate::ICMP_ULT || Pred0 == ICmpInst::Predicate::ICMP_UGE) && "Unexpected predicate type." ) ? void (0) : __assert_fail ("(Pred0 == ICmpInst::Predicate::ICMP_ULT || Pred0 == ICmpInst::Predicate::ICMP_UGE) && \"Unexpected predicate type.\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 1507 , __extension__ __PRETTY_FUNCTION__)); | ||||
1508 | if (Pred0 == ICmpInst::Predicate::ICMP_UGE) | ||||
1509 | std::swap(ThresholdLowIncl, ThresholdHighExcl); | ||||
1510 | |||||
1511 | // The fold has a precondition 1: C2 s>= ThresholdLow | ||||
1512 | auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2, | ||||
1513 | ThresholdLowIncl); | ||||
1514 | if (!match(Precond1, m_One())) | ||||
1515 | return nullptr; | ||||
1516 | // The fold has a precondition 2: C2 s<= ThresholdHigh | ||||
1517 | auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2, | ||||
1518 | ThresholdHighExcl); | ||||
1519 | if (!match(Precond2, m_One())) | ||||
1520 | return nullptr; | ||||
1521 | |||||
1522 | // If we are matching from a truncated input, we need to sext the | ||||
1523 | // ReplacementLow and ReplacementHigh values. Only do the transform if they | ||||
1524 | // are free to extend due to being constants. | ||||
1525 | if (X->getType() != Sel0.getType()) { | ||||
1526 | Constant *LowC, *HighC; | ||||
1527 | if (!match(ReplacementLow, m_ImmConstant(LowC)) || | ||||
1528 | !match(ReplacementHigh, m_ImmConstant(HighC))) | ||||
1529 | return nullptr; | ||||
1530 | ReplacementLow = ConstantExpr::getSExt(LowC, X->getType()); | ||||
1531 | ReplacementHigh = ConstantExpr::getSExt(HighC, X->getType()); | ||||
1532 | } | ||||
1533 | |||||
1534 | // All good, finally emit the new pattern. | ||||
1535 | Value *ShouldReplaceLow = Builder.CreateICmpSLT(X, ThresholdLowIncl); | ||||
1536 | Value *ShouldReplaceHigh = Builder.CreateICmpSGE(X, ThresholdHighExcl); | ||||
1537 | Value *MaybeReplacedLow = | ||||
1538 | Builder.CreateSelect(ShouldReplaceLow, ReplacementLow, X); | ||||
1539 | |||||
1540 | // Create the final select. If we looked through a truncate above, we will | ||||
1541 | // need to retruncate the result. | ||||
1542 | Value *MaybeReplacedHigh = Builder.CreateSelect( | ||||
1543 | ShouldReplaceHigh, ReplacementHigh, MaybeReplacedLow); | ||||
1544 | return Builder.CreateTrunc(MaybeReplacedHigh, Sel0.getType()); | ||||
1545 | } | ||||
1546 | |||||
1547 | // If we have | ||||
1548 | // %cmp = icmp [canonical predicate] i32 %x, C0 | ||||
1549 | // %r = select i1 %cmp, i32 %y, i32 C1 | ||||
1550 | // Where C0 != C1 and %x may be different from %y, see if the constant that we | ||||
1551 | // will have if we flip the strictness of the predicate (i.e. without changing | ||||
1552 | // the result) is identical to the C1 in select. If it matches we can change | ||||
1553 | // original comparison to one with swapped predicate, reuse the constant, | ||||
1554 | // and swap the hands of select. | ||||
1555 | static Instruction * | ||||
1556 | tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, | ||||
1557 | InstCombinerImpl &IC) { | ||||
1558 | ICmpInst::Predicate Pred; | ||||
1559 | Value *X; | ||||
1560 | Constant *C0; | ||||
1561 | if (!match(&Cmp, m_OneUse(m_ICmp( | ||||
1562 | Pred, m_Value(X), | ||||
1563 | m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0)))))) | ||||
1564 | return nullptr; | ||||
1565 | |||||
1566 | // If comparison predicate is non-relational, we won't be able to do anything. | ||||
1567 | if (ICmpInst::isEquality(Pred)) | ||||
1568 | return nullptr; | ||||
1569 | |||||
1570 | // If comparison predicate is non-canonical, then we certainly won't be able | ||||
1571 | // to make it canonical; canonicalizeCmpWithConstant() already tried. | ||||
1572 | if (!InstCombiner::isCanonicalPredicate(Pred)) | ||||
1573 | return nullptr; | ||||
1574 | |||||
1575 | // If the [input] type of comparison and select type are different, lets abort | ||||
1576 | // for now. We could try to compare constants with trunc/[zs]ext though. | ||||
1577 | if (C0->getType() != Sel.getType()) | ||||
1578 | return nullptr; | ||||
1579 | |||||
1580 | // ULT with 'add' of a constant is canonical. See foldICmpAddConstant(). | ||||
1581 | // FIXME: Are there more magic icmp predicate+constant pairs we must avoid? | ||||
1582 | // Or should we just abandon this transform entirely? | ||||
1583 | if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant()))) | ||||
1584 | return nullptr; | ||||
1585 | |||||
1586 | |||||
1587 | Value *SelVal0, *SelVal1; // We do not care which one is from where. | ||||
1588 | match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1))); | ||||
1589 | // At least one of these values we are selecting between must be a constant | ||||
1590 | // else we'll never succeed. | ||||
1591 | if (!match(SelVal0, m_AnyIntegralConstant()) && | ||||
1592 | !match(SelVal1, m_AnyIntegralConstant())) | ||||
1593 | return nullptr; | ||||
1594 | |||||
1595 | // Does this constant C match any of the `select` values? | ||||
1596 | auto MatchesSelectValue = [SelVal0, SelVal1](Constant *C) { | ||||
1597 | return C->isElementWiseEqual(SelVal0) || C->isElementWiseEqual(SelVal1); | ||||
1598 | }; | ||||
1599 | |||||
1600 | // If C0 *already* matches true/false value of select, we are done. | ||||
1601 | if (MatchesSelectValue(C0)) | ||||
1602 | return nullptr; | ||||
1603 | |||||
1604 | // Check the constant we'd have with flipped-strictness predicate. | ||||
1605 | auto FlippedStrictness = | ||||
1606 | InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0); | ||||
1607 | if (!FlippedStrictness) | ||||
1608 | return nullptr; | ||||
1609 | |||||
1610 | // If said constant doesn't match either, then there is no hope, | ||||
1611 | if (!MatchesSelectValue(FlippedStrictness->second)) | ||||
1612 | return nullptr; | ||||
1613 | |||||
1614 | // It matched! Lets insert the new comparison just before select. | ||||
1615 | InstCombiner::BuilderTy::InsertPointGuard Guard(IC.Builder); | ||||
1616 | IC.Builder.SetInsertPoint(&Sel); | ||||
1617 | |||||
1618 | Pred = ICmpInst::getSwappedPredicate(Pred); // Yes, swapped. | ||||
1619 | Value *NewCmp = IC.Builder.CreateICmp(Pred, X, FlippedStrictness->second, | ||||
1620 | Cmp.getName() + ".inv"); | ||||
1621 | IC.replaceOperand(Sel, 0, NewCmp); | ||||
1622 | Sel.swapValues(); | ||||
1623 | Sel.swapProfMetadata(); | ||||
1624 | |||||
1625 | return &Sel; | ||||
1626 | } | ||||
1627 | |||||
1628 | static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal, | ||||
1629 | Value *FVal, | ||||
1630 | InstCombiner::BuilderTy &Builder) { | ||||
1631 | if (!Cmp->hasOneUse()) | ||||
1632 | return nullptr; | ||||
1633 | |||||
1634 | const APInt *CmpC; | ||||
1635 | if (!match(Cmp->getOperand(1), m_APIntAllowUndef(CmpC))) | ||||
1636 | return nullptr; | ||||
1637 | |||||
1638 | // (X u< 2) ? -X : -1 --> sext (X != 0) | ||||
1639 | Value *X = Cmp->getOperand(0); | ||||
1640 | if (Cmp->getPredicate() == ICmpInst::ICMP_ULT && *CmpC == 2 && | ||||
1641 | match(TVal, m_Neg(m_Specific(X))) && match(FVal, m_AllOnes())) | ||||
1642 | return new SExtInst(Builder.CreateIsNotNull(X), TVal->getType()); | ||||
1643 | |||||
1644 | // (X u> 1) ? -1 : -X --> sext (X != 0) | ||||
1645 | if (Cmp->getPredicate() == ICmpInst::ICMP_UGT && *CmpC == 1 && | ||||
1646 | match(FVal, m_Neg(m_Specific(X))) && match(TVal, m_AllOnes())) | ||||
1647 | return new SExtInst(Builder.CreateIsNotNull(X), TVal->getType()); | ||||
1648 | |||||
1649 | return nullptr; | ||||
1650 | } | ||||
1651 | |||||
1652 | static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI, | ||||
1653 | InstCombiner::BuilderTy &Builder) { | ||||
1654 | const APInt *CmpC; | ||||
1655 | Value *V; | ||||
1656 | CmpInst::Predicate Pred; | ||||
1657 | if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC)))) | ||||
1658 | return nullptr; | ||||
1659 | |||||
1660 | // Match clamp away from min/max value as a max/min operation. | ||||
1661 | Value *TVal = SI.getTrueValue(); | ||||
1662 | Value *FVal = SI.getFalseValue(); | ||||
1663 | if (Pred == ICmpInst::ICMP_EQ && V == FVal) { | ||||
1664 | // (V == UMIN) ? UMIN+1 : V --> umax(V, UMIN+1) | ||||
1665 | if (CmpC->isMinValue() && match(TVal, m_SpecificInt(*CmpC + 1))) | ||||
1666 | return Builder.CreateBinaryIntrinsic(Intrinsic::umax, V, TVal); | ||||
1667 | // (V == UMAX) ? UMAX-1 : V --> umin(V, UMAX-1) | ||||
1668 | if (CmpC->isMaxValue() && match(TVal, m_SpecificInt(*CmpC - 1))) | ||||
1669 | return Builder.CreateBinaryIntrinsic(Intrinsic::umin, V, TVal); | ||||
1670 | // (V == SMIN) ? SMIN+1 : V --> smax(V, SMIN+1) | ||||
1671 | if (CmpC->isMinSignedValue() && match(TVal, m_SpecificInt(*CmpC + 1))) | ||||
1672 | return Builder.CreateBinaryIntrinsic(Intrinsic::smax, V, TVal); | ||||
1673 | // (V == SMAX) ? SMAX-1 : V --> smin(V, SMAX-1) | ||||
1674 | if (CmpC->isMaxSignedValue() && match(TVal, m_SpecificInt(*CmpC - 1))) | ||||
1675 | return Builder.CreateBinaryIntrinsic(Intrinsic::smin, V, TVal); | ||||
1676 | } | ||||
1677 | |||||
1678 | BinaryOperator *BO; | ||||
1679 | const APInt *C; | ||||
1680 | CmpInst::Predicate CPred; | ||||
1681 | if (match(&SI, m_Select(m_Specific(ICI), m_APInt(C), m_BinOp(BO)))) | ||||
1682 | CPred = ICI->getPredicate(); | ||||
1683 | else if (match(&SI, m_Select(m_Specific(ICI), m_BinOp(BO), m_APInt(C)))) | ||||
1684 | CPred = ICI->getInversePredicate(); | ||||
1685 | else | ||||
1686 | return nullptr; | ||||
1687 | |||||
1688 | const APInt *BinOpC; | ||||
1689 | if (!match(BO, m_BinOp(m_Specific(V), m_APInt(BinOpC)))) | ||||
1690 | return nullptr; | ||||
1691 | |||||
1692 | ConstantRange R = ConstantRange::makeExactICmpRegion(CPred, *CmpC) | ||||
1693 | .binaryOp(BO->getOpcode(), *BinOpC); | ||||
1694 | if (R == *C) { | ||||
1695 | BO->dropPoisonGeneratingFlags(); | ||||
1696 | return BO; | ||||
1697 | } | ||||
1698 | return nullptr; | ||||
1699 | } | ||||
1700 | |||||
1701 | /// Visit a SelectInst that has an ICmpInst as its first operand. | ||||
1702 | Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, | ||||
1703 | ICmpInst *ICI) { | ||||
1704 | if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI)) | ||||
1705 | return NewSel; | ||||
1706 | |||||
1707 | if (Instruction *NewSPF = canonicalizeSPF(SI, *ICI, *this)) | ||||
1708 | return NewSPF; | ||||
1709 | |||||
1710 | if (Value *V = foldSelectInstWithICmpConst(SI, ICI, Builder)) | ||||
1711 | return replaceInstUsesWith(SI, V); | ||||
1712 | |||||
1713 | if (Value *V = canonicalizeClampLike(SI, *ICI, Builder)) | ||||
1714 | return replaceInstUsesWith(SI, V); | ||||
1715 | |||||
1716 | if (Instruction *NewSel = | ||||
1717 | tryToReuseConstantFromSelectInComparison(SI, *ICI, *this)) | ||||
1718 | return NewSel; | ||||
1719 | |||||
1720 | bool Changed = adjustMinMax(SI, *ICI); | ||||
1721 | |||||
1722 | if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) | ||||
1723 | return replaceInstUsesWith(SI, V); | ||||
1724 | |||||
1725 | // NOTE: if we wanted to, this is where to detect integer MIN/MAX | ||||
1726 | Value *TrueVal = SI.getTrueValue(); | ||||
1727 | Value *FalseVal = SI.getFalseValue(); | ||||
1728 | ICmpInst::Predicate Pred = ICI->getPredicate(); | ||||
1729 | Value *CmpLHS = ICI->getOperand(0); | ||||
1730 | Value *CmpRHS = ICI->getOperand(1); | ||||
1731 | if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS) && !isa<Constant>(CmpLHS)) { | ||||
1732 | if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) { | ||||
1733 | // Transform (X == C) ? X : Y -> (X == C) ? C : Y | ||||
1734 | SI.setOperand(1, CmpRHS); | ||||
1735 | Changed = true; | ||||
1736 | } else if (CmpLHS == FalseVal && Pred == ICmpInst::ICMP_NE) { | ||||
1737 | // Transform (X != C) ? Y : X -> (X != C) ? Y : C | ||||
1738 | SI.setOperand(2, CmpRHS); | ||||
1739 | Changed = true; | ||||
1740 | } | ||||
1741 | } | ||||
1742 | |||||
1743 | // Canonicalize a signbit condition to use zero constant by swapping: | ||||
1744 | // (CmpLHS > -1) ? TV : FV --> (CmpLHS < 0) ? FV : TV | ||||
1745 | // To avoid conflicts (infinite loops) with other canonicalizations, this is | ||||
1746 | // not applied with any constant select arm. | ||||
1747 | if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes()) && | ||||
1748 | !match(TrueVal, m_Constant()) && !match(FalseVal, m_Constant()) && | ||||
1749 | ICI->hasOneUse()) { | ||||
1750 | InstCombiner::BuilderTy::InsertPointGuard Guard(Builder); | ||||
1751 | Builder.SetInsertPoint(&SI); | ||||
1752 | Value *IsNeg = Builder.CreateIsNeg(CmpLHS, ICI->getName()); | ||||
1753 | replaceOperand(SI, 0, IsNeg); | ||||
1754 | SI.swapValues(); | ||||
1755 | SI.swapProfMetadata(); | ||||
1756 | return &SI; | ||||
1757 | } | ||||
1758 | |||||
1759 | // FIXME: This code is nearly duplicated in InstSimplify. Using/refactoring | ||||
1760 | // decomposeBitTestICmp() might help. | ||||
1761 | if (TrueVal->getType()->isIntOrIntVectorTy()) { | ||||
1762 | unsigned BitWidth = | ||||
1763 | DL.getTypeSizeInBits(TrueVal->getType()->getScalarType()); | ||||
1764 | APInt MinSignedValue = APInt::getSignedMinValue(BitWidth); | ||||
1765 | Value *X; | ||||
1766 | const APInt *Y, *C; | ||||
1767 | bool TrueWhenUnset; | ||||
1768 | bool IsBitTest = false; | ||||
1769 | if (ICmpInst::isEquality(Pred) && | ||||
1770 | match(CmpLHS, m_And(m_Value(X), m_Power2(Y))) && | ||||
1771 | match(CmpRHS, m_Zero())) { | ||||
1772 | IsBitTest = true; | ||||
1773 | TrueWhenUnset = Pred == ICmpInst::ICMP_EQ; | ||||
1774 | } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) { | ||||
1775 | X = CmpLHS; | ||||
1776 | Y = &MinSignedValue; | ||||
1777 | IsBitTest = true; | ||||
1778 | TrueWhenUnset = false; | ||||
1779 | } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) { | ||||
1780 | X = CmpLHS; | ||||
1781 | Y = &MinSignedValue; | ||||
1782 | IsBitTest = true; | ||||
1783 | TrueWhenUnset = true; | ||||
1784 | } | ||||
1785 | if (IsBitTest) { | ||||
1786 | Value *V = nullptr; | ||||
1787 | // (X & Y) == 0 ? X : X ^ Y --> X & ~Y | ||||
1788 | if (TrueWhenUnset && TrueVal == X && | ||||
1789 | match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) | ||||
1790 | V = Builder.CreateAnd(X, ~(*Y)); | ||||
1791 | // (X & Y) != 0 ? X ^ Y : X --> X & ~Y | ||||
1792 | else if (!TrueWhenUnset && FalseVal == X && | ||||
1793 | match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) | ||||
1794 | V = Builder.CreateAnd(X, ~(*Y)); | ||||
1795 | // (X & Y) == 0 ? X ^ Y : X --> X | Y | ||||
1796 | else if (TrueWhenUnset && FalseVal == X && | ||||
1797 | match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) | ||||
1798 | V = Builder.CreateOr(X, *Y); | ||||
1799 | // (X & Y) != 0 ? X : X ^ Y --> X | Y | ||||
1800 | else if (!TrueWhenUnset && TrueVal == X && | ||||
1801 | match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) | ||||
1802 | V = Builder.CreateOr(X, *Y); | ||||
1803 | |||||
1804 | if (V) | ||||
1805 | return replaceInstUsesWith(SI, V); | ||||
1806 | } | ||||
1807 | } | ||||
1808 | |||||
1809 | if (Instruction *V = | ||||
1810 | foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) | ||||
1811 | return V; | ||||
1812 | |||||
1813 | if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder)) | ||||
1814 | return V; | ||||
1815 | |||||
1816 | if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder)) | ||||
1817 | return V; | ||||
1818 | |||||
1819 | if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) | ||||
1820 | return replaceInstUsesWith(SI, V); | ||||
1821 | |||||
1822 | if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder)) | ||||
1823 | return replaceInstUsesWith(SI, V); | ||||
1824 | |||||
1825 | if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) | ||||
1826 | return replaceInstUsesWith(SI, V); | ||||
1827 | |||||
1828 | if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder)) | ||||
1829 | return replaceInstUsesWith(SI, V); | ||||
1830 | |||||
1831 | if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder)) | ||||
1832 | return replaceInstUsesWith(SI, V); | ||||
1833 | |||||
1834 | if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder)) | ||||
1835 | return replaceInstUsesWith(SI, V); | ||||
1836 | |||||
1837 | return Changed ? &SI : nullptr; | ||||
1838 | } | ||||
1839 | |||||
1840 | /// SI is a select whose condition is a PHI node (but the two may be in | ||||
1841 | /// different blocks). See if the true/false values (V) are live in all of the | ||||
1842 | /// predecessor blocks of the PHI. For example, cases like this can't be mapped: | ||||
1843 | /// | ||||
1844 | /// X = phi [ C1, BB1], [C2, BB2] | ||||
1845 | /// Y = add | ||||
1846 | /// Z = select X, Y, 0 | ||||
1847 | /// | ||||
1848 | /// because Y is not live in BB1/BB2. | ||||
1849 | static bool canSelectOperandBeMappingIntoPredBlock(const Value *V, | ||||
1850 | const SelectInst &SI) { | ||||
1851 | // If the value is a non-instruction value like a constant or argument, it | ||||
1852 | // can always be mapped. | ||||
1853 | const Instruction *I = dyn_cast<Instruction>(V); | ||||
1854 | if (!I) return true; | ||||
1855 | |||||
1856 | // If V is a PHI node defined in the same block as the condition PHI, we can | ||||
1857 | // map the arguments. | ||||
1858 | const PHINode *CondPHI = cast<PHINode>(SI.getCondition()); | ||||
1859 | |||||
1860 | if (const PHINode *VP = dyn_cast<PHINode>(I)) | ||||
1861 | if (VP->getParent() == CondPHI->getParent()) | ||||
1862 | return true; | ||||
1863 | |||||
1864 | // Otherwise, if the PHI and select are defined in the same block and if V is | ||||
1865 | // defined in a different block, then we can transform it. | ||||
1866 | if (SI.getParent() == CondPHI->getParent() && | ||||
1867 | I->getParent() != CondPHI->getParent()) | ||||
1868 | return true; | ||||
1869 | |||||
1870 | // Otherwise we have a 'hard' case and we can't tell without doing more | ||||
1871 | // detailed dominator based analysis, punt. | ||||
1872 | return false; | ||||
1873 | } | ||||
1874 | |||||
1875 | /// We have an SPF (e.g. a min or max) of an SPF of the form: | ||||
1876 | /// SPF2(SPF1(A, B), C) | ||||
1877 | Instruction *InstCombinerImpl::foldSPFofSPF(Instruction *Inner, | ||||
1878 | SelectPatternFlavor SPF1, Value *A, | ||||
1879 | Value *B, Instruction &Outer, | ||||
1880 | SelectPatternFlavor SPF2, | ||||
1881 | Value *C) { | ||||
1882 | if (Outer.getType() != Inner->getType()) | ||||
1883 | return nullptr; | ||||
1884 | |||||
1885 | if (C == A || C == B) { | ||||
1886 | // MAX(MAX(A, B), B) -> MAX(A, B) | ||||
1887 | // MIN(MIN(a, b), a) -> MIN(a, b) | ||||
1888 | // TODO: This could be done in instsimplify. | ||||
1889 | if (SPF1 == SPF2 && SelectPatternResult::isMinOrMax(SPF1)) | ||||
1890 | return replaceInstUsesWith(Outer, Inner); | ||||
1891 | } | ||||
1892 | |||||
1893 | return nullptr; | ||||
1894 | } | ||||
1895 | |||||
1896 | /// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))). | ||||
1897 | /// This is even legal for FP. | ||||
1898 | static Instruction *foldAddSubSelect(SelectInst &SI, | ||||
1899 | InstCombiner::BuilderTy &Builder) { | ||||
1900 | Value *CondVal = SI.getCondition(); | ||||
1901 | Value *TrueVal = SI.getTrueValue(); | ||||
1902 | Value *FalseVal = SI.getFalseValue(); | ||||
1903 | auto *TI = dyn_cast<Instruction>(TrueVal); | ||||
1904 | auto *FI = dyn_cast<Instruction>(FalseVal); | ||||
1905 | if (!TI || !FI || !TI->hasOneUse() || !FI->hasOneUse()) | ||||
1906 | return nullptr; | ||||
1907 | |||||
1908 | Instruction *AddOp = nullptr, *SubOp = nullptr; | ||||
1909 | if ((TI->getOpcode() == Instruction::Sub && | ||||
1910 | FI->getOpcode() == Instruction::Add) || | ||||
1911 | (TI->getOpcode() == Instruction::FSub && | ||||
1912 | FI->getOpcode() == Instruction::FAdd)) { | ||||
1913 | AddOp = FI; | ||||
1914 | SubOp = TI; | ||||
1915 | } else if ((FI->getOpcode() == Instruction::Sub && | ||||
1916 | TI->getOpcode() == Instruction::Add) || | ||||
1917 | (FI->getOpcode() == Instruction::FSub && | ||||
1918 | TI->getOpcode() == Instruction::FAdd)) { | ||||
1919 | AddOp = TI; | ||||
1920 | SubOp = FI; | ||||
1921 | } | ||||
1922 | |||||
1923 | if (AddOp) { | ||||
1924 | Value *OtherAddOp = nullptr; | ||||
1925 | if (SubOp->getOperand(0) == AddOp->getOperand(0)) { | ||||
1926 | OtherAddOp = AddOp->getOperand(1); | ||||
1927 | } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { | ||||
1928 | OtherAddOp = AddOp->getOperand(0); | ||||
1929 | } | ||||
1930 | |||||
1931 | if (OtherAddOp) { | ||||
1932 | // So at this point we know we have (Y -> OtherAddOp): | ||||
1933 | // select C, (add X, Y), (sub X, Z) | ||||
1934 | Value *NegVal; // Compute -Z | ||||
1935 | if (SI.getType()->isFPOrFPVectorTy()) { | ||||
1936 | NegVal = Builder.CreateFNeg(SubOp->getOperand(1)); | ||||
1937 | if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) { | ||||
1938 | FastMathFlags Flags = AddOp->getFastMathFlags(); | ||||
1939 | Flags &= SubOp->getFastMathFlags(); | ||||
1940 | NegInst->setFastMathFlags(Flags); | ||||
1941 | } | ||||
1942 | } else { | ||||
1943 | NegVal = Builder.CreateNeg(SubOp->getOperand(1)); | ||||
1944 | } | ||||
1945 | |||||
1946 | Value *NewTrueOp = OtherAddOp; | ||||
1947 | Value *NewFalseOp = NegVal; | ||||
1948 | if (AddOp != TI) | ||||
1949 | std::swap(NewTrueOp, NewFalseOp); | ||||
1950 | Value *NewSel = Builder.CreateSelect(CondVal, NewTrueOp, NewFalseOp, | ||||
1951 | SI.getName() + ".p", &SI); | ||||
1952 | |||||
1953 | if (SI.getType()->isFPOrFPVectorTy()) { | ||||
1954 | Instruction *RI = | ||||
1955 | BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); | ||||
1956 | |||||
1957 | FastMathFlags Flags = AddOp->getFastMathFlags(); | ||||
1958 | Flags &= SubOp->getFastMathFlags(); | ||||
1959 | RI->setFastMathFlags(Flags); | ||||
1960 | return RI; | ||||
1961 | } else | ||||
1962 | return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); | ||||
1963 | } | ||||
1964 | } | ||||
1965 | return nullptr; | ||||
1966 | } | ||||
1967 | |||||
1968 | /// Turn X + Y overflows ? -1 : X + Y -> uadd_sat X, Y | ||||
1969 | /// And X - Y overflows ? 0 : X - Y -> usub_sat X, Y | ||||
1970 | /// Along with a number of patterns similar to: | ||||
1971 | /// X + Y overflows ? (X < 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y | ||||
1972 | /// X - Y overflows ? (X > 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y | ||||
1973 | static Instruction * | ||||
1974 | foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { | ||||
1975 | Value *CondVal = SI.getCondition(); | ||||
1976 | Value *TrueVal = SI.getTrueValue(); | ||||
1977 | Value *FalseVal = SI.getFalseValue(); | ||||
1978 | |||||
1979 | WithOverflowInst *II; | ||||
1980 | if (!match(CondVal, m_ExtractValue<1>(m_WithOverflowInst(II))) || | ||||
1981 | !match(FalseVal, m_ExtractValue<0>(m_Specific(II)))) | ||||
1982 | return nullptr; | ||||
1983 | |||||
1984 | Value *X = II->getLHS(); | ||||
1985 | Value *Y = II->getRHS(); | ||||
1986 | |||||
1987 | auto IsSignedSaturateLimit = [&](Value *Limit, bool IsAdd) { | ||||
1988 | Type *Ty = Limit->getType(); | ||||
1989 | |||||
1990 | ICmpInst::Predicate Pred; | ||||
1991 | Value *TrueVal, *FalseVal, *Op; | ||||
1992 | const APInt *C; | ||||
1993 | if (!match(Limit, m_Select(m_ICmp(Pred, m_Value(Op), m_APInt(C)), | ||||
1994 | m_Value(TrueVal), m_Value(FalseVal)))) | ||||
1995 | return false; | ||||
1996 | |||||
1997 | auto IsZeroOrOne = [](const APInt &C) { return C.isZero() || C.isOne(); }; | ||||
1998 | auto IsMinMax = [&](Value *Min, Value *Max) { | ||||
1999 | APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); | ||||
2000 | APInt MaxVal = APInt::getSignedMaxValue(Ty->getScalarSizeInBits()); | ||||
2001 | return match(Min, m_SpecificInt(MinVal)) && | ||||
2002 | match(Max, m_SpecificInt(MaxVal)); | ||||
2003 | }; | ||||
2004 | |||||
2005 | if (Op != X && Op != Y) | ||||
2006 | return false; | ||||
2007 | |||||
2008 | if (IsAdd) { | ||||
2009 | // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y | ||||
2010 | // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y | ||||
2011 | // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y | ||||
2012 | // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y | ||||
2013 | if (Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) && | ||||
2014 | IsMinMax(TrueVal, FalseVal)) | ||||
2015 | return true; | ||||
2016 | // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y | ||||
2017 | // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y | ||||
2018 | // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y | ||||
2019 | // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y | ||||
2020 | if (Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) && | ||||
2021 | IsMinMax(FalseVal, TrueVal)) | ||||
2022 | return true; | ||||
2023 | } else { | ||||
2024 | // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y | ||||
2025 | // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y | ||||
2026 | if (Op == X && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C + 1) && | ||||
2027 | IsMinMax(TrueVal, FalseVal)) | ||||
2028 | return true; | ||||
2029 | // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y | ||||
2030 | // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y | ||||
2031 | if (Op == X && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 2) && | ||||
2032 | IsMinMax(FalseVal, TrueVal)) | ||||
2033 | return true; | ||||
2034 | // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y | ||||
2035 | // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y | ||||
2036 | if (Op == Y && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) && | ||||
2037 | IsMinMax(FalseVal, TrueVal)) | ||||
2038 | return true; | ||||
2039 | // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y | ||||
2040 | // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y | ||||
2041 | if (Op == Y && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) && | ||||
2042 | IsMinMax(TrueVal, FalseVal)) | ||||
2043 | return true; | ||||
2044 | } | ||||
2045 | |||||
2046 | return false; | ||||
2047 | }; | ||||
2048 | |||||
2049 | Intrinsic::ID NewIntrinsicID; | ||||
2050 | if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow && | ||||
2051 | match(TrueVal, m_AllOnes())) | ||||
2052 | // X + Y overflows ? -1 : X + Y -> uadd_sat X, Y | ||||
2053 | NewIntrinsicID = Intrinsic::uadd_sat; | ||||
2054 | else if (II->getIntrinsicID() == Intrinsic::usub_with_overflow && | ||||
2055 | match(TrueVal, m_Zero())) | ||||
2056 | // X - Y overflows ? 0 : X - Y -> usub_sat X, Y | ||||
2057 | NewIntrinsicID = Intrinsic::usub_sat; | ||||
2058 | else if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow && | ||||
2059 | IsSignedSaturateLimit(TrueVal, /*IsAdd=*/true)) | ||||
2060 | // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y | ||||
2061 | // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y | ||||
2062 | // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y | ||||
2063 | // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y | ||||
2064 | // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y | ||||
2065 | // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y | ||||
2066 | // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y | ||||
2067 | // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y | ||||
2068 | NewIntrinsicID = Intrinsic::sadd_sat; | ||||
2069 | else if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow && | ||||
2070 | IsSignedSaturateLimit(TrueVal, /*IsAdd=*/false)) | ||||
2071 | // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y | ||||
2072 | // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y | ||||
2073 | // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y | ||||
2074 | // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y | ||||
2075 | // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y | ||||
2076 | // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y | ||||
2077 | // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y | ||||
2078 | // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y | ||||
2079 | NewIntrinsicID = Intrinsic::ssub_sat; | ||||
2080 | else | ||||
2081 | return nullptr; | ||||
2082 | |||||
2083 | Function *F = | ||||
2084 | Intrinsic::getDeclaration(SI.getModule(), NewIntrinsicID, SI.getType()); | ||||
2085 | return CallInst::Create(F, {X, Y}); | ||||
2086 | } | ||||
2087 | |||||
2088 | Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) { | ||||
2089 | Constant *C; | ||||
2090 | if (!match(Sel.getTrueValue(), m_Constant(C)) && | ||||
2091 | !match(Sel.getFalseValue(), m_Constant(C))) | ||||
2092 | return nullptr; | ||||
2093 | |||||
2094 | Instruction *ExtInst; | ||||
2095 | if (!match(Sel.getTrueValue(), m_Instruction(ExtInst)) && | ||||
2096 | !match(Sel.getFalseValue(), m_Instruction(ExtInst))) | ||||
2097 | return nullptr; | ||||
2098 | |||||
2099 | auto ExtOpcode = ExtInst->getOpcode(); | ||||
2100 | if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt) | ||||
2101 | return nullptr; | ||||
2102 | |||||
2103 | // If we are extending from a boolean type or if we can create a select that | ||||
2104 | // has the same size operands as its condition, try to narrow the select. | ||||
2105 | Value *X = ExtInst->getOperand(0); | ||||
2106 | Type *SmallType = X->getType(); | ||||
2107 | Value *Cond = Sel.getCondition(); | ||||
2108 | auto *Cmp = dyn_cast<CmpInst>(Cond); | ||||
2109 | if (!SmallType->isIntOrIntVectorTy(1) && | ||||
2110 | (!Cmp || Cmp->getOperand(0)->getType() != SmallType)) | ||||
2111 | return nullptr; | ||||
2112 | |||||
2113 | // If the constant is the same after truncation to the smaller type and | ||||
2114 | // extension to the original type, we can narrow the select. | ||||
2115 | Type *SelType = Sel.getType(); | ||||
2116 | Constant *TruncC = ConstantExpr::getTrunc(C, SmallType); | ||||
2117 | Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType); | ||||
2118 | if (ExtC == C && ExtInst->hasOneUse()) { | ||||
2119 | Value *TruncCVal = cast<Value>(TruncC); | ||||
2120 | if (ExtInst == Sel.getFalseValue()) | ||||
2121 | std::swap(X, TruncCVal); | ||||
2122 | |||||
2123 | // select Cond, (ext X), C --> ext(select Cond, X, C') | ||||
2124 | // select Cond, C, (ext X) --> ext(select Cond, C', X) | ||||
2125 | Value *NewSel = Builder.CreateSelect(Cond, X, TruncCVal, "narrow", &Sel); | ||||
2126 | return CastInst::Create(Instruction::CastOps(ExtOpcode), NewSel, SelType); | ||||
2127 | } | ||||
2128 | |||||
2129 | // If one arm of the select is the extend of the condition, replace that arm | ||||
2130 | // with the extension of the appropriate known bool value. | ||||
2131 | if (Cond == X) { | ||||
2132 | if (ExtInst == Sel.getTrueValue()) { | ||||
2133 | // select X, (sext X), C --> select X, -1, C | ||||
2134 | // select X, (zext X), C --> select X, 1, C | ||||
2135 | Constant *One = ConstantInt::getTrue(SmallType); | ||||
2136 | Constant *AllOnesOrOne = ConstantExpr::getCast(ExtOpcode, One, SelType); | ||||
2137 | return SelectInst::Create(Cond, AllOnesOrOne, C, "", nullptr, &Sel); | ||||
2138 | } else { | ||||
2139 | // select X, C, (sext X) --> select X, C, 0 | ||||
2140 | // select X, C, (zext X) --> select X, C, 0 | ||||
2141 | Constant *Zero = ConstantInt::getNullValue(SelType); | ||||
2142 | return SelectInst::Create(Cond, C, Zero, "", nullptr, &Sel); | ||||
2143 | } | ||||
2144 | } | ||||
2145 | |||||
2146 | return nullptr; | ||||
2147 | } | ||||
2148 | |||||
2149 | /// Try to transform a vector select with a constant condition vector into a | ||||
2150 | /// shuffle for easier combining with other shuffles and insert/extract. | ||||
2151 | static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { | ||||
2152 | Value *CondVal = SI.getCondition(); | ||||
2153 | Constant *CondC; | ||||
2154 | auto *CondValTy = dyn_cast<FixedVectorType>(CondVal->getType()); | ||||
2155 | if (!CondValTy || !match(CondVal, m_Constant(CondC))) | ||||
2156 | return nullptr; | ||||
2157 | |||||
2158 | unsigned NumElts = CondValTy->getNumElements(); | ||||
2159 | SmallVector<int, 16> Mask; | ||||
2160 | Mask.reserve(NumElts); | ||||
2161 | for (unsigned i = 0; i != NumElts; ++i) { | ||||
2162 | Constant *Elt = CondC->getAggregateElement(i); | ||||
2163 | if (!Elt) | ||||
2164 | return nullptr; | ||||
2165 | |||||
2166 | if (Elt->isOneValue()) { | ||||
2167 | // If the select condition element is true, choose from the 1st vector. | ||||
2168 | Mask.push_back(i); | ||||
2169 | } else if (Elt->isNullValue()) { | ||||
2170 | // If the select condition element is false, choose from the 2nd vector. | ||||
2171 | Mask.push_back(i + NumElts); | ||||
2172 | } else if (isa<UndefValue>(Elt)) { | ||||
2173 | // Undef in a select condition (choose one of the operands) does not mean | ||||
2174 | // the same thing as undef in a shuffle mask (any value is acceptable), so | ||||
2175 | // give up. | ||||
2176 | return nullptr; | ||||
2177 | } else { | ||||
2178 | // Bail out on a constant expression. | ||||
2179 | return nullptr; | ||||
2180 | } | ||||
2181 | } | ||||
2182 | |||||
2183 | return new ShuffleVectorInst(SI.getTrueValue(), SI.getFalseValue(), Mask); | ||||
2184 | } | ||||
2185 | |||||
2186 | /// If we have a select of vectors with a scalar condition, try to convert that | ||||
2187 | /// to a vector select by splatting the condition. A splat may get folded with | ||||
2188 | /// other operations in IR and having all operands of a select be vector types | ||||
2189 | /// is likely better for vector codegen. | ||||
2190 | static Instruction *canonicalizeScalarSelectOfVecs(SelectInst &Sel, | ||||
2191 | InstCombinerImpl &IC) { | ||||
2192 | auto *Ty = dyn_cast<VectorType>(Sel.getType()); | ||||
2193 | if (!Ty) | ||||
2194 | return nullptr; | ||||
2195 | |||||
2196 | // We can replace a single-use extract with constant index. | ||||
2197 | Value *Cond = Sel.getCondition(); | ||||
2198 | if (!match(Cond, m_OneUse(m_ExtractElt(m_Value(), m_ConstantInt())))) | ||||
2199 | return nullptr; | ||||
2200 | |||||
2201 | // select (extelt V, Index), T, F --> select (splat V, Index), T, F | ||||
2202 | // Splatting the extracted condition reduces code (we could directly create a | ||||
2203 | // splat shuffle of the source vector to eliminate the intermediate step). | ||||
2204 | return IC.replaceOperand( | ||||
2205 | Sel, 0, IC.Builder.CreateVectorSplat(Ty->getElementCount(), Cond)); | ||||
2206 | } | ||||
2207 | |||||
2208 | /// Reuse bitcasted operands between a compare and select: | ||||
2209 | /// select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> | ||||
2210 | /// bitcast (select (cmp (bitcast C), (bitcast D)), (bitcast C), (bitcast D)) | ||||
2211 | static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, | ||||
2212 | InstCombiner::BuilderTy &Builder) { | ||||
2213 | Value *Cond = Sel.getCondition(); | ||||
2214 | Value *TVal = Sel.getTrueValue(); | ||||
2215 | Value *FVal = Sel.getFalseValue(); | ||||
2216 | |||||
2217 | CmpInst::Predicate Pred; | ||||
2218 | Value *A, *B; | ||||
2219 | if (!match(Cond, m_Cmp(Pred, m_Value(A), m_Value(B)))) | ||||
2220 | return nullptr; | ||||
2221 | |||||
2222 | // The select condition is a compare instruction. If the select's true/false | ||||
2223 | // values are already the same as the compare operands, there's nothing to do. | ||||
2224 | if (TVal == A || TVal == B || FVal == A || FVal == B) | ||||
2225 | return nullptr; | ||||
2226 | |||||
2227 | Value *C, *D; | ||||
2228 | if (!match(A, m_BitCast(m_Value(C))) || !match(B, m_BitCast(m_Value(D)))) | ||||
2229 | return nullptr; | ||||
2230 | |||||
2231 | // select (cmp (bitcast C), (bitcast D)), (bitcast TSrc), (bitcast FSrc) | ||||
2232 | Value *TSrc, *FSrc; | ||||
2233 | if (!match(TVal, m_BitCast(m_Value(TSrc))) || | ||||
2234 | !match(FVal, m_BitCast(m_Value(FSrc)))) | ||||
2235 | return nullptr; | ||||
2236 | |||||
2237 | // If the select true/false values are *different bitcasts* of the same source | ||||
2238 | // operands, make the select operands the same as the compare operands and | ||||
2239 | // cast the result. This is the canonical select form for min/max. | ||||
2240 | Value *NewSel; | ||||
2241 | if (TSrc == C && FSrc == D) { | ||||
2242 | // select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> | ||||
2243 | // bitcast (select (cmp A, B), A, B) | ||||
2244 | NewSel = Builder.CreateSelect(Cond, A, B, "", &Sel); | ||||
2245 | } else if (TSrc == D && FSrc == C) { | ||||
2246 | // select (cmp (bitcast C), (bitcast D)), (bitcast' D), (bitcast' C) --> | ||||
2247 | // bitcast (select (cmp A, B), B, A) | ||||
2248 | NewSel = Builder.CreateSelect(Cond, B, A, "", &Sel); | ||||
2249 | } else { | ||||
2250 | return nullptr; | ||||
2251 | } | ||||
2252 | return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType()); | ||||
2253 | } | ||||
2254 | |||||
2255 | /// Try to eliminate select instructions that test the returned flag of cmpxchg | ||||
2256 | /// instructions. | ||||
2257 | /// | ||||
2258 | /// If a select instruction tests the returned flag of a cmpxchg instruction and | ||||
2259 | /// selects between the returned value of the cmpxchg instruction its compare | ||||
2260 | /// operand, the result of the select will always be equal to its false value. | ||||
2261 | /// For example: | ||||
2262 | /// | ||||
2263 | /// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst | ||||
2264 | /// %1 = extractvalue { i64, i1 } %0, 1 | ||||
2265 | /// %2 = extractvalue { i64, i1 } %0, 0 | ||||
2266 | /// %3 = select i1 %1, i64 %compare, i64 %2 | ||||
2267 | /// ret i64 %3 | ||||
2268 | /// | ||||
2269 | /// The returned value of the cmpxchg instruction (%2) is the original value | ||||
2270 | /// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2 | ||||
2271 | /// must have been equal to %compare. Thus, the result of the select is always | ||||
2272 | /// equal to %2, and the code can be simplified to: | ||||
2273 | /// | ||||
2274 | /// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst | ||||
2275 | /// %1 = extractvalue { i64, i1 } %0, 0 | ||||
2276 | /// ret i64 %1 | ||||
2277 | /// | ||||
2278 | static Value *foldSelectCmpXchg(SelectInst &SI) { | ||||
2279 | // A helper that determines if V is an extractvalue instruction whose | ||||
2280 | // aggregate operand is a cmpxchg instruction and whose single index is equal | ||||
2281 | // to I. If such conditions are true, the helper returns the cmpxchg | ||||
2282 | // instruction; otherwise, a nullptr is returned. | ||||
2283 | auto isExtractFromCmpXchg = [](Value *V, unsigned I) -> AtomicCmpXchgInst * { | ||||
2284 | auto *Extract = dyn_cast<ExtractValueInst>(V); | ||||
2285 | if (!Extract) | ||||
2286 | return nullptr; | ||||
2287 | if (Extract->getIndices()[0] != I) | ||||
2288 | return nullptr; | ||||
2289 | return dyn_cast<AtomicCmpXchgInst>(Extract->getAggregateOperand()); | ||||
2290 | }; | ||||
2291 | |||||
2292 | // If the select has a single user, and this user is a select instruction that | ||||
2293 | // we can simplify, skip the cmpxchg simplification for now. | ||||
2294 | if (SI.hasOneUse()) | ||||
2295 | if (auto *Select = dyn_cast<SelectInst>(SI.user_back())) | ||||
2296 | if (Select->getCondition() == SI.getCondition()) | ||||
2297 | if (Select->getFalseValue() == SI.getTrueValue() || | ||||
2298 | Select->getTrueValue() == SI.getFalseValue()) | ||||
2299 | return nullptr; | ||||
2300 | |||||
2301 | // Ensure the select condition is the returned flag of a cmpxchg instruction. | ||||
2302 | auto *CmpXchg = isExtractFromCmpXchg(SI.getCondition(), 1); | ||||
2303 | if (!CmpXchg) | ||||
2304 | return nullptr; | ||||
2305 | |||||
2306 | // Check the true value case: The true value of the select is the returned | ||||
2307 | // value of the same cmpxchg used by the condition, and the false value is the | ||||
2308 | // cmpxchg instruction's compare operand. | ||||
2309 | if (auto *X = isExtractFromCmpXchg(SI.getTrueValue(), 0)) | ||||
2310 | if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue()) | ||||
2311 | return SI.getFalseValue(); | ||||
2312 | |||||
2313 | // Check the false value case: The false value of the select is the returned | ||||
2314 | // value of the same cmpxchg used by the condition, and the true value is the | ||||
2315 | // cmpxchg instruction's compare operand. | ||||
2316 | if (auto *X = isExtractFromCmpXchg(SI.getFalseValue(), 0)) | ||||
2317 | if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue()) | ||||
2318 | return SI.getFalseValue(); | ||||
2319 | |||||
2320 | return nullptr; | ||||
2321 | } | ||||
2322 | |||||
2323 | /// Try to reduce a funnel/rotate pattern that includes a compare and select | ||||
2324 | /// into a funnel shift intrinsic. Example: | ||||
2325 | /// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) | ||||
2326 | /// --> call llvm.fshl.i32(a, a, b) | ||||
2327 | /// fshl32(a, b, c) --> (c == 0 ? a : ((b >> (32 - c)) | (a << c))) | ||||
2328 | /// --> call llvm.fshl.i32(a, b, c) | ||||
2329 | /// fshr32(a, b, c) --> (c == 0 ? b : ((a >> (32 - c)) | (b << c))) | ||||
2330 | /// --> call llvm.fshr.i32(a, b, c) | ||||
2331 | static Instruction *foldSelectFunnelShift(SelectInst &Sel, | ||||
2332 | InstCombiner::BuilderTy &Builder) { | ||||
2333 | // This must be a power-of-2 type for a bitmasking transform to be valid. | ||||
2334 | unsigned Width = Sel.getType()->getScalarSizeInBits(); | ||||
2335 | if (!isPowerOf2_32(Width)) | ||||
2336 | return nullptr; | ||||
2337 | |||||
2338 | BinaryOperator *Or0, *Or1; | ||||
2339 | if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1))))) | ||||
2340 | return nullptr; | ||||
2341 | |||||
2342 | Value *SV0, *SV1, *SA0, *SA1; | ||||
2343 | if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(SV0), | ||||
2344 | m_ZExtOrSelf(m_Value(SA0))))) || | ||||
2345 | !match(Or1, m_OneUse(m_LogicalShift(m_Value(SV1), | ||||
2346 | m_ZExtOrSelf(m_Value(SA1))))) || | ||||
2347 | Or0->getOpcode() == Or1->getOpcode()) | ||||
2348 | return nullptr; | ||||
2349 | |||||
2350 | // Canonicalize to or(shl(SV0, SA0), lshr(SV1, SA1)). | ||||
2351 | if (Or0->getOpcode() == BinaryOperator::LShr) { | ||||
2352 | std::swap(Or0, Or1); | ||||
2353 | std::swap(SV0, SV1); | ||||
2354 | std::swap(SA0, SA1); | ||||
2355 | } | ||||
2356 | assert(Or0->getOpcode() == BinaryOperator::Shl &&(static_cast <bool> (Or0->getOpcode() == BinaryOperator ::Shl && Or1->getOpcode() == BinaryOperator::LShr && "Illegal or(shift,shift) pair") ? void (0) : __assert_fail ( "Or0->getOpcode() == BinaryOperator::Shl && Or1->getOpcode() == BinaryOperator::LShr && \"Illegal or(shift,shift) pair\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 2358 , __extension__ __PRETTY_FUNCTION__)) | ||||
2357 | Or1->getOpcode() == BinaryOperator::LShr &&(static_cast <bool> (Or0->getOpcode() == BinaryOperator ::Shl && Or1->getOpcode() == BinaryOperator::LShr && "Illegal or(shift,shift) pair") ? void (0) : __assert_fail ( "Or0->getOpcode() == BinaryOperator::Shl && Or1->getOpcode() == BinaryOperator::LShr && \"Illegal or(shift,shift) pair\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 2358 , __extension__ __PRETTY_FUNCTION__)) | ||||
2358 | "Illegal or(shift,shift) pair")(static_cast <bool> (Or0->getOpcode() == BinaryOperator ::Shl && Or1->getOpcode() == BinaryOperator::LShr && "Illegal or(shift,shift) pair") ? void (0) : __assert_fail ( "Or0->getOpcode() == BinaryOperator::Shl && Or1->getOpcode() == BinaryOperator::LShr && \"Illegal or(shift,shift) pair\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 2358 , __extension__ __PRETTY_FUNCTION__)); | ||||
2359 | |||||
2360 | // Check the shift amounts to see if they are an opposite pair. | ||||
2361 | Value *ShAmt; | ||||
2362 | if (match(SA1, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA0))))) | ||||
2363 | ShAmt = SA0; | ||||
2364 | else if (match(SA0, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA1))))) | ||||
2365 | ShAmt = SA1; | ||||
2366 | else | ||||
2367 | return nullptr; | ||||
2368 | |||||
2369 | // We should now have this pattern: | ||||
2370 | // select ?, TVal, (or (shl SV0, SA0), (lshr SV1, SA1)) | ||||
2371 | // The false value of the select must be a funnel-shift of the true value: | ||||
2372 | // IsFShl -> TVal must be SV0 else TVal must be SV1. | ||||
2373 | bool IsFshl = (ShAmt == SA0); | ||||
2374 | Value *TVal = Sel.getTrueValue(); | ||||
2375 | if ((IsFshl && TVal != SV0) || (!IsFshl && TVal != SV1)) | ||||
2376 | return nullptr; | ||||
2377 | |||||
2378 | // Finally, see if the select is filtering out a shift-by-zero. | ||||
2379 | Value *Cond = Sel.getCondition(); | ||||
2380 | ICmpInst::Predicate Pred; | ||||
2381 | if (!match(Cond, m_OneUse(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()))) || | ||||
2382 | Pred != ICmpInst::ICMP_EQ) | ||||
2383 | return nullptr; | ||||
2384 | |||||
2385 | // If this is not a rotate then the select was blocking poison from the | ||||
2386 | // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it. | ||||
2387 | if (SV0 != SV1) { | ||||
2388 | if (IsFshl && !llvm::isGuaranteedNotToBePoison(SV1)) | ||||
2389 | SV1 = Builder.CreateFreeze(SV1); | ||||
2390 | else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(SV0)) | ||||
2391 | SV0 = Builder.CreateFreeze(SV0); | ||||
2392 | } | ||||
2393 | |||||
2394 | // This is a funnel/rotate that avoids shift-by-bitwidth UB in a suboptimal way. | ||||
2395 | // Convert to funnel shift intrinsic. | ||||
2396 | Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; | ||||
2397 | Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType()); | ||||
2398 | ShAmt = Builder.CreateZExt(ShAmt, Sel.getType()); | ||||
2399 | return CallInst::Create(F, { SV0, SV1, ShAmt }); | ||||
2400 | } | ||||
2401 | |||||
2402 | static Instruction *foldSelectToCopysign(SelectInst &Sel, | ||||
2403 | InstCombiner::BuilderTy &Builder) { | ||||
2404 | Value *Cond = Sel.getCondition(); | ||||
2405 | Value *TVal = Sel.getTrueValue(); | ||||
2406 | Value *FVal = Sel.getFalseValue(); | ||||
2407 | Type *SelType = Sel.getType(); | ||||
2408 | |||||
2409 | // Match select ?, TC, FC where the constants are equal but negated. | ||||
2410 | // TODO: Generalize to handle a negated variable operand? | ||||
2411 | const APFloat *TC, *FC; | ||||
2412 | if (!match(TVal, m_APFloatAllowUndef(TC)) || | ||||
2413 | !match(FVal, m_APFloatAllowUndef(FC)) || | ||||
2414 | !abs(*TC).bitwiseIsEqual(abs(*FC))) | ||||
2415 | return nullptr; | ||||
2416 | |||||
2417 | assert(TC != FC && "Expected equal select arms to simplify")(static_cast <bool> (TC != FC && "Expected equal select arms to simplify" ) ? void (0) : __assert_fail ("TC != FC && \"Expected equal select arms to simplify\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 2417 , __extension__ __PRETTY_FUNCTION__)); | ||||
2418 | |||||
2419 | Value *X; | ||||
2420 | const APInt *C; | ||||
2421 | bool IsTrueIfSignSet; | ||||
2422 | ICmpInst::Predicate Pred; | ||||
2423 | if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) || | ||||
2424 | !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) || | ||||
2425 | X->getType() != SelType) | ||||
2426 | return nullptr; | ||||
2427 | |||||
2428 | // If needed, negate the value that will be the sign argument of the copysign: | ||||
2429 | // (bitcast X) < 0 ? -TC : TC --> copysign(TC, X) | ||||
2430 | // (bitcast X) < 0 ? TC : -TC --> copysign(TC, -X) | ||||
2431 | // (bitcast X) >= 0 ? -TC : TC --> copysign(TC, -X) | ||||
2432 | // (bitcast X) >= 0 ? TC : -TC --> copysign(TC, X) | ||||
2433 | // Note: FMF from the select can not be propagated to the new instructions. | ||||
2434 | if (IsTrueIfSignSet ^ TC->isNegative()) | ||||
2435 | X = Builder.CreateFNeg(X); | ||||
2436 | |||||
2437 | // Canonicalize the magnitude argument as the positive constant since we do | ||||
2438 | // not care about its sign. | ||||
2439 | Value *MagArg = ConstantFP::get(SelType, abs(*TC)); | ||||
2440 | Function *F = Intrinsic::getDeclaration(Sel.getModule(), Intrinsic::copysign, | ||||
2441 | Sel.getType()); | ||||
2442 | return CallInst::Create(F, { MagArg, X }); | ||||
2443 | } | ||||
2444 | |||||
2445 | Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { | ||||
2446 | if (!isa<VectorType>(Sel.getType())) | ||||
2447 | return nullptr; | ||||
2448 | |||||
2449 | Value *Cond = Sel.getCondition(); | ||||
2450 | Value *TVal = Sel.getTrueValue(); | ||||
2451 | Value *FVal = Sel.getFalseValue(); | ||||
2452 | Value *C, *X, *Y; | ||||
2453 | |||||
2454 | if (match(Cond, m_VecReverse(m_Value(C)))) { | ||||
2455 | auto createSelReverse = [&](Value *C, Value *X, Value *Y) { | ||||
2456 | Value *V = Builder.CreateSelect(C, X, Y, Sel.getName(), &Sel); | ||||
2457 | if (auto *I = dyn_cast<Instruction>(V)) | ||||
2458 | I->copyIRFlags(&Sel); | ||||
2459 | Module *M = Sel.getModule(); | ||||
2460 | Function *F = Intrinsic::getDeclaration( | ||||
2461 | M, Intrinsic::experimental_vector_reverse, V->getType()); | ||||
2462 | return CallInst::Create(F, V); | ||||
2463 | }; | ||||
2464 | |||||
2465 | if (match(TVal, m_VecReverse(m_Value(X)))) { | ||||
2466 | // select rev(C), rev(X), rev(Y) --> rev(select C, X, Y) | ||||
2467 | if (match(FVal, m_VecReverse(m_Value(Y))) && | ||||
2468 | (Cond->hasOneUse() || TVal->hasOneUse() || FVal->hasOneUse())) | ||||
2469 | return createSelReverse(C, X, Y); | ||||
2470 | |||||
2471 | // select rev(C), rev(X), FValSplat --> rev(select C, X, FValSplat) | ||||
2472 | if ((Cond->hasOneUse() || TVal->hasOneUse()) && isSplatValue(FVal)) | ||||
2473 | return createSelReverse(C, X, FVal); | ||||
2474 | } | ||||
2475 | // select rev(C), TValSplat, rev(Y) --> rev(select C, TValSplat, Y) | ||||
2476 | else if (isSplatValue(TVal) && match(FVal, m_VecReverse(m_Value(Y))) && | ||||
2477 | (Cond->hasOneUse() || FVal->hasOneUse())) | ||||
2478 | return createSelReverse(C, TVal, Y); | ||||
2479 | } | ||||
2480 | |||||
2481 | auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType()); | ||||
2482 | if (!VecTy) | ||||
2483 | return nullptr; | ||||
2484 | |||||
2485 | unsigned NumElts = VecTy->getNumElements(); | ||||
2486 | APInt UndefElts(NumElts, 0); | ||||
2487 | APInt AllOnesEltMask(APInt::getAllOnes(NumElts)); | ||||
2488 | if (Value *V = SimplifyDemandedVectorElts(&Sel, AllOnesEltMask, UndefElts)) { | ||||
2489 | if (V != &Sel) | ||||
2490 | return replaceInstUsesWith(Sel, V); | ||||
2491 | return &Sel; | ||||
2492 | } | ||||
2493 | |||||
2494 | // A select of a "select shuffle" with a common operand can be rearranged | ||||
2495 | // to select followed by "select shuffle". Because of poison, this only works | ||||
2496 | // in the case of a shuffle with no undefined mask elements. | ||||
2497 | ArrayRef<int> Mask; | ||||
2498 | if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && | ||||
2499 | !is_contained(Mask, PoisonMaskElem) && | ||||
2500 | cast<ShuffleVectorInst>(TVal)->isSelect()) { | ||||
2501 | if (X == FVal) { | ||||
2502 | // select Cond, (shuf_sel X, Y), X --> shuf_sel X, (select Cond, Y, X) | ||||
2503 | Value *NewSel = Builder.CreateSelect(Cond, Y, X, "sel", &Sel); | ||||
2504 | return new ShuffleVectorInst(X, NewSel, Mask); | ||||
2505 | } | ||||
2506 | if (Y == FVal) { | ||||
2507 | // select Cond, (shuf_sel X, Y), Y --> shuf_sel (select Cond, X, Y), Y | ||||
2508 | Value *NewSel = Builder.CreateSelect(Cond, X, Y, "sel", &Sel); | ||||
2509 | return new ShuffleVectorInst(NewSel, Y, Mask); | ||||
2510 | } | ||||
2511 | } | ||||
2512 | if (match(FVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && | ||||
2513 | !is_contained(Mask, PoisonMaskElem) && | ||||
2514 | cast<ShuffleVectorInst>(FVal)->isSelect()) { | ||||
2515 | if (X == TVal) { | ||||
2516 | // select Cond, X, (shuf_sel X, Y) --> shuf_sel X, (select Cond, X, Y) | ||||
2517 | Value *NewSel = Builder.CreateSelect(Cond, X, Y, "sel", &Sel); | ||||
2518 | return new ShuffleVectorInst(X, NewSel, Mask); | ||||
2519 | } | ||||
2520 | if (Y == TVal) { | ||||
2521 | // select Cond, Y, (shuf_sel X, Y) --> shuf_sel (select Cond, Y, X), Y | ||||
2522 | Value *NewSel = Builder.CreateSelect(Cond, Y, X, "sel", &Sel); | ||||
2523 | return new ShuffleVectorInst(NewSel, Y, Mask); | ||||
2524 | } | ||||
2525 | } | ||||
2526 | |||||
2527 | return nullptr; | ||||
2528 | } | ||||
2529 | |||||
2530 | static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB, | ||||
2531 | const DominatorTree &DT, | ||||
2532 | InstCombiner::BuilderTy &Builder) { | ||||
2533 | // Find the block's immediate dominator that ends with a conditional branch | ||||
2534 | // that matches select's condition (maybe inverted). | ||||
2535 | auto *IDomNode = DT[BB]->getIDom(); | ||||
2536 | if (!IDomNode) | ||||
2537 | return nullptr; | ||||
2538 | BasicBlock *IDom = IDomNode->getBlock(); | ||||
2539 | |||||
2540 | Value *Cond = Sel.getCondition(); | ||||
2541 | Value *IfTrue, *IfFalse; | ||||
2542 | BasicBlock *TrueSucc, *FalseSucc; | ||||
2543 | if (match(IDom->getTerminator(), | ||||
2544 | m_Br(m_Specific(Cond), m_BasicBlock(TrueSucc), | ||||
2545 | m_BasicBlock(FalseSucc)))) { | ||||
2546 | IfTrue = Sel.getTrueValue(); | ||||
2547 | IfFalse = Sel.getFalseValue(); | ||||
2548 | } else if (match(IDom->getTerminator(), | ||||
2549 | m_Br(m_Not(m_Specific(Cond)), m_BasicBlock(TrueSucc), | ||||
2550 | m_BasicBlock(FalseSucc)))) { | ||||
2551 | IfTrue = Sel.getFalseValue(); | ||||
2552 | IfFalse = Sel.getTrueValue(); | ||||
2553 | } else | ||||
2554 | return nullptr; | ||||
2555 | |||||
2556 | // Make sure the branches are actually different. | ||||
2557 | if (TrueSucc == FalseSucc) | ||||
2558 | return nullptr; | ||||
2559 | |||||
2560 | // We want to replace select %cond, %a, %b with a phi that takes value %a | ||||
2561 | // for all incoming edges that are dominated by condition `%cond == true`, | ||||
2562 | // and value %b for edges dominated by condition `%cond == false`. If %a | ||||
2563 | // or %b are also phis from the same basic block, we can go further and take | ||||
2564 | // their incoming values from the corresponding blocks. | ||||
2565 | BasicBlockEdge TrueEdge(IDom, TrueSucc); | ||||
2566 | BasicBlockEdge FalseEdge(IDom, FalseSucc); | ||||
2567 | DenseMap<BasicBlock *, Value *> Inputs; | ||||
2568 | for (auto *Pred : predecessors(BB)) { | ||||
2569 | // Check implication. | ||||
2570 | BasicBlockEdge Incoming(Pred, BB); | ||||
2571 | if (DT.dominates(TrueEdge, Incoming)) | ||||
2572 | Inputs[Pred] = IfTrue->DoPHITranslation(BB, Pred); | ||||
2573 | else if (DT.dominates(FalseEdge, Incoming)) | ||||
2574 | Inputs[Pred] = IfFalse->DoPHITranslation(BB, Pred); | ||||
2575 | else | ||||
2576 | return nullptr; | ||||
2577 | // Check availability. | ||||
2578 | if (auto *Insn = dyn_cast<Instruction>(Inputs[Pred])) | ||||
2579 | if (!DT.dominates(Insn, Pred->getTerminator())) | ||||
2580 | return nullptr; | ||||
2581 | } | ||||
2582 | |||||
2583 | Builder.SetInsertPoint(&*BB->begin()); | ||||
2584 | auto *PN = Builder.CreatePHI(Sel.getType(), Inputs.size()); | ||||
2585 | for (auto *Pred : predecessors(BB)) | ||||
2586 | PN->addIncoming(Inputs[Pred], Pred); | ||||
2587 | PN->takeName(&Sel); | ||||
2588 | return PN; | ||||
2589 | } | ||||
2590 | |||||
2591 | static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT, | ||||
2592 | InstCombiner::BuilderTy &Builder) { | ||||
2593 | // Try to replace this select with Phi in one of these blocks. | ||||
2594 | SmallSetVector<BasicBlock *, 4> CandidateBlocks; | ||||
2595 | CandidateBlocks.insert(Sel.getParent()); | ||||
2596 | for (Value *V : Sel.operands()) | ||||
2597 | if (auto *I = dyn_cast<Instruction>(V)) | ||||
2598 | CandidateBlocks.insert(I->getParent()); | ||||
2599 | |||||
2600 | for (BasicBlock *BB : CandidateBlocks) | ||||
2601 | if (auto *PN = foldSelectToPhiImpl(Sel, BB, DT, Builder)) | ||||
2602 | return PN; | ||||
2603 | return nullptr; | ||||
2604 | } | ||||
2605 | |||||
2606 | static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) { | ||||
2607 | FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition()); | ||||
2608 | if (!FI) | ||||
2609 | return nullptr; | ||||
2610 | |||||
2611 | Value *Cond = FI->getOperand(0); | ||||
2612 | Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); | ||||
2613 | |||||
2614 | // select (freeze(x == y)), x, y --> y | ||||
2615 | // select (freeze(x != y)), x, y --> x | ||||
2616 | // The freeze should be only used by this select. Otherwise, remaining uses of | ||||
2617 | // the freeze can observe a contradictory value. | ||||
2618 | // c = freeze(x == y) ; Let's assume that y = poison & x = 42; c is 0 or 1 | ||||
2619 | // a = select c, x, y ; | ||||
2620 | // f(a, c) ; f(poison, 1) cannot happen, but if a is folded | ||||
2621 | // ; to y, this can happen. | ||||
2622 | CmpInst::Predicate Pred; | ||||
2623 | if (FI->hasOneUse() && | ||||
2624 | match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) && | ||||
2625 | (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)) { | ||||
2626 | return Pred == ICmpInst::ICMP_EQ ? FalseVal : TrueVal; | ||||
2627 | } | ||||
2628 | |||||
2629 | return nullptr; | ||||
2630 | } | ||||
2631 | |||||
2632 | Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op, | ||||
2633 | SelectInst &SI, | ||||
2634 | bool IsAnd) { | ||||
2635 | Value *CondVal = SI.getCondition(); | ||||
2636 | Value *A = SI.getTrueValue(); | ||||
2637 | Value *B = SI.getFalseValue(); | ||||
2638 | |||||
2639 | assert(Op->getType()->isIntOrIntVectorTy(1) &&(static_cast <bool> (Op->getType()->isIntOrIntVectorTy (1) && "Op must be either i1 or vector of i1.") ? void (0) : __assert_fail ("Op->getType()->isIntOrIntVectorTy(1) && \"Op must be either i1 or vector of i1.\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 2640 , __extension__ __PRETTY_FUNCTION__)) | ||||
2640 | "Op must be either i1 or vector of i1.")(static_cast <bool> (Op->getType()->isIntOrIntVectorTy (1) && "Op must be either i1 or vector of i1.") ? void (0) : __assert_fail ("Op->getType()->isIntOrIntVectorTy(1) && \"Op must be either i1 or vector of i1.\"" , "llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp", 2640 , __extension__ __PRETTY_FUNCTION__)); | ||||
2641 | |||||
2642 | std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd); | ||||
2643 | if (!Res) | ||||
2644 | return nullptr; | ||||
2645 | |||||
2646 | Value *Zero = Constant::getNullValue(A->getType()); | ||||
2647 | Value *One = Constant::getAllOnesValue(A->getType()); | ||||
2648 | |||||
2649 | if (*Res == true) { | ||||
2650 | if (IsAnd) | ||||
2651 | // select op, (select cond, A, B), false => select op, A, false | ||||
2652 | // and op, (select cond, A, B) => select op, A, false | ||||
2653 | // if op = true implies condval = true. | ||||
2654 | return SelectInst::Create(Op, A, Zero); | ||||
2655 | else | ||||
2656 | // select op, true, (select cond, A, B) => select op, true, A | ||||
2657 | // or op, (select cond, A, B) => select op, true, A | ||||
2658 | // if op = false implies condval = true. | ||||
2659 | return SelectInst::Create(Op, One, A); | ||||
2660 | } else { | ||||
2661 | if (IsAnd) | ||||
2662 | // select op, (select cond, A, B), false => select op, B, false | ||||
2663 | // and op, (select cond, A, B) => select op, B, false | ||||
2664 | // if op = true implies condval = false. | ||||
2665 | return SelectInst::Create(Op, B, Zero); | ||||
2666 | else | ||||
2667 | // select op, true, (select cond, A, B) => select op, true, B | ||||
2668 | // or op, (select cond, A, B) => select op, true, B | ||||
2669 | // if op = false implies condval = false. | ||||
2670 | return SelectInst::Create(Op, One, B); | ||||
2671 | } | ||||
2672 | } | ||||
2673 | |||||
2674 | // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need | ||||
2675 | // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. | ||||
2676 | static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, | ||||
2677 | InstCombinerImpl &IC) { | ||||
2678 | Value *CondVal = SI.getCondition(); | ||||
2679 | |||||
2680 | bool ChangedFMF = false; | ||||
2681 | for (bool Swap : {false, true}) { | ||||
2682 | Value *TrueVal = SI.getTrueValue(); | ||||
2683 | Value *X = SI.getFalseValue(); | ||||
2684 | CmpInst::Predicate Pred; | ||||
2685 | |||||
2686 | if (Swap) | ||||
2687 | std::swap(TrueVal, X); | ||||
2688 | |||||
2689 | if (!match(CondVal, m_FCmp(Pred, m_Specific(X), m_AnyZeroFP()))) | ||||
2690 | continue; | ||||
2691 | |||||
2692 | // fold (X <= +/-0.0) ? (0.0 - X) : X to fabs(X), when 'Swap' is false | ||||
2693 | // fold (X > +/-0.0) ? X : (0.0 - X) to fabs(X), when 'Swap' is true | ||||
2694 | if (match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) { | ||||
2695 | if (!Swap && (Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULE)) { | ||||
2696 | Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); | ||||
2697 | return IC.replaceInstUsesWith(SI, Fabs); | ||||
2698 | } | ||||
2699 | if (Swap && (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_UGT)) { | ||||
2700 | Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); | ||||
2701 | return IC.replaceInstUsesWith(SI, Fabs); | ||||
2702 | } | ||||
2703 | } | ||||
2704 | |||||
2705 | if (!match(TrueVal, m_FNeg(m_Specific(X)))) | ||||
2706 | return nullptr; | ||||
2707 | |||||
2708 | // Forward-propagate nnan and ninf from the fneg to the select. | ||||
2709 | // If all inputs are not those values, then the select is not either. | ||||
2710 | // Note: nsz is defined differently, so it may not be correct to propagate. | ||||
2711 | FastMathFlags FMF = cast<FPMathOperator>(TrueVal)->getFastMathFlags(); | ||||
2712 | if (FMF.noNaNs() && !SI.hasNoNaNs()) { | ||||
2713 | SI.setHasNoNaNs(true); | ||||
2714 | ChangedFMF = true; | ||||
2715 | } | ||||
2716 | if (FMF.noInfs() && !SI.hasNoInfs()) { | ||||
2717 | SI.setHasNoInfs(true); | ||||
2718 | ChangedFMF = true; | ||||
2719 | } | ||||
2720 | |||||
2721 | // With nsz, when 'Swap' is false: | ||||
2722 | // fold (X < +/-0.0) ? -X : X or (X <= +/-0.0) ? -X : X to fabs(X) | ||||
2723 | // fold (X > +/-0.0) ? -X : X or (X >= +/-0.0) ? -X : X to -fabs(x) | ||||
2724 | // when 'Swap' is true: | ||||
2725 | // fold (X > +/-0.0) ? X : -X or (X >= +/-0.0) ? X : -X to fabs(X) | ||||
2726 | // fold (X < +/-0.0) ? X : -X or (X <= +/-0.0) ? X : -X to -fabs(X) | ||||
2727 | // | ||||
2728 | // Note: We require "nnan" for this fold because fcmp ignores the signbit | ||||
2729 | // of NAN, but IEEE-754 specifies the signbit of NAN values with | ||||
2730 | // fneg/fabs operations. | ||||
2731 | if (!SI.hasNoSignedZeros() || !SI.hasNoNaNs()) | ||||
2732 | return nullptr; | ||||
2733 | |||||
2734 | if (Swap) | ||||
2735 | Pred = FCmpInst::getSwappedPredicate(Pred); | ||||
2736 | |||||
2737 | bool IsLTOrLE = Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE || | ||||
2738 | Pred == FCmpInst::FCMP_ULT || Pred == FCmpInst::FCMP_ULE; | ||||
2739 | bool IsGTOrGE = Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE || | ||||
2740 | Pred == FCmpInst::FCMP_UGT || Pred == FCmpInst::FCMP_UGE; | ||||
2741 | |||||
2742 | if (IsLTOrLE) { | ||||
2743 | Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); | ||||
2744 | return IC.replaceInstUsesWith(SI, Fabs); | ||||
2745 | } | ||||
2746 | if (IsGTOrGE) { | ||||
2747 | Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); | ||||
2748 | Instruction *NewFNeg = UnaryOperator::CreateFNeg(Fabs); | ||||
2749 | NewFNeg->setFastMathFlags(SI.getFastMathFlags()); | ||||
2750 | return NewFNeg; | ||||
2751 | } | ||||
2752 | } | ||||
2753 | |||||
2754 | return ChangedFMF ? &SI : nullptr; | ||||
2755 | } | ||||
2756 | |||||
2757 | // Match the following IR pattern: | ||||
2758 | // %x.lowbits = and i8 %x, %lowbitmask | ||||
2759 | // %x.lowbits.are.zero = icmp eq i8 %x.lowbits, 0 | ||||
2760 | // %x.biased = add i8 %x, %bias | ||||
2761 | // %x.biased.highbits = and i8 %x.biased, %highbitmask | ||||
2762 | // %x.roundedup = select i1 %x.lowbits.are.zero, i8 %x, i8 %x.biased.highbits | ||||
2763 | // Define: | ||||
2764 | // %alignment = add i8 %lowbitmask, 1 | ||||
2765 | // Iff 1. an %alignment is a power-of-two (aka, %lowbitmask is a low bit mask) | ||||
2766 | // and 2. %bias is equal to either %lowbitmask or %alignment, | ||||
2767 | // and 3. %highbitmask is equal to ~%lowbitmask (aka, to -%alignment) | ||||
2768 | // then this pattern can be transformed into: | ||||
2769 | // %x.offset = add i8 %x, %lowbitmask | ||||
2770 | // %x.roundedup = and i8 %x.offset, %highbitmask | ||||
2771 | static Value * | ||||
2772 | foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, | ||||
2773 | InstCombiner::BuilderTy &Builder) { | ||||
2774 | Value *Cond = SI.getCondition(); | ||||
2775 | Value *X = SI.getTrueValue(); | ||||
2776 | Value *XBiasedHighBits = SI.getFalseValue(); | ||||
2777 | |||||
2778 | ICmpInst::Predicate Pred; | ||||
2779 | Value *XLowBits; | ||||
2780 | if (!match(Cond, m_ICmp(Pred, m_Value(XLowBits), m_ZeroInt())) || | ||||
2781 | !ICmpInst::isEquality(Pred)) | ||||
2782 | return nullptr; | ||||
2783 | |||||
2784 | if (Pred == ICmpInst::Predicate::ICMP_NE) | ||||
2785 | std::swap(X, XBiasedHighBits); | ||||
2786 | |||||
2787 | // FIXME: we could support non non-splats here. | ||||
2788 | |||||
2789 | const APInt *LowBitMaskCst; | ||||
2790 | if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowUndef(LowBitMaskCst)))) | ||||
2791 | return nullptr; | ||||
2792 | |||||
2793 | // Match even if the AND and ADD are swapped. | ||||
2794 | const APInt *BiasCst, *HighBitMaskCst; | ||||
2795 | if (!match(XBiasedHighBits, | ||||
2796 | m_And(m_Add(m_Specific(X), m_APIntAllowUndef(BiasCst)), | ||||
2797 | m_APIntAllowUndef(HighBitMaskCst))) && | ||||
2798 | !match(XBiasedHighBits, | ||||
2799 | m_Add(m_And(m_Specific(X), m_APIntAllowUndef(HighBitMaskCst)), | ||||
2800 | m_APIntAllowUndef(BiasCst)))) | ||||
2801 | return nullptr; | ||||
2802 | |||||
2803 | if (!LowBitMaskCst->isMask()) | ||||
2804 | return nullptr; | ||||
2805 | |||||
2806 | APInt InvertedLowBitMaskCst = ~*LowBitMaskCst; | ||||
2807 | if (InvertedLowBitMaskCst != *HighBitMaskCst) | ||||
2808 | return nullptr; | ||||
2809 | |||||
2810 | APInt AlignmentCst = *LowBitMaskCst + 1; | ||||
2811 | |||||
2812 | if (*BiasCst != AlignmentCst && *BiasCst != *LowBitMaskCst) | ||||
2813 | return nullptr; | ||||
2814 | |||||
2815 | if (!XBiasedHighBits->hasOneUse()) { | ||||
2816 | if (*BiasCst == *LowBitMaskCst) | ||||
2817 | return XBiasedHighBits; | ||||
2818 | return nullptr; | ||||
2819 | } | ||||
2820 | |||||
2821 | // FIXME: could we preserve undef's here? | ||||
2822 | Type *Ty = X->getType(); | ||||
2823 | Value *XOffset = Builder.CreateAdd(X, ConstantInt::get(Ty, *LowBitMaskCst), | ||||
2824 | X->getName() + ".biased"); | ||||
2825 | Value *R = Builder.CreateAnd(XOffset, ConstantInt::get(Ty, *HighBitMaskCst)); | ||||
2826 | R->takeName(&SI); | ||||
2827 | return R; | ||||
2828 | } | ||||
2829 | |||||
2830 | namespace { | ||||
2831 | struct DecomposedSelect { | ||||
2832 | Value *Cond = nullptr; | ||||
2833 | Value *TrueVal = nullptr; | ||||
2834 | Value *FalseVal = nullptr; | ||||
2835 | }; | ||||
2836 | } // namespace | ||||
2837 | |||||
2838 | /// Look for patterns like | ||||
2839 | /// %outer.cond = select i1 %inner.cond, i1 %alt.cond, i1 false | ||||
2840 | /// %inner.sel = select i1 %inner.cond, i8 %inner.sel.t, i8 %inner.sel.f | ||||
2841 | /// %outer.sel = select i1 %outer.cond, i8 %outer.sel.t, i8 %inner.sel | ||||
2842 | /// and rewrite it as | ||||
2843 | /// %inner.sel = select i1 %cond.alternative, i8 %sel.outer.t, i8 %sel.inner.t | ||||
2844 | /// %sel.outer = select i1 %cond.inner, i8 %inner.sel, i8 %sel.inner.f | ||||
2845 | static Instruction *foldNestedSelects(SelectInst &OuterSelVal, | ||||
2846 | InstCombiner::BuilderTy &Builder) { | ||||
2847 | // We must start with a `select`. | ||||
2848 | DecomposedSelect OuterSel; | ||||
2849 | match(&OuterSelVal, | ||||
2850 | m_Select(m_Value(OuterSel.Cond), m_Value(OuterSel.TrueVal), | ||||
2851 | m_Value(OuterSel.FalseVal))); | ||||
2852 | |||||
2853 | // Canonicalize inversion of the outermost `select`'s condition. | ||||
2854 | if (match(OuterSel.Cond, m_Not(m_Value(OuterSel.Cond)))) | ||||
2855 | std::swap(OuterSel.TrueVal, OuterSel.FalseVal); | ||||
2856 | |||||
2857 | // The condition of the outermost select must be an `and`/`or`. | ||||
2858 | if (!match(OuterSel.Cond, m_c_LogicalOp(m_Value(), m_Value()))) | ||||
2859 | return nullptr; | ||||
2860 | |||||
2861 | // Depending on the logical op, inner select might be in different hand. | ||||
2862 | bool IsAndVariant = match(OuterSel.Cond, m_LogicalAnd()); | ||||
2863 | Value *InnerSelVal = IsAndVariant ? OuterSel.FalseVal : OuterSel.TrueVal; | ||||
2864 | |||||
2865 | // Profitability check - avoid increasing instruction count. | ||||
2866 | if (none_of(ArrayRef<Value *>({OuterSelVal.getCondition(), InnerSelVal}), | ||||
2867 | [](Value *V) { return V->hasOneUse(); })) | ||||
2868 | return nullptr; | ||||
2869 | |||||
2870 | // The appropriate hand of the outermost `select` must be a select itself. | ||||
2871 | DecomposedSelect InnerSel; | ||||
2872 | if (!match(InnerSelVal, | ||||
2873 | m_Select(m_Value(InnerSel.Cond), m_Value(InnerSel.TrueVal), | ||||
2874 | m_Value(InnerSel.FalseVal)))) | ||||
2875 | return nullptr; | ||||
2876 | |||||
2877 | // Canonicalize inversion of the innermost `select`'s condition. | ||||
2878 | if (match(InnerSel.Cond, m_Not(m_Value(InnerSel.Cond)))) | ||||
2879 | std::swap(InnerSel.TrueVal, InnerSel.FalseVal); | ||||
2880 | |||||
2881 | Value *AltCond = nullptr; | ||||
2882 | auto matchOuterCond = [OuterSel, &AltCond](auto m_InnerCond) { | ||||
2883 | return match(OuterSel.Cond, m_c_LogicalOp(m_InnerCond, m_Value(AltCond))); | ||||
2884 | }; | ||||
2885 | |||||
2886 | // Finally, match the condition that was driving the outermost `select`, | ||||
2887 | // it should be a logical operation between the condition that was driving | ||||
2888 | // the innermost `select` (after accounting for the possible inversions | ||||
2889 | // of the condition), and some other condition. | ||||
2890 | if (matchOuterCond(m_Specific(InnerSel.Cond))) { | ||||
2891 | // Done! | ||||
2892 | } else if (Value * NotInnerCond; matchOuterCond(m_CombineAnd( | ||||
2893 | m_Not(m_Specific(InnerSel.Cond)), m_Value(NotInnerCond)))) { | ||||
2894 | // Done! | ||||
2895 | std::swap(InnerSel.TrueVal, InnerSel.FalseVal); | ||||
2896 | InnerSel.Cond = NotInnerCond; | ||||
2897 | } else // Not the pattern we were looking for. | ||||
2898 | return nullptr; | ||||
2899 | |||||
2900 | Value *SelInner = Builder.CreateSelect( | ||||
2901 | AltCond, IsAndVariant ? OuterSel.TrueVal : InnerSel.FalseVal, | ||||
2902 | IsAndVariant ? InnerSel.TrueVal : OuterSel.FalseVal); | ||||
2903 | SelInner->takeName(InnerSelVal); | ||||
2904 | return SelectInst::Create(InnerSel.Cond, | ||||
2905 | IsAndVariant ? SelInner : InnerSel.TrueVal, | ||||
2906 | !IsAndVariant ? SelInner : InnerSel.FalseVal); | ||||
2907 | } | ||||
2908 | |||||
2909 | Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { | ||||
2910 | Value *CondVal = SI.getCondition(); | ||||
2911 | Value *TrueVal = SI.getTrueValue(); | ||||
2912 | Value *FalseVal = SI.getFalseValue(); | ||||
2913 | Type *SelType = SI.getType(); | ||||
2914 | |||||
2915 | // Avoid potential infinite loops by checking for non-constant condition. | ||||
2916 | // TODO: Can we assert instead by improving canonicalizeSelectToShuffle()? | ||||
2917 | // Scalar select must have simplified? | ||||
2918 | if (!SelType->isIntOrIntVectorTy(1) || isa<Constant>(CondVal) || | ||||
2919 | TrueVal->getType() != CondVal->getType()) | ||||
2920 | return nullptr; | ||||
2921 | |||||
2922 | auto *One = ConstantInt::getTrue(SelType); | ||||
2923 | auto *Zero = ConstantInt::getFalse(SelType); | ||||
2924 | Value *A, *B, *C, *D; | ||||
2925 | |||||
2926 | // Folding select to and/or i1 isn't poison safe in general. impliesPoison | ||||
2927 | // checks whether folding it does not convert a well-defined value into | ||||
2928 | // poison. | ||||
2929 | if (match(TrueVal, m_One())) { | ||||
2930 | if (impliesPoison(FalseVal, CondVal)) { | ||||
2931 | // Change: A = select B, true, C --> A = or B, C | ||||
2932 | return BinaryOperator::CreateOr(CondVal, FalseVal); | ||||
2933 | } | ||||
2934 | |||||
2935 | if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) | ||||
2936 | if (auto *RHS = dyn_cast<FCmpInst>(FalseVal)) | ||||
2937 | if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false, | ||||
2938 | /*IsSelectLogical*/ true)) | ||||
2939 | return replaceInstUsesWith(SI, V); | ||||
2940 | |||||
2941 | // (A && B) || (C && B) --> (A || C) && B | ||||
2942 | if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) && | ||||
2943 | match(FalseVal, m_LogicalAnd(m_Value(C), m_Value(D))) && | ||||
2944 | (CondVal->hasOneUse() || FalseVal->hasOneUse())) { | ||||
2945 | bool CondLogicAnd = isa<SelectInst>(CondVal); | ||||
2946 | bool FalseLogicAnd = isa<SelectInst>(FalseVal); | ||||
2947 | auto AndFactorization = [&](Value *Common, Value *InnerCond, | ||||
2948 | Value *InnerVal, | ||||
2949 | bool SelFirst = false) -> Instruction * { | ||||
2950 | Value *InnerSel = Builder.CreateSelect(InnerCond, One, InnerVal); | ||||
2951 | if (SelFirst) | ||||
2952 | std::swap(Common, InnerSel); | ||||
2953 | if (FalseLogicAnd || (CondLogicAnd && Common == A)) | ||||
2954 | return SelectInst::Create(Common, InnerSel, Zero); | ||||
2955 | else | ||||
2956 | return BinaryOperator::CreateAnd(Common, InnerSel); | ||||
2957 | }; | ||||
2958 | |||||
2959 | if (A == C) | ||||
2960 | return AndFactorization(A, B, D); | ||||
2961 | if (A == D) | ||||
2962 | return AndFactorization(A, B, C); | ||||
2963 | if (B == C) | ||||
2964 | return AndFactorization(B, A, D); | ||||
2965 | if (B == D) | ||||
2966 | return AndFactorization(B, A, C, CondLogicAnd && FalseLogicAnd); | ||||
2967 | } | ||||
2968 | } | ||||
2969 | |||||
2970 | if (match(FalseVal, m_Zero())) { | ||||
2971 | if (impliesPoison(TrueVal, CondVal)) { | ||||
2972 | // Change: A = select B, C, false --> A = and B, C | ||||
2973 | return BinaryOperator::CreateAnd(CondVal, TrueVal); | ||||
2974 | } | ||||
2975 | |||||
2976 | if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) | ||||
2977 | if (auto *RHS = dyn_cast<FCmpInst>(TrueVal)) | ||||
2978 | if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true, | ||||
2979 | /*IsSelectLogical*/ true)) | ||||
2980 | return replaceInstUsesWith(SI, V); | ||||
2981 | |||||
2982 | // (A || B) && (C || B) --> (A && C) || B | ||||
2983 | if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) && | ||||
2984 | match(TrueVal, m_LogicalOr(m_Value(C), m_Value(D))) && | ||||
2985 | (CondVal->hasOneUse() || TrueVal->hasOneUse())) { | ||||
2986 | bool CondLogicOr = isa<SelectInst>(CondVal); | ||||
2987 | bool TrueLogicOr = isa<SelectInst>(TrueVal); | ||||
2988 | auto OrFactorization = [&](Value *Common, Value *InnerCond, | ||||
2989 | Value *InnerVal, | ||||
2990 | bool SelFirst = false) -> Instruction * { | ||||
2991 | Value *InnerSel = Builder.CreateSelect(InnerCond, InnerVal, Zero); | ||||
2992 | if (SelFirst) | ||||
2993 | std::swap(Common, InnerSel); | ||||
2994 | if (TrueLogicOr || (CondLogicOr && Common == A)) | ||||
2995 | return SelectInst::Create(Common, One, InnerSel); | ||||
2996 | else | ||||
2997 | return BinaryOperator::CreateOr(Common, InnerSel); | ||||
2998 | }; | ||||
2999 | |||||
3000 | if (A == C) | ||||
3001 | return OrFactorization(A, B, D); | ||||
3002 | if (A == D) | ||||
3003 | return OrFactorization(A, B, C); | ||||
3004 | if (B == C) | ||||
3005 | return OrFactorization(B, A, D); | ||||
3006 | if (B == D) | ||||
3007 | return OrFactorization(B, A, C, CondLogicOr && TrueLogicOr); | ||||
3008 | } | ||||
3009 | } | ||||
3010 | |||||
3011 | // We match the "full" 0 or 1 constant here to avoid a potential infinite | ||||
3012 | // loop with vectors that may have undefined/poison elements. | ||||
3013 | // select a, false, b -> select !a, b, false | ||||
3014 | if (match(TrueVal, m_Specific(Zero))) { | ||||
3015 | Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); | ||||
3016 | return SelectInst::Create(NotCond, FalseVal, Zero); | ||||
3017 | } | ||||
3018 | // select a, b, true -> select !a, true, b | ||||
3019 | if (match(FalseVal, m_Specific(One))) { | ||||
3020 | Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); | ||||
3021 | return SelectInst::Create(NotCond, One, TrueVal); | ||||
3022 | } | ||||
3023 | |||||
3024 | // DeMorgan in select form: !a && !b --> !(a || b) | ||||
3025 | // select !a, !b, false --> not (select a, true, b) | ||||
3026 | if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) && | ||||
3027 | (CondVal->hasOneUse() || TrueVal->hasOneUse()) && | ||||
3028 | !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) | ||||
3029 | return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B)); | ||||
3030 | |||||
3031 | // DeMorgan in select form: !a || !b --> !(a && b) | ||||
3032 | // select !a, true, !b --> not (select a, b, false) | ||||
3033 | if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) && | ||||
3034 | (CondVal->hasOneUse() || FalseVal->hasOneUse()) && | ||||
3035 | !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) | ||||
3036 | return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero)); | ||||
3037 | |||||
3038 | // select (select a, true, b), true, b -> select a, true, b | ||||
3039 | if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) && | ||||
3040 | match(TrueVal, m_One()) && match(FalseVal, m_Specific(B))) | ||||
3041 | return replaceOperand(SI, 0, A); | ||||
3042 | // select (select a, b, false), b, false -> select a, b, false | ||||
3043 | if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && | ||||
3044 | match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) | ||||
3045 | return replaceOperand(SI, 0, A); | ||||
3046 | // select a, (select ~a, true, b), false -> select a, b, false | ||||
3047 | if (match(TrueVal, m_c_LogicalOr(m_Not(m_Specific(CondVal)), m_Value(B))) && | ||||
3048 | match(FalseVal, m_Zero())) | ||||
3049 | return replaceOperand(SI, 1, B); | ||||
3050 | // select a, true, (select ~a, b, false) -> select a, true, b | ||||
3051 | if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Specific(CondVal)), m_Value(B))) && | ||||
3052 | match(TrueVal, m_One())) | ||||
3053 | return replaceOperand(SI, 2, B); | ||||
3054 | |||||
3055 | // ~(A & B) & (A | B) --> A ^ B | ||||
3056 | if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))), | ||||
3057 | m_c_LogicalOr(m_Deferred(A), m_Deferred(B))))) | ||||
3058 | return BinaryOperator::CreateXor(A, B); | ||||
3059 | |||||
3060 | // select (~a | c), a, b -> and a, (or c, freeze(b)) | ||||
3061 | if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) && | ||||
3062 | CondVal->hasOneUse()) { | ||||
3063 | FalseVal = Builder.CreateFreeze(FalseVal); | ||||
3064 | return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal)); | ||||
3065 | } | ||||
3066 | // select (~c & b), a, b -> and b, (or freeze(a), c) | ||||
3067 | if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) && | ||||
3068 | CondVal->hasOneUse()) { | ||||
3069 | TrueVal = Builder.CreateFreeze(TrueVal); | ||||
3070 | return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal)); | ||||
3071 | } | ||||
3072 | |||||
3073 | if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { | ||||
3074 | Use *Y = nullptr; | ||||
3075 | bool IsAnd = match(FalseVal, m_Zero()) ? true : false; | ||||
3076 | Value *Op1 = IsAnd ? TrueVal : FalseVal; | ||||
3077 | if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) { | ||||
3078 | auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr"); | ||||
3079 | InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser())); | ||||
3080 | replaceUse(*Y, FI); | ||||
3081 | return replaceInstUsesWith(SI, Op1); | ||||
3082 | } | ||||
3083 | |||||
3084 | if (auto *Op1SI = dyn_cast<SelectInst>(Op1)) | ||||
3085 | if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI, | ||||
3086 | /* IsAnd */ IsAnd)) | ||||
3087 | return I; | ||||
3088 | |||||
3089 | if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) | ||||
3090 | if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) | ||||
3091 | if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd, | ||||
3092 | /* IsLogical */ true)) | ||||
3093 | return replaceInstUsesWith(SI, V); | ||||
3094 | } | ||||
3095 | |||||
3096 | // select (a || b), c, false -> select a, c, false | ||||
3097 | // select c, (a || b), false -> select c, a, false | ||||
3098 | // if c implies that b is false. | ||||
3099 | if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) && | ||||
3100 | match(FalseVal, m_Zero())) { | ||||
3101 | std::optional<bool> Res = isImpliedCondition(TrueVal, B, DL); | ||||
3102 | if (Res && *Res == false) | ||||
3103 | return replaceOperand(SI, 0, A); | ||||
3104 | } | ||||
3105 | if (match(TrueVal, m_LogicalOr(m_Value(A), m_Value(B))) && | ||||
3106 | match(FalseVal, m_Zero())) { | ||||
3107 | std::optional<bool> Res = isImpliedCondition(CondVal, B, DL); | ||||
3108 | if (Res && *Res == false) | ||||
3109 | return replaceOperand(SI, 1, A); | ||||
3110 | } | ||||
3111 | // select c, true, (a && b) -> select c, true, a | ||||
3112 | // select (a && b), true, c -> select a, true, c | ||||
3113 | // if c = false implies that b = true | ||||
3114 | if (match(TrueVal, m_One()) && | ||||
3115 | match(FalseVal, m_LogicalAnd(m_Value(A), m_Value(B)))) { | ||||
3116 | std::optional<bool> Res = isImpliedCondition(CondVal, B, DL, false); | ||||
3117 | if (Res && *Res == true) | ||||
3118 | return replaceOperand(SI, 2, A); | ||||
3119 | } | ||||
3120 | if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) && | ||||
3121 | match(TrueVal, m_One())) { | ||||
3122 | std::optional<bool> Res = isImpliedCondition(FalseVal, B, DL, false); | ||||
3123 | if (Res && *Res == true) | ||||
3124 | return replaceOperand(SI, 0, A); | ||||
3125 | } | ||||
3126 | |||||
3127 | if (match(TrueVal, m_One())) { | ||||
3128 | Value *C; | ||||
3129 | |||||
3130 | // (C && A) || (!C && B) --> sel C, A, B | ||||
3131 | // (A && C) || (!C && B) --> sel C, A, B | ||||
3132 | // (C && A) || (B && !C) --> sel C, A, B | ||||
3133 | // (A && C) || (B && !C) --> sel C, A, B (may require freeze) | ||||
3134 | if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Value(C)), m_Value(B))) && | ||||
3135 | match(CondVal, m_c_LogicalAnd(m_Specific(C), m_Value(A)))) { | ||||
3136 | auto *SelCond = dyn_cast<SelectInst>(CondVal); | ||||
3137 | auto *SelFVal = dyn_cast<SelectInst>(FalseVal); | ||||
3138 | bool MayNeedFreeze = SelCond && SelFVal && | ||||
3139 | match(SelFVal->getTrueValue(), | ||||
3140 | m_Not(m_Specific(SelCond->getTrueValue()))); | ||||
3141 | if (MayNeedFreeze) | ||||
3142 | C = Builder.CreateFreeze(C); | ||||
3143 | return SelectInst::Create(C, A, B); | ||||
3144 | } | ||||
3145 | |||||
3146 | // (!C && A) || (C && B) --> sel C, B, A | ||||
3147 | // (A && !C) || (C && B) --> sel C, B, A | ||||
3148 | // (!C && A) || (B && C) --> sel C, B, A | ||||
3149 | // (A && !C) || (B && C) --> sel C, B, A (may require freeze) | ||||
3150 | if (match(CondVal, m_c_LogicalAnd(m_Not(m_Value(C)), m_Value(A))) && | ||||
3151 | match(FalseVal, m_c_LogicalAnd(m_Specific(C), m_Value(B)))) { | ||||
3152 | auto *SelCond = dyn_cast<SelectInst>(CondVal); | ||||
3153 | auto *SelFVal = dyn_cast<SelectInst>(FalseVal); | ||||
3154 | bool MayNeedFreeze = SelCond && SelFVal && | ||||
3155 | match(SelCond->getTrueValue(), | ||||
3156 | m_Not(m_Specific(SelFVal->getTrueValue()))); | ||||
3157 | if (MayNeedFreeze) | ||||
3158 | C = Builder.CreateFreeze(C); | ||||
3159 | return SelectInst::Create(C, B, A); | ||||
3160 | } | ||||
3161 | } | ||||
3162 | |||||
3163 | return nullptr; | ||||
3164 | } | ||||
3165 | |||||
3166 | // Return true if we can safely remove the select instruction for std::bit_ceil | ||||
3167 | // pattern. | ||||
3168 | static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, | ||||
3169 | const APInt *Cond1, Value *CtlzOp, | ||||
3170 | unsigned BitWidth) { | ||||
3171 | // The challenge in recognizing std::bit_ceil(X) is that the operand is used | ||||
3172 | // for the CTLZ proper and select condition, each possibly with some | ||||
3173 | // operation like add and sub. | ||||
3174 | // | ||||
3175 | // Our aim is to make sure that -ctlz & (BitWidth - 1) == 0 even when the | ||||
3176 | // select instruction would select 1, which allows us to get rid of the select | ||||
3177 | // instruction. | ||||
3178 | // | ||||
3179 | // To see if we can do so, we do some symbolic execution with ConstantRange. | ||||
3180 | // Specifically, we compute the range of values that Cond0 could take when | ||||
3181 | // Cond == false. Then we successively transform the range until we obtain | ||||
3182 | // the range of values that CtlzOp could take. | ||||
3183 | // | ||||
3184 | // Conceptually, we follow the def-use chain backward from Cond0 while | ||||
3185 | // transforming the range for Cond0 until we meet the common ancestor of Cond0 | ||||
3186 | // and CtlzOp. Then we follow the def-use chain forward until we obtain the | ||||
3187 | // range for CtlzOp. That said, we only follow at most one ancestor from | ||||
3188 | // Cond0. Likewise, we only follow at most one ancestor from CtrlOp. | ||||
3189 | |||||
3190 | ConstantRange CR = ConstantRange::makeExactICmpRegion( | ||||
3191 | CmpInst::getInversePredicate(Pred), *Cond1); | ||||
3192 | |||||
3193 | // Match the operation that's used to compute CtlzOp from CommonAncestor. If | ||||
3194 | // CtlzOp == CommonAncestor, return true as no operation is needed. If a | ||||
3195 | // match is found, execute the operation on CR, update CR, and return true. | ||||
3196 | // Otherwise, return false. | ||||
3197 | auto MatchForward = [&](Value *CommonAncestor) { | ||||
3198 | const APInt *C = nullptr; | ||||
3199 | if (CtlzOp == CommonAncestor) | ||||
3200 | return true; | ||||
3201 | if (match(CtlzOp, m_Add(m_Specific(CommonAncestor), m_APInt(C)))) { | ||||
3202 | CR = CR.add(*C); | ||||
3203 | return true; | ||||
3204 | } | ||||
3205 | if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) { | ||||
3206 | CR = ConstantRange(*C).sub(CR); | ||||
3207 | return true; | ||||
3208 | } | ||||
3209 | if (match(CtlzOp, m_Not(m_Specific(CommonAncestor)))) { | ||||
3210 | CR = CR.binaryNot(); | ||||
3211 | return true; | ||||
3212 | } | ||||
3213 | return false; | ||||
3214 | }; | ||||
3215 | |||||
3216 | const APInt *C = nullptr; | ||||
3217 | Value *CommonAncestor; | ||||
3218 | if (MatchForward(Cond0)) { | ||||
3219 | // Cond0 is either CtlzOp or CtlzOp's parent. CR has been updated. | ||||
3220 | } else if (match(Cond0, m_Add(m_Value(CommonAncestor), m_APInt(C)))) { | ||||
3221 | CR = CR.sub(*C); | ||||
3222 | if (!MatchForward(CommonAncestor)) | ||||
3223 | return false; | ||||
3224 | // Cond0's parent is either CtlzOp or CtlzOp's parent. CR has been updated. | ||||
3225 | } else { | ||||
3226 | return false; | ||||
3227 | } | ||||
3228 | |||||
3229 | // Return true if all the values in the range are either 0 or negative (if | ||||
3230 | // treated as signed). We do so by evaluating: | ||||
3231 | // | ||||
3232 | // CR - 1 u>= (1 << BitWidth) - 1. | ||||
3233 | APInt IntMax = APInt::getSignMask(BitWidth) - 1; | ||||
3234 | CR = CR.sub(APInt(BitWidth, 1)); | ||||
3235 | return CR.icmp(ICmpInst::ICMP_UGE, IntMax); | ||||
3236 | } | ||||
3237 | |||||
3238 | // Transform the std::bit_ceil(X) pattern like: | ||||
3239 | // | ||||
3240 | // %dec = add i32 %x, -1 | ||||
3241 | // %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) | ||||
3242 | // %sub = sub i32 32, %ctlz | ||||
3243 | // %shl = shl i32 1, %sub | ||||
3244 | // %ugt = icmp ugt i32 %x, 1 | ||||
3245 | // %sel = select i1 %ugt, i32 %shl, i32 1 | ||||
3246 | // | ||||
3247 | // into: | ||||
3248 | // | ||||
3249 | // %dec = add i32 %x, -1 | ||||
3250 | // %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) | ||||
3251 | // %neg = sub i32 0, %ctlz | ||||
3252 | // %masked = and i32 %ctlz, 31 | ||||
3253 | // %shl = shl i32 1, %sub | ||||
3254 | // | ||||
3255 | // Note that the select is optimized away while the shift count is masked with | ||||
3256 | // 31. We handle some variations of the input operand like std::bit_ceil(X + | ||||
3257 | // 1). | ||||
3258 | static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) { | ||||
3259 | Type *SelType = SI.getType(); | ||||
3260 | unsigned BitWidth = SelType->getScalarSizeInBits(); | ||||
3261 | |||||
3262 | Value *FalseVal = SI.getFalseValue(); | ||||
3263 | Value *TrueVal = SI.getTrueValue(); | ||||
3264 | ICmpInst::Predicate Pred; | ||||
3265 | const APInt *Cond1; | ||||
3266 | Value *Cond0, *Ctlz, *CtlzOp; | ||||
3267 | if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1)))) | ||||
3268 | return nullptr; | ||||
3269 | |||||
3270 | if (match(TrueVal, m_One())) { | ||||
3271 | std::swap(FalseVal, TrueVal); | ||||
3272 | Pred = CmpInst::getInversePredicate(Pred); | ||||
3273 | } | ||||
3274 | |||||
3275 | if (!match(FalseVal, m_One()) || | ||||
3276 | !match(TrueVal, | ||||
3277 | m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth), | ||||
3278 | m_Value(Ctlz)))))) || | ||||
3279 | !match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) || | ||||
3280 | !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth)) | ||||
3281 | return nullptr; | ||||
3282 | |||||
3283 | // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a | ||||
3284 | // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth | ||||
3285 | // is an integer constant. Masking with BitWidth-1 comes free on some | ||||
3286 | // hardware as part of the shift instruction. | ||||
3287 | Value *Neg = Builder.CreateNeg(Ctlz); | ||||
3288 | Value *Masked = | ||||
3289 | Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1)); | ||||
3290 | return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1), | ||||
3291 | Masked); | ||||
3292 | } | ||||
3293 | |||||
3294 | Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { | ||||
3295 | Value *CondVal = SI.getCondition(); | ||||
3296 | Value *TrueVal = SI.getTrueValue(); | ||||
3297 | Value *FalseVal = SI.getFalseValue(); | ||||
3298 | Type *SelType = SI.getType(); | ||||
3299 | |||||
3300 | if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal, | ||||
| |||||
3301 | SQ.getWithInstruction(&SI))) | ||||
3302 | return replaceInstUsesWith(SI, V); | ||||
3303 | |||||
3304 | if (Instruction *I
| ||||
3305 | return I; | ||||
3306 | |||||
3307 | if (Instruction *I
| ||||
3308 | return I; | ||||
3309 | |||||
3310 | // If the type of select is not an integer type or if the condition and | ||||
3311 | // the selection type are not both scalar nor both vector types, there is no | ||||
3312 | // point in attempting to match these patterns. | ||||
3313 | Type *CondType = CondVal->getType(); | ||||
3314 | if (!isa<Constant>(CondVal) && SelType->isIntOrIntVectorTy() && | ||||
3315 | CondType->isVectorTy() == SelType->isVectorTy()) { | ||||
3316 | if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, | ||||
3317 | ConstantInt::getTrue(CondType), SQ, | ||||
3318 | /* AllowRefinement */ true)) | ||||
3319 | return replaceOperand(SI, 1, S); | ||||
3320 | |||||
3321 | if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, | ||||
3322 | ConstantInt::getFalse(CondType), SQ, | ||||
3323 | /* AllowRefinement */ true)) | ||||
3324 | return replaceOperand(SI, 2, S); | ||||
3325 | |||||
3326 | // Handle patterns involving sext/zext + not explicitly, | ||||
3327 | // as simplifyWithOpReplaced() only looks past one instruction. | ||||
3328 | Value *NotCond; | ||||
3329 | |||||
3330 | // select a, sext(!a), b -> select !a, b, 0 | ||||
3331 | // select a, zext(!a), b -> select !a, b, 0 | ||||
3332 | if (match(TrueVal, m_ZExtOrSExt(m_CombineAnd(m_Value(NotCond), | ||||
3333 | m_Not(m_Specific(CondVal)))))) | ||||
3334 | return SelectInst::Create(NotCond, FalseVal, | ||||
3335 | Constant::getNullValue(SelType)); | ||||
3336 | |||||
3337 | // select a, b, zext(!a) -> select !a, 1, b | ||||
3338 | if (match(FalseVal, m_ZExt(m_CombineAnd(m_Value(NotCond), | ||||
3339 | m_Not(m_Specific(CondVal)))))) | ||||
3340 | return SelectInst::Create(NotCond, ConstantInt::get(SelType, 1), TrueVal); | ||||
3341 | |||||
3342 | // select a, b, sext(!a) -> select !a, -1, b | ||||
3343 | if (match(FalseVal, m_SExt(m_CombineAnd(m_Value(NotCond), | ||||
3344 | m_Not(m_Specific(CondVal)))))) | ||||
3345 | return SelectInst::Create(NotCond, Constant::getAllOnesValue(SelType), | ||||
3346 | TrueVal); | ||||
3347 | } | ||||
3348 | |||||
3349 | if (Instruction *R = foldSelectOfBools(SI)) | ||||
3350 | return R; | ||||
3351 | |||||
3352 | // Selecting between two integer or vector splat integer constants? | ||||
3353 | // | ||||
3354 | // Note that we don't handle a scalar select of vectors: | ||||
3355 | // select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0> | ||||
3356 | // because that may need 3 instructions to splat the condition value: | ||||
3357 | // extend, insertelement, shufflevector. | ||||
3358 | // | ||||
3359 | // Do not handle i1 TrueVal and FalseVal otherwise would result in | ||||
3360 | // zext/sext i1 to i1. | ||||
3361 | if (SelType->isIntOrIntVectorTy() && !SelType->isIntOrIntVectorTy(1) && | ||||
3362 | CondVal->getType()->isVectorTy() == SelType->isVectorTy()) { | ||||
3363 | // select C, 1, 0 -> zext C to int | ||||
3364 | if (match(TrueVal, m_One()) && match(FalseVal, m_Zero())) | ||||
3365 | return new ZExtInst(CondVal, SelType); | ||||
3366 | |||||
3367 | // select C, -1, 0 -> sext C to int | ||||
3368 | if (match(TrueVal, m_AllOnes()) && match(FalseVal, m_Zero())) | ||||
3369 | return new SExtInst(CondVal, SelType); | ||||
3370 | |||||
3371 | // select C, 0, 1 -> zext !C to int | ||||
3372 | if (match(TrueVal, m_Zero()) && match(FalseVal, m_One())) { | ||||
3373 | Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); | ||||
3374 | return new ZExtInst(NotCond, SelType); | ||||
3375 | } | ||||
3376 | |||||
3377 | // select C, 0, -1 -> sext !C to int | ||||
3378 | if (match(TrueVal, m_Zero()) && match(FalseVal, m_AllOnes())) { | ||||
3379 | Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); | ||||
3380 | return new SExtInst(NotCond, SelType); | ||||
3381 | } | ||||
3382 | } | ||||
3383 | |||||
3384 | if (auto *FCmp
| ||||
3385 | Value *Cmp0 = FCmp->getOperand(0), *Cmp1 = FCmp->getOperand(1); | ||||
3386 | // Are we selecting a value based on a comparison of the two values? | ||||
3387 | if ((Cmp0 == TrueVal && Cmp1 == FalseVal) || | ||||
3388 | (Cmp0 == FalseVal && Cmp1 == TrueVal)) { | ||||
3389 | // Canonicalize to use ordered comparisons by swapping the select | ||||
3390 | // operands. | ||||
3391 | // | ||||
3392 | // e.g. | ||||
3393 | // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X | ||||
3394 | if (FCmp->hasOneUse() && FCmpInst::isUnordered(FCmp->getPredicate())) { | ||||
3395 | FCmpInst::Predicate InvPred = FCmp->getInversePredicate(); | ||||
3396 | IRBuilder<>::FastMathFlagGuard FMFG(Builder); | ||||
3397 | // FIXME: The FMF should propagate from the select, not the fcmp. | ||||
3398 | Builder.setFastMathFlags(FCmp->getFastMathFlags()); | ||||
3399 | Value *NewCond = Builder.CreateFCmp(InvPred, Cmp0, Cmp1, | ||||
3400 | FCmp->getName() + ".inv"); | ||||
3401 | Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal); | ||||
3402 | return replaceInstUsesWith(SI, NewSel); | ||||
3403 | } | ||||
3404 | } | ||||
3405 | } | ||||
3406 | |||||
3407 | if (isa<FPMathOperator>(SI)) { | ||||
3408 | // TODO: Try to forward-propagate FMF from select arms to the select. | ||||
3409 | |||||
3410 | // Canonicalize select of FP values where NaN and -0.0 are not valid as | ||||
3411 | // minnum/maxnum intrinsics. | ||||
3412 | if (SI.hasNoNaNs() && SI.hasNoSignedZeros()) { | ||||
3413 | Value *X, *Y; | ||||
3414 | if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y)))) | ||||
3415 | return replaceInstUsesWith( | ||||
3416 | SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI)); | ||||
3417 | |||||
3418 | if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y)))) | ||||
3419 | return replaceInstUsesWith( | ||||
3420 | SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI)); | ||||
3421 | } | ||||
3422 | } | ||||
3423 | |||||
3424 | // Fold selecting to fabs. | ||||
3425 | if (Instruction *Fabs
| ||||
3426 | return Fabs; | ||||
3427 | |||||
3428 | // See if we are selecting two values based on a comparison of the two values. | ||||
3429 | if (ICmpInst *ICI
| ||||
3430 | if (Instruction *Result = foldSelectInstWithICmp(SI, ICI)) | ||||
3431 | return Result; | ||||
3432 | |||||
3433 | if (Instruction *Add
| ||||
3434 | return Add; | ||||
3435 | if (Instruction *Add
| ||||
3436 | return Add; | ||||
3437 | if (Instruction *Or
| ||||
3438 | return Or; | ||||
3439 | if (Instruction *Mul
| ||||
3440 | return Mul; | ||||
3441 | |||||
3442 | // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) | ||||
3443 | auto *TI = dyn_cast<Instruction>(TrueVal); | ||||
3444 | auto *FI = dyn_cast<Instruction>(FalseVal); | ||||
3445 | if (TI
| ||||
3446 | if (Instruction *IV = foldSelectOpOp(SI, TI, FI)) | ||||
3447 | return IV; | ||||
3448 | |||||
3449 | if (Instruction *I
| ||||
3450 | return I; | ||||
3451 | |||||
3452 | // Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0)) | ||||
3453 | // Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx)) | ||||
3454 | auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base, | ||||
3455 | bool Swap) -> GetElementPtrInst * { | ||||
3456 | Value *Ptr = Gep->getPointerOperand(); | ||||
3457 | if (Gep->getNumOperands() != 2 || Gep->getPointerOperand() != Base || | ||||
3458 | !Gep->hasOneUse()) | ||||
3459 | return nullptr; | ||||
3460 | Value *Idx = Gep->getOperand(1); | ||||
3461 | if (isa<VectorType>(CondVal->getType()) && !isa<VectorType>(Idx->getType())) | ||||
3462 | return nullptr; | ||||
3463 | Type *ElementType = Gep->getResultElementType(); | ||||
3464 | Value *NewT = Idx; | ||||
3465 | Value *NewF = Constant::getNullValue(Idx->getType()); | ||||
3466 | if (Swap) | ||||
3467 | std::swap(NewT, NewF); | ||||
3468 | Value *NewSI = | ||||
3469 | Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI); | ||||
3470 | return GetElementPtrInst::Create(ElementType, Ptr, {NewSI}); | ||||
3471 | }; | ||||
3472 | if (auto *TrueGep
| ||||
3473 | if (auto *NewGep = SelectGepWithBase(TrueGep, FalseVal, false)) | ||||
3474 | return NewGep; | ||||
3475 | if (auto *FalseGep
| ||||
3476 | if (auto *NewGep = SelectGepWithBase(FalseGep, TrueVal, true)) | ||||
3477 | return NewGep; | ||||
3478 | |||||
3479 | // See if we can fold the select into one of our operands. | ||||
3480 | if (SelType->isIntOrIntVectorTy() || SelType->isFPOrFPVectorTy()) { | ||||
3481 | if (Instruction *FoldI = foldSelectIntoOp(SI, TrueVal, FalseVal)) | ||||
3482 | return FoldI; | ||||
3483 | |||||
3484 | Value *LHS, *RHS; | ||||
3485 | Instruction::CastOps CastOp; | ||||
3486 | SelectPatternResult SPR = matchSelectPattern(&SI, LHS, RHS, &CastOp); | ||||
3487 | auto SPF = SPR.Flavor; | ||||
3488 | if (SPF) { | ||||
3489 | Value *LHS2, *RHS2; | ||||
3490 | if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor) | ||||
3491 | if (Instruction *R = foldSPFofSPF(cast<Instruction>(LHS), SPF2, LHS2, | ||||
3492 | RHS2, SI, SPF, RHS)) | ||||
3493 | return R; | ||||
3494 | if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor) | ||||
3495 | if (Instruction *R = foldSPFofSPF(cast<Instruction>(RHS), SPF2, LHS2, | ||||
3496 | RHS2, SI, SPF, LHS)) | ||||
3497 | return R; | ||||
3498 | } | ||||
3499 | |||||
3500 | if (SelectPatternResult::isMinOrMax(SPF)) { | ||||
3501 | // Canonicalize so that | ||||
3502 | // - type casts are outside select patterns. | ||||
3503 | // - float clamp is transformed to min/max pattern | ||||
3504 | |||||
3505 | bool IsCastNeeded = LHS->getType() != SelType; | ||||
3506 | Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0); | ||||
3507 | Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1); | ||||
3508 | if (IsCastNeeded || | ||||
3509 | (LHS->getType()->isFPOrFPVectorTy() && | ||||
3510 | ((CmpLHS != LHS && CmpLHS != RHS) || | ||||
3511 | (CmpRHS != LHS && CmpRHS != RHS)))) { | ||||
3512 | CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered); | ||||
3513 | |||||
3514 | Value *Cmp; | ||||
3515 | if (CmpInst::isIntPredicate(MinMaxPred)) { | ||||
3516 | Cmp = Builder.CreateICmp(MinMaxPred, LHS, RHS); | ||||
3517 | } else { | ||||
3518 | IRBuilder<>::FastMathFlagGuard FMFG(Builder); | ||||
3519 | auto FMF = | ||||
3520 | cast<FPMathOperator>(SI.getCondition())->getFastMathFlags(); | ||||
3521 | Builder.setFastMathFlags(FMF); | ||||
3522 | Cmp = Builder.CreateFCmp(MinMaxPred, LHS, RHS); | ||||
3523 | } | ||||
3524 | |||||
3525 | Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI); | ||||
3526 | if (!IsCastNeeded) | ||||
3527 | return replaceInstUsesWith(SI, NewSI); | ||||
3528 | |||||
3529 | Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType); | ||||
3530 | return replaceInstUsesWith(SI, NewCast); | ||||
3531 | } | ||||
3532 | } | ||||
3533 | } | ||||
3534 | |||||
3535 | // See if we can fold the select into a phi node if the condition is a select. | ||||
3536 | if (auto *PN
| ||||
3537 | // The true/false values have to be live in the PHI predecessor's blocks. | ||||
3538 | if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) && | ||||
3539 | canSelectOperandBeMappingIntoPredBlock(FalseVal, SI)) | ||||
3540 | if (Instruction *NV = foldOpIntoPhi(SI, PN)) | ||||
3541 | return NV; | ||||
3542 | |||||
3543 | if (SelectInst *TrueSI
| ||||
3544 | if (TrueSI->getCondition()->getType() == CondVal->getType()) { | ||||
3545 | // select(C, select(C, a, b), c) -> select(C, a, c) | ||||
3546 | if (TrueSI->getCondition() == CondVal) { | ||||
3547 | if (SI.getTrueValue() == TrueSI->getTrueValue()) | ||||
3548 | return nullptr; | ||||
3549 | return replaceOperand(SI, 1, TrueSI->getTrueValue()); | ||||
3550 | } | ||||
3551 | // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b) | ||||
3552 | // We choose this as normal form to enable folding on the And and | ||||
3553 | // shortening paths for the values (this helps getUnderlyingObjects() for | ||||
3554 | // example). | ||||
3555 | if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) { | ||||
3556 | Value *And = Builder.CreateLogicalAnd(CondVal, TrueSI->getCondition()); | ||||
3557 | replaceOperand(SI, 0, And); | ||||
3558 | replaceOperand(SI, 1, TrueSI->getTrueValue()); | ||||
3559 | return &SI; | ||||
3560 | } | ||||
3561 | } | ||||
3562 | } | ||||
3563 | if (SelectInst *FalseSI
| ||||
3564 | if (FalseSI->getCondition()->getType() == CondVal->getType()) { | ||||
3565 | // select(C, a, select(C, b, c)) -> select(C, a, c) | ||||
3566 | if (FalseSI->getCondition() == CondVal) { | ||||
3567 | if (SI.getFalseValue() == FalseSI->getFalseValue()) | ||||
3568 | return nullptr; | ||||
3569 | return replaceOperand(SI, 2, FalseSI->getFalseValue()); | ||||
3570 | } | ||||
3571 | // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b) | ||||
3572 | if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) { | ||||
3573 | Value *Or = Builder.CreateLogicalOr(CondVal, FalseSI->getCondition()); | ||||
3574 | replaceOperand(SI, 0, Or); | ||||
3575 | replaceOperand(SI, 2, FalseSI->getFalseValue()); | ||||
3576 | return &SI; | ||||
3577 | } | ||||
3578 | } | ||||
3579 | } | ||||
3580 | |||||
3581 | // Try to simplify a binop sandwiched between 2 selects with the same | ||||
3582 | // condition. This is not valid for div/rem because the select might be | ||||
3583 | // preventing a division-by-zero. | ||||
3584 | // TODO: A div/rem restriction is conservative; use something like | ||||
3585 | // isSafeToSpeculativelyExecute(). | ||||
3586 | // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) | ||||
3587 | BinaryOperator *TrueBO; | ||||
3588 | if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && !TrueBO->isIntDivRem()) { | ||||
3589 | if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { | ||||
3590 | if (TrueBOSI->getCondition() == CondVal) { | ||||
3591 | replaceOperand(*TrueBO, 0, TrueBOSI->getTrueValue()); | ||||
3592 | Worklist.push(TrueBO); | ||||
3593 | return &SI; | ||||
3594 | } | ||||
3595 | } | ||||
3596 | if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(1))) { | ||||
3597 | if (TrueBOSI->getCondition() == CondVal) { | ||||
3598 | replaceOperand(*TrueBO, 1, TrueBOSI->getTrueValue()); | ||||
3599 | Worklist.push(TrueBO); | ||||
3600 | return &SI; | ||||
3601 | } | ||||
3602 | } | ||||
3603 | } | ||||
3604 | |||||
3605 | // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) | ||||
3606 | BinaryOperator *FalseBO; | ||||
3607 | if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && !FalseBO->isIntDivRem()) { | ||||
3608 | if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { | ||||
3609 | if (FalseBOSI->getCondition() == CondVal) { | ||||
3610 | replaceOperand(*FalseBO, 0, FalseBOSI->getFalseValue()); | ||||
3611 | Worklist.push(FalseBO); | ||||
3612 | return &SI; | ||||
3613 | } | ||||
3614 | } | ||||
3615 | if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(1))) { | ||||
3616 | if (FalseBOSI->getCondition() == CondVal) { | ||||
3617 | replaceOperand(*FalseBO, 1, FalseBOSI->getFalseValue()); | ||||
3618 | Worklist.push(FalseBO); | ||||
3619 | return &SI; | ||||
3620 | } | ||||
3621 | } | ||||
3622 | } | ||||
3623 | |||||
3624 | Value *NotCond; | ||||
3625 | if (match(CondVal, m_Not(m_Value(NotCond))) && | ||||
3626 | !InstCombiner::shouldAvoidAbsorbingNotIntoSelect(SI)) { | ||||
3627 | replaceOperand(SI, 0, NotCond); | ||||
3628 | SI.swapValues(); | ||||
3629 | SI.swapProfMetadata(); | ||||
3630 | return &SI; | ||||
3631 | } | ||||
3632 | |||||
3633 | if (Instruction *I
| ||||
3634 | return I; | ||||
3635 | |||||
3636 | // If we can compute the condition, there's no need for a select. | ||||
3637 | // Like the above fold, we are attempting to reduce compile-time cost by | ||||
3638 | // putting this fold here with limitations rather than in InstSimplify. | ||||
3639 | // The motivation for this call into value tracking is to take advantage of | ||||
3640 | // the assumption cache, so make sure that is populated. | ||||
3641 | if (!CondVal->getType()->isVectorTy() && !AC.assumptions().empty()) { | ||||
3642 | KnownBits Known(1); | ||||
3643 | computeKnownBits(CondVal, Known, 0, &SI); | ||||
3644 | if (Known.One.isOne()) | ||||
3645 | return replaceInstUsesWith(SI, TrueVal); | ||||
3646 | if (Known.Zero.isOne()) | ||||
3647 | return replaceInstUsesWith(SI, FalseVal); | ||||
3648 | } | ||||
3649 | |||||
3650 | if (Instruction *BitCastSel
| ||||
3651 | return BitCastSel; | ||||
3652 | |||||
3653 | // Simplify selects that test the returned flag of cmpxchg instructions. | ||||
3654 | if (Value *V
| ||||
3655 | return replaceInstUsesWith(SI, V); | ||||
3656 | |||||
3657 | if (Instruction *Select
| ||||
3658 | return Select; | ||||
3659 | |||||
3660 | if (Instruction *Funnel
| ||||
3661 | return Funnel; | ||||
3662 | |||||
3663 | if (Instruction *Copysign
| ||||
3664 | return Copysign; | ||||
3665 | |||||
3666 | if (Instruction *PN
| ||||
3667 | return replaceInstUsesWith(SI, PN); | ||||
3668 | |||||
3669 | if (Value *Fr
| ||||
3670 | return replaceInstUsesWith(SI, Fr); | ||||
3671 | |||||
3672 | if (Value *V
| ||||
3673 | return replaceInstUsesWith(SI, V); | ||||
3674 | |||||
3675 | // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0) | ||||
3676 | // Load inst is intentionally not checked for hasOneUse() | ||||
3677 | if (match(FalseVal, m_Zero()) && | ||||
3678 | (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), | ||||
3679 | m_CombineOr(m_Undef(), m_Zero()))) || | ||||
3680 | match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal), | ||||
3681 | m_CombineOr(m_Undef(), m_Zero()))))) { | ||||
3682 | auto *MaskedInst = cast<IntrinsicInst>(TrueVal); | ||||
3683 | if (isa<UndefValue>(MaskedInst->getArgOperand(3))) | ||||
3684 | MaskedInst->setArgOperand(3, FalseVal /* Zero */); | ||||
3685 | return replaceInstUsesWith(SI, MaskedInst); | ||||
3686 | } | ||||
3687 | |||||
3688 | Value *Mask; | ||||
3689 | if (match(TrueVal, m_Zero()) && | ||||
3690 | (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), | ||||
3691 | m_CombineOr(m_Undef(), m_Zero()))) || | ||||
3692 | match(FalseVal, m_MaskedGather(m_Value(), m_Value(), m_Value(Mask), | ||||
3693 | m_CombineOr(m_Undef(), m_Zero())))) && | ||||
3694 | (CondVal->getType() == Mask->getType())) { | ||||
3695 | // We can remove the select by ensuring the load zeros all lanes the | ||||
3696 | // select would have. We determine this by proving there is no overlap | ||||
3697 | // between the load and select masks. | ||||
3698 | // (i.e (load_mask & select_mask) == 0 == no overlap) | ||||
3699 | bool CanMergeSelectIntoLoad = false; | ||||
3700 | if (Value *V = simplifyAndInst(CondVal, Mask, SQ.getWithInstruction(&SI))) | ||||
3701 | CanMergeSelectIntoLoad = match(V, m_Zero()); | ||||
3702 | |||||
3703 | if (CanMergeSelectIntoLoad) { | ||||
3704 | auto *MaskedInst = cast<IntrinsicInst>(FalseVal); | ||||
3705 | if (isa<UndefValue>(MaskedInst->getArgOperand(3))) | ||||
3706 | MaskedInst->setArgOperand(3, TrueVal /* Zero */); | ||||
3707 | return replaceInstUsesWith(SI, MaskedInst); | ||||
3708 | } | ||||
3709 | } | ||||
3710 | |||||
3711 | if (Instruction *I = foldNestedSelects(SI, Builder)) | ||||
3712 | return I; | ||||
3713 | |||||
3714 | // Match logical variants of the pattern, | ||||
3715 | // and transform them iff that gets rid of inversions. | ||||
3716 | // (~x) | y --> ~(x & (~y)) | ||||
3717 | // (~x) & y --> ~(x | (~y)) | ||||
3718 | if (sinkNotIntoOtherHandOfLogicalOp(SI)) | ||||
3719 | return &SI; | ||||
3720 | |||||
3721 | if (Instruction *I = foldBitCeil(SI, Builder)) | ||||
3722 | return I; | ||||
3723 | |||||
3724 | return nullptr; | ||||
3725 | } |
1 | //===- PatternMatch.h - Match on the LLVM IR --------------------*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file provides a simple and efficient mechanism for performing general | |||
10 | // tree-based pattern matches on the LLVM IR. The power of these routines is | |||
11 | // that it allows you to write concise patterns that are expressive and easy to | |||
12 | // understand. The other major advantage of this is that it allows you to | |||
13 | // trivially capture/bind elements in the pattern to variables. For example, | |||
14 | // you can do something like this: | |||
15 | // | |||
16 | // Value *Exp = ... | |||
17 | // Value *X, *Y; ConstantInt *C1, *C2; // (X & C1) | (Y & C2) | |||
18 | // if (match(Exp, m_Or(m_And(m_Value(X), m_ConstantInt(C1)), | |||
19 | // m_And(m_Value(Y), m_ConstantInt(C2))))) { | |||
20 | // ... Pattern is matched and variables are bound ... | |||
21 | // } | |||
22 | // | |||
23 | // This is primarily useful to things like the instruction combiner, but can | |||
24 | // also be useful for static analysis tools or code generators. | |||
25 | // | |||
26 | //===----------------------------------------------------------------------===// | |||
27 | ||||
28 | #ifndef LLVM_IR_PATTERNMATCH_H | |||
29 | #define LLVM_IR_PATTERNMATCH_H | |||
30 | ||||
31 | #include "llvm/ADT/APFloat.h" | |||
32 | #include "llvm/ADT/APInt.h" | |||
33 | #include "llvm/IR/Constant.h" | |||
34 | #include "llvm/IR/Constants.h" | |||
35 | #include "llvm/IR/DataLayout.h" | |||
36 | #include "llvm/IR/InstrTypes.h" | |||
37 | #include "llvm/IR/Instruction.h" | |||
38 | #include "llvm/IR/Instructions.h" | |||
39 | #include "llvm/IR/IntrinsicInst.h" | |||
40 | #include "llvm/IR/Intrinsics.h" | |||
41 | #include "llvm/IR/Operator.h" | |||
42 | #include "llvm/IR/Value.h" | |||
43 | #include "llvm/Support/Casting.h" | |||
44 | #include <cstdint> | |||
45 | ||||
46 | namespace llvm { | |||
47 | namespace PatternMatch { | |||
48 | ||||
49 | template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) { | |||
50 | return const_cast<Pattern &>(P).match(V); | |||
51 | } | |||
52 | ||||
53 | template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) { | |||
54 | return const_cast<Pattern &>(P).match(Mask); | |||
55 | } | |||
56 | ||||
57 | template <typename SubPattern_t> struct OneUse_match { | |||
58 | SubPattern_t SubPattern; | |||
59 | ||||
60 | OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} | |||
61 | ||||
62 | template <typename OpTy> bool match(OpTy *V) { | |||
63 | return V->hasOneUse() && SubPattern.match(V); | |||
64 | } | |||
65 | }; | |||
66 | ||||
67 | template <typename T> inline OneUse_match<T> m_OneUse(const T &SubPattern) { | |||
68 | return SubPattern; | |||
69 | } | |||
70 | ||||
71 | template <typename Class> struct class_match { | |||
72 | template <typename ITy> bool match(ITy *V) { return isa<Class>(V); } | |||
73 | }; | |||
74 | ||||
75 | /// Match an arbitrary value and ignore it. | |||
76 | inline class_match<Value> m_Value() { return class_match<Value>(); } | |||
77 | ||||
78 | /// Match an arbitrary unary operation and ignore it. | |||
79 | inline class_match<UnaryOperator> m_UnOp() { | |||
80 | return class_match<UnaryOperator>(); | |||
81 | } | |||
82 | ||||
83 | /// Match an arbitrary binary operation and ignore it. | |||
84 | inline class_match<BinaryOperator> m_BinOp() { | |||
85 | return class_match<BinaryOperator>(); | |||
86 | } | |||
87 | ||||
88 | /// Matches any compare instruction and ignore it. | |||
89 | inline class_match<CmpInst> m_Cmp() { return class_match<CmpInst>(); } | |||
90 | ||||
91 | struct undef_match { | |||
92 | static bool check(const Value *V) { | |||
93 | if (isa<UndefValue>(V)) | |||
94 | return true; | |||
95 | ||||
96 | const auto *CA = dyn_cast<ConstantAggregate>(V); | |||
97 | if (!CA) | |||
98 | return false; | |||
99 | ||||
100 | SmallPtrSet<const ConstantAggregate *, 8> Seen; | |||
101 | SmallVector<const ConstantAggregate *, 8> Worklist; | |||
102 | ||||
103 | // Either UndefValue, PoisonValue, or an aggregate that only contains | |||
104 | // these is accepted by matcher. | |||
105 | // CheckValue returns false if CA cannot satisfy this constraint. | |||
106 | auto CheckValue = [&](const ConstantAggregate *CA) { | |||
107 | for (const Value *Op : CA->operand_values()) { | |||
108 | if (isa<UndefValue>(Op)) | |||
109 | continue; | |||
110 | ||||
111 | const auto *CA = dyn_cast<ConstantAggregate>(Op); | |||
112 | if (!CA) | |||
113 | return false; | |||
114 | if (Seen.insert(CA).second) | |||
115 | Worklist.emplace_back(CA); | |||
116 | } | |||
117 | ||||
118 | return true; | |||
119 | }; | |||
120 | ||||
121 | if (!CheckValue(CA)) | |||
122 | return false; | |||
123 | ||||
124 | while (!Worklist.empty()) { | |||
125 | if (!CheckValue(Worklist.pop_back_val())) | |||
126 | return false; | |||
127 | } | |||
128 | return true; | |||
129 | } | |||
130 | template <typename ITy> bool match(ITy *V) { return check(V); } | |||
131 | }; | |||
132 | ||||
133 | /// Match an arbitrary undef constant. This matches poison as well. | |||
134 | /// If this is an aggregate and contains a non-aggregate element that is | |||
135 | /// neither undef nor poison, the aggregate is not matched. | |||
136 | inline auto m_Undef() { return undef_match(); } | |||
137 | ||||
138 | /// Match an arbitrary poison constant. | |||
139 | inline class_match<PoisonValue> m_Poison() { | |||
140 | return class_match<PoisonValue>(); | |||
141 | } | |||
142 | ||||
143 | /// Match an arbitrary Constant and ignore it. | |||
144 | inline class_match<Constant> m_Constant() { return class_match<Constant>(); } | |||
145 | ||||
146 | /// Match an arbitrary ConstantInt and ignore it. | |||
147 | inline class_match<ConstantInt> m_ConstantInt() { | |||
148 | return class_match<ConstantInt>(); | |||
149 | } | |||
150 | ||||
151 | /// Match an arbitrary ConstantFP and ignore it. | |||
152 | inline class_match<ConstantFP> m_ConstantFP() { | |||
153 | return class_match<ConstantFP>(); | |||
154 | } | |||
155 | ||||
156 | struct constantexpr_match { | |||
157 | template <typename ITy> bool match(ITy *V) { | |||
158 | auto *C = dyn_cast<Constant>(V); | |||
159 | return C && (isa<ConstantExpr>(C) || C->containsConstantExpression()); | |||
160 | } | |||
161 | }; | |||
162 | ||||
163 | /// Match a constant expression or a constant that contains a constant | |||
164 | /// expression. | |||
165 | inline constantexpr_match m_ConstantExpr() { return constantexpr_match(); } | |||
166 | ||||
167 | /// Match an arbitrary basic block value and ignore it. | |||
168 | inline class_match<BasicBlock> m_BasicBlock() { | |||
169 | return class_match<BasicBlock>(); | |||
170 | } | |||
171 | ||||
172 | /// Inverting matcher | |||
173 | template <typename Ty> struct match_unless { | |||
174 | Ty M; | |||
175 | ||||
176 | match_unless(const Ty &Matcher) : M(Matcher) {} | |||
177 | ||||
178 | template <typename ITy> bool match(ITy *V) { return !M.match(V); } | |||
179 | }; | |||
180 | ||||
181 | /// Match if the inner matcher does *NOT* match. | |||
182 | template <typename Ty> inline match_unless<Ty> m_Unless(const Ty &M) { | |||
183 | return match_unless<Ty>(M); | |||
184 | } | |||
185 | ||||
186 | /// Matching combinators | |||
187 | template <typename LTy, typename RTy> struct match_combine_or { | |||
188 | LTy L; | |||
189 | RTy R; | |||
190 | ||||
191 | match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} | |||
192 | ||||
193 | template <typename ITy> bool match(ITy *V) { | |||
194 | if (L.match(V)) | |||
195 | return true; | |||
196 | if (R.match(V)) | |||
197 | return true; | |||
198 | return false; | |||
199 | } | |||
200 | }; | |||
201 | ||||
202 | template <typename LTy, typename RTy> struct match_combine_and { | |||
203 | LTy L; | |||
204 | RTy R; | |||
205 | ||||
206 | match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} | |||
207 | ||||
208 | template <typename ITy> bool match(ITy *V) { | |||
209 | if (L.match(V)) | |||
210 | if (R.match(V)) | |||
211 | return true; | |||
212 | return false; | |||
213 | } | |||
214 | }; | |||
215 | ||||
216 | /// Combine two pattern matchers matching L || R | |||
217 | template <typename LTy, typename RTy> | |||
218 | inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) { | |||
219 | return match_combine_or<LTy, RTy>(L, R); | |||
220 | } | |||
221 | ||||
222 | /// Combine two pattern matchers matching L && R | |||
223 | template <typename LTy, typename RTy> | |||
224 | inline match_combine_and<LTy, RTy> m_CombineAnd(const LTy &L, const RTy &R) { | |||
225 | return match_combine_and<LTy, RTy>(L, R); | |||
226 | } | |||
227 | ||||
228 | struct apint_match { | |||
229 | const APInt *&Res; | |||
230 | bool AllowUndef; | |||
231 | ||||
232 | apint_match(const APInt *&Res, bool AllowUndef) | |||
233 | : Res(Res), AllowUndef(AllowUndef) {} | |||
234 | ||||
235 | template <typename ITy> bool match(ITy *V) { | |||
236 | if (auto *CI = dyn_cast<ConstantInt>(V)) { | |||
237 | Res = &CI->getValue(); | |||
238 | return true; | |||
239 | } | |||
240 | if (V->getType()->isVectorTy()) | |||
241 | if (const auto *C = dyn_cast<Constant>(V)) | |||
242 | if (auto *CI = | |||
243 | dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndef))) { | |||
244 | Res = &CI->getValue(); | |||
245 | return true; | |||
246 | } | |||
247 | return false; | |||
248 | } | |||
249 | }; | |||
250 | // Either constexpr if or renaming ConstantFP::getValueAPF to | |||
251 | // ConstantFP::getValue is needed to do it via single template | |||
252 | // function for both apint/apfloat. | |||
253 | struct apfloat_match { | |||
254 | const APFloat *&Res; | |||
255 | bool AllowUndef; | |||
256 | ||||
257 | apfloat_match(const APFloat *&Res, bool AllowUndef) | |||
258 | : Res(Res), AllowUndef(AllowUndef) {} | |||
259 | ||||
260 | template <typename ITy> bool match(ITy *V) { | |||
261 | if (auto *CI = dyn_cast<ConstantFP>(V)) { | |||
262 | Res = &CI->getValueAPF(); | |||
263 | return true; | |||
264 | } | |||
265 | if (V->getType()->isVectorTy()) | |||
266 | if (const auto *C = dyn_cast<Constant>(V)) | |||
267 | if (auto *CI = | |||
268 | dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndef))) { | |||
269 | Res = &CI->getValueAPF(); | |||
270 | return true; | |||
271 | } | |||
272 | return false; | |||
273 | } | |||
274 | }; | |||
275 | ||||
276 | /// Match a ConstantInt or splatted ConstantVector, binding the | |||
277 | /// specified pointer to the contained APInt. | |||
278 | inline apint_match m_APInt(const APInt *&Res) { | |||
279 | // Forbid undefs by default to maintain previous behavior. | |||
280 | return apint_match(Res, /* AllowUndef */ false); | |||
281 | } | |||
282 | ||||
283 | /// Match APInt while allowing undefs in splat vector constants. | |||
284 | inline apint_match m_APIntAllowUndef(const APInt *&Res) { | |||
285 | return apint_match(Res, /* AllowUndef */ true); | |||
286 | } | |||
287 | ||||
288 | /// Match APInt while forbidding undefs in splat vector constants. | |||
289 | inline apint_match m_APIntForbidUndef(const APInt *&Res) { | |||
290 | return apint_match(Res, /* AllowUndef */ false); | |||
291 | } | |||
292 | ||||
293 | /// Match a ConstantFP or splatted ConstantVector, binding the | |||
294 | /// specified pointer to the contained APFloat. | |||
295 | inline apfloat_match m_APFloat(const APFloat *&Res) { | |||
296 | // Forbid undefs by default to maintain previous behavior. | |||
297 | return apfloat_match(Res, /* AllowUndef */ false); | |||
298 | } | |||
299 | ||||
300 | /// Match APFloat while allowing undefs in splat vector constants. | |||
301 | inline apfloat_match m_APFloatAllowUndef(const APFloat *&Res) { | |||
302 | return apfloat_match(Res, /* AllowUndef */ true); | |||
303 | } | |||
304 | ||||
305 | /// Match APFloat while forbidding undefs in splat vector constants. | |||
306 | inline apfloat_match m_APFloatForbidUndef(const APFloat *&Res) { | |||
307 | return apfloat_match(Res, /* AllowUndef */ false); | |||
308 | } | |||
309 | ||||
310 | template <int64_t Val> struct constantint_match { | |||
311 | template <typename ITy> bool match(ITy *V) { | |||
312 | if (const auto *CI = dyn_cast<ConstantInt>(V)) { | |||
313 | const APInt &CIV = CI->getValue(); | |||
314 | if (Val >= 0) | |||
315 | return CIV == static_cast<uint64_t>(Val); | |||
316 | // If Val is negative, and CI is shorter than it, truncate to the right | |||
317 | // number of bits. If it is larger, then we have to sign extend. Just | |||
318 | // compare their negated values. | |||
319 | return -CIV == -Val; | |||
320 | } | |||
321 | return false; | |||
322 | } | |||
323 | }; | |||
324 | ||||
325 | /// Match a ConstantInt with a specific value. | |||
326 | template <int64_t Val> inline constantint_match<Val> m_ConstantInt() { | |||
327 | return constantint_match<Val>(); | |||
328 | } | |||
329 | ||||
330 | /// This helper class is used to match constant scalars, vector splats, | |||
331 | /// and fixed width vectors that satisfy a specified predicate. | |||
332 | /// For fixed width vector constants, undefined elements are ignored. | |||
333 | template <typename Predicate, typename ConstantVal> | |||
334 | struct cstval_pred_ty : public Predicate { | |||
335 | template <typename ITy> bool match(ITy *V) { | |||
336 | if (const auto *CV = dyn_cast<ConstantVal>(V)) | |||
337 | return this->isValue(CV->getValue()); | |||
338 | if (const auto *VTy = dyn_cast<VectorType>(V->getType())) { | |||
339 | if (const auto *C = dyn_cast<Constant>(V)) { | |||
340 | if (const auto *CV = dyn_cast_or_null<ConstantVal>(C->getSplatValue())) | |||
341 | return this->isValue(CV->getValue()); | |||
342 | ||||
343 | // Number of elements of a scalable vector unknown at compile time | |||
344 | auto *FVTy = dyn_cast<FixedVectorType>(VTy); | |||
345 | if (!FVTy) | |||
346 | return false; | |||
347 | ||||
348 | // Non-splat vector constant: check each element for a match. | |||
349 | unsigned NumElts = FVTy->getNumElements(); | |||
350 | assert(NumElts != 0 && "Constant vector with no elements?")(static_cast <bool> (NumElts != 0 && "Constant vector with no elements?" ) ? void (0) : __assert_fail ("NumElts != 0 && \"Constant vector with no elements?\"" , "llvm/include/llvm/IR/PatternMatch.h", 350, __extension__ __PRETTY_FUNCTION__ )); | |||
351 | bool HasNonUndefElements = false; | |||
352 | for (unsigned i = 0; i != NumElts; ++i) { | |||
353 | Constant *Elt = C->getAggregateElement(i); | |||
354 | if (!Elt) | |||
355 | return false; | |||
356 | if (isa<UndefValue>(Elt)) | |||
357 | continue; | |||
358 | auto *CV = dyn_cast<ConstantVal>(Elt); | |||
359 | if (!CV || !this->isValue(CV->getValue())) | |||
360 | return false; | |||
361 | HasNonUndefElements = true; | |||
362 | } | |||
363 | return HasNonUndefElements; | |||
364 | } | |||
365 | } | |||
366 | return false; | |||
367 | } | |||
368 | }; | |||
369 | ||||
370 | /// specialization of cstval_pred_ty for ConstantInt | |||
371 | template <typename Predicate> | |||
372 | using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>; | |||
373 | ||||
374 | /// specialization of cstval_pred_ty for ConstantFP | |||
375 | template <typename Predicate> | |||
376 | using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>; | |||
377 | ||||
378 | /// This helper class is used to match scalar and vector constants that | |||
379 | /// satisfy a specified predicate, and bind them to an APInt. | |||
380 | template <typename Predicate> struct api_pred_ty : public Predicate { | |||
381 | const APInt *&Res; | |||
382 | ||||
383 | api_pred_ty(const APInt *&R) : Res(R) {} | |||
384 | ||||
385 | template <typename ITy> bool match(ITy *V) { | |||
386 | if (const auto *CI = dyn_cast<ConstantInt>(V)) | |||
387 | if (this->isValue(CI->getValue())) { | |||
388 | Res = &CI->getValue(); | |||
389 | return true; | |||
390 | } | |||
391 | if (V->getType()->isVectorTy()) | |||
392 | if (const auto *C = dyn_cast<Constant>(V)) | |||
393 | if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue())) | |||
394 | if (this->isValue(CI->getValue())) { | |||
395 | Res = &CI->getValue(); | |||
396 | return true; | |||
397 | } | |||
398 | ||||
399 | return false; | |||
400 | } | |||
401 | }; | |||
402 | ||||
403 | /// This helper class is used to match scalar and vector constants that | |||
404 | /// satisfy a specified predicate, and bind them to an APFloat. | |||
405 | /// Undefs are allowed in splat vector constants. | |||
406 | template <typename Predicate> struct apf_pred_ty : public Predicate { | |||
407 | const APFloat *&Res; | |||
408 | ||||
409 | apf_pred_ty(const APFloat *&R) : Res(R) {} | |||
410 | ||||
411 | template <typename ITy> bool match(ITy *V) { | |||
412 | if (const auto *CI = dyn_cast<ConstantFP>(V)) | |||
413 | if (this->isValue(CI->getValue())) { | |||
414 | Res = &CI->getValue(); | |||
415 | return true; | |||
416 | } | |||
417 | if (V->getType()->isVectorTy()) | |||
418 | if (const auto *C = dyn_cast<Constant>(V)) | |||
419 | if (auto *CI = dyn_cast_or_null<ConstantFP>( | |||
420 | C->getSplatValue(/* AllowUndef */ true))) | |||
421 | if (this->isValue(CI->getValue())) { | |||
422 | Res = &CI->getValue(); | |||
423 | return true; | |||
424 | } | |||
425 | ||||
426 | return false; | |||
427 | } | |||
428 | }; | |||
429 | ||||
430 | /////////////////////////////////////////////////////////////////////////////// | |||
431 | // | |||
432 | // Encapsulate constant value queries for use in templated predicate matchers. | |||
433 | // This allows checking if constants match using compound predicates and works | |||
434 | // with vector constants, possibly with relaxed constraints. For example, ignore | |||
435 | // undef values. | |||
436 | // | |||
437 | /////////////////////////////////////////////////////////////////////////////// | |||
438 | ||||
439 | struct is_any_apint { | |||
440 | bool isValue(const APInt &C) { return true; } | |||
441 | }; | |||
442 | /// Match an integer or vector with any integral constant. | |||
443 | /// For vectors, this includes constants with undefined elements. | |||
444 | inline cst_pred_ty<is_any_apint> m_AnyIntegralConstant() { | |||
445 | return cst_pred_ty<is_any_apint>(); | |||
446 | } | |||
447 | ||||
448 | struct is_all_ones { | |||
449 | bool isValue(const APInt &C) { return C.isAllOnes(); } | |||
450 | }; | |||
451 | /// Match an integer or vector with all bits set. | |||
452 | /// For vectors, this includes constants with undefined elements. | |||
453 | inline cst_pred_ty<is_all_ones> m_AllOnes() { | |||
454 | return cst_pred_ty<is_all_ones>(); | |||
455 | } | |||
456 | ||||
457 | struct is_maxsignedvalue { | |||
458 | bool isValue(const APInt &C) { return C.isMaxSignedValue(); } | |||
459 | }; | |||
460 | /// Match an integer or vector with values having all bits except for the high | |||
461 | /// bit set (0x7f...). | |||
462 | /// For vectors, this includes constants with undefined elements. | |||
463 | inline cst_pred_ty<is_maxsignedvalue> m_MaxSignedValue() { | |||
464 | return cst_pred_ty<is_maxsignedvalue>(); | |||
465 | } | |||
466 | inline api_pred_ty<is_maxsignedvalue> m_MaxSignedValue(const APInt *&V) { | |||
467 | return V; | |||
468 | } | |||
469 | ||||
470 | struct is_negative { | |||
471 | bool isValue(const APInt &C) { return C.isNegative(); } | |||
472 | }; | |||
473 | /// Match an integer or vector of negative values. | |||
474 | /// For vectors, this includes constants with undefined elements. | |||
475 | inline cst_pred_ty<is_negative> m_Negative() { | |||
476 | return cst_pred_ty<is_negative>(); | |||
477 | } | |||
478 | inline api_pred_ty<is_negative> m_Negative(const APInt *&V) { return V; } | |||
479 | ||||
480 | struct is_nonnegative { | |||
481 | bool isValue(const APInt &C) { return C.isNonNegative(); } | |||
482 | }; | |||
483 | /// Match an integer or vector of non-negative values. | |||
484 | /// For vectors, this includes constants with undefined elements. | |||
485 | inline cst_pred_ty<is_nonnegative> m_NonNegative() { | |||
486 | return cst_pred_ty<is_nonnegative>(); | |||
487 | } | |||
488 | inline api_pred_ty<is_nonnegative> m_NonNegative(const APInt *&V) { return V; } | |||
489 | ||||
490 | struct is_strictlypositive { | |||
491 | bool isValue(const APInt &C) { return C.isStrictlyPositive(); } | |||
492 | }; | |||
493 | /// Match an integer or vector of strictly positive values. | |||
494 | /// For vectors, this includes constants with undefined elements. | |||
495 | inline cst_pred_ty<is_strictlypositive> m_StrictlyPositive() { | |||
496 | return cst_pred_ty<is_strictlypositive>(); | |||
497 | } | |||
498 | inline api_pred_ty<is_strictlypositive> m_StrictlyPositive(const APInt *&V) { | |||
499 | return V; | |||
500 | } | |||
501 | ||||
502 | struct is_nonpositive { | |||
503 | bool isValue(const APInt &C) { return C.isNonPositive(); } | |||
504 | }; | |||
505 | /// Match an integer or vector of non-positive values. | |||
506 | /// For vectors, this includes constants with undefined elements. | |||
507 | inline cst_pred_ty<is_nonpositive> m_NonPositive() { | |||
508 | return cst_pred_ty<is_nonpositive>(); | |||
509 | } | |||
510 | inline api_pred_ty<is_nonpositive> m_NonPositive(const APInt *&V) { return V; } | |||
511 | ||||
512 | struct is_one { | |||
513 | bool isValue(const APInt &C) { return C.isOne(); } | |||
514 | }; | |||
515 | /// Match an integer 1 or a vector with all elements equal to 1. | |||
516 | /// For vectors, this includes constants with undefined elements. | |||
517 | inline cst_pred_ty<is_one> m_One() { return cst_pred_ty<is_one>(); } | |||
518 | ||||
519 | struct is_zero_int { | |||
520 | bool isValue(const APInt &C) { return C.isZero(); } | |||
521 | }; | |||
522 | /// Match an integer 0 or a vector with all elements equal to 0. | |||
523 | /// For vectors, this includes constants with undefined elements. | |||
524 | inline cst_pred_ty<is_zero_int> m_ZeroInt() { | |||
525 | return cst_pred_ty<is_zero_int>(); | |||
526 | } | |||
527 | ||||
528 | struct is_zero { | |||
529 | template <typename ITy> bool match(ITy *V) { | |||
530 | auto *C = dyn_cast<Constant>(V); | |||
531 | // FIXME: this should be able to do something for scalable vectors | |||
532 | return C && (C->isNullValue() || cst_pred_ty<is_zero_int>().match(C)); | |||
533 | } | |||
534 | }; | |||
535 | /// Match any null constant or a vector with all elements equal to 0. | |||
536 | /// For vectors, this includes constants with undefined elements. | |||
537 | inline is_zero m_Zero() { return is_zero(); } | |||
538 | ||||
539 | struct is_power2 { | |||
540 | bool isValue(const APInt &C) { return C.isPowerOf2(); } | |||
541 | }; | |||
542 | /// Match an integer or vector power-of-2. | |||
543 | /// For vectors, this includes constants with undefined elements. | |||
544 | inline cst_pred_ty<is_power2> m_Power2() { return cst_pred_ty<is_power2>(); } | |||
545 | inline api_pred_ty<is_power2> m_Power2(const APInt *&V) { return V; } | |||
546 | ||||
547 | struct is_negated_power2 { | |||
548 | bool isValue(const APInt &C) { return C.isNegatedPowerOf2(); } | |||
549 | }; | |||
550 | /// Match a integer or vector negated power-of-2. | |||
551 | /// For vectors, this includes constants with undefined elements. | |||
552 | inline cst_pred_ty<is_negated_power2> m_NegatedPower2() { | |||
553 | return cst_pred_ty<is_negated_power2>(); | |||
554 | } | |||
555 | inline api_pred_ty<is_negated_power2> m_NegatedPower2(const APInt *&V) { | |||
556 | return V; | |||
557 | } | |||
558 | ||||
559 | struct is_power2_or_zero { | |||
560 | bool isValue(const APInt &C) { return !C || C.isPowerOf2(); } | |||
561 | }; | |||
562 | /// Match an integer or vector of 0 or power-of-2 values. | |||
563 | /// For vectors, this includes constants with undefined elements. | |||
564 | inline cst_pred_ty<is_power2_or_zero> m_Power2OrZero() { | |||
565 | return cst_pred_ty<is_power2_or_zero>(); | |||
566 | } | |||
567 | inline api_pred_ty<is_power2_or_zero> m_Power2OrZero(const APInt *&V) { | |||
568 | return V; | |||
569 | } | |||
570 | ||||
571 | struct is_sign_mask { | |||
572 | bool isValue(const APInt &C) { return C.isSignMask(); } | |||
573 | }; | |||
574 | /// Match an integer or vector with only the sign bit(s) set. | |||
575 | /// For vectors, this includes constants with undefined elements. | |||
576 | inline cst_pred_ty<is_sign_mask> m_SignMask() { | |||
577 | return cst_pred_ty<is_sign_mask>(); | |||
578 | } | |||
579 | ||||
580 | struct is_lowbit_mask { | |||
581 | bool isValue(const APInt &C) { return C.isMask(); } | |||
582 | }; | |||
583 | /// Match an integer or vector with only the low bit(s) set. | |||
584 | /// For vectors, this includes constants with undefined elements. | |||
585 | inline cst_pred_ty<is_lowbit_mask> m_LowBitMask() { | |||
586 | return cst_pred_ty<is_lowbit_mask>(); | |||
587 | } | |||
588 | inline api_pred_ty<is_lowbit_mask> m_LowBitMask(const APInt *&V) { return V; } | |||
589 | ||||
590 | struct icmp_pred_with_threshold { | |||
591 | ICmpInst::Predicate Pred; | |||
592 | const APInt *Thr; | |||
593 | bool isValue(const APInt &C) { return ICmpInst::compare(C, *Thr, Pred); } | |||
594 | }; | |||
595 | /// Match an integer or vector with every element comparing 'pred' (eg/ne/...) | |||
596 | /// to Threshold. For vectors, this includes constants with undefined elements. | |||
597 | inline cst_pred_ty<icmp_pred_with_threshold> | |||
598 | m_SpecificInt_ICMP(ICmpInst::Predicate Predicate, const APInt &Threshold) { | |||
599 | cst_pred_ty<icmp_pred_with_threshold> P; | |||
600 | P.Pred = Predicate; | |||
601 | P.Thr = &Threshold; | |||
602 | return P; | |||
603 | } | |||
604 | ||||
605 | struct is_nan { | |||
606 | bool isValue(const APFloat &C) { return C.isNaN(); } | |||
607 | }; | |||
608 | /// Match an arbitrary NaN constant. This includes quiet and signalling nans. | |||
609 | /// For vectors, this includes constants with undefined elements. | |||
610 | inline cstfp_pred_ty<is_nan> m_NaN() { return cstfp_pred_ty<is_nan>(); } | |||
611 | ||||
612 | struct is_nonnan { | |||
613 | bool isValue(const APFloat &C) { return !C.isNaN(); } | |||
614 | }; | |||
615 | /// Match a non-NaN FP constant. | |||
616 | /// For vectors, this includes constants with undefined elements. | |||
617 | inline cstfp_pred_ty<is_nonnan> m_NonNaN() { | |||
618 | return cstfp_pred_ty<is_nonnan>(); | |||
619 | } | |||
620 | ||||
621 | struct is_inf { | |||
622 | bool isValue(const APFloat &C) { return C.isInfinity(); } | |||
623 | }; | |||
624 | /// Match a positive or negative infinity FP constant. | |||
625 | /// For vectors, this includes constants with undefined elements. | |||
626 | inline cstfp_pred_ty<is_inf> m_Inf() { return cstfp_pred_ty<is_inf>(); } | |||
627 | ||||
628 | struct is_noninf { | |||
629 | bool isValue(const APFloat &C) { return !C.isInfinity(); } | |||
630 | }; | |||
631 | /// Match a non-infinity FP constant, i.e. finite or NaN. | |||
632 | /// For vectors, this includes constants with undefined elements. | |||
633 | inline cstfp_pred_ty<is_noninf> m_NonInf() { | |||
634 | return cstfp_pred_ty<is_noninf>(); | |||
635 | } | |||
636 | ||||
637 | struct is_finite { | |||
638 | bool isValue(const APFloat &C) { return C.isFinite(); } | |||
639 | }; | |||
640 | /// Match a finite FP constant, i.e. not infinity or NaN. | |||
641 | /// For vectors, this includes constants with undefined elements. | |||
642 | inline cstfp_pred_ty<is_finite> m_Finite() { | |||
643 | return cstfp_pred_ty<is_finite>(); | |||
644 | } | |||
645 | inline apf_pred_ty<is_finite> m_Finite(const APFloat *&V) { return V; } | |||
646 | ||||
647 | struct is_finitenonzero { | |||
648 | bool isValue(const APFloat &C) { return C.isFiniteNonZero(); } | |||
649 | }; | |||
650 | /// Match a finite non-zero FP constant. | |||
651 | /// For vectors, this includes constants with undefined elements. | |||
652 | inline cstfp_pred_ty<is_finitenonzero> m_FiniteNonZero() { | |||
653 | return cstfp_pred_ty<is_finitenonzero>(); | |||
654 | } | |||
655 | inline apf_pred_ty<is_finitenonzero> m_FiniteNonZero(const APFloat *&V) { | |||
656 | return V; | |||
657 | } | |||
658 | ||||
659 | struct is_any_zero_fp { | |||
660 | bool isValue(const APFloat &C) { return C.isZero(); } | |||
661 | }; | |||
662 | /// Match a floating-point negative zero or positive zero. | |||
663 | /// For vectors, this includes constants with undefined elements. | |||
664 | inline cstfp_pred_ty<is_any_zero_fp> m_AnyZeroFP() { | |||
665 | return cstfp_pred_ty<is_any_zero_fp>(); | |||
666 | } | |||
667 | ||||
668 | struct is_pos_zero_fp { | |||
669 | bool isValue(const APFloat &C) { return C.isPosZero(); } | |||
670 | }; | |||
671 | /// Match a floating-point positive zero. | |||
672 | /// For vectors, this includes constants with undefined elements. | |||
673 | inline cstfp_pred_ty<is_pos_zero_fp> m_PosZeroFP() { | |||
674 | return cstfp_pred_ty<is_pos_zero_fp>(); | |||
675 | } | |||
676 | ||||
677 | struct is_neg_zero_fp { | |||
678 | bool isValue(const APFloat &C) { return C.isNegZero(); } | |||
679 | }; | |||
680 | /// Match a floating-point negative zero. | |||
681 | /// For vectors, this includes constants with undefined elements. | |||
682 | inline cstfp_pred_ty<is_neg_zero_fp> m_NegZeroFP() { | |||
683 | return cstfp_pred_ty<is_neg_zero_fp>(); | |||
684 | } | |||
685 | ||||
686 | struct is_non_zero_fp { | |||
687 | bool isValue(const APFloat &C) { return C.isNonZero(); } | |||
688 | }; | |||
689 | /// Match a floating-point non-zero. | |||
690 | /// For vectors, this includes constants with undefined elements. | |||
691 | inline cstfp_pred_ty<is_non_zero_fp> m_NonZeroFP() { | |||
692 | return cstfp_pred_ty<is_non_zero_fp>(); | |||
693 | } | |||
694 | ||||
695 | /////////////////////////////////////////////////////////////////////////////// | |||
696 | ||||
697 | template <typename Class> struct bind_ty { | |||
698 | Class *&VR; | |||
699 | ||||
700 | bind_ty(Class *&V) : VR(V) {} | |||
701 | ||||
702 | template <typename ITy> bool match(ITy *V) { | |||
703 | if (auto *CV = dyn_cast<Class>(V)) { | |||
704 | VR = CV; | |||
705 | return true; | |||
706 | } | |||
707 | return false; | |||
708 | } | |||
709 | }; | |||
710 | ||||
711 | /// Match a value, capturing it if we match. | |||
712 | inline bind_ty<Value> m_Value(Value *&V) { return V; } | |||
713 | inline bind_ty<const Value> m_Value(const Value *&V) { return V; } | |||
714 | ||||
715 | /// Match an instruction, capturing it if we match. | |||
716 | inline bind_ty<Instruction> m_Instruction(Instruction *&I) { return I; } | |||
717 | /// Match a unary operator, capturing it if we match. | |||
718 | inline bind_ty<UnaryOperator> m_UnOp(UnaryOperator *&I) { return I; } | |||
719 | /// Match a binary operator, capturing it if we match. | |||
720 | inline bind_ty<BinaryOperator> m_BinOp(BinaryOperator *&I) { return I; } | |||
721 | /// Match a with overflow intrinsic, capturing it if we match. | |||
722 | inline bind_ty<WithOverflowInst> m_WithOverflowInst(WithOverflowInst *&I) { | |||
723 | return I; | |||
724 | } | |||
725 | inline bind_ty<const WithOverflowInst> | |||
726 | m_WithOverflowInst(const WithOverflowInst *&I) { | |||
727 | return I; | |||
728 | } | |||
729 | ||||
730 | /// Match a Constant, capturing the value if we match. | |||
731 | inline bind_ty<Constant> m_Constant(Constant *&C) { return C; } | |||
732 | ||||
733 | /// Match a ConstantInt, capturing the value if we match. | |||
734 | inline bind_ty<ConstantInt> m_ConstantInt(ConstantInt *&CI) { return CI; } | |||
735 | ||||
736 | /// Match a ConstantFP, capturing the value if we match. | |||
737 | inline bind_ty<ConstantFP> m_ConstantFP(ConstantFP *&C) { return C; } | |||
738 | ||||
739 | /// Match a ConstantExpr, capturing the value if we match. | |||
740 | inline bind_ty<ConstantExpr> m_ConstantExpr(ConstantExpr *&C) { return C; } | |||
741 | ||||
742 | /// Match a basic block value, capturing it if we match. | |||
743 | inline bind_ty<BasicBlock> m_BasicBlock(BasicBlock *&V) { return V; } | |||
744 | inline bind_ty<const BasicBlock> m_BasicBlock(const BasicBlock *&V) { | |||
745 | return V; | |||
746 | } | |||
747 | ||||
748 | /// Match an arbitrary immediate Constant and ignore it. | |||
749 | inline match_combine_and<class_match<Constant>, | |||
750 | match_unless<constantexpr_match>> | |||
751 | m_ImmConstant() { | |||
752 | return m_CombineAnd(m_Constant(), m_Unless(m_ConstantExpr())); | |||
753 | } | |||
754 | ||||
755 | /// Match an immediate Constant, capturing the value if we match. | |||
756 | inline match_combine_and<bind_ty<Constant>, | |||
757 | match_unless<constantexpr_match>> | |||
758 | m_ImmConstant(Constant *&C) { | |||
759 | return m_CombineAnd(m_Constant(C), m_Unless(m_ConstantExpr())); | |||
760 | } | |||
761 | ||||
762 | /// Match a specified Value*. | |||
763 | struct specificval_ty { | |||
764 | const Value *Val; | |||
765 | ||||
766 | specificval_ty(const Value *V) : Val(V) {} | |||
767 | ||||
768 | template <typename ITy> bool match(ITy *V) { return V == Val; } | |||
769 | }; | |||
770 | ||||
771 | /// Match if we have a specific specified value. | |||
772 | inline specificval_ty m_Specific(const Value *V) { return V; } | |||
773 | ||||
774 | /// Stores a reference to the Value *, not the Value * itself, | |||
775 | /// thus can be used in commutative matchers. | |||
776 | template <typename Class> struct deferredval_ty { | |||
777 | Class *const &Val; | |||
778 | ||||
779 | deferredval_ty(Class *const &V) : Val(V) {} | |||
780 | ||||
781 | template <typename ITy> bool match(ITy *const V) { return V == Val; } | |||
782 | }; | |||
783 | ||||
784 | /// Like m_Specific(), but works if the specific value to match is determined | |||
785 | /// as part of the same match() expression. For example: | |||
786 | /// m_Add(m_Value(X), m_Specific(X)) is incorrect, because m_Specific() will | |||
787 | /// bind X before the pattern match starts. | |||
788 | /// m_Add(m_Value(X), m_Deferred(X)) is correct, and will check against | |||
789 | /// whichever value m_Value(X) populated. | |||
790 | inline deferredval_ty<Value> m_Deferred(Value *const &V) { return V; } | |||
791 | inline deferredval_ty<const Value> m_Deferred(const Value *const &V) { | |||
792 | return V; | |||
793 | } | |||
794 | ||||
795 | /// Match a specified floating point value or vector of all elements of | |||
796 | /// that value. | |||
797 | struct specific_fpval { | |||
798 | double Val; | |||
799 | ||||
800 | specific_fpval(double V) : Val(V) {} | |||
801 | ||||
802 | template <typename ITy> bool match(ITy *V) { | |||
803 | if (const auto *CFP = dyn_cast<ConstantFP>(V)) | |||
804 | return CFP->isExactlyValue(Val); | |||
805 | if (V->getType()->isVectorTy()) | |||
806 | if (const auto *C = dyn_cast<Constant>(V)) | |||
807 | if (auto *CFP = dyn_cast_or_null<ConstantFP>(C->getSplatValue())) | |||
808 | return CFP->isExactlyValue(Val); | |||
809 | return false; | |||
810 | } | |||
811 | }; | |||
812 | ||||
813 | /// Match a specific floating point value or vector with all elements | |||
814 | /// equal to the value. | |||
815 | inline specific_fpval m_SpecificFP(double V) { return specific_fpval(V); } | |||
816 | ||||
817 | /// Match a float 1.0 or vector with all elements equal to 1.0. | |||
818 | inline specific_fpval m_FPOne() { return m_SpecificFP(1.0); } | |||
819 | ||||
820 | struct bind_const_intval_ty { | |||
821 | uint64_t &VR; | |||
822 | ||||
823 | bind_const_intval_ty(uint64_t &V) : VR(V) {} | |||
824 | ||||
825 | template <typename ITy> bool match(ITy *V) { | |||
826 | if (const auto *CV = dyn_cast<ConstantInt>(V)) | |||
827 | if (CV->getValue().ule(UINT64_MAX(18446744073709551615UL))) { | |||
828 | VR = CV->getZExtValue(); | |||
829 | return true; | |||
830 | } | |||
831 | return false; | |||
832 | } | |||
833 | }; | |||
834 | ||||
835 | /// Match a specified integer value or vector of all elements of that | |||
836 | /// value. | |||
837 | template <bool AllowUndefs> struct specific_intval { | |||
838 | APInt Val; | |||
839 | ||||
840 | specific_intval(APInt V) : Val(std::move(V)) {} | |||
841 | ||||
842 | template <typename ITy> bool match(ITy *V) { | |||
843 | const auto *CI = dyn_cast<ConstantInt>(V); | |||
844 | if (!CI && V->getType()->isVectorTy()) | |||
845 | if (const auto *C = dyn_cast<Constant>(V)) | |||
846 | CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs)); | |||
847 | ||||
848 | return CI && APInt::isSameValue(CI->getValue(), Val); | |||
849 | } | |||
850 | }; | |||
851 | ||||
852 | /// Match a specific integer value or vector with all elements equal to | |||
853 | /// the value. | |||
854 | inline specific_intval<false> m_SpecificInt(APInt V) { | |||
855 | return specific_intval<false>(std::move(V)); | |||
856 | } | |||
857 | ||||
858 | inline specific_intval<false> m_SpecificInt(uint64_t V) { | |||
859 | return m_SpecificInt(APInt(64, V)); | |||
860 | } | |||
861 | ||||
862 | inline specific_intval<true> m_SpecificIntAllowUndef(APInt V) { | |||
863 | return specific_intval<true>(std::move(V)); | |||
864 | } | |||
865 | ||||
866 | inline specific_intval<true> m_SpecificIntAllowUndef(uint64_t V) { | |||
867 | return m_SpecificIntAllowUndef(APInt(64, V)); | |||
868 | } | |||
869 | ||||
870 | /// Match a ConstantInt and bind to its value. This does not match | |||
871 | /// ConstantInts wider than 64-bits. | |||
872 | inline bind_const_intval_ty m_ConstantInt(uint64_t &V) { return V; } | |||
873 | ||||
874 | /// Match a specified basic block value. | |||
875 | struct specific_bbval { | |||
876 | BasicBlock *Val; | |||
877 | ||||
878 | specific_bbval(BasicBlock *Val) : Val(Val) {} | |||
879 | ||||
880 | template <typename ITy> bool match(ITy *V) { | |||
881 | const auto *BB = dyn_cast<BasicBlock>(V); | |||
882 | return BB && BB == Val; | |||
883 | } | |||
884 | }; | |||
885 | ||||
886 | /// Match a specific basic block value. | |||
887 | inline specific_bbval m_SpecificBB(BasicBlock *BB) { | |||
888 | return specific_bbval(BB); | |||
889 | } | |||
890 | ||||
891 | /// A commutative-friendly version of m_Specific(). | |||
892 | inline deferredval_ty<BasicBlock> m_Deferred(BasicBlock *const &BB) { | |||
893 | return BB; | |||
894 | } | |||
895 | inline deferredval_ty<const BasicBlock> | |||
896 | m_Deferred(const BasicBlock *const &BB) { | |||
897 | return BB; | |||
898 | } | |||
899 | ||||
900 | //===----------------------------------------------------------------------===// | |||
901 | // Matcher for any binary operator. | |||
902 | // | |||
903 | template <typename LHS_t, typename RHS_t, bool Commutable = false> | |||
904 | struct AnyBinaryOp_match { | |||
905 | LHS_t L; | |||
906 | RHS_t R; | |||
907 | ||||
908 | // The evaluation order is always stable, regardless of Commutability. | |||
909 | // The LHS is always matched first. | |||
910 | AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} | |||
911 | ||||
912 | template <typename OpTy> bool match(OpTy *V) { | |||
913 | if (auto *I = dyn_cast<BinaryOperator>(V)) | |||
914 | return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || | |||
915 | (Commutable && L.match(I->getOperand(1)) && | |||
916 | R.match(I->getOperand(0))); | |||
917 | return false; | |||
918 | } | |||
919 | }; | |||
920 | ||||
921 | template <typename LHS, typename RHS> | |||
922 | inline AnyBinaryOp_match<LHS, RHS> m_BinOp(const LHS &L, const RHS &R) { | |||
923 | return AnyBinaryOp_match<LHS, RHS>(L, R); | |||
924 | } | |||
925 | ||||
926 | //===----------------------------------------------------------------------===// | |||
927 | // Matcher for any unary operator. | |||
928 | // TODO fuse unary, binary matcher into n-ary matcher | |||
929 | // | |||
930 | template <typename OP_t> struct AnyUnaryOp_match { | |||
931 | OP_t X; | |||
932 | ||||
933 | AnyUnaryOp_match(const OP_t &X) : X(X) {} | |||
934 | ||||
935 | template <typename OpTy> bool match(OpTy *V) { | |||
936 | if (auto *I = dyn_cast<UnaryOperator>(V)) | |||
937 | return X.match(I->getOperand(0)); | |||
938 | return false; | |||
939 | } | |||
940 | }; | |||
941 | ||||
942 | template <typename OP_t> inline AnyUnaryOp_match<OP_t> m_UnOp(const OP_t &X) { | |||
943 | return AnyUnaryOp_match<OP_t>(X); | |||
944 | } | |||
945 | ||||
946 | //===----------------------------------------------------------------------===// | |||
947 | // Matchers for specific binary operators. | |||
948 | // | |||
949 | ||||
950 | template <typename LHS_t, typename RHS_t, unsigned Opcode, | |||
951 | bool Commutable = false> | |||
952 | struct BinaryOp_match { | |||
953 | LHS_t L; | |||
954 | RHS_t R; | |||
955 | ||||
956 | // The evaluation order is always stable, regardless of Commutability. | |||
957 | // The LHS is always matched first. | |||
958 | BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} | |||
959 | ||||
960 | template <typename OpTy> inline bool match(unsigned Opc, OpTy *V) { | |||
961 | if (V->getValueID() == Value::InstructionVal + Opc) { | |||
| ||||
962 | auto *I = cast<BinaryOperator>(V); | |||
963 | return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || | |||
964 | (Commutable && L.match(I->getOperand(1)) && | |||
965 | R.match(I->getOperand(0))); | |||
966 | } | |||
967 | if (auto *CE = dyn_cast<ConstantExpr>(V)) | |||
968 | return CE->getOpcode() == Opc && | |||
969 | ((L.match(CE->getOperand(0)) && R.match(CE->getOperand(1))) || | |||
970 | (Commutable && L.match(CE->getOperand(1)) && | |||
971 | R.match(CE->getOperand(0)))); | |||
972 | return false; | |||
973 | } | |||
974 | ||||
975 | template <typename OpTy> bool match(OpTy *V) { return match(Opcode, V); } | |||
976 | }; | |||
977 | ||||
978 | template <typename LHS, typename RHS> | |||
979 | inline BinaryOp_match<LHS, RHS, Instruction::Add> m_Add(const LHS &L, | |||
980 | const RHS &R) { | |||
981 | return BinaryOp_match<LHS, RHS, Instruction::Add>(L, R); | |||
982 | } | |||
983 | ||||
984 | template <typename LHS, typename RHS> | |||
985 | inline BinaryOp_match<LHS, RHS, Instruction::FAdd> m_FAdd(const LHS &L, | |||
986 | const RHS &R) { | |||
987 | return BinaryOp_match<LHS, RHS, Instruction::FAdd>(L, R); | |||
988 | } | |||
989 | ||||
990 | template <typename LHS, typename RHS> | |||
991 | inline BinaryOp_match<LHS, RHS, Instruction::Sub> m_Sub(const LHS &L, | |||
992 | const RHS &R) { | |||
993 | return BinaryOp_match<LHS, RHS, Instruction::Sub>(L, R); | |||
994 | } | |||
995 | ||||
996 | template <typename LHS, typename RHS> | |||
997 | inline BinaryOp_match<LHS, RHS, Instruction::FSub> m_FSub(const LHS &L, | |||
998 | const RHS &R) { | |||
999 | return BinaryOp_match<LHS, RHS, Instruction::FSub>(L, R); | |||
1000 | } | |||
1001 | ||||
1002 | template <typename Op_t> struct FNeg_match { | |||
1003 | Op_t X; | |||
1004 | ||||
1005 | FNeg_match(const Op_t &Op) : X(Op) {} | |||
1006 | template <typename OpTy> bool match(OpTy *V) { | |||
1007 | auto *FPMO = dyn_cast<FPMathOperator>(V); | |||
1008 | if (!FPMO) | |||
1009 | return false; | |||
1010 | ||||
1011 | if (FPMO->getOpcode() == Instruction::FNeg) | |||
1012 | return X.match(FPMO->getOperand(0)); | |||
1013 | ||||
1014 | if (FPMO->getOpcode() == Instruction::FSub) { | |||
1015 | if (FPMO->hasNoSignedZeros()) { | |||
1016 | // With 'nsz', any zero goes. | |||
1017 | if (!cstfp_pred_ty<is_any_zero_fp>().match(FPMO->getOperand(0))) | |||
1018 | return false; | |||
1019 | } else { | |||
1020 | // Without 'nsz', we need fsub -0.0, X exactly. | |||
1021 | if (!cstfp_pred_ty<is_neg_zero_fp>().match(FPMO->getOperand(0))) | |||
1022 | return false; | |||
1023 | } | |||
1024 | ||||
1025 | return X.match(FPMO->getOperand(1)); | |||
1026 | } | |||
1027 | ||||
1028 | return false; | |||
1029 | } | |||
1030 | }; | |||
1031 | ||||
1032 | /// Match 'fneg X' as 'fsub -0.0, X'. | |||
1033 | template <typename OpTy> inline FNeg_match<OpTy> m_FNeg(const OpTy &X) { | |||
1034 | return FNeg_match<OpTy>(X); | |||
1035 | } | |||
1036 | ||||
1037 | /// Match 'fneg X' as 'fsub +-0.0, X'. | |||
1038 | template <typename RHS> | |||
1039 | inline BinaryOp_match<cstfp_pred_ty<is_any_zero_fp>, RHS, Instruction::FSub> | |||
1040 | m_FNegNSZ(const RHS &X) { | |||
1041 | return m_FSub(m_AnyZeroFP(), X); | |||
1042 | } | |||
1043 | ||||
1044 | template <typename LHS, typename RHS> | |||
1045 | inline BinaryOp_match<LHS, RHS, Instruction::Mul> m_Mul(const LHS &L, | |||
1046 | const RHS &R) { | |||
1047 | return BinaryOp_match<LHS, RHS, Instruction::Mul>(L, R); | |||
1048 | } | |||
1049 | ||||
1050 | template <typename LHS, typename RHS> | |||
1051 | inline BinaryOp_match<LHS, RHS, Instruction::FMul> m_FMul(const LHS &L, | |||
1052 | const RHS &R) { | |||
1053 | return BinaryOp_match<LHS, RHS, Instruction::FMul>(L, R); | |||
1054 | } | |||
1055 | ||||
1056 | template <typename LHS, typename RHS> | |||
1057 | inline BinaryOp_match<LHS, RHS, Instruction::UDiv> m_UDiv(const LHS &L, | |||
1058 | const RHS &R) { | |||
1059 | return BinaryOp_match<LHS, RHS, Instruction::UDiv>(L, R); | |||
1060 | } | |||
1061 | ||||
1062 | template <typename LHS, typename RHS> | |||
1063 | inline BinaryOp_match<LHS, RHS, Instruction::SDiv> m_SDiv(const LHS &L, | |||
1064 | const RHS &R) { | |||
1065 | return BinaryOp_match<LHS, RHS, Instruction::SDiv>(L, R); | |||
1066 | } | |||
1067 | ||||
1068 | template <typename LHS, typename RHS> | |||
1069 | inline BinaryOp_match<LHS, RHS, Instruction::FDiv> m_FDiv(const LHS &L, | |||
1070 | const RHS &R) { | |||
1071 | return BinaryOp_match<LHS, RHS, Instruction::FDiv>(L, R); | |||
1072 | } | |||
1073 | ||||
1074 | template <typename LHS, typename RHS> | |||
1075 | inline BinaryOp_match<LHS, RHS, Instruction::URem> m_URem(const LHS &L, | |||
1076 | const RHS &R) { | |||
1077 | return BinaryOp_match<LHS, RHS, Instruction::URem>(L, R); | |||
1078 | } | |||
1079 | ||||
1080 | template <typename LHS, typename RHS> | |||
1081 | inline BinaryOp_match<LHS, RHS, Instruction::SRem> m_SRem(const LHS &L, | |||
1082 | const RHS &R) { | |||
1083 | return BinaryOp_match<LHS, RHS, Instruction::SRem>(L, R); | |||
1084 | } | |||
1085 | ||||
1086 | template <typename LHS, typename RHS> | |||
1087 | inline BinaryOp_match<LHS, RHS, Instruction::FRem> m_FRem(const LHS &L, | |||
1088 | const RHS &R) { | |||
1089 | return BinaryOp_match<LHS, RHS, Instruction::FRem>(L, R); | |||
1090 | } | |||
1091 | ||||
1092 | template <typename LHS, typename RHS> | |||
1093 | inline BinaryOp_match<LHS, RHS, Instruction::And> m_And(const LHS &L, | |||
1094 | const RHS &R) { | |||
1095 | return BinaryOp_match<LHS, RHS, Instruction::And>(L, R); | |||
1096 | } | |||
1097 | ||||
1098 | template <typename LHS, typename RHS> | |||
1099 | inline BinaryOp_match<LHS, RHS, Instruction::Or> m_Or(const LHS &L, | |||
1100 | const RHS &R) { | |||
1101 | return BinaryOp_match<LHS, RHS, Instruction::Or>(L, R); | |||
1102 | } | |||
1103 | ||||
1104 | template <typename LHS, typename RHS> | |||
1105 | inline BinaryOp_match<LHS, RHS, Instruction::Xor> m_Xor(const LHS &L, | |||
1106 | const RHS &R) { | |||
1107 | return BinaryOp_match<LHS, RHS, Instruction::Xor>(L, R); | |||
1108 | } | |||
1109 | ||||
1110 | template <typename LHS, typename RHS> | |||
1111 | inline BinaryOp_match<LHS, RHS, Instruction::Shl> m_Shl(const LHS &L, | |||
1112 | const RHS &R) { | |||
1113 | return BinaryOp_match<LHS, RHS, Instruction::Shl>(L, R); | |||
1114 | } | |||
1115 | ||||
1116 | template <typename LHS, typename RHS> | |||
1117 | inline BinaryOp_match<LHS, RHS, Instruction::LShr> m_LShr(const LHS &L, | |||
1118 | const RHS &R) { | |||
1119 | return BinaryOp_match<LHS, RHS, Instruction::LShr>(L, R); | |||
1120 | } | |||
1121 | ||||
1122 | template <typename LHS, typename RHS> | |||
1123 | inline BinaryOp_match<LHS, RHS, Instruction::AShr> m_AShr(const LHS &L, | |||
1124 | const RHS &R) { | |||
1125 | return BinaryOp_match<LHS, RHS, Instruction::AShr>(L, R); | |||
1126 | } | |||
1127 | ||||
1128 | template <typename LHS_t, typename RHS_t, unsigned Opcode, | |||
1129 | unsigned WrapFlags = 0> | |||
1130 | struct OverflowingBinaryOp_match { | |||
1131 | LHS_t L; | |||
1132 | RHS_t R; | |||
1133 | ||||
1134 | OverflowingBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) | |||
1135 | : L(LHS), R(RHS) {} | |||
1136 | ||||
1137 | template <typename OpTy> bool match(OpTy *V) { | |||
1138 | if (auto *Op = dyn_cast<OverflowingBinaryOperator>(V)) { | |||
1139 | if (Op->getOpcode() != Opcode) | |||
1140 | return false; | |||
1141 | if ((WrapFlags & OverflowingBinaryOperator::NoUnsignedWrap) && | |||
1142 | !Op->hasNoUnsignedWrap()) | |||
1143 | return false; | |||
1144 | if ((WrapFlags & OverflowingBinaryOperator::NoSignedWrap) && | |||
1145 | !Op->hasNoSignedWrap()) | |||
1146 | return false; | |||
1147 | return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); | |||
1148 | } | |||
1149 | return false; | |||
1150 | } | |||
1151 | }; | |||
1152 | ||||
1153 | template <typename LHS, typename RHS> | |||
1154 | inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Add, | |||
1155 | OverflowingBinaryOperator::NoSignedWrap> | |||
1156 | m_NSWAdd(const LHS &L, const RHS &R) { | |||
1157 | return OverflowingBinaryOp_match<LHS, RHS, Instruction::Add, | |||
1158 | OverflowingBinaryOperator::NoSignedWrap>(L, | |||
1159 | R); | |||
1160 | } | |||
1161 | template <typename LHS, typename RHS> | |||
1162 | inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Sub, | |||
1163 | OverflowingBinaryOperator::NoSignedWrap> | |||
1164 | m_NSWSub(const LHS &L, const RHS &R) { | |||
1165 | return OverflowingBinaryOp_match<LHS, RHS, Instruction::Sub, | |||
1166 | OverflowingBinaryOperator::NoSignedWrap>(L, | |||
1167 | R); | |||
1168 | } | |||
1169 | template <typename LHS, typename RHS> | |||
1170 | inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Mul, | |||
1171 | OverflowingBinaryOperator::NoSignedWrap> | |||
1172 | m_NSWMul(const LHS &L, const RHS &R) { | |||
1173 | return OverflowingBinaryOp_match<LHS, RHS, Instruction::Mul, | |||
1174 | OverflowingBinaryOperator::NoSignedWrap>(L, | |||
1175 | R); | |||
1176 | } | |||
1177 | template <typename LHS, typename RHS> | |||
1178 | inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Shl, | |||
1179 | OverflowingBinaryOperator::NoSignedWrap> | |||
1180 | m_NSWShl(const LHS &L, const RHS &R) { | |||
1181 | return OverflowingBinaryOp_match<LHS, RHS, Instruction::Shl, | |||
1182 | OverflowingBinaryOperator::NoSignedWrap>(L, | |||
1183 | R); | |||
1184 | } | |||
1185 | ||||
1186 | template <typename LHS, typename RHS> | |||
1187 | inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Add, | |||
1188 | OverflowingBinaryOperator::NoUnsignedWrap> | |||
1189 | m_NUWAdd(const LHS &L, const RHS &R) { | |||
1190 | return OverflowingBinaryOp_match<LHS, RHS, Instruction::Add, | |||
1191 | OverflowingBinaryOperator::NoUnsignedWrap>( | |||
1192 | L, R); | |||
1193 | } | |||
1194 | template <typename LHS, typename RHS> | |||
1195 | inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Sub, | |||
1196 | OverflowingBinaryOperator::NoUnsignedWrap> | |||
1197 | m_NUWSub(const LHS &L, const RHS &R) { | |||
1198 | return OverflowingBinaryOp_match<LHS, RHS, Instruction::Sub, | |||
1199 | OverflowingBinaryOperator::NoUnsignedWrap>( | |||
1200 | L, R); | |||
1201 | } | |||
1202 | template <typename LHS, typename RHS> | |||
1203 | inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Mul, | |||
1204 | OverflowingBinaryOperator::NoUnsignedWrap> | |||
1205 | m_NUWMul(const LHS &L, const RHS &R) { | |||
1206 | return OverflowingBinaryOp_match<LHS, RHS, Instruction::Mul, | |||
1207 | OverflowingBinaryOperator::NoUnsignedWrap>( | |||
1208 | L, R); | |||
1209 | } | |||
1210 | template <typename LHS, typename RHS> | |||
1211 | inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Shl, | |||
1212 | OverflowingBinaryOperator::NoUnsignedWrap> | |||
1213 | m_NUWShl(const LHS &L, const RHS &R) { | |||
1214 | return OverflowingBinaryOp_match<LHS, RHS, Instruction::Shl, | |||
1215 | OverflowingBinaryOperator::NoUnsignedWrap>( | |||
1216 | L, R); | |||
1217 | } | |||
1218 | ||||
1219 | template <typename LHS_t, typename RHS_t, bool Commutable = false> | |||
1220 | struct SpecificBinaryOp_match | |||
1221 | : public BinaryOp_match<LHS_t, RHS_t, 0, Commutable> { | |||
1222 | unsigned Opcode; | |||
1223 | ||||
1224 | SpecificBinaryOp_match(unsigned Opcode, const LHS_t &LHS, const RHS_t &RHS) | |||
1225 | : BinaryOp_match<LHS_t, RHS_t, 0, Commutable>(LHS, RHS), Opcode(Opcode) {} | |||
1226 | ||||
1227 | template <typename OpTy> bool match(OpTy *V) { | |||
1228 | return BinaryOp_match<LHS_t, RHS_t, 0, Commutable>::match(Opcode, V); | |||
1229 | } | |||
1230 | }; | |||
1231 | ||||
1232 | /// Matches a specific opcode. | |||
1233 | template <typename LHS, typename RHS> | |||
1234 | inline SpecificBinaryOp_match<LHS, RHS> m_BinOp(unsigned Opcode, const LHS &L, | |||
1235 | const RHS &R) { | |||
1236 | return SpecificBinaryOp_match<LHS, RHS>(Opcode, L, R); | |||
1237 | } | |||
1238 | ||||
1239 | //===----------------------------------------------------------------------===// | |||
1240 | // Class that matches a group of binary opcodes. | |||
1241 | // | |||
1242 | template <typename LHS_t, typename RHS_t, typename Predicate> | |||
1243 | struct BinOpPred_match : Predicate { | |||
1244 | LHS_t L; | |||
1245 | RHS_t R; | |||
1246 | ||||
1247 | BinOpPred_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} | |||
1248 | ||||
1249 | template <typename OpTy> bool match(OpTy *V) { | |||
1250 | if (auto *I = dyn_cast<Instruction>(V)) | |||
1251 | return this->isOpType(I->getOpcode()) && L.match(I->getOperand(0)) && | |||
1252 | R.match(I->getOperand(1)); | |||
1253 | if (auto *CE = dyn_cast<ConstantExpr>(V)) | |||
1254 | return this->isOpType(CE->getOpcode()) && L.match(CE->getOperand(0)) && | |||
1255 | R.match(CE->getOperand(1)); | |||
1256 | return false; | |||
1257 | } | |||
1258 | }; | |||
1259 | ||||
1260 | struct is_shift_op { | |||
1261 | bool isOpType(unsigned Opcode) { return Instruction::isShift(Opcode); } | |||
1262 | }; | |||
1263 | ||||
1264 | struct is_right_shift_op { | |||
1265 | bool isOpType(unsigned Opcode) { | |||
1266 | return Opcode == Instruction::LShr || Opcode == Instruction::AShr; | |||
1267 | } | |||
1268 | }; | |||
1269 | ||||
1270 | struct is_logical_shift_op { | |||
1271 | bool isOpType(unsigned Opcode) { | |||
1272 | return Opcode == Instruction::LShr || Opcode == Instruction::Shl; | |||
1273 | } | |||
1274 | }; | |||
1275 | ||||
1276 | struct is_bitwiselogic_op { | |||
1277 | bool isOpType(unsigned Opcode) { | |||
1278 | return Instruction::isBitwiseLogicOp(Opcode); | |||
1279 | } | |||
1280 | }; | |||
1281 | ||||
1282 | struct is_idiv_op { | |||
1283 | bool isOpType(unsigned Opcode) { | |||
1284 | return Opcode == Instruction::SDiv || Opcode == Instruction::UDiv; | |||
1285 | } | |||
1286 | }; | |||
1287 | ||||
1288 | struct is_irem_op { | |||
1289 | bool isOpType(unsigned Opcode) { | |||
1290 | return Opcode == Instruction::SRem || Opcode == Instruction::URem; | |||
1291 | } | |||
1292 | }; | |||
1293 | ||||
1294 | /// Matches shift operations. | |||
1295 | template <typename LHS, typename RHS> | |||
1296 | inline BinOpPred_match<LHS, RHS, is_shift_op> m_Shift(const LHS &L, | |||
1297 | const RHS &R) { | |||
1298 | return BinOpPred_match<LHS, RHS, is_shift_op>(L, R); | |||
1299 | } | |||
1300 | ||||
1301 | /// Matches logical shift operations. | |||
1302 | template <typename LHS, typename RHS> | |||
1303 | inline BinOpPred_match<LHS, RHS, is_right_shift_op> m_Shr(const LHS &L, | |||
1304 | const RHS &R) { | |||
1305 | return BinOpPred_match<LHS, RHS, is_right_shift_op>(L, R); | |||
1306 | } | |||
1307 | ||||
1308 | /// Matches logical shift operations. | |||
1309 | template <typename LHS, typename RHS> | |||
1310 | inline BinOpPred_match<LHS, RHS, is_logical_shift_op> | |||
1311 | m_LogicalShift(const LHS &L, const RHS &R) { | |||
1312 | return BinOpPred_match<LHS, RHS, is_logical_shift_op>(L, R); | |||
1313 | } | |||
1314 | ||||
1315 | /// Matches bitwise logic operations. | |||
1316 | template <typename LHS, typename RHS> | |||
1317 | inline BinOpPred_match<LHS, RHS, is_bitwiselogic_op> | |||
1318 | m_BitwiseLogic(const LHS &L, const RHS &R) { | |||
1319 | return BinOpPred_match<LHS, RHS, is_bitwiselogic_op>(L, R); | |||
1320 | } | |||
1321 | ||||
1322 | /// Matches integer division operations. | |||
1323 | template <typename LHS, typename RHS> | |||
1324 | inline BinOpPred_match<LHS, RHS, is_idiv_op> m_IDiv(const LHS &L, | |||
1325 | const RHS &R) { | |||
1326 | return BinOpPred_match<LHS, RHS, is_idiv_op>(L, R); | |||
1327 | } | |||
1328 | ||||
1329 | /// Matches integer remainder operations. | |||
1330 | template <typename LHS, typename RHS> | |||
1331 | inline BinOpPred_match<LHS, RHS, is_irem_op> m_IRem(const LHS &L, | |||
1332 | const RHS &R) { | |||
1333 | return BinOpPred_match<LHS, RHS, is_irem_op>(L, R); | |||
1334 | } | |||
1335 | ||||
1336 | //===----------------------------------------------------------------------===// | |||
1337 | // Class that matches exact binary ops. | |||
1338 | // | |||
1339 | template <typename SubPattern_t> struct Exact_match { | |||
1340 | SubPattern_t SubPattern; | |||
1341 | ||||
1342 | Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} | |||
1343 | ||||
1344 | template <typename OpTy> bool match(OpTy *V) { | |||
1345 | if (auto *PEO = dyn_cast<PossiblyExactOperator>(V)) | |||
1346 | return PEO->isExact() && SubPattern.match(V); | |||
1347 | return false; | |||
1348 | } | |||
1349 | }; | |||
1350 | ||||
1351 | template <typename T> inline Exact_match<T> m_Exact(const T &SubPattern) { | |||
1352 | return SubPattern; | |||
1353 | } | |||
1354 | ||||
1355 | //===----------------------------------------------------------------------===// | |||
1356 | // Matchers for CmpInst classes | |||
1357 | // | |||
1358 | ||||
1359 | template <typename LHS_t, typename RHS_t, typename Class, typename PredicateTy, | |||
1360 | bool Commutable = false> | |||
1361 | struct CmpClass_match { | |||
1362 | PredicateTy &Predicate; | |||
1363 | LHS_t L; | |||
1364 | RHS_t R; | |||
1365 | ||||
1366 | // The evaluation order is always stable, regardless of Commutability. | |||
1367 | // The LHS is always matched first. | |||
1368 | CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) | |||
1369 | : Predicate(Pred), L(LHS), R(RHS) {} | |||
1370 | ||||
1371 | template <typename OpTy> bool match(OpTy *V) { | |||
1372 | if (auto *I = dyn_cast<Class>(V)) { | |||
1373 | if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { | |||
1374 | Predicate = I->getPredicate(); | |||
1375 | return true; | |||
1376 | } else if (Commutable && L.match(I->getOperand(1)) && | |||
1377 | R.match(I->getOperand(0))) { | |||
1378 | Predicate = I->getSwappedPredicate(); | |||
1379 | return true; | |||
1380 | } | |||
1381 | } | |||
1382 | return false; | |||
1383 | } | |||
1384 | }; | |||
1385 | ||||
1386 | template <typename LHS, typename RHS> | |||
1387 | inline CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate> | |||
1388 | m_Cmp(CmpInst::Predicate &Pred, const LHS &L, const RHS &R) { | |||
1389 | return CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(Pred, L, R); | |||
1390 | } | |||
1391 | ||||
1392 | template <typename LHS, typename RHS> | |||
1393 | inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate> | |||
1394 | m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) { | |||
1395 | return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(Pred, L, R); | |||
1396 | } | |||
1397 | ||||
1398 | template <typename LHS, typename RHS> | |||
1399 | inline CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate> | |||
1400 | m_FCmp(FCmpInst::Predicate &Pred, const LHS &L, const RHS &R) { | |||
1401 | return CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(Pred, L, R); | |||
1402 | } | |||
1403 | ||||
1404 | //===----------------------------------------------------------------------===// | |||
1405 | // Matchers for instructions with a given opcode and number of operands. | |||
1406 | // | |||
1407 | ||||
1408 | /// Matches instructions with Opcode and three operands. | |||
1409 | template <typename T0, unsigned Opcode> struct OneOps_match { | |||
1410 | T0 Op1; | |||
1411 | ||||
1412 | OneOps_match(const T0 &Op1) : Op1(Op1) {} | |||
1413 | ||||
1414 | template <typename OpTy> bool match(OpTy *V) { | |||
1415 | if (V->getValueID() == Value::InstructionVal + Opcode) { | |||
1416 | auto *I = cast<Instruction>(V); | |||
1417 | return Op1.match(I->getOperand(0)); | |||
1418 | } | |||
1419 | return false; | |||
1420 | } | |||
1421 | }; | |||
1422 | ||||
1423 | /// Matches instructions with Opcode and three operands. | |||
1424 | template <typename T0, typename T1, unsigned Opcode> struct TwoOps_match { | |||
1425 | T0 Op1; | |||
1426 | T1 Op2; | |||
1427 | ||||
1428 | TwoOps_match(const T0 &Op1, const T1 &Op2) : Op1(Op1), Op2(Op2) {} | |||
1429 | ||||
1430 | template <typename OpTy> bool match(OpTy *V) { | |||
1431 | if (V->getValueID() == Value::InstructionVal + Opcode) { | |||
1432 | auto *I = cast<Instruction>(V); | |||
1433 | return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)); | |||
1434 | } | |||
1435 | return false; | |||
1436 | } | |||
1437 | }; | |||
1438 | ||||
1439 | /// Matches instructions with Opcode and three operands. | |||
1440 | template <typename T0, typename T1, typename T2, unsigned Opcode> | |||
1441 | struct ThreeOps_match { | |||
1442 | T0 Op1; | |||
1443 | T1 Op2; | |||
1444 | T2 Op3; | |||
1445 | ||||
1446 | ThreeOps_match(const T0 &Op1, const T1 &Op2, const T2 &Op3) | |||
1447 | : Op1(Op1), Op2(Op2), Op3(Op3) {} | |||
1448 | ||||
1449 | template <typename OpTy> bool match(OpTy *V) { | |||
1450 | if (V->getValueID() == Value::InstructionVal + Opcode) { | |||
1451 | auto *I = cast<Instruction>(V); | |||
1452 | return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && | |||
1453 | Op3.match(I->getOperand(2)); | |||
1454 | } | |||
1455 | return false; | |||
1456 | } | |||
1457 | }; | |||
1458 | ||||
1459 | /// Matches SelectInst. | |||
1460 | template <typename Cond, typename LHS, typename RHS> | |||
1461 | inline ThreeOps_match<Cond, LHS, RHS, Instruction::Select> | |||
1462 | m_Select(const Cond &C, const LHS &L, const RHS &R) { | |||
1463 | return ThreeOps_match<Cond, LHS, RHS, Instruction::Select>(C, L, R); | |||
1464 | } | |||
1465 | ||||
1466 | /// This matches a select of two constants, e.g.: | |||
1467 | /// m_SelectCst<-1, 0>(m_Value(V)) | |||
1468 | template <int64_t L, int64_t R, typename Cond> | |||
1469 | inline ThreeOps_match<Cond, constantint_match<L>, constantint_match<R>, | |||
1470 | Instruction::Select> | |||
1471 | m_SelectCst(const Cond &C) { | |||
1472 | return m_Select(C, m_ConstantInt<L>(), m_ConstantInt<R>()); | |||
1473 | } | |||
1474 | ||||
1475 | /// Matches FreezeInst. | |||
1476 | template <typename OpTy> | |||
1477 | inline OneOps_match<OpTy, Instruction::Freeze> m_Freeze(const OpTy &Op) { | |||
1478 | return OneOps_match<OpTy, Instruction::Freeze>(Op); | |||
1479 | } | |||
1480 | ||||
1481 | /// Matches InsertElementInst. | |||
1482 | template <typename Val_t, typename Elt_t, typename Idx_t> | |||
1483 | inline ThreeOps_match<Val_t, Elt_t, Idx_t, Instruction::InsertElement> | |||
1484 | m_InsertElt(const Val_t &Val, const Elt_t &Elt, const Idx_t &Idx) { | |||
1485 | return ThreeOps_match<Val_t, Elt_t, Idx_t, Instruction::InsertElement>( | |||
1486 | Val, Elt, Idx); | |||
1487 | } | |||
1488 | ||||
1489 | /// Matches ExtractElementInst. | |||
1490 | template <typename Val_t, typename Idx_t> | |||
1491 | inline TwoOps_match<Val_t, Idx_t, Instruction::ExtractElement> | |||
1492 | m_ExtractElt(const Val_t &Val, const Idx_t &Idx) { | |||
1493 | return TwoOps_match<Val_t, Idx_t, Instruction::ExtractElement>(Val, Idx); | |||
1494 | } | |||
1495 | ||||
1496 | /// Matches shuffle. | |||
1497 | template <typename T0, typename T1, typename T2> struct Shuffle_match { | |||
1498 | T0 Op1; | |||
1499 | T1 Op2; | |||
1500 | T2 Mask; | |||
1501 | ||||
1502 | Shuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask) | |||
1503 | : Op1(Op1), Op2(Op2), Mask(Mask) {} | |||
1504 | ||||
1505 | template <typename OpTy> bool match(OpTy *V) { | |||
1506 | if (auto *I = dyn_cast<ShuffleVectorInst>(V)) { | |||
1507 | return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && | |||
1508 | Mask.match(I->getShuffleMask()); | |||
1509 | } | |||
1510 | return false; | |||
1511 | } | |||
1512 | }; | |||
1513 | ||||
1514 | struct m_Mask { | |||
1515 | ArrayRef<int> &MaskRef; | |||
1516 | m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {} | |||
1517 | bool match(ArrayRef<int> Mask) { | |||
1518 | MaskRef = Mask; | |||
1519 | return true; | |||
1520 | } | |||
1521 | }; | |||
1522 | ||||
1523 | struct m_ZeroMask { | |||
1524 | bool match(ArrayRef<int> Mask) { | |||
1525 | return all_of(Mask, [](int Elem) { return Elem == 0 || Elem == -1; }); | |||
1526 | } | |||
1527 | }; | |||
1528 | ||||
1529 | struct m_SpecificMask { | |||
1530 | ArrayRef<int> &MaskRef; | |||
1531 | m_SpecificMask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {} | |||
1532 | bool match(ArrayRef<int> Mask) { return MaskRef == Mask; } | |||
1533 | }; | |||
1534 | ||||
1535 | struct m_SplatOrUndefMask { | |||
1536 | int &SplatIndex; | |||
1537 | m_SplatOrUndefMask(int &SplatIndex) : SplatIndex(SplatIndex) {} | |||
1538 | bool match(ArrayRef<int> Mask) { | |||
1539 | const auto *First = find_if(Mask, [](int Elem) { return Elem != -1; }); | |||
1540 | if (First == Mask.end()) | |||
1541 | return false; | |||
1542 | SplatIndex = *First; | |||
1543 | return all_of(Mask, | |||
1544 | [First](int Elem) { return Elem == *First || Elem == -1; }); | |||
1545 | } | |||
1546 | }; | |||
1547 | ||||
1548 | /// Matches ShuffleVectorInst independently of mask value. | |||
1549 | template <typename V1_t, typename V2_t> | |||
1550 | inline TwoOps_match<V1_t, V2_t, Instruction::ShuffleVector> | |||
1551 | m_Shuffle(const V1_t &v1, const V2_t &v2) { | |||
1552 | return TwoOps_match<V1_t, V2_t, Instruction::ShuffleVector>(v1, v2); | |||
1553 | } | |||
1554 | ||||
1555 | template <typename V1_t, typename V2_t, typename Mask_t> | |||
1556 | inline Shuffle_match<V1_t, V2_t, Mask_t> | |||
1557 | m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) { | |||
1558 | return Shuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask); | |||
1559 | } | |||
1560 | ||||
1561 | /// Matches LoadInst. | |||
1562 | template <typename OpTy> | |||
1563 | inline OneOps_match<OpTy, Instruction::Load> m_Load(const OpTy &Op) { | |||
1564 | return OneOps_match<OpTy, Instruction::Load>(Op); | |||
1565 | } | |||
1566 | ||||
1567 | /// Matches StoreInst. | |||
1568 | template <typename ValueOpTy, typename PointerOpTy> | |||
1569 | inline TwoOps_match<ValueOpTy, PointerOpTy, Instruction::Store> | |||
1570 | m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp) { | |||
1571 | return TwoOps_match<ValueOpTy, PointerOpTy, Instruction::Store>(ValueOp, | |||
1572 | PointerOp); | |||
1573 | } | |||
1574 | ||||
1575 | //===----------------------------------------------------------------------===// | |||
1576 | // Matchers for CastInst classes | |||
1577 | // | |||
1578 | ||||
1579 | template <typename Op_t, unsigned Opcode> struct CastClass_match { | |||
1580 | Op_t Op; | |||
1581 | ||||
1582 | CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} | |||
1583 | ||||
1584 | template <typename OpTy> bool match(OpTy *V) { | |||
1585 | if (auto *O = dyn_cast<Operator>(V)) | |||
1586 | return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); | |||
1587 | return false; | |||
1588 | } | |||
1589 | }; | |||
1590 | ||||
1591 | /// Matches BitCast. | |||
1592 | template <typename OpTy> | |||
1593 | inline CastClass_match<OpTy, Instruction::BitCast> m_BitCast(const OpTy &Op) { | |||
1594 | return CastClass_match<OpTy, Instruction::BitCast>(Op); | |||
1595 | } | |||
1596 | ||||
1597 | /// Matches PtrToInt. | |||
1598 | template <typename OpTy> | |||
1599 | inline CastClass_match<OpTy, Instruction::PtrToInt> m_PtrToInt(const OpTy &Op) { | |||
1600 | return CastClass_match<OpTy, Instruction::PtrToInt>(Op); | |||
1601 | } | |||
1602 | ||||
1603 | /// Matches IntToPtr. | |||
1604 | template <typename OpTy> | |||
1605 | inline CastClass_match<OpTy, Instruction::IntToPtr> m_IntToPtr(const OpTy &Op) { | |||
1606 | return CastClass_match<OpTy, Instruction::IntToPtr>(Op); | |||
1607 | } | |||
1608 | ||||
1609 | /// Matches Trunc. | |||
1610 | template <typename OpTy> | |||
1611 | inline CastClass_match<OpTy, Instruction::Trunc> m_Trunc(const OpTy &Op) { | |||
1612 | return CastClass_match<OpTy, Instruction::Trunc>(Op); | |||
1613 | } | |||
1614 | ||||
1615 | template <typename OpTy> | |||
1616 | inline match_combine_or<CastClass_match<OpTy, Instruction::Trunc>, OpTy> | |||
1617 | m_TruncOrSelf(const OpTy &Op) { | |||
1618 | return m_CombineOr(m_Trunc(Op), Op); | |||
1619 | } | |||
1620 | ||||
1621 | /// Matches SExt. | |||
1622 | template <typename OpTy> | |||
1623 | inline CastClass_match<OpTy, Instruction::SExt> m_SExt(const OpTy &Op) { | |||
1624 | return CastClass_match<OpTy, Instruction::SExt>(Op); | |||
1625 | } | |||
1626 | ||||
1627 | /// Matches ZExt. | |||
1628 | template <typename OpTy> | |||
1629 | inline CastClass_match<OpTy, Instruction::ZExt> m_ZExt(const OpTy &Op) { | |||
1630 | return CastClass_match<OpTy, Instruction::ZExt>(Op); | |||
1631 | } | |||
1632 | ||||
1633 | template <typename OpTy> | |||
1634 | inline match_combine_or<CastClass_match<OpTy, Instruction::ZExt>, OpTy> | |||
1635 | m_ZExtOrSelf(const OpTy &Op) { | |||
1636 | return m_CombineOr(m_ZExt(Op), Op); | |||
1637 | } | |||
1638 | ||||
1639 | template <typename OpTy> | |||
1640 | inline match_combine_or<CastClass_match<OpTy, Instruction::SExt>, OpTy> | |||
1641 | m_SExtOrSelf(const OpTy &Op) { | |||
1642 | return m_CombineOr(m_SExt(Op), Op); | |||
1643 | } | |||
1644 | ||||
1645 | template <typename OpTy> | |||
1646 | inline match_combine_or<CastClass_match<OpTy, Instruction::ZExt>, | |||
1647 | CastClass_match<OpTy, Instruction::SExt>> | |||
1648 | m_ZExtOrSExt(const OpTy &Op) { | |||
1649 | return m_CombineOr(m_ZExt(Op), m_SExt(Op)); | |||
1650 | } | |||
1651 | ||||
1652 | template <typename OpTy> | |||
1653 | inline match_combine_or< | |||
1654 | match_combine_or<CastClass_match<OpTy, Instruction::ZExt>, | |||
1655 | CastClass_match<OpTy, Instruction::SExt>>, | |||
1656 | OpTy> | |||
1657 | m_ZExtOrSExtOrSelf(const OpTy &Op) { | |||
1658 | return m_CombineOr(m_ZExtOrSExt(Op), Op); | |||
1659 | } | |||
1660 | ||||
1661 | template <typename OpTy> | |||
1662 | inline CastClass_match<OpTy, Instruction::UIToFP> m_UIToFP(const OpTy &Op) { | |||
1663 | return CastClass_match<OpTy, Instruction::UIToFP>(Op); | |||
1664 | } | |||
1665 | ||||
1666 | template <typename OpTy> | |||
1667 | inline CastClass_match<OpTy, Instruction::SIToFP> m_SIToFP(const OpTy &Op) { | |||
1668 | return CastClass_match<OpTy, Instruction::SIToFP>(Op); | |||
1669 | } | |||
1670 | ||||
1671 | template <typename OpTy> | |||
1672 | inline CastClass_match<OpTy, Instruction::FPToUI> m_FPToUI(const OpTy &Op) { | |||
1673 | return CastClass_match<OpTy, Instruction::FPToUI>(Op); | |||
1674 | } | |||
1675 | ||||
1676 | template <typename OpTy> | |||
1677 | inline CastClass_match<OpTy, Instruction::FPToSI> m_FPToSI(const OpTy &Op) { | |||
1678 | return CastClass_match<OpTy, Instruction::FPToSI>(Op); | |||
1679 | } | |||
1680 | ||||
1681 | template <typename OpTy> | |||
1682 | inline CastClass_match<OpTy, Instruction::FPTrunc> m_FPTrunc(const OpTy &Op) { | |||
1683 | return CastClass_match<OpTy, Instruction::FPTrunc>(Op); | |||
1684 | } | |||
1685 | ||||
1686 | template <typename OpTy> | |||
1687 | inline CastClass_match<OpTy, Instruction::FPExt> m_FPExt(const OpTy &Op) { | |||
1688 | return CastClass_match<OpTy, Instruction::FPExt>(Op); | |||
1689 | } | |||
1690 | ||||
1691 | //===----------------------------------------------------------------------===// | |||
1692 | // Matchers for control flow. | |||
1693 | // | |||
1694 | ||||
1695 | struct br_match { | |||
1696 | BasicBlock *&Succ; | |||
1697 | ||||
1698 | br_match(BasicBlock *&Succ) : Succ(Succ) {} | |||
1699 | ||||
1700 | template <typename OpTy> bool match(OpTy *V) { | |||
1701 | if (auto *BI = dyn_cast<BranchInst>(V)) | |||
1702 | if (BI->isUnconditional()) { | |||
1703 | Succ = BI->getSuccessor(0); | |||
1704 | return true; | |||
1705 | } | |||
1706 | return false; | |||
1707 | } | |||
1708 | }; | |||
1709 | ||||
1710 | inline br_match m_UnconditionalBr(BasicBlock *&Succ) { return br_match(Succ); } | |||
1711 | ||||
1712 | template <typename Cond_t, typename TrueBlock_t, typename FalseBlock_t> | |||
1713 | struct brc_match { | |||
1714 | Cond_t Cond; | |||
1715 | TrueBlock_t T; | |||
1716 | FalseBlock_t F; | |||
1717 | ||||
1718 | brc_match(const Cond_t &C, const TrueBlock_t &t, const FalseBlock_t &f) | |||
1719 | : Cond(C), T(t), F(f) {} | |||
1720 | ||||
1721 | template <typename OpTy> bool match(OpTy *V) { | |||
1722 | if (auto *BI = dyn_cast<BranchInst>(V)) | |||
1723 | if (BI->isConditional() && Cond.match(BI->getCondition())) | |||
1724 | return T.match(BI->getSuccessor(0)) && F.match(BI->getSuccessor(1)); | |||
1725 | return false; | |||
1726 | } | |||
1727 | }; | |||
1728 | ||||
1729 | template <typename Cond_t> | |||
1730 | inline brc_match<Cond_t, bind_ty<BasicBlock>, bind_ty<BasicBlock>> | |||
1731 | m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F) { | |||
1732 | return brc_match<Cond_t, bind_ty<BasicBlock>, bind_ty<BasicBlock>>( | |||
1733 | C, m_BasicBlock(T), m_BasicBlock(F)); | |||
1734 | } | |||
1735 | ||||
1736 | template <typename Cond_t, typename TrueBlock_t, typename FalseBlock_t> | |||
1737 | inline brc_match<Cond_t, TrueBlock_t, FalseBlock_t> | |||
1738 | m_Br(const Cond_t &C, const TrueBlock_t &T, const FalseBlock_t &F) { | |||
1739 | return brc_match<Cond_t, TrueBlock_t, FalseBlock_t>(C, T, F); | |||
1740 | } | |||
1741 | ||||
1742 | //===----------------------------------------------------------------------===// | |||
1743 | // Matchers for max/min idioms, eg: "select (sgt x, y), x, y" -> smax(x,y). | |||
1744 | // | |||
1745 | ||||
1746 | template <typename CmpInst_t, typename LHS_t, typename RHS_t, typename Pred_t, | |||
1747 | bool Commutable = false> | |||
1748 | struct MaxMin_match { | |||
1749 | using PredType = Pred_t; | |||
1750 | LHS_t L; | |||
1751 | RHS_t R; | |||
1752 | ||||
1753 | // The evaluation order is always stable, regardless of Commutability. | |||
1754 | // The LHS is always matched first. | |||
1755 | MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} | |||
1756 | ||||
1757 | template <typename OpTy> bool match(OpTy *V) { | |||
1758 | if (auto *II = dyn_cast<IntrinsicInst>(V)) { | |||
1759 | Intrinsic::ID IID = II->getIntrinsicID(); | |||
1760 | if ((IID == Intrinsic::smax && Pred_t::match(ICmpInst::ICMP_SGT)) || | |||
1761 | (IID == Intrinsic::smin && Pred_t::match(ICmpInst::ICMP_SLT)) || | |||
1762 | (IID == Intrinsic::umax && Pred_t::match(ICmpInst::ICMP_UGT)) || | |||
1763 | (IID == Intrinsic::umin && Pred_t::match(ICmpInst::ICMP_ULT))) { | |||
1764 | Value *LHS = II->getOperand(0), *RHS = II->getOperand(1); | |||
1765 | return (L.match(LHS) && R.match(RHS)) || | |||
1766 | (Commutable && L.match(RHS) && R.match(LHS)); | |||
1767 | } | |||
1768 | } | |||
1769 | // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". | |||
1770 | auto *SI = dyn_cast<SelectInst>(V); | |||
1771 | if (!SI) | |||
1772 | return false; | |||
1773 | auto *Cmp = dyn_cast<CmpInst_t>(SI->getCondition()); | |||
1774 | if (!Cmp) | |||
1775 | return false; | |||
1776 | // At this point we have a select conditioned on a comparison. Check that | |||
1777 | // it is the values returned by the select that are being compared. | |||
1778 | auto *TrueVal = SI->getTrueValue(); | |||
1779 | auto *FalseVal = SI->getFalseValue(); | |||
1780 | auto *LHS = Cmp->getOperand(0); | |||
1781 | auto *RHS = Cmp->getOperand(1); | |||
1782 | if ((TrueVal != LHS || FalseVal != RHS) && | |||
1783 | (TrueVal != RHS || FalseVal != LHS)) | |||
1784 | return false; | |||
1785 | typename CmpInst_t::Predicate Pred = | |||
1786 | LHS == TrueVal ? Cmp->getPredicate() : Cmp->getInversePredicate(); | |||
1787 | // Does "(x pred y) ? x : y" represent the desired max/min operation? | |||
1788 | if (!Pred_t::match(Pred)) | |||
1789 | return false; | |||
1790 | // It does! Bind the operands. | |||
1791 | return (L.match(LHS) && R.match(RHS)) || | |||
1792 | (Commutable && L.match(RHS) && R.match(LHS)); | |||
1793 | } | |||
1794 | }; | |||
1795 | ||||
1796 | /// Helper class for identifying signed max predicates. | |||
1797 | struct smax_pred_ty { | |||
1798 | static bool match(ICmpInst::Predicate Pred) { | |||
1799 | return Pred == CmpInst::ICMP_SGT || Pred == CmpInst::ICMP_SGE; | |||
1800 | } | |||
1801 | }; | |||
1802 | ||||
1803 | /// Helper class for identifying signed min predicates. | |||
1804 | struct smin_pred_ty { | |||
1805 | static bool match(ICmpInst::Predicate Pred) { | |||
1806 | return Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_SLE; | |||
1807 | } | |||
1808 | }; | |||
1809 | ||||
1810 | /// Helper class for identifying unsigned max predicates. | |||
1811 | struct umax_pred_ty { | |||
1812 | static bool match(ICmpInst::Predicate Pred) { | |||
1813 | return Pred == CmpInst::ICMP_UGT || Pred == CmpInst::ICMP_UGE; | |||
1814 | } | |||
1815 | }; | |||
1816 | ||||
1817 | /// Helper class for identifying unsigned min predicates. | |||
1818 | struct umin_pred_ty { | |||
1819 | static bool match(ICmpInst::Predicate Pred) { | |||
1820 | return Pred == CmpInst::ICMP_ULT || Pred == CmpInst::ICMP_ULE; | |||
1821 | } | |||
1822 | }; | |||
1823 | ||||
1824 | /// Helper class for identifying ordered max predicates. | |||
1825 | struct ofmax_pred_ty { | |||
1826 | static bool match(FCmpInst::Predicate Pred) { | |||
1827 | return Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_OGE; | |||
1828 | } | |||
1829 | }; | |||
1830 | ||||
1831 | /// Helper class for identifying ordered min predicates. | |||
1832 | struct ofmin_pred_ty { | |||
1833 | static bool match(FCmpInst::Predicate Pred) { | |||
1834 | return Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_OLE; | |||
1835 | } | |||
1836 | }; | |||
1837 | ||||
1838 | /// Helper class for identifying unordered max predicates. | |||
1839 | struct ufmax_pred_ty { | |||
1840 | static bool match(FCmpInst::Predicate Pred) { | |||
1841 | return Pred == CmpInst::FCMP_UGT || Pred == CmpInst::FCMP_UGE; | |||
1842 | } | |||
1843 | }; | |||
1844 | ||||
1845 | /// Helper class for identifying unordered min predicates. | |||
1846 | struct ufmin_pred_ty { | |||
1847 | static bool match(FCmpInst::Predicate Pred) { | |||
1848 | return Pred == CmpInst::FCMP_ULT || Pred == CmpInst::FCMP_ULE; | |||
1849 | } | |||
1850 | }; | |||
1851 | ||||
1852 | template <typename LHS, typename RHS> | |||
1853 | inline MaxMin_match<ICmpInst, LHS, RHS, smax_pred_ty> m_SMax(const LHS &L, | |||
1854 | const RHS &R) { | |||
1855 | return MaxMin_match<ICmpInst, LHS, RHS, smax_pred_ty>(L, R); | |||
1856 | } | |||
1857 | ||||
1858 | template <typename LHS, typename RHS> | |||
1859 | inline MaxMin_match<ICmpInst, LHS, RHS, smin_pred_ty> m_SMin(const LHS &L, | |||
1860 | const RHS &R) { | |||
1861 | return MaxMin_match<ICmpInst, LHS, RHS, smin_pred_ty>(L, R); | |||
1862 | } | |||
1863 | ||||
1864 | template <typename LHS, typename RHS> | |||
1865 | inline MaxMin_match<ICmpInst, LHS, RHS, umax_pred_ty> m_UMax(const LHS &L, | |||
1866 | const RHS &R) { | |||
1867 | return MaxMin_match<ICmpInst, LHS, RHS, umax_pred_ty>(L, R); | |||
1868 | } | |||
1869 | ||||
1870 | template <typename LHS, typename RHS> | |||
1871 | inline MaxMin_match<ICmpInst, LHS, RHS, umin_pred_ty> m_UMin(const LHS &L, | |||
1872 | const RHS &R) { | |||
1873 | return MaxMin_match<ICmpInst, LHS, RHS, umin_pred_ty>(L, R); | |||
1874 | } | |||
1875 | ||||
1876 | template <typename LHS, typename RHS> | |||
1877 | inline match_combine_or< | |||
1878 | match_combine_or<MaxMin_match<ICmpInst, LHS, RHS, smax_pred_ty>, | |||
1879 | MaxMin_match<ICmpInst, LHS, RHS, smin_pred_ty>>, | |||
1880 | match_combine_or<MaxMin_match<ICmpInst, LHS, RHS, umax_pred_ty>, | |||
1881 | MaxMin_match<ICmpInst, LHS, RHS, umin_pred_ty>>> | |||
1882 | m_MaxOrMin(const LHS &L, const RHS &R) { | |||
1883 | return m_CombineOr(m_CombineOr(m_SMax(L, R), m_SMin(L, R)), | |||
1884 | m_CombineOr(m_UMax(L, R), m_UMin(L, R))); | |||
1885 | } | |||
1886 | ||||
1887 | /// Match an 'ordered' floating point maximum function. | |||
1888 | /// Floating point has one special value 'NaN'. Therefore, there is no total | |||
1889 | /// order. However, if we can ignore the 'NaN' value (for example, because of a | |||
1890 | /// 'no-nans-float-math' flag) a combination of a fcmp and select has 'maximum' | |||
1891 | /// semantics. In the presence of 'NaN' we have to preserve the original | |||
1892 | /// select(fcmp(ogt/ge, L, R), L, R) semantics matched by this predicate. | |||
1893 | /// | |||
1894 | /// max(L, R) iff L and R are not NaN | |||
1895 | /// m_OrdFMax(L, R) = R iff L or R are NaN | |||
1896 | template <typename LHS, typename RHS> | |||
1897 | inline MaxMin_match<FCmpInst, LHS, RHS, ofmax_pred_ty> m_OrdFMax(const LHS &L, | |||
1898 | const RHS &R) { | |||
1899 | return MaxMin_match<FCmpInst, LHS, RHS, ofmax_pred_ty>(L, R); | |||
1900 | } | |||
1901 | ||||
1902 | /// Match an 'ordered' floating point minimum function. | |||
1903 | /// Floating point has one special value 'NaN'. Therefore, there is no total | |||
1904 | /// order. However, if we can ignore the 'NaN' value (for example, because of a | |||
1905 | /// 'no-nans-float-math' flag) a combination of a fcmp and select has 'minimum' | |||
1906 | /// semantics. In the presence of 'NaN' we have to preserve the original | |||
1907 | /// select(fcmp(olt/le, L, R), L, R) semantics matched by this predicate. | |||
1908 | /// | |||
1909 | /// min(L, R) iff L and R are not NaN | |||
1910 | /// m_OrdFMin(L, R) = R iff L or R are NaN | |||
1911 | template <typename LHS, typename RHS> | |||
1912 | inline MaxMin_match<FCmpInst, LHS, RHS, ofmin_pred_ty> m_OrdFMin(const LHS &L, | |||
1913 | const RHS &R) { | |||
1914 | return MaxMin_match<FCmpInst, LHS, RHS, ofmin_pred_ty>(L, R); | |||
1915 | } | |||
1916 | ||||
1917 | /// Match an 'unordered' floating point maximum function. | |||
1918 | /// Floating point has one special value 'NaN'. Therefore, there is no total | |||
1919 | /// order. However, if we can ignore the 'NaN' value (for example, because of a | |||
1920 | /// 'no-nans-float-math' flag) a combination of a fcmp and select has 'maximum' | |||
1921 | /// semantics. In the presence of 'NaN' we have to preserve the original | |||
1922 | /// select(fcmp(ugt/ge, L, R), L, R) semantics matched by this predicate. | |||
1923 | /// | |||
1924 | /// max(L, R) iff L and R are not NaN | |||
1925 | /// m_UnordFMax(L, R) = L iff L or R are NaN | |||
1926 | template <typename LHS, typename RHS> | |||
1927 | inline MaxMin_match<FCmpInst, LHS, RHS, ufmax_pred_ty> | |||
1928 | m_UnordFMax(const LHS &L, const RHS &R) { | |||
1929 | return MaxMin_match<FCmpInst, LHS, RHS, ufmax_pred_ty>(L, R); | |||
1930 | } | |||
1931 | ||||
1932 | /// Match an 'unordered' floating point minimum function. | |||
1933 | /// Floating point has one special value 'NaN'. Therefore, there is no total | |||
1934 | /// order. However, if we can ignore the 'NaN' value (for example, because of a | |||
1935 | /// 'no-nans-float-math' flag) a combination of a fcmp and select has 'minimum' | |||
1936 | /// semantics. In the presence of 'NaN' we have to preserve the original | |||
1937 | /// select(fcmp(ult/le, L, R), L, R) semantics matched by this predicate. | |||
1938 | /// | |||
1939 | /// min(L, R) iff L and R are not NaN | |||
1940 | /// m_UnordFMin(L, R) = L iff L or R are NaN | |||
1941 | template <typename LHS, typename RHS> | |||
1942 | inline MaxMin_match<FCmpInst, LHS, RHS, ufmin_pred_ty> | |||
1943 | m_UnordFMin(const LHS &L, const RHS &R) { | |||
1944 | return MaxMin_match<FCmpInst, LHS, RHS, ufmin_pred_ty>(L, R); | |||
1945 | } | |||
1946 | ||||
1947 | //===----------------------------------------------------------------------===// | |||
1948 | // Matchers for overflow check patterns: e.g. (a + b) u< a, (a ^ -1) <u b | |||
1949 | // Note that S might be matched to other instructions than AddInst. | |||
1950 | // | |||
1951 | ||||
1952 | template <typename LHS_t, typename RHS_t, typename Sum_t> | |||
1953 | struct UAddWithOverflow_match { | |||
1954 | LHS_t L; | |||
1955 | RHS_t R; | |||
1956 | Sum_t S; | |||
1957 | ||||
1958 | UAddWithOverflow_match(const LHS_t &L, const RHS_t &R, const Sum_t &S) | |||
1959 | : L(L), R(R), S(S) {} | |||
1960 | ||||
1961 | template <typename OpTy> bool match(OpTy *V) { | |||
1962 | Value *ICmpLHS, *ICmpRHS; | |||
1963 | ICmpInst::Predicate Pred; | |||
1964 | if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) | |||
1965 | return false; | |||
1966 | ||||
1967 | Value *AddLHS, *AddRHS; | |||
1968 | auto AddExpr = m_Add(m_Value(AddLHS), m_Value(AddRHS)); | |||
1969 | ||||
1970 | // (a + b) u< a, (a + b) u< b | |||
1971 | if (Pred == ICmpInst::ICMP_ULT) | |||
1972 | if (AddExpr.match(ICmpLHS) && (ICmpRHS == AddLHS || ICmpRHS == AddRHS)) | |||
1973 | return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpLHS); | |||
1974 | ||||
1975 | // a >u (a + b), b >u (a + b) | |||
1976 | if (Pred == ICmpInst::ICMP_UGT) | |||
1977 | if (AddExpr.match(ICmpRHS) && (ICmpLHS == AddLHS || ICmpLHS == AddRHS)) | |||
1978 | return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpRHS); | |||
1979 | ||||
1980 | Value *Op1; | |||
1981 | auto XorExpr = m_OneUse(m_Xor(m_Value(Op1), m_AllOnes())); | |||
1982 | // (a ^ -1) <u b | |||
1983 | if (Pred == ICmpInst::ICMP_ULT) { | |||
1984 | if (XorExpr.match(ICmpLHS)) | |||
1985 | return L.match(Op1) && R.match(ICmpRHS) && S.match(ICmpLHS); | |||
1986 | } | |||
1987 | // b > u (a ^ -1) | |||
1988 | if (Pred == ICmpInst::ICMP_UGT) { | |||
1989 | if (XorExpr.match(ICmpRHS)) | |||
1990 | return L.match(Op1) && R.match(ICmpLHS) && S.match(ICmpRHS); | |||
1991 | } | |||
1992 | ||||
1993 | // Match special-case for increment-by-1. | |||
1994 | if (Pred == ICmpInst::ICMP_EQ) { | |||
1995 | // (a + 1) == 0 | |||
1996 | // (1 + a) == 0 | |||
1997 | if (AddExpr.match(ICmpLHS) && m_ZeroInt().match(ICmpRHS) && | |||
1998 | (m_One().match(AddLHS) || m_One().match(AddRHS))) | |||
1999 | return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpLHS); | |||
2000 | // 0 == (a + 1) | |||
2001 | // 0 == (1 + a) | |||
2002 | if (m_ZeroInt().match(ICmpLHS) && AddExpr.match(ICmpRHS) && | |||
2003 | (m_One().match(AddLHS) || m_One().match(AddRHS))) | |||
2004 | return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpRHS); | |||
2005 | } | |||
2006 | ||||
2007 | return false; | |||
2008 | } | |||
2009 | }; | |||
2010 | ||||
2011 | /// Match an icmp instruction checking for unsigned overflow on addition. | |||
2012 | /// | |||
2013 | /// S is matched to the addition whose result is being checked for overflow, and | |||
2014 | /// L and R are matched to the LHS and RHS of S. | |||
2015 | template <typename LHS_t, typename RHS_t, typename Sum_t> | |||
2016 | UAddWithOverflow_match<LHS_t, RHS_t, Sum_t> | |||
2017 | m_UAddWithOverflow(const LHS_t &L, const RHS_t &R, const Sum_t &S) { | |||
2018 | return UAddWithOverflow_match<LHS_t, RHS_t, Sum_t>(L, R, S); | |||
2019 | } | |||
2020 | ||||
2021 | template <typename Opnd_t> struct Argument_match { | |||
2022 | unsigned OpI; | |||
2023 | Opnd_t Val; | |||
2024 | ||||
2025 | Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} | |||
2026 | ||||
2027 | template <typename OpTy> bool match(OpTy *V) { | |||
2028 | // FIXME: Should likely be switched to use `CallBase`. | |||
2029 | if (const auto *CI = dyn_cast<CallInst>(V)) | |||
2030 | return Val.match(CI->getArgOperand(OpI)); | |||
2031 | return false; | |||
2032 | } | |||
2033 | }; | |||
2034 | ||||
2035 | /// Match an argument. | |||
2036 | template <unsigned OpI, typename Opnd_t> | |||
2037 | inline Argument_match<Opnd_t> m_Argument(const Opnd_t &Op) { | |||
2038 | return Argument_match<Opnd_t>(OpI, Op); | |||
2039 | } | |||
2040 | ||||
2041 | /// Intrinsic matchers. | |||
2042 | struct IntrinsicID_match { | |||
2043 | unsigned ID; | |||
2044 | ||||
2045 | IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} | |||
2046 | ||||
2047 | template <typename OpTy> bool match(OpTy *V) { | |||
2048 | if (const auto *CI = dyn_cast<CallInst>(V)) | |||
2049 | if (const auto *F = CI->getCalledFunction()) | |||
2050 | return F->getIntrinsicID() == ID; | |||
2051 | return false; | |||
2052 | } | |||
2053 | }; | |||
2054 | ||||
2055 | /// Intrinsic matches are combinations of ID matchers, and argument | |||
2056 | /// matchers. Higher arity matcher are defined recursively in terms of and-ing | |||
2057 | /// them with lower arity matchers. Here's some convenient typedefs for up to | |||
2058 | /// several arguments, and more can be added as needed | |||
2059 | template <typename T0 = void, typename T1 = void, typename T2 = void, | |||
2060 | typename T3 = void, typename T4 = void, typename T5 = void, | |||
2061 | typename T6 = void, typename T7 = void, typename T8 = void, | |||
2062 | typename T9 = void, typename T10 = void> | |||
2063 | struct m_Intrinsic_Ty; | |||
2064 | template <typename T0> struct m_Intrinsic_Ty<T0> { | |||
2065 | using Ty = match_combine_and<IntrinsicID_match, Argument_match<T0>>; | |||
2066 | }; | |||
2067 | template <typename T0, typename T1> struct m_Intrinsic_Ty<T0, T1> { | |||
2068 | using Ty = | |||
2069 | match_combine_and<typename m_Intrinsic_Ty<T0>::Ty, Argument_match<T1>>; | |||
2070 | }; | |||
2071 | template <typename T0, typename T1, typename T2> | |||
2072 | struct m_Intrinsic_Ty<T0, T1, T2> { | |||
2073 | using Ty = match_combine_and<typename m_Intrinsic_Ty<T0, T1>::Ty, | |||
2074 | Argument_match<T2>>; | |||
2075 | }; | |||
2076 | template <typename T0, typename T1, typename T2, typename T3> | |||
2077 | struct m_Intrinsic_Ty<T0, T1, T2, T3> { | |||
2078 | using Ty = match_combine_and<typename m_Intrinsic_Ty<T0, T1, T2>::Ty, | |||
2079 | Argument_match<T3>>; | |||
2080 | }; | |||
2081 | ||||
2082 | template <typename T0, typename T1, typename T2, typename T3, typename T4> | |||
2083 | struct m_Intrinsic_Ty<T0, T1, T2, T3, T4> { | |||
2084 | using Ty = match_combine_and<typename m_Intrinsic_Ty<T0, T1, T2, T3>::Ty, | |||
2085 | Argument_match<T4>>; | |||
2086 | }; | |||
2087 | ||||
2088 | template <typename T0, typename T1, typename T2, typename T3, typename T4, | |||
2089 | typename T5> | |||
2090 | struct m_Intrinsic_Ty<T0, T1, T2, T3, T4, T5> { | |||
2091 | using Ty = match_combine_and<typename m_Intrinsic_Ty<T0, T1, T2, T3, T4>::Ty, | |||
2092 | Argument_match<T5>>; | |||
2093 | }; | |||
2094 | ||||
2095 | /// Match intrinsic calls like this: | |||
2096 | /// m_Intrinsic<Intrinsic::fabs>(m_Value(X)) | |||
2097 | template <Intrinsic::ID IntrID> inline IntrinsicID_match m_Intrinsic() { | |||
2098 | return IntrinsicID_match(IntrID); | |||
2099 | } | |||
2100 | ||||
2101 | /// Matches MaskedLoad Intrinsic. | |||
2102 | template <typename Opnd0, typename Opnd1, typename Opnd2, typename Opnd3> | |||
2103 | inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2, Opnd3>::Ty | |||
2104 | m_MaskedLoad(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2, | |||
2105 | const Opnd3 &Op3) { | |||
2106 | return m_Intrinsic<Intrinsic::masked_load>(Op0, Op1, Op2, Op3); | |||
2107 | } | |||
2108 | ||||
2109 | /// Matches MaskedGather Intrinsic. | |||
2110 | template <typename Opnd0, typename Opnd1, typename Opnd2, typename Opnd3> | |||
2111 | inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2, Opnd3>::Ty | |||
2112 | m_MaskedGather(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2, | |||
2113 | const Opnd3 &Op3) { | |||
2114 | return m_Intrinsic<Intrinsic::masked_gather>(Op0, Op1, Op2, Op3); | |||
2115 | } | |||
2116 | ||||
2117 | template <Intrinsic::ID IntrID, typename T0> | |||
2118 | inline typename m_Intrinsic_Ty<T0>::Ty m_Intrinsic(const T0 &Op0) { | |||
2119 | return m_CombineAnd(m_Intrinsic<IntrID>(), m_Argument<0>(Op0)); | |||
2120 | } | |||
2121 | ||||
2122 | template <Intrinsic::ID IntrID, typename T0, typename T1> | |||
2123 | inline typename m_Intrinsic_Ty<T0, T1>::Ty m_Intrinsic(const T0 &Op0, | |||
2124 | const T1 &Op1) { | |||
2125 | return m_CombineAnd(m_Intrinsic<IntrID>(Op0), m_Argument<1>(Op1)); | |||
2126 | } | |||
2127 | ||||
2128 | template <Intrinsic::ID IntrID, typename T0, typename T1, typename T2> | |||
2129 | inline typename m_Intrinsic_Ty<T0, T1, T2>::Ty | |||
2130 | m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2) { | |||
2131 | return m_CombineAnd(m_Intrinsic<IntrID>(Op0, Op1), m_Argument<2>(Op2)); | |||
2132 | } | |||
2133 | ||||
2134 | template <Intrinsic::ID IntrID, typename T0, typename T1, typename T2, | |||
2135 | typename T3> | |||
2136 | inline typename m_Intrinsic_Ty<T0, T1, T2, T3>::Ty | |||
2137 | m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3) { | |||
2138 | return m_CombineAnd(m_Intrinsic<IntrID>(Op0, Op1, Op2), m_Argument<3>(Op3)); | |||
2139 | } | |||
2140 | ||||
2141 | template <Intrinsic::ID IntrID, typename T0, typename T1, typename T2, | |||
2142 | typename T3, typename T4> | |||
2143 | inline typename m_Intrinsic_Ty<T0, T1, T2, T3, T4>::Ty | |||
2144 | m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3, | |||
2145 | const T4 &Op4) { | |||
2146 | return m_CombineAnd(m_Intrinsic<IntrID>(Op0, Op1, Op2, Op3), | |||
2147 | m_Argument<4>(Op4)); | |||
2148 | } | |||
2149 | ||||
2150 | template <Intrinsic::ID IntrID, typename T0, typename T1, typename T2, | |||
2151 | typename T3, typename T4, typename T5> | |||
2152 | inline typename m_Intrinsic_Ty<T0, T1, T2, T3, T4, T5>::Ty | |||
2153 | m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3, | |||
2154 | const T4 &Op4, const T5 &Op5) { | |||
2155 | return m_CombineAnd(m_Intrinsic<IntrID>(Op0, Op1, Op2, Op3, Op4), | |||
2156 | m_Argument<5>(Op5)); | |||
2157 | } | |||
2158 | ||||
2159 | // Helper intrinsic matching specializations. | |||
2160 | template <typename Opnd0> | |||
2161 | inline typename m_Intrinsic_Ty<Opnd0>::Ty m_BitReverse(const Opnd0 &Op0) { | |||
2162 | return m_Intrinsic<Intrinsic::bitreverse>(Op0); | |||
2163 | } | |||
2164 | ||||
2165 | template <typename Opnd0> | |||
2166 | inline typename m_Intrinsic_Ty<Opnd0>::Ty m_BSwap(const Opnd0 &Op0) { | |||
2167 | return m_Intrinsic<Intrinsic::bswap>(Op0); | |||
2168 | } | |||
2169 | ||||
2170 | template <typename Opnd0> | |||
2171 | inline typename m_Intrinsic_Ty<Opnd0>::Ty m_FAbs(const Opnd0 &Op0) { | |||
2172 | return m_Intrinsic<Intrinsic::fabs>(Op0); | |||
2173 | } | |||
2174 | ||||
2175 | template <typename Opnd0> | |||
2176 | inline typename m_Intrinsic_Ty<Opnd0>::Ty m_FCanonicalize(const Opnd0 &Op0) { | |||
2177 | return m_Intrinsic<Intrinsic::canonicalize>(Op0); | |||
2178 | } | |||
2179 | ||||
2180 | template <typename Opnd0, typename Opnd1> | |||
2181 | inline typename m_Intrinsic_Ty<Opnd0, Opnd1>::Ty m_FMin(const Opnd0 &Op0, | |||
2182 | const Opnd1 &Op1) { | |||
2183 | return m_Intrinsic<Intrinsic::minnum>(Op0, Op1); | |||
2184 | } | |||
2185 | ||||
2186 | template <typename Opnd0, typename Opnd1> | |||
2187 | inline typename m_Intrinsic_Ty<Opnd0, Opnd1>::Ty m_FMax(const Opnd0 &Op0, | |||
2188 | const Opnd1 &Op1) { | |||
2189 | return m_Intrinsic<Intrinsic::maxnum>(Op0, Op1); | |||
2190 | } | |||
2191 | ||||
2192 | template <typename Opnd0, typename Opnd1, typename Opnd2> | |||
2193 | inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2>::Ty | |||
2194 | m_FShl(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2) { | |||
2195 | return m_Intrinsic<Intrinsic::fshl>(Op0, Op1, Op2); | |||
2196 | } | |||
2197 | ||||
2198 | template <typename Opnd0, typename Opnd1, typename Opnd2> | |||
2199 | inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2>::Ty | |||
2200 | m_FShr(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2) { | |||
2201 | return m_Intrinsic<Intrinsic::fshr>(Op0, Op1, Op2); | |||
2202 | } | |||
2203 | ||||
2204 | template <typename Opnd0> | |||
2205 | inline typename m_Intrinsic_Ty<Opnd0>::Ty m_Sqrt(const Opnd0 &Op0) { | |||
2206 | return m_Intrinsic<Intrinsic::sqrt>(Op0); | |||
2207 | } | |||
2208 | ||||
2209 | template <typename Opnd0, typename Opnd1> | |||
2210 | inline typename m_Intrinsic_Ty<Opnd0, Opnd1>::Ty m_CopySign(const Opnd0 &Op0, | |||
2211 | const Opnd1 &Op1) { | |||
2212 | return m_Intrinsic<Intrinsic::copysign>(Op0, Op1); | |||
2213 | } | |||
2214 | ||||
2215 | template <typename Opnd0> | |||
2216 | inline typename m_Intrinsic_Ty<Opnd0>::Ty m_VecReverse(const Opnd0 &Op0) { | |||
2217 | return m_Intrinsic<Intrinsic::experimental_vector_reverse>(Op0); | |||
2218 | } | |||
2219 | ||||
2220 | //===----------------------------------------------------------------------===// | |||
2221 | // Matchers for two-operands operators with the operators in either order | |||
2222 | // | |||
2223 | ||||
2224 | /// Matches a BinaryOperator with LHS and RHS in either order. | |||
2225 | template <typename LHS, typename RHS> | |||
2226 | inline AnyBinaryOp_match<LHS, RHS, true> m_c_BinOp(const LHS &L, const RHS &R) { | |||
2227 | return AnyBinaryOp_match<LHS, RHS, true>(L, R); | |||
2228 | } | |||
2229 | ||||
2230 | /// Matches an ICmp with a predicate over LHS and RHS in either order. | |||
2231 | /// Swaps the predicate if operands are commuted. | |||
2232 | template <typename LHS, typename RHS> | |||
2233 | inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true> | |||
2234 | m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) { | |||
2235 | return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(Pred, L, | |||
2236 | R); | |||
2237 | } | |||
2238 | ||||
2239 | /// Matches a specific opcode with LHS and RHS in either order. | |||
2240 | template <typename LHS, typename RHS> | |||
2241 | inline SpecificBinaryOp_match<LHS, RHS, true> | |||
2242 | m_c_BinOp(unsigned Opcode, const LHS &L, const RHS &R) { | |||
2243 | return SpecificBinaryOp_match<LHS, RHS, true>(Opcode, L, R); | |||
2244 | } | |||
2245 | ||||
2246 | /// Matches a Add with LHS and RHS in either order. | |||
2247 | template <typename LHS, typename RHS> | |||
2248 | inline BinaryOp_match<LHS, RHS, Instruction::Add, true> m_c_Add(const LHS &L, | |||
2249 | const RHS &R) { | |||
2250 | return BinaryOp_match<LHS, RHS, Instruction::Add, true>(L, R); | |||
2251 | } | |||
2252 | ||||
2253 | /// Matches a Mul with LHS and RHS in either order. | |||
2254 | template <typename LHS, typename RHS> | |||
2255 | inline BinaryOp_match<LHS, RHS, Instruction::Mul, true> m_c_Mul(const LHS &L, | |||
2256 | const RHS &R) { | |||
2257 | return BinaryOp_match<LHS, RHS, Instruction::Mul, true>(L, R); | |||
2258 | } | |||
2259 | ||||
2260 | /// Matches an And with LHS and RHS in either order. | |||
2261 | template <typename LHS, typename RHS> | |||
2262 | inline BinaryOp_match<LHS, RHS, Instruction::And, true> m_c_And(const LHS &L, | |||
2263 | const RHS &R) { | |||
2264 | return BinaryOp_match<LHS, RHS, Instruction::And, true>(L, R); | |||
2265 | } | |||
2266 | ||||
2267 | /// Matches an Or with LHS and RHS in either order. | |||
2268 | template <typename LHS, typename RHS> | |||
2269 | inline BinaryOp_match<LHS, RHS, Instruction::Or, true> m_c_Or(const LHS &L, | |||
2270 | const RHS &R) { | |||
2271 | return BinaryOp_match<LHS, RHS, Instruction::Or, true>(L, R); | |||
2272 | } | |||
2273 | ||||
2274 | /// Matches an Xor with LHS and RHS in either order. | |||
2275 | template <typename LHS, typename RHS> | |||
2276 | inline BinaryOp_match<LHS, RHS, Instruction::Xor, true> m_c_Xor(const LHS &L, | |||
2277 | const RHS &R) { | |||
2278 | return BinaryOp_match<LHS, RHS, Instruction::Xor, true>(L, R); | |||
2279 | } | |||
2280 | ||||
2281 | /// Matches a 'Neg' as 'sub 0, V'. | |||
2282 | template <typename ValTy> | |||
2283 | inline BinaryOp_match<cst_pred_ty<is_zero_int>, ValTy, Instruction::Sub> | |||
2284 | m_Neg(const ValTy &V) { | |||
2285 | return m_Sub(m_ZeroInt(), V); | |||
2286 | } | |||
2287 | ||||
2288 | /// Matches a 'Neg' as 'sub nsw 0, V'. | |||
2289 | template <typename ValTy> | |||
2290 | inline OverflowingBinaryOp_match<cst_pred_ty<is_zero_int>, ValTy, | |||
2291 | Instruction::Sub, | |||
2292 | OverflowingBinaryOperator::NoSignedWrap> | |||
2293 | m_NSWNeg(const ValTy &V) { | |||
2294 | return m_NSWSub(m_ZeroInt(), V); | |||
2295 | } | |||
2296 | ||||
2297 | /// Matches a 'Not' as 'xor V, -1' or 'xor -1, V'. | |||
2298 | /// NOTE: we first match the 'Not' (by matching '-1'), | |||
2299 | /// and only then match the inner matcher! | |||
2300 | template <typename ValTy> | |||
2301 | inline BinaryOp_match<cst_pred_ty<is_all_ones>, ValTy, Instruction::Xor, true> | |||
2302 | m_Not(const ValTy &V) { | |||
2303 | return m_c_Xor(m_AllOnes(), V); | |||
2304 | } | |||
2305 | ||||
2306 | template <typename ValTy> struct NotForbidUndef_match { | |||
2307 | ValTy Val; | |||
2308 | NotForbidUndef_match(const ValTy &V) : Val(V) {} | |||
2309 | ||||
2310 | template <typename OpTy> bool match(OpTy *V) { | |||
2311 | // We do not use m_c_Xor because that could match an arbitrary APInt that is | |||
2312 | // not -1 as C and then fail to match the other operand if it is -1. | |||
2313 | // This code should still work even when both operands are constants. | |||
2314 | Value *X; | |||
2315 | const APInt *C; | |||
2316 | if (m_Xor(m_Value(X), m_APIntForbidUndef(C)).match(V) && C->isAllOnes()) | |||
2317 | return Val.match(X); | |||
2318 | if (m_Xor(m_APIntForbidUndef(C), m_Value(X)).match(V) && C->isAllOnes()) | |||
2319 | return Val.match(X); | |||
2320 | return false; | |||
2321 | } | |||
2322 | }; | |||
2323 | ||||
2324 | /// Matches a bitwise 'not' as 'xor V, -1' or 'xor -1, V'. For vectors, the | |||
2325 | /// constant value must be composed of only -1 scalar elements. | |||
2326 | template <typename ValTy> | |||
2327 | inline NotForbidUndef_match<ValTy> m_NotForbidUndef(const ValTy &V) { | |||
2328 | return NotForbidUndef_match<ValTy>(V); | |||
2329 | } | |||
2330 | ||||
2331 | /// Matches an SMin with LHS and RHS in either order. | |||
2332 | template <typename LHS, typename RHS> | |||
2333 | inline MaxMin_match<ICmpInst, LHS, RHS, smin_pred_ty, true> | |||
2334 | m_c_SMin(const LHS &L, const RHS &R) { | |||
2335 | return MaxMin_match<ICmpInst, LHS, RHS, smin_pred_ty, true>(L, R); | |||
2336 | } | |||
2337 | /// Matches an SMax with LHS and RHS in either order. | |||
2338 | template <typename LHS, typename RHS> | |||
2339 | inline MaxMin_match<ICmpInst, LHS, RHS, smax_pred_ty, true> | |||
2340 | m_c_SMax(const LHS &L, const RHS &R) { | |||
2341 | return MaxMin_match<ICmpInst, LHS, RHS, smax_pred_ty, true>(L, R); | |||
2342 | } | |||
2343 | /// Matches a UMin with LHS and RHS in either order. | |||
2344 | template <typename LHS, typename RHS> | |||
2345 | inline MaxMin_match<ICmpInst, LHS, RHS, umin_pred_ty, true> | |||
2346 | m_c_UMin(const LHS &L, const RHS &R) { | |||
2347 | return MaxMin_match<ICmpInst, LHS, RHS, umin_pred_ty, true>(L, R); | |||
2348 | } | |||
2349 | /// Matches a UMax with LHS and RHS in either order. | |||
2350 | template <typename LHS, typename RHS> | |||
2351 | inline MaxMin_match<ICmpInst, LHS, RHS, umax_pred_ty, true> | |||
2352 | m_c_UMax(const LHS &L, const RHS &R) { | |||
2353 | return MaxMin_match<ICmpInst, LHS, RHS, umax_pred_ty, true>(L, R); | |||
2354 | } | |||
2355 | ||||
2356 | template <typename LHS, typename RHS> | |||
2357 | inline match_combine_or< | |||
2358 | match_combine_or<MaxMin_match<ICmpInst, LHS, RHS, smax_pred_ty, true>, | |||
2359 | MaxMin_match<ICmpInst, LHS, RHS, smin_pred_ty, true>>, | |||
2360 | match_combine_or<MaxMin_match<ICmpInst, LHS, RHS, umax_pred_ty, true>, | |||
2361 | MaxMin_match<ICmpInst, LHS, RHS, umin_pred_ty, true>>> | |||
2362 | m_c_MaxOrMin(const LHS &L, const RHS &R) { | |||
2363 | return m_CombineOr(m_CombineOr(m_c_SMax(L, R), m_c_SMin(L, R)), | |||
2364 | m_CombineOr(m_c_UMax(L, R), m_c_UMin(L, R))); | |||
2365 | } | |||
2366 | ||||
2367 | template <Intrinsic::ID IntrID, typename T0, typename T1> | |||
2368 | inline match_combine_or<typename m_Intrinsic_Ty<T0, T1>::Ty, | |||
2369 | typename m_Intrinsic_Ty<T1, T0>::Ty> | |||
2370 | m_c_Intrinsic(const T0 &Op0, const T1 &Op1) { | |||
2371 | return m_CombineOr(m_Intrinsic<IntrID>(Op0, Op1), | |||
2372 | m_Intrinsic<IntrID>(Op1, Op0)); | |||
2373 | } | |||
2374 | ||||
2375 | /// Matches FAdd with LHS and RHS in either order. | |||
2376 | template <typename LHS, typename RHS> | |||
2377 | inline BinaryOp_match<LHS, RHS, Instruction::FAdd, true> | |||
2378 | m_c_FAdd(const LHS &L, const RHS &R) { | |||
2379 | return BinaryOp_match<LHS, RHS, Instruction::FAdd, true>(L, R); | |||
2380 | } | |||
2381 | ||||
2382 | /// Matches FMul with LHS and RHS in either order. | |||
2383 | template <typename LHS, typename RHS> | |||
2384 | inline BinaryOp_match<LHS, RHS, Instruction::FMul, true> | |||
2385 | m_c_FMul(const LHS &L, const RHS &R) { | |||
2386 | return BinaryOp_match<LHS, RHS, Instruction::FMul, true>(L, R); | |||
2387 | } | |||
2388 | ||||
2389 | template <typename Opnd_t> struct Signum_match { | |||
2390 | Opnd_t Val; | |||
2391 | Signum_match(const Opnd_t &V) : Val(V) {} | |||
2392 | ||||
2393 | template <typename OpTy> bool match(OpTy *V) { | |||
2394 | unsigned TypeSize = V->getType()->getScalarSizeInBits(); | |||
2395 | if (TypeSize == 0) | |||
2396 | return false; | |||
2397 | ||||
2398 | unsigned ShiftWidth = TypeSize - 1; | |||
2399 | Value *OpL = nullptr, *OpR = nullptr; | |||
2400 | ||||
2401 | // This is the representation of signum we match: | |||
2402 | // | |||
2403 | // signum(x) == (x >> 63) | (-x >>u 63) | |||
2404 | // | |||
2405 | // An i1 value is its own signum, so it's correct to match | |||
2406 | // | |||
2407 | // signum(x) == (x >> 0) | (-x >>u 0) | |||
2408 | // | |||
2409 | // for i1 values. | |||
2410 | ||||
2411 | auto LHS = m_AShr(m_Value(OpL), m_SpecificInt(ShiftWidth)); | |||
2412 | auto RHS = m_LShr(m_Neg(m_Value(OpR)), m_SpecificInt(ShiftWidth)); | |||
2413 | auto Signum = m_Or(LHS, RHS); | |||
2414 | ||||
2415 | return Signum.match(V) && OpL == OpR && Val.match(OpL); | |||
2416 | } | |||
2417 | }; | |||
2418 | ||||
2419 | /// Matches a signum pattern. | |||
2420 | /// | |||
2421 | /// signum(x) = | |||
2422 | /// x > 0 -> 1 | |||
2423 | /// x == 0 -> 0 | |||
2424 | /// x < 0 -> -1 | |||
2425 | template <typename Val_t> inline Signum_match<Val_t> m_Signum(const Val_t &V) { | |||
2426 | return Signum_match<Val_t>(V); | |||
2427 | } | |||
2428 | ||||
2429 | template <int Ind, typename Opnd_t> struct ExtractValue_match { | |||
2430 | Opnd_t Val; | |||
2431 | ExtractValue_match(const Opnd_t &V) : Val(V) {} | |||
2432 | ||||
2433 | template <typename OpTy> bool match(OpTy *V) { | |||
2434 | if (auto *I = dyn_cast<ExtractValueInst>(V)) { | |||
2435 | // If Ind is -1, don't inspect indices | |||
2436 | if (Ind != -1 && | |||
2437 | !(I->getNumIndices() == 1 && I->getIndices()[0] == (unsigned)Ind)) | |||
2438 | return false; | |||
2439 | return Val.match(I->getAggregateOperand()); | |||
2440 | } | |||
2441 | return false; | |||
2442 | } | |||
2443 | }; | |||
2444 | ||||
2445 | /// Match a single index ExtractValue instruction. | |||
2446 | /// For example m_ExtractValue<1>(...) | |||
2447 | template <int Ind, typename Val_t> | |||
2448 | inline ExtractValue_match<Ind, Val_t> m_ExtractValue(const Val_t &V) { | |||
2449 | return ExtractValue_match<Ind, Val_t>(V); | |||
2450 | } | |||
2451 | ||||
2452 | /// Match an ExtractValue instruction with any index. | |||
2453 | /// For example m_ExtractValue(...) | |||
2454 | template <typename Val_t> | |||
2455 | inline ExtractValue_match<-1, Val_t> m_ExtractValue(const Val_t &V) { | |||
2456 | return ExtractValue_match<-1, Val_t>(V); | |||
2457 | } | |||
2458 | ||||
2459 | /// Matcher for a single index InsertValue instruction. | |||
2460 | template <int Ind, typename T0, typename T1> struct InsertValue_match { | |||
2461 | T0 Op0; | |||
2462 | T1 Op1; | |||
2463 | ||||
2464 | InsertValue_match(const T0 &Op0, const T1 &Op1) : Op0(Op0), Op1(Op1) {} | |||
2465 | ||||
2466 | template <typename OpTy> bool match(OpTy *V) { | |||
2467 | if (auto *I = dyn_cast<InsertValueInst>(V)) { | |||
2468 | return Op0.match(I->getOperand(0)) && Op1.match(I->getOperand(1)) && | |||
2469 | I->getNumIndices() == 1 && Ind == I->getIndices()[0]; | |||
2470 | } | |||
2471 | return false; | |||
2472 | } | |||
2473 | }; | |||
2474 | ||||
2475 | /// Matches a single index InsertValue instruction. | |||
2476 | template <int Ind, typename Val_t, typename Elt_t> | |||
2477 | inline InsertValue_match<Ind, Val_t, Elt_t> m_InsertValue(const Val_t &Val, | |||
2478 | const Elt_t &Elt) { | |||
2479 | return InsertValue_match<Ind, Val_t, Elt_t>(Val, Elt); | |||
2480 | } | |||
2481 | ||||
2482 | /// Matches patterns for `vscale`. This can either be a call to `llvm.vscale` or | |||
2483 | /// the constant expression | |||
2484 | /// `ptrtoint(gep <vscale x 1 x i8>, <vscale x 1 x i8>* null, i32 1>` | |||
2485 | /// under the right conditions determined by DataLayout. | |||
2486 | struct VScaleVal_match { | |||
2487 | template <typename ITy> bool match(ITy *V) { | |||
2488 | if (m_Intrinsic<Intrinsic::vscale>().match(V)) | |||
2489 | return true; | |||
2490 | ||||
2491 | Value *Ptr; | |||
2492 | if (m_PtrToInt(m_Value(Ptr)).match(V)) { | |||
2493 | if (auto *GEP = dyn_cast<GEPOperator>(Ptr)) { | |||
2494 | auto *DerefTy = | |||
2495 | dyn_cast<ScalableVectorType>(GEP->getSourceElementType()); | |||
2496 | if (GEP->getNumIndices() == 1 && DerefTy && | |||
2497 | DerefTy->getElementType()->isIntegerTy(8) && | |||
2498 | m_Zero().match(GEP->getPointerOperand()) && | |||
2499 | m_SpecificInt(1).match(GEP->idx_begin()->get())) | |||
2500 | return true; | |||
2501 | } | |||
2502 | } | |||
2503 | ||||
2504 | return false; | |||
2505 | } | |||
2506 | }; | |||
2507 | ||||
2508 | inline VScaleVal_match m_VScale() { | |||
2509 | return VScaleVal_match(); | |||
2510 | } | |||
2511 | ||||
2512 | template <typename LHS, typename RHS, unsigned Opcode, bool Commutable = false> | |||
2513 | struct LogicalOp_match { | |||
2514 | LHS L; | |||
2515 | RHS R; | |||
2516 | ||||
2517 | LogicalOp_match(const LHS &L, const RHS &R) : L(L), R(R) {} | |||
2518 | ||||
2519 | template <typename T> bool match(T *V) { | |||
2520 | auto *I = dyn_cast<Instruction>(V); | |||
2521 | if (!I || !I->getType()->isIntOrIntVectorTy(1)) | |||
2522 | return false; | |||
2523 | ||||
2524 | if (I->getOpcode() == Opcode) { | |||
2525 | auto *Op0 = I->getOperand(0); | |||
2526 | auto *Op1 = I->getOperand(1); | |||
2527 | return (L.match(Op0) && R.match(Op1)) || | |||
2528 | (Commutable && L.match(Op1) && R.match(Op0)); | |||
2529 | } | |||
2530 | ||||
2531 | if (auto *Select = dyn_cast<SelectInst>(I)) { | |||
2532 | auto *Cond = Select->getCondition(); | |||
2533 | auto *TVal = Select->getTrueValue(); | |||
2534 | auto *FVal = Select->getFalseValue(); | |||
2535 | ||||
2536 | // Don't match a scalar select of bool vectors. | |||
2537 | // Transforms expect a single type for operands if this matches. | |||
2538 | if (Cond->getType() != Select->getType()) | |||
2539 | return false; | |||
2540 | ||||
2541 | if (Opcode == Instruction::And) { | |||
2542 | auto *C = dyn_cast<Constant>(FVal); | |||
2543 | if (C && C->isNullValue()) | |||
2544 | return (L.match(Cond) && R.match(TVal)) || | |||
2545 | (Commutable && L.match(TVal) && R.match(Cond)); | |||
2546 | } else { | |||
2547 | assert(Opcode == Instruction::Or)(static_cast <bool> (Opcode == Instruction::Or) ? void ( 0) : __assert_fail ("Opcode == Instruction::Or", "llvm/include/llvm/IR/PatternMatch.h" , 2547, __extension__ __PRETTY_FUNCTION__)); | |||
2548 | auto *C = dyn_cast<Constant>(TVal); | |||
2549 | if (C && C->isOneValue()) | |||
2550 | return (L.match(Cond) && R.match(FVal)) || | |||
2551 | (Commutable && L.match(FVal) && R.match(Cond)); | |||
2552 | } | |||
2553 | } | |||
2554 | ||||
2555 | return false; | |||
2556 | } | |||
2557 | }; | |||
2558 | ||||
2559 | /// Matches L && R either in the form of L & R or L ? R : false. | |||
2560 | /// Note that the latter form is poison-blocking. | |||
2561 | template <typename LHS, typename RHS> | |||
2562 | inline LogicalOp_match<LHS, RHS, Instruction::And> m_LogicalAnd(const LHS &L, | |||
2563 | const RHS &R) { | |||
2564 | return LogicalOp_match<LHS, RHS, Instruction::And>(L, R); | |||
2565 | } | |||
2566 | ||||
2567 | /// Matches L && R where L and R are arbitrary values. | |||
2568 | inline auto m_LogicalAnd() { return m_LogicalAnd(m_Value(), m_Value()); } | |||
2569 | ||||
2570 | /// Matches L && R with LHS and RHS in either order. | |||
2571 | template <typename LHS, typename RHS> | |||
2572 | inline LogicalOp_match<LHS, RHS, Instruction::And, true> | |||
2573 | m_c_LogicalAnd(const LHS &L, const RHS &R) { | |||
2574 | return LogicalOp_match<LHS, RHS, Instruction::And, true>(L, R); | |||
2575 | } | |||
2576 | ||||
2577 | /// Matches L || R either in the form of L | R or L ? true : R. | |||
2578 | /// Note that the latter form is poison-blocking. | |||
2579 | template <typename LHS, typename RHS> | |||
2580 | inline LogicalOp_match<LHS, RHS, Instruction::Or> m_LogicalOr(const LHS &L, | |||
2581 | const RHS &R) { | |||
2582 | return LogicalOp_match<LHS, RHS, Instruction::Or>(L, R); | |||
2583 | } | |||
2584 | ||||
2585 | /// Matches L || R where L and R are arbitrary values. | |||
2586 | inline auto m_LogicalOr() { return m_LogicalOr(m_Value(), m_Value()); } | |||
2587 | ||||
2588 | /// Matches L || R with LHS and RHS in either order. | |||
2589 | template <typename LHS, typename RHS> | |||
2590 | inline LogicalOp_match<LHS, RHS, Instruction::Or, true> | |||
2591 | m_c_LogicalOr(const LHS &L, const RHS &R) { | |||
2592 | return LogicalOp_match<LHS, RHS, Instruction::Or, true>(L, R); | |||
2593 | } | |||
2594 | ||||
2595 | /// Matches either L && R or L || R, | |||
2596 | /// either one being in the either binary or logical form. | |||
2597 | /// Note that the latter form is poison-blocking. | |||
2598 | template <typename LHS, typename RHS, bool Commutable = false> | |||
2599 | inline auto m_LogicalOp(const LHS &L, const RHS &R) { | |||
2600 | return m_CombineOr( | |||
2601 | LogicalOp_match<LHS, RHS, Instruction::And, Commutable>(L, R), | |||
2602 | LogicalOp_match<LHS, RHS, Instruction::Or, Commutable>(L, R)); | |||
2603 | } | |||
2604 | ||||
2605 | /// Matches either L && R or L || R where L and R are arbitrary values. | |||
2606 | inline auto m_LogicalOp() { return m_LogicalOp(m_Value(), m_Value()); } | |||
2607 | ||||
2608 | /// Matches either L && R or L || R with LHS and RHS in either order. | |||
2609 | template <typename LHS, typename RHS> | |||
2610 | inline auto m_c_LogicalOp(const LHS &L, const RHS &R) { | |||
2611 | return m_LogicalOp<LHS, RHS, /*Commutable=*/true>(L, R); | |||
2612 | } | |||
2613 | ||||
2614 | } // end namespace PatternMatch | |||
2615 | } // end namespace llvm | |||
2616 | ||||
2617 | #endif // LLVM_IR_PATTERNMATCH_H |