Bug Summary

File:build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Warning:line 2424, column 21
The result of the left shift is undefined due to shifting by '4294967295', which is greater or equal to the width of type 'int'

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name AArch64TargetTransformInfo.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I lib/Target/AArch64 -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/lib/Target/AArch64 -I include -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -ferror-limit 19 -fvisibility=hidden -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-09-04-125545-48738-1 -x c++ /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

1//===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "AArch64TargetTransformInfo.h"
10#include "AArch64ExpandImm.h"
11#include "AArch64PerfectShuffle.h"
12#include "MCTargetDesc/AArch64AddressingModes.h"
13#include "llvm/Analysis/IVDescriptors.h"
14#include "llvm/Analysis/LoopInfo.h"
15#include "llvm/Analysis/TargetTransformInfo.h"
16#include "llvm/CodeGen/BasicTTIImpl.h"
17#include "llvm/CodeGen/CostTable.h"
18#include "llvm/CodeGen/TargetLowering.h"
19#include "llvm/IR/IntrinsicInst.h"
20#include "llvm/IR/Intrinsics.h"
21#include "llvm/IR/IntrinsicsAArch64.h"
22#include "llvm/IR/PatternMatch.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Transforms/InstCombine/InstCombiner.h"
25#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
26#include <algorithm>
27using namespace llvm;
28using namespace llvm::PatternMatch;
29
30#define DEBUG_TYPE"aarch64tti" "aarch64tti"
31
32static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
33 cl::init(true), cl::Hidden);
34
35static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10),
36 cl::Hidden);
37
38static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead",
39 cl::init(10), cl::Hidden);
40
41class TailFoldingKind {
42private:
43 uint8_t Bits = 0; // Currently defaults to disabled.
44
45public:
46 enum TailFoldingOpts {
47 TFDisabled = 0x0,
48 TFReductions = 0x01,
49 TFRecurrences = 0x02,
50 TFSimple = 0x80,
51 TFAll = TFReductions | TFRecurrences | TFSimple
52 };
53
54 void operator=(const std::string &Val) {
55 if (Val.empty())
56 return;
57 SmallVector<StringRef, 6> TailFoldTypes;
58 StringRef(Val).split(TailFoldTypes, '+', -1, false);
59 for (auto TailFoldType : TailFoldTypes) {
60 if (TailFoldType == "disabled")
61 Bits = 0;
62 else if (TailFoldType == "all")
63 Bits = TFAll;
64 else if (TailFoldType == "default")
65 Bits = 0; // Currently defaults to never tail-folding.
66 else if (TailFoldType == "simple")
67 add(TFSimple);
68 else if (TailFoldType == "reductions")
69 add(TFReductions);
70 else if (TailFoldType == "recurrences")
71 add(TFRecurrences);
72 else if (TailFoldType == "noreductions")
73 remove(TFReductions);
74 else if (TailFoldType == "norecurrences")
75 remove(TFRecurrences);
76 else {
77 errs()
78 << "invalid argument " << TailFoldType.str()
79 << " to -sve-tail-folding=; each element must be one of: disabled, "
80 "all, default, simple, reductions, noreductions, recurrences, "
81 "norecurrences\n";
82 }
83 }
84 }
85
86 operator uint8_t() const { return Bits; }
87
88 void add(uint8_t Flag) { Bits |= Flag; }
89 void remove(uint8_t Flag) { Bits &= ~Flag; }
90};
91
92TailFoldingKind TailFoldingKindLoc;
93
94cl::opt<TailFoldingKind, true, cl::parser<std::string>> SVETailFolding(
95 "sve-tail-folding",
96 cl::desc(
97 "Control the use of vectorisation using tail-folding for SVE:"
98 "\ndisabled No loop types will vectorize using tail-folding"
99 "\ndefault Uses the default tail-folding settings for the target "
100 "CPU"
101 "\nall All legal loop types will vectorize using tail-folding"
102 "\nsimple Use tail-folding for simple loops (not reductions or "
103 "recurrences)"
104 "\nreductions Use tail-folding for loops containing reductions"
105 "\nrecurrences Use tail-folding for loops containing fixed order "
106 "recurrences"),
107 cl::location(TailFoldingKindLoc));
108
109bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
110 const Function *Callee) const {
111 const TargetMachine &TM = getTLI()->getTargetMachine();
112
113 const FeatureBitset &CallerBits =
114 TM.getSubtargetImpl(*Caller)->getFeatureBits();
115 const FeatureBitset &CalleeBits =
116 TM.getSubtargetImpl(*Callee)->getFeatureBits();
117
118 // Inline a callee if its target-features are a subset of the callers
119 // target-features.
120 return (CallerBits & CalleeBits) == CalleeBits;
121}
122
123bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
124 TargetTransformInfo::RegisterKind K) const {
125 assert(K != TargetTransformInfo::RGK_Scalar)(static_cast <bool> (K != TargetTransformInfo::RGK_Scalar
) ? void (0) : __assert_fail ("K != TargetTransformInfo::RGK_Scalar"
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 125
, __extension__ __PRETTY_FUNCTION__))
;
126 return K == TargetTransformInfo::RGK_FixedWidthVector;
127}
128
129/// Calculate the cost of materializing a 64-bit value. This helper
130/// method might only calculate a fraction of a larger immediate. Therefore it
131/// is valid to return a cost of ZERO.
132InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
133 // Check if the immediate can be encoded within an instruction.
134 if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
135 return 0;
136
137 if (Val < 0)
138 Val = ~Val;
139
140 // Calculate how many moves we will need to materialize this constant.
141 SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
142 AArch64_IMM::expandMOVImm(Val, 64, Insn);
143 return Insn.size();
144}
145
146/// Calculate the cost of materializing the given constant.
147InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
148 TTI::TargetCostKind CostKind) {
149 assert(Ty->isIntegerTy())(static_cast <bool> (Ty->isIntegerTy()) ? void (0) :
__assert_fail ("Ty->isIntegerTy()", "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp"
, 149, __extension__ __PRETTY_FUNCTION__))
;
150
151 unsigned BitSize = Ty->getPrimitiveSizeInBits();
152 if (BitSize == 0)
153 return ~0U;
154
155 // Sign-extend all constants to a multiple of 64-bit.
156 APInt ImmVal = Imm;
157 if (BitSize & 0x3f)
158 ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
159
160 // Split the constant into 64-bit chunks and calculate the cost for each
161 // chunk.
162 InstructionCost Cost = 0;
163 for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
164 APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
165 int64_t Val = Tmp.getSExtValue();
166 Cost += getIntImmCost(Val);
167 }
168 // We need at least one instruction to materialze the constant.
169 return std::max<InstructionCost>(1, Cost);
170}
171
172InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
173 const APInt &Imm, Type *Ty,
174 TTI::TargetCostKind CostKind,
175 Instruction *Inst) {
176 assert(Ty->isIntegerTy())(static_cast <bool> (Ty->isIntegerTy()) ? void (0) :
__assert_fail ("Ty->isIntegerTy()", "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp"
, 176, __extension__ __PRETTY_FUNCTION__))
;
177
178 unsigned BitSize = Ty->getPrimitiveSizeInBits();
179 // There is no cost model for constants with a bit size of 0. Return TCC_Free
180 // here, so that constant hoisting will ignore this constant.
181 if (BitSize == 0)
182 return TTI::TCC_Free;
183
184 unsigned ImmIdx = ~0U;
185 switch (Opcode) {
186 default:
187 return TTI::TCC_Free;
188 case Instruction::GetElementPtr:
189 // Always hoist the base address of a GetElementPtr.
190 if (Idx == 0)
191 return 2 * TTI::TCC_Basic;
192 return TTI::TCC_Free;
193 case Instruction::Store:
194 ImmIdx = 0;
195 break;
196 case Instruction::Add:
197 case Instruction::Sub:
198 case Instruction::Mul:
199 case Instruction::UDiv:
200 case Instruction::SDiv:
201 case Instruction::URem:
202 case Instruction::SRem:
203 case Instruction::And:
204 case Instruction::Or:
205 case Instruction::Xor:
206 case Instruction::ICmp:
207 ImmIdx = 1;
208 break;
209 // Always return TCC_Free for the shift value of a shift instruction.
210 case Instruction::Shl:
211 case Instruction::LShr:
212 case Instruction::AShr:
213 if (Idx == 1)
214 return TTI::TCC_Free;
215 break;
216 case Instruction::Trunc:
217 case Instruction::ZExt:
218 case Instruction::SExt:
219 case Instruction::IntToPtr:
220 case Instruction::PtrToInt:
221 case Instruction::BitCast:
222 case Instruction::PHI:
223 case Instruction::Call:
224 case Instruction::Select:
225 case Instruction::Ret:
226 case Instruction::Load:
227 break;
228 }
229
230 if (Idx == ImmIdx) {
231 int NumConstants = (BitSize + 63) / 64;
232 InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
233 return (Cost <= NumConstants * TTI::TCC_Basic)
234 ? static_cast<int>(TTI::TCC_Free)
235 : Cost;
236 }
237 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
238}
239
240InstructionCost
241AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
242 const APInt &Imm, Type *Ty,
243 TTI::TargetCostKind CostKind) {
244 assert(Ty->isIntegerTy())(static_cast <bool> (Ty->isIntegerTy()) ? void (0) :
__assert_fail ("Ty->isIntegerTy()", "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp"
, 244, __extension__ __PRETTY_FUNCTION__))
;
245
246 unsigned BitSize = Ty->getPrimitiveSizeInBits();
247 // There is no cost model for constants with a bit size of 0. Return TCC_Free
248 // here, so that constant hoisting will ignore this constant.
249 if (BitSize == 0)
250 return TTI::TCC_Free;
251
252 // Most (all?) AArch64 intrinsics do not support folding immediates into the
253 // selected instruction, so we compute the materialization cost for the
254 // immediate directly.
255 if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
256 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
257
258 switch (IID) {
259 default:
260 return TTI::TCC_Free;
261 case Intrinsic::sadd_with_overflow:
262 case Intrinsic::uadd_with_overflow:
263 case Intrinsic::ssub_with_overflow:
264 case Intrinsic::usub_with_overflow:
265 case Intrinsic::smul_with_overflow:
266 case Intrinsic::umul_with_overflow:
267 if (Idx == 1) {
268 int NumConstants = (BitSize + 63) / 64;
269 InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
270 return (Cost <= NumConstants * TTI::TCC_Basic)
271 ? static_cast<int>(TTI::TCC_Free)
272 : Cost;
273 }
274 break;
275 case Intrinsic::experimental_stackmap:
276 if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
277 return TTI::TCC_Free;
278 break;
279 case Intrinsic::experimental_patchpoint_void:
280 case Intrinsic::experimental_patchpoint_i64:
281 if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
282 return TTI::TCC_Free;
283 break;
284 case Intrinsic::experimental_gc_statepoint:
285 if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
286 return TTI::TCC_Free;
287 break;
288 }
289 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
290}
291
292TargetTransformInfo::PopcntSupportKind
293AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
294 assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2")(static_cast <bool> (isPowerOf2_32(TyWidth) && "Ty width must be power of 2"
) ? void (0) : __assert_fail ("isPowerOf2_32(TyWidth) && \"Ty width must be power of 2\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 294
, __extension__ __PRETTY_FUNCTION__))
;
295 if (TyWidth == 32 || TyWidth == 64)
296 return TTI::PSK_FastHardware;
297 // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
298 return TTI::PSK_Software;
299}
300
301InstructionCost
302AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
303 TTI::TargetCostKind CostKind) {
304 auto *RetTy = ICA.getReturnType();
305 switch (ICA.getID()) {
306 case Intrinsic::umin:
307 case Intrinsic::umax:
308 case Intrinsic::smin:
309 case Intrinsic::smax: {
310 static const auto ValidMinMaxTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
311 MVT::v8i16, MVT::v2i32, MVT::v4i32};
312 auto LT = getTypeLegalizationCost(RetTy);
313 // v2i64 types get converted to cmp+bif hence the cost of 2
314 if (LT.second == MVT::v2i64)
315 return LT.first * 2;
316 if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
317 return LT.first;
318 break;
319 }
320 case Intrinsic::sadd_sat:
321 case Intrinsic::ssub_sat:
322 case Intrinsic::uadd_sat:
323 case Intrinsic::usub_sat: {
324 static const auto ValidSatTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
325 MVT::v8i16, MVT::v2i32, MVT::v4i32,
326 MVT::v2i64};
327 auto LT = getTypeLegalizationCost(RetTy);
328 // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
329 // need to extend the type, as it uses shr(qadd(shl, shl)).
330 unsigned Instrs =
331 LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
332 if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
333 return LT.first * Instrs;
334 break;
335 }
336 case Intrinsic::abs: {
337 static const auto ValidAbsTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
338 MVT::v8i16, MVT::v2i32, MVT::v4i32,
339 MVT::v2i64};
340 auto LT = getTypeLegalizationCost(RetTy);
341 if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
342 return LT.first;
343 break;
344 }
345 case Intrinsic::experimental_stepvector: {
346 InstructionCost Cost = 1; // Cost of the `index' instruction
347 auto LT = getTypeLegalizationCost(RetTy);
348 // Legalisation of illegal vectors involves an `index' instruction plus
349 // (LT.first - 1) vector adds.
350 if (LT.first > 1) {
351 Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
352 InstructionCost AddCost =
353 getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
354 Cost += AddCost * (LT.first - 1);
355 }
356 return Cost;
357 }
358 case Intrinsic::bitreverse: {
359 static const CostTblEntry BitreverseTbl[] = {
360 {Intrinsic::bitreverse, MVT::i32, 1},
361 {Intrinsic::bitreverse, MVT::i64, 1},
362 {Intrinsic::bitreverse, MVT::v8i8, 1},
363 {Intrinsic::bitreverse, MVT::v16i8, 1},
364 {Intrinsic::bitreverse, MVT::v4i16, 2},
365 {Intrinsic::bitreverse, MVT::v8i16, 2},
366 {Intrinsic::bitreverse, MVT::v2i32, 2},
367 {Intrinsic::bitreverse, MVT::v4i32, 2},
368 {Intrinsic::bitreverse, MVT::v1i64, 2},
369 {Intrinsic::bitreverse, MVT::v2i64, 2},
370 };
371 const auto LegalisationCost = getTypeLegalizationCost(RetTy);
372 const auto *Entry =
373 CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
374 if (Entry) {
375 // Cost Model is using the legal type(i32) that i8 and i16 will be
376 // converted to +1 so that we match the actual lowering cost
377 if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
378 TLI->getValueType(DL, RetTy, true) == MVT::i16)
379 return LegalisationCost.first * Entry->Cost + 1;
380
381 return LegalisationCost.first * Entry->Cost;
382 }
383 break;
384 }
385 case Intrinsic::ctpop: {
386 if (!ST->hasNEON()) {
387 // 32-bit or 64-bit ctpop without NEON is 12 instructions.
388 return getTypeLegalizationCost(RetTy).first * 12;
389 }
390 static const CostTblEntry CtpopCostTbl[] = {
391 {ISD::CTPOP, MVT::v2i64, 4},
392 {ISD::CTPOP, MVT::v4i32, 3},
393 {ISD::CTPOP, MVT::v8i16, 2},
394 {ISD::CTPOP, MVT::v16i8, 1},
395 {ISD::CTPOP, MVT::i64, 4},
396 {ISD::CTPOP, MVT::v2i32, 3},
397 {ISD::CTPOP, MVT::v4i16, 2},
398 {ISD::CTPOP, MVT::v8i8, 1},
399 {ISD::CTPOP, MVT::i32, 5},
400 };
401 auto LT = getTypeLegalizationCost(RetTy);
402 MVT MTy = LT.second;
403 if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
404 // Extra cost of +1 when illegal vector types are legalized by promoting
405 // the integer type.
406 int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
407 RetTy->getScalarSizeInBits()
408 ? 1
409 : 0;
410 return LT.first * Entry->Cost + ExtraCost;
411 }
412 break;
413 }
414 case Intrinsic::sadd_with_overflow:
415 case Intrinsic::uadd_with_overflow:
416 case Intrinsic::ssub_with_overflow:
417 case Intrinsic::usub_with_overflow:
418 case Intrinsic::smul_with_overflow:
419 case Intrinsic::umul_with_overflow: {
420 static const CostTblEntry WithOverflowCostTbl[] = {
421 {Intrinsic::sadd_with_overflow, MVT::i8, 3},
422 {Intrinsic::uadd_with_overflow, MVT::i8, 3},
423 {Intrinsic::sadd_with_overflow, MVT::i16, 3},
424 {Intrinsic::uadd_with_overflow, MVT::i16, 3},
425 {Intrinsic::sadd_with_overflow, MVT::i32, 1},
426 {Intrinsic::uadd_with_overflow, MVT::i32, 1},
427 {Intrinsic::sadd_with_overflow, MVT::i64, 1},
428 {Intrinsic::uadd_with_overflow, MVT::i64, 1},
429 {Intrinsic::ssub_with_overflow, MVT::i8, 3},
430 {Intrinsic::usub_with_overflow, MVT::i8, 3},
431 {Intrinsic::ssub_with_overflow, MVT::i16, 3},
432 {Intrinsic::usub_with_overflow, MVT::i16, 3},
433 {Intrinsic::ssub_with_overflow, MVT::i32, 1},
434 {Intrinsic::usub_with_overflow, MVT::i32, 1},
435 {Intrinsic::ssub_with_overflow, MVT::i64, 1},
436 {Intrinsic::usub_with_overflow, MVT::i64, 1},
437 {Intrinsic::smul_with_overflow, MVT::i8, 5},
438 {Intrinsic::umul_with_overflow, MVT::i8, 4},
439 {Intrinsic::smul_with_overflow, MVT::i16, 5},
440 {Intrinsic::umul_with_overflow, MVT::i16, 4},
441 {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst
442 {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw
443 {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp
444 {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr
445 };
446 EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true);
447 if (MTy.isSimple())
448 if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(),
449 MTy.getSimpleVT()))
450 return Entry->Cost;
451 break;
452 }
453 case Intrinsic::fptosi_sat:
454 case Intrinsic::fptoui_sat: {
455 if (ICA.getArgTypes().empty())
456 break;
457 bool IsSigned = ICA.getID() == Intrinsic::fptosi_sat;
458 auto LT = getTypeLegalizationCost(ICA.getArgTypes()[0]);
459 EVT MTy = TLI->getValueType(DL, RetTy);
460 // Check for the legal types, which are where the size of the input and the
461 // output are the same, or we are using cvt f64->i32 or f32->i64.
462 if ((LT.second == MVT::f32 || LT.second == MVT::f64 ||
463 LT.second == MVT::v2f32 || LT.second == MVT::v4f32 ||
464 LT.second == MVT::v2f64) &&
465 (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
466 (LT.second == MVT::f64 && MTy == MVT::i32) ||
467 (LT.second == MVT::f32 && MTy == MVT::i64)))
468 return LT.first;
469 // Similarly for fp16 sizes
470 if (ST->hasFullFP16() &&
471 ((LT.second == MVT::f16 && MTy == MVT::i32) ||
472 ((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
473 (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits()))))
474 return LT.first;
475
476 // Otherwise we use a legal convert followed by a min+max
477 if ((LT.second.getScalarType() == MVT::f32 ||
478 LT.second.getScalarType() == MVT::f64 ||
479 (ST->hasFullFP16() && LT.second.getScalarType() == MVT::f16)) &&
480 LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) {
481 Type *LegalTy =
482 Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits());
483 if (LT.second.isVector())
484 LegalTy = VectorType::get(LegalTy, LT.second.getVectorElementCount());
485 InstructionCost Cost = 1;
486 IntrinsicCostAttributes Attrs1(IsSigned ? Intrinsic::smin : Intrinsic::umin,
487 LegalTy, {LegalTy, LegalTy});
488 Cost += getIntrinsicInstrCost(Attrs1, CostKind);
489 IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax,
490 LegalTy, {LegalTy, LegalTy});
491 Cost += getIntrinsicInstrCost(Attrs2, CostKind);
492 return LT.first * Cost;
493 }
494 break;
495 }
496 default:
497 break;
498 }
499 return BaseT::getIntrinsicInstrCost(ICA, CostKind);
500}
501
502/// The function will remove redundant reinterprets casting in the presence
503/// of the control flow
504static Optional<Instruction *> processPhiNode(InstCombiner &IC,
505 IntrinsicInst &II) {
506 SmallVector<Instruction *, 32> Worklist;
507 auto RequiredType = II.getType();
508
509 auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
510 assert(PN && "Expected Phi Node!")(static_cast <bool> (PN && "Expected Phi Node!"
) ? void (0) : __assert_fail ("PN && \"Expected Phi Node!\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 510
, __extension__ __PRETTY_FUNCTION__))
;
511
512 // Don't create a new Phi unless we can remove the old one.
513 if (!PN->hasOneUse())
514 return None;
515
516 for (Value *IncValPhi : PN->incoming_values()) {
517 auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
518 if (!Reinterpret ||
519 Reinterpret->getIntrinsicID() !=
520 Intrinsic::aarch64_sve_convert_to_svbool ||
521 RequiredType != Reinterpret->getArgOperand(0)->getType())
522 return None;
523 }
524
525 // Create the new Phi
526 LLVMContext &Ctx = PN->getContext();
527 IRBuilder<> Builder(Ctx);
528 Builder.SetInsertPoint(PN);
529 PHINode *NPN = Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
530 Worklist.push_back(PN);
531
532 for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
533 auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
534 NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
535 Worklist.push_back(Reinterpret);
536 }
537
538 // Cleanup Phi Node and reinterprets
539 return IC.replaceInstUsesWith(II, NPN);
540}
541
542// (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _))))
543// => (binop (pred) (from_svbool _) (from_svbool _))
544//
545// The above transformation eliminates a `to_svbool` in the predicate
546// operand of bitwise operation `binop` by narrowing the vector width of
547// the operation. For example, it would convert a `<vscale x 16 x i1>
548// and` into a `<vscale x 4 x i1> and`. This is profitable because
549// to_svbool must zero the new lanes during widening, whereas
550// from_svbool is free.
551static Optional<Instruction *> tryCombineFromSVBoolBinOp(InstCombiner &IC,
552 IntrinsicInst &II) {
553 auto BinOp = dyn_cast<IntrinsicInst>(II.getOperand(0));
554 if (!BinOp)
555 return None;
556
557 auto IntrinsicID = BinOp->getIntrinsicID();
558 switch (IntrinsicID) {
559 case Intrinsic::aarch64_sve_and_z:
560 case Intrinsic::aarch64_sve_bic_z:
561 case Intrinsic::aarch64_sve_eor_z:
562 case Intrinsic::aarch64_sve_nand_z:
563 case Intrinsic::aarch64_sve_nor_z:
564 case Intrinsic::aarch64_sve_orn_z:
565 case Intrinsic::aarch64_sve_orr_z:
566 break;
567 default:
568 return None;
569 }
570
571 auto BinOpPred = BinOp->getOperand(0);
572 auto BinOpOp1 = BinOp->getOperand(1);
573 auto BinOpOp2 = BinOp->getOperand(2);
574
575 auto PredIntr = dyn_cast<IntrinsicInst>(BinOpPred);
576 if (!PredIntr ||
577 PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool)
578 return None;
579
580 auto PredOp = PredIntr->getOperand(0);
581 auto PredOpTy = cast<VectorType>(PredOp->getType());
582 if (PredOpTy != II.getType())
583 return None;
584
585 IRBuilder<> Builder(II.getContext());
586 Builder.SetInsertPoint(&II);
587
588 SmallVector<Value *> NarrowedBinOpArgs = {PredOp};
589 auto NarrowBinOpOp1 = Builder.CreateIntrinsic(
590 Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1});
591 NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
592 if (BinOpOp1 == BinOpOp2)
593 NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
594 else
595 NarrowedBinOpArgs.push_back(Builder.CreateIntrinsic(
596 Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2}));
597
598 auto NarrowedBinOp =
599 Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs);
600 return IC.replaceInstUsesWith(II, NarrowedBinOp);
601}
602
603static Optional<Instruction *> instCombineConvertFromSVBool(InstCombiner &IC,
604 IntrinsicInst &II) {
605 // If the reinterpret instruction operand is a PHI Node
606 if (isa<PHINode>(II.getArgOperand(0)))
607 return processPhiNode(IC, II);
608
609 if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II))
610 return BinOpCombine;
611
612 SmallVector<Instruction *, 32> CandidatesForRemoval;
613 Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
614
615 const auto *IVTy = cast<VectorType>(II.getType());
616
617 // Walk the chain of conversions.
618 while (Cursor) {
619 // If the type of the cursor has fewer lanes than the final result, zeroing
620 // must take place, which breaks the equivalence chain.
621 const auto *CursorVTy = cast<VectorType>(Cursor->getType());
622 if (CursorVTy->getElementCount().getKnownMinValue() <
623 IVTy->getElementCount().getKnownMinValue())
624 break;
625
626 // If the cursor has the same type as I, it is a viable replacement.
627 if (Cursor->getType() == IVTy)
628 EarliestReplacement = Cursor;
629
630 auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
631
632 // If this is not an SVE conversion intrinsic, this is the end of the chain.
633 if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
634 Intrinsic::aarch64_sve_convert_to_svbool ||
635 IntrinsicCursor->getIntrinsicID() ==
636 Intrinsic::aarch64_sve_convert_from_svbool))
637 break;
638
639 CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
640 Cursor = IntrinsicCursor->getOperand(0);
641 }
642
643 // If no viable replacement in the conversion chain was found, there is
644 // nothing to do.
645 if (!EarliestReplacement)
646 return None;
647
648 return IC.replaceInstUsesWith(II, EarliestReplacement);
649}
650
651static Optional<Instruction *> instCombineSVESel(InstCombiner &IC,
652 IntrinsicInst &II) {
653 IRBuilder<> Builder(&II);
654 auto Select = Builder.CreateSelect(II.getOperand(0), II.getOperand(1),
655 II.getOperand(2));
656 return IC.replaceInstUsesWith(II, Select);
657}
658
659static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
660 IntrinsicInst &II) {
661 IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
662 if (!Pg)
663 return None;
664
665 if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
666 return None;
667
668 const auto PTruePattern =
669 cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
670 if (PTruePattern != AArch64SVEPredPattern::vl1)
671 return None;
672
673 // The intrinsic is inserting into lane zero so use an insert instead.
674 auto *IdxTy = Type::getInt64Ty(II.getContext());
675 auto *Insert = InsertElementInst::Create(
676 II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
677 Insert->insertBefore(&II);
678 Insert->takeName(&II);
679
680 return IC.replaceInstUsesWith(II, Insert);
681}
682
683static Optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
684 IntrinsicInst &II) {
685 // Replace DupX with a regular IR splat.
686 IRBuilder<> Builder(II.getContext());
687 Builder.SetInsertPoint(&II);
688 auto *RetTy = cast<ScalableVectorType>(II.getType());
689 Value *Splat =
690 Builder.CreateVectorSplat(RetTy->getElementCount(), II.getArgOperand(0));
691 Splat->takeName(&II);
692 return IC.replaceInstUsesWith(II, Splat);
693}
694
695static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
696 IntrinsicInst &II) {
697 LLVMContext &Ctx = II.getContext();
698 IRBuilder<> Builder(Ctx);
699 Builder.SetInsertPoint(&II);
700
701 // Check that the predicate is all active
702 auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
703 if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
704 return None;
705
706 const auto PTruePattern =
707 cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
708 if (PTruePattern != AArch64SVEPredPattern::all)
709 return None;
710
711 // Check that we have a compare of zero..
712 auto *SplatValue =
713 dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
714 if (!SplatValue || !SplatValue->isZero())
715 return None;
716
717 // ..against a dupq
718 auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
719 if (!DupQLane ||
720 DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
721 return None;
722
723 // Where the dupq is a lane 0 replicate of a vector insert
724 if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
725 return None;
726
727 auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
728 if (!VecIns || VecIns->getIntrinsicID() != Intrinsic::vector_insert)
729 return None;
730
731 // Where the vector insert is a fixed constant vector insert into undef at
732 // index zero
733 if (!isa<UndefValue>(VecIns->getArgOperand(0)))
734 return None;
735
736 if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
737 return None;
738
739 auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
740 if (!ConstVec)
741 return None;
742
743 auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
744 auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
745 if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
746 return None;
747
748 unsigned NumElts = VecTy->getNumElements();
749 unsigned PredicateBits = 0;
750
751 // Expand intrinsic operands to a 16-bit byte level predicate
752 for (unsigned I = 0; I < NumElts; ++I) {
753 auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
754 if (!Arg)
755 return None;
756 if (!Arg->isZero())
757 PredicateBits |= 1 << (I * (16 / NumElts));
758 }
759
760 // If all bits are zero bail early with an empty predicate
761 if (PredicateBits == 0) {
762 auto *PFalse = Constant::getNullValue(II.getType());
763 PFalse->takeName(&II);
764 return IC.replaceInstUsesWith(II, PFalse);
765 }
766
767 // Calculate largest predicate type used (where byte predicate is largest)
768 unsigned Mask = 8;
769 for (unsigned I = 0; I < 16; ++I)
770 if ((PredicateBits & (1 << I)) != 0)
771 Mask |= (I % 8);
772
773 unsigned PredSize = Mask & -Mask;
774 auto *PredType = ScalableVectorType::get(
775 Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
776
777 // Ensure all relevant bits are set
778 for (unsigned I = 0; I < 16; I += PredSize)
779 if ((PredicateBits & (1 << I)) == 0)
780 return None;
781
782 auto *PTruePat =
783 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
784 auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
785 {PredType}, {PTruePat});
786 auto *ConvertToSVBool = Builder.CreateIntrinsic(
787 Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
788 auto *ConvertFromSVBool =
789 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
790 {II.getType()}, {ConvertToSVBool});
791
792 ConvertFromSVBool->takeName(&II);
793 return IC.replaceInstUsesWith(II, ConvertFromSVBool);
794}
795
796static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
797 IntrinsicInst &II) {
798 IRBuilder<> Builder(II.getContext());
799 Builder.SetInsertPoint(&II);
800 Value *Pg = II.getArgOperand(0);
801 Value *Vec = II.getArgOperand(1);
802 auto IntrinsicID = II.getIntrinsicID();
803 bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
804
805 // lastX(splat(X)) --> X
806 if (auto *SplatVal = getSplatValue(Vec))
807 return IC.replaceInstUsesWith(II, SplatVal);
808
809 // If x and/or y is a splat value then:
810 // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
811 Value *LHS, *RHS;
812 if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
813 if (isSplatValue(LHS) || isSplatValue(RHS)) {
814 auto *OldBinOp = cast<BinaryOperator>(Vec);
815 auto OpC = OldBinOp->getOpcode();
816 auto *NewLHS =
817 Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
818 auto *NewRHS =
819 Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
820 auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
821 OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II);
822 return IC.replaceInstUsesWith(II, NewBinOp);
823 }
824 }
825
826 auto *C = dyn_cast<Constant>(Pg);
827 if (IsAfter && C && C->isNullValue()) {
828 // The intrinsic is extracting lane 0 so use an extract instead.
829 auto *IdxTy = Type::getInt64Ty(II.getContext());
830 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
831 Extract->insertBefore(&II);
832 Extract->takeName(&II);
833 return IC.replaceInstUsesWith(II, Extract);
834 }
835
836 auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
837 if (!IntrPG)
838 return None;
839
840 if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
841 return None;
842
843 const auto PTruePattern =
844 cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
845
846 // Can the intrinsic's predicate be converted to a known constant index?
847 unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
848 if (!MinNumElts)
849 return None;
850
851 unsigned Idx = MinNumElts - 1;
852 // Increment the index if extracting the element after the last active
853 // predicate element.
854 if (IsAfter)
855 ++Idx;
856
857 // Ignore extracts whose index is larger than the known minimum vector
858 // length. NOTE: This is an artificial constraint where we prefer to
859 // maintain what the user asked for until an alternative is proven faster.
860 auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
861 if (Idx >= PgVTy->getMinNumElements())
862 return None;
863
864 // The intrinsic is extracting a fixed lane so use an extract instead.
865 auto *IdxTy = Type::getInt64Ty(II.getContext());
866 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
867 Extract->insertBefore(&II);
868 Extract->takeName(&II);
869 return IC.replaceInstUsesWith(II, Extract);
870}
871
872static Optional<Instruction *> instCombineSVECondLast(InstCombiner &IC,
873 IntrinsicInst &II) {
874 // The SIMD&FP variant of CLAST[AB] is significantly faster than the scalar
875 // integer variant across a variety of micro-architectures. Replace scalar
876 // integer CLAST[AB] intrinsic with optimal SIMD&FP variant. A simple
877 // bitcast-to-fp + clast[ab] + bitcast-to-int will cost a cycle or two more
878 // depending on the micro-architecture, but has been observed as generally
879 // being faster, particularly when the CLAST[AB] op is a loop-carried
880 // dependency.
881 IRBuilder<> Builder(II.getContext());
882 Builder.SetInsertPoint(&II);
883 Value *Pg = II.getArgOperand(0);
884 Value *Fallback = II.getArgOperand(1);
885 Value *Vec = II.getArgOperand(2);
886 Type *Ty = II.getType();
887
888 if (!Ty->isIntegerTy())
889 return None;
890
891 Type *FPTy;
892 switch (cast<IntegerType>(Ty)->getBitWidth()) {
893 default:
894 return None;
895 case 16:
896 FPTy = Builder.getHalfTy();
897 break;
898 case 32:
899 FPTy = Builder.getFloatTy();
900 break;
901 case 64:
902 FPTy = Builder.getDoubleTy();
903 break;
904 }
905
906 Value *FPFallBack = Builder.CreateBitCast(Fallback, FPTy);
907 auto *FPVTy = VectorType::get(
908 FPTy, cast<VectorType>(Vec->getType())->getElementCount());
909 Value *FPVec = Builder.CreateBitCast(Vec, FPVTy);
910 auto *FPII = Builder.CreateIntrinsic(II.getIntrinsicID(), {FPVec->getType()},
911 {Pg, FPFallBack, FPVec});
912 Value *FPIItoInt = Builder.CreateBitCast(FPII, II.getType());
913 return IC.replaceInstUsesWith(II, FPIItoInt);
914}
915
916static Optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
917 IntrinsicInst &II) {
918 LLVMContext &Ctx = II.getContext();
919 IRBuilder<> Builder(Ctx);
920 Builder.SetInsertPoint(&II);
921 // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
922 // can work with RDFFR_PP for ptest elimination.
923 auto *AllPat =
924 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
925 auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
926 {II.getType()}, {AllPat});
927 auto *RDFFR =
928 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
929 RDFFR->takeName(&II);
930 return IC.replaceInstUsesWith(II, RDFFR);
931}
932
933static Optional<Instruction *>
934instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
935 const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
936
937 if (Pattern == AArch64SVEPredPattern::all) {
938 LLVMContext &Ctx = II.getContext();
939 IRBuilder<> Builder(Ctx);
940 Builder.SetInsertPoint(&II);
941
942 Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
943 auto *VScale = Builder.CreateVScale(StepVal);
944 VScale->takeName(&II);
945 return IC.replaceInstUsesWith(II, VScale);
946 }
947
948 unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
949
950 return MinNumElts && NumElts >= MinNumElts
951 ? Optional<Instruction *>(IC.replaceInstUsesWith(
952 II, ConstantInt::get(II.getType(), MinNumElts)))
953 : None;
954}
955
956static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
957 IntrinsicInst &II) {
958 IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
959 IntrinsicInst *Op2 = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
960
961 if (Op1 && Op2 &&
962 Op1->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
963 Op2->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
964 Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) {
965
966 IRBuilder<> Builder(II.getContext());
967 Builder.SetInsertPoint(&II);
968
969 Value *Ops[] = {Op1->getArgOperand(0), Op2->getArgOperand(0)};
970 Type *Tys[] = {Op1->getArgOperand(0)->getType()};
971
972 auto *PTest = Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
973
974 PTest->takeName(&II);
975 return IC.replaceInstUsesWith(II, PTest);
976 }
977
978 return None;
979}
980
981static Optional<Instruction *> instCombineSVEVectorFMLA(InstCombiner &IC,
982 IntrinsicInst &II) {
983 // fold (fadd p a (fmul p b c)) -> (fma p a b c)
984 Value *P = II.getOperand(0);
985 Value *A = II.getOperand(1);
986 auto FMul = II.getOperand(2);
987 Value *B, *C;
988 if (!match(FMul, m_Intrinsic<Intrinsic::aarch64_sve_fmul>(
989 m_Specific(P), m_Value(B), m_Value(C))))
990 return None;
991
992 if (!FMul->hasOneUse())
993 return None;
994
995 llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
996 // Stop the combine when the flags on the inputs differ in case dropping flags
997 // would lead to us missing out on more beneficial optimizations.
998 if (FAddFlags != cast<CallInst>(FMul)->getFastMathFlags())
999 return None;
1000 if (!FAddFlags.allowContract())
1001 return None;
1002
1003 IRBuilder<> Builder(II.getContext());
1004 Builder.SetInsertPoint(&II);
1005 auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla,
1006 {II.getType()}, {P, A, B, C}, &II);
1007 FMLA->setFastMathFlags(FAddFlags);
1008 return IC.replaceInstUsesWith(II, FMLA);
1009}
1010
1011static bool isAllActivePredicate(Value *Pred) {
1012 // Look through convert.from.svbool(convert.to.svbool(...) chain.
1013 Value *UncastedPred;
1014 if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
1015 m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
1016 m_Value(UncastedPred)))))
1017 // If the predicate has the same or less lanes than the uncasted
1018 // predicate then we know the casting has no effect.
1019 if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
1020 cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
1021 Pred = UncastedPred;
1022
1023 return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1024 m_ConstantInt<AArch64SVEPredPattern::all>()));
1025}
1026
1027static Optional<Instruction *>
1028instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
1029 IRBuilder<> Builder(II.getContext());
1030 Builder.SetInsertPoint(&II);
1031
1032 Value *Pred = II.getOperand(0);
1033 Value *PtrOp = II.getOperand(1);
1034 Type *VecTy = II.getType();
1035 Value *VecPtr = Builder.CreateBitCast(PtrOp, VecTy->getPointerTo());
1036
1037 if (isAllActivePredicate(Pred)) {
1038 LoadInst *Load = Builder.CreateLoad(VecTy, VecPtr);
1039 Load->copyMetadata(II);
1040 return IC.replaceInstUsesWith(II, Load);
1041 }
1042
1043 CallInst *MaskedLoad =
1044 Builder.CreateMaskedLoad(VecTy, VecPtr, PtrOp->getPointerAlignment(DL),
1045 Pred, ConstantAggregateZero::get(VecTy));
1046 MaskedLoad->copyMetadata(II);
1047 return IC.replaceInstUsesWith(II, MaskedLoad);
1048}
1049
1050static Optional<Instruction *>
1051instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
1052 IRBuilder<> Builder(II.getContext());
1053 Builder.SetInsertPoint(&II);
1054
1055 Value *VecOp = II.getOperand(0);
1056 Value *Pred = II.getOperand(1);
1057 Value *PtrOp = II.getOperand(2);
1058 Value *VecPtr =
1059 Builder.CreateBitCast(PtrOp, VecOp->getType()->getPointerTo());
1060
1061 if (isAllActivePredicate(Pred)) {
1062 StoreInst *Store = Builder.CreateStore(VecOp, VecPtr);
1063 Store->copyMetadata(II);
1064 return IC.eraseInstFromFunction(II);
1065 }
1066
1067 CallInst *MaskedStore = Builder.CreateMaskedStore(
1068 VecOp, VecPtr, PtrOp->getPointerAlignment(DL), Pred);
1069 MaskedStore->copyMetadata(II);
1070 return IC.eraseInstFromFunction(II);
1071}
1072
1073static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
1074 switch (Intrinsic) {
1075 case Intrinsic::aarch64_sve_fmul:
1076 return Instruction::BinaryOps::FMul;
1077 case Intrinsic::aarch64_sve_fadd:
1078 return Instruction::BinaryOps::FAdd;
1079 case Intrinsic::aarch64_sve_fsub:
1080 return Instruction::BinaryOps::FSub;
1081 default:
1082 return Instruction::BinaryOpsEnd;
1083 }
1084}
1085
1086static Optional<Instruction *> instCombineSVEVectorBinOp(InstCombiner &IC,
1087 IntrinsicInst &II) {
1088 auto *OpPredicate = II.getOperand(0);
1089 auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
1090 if (BinOpCode == Instruction::BinaryOpsEnd ||
1091 !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1092 m_ConstantInt<AArch64SVEPredPattern::all>())))
1093 return None;
1094 IRBuilder<> Builder(II.getContext());
1095 Builder.SetInsertPoint(&II);
1096 Builder.setFastMathFlags(II.getFastMathFlags());
1097 auto BinOp =
1098 Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2));
1099 return IC.replaceInstUsesWith(II, BinOp);
1100}
1101
1102static Optional<Instruction *> instCombineSVEVectorFAdd(InstCombiner &IC,
1103 IntrinsicInst &II) {
1104 if (auto FMLA = instCombineSVEVectorFMLA(IC, II))
1105 return FMLA;
1106 return instCombineSVEVectorBinOp(IC, II);
1107}
1108
1109static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
1110 IntrinsicInst &II) {
1111 auto *OpPredicate = II.getOperand(0);
1112 auto *OpMultiplicand = II.getOperand(1);
1113 auto *OpMultiplier = II.getOperand(2);
1114
1115 IRBuilder<> Builder(II.getContext());
1116 Builder.SetInsertPoint(&II);
1117
1118 // Return true if a given instruction is a unit splat value, false otherwise.
1119 auto IsUnitSplat = [](auto *I) {
1120 auto *SplatValue = getSplatValue(I);
1121 if (!SplatValue)
1122 return false;
1123 return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
1124 };
1125
1126 // Return true if a given instruction is an aarch64_sve_dup intrinsic call
1127 // with a unit splat value, false otherwise.
1128 auto IsUnitDup = [](auto *I) {
1129 auto *IntrI = dyn_cast<IntrinsicInst>(I);
1130 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
1131 return false;
1132
1133 auto *SplatValue = IntrI->getOperand(2);
1134 return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
1135 };
1136
1137 if (IsUnitSplat(OpMultiplier)) {
1138 // [f]mul pg %n, (dupx 1) => %n
1139 OpMultiplicand->takeName(&II);
1140 return IC.replaceInstUsesWith(II, OpMultiplicand);
1141 } else if (IsUnitDup(OpMultiplier)) {
1142 // [f]mul pg %n, (dup pg 1) => %n
1143 auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
1144 auto *DupPg = DupInst->getOperand(1);
1145 // TODO: this is naive. The optimization is still valid if DupPg
1146 // 'encompasses' OpPredicate, not only if they're the same predicate.
1147 if (OpPredicate == DupPg) {
1148 OpMultiplicand->takeName(&II);
1149 return IC.replaceInstUsesWith(II, OpMultiplicand);
1150 }
1151 }
1152
1153 return instCombineSVEVectorBinOp(IC, II);
1154}
1155
1156static Optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
1157 IntrinsicInst &II) {
1158 IRBuilder<> Builder(II.getContext());
1159 Builder.SetInsertPoint(&II);
1160 Value *UnpackArg = II.getArgOperand(0);
1161 auto *RetTy = cast<ScalableVectorType>(II.getType());
1162 bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
1163 II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
1164
1165 // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
1166 // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
1167 if (auto *ScalarArg = getSplatValue(UnpackArg)) {
1168 ScalarArg =
1169 Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
1170 Value *NewVal =
1171 Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
1172 NewVal->takeName(&II);
1173 return IC.replaceInstUsesWith(II, NewVal);
1174 }
1175
1176 return None;
1177}
1178static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
1179 IntrinsicInst &II) {
1180 auto *OpVal = II.getOperand(0);
1181 auto *OpIndices = II.getOperand(1);
1182 VectorType *VTy = cast<VectorType>(II.getType());
1183
1184 // Check whether OpIndices is a constant splat value < minimal element count
1185 // of result.
1186 auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
1187 if (!SplatValue ||
1188 SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
1189 return None;
1190
1191 // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
1192 // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
1193 IRBuilder<> Builder(II.getContext());
1194 Builder.SetInsertPoint(&II);
1195 auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue);
1196 auto *VectorSplat =
1197 Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
1198
1199 VectorSplat->takeName(&II);
1200 return IC.replaceInstUsesWith(II, VectorSplat);
1201}
1202
1203static Optional<Instruction *> instCombineSVETupleGet(InstCombiner &IC,
1204 IntrinsicInst &II) {
1205 // Try to remove sequences of tuple get/set.
1206 Value *SetTuple, *SetIndex, *SetValue;
1207 auto *GetTuple = II.getArgOperand(0);
1208 auto *GetIndex = II.getArgOperand(1);
1209 // Check that we have tuple_get(GetTuple, GetIndex) where GetTuple is a
1210 // call to tuple_set i.e. tuple_set(SetTuple, SetIndex, SetValue).
1211 // Make sure that the types of the current intrinsic and SetValue match
1212 // in order to safely remove the sequence.
1213 if (!match(GetTuple,
1214 m_Intrinsic<Intrinsic::aarch64_sve_tuple_set>(
1215 m_Value(SetTuple), m_Value(SetIndex), m_Value(SetValue))) ||
1216 SetValue->getType() != II.getType())
1217 return None;
1218 // Case where we get the same index right after setting it.
1219 // tuple_get(tuple_set(SetTuple, SetIndex, SetValue), GetIndex) --> SetValue
1220 if (GetIndex == SetIndex)
1221 return IC.replaceInstUsesWith(II, SetValue);
1222 // If we are getting a different index than what was set in the tuple_set
1223 // intrinsic. We can just set the input tuple to the one up in the chain.
1224 // tuple_get(tuple_set(SetTuple, SetIndex, SetValue), GetIndex)
1225 // --> tuple_get(SetTuple, GetIndex)
1226 return IC.replaceOperand(II, 0, SetTuple);
1227}
1228
1229static Optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
1230 IntrinsicInst &II) {
1231 // zip1(uzp1(A, B), uzp2(A, B)) --> A
1232 // zip2(uzp1(A, B), uzp2(A, B)) --> B
1233 Value *A, *B;
1234 if (match(II.getArgOperand(0),
1235 m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
1236 match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
1237 m_Specific(A), m_Specific(B))))
1238 return IC.replaceInstUsesWith(
1239 II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
1240
1241 return None;
1242}
1243
1244static Optional<Instruction *> instCombineLD1GatherIndex(InstCombiner &IC,
1245 IntrinsicInst &II) {
1246 Value *Mask = II.getOperand(0);
1247 Value *BasePtr = II.getOperand(1);
1248 Value *Index = II.getOperand(2);
1249 Type *Ty = II.getType();
1250 Value *PassThru = ConstantAggregateZero::get(Ty);
1251
1252 // Contiguous gather => masked load.
1253 // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
1254 // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
1255 Value *IndexBase;
1256 if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1257 m_Value(IndexBase), m_SpecificInt(1)))) {
1258 IRBuilder<> Builder(II.getContext());
1259 Builder.SetInsertPoint(&II);
1260
1261 Align Alignment =
1262 BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1263
1264 Type *VecPtrTy = PointerType::getUnqual(Ty);
1265 Value *Ptr = Builder.CreateGEP(
1266 cast<VectorType>(Ty)->getElementType(), BasePtr, IndexBase);
1267 Ptr = Builder.CreateBitCast(Ptr, VecPtrTy);
1268 CallInst *MaskedLoad =
1269 Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
1270 MaskedLoad->takeName(&II);
1271 return IC.replaceInstUsesWith(II, MaskedLoad);
1272 }
1273
1274 return None;
1275}
1276
1277static Optional<Instruction *> instCombineST1ScatterIndex(InstCombiner &IC,
1278 IntrinsicInst &II) {
1279 Value *Val = II.getOperand(0);
1280 Value *Mask = II.getOperand(1);
1281 Value *BasePtr = II.getOperand(2);
1282 Value *Index = II.getOperand(3);
1283 Type *Ty = Val->getType();
1284
1285 // Contiguous scatter => masked store.
1286 // (sve.st1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
1287 // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
1288 Value *IndexBase;
1289 if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1290 m_Value(IndexBase), m_SpecificInt(1)))) {
1291 IRBuilder<> Builder(II.getContext());
1292 Builder.SetInsertPoint(&II);
1293
1294 Align Alignment =
1295 BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1296
1297 Value *Ptr = Builder.CreateGEP(
1298 cast<VectorType>(Ty)->getElementType(), BasePtr, IndexBase);
1299 Type *VecPtrTy = PointerType::getUnqual(Ty);
1300 Ptr = Builder.CreateBitCast(Ptr, VecPtrTy);
1301
1302 (void)Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
1303
1304 return IC.eraseInstFromFunction(II);
1305 }
1306
1307 return None;
1308}
1309
1310static Optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
1311 IntrinsicInst &II) {
1312 IRBuilder<> Builder(II.getContext());
1313 Builder.SetInsertPoint(&II);
1314 Type *Int32Ty = Builder.getInt32Ty();
1315 Value *Pred = II.getOperand(0);
1316 Value *Vec = II.getOperand(1);
1317 Value *DivVec = II.getOperand(2);
1318
1319 Value *SplatValue = getSplatValue(DivVec);
1320 ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
1321 if (!SplatConstantInt)
1322 return None;
1323 APInt Divisor = SplatConstantInt->getValue();
1324
1325 if (Divisor.isPowerOf2()) {
1326 Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1327 auto ASRD = Builder.CreateIntrinsic(
1328 Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1329 return IC.replaceInstUsesWith(II, ASRD);
1330 }
1331 if (Divisor.isNegatedPowerOf2()) {
1332 Divisor.negate();
1333 Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1334 auto ASRD = Builder.CreateIntrinsic(
1335 Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1336 auto NEG = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_neg,
1337 {ASRD->getType()}, {ASRD, Pred, ASRD});
1338 return IC.replaceInstUsesWith(II, NEG);
1339 }
1340
1341 return None;
1342}
1343
1344static Optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC,
1345 IntrinsicInst &II) {
1346 Value *A = II.getArgOperand(0);
1347 Value *B = II.getArgOperand(1);
1348 if (A == B)
1349 return IC.replaceInstUsesWith(II, A);
1350
1351 return None;
1352}
1353
1354static Optional<Instruction *> instCombineSVESrshl(InstCombiner &IC,
1355 IntrinsicInst &II) {
1356 IRBuilder<> Builder(&II);
1357 Value *Pred = II.getOperand(0);
1358 Value *Vec = II.getOperand(1);
1359 Value *Shift = II.getOperand(2);
1360
1361 // Convert SRSHL into the simpler LSL intrinsic when fed by an ABS intrinsic.
1362 Value *AbsPred, *MergedValue;
1363 if (!match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_sqabs>(
1364 m_Value(MergedValue), m_Value(AbsPred), m_Value())) &&
1365 !match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_abs>(
1366 m_Value(MergedValue), m_Value(AbsPred), m_Value())))
1367
1368 return None;
1369
1370 // Transform is valid if any of the following are true:
1371 // * The ABS merge value is an undef or non-negative
1372 // * The ABS predicate is all active
1373 // * The ABS predicate and the SRSHL predicates are the same
1374 if (!isa<UndefValue>(MergedValue) &&
1375 !match(MergedValue, m_NonNegative()) &&
1376 AbsPred != Pred && !isAllActivePredicate(AbsPred))
1377 return None;
1378
1379 // Only valid when the shift amount is non-negative, otherwise the rounding
1380 // behaviour of SRSHL cannot be ignored.
1381 if (!match(Shift, m_NonNegative()))
1382 return None;
1383
1384 auto LSL = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_lsl, {II.getType()},
1385 {Pred, Vec, Shift});
1386
1387 return IC.replaceInstUsesWith(II, LSL);
1388}
1389
1390Optional<Instruction *>
1391AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
1392 IntrinsicInst &II) const {
1393 Intrinsic::ID IID = II.getIntrinsicID();
1394 switch (IID) {
1395 default:
1396 break;
1397 case Intrinsic::aarch64_neon_fmaxnm:
1398 case Intrinsic::aarch64_neon_fminnm:
1399 return instCombineMaxMinNM(IC, II);
1400 case Intrinsic::aarch64_sve_convert_from_svbool:
1401 return instCombineConvertFromSVBool(IC, II);
1402 case Intrinsic::aarch64_sve_dup:
1403 return instCombineSVEDup(IC, II);
1404 case Intrinsic::aarch64_sve_dup_x:
1405 return instCombineSVEDupX(IC, II);
1406 case Intrinsic::aarch64_sve_cmpne:
1407 case Intrinsic::aarch64_sve_cmpne_wide:
1408 return instCombineSVECmpNE(IC, II);
1409 case Intrinsic::aarch64_sve_rdffr:
1410 return instCombineRDFFR(IC, II);
1411 case Intrinsic::aarch64_sve_lasta:
1412 case Intrinsic::aarch64_sve_lastb:
1413 return instCombineSVELast(IC, II);
1414 case Intrinsic::aarch64_sve_clasta_n:
1415 case Intrinsic::aarch64_sve_clastb_n:
1416 return instCombineSVECondLast(IC, II);
1417 case Intrinsic::aarch64_sve_cntd:
1418 return instCombineSVECntElts(IC, II, 2);
1419 case Intrinsic::aarch64_sve_cntw:
1420 return instCombineSVECntElts(IC, II, 4);
1421 case Intrinsic::aarch64_sve_cnth:
1422 return instCombineSVECntElts(IC, II, 8);
1423 case Intrinsic::aarch64_sve_cntb:
1424 return instCombineSVECntElts(IC, II, 16);
1425 case Intrinsic::aarch64_sve_ptest_any:
1426 case Intrinsic::aarch64_sve_ptest_first:
1427 case Intrinsic::aarch64_sve_ptest_last:
1428 return instCombineSVEPTest(IC, II);
1429 case Intrinsic::aarch64_sve_mul:
1430 case Intrinsic::aarch64_sve_fmul:
1431 return instCombineSVEVectorMul(IC, II);
1432 case Intrinsic::aarch64_sve_fadd:
1433 return instCombineSVEVectorFAdd(IC, II);
1434 case Intrinsic::aarch64_sve_fsub:
1435 return instCombineSVEVectorBinOp(IC, II);
1436 case Intrinsic::aarch64_sve_tbl:
1437 return instCombineSVETBL(IC, II);
1438 case Intrinsic::aarch64_sve_uunpkhi:
1439 case Intrinsic::aarch64_sve_uunpklo:
1440 case Intrinsic::aarch64_sve_sunpkhi:
1441 case Intrinsic::aarch64_sve_sunpklo:
1442 return instCombineSVEUnpack(IC, II);
1443 case Intrinsic::aarch64_sve_tuple_get:
1444 return instCombineSVETupleGet(IC, II);
1445 case Intrinsic::aarch64_sve_zip1:
1446 case Intrinsic::aarch64_sve_zip2:
1447 return instCombineSVEZip(IC, II);
1448 case Intrinsic::aarch64_sve_ld1_gather_index:
1449 return instCombineLD1GatherIndex(IC, II);
1450 case Intrinsic::aarch64_sve_st1_scatter_index:
1451 return instCombineST1ScatterIndex(IC, II);
1452 case Intrinsic::aarch64_sve_ld1:
1453 return instCombineSVELD1(IC, II, DL);
1454 case Intrinsic::aarch64_sve_st1:
1455 return instCombineSVEST1(IC, II, DL);
1456 case Intrinsic::aarch64_sve_sdiv:
1457 return instCombineSVESDIV(IC, II);
1458 case Intrinsic::aarch64_sve_sel:
1459 return instCombineSVESel(IC, II);
1460 case Intrinsic::aarch64_sve_srshl:
1461 return instCombineSVESrshl(IC, II);
1462 }
1463
1464 return None;
1465}
1466
1467Optional<Value *> AArch64TTIImpl::simplifyDemandedVectorEltsIntrinsic(
1468 InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts,
1469 APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3,
1470 std::function<void(Instruction *, unsigned, APInt, APInt &)>
1471 SimplifyAndSetOp) const {
1472 switch (II.getIntrinsicID()) {
1473 default:
1474 break;
1475 case Intrinsic::aarch64_neon_fcvtxn:
1476 case Intrinsic::aarch64_neon_rshrn:
1477 case Intrinsic::aarch64_neon_sqrshrn:
1478 case Intrinsic::aarch64_neon_sqrshrun:
1479 case Intrinsic::aarch64_neon_sqshrn:
1480 case Intrinsic::aarch64_neon_sqshrun:
1481 case Intrinsic::aarch64_neon_sqxtn:
1482 case Intrinsic::aarch64_neon_sqxtun:
1483 case Intrinsic::aarch64_neon_uqrshrn:
1484 case Intrinsic::aarch64_neon_uqshrn:
1485 case Intrinsic::aarch64_neon_uqxtn:
1486 SimplifyAndSetOp(&II, 0, OrigDemandedElts, UndefElts);
1487 break;
1488 }
1489
1490 return None;
1491}
1492
1493bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
1494 ArrayRef<const Value *> Args) {
1495
1496 // A helper that returns a vector type from the given type. The number of
1497 // elements in type Ty determines the vector width.
1498 auto toVectorTy = [&](Type *ArgTy) {
1499 return VectorType::get(ArgTy->getScalarType(),
1500 cast<VectorType>(DstTy)->getElementCount());
1501 };
1502
1503 // Exit early if DstTy is not a vector type whose elements are at least
1504 // 16-bits wide.
1505 if (!DstTy->isVectorTy() || DstTy->getScalarSizeInBits() < 16)
1506 return false;
1507
1508 // Determine if the operation has a widening variant. We consider both the
1509 // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
1510 // instructions.
1511 //
1512 // TODO: Add additional widening operations (e.g., shl, etc.) once we
1513 // verify that their extending operands are eliminated during code
1514 // generation.
1515 switch (Opcode) {
1516 case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
1517 case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
1518 case Instruction::Mul: // SMULL(2), UMULL(2)
1519 break;
1520 default:
1521 return false;
1522 }
1523
1524 // To be a widening instruction (either the "wide" or "long" versions), the
1525 // second operand must be a sign- or zero extend.
1526 if (Args.size() != 2 ||
1527 (!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])))
1528 return false;
1529 auto *Extend = cast<CastInst>(Args[1]);
1530 auto *Arg0 = dyn_cast<CastInst>(Args[0]);
1531
1532 // A mul only has a mull version (not like addw). Both operands need to be
1533 // extending and the same type.
1534 if (Opcode == Instruction::Mul &&
1535 (!Arg0 || Arg0->getOpcode() != Extend->getOpcode() ||
1536 Arg0->getOperand(0)->getType() != Extend->getOperand(0)->getType()))
1537 return false;
1538
1539 // Legalize the destination type and ensure it can be used in a widening
1540 // operation.
1541 auto DstTyL = getTypeLegalizationCost(DstTy);
1542 unsigned DstElTySize = DstTyL.second.getScalarSizeInBits();
1543 if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits())
1544 return false;
1545
1546 // Legalize the source type and ensure it can be used in a widening
1547 // operation.
1548 auto *SrcTy = toVectorTy(Extend->getSrcTy());
1549 auto SrcTyL = getTypeLegalizationCost(SrcTy);
1550 unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
1551 if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
1552 return false;
1553
1554 // Get the total number of vector elements in the legalized types.
1555 InstructionCost NumDstEls =
1556 DstTyL.first * DstTyL.second.getVectorMinNumElements();
1557 InstructionCost NumSrcEls =
1558 SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
1559
1560 // Return true if the legalized types have the same number of vector elements
1561 // and the destination element type size is twice that of the source type.
1562 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
1563}
1564
1565InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
1566 Type *Src,
1567 TTI::CastContextHint CCH,
1568 TTI::TargetCostKind CostKind,
1569 const Instruction *I) {
1570 int ISD = TLI->InstructionOpcodeToISD(Opcode);
1571 assert(ISD && "Invalid opcode")(static_cast <bool> (ISD && "Invalid opcode") ?
void (0) : __assert_fail ("ISD && \"Invalid opcode\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 1571
, __extension__ __PRETTY_FUNCTION__))
;
1572
1573 // If the cast is observable, and it is used by a widening instruction (e.g.,
1574 // uaddl, saddw, etc.), it may be free.
1575 if (I && I->hasOneUser()) {
1576 auto *SingleUser = cast<Instruction>(*I->user_begin());
1577 SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
1578 if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands)) {
1579 // If the cast is the second operand, it is free. We will generate either
1580 // a "wide" or "long" version of the widening instruction.
1581 if (I == SingleUser->getOperand(1))
1582 return 0;
1583 // If the cast is not the second operand, it will be free if it looks the
1584 // same as the second operand. In this case, we will generate a "long"
1585 // version of the widening instruction.
1586 if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1)))
1587 if (I->getOpcode() == unsigned(Cast->getOpcode()) &&
1588 cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy())
1589 return 0;
1590 }
1591 }
1592
1593 // TODO: Allow non-throughput costs that aren't binary.
1594 auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
1595 if (CostKind != TTI::TCK_RecipThroughput)
1596 return Cost == 0 ? 0 : 1;
1597 return Cost;
1598 };
1599
1600 EVT SrcTy = TLI->getValueType(DL, Src);
1601 EVT DstTy = TLI->getValueType(DL, Dst);
1602
1603 if (!SrcTy.isSimple() || !DstTy.isSimple())
1604 return AdjustCost(
1605 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
1606
1607 static const TypeConversionCostTblEntry
1608 ConversionTbl[] = {
1609 { ISD::TRUNCATE, MVT::v2i8, MVT::v2i64, 1}, // xtn
1610 { ISD::TRUNCATE, MVT::v2i16, MVT::v2i64, 1}, // xtn
1611 { ISD::TRUNCATE, MVT::v2i32, MVT::v2i64, 1}, // xtn
1612 { ISD::TRUNCATE, MVT::v4i8, MVT::v4i32, 1}, // xtn
1613 { ISD::TRUNCATE, MVT::v4i8, MVT::v4i64, 3}, // 2 xtn + 1 uzp1
1614 { ISD::TRUNCATE, MVT::v4i16, MVT::v4i32, 1}, // xtn
1615 { ISD::TRUNCATE, MVT::v4i16, MVT::v4i64, 2}, // 1 uzp1 + 1 xtn
1616 { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64, 1}, // 1 uzp1
1617 { ISD::TRUNCATE, MVT::v8i8, MVT::v8i16, 1}, // 1 xtn
1618 { ISD::TRUNCATE, MVT::v8i8, MVT::v8i32, 2}, // 1 uzp1 + 1 xtn
1619 { ISD::TRUNCATE, MVT::v8i8, MVT::v8i64, 4}, // 3 x uzp1 + xtn
1620 { ISD::TRUNCATE, MVT::v8i16, MVT::v8i32, 1}, // 1 uzp1
1621 { ISD::TRUNCATE, MVT::v8i16, MVT::v8i64, 3}, // 3 x uzp1
1622 { ISD::TRUNCATE, MVT::v8i32, MVT::v8i64, 2}, // 2 x uzp1
1623 { ISD::TRUNCATE, MVT::v16i8, MVT::v16i16, 1}, // uzp1
1624 { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 3}, // (2 + 1) x uzp1
1625 { ISD::TRUNCATE, MVT::v16i8, MVT::v16i64, 7}, // (4 + 2 + 1) x uzp1
1626 { ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 2}, // 2 x uzp1
1627 { ISD::TRUNCATE, MVT::v16i16, MVT::v16i64, 6}, // (4 + 2) x uzp1
1628 { ISD::TRUNCATE, MVT::v16i32, MVT::v16i64, 4}, // 4 x uzp1
1629
1630 // Truncations on nxvmiN
1631 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
1632 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
1633 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
1634 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
1635 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
1636 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
1637 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
1638 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
1639 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
1640 { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
1641 { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
1642 { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
1643 { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
1644 { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
1645 { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
1646 { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
1647
1648 // The number of shll instructions for the extension.
1649 { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 3 },
1650 { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i16, 3 },
1651 { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 2 },
1652 { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 2 },
1653 { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i8, 3 },
1654 { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i8, 3 },
1655 { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 2 },
1656 { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 2 },
1657 { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i8, 7 },
1658 { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i8, 7 },
1659 { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i16, 6 },
1660 { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i16, 6 },
1661 { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
1662 { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
1663 { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
1664 { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
1665
1666 // LowerVectorINT_TO_FP:
1667 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
1668 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
1669 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
1670 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
1671 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
1672 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
1673
1674 // Complex: to v2f32
1675 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8, 3 },
1676 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
1677 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
1678 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8, 3 },
1679 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
1680 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
1681
1682 // Complex: to v4f32
1683 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8, 4 },
1684 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
1685 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8, 3 },
1686 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
1687
1688 // Complex: to v8f32
1689 { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8, 10 },
1690 { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
1691 { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8, 10 },
1692 { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
1693
1694 // Complex: to v16f32
1695 { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
1696 { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
1697
1698 // Complex: to v2f64
1699 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8, 4 },
1700 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
1701 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
1702 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8, 4 },
1703 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
1704 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
1705
1706 // Complex: to v4f64
1707 { ISD::SINT_TO_FP, MVT::v4f64, MVT::v4i32, 4 },
1708 { ISD::UINT_TO_FP, MVT::v4f64, MVT::v4i32, 4 },
1709
1710 // LowerVectorFP_TO_INT
1711 { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
1712 { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
1713 { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
1714 { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
1715 { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
1716 { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
1717
1718 // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
1719 { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
1720 { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
1721 { ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f32, 1 },
1722 { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
1723 { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
1724 { ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f32, 1 },
1725
1726 // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
1727 { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
1728 { ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f32, 2 },
1729 { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
1730 { ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f32, 2 },
1731
1732 // Complex, from nxv2f32.
1733 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
1734 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
1735 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
1736 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f32, 1 },
1737 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
1738 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
1739 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
1740 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f32, 1 },
1741
1742 // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
1743 { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
1744 { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
1745 { ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f64, 2 },
1746 { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
1747 { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
1748 { ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f64, 2 },
1749
1750 // Complex, from nxv2f64.
1751 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
1752 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
1753 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
1754 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f64, 1 },
1755 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
1756 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
1757 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
1758 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f64, 1 },
1759
1760 // Complex, from nxv4f32.
1761 { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
1762 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
1763 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
1764 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f32, 1 },
1765 { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
1766 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
1767 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
1768 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f32, 1 },
1769
1770 // Complex, from nxv8f64. Illegal -> illegal conversions not required.
1771 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
1772 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f64, 7 },
1773 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
1774 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f64, 7 },
1775
1776 // Complex, from nxv4f64. Illegal -> illegal conversions not required.
1777 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
1778 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
1779 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f64, 3 },
1780 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
1781 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
1782 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f64, 3 },
1783
1784 // Complex, from nxv8f32. Illegal -> illegal conversions not required.
1785 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
1786 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f32, 3 },
1787 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
1788 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f32, 3 },
1789
1790 // Complex, from nxv8f16.
1791 { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
1792 { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
1793 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
1794 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f16, 1 },
1795 { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
1796 { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
1797 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
1798 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f16, 1 },
1799
1800 // Complex, from nxv4f16.
1801 { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1802 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1803 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1804 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f16, 1 },
1805 { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1806 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1807 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1808 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f16, 1 },
1809
1810 // Complex, from nxv2f16.
1811 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1812 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1813 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1814 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f16, 1 },
1815 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1816 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1817 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1818 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f16, 1 },
1819
1820 // Truncate from nxvmf32 to nxvmf16.
1821 { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
1822 { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
1823 { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
1824
1825 // Truncate from nxvmf64 to nxvmf16.
1826 { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
1827 { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
1828 { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
1829
1830 // Truncate from nxvmf64 to nxvmf32.
1831 { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
1832 { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
1833 { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
1834
1835 // Extend from nxvmf16 to nxvmf32.
1836 { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
1837 { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
1838 { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
1839
1840 // Extend from nxvmf16 to nxvmf64.
1841 { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
1842 { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
1843 { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
1844
1845 // Extend from nxvmf32 to nxvmf64.
1846 { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
1847 { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
1848 { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
1849
1850 // Bitcasts from float to integer
1851 { ISD::BITCAST, MVT::nxv2f16, MVT::nxv2i16, 0 },
1852 { ISD::BITCAST, MVT::nxv4f16, MVT::nxv4i16, 0 },
1853 { ISD::BITCAST, MVT::nxv2f32, MVT::nxv2i32, 0 },
1854
1855 // Bitcasts from integer to float
1856 { ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0 },
1857 { ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0 },
1858 { ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0 },
1859 };
1860
1861 if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
1862 DstTy.getSimpleVT(),
1863 SrcTy.getSimpleVT()))
1864 return AdjustCost(Entry->Cost);
1865
1866 static const TypeConversionCostTblEntry FP16Tbl[] = {
1867 {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f16, 1}, // fcvtzs
1868 {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f16, 1},
1869 {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f16, 1}, // fcvtzs
1870 {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f16, 1},
1871 {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f16, 2}, // fcvtl+fcvtzs
1872 {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f16, 2},
1873 {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f16, 2}, // fcvtzs+xtn
1874 {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f16, 2},
1875 {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f16, 1}, // fcvtzs
1876 {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f16, 1},
1877 {ISD::FP_TO_SINT, MVT::v8i32, MVT::v8f16, 4}, // 2*fcvtl+2*fcvtzs
1878 {ISD::FP_TO_UINT, MVT::v8i32, MVT::v8f16, 4},
1879 {ISD::FP_TO_SINT, MVT::v16i8, MVT::v16f16, 3}, // 2*fcvtzs+xtn
1880 {ISD::FP_TO_UINT, MVT::v16i8, MVT::v16f16, 3},
1881 {ISD::FP_TO_SINT, MVT::v16i16, MVT::v16f16, 2}, // 2*fcvtzs
1882 {ISD::FP_TO_UINT, MVT::v16i16, MVT::v16f16, 2},
1883 {ISD::FP_TO_SINT, MVT::v16i32, MVT::v16f16, 8}, // 4*fcvtl+4*fcvtzs
1884 {ISD::FP_TO_UINT, MVT::v16i32, MVT::v16f16, 8},
1885 {ISD::UINT_TO_FP, MVT::v8f16, MVT::v8i8, 2}, // ushll + ucvtf
1886 {ISD::SINT_TO_FP, MVT::v8f16, MVT::v8i8, 2}, // sshll + scvtf
1887 {ISD::UINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * ushl(2) + 2 * ucvtf
1888 {ISD::SINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * sshl(2) + 2 * scvtf
1889 };
1890
1891 if (ST->hasFullFP16())
1892 if (const auto *Entry = ConvertCostTableLookup(
1893 FP16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
1894 return AdjustCost(Entry->Cost);
1895
1896 return AdjustCost(
1897 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
1898}
1899
1900InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
1901 Type *Dst,
1902 VectorType *VecTy,
1903 unsigned Index) {
1904
1905 // Make sure we were given a valid extend opcode.
1906 assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&(static_cast <bool> ((Opcode == Instruction::SExt || Opcode
== Instruction::ZExt) && "Invalid opcode") ? void (0
) : __assert_fail ("(Opcode == Instruction::SExt || Opcode == Instruction::ZExt) && \"Invalid opcode\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 1907
, __extension__ __PRETTY_FUNCTION__))
1907 "Invalid opcode")(static_cast <bool> ((Opcode == Instruction::SExt || Opcode
== Instruction::ZExt) && "Invalid opcode") ? void (0
) : __assert_fail ("(Opcode == Instruction::SExt || Opcode == Instruction::ZExt) && \"Invalid opcode\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 1907
, __extension__ __PRETTY_FUNCTION__))
;
1908
1909 // We are extending an element we extract from a vector, so the source type
1910 // of the extend is the element type of the vector.
1911 auto *Src = VecTy->getElementType();
1912
1913 // Sign- and zero-extends are for integer types only.
1914 assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type")(static_cast <bool> (isa<IntegerType>(Dst) &&
isa<IntegerType>(Src) && "Invalid type") ? void
(0) : __assert_fail ("isa<IntegerType>(Dst) && isa<IntegerType>(Src) && \"Invalid type\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 1914
, __extension__ __PRETTY_FUNCTION__))
;
1915
1916 // Get the cost for the extract. We compute the cost (if any) for the extend
1917 // below.
1918 InstructionCost Cost =
1919 getVectorInstrCost(Instruction::ExtractElement, VecTy, Index);
1920
1921 // Legalize the types.
1922 auto VecLT = getTypeLegalizationCost(VecTy);
1923 auto DstVT = TLI->getValueType(DL, Dst);
1924 auto SrcVT = TLI->getValueType(DL, Src);
1925 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1926
1927 // If the resulting type is still a vector and the destination type is legal,
1928 // we may get the extension for free. If not, get the default cost for the
1929 // extend.
1930 if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
1931 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1932 CostKind);
1933
1934 // The destination type should be larger than the element type. If not, get
1935 // the default cost for the extend.
1936 if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
1937 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1938 CostKind);
1939
1940 switch (Opcode) {
1941 default:
1942 llvm_unreachable("Opcode should be either SExt or ZExt")::llvm::llvm_unreachable_internal("Opcode should be either SExt or ZExt"
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 1942
)
;
1943
1944 // For sign-extends, we only need a smov, which performs the extension
1945 // automatically.
1946 case Instruction::SExt:
1947 return Cost;
1948
1949 // For zero-extends, the extend is performed automatically by a umov unless
1950 // the destination type is i64 and the element type is i8 or i16.
1951 case Instruction::ZExt:
1952 if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
1953 return Cost;
1954 }
1955
1956 // If we are unable to perform the extend for free, get the default cost.
1957 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1958 CostKind);
1959}
1960
1961InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
1962 TTI::TargetCostKind CostKind,
1963 const Instruction *I) {
1964 if (CostKind != TTI::TCK_RecipThroughput)
1965 return Opcode == Instruction::PHI ? 0 : 1;
1966 assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind")(static_cast <bool> (CostKind == TTI::TCK_RecipThroughput
&& "unexpected CostKind") ? void (0) : __assert_fail
("CostKind == TTI::TCK_RecipThroughput && \"unexpected CostKind\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 1966
, __extension__ __PRETTY_FUNCTION__))
;
1967 // Branches are assumed to be predicted.
1968 return 0;
1969}
1970
1971InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
1972 unsigned Index) {
1973 assert(Val->isVectorTy() && "This must be a vector type")(static_cast <bool> (Val->isVectorTy() && "This must be a vector type"
) ? void (0) : __assert_fail ("Val->isVectorTy() && \"This must be a vector type\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 1973
, __extension__ __PRETTY_FUNCTION__))
;
1974
1975 if (Index != -1U) {
1976 // Legalize the type.
1977 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
1978
1979 // This type is legalized to a scalar type.
1980 if (!LT.second.isVector())
1981 return 0;
1982
1983 // The type may be split. For fixed-width vectors we can normalize the
1984 // index to the new type.
1985 if (LT.second.isFixedLengthVector()) {
1986 unsigned Width = LT.second.getVectorNumElements();
1987 Index = Index % Width;
1988 }
1989
1990 // The element at index zero is already inside the vector.
1991 if (Index == 0)
1992 return 0;
1993 }
1994
1995 // All other insert/extracts cost this much.
1996 return ST->getVectorInsertExtractBaseCost();
1997}
1998
1999InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
2000 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
2001 TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
2002 ArrayRef<const Value *> Args,
2003 const Instruction *CxtI) {
2004
2005 // TODO: Handle more cost kinds.
2006 if (CostKind != TTI::TCK_RecipThroughput)
2007 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2008 Op2Info, Args, CxtI);
2009
2010 // Legalize the type.
2011 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
2012 int ISD = TLI->InstructionOpcodeToISD(Opcode);
2013
2014 switch (ISD) {
2015 default:
2016 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2017 Op2Info);
2018 case ISD::SDIV:
2019 if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) {
2020 // On AArch64, scalar signed division by constants power-of-two are
2021 // normally expanded to the sequence ADD + CMP + SELECT + SRA.
2022 // The OperandValue properties many not be same as that of previous
2023 // operation; conservatively assume OP_None.
2024 InstructionCost Cost = getArithmeticInstrCost(
2025 Instruction::Add, Ty, CostKind,
2026 Op1Info.getNoProps(), Op2Info.getNoProps());
2027 Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
2028 Op1Info.getNoProps(), Op2Info.getNoProps());
2029 Cost += getArithmeticInstrCost(
2030 Instruction::Select, Ty, CostKind,
2031 Op1Info.getNoProps(), Op2Info.getNoProps());
2032 Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
2033 Op1Info.getNoProps(), Op2Info.getNoProps());
2034 return Cost;
2035 }
2036 [[fallthrough]];
2037 case ISD::UDIV: {
2038 if (Op2Info.isConstant() && Op2Info.isUniform()) {
2039 auto VT = TLI->getValueType(DL, Ty);
2040 if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
2041 // Vector signed division by constant are expanded to the
2042 // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
2043 // to MULHS + SUB + SRL + ADD + SRL.
2044 InstructionCost MulCost = getArithmeticInstrCost(
2045 Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2046 InstructionCost AddCost = getArithmeticInstrCost(
2047 Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2048 InstructionCost ShrCost = getArithmeticInstrCost(
2049 Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2050 return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
2051 }
2052 }
2053
2054 InstructionCost Cost = BaseT::getArithmeticInstrCost(
2055 Opcode, Ty, CostKind, Op1Info, Op2Info);
2056 if (Ty->isVectorTy()) {
2057 // On AArch64, vector divisions are not supported natively and are
2058 // expanded into scalar divisions of each pair of elements.
2059 Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty, CostKind,
2060 Op1Info, Op2Info);
2061 Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
2062 Op1Info, Op2Info);
2063 // TODO: if one of the arguments is scalar, then it's not necessary to
2064 // double the cost of handling the vector elements.
2065 Cost += Cost;
2066 }
2067 return Cost;
2068 }
2069 case ISD::MUL:
2070 // Since we do not have a MUL.2d instruction, a mul <2 x i64> is expensive
2071 // as elements are extracted from the vectors and the muls scalarized.
2072 // As getScalarizationOverhead is a bit too pessimistic, we estimate the
2073 // cost for a i64 vector directly here, which is:
2074 // - four 2-cost i64 extracts,
2075 // - two 2-cost i64 inserts, and
2076 // - two 1-cost muls.
2077 // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
2078 // LT.first = 2 the cost is 28. If both operands are extensions it will not
2079 // need to scalarize so the cost can be cheaper (smull or umull).
2080 if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
2081 return LT.first;
2082 return LT.first * 14;
2083 case ISD::ADD:
2084 case ISD::XOR:
2085 case ISD::OR:
2086 case ISD::AND:
2087 case ISD::SRL:
2088 case ISD::SRA:
2089 case ISD::SHL:
2090 // These nodes are marked as 'custom' for combining purposes only.
2091 // We know that they are legal. See LowerAdd in ISelLowering.
2092 return LT.first;
2093
2094 case ISD::FADD:
2095 case ISD::FSUB:
2096 case ISD::FMUL:
2097 case ISD::FDIV:
2098 case ISD::FNEG:
2099 // These nodes are marked as 'custom' just to lower them to SVE.
2100 // We know said lowering will incur no additional cost.
2101 if (!Ty->getScalarType()->isFP128Ty())
2102 return 2 * LT.first;
2103
2104 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2105 Op2Info);
2106 }
2107}
2108
2109InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
2110 ScalarEvolution *SE,
2111 const SCEV *Ptr) {
2112 // Address computations in vectorized code with non-consecutive addresses will
2113 // likely result in more instructions compared to scalar code where the
2114 // computation can more often be merged into the index mode. The resulting
2115 // extra micro-ops can significantly decrease throughput.
2116 unsigned NumVectorInstToHideOverhead = 10;
2117 int MaxMergeDistance = 64;
2118
2119 if (Ty->isVectorTy() && SE &&
2120 !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
2121 return NumVectorInstToHideOverhead;
2122
2123 // In many cases the address computation is not merged into the instruction
2124 // addressing mode.
2125 return 1;
2126}
2127
2128InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
2129 Type *CondTy,
2130 CmpInst::Predicate VecPred,
2131 TTI::TargetCostKind CostKind,
2132 const Instruction *I) {
2133 // TODO: Handle other cost kinds.
2134 if (CostKind != TTI::TCK_RecipThroughput)
2135 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
2136 I);
2137
2138 int ISD = TLI->InstructionOpcodeToISD(Opcode);
2139 // We don't lower some vector selects well that are wider than the register
2140 // width.
2141 if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
2142 // We would need this many instructions to hide the scalarization happening.
2143 const int AmortizationCost = 20;
2144
2145 // If VecPred is not set, check if we can get a predicate from the context
2146 // instruction, if its type matches the requested ValTy.
2147 if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
2148 CmpInst::Predicate CurrentPred;
2149 if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
2150 m_Value())))
2151 VecPred = CurrentPred;
2152 }
2153 // Check if we have a compare/select chain that can be lowered using
2154 // a (F)CMxx & BFI pair.
2155 if (CmpInst::isIntPredicate(VecPred) || VecPred == CmpInst::FCMP_OLE ||
2156 VecPred == CmpInst::FCMP_OLT || VecPred == CmpInst::FCMP_OGT ||
2157 VecPred == CmpInst::FCMP_OGE || VecPred == CmpInst::FCMP_OEQ ||
2158 VecPred == CmpInst::FCMP_UNE) {
2159 static const auto ValidMinMaxTys = {
2160 MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
2161 MVT::v4i32, MVT::v2i64, MVT::v2f32, MVT::v4f32, MVT::v2f64};
2162 static const auto ValidFP16MinMaxTys = {MVT::v4f16, MVT::v8f16};
2163
2164 auto LT = getTypeLegalizationCost(ValTy);
2165 if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }) ||
2166 (ST->hasFullFP16() &&
2167 any_of(ValidFP16MinMaxTys, [&LT](MVT M) { return M == LT.second; })))
2168 return LT.first;
2169 }
2170
2171 static const TypeConversionCostTblEntry
2172 VectorSelectTbl[] = {
2173 { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
2174 { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
2175 { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
2176 { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
2177 { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
2178 { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
2179 };
2180
2181 EVT SelCondTy = TLI->getValueType(DL, CondTy);
2182 EVT SelValTy = TLI->getValueType(DL, ValTy);
2183 if (SelCondTy.isSimple() && SelValTy.isSimple()) {
2184 if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
2185 SelCondTy.getSimpleVT(),
2186 SelValTy.getSimpleVT()))
2187 return Entry->Cost;
2188 }
2189 }
2190 // The base case handles scalable vectors fine for now, since it treats the
2191 // cost as 1 * legalization cost.
2192 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
2193}
2194
2195AArch64TTIImpl::TTI::MemCmpExpansionOptions
2196AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
2197 TTI::MemCmpExpansionOptions Options;
2198 if (ST->requiresStrictAlign()) {
2199 // TODO: Add cost modeling for strict align. Misaligned loads expand to
2200 // a bunch of instructions when strict align is enabled.
2201 return Options;
2202 }
2203 Options.AllowOverlappingLoads = true;
2204 Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
2205 Options.NumLoadsPerBlock = Options.MaxNumLoads;
2206 // TODO: Though vector loads usually perform well on AArch64, in some targets
2207 // they may wake up the FP unit, which raises the power consumption. Perhaps
2208 // they could be used with no holds barred (-O3).
2209 Options.LoadSizes = {8, 4, 2, 1};
2210 return Options;
2211}
2212
2213bool AArch64TTIImpl::prefersVectorizedAddressing() const {
2214 return ST->hasSVE();
2215}
2216
2217InstructionCost
2218AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
2219 Align Alignment, unsigned AddressSpace,
2220 TTI::TargetCostKind CostKind) {
2221 if (useNeonVector(Src))
2222 return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
2223 CostKind);
2224 auto LT = getTypeLegalizationCost(Src);
2225 if (!LT.first.isValid())
2226 return InstructionCost::getInvalid();
2227
2228 // The code-generator is currently not able to handle scalable vectors
2229 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
2230 // it. This change will be removed when code-generation for these types is
2231 // sufficiently reliable.
2232 if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1))
2233 return InstructionCost::getInvalid();
2234
2235 return LT.first * 2;
2236}
2237
2238static unsigned getSVEGatherScatterOverhead(unsigned Opcode) {
2239 return Opcode == Instruction::Load ? SVEGatherOverhead : SVEScatterOverhead;
2240}
2241
2242InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
2243 unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
2244 Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
2245 if (useNeonVector(DataTy))
2246 return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
2247 Alignment, CostKind, I);
2248 auto *VT = cast<VectorType>(DataTy);
2249 auto LT = getTypeLegalizationCost(DataTy);
2250 if (!LT.first.isValid())
2251 return InstructionCost::getInvalid();
2252
2253 // The code-generator is currently not able to handle scalable vectors
2254 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
2255 // it. This change will be removed when code-generation for these types is
2256 // sufficiently reliable.
2257 if (cast<VectorType>(DataTy)->getElementCount() ==
2258 ElementCount::getScalable(1))
2259 return InstructionCost::getInvalid();
2260
2261 ElementCount LegalVF = LT.second.getVectorElementCount();
2262 InstructionCost MemOpCost =
2263 getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind,
2264 {TTI::OK_AnyValue, TTI::OP_None}, I);
2265 // Add on an overhead cost for using gathers/scatters.
2266 // TODO: At the moment this is applied unilaterally for all CPUs, but at some
2267 // point we may want a per-CPU overhead.
2268 MemOpCost *= getSVEGatherScatterOverhead(Opcode);
2269 return LT.first * MemOpCost * getMaxNumElements(LegalVF);
2270}
2271
2272bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
2273 return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
2274}
2275
2276InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
2277 MaybeAlign Alignment,
2278 unsigned AddressSpace,
2279 TTI::TargetCostKind CostKind,
2280 TTI::OperandValueInfo OpInfo,
2281 const Instruction *I) {
2282 EVT VT = TLI->getValueType(DL, Ty, true);
2283 // Type legalization can't handle structs
2284 if (VT == MVT::Other)
2285 return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
2286 CostKind);
2287
2288 auto LT = getTypeLegalizationCost(Ty);
2289 if (!LT.first.isValid())
2290 return InstructionCost::getInvalid();
2291
2292 // The code-generator is currently not able to handle scalable vectors
2293 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
2294 // it. This change will be removed when code-generation for these types is
2295 // sufficiently reliable.
2296 if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
2297 if (VTy->getElementCount() == ElementCount::getScalable(1))
2298 return InstructionCost::getInvalid();
2299
2300 // TODO: consider latency as well for TCK_SizeAndLatency.
2301 if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
2302 return LT.first;
2303
2304 if (CostKind != TTI::TCK_RecipThroughput)
2305 return 1;
2306
2307 if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
2308 LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
2309 // Unaligned stores are extremely inefficient. We don't split all
2310 // unaligned 128-bit stores because the negative impact that has shown in
2311 // practice on inlined block copy code.
2312 // We make such stores expensive so that we will only vectorize if there
2313 // are 6 other instructions getting vectorized.
2314 const int AmortizationCost = 6;
2315
2316 return LT.first * 2 * AmortizationCost;
2317 }
2318
2319 // Check truncating stores and extending loads.
2320 if (useNeonVector(Ty) &&
2321 Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
2322 // v4i8 types are lowered to scalar a load/store and sshll/xtn.
2323 if (VT == MVT::v4i8)
2324 return 2;
2325 // Otherwise we need to scalarize.
2326 return cast<FixedVectorType>(Ty)->getNumElements() * 2;
2327 }
2328
2329 return LT.first;
2330}
2331
2332InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
2333 unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
2334 Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
2335 bool UseMaskForCond, bool UseMaskForGaps) {
2336 assert(Factor >= 2 && "Invalid interleave factor")(static_cast <bool> (Factor >= 2 && "Invalid interleave factor"
) ? void (0) : __assert_fail ("Factor >= 2 && \"Invalid interleave factor\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 2336
, __extension__ __PRETTY_FUNCTION__))
;
2337 auto *VecVTy = cast<FixedVectorType>(VecTy);
2338
2339 if (!UseMaskForCond && !UseMaskForGaps &&
2340 Factor <= TLI->getMaxSupportedInterleaveFactor()) {
2341 unsigned NumElts = VecVTy->getNumElements();
2342 auto *SubVecTy =
2343 FixedVectorType::get(VecTy->getScalarType(), NumElts / Factor);
2344
2345 // ldN/stN only support legal vector types of size 64 or 128 in bits.
2346 // Accesses having vector types that are a multiple of 128 bits can be
2347 // matched to more than one ldN/stN instruction.
2348 bool UseScalable;
2349 if (NumElts % Factor == 0 &&
2350 TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
2351 return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
2352 }
2353
2354 return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
2355 Alignment, AddressSpace, CostKind,
2356 UseMaskForCond, UseMaskForGaps);
2357}
2358
2359InstructionCost
2360AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
2361 InstructionCost Cost = 0;
2362 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2363 for (auto *I : Tys) {
2364 if (!I->isVectorTy())
2365 continue;
2366 if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
2367 128)
2368 Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
2369 getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
2370 }
2371 return Cost;
2372}
2373
2374unsigned AArch64TTIImpl::getMaxInterleaveFactor(unsigned VF) {
2375 return ST->getMaxInterleaveFactor();
2376}
2377
2378// For Falkor, we want to avoid having too many strided loads in a loop since
2379// that can exhaust the HW prefetcher resources. We adjust the unroller
2380// MaxCount preference below to attempt to ensure unrolling doesn't create too
2381// many strided loads.
2382static void
2383getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
2384 TargetTransformInfo::UnrollingPreferences &UP) {
2385 enum { MaxStridedLoads = 7 };
2386 auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
2387 int StridedLoads = 0;
2388 // FIXME? We could make this more precise by looking at the CFG and
2389 // e.g. not counting loads in each side of an if-then-else diamond.
2390 for (const auto BB : L->blocks()) {
2391 for (auto &I : *BB) {
2392 LoadInst *LMemI = dyn_cast<LoadInst>(&I);
2393 if (!LMemI)
2394 continue;
2395
2396 Value *PtrValue = LMemI->getPointerOperand();
2397 if (L->isLoopInvariant(PtrValue))
2398 continue;
2399
2400 const SCEV *LSCEV = SE.getSCEV(PtrValue);
2401 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
2402 if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
2403 continue;
2404
2405 // FIXME? We could take pairing of unrolled load copies into account
2406 // by looking at the AddRec, but we would probably have to limit this
2407 // to loops with no stores or other memory optimization barriers.
2408 ++StridedLoads;
2409 // We've seen enough strided loads that seeing more won't make a
2410 // difference.
2411 if (StridedLoads > MaxStridedLoads / 2)
2412 return StridedLoads;
2413 }
2414 }
2415 return StridedLoads;
2416 };
2417
2418 int StridedLoads = countStridedLoads(L, SE);
2419 LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoadsdo { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("aarch64tti")) { dbgs() << "falkor-hwpf: detected " <<
StridedLoads << " strided loads\n"; } } while (false)
7
Assuming 'DebugFlag' is false
8
Loop condition is false. Exiting loop
2420 << " strided loads\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("aarch64tti")) { dbgs() << "falkor-hwpf: detected " <<
StridedLoads << " strided loads\n"; } } while (false)
;
2421 // Pick the largest power of 2 unroll count that won't result in too many
2422 // strided loads.
2423 if (StridedLoads) {
9
Assuming 'StridedLoads' is not equal to 0
10
Taking true branch
2424 UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
11
Calling 'Log2_32'
13
Returning from 'Log2_32'
14
The result of the left shift is undefined due to shifting by '4294967295', which is greater or equal to the width of type 'int'
2425 LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("aarch64tti")) { dbgs() << "falkor-hwpf: setting unroll MaxCount to "
<< UP.MaxCount << '\n'; } } while (false)
2426 << UP.MaxCount << '\n')do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("aarch64tti")) { dbgs() << "falkor-hwpf: setting unroll MaxCount to "
<< UP.MaxCount << '\n'; } } while (false)
;
2427 }
2428}
2429
2430void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
2431 TTI::UnrollingPreferences &UP,
2432 OptimizationRemarkEmitter *ORE) {
2433 // Enable partial unrolling and runtime unrolling.
2434 BaseT::getUnrollingPreferences(L, SE, UP, ORE);
2435
2436 UP.UpperBound = true;
2437
2438 // For inner loop, it is more likely to be a hot one, and the runtime check
2439 // can be promoted out from LICM pass, so the overhead is less, let's try
2440 // a larger threshold to unroll more loops.
2441 if (L->getLoopDepth() > 1)
1
Assuming the condition is false
2
Taking false branch
2442 UP.PartialThreshold *= 2;
2443
2444 // Disable partial & runtime unrolling on -Os.
2445 UP.PartialOptSizeThreshold = 0;
2446
2447 if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
3
Assuming the condition is true
5
Taking true branch
2448 EnableFalkorHWPFUnrollFix)
4
Assuming the condition is true
2449 getFalkorUnrollingPreferences(L, SE, UP);
6
Calling 'getFalkorUnrollingPreferences'
2450
2451 // Scan the loop: don't unroll loops with calls as this could prevent
2452 // inlining. Don't unroll vector loops either, as they don't benefit much from
2453 // unrolling.
2454 for (auto *BB : L->getBlocks()) {
2455 for (auto &I : *BB) {
2456 // Don't unroll vectorised loop.
2457 if (I.getType()->isVectorTy())
2458 return;
2459
2460 if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
2461 if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
2462 if (!isLoweredToCall(F))
2463 continue;
2464 }
2465 return;
2466 }
2467 }
2468 }
2469
2470 // Enable runtime unrolling for in-order models
2471 // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
2472 // checking for that case, we can ensure that the default behaviour is
2473 // unchanged
2474 if (ST->getProcFamily() != AArch64Subtarget::Others &&
2475 !ST->getSchedModel().isOutOfOrder()) {
2476 UP.Runtime = true;
2477 UP.Partial = true;
2478 UP.UnrollRemainder = true;
2479 UP.DefaultUnrollRuntimeCount = 4;
2480
2481 UP.UnrollAndJam = true;
2482 UP.UnrollAndJamInnerLoopThreshold = 60;
2483 }
2484}
2485
2486void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
2487 TTI::PeelingPreferences &PP) {
2488 BaseT::getPeelingPreferences(L, SE, PP);
2489}
2490
2491Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
2492 Type *ExpectedType) {
2493 switch (Inst->getIntrinsicID()) {
2494 default:
2495 return nullptr;
2496 case Intrinsic::aarch64_neon_st2:
2497 case Intrinsic::aarch64_neon_st3:
2498 case Intrinsic::aarch64_neon_st4: {
2499 // Create a struct type
2500 StructType *ST = dyn_cast<StructType>(ExpectedType);
2501 if (!ST)
2502 return nullptr;
2503 unsigned NumElts = Inst->arg_size() - 1;
2504 if (ST->getNumElements() != NumElts)
2505 return nullptr;
2506 for (unsigned i = 0, e = NumElts; i != e; ++i) {
2507 if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
2508 return nullptr;
2509 }
2510 Value *Res = UndefValue::get(ExpectedType);
2511 IRBuilder<> Builder(Inst);
2512 for (unsigned i = 0, e = NumElts; i != e; ++i) {
2513 Value *L = Inst->getArgOperand(i);
2514 Res = Builder.CreateInsertValue(Res, L, i);
2515 }
2516 return Res;
2517 }
2518 case Intrinsic::aarch64_neon_ld2:
2519 case Intrinsic::aarch64_neon_ld3:
2520 case Intrinsic::aarch64_neon_ld4:
2521 if (Inst->getType() == ExpectedType)
2522 return Inst;
2523 return nullptr;
2524 }
2525}
2526
2527bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
2528 MemIntrinsicInfo &Info) {
2529 switch (Inst->getIntrinsicID()) {
2530 default:
2531 break;
2532 case Intrinsic::aarch64_neon_ld2:
2533 case Intrinsic::aarch64_neon_ld3:
2534 case Intrinsic::aarch64_neon_ld4:
2535 Info.ReadMem = true;
2536 Info.WriteMem = false;
2537 Info.PtrVal = Inst->getArgOperand(0);
2538 break;
2539 case Intrinsic::aarch64_neon_st2:
2540 case Intrinsic::aarch64_neon_st3:
2541 case Intrinsic::aarch64_neon_st4:
2542 Info.ReadMem = false;
2543 Info.WriteMem = true;
2544 Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
2545 break;
2546 }
2547
2548 switch (Inst->getIntrinsicID()) {
2549 default:
2550 return false;
2551 case Intrinsic::aarch64_neon_ld2:
2552 case Intrinsic::aarch64_neon_st2:
2553 Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
2554 break;
2555 case Intrinsic::aarch64_neon_ld3:
2556 case Intrinsic::aarch64_neon_st3:
2557 Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
2558 break;
2559 case Intrinsic::aarch64_neon_ld4:
2560 case Intrinsic::aarch64_neon_st4:
2561 Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
2562 break;
2563 }
2564 return true;
2565}
2566
2567/// See if \p I should be considered for address type promotion. We check if \p
2568/// I is a sext with right type and used in memory accesses. If it used in a
2569/// "complex" getelementptr, we allow it to be promoted without finding other
2570/// sext instructions that sign extended the same initial value. A getelementptr
2571/// is considered as "complex" if it has more than 2 operands.
2572bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
2573 const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
2574 bool Considerable = false;
2575 AllowPromotionWithoutCommonHeader = false;
2576 if (!isa<SExtInst>(&I))
2577 return false;
2578 Type *ConsideredSExtType =
2579 Type::getInt64Ty(I.getParent()->getParent()->getContext());
2580 if (I.getType() != ConsideredSExtType)
2581 return false;
2582 // See if the sext is the one with the right type and used in at least one
2583 // GetElementPtrInst.
2584 for (const User *U : I.users()) {
2585 if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
2586 Considerable = true;
2587 // A getelementptr is considered as "complex" if it has more than 2
2588 // operands. We will promote a SExt used in such complex GEP as we
2589 // expect some computation to be merged if they are done on 64 bits.
2590 if (GEPInst->getNumOperands() > 2) {
2591 AllowPromotionWithoutCommonHeader = true;
2592 break;
2593 }
2594 }
2595 }
2596 return Considerable;
2597}
2598
2599bool AArch64TTIImpl::isLegalToVectorizeReduction(
2600 const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
2601 if (!VF.isScalable())
2602 return true;
2603
2604 Type *Ty = RdxDesc.getRecurrenceType();
2605 if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
2606 return false;
2607
2608 switch (RdxDesc.getRecurrenceKind()) {
2609 case RecurKind::Add:
2610 case RecurKind::FAdd:
2611 case RecurKind::And:
2612 case RecurKind::Or:
2613 case RecurKind::Xor:
2614 case RecurKind::SMin:
2615 case RecurKind::SMax:
2616 case RecurKind::UMin:
2617 case RecurKind::UMax:
2618 case RecurKind::FMin:
2619 case RecurKind::FMax:
2620 case RecurKind::SelectICmp:
2621 case RecurKind::SelectFCmp:
2622 case RecurKind::FMulAdd:
2623 return true;
2624 default:
2625 return false;
2626 }
2627}
2628
2629InstructionCost
2630AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
2631 bool IsUnsigned,
2632 TTI::TargetCostKind CostKind) {
2633 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
2634
2635 if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
2636 return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
2637
2638 assert((isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) &&(static_cast <bool> ((isa<ScalableVectorType>(Ty)
== isa<ScalableVectorType>(CondTy)) && "Both vector needs to be equally scalable"
) ? void (0) : __assert_fail ("(isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) && \"Both vector needs to be equally scalable\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 2639
, __extension__ __PRETTY_FUNCTION__))
2639 "Both vector needs to be equally scalable")(static_cast <bool> ((isa<ScalableVectorType>(Ty)
== isa<ScalableVectorType>(CondTy)) && "Both vector needs to be equally scalable"
) ? void (0) : __assert_fail ("(isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) && \"Both vector needs to be equally scalable\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 2639
, __extension__ __PRETTY_FUNCTION__))
;
2640
2641 InstructionCost LegalizationCost = 0;
2642 if (LT.first > 1) {
2643 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
2644 unsigned MinMaxOpcode =
2645 Ty->isFPOrFPVectorTy()
2646 ? Intrinsic::maxnum
2647 : (IsUnsigned ? Intrinsic::umin : Intrinsic::smin);
2648 IntrinsicCostAttributes Attrs(MinMaxOpcode, LegalVTy, {LegalVTy, LegalVTy});
2649 LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
2650 }
2651
2652 return LegalizationCost + /*Cost of horizontal reduction*/ 2;
2653}
2654
2655InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
2656 unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
2657 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
2658 InstructionCost LegalizationCost = 0;
2659 if (LT.first > 1) {
2660 Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
2661 LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
2662 LegalizationCost *= LT.first - 1;
2663 }
2664
2665 int ISD = TLI->InstructionOpcodeToISD(Opcode);
2666 assert(ISD && "Invalid opcode")(static_cast <bool> (ISD && "Invalid opcode") ?
void (0) : __assert_fail ("ISD && \"Invalid opcode\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 2666
, __extension__ __PRETTY_FUNCTION__))
;
2667 // Add the final reduction cost for the legal horizontal reduction
2668 switch (ISD) {
2669 case ISD::ADD:
2670 case ISD::AND:
2671 case ISD::OR:
2672 case ISD::XOR:
2673 case ISD::FADD:
2674 return LegalizationCost + 2;
2675 default:
2676 return InstructionCost::getInvalid();
2677 }
2678}
2679
2680InstructionCost
2681AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
2682 Optional<FastMathFlags> FMF,
2683 TTI::TargetCostKind CostKind) {
2684 if (TTI::requiresOrderedReduction(FMF)) {
2685 if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
2686 InstructionCost BaseCost =
2687 BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
2688 // Add on extra cost to reflect the extra overhead on some CPUs. We still
2689 // end up vectorizing for more computationally intensive loops.
2690 return BaseCost + FixedVTy->getNumElements();
2691 }
2692
2693 if (Opcode != Instruction::FAdd)
2694 return InstructionCost::getInvalid();
2695
2696 auto *VTy = cast<ScalableVectorType>(ValTy);
2697 InstructionCost Cost =
2698 getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
2699 Cost *= getMaxNumElements(VTy->getElementCount());
2700 return Cost;
2701 }
2702
2703 if (isa<ScalableVectorType>(ValTy))
2704 return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
2705
2706 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
2707 MVT MTy = LT.second;
2708 int ISD = TLI->InstructionOpcodeToISD(Opcode);
2709 assert(ISD && "Invalid opcode")(static_cast <bool> (ISD && "Invalid opcode") ?
void (0) : __assert_fail ("ISD && \"Invalid opcode\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 2709
, __extension__ __PRETTY_FUNCTION__))
;
2710
2711 // Horizontal adds can use the 'addv' instruction. We model the cost of these
2712 // instructions as twice a normal vector add, plus 1 for each legalization
2713 // step (LT.first). This is the only arithmetic vector reduction operation for
2714 // which we have an instruction.
2715 // OR, XOR and AND costs should match the codegen from:
2716 // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
2717 // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
2718 // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
2719 static const CostTblEntry CostTblNoPairwise[]{
2720 {ISD::ADD, MVT::v8i8, 2},
2721 {ISD::ADD, MVT::v16i8, 2},
2722 {ISD::ADD, MVT::v4i16, 2},
2723 {ISD::ADD, MVT::v8i16, 2},
2724 {ISD::ADD, MVT::v4i32, 2},
2725 {ISD::ADD, MVT::v2i64, 2},
2726 {ISD::OR, MVT::v8i8, 15},
2727 {ISD::OR, MVT::v16i8, 17},
2728 {ISD::OR, MVT::v4i16, 7},
2729 {ISD::OR, MVT::v8i16, 9},
2730 {ISD::OR, MVT::v2i32, 3},
2731 {ISD::OR, MVT::v4i32, 5},
2732 {ISD::OR, MVT::v2i64, 3},
2733 {ISD::XOR, MVT::v8i8, 15},
2734 {ISD::XOR, MVT::v16i8, 17},
2735 {ISD::XOR, MVT::v4i16, 7},
2736 {ISD::XOR, MVT::v8i16, 9},
2737 {ISD::XOR, MVT::v2i32, 3},
2738 {ISD::XOR, MVT::v4i32, 5},
2739 {ISD::XOR, MVT::v2i64, 3},
2740 {ISD::AND, MVT::v8i8, 15},
2741 {ISD::AND, MVT::v16i8, 17},
2742 {ISD::AND, MVT::v4i16, 7},
2743 {ISD::AND, MVT::v8i16, 9},
2744 {ISD::AND, MVT::v2i32, 3},
2745 {ISD::AND, MVT::v4i32, 5},
2746 {ISD::AND, MVT::v2i64, 3},
2747 };
2748 switch (ISD) {
2749 default:
2750 break;
2751 case ISD::ADD:
2752 if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
2753 return (LT.first - 1) + Entry->Cost;
2754 break;
2755 case ISD::XOR:
2756 case ISD::AND:
2757 case ISD::OR:
2758 const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
2759 if (!Entry)
2760 break;
2761 auto *ValVTy = cast<FixedVectorType>(ValTy);
2762 if (!ValVTy->getElementType()->isIntegerTy(1) &&
2763 MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
2764 isPowerOf2_32(ValVTy->getNumElements())) {
2765 InstructionCost ExtraCost = 0;
2766 if (LT.first != 1) {
2767 // Type needs to be split, so there is an extra cost of LT.first - 1
2768 // arithmetic ops.
2769 auto *Ty = FixedVectorType::get(ValTy->getElementType(),
2770 MTy.getVectorNumElements());
2771 ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
2772 ExtraCost *= LT.first - 1;
2773 }
2774 return Entry->Cost + ExtraCost;
2775 }
2776 break;
2777 }
2778 return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
2779}
2780
2781InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
2782 static const CostTblEntry ShuffleTbl[] = {
2783 { TTI::SK_Splice, MVT::nxv16i8, 1 },
2784 { TTI::SK_Splice, MVT::nxv8i16, 1 },
2785 { TTI::SK_Splice, MVT::nxv4i32, 1 },
2786 { TTI::SK_Splice, MVT::nxv2i64, 1 },
2787 { TTI::SK_Splice, MVT::nxv2f16, 1 },
2788 { TTI::SK_Splice, MVT::nxv4f16, 1 },
2789 { TTI::SK_Splice, MVT::nxv8f16, 1 },
2790 { TTI::SK_Splice, MVT::nxv2bf16, 1 },
2791 { TTI::SK_Splice, MVT::nxv4bf16, 1 },
2792 { TTI::SK_Splice, MVT::nxv8bf16, 1 },
2793 { TTI::SK_Splice, MVT::nxv2f32, 1 },
2794 { TTI::SK_Splice, MVT::nxv4f32, 1 },
2795 { TTI::SK_Splice, MVT::nxv2f64, 1 },
2796 };
2797
2798 // The code-generator is currently not able to handle scalable vectors
2799 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
2800 // it. This change will be removed when code-generation for these types is
2801 // sufficiently reliable.
2802 if (Tp->getElementCount() == ElementCount::getScalable(1))
2803 return InstructionCost::getInvalid();
2804
2805 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
2806 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
2807 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2808 EVT PromotedVT = LT.second.getScalarType() == MVT::i1
2809 ? TLI->getPromotedVTForPredicate(EVT(LT.second))
2810 : LT.second;
2811 Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
2812 InstructionCost LegalizationCost = 0;
2813 if (Index < 0) {
2814 LegalizationCost =
2815 getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
2816 CmpInst::BAD_ICMP_PREDICATE, CostKind) +
2817 getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
2818 CmpInst::BAD_ICMP_PREDICATE, CostKind);
2819 }
2820
2821 // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
2822 // Cost performed on a promoted type.
2823 if (LT.second.getScalarType() == MVT::i1) {
2824 LegalizationCost +=
2825 getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
2826 TTI::CastContextHint::None, CostKind) +
2827 getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
2828 TTI::CastContextHint::None, CostKind);
2829 }
2830 const auto *Entry =
2831 CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
2832 assert(Entry && "Illegal Type for Splice")(static_cast <bool> (Entry && "Illegal Type for Splice"
) ? void (0) : __assert_fail ("Entry && \"Illegal Type for Splice\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 2832
, __extension__ __PRETTY_FUNCTION__))
;
2833 LegalizationCost += Entry->Cost;
2834 return LegalizationCost * LT.first;
2835}
2836
2837InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
2838 VectorType *Tp,
2839 ArrayRef<int> Mask,
2840 TTI::TargetCostKind CostKind,
2841 int Index, VectorType *SubTp,
2842 ArrayRef<const Value *> Args) {
2843 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
2844 // If we have a Mask, and the LT is being legalized somehow, split the Mask
2845 // into smaller vectors and sum the cost of each shuffle.
2846 if (!Mask.empty() && isa<FixedVectorType>(Tp) && LT.second.isVector() &&
2847 Tp->getScalarSizeInBits() == LT.second.getScalarSizeInBits() &&
2848 cast<FixedVectorType>(Tp)->getNumElements() >
2849 LT.second.getVectorNumElements() &&
2850 !Index && !SubTp) {
2851 unsigned TpNumElts = cast<FixedVectorType>(Tp)->getNumElements();
2852 assert(Mask.size() == TpNumElts && "Expected Mask and Tp size to match!")(static_cast <bool> (Mask.size() == TpNumElts &&
"Expected Mask and Tp size to match!") ? void (0) : __assert_fail
("Mask.size() == TpNumElts && \"Expected Mask and Tp size to match!\""
, "llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp", 2852
, __extension__ __PRETTY_FUNCTION__))
;
2853 unsigned LTNumElts = LT.second.getVectorNumElements();
2854 unsigned NumVecs = (TpNumElts + LTNumElts - 1) / LTNumElts;
2855 VectorType *NTp =
2856 VectorType::get(Tp->getScalarType(), LT.second.getVectorElementCount());
2857 InstructionCost Cost;
2858 for (unsigned N = 0; N < NumVecs; N++) {
2859 SmallVector<int> NMask;
2860 // Split the existing mask into chunks of size LTNumElts. Track the source
2861 // sub-vectors to ensure the result has at most 2 inputs.
2862 unsigned Source1, Source2;
2863 unsigned NumSources = 0;
2864 for (unsigned E = 0; E < LTNumElts; E++) {
2865 int MaskElt = (N * LTNumElts + E < TpNumElts) ? Mask[N * LTNumElts + E]
2866 : UndefMaskElem;
2867 if (MaskElt < 0) {
2868 NMask.push_back(UndefMaskElem);
2869 continue;
2870 }
2871
2872 // Calculate which source from the input this comes from and whether it
2873 // is new to us.
2874 unsigned Source = MaskElt / LTNumElts;
2875 if (NumSources == 0) {
2876 Source1 = Source;
2877 NumSources = 1;
2878 } else if (NumSources == 1 && Source != Source1) {
2879 Source2 = Source;
2880 NumSources = 2;
2881 } else if (NumSources >= 2 && Source != Source1 && Source != Source2) {
2882 NumSources++;
2883 }
2884
2885 // Add to the new mask. For the NumSources>2 case these are not correct,
2886 // but are only used for the modular lane number.
2887 if (Source == Source1)
2888 NMask.push_back(MaskElt % LTNumElts);
2889 else if (Source == Source2)
2890 NMask.push_back(MaskElt % LTNumElts + LTNumElts);
2891 else
2892 NMask.push_back(MaskElt % LTNumElts);
2893 }
2894 // If the sub-mask has at most 2 input sub-vectors then re-cost it using
2895 // getShuffleCost. If not then cost it using the worst case.
2896 if (NumSources <= 2)
2897 Cost += getShuffleCost(NumSources <= 1 ? TTI::SK_PermuteSingleSrc
2898 : TTI::SK_PermuteTwoSrc,
2899 NTp, NMask, CostKind, 0, nullptr, Args);
2900 else if (any_of(enumerate(NMask), [&](const auto &ME) {
2901 return ME.value() % LTNumElts == ME.index();
2902 }))
2903 Cost += LTNumElts - 1;
2904 else
2905 Cost += LTNumElts;
2906 }
2907 return Cost;
2908 }
2909
2910 Kind = improveShuffleKindFromMask(Kind, Mask);
2911
2912 // Check for broadcast loads.
2913 if (Kind == TTI::SK_Broadcast) {
2914 bool IsLoad = !Args.empty() && isa<LoadInst>(Args[0]);
2915 if (IsLoad && LT.second.isVector() &&
2916 isLegalBroadcastLoad(Tp->getElementType(),
2917 LT.second.getVectorElementCount()))
2918 return 0; // broadcast is handled by ld1r
2919 }
2920
2921 // If we have 4 elements for the shuffle and a Mask, get the cost straight
2922 // from the perfect shuffle tables.
2923 if (Mask.size() == 4 && Tp->getElementCount() == ElementCount::getFixed(4) &&
2924 (Tp->getScalarSizeInBits() == 16 || Tp->getScalarSizeInBits() == 32) &&
2925 all_of(Mask, [](int E) { return E < 8; }))
2926 return getPerfectShuffleCost(Mask);
2927
2928 if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
2929 Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
2930 Kind == TTI::SK_Reverse || Kind == TTI::SK_Splice) {
2931 static const CostTblEntry ShuffleTbl[] = {
2932 // Broadcast shuffle kinds can be performed with 'dup'.
2933 { TTI::SK_Broadcast, MVT::v8i8, 1 },
2934 { TTI::SK_Broadcast, MVT::v16i8, 1 },
2935 { TTI::SK_Broadcast, MVT::v4i16, 1 },
2936 { TTI::SK_Broadcast, MVT::v8i16, 1 },
2937 { TTI::SK_Broadcast, MVT::v2i32, 1 },
2938 { TTI::SK_Broadcast, MVT::v4i32, 1 },
2939 { TTI::SK_Broadcast, MVT::v2i64, 1 },
2940 { TTI::SK_Broadcast, MVT::v2f32, 1 },
2941 { TTI::SK_Broadcast, MVT::v4f32, 1 },
2942 { TTI::SK_Broadcast, MVT::v2f64, 1 },
2943 // Transpose shuffle kinds can be performed with 'trn1/trn2' and
2944 // 'zip1/zip2' instructions.
2945 { TTI::SK_Transpose, MVT::v8i8, 1 },
2946 { TTI::SK_Transpose, MVT::v16i8, 1 },
2947 { TTI::SK_Transpose, MVT::v4i16, 1 },
2948 { TTI::SK_Transpose, MVT::v8i16, 1 },
2949 { TTI::SK_Transpose, MVT::v2i32, 1 },
2950 { TTI::SK_Transpose, MVT::v4i32, 1 },
2951 { TTI::SK_Transpose, MVT::v2i64, 1 },
2952 { TTI::SK_Transpose, MVT::v2f32, 1 },
2953 { TTI::SK_Transpose, MVT::v4f32, 1 },
2954 { TTI::SK_Transpose, MVT::v2f64, 1 },
2955 // Select shuffle kinds.
2956 // TODO: handle vXi8/vXi16.
2957 { TTI::SK_Select, MVT::v2i32, 1 }, // mov.
2958 { TTI::SK_Select, MVT::v4i32, 2 }, // rev+trn (or similar).
2959 { TTI::SK_Select, MVT::v2i64, 1 }, // mov.
2960 { TTI::SK_Select, MVT::v2f32, 1 }, // mov.
2961 { TTI::SK_Select, MVT::v4f32, 2 }, // rev+trn (or similar).
2962 { TTI::SK_Select, MVT::v2f64, 1 }, // mov.
2963 // PermuteSingleSrc shuffle kinds.
2964 { TTI::SK_PermuteSingleSrc, MVT::v2i32, 1 }, // mov.
2965 { TTI::SK_PermuteSingleSrc, MVT::v4i32, 3 }, // perfectshuffle worst case.
2966 { TTI::SK_PermuteSingleSrc, MVT::v2i64, 1 }, // mov.
2967 { TTI::SK_PermuteSingleSrc, MVT::v2f32, 1 }, // mov.
2968 { TTI::SK_PermuteSingleSrc, MVT::v4f32, 3 }, // perfectshuffle worst case.
2969 { TTI::SK_PermuteSingleSrc, MVT::v2f64, 1 }, // mov.
2970 { TTI::SK_PermuteSingleSrc, MVT::v4i16, 3 }, // perfectshuffle worst case.
2971 { TTI::SK_PermuteSingleSrc, MVT::v4f16, 3 }, // perfectshuffle worst case.
2972 { TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3 }, // perfectshuffle worst case.
2973 { TTI::SK_PermuteSingleSrc, MVT::v8i16, 8 }, // constpool + load + tbl
2974 { TTI::SK_PermuteSingleSrc, MVT::v8f16, 8 }, // constpool + load + tbl
2975 { TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8 }, // constpool + load + tbl
2976 { TTI::SK_PermuteSingleSrc, MVT::v8i8, 8 }, // constpool + load + tbl
2977 { TTI::SK_PermuteSingleSrc, MVT::v16i8, 8 }, // constpool + load + tbl
2978 // Reverse can be lowered with `rev`.
2979 { TTI::SK_Reverse, MVT::v2i32, 1 }, // mov.
2980 { TTI::SK_Reverse, MVT::v4i32, 2 }, // REV64; EXT
2981 { TTI::SK_Reverse, MVT::v2i64, 1 }, // mov.
2982 { TTI::SK_Reverse, MVT::v2f32, 1 }, // mov.
2983 { TTI::SK_Reverse, MVT::v4f32, 2 }, // REV64; EXT
2984 { TTI::SK_Reverse, MVT::v2f64, 1 }, // mov.
2985 { TTI::SK_Reverse, MVT::v8f16, 2 }, // REV64; EXT
2986 { TTI::SK_Reverse, MVT::v8i16, 2 }, // REV64; EXT
2987 { TTI::SK_Reverse, MVT::v16i8, 2 }, // REV64; EXT
2988 { TTI::SK_Reverse, MVT::v4f16, 1 }, // REV64
2989 { TTI::SK_Reverse, MVT::v4i16, 1 }, // REV64
2990 { TTI::SK_Reverse, MVT::v8i8, 1 }, // REV64
2991 // Splice can all be lowered as `ext`.
2992 { TTI::SK_Splice, MVT::v2i32, 1 },
2993 { TTI::SK_Splice, MVT::v4i32, 1 },
2994 { TTI::SK_Splice, MVT::v2i64, 1 },
2995 { TTI::SK_Splice, MVT::v2f32, 1 },
2996 { TTI::SK_Splice, MVT::v4f32, 1 },
2997 { TTI::SK_Splice, MVT::v2f64, 1 },
2998 { TTI::SK_Splice, MVT::v8f16, 1 },
2999 { TTI::SK_Splice, MVT::v8bf16, 1 },
3000 { TTI::SK_Splice, MVT::v8i16, 1 },
3001 { TTI::SK_Splice, MVT::v16i8, 1 },
3002 { TTI::SK_Splice, MVT::v4bf16, 1 },
3003 { TTI::SK_Splice, MVT::v4f16, 1 },
3004 { TTI::SK_Splice, MVT::v4i16, 1 },
3005 { TTI::SK_Splice, MVT::v8i8, 1 },
3006 // Broadcast shuffle kinds for scalable vectors
3007 { TTI::SK_Broadcast, MVT::nxv16i8, 1 },
3008 { TTI::SK_Broadcast, MVT::nxv8i16, 1 },
3009 { TTI::SK_Broadcast, MVT::nxv4i32, 1 },
3010 { TTI::SK_Broadcast, MVT::nxv2i64, 1 },
3011 { TTI::SK_Broadcast, MVT::nxv2f16, 1 },
3012 { TTI::SK_Broadcast, MVT::nxv4f16, 1 },
3013 { TTI::SK_Broadcast, MVT::nxv8f16, 1 },
3014 { TTI::SK_Broadcast, MVT::nxv2bf16, 1 },
3015 { TTI::SK_Broadcast, MVT::nxv4bf16, 1 },
3016 { TTI::SK_Broadcast, MVT::nxv8bf16, 1 },
3017 { TTI::SK_Broadcast, MVT::nxv2f32, 1 },
3018 { TTI::SK_Broadcast, MVT::nxv4f32, 1 },
3019 { TTI::SK_Broadcast, MVT::nxv2f64, 1 },
3020 { TTI::SK_Broadcast, MVT::nxv16i1, 1 },
3021 { TTI::SK_Broadcast, MVT::nxv8i1, 1 },
3022 { TTI::SK_Broadcast, MVT::nxv4i1, 1 },
3023 { TTI::SK_Broadcast, MVT::nxv2i1, 1 },
3024 // Handle the cases for vector.reverse with scalable vectors
3025 { TTI::SK_Reverse, MVT::nxv16i8, 1 },
3026 { TTI::SK_Reverse, MVT::nxv8i16, 1 },
3027 { TTI::SK_Reverse, MVT::nxv4i32, 1 },
3028 { TTI::SK_Reverse, MVT::nxv2i64, 1 },
3029 { TTI::SK_Reverse, MVT::nxv2f16, 1 },
3030 { TTI::SK_Reverse, MVT::nxv4f16, 1 },
3031 { TTI::SK_Reverse, MVT::nxv8f16, 1 },
3032 { TTI::SK_Reverse, MVT::nxv2bf16, 1 },
3033 { TTI::SK_Reverse, MVT::nxv4bf16, 1 },
3034 { TTI::SK_Reverse, MVT::nxv8bf16, 1 },
3035 { TTI::SK_Reverse, MVT::nxv2f32, 1 },
3036 { TTI::SK_Reverse, MVT::nxv4f32, 1 },
3037 { TTI::SK_Reverse, MVT::nxv2f64, 1 },
3038 { TTI::SK_Reverse, MVT::nxv16i1, 1 },
3039 { TTI::SK_Reverse, MVT::nxv8i1, 1 },
3040 { TTI::SK_Reverse, MVT::nxv4i1, 1 },
3041 { TTI::SK_Reverse, MVT::nxv2i1, 1 },
3042 };
3043 if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
3044 return LT.first * Entry->Cost;
3045 }
3046
3047 if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
3048 return getSpliceCost(Tp, Index);
3049
3050 // Inserting a subvector can often be done with either a D, S or H register
3051 // move, so long as the inserted vector is "aligned".
3052 if (Kind == TTI::SK_InsertSubvector && LT.second.isFixedLengthVector() &&
3053 LT.second.getSizeInBits() <= 128 && SubTp) {
3054 std::pair<InstructionCost, MVT> SubLT = getTypeLegalizationCost(SubTp);
3055 if (SubLT.second.isVector()) {
3056 int NumElts = LT.second.getVectorNumElements();
3057 int NumSubElts = SubLT.second.getVectorNumElements();
3058 if ((Index % NumSubElts) == 0 && (NumElts % NumSubElts) == 0)
3059 return SubLT.first;
3060 }
3061 }
3062
3063 return BaseT::getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp);
3064}
3065
3066bool AArch64TTIImpl::preferPredicateOverEpilogue(
3067 Loop *L, LoopInfo *LI, ScalarEvolution &SE, AssumptionCache &AC,
3068 TargetLibraryInfo *TLI, DominatorTree *DT, LoopVectorizationLegality *LVL,
3069 InterleavedAccessInfo *IAI) {
3070 if (!ST->hasSVE() || TailFoldingKindLoc == TailFoldingKind::TFDisabled)
3071 return false;
3072
3073 // We don't currently support vectorisation with interleaving for SVE - with
3074 // such loops we're better off not using tail-folding. This gives us a chance
3075 // to fall back on fixed-width vectorisation using NEON's ld2/st2/etc.
3076 if (IAI->hasGroups())
3077 return false;
3078
3079 TailFoldingKind Required; // Defaults to 0.
3080 if (LVL->getReductionVars().size())
3081 Required.add(TailFoldingKind::TFReductions);
3082 if (LVL->getFixedOrderRecurrences().size())
3083 Required.add(TailFoldingKind::TFRecurrences);
3084 if (!Required)
3085 Required.add(TailFoldingKind::TFSimple);
3086
3087 return (TailFoldingKindLoc & Required) == Required;
3088}
3089
3090InstructionCost
3091AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
3092 int64_t BaseOffset, bool HasBaseReg,
3093 int64_t Scale, unsigned AddrSpace) const {
3094 // Scaling factors are not free at all.
3095 // Operands | Rt Latency
3096 // -------------------------------------------
3097 // Rt, [Xn, Xm] | 4
3098 // -------------------------------------------
3099 // Rt, [Xn, Xm, lsl #imm] | Rn: 4 Rm: 5
3100 // Rt, [Xn, Wm, <extend> #imm] |
3101 TargetLoweringBase::AddrMode AM;
3102 AM.BaseGV = BaseGV;
3103 AM.BaseOffs = BaseOffset;
3104 AM.HasBaseReg = HasBaseReg;
3105 AM.Scale = Scale;
3106 if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
3107 // Scale represents reg2 * scale, thus account for 1 if
3108 // it is not equal to 0 or 1.
3109 return AM.Scale != 0 && AM.Scale != 1;
3110 return -1;
3111}

/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/include/llvm/Support/MathExtras.h

1//===-- llvm/Support/MathExtras.h - Useful math functions -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains some functions that are useful for math stuff.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_SUPPORT_MATHEXTRAS_H
14#define LLVM_SUPPORT_MATHEXTRAS_H
15
16#include "llvm/ADT/bit.h"
17#include "llvm/Support/Compiler.h"
18#include <cassert>
19#include <climits>
20#include <cmath>
21#include <cstdint>
22#include <cstring>
23#include <limits>
24#include <type_traits>
25
26#ifdef _MSC_VER
27// Declare these intrinsics manually rather including intrin.h. It's very
28// expensive, and MathExtras.h is popular.
29// #include <intrin.h>
30extern "C" {
31unsigned char _BitScanForward(unsigned long *_Index, unsigned long _Mask);
32unsigned char _BitScanForward64(unsigned long *_Index, unsigned __int64 _Mask);
33unsigned char _BitScanReverse(unsigned long *_Index, unsigned long _Mask);
34unsigned char _BitScanReverse64(unsigned long *_Index, unsigned __int64 _Mask);
35}
36#endif
37
38namespace llvm {
39
40/// The behavior an operation has on an input of 0.
41enum ZeroBehavior {
42 /// The returned value is undefined.
43 ZB_Undefined,
44 /// The returned value is numeric_limits<T>::max()
45 ZB_Max,
46 /// The returned value is numeric_limits<T>::digits
47 ZB_Width
48};
49
50/// Mathematical constants.
51namespace numbers {
52// TODO: Track C++20 std::numbers.
53// TODO: Favor using the hexadecimal FP constants (requires C++17).
54constexpr double e = 2.7182818284590452354, // (0x1.5bf0a8b145749P+1) https://oeis.org/A001113
55 egamma = .57721566490153286061, // (0x1.2788cfc6fb619P-1) https://oeis.org/A001620
56 ln2 = .69314718055994530942, // (0x1.62e42fefa39efP-1) https://oeis.org/A002162
57 ln10 = 2.3025850929940456840, // (0x1.24bb1bbb55516P+1) https://oeis.org/A002392
58 log2e = 1.4426950408889634074, // (0x1.71547652b82feP+0)
59 log10e = .43429448190325182765, // (0x1.bcb7b1526e50eP-2)
60 pi = 3.1415926535897932385, // (0x1.921fb54442d18P+1) https://oeis.org/A000796
61 inv_pi = .31830988618379067154, // (0x1.45f306bc9c883P-2) https://oeis.org/A049541
62 sqrtpi = 1.7724538509055160273, // (0x1.c5bf891b4ef6bP+0) https://oeis.org/A002161
63 inv_sqrtpi = .56418958354775628695, // (0x1.20dd750429b6dP-1) https://oeis.org/A087197
64 sqrt2 = 1.4142135623730950488, // (0x1.6a09e667f3bcdP+0) https://oeis.org/A00219
65 inv_sqrt2 = .70710678118654752440, // (0x1.6a09e667f3bcdP-1)
66 sqrt3 = 1.7320508075688772935, // (0x1.bb67ae8584caaP+0) https://oeis.org/A002194
67 inv_sqrt3 = .57735026918962576451, // (0x1.279a74590331cP-1)
68 phi = 1.6180339887498948482; // (0x1.9e3779b97f4a8P+0) https://oeis.org/A001622
69constexpr float ef = 2.71828183F, // (0x1.5bf0a8P+1) https://oeis.org/A001113
70 egammaf = .577215665F, // (0x1.2788d0P-1) https://oeis.org/A001620
71 ln2f = .693147181F, // (0x1.62e430P-1) https://oeis.org/A002162
72 ln10f = 2.30258509F, // (0x1.26bb1cP+1) https://oeis.org/A002392
73 log2ef = 1.44269504F, // (0x1.715476P+0)
74 log10ef = .434294482F, // (0x1.bcb7b2P-2)
75 pif = 3.14159265F, // (0x1.921fb6P+1) https://oeis.org/A000796
76 inv_pif = .318309886F, // (0x1.45f306P-2) https://oeis.org/A049541
77 sqrtpif = 1.77245385F, // (0x1.c5bf8aP+0) https://oeis.org/A002161
78 inv_sqrtpif = .564189584F, // (0x1.20dd76P-1) https://oeis.org/A087197
79 sqrt2f = 1.41421356F, // (0x1.6a09e6P+0) https://oeis.org/A002193
80 inv_sqrt2f = .707106781F, // (0x1.6a09e6P-1)
81 sqrt3f = 1.73205081F, // (0x1.bb67aeP+0) https://oeis.org/A002194
82 inv_sqrt3f = .577350269F, // (0x1.279a74P-1)
83 phif = 1.61803399F; // (0x1.9e377aP+0) https://oeis.org/A001622
84} // namespace numbers
85
86namespace detail {
87template <typename T, std::size_t SizeOfT> struct TrailingZerosCounter {
88 static unsigned count(T Val, ZeroBehavior) {
89 if (!Val)
90 return std::numeric_limits<T>::digits;
91 if (Val & 0x1)
92 return 0;
93
94 // Bisection method.
95 unsigned ZeroBits = 0;
96 T Shift = std::numeric_limits<T>::digits >> 1;
97 T Mask = std::numeric_limits<T>::max() >> Shift;
98 while (Shift) {
99 if ((Val & Mask) == 0) {
100 Val >>= Shift;
101 ZeroBits |= Shift;
102 }
103 Shift >>= 1;
104 Mask >>= Shift;
105 }
106 return ZeroBits;
107 }
108};
109
110#if defined(__GNUC__4) || defined(_MSC_VER)
111template <typename T> struct TrailingZerosCounter<T, 4> {
112 static unsigned count(T Val, ZeroBehavior ZB) {
113 if (ZB != ZB_Undefined && Val == 0)
114 return 32;
115
116#if __has_builtin(__builtin_ctz)1 || defined(__GNUC__4)
117 return __builtin_ctz(Val);
118#elif defined(_MSC_VER)
119 unsigned long Index;
120 _BitScanForward(&Index, Val);
121 return Index;
122#endif
123 }
124};
125
126#if !defined(_MSC_VER) || defined(_M_X64)
127template <typename T> struct TrailingZerosCounter<T, 8> {
128 static unsigned count(T Val, ZeroBehavior ZB) {
129 if (ZB != ZB_Undefined && Val == 0)
130 return 64;
131
132#if __has_builtin(__builtin_ctzll)1 || defined(__GNUC__4)
133 return __builtin_ctzll(Val);
134#elif defined(_MSC_VER)
135 unsigned long Index;
136 _BitScanForward64(&Index, Val);
137 return Index;
138#endif
139 }
140};
141#endif
142#endif
143} // namespace detail
144
145/// Count number of 0's from the least significant bit to the most
146/// stopping at the first 1.
147///
148/// Only unsigned integral types are allowed.
149///
150/// \param ZB the behavior on an input of 0. Only ZB_Width and ZB_Undefined are
151/// valid arguments.
152template <typename T>
153unsigned countTrailingZeros(T Val, ZeroBehavior ZB = ZB_Width) {
154 static_assert(std::is_unsigned_v<T>,
155 "Only unsigned integral types are allowed.");
156 return llvm::detail::TrailingZerosCounter<T, sizeof(T)>::count(Val, ZB);
157}
158
159namespace detail {
160template <typename T, std::size_t SizeOfT> struct LeadingZerosCounter {
161 static unsigned count(T Val, ZeroBehavior) {
162 if (!Val)
163 return std::numeric_limits<T>::digits;
164
165 // Bisection method.
166 unsigned ZeroBits = 0;
167 for (T Shift = std::numeric_limits<T>::digits >> 1; Shift; Shift >>= 1) {
168 T Tmp = Val >> Shift;
169 if (Tmp)
170 Val = Tmp;
171 else
172 ZeroBits |= Shift;
173 }
174 return ZeroBits;
175 }
176};
177
178#if defined(__GNUC__4) || defined(_MSC_VER)
179template <typename T> struct LeadingZerosCounter<T, 4> {
180 static unsigned count(T Val, ZeroBehavior ZB) {
181 if (ZB != ZB_Undefined && Val == 0)
182 return 32;
183
184#if __has_builtin(__builtin_clz)1 || defined(__GNUC__4)
185 return __builtin_clz(Val);
186#elif defined(_MSC_VER)
187 unsigned long Index;
188 _BitScanReverse(&Index, Val);
189 return Index ^ 31;
190#endif
191 }
192};
193
194#if !defined(_MSC_VER) || defined(_M_X64)
195template <typename T> struct LeadingZerosCounter<T, 8> {
196 static unsigned count(T Val, ZeroBehavior ZB) {
197 if (ZB != ZB_Undefined && Val == 0)
198 return 64;
199
200#if __has_builtin(__builtin_clzll)1 || defined(__GNUC__4)
201 return __builtin_clzll(Val);
202#elif defined(_MSC_VER)
203 unsigned long Index;
204 _BitScanReverse64(&Index, Val);
205 return Index ^ 63;
206#endif
207 }
208};
209#endif
210#endif
211} // namespace detail
212
213/// Count number of 0's from the most significant bit to the least
214/// stopping at the first 1.
215///
216/// Only unsigned integral types are allowed.
217///
218/// \param ZB the behavior on an input of 0. Only ZB_Width and ZB_Undefined are
219/// valid arguments.
220template <typename T>
221unsigned countLeadingZeros(T Val, ZeroBehavior ZB = ZB_Width) {
222 static_assert(std::is_unsigned_v<T>,
223 "Only unsigned integral types are allowed.");
224 return llvm::detail::LeadingZerosCounter<T, sizeof(T)>::count(Val, ZB);
225}
226
227/// Get the index of the first set bit starting from the least
228/// significant bit.
229///
230/// Only unsigned integral types are allowed.
231///
232/// \param ZB the behavior on an input of 0. Only ZB_Max and ZB_Undefined are
233/// valid arguments.
234template <typename T> T findFirstSet(T Val, ZeroBehavior ZB = ZB_Max) {
235 if (ZB == ZB_Max && Val == 0)
236 return std::numeric_limits<T>::max();
237
238 return countTrailingZeros(Val, ZB_Undefined);
239}
240
241/// Create a bitmask with the N right-most bits set to 1, and all other
242/// bits set to 0. Only unsigned types are allowed.
243template <typename T> T maskTrailingOnes(unsigned N) {
244 static_assert(std::is_unsigned<T>::value, "Invalid type!");
245 const unsigned Bits = CHAR_BIT8 * sizeof(T);
246 assert(N <= Bits && "Invalid bit index")(static_cast <bool> (N <= Bits && "Invalid bit index"
) ? void (0) : __assert_fail ("N <= Bits && \"Invalid bit index\""
, "llvm/include/llvm/Support/MathExtras.h", 246, __extension__
__PRETTY_FUNCTION__))
;
247 return N == 0 ? 0 : (T(-1) >> (Bits - N));
248}
249
250/// Create a bitmask with the N left-most bits set to 1, and all other
251/// bits set to 0. Only unsigned types are allowed.
252template <typename T> T maskLeadingOnes(unsigned N) {
253 return ~maskTrailingOnes<T>(CHAR_BIT8 * sizeof(T) - N);
254}
255
256/// Create a bitmask with the N right-most bits set to 0, and all other
257/// bits set to 1. Only unsigned types are allowed.
258template <typename T> T maskTrailingZeros(unsigned N) {
259 return maskLeadingOnes<T>(CHAR_BIT8 * sizeof(T) - N);
260}
261
262/// Create a bitmask with the N left-most bits set to 0, and all other
263/// bits set to 1. Only unsigned types are allowed.
264template <typename T> T maskLeadingZeros(unsigned N) {
265 return maskTrailingOnes<T>(CHAR_BIT8 * sizeof(T) - N);
266}
267
268/// Get the index of the last set bit starting from the least
269/// significant bit.
270///
271/// Only unsigned integral types are allowed.
272///
273/// \param ZB the behavior on an input of 0. Only ZB_Max and ZB_Undefined are
274/// valid arguments.
275template <typename T> T findLastSet(T Val, ZeroBehavior ZB = ZB_Max) {
276 if (ZB == ZB_Max && Val == 0)
277 return std::numeric_limits<T>::max();
278
279 // Use ^ instead of - because both gcc and llvm can remove the associated ^
280 // in the __builtin_clz intrinsic on x86.
281 return countLeadingZeros(Val, ZB_Undefined) ^
282 (std::numeric_limits<T>::digits - 1);
283}
284
285/// Macro compressed bit reversal table for 256 bits.
286///
287/// http://graphics.stanford.edu/~seander/bithacks.html#BitReverseTable
288static const unsigned char BitReverseTable256[256] = {
289#define R2(n) n, n + 2 * 64, n + 1 * 64, n + 3 * 64
290#define R4(n) R2(n), R2(n + 2 * 16), R2(n + 1 * 16), R2(n + 3 * 16)
291#define R6(n) R4(n), R4(n + 2 * 4), R4(n + 1 * 4), R4(n + 3 * 4)
292 R6(0), R6(2), R6(1), R6(3)
293#undef R2
294#undef R4
295#undef R6
296};
297
298/// Reverse the bits in \p Val.
299template <typename T> T reverseBits(T Val) {
300#if __has_builtin(__builtin_bitreverse8)1
301 if constexpr (std::is_same_v<T, uint8_t>)
302 return __builtin_bitreverse8(Val);
303#endif
304#if __has_builtin(__builtin_bitreverse16)1
305 if constexpr (std::is_same_v<T, uint16_t>)
306 return __builtin_bitreverse16(Val);
307#endif
308#if __has_builtin(__builtin_bitreverse32)1
309 if constexpr (std::is_same_v<T, uint32_t>)
310 return __builtin_bitreverse32(Val);
311#endif
312#if __has_builtin(__builtin_bitreverse64)1
313 if constexpr (std::is_same_v<T, uint64_t>)
314 return __builtin_bitreverse64(Val);
315#endif
316
317 unsigned char in[sizeof(Val)];
318 unsigned char out[sizeof(Val)];
319 std::memcpy(in, &Val, sizeof(Val));
320 for (unsigned i = 0; i < sizeof(Val); ++i)
321 out[(sizeof(Val) - i) - 1] = BitReverseTable256[in[i]];
322 std::memcpy(&Val, out, sizeof(Val));
323 return Val;
324}
325
326// NOTE: The following support functions use the _32/_64 extensions instead of
327// type overloading so that signed and unsigned integers can be used without
328// ambiguity.
329
330/// Return the high 32 bits of a 64 bit value.
331constexpr inline uint32_t Hi_32(uint64_t Value) {
332 return static_cast<uint32_t>(Value >> 32);
333}
334
335/// Return the low 32 bits of a 64 bit value.
336constexpr inline uint32_t Lo_32(uint64_t Value) {
337 return static_cast<uint32_t>(Value);
338}
339
340/// Make a 64-bit integer from a high / low pair of 32-bit integers.
341constexpr inline uint64_t Make_64(uint32_t High, uint32_t Low) {
342 return ((uint64_t)High << 32) | (uint64_t)Low;
343}
344
345/// Checks if an integer fits into the given bit width.
346template <unsigned N> constexpr inline bool isInt(int64_t x) {
347 if constexpr (N == 8)
348 return static_cast<int8_t>(x) == x;
349 if constexpr (N == 16)
350 return static_cast<int16_t>(x) == x;
351 if constexpr (N == 32)
352 return static_cast<int32_t>(x) == x;
353 if constexpr (N < 64)
354 return -(INT64_C(1)1L << (N - 1)) <= x && x < (INT64_C(1)1L << (N - 1));
355 (void)x; // MSVC v19.25 warns that x is unused.
356 return true;
357}
358
359/// Checks if a signed integer is an N bit number shifted left by S.
360template <unsigned N, unsigned S>
361constexpr inline bool isShiftedInt(int64_t x) {
362 static_assert(
363 N > 0, "isShiftedInt<0> doesn't make sense (refers to a 0-bit number.");
364 static_assert(N + S <= 64, "isShiftedInt<N, S> with N + S > 64 is too wide.");
365 return isInt<N + S>(x) && (x % (UINT64_C(1)1UL << S) == 0);
366}
367
368/// Checks if an unsigned integer fits into the given bit width.
369template <unsigned N> constexpr inline bool isUInt(uint64_t x) {
370 static_assert(N > 0, "isUInt<0> doesn't make sense");
371 if constexpr (N == 8)
372 return static_cast<uint8_t>(x) == x;
373 if constexpr (N == 16)
374 return static_cast<uint16_t>(x) == x;
375 if constexpr (N == 32)
376 return static_cast<uint32_t>(x) == x;
377 if constexpr (N < 64)
378 return x < (UINT64_C(1)1UL << (N));
379 (void)x; // MSVC v19.25 warns that x is unused.
380 return true;
381}
382
383/// Checks if a unsigned integer is an N bit number shifted left by S.
384template <unsigned N, unsigned S>
385constexpr inline bool isShiftedUInt(uint64_t x) {
386 static_assert(
387 N > 0, "isShiftedUInt<0> doesn't make sense (refers to a 0-bit number)");
388 static_assert(N + S <= 64,
389 "isShiftedUInt<N, S> with N + S > 64 is too wide.");
390 // Per the two static_asserts above, S must be strictly less than 64. So
391 // 1 << S is not undefined behavior.
392 return isUInt<N + S>(x) && (x % (UINT64_C(1)1UL << S) == 0);
393}
394
395/// Gets the maximum value for a N-bit unsigned integer.
396inline uint64_t maxUIntN(uint64_t N) {
397 assert(N > 0 && N <= 64 && "integer width out of range")(static_cast <bool> (N > 0 && N <= 64 &&
"integer width out of range") ? void (0) : __assert_fail ("N > 0 && N <= 64 && \"integer width out of range\""
, "llvm/include/llvm/Support/MathExtras.h", 397, __extension__
__PRETTY_FUNCTION__))
;
398
399 // uint64_t(1) << 64 is undefined behavior, so we can't do
400 // (uint64_t(1) << N) - 1
401 // without checking first that N != 64. But this works and doesn't have a
402 // branch.
403 return UINT64_MAX(18446744073709551615UL) >> (64 - N);
404}
405
406/// Gets the minimum value for a N-bit signed integer.
407inline int64_t minIntN(int64_t N) {
408 assert(N > 0 && N <= 64 && "integer width out of range")(static_cast <bool> (N > 0 && N <= 64 &&
"integer width out of range") ? void (0) : __assert_fail ("N > 0 && N <= 64 && \"integer width out of range\""
, "llvm/include/llvm/Support/MathExtras.h", 408, __extension__
__PRETTY_FUNCTION__))
;
409
410 return UINT64_C(1)1UL + ~(UINT64_C(1)1UL << (N - 1));
411}
412
413/// Gets the maximum value for a N-bit signed integer.
414inline int64_t maxIntN(int64_t N) {
415 assert(N > 0 && N <= 64 && "integer width out of range")(static_cast <bool> (N > 0 && N <= 64 &&
"integer width out of range") ? void (0) : __assert_fail ("N > 0 && N <= 64 && \"integer width out of range\""
, "llvm/include/llvm/Support/MathExtras.h", 415, __extension__
__PRETTY_FUNCTION__))
;
416
417 // This relies on two's complement wraparound when N == 64, so we convert to
418 // int64_t only at the very end to avoid UB.
419 return (UINT64_C(1)1UL << (N - 1)) - 1;
420}
421
422/// Checks if an unsigned integer fits into the given (dynamic) bit width.
423inline bool isUIntN(unsigned N, uint64_t x) {
424 return N >= 64 || x <= maxUIntN(N);
425}
426
427/// Checks if an signed integer fits into the given (dynamic) bit width.
428inline bool isIntN(unsigned N, int64_t x) {
429 return N >= 64 || (minIntN(N) <= x && x <= maxIntN(N));
430}
431
432/// Return true if the argument is a non-empty sequence of ones starting at the
433/// least significant bit with the remainder zero (32 bit version).
434/// Ex. isMask_32(0x0000FFFFU) == true.
435constexpr inline bool isMask_32(uint32_t Value) {
436 return Value && ((Value + 1) & Value) == 0;
437}
438
439/// Return true if the argument is a non-empty sequence of ones starting at the
440/// least significant bit with the remainder zero (64 bit version).
441constexpr inline bool isMask_64(uint64_t Value) {
442 return Value && ((Value + 1) & Value) == 0;
443}
444
445/// Return true if the argument contains a non-empty sequence of ones with the
446/// remainder zero (32 bit version.) Ex. isShiftedMask_32(0x0000FF00U) == true.
447constexpr inline bool isShiftedMask_32(uint32_t Value) {
448 return Value && isMask_32((Value - 1) | Value);
449}
450
451/// Return true if the argument contains a non-empty sequence of ones with the
452/// remainder zero (64 bit version.)
453constexpr inline bool isShiftedMask_64(uint64_t Value) {
454 return Value && isMask_64((Value - 1) | Value);
455}
456
457/// Return true if the argument is a power of two > 0.
458/// Ex. isPowerOf2_32(0x00100000U) == true (32 bit edition.)
459constexpr inline bool isPowerOf2_32(uint32_t Value) {
460 return llvm::has_single_bit(Value);
461}
462
463/// Return true if the argument is a power of two > 0 (64 bit edition.)
464constexpr inline bool isPowerOf2_64(uint64_t Value) {
465 return llvm::has_single_bit(Value);
466}
467
468/// Count the number of ones from the most significant bit to the first
469/// zero bit.
470///
471/// Ex. countLeadingOnes(0xFF0FFF00) == 8.
472/// Only unsigned integral types are allowed.
473///
474/// \param ZB the behavior on an input of all ones. Only ZB_Width and
475/// ZB_Undefined are valid arguments.
476template <typename T>
477unsigned countLeadingOnes(T Value, ZeroBehavior ZB = ZB_Width) {
478 static_assert(std::is_unsigned_v<T>,
479 "Only unsigned integral types are allowed.");
480 return countLeadingZeros<T>(~Value, ZB);
481}
482
483/// Count the number of ones from the least significant bit to the first
484/// zero bit.
485///
486/// Ex. countTrailingOnes(0x00FF00FF) == 8.
487/// Only unsigned integral types are allowed.
488///
489/// \param ZB the behavior on an input of all ones. Only ZB_Width and
490/// ZB_Undefined are valid arguments.
491template <typename T>
492unsigned countTrailingOnes(T Value, ZeroBehavior ZB = ZB_Width) {
493 static_assert(std::is_unsigned_v<T>,
494 "Only unsigned integral types are allowed.");
495 return countTrailingZeros<T>(~Value, ZB);
496}
497
498/// Count the number of set bits in a value.
499/// Ex. countPopulation(0xF000F000) = 8
500/// Returns 0 if the word is zero.
501template <typename T>
502inline unsigned countPopulation(T Value) {
503 static_assert(std::is_unsigned_v<T>,
504 "Only unsigned integral types are allowed.");
505 return (unsigned)llvm::popcount(Value);
506}
507
508/// Return true if the argument contains a non-empty sequence of ones with the
509/// remainder zero (32 bit version.) Ex. isShiftedMask_32(0x0000FF00U) == true.
510/// If true, \p MaskIdx will specify the index of the lowest set bit and \p
511/// MaskLen is updated to specify the length of the mask, else neither are
512/// updated.
513inline bool isShiftedMask_32(uint32_t Value, unsigned &MaskIdx,
514 unsigned &MaskLen) {
515 if (!isShiftedMask_32(Value))
516 return false;
517 MaskIdx = countTrailingZeros(Value);
518 MaskLen = countPopulation(Value);
519 return true;
520}
521
522/// Return true if the argument contains a non-empty sequence of ones with the
523/// remainder zero (64 bit version.) If true, \p MaskIdx will specify the index
524/// of the lowest set bit and \p MaskLen is updated to specify the length of the
525/// mask, else neither are updated.
526inline bool isShiftedMask_64(uint64_t Value, unsigned &MaskIdx,
527 unsigned &MaskLen) {
528 if (!isShiftedMask_64(Value))
529 return false;
530 MaskIdx = countTrailingZeros(Value);
531 MaskLen = countPopulation(Value);
532 return true;
533}
534
535/// Compile time Log2.
536/// Valid only for positive powers of two.
537template <size_t kValue> constexpr inline size_t CTLog2() {
538 static_assert(kValue > 0 && llvm::isPowerOf2_64(kValue),
539 "Value is not a valid power of 2");
540 return 1 + CTLog2<kValue / 2>();
541}
542
543template <> constexpr inline size_t CTLog2<1>() { return 0; }
544
545/// Return the floor log base 2 of the specified value, -1 if the value is zero.
546/// (32 bit edition.)
547/// Ex. Log2_32(32) == 5, Log2_32(1) == 0, Log2_32(0) == -1, Log2_32(6) == 2
548inline unsigned Log2_32(uint32_t Value) {
549 return 31 - countLeadingZeros(Value);
12
Returning the value 4294967295
550}
551
552/// Return the floor log base 2 of the specified value, -1 if the value is zero.
553/// (64 bit edition.)
554inline unsigned Log2_64(uint64_t Value) {
555 return 63 - countLeadingZeros(Value);
556}
557
558/// Return the ceil log base 2 of the specified value, 32 if the value is zero.
559/// (32 bit edition).
560/// Ex. Log2_32_Ceil(32) == 5, Log2_32_Ceil(1) == 0, Log2_32_Ceil(6) == 3
561inline unsigned Log2_32_Ceil(uint32_t Value) {
562 return 32 - countLeadingZeros(Value - 1);
563}
564
565/// Return the ceil log base 2 of the specified value, 64 if the value is zero.
566/// (64 bit edition.)
567inline unsigned Log2_64_Ceil(uint64_t Value) {
568 return 64 - countLeadingZeros(Value - 1);
569}
570
571/// This function takes a 64-bit integer and returns the bit equivalent double.
572inline double BitsToDouble(uint64_t Bits) {
573 static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes");
574 return llvm::bit_cast<double>(Bits);
575}
576
577/// This function takes a 32-bit integer and returns the bit equivalent float.
578inline float BitsToFloat(uint32_t Bits) {
579 static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes");
580 return llvm::bit_cast<float>(Bits);
581}
582
583/// This function takes a double and returns the bit equivalent 64-bit integer.
584/// Note that copying doubles around changes the bits of NaNs on some hosts,
585/// notably x86, so this routine cannot be used if these bits are needed.
586inline uint64_t DoubleToBits(double Double) {
587 static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes");
588 return llvm::bit_cast<uint64_t>(Double);
589}
590
591/// This function takes a float and returns the bit equivalent 32-bit integer.
592/// Note that copying floats around changes the bits of NaNs on some hosts,
593/// notably x86, so this routine cannot be used if these bits are needed.
594inline uint32_t FloatToBits(float Float) {
595 static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes");
596 return llvm::bit_cast<uint32_t>(Float);
597}
598
599/// A and B are either alignments or offsets. Return the minimum alignment that
600/// may be assumed after adding the two together.
601constexpr inline uint64_t MinAlign(uint64_t A, uint64_t B) {
602 // The largest power of 2 that divides both A and B.
603 //
604 // Replace "-Value" by "1+~Value" in the following commented code to avoid
605 // MSVC warning C4146
606 // return (A | B) & -(A | B);
607 return (A | B) & (1 + ~(A | B));
608}
609
610/// Returns the next power of two (in 64-bits) that is strictly greater than A.
611/// Returns zero on overflow.
612constexpr inline uint64_t NextPowerOf2(uint64_t A) {
613 A |= (A >> 1);
614 A |= (A >> 2);
615 A |= (A >> 4);
616 A |= (A >> 8);
617 A |= (A >> 16);
618 A |= (A >> 32);
619 return A + 1;
620}
621
622/// Returns the power of two which is less than or equal to the given value.
623/// Essentially, it is a floor operation across the domain of powers of two.
624inline uint64_t PowerOf2Floor(uint64_t A) {
625 if (!A) return 0;
626 return 1ull << (63 - countLeadingZeros(A, ZB_Undefined));
627}
628
629/// Returns the power of two which is greater than or equal to the given value.
630/// Essentially, it is a ceil operation across the domain of powers of two.
631inline uint64_t PowerOf2Ceil(uint64_t A) {
632 if (!A)
633 return 0;
634 return NextPowerOf2(A - 1);
635}
636
637/// Returns the next integer (mod 2**64) that is greater than or equal to
638/// \p Value and is a multiple of \p Align. \p Align must be non-zero.
639///
640/// Examples:
641/// \code
642/// alignTo(5, 8) = 8
643/// alignTo(17, 8) = 24
644/// alignTo(~0LL, 8) = 0
645/// alignTo(321, 255) = 510
646/// \endcode
647inline uint64_t alignTo(uint64_t Value, uint64_t Align) {
648 assert(Align != 0u && "Align can't be 0.")(static_cast <bool> (Align != 0u && "Align can't be 0."
) ? void (0) : __assert_fail ("Align != 0u && \"Align can't be 0.\""
, "llvm/include/llvm/Support/MathExtras.h", 648, __extension__
__PRETTY_FUNCTION__))
;
649 return (Value + Align - 1) / Align * Align;
650}
651
652inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
653 assert(Align != 0 && (Align & (Align - 1)) == 0 &&(static_cast <bool> (Align != 0 && (Align &
(Align - 1)) == 0 && "Align must be a power of 2") ?
void (0) : __assert_fail ("Align != 0 && (Align & (Align - 1)) == 0 && \"Align must be a power of 2\""
, "llvm/include/llvm/Support/MathExtras.h", 654, __extension__
__PRETTY_FUNCTION__))
654 "Align must be a power of 2")(static_cast <bool> (Align != 0 && (Align &
(Align - 1)) == 0 && "Align must be a power of 2") ?
void (0) : __assert_fail ("Align != 0 && (Align & (Align - 1)) == 0 && \"Align must be a power of 2\""
, "llvm/include/llvm/Support/MathExtras.h", 654, __extension__
__PRETTY_FUNCTION__))
;
655 return (Value + Align - 1) & -Align;
656}
657
658/// If non-zero \p Skew is specified, the return value will be a minimal integer
659/// that is greater than or equal to \p Size and equal to \p A * N + \p Skew for
660/// some integer N. If \p Skew is larger than \p A, its value is adjusted to '\p
661/// Skew mod \p A'. \p Align must be non-zero.
662///
663/// Examples:
664/// \code
665/// alignTo(5, 8, 7) = 7
666/// alignTo(17, 8, 1) = 17
667/// alignTo(~0LL, 8, 3) = 3
668/// alignTo(321, 255, 42) = 552
669/// \endcode
670inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew) {
671 assert(Align != 0u && "Align can't be 0.")(static_cast <bool> (Align != 0u && "Align can't be 0."
) ? void (0) : __assert_fail ("Align != 0u && \"Align can't be 0.\""
, "llvm/include/llvm/Support/MathExtras.h", 671, __extension__
__PRETTY_FUNCTION__))
;
672 Skew %= Align;
673 return alignTo(Value - Skew, Align) + Skew;
674}
675
676/// Returns the next integer (mod 2**64) that is greater than or equal to
677/// \p Value and is a multiple of \c Align. \c Align must be non-zero.
678template <uint64_t Align> constexpr inline uint64_t alignTo(uint64_t Value) {
679 static_assert(Align != 0u, "Align must be non-zero");
680 return (Value + Align - 1) / Align * Align;
681}
682
683/// Returns the integer ceil(Numerator / Denominator).
684inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
685 return alignTo(Numerator, Denominator) / Denominator;
686}
687
688/// Returns the integer nearest(Numerator / Denominator).
689inline uint64_t divideNearest(uint64_t Numerator, uint64_t Denominator) {
690 return (Numerator + (Denominator / 2)) / Denominator;
691}
692
693/// Returns the largest uint64_t less than or equal to \p Value and is
694/// \p Skew mod \p Align. \p Align must be non-zero
695inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
696 assert(Align != 0u && "Align can't be 0.")(static_cast <bool> (Align != 0u && "Align can't be 0."
) ? void (0) : __assert_fail ("Align != 0u && \"Align can't be 0.\""
, "llvm/include/llvm/Support/MathExtras.h", 696, __extension__
__PRETTY_FUNCTION__))
;
697 Skew %= Align;
698 return (Value - Skew) / Align * Align + Skew;
699}
700
701/// Sign-extend the number in the bottom B bits of X to a 32-bit integer.
702/// Requires 0 < B <= 32.
703template <unsigned B> constexpr inline int32_t SignExtend32(uint32_t X) {
704 static_assert(B > 0, "Bit width can't be 0.");
705 static_assert(B <= 32, "Bit width out of range.");
706 return int32_t(X << (32 - B)) >> (32 - B);
707}
708
709/// Sign-extend the number in the bottom B bits of X to a 32-bit integer.
710/// Requires 0 < B <= 32.
711inline int32_t SignExtend32(uint32_t X, unsigned B) {
712 assert(B > 0 && "Bit width can't be 0.")(static_cast <bool> (B > 0 && "Bit width can't be 0."
) ? void (0) : __assert_fail ("B > 0 && \"Bit width can't be 0.\""
, "llvm/include/llvm/Support/MathExtras.h", 712, __extension__
__PRETTY_FUNCTION__))
;
713 assert(B <= 32 && "Bit width out of range.")(static_cast <bool> (B <= 32 && "Bit width out of range."
) ? void (0) : __assert_fail ("B <= 32 && \"Bit width out of range.\""
, "llvm/include/llvm/Support/MathExtras.h", 713, __extension__
__PRETTY_FUNCTION__))
;
714 return int32_t(X << (32 - B)) >> (32 - B);
715}
716
717/// Sign-extend the number in the bottom B bits of X to a 64-bit integer.
718/// Requires 0 < B <= 64.
719template <unsigned B> constexpr inline int64_t SignExtend64(uint64_t x) {
720 static_assert(B > 0, "Bit width can't be 0.");
721 static_assert(B <= 64, "Bit width out of range.");
722 return int64_t(x << (64 - B)) >> (64 - B);
723}
724
725/// Sign-extend the number in the bottom B bits of X to a 64-bit integer.
726/// Requires 0 < B <= 64.
727inline int64_t SignExtend64(uint64_t X, unsigned B) {
728 assert(B > 0 && "Bit width can't be 0.")(static_cast <bool> (B > 0 && "Bit width can't be 0."
) ? void (0) : __assert_fail ("B > 0 && \"Bit width can't be 0.\""
, "llvm/include/llvm/Support/MathExtras.h", 728, __extension__
__PRETTY_FUNCTION__))
;
729 assert(B <= 64 && "Bit width out of range.")(static_cast <bool> (B <= 64 && "Bit width out of range."
) ? void (0) : __assert_fail ("B <= 64 && \"Bit width out of range.\""
, "llvm/include/llvm/Support/MathExtras.h", 729, __extension__
__PRETTY_FUNCTION__))
;
730 return int64_t(X << (64 - B)) >> (64 - B);
731}
732
733/// Subtract two unsigned integers, X and Y, of type T and return the absolute
734/// value of the result.
735template <typename T>
736std::enable_if_t<std::is_unsigned<T>::value, T> AbsoluteDifference(T X, T Y) {
737 return X > Y ? (X - Y) : (Y - X);
738}
739
740/// Add two unsigned integers, X and Y, of type T. Clamp the result to the
741/// maximum representable value of T on overflow. ResultOverflowed indicates if
742/// the result is larger than the maximum representable value of type T.
743template <typename T>
744std::enable_if_t<std::is_unsigned<T>::value, T>
745SaturatingAdd(T X, T Y, bool *ResultOverflowed = nullptr) {
746 bool Dummy;
747 bool &Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy;
748 // Hacker's Delight, p. 29
749 T Z = X + Y;
750 Overflowed = (Z < X || Z < Y);
751 if (Overflowed)
752 return std::numeric_limits<T>::max();
753 else
754 return Z;
755}
756
757/// Multiply two unsigned integers, X and Y, of type T. Clamp the result to the
758/// maximum representable value of T on overflow. ResultOverflowed indicates if
759/// the result is larger than the maximum representable value of type T.
760template <typename T>
761std::enable_if_t<std::is_unsigned<T>::value, T>
762SaturatingMultiply(T X, T Y, bool *ResultOverflowed = nullptr) {
763 bool Dummy;
764 bool &Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy;
765
766 // Hacker's Delight, p. 30 has a different algorithm, but we don't use that
767 // because it fails for uint16_t (where multiplication can have undefined
768 // behavior due to promotion to int), and requires a division in addition
769 // to the multiplication.
770
771 Overflowed = false;
772
773 // Log2(Z) would be either Log2Z or Log2Z + 1.
774 // Special case: if X or Y is 0, Log2_64 gives -1, and Log2Z
775 // will necessarily be less than Log2Max as desired.
776 int Log2Z = Log2_64(X) + Log2_64(Y);
777 const T Max = std::numeric_limits<T>::max();
778 int Log2Max = Log2_64(Max);
779 if (Log2Z < Log2Max) {
780 return X * Y;
781 }
782 if (Log2Z > Log2Max) {
783 Overflowed = true;
784 return Max;
785 }
786
787 // We're going to use the top bit, and maybe overflow one
788 // bit past it. Multiply all but the bottom bit then add
789 // that on at the end.
790 T Z = (X >> 1) * Y;
791 if (Z & ~(Max >> 1)) {
792 Overflowed = true;
793 return Max;
794 }
795 Z <<= 1;
796 if (X & 1)
797 return SaturatingAdd(Z, Y, ResultOverflowed);
798
799 return Z;
800}
801
802/// Multiply two unsigned integers, X and Y, and add the unsigned integer, A to
803/// the product. Clamp the result to the maximum representable value of T on
804/// overflow. ResultOverflowed indicates if the result is larger than the
805/// maximum representable value of type T.
806template <typename T>
807std::enable_if_t<std::is_unsigned<T>::value, T>
808SaturatingMultiplyAdd(T X, T Y, T A, bool *ResultOverflowed = nullptr) {
809 bool Dummy;
810 bool &Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy;
811
812 T Product = SaturatingMultiply(X, Y, &Overflowed);
813 if (Overflowed)
814 return Product;
815
816 return SaturatingAdd(A, Product, &Overflowed);
817}
818
819/// Use this rather than HUGE_VALF; the latter causes warnings on MSVC.
820extern const float huge_valf;
821
822
823/// Add two signed integers, computing the two's complement truncated result,
824/// returning true if overflow occurred.
825template <typename T>
826std::enable_if_t<std::is_signed<T>::value, T> AddOverflow(T X, T Y, T &Result) {
827#if __has_builtin(__builtin_add_overflow)1
828 return __builtin_add_overflow(X, Y, &Result);
829#else
830 // Perform the unsigned addition.
831 using U = std::make_unsigned_t<T>;
832 const U UX = static_cast<U>(X);
833 const U UY = static_cast<U>(Y);
834 const U UResult = UX + UY;
835
836 // Convert to signed.
837 Result = static_cast<T>(UResult);
838
839 // Adding two positive numbers should result in a positive number.
840 if (X > 0 && Y > 0)
841 return Result <= 0;
842 // Adding two negatives should result in a negative number.
843 if (X < 0 && Y < 0)
844 return Result >= 0;
845 return false;
846#endif
847}
848
849/// Subtract two signed integers, computing the two's complement truncated
850/// result, returning true if an overflow ocurred.
851template <typename T>
852std::enable_if_t<std::is_signed<T>::value, T> SubOverflow(T X, T Y, T &Result) {
853#if __has_builtin(__builtin_sub_overflow)1
854 return __builtin_sub_overflow(X, Y, &Result);
855#else
856 // Perform the unsigned addition.
857 using U = std::make_unsigned_t<T>;
858 const U UX = static_cast<U>(X);
859 const U UY = static_cast<U>(Y);
860 const U UResult = UX - UY;
861
862 // Convert to signed.
863 Result = static_cast<T>(UResult);
864
865 // Subtracting a positive number from a negative results in a negative number.
866 if (X <= 0 && Y > 0)
867 return Result >= 0;
868 // Subtracting a negative number from a positive results in a positive number.
869 if (X >= 0 && Y < 0)
870 return Result <= 0;
871 return false;
872#endif
873}
874
875/// Multiply two signed integers, computing the two's complement truncated
876/// result, returning true if an overflow ocurred.
877template <typename T>
878std::enable_if_t<std::is_signed<T>::value, T> MulOverflow(T X, T Y, T &Result) {
879 // Perform the unsigned multiplication on absolute values.
880 using U = std::make_unsigned_t<T>;
881 const U UX = X < 0 ? (0 - static_cast<U>(X)) : static_cast<U>(X);
882 const U UY = Y < 0 ? (0 - static_cast<U>(Y)) : static_cast<U>(Y);
883 const U UResult = UX * UY;
884
885 // Convert to signed.
886 const bool IsNegative = (X < 0) ^ (Y < 0);
887 Result = IsNegative ? (0 - UResult) : UResult;
888
889 // If any of the args was 0, result is 0 and no overflow occurs.
890 if (UX == 0 || UY == 0)
891 return false;
892
893 // UX and UY are in [1, 2^n], where n is the number of digits.
894 // Check how the max allowed absolute value (2^n for negative, 2^(n-1) for
895 // positive) divided by an argument compares to the other.
896 if (IsNegative)
897 return UX > (static_cast<U>(std::numeric_limits<T>::max()) + U(1)) / UY;
898 else
899 return UX > (static_cast<U>(std::numeric_limits<T>::max())) / UY;
900}
901
902} // End llvm namespace
903
904#endif