LLVM 23.0.0git
DAGCombiner.cpp
Go to the documentation of this file.
1//===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass combines dag nodes to form fewer, simpler DAG nodes. It can be run
10// both before and after the DAG is legalized.
11//
12// This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13// primarily intended to handle simplification opportunities that are implicit
14// in the LLVM IR and exposed by the various codegen lowering phases.
15//
16//===----------------------------------------------------------------------===//
17
18#include "llvm/ADT/APFloat.h"
19#include "llvm/ADT/APInt.h"
20#include "llvm/ADT/APSInt.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/DenseMap.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SetVector.h"
28#include "llvm/ADT/SmallSet.h"
30#include "llvm/ADT/Statistic.h"
52#include "llvm/IR/Attributes.h"
53#include "llvm/IR/Constant.h"
54#include "llvm/IR/DataLayout.h"
56#include "llvm/IR/Function.h"
57#include "llvm/IR/Metadata.h"
62#include "llvm/Support/Debug.h"
70#include <algorithm>
71#include <cassert>
72#include <cstdint>
73#include <functional>
74#include <iterator>
75#include <optional>
76#include <string>
77#include <tuple>
78#include <utility>
79#include <variant>
80
81#include "MatchContext.h"
82
83using namespace llvm;
84using namespace llvm::SDPatternMatch;
85
86#define DEBUG_TYPE "dagcombine"
87
88STATISTIC(NodesCombined , "Number of dag nodes combined");
89STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
90STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
91STATISTIC(OpsNarrowed , "Number of load/op/store narrowed");
92STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int");
93STATISTIC(SlicedLoads, "Number of load sliced");
94STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
95
96DEBUG_COUNTER(DAGCombineCounter, "dagcombine",
97 "Controls whether a DAG combine is performed for a node");
98
99static cl::opt<bool>
100CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
101 cl::desc("Enable DAG combiner's use of IR alias analysis"));
102
103static cl::opt<bool>
104UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
105 cl::desc("Enable DAG combiner's use of TBAA"));
106
107#ifndef NDEBUG
109CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
110 cl::desc("Only use DAG-combiner alias analysis in this"
111 " function"));
112#endif
113
114/// Hidden option to stress test load slicing, i.e., when this option
115/// is enabled, load slicing bypasses most of its profitability guards.
116static cl::opt<bool>
117StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
118 cl::desc("Bypass the profitability model of load slicing"),
119 cl::init(false));
120
121static cl::opt<bool>
122 MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
123 cl::desc("DAG combiner may split indexing from loads"));
124
125static cl::opt<bool>
126 EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
127 cl::desc("DAG combiner enable merging multiple stores "
128 "into a wider store"));
129
131 "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
132 cl::desc("Limit the number of operands to inline for Token Factors"));
133
135 "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
136 cl::desc("Limit the number of times for the same StoreNode and RootNode "
137 "to bail out in store merging dependence check"));
138
140 "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
141 cl::desc("DAG combiner enable reducing the width of load/op/store "
142 "sequence"));
144 "combiner-reduce-load-op-store-width-force-narrowing-profitable",
145 cl::Hidden, cl::init(false),
146 cl::desc("DAG combiner force override the narrowing profitable check when "
147 "reducing the width of load/op/store sequences"));
148
150 "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
151 cl::desc("DAG combiner enable load/<replace bytes>/store with "
152 "a narrower store"));
153
154static cl::opt<bool> DisableCombines("combiner-disabled", cl::Hidden,
155 cl::init(false),
156 cl::desc("Disable the DAG combiner"));
157
158namespace {
159
160 class DAGCombiner {
161 SelectionDAG &DAG;
162 const TargetLowering &TLI;
163 const SelectionDAGTargetInfo *STI;
165 CodeGenOptLevel OptLevel;
166 bool LegalDAG = false;
167 bool LegalOperations = false;
168 bool LegalTypes = false;
169 bool ForCodeSize;
170 bool DisableGenericCombines;
171
172 /// Worklist of all of the nodes that need to be simplified.
173 ///
174 /// This must behave as a stack -- new nodes to process are pushed onto the
175 /// back and when processing we pop off of the back.
176 ///
177 /// The worklist will not contain duplicates but may contain null entries
178 /// due to nodes being deleted from the underlying DAG. For fast lookup and
179 /// deduplication, the index of the node in this vector is stored in the
180 /// node in SDNode::CombinerWorklistIndex.
182
183 /// This records all nodes attempted to be added to the worklist since we
184 /// considered a new worklist entry. As we keep do not add duplicate nodes
185 /// in the worklist, this is different from the tail of the worklist.
187
188 /// Map from candidate StoreNode to the pair of RootNode and count.
189 /// The count is used to track how many times we have seen the StoreNode
190 /// with the same RootNode bail out in dependence check. If we have seen
191 /// the bail out for the same pair many times over a limit, we won't
192 /// consider the StoreNode with the same RootNode as store merging
193 /// candidate again.
195
196 // BatchAA - Used for DAG load/store alias analysis.
197 BatchAAResults *BatchAA;
198
199 /// This caches all chains that have already been processed in
200 /// DAGCombiner::getStoreMergeCandidates() and found to have no mergeable
201 /// stores candidates.
202 SmallPtrSet<SDNode *, 4> ChainsWithoutMergeableStores;
203
204 /// When an instruction is simplified, add all users of the instruction to
205 /// the work lists because they might get more simplified now.
206 void AddUsersToWorklist(SDNode *N) {
207 for (SDNode *Node : N->users())
208 AddToWorklist(Node);
209 }
210
211 /// Convenient shorthand to add a node and all of its user to the worklist.
212 void AddToWorklistWithUsers(SDNode *N) {
213 AddUsersToWorklist(N);
214 AddToWorklist(N);
215 }
216
217 // Prune potentially dangling nodes. This is called after
218 // any visit to a node, but should also be called during a visit after any
219 // failed combine which may have created a DAG node.
220 void clearAddedDanglingWorklistEntries() {
221 // Check any nodes added to the worklist to see if they are prunable.
222 while (!PruningList.empty()) {
223 auto *N = PruningList.pop_back_val();
224 if (N->use_empty())
225 recursivelyDeleteUnusedNodes(N);
226 }
227 }
228
229 SDNode *getNextWorklistEntry() {
230 // Before we do any work, remove nodes that are not in use.
231 clearAddedDanglingWorklistEntries();
232 SDNode *N = nullptr;
233 // The Worklist holds the SDNodes in order, but it may contain null
234 // entries.
235 while (!N && !Worklist.empty()) {
236 N = Worklist.pop_back_val();
237 }
238
239 if (N) {
240 assert(N->getCombinerWorklistIndex() >= 0 &&
241 "Found a worklist entry without a corresponding map entry!");
242 // Set to -2 to indicate that we combined the node.
243 N->setCombinerWorklistIndex(-2);
244 }
245 return N;
246 }
247
248 /// Call the node-specific routine that folds each particular type of node.
249 SDValue visit(SDNode *N);
250
251 public:
252 DAGCombiner(SelectionDAG &D, BatchAAResults *BatchAA, CodeGenOptLevel OL)
253 : DAG(D), TLI(D.getTargetLoweringInfo()),
254 STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL),
255 BatchAA(BatchAA) {
256 ForCodeSize = DAG.shouldOptForSize();
257 DisableGenericCombines =
258 DisableCombines || (STI && STI->disableGenericCombines(OptLevel));
259
260 MaximumLegalStoreInBits = 0;
261 // We use the minimum store size here, since that's all we can guarantee
262 // for the scalable vector types.
263 for (MVT VT : MVT::all_valuetypes())
264 if (EVT(VT).isSimple() && VT != MVT::Other &&
265 TLI.isTypeLegal(EVT(VT)) &&
266 VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
267 MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
268 }
269
270 void ConsiderForPruning(SDNode *N) {
271 // Mark this for potential pruning.
272 PruningList.insert(N);
273 }
274
275 /// Add to the worklist making sure its instance is at the back (next to be
276 /// processed.)
277 void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true,
278 bool SkipIfCombinedBefore = false) {
279 assert(N->getOpcode() != ISD::DELETED_NODE &&
280 "Deleted Node added to Worklist");
281
282 // Skip handle nodes as they can't usefully be combined and confuse the
283 // zero-use deletion strategy.
284 if (N->getOpcode() == ISD::HANDLENODE)
285 return;
286
287 if (SkipIfCombinedBefore && N->getCombinerWorklistIndex() == -2)
288 return;
289
290 if (IsCandidateForPruning)
291 ConsiderForPruning(N);
292
293 if (N->getCombinerWorklistIndex() < 0) {
294 N->setCombinerWorklistIndex(Worklist.size());
295 Worklist.push_back(N);
296 }
297 }
298
299 /// Remove all instances of N from the worklist.
300 void removeFromWorklist(SDNode *N) {
301 PruningList.remove(N);
302 StoreRootCountMap.erase(N);
303
304 int WorklistIndex = N->getCombinerWorklistIndex();
305 // If not in the worklist, the index might be -1 or -2 (was combined
306 // before). As the node gets deleted anyway, there's no need to update
307 // the index.
308 if (WorklistIndex < 0)
309 return; // Not in the worklist.
310
311 // Null out the entry rather than erasing it to avoid a linear operation.
312 Worklist[WorklistIndex] = nullptr;
313 N->setCombinerWorklistIndex(-1);
314 }
315
316 void deleteAndRecombine(SDNode *N);
317 bool recursivelyDeleteUnusedNodes(SDNode *N);
318
319 /// Replaces all uses of the results of one DAG node with new values.
320 SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
321 bool AddTo = true);
322
323 /// Replaces all uses of the results of one DAG node with new values.
324 SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
325 return CombineTo(N, &Res, 1, AddTo);
326 }
327
328 /// Replaces all uses of the results of one DAG node with new values.
329 SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
330 bool AddTo = true) {
331 SDValue To[] = { Res0, Res1 };
332 return CombineTo(N, To, 2, AddTo);
333 }
334
335 SDValue CombineTo(SDNode *N, SmallVectorImpl<SDValue> *To,
336 bool AddTo = true) {
337 return CombineTo(N, To->data(), To->size(), AddTo);
338 }
339
340 void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
341
342 private:
343 unsigned MaximumLegalStoreInBits;
344
345 /// Check the specified integer node value to see if it can be simplified or
346 /// if things it uses can be simplified by bit propagation.
347 /// If so, return true.
348 bool SimplifyDemandedBits(SDValue Op) {
349 unsigned BitWidth = Op.getScalarValueSizeInBits();
350 APInt DemandedBits = APInt::getAllOnes(BitWidth);
351 return SimplifyDemandedBits(Op, DemandedBits);
352 }
353
354 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
355 EVT VT = Op.getValueType();
356 APInt DemandedElts = VT.isFixedLengthVector()
358 : APInt(1, 1);
359 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, false);
360 }
361
362 /// Check the specified vector node value to see if it can be simplified or
363 /// if things it uses can be simplified as it only uses some of the
364 /// elements. If so, return true.
365 bool SimplifyDemandedVectorElts(SDValue Op) {
366 // TODO: For now just pretend it cannot be simplified.
367 if (Op.getValueType().isScalableVector())
368 return false;
369
370 unsigned NumElts = Op.getValueType().getVectorNumElements();
371 APInt DemandedElts = APInt::getAllOnes(NumElts);
372 return SimplifyDemandedVectorElts(Op, DemandedElts);
373 }
374
375 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
376 const APInt &DemandedElts,
377 bool AssumeSingleUse = false);
378 bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
379 bool AssumeSingleUse = false);
380
381 bool CombineToPreIndexedLoadStore(SDNode *N);
382 bool CombineToPostIndexedLoadStore(SDNode *N);
383 SDValue SplitIndexingFromLoad(LoadSDNode *LD);
384 bool SliceUpLoad(SDNode *N);
385
386 // Looks up the chain to find a unique (unaliased) store feeding the passed
387 // load. If no such store is found, returns a nullptr.
388 // Note: This will look past a CALLSEQ_START if the load is chained to it so
389 // so that it can find stack stores for byval params.
390 StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset);
391 // Scalars have size 0 to distinguish from singleton vectors.
392 SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
393 bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
394 bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
395
396 void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
397 SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
398 SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
399 SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
400 SDValue PromoteIntBinOp(SDValue Op);
401 SDValue PromoteIntShiftOp(SDValue Op);
402 SDValue PromoteExtend(SDValue Op);
403 bool PromoteLoad(SDValue Op);
404
405 SDValue foldShiftToAvg(SDNode *N, const SDLoc &DL);
406 // Fold `a bitwiseop (~b +/- c)` -> `a bitwiseop ~(b -/+ c)`
407 SDValue foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT);
408
409 SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
410 SDValue RHS, SDValue True, SDValue False,
411 ISD::CondCode CC);
412
413 /// Call the node-specific routine that knows how to fold each
414 /// particular type of node. If that doesn't do anything, try the
415 /// target-specific DAG combines.
416 SDValue combine(SDNode *N);
417
418 // Visitation implementation - Implement dag node combining for different
419 // node types. The semantics are as follows:
420 // Return Value:
421 // SDValue.getNode() == 0 - No change was made
422 // SDValue.getNode() == N - N was replaced, is dead and has been handled.
423 // otherwise - N should be replaced by the returned Operand.
424 //
425 SDValue visitTokenFactor(SDNode *N);
426 SDValue visitMERGE_VALUES(SDNode *N);
427 SDValue visitADD(SDNode *N);
428 SDValue visitADDLike(SDNode *N);
429 SDValue visitADDLikeCommutative(SDValue N0, SDValue N1,
430 SDNode *LocReference);
431 SDValue visitPTRADD(SDNode *N);
432 SDValue visitSUB(SDNode *N);
433 SDValue visitADDSAT(SDNode *N);
434 SDValue visitSUBSAT(SDNode *N);
435 SDValue visitADDC(SDNode *N);
436 SDValue visitADDO(SDNode *N);
437 SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
438 SDValue visitSUBC(SDNode *N);
439 SDValue visitSUBO(SDNode *N);
440 SDValue visitADDE(SDNode *N);
441 SDValue visitUADDO_CARRY(SDNode *N);
442 SDValue visitSADDO_CARRY(SDNode *N);
443 SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
444 SDNode *N);
445 SDValue visitSADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
446 SDNode *N);
447 SDValue visitSUBE(SDNode *N);
448 SDValue visitUSUBO_CARRY(SDNode *N);
449 SDValue visitSSUBO_CARRY(SDNode *N);
450 template <class MatchContextClass> SDValue visitMUL(SDNode *N);
451 SDValue visitMULFIX(SDNode *N);
452 SDValue useDivRem(SDNode *N);
453 SDValue visitSDIV(SDNode *N);
454 SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
455 SDValue visitUDIV(SDNode *N);
456 SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
457 SDValue visitREM(SDNode *N);
458 SDValue visitMULHU(SDNode *N);
459 SDValue visitMULHS(SDNode *N);
460 SDValue visitAVG(SDNode *N);
461 SDValue visitABD(SDNode *N);
462 SDValue visitSMUL_LOHI(SDNode *N);
463 SDValue visitUMUL_LOHI(SDNode *N);
464 SDValue visitMULO(SDNode *N);
465 SDValue visitIMINMAX(SDNode *N);
466 SDValue visitAND(SDNode *N);
467 SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
468 SDValue visitOR(SDNode *N);
469 SDValue visitORLike(SDValue N0, SDValue N1, const SDLoc &DL);
470 SDValue visitXOR(SDNode *N);
471 SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
472 SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
473 SDValue visitSHL(SDNode *N);
474 SDValue visitSRA(SDNode *N);
475 SDValue visitSRL(SDNode *N);
476 SDValue visitFunnelShift(SDNode *N);
477 SDValue visitSHLSAT(SDNode *N);
478 SDValue visitRotate(SDNode *N);
479 SDValue visitABS(SDNode *N);
480 SDValue visitCLMUL(SDNode *N);
481 SDValue visitBSWAP(SDNode *N);
482 SDValue visitBITREVERSE(SDNode *N);
483 SDValue visitCTLZ(SDNode *N);
484 SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
485 SDValue visitCTTZ(SDNode *N);
486 SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
487 SDValue visitCTPOP(SDNode *N);
488 SDValue visitSELECT(SDNode *N);
489 SDValue visitVSELECT(SDNode *N);
490 SDValue visitVP_SELECT(SDNode *N);
491 SDValue visitSELECT_CC(SDNode *N);
492 SDValue visitSETCC(SDNode *N);
493 SDValue visitSETCCCARRY(SDNode *N);
494 SDValue visitSIGN_EXTEND(SDNode *N);
495 SDValue visitZERO_EXTEND(SDNode *N);
496 SDValue visitANY_EXTEND(SDNode *N);
497 SDValue visitAssertExt(SDNode *N);
498 SDValue visitAssertAlign(SDNode *N);
499 SDValue visitSIGN_EXTEND_INREG(SDNode *N);
500 SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
501 SDValue visitTRUNCATE(SDNode *N);
502 SDValue visitTRUNCATE_USAT_U(SDNode *N);
503 SDValue visitBITCAST(SDNode *N);
504 SDValue visitFREEZE(SDNode *N);
505 SDValue visitBUILD_PAIR(SDNode *N);
506 SDValue visitFADD(SDNode *N);
507 SDValue visitVP_FADD(SDNode *N);
508 SDValue visitVP_FSUB(SDNode *N);
509 SDValue visitSTRICT_FADD(SDNode *N);
510 SDValue visitFSUB(SDNode *N);
511 SDValue visitFMUL(SDNode *N);
512 template <class MatchContextClass> SDValue visitFMA(SDNode *N);
513 SDValue visitFMAD(SDNode *N);
514 SDValue visitFMULADD(SDNode *N);
515 SDValue visitFDIV(SDNode *N);
516 SDValue visitFREM(SDNode *N);
517 SDValue visitFSQRT(SDNode *N);
518 SDValue visitFCOPYSIGN(SDNode *N);
519 SDValue visitFPOW(SDNode *N);
520 SDValue visitFCANONICALIZE(SDNode *N);
521 SDValue visitSINT_TO_FP(SDNode *N);
522 SDValue visitUINT_TO_FP(SDNode *N);
523 SDValue visitFP_TO_SINT(SDNode *N);
524 SDValue visitFP_TO_UINT(SDNode *N);
525 SDValue visitXROUND(SDNode *N);
526 SDValue visitFP_ROUND(SDNode *N);
527 SDValue visitFP_EXTEND(SDNode *N);
528 SDValue visitFNEG(SDNode *N);
529 SDValue visitFABS(SDNode *N);
530 SDValue visitFCEIL(SDNode *N);
531 SDValue visitFTRUNC(SDNode *N);
532 SDValue visitFFREXP(SDNode *N);
533 SDValue visitFFLOOR(SDNode *N);
534 SDValue visitFMinMax(SDNode *N);
535 SDValue visitBRCOND(SDNode *N);
536 SDValue visitBR_CC(SDNode *N);
537 SDValue visitLOAD(SDNode *N);
538
539 SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
540 SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
541 SDValue replaceStoreOfInsertLoad(StoreSDNode *ST);
542
543 bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
544
545 SDValue visitSTORE(SDNode *N);
546 SDValue visitATOMIC_STORE(SDNode *N);
547 SDValue visitLIFETIME_END(SDNode *N);
548 SDValue visitINSERT_VECTOR_ELT(SDNode *N);
549 SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
550 SDValue visitBUILD_VECTOR(SDNode *N);
551 SDValue visitCONCAT_VECTORS(SDNode *N);
552 SDValue visitVECTOR_INTERLEAVE(SDNode *N);
553 SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
554 SDValue visitVECTOR_SHUFFLE(SDNode *N);
555 SDValue visitSCALAR_TO_VECTOR(SDNode *N);
556 SDValue visitINSERT_SUBVECTOR(SDNode *N);
557 SDValue visitVECTOR_COMPRESS(SDNode *N);
558 SDValue visitMLOAD(SDNode *N);
559 SDValue visitMSTORE(SDNode *N);
560 SDValue visitMGATHER(SDNode *N);
561 SDValue visitMSCATTER(SDNode *N);
562 SDValue visitMHISTOGRAM(SDNode *N);
563 SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
564 SDValue visitVPGATHER(SDNode *N);
565 SDValue visitVPSCATTER(SDNode *N);
566 SDValue visitVP_STRIDED_LOAD(SDNode *N);
567 SDValue visitVP_STRIDED_STORE(SDNode *N);
568 SDValue visitFP_TO_FP16(SDNode *N);
569 SDValue visitFP16_TO_FP(SDNode *N);
570 SDValue visitFP_TO_BF16(SDNode *N);
571 SDValue visitBF16_TO_FP(SDNode *N);
572 SDValue visitVECREDUCE(SDNode *N);
573 SDValue visitVPOp(SDNode *N);
574 SDValue visitGET_FPENV_MEM(SDNode *N);
575 SDValue visitSET_FPENV_MEM(SDNode *N);
576
577 template <class MatchContextClass>
578 SDValue visitFADDForFMACombine(SDNode *N);
579 template <class MatchContextClass>
580 SDValue visitFSUBForFMACombine(SDNode *N);
581 SDValue visitFMULForFMADistributiveCombine(SDNode *N);
582
583 SDValue XformToShuffleWithZero(SDNode *N);
584 bool reassociationCanBreakAddressingModePattern(unsigned Opc,
585 const SDLoc &DL,
586 SDNode *N,
587 SDValue N0,
588 SDValue N1);
589 SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
590 SDValue N1, SDNodeFlags Flags);
591 SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
592 SDValue N1, SDNodeFlags Flags);
593 SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
594 EVT VT, SDValue N0, SDValue N1,
595 SDNodeFlags Flags = SDNodeFlags());
596
597 SDValue visitShiftByConstant(SDNode *N);
598
599 SDValue foldSelectOfConstants(SDNode *N);
600 SDValue foldVSelectOfConstants(SDNode *N);
601 SDValue foldBinOpIntoSelect(SDNode *BO);
602 bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
603 SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
604 SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
605 SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
606 SDValue N2, SDValue N3, ISD::CondCode CC,
607 bool NotExtCompare = false);
608 SDValue convertSelectOfFPConstantsToLoadOffset(
609 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
610 ISD::CondCode CC);
611 SDValue foldSignChangeInBitcast(SDNode *N);
612 SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
613 SDValue N2, SDValue N3, ISD::CondCode CC);
614 SDValue foldSelectOfBinops(SDNode *N);
615 SDValue foldSextSetcc(SDNode *N);
616 SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
617 const SDLoc &DL);
618 SDValue foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL);
619 SDValue foldABSToABD(SDNode *N, const SDLoc &DL);
620 SDValue foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
621 SDValue False, ISD::CondCode CC, const SDLoc &DL);
622 SDValue foldSelectToUMin(SDValue LHS, SDValue RHS, SDValue True,
623 SDValue False, ISD::CondCode CC, const SDLoc &DL);
624 SDValue unfoldMaskedMerge(SDNode *N);
625 SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
626 SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
627 const SDLoc &DL, bool foldBooleans);
628 SDValue rebuildSetCC(SDValue N);
629
630 bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
631 SDValue &CC, bool MatchStrict = false) const;
632 bool isOneUseSetCC(SDValue N) const;
633
634 SDValue foldAddToAvg(SDNode *N, const SDLoc &DL);
635 SDValue foldSubToAvg(SDNode *N, const SDLoc &DL);
636
637 SDValue foldCTLZToCTLS(SDValue Src, const SDLoc &DL);
638
639 SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
640 unsigned HiOp);
641 SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
642 SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
643 const TargetLowering &TLI);
644 SDValue foldPartialReduceMLAMulOp(SDNode *N);
645 SDValue foldPartialReduceAdd(SDNode *N);
646
647 SDValue CombineExtLoad(SDNode *N);
648 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
649 SDValue combineRepeatedFPDivisors(SDNode *N);
650 SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
651 SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf);
652 SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
653 SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
654 SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
655 SDValue BuildSDIV(SDNode *N);
656 SDValue BuildSDIVPow2(SDNode *N);
657 SDValue BuildUDIV(SDNode *N);
658 SDValue BuildSREMPow2(SDNode *N);
659 SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
660 SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
661 bool KnownNeverZero = false,
662 bool InexpensiveOnly = false,
663 std::optional<EVT> OutVT = std::nullopt);
664 SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
665 SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
666 SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
667 SDValue buildSqrtEstimateImpl(SDValue Op, bool Recip, SDNodeFlags Flags);
668 SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
669 bool Reciprocal);
670 SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
671 bool Reciprocal);
672 SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
673 bool DemandHighBits = true);
674 SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
675 SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
676 SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
677 bool HasPos, unsigned PosOpcode,
678 unsigned NegOpcode, const SDLoc &DL);
679 SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
680 SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
681 bool HasPos, unsigned PosOpcode,
682 unsigned NegOpcode, const SDLoc &DL);
683 SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
684 bool FromAdd);
685 SDValue MatchLoadCombine(SDNode *N);
686 SDValue mergeTruncStores(StoreSDNode *N);
687 SDValue reduceLoadWidth(SDNode *N);
688 SDValue ReduceLoadOpStoreWidth(SDNode *N);
689 SDValue splitMergedValStore(StoreSDNode *ST);
690 SDValue TransformFPLoadStorePair(SDNode *N);
691 SDValue convertBuildVecZextToZext(SDNode *N);
692 SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
693 SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
694 SDValue reduceBuildVecTruncToBitCast(SDNode *N);
695 SDValue reduceBuildVecToShuffle(SDNode *N);
696 SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
697 ArrayRef<int> VectorMask, SDValue VecIn1,
698 SDValue VecIn2, unsigned LeftIdx,
699 bool DidSplitVec);
700 SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
701
702 /// Walk up chain skipping non-aliasing memory nodes,
703 /// looking for aliasing nodes and adding them to the Aliases vector.
704 void GatherAllAliases(SDNode *N, SDValue OriginalChain,
705 SmallVectorImpl<SDValue> &Aliases);
706
707 /// Return true if there is any possibility that the two addresses overlap.
708 bool mayAlias(SDNode *Op0, SDNode *Op1) const;
709
710 /// Walk up chain skipping non-aliasing memory nodes, looking for a better
711 /// chain (aliasing node.)
712 SDValue FindBetterChain(SDNode *N, SDValue Chain);
713
714 /// Try to replace a store and any possibly adjacent stores on
715 /// consecutive chains with better chains. Return true only if St is
716 /// replaced.
717 ///
718 /// Notice that other chains may still be replaced even if the function
719 /// returns false.
720 bool findBetterNeighborChains(StoreSDNode *St);
721
722 // Helper for findBetterNeighborChains. Walk up store chain add additional
723 // chained stores that do not overlap and can be parallelized.
724 bool parallelizeChainedStores(StoreSDNode *St);
725
726 /// Holds a pointer to an LSBaseSDNode as well as information on where it
727 /// is located in a sequence of memory operations connected by a chain.
728 struct MemOpLink {
729 // Ptr to the mem node.
730 LSBaseSDNode *MemNode;
731
732 // Offset from the base ptr.
733 int64_t OffsetFromBase;
734
735 MemOpLink(LSBaseSDNode *N, int64_t Offset)
736 : MemNode(N), OffsetFromBase(Offset) {}
737 };
738
739 // Classify the origin of a stored value.
740 enum class StoreSource { Unknown, Constant, Extract, Load };
741 StoreSource getStoreSource(SDValue StoreVal) {
742 switch (StoreVal.getOpcode()) {
743 case ISD::Constant:
744 case ISD::ConstantFP:
745 return StoreSource::Constant;
749 return StoreSource::Constant;
750 return StoreSource::Unknown;
753 return StoreSource::Extract;
754 case ISD::LOAD:
755 return StoreSource::Load;
756 default:
757 return StoreSource::Unknown;
758 }
759 }
760
761 /// This is a helper function for visitMUL to check the profitability
762 /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
763 /// MulNode is the original multiply, AddNode is (add x, c1),
764 /// and ConstNode is c2.
765 bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
766 SDValue ConstNode);
767
768 /// This is a helper function for visitAND and visitZERO_EXTEND. Returns
769 /// true if the (and (load x) c) pattern matches an extload. ExtVT returns
770 /// the type of the loaded value to be extended.
771 bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
772 EVT LoadResultTy, EVT &ExtVT);
773
774 /// Helper function to calculate whether the given Load/Store can have its
775 /// width reduced to ExtVT.
776 bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
777 EVT &MemVT, unsigned ShAmt = 0);
778
779 /// Used by BackwardsPropagateMask to find suitable loads.
780 bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
781 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
782 ConstantSDNode *Mask, SDNode *&NodeToMask);
783 /// Attempt to propagate a given AND node back to load leaves so that they
784 /// can be combined into narrow loads.
785 bool BackwardsPropagateMask(SDNode *N);
786
787 /// Helper function for mergeConsecutiveStores which merges the component
788 /// store chains.
789 SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
790 unsigned NumStores);
791
792 /// Helper function for mergeConsecutiveStores which checks if all the store
793 /// nodes have the same underlying object. We can still reuse the first
794 /// store's pointer info if all the stores are from the same object.
795 bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes);
796
797 /// This is a helper function for mergeConsecutiveStores. When the source
798 /// elements of the consecutive stores are all constants or all extracted
799 /// vector elements, try to merge them into one larger store introducing
800 /// bitcasts if necessary. \return True if a merged store was created.
801 bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
802 EVT MemVT, unsigned NumStores,
803 bool IsConstantSrc, bool UseVector,
804 bool UseTrunc);
805
806 /// This is a helper function for mergeConsecutiveStores. Stores that
807 /// potentially may be merged with St are placed in StoreNodes. On success,
808 /// returns a chain predecessor to all store candidates.
809 SDNode *getStoreMergeCandidates(StoreSDNode *St,
810 SmallVectorImpl<MemOpLink> &StoreNodes);
811
812 /// Helper function for mergeConsecutiveStores. Checks if candidate stores
813 /// have indirect dependency through their operands. RootNode is the
814 /// predecessor to all stores calculated by getStoreMergeCandidates and is
815 /// used to prune the dependency check. \return True if safe to merge.
816 bool checkMergeStoreCandidatesForDependencies(
817 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
818 SDNode *RootNode);
819
820 /// Helper function for tryStoreMergeOfLoads. Checks if the load/store
821 /// chain has a call in it. \return True if a call is found.
822 bool hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld);
823
824 /// This is a helper function for mergeConsecutiveStores. Given a list of
825 /// store candidates, find the first N that are consecutive in memory.
826 /// Returns 0 if there are not at least 2 consecutive stores to try merging.
827 unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
828 int64_t ElementSizeBytes) const;
829
830 /// This is a helper function for mergeConsecutiveStores. It is used for
831 /// store chains that are composed entirely of constant values.
832 bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
833 unsigned NumConsecutiveStores,
834 EVT MemVT, SDNode *Root, bool AllowVectors);
835
836 /// This is a helper function for mergeConsecutiveStores. It is used for
837 /// store chains that are composed entirely of extracted vector elements.
838 /// When extracting multiple vector elements, try to store them in one
839 /// vector store rather than a sequence of scalar stores.
840 bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
841 unsigned NumConsecutiveStores, EVT MemVT,
842 SDNode *Root);
843
844 /// This is a helper function for mergeConsecutiveStores. It is used for
845 /// store chains that are composed entirely of loaded values.
846 bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
847 unsigned NumConsecutiveStores, EVT MemVT,
848 SDNode *Root, bool AllowVectors,
849 bool IsNonTemporalStore, bool IsNonTemporalLoad);
850
851 /// Merge consecutive store operations into a wide store.
852 /// This optimization uses wide integers or vectors when possible.
853 /// \return true if stores were merged.
854 bool mergeConsecutiveStores(StoreSDNode *St);
855
856 /// Try to transform a truncation where C is a constant:
857 /// (trunc (and X, C)) -> (and (trunc X), (trunc C))
858 ///
859 /// \p N needs to be a truncation and its first operand an AND. Other
860 /// requirements are checked by the function (e.g. that trunc is
861 /// single-use) and if missed an empty SDValue is returned.
862 SDValue distributeTruncateThroughAnd(SDNode *N);
863
864 /// Helper function to determine whether the target supports operation
865 /// given by \p Opcode for type \p VT, that is, whether the operation
866 /// is legal or custom before legalizing operations, and whether is
867 /// legal (but not custom) after legalization.
868 bool hasOperation(unsigned Opcode, EVT VT) {
869 return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
870 }
871
872 bool hasUMin(EVT VT) const {
873 auto LK = TLI.getTypeConversion(*DAG.getContext(), VT);
874 return (LK.first == TargetLoweringBase::TypeLegal ||
876 TLI.isOperationLegalOrCustom(ISD::UMIN, LK.second);
877 }
878
879 public:
880 /// Runs the dag combiner on all nodes in the work list
881 void Run(CombineLevel AtLevel);
882
883 SelectionDAG &getDAG() const { return DAG; }
884
885 /// Convenience wrapper around TargetLowering::getShiftAmountTy.
886 EVT getShiftAmountTy(EVT LHSTy) {
887 return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout());
888 }
889
890 /// This method returns true if we are running before type legalization or
891 /// if the specified VT is legal.
892 bool isTypeLegal(const EVT &VT) {
893 if (!LegalTypes) return true;
894 return TLI.isTypeLegal(VT);
895 }
896
897 /// Convenience wrapper around TargetLowering::getSetCCResultType
898 EVT getSetCCResultType(EVT VT) const {
899 return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
900 }
901
902 void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
903 SDValue OrigLoad, SDValue ExtLoad,
904 ISD::NodeType ExtType);
905 };
906
907/// This class is a DAGUpdateListener that removes any deleted
908/// nodes from the worklist.
909class WorklistRemover : public SelectionDAG::DAGUpdateListener {
910 DAGCombiner &DC;
911
912public:
913 explicit WorklistRemover(DAGCombiner &dc)
914 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
915
916 void NodeDeleted(SDNode *N, SDNode *E) override {
917 DC.removeFromWorklist(N);
918 }
919};
920
921class WorklistInserter : public SelectionDAG::DAGUpdateListener {
922 DAGCombiner &DC;
923
924public:
925 explicit WorklistInserter(DAGCombiner &dc)
926 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
927
928 // FIXME: Ideally we could add N to the worklist, but this causes exponential
929 // compile time costs in large DAGs, e.g. Halide.
930 void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
931};
932
933} // end anonymous namespace
934
935//===----------------------------------------------------------------------===//
936// TargetLowering::DAGCombinerInfo implementation
937//===----------------------------------------------------------------------===//
938
940 ((DAGCombiner*)DC)->AddToWorklist(N);
941}
942
944CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
945 return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
946}
947
949CombineTo(SDNode *N, SDValue Res, bool AddTo) {
950 return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
951}
952
954CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
955 return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
956}
957
960 return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
961}
962
965 return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
966}
967
968//===----------------------------------------------------------------------===//
969// Helper Functions
970//===----------------------------------------------------------------------===//
971
972void DAGCombiner::deleteAndRecombine(SDNode *N) {
973 removeFromWorklist(N);
974
975 // If the operands of this node are only used by the node, they will now be
976 // dead. Make sure to re-visit them and recursively delete dead nodes.
977 for (const SDValue &Op : N->ops())
978 // For an operand generating multiple values, one of the values may
979 // become dead allowing further simplification (e.g. split index
980 // arithmetic from an indexed load).
981 if (Op->hasOneUse() || Op->getNumValues() > 1)
982 AddToWorklist(Op.getNode());
983
984 DAG.DeleteNode(N);
985}
986
987// APInts must be the same size for most operations, this helper
988// function zero extends the shorter of the pair so that they match.
989// We provide an Offset so that we can create bitwidths that won't overflow.
990static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
991 unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
992 LHS = LHS.zext(Bits);
993 RHS = RHS.zext(Bits);
994}
995
996// Return true if this node is a setcc, or is a select_cc
997// that selects between the target values used for true and false, making it
998// equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
999// the appropriate nodes based on the type of node we are checking. This
1000// simplifies life a bit for the callers.
1001bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
1002 SDValue &CC, bool MatchStrict) const {
1003 if (N.getOpcode() == ISD::SETCC) {
1004 LHS = N.getOperand(0);
1005 RHS = N.getOperand(1);
1006 CC = N.getOperand(2);
1007 return true;
1008 }
1009
1010 if (MatchStrict &&
1011 (N.getOpcode() == ISD::STRICT_FSETCC ||
1012 N.getOpcode() == ISD::STRICT_FSETCCS)) {
1013 LHS = N.getOperand(1);
1014 RHS = N.getOperand(2);
1015 CC = N.getOperand(3);
1016 return true;
1017 }
1018
1019 if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) ||
1020 !TLI.isConstFalseVal(N.getOperand(3)))
1021 return false;
1022
1023 if (TLI.getBooleanContents(N.getValueType()) ==
1025 return false;
1026
1027 LHS = N.getOperand(0);
1028 RHS = N.getOperand(1);
1029 CC = N.getOperand(4);
1030 return true;
1031}
1032
1033/// Return true if this is a SetCC-equivalent operation with only one use.
1034/// If this is true, it allows the users to invert the operation for free when
1035/// it is profitable to do so.
1036bool DAGCombiner::isOneUseSetCC(SDValue N) const {
1037 SDValue N0, N1, N2;
1038 if (isSetCCEquivalent(N, N0, N1, N2) && N->hasOneUse())
1039 return true;
1040 return false;
1041}
1042
1044 if (!ScalarTy.isSimple())
1045 return false;
1046
1047 uint64_t MaskForTy = 0ULL;
1048 switch (ScalarTy.getSimpleVT().SimpleTy) {
1049 case MVT::i8:
1050 MaskForTy = 0xFFULL;
1051 break;
1052 case MVT::i16:
1053 MaskForTy = 0xFFFFULL;
1054 break;
1055 case MVT::i32:
1056 MaskForTy = 0xFFFFFFFFULL;
1057 break;
1058 default:
1059 return false;
1060 break;
1061 }
1062
1063 APInt Val;
1064 if (ISD::isConstantSplatVector(N, Val))
1065 return Val.getLimitedValue() == MaskForTy;
1066
1067 return false;
1068}
1069
1070// Determines if it is a constant integer or a splat/build vector of constant
1071// integers (and undefs).
1072// Do not permit build vector implicit truncation unless AllowTruncation is set.
1073static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false,
1074 bool AllowTruncation = false) {
1076 return !(Const->isOpaque() && NoOpaques);
1077 if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
1078 return false;
1079 unsigned BitWidth = N.getScalarValueSizeInBits();
1080 for (const SDValue &Op : N->op_values()) {
1081 if (Op.isUndef())
1082 continue;
1084 if (!Const || (Const->isOpaque() && NoOpaques))
1085 return false;
1086 // When AllowTruncation is true, allow constants that have been promoted
1087 // during type legalization as long as the value fits in the target type.
1088 if ((AllowTruncation &&
1089 Const->getAPIntValue().getActiveBits() > BitWidth) ||
1090 (!AllowTruncation && Const->getAPIntValue().getBitWidth() != BitWidth))
1091 return false;
1092 }
1093 return true;
1094}
1095
1096// Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1097// undef's.
1098static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1099 if (V.getOpcode() != ISD::BUILD_VECTOR)
1100 return false;
1101 return isConstantOrConstantVector(V, NoOpaques) ||
1103}
1104
1105// Determine if this an indexed load with an opaque target constant index.
1106static bool canSplitIdx(LoadSDNode *LD) {
1107 return MaySplitLoadIndex &&
1108 (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
1109 !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
1110}
1111
1112bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1113 const SDLoc &DL,
1114 SDNode *N,
1115 SDValue N0,
1116 SDValue N1) {
1117 // Currently this only tries to ensure we don't undo the GEP splits done by
1118 // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1119 // we check if the following transformation would be problematic:
1120 // (load/store (add, (add, x, offset1), offset2)) ->
1121 // (load/store (add, x, offset1+offset2)).
1122
1123 // (load/store (add, (add, x, y), offset2)) ->
1124 // (load/store (add, (add, x, offset2), y)).
1125
1126 if (!N0.isAnyAdd())
1127 return false;
1128
1129 // Check for vscale addressing modes.
1130 // (load/store (add/sub (add x, y), vscale))
1131 // (load/store (add/sub (add x, y), (lsl vscale, C)))
1132 // (load/store (add/sub (add x, y), (mul vscale, C)))
1133 if ((N1.getOpcode() == ISD::VSCALE ||
1134 ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::MUL) &&
1135 N1.getOperand(0).getOpcode() == ISD::VSCALE &&
1137 N1.getValueType().getFixedSizeInBits() <= 64) {
1138 int64_t ScalableOffset = N1.getOpcode() == ISD::VSCALE
1139 ? N1.getConstantOperandVal(0)
1140 : (N1.getOperand(0).getConstantOperandVal(0) *
1141 (N1.getOpcode() == ISD::SHL
1142 ? (1LL << N1.getConstantOperandVal(1))
1143 : N1.getConstantOperandVal(1)));
1144 if (Opc == ISD::SUB)
1145 ScalableOffset = -ScalableOffset;
1146 if (all_of(N->users(), [&](SDNode *Node) {
1147 if (auto *LoadStore = dyn_cast<MemSDNode>(Node);
1148 LoadStore && LoadStore->hasUniqueMemOperand() &&
1149 LoadStore->getBasePtr().getNode() == N) {
1150 TargetLoweringBase::AddrMode AM;
1151 AM.HasBaseReg = true;
1152 AM.ScalableOffset = ScalableOffset;
1153 EVT VT = LoadStore->getMemoryVT();
1154 unsigned AS = LoadStore->getAddressSpace();
1155 Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1156 return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy,
1157 AS);
1158 }
1159 return false;
1160 }))
1161 return true;
1162 }
1163
1164 if (Opc != ISD::ADD && Opc != ISD::PTRADD)
1165 return false;
1166
1167 auto *C2 = dyn_cast<ConstantSDNode>(N1);
1168 if (!C2)
1169 return false;
1170
1171 const APInt &C2APIntVal = C2->getAPIntValue();
1172 if (C2APIntVal.getSignificantBits() > 64)
1173 return false;
1174
1175 if (auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
1176 if (N0.hasOneUse())
1177 return false;
1178
1179 const APInt &C1APIntVal = C1->getAPIntValue();
1180 const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1181 if (CombinedValueIntVal.getSignificantBits() > 64)
1182 return false;
1183 const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1184
1185 for (SDNode *Node : N->users()) {
1186 if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) {
1187 if (!LoadStore->hasUniqueMemOperand())
1188 continue;
1189 // Is x[offset2] already not a legal addressing mode? If so then
1190 // reassociating the constants breaks nothing (we test offset2 because
1191 // that's the one we hope to fold into the load or store).
1192 TargetLoweringBase::AddrMode AM;
1193 AM.HasBaseReg = true;
1194 AM.BaseOffs = C2APIntVal.getSExtValue();
1195 EVT VT = LoadStore->getMemoryVT();
1196 unsigned AS = LoadStore->getAddressSpace();
1197 Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1198 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1199 continue;
1200
1201 // Would x[offset1+offset2] still be a legal addressing mode?
1202 AM.BaseOffs = CombinedValue;
1203 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1204 return true;
1205 }
1206 }
1207 } else {
1208 if (auto *GA = dyn_cast<GlobalAddressSDNode>(N0.getOperand(1)))
1209 if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1210 return false;
1211
1212 for (SDNode *Node : N->users()) {
1213 auto *LoadStore = dyn_cast<MemSDNode>(Node);
1214 if (!LoadStore || !LoadStore->hasUniqueMemOperand())
1215 return false;
1216
1217 // Is x[offset2] a legal addressing mode? If so then
1218 // reassociating the constants breaks address pattern
1219 TargetLoweringBase::AddrMode AM;
1220 AM.HasBaseReg = true;
1221 AM.BaseOffs = C2APIntVal.getSExtValue();
1222 EVT VT = LoadStore->getMemoryVT();
1223 unsigned AS = LoadStore->getAddressSpace();
1224 Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1225 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1226 return false;
1227 }
1228 return true;
1229 }
1230
1231 return false;
1232}
1233
1234/// Helper for DAGCombiner::reassociateOps. Try to reassociate (Opc N0, N1) if
1235/// \p N0 is the same kind of operation as \p Opc.
1236SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1237 SDValue N0, SDValue N1,
1238 SDNodeFlags Flags) {
1239 EVT VT = N0.getValueType();
1240
1241 if (N0.getOpcode() != Opc)
1242 return SDValue();
1243
1244 SDValue N00 = N0.getOperand(0);
1245 SDValue N01 = N0.getOperand(1);
1246
1248 SDNodeFlags NewFlags;
1249 if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
1250 Flags.hasNoUnsignedWrap())
1251 NewFlags |= SDNodeFlags::NoUnsignedWrap;
1252
1254 // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1255 if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, {N01, N1})) {
1256 NewFlags.setDisjoint(Flags.hasDisjoint() &&
1257 N0->getFlags().hasDisjoint());
1258 return DAG.getNode(Opc, DL, VT, N00, OpNode, NewFlags);
1259 }
1260 return SDValue();
1261 }
1262 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1263 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1264 // iff (op x, c1) has one use
1265 SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, NewFlags);
1266 return DAG.getNode(Opc, DL, VT, OpNode, N01, NewFlags);
1267 }
1268 }
1269
1270 // Check for repeated operand logic simplifications.
1271 if (Opc == ISD::AND || Opc == ISD::OR) {
1272 // (N00 & N01) & N00 --> N00 & N01
1273 // (N00 & N01) & N01 --> N00 & N01
1274 // (N00 | N01) | N00 --> N00 | N01
1275 // (N00 | N01) | N01 --> N00 | N01
1276 if (N1 == N00 || N1 == N01)
1277 return N0;
1278 }
1279 if (Opc == ISD::XOR) {
1280 // (N00 ^ N01) ^ N00 --> N01
1281 if (N1 == N00)
1282 return N01;
1283 // (N00 ^ N01) ^ N01 --> N00
1284 if (N1 == N01)
1285 return N00;
1286 }
1287
1288 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1289 if (N1 != N01) {
1290 // Reassociate if (op N00, N1) already exist
1291 if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N00, N1})) {
1292 // if Op (Op N00, N1), N01 already exist
1293 // we need to stop reassciate to avoid dead loop
1294 if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N01}))
1295 return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N01);
1296 }
1297 }
1298
1299 if (N1 != N00) {
1300 // Reassociate if (op N01, N1) already exist
1301 if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N01, N1})) {
1302 // if Op (Op N01, N1), N00 already exist
1303 // we need to stop reassciate to avoid dead loop
1304 if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N00}))
1305 return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N00);
1306 }
1307 }
1308
1309 // Reassociate the operands from (OR/AND (OR/AND(N00, N001)), N1) to (OR/AND
1310 // (OR/AND(N00, N1)), N01) when N00 and N1 are comparisons with the same
1311 // predicate or to (OR/AND (OR/AND(N1, N01)), N00) when N01 and N1 are
1312 // comparisons with the same predicate. This enables optimizations as the
1313 // following one:
1314 // CMP(A,C)||CMP(B,C) => CMP(MIN/MAX(A,B), C)
1315 // CMP(A,C)&&CMP(B,C) => CMP(MIN/MAX(A,B), C)
1316 if (Opc == ISD::AND || Opc == ISD::OR) {
1317 if (N1->getOpcode() == ISD::SETCC && N00->getOpcode() == ISD::SETCC &&
1318 N01->getOpcode() == ISD::SETCC) {
1319 ISD::CondCode CC1 = cast<CondCodeSDNode>(N1.getOperand(2))->get();
1320 ISD::CondCode CC00 = cast<CondCodeSDNode>(N00.getOperand(2))->get();
1321 ISD::CondCode CC01 = cast<CondCodeSDNode>(N01.getOperand(2))->get();
1322 if (CC1 == CC00 && CC1 != CC01) {
1323 SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, Flags);
1324 return DAG.getNode(Opc, DL, VT, OpNode, N01, Flags);
1325 }
1326 if (CC1 == CC01 && CC1 != CC00) {
1327 SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N01, N1, Flags);
1328 return DAG.getNode(Opc, DL, VT, OpNode, N00, Flags);
1329 }
1330 }
1331 }
1332 }
1333
1334 return SDValue();
1335}
1336
1337/// Try to reassociate commutative (Opc N0, N1) if either \p N0 or \p N1 is the
1338/// same kind of operation as \p Opc.
1339SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1340 SDValue N1, SDNodeFlags Flags) {
1341 assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1342
1343 // Floating-point reassociation is not allowed without loose FP math.
1344 if (N0.getValueType().isFloatingPoint() ||
1346 if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1347 return SDValue();
1348
1349 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags))
1350 return Combined;
1351 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0, Flags))
1352 return Combined;
1353 return SDValue();
1354}
1355
1356// Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1357// Note that we only expect Flags to be passed from FP operations. For integer
1358// operations they need to be dropped.
1359SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1360 const SDLoc &DL, EVT VT, SDValue N0,
1361 SDValue N1, SDNodeFlags Flags) {
1362 if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1363 N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
1364 N0->hasOneUse() && N1->hasOneUse() &&
1366 TLI.shouldReassociateReduction(RedOpc, N0.getOperand(0).getValueType())) {
1367 SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1368 return DAG.getNode(RedOpc, DL, VT,
1369 DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
1370 N0.getOperand(0), N1.getOperand(0)));
1371 }
1372
1373 // Reassociate op(op(vecreduce(a), b), op(vecreduce(c), d)) into
1374 // op(vecreduce(op(a, c)), op(b, d)), to combine the reductions into a
1375 // single node.
1376 SDValue A, B, C, D, RedA, RedB;
1377 if (sd_match(N0, m_OneUse(m_c_BinOp(
1378 Opc,
1379 m_AllOf(m_OneUse(m_UnaryOp(RedOpc, m_Value(A))),
1380 m_Value(RedA)),
1381 m_Value(B)))) &&
1383 Opc,
1384 m_AllOf(m_OneUse(m_UnaryOp(RedOpc, m_Value(C))),
1385 m_Value(RedB)),
1386 m_Value(D)))) &&
1387 !sd_match(B, m_UnaryOp(RedOpc, m_Value())) &&
1388 !sd_match(D, m_UnaryOp(RedOpc, m_Value())) &&
1389 A.getValueType() == C.getValueType() &&
1390 hasOperation(Opc, A.getValueType()) &&
1391 TLI.shouldReassociateReduction(RedOpc, VT)) {
1392 if ((Opc == ISD::FADD || Opc == ISD::FMUL) &&
1393 (!N0->getFlags().hasAllowReassociation() ||
1395 !RedA->getFlags().hasAllowReassociation() ||
1396 !RedB->getFlags().hasAllowReassociation()))
1397 return SDValue();
1398 SelectionDAG::FlagInserter FlagsInserter(
1399 DAG, Flags & N0->getFlags() & N1->getFlags() & RedA->getFlags() &
1400 RedB->getFlags());
1401 SDValue Op = DAG.getNode(Opc, DL, A.getValueType(), A, C);
1402 SDValue Red = DAG.getNode(RedOpc, DL, VT, Op);
1403 SDValue Op2 = DAG.getNode(Opc, DL, VT, B, D);
1404 return DAG.getNode(Opc, DL, VT, Red, Op2);
1405 }
1406 return SDValue();
1407}
1408
1409SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1410 bool AddTo) {
1411 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1412 ++NodesCombined;
1413 LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1414 To[0].dump(&DAG);
1415 dbgs() << " and " << NumTo - 1 << " other values\n");
1416 for (unsigned i = 0, e = NumTo; i != e; ++i)
1417 assert((!To[i].getNode() ||
1418 N->getValueType(i) == To[i].getValueType()) &&
1419 "Cannot combine value to value of different type!");
1420
1421 WorklistRemover DeadNodes(*this);
1422 DAG.ReplaceAllUsesWith(N, To);
1423 if (AddTo) {
1424 // Push the new nodes and any users onto the worklist
1425 for (unsigned i = 0, e = NumTo; i != e; ++i) {
1426 if (To[i].getNode())
1427 AddToWorklistWithUsers(To[i].getNode());
1428 }
1429 }
1430
1431 // Finally, if the node is now dead, remove it from the graph. The node
1432 // may not be dead if the replacement process recursively simplified to
1433 // something else needing this node.
1434 if (N->use_empty())
1435 deleteAndRecombine(N);
1436 return SDValue(N, 0);
1437}
1438
1439void DAGCombiner::
1440CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1441 // Replace the old value with the new one.
1442 ++NodesCombined;
1443 LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1444 dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1445
1446 // Replace all uses.
1447 DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1448
1449 // Push the new node and any (possibly new) users onto the worklist.
1450 AddToWorklistWithUsers(TLO.New.getNode());
1451
1452 // Finally, if the node is now dead, remove it from the graph.
1453 recursivelyDeleteUnusedNodes(TLO.Old.getNode());
1454}
1455
1456/// Check the specified integer node value to see if it can be simplified or if
1457/// things it uses can be simplified by bit propagation. If so, return true.
1458bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1459 const APInt &DemandedElts,
1460 bool AssumeSingleUse) {
1461 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1462 KnownBits Known;
1463 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1464 AssumeSingleUse))
1465 return false;
1466
1467 // Revisit the node.
1468 AddToWorklist(Op.getNode());
1469
1470 CommitTargetLoweringOpt(TLO);
1471 return true;
1472}
1473
1474/// Check the specified vector node value to see if it can be simplified or
1475/// if things it uses can be simplified as it only uses some of the elements.
1476/// If so, return true.
1477bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1478 const APInt &DemandedElts,
1479 bool AssumeSingleUse) {
1480 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1481 APInt KnownUndef, KnownZero;
1482 if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1483 TLO, 0, AssumeSingleUse))
1484 return false;
1485
1486 // Revisit the node.
1487 AddToWorklist(Op.getNode());
1488
1489 CommitTargetLoweringOpt(TLO);
1490 return true;
1491}
1492
1493void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1494 SDLoc DL(Load);
1495 EVT VT = Load->getValueType(0);
1496 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1497
1498 LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1499 Trunc.dump(&DAG); dbgs() << '\n');
1500
1501 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1502 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1503
1504 AddToWorklist(Trunc.getNode());
1505 recursivelyDeleteUnusedNodes(Load);
1506}
1507
1508SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1509 Replace = false;
1510 SDLoc DL(Op);
1511 if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1512 LoadSDNode *LD = cast<LoadSDNode>(Op);
1513 EVT MemVT = LD->getMemoryVT();
1515 : LD->getExtensionType();
1516 Replace = true;
1517 return DAG.getExtLoad(ExtType, DL, PVT,
1518 LD->getChain(), LD->getBasePtr(),
1519 MemVT, LD->getMemOperand());
1520 }
1521
1522 unsigned Opc = Op.getOpcode();
1523 switch (Opc) {
1524 default: break;
1525 case ISD::AssertSext:
1526 if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1527 return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1528 break;
1529 case ISD::AssertZext:
1530 if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1531 return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1532 break;
1533 case ISD::Constant: {
1534 unsigned ExtOpc =
1535 Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1536 return DAG.getNode(ExtOpc, DL, PVT, Op);
1537 }
1538 }
1539
1540 if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1541 return SDValue();
1542 return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1543}
1544
1545SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1547 return SDValue();
1548 EVT OldVT = Op.getValueType();
1549 SDLoc DL(Op);
1550 bool Replace = false;
1551 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1552 if (!NewOp.getNode())
1553 return SDValue();
1554 AddToWorklist(NewOp.getNode());
1555
1556 if (Replace)
1557 ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1558 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1559 DAG.getValueType(OldVT));
1560}
1561
1562SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1563 EVT OldVT = Op.getValueType();
1564 SDLoc DL(Op);
1565 bool Replace = false;
1566 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1567 if (!NewOp.getNode())
1568 return SDValue();
1569 AddToWorklist(NewOp.getNode());
1570
1571 if (Replace)
1572 ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1573 return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1574}
1575
1576/// Promote the specified integer binary operation if the target indicates it is
1577/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1578/// i32 since i16 instructions are longer.
1579SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1580 if (!LegalOperations)
1581 return SDValue();
1582
1583 EVT VT = Op.getValueType();
1584 if (VT.isVector() || !VT.isInteger())
1585 return SDValue();
1586
1587 // If operation type is 'undesirable', e.g. i16 on x86, consider
1588 // promoting it.
1589 unsigned Opc = Op.getOpcode();
1590 if (TLI.isTypeDesirableForOp(Opc, VT))
1591 return SDValue();
1592
1593 EVT PVT = VT;
1594 // Consult target whether it is a good idea to promote this operation and
1595 // what's the right type to promote it to.
1596 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1597 assert(PVT != VT && "Don't know what type to promote to!");
1598
1599 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1600
1601 bool Replace0 = false;
1602 SDValue N0 = Op.getOperand(0);
1603 SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1604
1605 bool Replace1 = false;
1606 SDValue N1 = Op.getOperand(1);
1607 SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1608 SDLoc DL(Op);
1609
1610 SDValue RV =
1611 DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1612
1613 // We are always replacing N0/N1's use in N and only need additional
1614 // replacements if there are additional uses.
1615 // Note: We are checking uses of the *nodes* (SDNode) rather than values
1616 // (SDValue) here because the node may reference multiple values
1617 // (for example, the chain value of a load node).
1618 Replace0 &= !N0->hasOneUse();
1619 Replace1 &= (N0 != N1) && !N1->hasOneUse();
1620
1621 // Combine Op here so it is preserved past replacements.
1622 CombineTo(Op.getNode(), RV);
1623
1624 // If operands have a use ordering, make sure we deal with
1625 // predecessor first.
1626 if (Replace0 && Replace1 && N0->isPredecessorOf(N1.getNode())) {
1627 std::swap(N0, N1);
1628 std::swap(NN0, NN1);
1629 }
1630
1631 if (Replace0) {
1632 AddToWorklist(NN0.getNode());
1633 ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1634 }
1635 if (Replace1) {
1636 AddToWorklist(NN1.getNode());
1637 ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1638 }
1639 return Op;
1640 }
1641 return SDValue();
1642}
1643
1644/// Promote the specified integer shift operation if the target indicates it is
1645/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1646/// i32 since i16 instructions are longer.
1647SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1648 if (!LegalOperations)
1649 return SDValue();
1650
1651 EVT VT = Op.getValueType();
1652 if (VT.isVector() || !VT.isInteger())
1653 return SDValue();
1654
1655 // If operation type is 'undesirable', e.g. i16 on x86, consider
1656 // promoting it.
1657 unsigned Opc = Op.getOpcode();
1658 if (TLI.isTypeDesirableForOp(Opc, VT))
1659 return SDValue();
1660
1661 EVT PVT = VT;
1662 // Consult target whether it is a good idea to promote this operation and
1663 // what's the right type to promote it to.
1664 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1665 assert(PVT != VT && "Don't know what type to promote to!");
1666
1667 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1668
1669 SDNodeFlags TruncFlags;
1670 bool Replace = false;
1671 SDValue N0 = Op.getOperand(0);
1672 if (Opc == ISD::SRA) {
1673 N0 = SExtPromoteOperand(N0, PVT);
1674 } else if (Opc == ISD::SRL) {
1675 N0 = ZExtPromoteOperand(N0, PVT);
1676 } else {
1677 if (Op->getFlags().hasNoUnsignedWrap()) {
1678 N0 = ZExtPromoteOperand(N0, PVT);
1679 TruncFlags = SDNodeFlags::NoUnsignedWrap;
1680 } else if (Op->getFlags().hasNoSignedWrap()) {
1681 N0 = SExtPromoteOperand(N0, PVT);
1682 TruncFlags = SDNodeFlags::NoSignedWrap;
1683 } else {
1684 N0 = PromoteOperand(N0, PVT, Replace);
1685 }
1686 }
1687
1688 if (!N0.getNode())
1689 return SDValue();
1690
1691 SDLoc DL(Op);
1692 SDValue N1 = Op.getOperand(1);
1693 SDValue RV = DAG.getNode(ISD::TRUNCATE, DL, VT,
1694 DAG.getNode(Opc, DL, PVT, N0, N1), TruncFlags);
1695
1696 if (Replace)
1697 ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1698
1699 // Deal with Op being deleted.
1700 if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1701 return RV;
1702 }
1703 return SDValue();
1704}
1705
1706SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1707 if (!LegalOperations)
1708 return SDValue();
1709
1710 EVT VT = Op.getValueType();
1711 if (VT.isVector() || !VT.isInteger())
1712 return SDValue();
1713
1714 // If operation type is 'undesirable', e.g. i16 on x86, consider
1715 // promoting it.
1716 unsigned Opc = Op.getOpcode();
1717 if (TLI.isTypeDesirableForOp(Opc, VT))
1718 return SDValue();
1719
1720 EVT PVT = VT;
1721 // Consult target whether it is a good idea to promote this operation and
1722 // what's the right type to promote it to.
1723 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1724 assert(PVT != VT && "Don't know what type to promote to!");
1725 // fold (aext (aext x)) -> (aext x)
1726 // fold (aext (zext x)) -> (zext x)
1727 // fold (aext (sext x)) -> (sext x)
1728 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1729 return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1730 }
1731 return SDValue();
1732}
1733
1734bool DAGCombiner::PromoteLoad(SDValue Op) {
1735 if (!LegalOperations)
1736 return false;
1737
1738 if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1739 return false;
1740
1741 EVT VT = Op.getValueType();
1742 if (VT.isVector() || !VT.isInteger())
1743 return false;
1744
1745 // If operation type is 'undesirable', e.g. i16 on x86, consider
1746 // promoting it.
1747 unsigned Opc = Op.getOpcode();
1748 if (TLI.isTypeDesirableForOp(Opc, VT))
1749 return false;
1750
1751 EVT PVT = VT;
1752 // Consult target whether it is a good idea to promote this operation and
1753 // what's the right type to promote it to.
1754 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1755 assert(PVT != VT && "Don't know what type to promote to!");
1756
1757 SDLoc DL(Op);
1758 SDNode *N = Op.getNode();
1759 LoadSDNode *LD = cast<LoadSDNode>(N);
1760 EVT MemVT = LD->getMemoryVT();
1762 : LD->getExtensionType();
1763 SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1764 LD->getChain(), LD->getBasePtr(),
1765 MemVT, LD->getMemOperand());
1766 SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1767
1768 LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1769 Result.dump(&DAG); dbgs() << '\n');
1770
1771 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1772 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1773
1774 AddToWorklist(Result.getNode());
1775 recursivelyDeleteUnusedNodes(N);
1776 return true;
1777 }
1778
1779 return false;
1780}
1781
1782/// Recursively delete a node which has no uses and any operands for
1783/// which it is the only use.
1784///
1785/// Note that this both deletes the nodes and removes them from the worklist.
1786/// It also adds any nodes who have had a user deleted to the worklist as they
1787/// may now have only one use and subject to other combines.
1788bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1789 if (!N->use_empty())
1790 return false;
1791
1792 SmallSetVector<SDNode *, 16> Nodes;
1793 Nodes.insert(N);
1794 do {
1795 N = Nodes.pop_back_val();
1796 if (!N)
1797 continue;
1798
1799 if (N->use_empty()) {
1800 for (const SDValue &ChildN : N->op_values())
1801 Nodes.insert(ChildN.getNode());
1802
1803 removeFromWorklist(N);
1804 DAG.DeleteNode(N);
1805 } else {
1806 AddToWorklist(N);
1807 }
1808 } while (!Nodes.empty());
1809 return true;
1810}
1811
1812//===----------------------------------------------------------------------===//
1813// Main DAG Combiner implementation
1814//===----------------------------------------------------------------------===//
1815
1816void DAGCombiner::Run(CombineLevel AtLevel) {
1817 // set the instance variables, so that the various visit routines may use it.
1818 Level = AtLevel;
1819 LegalDAG = Level >= AfterLegalizeDAG;
1820 LegalOperations = Level >= AfterLegalizeVectorOps;
1821 LegalTypes = Level >= AfterLegalizeTypes;
1822
1823 WorklistInserter AddNodes(*this);
1824
1825 // Add all the dag nodes to the worklist.
1826 //
1827 // Note: All nodes are not added to PruningList here, this is because the only
1828 // nodes which can be deleted are those which have no uses and all other nodes
1829 // which would otherwise be added to the worklist by the first call to
1830 // getNextWorklistEntry are already present in it.
1831 for (SDNode &Node : DAG.allnodes())
1832 AddToWorklist(&Node, /* IsCandidateForPruning */ Node.use_empty());
1833
1834 // Create a dummy node (which is not added to allnodes), that adds a reference
1835 // to the root node, preventing it from being deleted, and tracking any
1836 // changes of the root.
1837 HandleSDNode Dummy(DAG.getRoot());
1838
1839 // While we have a valid worklist entry node, try to combine it.
1840 while (SDNode *N = getNextWorklistEntry()) {
1841 // If N has no uses, it is dead. Make sure to revisit all N's operands once
1842 // N is deleted from the DAG, since they too may now be dead or may have a
1843 // reduced number of uses, allowing other xforms.
1844 if (recursivelyDeleteUnusedNodes(N))
1845 continue;
1846
1847 WorklistRemover DeadNodes(*this);
1848
1849 // If this combine is running after legalizing the DAG, re-legalize any
1850 // nodes pulled off the worklist.
1851 if (LegalDAG) {
1852 SmallSetVector<SDNode *, 16> UpdatedNodes;
1853 bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1854
1855 for (SDNode *LN : UpdatedNodes)
1856 AddToWorklistWithUsers(LN);
1857
1858 if (!NIsValid)
1859 continue;
1860 }
1861
1862 LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1863
1864 // Add any operands of the new node which have not yet been combined to the
1865 // worklist as well. getNextWorklistEntry flags nodes that have been
1866 // combined before. Because the worklist uniques things already, this won't
1867 // repeatedly process the same operand.
1868 for (const SDValue &ChildN : N->op_values())
1869 AddToWorklist(ChildN.getNode(), /*IsCandidateForPruning=*/true,
1870 /*SkipIfCombinedBefore=*/true);
1871
1872 SDValue RV = combine(N);
1873
1874 if (!RV.getNode())
1875 continue;
1876
1877 ++NodesCombined;
1878
1879 // Invalidate cached info.
1880 ChainsWithoutMergeableStores.clear();
1881
1882 // If we get back the same node we passed in, rather than a new node or
1883 // zero, we know that the node must have defined multiple values and
1884 // CombineTo was used. Since CombineTo takes care of the worklist
1885 // mechanics for us, we have no work to do in this case.
1886 if (RV.getNode() == N)
1887 continue;
1888
1889 assert(N->getOpcode() != ISD::DELETED_NODE &&
1890 RV.getOpcode() != ISD::DELETED_NODE &&
1891 "Node was deleted but visit returned new node!");
1892
1893 LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1894
1895 if (N->getNumValues() == RV->getNumValues())
1896 DAG.ReplaceAllUsesWith(N, RV.getNode());
1897 else {
1898 assert(N->getValueType(0) == RV.getValueType() &&
1899 N->getNumValues() == 1 && "Type mismatch");
1900 DAG.ReplaceAllUsesWith(N, &RV);
1901 }
1902
1903 // Push the new node and any users onto the worklist. Omit this if the
1904 // new node is the EntryToken (e.g. if a store managed to get optimized
1905 // out), because re-visiting the EntryToken and its users will not uncover
1906 // any additional opportunities, but there may be a large number of such
1907 // users, potentially causing compile time explosion.
1908 if (RV.getOpcode() != ISD::EntryToken)
1909 AddToWorklistWithUsers(RV.getNode());
1910
1911 // Finally, if the node is now dead, remove it from the graph. The node
1912 // may not be dead if the replacement process recursively simplified to
1913 // something else needing this node. This will also take care of adding any
1914 // operands which have lost a user to the worklist.
1915 recursivelyDeleteUnusedNodes(N);
1916 }
1917
1918 // If the root changed (e.g. it was a dead load, update the root).
1919 DAG.setRoot(Dummy.getValue());
1920 DAG.RemoveDeadNodes();
1921}
1922
1923SDValue DAGCombiner::visit(SDNode *N) {
1924 // clang-format off
1925 switch (N->getOpcode()) {
1926 default: break;
1927 case ISD::TokenFactor: return visitTokenFactor(N);
1928 case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
1929 case ISD::ADD: return visitADD(N);
1930 case ISD::PTRADD: return visitPTRADD(N);
1931 case ISD::SUB: return visitSUB(N);
1932 case ISD::SADDSAT:
1933 case ISD::UADDSAT: return visitADDSAT(N);
1934 case ISD::SSUBSAT:
1935 case ISD::USUBSAT: return visitSUBSAT(N);
1936 case ISD::ADDC: return visitADDC(N);
1937 case ISD::SADDO:
1938 case ISD::UADDO: return visitADDO(N);
1939 case ISD::SUBC: return visitSUBC(N);
1940 case ISD::SSUBO:
1941 case ISD::USUBO: return visitSUBO(N);
1942 case ISD::ADDE: return visitADDE(N);
1943 case ISD::UADDO_CARRY: return visitUADDO_CARRY(N);
1944 case ISD::SADDO_CARRY: return visitSADDO_CARRY(N);
1945 case ISD::SUBE: return visitSUBE(N);
1946 case ISD::USUBO_CARRY: return visitUSUBO_CARRY(N);
1947 case ISD::SSUBO_CARRY: return visitSSUBO_CARRY(N);
1948 case ISD::SMULFIX:
1949 case ISD::SMULFIXSAT:
1950 case ISD::UMULFIX:
1951 case ISD::UMULFIXSAT: return visitMULFIX(N);
1952 case ISD::MUL: return visitMUL<EmptyMatchContext>(N);
1953 case ISD::SDIV: return visitSDIV(N);
1954 case ISD::UDIV: return visitUDIV(N);
1955 case ISD::SREM:
1956 case ISD::UREM: return visitREM(N);
1957 case ISD::MULHU: return visitMULHU(N);
1958 case ISD::MULHS: return visitMULHS(N);
1959 case ISD::AVGFLOORS:
1960 case ISD::AVGFLOORU:
1961 case ISD::AVGCEILS:
1962 case ISD::AVGCEILU: return visitAVG(N);
1963 case ISD::ABDS:
1964 case ISD::ABDU: return visitABD(N);
1965 case ISD::SMUL_LOHI: return visitSMUL_LOHI(N);
1966 case ISD::UMUL_LOHI: return visitUMUL_LOHI(N);
1967 case ISD::SMULO:
1968 case ISD::UMULO: return visitMULO(N);
1969 case ISD::SMIN:
1970 case ISD::SMAX:
1971 case ISD::UMIN:
1972 case ISD::UMAX: return visitIMINMAX(N);
1973 case ISD::AND: return visitAND(N);
1974 case ISD::OR: return visitOR(N);
1975 case ISD::XOR: return visitXOR(N);
1976 case ISD::SHL: return visitSHL(N);
1977 case ISD::SRA: return visitSRA(N);
1978 case ISD::SRL: return visitSRL(N);
1979 case ISD::ROTR:
1980 case ISD::ROTL: return visitRotate(N);
1981 case ISD::FSHL:
1982 case ISD::FSHR: return visitFunnelShift(N);
1983 case ISD::SSHLSAT:
1984 case ISD::USHLSAT: return visitSHLSAT(N);
1985 case ISD::ABS: return visitABS(N);
1986 case ISD::CLMUL:
1987 case ISD::CLMULR:
1988 case ISD::CLMULH: return visitCLMUL(N);
1989 case ISD::BSWAP: return visitBSWAP(N);
1990 case ISD::BITREVERSE: return visitBITREVERSE(N);
1991 case ISD::CTLZ: return visitCTLZ(N);
1992 case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N);
1993 case ISD::CTTZ: return visitCTTZ(N);
1994 case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
1995 case ISD::CTPOP: return visitCTPOP(N);
1996 case ISD::SELECT: return visitSELECT(N);
1997 case ISD::VSELECT: return visitVSELECT(N);
1998 case ISD::SELECT_CC: return visitSELECT_CC(N);
1999 case ISD::SETCC: return visitSETCC(N);
2000 case ISD::SETCCCARRY: return visitSETCCCARRY(N);
2001 case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
2002 case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
2003 case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
2004 case ISD::AssertSext:
2005 case ISD::AssertZext: return visitAssertExt(N);
2006 case ISD::AssertAlign: return visitAssertAlign(N);
2007 case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
2010 case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
2011 case ISD::TRUNCATE: return visitTRUNCATE(N);
2012 case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT_U(N);
2013 case ISD::BITCAST: return visitBITCAST(N);
2014 case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
2015 case ISD::FADD: return visitFADD(N);
2016 case ISD::STRICT_FADD: return visitSTRICT_FADD(N);
2017 case ISD::FSUB: return visitFSUB(N);
2018 case ISD::FMUL: return visitFMUL(N);
2019 case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
2020 case ISD::FMAD: return visitFMAD(N);
2021 case ISD::FMULADD: return visitFMULADD(N);
2022 case ISD::FDIV: return visitFDIV(N);
2023 case ISD::FREM: return visitFREM(N);
2024 case ISD::FSQRT: return visitFSQRT(N);
2025 case ISD::FCOPYSIGN: return visitFCOPYSIGN(N);
2026 case ISD::FPOW: return visitFPOW(N);
2027 case ISD::SINT_TO_FP: return visitSINT_TO_FP(N);
2028 case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
2029 case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
2030 case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
2031 case ISD::LROUND:
2032 case ISD::LLROUND:
2033 case ISD::LRINT:
2034 case ISD::LLRINT: return visitXROUND(N);
2035 case ISD::FP_ROUND: return visitFP_ROUND(N);
2036 case ISD::FP_EXTEND: return visitFP_EXTEND(N);
2037 case ISD::FNEG: return visitFNEG(N);
2038 case ISD::FABS: return visitFABS(N);
2039 case ISD::FFLOOR: return visitFFLOOR(N);
2040 case ISD::FMINNUM:
2041 case ISD::FMAXNUM:
2042 case ISD::FMINIMUM:
2043 case ISD::FMAXIMUM:
2044 case ISD::FMINIMUMNUM:
2045 case ISD::FMAXIMUMNUM: return visitFMinMax(N);
2046 case ISD::FCEIL: return visitFCEIL(N);
2047 case ISD::FTRUNC: return visitFTRUNC(N);
2048 case ISD::FFREXP: return visitFFREXP(N);
2049 case ISD::BRCOND: return visitBRCOND(N);
2050 case ISD::BR_CC: return visitBR_CC(N);
2051 case ISD::LOAD: return visitLOAD(N);
2052 case ISD::STORE: return visitSTORE(N);
2053 case ISD::ATOMIC_STORE: return visitATOMIC_STORE(N);
2054 case ISD::INSERT_VECTOR_ELT: return visitINSERT_VECTOR_ELT(N);
2055 case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
2056 case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N);
2057 case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N);
2058 case ISD::VECTOR_INTERLEAVE: return visitVECTOR_INTERLEAVE(N);
2059 case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N);
2060 case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N);
2061 case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N);
2062 case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N);
2063 case ISD::MGATHER: return visitMGATHER(N);
2064 case ISD::MLOAD: return visitMLOAD(N);
2065 case ISD::MSCATTER: return visitMSCATTER(N);
2066 case ISD::MSTORE: return visitMSTORE(N);
2067 case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
2072 return visitPARTIAL_REDUCE_MLA(N);
2073 case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
2074 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
2075 case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
2076 case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
2077 case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
2078 case ISD::BF16_TO_FP: return visitBF16_TO_FP(N);
2079 case ISD::FREEZE: return visitFREEZE(N);
2080 case ISD::GET_FPENV_MEM: return visitGET_FPENV_MEM(N);
2081 case ISD::SET_FPENV_MEM: return visitSET_FPENV_MEM(N);
2082 case ISD::FCANONICALIZE: return visitFCANONICALIZE(N);
2085 case ISD::VECREDUCE_ADD:
2086 case ISD::VECREDUCE_MUL:
2087 case ISD::VECREDUCE_AND:
2088 case ISD::VECREDUCE_OR:
2089 case ISD::VECREDUCE_XOR:
2097 case ISD::VECREDUCE_FMINIMUM: return visitVECREDUCE(N);
2098#define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
2099#include "llvm/IR/VPIntrinsics.def"
2100 return visitVPOp(N);
2101 }
2102 // clang-format on
2103 return SDValue();
2104}
2105
2106SDValue DAGCombiner::combine(SDNode *N) {
2107 if (!DebugCounter::shouldExecute(DAGCombineCounter))
2108 return SDValue();
2109
2110 SDValue RV;
2111 if (!DisableGenericCombines)
2112 RV = visit(N);
2113
2114 // If nothing happened, try a target-specific DAG combine.
2115 if (!RV.getNode()) {
2116 assert(N->getOpcode() != ISD::DELETED_NODE &&
2117 "Node was deleted but visit returned NULL!");
2118
2119 if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
2120 TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
2121
2122 // Expose the DAG combiner to the target combiner impls.
2123 TargetLowering::DAGCombinerInfo
2124 DagCombineInfo(DAG, Level, false, this);
2125
2126 RV = TLI.PerformDAGCombine(N, DagCombineInfo);
2127 }
2128 }
2129
2130 // If nothing happened still, try promoting the operation.
2131 if (!RV.getNode()) {
2132 switch (N->getOpcode()) {
2133 default: break;
2134 case ISD::ADD:
2135 case ISD::SUB:
2136 case ISD::MUL:
2137 case ISD::AND:
2138 case ISD::OR:
2139 case ISD::XOR:
2140 RV = PromoteIntBinOp(SDValue(N, 0));
2141 break;
2142 case ISD::SHL:
2143 case ISD::SRA:
2144 case ISD::SRL:
2145 RV = PromoteIntShiftOp(SDValue(N, 0));
2146 break;
2147 case ISD::SIGN_EXTEND:
2148 case ISD::ZERO_EXTEND:
2149 case ISD::ANY_EXTEND:
2150 RV = PromoteExtend(SDValue(N, 0));
2151 break;
2152 case ISD::LOAD:
2153 if (PromoteLoad(SDValue(N, 0)))
2154 RV = SDValue(N, 0);
2155 break;
2156 }
2157 }
2158
2159 // If N is a commutative binary node, try to eliminate it if the commuted
2160 // version is already present in the DAG.
2161 if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) {
2162 SDValue N0 = N->getOperand(0);
2163 SDValue N1 = N->getOperand(1);
2164
2165 // Constant operands are canonicalized to RHS.
2166 if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
2167 SDValue Ops[] = {N1, N0};
2168 SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
2169 N->getFlags());
2170 if (CSENode)
2171 return SDValue(CSENode, 0);
2172 }
2173 }
2174
2175 return RV;
2176}
2177
2178/// Given a node, return its input chain if it has one, otherwise return a null
2179/// sd operand.
2181 if (unsigned NumOps = N->getNumOperands()) {
2182 if (N->getOperand(0).getValueType() == MVT::Other)
2183 return N->getOperand(0);
2184 if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
2185 return N->getOperand(NumOps-1);
2186 for (unsigned i = 1; i < NumOps-1; ++i)
2187 if (N->getOperand(i).getValueType() == MVT::Other)
2188 return N->getOperand(i);
2189 }
2190 return SDValue();
2191}
2192
2193SDValue DAGCombiner::visitFCANONICALIZE(SDNode *N) {
2194 SDValue Operand = N->getOperand(0);
2195 EVT VT = Operand.getValueType();
2196 SDLoc dl(N);
2197
2198 // Canonicalize undef to quiet NaN.
2199 if (Operand.isUndef()) {
2200 APFloat CanonicalQNaN = APFloat::getQNaN(VT.getFltSemantics());
2201 return DAG.getConstantFP(CanonicalQNaN, dl, VT);
2202 }
2203 return SDValue();
2204}
2205
2206SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
2207 // If N has two operands, where one has an input chain equal to the other,
2208 // the 'other' chain is redundant.
2209 if (N->getNumOperands() == 2) {
2210 if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
2211 return N->getOperand(0);
2212 if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
2213 return N->getOperand(1);
2214 }
2215
2216 // Don't simplify token factors if optnone.
2217 if (OptLevel == CodeGenOptLevel::None)
2218 return SDValue();
2219
2220 // Don't simplify the token factor if the node itself has too many operands.
2221 if (N->getNumOperands() > TokenFactorInlineLimit)
2222 return SDValue();
2223
2224 // If the sole user is a token factor, we should make sure we have a
2225 // chance to merge them together. This prevents TF chains from inhibiting
2226 // optimizations.
2227 if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TokenFactor)
2228 AddToWorklist(*(N->user_begin()));
2229
2230 SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
2231 SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
2232 SmallPtrSet<SDNode*, 16> SeenOps;
2233 bool Changed = false; // If we should replace this token factor.
2234
2235 // Start out with this token factor.
2236 TFs.push_back(N);
2237
2238 // Iterate through token factors. The TFs grows when new token factors are
2239 // encountered.
2240 for (unsigned i = 0; i < TFs.size(); ++i) {
2241 // Limit number of nodes to inline, to avoid quadratic compile times.
2242 // We have to add the outstanding Token Factors to Ops, otherwise we might
2243 // drop Ops from the resulting Token Factors.
2244 if (Ops.size() > TokenFactorInlineLimit) {
2245 for (unsigned j = i; j < TFs.size(); j++)
2246 Ops.emplace_back(TFs[j], 0);
2247 // Drop unprocessed Token Factors from TFs, so we do not add them to the
2248 // combiner worklist later.
2249 TFs.resize(i);
2250 break;
2251 }
2252
2253 SDNode *TF = TFs[i];
2254 // Check each of the operands.
2255 for (const SDValue &Op : TF->op_values()) {
2256 switch (Op.getOpcode()) {
2257 case ISD::EntryToken:
2258 // Entry tokens don't need to be added to the list. They are
2259 // redundant.
2260 Changed = true;
2261 break;
2262
2263 case ISD::TokenFactor:
2264 if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
2265 // Queue up for processing.
2266 TFs.push_back(Op.getNode());
2267 Changed = true;
2268 break;
2269 }
2270 [[fallthrough]];
2271
2272 default:
2273 // Only add if it isn't already in the list.
2274 if (SeenOps.insert(Op.getNode()).second)
2275 Ops.push_back(Op);
2276 else
2277 Changed = true;
2278 break;
2279 }
2280 }
2281 }
2282
2283 // Re-visit inlined Token Factors, to clean them up in case they have been
2284 // removed. Skip the first Token Factor, as this is the current node.
2285 for (unsigned i = 1, e = TFs.size(); i < e; i++)
2286 AddToWorklist(TFs[i]);
2287
2288 // Remove Nodes that are chained to another node in the list. Do so
2289 // by walking up chains breath-first stopping when we've seen
2290 // another operand. In general we must climb to the EntryNode, but we can exit
2291 // early if we find all remaining work is associated with just one operand as
2292 // no further pruning is possible.
2293
2294 // List of nodes to search through and original Ops from which they originate.
2296 SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2297 SmallPtrSet<SDNode *, 16> SeenChains;
2298 bool DidPruneOps = false;
2299
2300 unsigned NumLeftToConsider = 0;
2301 for (const SDValue &Op : Ops) {
2302 Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
2303 OpWorkCount.push_back(1);
2304 }
2305
2306 auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2307 // If this is an Op, we can remove the op from the list. Remark any
2308 // search associated with it as from the current OpNumber.
2309 if (SeenOps.contains(Op)) {
2310 Changed = true;
2311 DidPruneOps = true;
2312 unsigned OrigOpNumber = 0;
2313 while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2314 OrigOpNumber++;
2315 assert((OrigOpNumber != Ops.size()) &&
2316 "expected to find TokenFactor Operand");
2317 // Re-mark worklist from OrigOpNumber to OpNumber
2318 for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2319 if (Worklist[i].second == OrigOpNumber) {
2320 Worklist[i].second = OpNumber;
2321 }
2322 }
2323 OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2324 OpWorkCount[OrigOpNumber] = 0;
2325 NumLeftToConsider--;
2326 }
2327 // Add if it's a new chain
2328 if (SeenChains.insert(Op).second) {
2329 OpWorkCount[OpNumber]++;
2330 Worklist.push_back(std::make_pair(Op, OpNumber));
2331 }
2332 };
2333
2334 for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2335 // We need at least be consider at least 2 Ops to prune.
2336 if (NumLeftToConsider <= 1)
2337 break;
2338 auto CurNode = Worklist[i].first;
2339 auto CurOpNumber = Worklist[i].second;
2340 assert((OpWorkCount[CurOpNumber] > 0) &&
2341 "Node should not appear in worklist");
2342 switch (CurNode->getOpcode()) {
2343 case ISD::EntryToken:
2344 // Hitting EntryToken is the only way for the search to terminate without
2345 // hitting
2346 // another operand's search. Prevent us from marking this operand
2347 // considered.
2348 NumLeftToConsider++;
2349 break;
2350 case ISD::TokenFactor:
2351 for (const SDValue &Op : CurNode->op_values())
2352 AddToWorklist(i, Op.getNode(), CurOpNumber);
2353 break;
2355 case ISD::LIFETIME_END:
2356 case ISD::CopyFromReg:
2357 case ISD::CopyToReg:
2358 AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
2359 break;
2360 default:
2361 if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
2362 AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2363 break;
2364 }
2365 OpWorkCount[CurOpNumber]--;
2366 if (OpWorkCount[CurOpNumber] == 0)
2367 NumLeftToConsider--;
2368 }
2369
2370 // If we've changed things around then replace token factor.
2371 if (Changed) {
2373 if (Ops.empty()) {
2374 // The entry token is the only possible outcome.
2375 Result = DAG.getEntryNode();
2376 } else {
2377 if (DidPruneOps) {
2378 SmallVector<SDValue, 8> PrunedOps;
2379 //
2380 for (const SDValue &Op : Ops) {
2381 if (SeenChains.count(Op.getNode()) == 0)
2382 PrunedOps.push_back(Op);
2383 }
2384 Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2385 } else {
2386 Result = DAG.getTokenFactor(SDLoc(N), Ops);
2387 }
2388 }
2389 return Result;
2390 }
2391 return SDValue();
2392}
2393
2394/// MERGE_VALUES can always be eliminated.
2395SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2396 WorklistRemover DeadNodes(*this);
2397 // Replacing results may cause a different MERGE_VALUES to suddenly
2398 // be CSE'd with N, and carry its uses with it. Iterate until no
2399 // uses remain, to ensure that the node can be safely deleted.
2400 // First add the users of this node to the work list so that they
2401 // can be tried again once they have new operands.
2402 AddUsersToWorklist(N);
2403 do {
2404 // Do as a single replacement to avoid rewalking use lists.
2406 DAG.ReplaceAllUsesWith(N, Ops.data());
2407 } while (!N->use_empty());
2408 deleteAndRecombine(N);
2409 return SDValue(N, 0); // Return N so it doesn't get rechecked!
2410}
2411
2412/// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2413/// ConstantSDNode pointer else nullptr.
2416 return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2417}
2418
2419// isTruncateOf - If N is a truncate of some other value, return true, record
2420// the value being truncated in Op and which of Op's bits are zero/one in Known.
2421// This function computes KnownBits to avoid a duplicated call to
2422// computeKnownBits in the caller.
2424 KnownBits &Known) {
2425 if (N->getOpcode() == ISD::TRUNCATE) {
2426 Op = N->getOperand(0);
2427 Known = DAG.computeKnownBits(Op);
2428 if (N->getFlags().hasNoUnsignedWrap())
2429 Known.Zero.setBitsFrom(N.getScalarValueSizeInBits());
2430 return true;
2431 }
2432
2433 if (N.getValueType().getScalarType() != MVT::i1 ||
2434 !sd_match(
2436 return false;
2437
2438 Known = DAG.computeKnownBits(Op);
2439 return (Known.Zero | 1).isAllOnes();
2440}
2441
2442/// Return true if 'Use' is a load or a store that uses N as its base pointer
2443/// and that N may be folded in the load / store addressing mode.
2445 const TargetLowering &TLI) {
2446 EVT VT;
2447 unsigned AS;
2448
2449 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2450 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2451 return false;
2452 VT = LD->getMemoryVT();
2453 AS = LD->getAddressSpace();
2454 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2455 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2456 return false;
2457 VT = ST->getMemoryVT();
2458 AS = ST->getAddressSpace();
2460 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2461 return false;
2462 VT = LD->getMemoryVT();
2463 AS = LD->getAddressSpace();
2465 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2466 return false;
2467 VT = ST->getMemoryVT();
2468 AS = ST->getAddressSpace();
2469 } else {
2470 return false;
2471 }
2472
2474 if (N->isAnyAdd()) {
2475 AM.HasBaseReg = true;
2477 if (Offset)
2478 // [reg +/- imm]
2479 AM.BaseOffs = Offset->getSExtValue();
2480 else
2481 // [reg +/- reg]
2482 AM.Scale = 1;
2483 } else if (N->getOpcode() == ISD::SUB) {
2484 AM.HasBaseReg = true;
2486 if (Offset)
2487 // [reg +/- imm]
2488 AM.BaseOffs = -Offset->getSExtValue();
2489 else
2490 // [reg +/- reg]
2491 AM.Scale = 1;
2492 } else {
2493 return false;
2494 }
2495
2496 return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2497 VT.getTypeForEVT(*DAG.getContext()), AS);
2498}
2499
2500/// This inverts a canonicalization in IR that replaces a variable select arm
2501/// with an identity constant. Codegen improves if we re-use the variable
2502/// operand rather than load a constant. This can also be converted into a
2503/// masked vector operation if the target supports it.
2505 bool ShouldCommuteOperands) {
2506 SDValue N0 = N->getOperand(0);
2507 SDValue N1 = N->getOperand(1);
2508
2509 // Match a select as operand 1. The identity constant that we are looking for
2510 // is only valid as operand 1 of a non-commutative binop.
2511 if (ShouldCommuteOperands)
2512 std::swap(N0, N1);
2513
2514 SDValue Cond, TVal, FVal;
2516 m_Value(FVal)))))
2517 return SDValue();
2518
2519 // We can't hoist all instructions because of immediate UB (not speculatable).
2520 // For example div/rem by zero.
2522 return SDValue();
2523
2524 unsigned SelOpcode = N1.getOpcode();
2525 unsigned Opcode = N->getOpcode();
2526 EVT VT = N->getValueType(0);
2527 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2528
2529 // This transform increases uses of N0, so freeze it to be safe.
2530 // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2531 unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2532 if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo) &&
2533 TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
2534 FVal)) {
2535 SDValue F0 = DAG.getFreeze(N0);
2536 SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
2537 return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
2538 }
2539 // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2540 if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo) &&
2541 TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
2542 TVal)) {
2543 SDValue F0 = DAG.getFreeze(N0);
2544 SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
2545 return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
2546 }
2547
2548 return SDValue();
2549}
2550
2551SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2552 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2553 assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2554 "Unexpected binary operator");
2555
2556 if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2557 return Sel;
2558
2559 if (TLI.isCommutativeBinOp(BO->getOpcode()))
2560 if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2561 return Sel;
2562
2563 // Don't do this unless the old select is going away. We want to eliminate the
2564 // binary operator, not replace a binop with a select.
2565 // TODO: Handle ISD::SELECT_CC.
2566 unsigned SelOpNo = 0;
2567 SDValue Sel = BO->getOperand(0);
2568 auto BinOpcode = BO->getOpcode();
2569 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2570 SelOpNo = 1;
2571 Sel = BO->getOperand(1);
2572
2573 // Peek through trunc to shift amount type.
2574 if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA ||
2575 BinOpcode == ISD::SRL) && Sel.hasOneUse()) {
2576 // This is valid when the truncated bits of x are already zero.
2577 SDValue Op;
2578 KnownBits Known;
2579 if (isTruncateOf(DAG, Sel, Op, Known) &&
2581 Sel = Op;
2582 }
2583 }
2584
2585 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2586 return SDValue();
2587
2588 SDValue CT = Sel.getOperand(1);
2589 if (!isConstantOrConstantVector(CT, true) &&
2591 return SDValue();
2592
2593 SDValue CF = Sel.getOperand(2);
2594 if (!isConstantOrConstantVector(CF, true) &&
2596 return SDValue();
2597
2598 // Bail out if any constants are opaque because we can't constant fold those.
2599 // The exception is "and" and "or" with either 0 or -1 in which case we can
2600 // propagate non constant operands into select. I.e.:
2601 // and (select Cond, 0, -1), X --> select Cond, 0, X
2602 // or X, (select Cond, -1, 0) --> select Cond, -1, X
2603 bool CanFoldNonConst =
2604 (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2607
2608 SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2609 if (!CanFoldNonConst &&
2610 !isConstantOrConstantVector(CBO, true) &&
2612 return SDValue();
2613
2614 SDLoc DL(Sel);
2615 SDValue NewCT, NewCF;
2616 EVT VT = BO->getValueType(0);
2617
2618 if (CanFoldNonConst) {
2619 // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2620 if ((BinOpcode == ISD::AND && isNullOrNullSplat(CT)) ||
2621 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CT)))
2622 NewCT = CT;
2623 else
2624 NewCT = CBO;
2625
2626 if ((BinOpcode == ISD::AND && isNullOrNullSplat(CF)) ||
2627 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CF)))
2628 NewCF = CF;
2629 else
2630 NewCF = CBO;
2631 } else {
2632 // We have a select-of-constants followed by a binary operator with a
2633 // constant. Eliminate the binop by pulling the constant math into the
2634 // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2635 // CBO, CF + CBO
2636 NewCT = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CT})
2637 : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CT, CBO});
2638 if (!NewCT)
2639 return SDValue();
2640
2641 NewCF = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CF})
2642 : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CF, CBO});
2643 if (!NewCF)
2644 return SDValue();
2645 }
2646
2647 return DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF, BO->getFlags());
2648}
2649
2651 SelectionDAG &DAG) {
2652 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2653 "Expecting add or sub");
2654
2655 // Match a constant operand and a zext operand for the math instruction:
2656 // add Z, C
2657 // sub C, Z
2658 bool IsAdd = N->getOpcode() == ISD::ADD;
2659 SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2660 SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2661 auto *CN = dyn_cast<ConstantSDNode>(C);
2662 if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2663 return SDValue();
2664
2665 // Match the zext operand as a setcc of a boolean.
2666 if (Z.getOperand(0).getValueType() != MVT::i1)
2667 return SDValue();
2668
2669 // Match the compare as: setcc (X & 1), 0, eq.
2670 if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
2672 return SDValue();
2673
2674 // We are adding/subtracting a constant and an inverted low bit. Turn that
2675 // into a subtract/add of the low bit with incremented/decremented constant:
2676 // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2677 // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2678 EVT VT = C.getValueType();
2679 SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
2680 SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
2681 : DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2682 return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2683}
2684
2685// Attempt to form avgceil(A, B) from (A | B) - ((A ^ B) >> 1)
2686SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
2687 SDValue N0 = N->getOperand(0);
2688 EVT VT = N0.getValueType();
2689 SDValue A, B;
2690
2691 if ((!LegalOperations || hasOperation(ISD::AVGCEILU, VT)) &&
2693 m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)), m_One())))) {
2694 return DAG.getNode(ISD::AVGCEILU, DL, VT, A, B);
2695 }
2696 if ((!LegalOperations || hasOperation(ISD::AVGCEILS, VT)) &&
2698 m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)), m_One())))) {
2699 return DAG.getNode(ISD::AVGCEILS, DL, VT, A, B);
2700 }
2701 return SDValue();
2702}
2703
2704/// Try to fold a pointer arithmetic node.
2705/// This needs to be done separately from normal addition, because pointer
2706/// addition is not commutative.
2707SDValue DAGCombiner::visitPTRADD(SDNode *N) {
2708 SDValue N0 = N->getOperand(0);
2709 SDValue N1 = N->getOperand(1);
2710 EVT PtrVT = N0.getValueType();
2711 EVT IntVT = N1.getValueType();
2712 SDLoc DL(N);
2713
2714 // This is already ensured by an assert in SelectionDAG::getNode(). Several
2715 // combines here depend on this assumption.
2716 assert(PtrVT == IntVT &&
2717 "PTRADD with different operand types is not supported");
2718
2719 // fold (ptradd x, 0) -> x
2720 if (isNullConstant(N1))
2721 return N0;
2722
2723 // fold (ptradd 0, x) -> x
2724 if (PtrVT == IntVT && isNullConstant(N0))
2725 return N1;
2726
2727 if (N0.getOpcode() == ISD::PTRADD &&
2728 !reassociationCanBreakAddressingModePattern(ISD::PTRADD, DL, N, N0, N1)) {
2729 SDValue X = N0.getOperand(0);
2730 SDValue Y = N0.getOperand(1);
2731 SDValue Z = N1;
2732 bool N0OneUse = N0.hasOneUse();
2733 bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Y);
2734 bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Z);
2735
2736 // (ptradd (ptradd x, y), z) -> (ptradd x, (add y, z)) if:
2737 // * y is a constant and (ptradd x, y) has one use; or
2738 // * y and z are both constants.
2739 if ((YIsConstant && N0OneUse) || (YIsConstant && ZIsConstant)) {
2740 // If both additions in the original were NUW, the new ones are as well.
2741 SDNodeFlags Flags =
2742 (N->getFlags() & N0->getFlags()) & SDNodeFlags::NoUnsignedWrap;
2743 SDValue Add = DAG.getNode(ISD::ADD, DL, IntVT, {Y, Z}, Flags);
2744 AddToWorklist(Add.getNode());
2745 // We can't set InBounds even if both original ptradds were InBounds and
2746 // NUW: SDAG usually represents pointers as integers, therefore, the
2747 // matched pattern behaves as if it had implicit casts:
2748 // (ptradd inbounds (inttoptr (ptrtoint (ptradd inbounds x, y))), z)
2749 // The outer inbounds ptradd might therefore rely on a provenance that x
2750 // does not have.
2751 return DAG.getMemBasePlusOffset(X, Add, DL, Flags);
2752 }
2753 }
2754
2755 // The following combines can turn in-bounds pointer arithmetic out of bounds.
2756 // That is problematic for settings like AArch64's CPA, which checks that
2757 // intermediate results of pointer arithmetic remain in bounds. The target
2758 // therefore needs to opt-in to enable them.
2760 DAG.getMachineFunction().getFunction(), PtrVT))
2761 return SDValue();
2762
2763 if (N0.getOpcode() == ISD::PTRADD && isa<ConstantSDNode>(N1)) {
2764 // Fold (ptradd (ptradd GA, v), c) -> (ptradd (ptradd GA, c) v) with
2765 // global address GA and constant c, such that c can be folded into GA.
2766 // TODO: Support constant vector splats.
2767 SDValue GAValue = N0.getOperand(0);
2768 if (const GlobalAddressSDNode *GA =
2770 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2771 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
2772 // If both additions in the original were NUW, reassociation preserves
2773 // that.
2774 SDNodeFlags Flags =
2775 (N->getFlags() & N0->getFlags()) & SDNodeFlags::NoUnsignedWrap;
2776 // We can't set InBounds even if both original ptradds were InBounds and
2777 // NUW: SDAG usually represents pointers as integers, therefore, the
2778 // matched pattern behaves as if it had implicit casts:
2779 // (ptradd inbounds (inttoptr (ptrtoint (ptradd inbounds GA, v))), c)
2780 // The outer inbounds ptradd might therefore rely on a provenance that
2781 // GA does not have.
2782 SDValue Inner = DAG.getMemBasePlusOffset(GAValue, N1, DL, Flags);
2783 AddToWorklist(Inner.getNode());
2784 return DAG.getMemBasePlusOffset(Inner, N0.getOperand(1), DL, Flags);
2785 }
2786 }
2787 }
2788
2789 if (N1.getOpcode() == ISD::ADD && N1.hasOneUse()) {
2790 // (ptradd x, (add y, z)) -> (ptradd (ptradd x, y), z) if z is a constant,
2791 // y is not, and (add y, z) is used only once.
2792 // (ptradd x, (add y, z)) -> (ptradd (ptradd x, z), y) if y is a constant,
2793 // z is not, and (add y, z) is used only once.
2794 // The goal is to move constant offsets to the outermost ptradd, to create
2795 // more opportunities to fold offsets into memory instructions.
2796 // Together with the another combine above, this also implements
2797 // (ptradd (ptradd x, y), z) -> (ptradd (ptradd x, z), y)).
2798 SDValue X = N0;
2799 SDValue Y = N1.getOperand(0);
2800 SDValue Z = N1.getOperand(1);
2801 bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Y);
2802 bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Z);
2803
2804 // If both additions in the original were NUW, reassociation preserves that.
2805 SDNodeFlags CommonFlags = N->getFlags() & N1->getFlags();
2806 SDNodeFlags ReassocFlags = CommonFlags & SDNodeFlags::NoUnsignedWrap;
2807 if (CommonFlags.hasNoUnsignedWrap()) {
2808 // If both operations are NUW and the PTRADD is inbounds, the offests are
2809 // both non-negative, so the reassociated PTRADDs are also inbounds.
2810 ReassocFlags |= N->getFlags() & SDNodeFlags::InBounds;
2811 }
2812
2813 if (ZIsConstant != YIsConstant) {
2814 if (YIsConstant)
2815 std::swap(Y, Z);
2816 SDValue Inner = DAG.getMemBasePlusOffset(X, Y, DL, ReassocFlags);
2817 AddToWorklist(Inner.getNode());
2818 return DAG.getMemBasePlusOffset(Inner, Z, DL, ReassocFlags);
2819 }
2820 }
2821
2822 // Transform (ptradd a, b) -> (or disjoint a, b) if it is equivalent and if
2823 // that transformation can't block an offset folding at any use of the ptradd.
2824 // This should be done late, after legalization, so that it doesn't block
2825 // other ptradd combines that could enable more offset folding.
2826 if (LegalOperations && DAG.haveNoCommonBitsSet(N0, N1)) {
2827 bool TransformCannotBreakAddrMode = none_of(N->users(), [&](SDNode *User) {
2828 return canFoldInAddressingMode(N, User, DAG, TLI);
2829 });
2830
2831 if (TransformCannotBreakAddrMode)
2832 return DAG.getNode(ISD::OR, DL, PtrVT, N0, N1, SDNodeFlags::Disjoint);
2833 }
2834
2835 return SDValue();
2836}
2837
2838/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2839/// a shift and add with a different constant.
2841 SelectionDAG &DAG) {
2842 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2843 "Expecting add or sub");
2844
2845 // We need a constant operand for the add/sub, and the other operand is a
2846 // logical shift right: add (srl), C or sub C, (srl).
2847 bool IsAdd = N->getOpcode() == ISD::ADD;
2848 SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2849 SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2850 if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2851 ShiftOp.getOpcode() != ISD::SRL)
2852 return SDValue();
2853
2854 // The shift must be of a 'not' value.
2855 SDValue Not = ShiftOp.getOperand(0);
2856 if (!Not.hasOneUse() || !isBitwiseNot(Not))
2857 return SDValue();
2858
2859 // The shift must be moving the sign bit to the least-significant-bit.
2860 EVT VT = ShiftOp.getValueType();
2861 SDValue ShAmt = ShiftOp.getOperand(1);
2862 ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2863 if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2864 return SDValue();
2865
2866 // Eliminate the 'not' by adjusting the shift and add/sub constant:
2867 // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2868 // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2869 if (SDValue NewC = DAG.FoldConstantArithmetic(
2870 IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2871 {ConstantOp, DAG.getConstant(1, DL, VT)})) {
2872 SDValue NewShift = DAG.getNode(IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2873 Not.getOperand(0), ShAmt);
2874 return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2875 }
2876
2877 return SDValue();
2878}
2879
2880static bool
2882 return (isBitwiseNot(Op0) && Op0.getOperand(0) == Op1) ||
2883 (isBitwiseNot(Op1) && Op1.getOperand(0) == Op0);
2884}
2885
2886/// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2887/// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2888/// are no common bits set in the operands).
2889SDValue DAGCombiner::visitADDLike(SDNode *N) {
2890 SDValue N0 = N->getOperand(0);
2891 SDValue N1 = N->getOperand(1);
2892 EVT VT = N0.getValueType();
2893 SDLoc DL(N);
2894
2895 // fold (add x, undef) -> undef
2896 if (N0.isUndef())
2897 return N0;
2898 if (N1.isUndef())
2899 return N1;
2900
2901 // fold (add c1, c2) -> c1+c2
2902 if (SDValue C = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}))
2903 return C;
2904
2905 // canonicalize constant to RHS
2908 return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2909
2910 if (areBitwiseNotOfEachother(N0, N1))
2911 return DAG.getConstant(APInt::getAllOnes(VT.getScalarSizeInBits()), DL, VT);
2912
2913 // fold vector ops
2914 if (VT.isVector()) {
2915 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2916 return FoldedVOp;
2917
2918 // fold (add x, 0) -> x, vector edition
2920 return N0;
2921 }
2922
2923 // fold (add x, 0) -> x
2924 if (isNullConstant(N1))
2925 return N0;
2926
2927 if (N0.getOpcode() == ISD::SUB) {
2928 SDValue N00 = N0.getOperand(0);
2929 SDValue N01 = N0.getOperand(1);
2930
2931 // fold ((A-c1)+c2) -> (A+(c2-c1))
2932 if (SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N01}))
2933 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2934
2935 // fold ((c1-A)+c2) -> (c1+c2)-A
2936 if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N00}))
2937 return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2938 }
2939
2940 // add (sext i1 X), 1 -> zext (not i1 X)
2941 // We don't transform this pattern:
2942 // add (zext i1 X), -1 -> sext (not i1 X)
2943 // because most (?) targets generate better code for the zext form.
2944 if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2945 isOneOrOneSplat(N1)) {
2946 SDValue X = N0.getOperand(0);
2947 if ((!LegalOperations ||
2948 (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2950 X.getScalarValueSizeInBits() == 1) {
2951 SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2952 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2953 }
2954 }
2955
2956 // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2957 // iff (or x, c0) is equivalent to (add x, c0).
2958 // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2959 // iff (xor x, c0) is equivalent to (add x, c0).
2960 if (DAG.isADDLike(N0)) {
2961 SDValue N01 = N0.getOperand(1);
2962 if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N01}))
2963 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add);
2964 }
2965
2966 if (SDValue NewSel = foldBinOpIntoSelect(N))
2967 return NewSel;
2968
2969 // reassociate add
2970 if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) {
2971 if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2972 return RADD;
2973
2974 // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2975 // equivalent to (add x, c).
2976 // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2977 // equivalent to (add x, c).
2978 // Do this optimization only when adding c does not introduce instructions
2979 // for adding carries.
2980 auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2981 if (DAG.isADDLike(N0) && N0.hasOneUse() &&
2982 isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2983 // If N0's type does not split or is a sign mask, it does not introduce
2984 // add carry.
2985 auto TyActn = TLI.getTypeAction(*DAG.getContext(), N0.getValueType());
2986 bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal ||
2989 if (NoAddCarry)
2990 return DAG.getNode(
2991 ISD::ADD, DL, VT,
2992 DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
2993 N0.getOperand(1));
2994 }
2995 return SDValue();
2996 };
2997 if (SDValue Add = ReassociateAddOr(N0, N1))
2998 return Add;
2999 if (SDValue Add = ReassociateAddOr(N1, N0))
3000 return Add;
3001
3002 // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
3003 if (SDValue SD =
3004 reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
3005 return SD;
3006 }
3007
3008 SDValue A, B, C, D;
3009
3010 // fold ((0-A) + B) -> B-A
3011 if (sd_match(N0, m_Neg(m_Value(A))))
3012 return DAG.getNode(ISD::SUB, DL, VT, N1, A);
3013
3014 // fold (A + (0-B)) -> A-B
3015 if (sd_match(N1, m_Neg(m_Value(B))))
3016 return DAG.getNode(ISD::SUB, DL, VT, N0, B);
3017
3018 // fold (A+(B-A)) -> B
3019 if (sd_match(N1, m_Sub(m_Value(B), m_Specific(N0))))
3020 return B;
3021
3022 // fold ((B-A)+A) -> B
3023 if (sd_match(N0, m_Sub(m_Value(B), m_Specific(N1))))
3024 return B;
3025
3026 // fold ((A-B)+(C-A)) -> (C-B)
3027 if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
3029 return DAG.getNode(ISD::SUB, DL, VT, C, B);
3030
3031 // fold ((A-B)+(B-C)) -> (A-C)
3032 if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
3034 return DAG.getNode(ISD::SUB, DL, VT, A, C);
3035
3036 // fold (A+(B-(A+C))) to (B-C)
3037 // fold (A+(B-(C+A))) to (B-C)
3038 if (sd_match(N1, m_Sub(m_Value(B), m_Add(m_Specific(N0), m_Value(C)))))
3039 return DAG.getNode(ISD::SUB, DL, VT, B, C);
3040
3041 // fold (A+((B-A)+or-C)) to (B+or-C)
3042 if (sd_match(N1,
3044 m_Sub(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)))))
3045 return DAG.getNode(N1.getOpcode(), DL, VT, B, C);
3046
3047 // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
3048 if (sd_match(N0, m_OneUse(m_Sub(m_Value(A), m_Value(B)))) &&
3049 sd_match(N1, m_OneUse(m_Sub(m_Value(C), m_Value(D)))) &&
3051 return DAG.getNode(ISD::SUB, DL, VT,
3052 DAG.getNode(ISD::ADD, SDLoc(N0), VT, A, C),
3053 DAG.getNode(ISD::ADD, SDLoc(N1), VT, B, D));
3054
3055 // fold (add (umax X, C), -C) --> (usubsat X, C)
3056 if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
3057 auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
3058 return (!Max && !Op) ||
3059 (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
3060 };
3061 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
3062 /*AllowUndefs*/ true))
3063 return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
3064 N0.getOperand(1));
3065 }
3066
3068 return SDValue(N, 0);
3069
3070 if (isOneOrOneSplat(N1)) {
3071 // fold (add (xor a, -1), 1) -> (sub 0, a)
3072 if (isBitwiseNot(N0))
3073 return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
3074 N0.getOperand(0));
3075
3076 // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
3077 if (N0.getOpcode() == ISD::ADD) {
3078 SDValue A, Xor;
3079
3080 if (isBitwiseNot(N0.getOperand(0))) {
3081 A = N0.getOperand(1);
3082 Xor = N0.getOperand(0);
3083 } else if (isBitwiseNot(N0.getOperand(1))) {
3084 A = N0.getOperand(0);
3085 Xor = N0.getOperand(1);
3086 }
3087
3088 if (Xor)
3089 return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
3090 }
3091
3092 // Look for:
3093 // add (add x, y), 1
3094 // And if the target does not like this form then turn into:
3095 // sub y, (xor x, -1)
3096 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3097 N0.hasOneUse() &&
3098 // Limit this to after legalization if the add has wrap flags
3099 (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() &&
3100 !N->getFlags().hasNoSignedWrap()))) {
3101 SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
3102 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
3103 }
3104 }
3105
3106 // (x - y) + -1 -> add (xor y, -1), x
3107 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3108 isAllOnesOrAllOnesSplat(N1, /*AllowUndefs=*/true)) {
3109 SDValue Not = DAG.getNOT(DL, N0.getOperand(1), VT);
3110 return DAG.getNode(ISD::ADD, DL, VT, Not, N0.getOperand(0));
3111 }
3112
3113 // Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB).
3114 // This can help if the inner add has multiple uses.
3115 APInt CM, CA;
3116 if (ConstantSDNode *CB = dyn_cast<ConstantSDNode>(N1)) {
3117 if (VT.getScalarSizeInBits() <= 64) {
3119 m_ConstInt(CM)))) &&
3121 (CA * CM + CB->getAPIntValue()).getSExtValue())) {
3122 SDNodeFlags Flags;
3123 // If all the inputs are nuw, the outputs can be nuw. If all the input
3124 // are _also_ nsw the outputs can be too.
3125 if (N->getFlags().hasNoUnsignedWrap() &&
3126 N0->getFlags().hasNoUnsignedWrap() &&
3129 if (N->getFlags().hasNoSignedWrap() &&
3130 N0->getFlags().hasNoSignedWrap() &&
3133 }
3134 SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
3135 DAG.getConstant(CM, DL, VT), Flags);
3136 return DAG.getNode(
3137 ISD::ADD, DL, VT, Mul,
3138 DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
3139 }
3140 // Also look in case there is an intermediate add.
3141 if (sd_match(N0, m_OneUse(m_Add(
3143 m_ConstInt(CM))),
3144 m_Value(B)))) &&
3146 (CA * CM + CB->getAPIntValue()).getSExtValue())) {
3147 SDNodeFlags Flags;
3148 // If all the inputs are nuw, the outputs can be nuw. If all the input
3149 // are _also_ nsw the outputs can be too.
3150 SDValue OMul =
3151 N0.getOperand(0) == B ? N0.getOperand(1) : N0.getOperand(0);
3152 if (N->getFlags().hasNoUnsignedWrap() &&
3153 N0->getFlags().hasNoUnsignedWrap() &&
3154 OMul->getFlags().hasNoUnsignedWrap() &&
3155 OMul.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
3157 if (N->getFlags().hasNoSignedWrap() &&
3158 N0->getFlags().hasNoSignedWrap() &&
3159 OMul->getFlags().hasNoSignedWrap() &&
3160 OMul.getOperand(0)->getFlags().hasNoSignedWrap())
3162 }
3163 SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
3164 DAG.getConstant(CM, DL, VT), Flags);
3165 SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N1), VT, Mul, B, Flags);
3166 return DAG.getNode(
3167 ISD::ADD, DL, VT, Add,
3168 DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
3169 }
3170 }
3171 }
3172
3173 if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
3174 return Combined;
3175
3176 if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
3177 return Combined;
3178
3179 return SDValue();
3180}
3181
3182// Attempt to form avgfloor(A, B) from (A & B) + ((A ^ B) >> 1)
3183// Attempt to form avgfloor(A, B) from ((A >> 1) + (B >> 1)) + (A & B & 1)
3184// Attempt to form avgceil(A, B) from ((A >> 1) + (B >> 1)) + ((A | B) & 1)
3185SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
3186 SDValue N0 = N->getOperand(0);
3187 EVT VT = N0.getValueType();
3188 SDValue A, B;
3189
3190 if ((!LegalOperations || hasOperation(ISD::AVGFLOORU, VT)) &&
3191 (sd_match(N,
3193 m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)), m_One()))) ||
3196 m_Srl(m_Deferred(A), m_One()),
3197 m_Srl(m_Deferred(B), m_One()))))) {
3198 return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);
3199 }
3200 if ((!LegalOperations || hasOperation(ISD::AVGFLOORS, VT)) &&
3201 (sd_match(N,
3203 m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)), m_One()))) ||
3206 m_Sra(m_Deferred(A), m_One()),
3207 m_Sra(m_Deferred(B), m_One()))))) {
3208 return DAG.getNode(ISD::AVGFLOORS, DL, VT, A, B);
3209 }
3210
3211 if ((!LegalOperations || hasOperation(ISD::AVGCEILU, VT)) &&
3212 sd_match(N,
3214 m_Srl(m_Deferred(A), m_One()),
3215 m_Srl(m_Deferred(B), m_One())))) {
3216 return DAG.getNode(ISD::AVGCEILU, DL, VT, A, B);
3217 }
3218 if ((!LegalOperations || hasOperation(ISD::AVGCEILS, VT)) &&
3219 sd_match(N,
3221 m_Sra(m_Deferred(A), m_One()),
3222 m_Sra(m_Deferred(B), m_One())))) {
3223 return DAG.getNode(ISD::AVGCEILS, DL, VT, A, B);
3224 }
3225
3226 return SDValue();
3227}
3228
3229SDValue DAGCombiner::visitADD(SDNode *N) {
3230 SDValue N0 = N->getOperand(0);
3231 SDValue N1 = N->getOperand(1);
3232 EVT VT = N0.getValueType();
3233 SDLoc DL(N);
3234
3235 if (SDValue Combined = visitADDLike(N))
3236 return Combined;
3237
3238 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
3239 return V;
3240
3241 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
3242 return V;
3243
3244 if (SDValue V = MatchRotate(N0, N1, SDLoc(N), /*FromAdd=*/true))
3245 return V;
3246
3247 // Try to match AVGFLOOR fixedwidth pattern
3248 if (SDValue V = foldAddToAvg(N, DL))
3249 return V;
3250
3251 // fold (a+b) -> (a|b) iff a and b share no bits.
3252 if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
3253 DAG.haveNoCommonBitsSet(N0, N1))
3254 return DAG.getNode(ISD::OR, DL, VT, N0, N1, SDNodeFlags::Disjoint);
3255
3256 // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
3257 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
3258 const APInt &C0 = N0->getConstantOperandAPInt(0);
3259 const APInt &C1 = N1->getConstantOperandAPInt(0);
3260 return DAG.getVScale(DL, VT, C0 + C1);
3261 }
3262
3263 // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
3264 if (N0.getOpcode() == ISD::ADD &&
3265 N0.getOperand(1).getOpcode() == ISD::VSCALE &&
3266 N1.getOpcode() == ISD::VSCALE) {
3267 const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
3268 const APInt &VS1 = N1->getConstantOperandAPInt(0);
3269 SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
3270 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
3271 }
3272
3273 // Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
3274 if (N0.getOpcode() == ISD::STEP_VECTOR &&
3275 N1.getOpcode() == ISD::STEP_VECTOR) {
3276 const APInt &C0 = N0->getConstantOperandAPInt(0);
3277 const APInt &C1 = N1->getConstantOperandAPInt(0);
3278 APInt NewStep = C0 + C1;
3279 return DAG.getStepVector(DL, VT, NewStep);
3280 }
3281
3282 // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
3283 if (N0.getOpcode() == ISD::ADD &&
3285 N1.getOpcode() == ISD::STEP_VECTOR) {
3286 const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
3287 const APInt &SV1 = N1->getConstantOperandAPInt(0);
3288 APInt NewStep = SV0 + SV1;
3289 SDValue SV = DAG.getStepVector(DL, VT, NewStep);
3290 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
3291 }
3292
3293 return SDValue();
3294}
3295
3296SDValue DAGCombiner::visitADDSAT(SDNode *N) {
3297 unsigned Opcode = N->getOpcode();
3298 SDValue N0 = N->getOperand(0);
3299 SDValue N1 = N->getOperand(1);
3300 EVT VT = N0.getValueType();
3301 bool IsSigned = Opcode == ISD::SADDSAT;
3302 SDLoc DL(N);
3303
3304 // fold (add_sat x, undef) -> -1
3305 if (N0.isUndef() || N1.isUndef())
3306 return DAG.getAllOnesConstant(DL, VT);
3307
3308 // fold (add_sat c1, c2) -> c3
3309 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
3310 return C;
3311
3312 // canonicalize constant to RHS
3315 return DAG.getNode(Opcode, DL, VT, N1, N0);
3316
3317 // fold vector ops
3318 if (VT.isVector()) {
3319 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3320 return FoldedVOp;
3321
3322 // fold (add_sat x, 0) -> x, vector edition
3324 return N0;
3325 }
3326
3327 // fold (add_sat x, 0) -> x
3328 if (isNullConstant(N1))
3329 return N0;
3330
3331 // If it cannot overflow, transform into an add.
3332 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3333 return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
3334
3335 return SDValue();
3336}
3337
3339 bool ForceCarryReconstruction = false) {
3340 bool Masked = false;
3341
3342 // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
3343 while (true) {
3344 if (ForceCarryReconstruction && V.getValueType() == MVT::i1)
3345 return V;
3346
3347 if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
3348 V = V.getOperand(0);
3349 continue;
3350 }
3351
3352 if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
3353 if (ForceCarryReconstruction)
3354 return V;
3355
3356 Masked = true;
3357 V = V.getOperand(0);
3358 continue;
3359 }
3360
3361 break;
3362 }
3363
3364 // If this is not a carry, return.
3365 if (V.getResNo() != 1)
3366 return SDValue();
3367
3368 if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY &&
3369 V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
3370 return SDValue();
3371
3372 EVT VT = V->getValueType(0);
3373 if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
3374 return SDValue();
3375
3376 // If the result is masked, then no matter what kind of bool it is we can
3377 // return. If it isn't, then we need to make sure the bool type is either 0 or
3378 // 1 and not other values.
3379 if (Masked ||
3380 TLI.getBooleanContents(V.getValueType()) ==
3382 return V;
3383
3384 return SDValue();
3385}
3386
3387/// Given the operands of an add/sub operation, see if the 2nd operand is a
3388/// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
3389/// the opcode and bypass the mask operation.
3390static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
3391 SelectionDAG &DAG, const SDLoc &DL) {
3392 if (N1.getOpcode() == ISD::ZERO_EXTEND)
3393 N1 = N1.getOperand(0);
3394
3395 if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
3396 return SDValue();
3397
3398 EVT VT = N0.getValueType();
3399 SDValue N10 = N1.getOperand(0);
3400 if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
3401 N10 = N10.getOperand(0);
3402
3403 if (N10.getValueType() != VT)
3404 return SDValue();
3405
3406 if (DAG.ComputeNumSignBits(N10) != VT.getScalarSizeInBits())
3407 return SDValue();
3408
3409 // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
3410 // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
3411 return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N10);
3412}
3413
3414/// Helper for doing combines based on N0 and N1 being added to each other.
3415SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
3416 SDNode *LocReference) {
3417 EVT VT = N0.getValueType();
3418 SDLoc DL(LocReference);
3419
3420 // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
3421 SDValue Y, N;
3422 if (sd_match(N1, m_Shl(m_Neg(m_Value(Y)), m_Value(N))))
3423 return DAG.getNode(ISD::SUB, DL, VT, N0,
3424 DAG.getNode(ISD::SHL, DL, VT, Y, N));
3425
3426 if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
3427 return V;
3428
3429 // Look for:
3430 // add (add x, 1), y
3431 // And if the target does not like this form then turn into:
3432 // sub y, (xor x, -1)
3433 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3434 N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1)) &&
3435 // Limit this to after legalization if the add has wrap flags
3436 (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() &&
3437 !N0->getFlags().hasNoSignedWrap()))) {
3438 SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
3439 return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
3440 }
3441
3442 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
3443 // Hoist one-use subtraction by non-opaque constant:
3444 // (x - C) + y -> (x + y) - C
3445 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3446 if (isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3447 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
3448 return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
3449 }
3450 // Hoist one-use subtraction from non-opaque constant:
3451 // (C - x) + y -> (y - x) + C
3452 if (isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3453 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
3454 return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
3455 }
3456 }
3457
3458 // add (mul x, C), x -> mul x, C+1
3459 if (N0.getOpcode() == ISD::MUL && N0.getOperand(0) == N1 &&
3460 isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true) &&
3461 N0.hasOneUse()) {
3462 SDValue NewC = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
3463 DAG.getConstant(1, DL, VT));
3464 return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), NewC);
3465 }
3466
3467 // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
3468 // rather than 'add 0/-1' (the zext should get folded).
3469 // add (sext i1 Y), X --> sub X, (zext i1 Y)
3470 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
3471 N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
3473 SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
3474 return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
3475 }
3476
3477 // add X, (sextinreg Y i1) -> sub X, (and Y 1)
3478 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3479 VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3480 if (TN->getVT() == MVT::i1) {
3481 SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3482 DAG.getConstant(1, DL, VT));
3483 return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
3484 }
3485 }
3486
3487 // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3488 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1)) &&
3489 N1.getResNo() == 0)
3490 return DAG.getNode(ISD::UADDO_CARRY, DL, N1->getVTList(),
3491 N0, N1.getOperand(0), N1.getOperand(2));
3492
3493 // (add X, Carry) -> (uaddo_carry X, 0, Carry)
3495 if (SDValue Carry = getAsCarry(TLI, N1))
3496 return DAG.getNode(ISD::UADDO_CARRY, DL,
3497 DAG.getVTList(VT, Carry.getValueType()), N0,
3498 DAG.getConstant(0, DL, VT), Carry);
3499
3500 return SDValue();
3501}
3502
3503SDValue DAGCombiner::visitADDC(SDNode *N) {
3504 SDValue N0 = N->getOperand(0);
3505 SDValue N1 = N->getOperand(1);
3506 EVT VT = N0.getValueType();
3507 SDLoc DL(N);
3508
3509 // If the flag result is dead, turn this into an ADD.
3510 if (!N->hasAnyUseOfValue(1))
3511 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3512 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3513
3514 // canonicalize constant to RHS.
3515 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3516 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3517 if (N0C && !N1C)
3518 return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
3519
3520 // fold (addc x, 0) -> x + no carry out
3521 if (isNullConstant(N1))
3522 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
3523 DL, MVT::Glue));
3524
3525 // If it cannot overflow, transform into an add.
3527 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3528 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3529
3530 return SDValue();
3531}
3532
3533/**
3534 * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
3535 * then the flip also occurs if computing the inverse is the same cost.
3536 * This function returns an empty SDValue in case it cannot flip the boolean
3537 * without increasing the cost of the computation. If you want to flip a boolean
3538 * no matter what, use DAG.getLogicalNOT.
3539 */
3541 const TargetLowering &TLI,
3542 bool Force) {
3543 if (Force && isa<ConstantSDNode>(V))
3544 return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
3545
3546 if (V.getOpcode() != ISD::XOR)
3547 return SDValue();
3548
3549 if (DAG.isBoolConstant(V.getOperand(1)) == true)
3550 return V.getOperand(0);
3551 if (Force && isConstOrConstSplat(V.getOperand(1), false))
3552 return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
3553 return SDValue();
3554}
3555
3556SDValue DAGCombiner::visitADDO(SDNode *N) {
3557 SDValue N0 = N->getOperand(0);
3558 SDValue N1 = N->getOperand(1);
3559 EVT VT = N0.getValueType();
3560 bool IsSigned = (ISD::SADDO == N->getOpcode());
3561
3562 EVT CarryVT = N->getValueType(1);
3563 SDLoc DL(N);
3564
3565 // If the flag result is dead, turn this into an ADD.
3566 if (!N->hasAnyUseOfValue(1))
3567 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3568 DAG.getUNDEF(CarryVT));
3569
3570 // canonicalize constant to RHS.
3573 return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
3574
3575 // fold (addo x, 0) -> x + no carry out
3576 if (isNullOrNullSplat(N1))
3577 return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3578
3579 // If it cannot overflow, transform into an add.
3580 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3581 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3582 DAG.getConstant(0, DL, CarryVT));
3583
3584 if (IsSigned) {
3585 // fold (saddo (xor a, -1), 1) -> (ssub 0, a).
3586 if (isBitwiseNot(N0) && isOneOrOneSplat(N1))
3587 return DAG.getNode(ISD::SSUBO, DL, N->getVTList(),
3588 DAG.getConstant(0, DL, VT), N0.getOperand(0));
3589 } else {
3590 // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3591 if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
3592 SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
3593 DAG.getConstant(0, DL, VT), N0.getOperand(0));
3594 return CombineTo(
3595 N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3596 }
3597
3598 if (SDValue Combined = visitUADDOLike(N0, N1, N))
3599 return Combined;
3600
3601 if (SDValue Combined = visitUADDOLike(N1, N0, N))
3602 return Combined;
3603 }
3604
3605 return SDValue();
3606}
3607
3608SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3609 EVT VT = N0.getValueType();
3610 if (VT.isVector())
3611 return SDValue();
3612
3613 // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3614 // If Y + 1 cannot overflow.
3615 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1))) {
3616 SDValue Y = N1.getOperand(0);
3617 SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
3619 return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0, Y,
3620 N1.getOperand(2));
3621 }
3622
3623 // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry)
3625 if (SDValue Carry = getAsCarry(TLI, N1))
3626 return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0,
3627 DAG.getConstant(0, SDLoc(N), VT), Carry);
3628
3629 return SDValue();
3630}
3631
3632SDValue DAGCombiner::visitADDE(SDNode *N) {
3633 SDValue N0 = N->getOperand(0);
3634 SDValue N1 = N->getOperand(1);
3635 SDValue CarryIn = N->getOperand(2);
3636
3637 // canonicalize constant to RHS
3638 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3639 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3640 if (N0C && !N1C)
3641 return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
3642 N1, N0, CarryIn);
3643
3644 // fold (adde x, y, false) -> (addc x, y)
3645 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3646 return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
3647
3648 return SDValue();
3649}
3650
3651SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) {
3652 SDValue N0 = N->getOperand(0);
3653 SDValue N1 = N->getOperand(1);
3654 SDValue CarryIn = N->getOperand(2);
3655 SDLoc DL(N);
3656
3657 // canonicalize constant to RHS
3658 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3659 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3660 if (N0C && !N1C)
3661 return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3662
3663 // fold (uaddo_carry x, y, false) -> (uaddo x, y)
3664 if (isNullConstant(CarryIn)) {
3665 if (!LegalOperations ||
3666 TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
3667 return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
3668 }
3669
3670 // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3671 if (isNullConstant(N0) && isNullConstant(N1)) {
3672 EVT VT = N0.getValueType();
3673 EVT CarryVT = CarryIn.getValueType();
3674 SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
3675 AddToWorklist(CarryExt.getNode());
3676 return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
3677 DAG.getConstant(1, DL, VT)),
3678 DAG.getConstant(0, DL, CarryVT));
3679 }
3680
3681 if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N))
3682 return Combined;
3683
3684 if (SDValue Combined = visitUADDO_CARRYLike(N1, N0, CarryIn, N))
3685 return Combined;
3686
3687 // We want to avoid useless duplication.
3688 // TODO: This is done automatically for binary operations. As UADDO_CARRY is
3689 // not a binary operation, this is not really possible to leverage this
3690 // existing mechanism for it. However, if more operations require the same
3691 // deduplication logic, then it may be worth generalize.
3692 SDValue Ops[] = {N1, N0, CarryIn};
3693 SDNode *CSENode =
3694 DAG.getNodeIfExists(ISD::UADDO_CARRY, N->getVTList(), Ops, N->getFlags());
3695 if (CSENode)
3696 return SDValue(CSENode, 0);
3697
3698 return SDValue();
3699}
3700
3701/**
3702 * If we are facing some sort of diamond carry propagation pattern try to
3703 * break it up to generate something like:
3704 * (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry)
3705 *
3706 * The end result is usually an increase in operation required, but because the
3707 * carry is now linearized, other transforms can kick in and optimize the DAG.
3708 *
3709 * Patterns typically look something like
3710 * (uaddo A, B)
3711 * / \
3712 * Carry Sum
3713 * | \
3714 * | (uaddo_carry *, 0, Z)
3715 * | /
3716 * \ Carry
3717 * | /
3718 * (uaddo_carry X, *, *)
3719 *
3720 * But numerous variation exist. Our goal is to identify A, B, X and Z and
3721 * produce a combine with a single path for carry propagation.
3722 */
3724 SelectionDAG &DAG, SDValue X,
3725 SDValue Carry0, SDValue Carry1,
3726 SDNode *N) {
3727 if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3728 return SDValue();
3729 if (Carry1.getOpcode() != ISD::UADDO)
3730 return SDValue();
3731
3732 SDValue Z;
3733
3734 /**
3735 * First look for a suitable Z. It will present itself in the form of
3736 * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3737 */
3738 if (Carry0.getOpcode() == ISD::UADDO_CARRY &&
3739 isNullConstant(Carry0.getOperand(1))) {
3740 Z = Carry0.getOperand(2);
3741 } else if (Carry0.getOpcode() == ISD::UADDO &&
3742 isOneConstant(Carry0.getOperand(1))) {
3743 EVT VT = Carry0->getValueType(1);
3744 Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
3745 } else {
3746 // We couldn't find a suitable Z.
3747 return SDValue();
3748 }
3749
3750
3751 auto cancelDiamond = [&](SDValue A,SDValue B) {
3752 SDLoc DL(N);
3753 SDValue NewY =
3754 DAG.getNode(ISD::UADDO_CARRY, DL, Carry0->getVTList(), A, B, Z);
3755 Combiner.AddToWorklist(NewY.getNode());
3756 return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), X,
3757 DAG.getConstant(0, DL, X.getValueType()),
3758 NewY.getValue(1));
3759 };
3760
3761 /**
3762 * (uaddo A, B)
3763 * |
3764 * Sum
3765 * |
3766 * (uaddo_carry *, 0, Z)
3767 */
3768 if (Carry0.getOperand(0) == Carry1.getValue(0)) {
3769 return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
3770 }
3771
3772 /**
3773 * (uaddo_carry A, 0, Z)
3774 * |
3775 * Sum
3776 * |
3777 * (uaddo *, B)
3778 */
3779 if (Carry1.getOperand(0) == Carry0.getValue(0)) {
3780 return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
3781 }
3782
3783 if (Carry1.getOperand(1) == Carry0.getValue(0)) {
3784 return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
3785 }
3786
3787 return SDValue();
3788}
3789
3790// If we are facing some sort of diamond carry/borrow in/out pattern try to
3791// match patterns like:
3792//
3793// (uaddo A, B) CarryIn
3794// | \ |
3795// | \ |
3796// PartialSum PartialCarryOutX /
3797// | | /
3798// | ____|____________/
3799// | / |
3800// (uaddo *, *) \________
3801// | \ \
3802// | \ |
3803// | PartialCarryOutY |
3804// | \ |
3805// | \ /
3806// AddCarrySum | ______/
3807// | /
3808// CarryOut = (or *, *)
3809//
3810// And generate UADDO_CARRY (or USUBO_CARRY) with two result values:
3811//
3812// {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn)
3813//
3814// Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY
3815// with a single path for carry/borrow out propagation.
3817 SDValue N0, SDValue N1, SDNode *N) {
3818 SDValue Carry0 = getAsCarry(TLI, N0);
3819 if (!Carry0)
3820 return SDValue();
3821 SDValue Carry1 = getAsCarry(TLI, N1);
3822 if (!Carry1)
3823 return SDValue();
3824
3825 unsigned Opcode = Carry0.getOpcode();
3826 if (Opcode != Carry1.getOpcode())
3827 return SDValue();
3828 if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3829 return SDValue();
3830 // Guarantee identical type of CarryOut
3831 EVT CarryOutType = N->getValueType(0);
3832 if (CarryOutType != Carry0.getValue(1).getValueType() ||
3833 CarryOutType != Carry1.getValue(1).getValueType())
3834 return SDValue();
3835
3836 // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3837 // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3838 if (Carry1.getNode()->isOperandOf(Carry0.getNode()))
3839 std::swap(Carry0, Carry1);
3840
3841 // Check if nodes are connected in expected way.
3842 if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3843 Carry1.getOperand(1) != Carry0.getValue(0))
3844 return SDValue();
3845
3846 // The carry in value must be on the righthand side for subtraction.
3847 unsigned CarryInOperandNum =
3848 Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3849 if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3850 return SDValue();
3851 SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3852
3853 unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
3854 if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3855 return SDValue();
3856
3857 // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3858 CarryIn = getAsCarry(TLI, CarryIn, true);
3859 if (!CarryIn)
3860 return SDValue();
3861
3862 SDLoc DL(N);
3863 CarryIn = DAG.getBoolExtOrTrunc(CarryIn, DL, Carry1->getValueType(1),
3864 Carry1->getValueType(0));
3865 SDValue Merged =
3866 DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3867 Carry0.getOperand(1), CarryIn);
3868
3869 // Please note that because we have proven that the result of the UADDO/USUBO
3870 // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3871 // therefore prove that if the first UADDO/USUBO overflows, the second
3872 // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3873 // maximum value.
3874 //
3875 // 0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3876 // 0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3877 //
3878 // This is important because it means that OR and XOR can be used to merge
3879 // carry flags; and that AND can return a constant zero.
3880 //
3881 // TODO: match other operations that can merge flags (ADD, etc)
3882 DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3883 if (N->getOpcode() == ISD::AND)
3884 return DAG.getConstant(0, DL, CarryOutType);
3885 return Merged.getValue(1);
3886}
3887
3888SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1,
3889 SDValue CarryIn, SDNode *N) {
3890 // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip
3891 // carry.
3892 if (isBitwiseNot(N0))
3893 if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3894 SDLoc DL(N);
3895 SDValue Sub = DAG.getNode(ISD::USUBO_CARRY, DL, N->getVTList(), N1,
3896 N0.getOperand(0), NotC);
3897 return CombineTo(
3898 N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3899 }
3900
3901 // Iff the flag result is dead:
3902 // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry)
3903 // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3904 // or the dependency between the instructions.
3905 if ((N0.getOpcode() == ISD::ADD ||
3906 (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3907 N0.getValue(1) != CarryIn)) &&
3908 isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3909 return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(),
3910 N0.getOperand(0), N0.getOperand(1), CarryIn);
3911
3912 /**
3913 * When one of the uaddo_carry argument is itself a carry, we may be facing
3914 * a diamond carry propagation. In which case we try to transform the DAG
3915 * to ensure linear carry propagation if that is possible.
3916 */
3917 if (auto Y = getAsCarry(TLI, N1)) {
3918 // Because both are carries, Y and Z can be swapped.
3919 if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3920 return R;
3921 if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3922 return R;
3923 }
3924
3925 return SDValue();
3926}
3927
3928SDValue DAGCombiner::visitSADDO_CARRYLike(SDValue N0, SDValue N1,
3929 SDValue CarryIn, SDNode *N) {
3930 // fold (saddo_carry (xor a, -1), b, c) -> (ssubo_carry b, a, !c)
3931 if (isBitwiseNot(N0)) {
3932 if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true))
3933 return DAG.getNode(ISD::SSUBO_CARRY, SDLoc(N), N->getVTList(), N1,
3934 N0.getOperand(0), NotC);
3935 }
3936
3937 return SDValue();
3938}
3939
3940SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3941 SDValue N0 = N->getOperand(0);
3942 SDValue N1 = N->getOperand(1);
3943 SDValue CarryIn = N->getOperand(2);
3944 SDLoc DL(N);
3945
3946 // canonicalize constant to RHS
3947 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3948 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3949 if (N0C && !N1C)
3950 return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3951
3952 // fold (saddo_carry x, y, false) -> (saddo x, y)
3953 if (isNullConstant(CarryIn)) {
3954 if (!LegalOperations ||
3955 TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
3956 return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
3957 }
3958
3959 if (SDValue Combined = visitSADDO_CARRYLike(N0, N1, CarryIn, N))
3960 return Combined;
3961
3962 if (SDValue Combined = visitSADDO_CARRYLike(N1, N0, CarryIn, N))
3963 return Combined;
3964
3965 return SDValue();
3966}
3967
3968// Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3969// clamp/truncation if necessary.
3971 SDValue RHS, SelectionDAG &DAG,
3972 const SDLoc &DL) {
3973 assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3974 "Illegal truncation");
3975
3976 if (DstVT == SrcVT)
3977 return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3978
3979 // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3980 // clamping RHS.
3982 DstVT.getScalarSizeInBits());
3983 if (!DAG.MaskedValueIsZero(LHS, UpperBits))
3984 return SDValue();
3985
3986 SDValue SatLimit =
3988 DstVT.getScalarSizeInBits()),
3989 DL, SrcVT);
3990 RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
3991 RHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, RHS);
3992 LHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, LHS);
3993 return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3994}
3995
3996// Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3997// usubsat(a,b), optionally as a truncated type.
3998SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL) {
3999 if (N->getOpcode() != ISD::SUB ||
4000 !(!LegalOperations || hasOperation(ISD::USUBSAT, DstVT)))
4001 return SDValue();
4002
4003 EVT SubVT = N->getValueType(0);
4004 SDValue Op0 = N->getOperand(0);
4005 SDValue Op1 = N->getOperand(1);
4006
4007 // Try to find umax(a,b) - b or a - umin(a,b) patterns
4008 // they may be converted to usubsat(a,b).
4009 if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
4010 SDValue MaxLHS = Op0.getOperand(0);
4011 SDValue MaxRHS = Op0.getOperand(1);
4012 if (MaxLHS == Op1)
4013 return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, DL);
4014 if (MaxRHS == Op1)
4015 return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, DL);
4016 }
4017
4018 if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
4019 SDValue MinLHS = Op1.getOperand(0);
4020 SDValue MinRHS = Op1.getOperand(1);
4021 if (MinLHS == Op0)
4022 return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, DL);
4023 if (MinRHS == Op0)
4024 return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, DL);
4025 }
4026
4027 // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
4028 if (Op1.getOpcode() == ISD::TRUNCATE &&
4029 Op1.getOperand(0).getOpcode() == ISD::UMIN &&
4030 Op1.getOperand(0).hasOneUse()) {
4031 SDValue MinLHS = Op1.getOperand(0).getOperand(0);
4032 SDValue MinRHS = Op1.getOperand(0).getOperand(1);
4033 if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(0) == Op0)
4034 return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinLHS, MinRHS,
4035 DAG, DL);
4036 if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(0) == Op0)
4037 return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinRHS, MinLHS,
4038 DAG, DL);
4039 }
4040
4041 return SDValue();
4042}
4043
4044// Refinement of DAG/Type Legalisation (promotion) when CTLZ is used for
4045// counting leading ones. Broadly, it replaces the substraction with a left
4046// shift.
4047//
4048// * DAG Legalisation Pattern:
4049//
4050// (sub (ctlz (zeroextend (not Src)))
4051// BitWidthDiff)
4052//
4053// if BitWidthDiff == BitWidth(Node) - BitWidth(Src)
4054// -->
4055//
4056// (ctlz_zero_undef (not (shl (anyextend Src)
4057// BitWidthDiff)))
4058//
4059// * Type Legalisation Pattern:
4060//
4061// (sub (ctlz (and (xor Src XorMask)
4062// AndMask))
4063// BitWidthDiff)
4064//
4065// if AndMask has only trailing ones
4066// and MaskBitWidth(AndMask) == BitWidth(Node) - BitWidthDiff
4067// and XorMask has more trailing ones than AndMask
4068// -->
4069//
4070// (ctlz_zero_undef (not (shl Src BitWidthDiff)))
4071template <class MatchContextClass>
4073 const SDLoc DL(N);
4074 SDValue N0 = N->getOperand(0);
4075 EVT VT = N0.getValueType();
4076 unsigned BitWidth = VT.getScalarSizeInBits();
4077
4078 MatchContextClass Matcher(DAG, DAG.getTargetLoweringInfo(), N);
4079
4080 APInt AndMask;
4081 APInt XorMask;
4082 uint64_t BitWidthDiff;
4083
4084 SDValue CtlzOp;
4085 SDValue Src;
4086
4087 if (!sd_context_match(
4088 N, Matcher, m_Sub(m_Ctlz(m_Value(CtlzOp)), m_ConstInt(BitWidthDiff))))
4089 return SDValue();
4090
4091 if (sd_context_match(CtlzOp, Matcher, m_ZExt(m_Not(m_Value(Src))))) {
4092 // DAG Legalisation Pattern:
4093 // (sub (ctlz (zero_extend (not Op)) BitWidthDiff))
4094 if ((BitWidth - Src.getValueType().getScalarSizeInBits()) != BitWidthDiff)
4095 return SDValue();
4096
4097 Src = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Src);
4098 } else if (sd_context_match(CtlzOp, Matcher,
4099 m_And(m_Xor(m_Value(Src), m_ConstInt(XorMask)),
4100 m_ConstInt(AndMask)))) {
4101 // Type Legalisation Pattern:
4102 // (sub (ctlz (and (xor Op XorMask) AndMask)) BitWidthDiff)
4103 if (BitWidthDiff >= BitWidth)
4104 return SDValue();
4105 unsigned AndMaskWidth = BitWidth - BitWidthDiff;
4106 if (!(AndMask.isMask(AndMaskWidth) && XorMask.countr_one() >= AndMaskWidth))
4107 return SDValue();
4108 } else
4109 return SDValue();
4110
4111 SDValue ShiftConst = DAG.getShiftAmountConstant(BitWidthDiff, VT, DL);
4112 SDValue LShift = Matcher.getNode(ISD::SHL, DL, VT, Src, ShiftConst);
4113 SDValue Not =
4114 Matcher.getNode(ISD::XOR, DL, VT, LShift, DAG.getAllOnesConstant(DL, VT));
4115
4116 return Matcher.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, Not);
4117}
4118
4119// Fold sub(x, mul(divrem(x,y)[0], y)) to divrem(x, y)[1]
4121 const SDLoc &DL) {
4122 assert(N->getOpcode() == ISD::SUB && "Node must be a SUB");
4123 SDValue Sub0 = N->getOperand(0);
4124 SDValue Sub1 = N->getOperand(1);
4125
4126 auto CheckAndFoldMulCase = [&](SDValue DivRem, SDValue MaybeY) -> SDValue {
4127 if ((DivRem.getOpcode() == ISD::SDIVREM ||
4128 DivRem.getOpcode() == ISD::UDIVREM) &&
4129 DivRem.getResNo() == 0 && DivRem.getOperand(0) == Sub0 &&
4130 DivRem.getOperand(1) == MaybeY) {
4131 return SDValue(DivRem.getNode(), 1);
4132 }
4133 return SDValue();
4134 };
4135
4136 if (Sub1.getOpcode() == ISD::MUL) {
4137 // (sub x, (mul divrem(x,y)[0], y))
4138 SDValue Mul0 = Sub1.getOperand(0);
4139 SDValue Mul1 = Sub1.getOperand(1);
4140
4141 if (SDValue Res = CheckAndFoldMulCase(Mul0, Mul1))
4142 return Res;
4143
4144 if (SDValue Res = CheckAndFoldMulCase(Mul1, Mul0))
4145 return Res;
4146
4147 } else if (Sub1.getOpcode() == ISD::SHL) {
4148 // Handle (sub x, (shl divrem(x,y)[0], C)) where y = 1 << C
4149 SDValue Shl0 = Sub1.getOperand(0);
4150 SDValue Shl1 = Sub1.getOperand(1);
4151 // Check if Shl0 is divrem(x, Y)[0]
4152 if ((Shl0.getOpcode() == ISD::SDIVREM ||
4153 Shl0.getOpcode() == ISD::UDIVREM) &&
4154 Shl0.getResNo() == 0 && Shl0.getOperand(0) == Sub0) {
4155
4156 SDValue Divisor = Shl0.getOperand(1);
4157
4158 ConstantSDNode *DivC = isConstOrConstSplat(Divisor);
4160 if (!DivC || !ShC)
4161 return SDValue();
4162
4163 if (DivC->getAPIntValue().isPowerOf2() &&
4164 DivC->getAPIntValue().logBase2() == ShC->getAPIntValue())
4165 return SDValue(Shl0.getNode(), 1);
4166 }
4167 }
4168 return SDValue();
4169}
4170
4171// Since it may not be valid to emit a fold to zero for vector initializers
4172// check if we can before folding.
4173static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
4174 SelectionDAG &DAG, bool LegalOperations) {
4175 if (!VT.isVector())
4176 return DAG.getConstant(0, DL, VT);
4177 if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
4178 return DAG.getConstant(0, DL, VT);
4179 return SDValue();
4180}
4181
4182SDValue DAGCombiner::visitSUB(SDNode *N) {
4183 SDValue N0 = N->getOperand(0);
4184 SDValue N1 = N->getOperand(1);
4185 EVT VT = N0.getValueType();
4186 unsigned BitWidth = VT.getScalarSizeInBits();
4187 SDLoc DL(N);
4188
4190 return V;
4191
4192 // fold (sub x, x) -> 0
4193 if (N0 == N1)
4194 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4195
4196 // fold (sub c1, c2) -> c3
4197 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
4198 return C;
4199
4200 // fold vector ops
4201 if (VT.isVector()) {
4202 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4203 return FoldedVOp;
4204
4205 // fold (sub x, 0) -> x, vector edition
4207 return N0;
4208 }
4209
4210 // (sub x, ([v]select (ult x, y), 0, y)) -> (umin x, (sub x, y))
4211 // (sub x, ([v]select (uge x, y), y, 0)) -> (umin x, (sub x, y))
4212 if (N1.hasOneUse() && hasUMin(VT)) {
4213 SDValue Y;
4214 auto MS0 = m_Specific(N0);
4215 auto MVY = m_Value(Y);
4216 auto MZ = m_Zero();
4217 auto MCC1 = m_SpecificCondCode(ISD::SETULT);
4218 auto MCC2 = m_SpecificCondCode(ISD::SETUGE);
4219
4220 if (sd_match(N1, m_SelectCCLike(MS0, MVY, MZ, m_Deferred(Y), MCC1)) ||
4221 sd_match(N1, m_SelectCCLike(MS0, MVY, m_Deferred(Y), MZ, MCC2)) ||
4222 sd_match(N1, m_VSelect(m_SetCC(MS0, MVY, MCC1), MZ, m_Deferred(Y))) ||
4223 sd_match(N1, m_VSelect(m_SetCC(MS0, MVY, MCC2), m_Deferred(Y), MZ)))
4224
4225 return DAG.getNode(ISD::UMIN, DL, VT, N0,
4226 DAG.getNode(ISD::SUB, DL, VT, N0, Y));
4227 }
4228
4229 if (SDValue NewSel = foldBinOpIntoSelect(N))
4230 return NewSel;
4231
4232 // fold (sub x, c) -> (add x, -c)
4233 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N1))
4234 return DAG.getNode(ISD::ADD, DL, VT, N0,
4235 DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
4236
4237 if (isNullOrNullSplat(N0)) {
4238 // Right-shifting everything out but the sign bit followed by negation is
4239 // the same as flipping arithmetic/logical shift type without the negation:
4240 // -(X >>u 31) -> (X >>s 31)
4241 // -(X >>s 31) -> (X >>u 31)
4242 if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
4243 ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
4244 if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
4245 auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
4246 if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
4247 return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
4248 }
4249 }
4250
4251 // 0 - X --> 0 if the sub is NUW.
4252 if (N->getFlags().hasNoUnsignedWrap())
4253 return N0;
4254
4256 // N1 is either 0 or the minimum signed value. If the sub is NSW, then
4257 // N1 must be 0 because negating the minimum signed value is undefined.
4258 if (N->getFlags().hasNoSignedWrap())
4259 return N0;
4260
4261 // 0 - X --> X if X is 0 or the minimum signed value.
4262 return N1;
4263 }
4264
4265 // Convert 0 - abs(x).
4266 if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
4268 if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
4269 return Result;
4270
4271 // Similar to the previous rule, but this time targeting an expanded abs.
4272 // (sub 0, (max X, (sub 0, X))) --> (min X, (sub 0, X))
4273 // as well as
4274 // (sub 0, (min X, (sub 0, X))) --> (max X, (sub 0, X))
4275 // Note that these two are applicable to both signed and unsigned min/max.
4276 SDValue X;
4277 SDValue S0;
4278 auto NegPat = m_AllOf(m_Neg(m_Deferred(X)), m_Value(S0));
4279 if (sd_match(N1, m_OneUse(m_AnyOf(m_SMax(m_Value(X), NegPat),
4280 m_UMax(m_Value(X), NegPat),
4281 m_SMin(m_Value(X), NegPat),
4282 m_UMin(m_Value(X), NegPat))))) {
4283 unsigned NewOpc = ISD::getInverseMinMaxOpcode(N1->getOpcode());
4284 if (hasOperation(NewOpc, VT))
4285 return DAG.getNode(NewOpc, DL, VT, X, S0);
4286 }
4287
4288 // Fold neg(splat(neg(x)) -> splat(x)
4289 if (VT.isVector()) {
4290 SDValue N1S = DAG.getSplatValue(N1, true);
4291 if (N1S && N1S.getOpcode() == ISD::SUB &&
4292 isNullConstant(N1S.getOperand(0)))
4293 return DAG.getSplat(VT, DL, N1S.getOperand(1));
4294 }
4295
4296 // sub 0, (and x, 1) --> SIGN_EXTEND_INREG x, i1
4297 if (N1.getOpcode() == ISD::AND && N1.hasOneUse() &&
4298 isOneOrOneSplat(N1->getOperand(1))) {
4299 EVT ExtVT = VT.changeElementType(*DAG.getContext(), MVT::i1);
4302 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N1->getOperand(0),
4303 DAG.getValueType(ExtVT));
4304 }
4305 }
4306 }
4307
4308 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
4310 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
4311
4312 // fold (A - (0-B)) -> A+B
4313 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
4314 return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
4315
4316 // fold A-(A-B) -> B
4317 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
4318 return N1.getOperand(1);
4319
4320 // fold (A+B)-A -> B
4321 if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
4322 return N0.getOperand(1);
4323
4324 // fold (A+B)-B -> A
4325 if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
4326 return N0.getOperand(0);
4327
4328 // fold (A+C1)-C2 -> A+(C1-C2)
4329 if (N0.getOpcode() == ISD::ADD) {
4330 SDValue N01 = N0.getOperand(1);
4331 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N01, N1}))
4332 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
4333 }
4334
4335 // fold C2-(A+C1) -> (C2-C1)-A
4336 if (N1.getOpcode() == ISD::ADD) {
4337 SDValue N11 = N1.getOperand(1);
4338 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}))
4339 return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
4340 }
4341
4342 // fold (A-C1)-C2 -> A-(C1+C2)
4343 if (N0.getOpcode() == ISD::SUB) {
4344 SDValue N01 = N0.getOperand(1);
4345 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N01, N1}))
4346 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
4347 }
4348
4349 // fold (c1-A)-c2 -> (c1-c2)-A
4350 if (N0.getOpcode() == ISD::SUB) {
4351 SDValue N00 = N0.getOperand(0);
4352 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N00, N1}))
4353 return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
4354 }
4355
4356 SDValue A, B, C;
4357
4358 // fold ((A+(B+C))-B) -> A+C
4359 if (sd_match(N0, m_Add(m_Value(A), m_Add(m_Specific(N1), m_Value(C)))))
4360 return DAG.getNode(ISD::ADD, DL, VT, A, C);
4361
4362 // fold ((A+(B-C))-B) -> A-C
4363 if (sd_match(N0, m_Add(m_Value(A), m_Sub(m_Specific(N1), m_Value(C)))))
4364 return DAG.getNode(ISD::SUB, DL, VT, A, C);
4365
4366 // fold ((A-(B-C))-C) -> A-B
4367 if (sd_match(N0, m_Sub(m_Value(A), m_Sub(m_Value(B), m_Specific(N1)))))
4368 return DAG.getNode(ISD::SUB, DL, VT, A, B);
4369
4370 // fold (A-(B-C)) -> A+(C-B)
4371 if (sd_match(N1, m_OneUse(m_Sub(m_Value(B), m_Value(C)))))
4372 return DAG.getNode(ISD::ADD, DL, VT, N0,
4373 DAG.getNode(ISD::SUB, DL, VT, C, B));
4374
4375 // A - (A & B) -> A & (~B)
4376 if (sd_match(N1, m_And(m_Specific(N0), m_Value(B))) &&
4377 (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true)))
4378 return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getNOT(DL, B, VT));
4379
4380 // fold (A - (-B * C)) -> (A + (B * C))
4381 if (sd_match(N1, m_OneUse(m_Mul(m_Neg(m_Value(B)), m_Value(C)))))
4382 return DAG.getNode(ISD::ADD, DL, VT, N0,
4383 DAG.getNode(ISD::MUL, DL, VT, B, C));
4384
4385 // If either operand of a sub is undef, the result is undef
4386 if (N0.isUndef())
4387 return N0;
4388 if (N1.isUndef())
4389 return N1;
4390
4391 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
4392 return V;
4393
4394 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
4395 return V;
4396
4397 // Try to match AVGCEIL fixedwidth pattern
4398 if (SDValue V = foldSubToAvg(N, DL))
4399 return V;
4400
4401 if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, DL))
4402 return V;
4403
4404 if (SDValue V = foldSubToUSubSat(VT, N, DL))
4405 return V;
4406
4407 if (SDValue V = foldRemainderIdiom(N, DAG, DL))
4408 return V;
4409
4410 // (A - B) - 1 -> add (xor B, -1), A
4412 m_One(/*AllowUndefs=*/true))))
4413 return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
4414
4415 // Look for:
4416 // sub y, (xor x, -1)
4417 // And if the target does not like this form then turn into:
4418 // add (add x, y), 1
4419 if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
4420 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
4421 return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
4422 }
4423
4424 // Hoist one-use addition by non-opaque constant:
4425 // (x + C) - y -> (x - y) + C
4426 if (!reassociationCanBreakAddressingModePattern(ISD::SUB, DL, N, N0, N1) &&
4427 N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
4428 isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
4429 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
4430 return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
4431 }
4432 // y - (x + C) -> (y - x) - C
4433 if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
4434 isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
4435 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
4436 return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
4437 }
4438 // (x - C) - y -> (x - y) - C
4439 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
4440 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4441 isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
4442 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
4443 return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
4444 }
4445 // (C - x) - y -> C - (x + y)
4446 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4447 isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
4448 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
4449 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
4450 }
4451
4452 // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
4453 // rather than 'sub 0/1' (the sext should get folded).
4454 // sub X, (zext i1 Y) --> add X, (sext i1 Y)
4455 if (N1.getOpcode() == ISD::ZERO_EXTEND &&
4456 N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
4457 TLI.getBooleanContents(VT) ==
4459 SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
4460 return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
4461 }
4462
4463 // fold B = sra (A, size(A)-1); sub (xor (A, B), B) -> (abs A)
4464 if ((!LegalOperations || hasOperation(ISD::ABS, VT)) &&
4466 sd_match(N0, m_Xor(m_Specific(A), m_Specific(N1))))
4467 return DAG.getNode(ISD::ABS, DL, VT, A);
4468
4469 // If the relocation model supports it, consider symbol offsets.
4470 if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
4471 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
4472 // fold (sub Sym+c1, Sym+c2) -> c1-c2
4473 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
4474 if (GA->getGlobal() == GB->getGlobal())
4475 return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
4476 DL, VT);
4477 }
4478
4479 // sub X, (sextinreg Y i1) -> add X, (and Y 1)
4480 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4481 VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
4482 if (TN->getVT() == MVT::i1) {
4483 SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
4484 DAG.getConstant(1, DL, VT));
4485 return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
4486 }
4487 }
4488
4489 // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
4490 // avoid if ISD::MUL handling is poor and ISD::SHL isn't an option.
4491 if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
4492 const APInt &IntVal = N1.getConstantOperandAPInt(0);
4493 if (!IntVal.isPowerOf2() ||
4494 hasOperation(ISD::MUL, N1.getOperand(0).getValueType()))
4495 return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
4496 }
4497
4498 // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
4499 if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
4500 APInt NewStep = -N1.getConstantOperandAPInt(0);
4501 return DAG.getNode(ISD::ADD, DL, VT, N0,
4502 DAG.getStepVector(DL, VT, NewStep));
4503 }
4504
4505 // Prefer an add for more folding potential and possibly better codegen:
4506 // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
4507 if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
4508 SDValue ShAmt = N1.getOperand(1);
4509 ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
4510 if (ShAmtC && ShAmtC->getAPIntValue() == (BitWidth - 1)) {
4511 SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
4512 return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
4513 }
4514 }
4515
4516 // As with the previous fold, prefer add for more folding potential.
4517 // Subtracting SMIN/0 is the same as adding SMIN/0:
4518 // N0 - (X << BW-1) --> N0 + (X << BW-1)
4519 if (N1.getOpcode() == ISD::SHL) {
4520 ConstantSDNode *ShlC = isConstOrConstSplat(N1.getOperand(1));
4521 if (ShlC && ShlC->getAPIntValue() == (BitWidth - 1))
4522 return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
4523 }
4524
4525 // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry)
4526 if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(N0.getOperand(1)) &&
4527 N0.getResNo() == 0 && N0.hasOneUse())
4528 return DAG.getNode(ISD::USUBO_CARRY, DL, N0->getVTList(),
4529 N0.getOperand(0), N1, N0.getOperand(2));
4530
4532 // (sub Carry, X) -> (uaddo_carry (sub 0, X), 0, Carry)
4533 if (SDValue Carry = getAsCarry(TLI, N0)) {
4534 SDValue X = N1;
4535 SDValue Zero = DAG.getConstant(0, DL, VT);
4536 SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
4537 return DAG.getNode(ISD::UADDO_CARRY, DL,
4538 DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
4539 Carry);
4540 }
4541 }
4542
4543 // If there's no chance of borrowing from adjacent bits, then sub is xor:
4544 // sub C0, X --> xor X, C0
4545 if (ConstantSDNode *C0 = isConstOrConstSplat(N0)) {
4546 if (!C0->isOpaque()) {
4547 const APInt &C0Val = C0->getAPIntValue();
4548 const APInt &MaybeOnes = ~DAG.computeKnownBits(N1).Zero;
4549 if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
4550 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
4551 }
4552 }
4553
4554 // smax(a,b) - smin(a,b) --> abds(a,b)
4555 if ((!LegalOperations || hasOperation(ISD::ABDS, VT)) &&
4556 sd_match(N0, &DAG, m_SMaxLike(m_Value(A), m_Value(B))) &&
4557 sd_match(N1, &DAG, m_SMinLike(m_Specific(A), m_Specific(B))))
4558 return DAG.getNode(ISD::ABDS, DL, VT, A, B);
4559
4560 // smin(a,b) - smax(a,b) --> neg(abds(a,b))
4561 if (hasOperation(ISD::ABDS, VT) &&
4562 sd_match(N0, &DAG, m_SMinLike(m_Value(A), m_Value(B))) &&
4563 sd_match(N1, &DAG, m_SMaxLike(m_Specific(A), m_Specific(B))))
4564 return DAG.getNegative(DAG.getNode(ISD::ABDS, DL, VT, A, B), DL, VT);
4565
4566 // umax(a,b) - umin(a,b) --> abdu(a,b)
4567 if ((!LegalOperations || hasOperation(ISD::ABDU, VT)) &&
4568 sd_match(N0, &DAG, m_UMaxLike(m_Value(A), m_Value(B))) &&
4569 sd_match(N1, &DAG, m_UMinLike(m_Specific(A), m_Specific(B))))
4570 return DAG.getNode(ISD::ABDU, DL, VT, A, B);
4571
4572 // umin(a,b) - umax(a,b) --> neg(abdu(a,b))
4573 if (hasOperation(ISD::ABDU, VT) &&
4574 sd_match(N0, &DAG, m_UMinLike(m_Value(A), m_Value(B))) &&
4575 sd_match(N1, &DAG, m_UMaxLike(m_Specific(A), m_Specific(B))))
4576 return DAG.getNegative(DAG.getNode(ISD::ABDU, DL, VT, A, B), DL, VT);
4577
4578 return SDValue();
4579}
4580
4581SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
4582 unsigned Opcode = N->getOpcode();
4583 SDValue N0 = N->getOperand(0);
4584 SDValue N1 = N->getOperand(1);
4585 EVT VT = N0.getValueType();
4586 bool IsSigned = Opcode == ISD::SSUBSAT;
4587 SDLoc DL(N);
4588
4589 // fold (sub_sat x, undef) -> 0
4590 if (N0.isUndef() || N1.isUndef())
4591 return DAG.getConstant(0, DL, VT);
4592
4593 // fold (sub_sat x, x) -> 0
4594 if (N0 == N1)
4595 return DAG.getConstant(0, DL, VT);
4596
4597 // fold (sub_sat c1, c2) -> c3
4598 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4599 return C;
4600
4601 // fold vector ops
4602 if (VT.isVector()) {
4603 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4604 return FoldedVOp;
4605
4606 // fold (sub_sat x, 0) -> x, vector edition
4608 return N0;
4609 }
4610
4611 // fold (sub_sat x, 0) -> x
4612 if (isNullConstant(N1))
4613 return N0;
4614
4615 // If it cannot overflow, transform into an sub.
4616 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4617 return DAG.getNode(ISD::SUB, DL, VT, N0, N1);
4618
4619 return SDValue();
4620}
4621
4622SDValue DAGCombiner::visitSUBC(SDNode *N) {
4623 SDValue N0 = N->getOperand(0);
4624 SDValue N1 = N->getOperand(1);
4625 EVT VT = N0.getValueType();
4626 SDLoc DL(N);
4627
4628 // If the flag result is dead, turn this into an SUB.
4629 if (!N->hasAnyUseOfValue(1))
4630 return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4631 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4632
4633 // fold (subc x, x) -> 0 + no borrow
4634 if (N0 == N1)
4635 return CombineTo(N, DAG.getConstant(0, DL, VT),
4636 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4637
4638 // fold (subc x, 0) -> x + no borrow
4639 if (isNullConstant(N1))
4640 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4641
4642 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4643 if (isAllOnesConstant(N0))
4644 return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
4645 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4646
4647 return SDValue();
4648}
4649
4650SDValue DAGCombiner::visitSUBO(SDNode *N) {
4651 SDValue N0 = N->getOperand(0);
4652 SDValue N1 = N->getOperand(1);
4653 EVT VT = N0.getValueType();
4654 bool IsSigned = (ISD::SSUBO == N->getOpcode());
4655
4656 EVT CarryVT = N->getValueType(1);
4657 SDLoc DL(N);
4658
4659 // If the flag result is dead, turn this into an SUB.
4660 if (!N->hasAnyUseOfValue(1))
4661 return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4662 DAG.getUNDEF(CarryVT));
4663
4664 // fold (subo x, x) -> 0 + no borrow
4665 if (N0 == N1)
4666 return CombineTo(N, DAG.getConstant(0, DL, VT),
4667 DAG.getConstant(0, DL, CarryVT));
4668
4669 // fold (subox, c) -> (addo x, -c)
4670 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N1))
4671 if (IsSigned && !N1C->isMinSignedValue())
4672 return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
4673 DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
4674
4675 // fold (subo x, 0) -> x + no borrow
4676 if (isNullOrNullSplat(N1))
4677 return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
4678
4679 // If it cannot overflow, transform into an sub.
4680 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4681 return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4682 DAG.getConstant(0, DL, CarryVT));
4683
4684 // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4685 if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
4686 return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
4687 DAG.getConstant(0, DL, CarryVT));
4688
4689 return SDValue();
4690}
4691
4692SDValue DAGCombiner::visitSUBE(SDNode *N) {
4693 SDValue N0 = N->getOperand(0);
4694 SDValue N1 = N->getOperand(1);
4695 SDValue CarryIn = N->getOperand(2);
4696
4697 // fold (sube x, y, false) -> (subc x, y)
4698 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
4699 return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
4700
4701 return SDValue();
4702}
4703
4704SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) {
4705 SDValue N0 = N->getOperand(0);
4706 SDValue N1 = N->getOperand(1);
4707 SDValue CarryIn = N->getOperand(2);
4708
4709 // fold (usubo_carry x, y, false) -> (usubo x, y)
4710 if (isNullConstant(CarryIn)) {
4711 if (!LegalOperations ||
4712 TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
4713 return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
4714 }
4715
4716 return SDValue();
4717}
4718
4719SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
4720 SDValue N0 = N->getOperand(0);
4721 SDValue N1 = N->getOperand(1);
4722 SDValue CarryIn = N->getOperand(2);
4723
4724 // fold (ssubo_carry x, y, false) -> (ssubo x, y)
4725 if (isNullConstant(CarryIn)) {
4726 if (!LegalOperations ||
4727 TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
4728 return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
4729 }
4730
4731 return SDValue();
4732}
4733
4734// Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
4735// UMULFIXSAT here.
4736SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4737 SDValue N0 = N->getOperand(0);
4738 SDValue N1 = N->getOperand(1);
4739 SDValue Scale = N->getOperand(2);
4740 EVT VT = N0.getValueType();
4741
4742 // fold (mulfix x, undef, scale) -> 0
4743 if (N0.isUndef() || N1.isUndef())
4744 return DAG.getConstant(0, SDLoc(N), VT);
4745
4746 // Canonicalize constant to RHS (vector doesn't have to splat)
4749 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
4750
4751 // fold (mulfix x, 0, scale) -> 0
4752 if (isNullConstant(N1))
4753 return DAG.getConstant(0, SDLoc(N), VT);
4754
4755 return SDValue();
4756}
4757
4758template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
4759 SDValue N0 = N->getOperand(0);
4760 SDValue N1 = N->getOperand(1);
4761 EVT VT = N0.getValueType();
4762 unsigned BitWidth = VT.getScalarSizeInBits();
4763 SDLoc DL(N);
4764 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
4765 MatchContextClass Matcher(DAG, TLI, N);
4766
4767 // fold (mul x, undef) -> 0
4768 if (N0.isUndef() || N1.isUndef())
4769 return DAG.getConstant(0, DL, VT);
4770
4771 // fold (mul c1, c2) -> c1*c2
4772 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}))
4773 return C;
4774
4775 // canonicalize constant to RHS (vector doesn't have to splat)
4778 return Matcher.getNode(ISD::MUL, DL, VT, N1, N0);
4779
4780 bool N1IsConst = false;
4781 bool N1IsOpaqueConst = false;
4782 APInt ConstValue1;
4783
4784 // fold vector ops
4785 if (VT.isVector()) {
4786 // TODO: Change this to use SimplifyVBinOp when it supports VP op.
4787 if (!UseVP)
4788 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4789 return FoldedVOp;
4790
4791 N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
4792 assert((!N1IsConst || ConstValue1.getBitWidth() == BitWidth) &&
4793 "Splat APInt should be element width");
4794 } else {
4795 N1IsConst = isa<ConstantSDNode>(N1);
4796 if (N1IsConst) {
4797 ConstValue1 = N1->getAsAPIntVal();
4798 N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
4799 }
4800 }
4801
4802 // fold (mul x, 0) -> 0
4803 if (N1IsConst && ConstValue1.isZero())
4804 return N1;
4805
4806 // fold (mul x, 1) -> x
4807 if (N1IsConst && ConstValue1.isOne())
4808 return N0;
4809
4810 if (!UseVP)
4811 if (SDValue NewSel = foldBinOpIntoSelect(N))
4812 return NewSel;
4813
4814 // fold (mul x, -1) -> 0-x
4815 if (N1IsConst && ConstValue1.isAllOnes())
4816 return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
4817
4818 // fold (mul x, (1 << c)) -> x << c
4819 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4820 (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4821 if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
4822 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4823 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4824 SDNodeFlags Flags;
4825 Flags.setNoUnsignedWrap(N->getFlags().hasNoUnsignedWrap());
4826 // TODO: Preserve setNoSignedWrap if LogBase2 isn't BitWidth - 1.
4827 return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc, Flags);
4828 }
4829 }
4830
4831 // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4832 if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4833 unsigned Log2Val = (-ConstValue1).logBase2();
4834
4835 // FIXME: If the input is something that is easily negated (e.g. a
4836 // single-use add), we should put the negate there.
4837 return Matcher.getNode(
4838 ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
4839 Matcher.getNode(ISD::SHL, DL, VT, N0,
4840 DAG.getShiftAmountConstant(Log2Val, VT, DL)));
4841 }
4842
4843 // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4844 // hi result is in use in case we hit this mid-legalization.
4845 if (!UseVP) {
4846 for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4847 if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4848 SDVTList LoHiVT = DAG.getVTList(VT, VT);
4849 // TODO: Can we match commutable operands with getNodeIfExists?
4850 if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4851 if (LoHi->hasAnyUseOfValue(1))
4852 return SDValue(LoHi, 0);
4853 if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4854 if (LoHi->hasAnyUseOfValue(1))
4855 return SDValue(LoHi, 0);
4856 }
4857 }
4858 }
4859
4860 // Try to transform:
4861 // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4862 // mul x, (2^N + 1) --> add (shl x, N), x
4863 // mul x, (2^N - 1) --> sub (shl x, N), x
4864 // Examples: x * 33 --> (x << 5) + x
4865 // x * 15 --> (x << 4) - x
4866 // x * -33 --> -((x << 5) + x)
4867 // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4868 // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4869 // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4870 // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4871 // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4872 // x * 0xf800 --> (x << 16) - (x << 11)
4873 // x * -0x8800 --> -((x << 15) + (x << 11))
4874 // x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4875 if (!UseVP && N1IsConst &&
4876 TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4877 // TODO: We could handle more general decomposition of any constant by
4878 // having the target set a limit on number of ops and making a
4879 // callback to determine that sequence (similar to sqrt expansion).
4880 unsigned MathOp = ISD::DELETED_NODE;
4881 APInt MulC = ConstValue1.abs();
4882 // The constant `2` should be treated as (2^0 + 1).
4883 unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero();
4884 MulC.lshrInPlace(TZeros);
4885 if ((MulC - 1).isPowerOf2())
4886 MathOp = ISD::ADD;
4887 else if ((MulC + 1).isPowerOf2())
4888 MathOp = ISD::SUB;
4889
4890 if (MathOp != ISD::DELETED_NODE) {
4891 unsigned ShAmt =
4892 MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4893 ShAmt += TZeros;
4894 assert(ShAmt < BitWidth &&
4895 "multiply-by-constant generated out of bounds shift");
4896 SDValue Shl =
4897 DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
4898 SDValue R =
4899 TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
4900 DAG.getNode(ISD::SHL, DL, VT, N0,
4901 DAG.getConstant(TZeros, DL, VT)))
4902 : DAG.getNode(MathOp, DL, VT, Shl, N0);
4903 if (ConstValue1.isNegative())
4904 R = DAG.getNegative(R, DL, VT);
4905 return R;
4906 }
4907 }
4908
4909 // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4910 if (sd_context_match(N0, Matcher, m_Opc(ISD::SHL))) {
4911 SDValue N01 = N0.getOperand(1);
4912 if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4913 return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
4914 }
4915
4916 // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4917 // use.
4918 {
4919 SDValue Sh, Y;
4920
4921 // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4922 if (sd_context_match(N0, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4924 Sh = N0; Y = N1;
4925 } else if (sd_context_match(N1, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4927 Sh = N1; Y = N0;
4928 }
4929
4930 if (Sh.getNode()) {
4931 SDValue Mul = Matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4932 return Matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4933 }
4934 }
4935
4936 // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4937 if (sd_context_match(N0, Matcher, m_Opc(ISD::ADD)) &&
4941 return Matcher.getNode(
4942 ISD::ADD, DL, VT,
4943 Matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4944 Matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4945
4946 // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4947 // avoid if ISD::MUL handling is poor and ISD::SHL isn't an option.
4948 ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4949 if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
4950 const APInt &C0 = N0.getConstantOperandAPInt(0);
4951 const APInt &C1 = NC1->getAPIntValue();
4952 if (!C0.isPowerOf2() || C1.isPowerOf2() ||
4953 hasOperation(ISD::MUL, NC1->getValueType(0)))
4954 return DAG.getVScale(DL, VT, C0 * C1);
4955 }
4956
4957 // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4958 APInt MulVal;
4959 if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
4960 ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4961 const APInt &C0 = N0.getConstantOperandAPInt(0);
4962 APInt NewStep = C0 * MulVal;
4963 return DAG.getStepVector(DL, VT, NewStep);
4964 }
4965
4966 // Fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)
4967 SDValue X;
4968 if (!UseVP && (!LegalOperations || hasOperation(ISD::ABS, VT)) &&
4970 N, Matcher,
4972 m_Deferred(X)))) {
4973 return Matcher.getNode(ISD::ABS, DL, VT, X);
4974 }
4975
4976 // Fold ((mul x, 0/undef) -> 0,
4977 // (mul x, 1) -> x) -> x)
4978 // -> and(x, mask)
4979 // We can replace vectors with '0' and '1' factors with a clearing mask.
4980 if (VT.isFixedLengthVector()) {
4981 unsigned NumElts = VT.getVectorNumElements();
4982 SmallBitVector ClearMask;
4983 ClearMask.reserve(NumElts);
4984 auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4985 if (!V || V->isZero()) {
4986 ClearMask.push_back(true);
4987 return true;
4988 }
4989 ClearMask.push_back(false);
4990 return V->isOne();
4991 };
4992 if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
4993 ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
4994 assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4995 EVT LegalSVT = N1.getOperand(0).getValueType();
4996 SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
4997 SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
4999 for (unsigned I = 0; I != NumElts; ++I)
5000 if (ClearMask[I])
5001 Mask[I] = Zero;
5002 return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
5003 }
5004 }
5005
5006 // reassociate mul
5007 // TODO: Change reassociateOps to support vp ops.
5008 if (!UseVP)
5009 if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
5010 return RMUL;
5011
5012 // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
5013 // TODO: Change reassociateReduction to support vp ops.
5014 if (!UseVP)
5015 if (SDValue SD =
5016 reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
5017 return SD;
5018
5019 // Simplify the operands using demanded-bits information.
5021 return SDValue(N, 0);
5022
5023 return SDValue();
5024}
5025
5026/// Return true if divmod libcall is available.
5027static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
5028 const SelectionDAG &DAG) {
5029 RTLIB::Libcall LC;
5030 EVT NodeType = Node->getValueType(0);
5031 if (!NodeType.isSimple())
5032 return false;
5033 switch (NodeType.getSimpleVT().SimpleTy) {
5034 default: return false; // No libcall for vector types.
5035 case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
5036 case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
5037 case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
5038 case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
5039 case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
5040 }
5041
5042 return DAG.getLibcalls().getLibcallImpl(LC) != RTLIB::Unsupported;
5043}
5044
5045/// Issue divrem if both quotient and remainder are needed.
5046SDValue DAGCombiner::useDivRem(SDNode *Node) {
5047 if (Node->use_empty())
5048 return SDValue(); // This is a dead node, leave it alone.
5049
5050 unsigned Opcode = Node->getOpcode();
5051 bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
5052 unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
5053
5054 // DivMod lib calls can still work on non-legal types if using lib-calls.
5055 EVT VT = Node->getValueType(0);
5056 if (VT.isVector() || !VT.isInteger())
5057 return SDValue();
5058
5059 if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
5060 return SDValue();
5061
5062 // If DIVREM is going to get expanded into a libcall,
5063 // but there is no libcall available, then don't combine.
5064 if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
5065 !isDivRemLibcallAvailable(Node, isSigned, DAG))
5066 return SDValue();
5067
5068 // If div is legal, it's better to do the normal expansion
5069 unsigned OtherOpcode = 0;
5070 if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
5071 OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
5072 if (TLI.isOperationLegalOrCustom(Opcode, VT))
5073 return SDValue();
5074 } else {
5075 OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5076 if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
5077 return SDValue();
5078 }
5079
5080 SDValue Op0 = Node->getOperand(0);
5081 SDValue Op1 = Node->getOperand(1);
5082 SDValue combined;
5083 for (SDNode *User : Op0->users()) {
5084 if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
5085 User->use_empty())
5086 continue;
5087 // Convert the other matching node(s), too;
5088 // otherwise, the DIVREM may get target-legalized into something
5089 // target-specific that we won't be able to recognize.
5090 unsigned UserOpc = User->getOpcode();
5091 if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
5092 User->getOperand(0) == Op0 &&
5093 User->getOperand(1) == Op1) {
5094 if (!combined) {
5095 if (UserOpc == OtherOpcode) {
5096 SDVTList VTs = DAG.getVTList(VT, VT);
5097 combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
5098 } else if (UserOpc == DivRemOpc) {
5099 combined = SDValue(User, 0);
5100 } else {
5101 assert(UserOpc == Opcode);
5102 continue;
5103 }
5104 }
5105 if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
5106 CombineTo(User, combined);
5107 else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
5108 CombineTo(User, combined.getValue(1));
5109 }
5110 }
5111 return combined;
5112}
5113
5115 SDValue N0 = N->getOperand(0);
5116 SDValue N1 = N->getOperand(1);
5117 EVT VT = N->getValueType(0);
5118 SDLoc DL(N);
5119
5120 unsigned Opc = N->getOpcode();
5121 bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
5122
5123 // X / undef -> undef
5124 // X % undef -> undef
5125 // X / 0 -> undef
5126 // X % 0 -> undef
5127 // NOTE: This includes vectors where any divisor element is zero/undef.
5128 if (DAG.isUndef(Opc, {N0, N1}))
5129 return DAG.getUNDEF(VT);
5130
5131 // undef / X -> 0
5132 // undef % X -> 0
5133 if (N0.isUndef())
5134 return DAG.getConstant(0, DL, VT);
5135
5136 // 0 / X -> 0
5137 // 0 % X -> 0
5139 if (N0C && N0C->isZero())
5140 return N0;
5141
5142 // X / X -> 1
5143 // X % X -> 0
5144 if (N0 == N1)
5145 return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
5146
5147 // X / 1 -> X
5148 // X % 1 -> 0
5149 // If this is a boolean op (single-bit element type), we can't have
5150 // division-by-zero or remainder-by-zero, so assume the divisor is 1.
5151 // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
5152 // it's a 1.
5153 if (isOneOrOneSplat(N1) || (VT.getScalarType() == MVT::i1))
5154 return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
5155
5156 return SDValue();
5157}
5158
5159SDValue DAGCombiner::visitSDIV(SDNode *N) {
5160 SDValue N0 = N->getOperand(0);
5161 SDValue N1 = N->getOperand(1);
5162 EVT VT = N->getValueType(0);
5163 EVT CCVT = getSetCCResultType(VT);
5164 SDLoc DL(N);
5165
5166 // fold (sdiv c1, c2) -> c1/c2
5167 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
5168 return C;
5169
5170 // fold vector ops
5171 if (VT.isVector())
5172 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5173 return FoldedVOp;
5174
5175 // fold (sdiv X, -1) -> 0-X
5176 ConstantSDNode *N1C = isConstOrConstSplat(N1);
5177 if (N1C && N1C->isAllOnes())
5178 return DAG.getNegative(N0, DL, VT);
5179
5180 // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
5181 if (N1C && N1C->isMinSignedValue())
5182 return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
5183 DAG.getConstant(1, DL, VT),
5184 DAG.getConstant(0, DL, VT));
5185
5186 if (SDValue V = simplifyDivRem(N, DAG))
5187 return V;
5188
5189 if (SDValue NewSel = foldBinOpIntoSelect(N))
5190 return NewSel;
5191
5192 // If we know the sign bits of both operands are zero, strength reduce to a
5193 // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
5194 if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
5195 return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
5196
5197 if (SDValue V = visitSDIVLike(N0, N1, N)) {
5198 // If the corresponding remainder node exists, update its users with
5199 // (Dividend - (Quotient * Divisor).
5200 if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
5201 { N0, N1 })) {
5202 // If the sdiv has the exact flag we shouldn't propagate it to the
5203 // remainder node.
5204 if (!N->getFlags().hasExact()) {
5205 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
5206 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5207 AddToWorklist(Mul.getNode());
5208 AddToWorklist(Sub.getNode());
5209 CombineTo(RemNode, Sub);
5210 }
5211 }
5212 return V;
5213 }
5214
5215 // sdiv, srem -> sdivrem
5216 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
5217 // true. Otherwise, we break the simplification logic in visitREM().
5218 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5219 if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
5220 if (SDValue DivRem = useDivRem(N))
5221 return DivRem;
5222
5223 return SDValue();
5224}
5225
5226static bool isDivisorPowerOfTwo(SDValue Divisor) {
5227 // Helper for determining whether a value is a power-2 constant scalar or a
5228 // vector of such elements.
5229 auto IsPowerOfTwo = [](ConstantSDNode *C) {
5230 if (C->isZero() || C->isOpaque())
5231 return false;
5232 if (C->getAPIntValue().isPowerOf2())
5233 return true;
5234 if (C->getAPIntValue().isNegatedPowerOf2())
5235 return true;
5236 return false;
5237 };
5238
5239 return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo, /*AllowUndefs=*/false,
5240 /*AllowTruncation=*/true);
5241}
5242
5243SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
5244 SDLoc DL(N);
5245 EVT VT = N->getValueType(0);
5246 EVT CCVT = getSetCCResultType(VT);
5247 unsigned BitWidth = VT.getScalarSizeInBits();
5248 unsigned MaxLegalDivRemBitWidth = TLI.getMaxDivRemBitWidthSupported();
5249
5250 // fold (sdiv X, pow2) -> simple ops after legalize
5251 // FIXME: We check for the exact bit here because the generic lowering gives
5252 // better results in that case. The target-specific lowering should learn how
5253 // to handle exact sdivs efficiently. An exception is made for large bitwidths
5254 // exceeding what the target can natively support, as division expansion was
5255 // skipped in favor of this optimization.
5256 if ((!N->getFlags().hasExact() || BitWidth > MaxLegalDivRemBitWidth) &&
5257 isDivisorPowerOfTwo(N1)) {
5258 // Target-specific implementation of sdiv x, pow2.
5259 if (SDValue Res = BuildSDIVPow2(N))
5260 return Res;
5261
5262 // Create constants that are functions of the shift amount value.
5263 EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
5264 SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
5265 SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
5266 C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
5267 SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
5268 if (!isConstantOrConstantVector(Inexact))
5269 return SDValue();
5270
5271 // Splat the sign bit into the register
5272 SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
5273 DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
5274 AddToWorklist(Sign.getNode());
5275
5276 // Add (N0 < 0) ? abs2 - 1 : 0;
5277 SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
5278 AddToWorklist(Srl.getNode());
5279 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
5280 AddToWorklist(Add.getNode());
5281 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
5282 AddToWorklist(Sra.getNode());
5283
5284 // Special case: (sdiv X, 1) -> X
5285 // Special Case: (sdiv X, -1) -> 0-X
5286 SDValue One = DAG.getConstant(1, DL, VT);
5288 SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
5289 SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
5290 SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
5291 Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
5292
5293 // If dividing by a positive value, we're done. Otherwise, the result must
5294 // be negated.
5295 SDValue Zero = DAG.getConstant(0, DL, VT);
5296 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
5297
5298 // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
5299 SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
5300 SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
5301 return Res;
5302 }
5303
5304 // If integer divide is expensive and we satisfy the requirements, emit an
5305 // alternate sequence. Targets may check function attributes for size/speed
5306 // trade-offs.
5307 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5308 if (isConstantOrConstantVector(N1, /*NoOpaques=*/false,
5309 /*AllowTruncation=*/true) &&
5310 !TLI.isIntDivCheap(N->getValueType(0), Attr))
5311 if (SDValue Op = BuildSDIV(N))
5312 return Op;
5313
5314 return SDValue();
5315}
5316
5317SDValue DAGCombiner::visitUDIV(SDNode *N) {
5318 SDValue N0 = N->getOperand(0);
5319 SDValue N1 = N->getOperand(1);
5320 EVT VT = N->getValueType(0);
5321 EVT CCVT = getSetCCResultType(VT);
5322 SDLoc DL(N);
5323
5324 // fold (udiv c1, c2) -> c1/c2
5325 if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
5326 return C;
5327
5328 // fold vector ops
5329 if (VT.isVector())
5330 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5331 return FoldedVOp;
5332
5333 // fold (udiv X, -1) -> select(X == -1, 1, 0)
5334 ConstantSDNode *N1C = isConstOrConstSplat(N1);
5335 if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
5336 return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
5337 DAG.getConstant(1, DL, VT),
5338 DAG.getConstant(0, DL, VT));
5339 }
5340
5341 if (SDValue V = simplifyDivRem(N, DAG))
5342 return V;
5343
5344 if (SDValue NewSel = foldBinOpIntoSelect(N))
5345 return NewSel;
5346
5347 if (SDValue V = visitUDIVLike(N0, N1, N)) {
5348 // If the corresponding remainder node exists, update its users with
5349 // (Dividend - (Quotient * Divisor).
5350 if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
5351 { N0, N1 })) {
5352 // If the udiv has the exact flag we shouldn't propagate it to the
5353 // remainder node.
5354 if (!N->getFlags().hasExact()) {
5355 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
5356 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5357 AddToWorklist(Mul.getNode());
5358 AddToWorklist(Sub.getNode());
5359 CombineTo(RemNode, Sub);
5360 }
5361 }
5362 return V;
5363 }
5364
5365 // sdiv, srem -> sdivrem
5366 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
5367 // true. Otherwise, we break the simplification logic in visitREM().
5368 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5369 if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
5370 if (SDValue DivRem = useDivRem(N))
5371 return DivRem;
5372
5373 // Simplify the operands using demanded-bits information.
5374 // We don't have demanded bits support for UDIV so this just enables constant
5375 // folding based on known bits.
5377 return SDValue(N, 0);
5378
5379 return SDValue();
5380}
5381
5382SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
5383 SDLoc DL(N);
5384 EVT VT = N->getValueType(0);
5385
5386 // fold (udiv x, (1 << c)) -> x >>u c
5387 if (isConstantOrConstantVector(N1, /*NoOpaques=*/true,
5388 /*AllowTruncation=*/true)) {
5389 if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
5390 AddToWorklist(LogBase2.getNode());
5391
5392 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
5393 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
5394 AddToWorklist(Trunc.getNode());
5395 return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
5396 }
5397 }
5398
5399 // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
5400 if (N1.getOpcode() == ISD::SHL) {
5401 SDValue N10 = N1.getOperand(0);
5402 if (isConstantOrConstantVector(N10, /*NoOpaques=*/true,
5403 /*AllowTruncation=*/true)) {
5404 if (SDValue LogBase2 = BuildLogBase2(N10, DL)) {
5405 AddToWorklist(LogBase2.getNode());
5406
5407 EVT ADDVT = N1.getOperand(1).getValueType();
5408 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
5409 AddToWorklist(Trunc.getNode());
5410 SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
5411 AddToWorklist(Add.getNode());
5412 return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
5413 }
5414 }
5415 }
5416
5417 // fold (udiv x, c) -> alternate
5418 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5419 if (isConstantOrConstantVector(N1, /*NoOpaques=*/false,
5420 /*AllowTruncation=*/true) &&
5421 !TLI.isIntDivCheap(N->getValueType(0), Attr))
5422 if (SDValue Op = BuildUDIV(N))
5423 return Op;
5424
5425 return SDValue();
5426}
5427
5428SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
5429 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) &&
5430 !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) {
5431 // Target-specific implementation of srem x, pow2.
5432 if (SDValue Res = BuildSREMPow2(N))
5433 return Res;
5434 }
5435 return SDValue();
5436}
5437
5438// handles ISD::SREM and ISD::UREM
5439SDValue DAGCombiner::visitREM(SDNode *N) {
5440 unsigned Opcode = N->getOpcode();
5441 SDValue N0 = N->getOperand(0);
5442 SDValue N1 = N->getOperand(1);
5443 EVT VT = N->getValueType(0);
5444 EVT CCVT = getSetCCResultType(VT);
5445
5446 bool isSigned = (Opcode == ISD::SREM);
5447 SDLoc DL(N);
5448
5449 // fold (rem c1, c2) -> c1%c2
5450 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5451 return C;
5452
5453 // fold (urem X, -1) -> select(FX == -1, 0, FX)
5454 // Freeze the numerator to avoid a miscompile with an undefined value.
5455 if (!isSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false) &&
5456 CCVT.isVector() == VT.isVector()) {
5457 SDValue F0 = DAG.getFreeze(N0);
5458 SDValue EqualsNeg1 = DAG.getSetCC(DL, CCVT, F0, N1, ISD::SETEQ);
5459 return DAG.getSelect(DL, VT, EqualsNeg1, DAG.getConstant(0, DL, VT), F0);
5460 }
5461
5462 if (SDValue V = simplifyDivRem(N, DAG))
5463 return V;
5464
5465 if (SDValue NewSel = foldBinOpIntoSelect(N))
5466 return NewSel;
5467
5468 if (isSigned) {
5469 // If we know the sign bits of both operands are zero, strength reduce to a
5470 // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
5471 if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
5472 return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
5473 } else {
5474 if (DAG.isKnownToBeAPowerOfTwo(N1, /*OrZero=*/true)) {
5475 // fold (urem x, pow2) -> (and x, pow2-1)
5476 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5477 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
5478 AddToWorklist(Add.getNode());
5479 return DAG.getNode(ISD::AND, DL, VT, N0, Add);
5480 }
5481 }
5482
5483 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5484
5485 // If X/C can be simplified by the division-by-constant logic, lower
5486 // X%C to the equivalent of X-X/C*C.
5487 // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
5488 // speculative DIV must not cause a DIVREM conversion. We guard against this
5489 // by skipping the simplification if isIntDivCheap(). When div is not cheap,
5490 // combine will not return a DIVREM. Regardless, checking cheapness here
5491 // makes sense since the simplification results in fatter code.
5492 if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
5493 if (isSigned) {
5494 // check if we can build faster implementation for srem
5495 if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
5496 return OptimizedRem;
5497 }
5498
5499 SDValue OptimizedDiv =
5500 isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
5501 if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
5502 // If the equivalent Div node also exists, update its users.
5503 unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5504 if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
5505 { N0, N1 }))
5506 CombineTo(DivNode, OptimizedDiv);
5507 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
5508 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5509 AddToWorklist(OptimizedDiv.getNode());
5510 AddToWorklist(Mul.getNode());
5511 return Sub;
5512 }
5513 }
5514
5515 // sdiv, srem -> sdivrem
5516 if (SDValue DivRem = useDivRem(N))
5517 return DivRem.getValue(1);
5518
5519 // fold urem(urem(A, BCst), Op1Cst) -> urem(A, Op1Cst)
5520 // iff urem(BCst, Op1Cst) == 0
5521 SDValue A;
5522 APInt Op1Cst, BCst;
5523 if (sd_match(N, m_URem(m_URem(m_Value(A), m_ConstInt(BCst)),
5524 m_ConstInt(Op1Cst))) &&
5525 BCst.urem(Op1Cst).isZero()) {
5526 return DAG.getNode(ISD::UREM, DL, VT, A, DAG.getConstant(Op1Cst, DL, VT));
5527 }
5528
5529 // fold srem(srem(A, BCst), Op1Cst) -> srem(A, Op1Cst)
5530 // iff srem(BCst, Op1Cst) == 0 && Op1Cst != 1
5531 if (sd_match(N, m_SRem(m_SRem(m_Value(A), m_ConstInt(BCst)),
5532 m_ConstInt(Op1Cst))) &&
5533 BCst.srem(Op1Cst).isZero() && !Op1Cst.isAllOnes()) {
5534 return DAG.getNode(ISD::SREM, DL, VT, A, DAG.getConstant(Op1Cst, DL, VT));
5535 }
5536
5537 return SDValue();
5538}
5539
5540SDValue DAGCombiner::visitMULHS(SDNode *N) {
5541 SDValue N0 = N->getOperand(0);
5542 SDValue N1 = N->getOperand(1);
5543 EVT VT = N->getValueType(0);
5544 SDLoc DL(N);
5545
5546 // fold (mulhs c1, c2)
5547 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
5548 return C;
5549
5550 // canonicalize constant to RHS.
5553 return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
5554
5555 if (VT.isVector()) {
5556 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5557 return FoldedVOp;
5558
5559 // fold (mulhs x, 0) -> 0
5560 // do not return N1, because undef node may exist.
5562 return DAG.getConstant(0, DL, VT);
5563 }
5564
5565 // fold (mulhs x, 0) -> 0
5566 if (isNullConstant(N1))
5567 return N1;
5568
5569 // fold (mulhs x, 1) -> (sra x, size(x)-1)
5570 if (isOneConstant(N1))
5571 return DAG.getNode(
5572 ISD::SRA, DL, VT, N0,
5574
5575 // fold (mulhs x, undef) -> 0
5576 if (N0.isUndef() || N1.isUndef())
5577 return DAG.getConstant(0, DL, VT);
5578
5579 // If the type twice as wide is legal, transform the mulhs to a wider multiply
5580 // plus a shift.
5581 if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
5582 !VT.isVector()) {
5583 MVT Simple = VT.getSimpleVT();
5584 unsigned SimpleSize = Simple.getSizeInBits();
5585 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5586 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5587 N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
5588 N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
5589 N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
5590 N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
5591 DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5592 return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
5593 }
5594 }
5595
5596 return SDValue();
5597}
5598
5599SDValue DAGCombiner::visitMULHU(SDNode *N) {
5600 SDValue N0 = N->getOperand(0);
5601 SDValue N1 = N->getOperand(1);
5602 EVT VT = N->getValueType(0);
5603 SDLoc DL(N);
5604
5605 // fold (mulhu c1, c2)
5606 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
5607 return C;
5608
5609 // canonicalize constant to RHS.
5612 return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
5613
5614 if (VT.isVector()) {
5615 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5616 return FoldedVOp;
5617
5618 // fold (mulhu x, 0) -> 0
5619 // do not return N1, because undef node may exist.
5621 return DAG.getConstant(0, DL, VT);
5622 }
5623
5624 // fold (mulhu x, 0) -> 0
5625 if (isNullConstant(N1))
5626 return N1;
5627
5628 // fold (mulhu x, 1) -> 0
5629 if (isOneConstant(N1))
5630 return DAG.getConstant(0, DL, VT);
5631
5632 // fold (mulhu x, undef) -> 0
5633 if (N0.isUndef() || N1.isUndef())
5634 return DAG.getConstant(0, DL, VT);
5635
5636 // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
5637 if (isConstantOrConstantVector(N1, /*NoOpaques=*/true,
5638 /*AllowTruncation=*/true) &&
5639 hasOperation(ISD::SRL, VT)) {
5640 if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
5641 unsigned NumEltBits = VT.getScalarSizeInBits();
5642 SDValue SRLAmt = DAG.getNode(
5643 ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
5644 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
5645 SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
5646 return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
5647 }
5648 }
5649
5650 // If the type twice as wide is legal, transform the mulhu to a wider multiply
5651 // plus a shift.
5652 if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
5653 !VT.isVector()) {
5654 MVT Simple = VT.getSimpleVT();
5655 unsigned SimpleSize = Simple.getSizeInBits();
5656 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5657 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5658 N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5659 N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5660 N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
5661 N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
5662 DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5663 return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
5664 }
5665 }
5666
5667 // Simplify the operands using demanded-bits information.
5668 // We don't have demanded bits support for MULHU so this just enables constant
5669 // folding based on known bits.
5671 return SDValue(N, 0);
5672
5673 return SDValue();
5674}
5675
5676SDValue DAGCombiner::visitAVG(SDNode *N) {
5677 unsigned Opcode = N->getOpcode();
5678 SDValue N0 = N->getOperand(0);
5679 SDValue N1 = N->getOperand(1);
5680 EVT VT = N->getValueType(0);
5681 SDLoc DL(N);
5682 bool IsSigned = Opcode == ISD::AVGCEILS || Opcode == ISD::AVGFLOORS;
5683
5684 // fold (avg c1, c2)
5685 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5686 return C;
5687
5688 // canonicalize constant to RHS.
5691 return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
5692
5693 if (VT.isVector())
5694 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5695 return FoldedVOp;
5696
5697 // fold (avg x, undef) -> x
5698 if (N0.isUndef())
5699 return N1;
5700 if (N1.isUndef())
5701 return N0;
5702
5703 // fold (avg x, x) --> x
5704 if (N0 == N1 && Level >= AfterLegalizeTypes)
5705 return N0;
5706
5707 // fold (avgfloor x, 0) -> x >> 1
5708 SDValue X, Y;
5710 return DAG.getNode(ISD::SRA, DL, VT, X,
5711 DAG.getShiftAmountConstant(1, VT, DL));
5713 return DAG.getNode(ISD::SRL, DL, VT, X,
5714 DAG.getShiftAmountConstant(1, VT, DL));
5715
5716 // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
5717 // fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
5718 if (!IsSigned &&
5719 sd_match(N, m_BinOp(Opcode, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
5720 X.getValueType() == Y.getValueType() &&
5721 hasOperation(Opcode, X.getValueType())) {
5722 SDValue AvgU = DAG.getNode(Opcode, DL, X.getValueType(), X, Y);
5723 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgU);
5724 }
5725 if (IsSigned &&
5726 sd_match(N, m_BinOp(Opcode, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
5727 X.getValueType() == Y.getValueType() &&
5728 hasOperation(Opcode, X.getValueType())) {
5729 SDValue AvgS = DAG.getNode(Opcode, DL, X.getValueType(), X, Y);
5730 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgS);
5731 }
5732
5733 // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
5734 // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
5735 // Check if avgflooru isn't legal/custom but avgceilu is.
5736 if (Opcode == ISD::AVGFLOORU && !hasOperation(ISD::AVGFLOORU, VT) &&
5737 (!LegalOperations || hasOperation(ISD::AVGCEILU, VT))) {
5738 if (DAG.isKnownNeverZero(N1))
5739 return DAG.getNode(
5740 ISD::AVGCEILU, DL, VT, N0,
5741 DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getAllOnesConstant(DL, VT)));
5742 if (DAG.isKnownNeverZero(N0))
5743 return DAG.getNode(
5744 ISD::AVGCEILU, DL, VT, N1,
5745 DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
5746 }
5747
5748 // Fold avgfloor((add nw x,y), 1) -> avgceil(x,y)
5749 // Fold avgfloor((add nw x,1), y) -> avgceil(x,y)
5750 if ((Opcode == ISD::AVGFLOORU && hasOperation(ISD::AVGCEILU, VT)) ||
5751 (Opcode == ISD::AVGFLOORS && hasOperation(ISD::AVGCEILS, VT))) {
5752 SDValue Add;
5753 if (sd_match(N,
5754 m_c_BinOp(Opcode,
5756 m_One())) ||
5757 sd_match(N, m_c_BinOp(Opcode,
5759 m_Value(Y)))) {
5760
5761 if (IsSigned && Add->getFlags().hasNoSignedWrap())
5762 return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
5763
5764 if (!IsSigned && Add->getFlags().hasNoUnsignedWrap())
5765 return DAG.getNode(ISD::AVGCEILU, DL, VT, X, Y);
5766 }
5767 }
5768
5769 // Fold avgfloors(x,y) -> avgflooru(x,y) if both x and y are non-negative
5770 if (Opcode == ISD::AVGFLOORS && hasOperation(ISD::AVGFLOORU, VT)) {
5771 if (DAG.SignBitIsZero(N0) && DAG.SignBitIsZero(N1))
5772 return DAG.getNode(ISD::AVGFLOORU, DL, VT, N0, N1);
5773 }
5774
5775 return SDValue();
5776}
5777
5778SDValue DAGCombiner::visitABD(SDNode *N) {
5779 unsigned Opcode = N->getOpcode();
5780 SDValue N0 = N->getOperand(0);
5781 SDValue N1 = N->getOperand(1);
5782 EVT VT = N->getValueType(0);
5783 SDLoc DL(N);
5784
5785 // fold (abd c1, c2)
5786 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5787 return C;
5788
5789 // canonicalize constant to RHS.
5792 return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
5793
5794 if (VT.isVector())
5795 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5796 return FoldedVOp;
5797
5798 // fold (abd x, undef) -> 0
5799 if (N0.isUndef() || N1.isUndef())
5800 return DAG.getConstant(0, DL, VT);
5801
5802 // fold (abd x, x) -> 0
5803 if (N0 == N1)
5804 return DAG.getConstant(0, DL, VT);
5805
5806 SDValue X, Y;
5807
5808 // fold (abds x, 0) -> abs x
5810 (!LegalOperations || hasOperation(ISD::ABS, VT)))
5811 return DAG.getNode(ISD::ABS, DL, VT, X);
5812
5813 // fold (abdu x, 0) -> x
5815 return X;
5816
5817 // fold (abds x, y) -> (abdu x, y) iff both args are known positive
5818 if (Opcode == ISD::ABDS && hasOperation(ISD::ABDU, VT) &&
5819 DAG.SignBitIsZero(N0) && DAG.SignBitIsZero(N1))
5820 return DAG.getNode(ISD::ABDU, DL, VT, N1, N0);
5821
5822 // fold (abd? (?ext x), (?ext y)) -> (zext (abd? x, y))
5825 EVT SmallVT = X.getScalarValueSizeInBits() > Y.getScalarValueSizeInBits()
5826 ? X.getValueType()
5827 : Y.getValueType();
5828 if (!LegalOperations || hasOperation(Opcode, SmallVT)) {
5829 SDValue ExtedX = DAG.getExtOrTrunc(X, SDLoc(X), SmallVT, N0->getOpcode());
5830 SDValue ExtedY = DAG.getExtOrTrunc(Y, SDLoc(Y), SmallVT, N0->getOpcode());
5831 SDValue SmallABD = DAG.getNode(Opcode, DL, SmallVT, {ExtedX, ExtedY});
5832 SDValue ZExted = DAG.getZExtOrTrunc(SmallABD, DL, VT);
5833 return ZExted;
5834 }
5835 }
5836
5837 // fold (abd? (?ext ty:x), small_const:c) -> (zext (abd? x, c))
5840 EVT SmallVT = X.getValueType();
5841 if (!LegalOperations || hasOperation(Opcode, SmallVT)) {
5842 uint64_t Bits = SmallVT.getScalarSizeInBits();
5843 unsigned RelevantBits =
5844 (Opcode == ISD::ABDS) ? DAG.ComputeMaxSignificantBits(Y)
5846 bool TruncatingYIsCheap = TLI.isTruncateFree(Y, SmallVT) ||
5848 Y,
5849 [&](auto *C) {
5850 const APInt &YConst = C->getAsAPIntVal();
5851 return (Opcode == ISD::ABDS)
5852 ? YConst.isSignedIntN(Bits)
5853 : YConst.isIntN(Bits);
5854 },
5855 /*AllowUndefs=*/true);
5856
5857 if (RelevantBits <= Bits && TruncatingYIsCheap) {
5858 SDValue NewY = DAG.getNode(ISD::TRUNCATE, SDLoc(Y), SmallVT, Y);
5859 SDValue SmallABD = DAG.getNode(Opcode, DL, SmallVT, {X, NewY});
5860 return DAG.getZExtOrTrunc(SmallABD, DL, VT);
5861 }
5862 }
5863 }
5864
5865 return SDValue();
5866}
5867
5868/// Perform optimizations common to nodes that compute two values. LoOp and HiOp
5869/// give the opcodes for the two computations that are being performed. Return
5870/// true if a simplification was made.
5871SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
5872 unsigned HiOp) {
5873 // If the high half is not needed, just compute the low half.
5874 bool HiExists = N->hasAnyUseOfValue(1);
5875 if (!HiExists && (!LegalOperations ||
5876 TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
5877 SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
5878 return CombineTo(N, Res, Res);
5879 }
5880
5881 // If the low half is not needed, just compute the high half.
5882 bool LoExists = N->hasAnyUseOfValue(0);
5883 if (!LoExists && (!LegalOperations ||
5884 TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
5885 SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
5886 return CombineTo(N, Res, Res);
5887 }
5888
5889 // If both halves are used, return as it is.
5890 if (LoExists && HiExists)
5891 return SDValue();
5892
5893 // If the two computed results can be simplified separately, separate them.
5894 if (LoExists) {
5895 SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
5896 AddToWorklist(Lo.getNode());
5897 SDValue LoOpt = combine(Lo.getNode());
5898 if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
5899 (!LegalOperations ||
5900 TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
5901 return CombineTo(N, LoOpt, LoOpt);
5902 }
5903
5904 if (HiExists) {
5905 SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
5906 AddToWorklist(Hi.getNode());
5907 SDValue HiOpt = combine(Hi.getNode());
5908 if (HiOpt.getNode() && HiOpt != Hi &&
5909 (!LegalOperations ||
5910 TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
5911 return CombineTo(N, HiOpt, HiOpt);
5912 }
5913
5914 return SDValue();
5915}
5916
5917SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
5918 if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
5919 return Res;
5920
5921 SDValue N0 = N->getOperand(0);
5922 SDValue N1 = N->getOperand(1);
5923 EVT VT = N->getValueType(0);
5924 SDLoc DL(N);
5925
5926 // Constant fold.
5928 return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N0, N1);
5929
5930 // canonicalize constant to RHS (vector doesn't have to splat)
5933 return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N1, N0);
5934
5935 // If the type is twice as wide is legal, transform the mulhu to a wider
5936 // multiply plus a shift.
5937 if (VT.isSimple() && !VT.isVector()) {
5938 MVT Simple = VT.getSimpleVT();
5939 unsigned SimpleSize = Simple.getSizeInBits();
5940 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5941 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5942 SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
5943 SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
5944 Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5945 // Compute the high part as N1.
5946 Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5947 DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5948 Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5949 // Compute the low part as N0.
5950 Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5951 return CombineTo(N, Lo, Hi);
5952 }
5953 }
5954
5955 return SDValue();
5956}
5957
5958SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
5959 if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
5960 return Res;
5961
5962 SDValue N0 = N->getOperand(0);
5963 SDValue N1 = N->getOperand(1);
5964 EVT VT = N->getValueType(0);
5965 SDLoc DL(N);
5966
5967 // Constant fold.
5969 return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N0, N1);
5970
5971 // canonicalize constant to RHS (vector doesn't have to splat)
5974 return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N1, N0);
5975
5976 // (umul_lohi N0, 0) -> (0, 0)
5977 if (isNullConstant(N1)) {
5978 SDValue Zero = DAG.getConstant(0, DL, VT);
5979 return CombineTo(N, Zero, Zero);
5980 }
5981
5982 // (umul_lohi N0, 1) -> (N0, 0)
5983 if (isOneConstant(N1)) {
5984 SDValue Zero = DAG.getConstant(0, DL, VT);
5985 return CombineTo(N, N0, Zero);
5986 }
5987
5988 // If the type is twice as wide is legal, transform the mulhu to a wider
5989 // multiply plus a shift.
5990 if (VT.isSimple() && !VT.isVector()) {
5991 MVT Simple = VT.getSimpleVT();
5992 unsigned SimpleSize = Simple.getSizeInBits();
5993 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5994 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5995 SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5996 SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5997 Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5998 // Compute the high part as N1.
5999 Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
6000 DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
6001 Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
6002 // Compute the low part as N0.
6003 Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
6004 return CombineTo(N, Lo, Hi);
6005 }
6006 }
6007
6008 return SDValue();
6009}
6010
6011SDValue DAGCombiner::visitMULO(SDNode *N) {
6012 SDValue N0 = N->getOperand(0);
6013 SDValue N1 = N->getOperand(1);
6014 EVT VT = N0.getValueType();
6015 bool IsSigned = (ISD::SMULO == N->getOpcode());
6016
6017 EVT CarryVT = N->getValueType(1);
6018 SDLoc DL(N);
6019
6020 ConstantSDNode *N0C = isConstOrConstSplat(N0);
6021 ConstantSDNode *N1C = isConstOrConstSplat(N1);
6022
6023 // fold operation with constant operands.
6024 // TODO: Move this to FoldConstantArithmetic when it supports nodes with
6025 // multiple results.
6026 if (N0C && N1C) {
6027 bool Overflow;
6028 APInt Result =
6029 IsSigned ? N0C->getAPIntValue().smul_ov(N1C->getAPIntValue(), Overflow)
6030 : N0C->getAPIntValue().umul_ov(N1C->getAPIntValue(), Overflow);
6031 return CombineTo(N, DAG.getConstant(Result, DL, VT),
6032 DAG.getBoolConstant(Overflow, DL, CarryVT, CarryVT));
6033 }
6034
6035 // canonicalize constant to RHS.
6038 return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
6039
6040 // fold (mulo x, 0) -> 0 + no carry out
6041 if (isNullOrNullSplat(N1))
6042 return CombineTo(N, DAG.getConstant(0, DL, VT),
6043 DAG.getConstant(0, DL, CarryVT));
6044
6045 // (mulo x, 2) -> (addo x, x)
6046 // FIXME: This needs a freeze.
6047 if (N1C && N1C->getAPIntValue() == 2 &&
6048 (!IsSigned || VT.getScalarSizeInBits() > 2))
6049 return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
6050 N->getVTList(), N0, N0);
6051
6052 // A 1 bit SMULO overflows if both inputs are 1.
6053 if (IsSigned && VT.getScalarSizeInBits() == 1) {
6054 SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
6055 SDValue Cmp = DAG.getSetCC(DL, CarryVT, And,
6056 DAG.getConstant(0, DL, VT), ISD::SETNE);
6057 return CombineTo(N, And, Cmp);
6058 }
6059
6060 // If it cannot overflow, transform into a mul.
6061 if (DAG.willNotOverflowMul(IsSigned, N0, N1))
6062 return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
6063 DAG.getConstant(0, DL, CarryVT));
6064 return SDValue();
6065}
6066
6067// Function to calculate whether the Min/Max pair of SDNodes (potentially
6068// swapped around) make a signed saturate pattern, clamping to between a signed
6069// saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
6070// Returns the node being clamped and the bitwidth of the clamp in BW. Should
6071// work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
6072// same as SimplifySelectCC. N0<N1 ? N2 : N3.
6074 SDValue N3, ISD::CondCode CC, unsigned &BW,
6075 bool &Unsigned, SelectionDAG &DAG) {
6076 auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
6077 ISD::CondCode CC) {
6078 // The compare and select operand should be the same or the select operands
6079 // should be truncated versions of the comparison.
6080 if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0)))
6081 return 0;
6082 // The constants need to be the same or a truncated version of each other.
6085 if (!N1C || !N3C)
6086 return 0;
6087 const APInt &C1 = N1C->getAPIntValue().trunc(N1.getScalarValueSizeInBits());
6088 const APInt &C2 = N3C->getAPIntValue().trunc(N3.getScalarValueSizeInBits());
6089 if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(C1.getBitWidth()))
6090 return 0;
6091 return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
6092 };
6093
6094 // Check the initial value is a SMIN/SMAX equivalent.
6095 unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
6096 if (!Opcode0)
6097 return SDValue();
6098
6099 // We could only need one range check, if the fptosi could never produce
6100 // the upper value.
6101 if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) {
6102 if (isNullOrNullSplat(N3)) {
6103 EVT IntVT = N0.getValueType().getScalarType();
6104 EVT FPVT = N0.getOperand(0).getValueType().getScalarType();
6105 if (FPVT.isSimple()) {
6106 Type *InputTy = FPVT.getTypeForEVT(*DAG.getContext());
6107 const fltSemantics &Semantics = InputTy->getFltSemantics();
6108 uint32_t MinBitWidth =
6109 APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true);
6110 if (IntVT.getSizeInBits() >= MinBitWidth) {
6111 Unsigned = true;
6112 BW = PowerOf2Ceil(MinBitWidth);
6113 return N0;
6114 }
6115 }
6116 }
6117 }
6118
6119 SDValue N00, N01, N02, N03;
6120 ISD::CondCode N0CC;
6121 switch (N0.getOpcode()) {
6122 case ISD::SMIN:
6123 case ISD::SMAX:
6124 N00 = N02 = N0.getOperand(0);
6125 N01 = N03 = N0.getOperand(1);
6126 N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
6127 break;
6128 case ISD::SELECT_CC:
6129 N00 = N0.getOperand(0);
6130 N01 = N0.getOperand(1);
6131 N02 = N0.getOperand(2);
6132 N03 = N0.getOperand(3);
6133 N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
6134 break;
6135 case ISD::SELECT:
6136 case ISD::VSELECT:
6137 if (N0.getOperand(0).getOpcode() != ISD::SETCC)
6138 return SDValue();
6139 N00 = N0.getOperand(0).getOperand(0);
6140 N01 = N0.getOperand(0).getOperand(1);
6141 N02 = N0.getOperand(1);
6142 N03 = N0.getOperand(2);
6143 N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get();
6144 break;
6145 default:
6146 return SDValue();
6147 }
6148
6149 unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
6150 if (!Opcode1 || Opcode0 == Opcode1)
6151 return SDValue();
6152
6153 ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01);
6154 ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1);
6155 if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0))
6156 return SDValue();
6157
6158 const APInt &MinC = MinCOp->getAPIntValue();
6159 const APInt &MaxC = MaxCOp->getAPIntValue();
6160 APInt MinCPlus1 = MinC + 1;
6161 if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
6162 BW = MinCPlus1.exactLogBase2() + 1;
6163 Unsigned = false;
6164 return N02;
6165 }
6166
6167 if (MaxC == 0 && MinC != 0 && MinCPlus1.isPowerOf2()) {
6168 BW = MinCPlus1.exactLogBase2();
6169 Unsigned = true;
6170 return N02;
6171 }
6172
6173 return SDValue();
6174}
6175
6177 SDValue N3, ISD::CondCode CC,
6178 SelectionDAG &DAG) {
6179 unsigned BW;
6180 bool Unsigned;
6181 SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG);
6182 if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
6183 return SDValue();
6184 EVT FPVT = Fp.getOperand(0).getValueType();
6185 EVT NewVT = FPVT.changeElementType(*DAG.getContext(),
6186 EVT::getIntegerVT(*DAG.getContext(), BW));
6187 unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
6188 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(NewOpc, FPVT, NewVT))
6189 return SDValue();
6190 SDLoc DL(Fp);
6191 SDValue Sat = DAG.getNode(NewOpc, DL, NewVT, Fp.getOperand(0),
6192 DAG.getValueType(NewVT.getScalarType()));
6193 return DAG.getExtOrTrunc(!Unsigned, Sat, DL, N2->getValueType(0));
6194}
6195
6197 SDValue N3, ISD::CondCode CC,
6198 SelectionDAG &DAG) {
6199 // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
6200 // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
6201 // be truncated versions of the setcc (N0/N1).
6202 if ((N0 != N2 &&
6203 (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0))) ||
6204 N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
6205 return SDValue();
6208 if (!N1C || !N3C)
6209 return SDValue();
6210 const APInt &C1 = N1C->getAPIntValue();
6211 const APInt &C3 = N3C->getAPIntValue();
6212 if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
6213 C1 != C3.zext(C1.getBitWidth()))
6214 return SDValue();
6215
6216 unsigned BW = (C1 + 1).exactLogBase2();
6217 EVT FPVT = N0.getOperand(0).getValueType();
6218 EVT NewVT = FPVT.changeElementType(*DAG.getContext(),
6219 EVT::getIntegerVT(*DAG.getContext(), BW));
6221 FPVT, NewVT))
6222 return SDValue();
6223
6224 SDValue Sat =
6225 DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), NewVT, N0.getOperand(0),
6226 DAG.getValueType(NewVT.getScalarType()));
6227 return DAG.getZExtOrTrunc(Sat, SDLoc(N0), N3.getValueType());
6228}
6229
6230SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
6231 SDValue N0 = N->getOperand(0);
6232 SDValue N1 = N->getOperand(1);
6233 EVT VT = N0.getValueType();
6234 unsigned Opcode = N->getOpcode();
6235 SDLoc DL(N);
6236
6237 // fold operation with constant operands.
6238 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
6239 return C;
6240
6241 // If the operands are the same, this is a no-op.
6242 if (N0 == N1)
6243 return N0;
6244
6245 // canonicalize constant to RHS
6248 return DAG.getNode(Opcode, DL, VT, N1, N0);
6249
6250 // fold vector ops
6251 if (VT.isVector())
6252 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
6253 return FoldedVOp;
6254
6255 // reassociate minmax
6256 if (SDValue RMINMAX = reassociateOps(Opcode, DL, N0, N1, N->getFlags()))
6257 return RMINMAX;
6258
6259 // If both operands are known to have the same sign (both non-negative or both
6260 // negative), flip between UMIN/UMAX and SMIN/SMAX.
6261 // Only do this if:
6262 // 1. The current op isn't legal and the flipped is.
6263 // 2. The saturation pattern is broken by canonicalization in InstCombine.
6264 bool IsOpIllegal = !TLI.isOperationLegal(Opcode, VT);
6265 bool IsSatBroken = Opcode == ISD::UMIN && N0.getOpcode() == ISD::SMAX;
6266
6267 if (IsSatBroken || IsOpIllegal) {
6268 auto HasKnownSameSign = [&](SDValue A, SDValue B) {
6269 if (A.isUndef() || B.isUndef())
6270 return true;
6271
6272 KnownBits KA = DAG.computeKnownBits(A);
6273 if (!KA.isNonNegative() && !KA.isNegative())
6274 return false;
6275
6276 KnownBits KB = DAG.computeKnownBits(B);
6277 if (KA.isNonNegative())
6278 return KB.isNonNegative();
6279 return KB.isNegative();
6280 };
6281
6282 if (HasKnownSameSign(N0, N1)) {
6283 unsigned AltOpcode = ISD::getOppositeSignednessMinMaxOpcode(Opcode);
6284 if ((IsSatBroken && IsOpIllegal) || TLI.isOperationLegal(AltOpcode, VT))
6285 return DAG.getNode(AltOpcode, DL, VT, N0, N1);
6286 }
6287 }
6288
6289 if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
6291 N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
6292 return S;
6293 if (Opcode == ISD::UMIN)
6294 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
6295 return S;
6296
6297 // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y))
6298 auto ReductionOpcode = [](unsigned Opcode) {
6299 switch (Opcode) {
6300 case ISD::SMIN:
6301 return ISD::VECREDUCE_SMIN;
6302 case ISD::SMAX:
6303 return ISD::VECREDUCE_SMAX;
6304 case ISD::UMIN:
6305 return ISD::VECREDUCE_UMIN;
6306 case ISD::UMAX:
6307 return ISD::VECREDUCE_UMAX;
6308 default:
6309 llvm_unreachable("Unexpected opcode");
6310 }
6311 };
6312 if (SDValue SD = reassociateReduction(ReductionOpcode(Opcode), Opcode,
6313 SDLoc(N), VT, N0, N1))
6314 return SD;
6315
6316 // Fold operation with vscale operands.
6317 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
6318 uint64_t C0 = N0->getConstantOperandVal(0);
6319 uint64_t C1 = N1->getConstantOperandVal(0);
6320 if (Opcode == ISD::UMAX)
6321 return C0 > C1 ? N0 : N1;
6322 else if (Opcode == ISD::UMIN)
6323 return C0 > C1 ? N1 : N0;
6324 }
6325
6326 // If we know the range of vscale, see if we can fold it given a constant.
6327 // TODO: Generalize this to other nodes by adding computeConstantRange
6328 if (N0.getOpcode() == ISD::VSCALE) {
6329 if (auto *C1 = dyn_cast<ConstantSDNode>(N1)) {
6330 const Function &F = DAG.getMachineFunction().getFunction();
6331 ConstantRange Range =
6333 .multiply(ConstantRange(N0.getConstantOperandAPInt(0)));
6334
6335 const APInt &C1V = C1->getAPIntValue();
6336 if ((Opcode == ISD::UMAX && Range.getUnsignedMax().ule(C1V)) ||
6337 (Opcode == ISD::UMIN && Range.getUnsignedMin().uge(C1V)) ||
6338 (Opcode == ISD::SMAX && Range.getSignedMax().sle(C1V)) ||
6339 (Opcode == ISD::SMIN && Range.getSignedMin().sge(C1V))) {
6340 return N1;
6341 }
6342 }
6343 }
6344
6345 // Simplify the operands using demanded-bits information.
6347 return SDValue(N, 0);
6348
6349 return SDValue();
6350}
6351
6352/// If this is a bitwise logic instruction and both operands have the same
6353/// opcode, try to sink the other opcode after the logic instruction.
6354SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
6355 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
6356 EVT VT = N0.getValueType();
6357 unsigned LogicOpcode = N->getOpcode();
6358 unsigned HandOpcode = N0.getOpcode();
6359 assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode");
6360 assert(HandOpcode == N1.getOpcode() && "Bad input!");
6361
6362 // Bail early if none of these transforms apply.
6363 if (N0.getNumOperands() == 0)
6364 return SDValue();
6365
6366 // FIXME: We should check number of uses of the operands to not increase
6367 // the instruction count for all transforms.
6368
6369 // Handle size-changing casts (or sign_extend_inreg).
6370 SDValue X = N0.getOperand(0);
6371 SDValue Y = N1.getOperand(0);
6372 EVT XVT = X.getValueType();
6373 SDLoc DL(N);
6374 if (ISD::isExtOpcode(HandOpcode) || ISD::isExtVecInRegOpcode(HandOpcode) ||
6375 (HandOpcode == ISD::SIGN_EXTEND_INREG &&
6376 N0.getOperand(1) == N1.getOperand(1))) {
6377 // If both operands have other uses, this transform would create extra
6378 // instructions without eliminating anything.
6379 if (!N0.hasOneUse() && !N1.hasOneUse())
6380 return SDValue();
6381 // We need matching integer source types.
6382 if (XVT != Y.getValueType())
6383 return SDValue();
6384 // Don't create an illegal op during or after legalization. Don't ever
6385 // create an unsupported vector op.
6386 if ((VT.isVector() || LegalOperations) &&
6387 !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
6388 return SDValue();
6389 // Avoid infinite looping with PromoteIntBinOp.
6390 // TODO: Should we apply desirable/legal constraints to all opcodes?
6391 if ((HandOpcode == ISD::ANY_EXTEND ||
6392 HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
6393 LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
6394 return SDValue();
6395 // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
6396 SDNodeFlags LogicFlags;
6397 LogicFlags.setDisjoint(N->getFlags().hasDisjoint() &&
6398 ISD::isExtOpcode(HandOpcode));
6399 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y, LogicFlags);
6400 if (HandOpcode == ISD::SIGN_EXTEND_INREG)
6401 return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
6402 return DAG.getNode(HandOpcode, DL, VT, Logic);
6403 }
6404
6405 // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
6406 if (HandOpcode == ISD::TRUNCATE) {
6407 // If both operands have other uses, this transform would create extra
6408 // instructions without eliminating anything.
6409 if (!N0.hasOneUse() && !N1.hasOneUse())
6410 return SDValue();
6411 // We need matching source types.
6412 if (XVT != Y.getValueType())
6413 return SDValue();
6414 // Don't create an illegal op during or after legalization.
6415 if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
6416 return SDValue();
6417 // Be extra careful sinking truncate. If it's free, there's no benefit in
6418 // widening a binop. Also, don't create a logic op on an illegal type.
6419 if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
6420 return SDValue();
6421 if (!TLI.isTypeLegal(XVT))
6422 return SDValue();
6423 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6424 return DAG.getNode(HandOpcode, DL, VT, Logic);
6425 }
6426
6427 // For binops SHL/SRL/SRA/AND:
6428 // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
6429 if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
6430 HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
6431 N0.getOperand(1) == N1.getOperand(1)) {
6432 // If either operand has other uses, this transform is not an improvement.
6433 if (!N0.hasOneUse() || !N1.hasOneUse())
6434 return SDValue();
6435 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6436 return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
6437 }
6438
6439 // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
6440 if (HandOpcode == ISD::BSWAP) {
6441 // If either operand has other uses, this transform is not an improvement.
6442 if (!N0.hasOneUse() || !N1.hasOneUse())
6443 return SDValue();
6444 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6445 return DAG.getNode(HandOpcode, DL, VT, Logic);
6446 }
6447
6448 // For funnel shifts FSHL/FSHR:
6449 // logic_op (OP x, x1, s), (OP y, y1, s) -->
6450 // --> OP (logic_op x, y), (logic_op, x1, y1), s
6451 if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
6452 N0.getOperand(2) == N1.getOperand(2)) {
6453 if (!N0.hasOneUse() || !N1.hasOneUse())
6454 return SDValue();
6455 SDValue X1 = N0.getOperand(1);
6456 SDValue Y1 = N1.getOperand(1);
6457 SDValue S = N0.getOperand(2);
6458 SDValue Logic0 = DAG.getNode(LogicOpcode, DL, VT, X, Y);
6459 SDValue Logic1 = DAG.getNode(LogicOpcode, DL, VT, X1, Y1);
6460 return DAG.getNode(HandOpcode, DL, VT, Logic0, Logic1, S);
6461 }
6462
6463 // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
6464 // Only perform this optimization up until type legalization, before
6465 // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
6466 // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
6467 // we don't want to undo this promotion.
6468 // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
6469 // on scalars.
6470 if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
6471 Level <= AfterLegalizeTypes) {
6472 // Input types must be integer and the same.
6473 if (XVT.isInteger() && XVT == Y.getValueType() &&
6474 !(VT.isVector() && TLI.isTypeLegal(VT) &&
6475 !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
6476 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6477 return DAG.getNode(HandOpcode, DL, VT, Logic);
6478 }
6479 }
6480
6481 // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
6482 // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
6483 // If both shuffles use the same mask, and both shuffle within a single
6484 // vector, then it is worthwhile to move the swizzle after the operation.
6485 // The type-legalizer generates this pattern when loading illegal
6486 // vector types from memory. In many cases this allows additional shuffle
6487 // optimizations.
6488 // There are other cases where moving the shuffle after the xor/and/or
6489 // is profitable even if shuffles don't perform a swizzle.
6490 // If both shuffles use the same mask, and both shuffles have the same first
6491 // or second operand, then it might still be profitable to move the shuffle
6492 // after the xor/and/or operation.
6493 if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
6494 auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
6495 auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
6496 assert(X.getValueType() == Y.getValueType() &&
6497 "Inputs to shuffles are not the same type");
6498
6499 // Check that both shuffles use the same mask. The masks are known to be of
6500 // the same length because the result vector type is the same.
6501 // Check also that shuffles have only one use to avoid introducing extra
6502 // instructions.
6503 if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
6504 !SVN0->getMask().equals(SVN1->getMask()))
6505 return SDValue();
6506
6507 // Don't try to fold this node if it requires introducing a
6508 // build vector of all zeros that might be illegal at this stage.
6509 SDValue ShOp = N0.getOperand(1);
6510 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
6511 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
6512
6513 // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
6514 if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
6515 SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
6516 N0.getOperand(0), N1.getOperand(0));
6517 return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
6518 }
6519
6520 // Don't try to fold this node if it requires introducing a
6521 // build vector of all zeros that might be illegal at this stage.
6522 ShOp = N0.getOperand(0);
6523 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
6524 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
6525
6526 // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
6527 if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
6528 SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
6529 N1.getOperand(1));
6530 return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
6531 }
6532 }
6533
6534 return SDValue();
6535}
6536
6537/// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
6538SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
6539 const SDLoc &DL) {
6540 SDValue LL, LR, RL, RR, N0CC, N1CC;
6541 if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
6542 !isSetCCEquivalent(N1, RL, RR, N1CC))
6543 return SDValue();
6544
6545 assert(N0.getValueType() == N1.getValueType() &&
6546 "Unexpected operand types for bitwise logic op");
6547 assert(LL.getValueType() == LR.getValueType() &&
6548 RL.getValueType() == RR.getValueType() &&
6549 "Unexpected operand types for setcc");
6550
6551 // If we're here post-legalization or the logic op type is not i1, the logic
6552 // op type must match a setcc result type. Also, all folds require new
6553 // operations on the left and right operands, so those types must match.
6554 EVT VT = N0.getValueType();
6555 EVT OpVT = LL.getValueType();
6556 if (LegalOperations || VT.getScalarType() != MVT::i1)
6557 if (VT != getSetCCResultType(OpVT))
6558 return SDValue();
6559 if (OpVT != RL.getValueType())
6560 return SDValue();
6561
6562 ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
6563 ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
6564 bool IsInteger = OpVT.isInteger();
6565 if (LR == RR && CC0 == CC1 && IsInteger) {
6566 bool IsZero = isNullOrNullSplat(LR);
6567 bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
6568
6569 // All bits clear?
6570 bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
6571 // All sign bits clear?
6572 bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
6573 // Any bits set?
6574 bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
6575 // Any sign bits set?
6576 bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
6577
6578 // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
6579 // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
6580 // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
6581 // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
6582 if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
6583 SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
6584 AddToWorklist(Or.getNode());
6585 return DAG.getSetCC(DL, VT, Or, LR, CC1);
6586 }
6587
6588 // All bits set?
6589 bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
6590 // All sign bits set?
6591 bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
6592 // Any bits clear?
6593 bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
6594 // Any sign bits clear?
6595 bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
6596
6597 // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
6598 // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
6599 // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
6600 // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
6601 if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
6602 SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
6603 AddToWorklist(And.getNode());
6604 return DAG.getSetCC(DL, VT, And, LR, CC1);
6605 }
6606 }
6607
6608 // TODO: What is the 'or' equivalent of this fold?
6609 // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
6610 if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
6611 IsInteger && CC0 == ISD::SETNE &&
6612 ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
6613 (isAllOnesConstant(LR) && isNullConstant(RR)))) {
6614 SDValue One = DAG.getConstant(1, DL, OpVT);
6615 SDValue Two = DAG.getConstant(2, DL, OpVT);
6616 SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
6617 AddToWorklist(Add.getNode());
6618 return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
6619 }
6620
6621 // Try more general transforms if the predicates match and the only user of
6622 // the compares is the 'and' or 'or'.
6623 if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
6624 N0.hasOneUse() && N1.hasOneUse()) {
6625 // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
6626 // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
6627 if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
6628 SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
6629 SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
6630 SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
6631 SDValue Zero = DAG.getConstant(0, DL, OpVT);
6632 return DAG.getSetCC(DL, VT, Or, Zero, CC1);
6633 }
6634
6635 // Turn compare of constants whose difference is 1 bit into add+and+setcc.
6636 if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
6637 // Match a shared variable operand and 2 non-opaque constant operands.
6638 auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
6639 // The difference of the constants must be a single bit.
6640 const APInt &CMax =
6641 APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue());
6642 const APInt &CMin =
6643 APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue());
6644 return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
6645 };
6646 if (LL == RL && ISD::matchBinaryPredicate(LR, RR, MatchDiffPow2)) {
6647 // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
6648 // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
6649 SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR);
6650 SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR);
6651 SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min);
6652 SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min);
6653 SDValue Mask = DAG.getNOT(DL, Diff, OpVT);
6654 SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask);
6655 SDValue Zero = DAG.getConstant(0, DL, OpVT);
6656 return DAG.getSetCC(DL, VT, And, Zero, CC0);
6657 }
6658 }
6659 }
6660
6661 // Canonicalize equivalent operands to LL == RL.
6662 if (LL == RR && LR == RL) {
6664 std::swap(RL, RR);
6665 }
6666
6667 // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6668 // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6669 if (LL == RL && LR == RR) {
6670 ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
6671 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
6672 if (NewCC != ISD::SETCC_INVALID &&
6673 (!LegalOperations ||
6674 (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
6675 TLI.isOperationLegal(ISD::SETCC, OpVT))))
6676 return DAG.getSetCC(DL, VT, LL, LR, NewCC);
6677 }
6678
6679 return SDValue();
6680}
6681
6682static bool arebothOperandsNotSNan(SDValue Operand1, SDValue Operand2,
6683 SelectionDAG &DAG) {
6684 return DAG.isKnownNeverSNaN(Operand2) && DAG.isKnownNeverSNaN(Operand1);
6685}
6686
6687static bool arebothOperandsNotNan(SDValue Operand1, SDValue Operand2,
6688 SelectionDAG &DAG) {
6689 return DAG.isKnownNeverNaN(Operand2) && DAG.isKnownNeverNaN(Operand1);
6690}
6691
6692/// Returns an appropriate FP min/max opcode for clamping operations.
6693static unsigned getMinMaxOpcodeForClamp(bool IsMin, SDValue Operand1,
6694 SDValue Operand2, SelectionDAG &DAG,
6695 const TargetLowering &TLI) {
6696 EVT VT = Operand1.getValueType();
6697 unsigned IEEEOp = IsMin ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
6698 if (TLI.isOperationLegalOrCustom(IEEEOp, VT) &&
6699 arebothOperandsNotNan(Operand1, Operand2, DAG))
6700 return IEEEOp;
6701 unsigned PreferredOp = IsMin ? ISD::FMINNUM : ISD::FMAXNUM;
6702 if (TLI.isOperationLegalOrCustom(PreferredOp, VT))
6703 return PreferredOp;
6704 return ISD::DELETED_NODE;
6705}
6706
6707// FIXME: use FMINIMUMNUM if possible, such as for RISC-V.
6709 SDValue Operand1, SDValue Operand2, bool SetCCNoNaNs, ISD::CondCode CC,
6710 unsigned OrAndOpcode, SelectionDAG &DAG, bool isFMAXNUMFMINNUM_IEEE,
6711 bool isFMAXNUMFMINNUM) {
6712 // The optimization cannot be applied for all the predicates because
6713 // of the way FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle
6714 // NaNs. For FMINNUM_IEEE/FMAXNUM_IEEE, the optimization cannot be
6715 // applied at all if one of the operands is a signaling NaN.
6716
6717 // It is safe to use FMINNUM_IEEE/FMAXNUM_IEEE if all the operands
6718 // are non NaN values.
6719 if (((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::OR)) ||
6720 ((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::AND))) {
6721 return (SetCCNoNaNs || arebothOperandsNotNan(Operand1, Operand2, DAG)) &&
6722 isFMAXNUMFMINNUM_IEEE
6725 }
6726
6727 if (((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::OR)) ||
6728 ((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::AND))) {
6729 return (SetCCNoNaNs || arebothOperandsNotNan(Operand1, Operand2, DAG)) &&
6730 isFMAXNUMFMINNUM_IEEE
6733 }
6734
6735 // Both FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle quiet
6736 // NaNs in the same way. But, FMINNUM/FMAXNUM and FMINNUM_IEEE/
6737 // FMAXNUM_IEEE handle signaling NaNs differently. If we cannot prove
6738 // that there are not any sNaNs, then the optimization is not valid
6739 // for FMINNUM_IEEE/FMAXNUM_IEEE. In the presence of sNaNs, we apply
6740 // the optimization using FMINNUM/FMAXNUM for the following cases. If
6741 // we can prove that we do not have any sNaNs, then we can do the
6742 // optimization using FMINNUM_IEEE/FMAXNUM_IEEE for the following
6743 // cases.
6744 if (((CC == ISD::SETOLT || CC == ISD::SETOLE) && (OrAndOpcode == ISD::OR)) ||
6745 ((CC == ISD::SETUGT || CC == ISD::SETUGE) && (OrAndOpcode == ISD::AND))) {
6746 return isFMAXNUMFMINNUM ? ISD::FMINNUM
6747 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6748 isFMAXNUMFMINNUM_IEEE
6751 }
6752
6753 if (((CC == ISD::SETOGT || CC == ISD::SETOGE) && (OrAndOpcode == ISD::OR)) ||
6754 ((CC == ISD::SETULT || CC == ISD::SETULE) && (OrAndOpcode == ISD::AND))) {
6755 return isFMAXNUMFMINNUM ? ISD::FMAXNUM
6756 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6757 isFMAXNUMFMINNUM_IEEE
6760 }
6761
6762 return ISD::DELETED_NODE;
6763}
6764
6767 assert(
6768 (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) &&
6769 "Invalid Op to combine SETCC with");
6770
6771 // TODO: Search past casts/truncates.
6772 SDValue LHS = LogicOp->getOperand(0);
6773 SDValue RHS = LogicOp->getOperand(1);
6774 if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC ||
6775 !LHS->hasOneUse() || !RHS->hasOneUse())
6776 return SDValue();
6777
6778 SDNodeFlags LHSSetCCFlags = LHS->getFlags();
6779 SDNodeFlags RHSSetCCFlags = RHS->getFlags();
6780 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6782 LogicOp, LHS.getNode(), RHS.getNode());
6783
6784 SDValue LHS0 = LHS->getOperand(0);
6785 SDValue RHS0 = RHS->getOperand(0);
6786 SDValue LHS1 = LHS->getOperand(1);
6787 SDValue RHS1 = RHS->getOperand(1);
6788 // TODO: We don't actually need a splat here, for vectors we just need the
6789 // invariants to hold for each element.
6790 auto *LHS1C = isConstOrConstSplat(LHS1);
6791 auto *RHS1C = isConstOrConstSplat(RHS1);
6792 ISD::CondCode CCL = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
6793 ISD::CondCode CCR = cast<CondCodeSDNode>(RHS.getOperand(2))->get();
6794 EVT VT = LogicOp->getValueType(0);
6795 EVT OpVT = LHS0.getValueType();
6796 SDLoc DL(LogicOp);
6797
6798 // Check if the operands of an and/or operation are comparisons and if they
6799 // compare against the same value. Replace the and/or-cmp-cmp sequence with
6800 // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp
6801 // sequence will be replaced with min-cmp sequence:
6802 // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1
6803 // and and-cmp-cmp will be replaced with max-cmp sequence:
6804 // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1
6805 // The optimization does not work for `==` or `!=` .
6806 // The two comparisons should have either the same predicate or the
6807 // predicate of one of the comparisons is the opposite of the other one.
6808 bool isFMAXNUMFMINNUM_IEEE = TLI.isOperationLegal(ISD::FMAXNUM_IEEE, OpVT) &&
6810 bool isFMAXNUMFMINNUM = TLI.isOperationLegalOrCustom(ISD::FMAXNUM, OpVT) &&
6812 if (((OpVT.isInteger() && TLI.isOperationLegal(ISD::UMAX, OpVT) &&
6813 TLI.isOperationLegal(ISD::SMAX, OpVT) &&
6814 TLI.isOperationLegal(ISD::UMIN, OpVT) &&
6815 TLI.isOperationLegal(ISD::SMIN, OpVT)) ||
6816 (OpVT.isFloatingPoint() &&
6817 (isFMAXNUMFMINNUM_IEEE || isFMAXNUMFMINNUM))) &&
6819 CCL != ISD::SETFALSE && CCL != ISD::SETO && CCL != ISD::SETUO &&
6820 CCL != ISD::SETTRUE &&
6821 (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(CCR))) {
6822
6823 SDValue CommonValue, Operand1, Operand2;
6825 if (CCL == CCR) {
6826 if (LHS0 == RHS0) {
6827 CommonValue = LHS0;
6828 Operand1 = LHS1;
6829 Operand2 = RHS1;
6831 } else if (LHS1 == RHS1) {
6832 CommonValue = LHS1;
6833 Operand1 = LHS0;
6834 Operand2 = RHS0;
6835 CC = CCL;
6836 }
6837 } else {
6838 assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC");
6839 if (LHS0 == RHS1) {
6840 CommonValue = LHS0;
6841 Operand1 = LHS1;
6842 Operand2 = RHS0;
6843 CC = CCR;
6844 } else if (RHS0 == LHS1) {
6845 CommonValue = LHS1;
6846 Operand1 = LHS0;
6847 Operand2 = RHS1;
6848 CC = CCL;
6849 }
6850 }
6851
6852 // Don't do this transform for sign bit tests. Let foldLogicOfSetCCs
6853 // handle it using OR/AND.
6854 if (CC == ISD::SETLT && isNullOrNullSplat(CommonValue))
6855 CC = ISD::SETCC_INVALID;
6856 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CommonValue))
6857 CC = ISD::SETCC_INVALID;
6858
6859 if (CC != ISD::SETCC_INVALID) {
6860 unsigned NewOpcode = ISD::DELETED_NODE;
6861 bool IsSigned = isSignedIntSetCC(CC);
6862 if (OpVT.isInteger()) {
6863 bool IsLess = (CC == ISD::SETLE || CC == ISD::SETULE ||
6864 CC == ISD::SETLT || CC == ISD::SETULT);
6865 bool IsOr = (LogicOp->getOpcode() == ISD::OR);
6866 if (IsLess == IsOr)
6867 NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN;
6868 else
6869 NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX;
6870 } else if (OpVT.isFloatingPoint())
6872 Operand1, Operand2,
6873 LHSSetCCFlags.hasNoNaNs() && RHSSetCCFlags.hasNoNaNs(), CC,
6874 LogicOp->getOpcode(), DAG, isFMAXNUMFMINNUM_IEEE, isFMAXNUMFMINNUM);
6875
6876 if (NewOpcode != ISD::DELETED_NODE) {
6877 // Propagate fast-math flags from setcc.
6878 SDNodeFlags Flags = LHS->getFlags() & RHS->getFlags();
6879 SDValue MinMaxValue =
6880 DAG.getNode(NewOpcode, DL, OpVT, Operand1, Operand2, Flags);
6881 return DAG.getSetCC(DL, VT, MinMaxValue, CommonValue, CC, /*Chain=*/{},
6882 /*IsSignaling=*/false, Flags);
6883 }
6884 }
6885 }
6886
6887 if (LHS0 == LHS1 && RHS0 == RHS1 && CCL == CCR &&
6888 LHS0.getValueType() == RHS0.getValueType() &&
6889 ((LogicOp->getOpcode() == ISD::AND && CCL == ISD::SETO) ||
6890 (LogicOp->getOpcode() == ISD::OR && CCL == ISD::SETUO)))
6891 return DAG.getSetCC(DL, VT, LHS0, RHS0, CCL);
6892
6893 if (TargetPreference == AndOrSETCCFoldKind::None)
6894 return SDValue();
6895
6896 if (CCL == CCR &&
6897 CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) &&
6898 LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger()) {
6899 const APInt &APLhs = LHS1C->getAPIntValue();
6900 const APInt &APRhs = RHS1C->getAPIntValue();
6901
6902 // Preference is to use ISD::ABS or we already have an ISD::ABS (in which
6903 // case this is just a compare).
6904 if (APLhs == (-APRhs) &&
6905 ((TargetPreference & AndOrSETCCFoldKind::ABS) ||
6906 DAG.doesNodeExist(ISD::ABS, DAG.getVTList(OpVT), {LHS0}))) {
6907 const APInt &C = APLhs.isNegative() ? APRhs : APLhs;
6908 // (icmp eq A, C) | (icmp eq A, -C)
6909 // -> (icmp eq Abs(A), C)
6910 // (icmp ne A, C) & (icmp ne A, -C)
6911 // -> (icmp ne Abs(A), C)
6912 SDValue AbsOp = DAG.getNode(ISD::ABS, DL, OpVT, LHS0);
6913 return DAG.getNode(ISD::SETCC, DL, VT, AbsOp,
6914 DAG.getConstant(C, DL, OpVT), LHS.getOperand(2));
6915 } else if (TargetPreference &
6917
6918 // AndOrSETCCFoldKind::AddAnd:
6919 // A == C0 | A == C1
6920 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6921 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0
6922 // A != C0 & A != C1
6923 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6924 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0
6925
6926 // AndOrSETCCFoldKind::NotAnd:
6927 // A == C0 | A == C1
6928 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6929 // -> ~A & smin(C0, C1) == 0
6930 // A != C0 & A != C1
6931 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6932 // -> ~A & smin(C0, C1) != 0
6933
6934 const APInt &MaxC = APIntOps::smax(APRhs, APLhs);
6935 const APInt &MinC = APIntOps::smin(APRhs, APLhs);
6936 APInt Dif = MaxC - MinC;
6937 if (!Dif.isZero() && Dif.isPowerOf2()) {
6938 if (MaxC.isAllOnes() &&
6939 (TargetPreference & AndOrSETCCFoldKind::NotAnd)) {
6940 SDValue NotOp = DAG.getNOT(DL, LHS0, OpVT);
6941 SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, NotOp,
6942 DAG.getConstant(MinC, DL, OpVT));
6943 return DAG.getNode(ISD::SETCC, DL, VT, AndOp,
6944 DAG.getConstant(0, DL, OpVT), LHS.getOperand(2));
6945 } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) {
6946
6947 SDValue AddOp = DAG.getNode(ISD::ADD, DL, OpVT, LHS0,
6948 DAG.getConstant(-MinC, DL, OpVT));
6949 SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, AddOp,
6950 DAG.getConstant(~Dif, DL, OpVT));
6951 return DAG.getNode(ISD::SETCC, DL, VT, AndOp,
6952 DAG.getConstant(0, DL, OpVT), LHS.getOperand(2));
6953 }
6954 }
6955 }
6956 }
6957
6958 return SDValue();
6959}
6960
6961// Combine `(select c, (X & 1), 0)` -> `(and (zext c), X)`.
6962// We canonicalize to the `select` form in the middle end, but the `and` form
6963// gets better codegen and all tested targets (arm, x86, riscv)
6965 const SDLoc &DL, SelectionDAG &DAG) {
6966 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6967 if (!isNullConstant(F))
6968 return SDValue();
6969
6970 EVT CondVT = Cond.getValueType();
6971 if (TLI.getBooleanContents(CondVT) !=
6973 return SDValue();
6974
6975 if (T.getOpcode() != ISD::AND)
6976 return SDValue();
6977
6978 if (!isOneConstant(T.getOperand(1)))
6979 return SDValue();
6980
6981 EVT OpVT = T.getValueType();
6982
6983 SDValue CondMask =
6984 OpVT == CondVT ? Cond : DAG.getBoolExtOrTrunc(Cond, DL, OpVT, CondVT);
6985 return DAG.getNode(ISD::AND, DL, OpVT, CondMask, T.getOperand(0));
6986}
6987
6988/// This contains all DAGCombine rules which reduce two values combined by
6989/// an And operation to a single value. This makes them reusable in the context
6990/// of visitSELECT(). Rules involving constants are not included as
6991/// visitSELECT() already handles those cases.
6992SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
6993 EVT VT = N1.getValueType();
6994 SDLoc DL(N);
6995
6996 // fold (and x, undef) -> 0
6997 if (N0.isUndef() || N1.isUndef())
6998 return DAG.getConstant(0, DL, VT);
6999
7000 if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
7001 return V;
7002
7003 // Canonicalize:
7004 // and(x, add) -> and(add, x)
7005 if (N1.getOpcode() == ISD::ADD)
7006 std::swap(N0, N1);
7007
7008 // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
7009 if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
7010 VT.isScalarInteger() && VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
7011 if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
7012 if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
7013 // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
7014 // immediate for an add, but it is legal if its top c2 bits are set,
7015 // transform the ADD so the immediate doesn't need to be materialized
7016 // in a register.
7017 APInt ADDC = ADDI->getAPIntValue();
7018 APInt SRLC = SRLI->getAPIntValue();
7019 if (ADDC.getSignificantBits() <= 64 && SRLC.ult(VT.getSizeInBits()) &&
7020 !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
7022 SRLC.getZExtValue());
7023 if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
7024 ADDC |= Mask;
7025 if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
7026 SDLoc DL0(N0);
7027 SDValue NewAdd =
7028 DAG.getNode(ISD::ADD, DL0, VT,
7029 N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
7030 CombineTo(N0.getNode(), NewAdd);
7031 // Return N so it doesn't get rechecked!
7032 return SDValue(N, 0);
7033 }
7034 }
7035 }
7036 }
7037 }
7038 }
7039
7040 return SDValue();
7041}
7042
7043bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
7044 EVT LoadResultTy, EVT &ExtVT) {
7045 if (!AndC->getAPIntValue().isMask())
7046 return false;
7047
7048 unsigned ActiveBits = AndC->getAPIntValue().countr_one();
7049
7050 ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
7051 EVT LoadedVT = LoadN->getMemoryVT();
7052
7053 if (ExtVT == LoadedVT &&
7054 (!LegalOperations ||
7055 TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
7056 // ZEXTLOAD will match without needing to change the size of the value being
7057 // loaded.
7058 return true;
7059 }
7060
7061 // Do not change the width of a volatile or atomic loads.
7062 if (!LoadN->isSimple())
7063 return false;
7064
7065 // Do not generate loads of non-round integer types since these can
7066 // be expensive (and would be wrong if the type is not byte sized).
7067 if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
7068 return false;
7069
7070 if (LegalOperations &&
7071 !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
7072 return false;
7073
7074 if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT, /*ByteOffset=*/0))
7075 return false;
7076
7077 return true;
7078}
7079
7080bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
7081 ISD::LoadExtType ExtType, EVT &MemVT,
7082 unsigned ShAmt) {
7083 if (!LDST)
7084 return false;
7085
7086 // Only allow byte offsets.
7087 if (ShAmt % 8)
7088 return false;
7089 const unsigned ByteShAmt = ShAmt / 8;
7090
7091 // Do not generate loads of non-round integer types since these can
7092 // be expensive (and would be wrong if the type is not byte sized).
7093 if (!MemVT.isRound())
7094 return false;
7095
7096 // Don't change the width of a volatile or atomic loads.
7097 if (!LDST->isSimple())
7098 return false;
7099
7100 EVT LdStMemVT = LDST->getMemoryVT();
7101
7102 // Bail out when changing the scalable property, since we can't be sure that
7103 // we're actually narrowing here.
7104 if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
7105 return false;
7106
7107 // Verify that we are actually reducing a load width here.
7108 if (LdStMemVT.bitsLT(MemVT))
7109 return false;
7110
7111 // Ensure that this isn't going to produce an unsupported memory access.
7112 if (ShAmt) {
7113 const Align LDSTAlign = LDST->getAlign();
7114 const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
7115 if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
7116 LDST->getAddressSpace(), NarrowAlign,
7117 LDST->getMemOperand()->getFlags()))
7118 return false;
7119 }
7120
7121 // It's not possible to generate a constant of extended or untyped type.
7122 EVT PtrType = LDST->getBasePtr().getValueType();
7123 if (PtrType == MVT::Untyped || PtrType.isExtended())
7124 return false;
7125
7126 if (isa<LoadSDNode>(LDST)) {
7127 LoadSDNode *Load = cast<LoadSDNode>(LDST);
7128 // Don't transform one with multiple uses, this would require adding a new
7129 // load.
7130 if (!SDValue(Load, 0).hasOneUse())
7131 return false;
7132
7133 if (LegalOperations &&
7134 !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
7135 return false;
7136
7137 // For the transform to be legal, the load must produce only two values
7138 // (the value loaded and the chain). Don't transform a pre-increment
7139 // load, for example, which produces an extra value. Otherwise the
7140 // transformation is not equivalent, and the downstream logic to replace
7141 // uses gets things wrong.
7142 if (Load->getNumValues() > 2)
7143 return false;
7144
7145 // If the load that we're shrinking is an extload and we're not just
7146 // discarding the extension we can't simply shrink the load. Bail.
7147 // TODO: It would be possible to merge the extensions in some cases.
7148 if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
7149 Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
7150 return false;
7151
7152 if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT, ByteShAmt))
7153 return false;
7154 } else {
7155 assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
7156 StoreSDNode *Store = cast<StoreSDNode>(LDST);
7157 // Can't write outside the original store
7158 if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
7159 return false;
7160
7161 if (LegalOperations &&
7162 !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
7163 return false;
7164 }
7165 return true;
7166}
7167
7168bool DAGCombiner::SearchForAndLoads(SDNode *N,
7169 SmallVectorImpl<LoadSDNode*> &Loads,
7170 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
7171 ConstantSDNode *Mask,
7172 SDNode *&NodeToMask) {
7173 // Recursively search for the operands, looking for loads which can be
7174 // narrowed.
7175 for (SDValue Op : N->op_values()) {
7176 if (Op.getValueType().isVector())
7177 return false;
7178
7179 // Some constants may need fixing up later if they are too large.
7180 if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
7181 assert(ISD::isBitwiseLogicOp(N->getOpcode()) &&
7182 "Expected bitwise logic operation");
7183 if (!C->getAPIntValue().isSubsetOf(Mask->getAPIntValue()))
7184 NodesWithConsts.insert(N);
7185 continue;
7186 }
7187
7188 if (!Op.hasOneUse())
7189 return false;
7190
7191 switch(Op.getOpcode()) {
7192 case ISD::LOAD: {
7193 auto *Load = cast<LoadSDNode>(Op);
7194 EVT ExtVT;
7195 if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
7196 isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
7197
7198 // ZEXTLOAD is already small enough.
7199 if (Load->getExtensionType() == ISD::ZEXTLOAD &&
7200 ExtVT.bitsGE(Load->getMemoryVT()))
7201 continue;
7202
7203 // Use LE to convert equal sized loads to zext.
7204 if (ExtVT.bitsLE(Load->getMemoryVT()))
7205 Loads.push_back(Load);
7206
7207 continue;
7208 }
7209 return false;
7210 }
7211 case ISD::ZERO_EXTEND:
7212 case ISD::AssertZext: {
7213 unsigned ActiveBits = Mask->getAPIntValue().countr_one();
7214 EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
7215 EVT VT = Op.getOpcode() == ISD::AssertZext ?
7216 cast<VTSDNode>(Op.getOperand(1))->getVT() :
7217 Op.getOperand(0).getValueType();
7218
7219 // We can accept extending nodes if the mask is wider or an equal
7220 // width to the original type.
7221 if (ExtVT.bitsGE(VT))
7222 continue;
7223 break;
7224 }
7225 case ISD::OR:
7226 case ISD::XOR:
7227 case ISD::AND:
7228 if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
7229 NodeToMask))
7230 return false;
7231 continue;
7232 }
7233
7234 // Allow one node which will masked along with any loads found.
7235 if (NodeToMask)
7236 return false;
7237
7238 // Also ensure that the node to be masked only produces one data result.
7239 NodeToMask = Op.getNode();
7240 if (NodeToMask->getNumValues() > 1) {
7241 bool HasValue = false;
7242 for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
7243 MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
7244 if (VT != MVT::Glue && VT != MVT::Other) {
7245 if (HasValue) {
7246 NodeToMask = nullptr;
7247 return false;
7248 }
7249 HasValue = true;
7250 }
7251 }
7252 assert(HasValue && "Node to be masked has no data result?");
7253 }
7254 }
7255 return true;
7256}
7257
7258bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
7259 auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
7260 if (!Mask)
7261 return false;
7262
7263 if (!Mask->getAPIntValue().isMask())
7264 return false;
7265
7266 // No need to do anything if the and directly uses a load.
7267 if (isa<LoadSDNode>(N->getOperand(0)))
7268 return false;
7269
7271 SmallPtrSet<SDNode*, 2> NodesWithConsts;
7272 SDNode *FixupNode = nullptr;
7273 if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
7274 if (Loads.empty())
7275 return false;
7276
7277 LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
7278 SDValue MaskOp = N->getOperand(1);
7279
7280 // If it exists, fixup the single node we allow in the tree that needs
7281 // masking.
7282 if (FixupNode) {
7283 LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
7284 SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
7285 FixupNode->getValueType(0),
7286 SDValue(FixupNode, 0), MaskOp);
7287 DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
7288 if (And.getOpcode() == ISD ::AND)
7289 DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
7290 }
7291
7292 // Narrow any constants that need it.
7293 for (auto *LogicN : NodesWithConsts) {
7294 SDValue Op0 = LogicN->getOperand(0);
7295 SDValue Op1 = LogicN->getOperand(1);
7296
7297 // We only need to fix AND if both inputs are constants. And we only need
7298 // to fix one of the constants.
7299 if (LogicN->getOpcode() == ISD::AND &&
7301 continue;
7302
7303 if (isa<ConstantSDNode>(Op0) && LogicN->getOpcode() != ISD::AND)
7304 Op0 =
7305 DAG.getNode(ISD::AND, SDLoc(Op0), Op0.getValueType(), Op0, MaskOp);
7306
7307 if (isa<ConstantSDNode>(Op1))
7308 Op1 =
7309 DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(), Op1, MaskOp);
7310
7311 if (isa<ConstantSDNode>(Op0) && !isa<ConstantSDNode>(Op1))
7312 std::swap(Op0, Op1);
7313
7314 DAG.UpdateNodeOperands(LogicN, Op0, Op1);
7315 }
7316
7317 // Create narrow loads.
7318 for (auto *Load : Loads) {
7319 LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
7320 SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
7321 SDValue(Load, 0), MaskOp);
7322 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
7323 if (And.getOpcode() == ISD ::AND)
7324 And = SDValue(
7325 DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
7326 SDValue NewLoad = reduceLoadWidth(And.getNode());
7327 assert(NewLoad &&
7328 "Shouldn't be masking the load if it can't be narrowed");
7329 CombineTo(Load, NewLoad, NewLoad.getValue(1));
7330 }
7331 DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
7332 return true;
7333 }
7334 return false;
7335}
7336
7337// Unfold
7338// x & (-1 'logical shift' y)
7339// To
7340// (x 'opposite logical shift' y) 'logical shift' y
7341// if it is better for performance.
7342SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
7343 assert(N->getOpcode() == ISD::AND);
7344
7345 SDValue N0 = N->getOperand(0);
7346 SDValue N1 = N->getOperand(1);
7347
7348 // Do we actually prefer shifts over mask?
7350 return SDValue();
7351
7352 // Try to match (-1 '[outer] logical shift' y)
7353 unsigned OuterShift;
7354 unsigned InnerShift; // The opposite direction to the OuterShift.
7355 SDValue Y; // Shift amount.
7356 auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
7357 if (!M.hasOneUse())
7358 return false;
7359 OuterShift = M->getOpcode();
7360 if (OuterShift == ISD::SHL)
7361 InnerShift = ISD::SRL;
7362 else if (OuterShift == ISD::SRL)
7363 InnerShift = ISD::SHL;
7364 else
7365 return false;
7366 if (!isAllOnesConstant(M->getOperand(0)))
7367 return false;
7368 Y = M->getOperand(1);
7369 return true;
7370 };
7371
7372 SDValue X;
7373 if (matchMask(N1))
7374 X = N0;
7375 else if (matchMask(N0))
7376 X = N1;
7377 else
7378 return SDValue();
7379
7380 SDLoc DL(N);
7381 EVT VT = N->getValueType(0);
7382
7383 // tmp = x 'opposite logical shift' y
7384 SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
7385 // ret = tmp 'logical shift' y
7386 SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
7387
7388 return T1;
7389}
7390
7391/// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
7392/// For a target with a bit test, this is expected to become test + set and save
7393/// at least 1 instruction.
7395 assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
7396
7397 // Look through an optional extension.
7398 SDValue And0 = And->getOperand(0), And1 = And->getOperand(1);
7399 if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
7400 And0 = And0.getOperand(0);
7401 if (!isOneConstant(And1) || !And0.hasOneUse())
7402 return SDValue();
7403
7404 SDValue Src = And0;
7405
7406 // Attempt to find a 'not' op.
7407 // TODO: Should we favor test+set even without the 'not' op?
7408 bool FoundNot = false;
7409 if (isBitwiseNot(Src)) {
7410 FoundNot = true;
7411 Src = Src.getOperand(0);
7412
7413 // Look though an optional truncation. The source operand may not be the
7414 // same type as the original 'and', but that is ok because we are masking
7415 // off everything but the low bit.
7416 if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
7417 Src = Src.getOperand(0);
7418 }
7419
7420 // Match a shift-right by constant.
7421 if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
7422 return SDValue();
7423
7424 // This is probably not worthwhile without a supported type.
7425 EVT SrcVT = Src.getValueType();
7426 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
7427 if (!TLI.isTypeLegal(SrcVT))
7428 return SDValue();
7429
7430 // We might have looked through casts that make this transform invalid.
7431 unsigned BitWidth = SrcVT.getScalarSizeInBits();
7432 SDValue ShiftAmt = Src.getOperand(1);
7433 auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt);
7434 if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(BitWidth))
7435 return SDValue();
7436
7437 // Set source to shift source.
7438 Src = Src.getOperand(0);
7439
7440 // Try again to find a 'not' op.
7441 // TODO: Should we favor test+set even with two 'not' ops?
7442 if (!FoundNot) {
7443 if (!isBitwiseNot(Src))
7444 return SDValue();
7445 Src = Src.getOperand(0);
7446 }
7447
7448 if (!TLI.hasBitTest(Src, ShiftAmt))
7449 return SDValue();
7450
7451 // Turn this into a bit-test pattern using mask op + setcc:
7452 // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
7453 // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
7454 SDLoc DL(And);
7455 SDValue X = DAG.getZExtOrTrunc(Src, DL, SrcVT);
7456 EVT CCVT =
7457 TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT);
7458 SDValue Mask = DAG.getConstant(
7459 APInt::getOneBitSet(BitWidth, ShiftAmtC->getZExtValue()), DL, SrcVT);
7460 SDValue NewAnd = DAG.getNode(ISD::AND, DL, SrcVT, X, Mask);
7461 SDValue Zero = DAG.getConstant(0, DL, SrcVT);
7462 SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
7463 return DAG.getZExtOrTrunc(Setcc, DL, And->getValueType(0));
7464}
7465
7466/// For targets that support usubsat, match a bit-hack form of that operation
7467/// that ends in 'and' and convert it.
7469 EVT VT = N->getValueType(0);
7470 unsigned BitWidth = VT.getScalarSizeInBits();
7471 APInt SignMask = APInt::getSignMask(BitWidth);
7472
7473 // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
7474 // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
7475 // xor/add with SMIN (signmask) are logically equivalent.
7476 SDValue X;
7477 if (!sd_match(N, m_And(m_OneUse(m_Xor(m_Value(X), m_SpecificInt(SignMask))),
7479 m_SpecificInt(BitWidth - 1))))) &&
7482 m_SpecificInt(BitWidth - 1))))))
7483 return SDValue();
7484
7485 return DAG.getNode(ISD::USUBSAT, DL, VT, X,
7486 DAG.getConstant(SignMask, DL, VT));
7487}
7488
7489/// Given a bitwise logic operation N with a matching bitwise logic operand,
7490/// fold a pattern where 2 of the source operands are identically shifted
7491/// values. For example:
7492/// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
7494 SelectionDAG &DAG) {
7495 unsigned LogicOpcode = N->getOpcode();
7496 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
7497 "Expected bitwise logic operation");
7498
7499 if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
7500 return SDValue();
7501
7502 // Match another bitwise logic op and a shift.
7503 unsigned ShiftOpcode = ShiftOp.getOpcode();
7504 if (LogicOp.getOpcode() != LogicOpcode ||
7505 !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
7506 ShiftOpcode == ISD::SRA))
7507 return SDValue();
7508
7509 // Match another shift op inside the first logic operand. Handle both commuted
7510 // possibilities.
7511 // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
7512 // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
7513 SDValue X1 = ShiftOp.getOperand(0);
7514 SDValue Y = ShiftOp.getOperand(1);
7515 SDValue X0, Z;
7516 if (LogicOp.getOperand(0).getOpcode() == ShiftOpcode &&
7517 LogicOp.getOperand(0).getOperand(1) == Y) {
7518 X0 = LogicOp.getOperand(0).getOperand(0);
7519 Z = LogicOp.getOperand(1);
7520 } else if (LogicOp.getOperand(1).getOpcode() == ShiftOpcode &&
7521 LogicOp.getOperand(1).getOperand(1) == Y) {
7522 X0 = LogicOp.getOperand(1).getOperand(0);
7523 Z = LogicOp.getOperand(0);
7524 } else {
7525 return SDValue();
7526 }
7527
7528 EVT VT = N->getValueType(0);
7529 SDLoc DL(N);
7530 SDValue LogicX = DAG.getNode(LogicOpcode, DL, VT, X0, X1);
7531 SDValue NewShift = DAG.getNode(ShiftOpcode, DL, VT, LogicX, Y);
7532 return DAG.getNode(LogicOpcode, DL, VT, NewShift, Z);
7533}
7534
7535/// Given a tree of logic operations with shape like
7536/// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
7537/// try to match and fold shift operations with the same shift amount.
7538/// For example:
7539/// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
7540/// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
7542 SDValue RightHand, SelectionDAG &DAG) {
7543 unsigned LogicOpcode = N->getOpcode();
7544 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
7545 "Expected bitwise logic operation");
7546 if (LeftHand.getOpcode() != LogicOpcode ||
7547 RightHand.getOpcode() != LogicOpcode)
7548 return SDValue();
7549 if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
7550 return SDValue();
7551
7552 // Try to match one of following patterns:
7553 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
7554 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
7555 // Note that foldLogicOfShifts will handle commuted versions of the left hand
7556 // itself.
7557 SDValue CombinedShifts, W;
7558 SDValue R0 = RightHand.getOperand(0);
7559 SDValue R1 = RightHand.getOperand(1);
7560 if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R0, DAG)))
7561 W = R1;
7562 else if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R1, DAG)))
7563 W = R0;
7564 else
7565 return SDValue();
7566
7567 EVT VT = N->getValueType(0);
7568 SDLoc DL(N);
7569 return DAG.getNode(LogicOpcode, DL, VT, CombinedShifts, W);
7570}
7571
7572/// Fold "masked merge" expressions like `(m & x) | (~m & y)` and its DeMorgan
7573/// variant `(~m | x) & (m | y)` into the equivalent `((x ^ y) & m) ^ y)`
7574/// pattern. This is typically a better representation for targets without a
7575/// fused "and-not" operation.
7577 const TargetLowering &TLI, const SDLoc &DL) {
7578 // Note that masked-merge variants using XOR or ADD expressions are
7579 // normalized to OR by InstCombine so we only check for OR or AND.
7580 assert((Node->getOpcode() == ISD::OR || Node->getOpcode() == ISD::AND) &&
7581 "Must be called with ISD::OR or ISD::AND node");
7582
7583 // If the target supports and-not, don't fold this.
7584 if (TLI.hasAndNot(SDValue(Node, 0)))
7585 return SDValue();
7586
7587 SDValue M, X, Y;
7588
7589 if (sd_match(Node,
7591 m_OneUse(m_And(m_Deferred(M), m_Value(X))))) ||
7592 sd_match(Node,
7594 m_OneUse(m_Or(m_Deferred(M), m_Value(Y)))))) {
7595 EVT VT = M.getValueType();
7596 SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, X, Y);
7597 SDValue And = DAG.getNode(ISD::AND, DL, VT, Xor, M);
7598 return DAG.getNode(ISD::XOR, DL, VT, And, Y);
7599 }
7600 return SDValue();
7601}
7602
7603SDValue DAGCombiner::visitAND(SDNode *N) {
7604 SDValue N0 = N->getOperand(0);
7605 SDValue N1 = N->getOperand(1);
7606 EVT VT = N1.getValueType();
7607 SDLoc DL(N);
7608
7609 // x & x --> x
7610 if (N0 == N1)
7611 return N0;
7612
7613 // fold (and c1, c2) -> c1&c2
7614 if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, DL, VT, {N0, N1}))
7615 return C;
7616
7617 // canonicalize constant to RHS
7620 return DAG.getNode(ISD::AND, DL, VT, N1, N0);
7621
7622 if (areBitwiseNotOfEachother(N0, N1))
7623 return DAG.getConstant(APInt::getZero(VT.getScalarSizeInBits()), DL, VT);
7624
7625 // fold vector ops
7626 if (VT.isVector()) {
7627 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
7628 return FoldedVOp;
7629
7630 // fold (and x, 0) -> 0, vector edition
7632 // do not return N1, because undef node may exist in N1
7634 N1.getValueType());
7635
7636 // fold (and x, -1) -> x, vector edition
7638 return N0;
7639
7640 // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
7641 bool Frozen = N0.getOpcode() == ISD::FREEZE;
7642 auto *MLoad = dyn_cast<MaskedLoadSDNode>(Frozen ? N0.getOperand(0) : N0);
7643 ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
7644 if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
7645 EVT MemVT = MLoad->getMemoryVT();
7646 if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT)) {
7647 // For this AND to be a zero extension of the masked load the elements
7648 // of the BuildVec must mask the bottom bits of the extended element
7649 // type
7650 if (Splat->getAPIntValue().isMask(MemVT.getScalarSizeInBits())) {
7651 SDValue NewLoad = DAG.getMaskedLoad(
7652 VT, DL, MLoad->getChain(), MLoad->getBasePtr(),
7653 MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(), MemVT,
7654 MLoad->getMemOperand(), MLoad->getAddressingMode(), ISD::ZEXTLOAD,
7655 MLoad->isExpandingLoad());
7656 CombineTo(N, Frozen ? N0 : NewLoad);
7657 CombineTo(MLoad, NewLoad, NewLoad.getValue(1));
7658 return SDValue(N, 0);
7659 }
7660 }
7661 }
7662 }
7663
7664 // fold (and x, -1) -> x
7665 if (isAllOnesConstant(N1))
7666 return N0;
7667
7668 // if (and x, c) is known to be zero, return 0
7669 unsigned BitWidth = VT.getScalarSizeInBits();
7670 ConstantSDNode *N1C = isConstOrConstSplat(N1);
7672 return DAG.getConstant(0, DL, VT);
7673
7674 if (SDValue R = foldAndOrOfSETCC(N, DAG))
7675 return R;
7676
7677 if (SDValue NewSel = foldBinOpIntoSelect(N))
7678 return NewSel;
7679
7680 // reassociate and
7681 if (SDValue RAND = reassociateOps(ISD::AND, DL, N0, N1, N->getFlags()))
7682 return RAND;
7683
7684 // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
7685 if (SDValue SD =
7686 reassociateReduction(ISD::VECREDUCE_AND, ISD::AND, DL, VT, N0, N1))
7687 return SD;
7688
7689 // fold (and (or x, C), D) -> D if (C & D) == D
7690 auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7691 return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
7692 };
7693 if (N0.getOpcode() == ISD::OR &&
7694 ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
7695 return N1;
7696
7697 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
7698 SDValue N0Op0 = N0.getOperand(0);
7699 EVT SrcVT = N0Op0.getValueType();
7700 unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
7701 APInt Mask = ~N1C->getAPIntValue();
7702 Mask = Mask.trunc(SrcBitWidth);
7703
7704 // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
7705 if (DAG.MaskedValueIsZero(N0Op0, Mask))
7706 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0Op0);
7707
7708 // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable.
7709 if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
7710 TLI.isTruncateFree(VT, SrcVT) && TLI.isZExtFree(SrcVT, VT) &&
7711 TLI.isTypeDesirableForOp(ISD::AND, SrcVT) &&
7712 TLI.isNarrowingProfitable(N, VT, SrcVT))
7713 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT,
7714 DAG.getNode(ISD::AND, DL, SrcVT, N0Op0,
7715 DAG.getZExtOrTrunc(N1, DL, SrcVT)));
7716 }
7717
7718 // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
7719 if (ISD::isExtOpcode(N0.getOpcode())) {
7720 unsigned ExtOpc = N0.getOpcode();
7721 SDValue N0Op0 = N0.getOperand(0);
7722 if (N0Op0.getOpcode() == ISD::AND &&
7723 (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0Op0, VT)) &&
7724 N0->hasOneUse() && N0Op0->hasOneUse()) {
7725 if (SDValue NewExt = DAG.FoldConstantArithmetic(ExtOpc, DL, VT,
7726 {N0Op0.getOperand(1)})) {
7727 if (SDValue NewMask =
7728 DAG.FoldConstantArithmetic(ISD::AND, DL, VT, {N1, NewExt})) {
7729 return DAG.getNode(ISD::AND, DL, VT,
7730 DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(0)),
7731 NewMask);
7732 }
7733 }
7734 }
7735 }
7736
7737 // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
7738 // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
7739 // already be zero by virtue of the width of the base type of the load.
7740 //
7741 // the 'X' node here can either be nothing or an extract_vector_elt to catch
7742 // more cases.
7743 if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7745 N0.getOperand(0).getOpcode() == ISD::LOAD &&
7746 N0.getOperand(0).getResNo() == 0) ||
7747 (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
7748 auto *Load =
7749 cast<LoadSDNode>((N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0));
7750
7751 // Get the constant (if applicable) the zero'th operand is being ANDed with.
7752 // This can be a pure constant or a vector splat, in which case we treat the
7753 // vector as a scalar and use the splat value.
7754 APInt Constant = APInt::getZero(1);
7755 if (const ConstantSDNode *C = isConstOrConstSplat(
7756 N1, /*AllowUndefs=*/false, /*AllowTruncation=*/true)) {
7757 Constant = C->getAPIntValue();
7758 } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
7759 unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
7760 APInt SplatValue, SplatUndef;
7761 unsigned SplatBitSize;
7762 bool HasAnyUndefs;
7763 // Endianness should not matter here. Code below makes sure that we only
7764 // use the result if the SplatBitSize is a multiple of the vector element
7765 // size. And after that we AND all element sized parts of the splat
7766 // together. So the end result should be the same regardless of in which
7767 // order we do those operations.
7768 const bool IsBigEndian = false;
7769 bool IsSplat =
7770 Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
7771 HasAnyUndefs, EltBitWidth, IsBigEndian);
7772
7773 // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
7774 // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
7775 if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
7776 // Undef bits can contribute to a possible optimisation if set, so
7777 // set them.
7778 SplatValue |= SplatUndef;
7779
7780 // The splat value may be something like "0x00FFFFFF", which means 0 for
7781 // the first vector value and FF for the rest, repeating. We need a mask
7782 // that will apply equally to all members of the vector, so AND all the
7783 // lanes of the constant together.
7784 Constant = APInt::getAllOnes(EltBitWidth);
7785 for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
7786 Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
7787 }
7788 }
7789
7790 // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
7791 // actually legal and isn't going to get expanded, else this is a false
7792 // optimisation.
7793 bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
7794 Load->getValueType(0),
7795 Load->getMemoryVT());
7796
7797 // Resize the constant to the same size as the original memory access before
7798 // extension. If it is still the AllOnesValue then this AND is completely
7799 // unneeded.
7800 Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
7801
7802 bool B;
7803 switch (Load->getExtensionType()) {
7804 default: B = false; break;
7805 case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
7806 case ISD::ZEXTLOAD:
7807 case ISD::NON_EXTLOAD: B = true; break;
7808 }
7809
7810 if (B && Constant.isAllOnes()) {
7811 // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
7812 // preserve semantics once we get rid of the AND.
7813 SDValue NewLoad(Load, 0);
7814
7815 // Fold the AND away. NewLoad may get replaced immediately.
7816 CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
7817
7818 if (Load->getExtensionType() == ISD::EXTLOAD) {
7819 NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
7820 Load->getValueType(0), SDLoc(Load),
7821 Load->getChain(), Load->getBasePtr(),
7822 Load->getOffset(), Load->getMemoryVT(),
7823 Load->getMemOperand());
7824 // Replace uses of the EXTLOAD with the new ZEXTLOAD.
7825 if (Load->getNumValues() == 3) {
7826 // PRE/POST_INC loads have 3 values.
7827 SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
7828 NewLoad.getValue(2) };
7829 CombineTo(Load, To, 3, true);
7830 } else {
7831 CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
7832 }
7833 }
7834
7835 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7836 }
7837 }
7838
7839 // Try to convert a constant mask AND into a shuffle clear mask.
7840 if (VT.isVector())
7841 if (SDValue Shuffle = XformToShuffleWithZero(N))
7842 return Shuffle;
7843
7844 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7845 return Combined;
7846
7847 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
7849 SDValue Ext = N0.getOperand(0);
7850 EVT ExtVT = Ext->getValueType(0);
7851 SDValue Extendee = Ext->getOperand(0);
7852
7853 unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
7854 if (N1C->getAPIntValue().isMask(ScalarWidth) &&
7855 (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, ExtVT))) {
7856 // (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
7857 // => (extract_subvector (iN_zeroext v))
7858 SDValue ZeroExtExtendee =
7859 DAG.getNode(ISD::ZERO_EXTEND, DL, ExtVT, Extendee);
7860
7861 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ZeroExtExtendee,
7862 N0.getOperand(1));
7863 }
7864 }
7865
7866 // fold (and (masked_gather x)) -> (zext_masked_gather x)
7867 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
7868 EVT MemVT = GN0->getMemoryVT();
7869 EVT ScalarVT = MemVT.getScalarType();
7870
7871 if (SDValue(GN0, 0).hasOneUse() &&
7872 isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
7874 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
7875 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
7876
7877 SDValue ZExtLoad = DAG.getMaskedGather(
7878 DAG.getVTList(VT, MVT::Other), MemVT, DL, Ops, GN0->getMemOperand(),
7879 GN0->getIndexType(), ISD::ZEXTLOAD);
7880
7881 CombineTo(N, ZExtLoad);
7882 AddToWorklist(ZExtLoad.getNode());
7883 // Avoid recheck of N.
7884 return SDValue(N, 0);
7885 }
7886 }
7887
7888 // fold (and (load x), 255) -> (zextload x, i8)
7889 // fold (and (extload x, i16), 255) -> (zextload x, i8)
7890 if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
7891 if (SDValue Res = reduceLoadWidth(N))
7892 return Res;
7893
7894 if (LegalTypes) {
7895 // Attempt to propagate the AND back up to the leaves which, if they're
7896 // loads, can be combined to narrow loads and the AND node can be removed.
7897 // Perform after legalization so that extend nodes will already be
7898 // combined into the loads.
7899 if (BackwardsPropagateMask(N))
7900 return SDValue(N, 0);
7901 }
7902
7903 if (SDValue Combined = visitANDLike(N0, N1, N))
7904 return Combined;
7905
7906 // Simplify: (and (op x...), (op y...)) -> (op (and x, y))
7907 if (N0.getOpcode() == N1.getOpcode())
7908 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7909 return V;
7910
7911 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
7912 return R;
7913 if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
7914 return R;
7915
7916 // Fold (and X, (bswap (not Y))) -> (and X, (not (bswap Y)))
7917 // Fold (and X, (bitreverse (not Y))) -> (and X, (not (bitreverse Y)))
7918 SDValue X, Y, Z, NotY;
7919 for (unsigned Opc : {ISD::BSWAP, ISD::BITREVERSE})
7920 if (sd_match(N,
7921 m_And(m_Value(X), m_OneUse(m_UnaryOp(Opc, m_Value(NotY))))) &&
7922 sd_match(NotY, m_Not(m_Value(Y))) &&
7923 (TLI.hasAndNot(SDValue(N, 0)) || NotY->hasOneUse()))
7924 return DAG.getNode(ISD::AND, DL, VT, X,
7925 DAG.getNOT(DL, DAG.getNode(Opc, DL, VT, Y), VT));
7926
7927 // Fold (and X, (rot (not Y), Z)) -> (and X, (not (rot Y, Z)))
7928 for (unsigned Opc : {ISD::ROTL, ISD::ROTR})
7929 if (sd_match(N, m_And(m_Value(X),
7930 m_OneUse(m_BinOp(Opc, m_Value(NotY), m_Value(Z))))) &&
7931 sd_match(NotY, m_Not(m_Value(Y))) &&
7932 (TLI.hasAndNot(SDValue(N, 0)) || NotY->hasOneUse()))
7933 return DAG.getNode(ISD::AND, DL, VT, X,
7934 DAG.getNOT(DL, DAG.getNode(Opc, DL, VT, Y, Z), VT));
7935
7936 // Fold (and X, (add (not Y), Z)) -> (and X, (not (sub Y, Z)))
7937 // Fold (and X, (sub (not Y), Z)) -> (and X, (not (add Y, Z)))
7938 if (TLI.hasAndNot(SDValue(N, 0)))
7939 if (SDValue Folded = foldBitwiseOpWithNeg(N, DL, VT))
7940 return Folded;
7941
7942 // Fold (and (srl X, C), 1) -> (srl X, BW-1) for signbit extraction
7943 // If we are shifting down an extended sign bit, see if we can simplify
7944 // this to shifting the MSB directly to expose further simplifications.
7945 // This pattern often appears after sext_inreg legalization.
7946 APInt Amt;
7947 if (sd_match(N, m_And(m_Srl(m_Value(X), m_ConstInt(Amt)), m_One())) &&
7948 Amt.ult(BitWidth - 1) && Amt.uge(BitWidth - DAG.ComputeNumSignBits(X)))
7949 return DAG.getNode(ISD::SRL, DL, VT, X,
7950 DAG.getShiftAmountConstant(BitWidth - 1, VT, DL));
7951
7952 // Masking the negated extension of a boolean is just the zero-extended
7953 // boolean:
7954 // and (sub 0, zext(bool X)), 1 --> zext(bool X)
7955 // and (sub 0, sext(bool X)), 1 --> zext(bool X)
7956 //
7957 // Note: the SimplifyDemandedBits fold below can make an information-losing
7958 // transform, and then we have no way to find this better fold.
7959 if (sd_match(N, m_And(m_Sub(m_Zero(), m_Value(X)), m_One()))) {
7960 if (X.getOpcode() == ISD::ZERO_EXTEND &&
7961 X.getOperand(0).getScalarValueSizeInBits() == 1)
7962 return X;
7963 if (X.getOpcode() == ISD::SIGN_EXTEND &&
7964 X.getOperand(0).getScalarValueSizeInBits() == 1)
7965 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, X.getOperand(0));
7966 }
7967
7968 // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
7969 // fold (and (sra)) -> (and (srl)) when possible.
7971 return SDValue(N, 0);
7972
7973 // fold (zext_inreg (extload x)) -> (zextload x)
7974 // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
7975 if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
7976 (ISD::isEXTLoad(N0.getNode()) ||
7977 (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
7978 auto *LN0 = cast<LoadSDNode>(N0);
7979 EVT MemVT = LN0->getMemoryVT();
7980 // If we zero all the possible extended bits, then we can turn this into
7981 // a zextload if we are running before legalize or the operation is legal.
7982 unsigned ExtBitSize = N1.getScalarValueSizeInBits();
7983 unsigned MemBitSize = MemVT.getScalarSizeInBits();
7984 APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
7985 if (DAG.MaskedValueIsZero(N1, ExtBits) &&
7986 ((!LegalOperations && LN0->isSimple()) ||
7987 TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
7988 SDValue ExtLoad =
7989 DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
7990 LN0->getBasePtr(), MemVT, LN0->getMemOperand());
7991 AddToWorklist(N);
7992 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
7993 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7994 }
7995 }
7996
7997 // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
7998 if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
7999 if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
8000 N0.getOperand(1), false))
8001 return BSwap;
8002 }
8003
8004 if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
8005 return Shifts;
8006
8007 if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
8008 return V;
8009
8010 // Recognize the following pattern:
8011 //
8012 // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
8013 //
8014 // where bitmask is a mask that clears the upper bits of AndVT. The
8015 // number of bits in bitmask must be a power of two.
8016 auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
8017 if (LHS->getOpcode() != ISD::SIGN_EXTEND)
8018 return false;
8019
8021 if (!C)
8022 return false;
8023
8024 if (!C->getAPIntValue().isMask(
8025 LHS.getOperand(0).getValueType().getFixedSizeInBits()))
8026 return false;
8027
8028 return true;
8029 };
8030
8031 // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
8032 if (IsAndZeroExtMask(N0, N1))
8033 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
8034
8035 if (hasOperation(ISD::USUBSAT, VT))
8036 if (SDValue V = foldAndToUsubsat(N, DAG, DL))
8037 return V;
8038
8039 // Postpone until legalization completed to avoid interference with bswap
8040 // folding
8041 if (LegalOperations || VT.isVector())
8042 if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
8043 return R;
8044
8045 if (VT.isScalarInteger() && VT != MVT::i1)
8046 if (SDValue R = foldMaskedMerge(N, DAG, TLI, DL))
8047 return R;
8048
8049 return SDValue();
8050}
8051
8052/// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
8053SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
8054 bool DemandHighBits) {
8055 if (!LegalOperations)
8056 return SDValue();
8057
8058 EVT VT = N->getValueType(0);
8059 if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
8060 return SDValue();
8062 return SDValue();
8063
8064 // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
8065 bool LookPassAnd0 = false;
8066 bool LookPassAnd1 = false;
8067 if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
8068 std::swap(N0, N1);
8069 if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
8070 std::swap(N0, N1);
8071 if (N0.getOpcode() == ISD::AND) {
8072 if (!N0->hasOneUse())
8073 return SDValue();
8074 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
8075 // Also handle 0xffff since the LHS is guaranteed to have zeros there.
8076 // This is needed for X86.
8077 if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
8078 N01C->getZExtValue() != 0xFFFF))
8079 return SDValue();
8080 N0 = N0.getOperand(0);
8081 LookPassAnd0 = true;
8082 }
8083
8084 if (N1.getOpcode() == ISD::AND) {
8085 if (!N1->hasOneUse())
8086 return SDValue();
8087 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
8088 if (!N11C || N11C->getZExtValue() != 0xFF)
8089 return SDValue();
8090 N1 = N1.getOperand(0);
8091 LookPassAnd1 = true;
8092 }
8093
8094 if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
8095 std::swap(N0, N1);
8096 if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
8097 return SDValue();
8098 if (!N0->hasOneUse() || !N1->hasOneUse())
8099 return SDValue();
8100
8101 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
8102 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
8103 if (!N01C || !N11C)
8104 return SDValue();
8105 if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
8106 return SDValue();
8107
8108 // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
8109 SDValue N00 = N0->getOperand(0);
8110 if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
8111 if (!N00->hasOneUse())
8112 return SDValue();
8113 ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
8114 if (!N001C || N001C->getZExtValue() != 0xFF)
8115 return SDValue();
8116 N00 = N00.getOperand(0);
8117 LookPassAnd0 = true;
8118 }
8119
8120 SDValue N10 = N1->getOperand(0);
8121 if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
8122 if (!N10->hasOneUse())
8123 return SDValue();
8124 ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
8125 // Also allow 0xFFFF since the bits will be shifted out. This is needed
8126 // for X86.
8127 if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
8128 N101C->getZExtValue() != 0xFFFF))
8129 return SDValue();
8130 N10 = N10.getOperand(0);
8131 LookPassAnd1 = true;
8132 }
8133
8134 if (N00 != N10)
8135 return SDValue();
8136
8137 // Make sure everything beyond the low halfword gets set to zero since the SRL
8138 // 16 will clear the top bits.
8139 unsigned OpSizeInBits = VT.getSizeInBits();
8140 if (OpSizeInBits > 16) {
8141 // If the left-shift isn't masked out then the only way this is a bswap is
8142 // if all bits beyond the low 8 are 0. In that case the entire pattern
8143 // reduces to a left shift anyway: leave it for other parts of the combiner.
8144 if (DemandHighBits && !LookPassAnd0)
8145 return SDValue();
8146
8147 // However, if the right shift isn't masked out then it might be because
8148 // it's not needed. See if we can spot that too. If the high bits aren't
8149 // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
8150 // upper bits to be zero.
8151 if (!LookPassAnd1) {
8152 unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
8153 if (!DAG.MaskedValueIsZero(N10,
8154 APInt::getBitsSet(OpSizeInBits, 16, HighBit)))
8155 return SDValue();
8156 }
8157 }
8158
8159 SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
8160 if (OpSizeInBits > 16) {
8161 SDLoc DL(N);
8162 Res = DAG.getNode(ISD::SRL, DL, VT, Res,
8163 DAG.getShiftAmountConstant(OpSizeInBits - 16, VT, DL));
8164 }
8165 return Res;
8166}
8167
8168/// Return true if the specified node is an element that makes up a 32-bit
8169/// packed halfword byteswap.
8170/// ((x & 0x000000ff) << 8) |
8171/// ((x & 0x0000ff00) >> 8) |
8172/// ((x & 0x00ff0000) << 8) |
8173/// ((x & 0xff000000) >> 8)
8175 if (!N->hasOneUse())
8176 return false;
8177
8178 unsigned Opc = N.getOpcode();
8179 if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
8180 return false;
8181
8182 SDValue N0 = N.getOperand(0);
8183 unsigned Opc0 = N0.getOpcode();
8184 if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
8185 return false;
8186
8187 ConstantSDNode *N1C = nullptr;
8188 // SHL or SRL: look upstream for AND mask operand
8189 if (Opc == ISD::AND)
8190 N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
8191 else if (Opc0 == ISD::AND)
8193 if (!N1C)
8194 return false;
8195
8196 unsigned MaskByteOffset;
8197 switch (N1C->getZExtValue()) {
8198 default:
8199 return false;
8200 case 0xFF: MaskByteOffset = 0; break;
8201 case 0xFF00: MaskByteOffset = 1; break;
8202 case 0xFFFF:
8203 // In case demanded bits didn't clear the bits that will be shifted out.
8204 // This is needed for X86.
8205 if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
8206 MaskByteOffset = 1;
8207 break;
8208 }
8209 return false;
8210 case 0xFF0000: MaskByteOffset = 2; break;
8211 case 0xFF000000: MaskByteOffset = 3; break;
8212 }
8213
8214 // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
8215 if (Opc == ISD::AND) {
8216 if (MaskByteOffset == 0 || MaskByteOffset == 2) {
8217 // (x >> 8) & 0xff
8218 // (x >> 8) & 0xff0000
8219 if (Opc0 != ISD::SRL)
8220 return false;
8222 if (!C || C->getZExtValue() != 8)
8223 return false;
8224 } else {
8225 // (x << 8) & 0xff00
8226 // (x << 8) & 0xff000000
8227 if (Opc0 != ISD::SHL)
8228 return false;
8230 if (!C || C->getZExtValue() != 8)
8231 return false;
8232 }
8233 } else if (Opc == ISD::SHL) {
8234 // (x & 0xff) << 8
8235 // (x & 0xff0000) << 8
8236 if (MaskByteOffset != 0 && MaskByteOffset != 2)
8237 return false;
8238 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
8239 if (!C || C->getZExtValue() != 8)
8240 return false;
8241 } else { // Opc == ISD::SRL
8242 // (x & 0xff00) >> 8
8243 // (x & 0xff000000) >> 8
8244 if (MaskByteOffset != 1 && MaskByteOffset != 3)
8245 return false;
8246 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
8247 if (!C || C->getZExtValue() != 8)
8248 return false;
8249 }
8250
8251 if (Parts[MaskByteOffset])
8252 return false;
8253
8254 Parts[MaskByteOffset] = N0.getOperand(0).getNode();
8255 return true;
8256}
8257
8258// Match 2 elements of a packed halfword bswap.
8260 if (N.getOpcode() == ISD::OR)
8261 return isBSwapHWordElement(N.getOperand(0), Parts) &&
8262 isBSwapHWordElement(N.getOperand(1), Parts);
8263
8264 if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
8265 ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
8266 if (!C || C->getAPIntValue() != 16)
8267 return false;
8268 Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
8269 return true;
8270 }
8271
8272 return false;
8273}
8274
8275// Match this pattern:
8276// (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
8277// And rewrite this to:
8278// (rotr (bswap A), 16)
8280 SelectionDAG &DAG, SDNode *N, SDValue N0,
8281 SDValue N1, EVT VT) {
8282 assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
8283 "MatchBSwapHWordOrAndAnd: expecting i32");
8284 if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
8285 return SDValue();
8286 if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
8287 return SDValue();
8288 // TODO: this is too restrictive; lifting this restriction requires more tests
8289 if (!N0->hasOneUse() || !N1->hasOneUse())
8290 return SDValue();
8293 if (!Mask0 || !Mask1)
8294 return SDValue();
8295 if (Mask0->getAPIntValue() != 0xff00ff00 ||
8296 Mask1->getAPIntValue() != 0x00ff00ff)
8297 return SDValue();
8298 SDValue Shift0 = N0.getOperand(0);
8299 SDValue Shift1 = N1.getOperand(0);
8300 if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
8301 return SDValue();
8302 ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
8303 ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
8304 if (!ShiftAmt0 || !ShiftAmt1)
8305 return SDValue();
8306 if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
8307 return SDValue();
8308 if (Shift0.getOperand(0) != Shift1.getOperand(0))
8309 return SDValue();
8310
8311 SDLoc DL(N);
8312 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
8313 SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
8314 return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
8315}
8316
8317/// Match a 32-bit packed halfword bswap. That is
8318/// ((x & 0x000000ff) << 8) |
8319/// ((x & 0x0000ff00) >> 8) |
8320/// ((x & 0x00ff0000) << 8) |
8321/// ((x & 0xff000000) >> 8)
8322/// => (rotl (bswap x), 16)
8323SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
8324 if (!LegalOperations)
8325 return SDValue();
8326
8327 EVT VT = N->getValueType(0);
8328 if (VT != MVT::i32)
8329 return SDValue();
8331 return SDValue();
8332
8333 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
8334 return BSwap;
8335
8336 // Try again with commuted operands.
8337 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT))
8338 return BSwap;
8339
8340
8341 // Look for either
8342 // (or (bswaphpair), (bswaphpair))
8343 // (or (or (bswaphpair), (and)), (and))
8344 // (or (or (and), (bswaphpair)), (and))
8345 SDNode *Parts[4] = {};
8346
8347 if (isBSwapHWordPair(N0, Parts)) {
8348 // (or (or (and), (and)), (or (and), (and)))
8349 if (!isBSwapHWordPair(N1, Parts))
8350 return SDValue();
8351 } else if (N0.getOpcode() == ISD::OR) {
8352 // (or (or (or (and), (and)), (and)), (and))
8353 if (!isBSwapHWordElement(N1, Parts))
8354 return SDValue();
8355 SDValue N00 = N0.getOperand(0);
8356 SDValue N01 = N0.getOperand(1);
8357 if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
8358 !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
8359 return SDValue();
8360 } else {
8361 return SDValue();
8362 }
8363
8364 // Make sure the parts are all coming from the same node.
8365 if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
8366 return SDValue();
8367
8368 SDLoc DL(N);
8369 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
8370 SDValue(Parts[0], 0));
8371
8372 // Result of the bswap should be rotated by 16. If it's not legal, then
8373 // do (x << 16) | (x >> 16).
8374 SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
8376 return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
8378 return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
8379 return DAG.getNode(ISD::OR, DL, VT,
8380 DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
8381 DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
8382}
8383
8384/// This contains all DAGCombine rules which reduce two values combined by
8385/// an Or operation to a single value \see visitANDLike().
8386SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, const SDLoc &DL) {
8387 EVT VT = N1.getValueType();
8388
8389 // fold (or x, undef) -> -1
8390 if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
8391 return DAG.getAllOnesConstant(DL, VT);
8392
8393 if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
8394 return V;
8395
8396 // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
8397 if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
8398 // Don't increase # computations.
8399 (N0->hasOneUse() || N1->hasOneUse())) {
8400 // We can only do this xform if we know that bits from X that are set in C2
8401 // but not in C1 are already zero. Likewise for Y.
8402 if (const ConstantSDNode *N0O1C =
8404 if (const ConstantSDNode *N1O1C =
8406 // We can only do this xform if we know that bits from X that are set in
8407 // C2 but not in C1 are already zero. Likewise for Y.
8408 const APInt &LHSMask = N0O1C->getAPIntValue();
8409 const APInt &RHSMask = N1O1C->getAPIntValue();
8410
8411 if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
8412 DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
8413 SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
8414 N0.getOperand(0), N1.getOperand(0));
8415 return DAG.getNode(ISD::AND, DL, VT, X,
8416 DAG.getConstant(LHSMask | RHSMask, DL, VT));
8417 }
8418 }
8419 }
8420 }
8421
8422 // (or (and X, M), (and X, N)) -> (and X, (or M, N))
8423 if (N0.getOpcode() == ISD::AND &&
8424 N1.getOpcode() == ISD::AND &&
8425 N0.getOperand(0) == N1.getOperand(0) &&
8426 // Don't increase # computations.
8427 (N0->hasOneUse() || N1->hasOneUse())) {
8428 SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
8429 N0.getOperand(1), N1.getOperand(1));
8430 return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
8431 }
8432
8433 return SDValue();
8434}
8435
8436/// OR combines for which the commuted variant will be tried as well.
8438 SDNode *N) {
8439 EVT VT = N0.getValueType();
8440 unsigned BW = VT.getScalarSizeInBits();
8441 SDLoc DL(N);
8442
8443 auto peekThroughResize = [](SDValue V) {
8444 if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE)
8445 return V->getOperand(0);
8446 return V;
8447 };
8448
8449 SDValue N0Resized = peekThroughResize(N0);
8450 if (N0Resized.getOpcode() == ISD::AND) {
8451 SDValue N1Resized = peekThroughResize(N1);
8452 SDValue N00 = N0Resized.getOperand(0);
8453 SDValue N01 = N0Resized.getOperand(1);
8454
8455 // fold or (and x, y), x --> x
8456 if (N00 == N1Resized || N01 == N1Resized)
8457 return N1;
8458
8459 // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
8460 // TODO: Set AllowUndefs = true.
8461 if (SDValue NotOperand = getBitwiseNotOperand(N01, N00,
8462 /* AllowUndefs */ false)) {
8463 if (peekThroughResize(NotOperand) == N1Resized)
8464 return DAG.getNode(ISD::OR, DL, VT, DAG.getZExtOrTrunc(N00, DL, VT),
8465 N1);
8466 }
8467
8468 // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
8469 if (SDValue NotOperand = getBitwiseNotOperand(N00, N01,
8470 /* AllowUndefs */ false)) {
8471 if (peekThroughResize(NotOperand) == N1Resized)
8472 return DAG.getNode(ISD::OR, DL, VT, DAG.getZExtOrTrunc(N01, DL, VT),
8473 N1);
8474 }
8475 }
8476
8477 SDValue X, Y;
8478
8479 // fold or (xor X, N1), N1 --> or X, N1
8480 if (sd_match(N0, m_Xor(m_Value(X), m_Specific(N1))))
8481 return DAG.getNode(ISD::OR, DL, VT, X, N1);
8482
8483 // fold or (xor x, y), (x and/or y) --> or x, y
8484 if (sd_match(N0, m_Xor(m_Value(X), m_Value(Y))) &&
8485 (sd_match(N1, m_And(m_Specific(X), m_Specific(Y))) ||
8487 return DAG.getNode(ISD::OR, DL, VT, X, Y);
8488
8489 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
8490 return R;
8491
8492 auto peekThroughZext = [](SDValue V) {
8493 if (V->getOpcode() == ISD::ZERO_EXTEND)
8494 return V->getOperand(0);
8495 return V;
8496 };
8497
8498 if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
8499 peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1))) {
8500 // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
8501 if (N0.getOperand(0) == N1.getOperand(0))
8502 return N0;
8503 // (fshl A, X, Y) | (shl X, Y) --> fshl (A|X), X, Y
8504 if (N0.getOperand(1) == N1.getOperand(0) && N0.hasOneUse() &&
8505 N1.hasOneUse()) {
8506 SDValue A = N0.getOperand(0);
8507 SDValue X = N1.getOperand(0);
8508 SDValue NewLHS = DAG.getNode(ISD::OR, DL, VT, A, X);
8509 return DAG.getNode(ISD::FSHL, DL, VT, NewLHS, X, N0.getOperand(2));
8510 }
8511 }
8512
8513 if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
8514 peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1))) {
8515 // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
8516 if (N0.getOperand(1) == N1.getOperand(0))
8517 return N0;
8518 // (fshr X, B, Y) | (srl X, Y) --> fshr X, (X|B), Y
8519 if (N0.getOperand(0) == N1.getOperand(0) && N0.hasOneUse() &&
8520 N1.hasOneUse()) {
8521 SDValue X = N1.getOperand(0);
8522 SDValue B = N0.getOperand(1);
8523 SDValue NewRHS = DAG.getNode(ISD::OR, DL, VT, X, B);
8524 return DAG.getNode(ISD::FSHR, DL, VT, X, NewRHS, N0.getOperand(2));
8525 }
8526 }
8527
8528 // (fshl A, B, S0) | (fshr C, D, S1) --> fshl (A|C), (B|D), S0
8529 // iff S0 + S1 == bitwidth(S1)
8530 if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::FSHR &&
8531 N0.hasOneUse() && N1.hasOneUse()) {
8532 auto *S0 = dyn_cast<ConstantSDNode>(N0.getOperand(2));
8533 auto *S1 = dyn_cast<ConstantSDNode>(N1.getOperand(2));
8534 if (S0 && S1 && S0->getZExtValue() < BW && S1->getZExtValue() < BW &&
8535 S0->getZExtValue() == (BW - S1->getZExtValue())) {
8536 SDValue A = N0.getOperand(0);
8537 SDValue B = N0.getOperand(1);
8538 SDValue C = N1.getOperand(0);
8539 SDValue D = N1.getOperand(1);
8540 SDValue NewLHS = DAG.getNode(ISD::OR, DL, VT, A, C);
8541 SDValue NewRHS = DAG.getNode(ISD::OR, DL, VT, B, D);
8542 return DAG.getNode(ISD::FSHL, DL, VT, NewLHS, NewRHS, N0.getOperand(2));
8543 }
8544 }
8545
8546 // Attempt to match a legalized build_pair-esque pattern:
8547 // or(shl(aext(Hi),BW/2),zext(Lo))
8548 SDValue Lo, Hi;
8549 if (sd_match(N0,
8551 sd_match(N1, m_ZExt(m_Value(Lo))) &&
8552 Lo.getScalarValueSizeInBits() == (BW / 2) &&
8553 Lo.getValueType() == Hi.getValueType()) {
8554 // Fold build_pair(not(Lo),not(Hi)) -> not(build_pair(Lo,Hi)).
8555 SDValue NotLo, NotHi;
8556 if (sd_match(Lo, m_OneUse(m_Not(m_Value(NotLo)))) &&
8557 sd_match(Hi, m_OneUse(m_Not(m_Value(NotHi))))) {
8558 Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotLo);
8559 Hi = DAG.getNode(ISD::ANY_EXTEND, DL, VT, NotHi);
8560 Hi = DAG.getNode(ISD::SHL, DL, VT, Hi,
8561 DAG.getShiftAmountConstant(BW / 2, VT, DL));
8562 return DAG.getNOT(DL, DAG.getNode(ISD::OR, DL, VT, Lo, Hi), VT);
8563 }
8564 }
8565
8566 return SDValue();
8567}
8568
8569SDValue DAGCombiner::visitOR(SDNode *N) {
8570 SDValue N0 = N->getOperand(0);
8571 SDValue N1 = N->getOperand(1);
8572 EVT VT = N1.getValueType();
8573 SDLoc DL(N);
8574
8575 // x | x --> x
8576 if (N0 == N1)
8577 return N0;
8578
8579 // fold (or c1, c2) -> c1|c2
8580 if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, DL, VT, {N0, N1}))
8581 return C;
8582
8583 // canonicalize constant to RHS
8586 return DAG.getNode(ISD::OR, DL, VT, N1, N0);
8587
8588 // fold vector ops
8589 if (VT.isVector()) {
8590 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
8591 return FoldedVOp;
8592
8593 // fold (or x, 0) -> x, vector edition
8595 return N0;
8596
8597 // fold (or x, -1) -> -1, vector edition
8599 // do not return N1, because undef node may exist in N1
8600 return DAG.getAllOnesConstant(DL, N1.getValueType());
8601
8602 // fold (or buildvector(x,0,-1,w), buildvector(0,y,z,w))
8603 // --> buildvector(x,y,-1,w)
8604 auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
8605 auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
8606 if (BV0 && BV1 && !BV0->getSplatValue() && !BV1->getSplatValue() &&
8607 N0.hasOneUse() && N1.hasOneUse() &&
8608 BV0->getOperand(0).getValueType() ==
8609 BV1->getOperand(0).getValueType()) {
8610 SmallVector<SDValue> MergedOps;
8611 unsigned NumElts = VT.getVectorNumElements();
8612 EVT EltVT = BV0->getOperand(0).getValueType();
8613 for (unsigned I = 0; I != NumElts; ++I) {
8614 auto *C0 = dyn_cast<ConstantSDNode>(BV0->getOperand(I));
8615 auto *C1 = dyn_cast<ConstantSDNode>(BV1->getOperand(I));
8616 if (C0 && C1)
8617 MergedOps.push_back(DAG.getConstant(
8618 C0->getAPIntValue() | C1->getAPIntValue(), DL, EltVT));
8619 else if (C0 && C0->isZero())
8620 MergedOps.push_back(BV1->getOperand(I));
8621 else if (C1 && C1->isZero())
8622 MergedOps.push_back(BV0->getOperand(I));
8623 else if (C0 && C0->isAllOnes())
8624 MergedOps.push_back(BV0->getOperand(I));
8625 else if (C1 && C1->isAllOnes())
8626 MergedOps.push_back(BV1->getOperand(I));
8627 else if (BV0->getOperand(I) == BV1->getOperand(I))
8628 MergedOps.push_back(BV0->getOperand(I));
8629 else
8630 break;
8631 }
8632 if (MergedOps.size() == NumElts)
8633 return DAG.getBuildVector(VT, DL, MergedOps);
8634 }
8635
8636 // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
8637 // Do this only if the resulting type / shuffle is legal.
8638 auto *SV0 = dyn_cast<ShuffleVectorSDNode>(N0);
8639 auto *SV1 = dyn_cast<ShuffleVectorSDNode>(N1);
8640 if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
8641 bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
8642 bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
8643 bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
8644 bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
8645 // Ensure both shuffles have a zero input.
8646 if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
8647 assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
8648 assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
8649 bool CanFold = true;
8650 int NumElts = VT.getVectorNumElements();
8651 SmallVector<int, 4> Mask(NumElts, -1);
8652
8653 for (int i = 0; i != NumElts; ++i) {
8654 int M0 = SV0->getMaskElt(i);
8655 int M1 = SV1->getMaskElt(i);
8656
8657 // Determine if either index is pointing to a zero vector.
8658 bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
8659 bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
8660
8661 // If one element is zero and the otherside is undef, keep undef.
8662 // This also handles the case that both are undef.
8663 if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
8664 continue;
8665
8666 // Make sure only one of the elements is zero.
8667 if (M0Zero == M1Zero) {
8668 CanFold = false;
8669 break;
8670 }
8671
8672 assert((M0 >= 0 || M1 >= 0) && "Undef index!");
8673
8674 // We have a zero and non-zero element. If the non-zero came from
8675 // SV0 make the index a LHS index. If it came from SV1, make it
8676 // a RHS index. We need to mod by NumElts because we don't care
8677 // which operand it came from in the original shuffles.
8678 Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
8679 }
8680
8681 if (CanFold) {
8682 SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
8683 SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
8684 SDValue LegalShuffle =
8685 TLI.buildLegalVectorShuffle(VT, DL, NewLHS, NewRHS, Mask, DAG);
8686 if (LegalShuffle)
8687 return LegalShuffle;
8688 }
8689 }
8690 }
8691 }
8692
8693 // fold (or x, 0) -> x
8694 if (isNullConstant(N1))
8695 return N0;
8696
8697 // fold (or x, -1) -> -1
8698 if (isAllOnesConstant(N1))
8699 return N1;
8700
8701 if (SDValue NewSel = foldBinOpIntoSelect(N))
8702 return NewSel;
8703
8704 // fold (or x, c) -> c iff (x & ~c) == 0
8705 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
8706 if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
8707 return N1;
8708
8709 if (SDValue R = foldAndOrOfSETCC(N, DAG))
8710 return R;
8711
8712 if (SDValue Combined = visitORLike(N0, N1, DL))
8713 return Combined;
8714
8715 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
8716 return Combined;
8717
8718 // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
8719 if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
8720 return BSwap;
8721 if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
8722 return BSwap;
8723
8724 // reassociate or
8725 if (SDValue ROR = reassociateOps(ISD::OR, DL, N0, N1, N->getFlags()))
8726 return ROR;
8727
8728 // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
8729 if (SDValue SD =
8730 reassociateReduction(ISD::VECREDUCE_OR, ISD::OR, DL, VT, N0, N1))
8731 return SD;
8732
8733 // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
8734 // iff (c1 & c2) != 0 or c1/c2 are undef.
8735 auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
8736 return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
8737 };
8738 if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
8739 ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
8740 if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
8741 {N1, N0.getOperand(1)})) {
8742 SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
8743 AddToWorklist(IOR.getNode());
8744 return DAG.getNode(ISD::AND, DL, VT, COR, IOR);
8745 }
8746 }
8747
8748 if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
8749 return Combined;
8750 if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
8751 return Combined;
8752
8753 // Simplify: (or (op x...), (op y...)) -> (op (or x, y))
8754 if (N0.getOpcode() == N1.getOpcode())
8755 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
8756 return V;
8757
8758 // See if this is some rotate idiom.
8759 if (SDValue Rot = MatchRotate(N0, N1, DL, /*FromAdd=*/false))
8760 return Rot;
8761
8762 if (SDValue Load = MatchLoadCombine(N))
8763 return Load;
8764
8765 // Simplify the operands using demanded-bits information.
8767 return SDValue(N, 0);
8768
8769 // If OR can be rewritten into ADD, try combines based on ADD.
8770 if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
8771 DAG.isADDLike(SDValue(N, 0)))
8772 if (SDValue Combined = visitADDLike(N))
8773 return Combined;
8774
8775 // Postpone until legalization completed to avoid interference with bswap
8776 // folding
8777 if (LegalOperations || VT.isVector())
8778 if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
8779 return R;
8780
8781 if (VT.isScalarInteger() && VT != MVT::i1)
8782 if (SDValue R = foldMaskedMerge(N, DAG, TLI, DL))
8783 return R;
8784
8785 return SDValue();
8786}
8787
8789 SDValue &Mask) {
8790 if (Op.getOpcode() == ISD::AND &&
8791 DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
8792 Mask = Op.getOperand(1);
8793 return Op.getOperand(0);
8794 }
8795 return Op;
8796}
8797
8798/// Match "(X shl/srl V1) & V2" where V2 may not be present.
8799static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
8800 SDValue &Mask) {
8801 Op = stripConstantMask(DAG, Op, Mask);
8802 if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
8803 Shift = Op;
8804 return true;
8805 }
8806 return false;
8807}
8808
8809/// Helper function for visitOR to extract the needed side of a rotate idiom
8810/// from a shl/srl/mul/udiv. This is meant to handle cases where
8811/// InstCombine merged some outside op with one of the shifts from
8812/// the rotate pattern.
8813/// \returns An empty \c SDValue if the needed shift couldn't be extracted.
8814/// Otherwise, returns an expansion of \p ExtractFrom based on the following
8815/// patterns:
8816///
8817/// (or (add v v) (shrl v bitwidth-1)):
8818/// expands (add v v) -> (shl v 1)
8819///
8820/// (or (mul v c0) (shrl (mul v c1) c2)):
8821/// expands (mul v c0) -> (shl (mul v c1) c3)
8822///
8823/// (or (udiv v c0) (shl (udiv v c1) c2)):
8824/// expands (udiv v c0) -> (shrl (udiv v c1) c3)
8825///
8826/// (or (shl v c0) (shrl (shl v c1) c2)):
8827/// expands (shl v c0) -> (shl (shl v c1) c3)
8828///
8829/// (or (shrl v c0) (shl (shrl v c1) c2)):
8830/// expands (shrl v c0) -> (shrl (shrl v c1) c3)
8831///
8832/// Such that in all cases, c3+c2==bitwidth(op v c1).
8834 SDValue ExtractFrom, SDValue &Mask,
8835 const SDLoc &DL) {
8836 assert(OppShift && ExtractFrom && "Empty SDValue");
8837 if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
8838 return SDValue();
8839
8840 ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
8841
8842 // Value and Type of the shift.
8843 SDValue OppShiftLHS = OppShift.getOperand(0);
8844 EVT ShiftedVT = OppShiftLHS.getValueType();
8845
8846 // Amount of the existing shift.
8847 ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
8848
8849 // (add v v) -> (shl v 1)
8850 // TODO: Should this be a general DAG canonicalization?
8851 if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
8852 ExtractFrom.getOpcode() == ISD::ADD &&
8853 ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
8854 ExtractFrom.getOperand(0) == OppShiftLHS &&
8855 OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
8856 return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
8857 DAG.getShiftAmountConstant(1, ShiftedVT, DL));
8858
8859 // Preconditions:
8860 // (or (op0 v c0) (shiftl/r (op0 v c1) c2))
8861 //
8862 // Find opcode of the needed shift to be extracted from (op0 v c0).
8863 unsigned Opcode = ISD::DELETED_NODE;
8864 bool IsMulOrDiv = false;
8865 // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
8866 // opcode or its arithmetic (mul or udiv) variant.
8867 auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
8868 IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
8869 if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
8870 return false;
8871 Opcode = NeededShift;
8872 return true;
8873 };
8874 // op0 must be either the needed shift opcode or the mul/udiv equivalent
8875 // that the needed shift can be extracted from.
8876 if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
8877 (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
8878 return SDValue();
8879
8880 // op0 must be the same opcode on both sides, have the same LHS argument,
8881 // and produce the same value type.
8882 if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
8883 OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
8884 ShiftedVT != ExtractFrom.getValueType())
8885 return SDValue();
8886
8887 // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
8888 ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
8889 // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
8890 ConstantSDNode *ExtractFromCst =
8891 isConstOrConstSplat(ExtractFrom.getOperand(1));
8892 // TODO: We should be able to handle non-uniform constant vectors for these values
8893 // Check that we have constant values.
8894 if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
8895 !OppLHSCst || !OppLHSCst->getAPIntValue() ||
8896 !ExtractFromCst || !ExtractFromCst->getAPIntValue())
8897 return SDValue();
8898
8899 // Compute the shift amount we need to extract to complete the rotate.
8900 const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
8901 if (OppShiftCst->getAPIntValue().ugt(VTWidth))
8902 return SDValue();
8903 APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
8904 // Normalize the bitwidth of the two mul/udiv/shift constant operands.
8905 APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
8906 APInt OppLHSAmt = OppLHSCst->getAPIntValue();
8907 zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
8908
8909 // Now try extract the needed shift from the ExtractFrom op and see if the
8910 // result matches up with the existing shift's LHS op.
8911 if (IsMulOrDiv) {
8912 // Op to extract from is a mul or udiv by a constant.
8913 // Check:
8914 // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
8915 // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
8916 const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
8917 NeededShiftAmt.getZExtValue());
8918 APInt ResultAmt;
8919 APInt Rem;
8920 APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
8921 if (Rem != 0 || ResultAmt != OppLHSAmt)
8922 return SDValue();
8923 } else {
8924 // Op to extract from is a shift by a constant.
8925 // Check:
8926 // c2 - (bitwidth(op0 v c0) - c1) == c0
8927 if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
8928 ExtractFromAmt.getBitWidth()))
8929 return SDValue();
8930 }
8931
8932 // Return the expanded shift op that should allow a rotate to be formed.
8933 EVT ShiftVT = OppShift.getOperand(1).getValueType();
8934 EVT ResVT = ExtractFrom.getValueType();
8935 SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
8936 return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
8937}
8938
8939// Return true if we can prove that, whenever Neg and Pos are both in the
8940// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
8941// for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
8942//
8943// (or (shift1 X, Neg), (shift2 X, Pos))
8944//
8945// reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
8946// in direction shift1 by Neg. The range [0, EltSize) means that we only need
8947// to consider shift amounts with defined behavior.
8948//
8949// The IsRotate flag should be set when the LHS of both shifts is the same.
8950// Otherwise if matching a general funnel shift, it should be clear.
8951static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8952 SelectionDAG &DAG, bool IsRotate, bool FromAdd) {
8953 const auto &TLI = DAG.getTargetLoweringInfo();
8954 // If EltSize is a power of 2 then:
8955 //
8956 // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
8957 // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
8958 //
8959 // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
8960 // for the stronger condition:
8961 //
8962 // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
8963 //
8964 // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
8965 // we can just replace Neg with Neg' for the rest of the function.
8966 //
8967 // In other cases we check for the even stronger condition:
8968 //
8969 // Neg == EltSize - Pos [B]
8970 //
8971 // for all Neg and Pos. Note that the (or ...) then invokes undefined
8972 // behavior if Pos == 0 (and consequently Neg == EltSize).
8973 //
8974 // We could actually use [A] whenever EltSize is a power of 2, but the
8975 // only extra cases that it would match are those uninteresting ones
8976 // where Neg and Pos are never in range at the same time. E.g. for
8977 // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
8978 // as well as (sub 32, Pos), but:
8979 //
8980 // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
8981 //
8982 // always invokes undefined behavior for 32-bit X.
8983 //
8984 // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
8985 // This allows us to peek through any operations that only affect Mask's
8986 // un-demanded bits.
8987 //
8988 // NOTE: We can only do this when matching operations which won't modify the
8989 // least Log2(EltSize) significant bits and not a general funnel shift.
8990 unsigned MaskLoBits = 0;
8991 if (IsRotate && !FromAdd && isPowerOf2_64(EltSize)) {
8992 unsigned Bits = Log2_64(EltSize);
8993 unsigned NegBits = Neg.getScalarValueSizeInBits();
8994 if (NegBits >= Bits) {
8995 APInt DemandedBits = APInt::getLowBitsSet(NegBits, Bits);
8996 if (SDValue Inner =
8998 Neg = Inner;
8999 MaskLoBits = Bits;
9000 }
9001 }
9002 }
9003
9004 // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
9005 if (Neg.getOpcode() != ISD::SUB)
9006 return false;
9008 if (!NegC)
9009 return false;
9010 SDValue NegOp1 = Neg.getOperand(1);
9011
9012 // On the RHS of [A], if Pos is the result of operation on Pos' that won't
9013 // affect Mask's demanded bits, just replace Pos with Pos'. These operations
9014 // are redundant for the purpose of the equality.
9015 if (MaskLoBits) {
9016 unsigned PosBits = Pos.getScalarValueSizeInBits();
9017 if (PosBits >= MaskLoBits) {
9018 APInt DemandedBits = APInt::getLowBitsSet(PosBits, MaskLoBits);
9019 if (SDValue Inner =
9021 Pos = Inner;
9022 }
9023 }
9024 }
9025
9026 // The condition we need is now:
9027 //
9028 // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
9029 //
9030 // If NegOp1 == Pos then we need:
9031 //
9032 // EltSize & Mask == NegC & Mask
9033 //
9034 // (because "x & Mask" is a truncation and distributes through subtraction).
9035 //
9036 // We also need to account for a potential truncation of NegOp1 if the amount
9037 // has already been legalized to a shift amount type.
9038 APInt Width;
9039 if ((Pos == NegOp1) ||
9040 (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
9041 Width = NegC->getAPIntValue();
9042
9043 // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
9044 // Then the condition we want to prove becomes:
9045 //
9046 // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
9047 //
9048 // which, again because "x & Mask" is a truncation, becomes:
9049 //
9050 // NegC & Mask == (EltSize - PosC) & Mask
9051 // EltSize & Mask == (NegC + PosC) & Mask
9052 else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
9053 if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
9054 Width = PosC->getAPIntValue() + NegC->getAPIntValue();
9055 else
9056 return false;
9057 } else
9058 return false;
9059
9060 // Now we just need to check that EltSize & Mask == Width & Mask.
9061 if (MaskLoBits)
9062 // EltSize & Mask is 0 since Mask is EltSize - 1.
9063 return Width.getLoBits(MaskLoBits) == 0;
9064 return Width == EltSize;
9065}
9066
9067// A subroutine of MatchRotate used once we have found an OR of two opposite
9068// shifts of Shifted. If Neg == <operand size> - Pos then the OR reduces
9069// to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
9070// former being preferred if supported. InnerPos and InnerNeg are Pos and
9071// Neg with outer conversions stripped away.
9072SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
9073 SDValue Neg, SDValue InnerPos,
9074 SDValue InnerNeg, bool FromAdd,
9075 bool HasPos, unsigned PosOpcode,
9076 unsigned NegOpcode, const SDLoc &DL) {
9077 // fold (or/add (shl x, (*ext y)),
9078 // (srl x, (*ext (sub 32, y)))) ->
9079 // (rotl x, y) or (rotr x, (sub 32, y))
9080 //
9081 // fold (or/add (shl x, (*ext (sub 32, y))),
9082 // (srl x, (*ext y))) ->
9083 // (rotr x, y) or (rotl x, (sub 32, y))
9084 EVT VT = Shifted.getValueType();
9085 if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
9086 /*IsRotate*/ true, FromAdd))
9087 return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
9088 HasPos ? Pos : Neg);
9089
9090 return SDValue();
9091}
9092
9093// A subroutine of MatchRotate used once we have found an OR of two opposite
9094// shifts of N0 + N1. If Neg == <operand size> - Pos then the OR reduces
9095// to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
9096// former being preferred if supported. InnerPos and InnerNeg are Pos and
9097// Neg with outer conversions stripped away.
9098// TODO: Merge with MatchRotatePosNeg.
9099SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
9100 SDValue Neg, SDValue InnerPos,
9101 SDValue InnerNeg, bool FromAdd,
9102 bool HasPos, unsigned PosOpcode,
9103 unsigned NegOpcode, const SDLoc &DL) {
9104 EVT VT = N0.getValueType();
9105 unsigned EltBits = VT.getScalarSizeInBits();
9106
9107 // fold (or/add (shl x0, (*ext y)),
9108 // (srl x1, (*ext (sub 32, y)))) ->
9109 // (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
9110 //
9111 // fold (or/add (shl x0, (*ext (sub 32, y))),
9112 // (srl x1, (*ext y))) ->
9113 // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
9114 if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1,
9115 FromAdd))
9116 return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
9117 HasPos ? Pos : Neg);
9118
9119 // Matching the shift+xor cases, we can't easily use the xor'd shift amount
9120 // so for now just use the PosOpcode case if its legal.
9121 // TODO: When can we use the NegOpcode case?
9122 if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
9123 SDValue X;
9124 // fold (or/add (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
9125 // -> (fshl x0, x1, y)
9126 if (sd_match(N1, m_Srl(m_Value(X), m_One())) &&
9127 sd_match(InnerNeg,
9128 m_Xor(m_Specific(InnerPos), m_SpecificInt(EltBits - 1))) &&
9130 return DAG.getNode(ISD::FSHL, DL, VT, N0, X, Pos);
9131 }
9132
9133 // fold (or/add (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
9134 // -> (fshr x0, x1, y)
9135 if (sd_match(N0, m_Shl(m_Value(X), m_One())) &&
9136 sd_match(InnerPos,
9137 m_Xor(m_Specific(InnerNeg), m_SpecificInt(EltBits - 1))) &&
9139 return DAG.getNode(ISD::FSHR, DL, VT, X, N1, Neg);
9140 }
9141
9142 // fold (or/add (shl (add x0, x0), (xor y, 31)), (srl x1, y))
9143 // -> (fshr x0, x1, y)
9144 // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
9145 if (sd_match(N0, m_Add(m_Value(X), m_Deferred(X))) &&
9146 sd_match(InnerPos,
9147 m_Xor(m_Specific(InnerNeg), m_SpecificInt(EltBits - 1))) &&
9149 return DAG.getNode(ISD::FSHR, DL, VT, X, N1, Neg);
9150 }
9151 }
9152
9153 return SDValue();
9154}
9155
9156// MatchRotate - Handle an 'or' or 'add' of two operands. If this is one of the
9157// many idioms for rotate, and if the target supports rotation instructions,
9158// generate a rot[lr]. This also matches funnel shift patterns, similar to
9159// rotation but with different shifted sources.
9160SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
9161 bool FromAdd) {
9162 EVT VT = LHS.getValueType();
9163
9164 // The target must have at least one rotate/funnel flavor.
9165 // We still try to match rotate by constant pre-legalization.
9166 // TODO: Support pre-legalization funnel-shift by constant.
9167 bool HasROTL = hasOperation(ISD::ROTL, VT);
9168 bool HasROTR = hasOperation(ISD::ROTR, VT);
9169 bool HasFSHL = hasOperation(ISD::FSHL, VT);
9170 bool HasFSHR = hasOperation(ISD::FSHR, VT);
9171
9172 // If the type is going to be promoted and the target has enabled custom
9173 // lowering for rotate, allow matching rotate by non-constants. Only allow
9174 // this for scalar types.
9175 if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) ==
9179 }
9180
9181 if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
9182 return SDValue();
9183
9184 // Check for truncated rotate.
9185 if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
9186 LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
9187 assert(LHS.getValueType() == RHS.getValueType());
9188 if (SDValue Rot =
9189 MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL, FromAdd))
9190 return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
9191 }
9192
9193 // Match "(X shl/srl V1) & V2" where V2 may not be present.
9194 SDValue LHSShift; // The shift.
9195 SDValue LHSMask; // AND value if any.
9196 matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
9197
9198 SDValue RHSShift; // The shift.
9199 SDValue RHSMask; // AND value if any.
9200 matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
9201
9202 // If neither side matched a rotate half, bail
9203 if (!LHSShift && !RHSShift)
9204 return SDValue();
9205
9206 // InstCombine may have combined a constant shl, srl, mul, or udiv with one
9207 // side of the rotate, so try to handle that here. In all cases we need to
9208 // pass the matched shift from the opposite side to compute the opcode and
9209 // needed shift amount to extract. We still want to do this if both sides
9210 // matched a rotate half because one half may be a potential overshift that
9211 // can be broken down (ie if InstCombine merged two shl or srl ops into a
9212 // single one).
9213
9214 // Have LHS side of the rotate, try to extract the needed shift from the RHS.
9215 if (LHSShift)
9216 if (SDValue NewRHSShift =
9217 extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
9218 RHSShift = NewRHSShift;
9219 // Have RHS side of the rotate, try to extract the needed shift from the LHS.
9220 if (RHSShift)
9221 if (SDValue NewLHSShift =
9222 extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
9223 LHSShift = NewLHSShift;
9224
9225 // If a side is still missing, nothing else we can do.
9226 if (!RHSShift || !LHSShift)
9227 return SDValue();
9228
9229 // At this point we've matched or extracted a shift op on each side.
9230
9231 if (LHSShift.getOpcode() == RHSShift.getOpcode())
9232 return SDValue(); // Shifts must disagree.
9233
9234 // Canonicalize shl to left side in a shl/srl pair.
9235 if (RHSShift.getOpcode() == ISD::SHL) {
9236 std::swap(LHS, RHS);
9237 std::swap(LHSShift, RHSShift);
9238 std::swap(LHSMask, RHSMask);
9239 }
9240
9241 // Something has gone wrong - we've lost the shl/srl pair - bail.
9242 if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
9243 return SDValue();
9244
9245 unsigned EltSizeInBits = VT.getScalarSizeInBits();
9246 SDValue LHSShiftArg = LHSShift.getOperand(0);
9247 SDValue LHSShiftAmt = LHSShift.getOperand(1);
9248 SDValue RHSShiftArg = RHSShift.getOperand(0);
9249 SDValue RHSShiftAmt = RHSShift.getOperand(1);
9250
9251 auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
9252 ConstantSDNode *RHS) {
9253 return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
9254 };
9255
9256 auto ApplyMasks = [&](SDValue Res) {
9257 // If there is an AND of either shifted operand, apply it to the result.
9258 if (LHSMask.getNode() || RHSMask.getNode()) {
9261
9262 if (LHSMask.getNode()) {
9263 SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
9264 Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
9265 DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
9266 }
9267 if (RHSMask.getNode()) {
9268 SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
9269 Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
9270 DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
9271 }
9272
9273 Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
9274 }
9275
9276 return Res;
9277 };
9278
9279 // TODO: Support pre-legalization funnel-shift by constant.
9280 bool IsRotate = LHSShiftArg == RHSShiftArg;
9281 if (!IsRotate && !(HasFSHL || HasFSHR)) {
9282 if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
9283 ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
9284 // Look for a disguised rotate by constant.
9285 // The common shifted operand X may be hidden inside another 'or'.
9286 SDValue X, Y;
9287 auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
9288 if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
9289 return false;
9290 if (CommonOp == Or.getOperand(0)) {
9291 X = CommonOp;
9292 Y = Or.getOperand(1);
9293 return true;
9294 }
9295 if (CommonOp == Or.getOperand(1)) {
9296 X = CommonOp;
9297 Y = Or.getOperand(0);
9298 return true;
9299 }
9300 return false;
9301 };
9302
9303 SDValue Res;
9304 if (matchOr(LHSShiftArg, RHSShiftArg)) {
9305 // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
9306 SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
9307 SDValue ShlY = DAG.getNode(ISD::SHL, DL, VT, Y, LHSShiftAmt);
9308 Res = DAG.getNode(ISD::OR, DL, VT, RotX, ShlY);
9309 } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
9310 // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
9311 SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
9312 SDValue SrlY = DAG.getNode(ISD::SRL, DL, VT, Y, RHSShiftAmt);
9313 Res = DAG.getNode(ISD::OR, DL, VT, RotX, SrlY);
9314 } else {
9315 return SDValue();
9316 }
9317
9318 return ApplyMasks(Res);
9319 }
9320
9321 return SDValue(); // Requires funnel shift support.
9322 }
9323
9324 // fold (or/add (shl x, C1), (srl x, C2)) -> (rotl x, C1)
9325 // fold (or/add (shl x, C1), (srl x, C2)) -> (rotr x, C2)
9326 // fold (or/add (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
9327 // fold (or/add (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
9328 // iff C1+C2 == EltSizeInBits
9329 if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
9330 SDValue Res;
9331 if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
9332 bool UseROTL = !LegalOperations || HasROTL;
9333 Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
9334 UseROTL ? LHSShiftAmt : RHSShiftAmt);
9335 } else {
9336 bool UseFSHL = !LegalOperations || HasFSHL;
9337 Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
9338 RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt);
9339 }
9340
9341 return ApplyMasks(Res);
9342 }
9343
9344 // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
9345 // shift.
9346 if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
9347 return SDValue();
9348
9349 // If there is a mask here, and we have a variable shift, we can't be sure
9350 // that we're masking out the right stuff.
9351 if (LHSMask.getNode() || RHSMask.getNode())
9352 return SDValue();
9353
9354 // If the shift amount is sign/zext/any-extended just peel it off.
9355 SDValue LExtOp0 = LHSShiftAmt;
9356 SDValue RExtOp0 = RHSShiftAmt;
9357 if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
9358 LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
9359 LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
9360 LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
9361 (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
9362 RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
9363 RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
9364 RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
9365 LExtOp0 = LHSShiftAmt.getOperand(0);
9366 RExtOp0 = RHSShiftAmt.getOperand(0);
9367 }
9368
9369 if (IsRotate && (HasROTL || HasROTR)) {
9370 if (SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt,
9371 LExtOp0, RExtOp0, FromAdd, HasROTL,
9373 return TryL;
9374
9375 if (SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
9376 RExtOp0, LExtOp0, FromAdd, HasROTR,
9378 return TryR;
9379 }
9380
9381 if (SDValue TryL = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt,
9382 RHSShiftAmt, LExtOp0, RExtOp0, FromAdd,
9383 HasFSHL, ISD::FSHL, ISD::FSHR, DL))
9384 return TryL;
9385
9386 if (SDValue TryR = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt,
9387 LHSShiftAmt, RExtOp0, LExtOp0, FromAdd,
9388 HasFSHR, ISD::FSHR, ISD::FSHL, DL))
9389 return TryR;
9390
9391 return SDValue();
9392}
9393
9394/// Recursively traverses the expression calculating the origin of the requested
9395/// byte of the given value. Returns std::nullopt if the provider can't be
9396/// calculated.
9397///
9398/// For all the values except the root of the expression, we verify that the
9399/// value has exactly one use and if not then return std::nullopt. This way if
9400/// the origin of the byte is returned it's guaranteed that the values which
9401/// contribute to the byte are not used outside of this expression.
9402
9403/// However, there is a special case when dealing with vector loads -- we allow
9404/// more than one use if the load is a vector type. Since the values that
9405/// contribute to the byte ultimately come from the ExtractVectorElements of the
9406/// Load, we don't care if the Load has uses other than ExtractVectorElements,
9407/// because those operations are independent from the pattern to be combined.
9408/// For vector loads, we simply care that the ByteProviders are adjacent
9409/// positions of the same vector, and their index matches the byte that is being
9410/// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
9411/// is the index used in an ExtractVectorElement, and \p StartingIndex is the
9412/// byte position we are trying to provide for the LoadCombine. If these do
9413/// not match, then we can not combine the vector loads. \p Index uses the
9414/// byte position we are trying to provide for and is matched against the
9415/// shl and load size. The \p Index algorithm ensures the requested byte is
9416/// provided for by the pattern, and the pattern does not over provide bytes.
9417///
9418///
9419/// The supported LoadCombine pattern for vector loads is as follows
9420/// or
9421/// / \
9422/// or shl
9423/// / \ |
9424/// or shl zext
9425/// / \ | |
9426/// shl zext zext EVE*
9427/// | | | |
9428/// zext EVE* EVE* LOAD
9429/// | | |
9430/// EVE* LOAD LOAD
9431/// |
9432/// LOAD
9433///
9434/// *ExtractVectorElement
9436
9437static std::optional<SDByteProvider>
9438calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
9439 std::optional<uint64_t> VectorIndex,
9440 unsigned StartingIndex = 0) {
9441
9442 // Typical i64 by i8 pattern requires recursion up to 8 calls depth
9443 if (Depth == 10)
9444 return std::nullopt;
9445
9446 // Only allow multiple uses if the instruction is a vector load (in which
9447 // case we will use the load for every ExtractVectorElement)
9448 if (Depth && !Op.hasOneUse() &&
9449 (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
9450 return std::nullopt;
9451
9452 // Fail to combine if we have encountered anything but a LOAD after handling
9453 // an ExtractVectorElement.
9454 if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
9455 return std::nullopt;
9456
9457 unsigned BitWidth = Op.getScalarValueSizeInBits();
9458 if (BitWidth % 8 != 0)
9459 return std::nullopt;
9460 unsigned ByteWidth = BitWidth / 8;
9461 assert(Index < ByteWidth && "invalid index requested");
9462 (void) ByteWidth;
9463
9464 switch (Op.getOpcode()) {
9465 case ISD::OR: {
9466 auto LHS =
9467 calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
9468 if (!LHS)
9469 return std::nullopt;
9470 auto RHS =
9471 calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
9472 if (!RHS)
9473 return std::nullopt;
9474
9475 if (LHS->isConstantZero())
9476 return RHS;
9477 if (RHS->isConstantZero())
9478 return LHS;
9479 return std::nullopt;
9480 }
9481 case ISD::SHL: {
9482 auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
9483 if (!ShiftOp)
9484 return std::nullopt;
9485
9486 uint64_t BitShift = ShiftOp->getZExtValue();
9487
9488 if (BitShift % 8 != 0)
9489 return std::nullopt;
9490 uint64_t ByteShift = BitShift / 8;
9491
9492 // If we are shifting by an amount greater than the index we are trying to
9493 // provide, then do not provide anything. Otherwise, subtract the index by
9494 // the amount we shifted by.
9495 return Index < ByteShift
9497 : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
9498 Depth + 1, VectorIndex, Index);
9499 }
9500 case ISD::ANY_EXTEND:
9501 case ISD::SIGN_EXTEND:
9502 case ISD::ZERO_EXTEND: {
9503 SDValue NarrowOp = Op->getOperand(0);
9504 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
9505 if (NarrowBitWidth % 8 != 0)
9506 return std::nullopt;
9507 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9508
9509 if (Index >= NarrowByteWidth)
9510 return Op.getOpcode() == ISD::ZERO_EXTEND
9511 ? std::optional<SDByteProvider>(
9513 : std::nullopt;
9514 return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
9515 StartingIndex);
9516 }
9517 case ISD::BSWAP:
9518 return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
9519 Depth + 1, VectorIndex, StartingIndex);
9521 auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
9522 if (!OffsetOp)
9523 return std::nullopt;
9524
9525 VectorIndex = OffsetOp->getZExtValue();
9526
9527 SDValue NarrowOp = Op->getOperand(0);
9528 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
9529 if (NarrowBitWidth % 8 != 0)
9530 return std::nullopt;
9531 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9532 // EXTRACT_VECTOR_ELT can extend the element type to the width of the return
9533 // type, leaving the high bits undefined.
9534 if (Index >= NarrowByteWidth)
9535 return std::nullopt;
9536
9537 // Check to see if the position of the element in the vector corresponds
9538 // with the byte we are trying to provide for. In the case of a vector of
9539 // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
9540 // the element will provide a range of bytes. For example, if we have a
9541 // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
9542 // 3).
9543 if (*VectorIndex * NarrowByteWidth > StartingIndex)
9544 return std::nullopt;
9545 if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
9546 return std::nullopt;
9547
9548 return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
9549 VectorIndex, StartingIndex);
9550 }
9551 case ISD::LOAD: {
9552 auto L = cast<LoadSDNode>(Op.getNode());
9553 if (!L->isSimple() || L->isIndexed())
9554 return std::nullopt;
9555
9556 unsigned NarrowBitWidth = L->getMemoryVT().getScalarSizeInBits();
9557 if (NarrowBitWidth % 8 != 0)
9558 return std::nullopt;
9559 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9560
9561 // If the width of the load does not reach byte we are trying to provide for
9562 // and it is not a ZEXTLOAD, then the load does not provide for the byte in
9563 // question
9564 if (Index >= NarrowByteWidth)
9565 return L->getExtensionType() == ISD::ZEXTLOAD
9566 ? std::optional<SDByteProvider>(
9568 : std::nullopt;
9569
9570 unsigned BPVectorIndex = VectorIndex.value_or(0U);
9571 return SDByteProvider::getSrc(L, Index, BPVectorIndex);
9572 }
9573 }
9574
9575 return std::nullopt;
9576}
9577
9578static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
9579 return i;
9580}
9581
9582static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
9583 return BW - i - 1;
9584}
9585
9586// Check if the bytes offsets we are looking at match with either big or
9587// little endian value loaded. Return true for big endian, false for little
9588// endian, and std::nullopt if match failed.
9589static std::optional<bool> isBigEndian(ArrayRef<int64_t> ByteOffsets,
9590 int64_t FirstOffset) {
9591 // The endian can be decided only when it is 2 bytes at least.
9592 unsigned Width = ByteOffsets.size();
9593 if (Width < 2)
9594 return std::nullopt;
9595
9596 bool BigEndian = true, LittleEndian = true;
9597 for (unsigned i = 0; i < Width; i++) {
9598 int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
9599 LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
9600 BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
9601 if (!BigEndian && !LittleEndian)
9602 return std::nullopt;
9603 }
9604
9605 assert((BigEndian != LittleEndian) && "It should be either big endian or"
9606 "little endian");
9607 return BigEndian;
9608}
9609
9610// Look through one layer of truncate or extend.
9612 switch (Value.getOpcode()) {
9613 case ISD::TRUNCATE:
9614 case ISD::ZERO_EXTEND:
9615 case ISD::SIGN_EXTEND:
9616 case ISD::ANY_EXTEND:
9617 return Value.getOperand(0);
9618 }
9619 return SDValue();
9620}
9621
9622/// Match a pattern where a wide type scalar value is stored by several narrow
9623/// stores. Fold it into a single store or a BSWAP and a store if the targets
9624/// supports it.
9625///
9626/// Assuming little endian target:
9627/// i8 *p = ...
9628/// i32 val = ...
9629/// p[0] = (val >> 0) & 0xFF;
9630/// p[1] = (val >> 8) & 0xFF;
9631/// p[2] = (val >> 16) & 0xFF;
9632/// p[3] = (val >> 24) & 0xFF;
9633/// =>
9634/// *((i32)p) = val;
9635///
9636/// i8 *p = ...
9637/// i32 val = ...
9638/// p[0] = (val >> 24) & 0xFF;
9639/// p[1] = (val >> 16) & 0xFF;
9640/// p[2] = (val >> 8) & 0xFF;
9641/// p[3] = (val >> 0) & 0xFF;
9642/// =>
9643/// *((i32)p) = BSWAP(val);
9644SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
9645 // The matching looks for "store (trunc x)" patterns that appear early but are
9646 // likely to be replaced by truncating store nodes during combining.
9647 // TODO: If there is evidence that running this later would help, this
9648 // limitation could be removed. Legality checks may need to be added
9649 // for the created store and optional bswap/rotate.
9650 if (LegalOperations || OptLevel == CodeGenOptLevel::None)
9651 return SDValue();
9652
9653 // We only handle merging simple stores of 1-4 bytes.
9654 // TODO: Allow unordered atomics when wider type is legal (see D66309)
9655 EVT MemVT = N->getMemoryVT();
9656 if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
9657 !N->isSimple() || N->isIndexed())
9658 return SDValue();
9659
9660 // Collect all of the stores in the chain, upto the maximum store width (i64).
9661 SDValue Chain = N->getChain();
9663 unsigned NarrowNumBits = MemVT.getScalarSizeInBits();
9664 unsigned MaxWideNumBits = 64;
9665 unsigned MaxStores = MaxWideNumBits / NarrowNumBits;
9666 while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
9667 // All stores must be the same size to ensure that we are writing all of the
9668 // bytes in the wide value.
9669 // This store should have exactly one use as a chain operand for another
9670 // store in the merging set. If there are other chain uses, then the
9671 // transform may not be safe because order of loads/stores outside of this
9672 // set may not be preserved.
9673 // TODO: We could allow multiple sizes by tracking each stored byte.
9674 if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
9675 Store->isIndexed() || !Store->hasOneUse())
9676 return SDValue();
9677 Stores.push_back(Store);
9678 Chain = Store->getChain();
9679 if (MaxStores < Stores.size())
9680 return SDValue();
9681 }
9682 // There is no reason to continue if we do not have at least a pair of stores.
9683 if (Stores.size() < 2)
9684 return SDValue();
9685
9686 // Handle simple types only.
9687 LLVMContext &Context = *DAG.getContext();
9688 unsigned NumStores = Stores.size();
9689 unsigned WideNumBits = NumStores * NarrowNumBits;
9690 if (WideNumBits != 16 && WideNumBits != 32 && WideNumBits != 64)
9691 return SDValue();
9692
9693 // Check if all bytes of the source value that we are looking at are stored
9694 // to the same base address. Collect offsets from Base address into OffsetMap.
9695 SDValue SourceValue;
9696 SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
9697 int64_t FirstOffset = INT64_MAX;
9698 StoreSDNode *FirstStore = nullptr;
9699 std::optional<BaseIndexOffset> Base;
9700 for (auto *Store : Stores) {
9701 // All the stores store different parts of the CombinedValue. A truncate is
9702 // required to get the partial value.
9703 SDValue Trunc = Store->getValue();
9704 if (Trunc.getOpcode() != ISD::TRUNCATE)
9705 return SDValue();
9706 // Other than the first/last part, a shift operation is required to get the
9707 // offset.
9708 int64_t Offset = 0;
9709 SDValue WideVal = Trunc.getOperand(0);
9710 if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
9711 isa<ConstantSDNode>(WideVal.getOperand(1))) {
9712 // The shift amount must be a constant multiple of the narrow type.
9713 // It is translated to the offset address in the wide source value "y".
9714 //
9715 // x = srl y, ShiftAmtC
9716 // i8 z = trunc x
9717 // store z, ...
9718 uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
9719 if (ShiftAmtC % NarrowNumBits != 0)
9720 return SDValue();
9721
9722 // Make sure we aren't reading bits that are shifted in.
9723 if (ShiftAmtC > WideVal.getScalarValueSizeInBits() - NarrowNumBits)
9724 return SDValue();
9725
9726 Offset = ShiftAmtC / NarrowNumBits;
9727 WideVal = WideVal.getOperand(0);
9728 }
9729
9730 // Stores must share the same source value with different offsets.
9731 if (!SourceValue)
9732 SourceValue = WideVal;
9733 else if (SourceValue != WideVal) {
9734 // Truncate and extends can be stripped to see if the values are related.
9735 if (stripTruncAndExt(SourceValue) != WideVal &&
9736 stripTruncAndExt(WideVal) != SourceValue)
9737 return SDValue();
9738
9739 if (WideVal.getScalarValueSizeInBits() >
9740 SourceValue.getScalarValueSizeInBits())
9741 SourceValue = WideVal;
9742
9743 // Give up if the source value type is smaller than the store size.
9744 if (SourceValue.getScalarValueSizeInBits() < WideNumBits)
9745 return SDValue();
9746 }
9747
9748 // Stores must share the same base address.
9749 BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
9750 int64_t ByteOffsetFromBase = 0;
9751 if (!Base)
9752 Base = Ptr;
9753 else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
9754 return SDValue();
9755
9756 // Remember the first store.
9757 if (ByteOffsetFromBase < FirstOffset) {
9758 FirstStore = Store;
9759 FirstOffset = ByteOffsetFromBase;
9760 }
9761 // Map the offset in the store and the offset in the combined value, and
9762 // early return if it has been set before.
9763 if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
9764 return SDValue();
9765 OffsetMap[Offset] = ByteOffsetFromBase;
9766 }
9767
9768 EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
9769
9770 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9771 assert(FirstStore && "First store must be set");
9772
9773 // Check that a store of the wide type is both allowed and fast on the target
9774 const DataLayout &Layout = DAG.getDataLayout();
9775 unsigned Fast = 0;
9776 bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
9777 *FirstStore->getMemOperand(), &Fast);
9778 if (!Allowed || !Fast)
9779 return SDValue();
9780
9781 // Check if the pieces of the value are going to the expected places in memory
9782 // to merge the stores.
9783 auto checkOffsets = [&](bool MatchLittleEndian) {
9784 if (MatchLittleEndian) {
9785 for (unsigned i = 0; i != NumStores; ++i)
9786 if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
9787 return false;
9788 } else { // MatchBigEndian by reversing loop counter.
9789 for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
9790 if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
9791 return false;
9792 }
9793 return true;
9794 };
9795
9796 // Check if the offsets line up for the native data layout of this target.
9797 bool NeedBswap = false;
9798 bool NeedRotate = false;
9799 if (!checkOffsets(Layout.isLittleEndian())) {
9800 // Special-case: check if byte offsets line up for the opposite endian.
9801 if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
9802 NeedBswap = true;
9803 else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
9804 NeedRotate = true;
9805 else
9806 return SDValue();
9807 }
9808
9809 SDLoc DL(N);
9810 if (WideVT != SourceValue.getValueType()) {
9811 assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
9812 "Unexpected store value to merge");
9813 SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
9814 }
9815
9816 // Before legalize we can introduce illegal bswaps/rotates which will be later
9817 // converted to an explicit bswap sequence. This way we end up with a single
9818 // store and byte shuffling instead of several stores and byte shuffling.
9819 if (NeedBswap) {
9820 SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
9821 } else if (NeedRotate) {
9822 assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
9823 SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
9824 SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
9825 }
9826
9827 SDValue NewStore =
9828 DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
9829 FirstStore->getPointerInfo(), FirstStore->getAlign());
9830
9831 // Rely on other DAG combine rules to remove the other individual stores.
9832 DAG.ReplaceAllUsesWith(N, NewStore.getNode());
9833 return NewStore;
9834}
9835
9836/// Match a pattern where a wide type scalar value is loaded by several narrow
9837/// loads and combined by shifts and ors. Fold it into a single load or a load
9838/// and a BSWAP if the targets supports it.
9839///
9840/// Assuming little endian target:
9841/// i8 *a = ...
9842/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
9843/// =>
9844/// i32 val = *((i32)a)
9845///
9846/// i8 *a = ...
9847/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
9848/// =>
9849/// i32 val = BSWAP(*((i32)a))
9850///
9851/// TODO: This rule matches complex patterns with OR node roots and doesn't
9852/// interact well with the worklist mechanism. When a part of the pattern is
9853/// updated (e.g. one of the loads) its direct users are put into the worklist,
9854/// but the root node of the pattern which triggers the load combine is not
9855/// necessarily a direct user of the changed node. For example, once the address
9856/// of t28 load is reassociated load combine won't be triggered:
9857/// t25: i32 = add t4, Constant:i32<2>
9858/// t26: i64 = sign_extend t25
9859/// t27: i64 = add t2, t26
9860/// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
9861/// t29: i32 = zero_extend t28
9862/// t32: i32 = shl t29, Constant:i8<8>
9863/// t33: i32 = or t23, t32
9864/// As a possible fix visitLoad can check if the load can be a part of a load
9865/// combine pattern and add corresponding OR roots to the worklist.
9866SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
9867 assert(N->getOpcode() == ISD::OR &&
9868 "Can only match load combining against OR nodes");
9869
9870 // Handles simple types only
9871 EVT VT = N->getValueType(0);
9872 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
9873 return SDValue();
9874 unsigned ByteWidth = VT.getSizeInBits() / 8;
9875
9876 bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
9877 auto MemoryByteOffset = [&](SDByteProvider P) {
9878 assert(P.hasSrc() && "Must be a memory byte provider");
9879 auto *Load = cast<LoadSDNode>(P.Src.value());
9880
9881 unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
9882
9883 assert(LoadBitWidth % 8 == 0 &&
9884 "can only analyze providers for individual bytes not bit");
9885 unsigned LoadByteWidth = LoadBitWidth / 8;
9886 return IsBigEndianTarget ? bigEndianByteAt(LoadByteWidth, P.DestOffset)
9887 : littleEndianByteAt(LoadByteWidth, P.DestOffset);
9888 };
9889
9890 std::optional<BaseIndexOffset> Base;
9891 SDValue Chain;
9892
9893 SmallPtrSet<LoadSDNode *, 8> Loads;
9894 std::optional<SDByteProvider> FirstByteProvider;
9895 int64_t FirstOffset = INT64_MAX;
9896
9897 // Check if all the bytes of the OR we are looking at are loaded from the same
9898 // base address. Collect bytes offsets from Base address in ByteOffsets.
9899 SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
9900 unsigned ZeroExtendedBytes = 0;
9901 for (int i = ByteWidth - 1; i >= 0; --i) {
9902 auto P =
9903 calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
9904 /*StartingIndex*/ i);
9905 if (!P)
9906 return SDValue();
9907
9908 if (P->isConstantZero()) {
9909 // It's OK for the N most significant bytes to be 0, we can just
9910 // zero-extend the load.
9911 if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
9912 return SDValue();
9913 continue;
9914 }
9915 assert(P->hasSrc() && "provenance should either be memory or zero");
9916 auto *L = cast<LoadSDNode>(P->Src.value());
9917
9918 // All loads must share the same chain
9919 SDValue LChain = L->getChain();
9920 if (!Chain)
9921 Chain = LChain;
9922 else if (Chain != LChain)
9923 return SDValue();
9924
9925 // Loads must share the same base address
9926 BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
9927 int64_t ByteOffsetFromBase = 0;
9928
9929 // For vector loads, the expected load combine pattern will have an
9930 // ExtractElement for each index in the vector. While each of these
9931 // ExtractElements will be accessing the same base address as determined
9932 // by the load instruction, the actual bytes they interact with will differ
9933 // due to different ExtractElement indices. To accurately determine the
9934 // byte position of an ExtractElement, we offset the base load ptr with
9935 // the index multiplied by the byte size of each element in the vector.
9936 if (L->getMemoryVT().isVector()) {
9937 unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
9938 if (LoadWidthInBit % 8 != 0)
9939 return SDValue();
9940 unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
9941 Ptr.addToOffset(ByteOffsetFromVector);
9942 }
9943
9944 if (!Base)
9945 Base = Ptr;
9946
9947 else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
9948 return SDValue();
9949
9950 // Calculate the offset of the current byte from the base address
9951 ByteOffsetFromBase += MemoryByteOffset(*P);
9952 ByteOffsets[i] = ByteOffsetFromBase;
9953
9954 // Remember the first byte load
9955 if (ByteOffsetFromBase < FirstOffset) {
9956 FirstByteProvider = P;
9957 FirstOffset = ByteOffsetFromBase;
9958 }
9959
9960 Loads.insert(L);
9961 }
9962
9963 assert(!Loads.empty() && "All the bytes of the value must be loaded from "
9964 "memory, so there must be at least one load which produces the value");
9965 assert(Base && "Base address of the accessed memory location must be set");
9966 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9967
9968 bool NeedsZext = ZeroExtendedBytes > 0;
9969
9970 EVT MemVT =
9971 EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
9972
9973 if (!MemVT.isSimple())
9974 return SDValue();
9975
9976 // Before legalize we can introduce too wide illegal loads which will be later
9977 // split into legal sized loads. This enables us to combine i64 load by i8
9978 // patterns to a couple of i32 loads on 32 bit targets.
9979 if (LegalOperations &&
9980 !TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
9981 MemVT))
9982 return SDValue();
9983
9984 // Check if the bytes of the OR we are looking at match with either big or
9985 // little endian value load
9986 std::optional<bool> IsBigEndian = isBigEndian(
9987 ArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
9988 if (!IsBigEndian)
9989 return SDValue();
9990
9991 assert(FirstByteProvider && "must be set");
9992
9993 // Ensure that the first byte is loaded from zero offset of the first load.
9994 // So the combined value can be loaded from the first load address.
9995 if (MemoryByteOffset(*FirstByteProvider) != 0)
9996 return SDValue();
9997 auto *FirstLoad = cast<LoadSDNode>(FirstByteProvider->Src.value());
9998
9999 // The node we are looking at matches with the pattern, check if we can
10000 // replace it with a single (possibly zero-extended) load and bswap + shift if
10001 // needed.
10002
10003 // If the load needs byte swap check if the target supports it
10004 bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
10005
10006 // Before legalize we can introduce illegal bswaps which will be later
10007 // converted to an explicit bswap sequence. This way we end up with a single
10008 // load and byte shuffling instead of several loads and byte shuffling.
10009 // We do not introduce illegal bswaps when zero-extending as this tends to
10010 // introduce too many arithmetic instructions.
10011 if (NeedsBswap && (LegalOperations || NeedsZext) &&
10012 !TLI.isOperationLegal(ISD::BSWAP, VT))
10013 return SDValue();
10014
10015 // If we need to bswap and zero extend, we have to insert a shift. Check that
10016 // it is legal.
10017 if (NeedsBswap && NeedsZext && LegalOperations &&
10018 !TLI.isOperationLegal(ISD::SHL, VT))
10019 return SDValue();
10020
10021 // Check that a load of the wide type is both allowed and fast on the target
10022 unsigned Fast = 0;
10023 bool Allowed =
10024 TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
10025 *FirstLoad->getMemOperand(), &Fast);
10026 if (!Allowed || !Fast)
10027 return SDValue();
10028
10029 SDValue NewLoad =
10030 DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
10031 Chain, FirstLoad->getBasePtr(),
10032 FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
10033
10034 // Transfer chain users from old loads to the new load.
10035 for (LoadSDNode *L : Loads)
10036 DAG.makeEquivalentMemoryOrdering(L, NewLoad);
10037
10038 if (!NeedsBswap)
10039 return NewLoad;
10040
10041 SDValue ShiftedLoad =
10042 NeedsZext ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
10043 DAG.getShiftAmountConstant(ZeroExtendedBytes * 8,
10044 VT, SDLoc(N)))
10045 : NewLoad;
10046 return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
10047}
10048
10049// If the target has andn, bsl, or a similar bit-select instruction,
10050// we want to unfold masked merge, with canonical pattern of:
10051// | A | |B|
10052// ((x ^ y) & m) ^ y
10053// | D |
10054// Into:
10055// (x & m) | (y & ~m)
10056// If y is a constant, m is not a 'not', and the 'andn' does not work with
10057// immediates, we unfold into a different pattern:
10058// ~(~x & m) & (m | y)
10059// If x is a constant, m is a 'not', and the 'andn' does not work with
10060// immediates, we unfold into a different pattern:
10061// (x | ~m) & ~(~m & ~y)
10062// NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
10063// the very least that breaks andnpd / andnps patterns, and because those
10064// patterns are simplified in IR and shouldn't be created in the DAG
10065SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
10066 assert(N->getOpcode() == ISD::XOR);
10067
10068 // Don't touch 'not' (i.e. where y = -1).
10069 if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
10070 return SDValue();
10071
10072 EVT VT = N->getValueType(0);
10073
10074 // There are 3 commutable operators in the pattern,
10075 // so we have to deal with 8 possible variants of the basic pattern.
10076 SDValue X, Y, M;
10077 auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
10078 if (And.getOpcode() != ISD::AND || !And.hasOneUse())
10079 return false;
10080 SDValue Xor = And.getOperand(XorIdx);
10081 if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
10082 return false;
10083 SDValue Xor0 = Xor.getOperand(0);
10084 SDValue Xor1 = Xor.getOperand(1);
10085 // Don't touch 'not' (i.e. where y = -1).
10086 if (isAllOnesOrAllOnesSplat(Xor1))
10087 return false;
10088 if (Other == Xor0)
10089 std::swap(Xor0, Xor1);
10090 if (Other != Xor1)
10091 return false;
10092 X = Xor0;
10093 Y = Xor1;
10094 M = And.getOperand(XorIdx ? 0 : 1);
10095 return true;
10096 };
10097
10098 SDValue N0 = N->getOperand(0);
10099 SDValue N1 = N->getOperand(1);
10100 if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
10101 !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
10102 return SDValue();
10103
10104 // Don't do anything if the mask is constant. This should not be reachable.
10105 // InstCombine should have already unfolded this pattern, and DAGCombiner
10106 // probably shouldn't produce it, too.
10107 if (isa<ConstantSDNode>(M.getNode()))
10108 return SDValue();
10109
10110 // We can transform if the target has AndNot
10111 if (!TLI.hasAndNot(M))
10112 return SDValue();
10113
10114 SDLoc DL(N);
10115
10116 // If Y is a constant, check that 'andn' works with immediates. Unless M is
10117 // a bitwise not that would already allow ANDN to be used.
10118 if (!TLI.hasAndNot(Y) && !isBitwiseNot(M)) {
10119 assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
10120 // If not, we need to do a bit more work to make sure andn is still used.
10121 SDValue NotX = DAG.getNOT(DL, X, VT);
10122 SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
10123 SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
10124 SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
10125 return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
10126 }
10127
10128 // If X is a constant and M is a bitwise not, check that 'andn' works with
10129 // immediates.
10130 if (!TLI.hasAndNot(X) && isBitwiseNot(M)) {
10131 assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
10132 // If not, we need to do a bit more work to make sure andn is still used.
10133 SDValue NotM = M.getOperand(0);
10134 SDValue LHS = DAG.getNode(ISD::OR, DL, VT, X, NotM);
10135 SDValue NotY = DAG.getNOT(DL, Y, VT);
10136 SDValue RHS = DAG.getNode(ISD::AND, DL, VT, NotM, NotY);
10137 SDValue NotRHS = DAG.getNOT(DL, RHS, VT);
10138 return DAG.getNode(ISD::AND, DL, VT, LHS, NotRHS);
10139 }
10140
10141 SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
10142 SDValue NotM = DAG.getNOT(DL, M, VT);
10143 SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
10144
10145 return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
10146}
10147
10148SDValue DAGCombiner::visitXOR(SDNode *N) {
10149 SDValue N0 = N->getOperand(0);
10150 SDValue N1 = N->getOperand(1);
10151 EVT VT = N0.getValueType();
10152 SDLoc DL(N);
10153
10154 // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
10155 if (N0.isUndef() && N1.isUndef())
10156 return DAG.getConstant(0, DL, VT);
10157
10158 // fold (xor x, undef) -> undef
10159 if (N0.isUndef())
10160 return N0;
10161 if (N1.isUndef())
10162 return N1;
10163
10164 // fold (xor c1, c2) -> c1^c2
10165 if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
10166 return C;
10167
10168 // canonicalize constant to RHS
10171 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
10172
10173 // fold vector ops
10174 if (VT.isVector()) {
10175 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10176 return FoldedVOp;
10177
10178 // fold (xor x, 0) -> x, vector edition
10180 return N0;
10181 }
10182
10183 // fold (xor x, 0) -> x
10184 if (isNullConstant(N1))
10185 return N0;
10186
10187 if (SDValue NewSel = foldBinOpIntoSelect(N))
10188 return NewSel;
10189
10190 // reassociate xor
10191 if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
10192 return RXOR;
10193
10194 // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
10195 if (SDValue SD =
10196 reassociateReduction(ISD::VECREDUCE_XOR, ISD::XOR, DL, VT, N0, N1))
10197 return SD;
10198
10199 // fold (a^b) -> (a|b) iff a and b share no bits.
10200 if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
10201 DAG.haveNoCommonBitsSet(N0, N1))
10202 return DAG.getNode(ISD::OR, DL, VT, N0, N1, SDNodeFlags::Disjoint);
10203
10204 // look for 'add-like' folds:
10205 // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
10206 if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
10208 if (SDValue Combined = visitADDLike(N))
10209 return Combined;
10210
10211 // fold not (setcc x, y, cc) -> setcc x y !cc
10212 // Avoid breaking: and (not(setcc x, y, cc), z) -> andn for vec
10213 unsigned N0Opcode = N0.getOpcode();
10214 SDValue LHS, RHS, CC;
10215 if (TLI.isConstTrueVal(N1) &&
10216 isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true) &&
10217 !(VT.isVector() && TLI.hasAndNot(SDValue(N, 0)) && N->hasOneUse() &&
10218 N->use_begin()->getUser()->getOpcode() == ISD::AND)) {
10220 LHS.getValueType());
10221 if (!LegalOperations ||
10222 TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
10223 switch (N0Opcode) {
10224 default:
10225 llvm_unreachable("Unhandled SetCC Equivalent!");
10226 case ISD::SETCC:
10227 return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
10228 case ISD::SELECT_CC:
10229 return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
10230 N0.getOperand(3), NotCC);
10231 case ISD::STRICT_FSETCC:
10232 case ISD::STRICT_FSETCCS: {
10233 if (N0.hasOneUse()) {
10234 // FIXME Can we handle multiple uses? Could we token factor the chain
10235 // results from the new/old setcc?
10236 SDValue SetCC =
10237 DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
10238 N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
10239 CombineTo(N, SetCC);
10240 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
10241 recursivelyDeleteUnusedNodes(N0.getNode());
10242 return SDValue(N, 0); // Return N so it doesn't get rechecked!
10243 }
10244 break;
10245 }
10246 }
10247 }
10248 }
10249
10250 // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
10251 if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
10252 isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
10253 SDValue V = N0.getOperand(0);
10254 SDLoc DL0(N0);
10255 V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
10256 DAG.getConstant(1, DL0, V.getValueType()));
10257 AddToWorklist(V.getNode());
10258 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
10259 }
10260
10261 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
10262 // fold (not (and x, y)) -> (or (not x), (not y)) iff x or y are setcc
10263 if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
10264 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
10265 SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
10266 if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
10267 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
10268 N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
10269 N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
10270 AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
10271 return DAG.getNode(NewOpcode, DL, VT, N00, N01);
10272 }
10273 }
10274 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
10275 // fold (not (and x, y)) -> (or (not x), (not y)) iff x or y are constants
10276 if (isAllOnesConstant(N1) && N0.hasOneUse() &&
10277 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
10278 SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
10279 if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
10280 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
10281 N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
10282 N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
10283 AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
10284 return DAG.getNode(NewOpcode, DL, VT, N00, N01);
10285 }
10286 }
10287
10288 // fold (not (sub Y, X)) -> (add X, ~Y) if Y is a constant
10289 if (N0.getOpcode() == ISD::SUB && isAllOnesConstant(N1)) {
10290 SDValue Y = N0.getOperand(0);
10291 SDValue X = N0.getOperand(1);
10292
10293 if (auto *YConst = dyn_cast<ConstantSDNode>(Y)) {
10294 APInt NotYValue = ~YConst->getAPIntValue();
10295 SDValue NotY = DAG.getConstant(NotYValue, DL, VT);
10296 return DAG.getNode(ISD::ADD, DL, VT, X, NotY, N->getFlags());
10297 }
10298 }
10299
10300 // fold (not (add X, -1)) -> (neg X)
10301 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() && isAllOnesConstant(N1) &&
10303 return DAG.getNegative(N0.getOperand(0), DL, VT);
10304 }
10305
10306 // fold (xor (and x, y), y) -> (and (not x), y)
10307 if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
10308 SDValue X = N0.getOperand(0);
10309 SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
10310 AddToWorklist(NotX.getNode());
10311 return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
10312 }
10313
10314 // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
10315 if (!LegalOperations || hasOperation(ISD::ABS, VT)) {
10316 SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
10317 SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
10318 if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
10319 SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
10320 SDValue S0 = S.getOperand(0);
10321 if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
10322 if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
10323 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
10324 return DAG.getNode(ISD::ABS, DL, VT, S0);
10325 }
10326 }
10327
10328 // fold (xor x, x) -> 0
10329 if (N0 == N1)
10330 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
10331
10332 // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
10333 // Here is a concrete example of this equivalence:
10334 // i16 x == 14
10335 // i16 shl == 1 << 14 == 16384 == 0b0100000000000000
10336 // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
10337 //
10338 // =>
10339 //
10340 // i16 ~1 == 0b1111111111111110
10341 // i16 rol(~1, 14) == 0b1011111111111111
10342 //
10343 // Some additional tips to help conceptualize this transform:
10344 // - Try to see the operation as placing a single zero in a value of all ones.
10345 // - There exists no value for x which would allow the result to contain zero.
10346 // - Values of x larger than the bitwidth are undefined and do not require a
10347 // consistent result.
10348 // - Pushing the zero left requires shifting one bits in from the right.
10349 // A rotate left of ~1 is a nice way of achieving the desired result.
10350 if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
10352 return DAG.getNode(ISD::ROTL, DL, VT, DAG.getSignedConstant(~1, DL, VT),
10353 N0.getOperand(1));
10354 }
10355
10356 // Simplify: xor (op x...), (op y...) -> (op (xor x, y))
10357 if (N0Opcode == N1.getOpcode())
10358 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
10359 return V;
10360
10361 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
10362 return R;
10363 if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
10364 return R;
10365 if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
10366 return R;
10367
10368 // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable
10369 if (SDValue MM = unfoldMaskedMerge(N))
10370 return MM;
10371
10372 // Simplify the expression using non-local knowledge.
10374 return SDValue(N, 0);
10375
10376 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
10377 return Combined;
10378
10379 // fold (xor (smin(x, C), C)) -> select (x < C), xor(x, C), 0
10380 // fold (xor (smax(x, C), C)) -> select (x > C), xor(x, C), 0
10381 // fold (xor (umin(x, C), C)) -> select (x < C), xor(x, C), 0
10382 // fold (xor (umax(x, C), C)) -> select (x > C), xor(x, C), 0
10383 SDValue Op0;
10384 if (sd_match(N0, m_OneUse(m_AnyOf(m_SMin(m_Value(Op0), m_Specific(N1)),
10385 m_SMax(m_Value(Op0), m_Specific(N1)),
10386 m_UMin(m_Value(Op0), m_Specific(N1)),
10387 m_UMax(m_Value(Op0), m_Specific(N1)))))) {
10388
10389 if (isa<ConstantSDNode>(N1) ||
10391 // For vectors, only optimize when the constant is zero or all-ones to
10392 // avoid generating more instructions
10393 if (VT.isVector()) {
10394 ConstantSDNode *N1C = isConstOrConstSplat(N1);
10395 if (!N1C || (!N1C->isZero() && !N1C->isAllOnes()))
10396 return SDValue();
10397 }
10398
10399 // Avoid the fold if the minmax operation is legal and select is expensive
10400 if (TLI.isOperationLegal(N0.getOpcode(), VT) &&
10402 return SDValue();
10403
10404 EVT CCVT = getSetCCResultType(VT);
10405 ISD::CondCode CC;
10406 switch (N0.getOpcode()) {
10407 case ISD::SMIN:
10408 CC = ISD::SETLT;
10409 break;
10410 case ISD::SMAX:
10411 CC = ISD::SETGT;
10412 break;
10413 case ISD::UMIN:
10414 CC = ISD::SETULT;
10415 break;
10416 case ISD::UMAX:
10417 CC = ISD::SETUGT;
10418 break;
10419 }
10420 SDValue FN1 = DAG.getFreeze(N1);
10421 SDValue Cmp = DAG.getSetCC(DL, CCVT, Op0, FN1, CC);
10422 SDValue XorXC = DAG.getNode(ISD::XOR, DL, VT, Op0, FN1);
10423 SDValue Zero = DAG.getConstant(0, DL, VT);
10424 return DAG.getSelect(DL, VT, Cmp, XorXC, Zero);
10425 }
10426 }
10427
10428 return SDValue();
10429}
10430
10431/// If we have a shift-by-constant of a bitwise logic op that itself has a
10432/// shift-by-constant operand with identical opcode, we may be able to convert
10433/// that into 2 independent shifts followed by the logic op. This is a
10434/// throughput improvement.
10436 // Match a one-use bitwise logic op.
10437 SDValue LogicOp = Shift->getOperand(0);
10438 if (!LogicOp.hasOneUse())
10439 return SDValue();
10440
10441 unsigned LogicOpcode = LogicOp.getOpcode();
10442 if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
10443 LogicOpcode != ISD::XOR)
10444 return SDValue();
10445
10446 // Find a matching one-use shift by constant.
10447 unsigned ShiftOpcode = Shift->getOpcode();
10448 SDValue C1 = Shift->getOperand(1);
10449 ConstantSDNode *C1Node = isConstOrConstSplat(C1);
10450 assert(C1Node && "Expected a shift with constant operand");
10451 const APInt &C1Val = C1Node->getAPIntValue();
10452 auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
10453 const APInt *&ShiftAmtVal) {
10454 if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
10455 return false;
10456
10457 ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
10458 if (!ShiftCNode)
10459 return false;
10460
10461 // Capture the shifted operand and shift amount value.
10462 ShiftOp = V.getOperand(0);
10463 ShiftAmtVal = &ShiftCNode->getAPIntValue();
10464
10465 // Shift amount types do not have to match their operand type, so check that
10466 // the constants are the same width.
10467 if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
10468 return false;
10469
10470 // The fold is not valid if the sum of the shift values doesn't fit in the
10471 // given shift amount type.
10472 bool Overflow = false;
10473 APInt NewShiftAmt = C1Val.uadd_ov(*ShiftAmtVal, Overflow);
10474 if (Overflow)
10475 return false;
10476
10477 // The fold is not valid if the sum of the shift values exceeds bitwidth.
10478 if (NewShiftAmt.uge(V.getScalarValueSizeInBits()))
10479 return false;
10480
10481 return true;
10482 };
10483
10484 // Logic ops are commutative, so check each operand for a match.
10485 SDValue X, Y;
10486 const APInt *C0Val;
10487 if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
10488 Y = LogicOp.getOperand(1);
10489 else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
10490 Y = LogicOp.getOperand(0);
10491 else
10492 return SDValue();
10493
10494 // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
10495 SDLoc DL(Shift);
10496 EVT VT = Shift->getValueType(0);
10497 EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
10498 SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
10499 SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
10500 SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
10501 return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2,
10502 LogicOp->getFlags());
10503}
10504
10505/// Handle transforms common to the three shifts, when the shift amount is a
10506/// constant.
10507/// We are looking for: (shift being one of shl/sra/srl)
10508/// shift (binop X, C0), C1
10509/// And want to transform into:
10510/// binop (shift X, C1), (shift C0, C1)
10511SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
10512 assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
10513
10514 // Do not turn a 'not' into a regular xor.
10515 if (isBitwiseNot(N->getOperand(0)))
10516 return SDValue();
10517
10518 // The inner binop must be one-use, since we want to replace it.
10519 SDValue LHS = N->getOperand(0);
10520 if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
10521 return SDValue();
10522
10523 // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
10524 if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
10525 return R;
10526
10527 // We want to pull some binops through shifts, so that we have (and (shift))
10528 // instead of (shift (and)), likewise for add, or, xor, etc. This sort of
10529 // thing happens with address calculations, so it's important to canonicalize
10530 // it.
10531 switch (LHS.getOpcode()) {
10532 default:
10533 return SDValue();
10534 case ISD::OR:
10535 case ISD::XOR:
10536 case ISD::AND:
10537 break;
10538 case ISD::ADD:
10539 if (N->getOpcode() != ISD::SHL)
10540 return SDValue(); // only shl(add) not sr[al](add).
10541 break;
10542 }
10543
10544 // FIXME: disable this unless the input to the binop is a shift by a constant
10545 // or is copy/select. Enable this in other cases when figure out it's exactly
10546 // profitable.
10547 SDValue BinOpLHSVal = LHS.getOperand(0);
10548 bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
10549 BinOpLHSVal.getOpcode() == ISD::SRA ||
10550 BinOpLHSVal.getOpcode() == ISD::SRL) &&
10551 isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
10552 bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
10553 BinOpLHSVal.getOpcode() == ISD::SELECT;
10554
10555 if (!IsShiftByConstant && !IsCopyOrSelect)
10556 return SDValue();
10557
10558 if (IsCopyOrSelect && N->hasOneUse())
10559 return SDValue();
10560
10561 // Attempt to fold the constants, shifting the binop RHS by the shift amount.
10562 SDLoc DL(N);
10563 EVT VT = N->getValueType(0);
10564 if (SDValue NewRHS = DAG.FoldConstantArithmetic(
10565 N->getOpcode(), DL, VT, {LHS.getOperand(1), N->getOperand(1)})) {
10566 SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
10567 N->getOperand(1));
10568 return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
10569 }
10570
10571 return SDValue();
10572}
10573
10574SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
10575 assert(N->getOpcode() == ISD::TRUNCATE);
10576 assert(N->getOperand(0).getOpcode() == ISD::AND);
10577
10578 // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
10579 EVT TruncVT = N->getValueType(0);
10580 if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
10581 TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
10582 SDValue N01 = N->getOperand(0).getOperand(1);
10583 if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
10584 SDLoc DL(N);
10585 SDValue N00 = N->getOperand(0).getOperand(0);
10586 SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
10587 SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
10588 AddToWorklist(Trunc00.getNode());
10589 AddToWorklist(Trunc01.getNode());
10590 return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
10591 }
10592 }
10593
10594 return SDValue();
10595}
10596
10597SDValue DAGCombiner::visitRotate(SDNode *N) {
10598 SDLoc dl(N);
10599 SDValue N0 = N->getOperand(0);
10600 SDValue N1 = N->getOperand(1);
10601 EVT VT = N->getValueType(0);
10602 unsigned Bitsize = VT.getScalarSizeInBits();
10603
10604 // fold (rot x, 0) -> x
10605 if (isNullOrNullSplat(N1))
10606 return N0;
10607
10608 // fold (rot x, c) -> x iff (c % BitSize) == 0
10609 if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
10610 APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
10611 if (DAG.MaskedValueIsZero(N1, ModuloMask))
10612 return N0;
10613 }
10614
10615 // fold (rot x, c) -> (rot x, c % BitSize)
10616 bool OutOfRange = false;
10617 auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
10618 OutOfRange |= C->getAPIntValue().uge(Bitsize);
10619 return true;
10620 };
10621 if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
10622 EVT AmtVT = N1.getValueType();
10623 SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
10624 if (SDValue Amt =
10625 DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
10626 return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
10627 }
10628
10629 // rot i16 X, 8 --> bswap X
10630 auto *RotAmtC = isConstOrConstSplat(N1);
10631 if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
10632 VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
10633 return DAG.getNode(ISD::BSWAP, dl, VT, N0);
10634
10635 // Simplify the operands using demanded-bits information.
10637 return SDValue(N, 0);
10638
10639 // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
10640 if (N1.getOpcode() == ISD::TRUNCATE &&
10641 N1.getOperand(0).getOpcode() == ISD::AND) {
10642 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10643 return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
10644 }
10645
10646 unsigned NextOp = N0.getOpcode();
10647
10648 // fold (rot* (rot* x, c2), c1)
10649 // -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
10650 if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
10651 bool C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
10653 if (C1 && C2 && N1.getValueType() == N0.getOperand(1).getValueType()) {
10654 EVT ShiftVT = N1.getValueType();
10655 bool SameSide = (N->getOpcode() == NextOp);
10656 unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
10657 SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
10658 SDValue Norm1 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
10659 {N1, BitsizeC});
10660 SDValue Norm2 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
10661 {N0.getOperand(1), BitsizeC});
10662 if (Norm1 && Norm2)
10663 if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
10664 CombineOp, dl, ShiftVT, {Norm1, Norm2})) {
10665 CombinedShift = DAG.FoldConstantArithmetic(ISD::ADD, dl, ShiftVT,
10666 {CombinedShift, BitsizeC});
10667 SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
10668 ISD::UREM, dl, ShiftVT, {CombinedShift, BitsizeC});
10669 return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
10670 CombinedShiftNorm);
10671 }
10672 }
10673 }
10674 return SDValue();
10675}
10676
10677SDValue DAGCombiner::visitSHL(SDNode *N) {
10678 SDValue N0 = N->getOperand(0);
10679 SDValue N1 = N->getOperand(1);
10680 if (SDValue V = DAG.simplifyShift(N0, N1))
10681 return V;
10682
10683 SDLoc DL(N);
10684 EVT VT = N0.getValueType();
10685 EVT ShiftVT = N1.getValueType();
10686 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10687
10688 // fold (shl c1, c2) -> c1<<c2
10689 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N0, N1}))
10690 return C;
10691
10692 // fold vector ops
10693 if (VT.isVector()) {
10694 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10695 return FoldedVOp;
10696
10697 BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
10698 // If setcc produces all-one true value then:
10699 // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
10700 if (N1CV && N1CV->isConstant()) {
10701 if (N0.getOpcode() == ISD::AND) {
10702 SDValue N00 = N0->getOperand(0);
10703 SDValue N01 = N0->getOperand(1);
10704 BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
10705
10706 if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
10709 if (SDValue C =
10710 DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N01, N1}))
10711 return DAG.getNode(ISD::AND, DL, VT, N00, C);
10712 }
10713 }
10714 }
10715 }
10716
10717 if (SDValue NewSel = foldBinOpIntoSelect(N))
10718 return NewSel;
10719
10720 // if (shl x, c) is known to be zero, return 0
10721 if (DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
10722 return DAG.getConstant(0, DL, VT);
10723
10724 // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
10725 if (N1.getOpcode() == ISD::TRUNCATE &&
10726 N1.getOperand(0).getOpcode() == ISD::AND) {
10727 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10728 return DAG.getNode(ISD::SHL, DL, VT, N0, NewOp1);
10729 }
10730
10731 // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
10732 if (N0.getOpcode() == ISD::SHL) {
10733 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10734 ConstantSDNode *RHS) {
10735 APInt c1 = LHS->getAPIntValue();
10736 APInt c2 = RHS->getAPIntValue();
10737 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10738 return (c1 + c2).uge(OpSizeInBits);
10739 };
10740 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
10741 return DAG.getConstant(0, DL, VT);
10742
10743 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
10744 ConstantSDNode *RHS) {
10745 APInt c1 = LHS->getAPIntValue();
10746 APInt c2 = RHS->getAPIntValue();
10747 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10748 return (c1 + c2).ult(OpSizeInBits);
10749 };
10750 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
10751 SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
10752 return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
10753 }
10754 }
10755
10756 // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
10757 // For this to be valid, the second form must not preserve any of the bits
10758 // that are shifted out by the inner shift in the first form. This means
10759 // the outer shift size must be >= the number of bits added by the ext.
10760 // As a corollary, we don't care what kind of ext it is.
10761 if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
10762 N0.getOpcode() == ISD::ANY_EXTEND ||
10763 N0.getOpcode() == ISD::SIGN_EXTEND) &&
10764 N0.getOperand(0).getOpcode() == ISD::SHL) {
10765 SDValue N0Op0 = N0.getOperand(0);
10766 SDValue InnerShiftAmt = N0Op0.getOperand(1);
10767 EVT InnerVT = N0Op0.getValueType();
10768 uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
10769
10770 auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
10771 ConstantSDNode *RHS) {
10772 APInt c1 = LHS->getAPIntValue();
10773 APInt c2 = RHS->getAPIntValue();
10774 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10775 return c2.uge(OpSizeInBits - InnerBitwidth) &&
10776 (c1 + c2).uge(OpSizeInBits);
10777 };
10778 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
10779 /*AllowUndefs*/ false,
10780 /*AllowTypeMismatch*/ true))
10781 return DAG.getConstant(0, DL, VT);
10782
10783 auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
10784 ConstantSDNode *RHS) {
10785 APInt c1 = LHS->getAPIntValue();
10786 APInt c2 = RHS->getAPIntValue();
10787 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10788 return c2.uge(OpSizeInBits - InnerBitwidth) &&
10789 (c1 + c2).ult(OpSizeInBits);
10790 };
10791 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
10792 /*AllowUndefs*/ false,
10793 /*AllowTypeMismatch*/ true)) {
10794 SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
10795 SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
10796 Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
10797 return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
10798 }
10799 }
10800
10801 // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
10802 // Only fold this if the inner zext has no other uses to avoid increasing
10803 // the total number of instructions.
10804 if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
10805 N0.getOperand(0).getOpcode() == ISD::SRL) {
10806 SDValue N0Op0 = N0.getOperand(0);
10807 SDValue InnerShiftAmt = N0Op0.getOperand(1);
10808
10809 auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10810 APInt c1 = LHS->getAPIntValue();
10811 APInt c2 = RHS->getAPIntValue();
10812 zeroExtendToMatch(c1, c2);
10813 return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
10814 };
10815 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
10816 /*AllowUndefs*/ false,
10817 /*AllowTypeMismatch*/ true)) {
10818 EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
10819 SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
10820 NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
10821 AddToWorklist(NewSHL.getNode());
10822 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
10823 }
10824 }
10825
10826 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
10827 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10828 ConstantSDNode *RHS) {
10829 const APInt &LHSC = LHS->getAPIntValue();
10830 const APInt &RHSC = RHS->getAPIntValue();
10831 return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
10832 LHSC.getZExtValue() <= RHSC.getZExtValue();
10833 };
10834
10835 // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
10836 // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
10837 if (N0->getFlags().hasExact()) {
10838 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10839 /*AllowUndefs*/ false,
10840 /*AllowTypeMismatch*/ true)) {
10841 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10842 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10843 return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10844 }
10845 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10846 /*AllowUndefs*/ false,
10847 /*AllowTypeMismatch*/ true)) {
10848 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10849 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10850 return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff);
10851 }
10852 }
10853
10854 // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
10855 // (and (srl x, (sub c1, c2), MASK)
10856 // Only fold this if the inner shift has no other uses -- if it does,
10857 // folding this will increase the total number of instructions.
10858 if (N0.getOpcode() == ISD::SRL &&
10859 (N0.getOperand(1) == N1 || N0.hasOneUse()) &&
10861 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10862 /*AllowUndefs*/ false,
10863 /*AllowTypeMismatch*/ true)) {
10864 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10865 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10866 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10867 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N01);
10868 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, Diff);
10869 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
10870 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10871 }
10872 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10873 /*AllowUndefs*/ false,
10874 /*AllowTypeMismatch*/ true)) {
10875 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10876 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10877 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10878 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N1);
10879 SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10880 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10881 }
10882 }
10883 }
10884
10885 // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
10886 if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
10887 isConstantOrConstantVector(N1, /* No Opaques */ true)) {
10888 SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
10889 SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
10890 return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
10891 }
10892
10893 // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
10894 // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
10895 // Variant of version done on multiply, except mul by a power of 2 is turned
10896 // into a shift.
10897 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
10898 TLI.isDesirableToCommuteWithShift(N, Level)) {
10899 SDValue N01 = N0.getOperand(1);
10900 if (SDValue Shl1 =
10901 DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1})) {
10902 SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
10903 AddToWorklist(Shl0.getNode());
10904 SDNodeFlags Flags;
10905 // Preserve the disjoint flag for Or.
10906 if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
10908 return DAG.getNode(N0.getOpcode(), DL, VT, Shl0, Shl1, Flags);
10909 }
10910 }
10911
10912 // fold (shl (sext (add_nsw x, c1)), c2) -> (add (shl (sext x), c2), c1 << c2)
10913 // TODO: Add zext/add_nuw variant with suitable test coverage
10914 // TODO: Should we limit this with isLegalAddImmediate?
10915 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
10916 N0.getOperand(0).getOpcode() == ISD::ADD &&
10917 N0.getOperand(0)->getFlags().hasNoSignedWrap() &&
10918 TLI.isDesirableToCommuteWithShift(N, Level)) {
10919 SDValue Add = N0.getOperand(0);
10920 SDLoc DL(N0);
10921 if (SDValue ExtC = DAG.FoldConstantArithmetic(N0.getOpcode(), DL, VT,
10922 {Add.getOperand(1)})) {
10923 if (SDValue ShlC =
10924 DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {ExtC, N1})) {
10925 SDValue ExtX = DAG.getNode(N0.getOpcode(), DL, VT, Add.getOperand(0));
10926 SDValue ShlX = DAG.getNode(ISD::SHL, DL, VT, ExtX, N1);
10927 return DAG.getNode(ISD::ADD, DL, VT, ShlX, ShlC);
10928 }
10929 }
10930 }
10931
10932 // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
10933 if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
10934 SDValue N01 = N0.getOperand(1);
10935 if (SDValue Shl =
10936 DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1}))
10937 return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), Shl);
10938 }
10939
10940 ConstantSDNode *N1C = isConstOrConstSplat(N1);
10941 if (N1C && !N1C->isOpaque())
10942 if (SDValue NewSHL = visitShiftByConstant(N))
10943 return NewSHL;
10944
10945 // fold (shl X, cttz(Y)) -> (mul (Y & -Y), X) if cttz is unsupported on the
10946 // target.
10947 if (((N1.getOpcode() == ISD::CTTZ &&
10948 VT.getScalarSizeInBits() <= ShiftVT.getScalarSizeInBits()) ||
10950 N1.hasOneUse() && !TLI.isOperationLegalOrCustom(ISD::CTTZ, ShiftVT) &&
10952 SDValue Y = N1.getOperand(0);
10953 SDLoc DL(N);
10954 SDValue NegY = DAG.getNegative(Y, DL, ShiftVT);
10955 SDValue And =
10956 DAG.getZExtOrTrunc(DAG.getNode(ISD::AND, DL, ShiftVT, Y, NegY), DL, VT);
10957 return DAG.getNode(ISD::MUL, DL, VT, And, N0);
10958 }
10959
10961 return SDValue(N, 0);
10962
10963 // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
10964 if (N0.getOpcode() == ISD::VSCALE && N1C) {
10965 const APInt &C0 = N0.getConstantOperandAPInt(0);
10966 const APInt &C1 = N1C->getAPIntValue();
10967 return DAG.getVScale(DL, VT, C0 << C1);
10968 }
10969
10970 SDValue X;
10971 APInt VS0;
10972
10973 // fold (shl (X * vscale(VS0)), C1) -> (X * vscale(VS0 << C1))
10974 if (N1C && sd_match(N0, m_Mul(m_Value(X), m_VScale(m_ConstInt(VS0))))) {
10975 SDNodeFlags Flags;
10976 Flags.setNoUnsignedWrap(N->getFlags().hasNoUnsignedWrap() &&
10977 N0->getFlags().hasNoUnsignedWrap());
10978
10979 SDValue VScale = DAG.getVScale(DL, VT, VS0 << N1C->getAPIntValue());
10980 return DAG.getNode(ISD::MUL, DL, VT, X, VScale, Flags);
10981 }
10982
10983 // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
10984 APInt ShlVal;
10985 if (N0.getOpcode() == ISD::STEP_VECTOR &&
10986 ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
10987 const APInt &C0 = N0.getConstantOperandAPInt(0);
10988 if (ShlVal.ult(C0.getBitWidth())) {
10989 APInt NewStep = C0 << ShlVal;
10990 return DAG.getStepVector(DL, VT, NewStep);
10991 }
10992 }
10993
10994 return SDValue();
10995}
10996
10997// Transform a right shift of a multiply into a multiply-high.
10998// Examples:
10999// (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
11000// (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
11002 const TargetLowering &TLI) {
11003 assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
11004 "SRL or SRA node is required here!");
11005
11006 // Check the shift amount. Proceed with the transformation if the shift
11007 // amount is constant.
11008 ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
11009 if (!ShiftAmtSrc)
11010 return SDValue();
11011
11012 // The operation feeding into the shift must be a multiply.
11013 SDValue ShiftOperand = N->getOperand(0);
11014 if (ShiftOperand.getOpcode() != ISD::MUL)
11015 return SDValue();
11016
11017 // Both operands must be equivalent extend nodes.
11018 SDValue LeftOp = ShiftOperand.getOperand(0);
11019 SDValue RightOp = ShiftOperand.getOperand(1);
11020
11021 bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
11022 bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
11023
11024 if (!IsSignExt && !IsZeroExt)
11025 return SDValue();
11026
11027 EVT NarrowVT = LeftOp.getOperand(0).getValueType();
11028 unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
11029
11030 // return true if U may use the lower bits of its operands
11031 auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
11032 if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
11033 return true;
11034 }
11035 ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(U->getOperand(1));
11036 if (!UShiftAmtSrc) {
11037 return true;
11038 }
11039 unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
11040 return UShiftAmt < NarrowVTSize;
11041 };
11042
11043 // If the lower part of the MUL is also used and MUL_LOHI is supported
11044 // do not introduce the MULH in favor of MUL_LOHI
11045 unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
11046 if (!ShiftOperand.hasOneUse() &&
11047 TLI.isOperationLegalOrCustom(MulLoHiOp, NarrowVT) &&
11048 llvm::any_of(ShiftOperand->users(), UserOfLowerBits)) {
11049 return SDValue();
11050 }
11051
11052 SDValue MulhRightOp;
11054 unsigned ActiveBits = IsSignExt
11055 ? Constant->getAPIntValue().getSignificantBits()
11056 : Constant->getAPIntValue().getActiveBits();
11057 if (ActiveBits > NarrowVTSize)
11058 return SDValue();
11059 MulhRightOp = DAG.getConstant(
11060 Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
11061 NarrowVT);
11062 } else {
11063 if (LeftOp.getOpcode() != RightOp.getOpcode())
11064 return SDValue();
11065 // Check that the two extend nodes are the same type.
11066 if (NarrowVT != RightOp.getOperand(0).getValueType())
11067 return SDValue();
11068 MulhRightOp = RightOp.getOperand(0);
11069 }
11070
11071 EVT WideVT = LeftOp.getValueType();
11072 // Proceed with the transformation if the wide types match.
11073 assert((WideVT == RightOp.getValueType()) &&
11074 "Cannot have a multiply node with two different operand types.");
11075
11076 // Proceed with the transformation if the wide type is twice as large
11077 // as the narrow type.
11078 if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
11079 return SDValue();
11080
11081 // Check the shift amount with the narrow type size.
11082 // Proceed with the transformation if the shift amount is the width
11083 // of the narrow type.
11084 unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
11085 if (ShiftAmt != NarrowVTSize)
11086 return SDValue();
11087
11088 // If the operation feeding into the MUL is a sign extend (sext),
11089 // we use mulhs. Othewise, zero extends (zext) use mulhu.
11090 unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
11091
11092 // Combine to mulh if mulh is legal/custom for the narrow type on the target
11093 // or if it is a vector type then we could transform to an acceptable type and
11094 // rely on legalization to split/combine the result.
11095 EVT TransformVT = NarrowVT;
11096 if (NarrowVT.isVector()) {
11097 TransformVT = TLI.getLegalTypeToTransformTo(*DAG.getContext(), NarrowVT);
11098 if (TransformVT.getScalarType() != NarrowVT.getScalarType())
11099 return SDValue();
11100 }
11101 if (!TLI.isOperationLegalOrCustom(MulhOpcode, TransformVT))
11102 return SDValue();
11103
11104 SDValue Result =
11105 DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), MulhRightOp);
11106 bool IsSigned = N->getOpcode() == ISD::SRA;
11107 return DAG.getExtOrTrunc(IsSigned, Result, DL, WideVT);
11108}
11109
11110// fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y))
11111// This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE
11113 unsigned Opcode = N->getOpcode();
11114 if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE)
11115 return SDValue();
11116
11117 SDValue N0 = N->getOperand(0);
11118 EVT VT = N->getValueType(0);
11119 SDLoc DL(N);
11120 SDValue X, Y;
11121
11122 // If both operands are bswap/bitreverse, ignore the multiuse
11124 m_UnaryOp(Opcode, m_Value(Y))))))
11125 return DAG.getNode(N0.getOpcode(), DL, VT, X, Y);
11126
11127 // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use.
11129 m_OneUse(m_UnaryOp(Opcode, m_Value(X))), m_Value(Y))))) {
11130 SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Y);
11131 return DAG.getNode(N0.getOpcode(), DL, VT, X, NewBitReorder);
11132 }
11133
11134 return SDValue();
11135}
11136
11137SDValue DAGCombiner::visitSRA(SDNode *N) {
11138 SDValue N0 = N->getOperand(0);
11139 SDValue N1 = N->getOperand(1);
11140 if (SDValue V = DAG.simplifyShift(N0, N1))
11141 return V;
11142
11143 SDLoc DL(N);
11144 EVT VT = N0.getValueType();
11145 unsigned OpSizeInBits = VT.getScalarSizeInBits();
11146
11147 // fold (sra c1, c2) -> (sra c1, c2)
11148 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, DL, VT, {N0, N1}))
11149 return C;
11150
11151 // Arithmetic shifting an all-sign-bit value is a no-op.
11152 // fold (sra 0, x) -> 0
11153 // fold (sra -1, x) -> -1
11154 if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
11155 return N0;
11156
11157 // fold vector ops
11158 if (VT.isVector())
11159 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
11160 return FoldedVOp;
11161
11162 if (SDValue NewSel = foldBinOpIntoSelect(N))
11163 return NewSel;
11164
11165 ConstantSDNode *N1C = isConstOrConstSplat(N1);
11166
11167 // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
11168 // clamp (add c1, c2) to max shift.
11169 if (N0.getOpcode() == ISD::SRA) {
11170 EVT ShiftVT = N1.getValueType();
11171 EVT ShiftSVT = ShiftVT.getScalarType();
11172 SmallVector<SDValue, 16> ShiftValues;
11173
11174 auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
11175 APInt c1 = LHS->getAPIntValue();
11176 APInt c2 = RHS->getAPIntValue();
11177 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
11178 APInt Sum = c1 + c2;
11179 unsigned ShiftSum =
11180 Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
11181 ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
11182 return true;
11183 };
11184 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
11185 SDValue ShiftValue;
11186 if (N1.getOpcode() == ISD::BUILD_VECTOR)
11187 ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
11188 else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
11189 assert(ShiftValues.size() == 1 &&
11190 "Expected matchBinaryPredicate to return one element for "
11191 "SPLAT_VECTORs");
11192 ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
11193 } else
11194 ShiftValue = ShiftValues[0];
11195 return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
11196 }
11197 }
11198
11199 // fold (sra (xor (sra x, c1), -1), c2) -> (xor (sra x, c3), -1)
11200 // This allows merging two arithmetic shifts even when there's a NOT in
11201 // between.
11202 SDValue X;
11203 APInt C1;
11204 if (N1C && sd_match(N0, m_OneUse(m_Not(
11205 m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1))))))) {
11206 APInt C2 = N1C->getAPIntValue();
11207 zeroExtendToMatch(C1, C2, 1 /* Overflow Bit */);
11208 APInt Sum = C1 + C2;
11209 unsigned ShiftSum = Sum.getLimitedValue(OpSizeInBits - 1);
11210 SDValue NewShift = DAG.getNode(
11211 ISD::SRA, DL, VT, X, DAG.getShiftAmountConstant(ShiftSum, VT, DL));
11212 return DAG.getNOT(DL, NewShift, VT);
11213 }
11214
11215 // fold (sra (shl X, m), (sub result_size, n))
11216 // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
11217 // result_size - n != m.
11218 // If truncate is free for the target sext(shl) is likely to result in better
11219 // code.
11220 if (N0.getOpcode() == ISD::SHL && N1C) {
11221 // Get the two constants of the shifts, CN0 = m, CN = n.
11222 const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
11223 if (N01C) {
11224 LLVMContext &Ctx = *DAG.getContext();
11225 // Determine what the truncate's result bitsize and type would be.
11226 EVT TruncVT = VT.changeElementType(
11227 Ctx, EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue()));
11228
11229 // Determine the residual right-shift amount.
11230 int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
11231
11232 // If the shift is not a no-op (in which case this should be just a sign
11233 // extend already), the truncated to type is legal, sign_extend is legal
11234 // on that type, and the truncate to that type is both legal and free,
11235 // perform the transform.
11236 if ((ShiftAmt > 0) &&
11239 TLI.isTruncateFree(VT, TruncVT)) {
11240 SDValue Amt = DAG.getShiftAmountConstant(ShiftAmt, VT, DL);
11241 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
11242 N0.getOperand(0), Amt);
11243 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
11244 Shift);
11245 return DAG.getNode(ISD::SIGN_EXTEND, DL,
11246 N->getValueType(0), Trunc);
11247 }
11248 }
11249 }
11250
11251 // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
11252 // sra (add (shl X, N1C), AddC), N1C -->
11253 // sext (add (trunc X to (width - N1C)), AddC')
11254 // sra (sub AddC, (shl X, N1C)), N1C -->
11255 // sext (sub AddC1',(trunc X to (width - N1C)))
11256 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
11257 N0.hasOneUse()) {
11258 bool IsAdd = N0.getOpcode() == ISD::ADD;
11259 SDValue Shl = N0.getOperand(IsAdd ? 0 : 1);
11260 if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(1) == N1 &&
11261 Shl.hasOneUse()) {
11262 // TODO: AddC does not need to be a splat.
11263 if (ConstantSDNode *AddC =
11264 isConstOrConstSplat(N0.getOperand(IsAdd ? 1 : 0))) {
11265 // Determine what the truncate's type would be and ask the target if
11266 // that is a free operation.
11267 LLVMContext &Ctx = *DAG.getContext();
11268 unsigned ShiftAmt = N1C->getZExtValue();
11269 EVT TruncVT = VT.changeElementType(
11270 Ctx, EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt));
11271
11272 // TODO: The simple type check probably belongs in the default hook
11273 // implementation and/or target-specific overrides (because
11274 // non-simple types likely require masking when legalized), but
11275 // that restriction may conflict with other transforms.
11276 if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
11277 TLI.isTruncateFree(VT, TruncVT)) {
11278 SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
11279 SDValue ShiftC =
11280 DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).trunc(
11281 TruncVT.getScalarSizeInBits()),
11282 DL, TruncVT);
11283 SDValue Add;
11284 if (IsAdd)
11285 Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
11286 else
11287 Add = DAG.getNode(ISD::SUB, DL, TruncVT, ShiftC, Trunc);
11288 return DAG.getSExtOrTrunc(Add, DL, VT);
11289 }
11290 }
11291 }
11292 }
11293
11294 // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
11295 if (N1.getOpcode() == ISD::TRUNCATE &&
11296 N1.getOperand(0).getOpcode() == ISD::AND) {
11297 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
11298 return DAG.getNode(ISD::SRA, DL, VT, N0, NewOp1);
11299 }
11300
11301 // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
11302 // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
11303 // if c1 is equal to the number of bits the trunc removes
11304 // TODO - support non-uniform vector shift amounts.
11305 if (N0.getOpcode() == ISD::TRUNCATE &&
11306 (N0.getOperand(0).getOpcode() == ISD::SRL ||
11307 N0.getOperand(0).getOpcode() == ISD::SRA) &&
11308 N0.getOperand(0).hasOneUse() &&
11309 N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
11310 SDValue N0Op0 = N0.getOperand(0);
11311 if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
11312 EVT LargeVT = N0Op0.getValueType();
11313 unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
11314 if (LargeShift->getAPIntValue() == TruncBits) {
11315 EVT LargeShiftVT = getShiftAmountTy(LargeVT);
11316 SDValue Amt = DAG.getZExtOrTrunc(N1, DL, LargeShiftVT);
11317 Amt = DAG.getNode(ISD::ADD, DL, LargeShiftVT, Amt,
11318 DAG.getConstant(TruncBits, DL, LargeShiftVT));
11319 SDValue SRA =
11320 DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
11321 return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
11322 }
11323 }
11324 }
11325
11326 // Simplify, based on bits shifted out of the LHS.
11328 return SDValue(N, 0);
11329
11330 // If the sign bit is known to be zero, switch this to a SRL.
11331 if (DAG.SignBitIsZero(N0))
11332 return DAG.getNode(ISD::SRL, DL, VT, N0, N1);
11333
11334 if (N1C && !N1C->isOpaque())
11335 if (SDValue NewSRA = visitShiftByConstant(N))
11336 return NewSRA;
11337
11338 // Try to transform this shift into a multiply-high if
11339 // it matches the appropriate pattern detected in combineShiftToMULH.
11340 if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
11341 return MULH;
11342
11343 // Attempt to convert a sra of a load into a narrower sign-extending load.
11344 if (SDValue NarrowLoad = reduceLoadWidth(N))
11345 return NarrowLoad;
11346
11347 if (SDValue AVG = foldShiftToAvg(N, DL))
11348 return AVG;
11349
11350 return SDValue();
11351}
11352
11353SDValue DAGCombiner::visitSRL(SDNode *N) {
11354 SDValue N0 = N->getOperand(0);
11355 SDValue N1 = N->getOperand(1);
11356 if (SDValue V = DAG.simplifyShift(N0, N1))
11357 return V;
11358
11359 SDLoc DL(N);
11360 EVT VT = N0.getValueType();
11361 EVT ShiftVT = N1.getValueType();
11362 unsigned OpSizeInBits = VT.getScalarSizeInBits();
11363
11364 // fold (srl c1, c2) -> c1 >>u c2
11365 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, DL, VT, {N0, N1}))
11366 return C;
11367
11368 // fold vector ops
11369 if (VT.isVector())
11370 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
11371 return FoldedVOp;
11372
11373 if (SDValue NewSel = foldBinOpIntoSelect(N))
11374 return NewSel;
11375
11376 // if (srl x, c) is known to be zero, return 0
11377 ConstantSDNode *N1C = isConstOrConstSplat(N1);
11378 if (N1C &&
11379 DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
11380 return DAG.getConstant(0, DL, VT);
11381
11382 // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
11383 if (N0.getOpcode() == ISD::SRL) {
11384 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
11385 ConstantSDNode *RHS) {
11386 APInt c1 = LHS->getAPIntValue();
11387 APInt c2 = RHS->getAPIntValue();
11388 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
11389 return (c1 + c2).uge(OpSizeInBits);
11390 };
11391 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
11392 return DAG.getConstant(0, DL, VT);
11393
11394 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
11395 ConstantSDNode *RHS) {
11396 APInt c1 = LHS->getAPIntValue();
11397 APInt c2 = RHS->getAPIntValue();
11398 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
11399 return (c1 + c2).ult(OpSizeInBits);
11400 };
11401 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
11402 SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
11403 return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
11404 }
11405 }
11406
11407 if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
11408 N0.getOperand(0).getOpcode() == ISD::SRL) {
11409 SDValue InnerShift = N0.getOperand(0);
11410 // TODO - support non-uniform vector shift amounts.
11411 if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
11412 uint64_t c1 = N001C->getZExtValue();
11413 uint64_t c2 = N1C->getZExtValue();
11414 EVT InnerShiftVT = InnerShift.getValueType();
11415 EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
11416 uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
11417 // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
11418 // This is only valid if the OpSizeInBits + c1 = size of inner shift.
11419 if (c1 + OpSizeInBits == InnerShiftSize) {
11420 if (c1 + c2 >= InnerShiftSize)
11421 return DAG.getConstant(0, DL, VT);
11422 SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
11423 SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
11424 InnerShift.getOperand(0), NewShiftAmt);
11425 return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
11426 }
11427 // In the more general case, we can clear the high bits after the shift:
11428 // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
11429 if (N0.hasOneUse() && InnerShift.hasOneUse() &&
11430 c1 + c2 < InnerShiftSize) {
11431 SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
11432 SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
11433 InnerShift.getOperand(0), NewShiftAmt);
11434 SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
11435 OpSizeInBits - c2),
11436 DL, InnerShiftVT);
11437 SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
11438 return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
11439 }
11440 }
11441 }
11442
11443 if (N0.getOpcode() == ISD::SHL) {
11444 // fold (srl (shl nuw x, c), c) -> x
11445 if (N0.getOperand(1) == N1 && N0->getFlags().hasNoUnsignedWrap())
11446 return N0.getOperand(0);
11447
11448 // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
11449 // (and (srl x, (sub c2, c1), MASK)
11450 if ((N0.getOperand(1) == N1 || N0->hasOneUse()) &&
11452 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
11453 ConstantSDNode *RHS) {
11454 const APInt &LHSC = LHS->getAPIntValue();
11455 const APInt &RHSC = RHS->getAPIntValue();
11456 return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
11457 LHSC.getZExtValue() <= RHSC.getZExtValue();
11458 };
11459 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
11460 /*AllowUndefs*/ false,
11461 /*AllowTypeMismatch*/ true)) {
11462 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
11463 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
11464 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
11465 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
11466 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
11467 SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
11468 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
11469 }
11470 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
11471 /*AllowUndefs*/ false,
11472 /*AllowTypeMismatch*/ true)) {
11473 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
11474 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
11475 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
11476 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
11477 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
11478 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
11479 }
11480 }
11481 }
11482
11483 // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
11484 // TODO - support non-uniform vector shift amounts.
11485 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
11486 // Shifting in all undef bits?
11487 EVT SmallVT = N0.getOperand(0).getValueType();
11488 unsigned BitSize = SmallVT.getScalarSizeInBits();
11489 if (N1C->getAPIntValue().uge(BitSize))
11490 return DAG.getUNDEF(VT);
11491
11492 if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
11493 uint64_t ShiftAmt = N1C->getZExtValue();
11494 SDLoc DL0(N0);
11495 SDValue SmallShift =
11496 DAG.getNode(ISD::SRL, DL0, SmallVT, N0.getOperand(0),
11497 DAG.getShiftAmountConstant(ShiftAmt, SmallVT, DL0));
11498 AddToWorklist(SmallShift.getNode());
11499 APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
11500 return DAG.getNode(ISD::AND, DL, VT,
11501 DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
11502 DAG.getConstant(Mask, DL, VT));
11503 }
11504 }
11505
11506 // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign
11507 // bit, which is unmodified by sra.
11508 if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
11509 if (N0.getOpcode() == ISD::SRA)
11510 return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), N1);
11511 }
11512
11513 // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit), and x has a power
11514 // of two bitwidth. The "5" represents (log2 (bitwidth x)).
11515 if (N1C && N0.getOpcode() == ISD::CTLZ &&
11516 isPowerOf2_32(OpSizeInBits) &&
11517 N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
11518 KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
11519
11520 // If any of the input bits are KnownOne, then the input couldn't be all
11521 // zeros, thus the result of the srl will always be zero.
11522 if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
11523
11524 // If all of the bits input the to ctlz node are known to be zero, then
11525 // the result of the ctlz is "32" and the result of the shift is one.
11526 APInt UnknownBits = ~Known.Zero;
11527 if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
11528
11529 // Otherwise, check to see if there is exactly one bit input to the ctlz.
11530 if (UnknownBits.isPowerOf2()) {
11531 // Okay, we know that only that the single bit specified by UnknownBits
11532 // could be set on input to the CTLZ node. If this bit is set, the SRL
11533 // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
11534 // to an SRL/XOR pair, which is likely to simplify more.
11535 unsigned ShAmt = UnknownBits.countr_zero();
11536 SDValue Op = N0.getOperand(0);
11537
11538 if (ShAmt) {
11539 SDLoc DL(N0);
11540 Op = DAG.getNode(ISD::SRL, DL, VT, Op,
11541 DAG.getShiftAmountConstant(ShAmt, VT, DL));
11542 AddToWorklist(Op.getNode());
11543 }
11544 return DAG.getNode(ISD::XOR, DL, VT, Op, DAG.getConstant(1, DL, VT));
11545 }
11546 }
11547
11548 // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
11549 if (N1.getOpcode() == ISD::TRUNCATE &&
11550 N1.getOperand(0).getOpcode() == ISD::AND) {
11551 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
11552 return DAG.getNode(ISD::SRL, DL, VT, N0, NewOp1);
11553 }
11554
11555 // fold (srl (logic_op x, (shl (zext y), c1)), c1)
11556 // -> (logic_op (srl x, c1), (zext y))
11557 // c1 <= leadingzeros(zext(y))
11558 SDValue X, ZExtY;
11559 if (N1C && sd_match(N0, m_OneUse(m_BitwiseLogic(
11560 m_Value(X),
11563 m_Specific(N1))))))) {
11564 unsigned NumLeadingZeros = ZExtY.getScalarValueSizeInBits() -
11566 if (N1C->getZExtValue() <= NumLeadingZeros)
11567 return DAG.getNode(N0.getOpcode(), SDLoc(N0), VT,
11568 DAG.getNode(ISD::SRL, SDLoc(N0), VT, X, N1), ZExtY);
11569 }
11570
11571 // fold operands of srl based on knowledge that the low bits are not
11572 // demanded.
11574 return SDValue(N, 0);
11575
11576 if (N1C && !N1C->isOpaque())
11577 if (SDValue NewSRL = visitShiftByConstant(N))
11578 return NewSRL;
11579
11580 // Attempt to convert a srl of a load into a narrower zero-extending load.
11581 if (SDValue NarrowLoad = reduceLoadWidth(N))
11582 return NarrowLoad;
11583
11584 // Here is a common situation. We want to optimize:
11585 //
11586 // %a = ...
11587 // %b = and i32 %a, 2
11588 // %c = srl i32 %b, 1
11589 // brcond i32 %c ...
11590 //
11591 // into
11592 //
11593 // %a = ...
11594 // %b = and %a, 2
11595 // %c = setcc eq %b, 0
11596 // brcond %c ...
11597 //
11598 // However when after the source operand of SRL is optimized into AND, the SRL
11599 // itself may not be optimized further. Look for it and add the BRCOND into
11600 // the worklist.
11601 //
11602 // The also tends to happen for binary operations when SimplifyDemandedBits
11603 // is involved.
11604 //
11605 // FIXME: This is unecessary if we process the DAG in topological order,
11606 // which we plan to do. This workaround can be removed once the DAG is
11607 // processed in topological order.
11608 if (N->hasOneUse()) {
11609 SDNode *User = *N->user_begin();
11610
11611 // Look pass the truncate.
11612 if (User->getOpcode() == ISD::TRUNCATE && User->hasOneUse())
11613 User = *User->user_begin();
11614
11615 if (User->getOpcode() == ISD::BRCOND || User->getOpcode() == ISD::AND ||
11616 User->getOpcode() == ISD::OR || User->getOpcode() == ISD::XOR)
11617 AddToWorklist(User);
11618 }
11619
11620 // Try to transform this shift into a multiply-high if
11621 // it matches the appropriate pattern detected in combineShiftToMULH.
11622 if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
11623 return MULH;
11624
11625 if (SDValue AVG = foldShiftToAvg(N, DL))
11626 return AVG;
11627
11628 SDValue Y;
11629 if (VT.getScalarSizeInBits() % 2 == 0 && N1C) {
11630 // Fold clmul(zext(x), zext(y)) >> (BW - 1 | BW) -> clmul(r|h)(x, y).
11631 unsigned HalfBW = VT.getScalarSizeInBits() / 2;
11632 if (sd_match(N0, m_Clmul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
11633 X.getScalarValueSizeInBits() == HalfBW &&
11634 Y.getScalarValueSizeInBits() == HalfBW) {
11635 if (N1C->getZExtValue() == HalfBW - 1 &&
11636 (!LegalOperations ||
11637 TLI.isOperationLegalOrCustom(ISD::CLMULR, X.getValueType())))
11638 return DAG.getNode(
11639 ISD::ZERO_EXTEND, DL, VT,
11640 DAG.getNode(ISD::CLMULR, DL, X.getValueType(), X, Y));
11641 if (N1C->getZExtValue() == HalfBW &&
11642 (!LegalOperations ||
11643 TLI.isOperationLegalOrCustom(ISD::CLMULH, X.getValueType())))
11644 return DAG.getNode(
11645 ISD::ZERO_EXTEND, DL, VT,
11646 DAG.getNode(ISD::CLMULH, DL, X.getValueType(), X, Y));
11647 }
11648 }
11649
11650 // Fold bitreverse(clmul(bitreverse(x), bitreverse(y))) >> 1 ->
11651 // clmulh(x, y).
11652 if (N1C && N1C->getZExtValue() == 1 &&
11654 m_BitReverse(m_Value(Y))))))
11655 return DAG.getNode(ISD::CLMULH, DL, VT, X, Y);
11656
11657 return SDValue();
11658}
11659
11660SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
11661 EVT VT = N->getValueType(0);
11662 SDValue N0 = N->getOperand(0);
11663 SDValue N1 = N->getOperand(1);
11664 SDValue N2 = N->getOperand(2);
11665 bool IsFSHL = N->getOpcode() == ISD::FSHL;
11666 unsigned BitWidth = VT.getScalarSizeInBits();
11667 SDLoc DL(N);
11668
11669 // fold (fshl/fshr C0, C1, C2) -> C3
11670 if (SDValue C =
11671 DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1, N2}))
11672 return C;
11673
11674 // fold (fshl N0, N1, 0) -> N0
11675 // fold (fshr N0, N1, 0) -> N1
11677 if (DAG.MaskedValueIsZero(
11678 N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
11679 return IsFSHL ? N0 : N1;
11680
11681 auto IsUndefOrZero = [](SDValue V) {
11682 return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
11683 };
11684
11685 // TODO - support non-uniform vector shift amounts.
11686 if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
11687 EVT ShAmtTy = N2.getValueType();
11688
11689 // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
11690 if (Cst->getAPIntValue().uge(BitWidth)) {
11691 uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
11692 return DAG.getNode(N->getOpcode(), DL, VT, N0, N1,
11693 DAG.getConstant(RotAmt, DL, ShAmtTy));
11694 }
11695
11696 unsigned ShAmt = Cst->getZExtValue();
11697 if (ShAmt == 0)
11698 return IsFSHL ? N0 : N1;
11699
11700 // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
11701 // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
11702 // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
11703 // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
11704 if (IsUndefOrZero(N0))
11705 return DAG.getNode(
11706 ISD::SRL, DL, VT, N1,
11707 DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt, DL, ShAmtTy));
11708 if (IsUndefOrZero(N1))
11709 return DAG.getNode(
11710 ISD::SHL, DL, VT, N0,
11711 DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt, DL, ShAmtTy));
11712
11713 // fold fshl(N0, N1, c) -> x and fshr(N0, N1, c) -> x
11714 // where N0 is any node that contributes "x >> C0" to the result:
11715 // lshr(x, C0) | fshr(_, x, C0) | fshl(_, x, C1)
11716 // and N1 is any node that contributes "x << C1" to the result:
11717 // shl(x, C1) | fshl(x, _, C1) | fshr(x, _, C0)
11718 // with C0 = IsFSHL ? amnt : BW-amnt, C1 = BW - C0
11719
11720 // ShAmt == 0 was handled above; uge(BitWidth) was reduced via modulo above.
11721 assert(ShAmt >= 1 && ShAmt < BitWidth &&
11722 "ShAmt must be in [1, BW-1] for the identity fold to be valid");
11723 SDValue Val;
11724 unsigned C0Expected = IsFSHL ? ShAmt : BitWidth - ShAmt;
11725 unsigned C1Expected = IsFSHL ? BitWidth - ShAmt : ShAmt;
11726
11727 if ((sd_match(N0, m_Srl(m_Value(Val), m_SpecificInt(C0Expected))) ||
11729 m_SpecificInt(C0Expected))) ||
11731 m_SpecificInt(C1Expected)))) &&
11732 (sd_match(N1, m_Shl(m_Specific(Val), m_SpecificInt(C1Expected))) ||
11734 m_SpecificInt(C1Expected))) ||
11736 m_SpecificInt(C0Expected)))))
11737 return Val;
11738
11739 // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
11740 // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
11741 // TODO - bigendian support once we have test coverage.
11742 // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
11743 // TODO - permit LHS EXTLOAD if extensions are shifted out.
11744 if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
11745 !DAG.getDataLayout().isBigEndian()) {
11746 auto *LHS = dyn_cast<LoadSDNode>(N0);
11747 auto *RHS = dyn_cast<LoadSDNode>(N1);
11748 if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
11749 LHS->getAddressSpace() == RHS->getAddressSpace() &&
11750 (LHS->hasNUsesOfValue(1, 0) || RHS->hasNUsesOfValue(1, 0)) &&
11752 if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
11753 SDLoc DL(RHS);
11754 uint64_t PtrOff =
11755 IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
11756 Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
11757 unsigned Fast = 0;
11758 if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
11759 RHS->getAddressSpace(), NewAlign,
11760 RHS->getMemOperand()->getFlags(), &Fast) &&
11761 Fast) {
11762 SDValue NewPtr = DAG.getMemBasePlusOffset(
11763 RHS->getBasePtr(), TypeSize::getFixed(PtrOff), DL);
11764 AddToWorklist(NewPtr.getNode());
11765 SDValue Load = DAG.getLoad(
11766 VT, DL, RHS->getChain(), NewPtr,
11767 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
11768 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
11769 DAG.makeEquivalentMemoryOrdering(LHS, Load.getValue(1));
11770 DAG.makeEquivalentMemoryOrdering(RHS, Load.getValue(1));
11771 return Load;
11772 }
11773 }
11774 }
11775 }
11776 }
11777
11778 // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
11779 // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
11780 // iff We know the shift amount is in range.
11781 // TODO: when is it worth doing SUB(BW, N2) as well?
11782 if (isPowerOf2_32(BitWidth)) {
11783 APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
11784 if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
11785 return DAG.getNode(ISD::SRL, DL, VT, N1, N2);
11786 if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
11787 return DAG.getNode(ISD::SHL, DL, VT, N0, N2);
11788 }
11789
11790 // fold (fshl N0, N0, N2) -> (rotl N0, N2)
11791 // fold (fshr N0, N0, N2) -> (rotr N0, N2)
11792 // TODO: Investigate flipping this rotate if only one is legal.
11793 // If funnel shift is legal as well we might be better off avoiding
11794 // non-constant (BW - N2).
11795 unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
11796 if (N0 == N1 && hasOperation(RotOpc, VT))
11797 return DAG.getNode(RotOpc, DL, VT, N0, N2);
11798
11799 // Simplify, based on bits shifted out of N0/N1.
11801 return SDValue(N, 0);
11802
11803 return SDValue();
11804}
11805
11806SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
11807 SDValue N0 = N->getOperand(0);
11808 SDValue N1 = N->getOperand(1);
11809 if (SDValue V = DAG.simplifyShift(N0, N1))
11810 return V;
11811
11812 SDLoc DL(N);
11813 EVT VT = N0.getValueType();
11814
11815 // fold (*shlsat c1, c2) -> c1<<c2
11816 if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1}))
11817 return C;
11818
11819 ConstantSDNode *N1C = isConstOrConstSplat(N1);
11820
11821 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) {
11822 // fold (sshlsat x, c) -> (shl x, c)
11823 if (N->getOpcode() == ISD::SSHLSAT && N1C &&
11824 N1C->getAPIntValue().ult(DAG.ComputeNumSignBits(N0)))
11825 return DAG.getNode(ISD::SHL, DL, VT, N0, N1);
11826
11827 // fold (ushlsat x, c) -> (shl x, c)
11828 if (N->getOpcode() == ISD::USHLSAT && N1C &&
11829 N1C->getAPIntValue().ule(
11831 return DAG.getNode(ISD::SHL, DL, VT, N0, N1);
11832 }
11833
11834 return SDValue();
11835}
11836
11837// Given a ABS node, detect the following patterns:
11838// (ABS (SUB (EXTEND a), (EXTEND b))).
11839// (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))).
11840// Generates UABD/SABD instruction.
11841SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
11842 EVT SrcVT = N->getValueType(0);
11843
11844 if (N->getOpcode() == ISD::TRUNCATE)
11845 N = N->getOperand(0).getNode();
11846
11847 EVT VT = N->getValueType(0);
11848 SDValue Op0, Op1;
11849
11850 if (!sd_match(N, m_Abs(m_Sub(m_Value(Op0), m_Value(Op1)))))
11851 return SDValue();
11852
11853 SDValue AbsOp0 = N->getOperand(0);
11854 unsigned Opc0 = Op0.getOpcode();
11855
11856 // Check if the operands of the sub are (zero|sign)-extended, otherwise
11857 // fallback to ValueTracking.
11858 if (Opc0 != Op1.getOpcode() ||
11859 (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
11860 Opc0 != ISD::SIGN_EXTEND_INREG)) {
11861 // fold (abs (sub nsw x, y)) -> abds(x, y)
11862 // Don't fold this for unsupported types as we lose the NSW handling.
11863 if (hasOperation(ISD::ABDS, VT) && TLI.preferABDSToABSWithNSW(VT) &&
11864 (AbsOp0->getFlags().hasNoSignedWrap() ||
11865 DAG.willNotOverflowSub(/*IsSigned=*/true, Op0, Op1))) {
11866 SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1);
11867 return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11868 }
11869 // fold (abs (sub x, y)) -> abdu(x, y)
11870 if (hasOperation(ISD::ABDU, VT) && DAG.SignBitIsZero(Op0) &&
11871 DAG.SignBitIsZero(Op1)) {
11872 SDValue ABD = DAG.getNode(ISD::ABDU, DL, VT, Op0, Op1);
11873 return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11874 }
11875 return SDValue();
11876 }
11877
11878 EVT VT0, VT1;
11879 if (Opc0 == ISD::SIGN_EXTEND_INREG) {
11880 VT0 = cast<VTSDNode>(Op0.getOperand(1))->getVT();
11881 VT1 = cast<VTSDNode>(Op1.getOperand(1))->getVT();
11882 } else {
11883 VT0 = Op0.getOperand(0).getValueType();
11884 VT1 = Op1.getOperand(0).getValueType();
11885 }
11886 unsigned ABDOpcode = (Opc0 == ISD::ZERO_EXTEND) ? ISD::ABDU : ISD::ABDS;
11887
11888 // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
11889 // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
11890 EVT MaxVT = VT0.bitsGT(VT1) ? VT0 : VT1;
11891 if ((VT0 == MaxVT || Op0->hasOneUse()) &&
11892 (VT1 == MaxVT || Op1->hasOneUse()) &&
11893 (!LegalTypes || hasOperation(ABDOpcode, MaxVT))) {
11894 SDValue ABD = DAG.getNode(ABDOpcode, DL, MaxVT,
11895 DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op0),
11896 DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op1));
11897 ABD = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ABD);
11898 return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11899 }
11900
11901 // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
11902 // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
11903 if (!LegalOperations || hasOperation(ABDOpcode, VT)) {
11904 SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1);
11905 return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11906 }
11907
11908 return SDValue();
11909}
11910
11911SDValue DAGCombiner::visitABS(SDNode *N) {
11912 SDValue N0 = N->getOperand(0);
11913 EVT VT = N->getValueType(0);
11914 SDLoc DL(N);
11915
11916 // fold (abs c1) -> c2
11917 if (SDValue C = DAG.FoldConstantArithmetic(ISD::ABS, DL, VT, {N0}))
11918 return C;
11919 // fold (abs (abs x)) -> (abs x)
11920 if (N0.getOpcode() == ISD::ABS)
11921 return N0;
11922 // fold (abs x) -> x iff not-negative
11923 if (DAG.SignBitIsZero(N0))
11924 return N0;
11925
11926 if (SDValue ABD = foldABSToABD(N, DL))
11927 return ABD;
11928
11929 // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
11930 // iff zero_extend/truncate are free.
11931 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
11932 EVT ExtVT = cast<VTSDNode>(N0.getOperand(1))->getVT();
11933 if (TLI.isTruncateFree(VT, ExtVT) && TLI.isZExtFree(ExtVT, VT) &&
11934 TLI.isTypeDesirableForOp(ISD::ABS, ExtVT) &&
11935 hasOperation(ISD::ABS, ExtVT)) {
11936 return DAG.getNode(
11937 ISD::ZERO_EXTEND, DL, VT,
11938 DAG.getNode(ISD::ABS, DL, ExtVT,
11939 DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N0.getOperand(0))));
11940 }
11941 }
11942
11943 return SDValue();
11944}
11945
11946SDValue DAGCombiner::visitCLMUL(SDNode *N) {
11947 unsigned Opcode = N->getOpcode();
11948 SDValue N0 = N->getOperand(0);
11949 SDValue N1 = N->getOperand(1);
11950 EVT VT = N->getValueType(0);
11951 SDLoc DL(N);
11952
11953 // fold (clmul c1, c2)
11954 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
11955 return C;
11956
11957 // canonicalize constant to RHS
11960 return DAG.getNode(Opcode, DL, VT, N1, N0);
11961
11962 // fold (clmul x, 0) -> 0
11964 return DAG.getConstant(0, DL, VT);
11965
11966 // fold (clmul x, c_pow2) -> (shl x, log2(c_pow2))
11967 // This also handles (clmul x, 1) -> x since (shl x, 0) simplifies to x.
11968 if (Opcode == ISD::CLMUL) {
11969 if (ConstantSDNode *C = isConstOrConstSplat(N1)) {
11970 APInt CV = C->getAPIntValue().trunc(VT.getScalarSizeInBits());
11971 if (CV.isPowerOf2() &&
11972 (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)))
11973 return DAG.getNode(ISD::SHL, DL, VT, N0,
11974 DAG.getShiftAmountConstant(CV.logBase2(), VT, DL));
11975 }
11976 }
11977
11978 return SDValue();
11979}
11980
11981SDValue DAGCombiner::visitBSWAP(SDNode *N) {
11982 SDValue N0 = N->getOperand(0);
11983 EVT VT = N->getValueType(0);
11984 SDLoc DL(N);
11985
11986 // fold (bswap c1) -> c2
11987 if (SDValue C = DAG.FoldConstantArithmetic(ISD::BSWAP, DL, VT, {N0}))
11988 return C;
11989 // fold (bswap (bswap x)) -> x
11990 if (N0.getOpcode() == ISD::BSWAP)
11991 return N0.getOperand(0);
11992
11993 // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
11994 // isn't supported, it will be expanded to bswap followed by a manual reversal
11995 // of bits in each byte. By placing bswaps before bitreverse, we can remove
11996 // the two bswaps if the bitreverse gets expanded.
11997 if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
11998 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
11999 return DAG.getNode(ISD::BITREVERSE, DL, VT, BSwap);
12000 }
12001
12002 // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
12003 // iff x >= bw/2 (i.e. lower half is known zero)
12004 unsigned BW = VT.getScalarSizeInBits();
12005 if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
12006 auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
12007 EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), BW / 2);
12008 if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
12009 ShAmt->getZExtValue() >= (BW / 2) &&
12010 (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(HalfVT) &&
12011 TLI.isTruncateFree(VT, HalfVT) &&
12012 (!LegalOperations || hasOperation(ISD::BSWAP, HalfVT))) {
12013 SDValue Res = N0.getOperand(0);
12014 if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
12015 Res = DAG.getNode(ISD::SHL, DL, VT, Res,
12016 DAG.getShiftAmountConstant(NewShAmt, VT, DL));
12017 Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
12018 Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
12019 return DAG.getZExtOrTrunc(Res, DL, VT);
12020 }
12021 }
12022
12023 // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
12024 // inverse-shift-of-bswap:
12025 // bswap (X u<< C) --> (bswap X) u>> C
12026 // bswap (X u>> C) --> (bswap X) u<< C
12027 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
12028 N0.hasOneUse()) {
12029 auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
12030 if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
12031 ShAmt->getZExtValue() % 8 == 0) {
12032 SDValue NewSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
12033 unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
12034 return DAG.getNode(InverseShift, DL, VT, NewSwap, N0.getOperand(1));
12035 }
12036 }
12037
12038 if (SDValue V = foldBitOrderCrossLogicOp(N, DAG))
12039 return V;
12040
12041 return SDValue();
12042}
12043
12044SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
12045 SDValue N0 = N->getOperand(0);
12046 EVT VT = N->getValueType(0);
12047 SDLoc DL(N);
12048
12049 // fold (bitreverse c1) -> c2
12050 if (SDValue C = DAG.FoldConstantArithmetic(ISD::BITREVERSE, DL, VT, {N0}))
12051 return C;
12052
12053 // fold (bitreverse (bitreverse x)) -> x
12054 if (N0.getOpcode() == ISD::BITREVERSE)
12055 return N0.getOperand(0);
12056
12057 SDValue X, Y;
12058
12059 // fold (bitreverse (lshr (bitreverse x), y)) -> (shl x, y)
12060 if ((!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
12062 return DAG.getNode(ISD::SHL, DL, VT, X, Y);
12063
12064 // fold (bitreverse (shl (bitreverse x), y)) -> (lshr x, y)
12065 if ((!LegalOperations || TLI.isOperationLegal(ISD::SRL, VT)) &&
12067 return DAG.getNode(ISD::SRL, DL, VT, X, Y);
12068
12069 // fold bitreverse(clmul(bitreverse(x), bitreverse(y))) -> clmulr(x, y)
12070 if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::CLMULR, VT)) &&
12072 return DAG.getNode(ISD::CLMULR, DL, VT, X, Y);
12073
12074 return SDValue();
12075}
12076
12077// Fold (ctlz (xor x, (sra x, bitwidth-1))) -> (add (ctls x), 1).
12078// Fold (ctlz (or (shl (xor x, (sra x, bitwidth-1)), 1), 1) -> (ctls x)
12079SDValue DAGCombiner::foldCTLZToCTLS(SDValue Src, const SDLoc &DL) {
12080 EVT VT = Src.getValueType();
12081
12082 auto LK = TLI.getTypeConversion(*DAG.getContext(), VT);
12083 if ((LK.first != TargetLoweringBase::TypeLegal &&
12085 !TLI.isOperationLegalOrCustom(ISD::CTLS, LK.second))
12086 return SDValue();
12087
12088 unsigned BitWidth = VT.getScalarSizeInBits();
12089
12090 bool NeedAdd = true;
12091
12092 SDValue X;
12094 m_SpecificInt(1))))) {
12095 NeedAdd = false;
12096 Src = X;
12097 }
12098
12099 if (!sd_match(Src,
12102 m_SpecificInt(BitWidth - 1)))))))
12103 return SDValue();
12104
12105 SDValue Res = DAG.getNode(ISD::CTLS, DL, VT, X);
12106 if (!NeedAdd)
12107 return Res;
12108
12109 return DAG.getNode(ISD::ADD, DL, VT, Res, DAG.getConstant(1, DL, VT));
12110}
12111
12112SDValue DAGCombiner::visitCTLZ(SDNode *N) {
12113 SDValue N0 = N->getOperand(0);
12114 EVT VT = N->getValueType(0);
12115 SDLoc DL(N);
12116
12117 // fold (ctlz c1) -> c2
12118 if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTLZ, DL, VT, {N0}))
12119 return C;
12120
12121 // If the value is known never to be zero, switch to the undef version.
12122 if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT))
12123 if (DAG.isKnownNeverZero(N0))
12124 return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, N0);
12125
12126 if (SDValue V = foldCTLZToCTLS(N0, DL))
12127 return V;
12128
12129 return SDValue();
12130}
12131
12132SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
12133 SDValue N0 = N->getOperand(0);
12134 EVT VT = N->getValueType(0);
12135 SDLoc DL(N);
12136
12137 // fold (ctlz_zero_undef c1) -> c2
12138 if (SDValue C =
12140 return C;
12141
12142 if (SDValue V = foldCTLZToCTLS(N0, DL))
12143 return V;
12144
12145 return SDValue();
12146}
12147
12148SDValue DAGCombiner::visitCTTZ(SDNode *N) {
12149 SDValue N0 = N->getOperand(0);
12150 EVT VT = N->getValueType(0);
12151 SDLoc DL(N);
12152
12153 // fold (cttz c1) -> c2
12154 if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTTZ, DL, VT, {N0}))
12155 return C;
12156
12157 // If the value is known never to be zero, switch to the undef version.
12158 if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT))
12159 if (DAG.isKnownNeverZero(N0))
12160 return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, DL, VT, N0);
12161
12162 return SDValue();
12163}
12164
12165SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
12166 SDValue N0 = N->getOperand(0);
12167 EVT VT = N->getValueType(0);
12168 SDLoc DL(N);
12169
12170 // fold (cttz_zero_undef c1) -> c2
12171 if (SDValue C =
12173 return C;
12174 return SDValue();
12175}
12176
12177SDValue DAGCombiner::visitCTPOP(SDNode *N) {
12178 SDValue N0 = N->getOperand(0);
12179 EVT VT = N->getValueType(0);
12180 unsigned NumBits = VT.getScalarSizeInBits();
12181 SDLoc DL(N);
12182
12183 // fold (ctpop c1) -> c2
12184 if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTPOP, DL, VT, {N0}))
12185 return C;
12186
12187 // If the source is being shifted, but doesn't affect any active bits,
12188 // then we can call CTPOP on the shift source directly.
12189 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SHL) {
12190 if (ConstantSDNode *AmtC = isConstOrConstSplat(N0.getOperand(1))) {
12191 const APInt &Amt = AmtC->getAPIntValue();
12192 if (Amt.ult(NumBits)) {
12193 KnownBits KnownSrc = DAG.computeKnownBits(N0.getOperand(0));
12194 if ((N0.getOpcode() == ISD::SRL &&
12195 Amt.ule(KnownSrc.countMinTrailingZeros())) ||
12196 (N0.getOpcode() == ISD::SHL &&
12197 Amt.ule(KnownSrc.countMinLeadingZeros()))) {
12198 return DAG.getNode(ISD::CTPOP, DL, VT, N0.getOperand(0));
12199 }
12200 }
12201 }
12202 }
12203
12204 // If the upper bits are known to be zero, then see if its profitable to
12205 // only count the lower bits.
12206 if (VT.isScalarInteger() && NumBits > 8 && (NumBits & 1) == 0) {
12207 EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), NumBits / 2);
12208 if (hasOperation(ISD::CTPOP, HalfVT) &&
12209 TLI.isTypeDesirableForOp(ISD::CTPOP, HalfVT) &&
12210 TLI.isTruncateFree(N0, HalfVT) && TLI.isZExtFree(HalfVT, VT)) {
12211 APInt UpperBits = APInt::getHighBitsSet(NumBits, NumBits / 2);
12212 if (DAG.MaskedValueIsZero(N0, UpperBits)) {
12213 SDValue PopCnt = DAG.getNode(ISD::CTPOP, DL, HalfVT,
12214 DAG.getZExtOrTrunc(N0, DL, HalfVT));
12215 return DAG.getZExtOrTrunc(PopCnt, DL, VT);
12216 }
12217 }
12218 }
12219
12220 return SDValue();
12221}
12222
12224 SDValue RHS, const SDNodeFlags Flags,
12225 const TargetLowering &TLI) {
12226 EVT VT = LHS.getValueType();
12227 if (!VT.isFloatingPoint())
12228 return false;
12229
12230 return Flags.hasNoSignedZeros() &&
12232 (Flags.hasNoNaNs() ||
12233 (DAG.isKnownNeverNaN(RHS) && DAG.isKnownNeverNaN(LHS)));
12234}
12235
12237 SDValue RHS, SDValue True, SDValue False,
12238 ISD::CondCode CC,
12239 const TargetLowering &TLI,
12240 SelectionDAG &DAG) {
12241 EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
12242 switch (CC) {
12243 case ISD::SETOLT:
12244 case ISD::SETOLE:
12245 case ISD::SETLT:
12246 case ISD::SETLE:
12247 case ISD::SETULT:
12248 case ISD::SETULE: {
12249 // Since it's known never nan to get here already, either fminnum or
12250 // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
12251 // expanded in terms of it.
12252 unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
12253 if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
12254 return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
12255
12256 unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
12257 if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
12258 return DAG.getNode(Opcode, DL, VT, LHS, RHS);
12259 return SDValue();
12260 }
12261 case ISD::SETOGT:
12262 case ISD::SETOGE:
12263 case ISD::SETGT:
12264 case ISD::SETGE:
12265 case ISD::SETUGT:
12266 case ISD::SETUGE: {
12267 unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
12268 if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
12269 return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
12270
12271 unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
12272 if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
12273 return DAG.getNode(Opcode, DL, VT, LHS, RHS);
12274 return SDValue();
12275 }
12276 default:
12277 return SDValue();
12278 }
12279}
12280
12281// Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
12282SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
12283 const unsigned Opcode = N->getOpcode();
12284 if (Opcode != ISD::SRA && Opcode != ISD::SRL)
12285 return SDValue();
12286
12287 EVT VT = N->getValueType(0);
12288 bool IsUnsigned = Opcode == ISD::SRL;
12289
12290 // Captured values.
12291 SDValue A, B;
12292
12293 // Match floor average as it is common to both floor/ceil avgs, ensure the add
12294 // doesn't wrap.
12295 SDNodeFlags Flags =
12297 if (sd_match(N, m_BinOp(Opcode,
12298 m_c_BinOp(ISD::ADD, m_Value(A), m_Value(B), Flags),
12299 m_One()))) {
12300 // Decide whether signed or unsigned.
12301 unsigned FloorISD = IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS;
12302 if (hasOperation(FloorISD, VT))
12303 return DAG.getNode(FloorISD, DL, VT, {A, B});
12304 }
12305
12306 return SDValue();
12307}
12308
12309SDValue DAGCombiner::foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT) {
12310 unsigned Opc = N->getOpcode();
12311 SDValue X, Y, Z;
12312 if (sd_match(
12314 return DAG.getNode(Opc, DL, VT, X,
12315 DAG.getNOT(DL, DAG.getNode(ISD::SUB, DL, VT, Y, Z), VT));
12316
12318 m_Value(Z)))))
12319 return DAG.getNode(Opc, DL, VT, X,
12320 DAG.getNOT(DL, DAG.getNode(ISD::ADD, DL, VT, Y, Z), VT));
12321
12322 return SDValue();
12323}
12324
12325/// Generate Min/Max node
12326SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
12327 SDValue RHS, SDValue True,
12328 SDValue False, ISD::CondCode CC) {
12329 if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
12330 return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
12331
12332 // If we can't directly match this, try to see if we can pull an fneg out of
12333 // the select.
12335 True, DAG, LegalOperations, ForCodeSize);
12336 if (!NegTrue)
12337 return SDValue();
12338
12339 HandleSDNode NegTrueHandle(NegTrue);
12340
12341 // Try to unfold an fneg from the select if we are comparing the negated
12342 // constant.
12343 //
12344 // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
12345 //
12346 // TODO: Handle fabs
12347 if (LHS == NegTrue) {
12348 // If we can't directly match this, try to see if we can pull an fneg out of
12349 // the select.
12351 RHS, DAG, LegalOperations, ForCodeSize);
12352 if (NegRHS) {
12353 HandleSDNode NegRHSHandle(NegRHS);
12354 if (NegRHS == False) {
12355 SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue,
12356 False, CC, TLI, DAG);
12357 if (Combined)
12358 return DAG.getNode(ISD::FNEG, DL, VT, Combined);
12359 }
12360 }
12361 }
12362
12363 return SDValue();
12364}
12365
12366/// If a (v)select has a condition value that is a sign-bit test, try to smear
12367/// the condition operand sign-bit across the value width and use it as a mask.
12369 SelectionDAG &DAG) {
12370 SDValue Cond = N->getOperand(0);
12371 SDValue C1 = N->getOperand(1);
12372 SDValue C2 = N->getOperand(2);
12374 return SDValue();
12375
12376 EVT VT = N->getValueType(0);
12377 if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
12378 VT != Cond.getOperand(0).getValueType())
12379 return SDValue();
12380
12381 // The inverted-condition + commuted-select variants of these patterns are
12382 // canonicalized to these forms in IR.
12383 SDValue X = Cond.getOperand(0);
12384 SDValue CondC = Cond.getOperand(1);
12385 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
12386 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
12388 // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
12389 SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
12390 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
12391 return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
12392 }
12393 if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
12394 // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
12395 SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
12396 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
12397 return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
12398 }
12399 return SDValue();
12400}
12401
12403 const TargetLowering &TLI) {
12404 if (!TLI.convertSelectOfConstantsToMath(VT))
12405 return false;
12406
12407 if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
12408 return true;
12410 return true;
12411
12412 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
12413 if (CC == ISD::SETLT && isNullOrNullSplat(Cond.getOperand(1)))
12414 return true;
12415 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond.getOperand(1)))
12416 return true;
12417
12418 return false;
12419}
12420
12421SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
12422 SDValue Cond = N->getOperand(0);
12423 SDValue N1 = N->getOperand(1);
12424 SDValue N2 = N->getOperand(2);
12425 EVT VT = N->getValueType(0);
12426 EVT CondVT = Cond.getValueType();
12427 SDLoc DL(N);
12428
12429 if (!VT.isInteger())
12430 return SDValue();
12431
12432 auto *C1 = dyn_cast<ConstantSDNode>(N1);
12433 auto *C2 = dyn_cast<ConstantSDNode>(N2);
12434 if (!C1 || !C2)
12435 return SDValue();
12436
12437 if (CondVT != MVT::i1 || LegalOperations) {
12438 // fold (select Cond, 0, 1) -> (xor Cond, 1)
12439 // We can't do this reliably if integer based booleans have different contents
12440 // to floating point based booleans. This is because we can't tell whether we
12441 // have an integer-based boolean or a floating-point-based boolean unless we
12442 // can find the SETCC that produced it and inspect its operands. This is
12443 // fairly easy if C is the SETCC node, but it can potentially be
12444 // undiscoverable (or not reasonably discoverable). For example, it could be
12445 // in another basic block or it could require searching a complicated
12446 // expression.
12447 if (CondVT.isInteger() &&
12448 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
12450 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
12452 C1->isZero() && C2->isOne()) {
12453 SDValue NotCond =
12454 DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
12455 if (VT.bitsEq(CondVT))
12456 return NotCond;
12457 return DAG.getZExtOrTrunc(NotCond, DL, VT);
12458 }
12459
12460 return SDValue();
12461 }
12462
12463 // Only do this before legalization to avoid conflicting with target-specific
12464 // transforms in the other direction (create a select from a zext/sext). There
12465 // is also a target-independent combine here in DAGCombiner in the other
12466 // direction for (select Cond, -1, 0) when the condition is not i1.
12467 assert(CondVT == MVT::i1 && !LegalOperations);
12468
12469 // select Cond, 1, 0 --> zext (Cond)
12470 if (C1->isOne() && C2->isZero())
12471 return DAG.getZExtOrTrunc(Cond, DL, VT);
12472
12473 // select Cond, -1, 0 --> sext (Cond)
12474 if (C1->isAllOnes() && C2->isZero())
12475 return DAG.getSExtOrTrunc(Cond, DL, VT);
12476
12477 // select Cond, 0, 1 --> zext (!Cond)
12478 if (C1->isZero() && C2->isOne()) {
12479 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
12480 NotCond = DAG.getZExtOrTrunc(NotCond, DL, VT);
12481 return NotCond;
12482 }
12483
12484 // select Cond, 0, -1 --> sext (!Cond)
12485 if (C1->isZero() && C2->isAllOnes()) {
12486 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
12487 NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
12488 return NotCond;
12489 }
12490
12491 // Use a target hook because some targets may prefer to transform in the
12492 // other direction.
12494 return SDValue();
12495
12496 // For any constants that differ by 1, we can transform the select into
12497 // an extend and add.
12498 const APInt &C1Val = C1->getAPIntValue();
12499 const APInt &C2Val = C2->getAPIntValue();
12500
12501 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
12502 if (C1Val - 1 == C2Val) {
12503 Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
12504 return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
12505 }
12506
12507 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
12508 if (C1Val + 1 == C2Val) {
12509 Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
12510 return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
12511 }
12512
12513 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
12514 if (C1Val.isPowerOf2() && C2Val.isZero()) {
12515 Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
12516 SDValue ShAmtC =
12517 DAG.getShiftAmountConstant(C1Val.exactLogBase2(), VT, DL);
12518 return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
12519 }
12520
12521 // select Cond, -1, C --> or (sext Cond), C
12522 if (C1->isAllOnes()) {
12523 Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
12524 return DAG.getNode(ISD::OR, DL, VT, Cond, N2);
12525 }
12526
12527 // select Cond, C, -1 --> or (sext (not Cond)), C
12528 if (C2->isAllOnes()) {
12529 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
12530 NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
12531 return DAG.getNode(ISD::OR, DL, VT, NotCond, N1);
12532 }
12533
12535 return V;
12536
12537 return SDValue();
12538}
12539
12540template <class MatchContextClass>
12542 SelectionDAG &DAG) {
12543 assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
12544 N->getOpcode() == ISD::VP_SELECT) &&
12545 "Expected a (v)(vp.)select");
12546 SDValue Cond = N->getOperand(0);
12547 SDValue T = N->getOperand(1), F = N->getOperand(2);
12548 EVT VT = N->getValueType(0);
12549 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12550 MatchContextClass matcher(DAG, TLI, N);
12551
12552 if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
12553 return SDValue();
12554
12555 // select Cond, Cond, F --> or Cond, freeze(F)
12556 // select Cond, 1, F --> or Cond, freeze(F)
12557 if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
12558 return matcher.getNode(ISD::OR, DL, VT, Cond, DAG.getFreeze(F));
12559
12560 // select Cond, T, Cond --> and Cond, freeze(T)
12561 // select Cond, T, 0 --> and Cond, freeze(T)
12562 if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
12563 return matcher.getNode(ISD::AND, DL, VT, Cond, DAG.getFreeze(T));
12564
12565 // select Cond, T, 1 --> or (not Cond), freeze(T)
12566 if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
12567 SDValue NotCond =
12568 matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
12569 return matcher.getNode(ISD::OR, DL, VT, NotCond, DAG.getFreeze(T));
12570 }
12571
12572 // select Cond, 0, F --> and (not Cond), freeze(F)
12573 if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
12574 SDValue NotCond =
12575 matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
12576 return matcher.getNode(ISD::AND, DL, VT, NotCond, DAG.getFreeze(F));
12577 }
12578
12579 return SDValue();
12580}
12581
12583 SDValue N0 = N->getOperand(0);
12584 SDValue N1 = N->getOperand(1);
12585 SDValue N2 = N->getOperand(2);
12586 EVT VT = N->getValueType(0);
12587 unsigned EltSizeInBits = VT.getScalarSizeInBits();
12588
12589 SDValue Cond0, Cond1;
12590 ISD::CondCode CC;
12591 if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
12592 m_CondCode(CC)))) ||
12593 VT != Cond0.getValueType())
12594 return SDValue();
12595
12596 // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
12597 // compare is inverted from that pattern ("Cond0 s> -1").
12598 if (CC == ISD::SETLT && isNullOrNullSplat(Cond1))
12599 ; // This is the pattern we are looking for.
12600 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond1))
12601 std::swap(N1, N2);
12602 else
12603 return SDValue();
12604
12605 // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & freeze(N1)
12606 if (isNullOrNullSplat(N2)) {
12607 SDLoc DL(N);
12608 SDValue ShiftAmt = DAG.getShiftAmountConstant(EltSizeInBits - 1, VT, DL);
12609 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
12610 return DAG.getNode(ISD::AND, DL, VT, Sra, DAG.getFreeze(N1));
12611 }
12612
12613 // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | freeze(N2)
12614 if (isAllOnesOrAllOnesSplat(N1)) {
12615 SDLoc DL(N);
12616 SDValue ShiftAmt = DAG.getShiftAmountConstant(EltSizeInBits - 1, VT, DL);
12617 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
12618 return DAG.getNode(ISD::OR, DL, VT, Sra, DAG.getFreeze(N2));
12619 }
12620
12621 // If we have to invert the sign bit mask, only do that transform if the
12622 // target has a bitwise 'and not' instruction (the invert is free).
12623 // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & freeze(N2)
12624 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12625 if (isNullOrNullSplat(N1) && TLI.hasAndNot(N1)) {
12626 SDLoc DL(N);
12627 SDValue ShiftAmt = DAG.getShiftAmountConstant(EltSizeInBits - 1, VT, DL);
12628 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
12629 SDValue Not = DAG.getNOT(DL, Sra, VT);
12630 return DAG.getNode(ISD::AND, DL, VT, Not, DAG.getFreeze(N2));
12631 }
12632
12633 // TODO: There's another pattern in this family, but it may require
12634 // implementing hasOrNot() to check for profitability:
12635 // (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | freeze(N2)
12636
12637 return SDValue();
12638}
12639
12640// Match SELECTs with absolute difference patterns.
12641// (select (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b)
12642// (select (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b)
12643// (select (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b)
12644// (select (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b)
12645SDValue DAGCombiner::foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
12646 SDValue False, ISD::CondCode CC,
12647 const SDLoc &DL) {
12648 bool IsSigned = isSignedIntSetCC(CC);
12649 unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12650 EVT VT = LHS.getValueType();
12651
12652 if (LegalOperations && !hasOperation(ABDOpc, VT))
12653 return SDValue();
12654
12655 // (setcc 0, b set???) --> (setcc b, 0, set???)
12656 if (isZeroOrZeroSplat(LHS)) {
12657 std::swap(LHS, RHS);
12659 }
12660
12661 // (setcc (add nsw A, Const), 0, sets??) --> (setcc A, -Const, sets??)
12662 SDValue A, B;
12663 if (ISD::isSignedIntSetCC(CC) && LHS->getFlags().hasNoSignedWrap() &&
12666 RHS = DAG.getNegative(B, LHS, B.getValueType());
12667 LHS = A;
12668 }
12669
12670 switch (CC) {
12671 case ISD::SETGT:
12672 case ISD::SETGE:
12673 case ISD::SETUGT:
12674 case ISD::SETUGE:
12679 return DAG.getNode(ABDOpc, DL, VT, LHS, RHS);
12684 hasOperation(ABDOpc, VT))
12685 return DAG.getNegative(DAG.getNode(ABDOpc, DL, VT, LHS, RHS), DL, VT);
12686 break;
12687 case ISD::SETLT:
12688 case ISD::SETLE:
12689 case ISD::SETULT:
12690 case ISD::SETULE:
12695 return DAG.getNode(ABDOpc, DL, VT, LHS, RHS);
12700 hasOperation(ABDOpc, VT))
12701 return DAG.getNegative(DAG.getNode(ABDOpc, DL, VT, LHS, RHS), DL, VT);
12702 break;
12703 default:
12704 break;
12705 }
12706
12707 return SDValue();
12708}
12709
12710// ([v]select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x)
12711// ([v]select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C))
12712SDValue DAGCombiner::foldSelectToUMin(SDValue LHS, SDValue RHS, SDValue True,
12713 SDValue False, ISD::CondCode CC,
12714 const SDLoc &DL) {
12715 APInt C;
12716 EVT VT = True.getValueType();
12717 if (sd_match(RHS, m_ConstInt(C)) && hasUMin(VT)) {
12718 if (CC == ISD::SETUGT && LHS == False &&
12719 sd_match(True, m_Add(m_Specific(False), m_SpecificInt(~C)))) {
12720 SDValue AddC = DAG.getConstant(~C, DL, VT);
12721 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, False, AddC);
12722 return DAG.getNode(ISD::UMIN, DL, VT, Add, False);
12723 }
12724 if (CC == ISD::SETULT && LHS == True &&
12725 sd_match(False, m_Add(m_Specific(True), m_SpecificInt(-C)))) {
12726 SDValue AddC = DAG.getConstant(-C, DL, VT);
12727 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, True, AddC);
12728 return DAG.getNode(ISD::UMIN, DL, VT, True, Add);
12729 }
12730 }
12731 return SDValue();
12732}
12733
12734SDValue DAGCombiner::visitSELECT(SDNode *N) {
12735 SDValue N0 = N->getOperand(0);
12736 SDValue N1 = N->getOperand(1);
12737 SDValue N2 = N->getOperand(2);
12738 EVT VT = N->getValueType(0);
12739 EVT VT0 = N0.getValueType();
12740 SDLoc DL(N);
12741 SDNodeFlags Flags = N->getFlags();
12742
12743 if (SDValue V = DAG.simplifySelect(N0, N1, N2))
12744 return V;
12745
12747 return V;
12748
12749 // select (not Cond), N1, N2 -> select Cond, N2, N1
12750 if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
12751 return DAG.getSelect(DL, VT, F, N2, N1, Flags);
12752
12753 if (SDValue V = foldSelectOfConstants(N))
12754 return V;
12755
12756 // If we can fold this based on the true/false value, do so.
12757 if (SimplifySelectOps(N, N1, N2))
12758 return SDValue(N, 0); // Don't revisit N.
12759
12760 if (VT0 == MVT::i1) {
12761 // The code in this block deals with the following 2 equivalences:
12762 // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
12763 // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
12764 // The target can specify its preferred form with the
12765 // shouldNormalizeToSelectSequence() callback. However we always transform
12766 // to the right anyway if we find the inner select exists in the DAG anyway
12767 // and we always transform to the left side if we know that we can further
12768 // optimize the combination of the conditions.
12769 bool normalizeToSequence =
12771 // select (and Cond0, Cond1), X, Y
12772 // -> select Cond0, (select Cond1, X, Y), Y
12773 if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
12774 SDValue Cond0 = N0->getOperand(0);
12775 SDValue Cond1 = N0->getOperand(1);
12776 SDValue InnerSelect =
12777 DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
12778 if (normalizeToSequence || !InnerSelect.use_empty())
12779 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
12780 InnerSelect, N2, Flags);
12781 // Cleanup on failure.
12782 if (InnerSelect.use_empty())
12783 recursivelyDeleteUnusedNodes(InnerSelect.getNode());
12784 }
12785 // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
12786 if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
12787 SDValue Cond0 = N0->getOperand(0);
12788 SDValue Cond1 = N0->getOperand(1);
12789 SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
12790 Cond1, N1, N2, Flags);
12791 if (normalizeToSequence || !InnerSelect.use_empty())
12792 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
12793 InnerSelect, Flags);
12794 // Cleanup on failure.
12795 if (InnerSelect.use_empty())
12796 recursivelyDeleteUnusedNodes(InnerSelect.getNode());
12797 }
12798
12799 // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
12800 if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
12801 SDValue N1_0 = N1->getOperand(0);
12802 SDValue N1_1 = N1->getOperand(1);
12803 SDValue N1_2 = N1->getOperand(2);
12804 if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
12805 // Create the actual and node if we can generate good code for it.
12806 if (!normalizeToSequence) {
12807 SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
12808 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
12809 N2, Flags);
12810 }
12811 // Otherwise see if we can optimize the "and" to a better pattern.
12812 if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
12813 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
12814 N2, Flags);
12815 }
12816 }
12817 }
12818 // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
12819 if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
12820 SDValue N2_0 = N2->getOperand(0);
12821 SDValue N2_1 = N2->getOperand(1);
12822 SDValue N2_2 = N2->getOperand(2);
12823 if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
12824 // Create the actual or node if we can generate good code for it.
12825 if (!normalizeToSequence) {
12826 SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
12827 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
12828 N2_2, Flags);
12829 }
12830 // Otherwise see if we can optimize to a better pattern.
12831 if (SDValue Combined = visitORLike(N0, N2_0, DL))
12832 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
12833 N2_2, Flags);
12834 }
12835 }
12836
12837 // select usubo(x, y).overflow, (sub y, x), (usubo x, y) -> abdu(x, y)
12838 if (N0.getOpcode() == ISD::USUBO && N0.getResNo() == 1 &&
12839 N2.getNode() == N0.getNode() && N2.getResNo() == 0 &&
12840 N1.getOpcode() == ISD::SUB && N2.getOperand(0) == N1.getOperand(1) &&
12841 N2.getOperand(1) == N1.getOperand(0) &&
12842 (!LegalOperations || TLI.isOperationLegal(ISD::ABDU, VT)))
12843 return DAG.getNode(ISD::ABDU, DL, VT, N0.getOperand(0), N0.getOperand(1));
12844
12845 // select usubo(x, y).overflow, (usubo x, y), (sub y, x) -> neg (abdu x, y)
12846 if (N0.getOpcode() == ISD::USUBO && N0.getResNo() == 1 &&
12847 N1.getNode() == N0.getNode() && N1.getResNo() == 0 &&
12848 N2.getOpcode() == ISD::SUB && N2.getOperand(0) == N1.getOperand(1) &&
12849 N2.getOperand(1) == N1.getOperand(0) &&
12850 (!LegalOperations || TLI.isOperationLegal(ISD::ABDU, VT)))
12851 return DAG.getNegative(
12852 DAG.getNode(ISD::ABDU, DL, VT, N0.getOperand(0), N0.getOperand(1)),
12853 DL, VT);
12854 }
12855
12856 // Fold selects based on a setcc into other things, such as min/max/abs.
12857 if (N0.getOpcode() == ISD::SETCC) {
12858 SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
12860
12861 // select (fcmp lt x, y), x, y -> fminnum x, y
12862 // select (fcmp gt x, y), x, y -> fmaxnum x, y
12863 //
12864 // This is OK if we don't care what happens if either operand is a NaN.
12865 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, Flags, TLI))
12866 if (SDValue FMinMax =
12867 combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, CC))
12868 return FMinMax;
12869
12870 // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
12871 // This is conservatively limited to pre-legal-operations to give targets
12872 // a chance to reverse the transform if they want to do that. Also, it is
12873 // unlikely that the pattern would be formed late, so it's probably not
12874 // worth going through the other checks.
12875 if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
12876 CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
12877 N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
12878 auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
12879 auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
12880 if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
12881 // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
12882 // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
12883 //
12884 // The IR equivalent of this transform would have this form:
12885 // %a = add %x, C
12886 // %c = icmp ugt %x, ~C
12887 // %r = select %c, -1, %a
12888 // =>
12889 // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
12890 // %u0 = extractvalue %u, 0
12891 // %u1 = extractvalue %u, 1
12892 // %r = select %u1, -1, %u0
12893 SDVTList VTs = DAG.getVTList(VT, VT0);
12894 SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
12895 return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
12896 }
12897 }
12898
12899 if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
12900 (!LegalOperations &&
12902 // Any flags available in a select/setcc fold will be on the setcc as they
12903 // migrated from fcmp
12904 return DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1, N2,
12905 N0.getOperand(2), N0->getFlags());
12906 }
12907
12908 if (SDValue ABD = foldSelectToABD(Cond0, Cond1, N1, N2, CC, DL))
12909 return ABD;
12910
12911 if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
12912 return NewSel;
12913
12914 // (select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x)
12915 // (select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C))
12916 if (SDValue UMin = foldSelectToUMin(Cond0, Cond1, N1, N2, CC, DL))
12917 return UMin;
12918 }
12919
12920 if (!VT.isVector())
12921 if (SDValue BinOp = foldSelectOfBinops(N))
12922 return BinOp;
12923
12924 if (SDValue R = combineSelectAsExtAnd(N0, N1, N2, DL, DAG))
12925 return R;
12926
12927 return SDValue();
12928}
12929
12930// This function assumes all the vselect's arguments are CONCAT_VECTOR
12931// nodes and that the condition is a BV of ConstantSDNodes (or undefs).
12933 SDLoc DL(N);
12934 SDValue Cond = N->getOperand(0);
12935 SDValue LHS = N->getOperand(1);
12936 SDValue RHS = N->getOperand(2);
12937 EVT VT = N->getValueType(0);
12938 int NumElems = VT.getVectorNumElements();
12939 assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
12940 RHS.getOpcode() == ISD::CONCAT_VECTORS &&
12941 Cond.getOpcode() == ISD::BUILD_VECTOR);
12942
12943 // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
12944 // binary ones here.
12945 if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
12946 return SDValue();
12947
12948 // We're sure we have an even number of elements due to the
12949 // concat_vectors we have as arguments to vselect.
12950 // Skip BV elements until we find one that's not an UNDEF
12951 // After we find an UNDEF element, keep looping until we get to half the
12952 // length of the BV and see if all the non-undef nodes are the same.
12953 ConstantSDNode *BottomHalf = nullptr;
12954 for (int i = 0; i < NumElems / 2; ++i) {
12955 if (Cond->getOperand(i)->isUndef())
12956 continue;
12957
12958 if (BottomHalf == nullptr)
12959 BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
12960 else if (Cond->getOperand(i).getNode() != BottomHalf)
12961 return SDValue();
12962 }
12963
12964 // Do the same for the second half of the BuildVector
12965 ConstantSDNode *TopHalf = nullptr;
12966 for (int i = NumElems / 2; i < NumElems; ++i) {
12967 if (Cond->getOperand(i)->isUndef())
12968 continue;
12969
12970 if (TopHalf == nullptr)
12971 TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
12972 else if (Cond->getOperand(i).getNode() != TopHalf)
12973 return SDValue();
12974 }
12975
12976 assert(TopHalf && BottomHalf &&
12977 "One half of the selector was all UNDEFs and the other was all the "
12978 "same value. This should have been addressed before this function.");
12979 return DAG.getNode(
12981 BottomHalf->isZero() ? RHS->getOperand(0) : LHS->getOperand(0),
12982 TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
12983}
12984
12985bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
12986 SelectionDAG &DAG, const SDLoc &DL) {
12987
12988 // Only perform the transformation when existing operands can be reused.
12989 if (IndexIsScaled)
12990 return false;
12991
12992 if (!isNullConstant(BasePtr) && !Index.hasOneUse())
12993 return false;
12994
12995 EVT VT = BasePtr.getValueType();
12996
12997 if (SDValue SplatVal = DAG.getSplatValue(Index);
12998 SplatVal && !isNullConstant(SplatVal) &&
12999 SplatVal.getValueType() == VT) {
13000 BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
13001 Index = DAG.getSplat(Index.getValueType(), DL, DAG.getConstant(0, DL, VT));
13002 return true;
13003 }
13004
13005 if (Index.getOpcode() != ISD::ADD)
13006 return false;
13007
13008 if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
13009 SplatVal && SplatVal.getValueType() == VT) {
13010 BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
13011 Index = Index.getOperand(1);
13012 return true;
13013 }
13014 if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1));
13015 SplatVal && SplatVal.getValueType() == VT) {
13016 BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
13017 Index = Index.getOperand(0);
13018 return true;
13019 }
13020 return false;
13021}
13022
13023// Fold sext/zext of index into index type.
13024bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
13025 SelectionDAG &DAG) {
13026 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13027
13028 // It's always safe to look through zero extends.
13029 if (Index.getOpcode() == ISD::ZERO_EXTEND) {
13030 if (TLI.shouldRemoveExtendFromGSIndex(Index, DataVT)) {
13031 IndexType = ISD::UNSIGNED_SCALED;
13032 Index = Index.getOperand(0);
13033 return true;
13034 }
13035 if (ISD::isIndexTypeSigned(IndexType)) {
13036 IndexType = ISD::UNSIGNED_SCALED;
13037 return true;
13038 }
13039 }
13040
13041 // It's only safe to look through sign extends when Index is signed.
13042 if (Index.getOpcode() == ISD::SIGN_EXTEND &&
13043 ISD::isIndexTypeSigned(IndexType) &&
13044 TLI.shouldRemoveExtendFromGSIndex(Index, DataVT)) {
13045 Index = Index.getOperand(0);
13046 return true;
13047 }
13048
13049 return false;
13050}
13051
13052SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
13053 VPScatterSDNode *MSC = cast<VPScatterSDNode>(N);
13054 SDValue Mask = MSC->getMask();
13055 SDValue Chain = MSC->getChain();
13056 SDValue Index = MSC->getIndex();
13057 SDValue Scale = MSC->getScale();
13058 SDValue StoreVal = MSC->getValue();
13059 SDValue BasePtr = MSC->getBasePtr();
13060 SDValue VL = MSC->getVectorLength();
13061 ISD::MemIndexType IndexType = MSC->getIndexType();
13062 SDLoc DL(N);
13063
13064 // Zap scatters with a zero mask.
13066 return Chain;
13067
13068 if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
13069 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
13070 return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
13071 DL, Ops, MSC->getMemOperand(), IndexType);
13072 }
13073
13074 if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
13075 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
13076 return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
13077 DL, Ops, MSC->getMemOperand(), IndexType);
13078 }
13079
13080 return SDValue();
13081}
13082
13083SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
13084 MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
13085 SDValue Mask = MSC->getMask();
13086 SDValue Chain = MSC->getChain();
13087 SDValue Index = MSC->getIndex();
13088 SDValue Scale = MSC->getScale();
13089 SDValue StoreVal = MSC->getValue();
13090 SDValue BasePtr = MSC->getBasePtr();
13091 ISD::MemIndexType IndexType = MSC->getIndexType();
13092 SDLoc DL(N);
13093
13094 // Zap scatters with a zero mask.
13096 return Chain;
13097
13098 if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
13099 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
13100 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
13101 DL, Ops, MSC->getMemOperand(), IndexType,
13102 MSC->isTruncatingStore());
13103 }
13104
13105 if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
13106 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
13107 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
13108 DL, Ops, MSC->getMemOperand(), IndexType,
13109 MSC->isTruncatingStore());
13110 }
13111
13112 return SDValue();
13113}
13114
13115SDValue DAGCombiner::visitMSTORE(SDNode *N) {
13116 MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
13117 SDValue Mask = MST->getMask();
13118 SDValue Chain = MST->getChain();
13119 SDValue Value = MST->getValue();
13120 SDValue Ptr = MST->getBasePtr();
13121
13122 // Zap masked stores with a zero mask.
13124 return Chain;
13125
13126 // Remove a masked store if base pointers and masks are equal.
13127 if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Chain)) {
13128 if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() &&
13129 MST1->isSimple() && MST1->getBasePtr() == Ptr &&
13130 !MST->getBasePtr().isUndef() &&
13131 ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() ==
13132 MST1->getMemoryVT().getStoreSize()) ||
13134 TypeSize::isKnownLE(MST1->getMemoryVT().getStoreSize(),
13135 MST->getMemoryVT().getStoreSize())) {
13136 CombineTo(MST1, MST1->getChain());
13137 if (N->getOpcode() != ISD::DELETED_NODE)
13138 AddToWorklist(N);
13139 return SDValue(N, 0);
13140 }
13141 }
13142
13143 // If this is a masked load with an all ones mask, we can use a unmasked load.
13144 // FIXME: Can we do this for indexed, compressing, or truncating stores?
13145 if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() &&
13146 !MST->isCompressingStore() && !MST->isTruncatingStore())
13147 return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
13148 MST->getBasePtr(), MST->getPointerInfo(),
13149 MST->getBaseAlign(), MST->getMemOperand()->getFlags(),
13150 MST->getAAInfo());
13151
13152 // Try transforming N to an indexed store.
13153 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
13154 return SDValue(N, 0);
13155
13156 if (MST->isTruncatingStore() && MST->isUnindexed() &&
13157 Value.getValueType().isInteger() &&
13159 !cast<ConstantSDNode>(Value)->isOpaque())) {
13160 APInt TruncDemandedBits =
13161 APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
13163
13164 // See if we can simplify the operation with
13165 // SimplifyDemandedBits, which only works if the value has a single use.
13166 if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
13167 // Re-visit the store if anything changed and the store hasn't been merged
13168 // with another node (N is deleted) SimplifyDemandedBits will add Value's
13169 // node back to the worklist if necessary, but we also need to re-visit
13170 // the Store node itself.
13171 if (N->getOpcode() != ISD::DELETED_NODE)
13172 AddToWorklist(N);
13173 return SDValue(N, 0);
13174 }
13175 }
13176
13177 // If this is a TRUNC followed by a masked store, fold this into a masked
13178 // truncating store. We can do this even if this is already a masked
13179 // truncstore.
13180 // TODO: Try combine to masked compress store if possiable.
13181 if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
13182 MST->isUnindexed() && !MST->isCompressingStore() &&
13183 TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
13184 MST->getMemoryVT(), LegalOperations)) {
13185 auto Mask = TLI.promoteTargetBoolean(DAG, MST->getMask(),
13186 Value.getOperand(0).getValueType());
13187 return DAG.getMaskedStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
13188 MST->getOffset(), Mask, MST->getMemoryVT(),
13189 MST->getMemOperand(), MST->getAddressingMode(),
13190 /*IsTruncating=*/true);
13191 }
13192
13193 return SDValue();
13194}
13195
13196SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
13197 auto *SST = cast<VPStridedStoreSDNode>(N);
13198 EVT EltVT = SST->getValue().getValueType().getVectorElementType();
13199 // Combine strided stores with unit-stride to a regular VP store.
13200 if (auto *CStride = dyn_cast<ConstantSDNode>(SST->getStride());
13201 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
13202 return DAG.getStoreVP(SST->getChain(), SDLoc(N), SST->getValue(),
13203 SST->getBasePtr(), SST->getOffset(), SST->getMask(),
13204 SST->getVectorLength(), SST->getMemoryVT(),
13205 SST->getMemOperand(), SST->getAddressingMode(),
13206 SST->isTruncatingStore(), SST->isCompressingStore());
13207 }
13208 return SDValue();
13209}
13210
13211SDValue DAGCombiner::visitVECTOR_COMPRESS(SDNode *N) {
13212 SDLoc DL(N);
13213 SDValue Vec = N->getOperand(0);
13214 SDValue Mask = N->getOperand(1);
13215 SDValue Passthru = N->getOperand(2);
13216 EVT VecVT = Vec.getValueType();
13217
13218 bool HasPassthru = !Passthru.isUndef();
13219
13220 APInt SplatVal;
13221 if (ISD::isConstantSplatVector(Mask.getNode(), SplatVal))
13222 return TLI.isConstTrueVal(Mask) ? Vec : Passthru;
13223
13224 if (Vec.isUndef() || Mask.isUndef())
13225 return Passthru;
13226
13227 // No need for potentially expensive compress if the mask is constant.
13230 EVT ScalarVT = VecVT.getVectorElementType();
13231 unsigned NumSelected = 0;
13232 unsigned NumElmts = VecVT.getVectorNumElements();
13233 for (unsigned I = 0; I < NumElmts; ++I) {
13234 SDValue MaskI = Mask.getOperand(I);
13235 // We treat undef mask entries as "false".
13236 if (MaskI.isUndef())
13237 continue;
13238
13239 if (TLI.isConstTrueVal(MaskI)) {
13240 SDValue VecI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec,
13241 DAG.getVectorIdxConstant(I, DL));
13242 Ops.push_back(VecI);
13243 NumSelected++;
13244 }
13245 }
13246 for (unsigned Rest = NumSelected; Rest < NumElmts; ++Rest) {
13247 SDValue Val =
13248 HasPassthru
13249 ? DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Passthru,
13250 DAG.getVectorIdxConstant(Rest, DL))
13251 : DAG.getUNDEF(ScalarVT);
13252 Ops.push_back(Val);
13253 }
13254 return DAG.getBuildVector(VecVT, DL, Ops);
13255 }
13256
13257 return SDValue();
13258}
13259
13260SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
13261 VPGatherSDNode *MGT = cast<VPGatherSDNode>(N);
13262 SDValue Mask = MGT->getMask();
13263 SDValue Chain = MGT->getChain();
13264 SDValue Index = MGT->getIndex();
13265 SDValue Scale = MGT->getScale();
13266 SDValue BasePtr = MGT->getBasePtr();
13267 SDValue VL = MGT->getVectorLength();
13268 ISD::MemIndexType IndexType = MGT->getIndexType();
13269 SDLoc DL(N);
13270
13271 if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
13272 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
13273 return DAG.getGatherVP(
13274 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
13275 Ops, MGT->getMemOperand(), IndexType);
13276 }
13277
13278 if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
13279 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
13280 return DAG.getGatherVP(
13281 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
13282 Ops, MGT->getMemOperand(), IndexType);
13283 }
13284
13285 return SDValue();
13286}
13287
13288SDValue DAGCombiner::visitMGATHER(SDNode *N) {
13289 MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
13290 SDValue Mask = MGT->getMask();
13291 SDValue Chain = MGT->getChain();
13292 SDValue Index = MGT->getIndex();
13293 SDValue Scale = MGT->getScale();
13294 SDValue PassThru = MGT->getPassThru();
13295 SDValue BasePtr = MGT->getBasePtr();
13296 ISD::MemIndexType IndexType = MGT->getIndexType();
13297 SDLoc DL(N);
13298
13299 // Zap gathers with a zero mask.
13301 return CombineTo(N, PassThru, MGT->getChain());
13302
13303 if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
13304 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
13305 return DAG.getMaskedGather(
13306 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
13307 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
13308 }
13309
13310 if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
13311 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
13312 return DAG.getMaskedGather(
13313 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
13314 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
13315 }
13316
13317 return SDValue();
13318}
13319
13320SDValue DAGCombiner::visitMLOAD(SDNode *N) {
13321 MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
13322 SDValue Mask = MLD->getMask();
13323
13324 // Zap masked loads with a zero mask.
13326 return CombineTo(N, MLD->getPassThru(), MLD->getChain());
13327
13328 // If this is a masked load with an all ones mask, we can use a unmasked load.
13329 // FIXME: Can we do this for indexed, expanding, or extending loads?
13330 if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() &&
13331 !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
13332 SDValue NewLd = DAG.getLoad(
13333 N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(),
13334 MLD->getPointerInfo(), MLD->getBaseAlign(),
13335 MLD->getMemOperand()->getFlags(), MLD->getAAInfo(), MLD->getRanges());
13336 return CombineTo(N, NewLd, NewLd.getValue(1));
13337 }
13338
13339 // Try transforming N to an indexed load.
13340 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
13341 return SDValue(N, 0);
13342
13343 return SDValue();
13344}
13345
13346SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
13347 MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
13348 SDValue Chain = HG->getChain();
13349 SDValue Inc = HG->getInc();
13350 SDValue Mask = HG->getMask();
13351 SDValue BasePtr = HG->getBasePtr();
13352 SDValue Index = HG->getIndex();
13353 SDLoc DL(HG);
13354
13355 EVT MemVT = HG->getMemoryVT();
13356 EVT DataVT = Index.getValueType();
13357 MachineMemOperand *MMO = HG->getMemOperand();
13358 ISD::MemIndexType IndexType = HG->getIndexType();
13359
13361 return Chain;
13362
13363 if (refineUniformBase(BasePtr, Index, HG->isIndexScaled(), DAG, DL) ||
13364 refineIndexType(Index, IndexType, DataVT, DAG)) {
13365 SDValue Ops[] = {Chain, Inc, Mask, BasePtr, Index,
13366 HG->getScale(), HG->getIntID()};
13367 return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
13368 MMO, IndexType);
13369 }
13370
13371 return SDValue();
13372}
13373
13374SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
13375 if (SDValue Res = foldPartialReduceMLAMulOp(N))
13376 return Res;
13377 if (SDValue Res = foldPartialReduceAdd(N))
13378 return Res;
13379 return SDValue();
13380}
13381
13382// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
13383// -> partial_reduce_*mla(acc, a, b)
13384//
13385// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
13386// -> partial_reduce_*mla(acc, x, splat(C))
13387//
13388// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
13389// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
13390//
13391// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
13392// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
13393SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
13394 SDLoc DL(N);
13395 auto *Context = DAG.getContext();
13396 SDValue Acc = N->getOperand(0);
13397 SDValue Op1 = N->getOperand(1);
13398 SDValue Op2 = N->getOperand(2);
13399 unsigned Opc = Op1->getOpcode();
13400
13401 // Handle predication by moving the SELECT into the operand of the MUL.
13402 SDValue Pred;
13403 if (Opc == ISD::VSELECT && (isZeroOrZeroSplat(Op1->getOperand(2)) ||
13404 isZeroOrZeroSplatFP(Op1->getOperand(2)))) {
13405 Pred = Op1->getOperand(0);
13406 Op1 = Op1->getOperand(1);
13407 Opc = Op1->getOpcode();
13408 }
13409
13410 if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
13411 return SDValue();
13412
13413 SDValue LHS = Op1->getOperand(0);
13414 SDValue RHS = Op1->getOperand(1);
13415
13416 // Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c.
13417 if (Opc == ISD::SHL) {
13418 APInt C;
13419 if (!ISD::isConstantSplatVector(RHS.getNode(), C))
13420 return SDValue();
13421
13422 RHS =
13423 DAG.getSplatVector(RHS.getValueType(), DL,
13424 DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL,
13425 RHS.getValueType().getScalarType()));
13426 Opc = ISD::MUL;
13427 }
13428
13429 if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(Op2)) &&
13431 return SDValue();
13432
13433 auto IsIntOrFPExtOpcode = [](unsigned int Opcode) {
13434 return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND);
13435 };
13436
13437 unsigned LHSOpcode = LHS->getOpcode();
13438 if (!IsIntOrFPExtOpcode(LHSOpcode))
13439 return SDValue();
13440
13441 SDValue LHSExtOp = LHS->getOperand(0);
13442 EVT LHSExtOpVT = LHSExtOp.getValueType();
13443
13444 // When Pred is non-zero, set Op = select(Pred, Op, splat(0)) and freeze
13445 // OtherOp to keep the same semantics when moving the selects into the MUL
13446 // operands.
13447 auto ApplyPredicate = [&](SDValue &Op, SDValue &OtherOp) {
13448 if (Pred) {
13449 EVT OpVT = Op.getValueType();
13450 SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
13451 : DAG.getConstant(0, DL, OpVT);
13452 Op = DAG.getSelect(DL, OpVT, Pred, Op, Zero);
13453 OtherOp = DAG.getFreeze(OtherOp);
13454 }
13455 };
13456
13457 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
13458 // -> partial_reduce_*mla(acc, x, C)
13459 APInt C;
13460 if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
13461 // TODO: Make use of partial_reduce_sumla here
13462 APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
13463 unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
13464 if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
13465 (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
13466 return SDValue();
13467
13468 unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
13471
13472 // Only perform these combines if the target supports folding
13473 // the extends into the operation.
13475 NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
13476 TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
13477 return SDValue();
13478
13479 SDValue C = DAG.getConstant(CTrunc, DL, LHSExtOpVT);
13480 ApplyPredicate(C, LHSExtOp);
13481 return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, C);
13482 }
13483
13484 unsigned RHSOpcode = RHS->getOpcode();
13485 if (!IsIntOrFPExtOpcode(RHSOpcode))
13486 return SDValue();
13487
13488 SDValue RHSExtOp = RHS->getOperand(0);
13489 if (LHSExtOpVT != RHSExtOp.getValueType())
13490 return SDValue();
13491
13492 unsigned NewOpc;
13493 if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
13494 NewOpc = ISD::PARTIAL_REDUCE_SMLA;
13495 else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
13496 NewOpc = ISD::PARTIAL_REDUCE_UMLA;
13497 else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
13499 else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
13501 std::swap(LHSExtOp, RHSExtOp);
13502 } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
13503 NewOpc = ISD::PARTIAL_REDUCE_FMLA;
13504 } else
13505 return SDValue();
13506 // For a 2-stage extend the signedness of both of the extends must match
13507 // If the mul has the same type, there is no outer extend, and thus we
13508 // can simply use the inner extends to pick the result node.
13509 // TODO: extend to handle nonneg zext as sext
13510 EVT AccElemVT = Acc.getValueType().getVectorElementType();
13511 if (Op1.getValueType().getVectorElementType() != AccElemVT &&
13512 NewOpc != N->getOpcode())
13513 return SDValue();
13514
13515 // Only perform these combines if the target supports folding
13516 // the extends into the operation.
13518 NewOpc, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
13519 TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
13520 return SDValue();
13521
13522 ApplyPredicate(RHSExtOp, LHSExtOp);
13523 return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
13524}
13525
13526// partial.reduce.*mla(acc, *ext(op), splat(1))
13527// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
13528// partial.reduce.sumla(acc, sext(op), splat(1))
13529// -> partial.reduce.smla(acc, op, splat(trunc(1)))
13530//
13531// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
13532// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
13533SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
13534 SDLoc DL(N);
13535 SDValue Acc = N->getOperand(0);
13536 SDValue Op1 = N->getOperand(1);
13537 SDValue Op2 = N->getOperand(2);
13538
13540 return SDValue();
13541
13542 SDValue Pred;
13543 unsigned Op1Opcode = Op1.getOpcode();
13544 if (Op1Opcode == ISD::VSELECT && (isZeroOrZeroSplat(Op1->getOperand(2)) ||
13545 isZeroOrZeroSplatFP(Op1->getOperand(2)))) {
13546 Pred = Op1->getOperand(0);
13547 Op1 = Op1->getOperand(1);
13548 Op1Opcode = Op1->getOpcode();
13549 }
13550
13551 if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
13552 return SDValue();
13553
13554 bool Op1IsSigned =
13555 Op1Opcode == ISD::SIGN_EXTEND || Op1Opcode == ISD::FP_EXTEND;
13556 bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
13557 EVT AccElemVT = Acc.getValueType().getVectorElementType();
13558 if (Op1IsSigned != NodeIsSigned &&
13559 Op1.getValueType().getVectorElementType() != AccElemVT)
13560 return SDValue();
13561
13562 unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
13564 : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
13566
13567 SDValue UnextOp1 = Op1.getOperand(0);
13568 EVT UnextOp1VT = UnextOp1.getValueType();
13569 auto *Context = DAG.getContext();
13571 NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
13572 TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
13573 return SDValue();
13574
13575 SDValue Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
13576 ? DAG.getConstantFP(1, DL, UnextOp1VT)
13577 : DAG.getConstant(1, DL, UnextOp1VT);
13578
13579 if (Pred) {
13580 SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
13581 ? DAG.getConstantFP(0, DL, UnextOp1VT)
13582 : DAG.getConstant(0, DL, UnextOp1VT);
13583 Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero);
13584 }
13585 return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
13586 Constant);
13587}
13588
13589SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
13590 auto *SLD = cast<VPStridedLoadSDNode>(N);
13591 EVT EltVT = SLD->getValueType(0).getVectorElementType();
13592 // Combine strided loads with unit-stride to a regular VP load.
13593 if (auto *CStride = dyn_cast<ConstantSDNode>(SLD->getStride());
13594 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
13595 SDValue NewLd = DAG.getLoadVP(
13596 SLD->getAddressingMode(), SLD->getExtensionType(), SLD->getValueType(0),
13597 SDLoc(N), SLD->getChain(), SLD->getBasePtr(), SLD->getOffset(),
13598 SLD->getMask(), SLD->getVectorLength(), SLD->getMemoryVT(),
13599 SLD->getMemOperand(), SLD->isExpandingLoad());
13600 return CombineTo(N, NewLd, NewLd.getValue(1));
13601 }
13602 return SDValue();
13603}
13604
13605/// A vector select of 2 constant vectors can be simplified to math/logic to
13606/// avoid a variable select instruction and possibly avoid constant loads.
13607SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
13608 SDValue Cond = N->getOperand(0);
13609 SDValue N1 = N->getOperand(1);
13610 SDValue N2 = N->getOperand(2);
13611 EVT VT = N->getValueType(0);
13612 if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
13616 return SDValue();
13617
13618 // Check if we can use the condition value to increment/decrement a single
13619 // constant value. This simplifies a select to an add and removes a constant
13620 // load/materialization from the general case.
13621 bool AllAddOne = true;
13622 bool AllSubOne = true;
13623 unsigned Elts = VT.getVectorNumElements();
13624 for (unsigned i = 0; i != Elts; ++i) {
13625 SDValue N1Elt = N1.getOperand(i);
13626 SDValue N2Elt = N2.getOperand(i);
13627 if (N1Elt.isUndef())
13628 continue;
13629 // N2 should not contain undef values since it will be reused in the fold.
13630 if (N2Elt.isUndef() || N1Elt.getValueType() != N2Elt.getValueType()) {
13631 AllAddOne = false;
13632 AllSubOne = false;
13633 break;
13634 }
13635
13636 const APInt &C1 = N1Elt->getAsAPIntVal();
13637 const APInt &C2 = N2Elt->getAsAPIntVal();
13638 if (C1 != C2 + 1)
13639 AllAddOne = false;
13640 if (C1 != C2 - 1)
13641 AllSubOne = false;
13642 }
13643
13644 // Further simplifications for the extra-special cases where the constants are
13645 // all 0 or all -1 should be implemented as folds of these patterns.
13646 SDLoc DL(N);
13647 if (AllAddOne || AllSubOne) {
13648 // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
13649 // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
13650 auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
13651 SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
13652 return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
13653 }
13654
13655 // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
13656 APInt Pow2C;
13657 if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
13658 isNullOrNullSplat(N2)) {
13659 SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
13660 SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
13661 return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
13662 }
13663
13665 return V;
13666
13667 // The general case for select-of-constants:
13668 // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
13669 // ...but that only makes sense if a vselect is slower than 2 logic ops, so
13670 // leave that to a machine-specific pass.
13671 return SDValue();
13672}
13673
13674SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
13675 SDValue N0 = N->getOperand(0);
13676 SDValue N1 = N->getOperand(1);
13677 SDValue N2 = N->getOperand(2);
13678 SDLoc DL(N);
13679
13680 if (SDValue V = DAG.simplifySelect(N0, N1, N2))
13681 return V;
13682
13684 return V;
13685
13686 return SDValue();
13687}
13688
13690 SDValue FVal,
13691 const TargetLowering &TLI,
13692 SelectionDAG &DAG,
13693 const SDLoc &DL) {
13694 EVT VT = TVal.getValueType();
13695 if (!TLI.isTypeLegal(VT))
13696 return SDValue();
13697
13698 EVT CondVT = Cond.getValueType();
13699 assert(CondVT.isVector() && "Vector select expects a vector selector!");
13700
13701 bool IsTAllZero = ISD::isConstantSplatVectorAllZeros(TVal.getNode());
13702 bool IsTAllOne = ISD::isConstantSplatVectorAllOnes(TVal.getNode());
13703 bool IsFAllZero = ISD::isConstantSplatVectorAllZeros(FVal.getNode());
13704 bool IsFAllOne = ISD::isConstantSplatVectorAllOnes(FVal.getNode());
13705
13706 // no vselect(cond, 0/-1, X) or vselect(cond, X, 0/-1), return
13707 if (!IsTAllZero && !IsTAllOne && !IsFAllZero && !IsFAllOne)
13708 return SDValue();
13709
13710 // select Cond, 0, 0 → 0
13711 if (IsTAllZero && IsFAllZero) {
13712 return VT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, VT)
13713 : DAG.getConstant(0, DL, VT);
13714 }
13715
13716 // check select(setgt lhs, -1), 1, -1 --> or (sra lhs, bitwidth - 1), 1
13717 APInt TValAPInt;
13718 if (Cond.getOpcode() == ISD::SETCC &&
13719 Cond.getOperand(2) == DAG.getCondCode(ISD::SETGT) &&
13720 Cond.getOperand(0).getValueType() == VT && VT.isSimple() &&
13721 ISD::isConstantSplatVector(TVal.getNode(), TValAPInt) &&
13722 TValAPInt.isOne() &&
13723 ISD::isConstantSplatVectorAllOnes(Cond.getOperand(1).getNode()) &&
13725 return SDValue();
13726 }
13727
13728 // To use the condition operand as a bitwise mask, it must have elements that
13729 // are the same size as the select elements. i.e, the condition operand must
13730 // have already been promoted from the IR select condition type <N x i1>.
13731 // Don't check if the types themselves are equal because that excludes
13732 // vector floating-point selects.
13733 if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
13734 return SDValue();
13735
13736 // Cond value must be 'sign splat' to be converted to a logical op.
13737 if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
13738 return SDValue();
13739
13740 // Try inverting Cond and swapping T/F if it gives all-ones/all-zeros form
13741 if (!IsTAllOne && !IsFAllZero && Cond.hasOneUse() &&
13742 Cond.getOpcode() == ISD::SETCC &&
13743 TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
13744 CondVT) {
13745 if (IsTAllZero || IsFAllOne) {
13746 SDValue CC = Cond.getOperand(2);
13748 cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
13749 Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
13750 InverseCC);
13751 std::swap(TVal, FVal);
13752 std::swap(IsTAllOne, IsFAllOne);
13753 std::swap(IsTAllZero, IsFAllZero);
13754 }
13755 }
13756
13758 "Select condition no longer all-sign bits");
13759
13760 // select Cond, -1, 0 → bitcast Cond
13761 if (IsTAllOne && IsFAllZero)
13762 return DAG.getBitcast(VT, Cond);
13763
13764 // select Cond, -1, x → or Cond, x
13765 if (IsTAllOne) {
13766 SDValue X = DAG.getBitcast(CondVT, DAG.getFreeze(FVal));
13767 SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, X);
13768 return DAG.getBitcast(VT, Or);
13769 }
13770
13771 // select Cond, x, 0 → and Cond, x
13772 if (IsFAllZero) {
13773 SDValue X = DAG.getBitcast(CondVT, DAG.getFreeze(TVal));
13774 SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, X);
13775 return DAG.getBitcast(VT, And);
13776 }
13777
13778 // select Cond, 0, x -> and not(Cond), x
13779 if (IsTAllZero &&
13781 SDValue X = DAG.getBitcast(CondVT, DAG.getFreeze(FVal));
13782 SDValue And =
13783 DAG.getNode(ISD::AND, DL, CondVT, DAG.getNOT(DL, Cond, CondVT), X);
13784 return DAG.getBitcast(VT, And);
13785 }
13786
13787 return SDValue();
13788}
13789
13790SDValue DAGCombiner::visitVSELECT(SDNode *N) {
13791 SDValue N0 = N->getOperand(0);
13792 SDValue N1 = N->getOperand(1);
13793 SDValue N2 = N->getOperand(2);
13794 EVT VT = N->getValueType(0);
13795 SDLoc DL(N);
13796
13797 if (SDValue V = DAG.simplifySelect(N0, N1, N2))
13798 return V;
13799
13801 return V;
13802
13803 // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
13804 if (!TLI.isTargetCanonicalSelect(N))
13805 if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
13806 return DAG.getSelect(DL, VT, F, N2, N1, N->getFlags());
13807
13808 // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
13809 if (N1.getOpcode() == ISD::ADD && N1.getOperand(0) == N2 && N1->hasOneUse() &&
13812 TLI.getBooleanContents(N0.getValueType()) ==
13814 return DAG.getNode(
13815 ISD::ADD, DL, N1.getValueType(), N2,
13816 DAG.getNode(ISD::AND, DL, N0.getValueType(), N1.getOperand(1), N0));
13817 }
13818
13819 // Canonicalize integer abs.
13820 // vselect (setg[te] X, 0), X, -X ->
13821 // vselect (setgt X, -1), X, -X ->
13822 // vselect (setl[te] X, 0), -X, X ->
13823 // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
13824 if (N0.getOpcode() == ISD::SETCC) {
13825 SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
13827 bool isAbs = false;
13828 bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
13829
13830 if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
13831 (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
13832 N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
13834 else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
13835 N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
13837
13838 if (isAbs) {
13840 return DAG.getNode(ISD::ABS, DL, VT, LHS);
13841
13842 SDValue Shift = DAG.getNode(
13843 ISD::SRA, DL, VT, LHS,
13844 DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, DL));
13845 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
13846 AddToWorklist(Shift.getNode());
13847 AddToWorklist(Add.getNode());
13848 return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
13849 }
13850
13851 // vselect x, y (fcmp lt x, y) -> fminnum x, y
13852 // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
13853 //
13854 // This is OK if we don't care about what happens if either operand is a
13855 // NaN.
13856 //
13857 if (N0.hasOneUse() &&
13858 isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, N->getFlags(), TLI)) {
13859 if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC))
13860 return FMinMax;
13861 }
13862
13863 if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
13864 return S;
13865 if (SDValue S = PerformUMinFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
13866 return S;
13867
13868 // If this select has a condition (setcc) with narrower operands than the
13869 // select, try to widen the compare to match the select width.
13870 // TODO: This should be extended to handle any constant.
13871 // TODO: This could be extended to handle non-loading patterns, but that
13872 // requires thorough testing to avoid regressions.
13873 if (isNullOrNullSplat(RHS)) {
13874 EVT NarrowVT = LHS.getValueType();
13876 EVT SetCCVT = getSetCCResultType(LHS.getValueType());
13877 unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
13878 unsigned WideWidth = WideVT.getScalarSizeInBits();
13879 bool IsSigned = isSignedIntSetCC(CC);
13880 auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13881 if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
13882 SetCCWidth != 1 && SetCCWidth < WideWidth &&
13883 TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
13884 TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
13885 // Both compare operands can be widened for free. The LHS can use an
13886 // extended load, and the RHS is a constant:
13887 // vselect (ext (setcc load(X), C)), N1, N2 -->
13888 // vselect (setcc extload(X), C'), N1, N2
13889 auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
13890 SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
13891 SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
13892 EVT WideSetCCVT = getSetCCResultType(WideVT);
13893 SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
13894 return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
13895 }
13896 }
13897
13898 if (SDValue ABD = foldSelectToABD(LHS, RHS, N1, N2, CC, DL))
13899 return ABD;
13900
13901 // Match VSELECTs into add with unsigned saturation.
13902 if (hasOperation(ISD::UADDSAT, VT)) {
13903 // Check if one of the arms of the VSELECT is vector with all bits set.
13904 // If it's on the left side invert the predicate to simplify logic below.
13905 SDValue Other;
13906 ISD::CondCode SatCC = CC;
13908 Other = N2;
13909 SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
13910 } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
13911 Other = N1;
13912 }
13913
13914 if (Other && Other.getOpcode() == ISD::ADD) {
13915 SDValue CondLHS = LHS, CondRHS = RHS;
13916 SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
13917
13918 // Canonicalize condition operands.
13919 if (SatCC == ISD::SETUGE) {
13920 std::swap(CondLHS, CondRHS);
13921 SatCC = ISD::SETULE;
13922 }
13923
13924 // We can test against either of the addition operands.
13925 // x <= x+y ? x+y : ~0 --> uaddsat x, y
13926 // x+y >= x ? x+y : ~0 --> uaddsat x, y
13927 if (SatCC == ISD::SETULE && Other == CondRHS &&
13928 (OpLHS == CondLHS || OpRHS == CondLHS))
13929 return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
13930
13931 if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
13932 (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
13933 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
13934 CondLHS == OpLHS) {
13935 // If the RHS is a constant we have to reverse the const
13936 // canonicalization.
13937 // x >= ~C ? x+C : ~0 --> uaddsat x, C
13938 auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
13939 return Cond->getAPIntValue() == ~Op->getAPIntValue();
13940 };
13941 if (SatCC == ISD::SETULE &&
13942 ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
13943 return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
13944 }
13945 }
13946 }
13947
13948 // Match VSELECTs into sub with unsigned saturation.
13949 if (hasOperation(ISD::USUBSAT, VT)) {
13950 // Check if one of the arms of the VSELECT is a zero vector. If it's on
13951 // the left side invert the predicate to simplify logic below.
13952 SDValue Other;
13953 ISD::CondCode SatCC = CC;
13955 Other = N2;
13956 SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
13958 Other = N1;
13959 }
13960
13961 // zext(x) >= y ? trunc(zext(x) - y) : 0
13962 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
13963 // zext(x) > y ? trunc(zext(x) - y) : 0
13964 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
13965 if (Other && Other.getOpcode() == ISD::TRUNCATE &&
13966 Other.getOperand(0).getOpcode() == ISD::SUB &&
13967 (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
13968 SDValue OpLHS = Other.getOperand(0).getOperand(0);
13969 SDValue OpRHS = Other.getOperand(0).getOperand(1);
13970 if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
13971 if (SDValue R = getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS,
13972 DAG, DL))
13973 return R;
13974 }
13975
13976 if (Other && Other.getNumOperands() == 2) {
13977 SDValue CondRHS = RHS;
13978 SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
13979
13980 if (OpLHS == LHS) {
13981 // Look for a general sub with unsigned saturation first.
13982 // x >= y ? x-y : 0 --> usubsat x, y
13983 // x > y ? x-y : 0 --> usubsat x, y
13984 if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
13985 Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
13986 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
13987
13988 if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
13989 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
13990 if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
13991 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
13992 // If the RHS is a constant we have to reverse the const
13993 // canonicalization.
13994 // x > C-1 ? x+-C : 0 --> usubsat x, C
13995 auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
13996 return (!Op && !Cond) ||
13997 (Op && Cond &&
13998 Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
13999 };
14000 if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
14001 ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
14002 /*AllowUndefs*/ true)) {
14003 OpRHS = DAG.getNegative(OpRHS, DL, VT);
14004 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
14005 }
14006
14007 // Another special case: If C was a sign bit, the sub has been
14008 // canonicalized into a xor.
14009 // FIXME: Would it be better to use computeKnownBits to
14010 // determine whether it's safe to decanonicalize the xor?
14011 // x s< 0 ? x^C : 0 --> usubsat x, C
14012 APInt SplatValue;
14013 if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
14014 ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
14016 SplatValue.isSignMask()) {
14017 // Note that we have to rebuild the RHS constant here to
14018 // ensure we don't rely on particular values of undef lanes.
14019 OpRHS = DAG.getConstant(SplatValue, DL, VT);
14020 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
14021 }
14022 }
14023 }
14024 }
14025 }
14026 }
14027
14028 // (vselect (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x)
14029 // (vselect (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C))
14030 if (SDValue UMin = foldSelectToUMin(LHS, RHS, N1, N2, CC, DL))
14031 return UMin;
14032 }
14033
14034 if (SimplifySelectOps(N, N1, N2))
14035 return SDValue(N, 0); // Don't revisit N.
14036
14037 // Fold (vselect all_ones, N1, N2) -> N1
14039 return N1;
14040 // Fold (vselect all_zeros, N1, N2) -> N2
14042 return N2;
14043
14044 // The ConvertSelectToConcatVector function is assuming both the above
14045 // checks for (vselect (build_vector all{ones,zeros) ...) have been made
14046 // and addressed.
14047 if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
14050 if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
14051 return CV;
14052 }
14053
14054 if (SDValue V = foldVSelectOfConstants(N))
14055 return V;
14056
14057 if (hasOperation(ISD::SRA, VT))
14059 return V;
14060
14062 return SDValue(N, 0);
14063
14064 if (SDValue V = combineVSelectWithAllOnesOrZeros(N0, N1, N2, TLI, DAG, DL))
14065 return V;
14066
14067 return SDValue();
14068}
14069
14070SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
14071 SDValue N0 = N->getOperand(0);
14072 SDValue N1 = N->getOperand(1);
14073 SDValue N2 = N->getOperand(2);
14074 SDValue N3 = N->getOperand(3);
14075 SDValue N4 = N->getOperand(4);
14076 ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
14077 SDLoc DL(N);
14078
14079 // fold select_cc lhs, rhs, x, x, cc -> x
14080 if (N2 == N3)
14081 return N2;
14082
14083 // select_cc bool, 0, x, y, seteq -> select bool, y, x
14084 if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
14085 isNullConstant(N1))
14086 return DAG.getSelect(DL, N2.getValueType(), N0, N3, N2);
14087
14088 // Determine if the condition we're dealing with is constant
14089 if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
14090 CC, DL, false)) {
14091 AddToWorklist(SCC.getNode());
14092
14093 // cond always true -> true val
14094 // cond always false -> false val
14095 if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode()))
14096 return SCCC->isZero() ? N3 : N2;
14097
14098 // When the condition is UNDEF, just return the first operand. This is
14099 // coherent the DAG creation, no setcc node is created in this case
14100 if (SCC->isUndef())
14101 return N2;
14102
14103 // Fold to a simpler select_cc
14104 if (SCC.getOpcode() == ISD::SETCC) {
14105 return DAG.getNode(ISD::SELECT_CC, DL, N2.getValueType(),
14106 SCC.getOperand(0), SCC.getOperand(1), N2, N3,
14107 SCC.getOperand(2), SCC->getFlags());
14108 }
14109 }
14110
14111 // If we can fold this based on the true/false value, do so.
14112 if (SimplifySelectOps(N, N2, N3))
14113 return SDValue(N, 0); // Don't revisit N.
14114
14115 // fold select_cc into other things, such as min/max/abs
14116 return SimplifySelectCC(DL, N0, N1, N2, N3, CC);
14117}
14118
14119SDValue DAGCombiner::visitSETCC(SDNode *N) {
14120 // setcc is very commonly used as an argument to brcond. This pattern
14121 // also lend itself to numerous combines and, as a result, it is desired
14122 // we keep the argument to a brcond as a setcc as much as possible.
14123 bool PreferSetCC =
14124 N->hasOneUse() && N->user_begin()->getOpcode() == ISD::BRCOND;
14125
14126 ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
14127 EVT VT = N->getValueType(0);
14128 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
14129 SDLoc DL(N);
14130
14131 if (SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, DL, !PreferSetCC)) {
14132 // If we prefer to have a setcc, and we don't, we'll try our best to
14133 // recreate one using rebuildSetCC.
14134 if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
14135 SDValue NewSetCC = rebuildSetCC(Combined);
14136
14137 // We don't have anything interesting to combine to.
14138 if (NewSetCC.getNode() == N)
14139 return SDValue();
14140
14141 if (NewSetCC)
14142 return NewSetCC;
14143 }
14144 return Combined;
14145 }
14146
14147 // Optimize
14148 // 1) (icmp eq/ne (and X, C0), (shift X, C1))
14149 // or
14150 // 2) (icmp eq/ne X, (rotate X, C1))
14151 // If C0 is a mask or shifted mask and the shift amt (C1) isolates the
14152 // remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
14153 // Then:
14154 // If C1 is a power of 2, then the rotate and shift+and versions are
14155 // equivilent, so we can interchange them depending on target preference.
14156 // Otherwise, if we have the shift+and version we can interchange srl/shl
14157 // which inturn affects the constant C0. We can use this to get better
14158 // constants again determined by target preference.
14159 if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
14160 auto IsAndWithShift = [](SDValue A, SDValue B) {
14161 return A.getOpcode() == ISD::AND &&
14162 (B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
14163 A.getOperand(0) == B.getOperand(0);
14164 };
14165 auto IsRotateWithOp = [](SDValue A, SDValue B) {
14166 return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
14167 B.getOperand(0) == A;
14168 };
14169 SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
14170 bool IsRotate = false;
14171
14172 // Find either shift+and or rotate pattern.
14173 if (IsAndWithShift(N0, N1)) {
14174 AndOrOp = N0;
14175 ShiftOrRotate = N1;
14176 } else if (IsAndWithShift(N1, N0)) {
14177 AndOrOp = N1;
14178 ShiftOrRotate = N0;
14179 } else if (IsRotateWithOp(N0, N1)) {
14180 IsRotate = true;
14181 AndOrOp = N0;
14182 ShiftOrRotate = N1;
14183 } else if (IsRotateWithOp(N1, N0)) {
14184 IsRotate = true;
14185 AndOrOp = N1;
14186 ShiftOrRotate = N0;
14187 }
14188
14189 if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
14190 (IsRotate || AndOrOp.hasOneUse())) {
14191 EVT OpVT = N0.getValueType();
14192 // Get constant shift/rotate amount and possibly mask (if its shift+and
14193 // variant).
14194 auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
14195 ConstantSDNode *CNode = isConstOrConstSplat(Op, /*AllowUndefs*/ false,
14196 /*AllowTrunc*/ false);
14197 if (CNode == nullptr)
14198 return std::nullopt;
14199 return CNode->getAPIntValue();
14200 };
14201 std::optional<APInt> AndCMask =
14202 IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(1));
14203 std::optional<APInt> ShiftCAmt =
14204 GetAPIntValue(ShiftOrRotate.getOperand(1));
14205 unsigned NumBits = OpVT.getScalarSizeInBits();
14206
14207 // We found constants.
14208 if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(NumBits)) {
14209 unsigned ShiftOpc = ShiftOrRotate.getOpcode();
14210 // Check that the constants meet the constraints.
14211 bool CanTransform = IsRotate;
14212 if (!CanTransform) {
14213 // Check that mask and shift compliment eachother
14214 CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
14215 // Check that we are comparing all bits
14216 CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
14217 // Check that the and mask is correct for the shift
14218 CanTransform &=
14219 ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
14220 }
14221
14222 // See if target prefers another shift/rotate opcode.
14223 unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
14224 OpVT, ShiftOpc, ShiftCAmt->isPowerOf2(), *ShiftCAmt, AndCMask);
14225 // Transform is valid and we have a new preference.
14226 if (CanTransform && NewShiftOpc != ShiftOpc) {
14227 SDValue NewShiftOrRotate =
14228 DAG.getNode(NewShiftOpc, DL, OpVT, ShiftOrRotate.getOperand(0),
14229 ShiftOrRotate.getOperand(1));
14230 SDValue NewAndOrOp = SDValue();
14231
14232 if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
14233 APInt NewMask =
14234 NewShiftOpc == ISD::SHL
14235 ? APInt::getHighBitsSet(NumBits,
14236 NumBits - ShiftCAmt->getZExtValue())
14237 : APInt::getLowBitsSet(NumBits,
14238 NumBits - ShiftCAmt->getZExtValue());
14239 NewAndOrOp =
14240 DAG.getNode(ISD::AND, DL, OpVT, ShiftOrRotate.getOperand(0),
14241 DAG.getConstant(NewMask, DL, OpVT));
14242 } else {
14243 NewAndOrOp = ShiftOrRotate.getOperand(0);
14244 }
14245
14246 return DAG.getSetCC(DL, VT, NewAndOrOp, NewShiftOrRotate, Cond);
14247 }
14248 }
14249 }
14250 }
14251 return SDValue();
14252}
14253
14254SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
14255 SDValue LHS = N->getOperand(0);
14256 SDValue RHS = N->getOperand(1);
14257 SDValue Carry = N->getOperand(2);
14258 SDValue Cond = N->getOperand(3);
14259
14260 // If Carry is false, fold to a regular SETCC.
14261 if (isNullConstant(Carry))
14262 return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
14263
14264 return SDValue();
14265}
14266
14267/// Check if N satisfies:
14268/// N is used once.
14269/// N is a Load.
14270/// The load is compatible with ExtOpcode. It means
14271/// If load has explicit zero/sign extension, ExpOpcode must have the same
14272/// extension.
14273/// Otherwise returns true.
14274static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
14275 if (!N.hasOneUse())
14276 return false;
14277
14278 if (!isa<LoadSDNode>(N))
14279 return false;
14280
14281 LoadSDNode *Load = cast<LoadSDNode>(N);
14282 ISD::LoadExtType LoadExt = Load->getExtensionType();
14283 if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
14284 return true;
14285
14286 // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
14287 // extension.
14288 if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
14289 (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
14290 return false;
14291
14292 return true;
14293}
14294
14295/// Fold
14296/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
14297/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
14298/// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
14299/// This function is called by the DAGCombiner when visiting sext/zext/aext
14300/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
14302 SelectionDAG &DAG, const SDLoc &DL,
14303 CombineLevel Level) {
14304 unsigned Opcode = N->getOpcode();
14305 SDValue N0 = N->getOperand(0);
14306 EVT VT = N->getValueType(0);
14307 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
14308 Opcode == ISD::ANY_EXTEND) &&
14309 "Expected EXTEND dag node in input!");
14310
14311 SDValue Cond, Op1, Op2;
14313 m_Value(Op2)))))
14314 return SDValue();
14315
14316 if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
14317 return SDValue();
14318
14319 auto ExtLoadOpcode = ISD::EXTLOAD;
14320 if (Opcode == ISD::SIGN_EXTEND)
14321 ExtLoadOpcode = ISD::SEXTLOAD;
14322 else if (Opcode == ISD::ZERO_EXTEND)
14323 ExtLoadOpcode = ISD::ZEXTLOAD;
14324
14325 // Illegal VSELECT may ISel fail if happen after legalization (DAG
14326 // Combine2), so we should conservatively check the OperationAction.
14327 LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
14328 LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
14329 if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
14330 !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()) ||
14331 (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
14333 return SDValue();
14334
14335 SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
14336 SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
14337 return DAG.getSelect(DL, VT, Cond, Ext1, Ext2);
14338}
14339
14340/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
14341/// a build_vector of constants.
14342/// This function is called by the DAGCombiner when visiting sext/zext/aext
14343/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
14344/// Vector extends are not folded if operations are legal; this is to
14345/// avoid introducing illegal build_vector dag nodes.
14347 const TargetLowering &TLI,
14348 SelectionDAG &DAG, bool LegalTypes) {
14349 unsigned Opcode = N->getOpcode();
14350 SDValue N0 = N->getOperand(0);
14351 EVT VT = N->getValueType(0);
14352
14353 assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) &&
14354 "Expected EXTEND dag node in input!");
14355
14356 // fold (sext c1) -> c1
14357 // fold (zext c1) -> c1
14358 // fold (aext c1) -> c1
14359 if (isa<ConstantSDNode>(N0))
14360 return DAG.getNode(Opcode, DL, VT, N0);
14361
14362 // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
14363 // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
14364 // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
14365 if (N0->getOpcode() == ISD::SELECT) {
14366 SDValue Op1 = N0->getOperand(1);
14367 SDValue Op2 = N0->getOperand(2);
14368 if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
14369 (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
14370 // For any_extend, choose sign extension of the constants to allow a
14371 // possible further transform to sign_extend_inreg.i.e.
14372 //
14373 // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
14374 // t2: i64 = any_extend t1
14375 // -->
14376 // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
14377 // -->
14378 // t4: i64 = sign_extend_inreg t3
14379 unsigned FoldOpc = Opcode;
14380 if (FoldOpc == ISD::ANY_EXTEND)
14381 FoldOpc = ISD::SIGN_EXTEND;
14382 return DAG.getSelect(DL, VT, N0->getOperand(0),
14383 DAG.getNode(FoldOpc, DL, VT, Op1),
14384 DAG.getNode(FoldOpc, DL, VT, Op2));
14385 }
14386 }
14387
14388 // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
14389 // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
14390 // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
14391 EVT SVT = VT.getScalarType();
14392 if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
14394 return SDValue();
14395
14396 // We can fold this node into a build_vector.
14397 unsigned VTBits = SVT.getSizeInBits();
14398 unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
14400 unsigned NumElts = VT.getVectorNumElements();
14401
14402 for (unsigned i = 0; i != NumElts; ++i) {
14403 SDValue Op = N0.getOperand(i);
14404 if (Op.isUndef()) {
14405 if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
14406 Elts.push_back(DAG.getUNDEF(SVT));
14407 else
14408 Elts.push_back(DAG.getConstant(0, DL, SVT));
14409 continue;
14410 }
14411
14412 SDLoc DL(Op);
14413 // Get the constant value and if needed trunc it to the size of the type.
14414 // Nodes like build_vector might have constants wider than the scalar type.
14415 APInt C = Op->getAsAPIntVal().zextOrTrunc(EVTBits);
14416 if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
14417 Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
14418 else
14419 Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
14420 }
14421
14422 return DAG.getBuildVector(VT, DL, Elts);
14423}
14424
14425// ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
14426// "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
14427// transformation. Returns true if extension are possible and the above
14428// mentioned transformation is profitable.
14430 unsigned ExtOpc,
14431 SmallVectorImpl<SDNode *> &ExtendNodes,
14432 const TargetLowering &TLI) {
14433 bool HasCopyToRegUses = false;
14434 bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
14435 for (SDUse &Use : N0->uses()) {
14436 SDNode *User = Use.getUser();
14437 if (User == N)
14438 continue;
14439 if (Use.getResNo() != N0.getResNo())
14440 continue;
14441 // FIXME: Only extend SETCC N, N and SETCC N, c for now.
14442 if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
14444 if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
14445 // Sign bits will be lost after a zext.
14446 return false;
14447 bool Add = false;
14448 for (unsigned i = 0; i != 2; ++i) {
14449 SDValue UseOp = User->getOperand(i);
14450 if (UseOp == N0)
14451 continue;
14452 if (!isa<ConstantSDNode>(UseOp))
14453 return false;
14454 Add = true;
14455 }
14456 if (Add)
14457 ExtendNodes.push_back(User);
14458 continue;
14459 }
14460 // If truncates aren't free and there are users we can't
14461 // extend, it isn't worthwhile.
14462 if (!isTruncFree)
14463 return false;
14464 // Remember if this value is live-out.
14465 if (User->getOpcode() == ISD::CopyToReg)
14466 HasCopyToRegUses = true;
14467 }
14468
14469 if (HasCopyToRegUses) {
14470 bool BothLiveOut = false;
14471 for (SDUse &Use : N->uses()) {
14472 if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
14473 BothLiveOut = true;
14474 break;
14475 }
14476 }
14477 if (BothLiveOut)
14478 // Both unextended and extended values are live out. There had better be
14479 // a good reason for the transformation.
14480 return !ExtendNodes.empty();
14481 }
14482 return true;
14483}
14484
14485void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
14486 SDValue OrigLoad, SDValue ExtLoad,
14487 ISD::NodeType ExtType) {
14488 // Extend SetCC uses if necessary.
14489 SDLoc DL(ExtLoad);
14490 for (SDNode *SetCC : SetCCs) {
14492
14493 for (unsigned j = 0; j != 2; ++j) {
14494 SDValue SOp = SetCC->getOperand(j);
14495 if (SOp == OrigLoad)
14496 Ops.push_back(ExtLoad);
14497 else
14498 Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
14499 }
14500
14501 Ops.push_back(SetCC->getOperand(2));
14502 CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
14503 }
14504}
14505
14506// FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
14507SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
14508 SDValue N0 = N->getOperand(0);
14509 EVT DstVT = N->getValueType(0);
14510 EVT SrcVT = N0.getValueType();
14511
14512 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
14513 N->getOpcode() == ISD::ZERO_EXTEND) &&
14514 "Unexpected node type (not an extend)!");
14515
14516 // fold (sext (load x)) to multiple smaller sextloads; same for zext.
14517 // For example, on a target with legal v4i32, but illegal v8i32, turn:
14518 // (v8i32 (sext (v8i16 (load x))))
14519 // into:
14520 // (v8i32 (concat_vectors (v4i32 (sextload x)),
14521 // (v4i32 (sextload (x + 16)))))
14522 // Where uses of the original load, i.e.:
14523 // (v8i16 (load x))
14524 // are replaced with:
14525 // (v8i16 (truncate
14526 // (v8i32 (concat_vectors (v4i32 (sextload x)),
14527 // (v4i32 (sextload (x + 16)))))))
14528 //
14529 // This combine is only applicable to illegal, but splittable, vectors.
14530 // All legal types, and illegal non-vector types, are handled elsewhere.
14531 // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
14532 //
14533 if (N0->getOpcode() != ISD::LOAD)
14534 return SDValue();
14535
14536 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14537
14538 if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
14539 !N0.hasOneUse() || !LN0->isSimple() ||
14540 !DstVT.isVector() || !DstVT.isPow2VectorType() ||
14542 return SDValue();
14543
14545 if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
14546 return SDValue();
14547
14548 ISD::LoadExtType ExtType =
14549 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
14550
14551 // Try to split the vector types to get down to legal types.
14552 EVT SplitSrcVT = SrcVT;
14553 EVT SplitDstVT = DstVT;
14554 while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
14555 SplitSrcVT.getVectorNumElements() > 1) {
14556 SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
14557 SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
14558 }
14559
14560 if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
14561 return SDValue();
14562
14563 assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
14564
14565 SDLoc DL(N);
14566 const unsigned NumSplits =
14567 DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
14568 const unsigned Stride = SplitSrcVT.getStoreSize();
14571
14572 SDValue BasePtr = LN0->getBasePtr();
14573 for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
14574 const unsigned Offset = Idx * Stride;
14575
14577 DAG.getExtLoad(ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(),
14578 BasePtr, LN0->getPointerInfo().getWithOffset(Offset),
14579 SplitSrcVT, LN0->getBaseAlign(),
14580 LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
14581
14582 BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::getFixed(Stride), DL);
14583
14584 Loads.push_back(SplitLoad.getValue(0));
14585 Chains.push_back(SplitLoad.getValue(1));
14586 }
14587
14588 SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
14589 SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
14590
14591 // Simplify TF.
14592 AddToWorklist(NewChain.getNode());
14593
14594 CombineTo(N, NewValue);
14595
14596 // Replace uses of the original load (before extension)
14597 // with a truncate of the concatenated sextloaded vectors.
14598 SDValue Trunc =
14599 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
14600 ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
14601 CombineTo(N0.getNode(), Trunc, NewChain);
14602 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14603}
14604
14605// fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
14606// (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
14607SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
14608 assert(N->getOpcode() == ISD::ZERO_EXTEND);
14609 EVT VT = N->getValueType(0);
14610 EVT OrigVT = N->getOperand(0).getValueType();
14611 if (TLI.isZExtFree(OrigVT, VT))
14612 return SDValue();
14613
14614 // and/or/xor
14615 SDValue N0 = N->getOperand(0);
14616 if (!ISD::isBitwiseLogicOp(N0.getOpcode()) ||
14617 N0.getOperand(1).getOpcode() != ISD::Constant ||
14618 (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
14619 return SDValue();
14620
14621 // shl/shr
14622 SDValue N1 = N0->getOperand(0);
14623 if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
14624 N1.getOperand(1).getOpcode() != ISD::Constant ||
14625 (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
14626 return SDValue();
14627
14628 // load
14629 if (!isa<LoadSDNode>(N1.getOperand(0)))
14630 return SDValue();
14631 LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
14632 EVT MemVT = Load->getMemoryVT();
14633 if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
14634 Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
14635 return SDValue();
14636
14637
14638 // If the shift op is SHL, the logic op must be AND, otherwise the result
14639 // will be wrong.
14640 if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
14641 return SDValue();
14642
14643 if (!N0.hasOneUse() || !N1.hasOneUse())
14644 return SDValue();
14645
14647 if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
14648 ISD::ZERO_EXTEND, SetCCs, TLI))
14649 return SDValue();
14650
14651 // Actually do the transformation.
14652 SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
14653 Load->getChain(), Load->getBasePtr(),
14654 Load->getMemoryVT(), Load->getMemOperand());
14655
14656 SDLoc DL1(N1);
14657 SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
14658 N1.getOperand(1));
14659
14660 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
14661 SDLoc DL0(N0);
14662 SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
14663 DAG.getConstant(Mask, DL0, VT));
14664
14665 ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
14666 CombineTo(N, And);
14667 if (SDValue(Load, 0).hasOneUse()) {
14668 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
14669 } else {
14670 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
14671 Load->getValueType(0), ExtLoad);
14672 CombineTo(Load, Trunc, ExtLoad.getValue(1));
14673 }
14674
14675 // N0 is dead at this point.
14676 recursivelyDeleteUnusedNodes(N0.getNode());
14677
14678 return SDValue(N,0); // Return N so it doesn't get rechecked!
14679}
14680
14681/// If we're narrowing or widening the result of a vector select and the final
14682/// size is the same size as a setcc (compare) feeding the select, then try to
14683/// apply the cast operation to the select's operands because matching vector
14684/// sizes for a select condition and other operands should be more efficient.
14685SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
14686 unsigned CastOpcode = Cast->getOpcode();
14687 assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
14688 CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
14689 CastOpcode == ISD::FP_ROUND) &&
14690 "Unexpected opcode for vector select narrowing/widening");
14691
14692 // We only do this transform before legal ops because the pattern may be
14693 // obfuscated by target-specific operations after legalization. Do not create
14694 // an illegal select op, however, because that may be difficult to lower.
14695 EVT VT = Cast->getValueType(0);
14696 if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
14697 return SDValue();
14698
14699 SDValue VSel = Cast->getOperand(0);
14700 if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
14701 VSel.getOperand(0).getOpcode() != ISD::SETCC)
14702 return SDValue();
14703
14704 // Does the setcc have the same vector size as the casted select?
14705 SDValue SetCC = VSel.getOperand(0);
14706 EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
14707 if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
14708 return SDValue();
14709
14710 // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
14711 SDValue A = VSel.getOperand(1);
14712 SDValue B = VSel.getOperand(2);
14713 SDValue CastA, CastB;
14714 SDLoc DL(Cast);
14715 if (CastOpcode == ISD::FP_ROUND) {
14716 // FP_ROUND (fptrunc) has an extra flag operand to pass along.
14717 CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
14718 CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
14719 } else {
14720 CastA = DAG.getNode(CastOpcode, DL, VT, A);
14721 CastB = DAG.getNode(CastOpcode, DL, VT, B);
14722 }
14723 return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
14724}
14725
14726// fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14727// fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14729 const TargetLowering &TLI, EVT VT,
14730 bool LegalOperations, SDNode *N,
14731 SDValue N0, ISD::LoadExtType ExtLoadType) {
14732 bool Frozen = N0.getOpcode() == ISD::FREEZE;
14733 auto *OldExtLoad = dyn_cast<LoadSDNode>(Frozen ? N0.getOperand(0) : N0);
14734 if (!OldExtLoad)
14735 return SDValue();
14736
14737 bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD)
14738 ? ISD::isSEXTLoad(OldExtLoad)
14739 : ISD::isZEXTLoad(OldExtLoad);
14740 if ((!isAExtLoad && !ISD::isEXTLoad(OldExtLoad)) ||
14741 !ISD::isUNINDEXEDLoad(OldExtLoad) || !OldExtLoad->hasNUsesOfValue(1, 0))
14742 return SDValue();
14743
14744 EVT MemVT = OldExtLoad->getMemoryVT();
14745 if ((LegalOperations || !OldExtLoad->isSimple() || VT.isVector()) &&
14746 !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
14747 return SDValue();
14748
14749 SDLoc DL(OldExtLoad);
14750 SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, DL, VT, OldExtLoad->getChain(),
14751 OldExtLoad->getBasePtr(), MemVT,
14752 OldExtLoad->getMemOperand());
14753 SDValue Res = ExtLoad;
14754 if (Frozen) {
14755 Res = DAG.getFreeze(ExtLoad);
14756 Res = DAG.getNode(
14757 ExtLoadType == ISD::SEXTLOAD ? ISD::AssertSext : ISD::AssertZext, DL,
14758 Res.getValueType(), Res,
14759 DAG.getValueType(OldExtLoad->getValueType(0).getScalarType()));
14760 }
14761 Combiner.CombineTo(N, Res);
14762 DAG.ReplaceAllUsesOfValueWith(SDValue(OldExtLoad, 1), ExtLoad.getValue(1));
14763 if (N0->use_empty())
14764 Combiner.recursivelyDeleteUnusedNodes(N0.getNode());
14765 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14766}
14767
14768// fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14769// Only generate vector extloads when 1) they're legal, and 2) they are
14770// deemed desirable by the target. NonNegZExt can be set to true if a zero
14771// extend has the nonneg flag to allow use of sextload if profitable.
14773 const TargetLowering &TLI, EVT VT,
14774 bool LegalOperations, SDNode *N, SDValue N0,
14775 ISD::LoadExtType ExtLoadType,
14776 ISD::NodeType ExtOpc,
14777 bool NonNegZExt = false) {
14778 bool Frozen = N0.getOpcode() == ISD::FREEZE;
14779 SDValue Freeze = Frozen ? N0 : SDValue();
14780 auto *Load = dyn_cast<LoadSDNode>(Frozen ? N0.getOperand(0) : N0);
14781 // TODO: Support multiple uses of the load when frozen.
14782 if (!Load || !ISD::isNON_EXTLoad(Load) || !ISD::isUNINDEXEDLoad(Load) ||
14783 (Frozen && !Load->hasNUsesOfValue(1, 0)))
14784 return {};
14785
14786 // If this is zext nneg, see if it would make sense to treat it as a sext.
14787 if (NonNegZExt) {
14788 assert(ExtLoadType == ISD::ZEXTLOAD && ExtOpc == ISD::ZERO_EXTEND &&
14789 "Unexpected load type or opcode");
14790 for (SDNode *User : Load->users()) {
14791 if (User->getOpcode() == ISD::SETCC) {
14793 if (ISD::isSignedIntSetCC(CC)) {
14794 ExtLoadType = ISD::SEXTLOAD;
14795 ExtOpc = ISD::SIGN_EXTEND;
14796 break;
14797 }
14798 }
14799 }
14800 }
14801
14802 // TODO: isFixedLengthVector() should be removed and any negative effects on
14803 // code generation being the result of that target's implementation of
14804 // isVectorLoadExtDesirable().
14805 if ((LegalOperations || VT.isFixedLengthVector() || !Load->isSimple()) &&
14806 !TLI.isLoadExtLegal(ExtLoadType, VT, Load->getValueType(0)))
14807 return {};
14808
14809 bool DoXform = true;
14811 if (!N0->hasOneUse())
14812 DoXform = ExtendUsesToFormExtLoad(VT, N, Frozen ? Freeze : SDValue(Load, 0),
14813 ExtOpc, SetCCs, TLI);
14814 if (VT.isVector())
14815 DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
14816 if (!DoXform)
14817 return {};
14818
14819 SDLoc DL(Load);
14820 // If the load value is used only by N, replace it via CombineTo N.
14821 bool NoReplaceTrunc = N0.hasOneUse();
14822 SDValue ExtLoad =
14823 DAG.getExtLoad(ExtLoadType, DL, VT, Load->getChain(), Load->getBasePtr(),
14824 Load->getValueType(0), Load->getMemOperand());
14825 SDValue Res = ExtLoad;
14826 if (Frozen) {
14827 Res = DAG.getFreeze(ExtLoad);
14828 Res = DAG.getNode(ExtLoadType == ISD::SEXTLOAD ? ISD::AssertSext
14830 DL, Res.getValueType(), Res,
14831 DAG.getValueType(Load->getValueType(0).getScalarType()));
14832 }
14833 Combiner.ExtendSetCCUses(SetCCs, N0, Res, ExtOpc);
14834 Combiner.CombineTo(N, Res);
14835 if (NoReplaceTrunc) {
14836 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
14837 Combiner.recursivelyDeleteUnusedNodes(N0.getNode());
14838 } else {
14839 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, Load->getValueType(0), Res);
14840 if (Frozen) {
14841 Combiner.CombineTo(Freeze.getNode(), Trunc);
14842 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
14843 } else {
14844 Combiner.CombineTo(Load, Trunc, ExtLoad.getValue(1));
14845 }
14846 }
14847 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14848}
14849
14850static SDValue
14852 bool LegalOperations, SDNode *N, SDValue N0,
14853 ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) {
14854 if (!N0.hasOneUse())
14855 return SDValue();
14856
14858 if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
14859 return SDValue();
14860
14861 if ((LegalOperations || !cast<MaskedLoadSDNode>(N0)->isSimple()) &&
14862 !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
14863 return SDValue();
14864
14865 if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
14866 return SDValue();
14867
14868 SDLoc dl(Ld);
14869 SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
14870 SDValue NewLoad = DAG.getMaskedLoad(
14871 VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
14872 PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
14873 ExtLoadType, Ld->isExpandingLoad());
14874 DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
14875 return NewLoad;
14876}
14877
14878// fold ([s|z]ext (atomic_load)) -> ([s|z]ext (truncate ([s|z]ext atomic_load)))
14880 const TargetLowering &TLI, EVT VT,
14881 SDValue N0,
14882 ISD::LoadExtType ExtLoadType) {
14883 auto *ALoad = dyn_cast<AtomicSDNode>(N0);
14884 if (!ALoad || ALoad->getOpcode() != ISD::ATOMIC_LOAD)
14885 return {};
14886 EVT MemoryVT = ALoad->getMemoryVT();
14887 if (!TLI.isAtomicLoadExtLegal(ExtLoadType, VT, MemoryVT))
14888 return {};
14889 // Can't fold into ALoad if it is already extending differently.
14890 ISD::LoadExtType ALoadExtTy = ALoad->getExtensionType();
14891 if ((ALoadExtTy == ISD::ZEXTLOAD && ExtLoadType == ISD::SEXTLOAD) ||
14892 (ALoadExtTy == ISD::SEXTLOAD && ExtLoadType == ISD::ZEXTLOAD))
14893 return {};
14894
14895 EVT OrigVT = ALoad->getValueType(0);
14896 assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
14897 auto *NewALoad = cast<AtomicSDNode>(DAG.getAtomicLoad(
14898 ExtLoadType, SDLoc(ALoad), MemoryVT, VT, ALoad->getChain(),
14899 ALoad->getBasePtr(), ALoad->getMemOperand()));
14901 SDValue(ALoad, 0),
14902 DAG.getNode(ISD::TRUNCATE, SDLoc(ALoad), OrigVT, SDValue(NewALoad, 0)));
14903 // Update the chain uses.
14904 DAG.ReplaceAllUsesOfValueWith(SDValue(ALoad, 1), SDValue(NewALoad, 1));
14905 return SDValue(NewALoad, 0);
14906}
14907
14909 bool LegalOperations) {
14910 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
14911 N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
14912
14913 SDValue SetCC = N->getOperand(0);
14914 if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
14915 !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
14916 return SDValue();
14917
14918 SDValue X = SetCC.getOperand(0);
14919 SDValue Ones = SetCC.getOperand(1);
14920 ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
14921 EVT VT = N->getValueType(0);
14922 EVT XVT = X.getValueType();
14923 // setge X, C is canonicalized to setgt, so we do not need to match that
14924 // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
14925 // not require the 'not' op.
14926 if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
14927 // Invert and smear/shift the sign bit:
14928 // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
14929 // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
14930 SDLoc DL(N);
14931 unsigned ShCt = VT.getSizeInBits() - 1;
14932 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14933 if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
14934 SDValue NotX = DAG.getNOT(DL, X, VT);
14935 SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
14936 auto ShiftOpcode =
14937 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
14938 return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
14939 }
14940 }
14941 return SDValue();
14942}
14943
14944SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
14945 SDValue N0 = N->getOperand(0);
14946 if (N0.getOpcode() != ISD::SETCC)
14947 return SDValue();
14948
14949 SDValue N00 = N0.getOperand(0);
14950 SDValue N01 = N0.getOperand(1);
14952 EVT VT = N->getValueType(0);
14953 EVT N00VT = N00.getValueType();
14954 SDLoc DL(N);
14955
14956 // Propagate fast-math-flags.
14957 SDNodeFlags Flags = N0->getFlags();
14958
14959 // On some architectures (such as SSE/NEON/etc) the SETCC result type is
14960 // the same size as the compared operands. Try to optimize sext(setcc())
14961 // if this is the case.
14962 if (VT.isVector() && !LegalOperations &&
14963 TLI.getBooleanContents(N00VT) ==
14965 EVT SVT = getSetCCResultType(N00VT);
14966
14967 // If we already have the desired type, don't change it.
14968 if (SVT != N0.getValueType()) {
14969 // We know that the # elements of the results is the same as the
14970 // # elements of the compare (and the # elements of the compare result
14971 // for that matter). Check to see that they are the same size. If so,
14972 // we know that the element size of the sext'd result matches the
14973 // element size of the compare operands.
14974 if (VT.getSizeInBits() == SVT.getSizeInBits())
14975 return DAG.getSetCC(DL, VT, N00, N01, CC, /*Chain=*/{},
14976 /*Signaling=*/false, Flags);
14977
14978 // If the desired elements are smaller or larger than the source
14979 // elements, we can use a matching integer vector type and then
14980 // truncate/sign extend.
14981 EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
14982 if (SVT == MatchingVecType) {
14983 SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC,
14984 /*Chain=*/{}, /*Signaling=*/false, Flags);
14985 return DAG.getSExtOrTrunc(VsetCC, DL, VT);
14986 }
14987 }
14988
14989 // Try to eliminate the sext of a setcc by zexting the compare operands.
14990 if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(ISD::SETCC, VT) &&
14992 bool IsSignedCmp = ISD::isSignedIntSetCC(CC);
14993 unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
14994 unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
14995
14996 // We have an unsupported narrow vector compare op that would be legal
14997 // if extended to the destination type. See if the compare operands
14998 // can be freely extended to the destination type.
14999 auto IsFreeToExtend = [&](SDValue V) {
15000 if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
15001 return true;
15002 // Match a simple, non-extended load that can be converted to a
15003 // legal {z/s}ext-load.
15004 // TODO: Allow widening of an existing {z/s}ext-load?
15005 if (!(ISD::isNON_EXTLoad(V.getNode()) &&
15006 ISD::isUNINDEXEDLoad(V.getNode()) &&
15007 cast<LoadSDNode>(V)->isSimple() &&
15008 TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
15009 return false;
15010
15011 // Non-chain users of this value must either be the setcc in this
15012 // sequence or extends that can be folded into the new {z/s}ext-load.
15013 for (SDUse &Use : V->uses()) {
15014 // Skip uses of the chain and the setcc.
15015 SDNode *User = Use.getUser();
15016 if (Use.getResNo() != 0 || User == N0.getNode())
15017 continue;
15018 // Extra users must have exactly the same cast we are about to create.
15019 // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
15020 // is enhanced similarly.
15021 if (User->getOpcode() != ExtOpcode || User->getValueType(0) != VT)
15022 return false;
15023 }
15024 return true;
15025 };
15026
15027 if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
15028 SDValue Ext0 = DAG.getNode(ExtOpcode, DL, VT, N00);
15029 SDValue Ext1 = DAG.getNode(ExtOpcode, DL, VT, N01);
15030 return DAG.getSetCC(DL, VT, Ext0, Ext1, CC, /*Chain=*/{},
15031 /*Signaling=*/false, Flags);
15032 }
15033 }
15034 }
15035
15036 // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
15037 // Here, T can be 1 or -1, depending on the type of the setcc and
15038 // getBooleanContents().
15039 unsigned SetCCWidth = N0.getScalarValueSizeInBits();
15040
15041 // To determine the "true" side of the select, we need to know the high bit
15042 // of the value returned by the setcc if it evaluates to true.
15043 // If the type of the setcc is i1, then the true case of the select is just
15044 // sext(i1 1), that is, -1.
15045 // If the type of the setcc is larger (say, i8) then the value of the high
15046 // bit depends on getBooleanContents(), so ask TLI for a real "true" value
15047 // of the appropriate width.
15048 SDValue ExtTrueVal = (SetCCWidth == 1)
15049 ? DAG.getAllOnesConstant(DL, VT)
15050 : DAG.getBoolConstant(true, DL, VT, N00VT);
15051 SDValue Zero = DAG.getConstant(0, DL, VT);
15052 if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
15053 return SCC;
15054
15055 if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(N0, VT, TLI)) {
15056 EVT SetCCVT = getSetCCResultType(N00VT);
15057 // Don't do this transform for i1 because there's a select transform
15058 // that would reverse it.
15059 // TODO: We should not do this transform at all without a target hook
15060 // because a sext is likely cheaper than a select?
15061 if (SetCCVT.getScalarSizeInBits() != 1 &&
15062 (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
15063 SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC, /*Chain=*/{},
15064 /*Signaling=*/false, Flags);
15065 return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero, Flags);
15066 }
15067 }
15068
15069 return SDValue();
15070}
15071
15072SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
15073 SDValue N0 = N->getOperand(0);
15074 EVT VT = N->getValueType(0);
15075 SDLoc DL(N);
15076
15077 if (VT.isVector())
15078 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
15079 return FoldedVOp;
15080
15081 // sext(undef) = 0 because the top bit will all be the same.
15082 if (N0.isUndef())
15083 return DAG.getConstant(0, DL, VT);
15084
15085 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15086 return Res;
15087
15088 // fold (sext (sext x)) -> (sext x)
15089 // fold (sext (aext x)) -> (sext x)
15090 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
15091 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
15092
15093 // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
15094 // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
15097 return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT,
15098 N0.getOperand(0));
15099
15100 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
15101 SDValue N00 = N0.getOperand(0);
15102 EVT ExtVT = cast<VTSDNode>(N0->getOperand(1))->getVT();
15103 if (N00.getOpcode() == ISD::TRUNCATE || TLI.isTruncateFree(N00, ExtVT)) {
15104 // fold (sext (sext_inreg x)) -> (sext (trunc x))
15105 if ((!LegalTypes || TLI.isTypeLegal(ExtVT))) {
15106 SDValue T = DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N00);
15107 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, T);
15108 }
15109
15110 // If the trunc wasn't legal, try to fold to (sext_inreg (anyext x))
15111 if (!LegalTypes || TLI.isTypeLegal(VT)) {
15112 SDValue ExtSrc = DAG.getAnyExtOrTrunc(N00, DL, VT);
15113 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, ExtSrc,
15114 N0->getOperand(1));
15115 }
15116 }
15117 }
15118
15119 if (N0.getOpcode() == ISD::TRUNCATE) {
15120 // fold (sext (truncate (load x))) -> (sext (smaller load x))
15121 // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
15122 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
15123 SDNode *oye = N0.getOperand(0).getNode();
15124 if (NarrowLoad.getNode() != N0.getNode()) {
15125 CombineTo(N0.getNode(), NarrowLoad);
15126 // CombineTo deleted the truncate, if needed, but not what's under it.
15127 AddToWorklist(oye);
15128 }
15129 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15130 }
15131
15132 // See if the value being truncated is already sign extended. If so, just
15133 // eliminate the trunc/sext pair.
15134 SDValue Op = N0.getOperand(0);
15135 unsigned OpBits = Op.getScalarValueSizeInBits();
15136 unsigned MidBits = N0.getScalarValueSizeInBits();
15137 unsigned DestBits = VT.getScalarSizeInBits();
15138
15139 if (N0->getFlags().hasNoSignedWrap() ||
15140 DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
15141 if (OpBits == DestBits) {
15142 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
15143 // bits, it is already ready.
15144 return Op;
15145 }
15146
15147 if (OpBits < DestBits) {
15148 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
15149 // bits, just sext from i32.
15150 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
15151 }
15152
15153 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
15154 // bits, just truncate to i32.
15155 SDNodeFlags Flags;
15156 Flags.setNoSignedWrap(true);
15157 Flags.setNoUnsignedWrap(N0->getFlags().hasNoUnsignedWrap());
15158 return DAG.getNode(ISD::TRUNCATE, DL, VT, Op, Flags);
15159 }
15160
15161 // fold (sext (truncate x)) -> (sextinreg x).
15162 if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
15163 N0.getValueType())) {
15164 if (OpBits < DestBits)
15165 Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
15166 else if (OpBits > DestBits)
15167 Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
15168 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
15169 DAG.getValueType(N0.getValueType()));
15170 }
15171 }
15172
15173 // Try to simplify (sext (load x)).
15174 if (SDValue foldedExt =
15175 tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
15177 return foldedExt;
15178
15179 if (SDValue foldedExt =
15180 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
15182 return foldedExt;
15183
15184 // fold (sext (load x)) to multiple smaller sextloads.
15185 // Only on illegal but splittable vectors.
15186 if (SDValue ExtLoad = CombineExtLoad(N))
15187 return ExtLoad;
15188
15189 // Try to simplify (sext (sextload x)).
15190 if (SDValue foldedExt = tryToFoldExtOfExtload(
15191 DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
15192 return foldedExt;
15193
15194 // Try to simplify (sext (atomic_load x)).
15195 if (SDValue foldedExt =
15196 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ISD::SEXTLOAD))
15197 return foldedExt;
15198
15199 // fold (sext (and/or/xor (load x), cst)) ->
15200 // (and/or/xor (sextload x), (sext cst))
15201 if (ISD::isBitwiseLogicOp(N0.getOpcode()) &&
15202 isa<LoadSDNode>(N0.getOperand(0)) &&
15203 N0.getOperand(1).getOpcode() == ISD::Constant &&
15204 (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
15205 LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
15206 EVT MemVT = LN00->getMemoryVT();
15207 if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
15208 LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
15210 bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
15211 ISD::SIGN_EXTEND, SetCCs, TLI);
15212 if (DoXform) {
15213 SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
15214 LN00->getChain(), LN00->getBasePtr(),
15215 LN00->getMemoryVT(),
15216 LN00->getMemOperand());
15217 APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
15218 SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
15219 ExtLoad, DAG.getConstant(Mask, DL, VT));
15220 ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
15221 bool NoReplaceTruncAnd = !N0.hasOneUse();
15222 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
15223 CombineTo(N, And);
15224 // If N0 has multiple uses, change other uses as well.
15225 if (NoReplaceTruncAnd) {
15226 SDValue TruncAnd =
15228 CombineTo(N0.getNode(), TruncAnd);
15229 }
15230 if (NoReplaceTrunc) {
15231 DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
15232 } else {
15233 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
15234 LN00->getValueType(0), ExtLoad);
15235 CombineTo(LN00, Trunc, ExtLoad.getValue(1));
15236 }
15237 return SDValue(N,0); // Return N so it doesn't get rechecked!
15238 }
15239 }
15240 }
15241
15242 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
15243 return V;
15244
15245 if (SDValue V = foldSextSetcc(N))
15246 return V;
15247
15248 // fold (sext x) -> (zext x) if the sign bit is known zero.
15249 if (!TLI.isSExtCheaperThanZExt(N0.getValueType(), VT) &&
15250 (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
15251 DAG.SignBitIsZero(N0))
15252 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0, SDNodeFlags::NonNeg);
15253
15254 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
15255 return NewVSel;
15256
15257 // Eliminate this sign extend by doing a negation in the destination type:
15258 // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
15259 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
15263 SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
15264 return DAG.getNegative(Zext, DL, VT);
15265 }
15266 // Eliminate this sign extend by doing a decrement in the destination type:
15267 // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
15268 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
15272 SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
15273 return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
15274 }
15275
15276 // fold sext (not i1 X) -> add (zext i1 X), -1
15277 // TODO: This could be extended to handle bool vectors.
15278 if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
15279 (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
15280 TLI.isOperationLegal(ISD::ADD, VT)))) {
15281 // If we can eliminate the 'not', the sext form should be better
15282 if (SDValue NewXor = visitXOR(N0.getNode())) {
15283 // Returning N0 is a form of in-visit replacement that may have
15284 // invalidated N0.
15285 if (NewXor.getNode() == N0.getNode()) {
15286 // Return SDValue here as the xor should have already been replaced in
15287 // this sext.
15288 return SDValue();
15289 }
15290
15291 // Return a new sext with the new xor.
15292 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
15293 }
15294
15295 SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
15296 return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
15297 }
15298
15299 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15300 return Res;
15301
15302 return SDValue();
15303}
15304
15305/// Given an extending node with a pop-count operand, if the target does not
15306/// support a pop-count in the narrow source type but does support it in the
15307/// destination type, widen the pop-count to the destination type.
15308static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG, const SDLoc &DL) {
15309 assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
15310 Extend->getOpcode() == ISD::ANY_EXTEND) &&
15311 "Expected extend op");
15312
15313 SDValue CtPop = Extend->getOperand(0);
15314 if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
15315 return SDValue();
15316
15317 EVT VT = Extend->getValueType(0);
15318 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
15321 return SDValue();
15322
15323 // zext (ctpop X) --> ctpop (zext X)
15324 SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
15325 return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
15326}
15327
15328// If we have (zext (abs X)) where X is a type that will be promoted by type
15329// legalization, convert to (abs (sext X)). But don't extend past a legal type.
15330static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
15331 assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
15332
15333 EVT VT = Extend->getValueType(0);
15334 if (VT.isVector())
15335 return SDValue();
15336
15337 SDValue Abs = Extend->getOperand(0);
15338 if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
15339 return SDValue();
15340
15341 EVT AbsVT = Abs.getValueType();
15342 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
15343 if (TLI.getTypeAction(*DAG.getContext(), AbsVT) !=
15345 return SDValue();
15346
15347 EVT LegalVT = TLI.getTypeToTransformTo(*DAG.getContext(), AbsVT);
15348
15349 SDValue SExt =
15350 DAG.getNode(ISD::SIGN_EXTEND, SDLoc(Abs), LegalVT, Abs.getOperand(0));
15351 SDValue NewAbs = DAG.getNode(ISD::ABS, SDLoc(Abs), LegalVT, SExt);
15352 return DAG.getZExtOrTrunc(NewAbs, SDLoc(Extend), VT);
15353}
15354
15355SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
15356 SDValue N0 = N->getOperand(0);
15357 EVT VT = N->getValueType(0);
15358 SDLoc DL(N);
15359
15360 if (VT.isVector())
15361 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
15362 return FoldedVOp;
15363
15364 // zext(undef) = 0
15365 if (N0.isUndef())
15366 return DAG.getConstant(0, DL, VT);
15367
15368 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15369 return Res;
15370
15371 // fold (zext (zext x)) -> (zext x)
15372 // fold (zext (aext x)) -> (zext x)
15373 if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
15374 SDNodeFlags Flags;
15375 if (N0.getOpcode() == ISD::ZERO_EXTEND)
15376 Flags.setNonNeg(N0->getFlags().hasNonNeg());
15377 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0), Flags);
15378 }
15379
15380 // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
15381 // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
15384 return DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, DL, VT, N0.getOperand(0));
15385
15386 // fold (zext (truncate x)) -> (zext x) or
15387 // (zext (truncate x)) -> (truncate x)
15388 // This is valid when the truncated bits of x are already zero.
15389 SDValue Op;
15390 KnownBits Known;
15391 if (isTruncateOf(DAG, N0, Op, Known)) {
15392 APInt TruncatedBits =
15393 (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
15394 APInt(Op.getScalarValueSizeInBits(), 0) :
15395 APInt::getBitsSet(Op.getScalarValueSizeInBits(),
15396 N0.getScalarValueSizeInBits(),
15397 std::min(Op.getScalarValueSizeInBits(),
15398 VT.getScalarSizeInBits()));
15399 if (TruncatedBits.isSubsetOf(Known.Zero)) {
15400 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
15401 DAG.salvageDebugInfo(*N0.getNode());
15402
15403 return ZExtOrTrunc;
15404 }
15405 }
15406
15407 // fold (zext (truncate x)) -> (and x, mask)
15408 if (N0.getOpcode() == ISD::TRUNCATE) {
15409 // fold (zext (truncate (load x))) -> (zext (smaller load x))
15410 // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
15411 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
15412 SDNode *oye = N0.getOperand(0).getNode();
15413 if (NarrowLoad.getNode() != N0.getNode()) {
15414 CombineTo(N0.getNode(), NarrowLoad);
15415 // CombineTo deleted the truncate, if needed, but not what's under it.
15416 AddToWorklist(oye);
15417 }
15418 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15419 }
15420
15421 EVT SrcVT = N0.getOperand(0).getValueType();
15422 EVT MinVT = N0.getValueType();
15423
15424 if (N->getFlags().hasNonNeg()) {
15425 SDValue Op = N0.getOperand(0);
15426 unsigned OpBits = SrcVT.getScalarSizeInBits();
15427 unsigned MidBits = MinVT.getScalarSizeInBits();
15428 unsigned DestBits = VT.getScalarSizeInBits();
15429
15430 if (N0->getFlags().hasNoSignedWrap() ||
15431 DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
15432 if (OpBits == DestBits) {
15433 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
15434 // bits, it is already ready.
15435 return Op;
15436 }
15437
15438 if (OpBits < DestBits) {
15439 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
15440 // bits, just sext from i32.
15441 // FIXME: This can probably be ZERO_EXTEND nneg?
15442 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
15443 }
15444
15445 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
15446 // bits, just truncate to i32.
15447 SDNodeFlags Flags;
15448 Flags.setNoSignedWrap(true);
15449 Flags.setNoUnsignedWrap(true);
15450 return DAG.getNode(ISD::TRUNCATE, DL, VT, Op, Flags);
15451 }
15452 }
15453
15454 // Try to mask before the extension to avoid having to generate a larger mask,
15455 // possibly over several sub-vectors.
15456 if (SrcVT.bitsLT(VT) && VT.isVector()) {
15457 if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
15459 SDValue Op = N0.getOperand(0);
15460 Op = DAG.getZeroExtendInReg(Op, DL, MinVT);
15461 AddToWorklist(Op.getNode());
15462 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
15463 // Transfer the debug info; the new node is equivalent to N0.
15464 DAG.transferDbgValues(N0, ZExtOrTrunc);
15465 return ZExtOrTrunc;
15466 }
15467 }
15468
15469 if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
15470 SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), DL, VT);
15471 AddToWorklist(Op.getNode());
15472 SDValue And = DAG.getZeroExtendInReg(Op, DL, MinVT);
15473 // We may safely transfer the debug info describing the truncate node over
15474 // to the equivalent and operation.
15475 DAG.transferDbgValues(N0, And);
15476 return And;
15477 }
15478 }
15479
15480 // Fold (zext (and (trunc x), cst)) -> (and x, cst),
15481 // if either of the casts is not free.
15482 if (N0.getOpcode() == ISD::AND &&
15483 N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
15484 N0.getOperand(1).getOpcode() == ISD::Constant &&
15485 (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0), N0.getValueType()) ||
15486 !TLI.isZExtFree(N0.getValueType(), VT))) {
15487 SDValue X = N0.getOperand(0).getOperand(0);
15488 X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
15489 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
15490 return DAG.getNode(ISD::AND, DL, VT,
15491 X, DAG.getConstant(Mask, DL, VT));
15492 }
15493
15494 // Try to simplify (zext (load x)).
15495 if (SDValue foldedExt = tryToFoldExtOfLoad(
15496 DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD,
15497 ISD::ZERO_EXTEND, N->getFlags().hasNonNeg()))
15498 return foldedExt;
15499
15500 if (SDValue foldedExt =
15501 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
15503 return foldedExt;
15504
15505 // fold (zext (load x)) to multiple smaller zextloads.
15506 // Only on illegal but splittable vectors.
15507 if (SDValue ExtLoad = CombineExtLoad(N))
15508 return ExtLoad;
15509
15510 // Try to simplify (zext (atomic_load x)).
15511 if (SDValue foldedExt =
15512 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ISD::ZEXTLOAD))
15513 return foldedExt;
15514
15515 // fold (zext (and/or/xor (load x), cst)) ->
15516 // (and/or/xor (zextload x), (zext cst))
15517 // Unless (and (load x) cst) will match as a zextload already and has
15518 // additional users, or the zext is already free.
15519 if (ISD::isBitwiseLogicOp(N0.getOpcode()) && !TLI.isZExtFree(N0, VT) &&
15520 isa<LoadSDNode>(N0.getOperand(0)) &&
15521 N0.getOperand(1).getOpcode() == ISD::Constant &&
15522 (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
15523 LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
15524 EVT MemVT = LN00->getMemoryVT();
15525 if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
15526 LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
15527 bool DoXform = true;
15529 if (!N0.hasOneUse()) {
15530 if (N0.getOpcode() == ISD::AND) {
15531 auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
15532 EVT LoadResultTy = AndC->getValueType(0);
15533 EVT ExtVT;
15534 if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
15535 DoXform = false;
15536 }
15537 }
15538 if (DoXform)
15539 DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
15540 ISD::ZERO_EXTEND, SetCCs, TLI);
15541 if (DoXform) {
15542 SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
15543 LN00->getChain(), LN00->getBasePtr(),
15544 LN00->getMemoryVT(),
15545 LN00->getMemOperand());
15546 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
15547 SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
15548 ExtLoad, DAG.getConstant(Mask, DL, VT));
15549 ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
15550 bool NoReplaceTruncAnd = !N0.hasOneUse();
15551 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
15552 CombineTo(N, And);
15553 // If N0 has multiple uses, change other uses as well.
15554 if (NoReplaceTruncAnd) {
15555 SDValue TruncAnd =
15557 CombineTo(N0.getNode(), TruncAnd);
15558 }
15559 if (NoReplaceTrunc) {
15560 DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
15561 } else {
15562 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
15563 LN00->getValueType(0), ExtLoad);
15564 CombineTo(LN00, Trunc, ExtLoad.getValue(1));
15565 }
15566 return SDValue(N,0); // Return N so it doesn't get rechecked!
15567 }
15568 }
15569 }
15570
15571 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
15572 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
15573 if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
15574 return ZExtLoad;
15575
15576 // Try to simplify (zext (zextload x)).
15577 if (SDValue foldedExt = tryToFoldExtOfExtload(
15578 DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
15579 return foldedExt;
15580
15581 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
15582 return V;
15583
15584 if (N0.getOpcode() == ISD::SETCC) {
15585 // Propagate fast-math-flags.
15586 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
15587
15588 // Only do this before legalize for now.
15589 if (!LegalOperations && VT.isVector() &&
15590 N0.getValueType().getVectorElementType() == MVT::i1) {
15591 EVT N00VT = N0.getOperand(0).getValueType();
15592 if (getSetCCResultType(N00VT) == N0.getValueType())
15593 return SDValue();
15594
15595 // We know that the # elements of the results is the same as the #
15596 // elements of the compare (and the # elements of the compare result for
15597 // that matter). Check to see that they are the same size. If so, we know
15598 // that the element size of the sext'd result matches the element size of
15599 // the compare operands.
15600 if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
15601 // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
15602 SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
15603 N0.getOperand(1), N0.getOperand(2));
15604 return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
15605 }
15606
15607 // If the desired elements are smaller or larger than the source
15608 // elements we can use a matching integer vector type and then
15609 // truncate/any extend followed by zext_in_reg.
15610 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
15611 SDValue VsetCC =
15612 DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
15613 N0.getOperand(1), N0.getOperand(2));
15614 return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
15615 N0.getValueType());
15616 }
15617
15618 // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
15619 EVT N0VT = N0.getValueType();
15620 EVT N00VT = N0.getOperand(0).getValueType();
15621 if (SDValue SCC = SimplifySelectCC(
15622 DL, N0.getOperand(0), N0.getOperand(1),
15623 DAG.getBoolConstant(true, DL, N0VT, N00VT),
15624 DAG.getBoolConstant(false, DL, N0VT, N00VT),
15625 cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
15626 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
15627 }
15628
15629 // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
15630 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
15631 !TLI.isZExtFree(N0, VT)) {
15632 SDValue ShVal = N0.getOperand(0);
15633 SDValue ShAmt = N0.getOperand(1);
15634 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(ShAmt)) {
15635 if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) {
15636 if (N0.getOpcode() == ISD::SHL) {
15637 // If the original shl may be shifting out bits, do not perform this
15638 // transformation.
15639 unsigned KnownZeroBits = ShVal.getValueSizeInBits() -
15640 ShVal.getOperand(0).getValueSizeInBits();
15641 if (ShAmtC->getAPIntValue().ugt(KnownZeroBits)) {
15642 // If the shift is too large, then see if we can deduce that the
15643 // shift is safe anyway.
15644
15645 // Check if the bits being shifted out are known to be zero.
15646 KnownBits KnownShVal = DAG.computeKnownBits(ShVal);
15647 if (ShAmtC->getAPIntValue().ugt(KnownShVal.countMinLeadingZeros()))
15648 return SDValue();
15649 }
15650 }
15651
15652 // Ensure that the shift amount is wide enough for the shifted value.
15653 if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
15654 ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
15655
15656 return DAG.getNode(N0.getOpcode(), DL, VT,
15657 DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ShVal), ShAmt);
15658 }
15659 }
15660 }
15661
15662 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
15663 return NewVSel;
15664
15665 if (SDValue NewCtPop = widenCtPop(N, DAG, DL))
15666 return NewCtPop;
15667
15668 if (SDValue V = widenAbs(N, DAG))
15669 return V;
15670
15671 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15672 return Res;
15673
15674 // CSE zext nneg with sext if the zext is not free.
15675 if (N->getFlags().hasNonNeg() && !TLI.isZExtFree(N0.getValueType(), VT)) {
15676 SDNode *CSENode = DAG.getNodeIfExists(ISD::SIGN_EXTEND, N->getVTList(), N0);
15677 if (CSENode)
15678 return SDValue(CSENode, 0);
15679 }
15680
15681 return SDValue();
15682}
15683
15684SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
15685 SDValue N0 = N->getOperand(0);
15686 EVT VT = N->getValueType(0);
15687 SDLoc DL(N);
15688
15689 // aext(undef) = undef
15690 if (N0.isUndef())
15691 return DAG.getUNDEF(VT);
15692
15693 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15694 return Res;
15695
15696 // fold (aext (aext x)) -> (aext x)
15697 // fold (aext (zext x)) -> (zext x)
15698 // fold (aext (sext x)) -> (sext x)
15699 if (N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::ZERO_EXTEND ||
15700 N0.getOpcode() == ISD::SIGN_EXTEND) {
15701 SDNodeFlags Flags;
15702 if (N0.getOpcode() == ISD::ZERO_EXTEND)
15703 Flags.setNonNeg(N0->getFlags().hasNonNeg());
15704 return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Flags);
15705 }
15706
15707 // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x)
15708 // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
15709 // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
15713 return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0));
15714
15715 // fold (aext (truncate (load x))) -> (aext (smaller load x))
15716 // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
15717 if (N0.getOpcode() == ISD::TRUNCATE) {
15718 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
15719 SDNode *oye = N0.getOperand(0).getNode();
15720 if (NarrowLoad.getNode() != N0.getNode()) {
15721 CombineTo(N0.getNode(), NarrowLoad);
15722 // CombineTo deleted the truncate, if needed, but not what's under it.
15723 AddToWorklist(oye);
15724 }
15725 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15726 }
15727 }
15728
15729 // fold (aext (truncate x))
15730 if (N0.getOpcode() == ISD::TRUNCATE)
15731 return DAG.getAnyExtOrTrunc(N0.getOperand(0), DL, VT);
15732
15733 // Fold (aext (and (trunc x), cst)) -> (and x, cst)
15734 // if the trunc is not free.
15735 if (N0.getOpcode() == ISD::AND &&
15736 N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
15737 N0.getOperand(1).getOpcode() == ISD::Constant &&
15738 !TLI.isTruncateFree(N0.getOperand(0).getOperand(0), N0.getValueType())) {
15739 SDValue X = DAG.getAnyExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
15740 SDValue Y = DAG.getNode(ISD::ANY_EXTEND, DL, VT, N0.getOperand(1));
15741 assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
15742 return DAG.getNode(ISD::AND, DL, VT, X, Y);
15743 }
15744
15745 // fold (aext (load x)) -> (aext (truncate (extload x)))
15746 // None of the supported targets knows how to perform load and any_ext
15747 // on vectors in one instruction, so attempt to fold to zext instead.
15748 if (VT.isVector()) {
15749 // Try to simplify (zext (load x)).
15750 if (SDValue foldedExt =
15751 tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
15753 return foldedExt;
15754 } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
15757 bool DoXform = true;
15759 if (!N0.hasOneUse())
15760 DoXform =
15761 ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
15762 if (DoXform) {
15763 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15764 SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, VT, LN0->getChain(),
15765 LN0->getBasePtr(), N0.getValueType(),
15766 LN0->getMemOperand());
15767 ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
15768 // If the load value is used only by N, replace it via CombineTo N.
15769 bool NoReplaceTrunc = N0.hasOneUse();
15770 CombineTo(N, ExtLoad);
15771 if (NoReplaceTrunc) {
15772 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
15773 recursivelyDeleteUnusedNodes(LN0);
15774 } else {
15775 SDValue Trunc =
15776 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
15777 CombineTo(LN0, Trunc, ExtLoad.getValue(1));
15778 }
15779 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15780 }
15781 }
15782
15783 // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
15784 // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
15785 // fold (aext ( extload x)) -> (aext (truncate (extload x)))
15786 if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
15787 ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
15788 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15789 ISD::LoadExtType ExtType = LN0->getExtensionType();
15790 EVT MemVT = LN0->getMemoryVT();
15791 if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
15792 SDValue ExtLoad =
15793 DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), LN0->getBasePtr(),
15794 MemVT, LN0->getMemOperand());
15795 CombineTo(N, ExtLoad);
15796 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
15797 recursivelyDeleteUnusedNodes(LN0);
15798 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15799 }
15800 }
15801
15802 if (N0.getOpcode() == ISD::SETCC) {
15803 // Propagate fast-math-flags.
15804 SDNodeFlags Flags = N0->getFlags();
15805 SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
15806
15807 // For vectors:
15808 // aext(setcc) -> vsetcc
15809 // aext(setcc) -> truncate(vsetcc)
15810 // aext(setcc) -> aext(vsetcc)
15811 // Only do this before legalize for now.
15812 if (VT.isVector() && !LegalOperations) {
15813 EVT N00VT = N0.getOperand(0).getValueType();
15814 if (getSetCCResultType(N00VT) == N0.getValueType())
15815 return SDValue();
15816
15817 // We know that the # elements of the results is the same as the
15818 // # elements of the compare (and the # elements of the compare result
15819 // for that matter). Check to see that they are the same size. If so,
15820 // we know that the element size of the sext'd result matches the
15821 // element size of the compare operands.
15822 if (VT.getSizeInBits() == N00VT.getSizeInBits())
15823 return DAG.getSetCC(DL, VT, N0.getOperand(0), N0.getOperand(1),
15824 cast<CondCodeSDNode>(N0.getOperand(2))->get(),
15825 /*Chain=*/{}, /*Signaling=*/false, Flags);
15826
15827 // If the desired elements are smaller or larger than the source
15828 // elements we can use a matching integer vector type and then
15829 // truncate/any extend
15830 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
15831 SDValue VsetCC = DAG.getSetCC(
15832 DL, MatchingVectorType, N0.getOperand(0), N0.getOperand(1),
15833 cast<CondCodeSDNode>(N0.getOperand(2))->get(), /*Chain=*/{},
15834 /*Signaling=*/false, Flags);
15835 return DAG.getAnyExtOrTrunc(VsetCC, DL, VT);
15836 }
15837
15838 // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
15839 if (SDValue SCC = SimplifySelectCC(
15840 DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
15841 DAG.getConstant(0, DL, VT),
15842 cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
15843 return SCC;
15844 }
15845
15846 if (SDValue NewCtPop = widenCtPop(N, DAG, DL))
15847 return NewCtPop;
15848
15849 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15850 return Res;
15851
15852 return SDValue();
15853}
15854
15855SDValue DAGCombiner::visitAssertExt(SDNode *N) {
15856 unsigned Opcode = N->getOpcode();
15857 SDValue N0 = N->getOperand(0);
15858 SDValue N1 = N->getOperand(1);
15859 EVT AssertVT = cast<VTSDNode>(N1)->getVT();
15860
15861 // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
15862 if (N0.getOpcode() == Opcode &&
15863 AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
15864 return N0;
15865
15866 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
15867 N0.getOperand(0).getOpcode() == Opcode) {
15868 // We have an assert, truncate, assert sandwich. Make one stronger assert
15869 // by asserting on the smallest asserted type to the larger source type.
15870 // This eliminates the later assert:
15871 // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
15872 // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
15873 SDLoc DL(N);
15874 SDValue BigA = N0.getOperand(0);
15875 EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
15876 EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
15877 SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
15878 SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
15879 BigA.getOperand(0), MinAssertVTVal);
15880 return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
15881 }
15882
15883 // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
15884 // than X. Just move the AssertZext in front of the truncate and drop the
15885 // AssertSExt.
15886 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
15888 Opcode == ISD::AssertZext) {
15889 SDValue BigA = N0.getOperand(0);
15890 EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
15891 if (AssertVT.bitsLT(BigA_AssertVT)) {
15892 SDLoc DL(N);
15893 SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
15894 BigA.getOperand(0), N1);
15895 return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
15896 }
15897 }
15898
15899 if (Opcode == ISD::AssertZext && N0.getOpcode() == ISD::AND &&
15901 const APInt &Mask = N0.getConstantOperandAPInt(1);
15902
15903 // If we have (AssertZext (and (AssertSext X, iX), M), iY) and Y is smaller
15904 // than X, and the And doesn't change the lower iX bits, we can move the
15905 // AssertZext in front of the And and drop the AssertSext.
15906 if (N0.getOperand(0).getOpcode() == ISD::AssertSext && N0.hasOneUse()) {
15907 SDValue BigA = N0.getOperand(0);
15908 EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
15909 if (AssertVT.bitsLT(BigA_AssertVT) &&
15910 Mask.countr_one() >= BigA_AssertVT.getScalarSizeInBits()) {
15911 SDLoc DL(N);
15912 SDValue NewAssert =
15913 DAG.getNode(Opcode, DL, N->getValueType(0), BigA.getOperand(0), N1);
15914 return DAG.getNode(ISD::AND, DL, N->getValueType(0), NewAssert,
15915 N0.getOperand(1));
15916 }
15917 }
15918
15919 // Remove AssertZext entirely if the mask guarantees the assertion cannot
15920 // fail.
15921 // TODO: Use KB countMinLeadingZeros to handle non-constant masks?
15922 if (Mask.isIntN(AssertVT.getScalarSizeInBits()))
15923 return N0;
15924 }
15925
15926 return SDValue();
15927}
15928
15929SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
15930 SDLoc DL(N);
15931
15932 Align AL = cast<AssertAlignSDNode>(N)->getAlign();
15933 SDValue N0 = N->getOperand(0);
15934
15935 // Fold (assertalign (assertalign x, AL0), AL1) ->
15936 // (assertalign x, max(AL0, AL1))
15937 if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
15938 return DAG.getAssertAlign(DL, N0.getOperand(0),
15939 std::max(AL, AAN->getAlign()));
15940
15941 // In rare cases, there are trivial arithmetic ops in source operands. Sink
15942 // this assert down to source operands so that those arithmetic ops could be
15943 // exposed to the DAG combining.
15944 switch (N0.getOpcode()) {
15945 default:
15946 break;
15947 case ISD::ADD:
15948 case ISD::PTRADD:
15949 case ISD::SUB: {
15950 unsigned AlignShift = Log2(AL);
15951 SDValue LHS = N0.getOperand(0);
15952 SDValue RHS = N0.getOperand(1);
15953 unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
15954 unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
15955 if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
15956 if (LHSAlignShift < AlignShift)
15957 LHS = DAG.getAssertAlign(DL, LHS, AL);
15958 if (RHSAlignShift < AlignShift)
15959 RHS = DAG.getAssertAlign(DL, RHS, AL);
15960 return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
15961 }
15962 break;
15963 }
15964 }
15965
15966 return SDValue();
15967}
15968
15969/// If the result of a load is shifted/masked/truncated to an effectively
15970/// narrower type, try to transform the load to a narrower type and/or
15971/// use an extending load.
15972SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
15973 unsigned Opc = N->getOpcode();
15974
15976 SDValue N0 = N->getOperand(0);
15977 EVT VT = N->getValueType(0);
15978 EVT ExtVT = VT;
15979
15980 // This transformation isn't valid for vector loads.
15981 if (VT.isVector())
15982 return SDValue();
15983
15984 // The ShAmt variable is used to indicate that we've consumed a right
15985 // shift. I.e. we want to narrow the width of the load by skipping to load the
15986 // ShAmt least significant bits.
15987 unsigned ShAmt = 0;
15988 // A special case is when the least significant bits from the load are masked
15989 // away, but using an AND rather than a right shift. HasShiftedOffset is used
15990 // to indicate that the narrowed load should be left-shifted ShAmt bits to get
15991 // the result.
15992 unsigned ShiftedOffset = 0;
15993 // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
15994 // extended to VT.
15995 if (Opc == ISD::SIGN_EXTEND_INREG) {
15996 ExtType = ISD::SEXTLOAD;
15997 ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
15998 } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
15999 // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
16000 // value, or it may be shifting a higher subword, half or byte into the
16001 // lowest bits.
16002
16003 // Only handle shift with constant shift amount, and the shiftee must be a
16004 // load.
16005 auto *LN = dyn_cast<LoadSDNode>(N0);
16006 auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
16007 if (!N1C || !LN)
16008 return SDValue();
16009 // If the shift amount is larger than the memory type then we're not
16010 // accessing any of the loaded bytes.
16011 ShAmt = N1C->getZExtValue();
16012 uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
16013 if (MemoryWidth <= ShAmt)
16014 return SDValue();
16015 // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
16016 ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
16017 ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
16018 // If original load is a SEXTLOAD then we can't simply replace it by a
16019 // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
16020 // followed by a ZEXT, but that is not handled at the moment). Similarly if
16021 // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
16022 if ((LN->getExtensionType() == ISD::SEXTLOAD ||
16023 LN->getExtensionType() == ISD::ZEXTLOAD) &&
16024 LN->getExtensionType() != ExtType)
16025 return SDValue();
16026 } else if (Opc == ISD::AND) {
16027 // An AND with a constant mask is the same as a truncate + zero-extend.
16028 auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
16029 if (!AndC)
16030 return SDValue();
16031
16032 const APInt &Mask = AndC->getAPIntValue();
16033 unsigned ActiveBits = 0;
16034 if (Mask.isMask()) {
16035 ActiveBits = Mask.countr_one();
16036 } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
16037 ShiftedOffset = ShAmt;
16038 } else {
16039 return SDValue();
16040 }
16041
16042 ExtType = ISD::ZEXTLOAD;
16043 ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
16044 }
16045
16046 // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
16047 // a right shift. Here we redo some of those checks, to possibly adjust the
16048 // ExtVT even further based on "a masking AND". We could also end up here for
16049 // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
16050 // need to be done here as well.
16051 if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
16052 SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
16053 // Bail out when the SRL has more than one use. This is done for historical
16054 // (undocumented) reasons. Maybe intent was to guard the AND-masking below
16055 // check below? And maybe it could be non-profitable to do the transform in
16056 // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
16057 // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
16058 if (!SRL.hasOneUse())
16059 return SDValue();
16060
16061 // Only handle shift with constant shift amount, and the shiftee must be a
16062 // load.
16063 auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
16064 auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
16065 if (!SRL1C || !LN)
16066 return SDValue();
16067
16068 // If the shift amount is larger than the input type then we're not
16069 // accessing any of the loaded bytes. If the load was a zextload/extload
16070 // then the result of the shift+trunc is zero/undef (handled elsewhere).
16071 ShAmt = SRL1C->getZExtValue();
16072 uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
16073 if (ShAmt >= MemoryWidth)
16074 return SDValue();
16075
16076 // Because a SRL must be assumed to *need* to zero-extend the high bits
16077 // (as opposed to anyext the high bits), we can't combine the zextload
16078 // lowering of SRL and an sextload.
16079 if (LN->getExtensionType() == ISD::SEXTLOAD)
16080 return SDValue();
16081
16082 // Avoid reading outside the memory accessed by the original load (could
16083 // happened if we only adjust the load base pointer by ShAmt). Instead we
16084 // try to narrow the load even further. The typical scenario here is:
16085 // (i64 (truncate (i96 (srl (load x), 64)))) ->
16086 // (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
16087 if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
16088 // Don't replace sextload by zextload.
16089 if (ExtType == ISD::SEXTLOAD)
16090 return SDValue();
16091 // Narrow the load.
16092 ExtType = ISD::ZEXTLOAD;
16093 ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
16094 }
16095
16096 // If the SRL is only used by a masking AND, we may be able to adjust
16097 // the ExtVT to make the AND redundant.
16098 SDNode *Mask = *(SRL->user_begin());
16099 if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
16100 isa<ConstantSDNode>(Mask->getOperand(1))) {
16101 unsigned Offset, ActiveBits;
16102 const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
16103 if (ShiftMask.isMask()) {
16104 EVT MaskedVT =
16105 EVT::getIntegerVT(*DAG.getContext(), ShiftMask.countr_one());
16106 // If the mask is smaller, recompute the type.
16107 if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
16108 TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
16109 ExtVT = MaskedVT;
16110 } else if (ExtType == ISD::ZEXTLOAD &&
16111 ShiftMask.isShiftedMask(Offset, ActiveBits) &&
16112 (Offset + ShAmt) < VT.getScalarSizeInBits()) {
16113 EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
16114 // If the mask is shifted we can use a narrower load and a shl to insert
16115 // the trailing zeros.
16116 if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) &&
16117 TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT)) {
16118 ExtVT = MaskedVT;
16119 ShAmt = Offset + ShAmt;
16120 ShiftedOffset = Offset;
16121 }
16122 }
16123 }
16124
16125 N0 = SRL.getOperand(0);
16126 }
16127
16128 // If the load is shifted left (and the result isn't shifted back right), we
16129 // can fold a truncate through the shift. The typical scenario is that N
16130 // points at a TRUNCATE here so the attempted fold is:
16131 // (truncate (shl (load x), c))) -> (shl (narrow load x), c)
16132 // ShLeftAmt will indicate how much a narrowed load should be shifted left.
16133 unsigned ShLeftAmt = 0;
16134 if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
16135 ExtVT == VT && TLI.isNarrowingProfitable(N, N0.getValueType(), VT)) {
16136 if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
16137 ShLeftAmt = N01->getZExtValue();
16138 N0 = N0.getOperand(0);
16139 }
16140 }
16141
16142 // If we haven't found a load, we can't narrow it.
16143 if (!isa<LoadSDNode>(N0))
16144 return SDValue();
16145
16146 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
16147 // Reducing the width of a volatile load is illegal. For atomics, we may be
16148 // able to reduce the width provided we never widen again. (see D66309)
16149 if (!LN0->isSimple() ||
16150 !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
16151 return SDValue();
16152
16153 auto AdjustBigEndianShift = [&](unsigned ShAmt) {
16154 unsigned LVTStoreBits =
16156 unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
16157 return LVTStoreBits - EVTStoreBits - ShAmt;
16158 };
16159
16160 // We need to adjust the pointer to the load by ShAmt bits in order to load
16161 // the correct bytes.
16162 unsigned PtrAdjustmentInBits =
16163 DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
16164
16165 uint64_t PtrOff = PtrAdjustmentInBits / 8;
16166 SDLoc DL(LN0);
16167 // The original load itself didn't wrap, so an offset within it doesn't.
16168 SDValue NewPtr =
16171 AddToWorklist(NewPtr.getNode());
16172
16173 SDValue Load;
16174 if (ExtType == ISD::NON_EXTLOAD) {
16175 const MDNode *OldRanges = LN0->getRanges();
16176 const MDNode *NewRanges = nullptr;
16177 // If LSBs are loaded and the truncated ConstantRange for the OldRanges
16178 // metadata is not the full-set for the new width then create a NewRanges
16179 // metadata for the truncated load
16180 if (ShAmt == 0 && OldRanges) {
16181 ConstantRange CR = getConstantRangeFromMetadata(*OldRanges);
16182 unsigned BitSize = VT.getScalarSizeInBits();
16183
16184 // It is possible for an 8-bit extending load with 8-bit range
16185 // metadata to be narrowed to an 8-bit load. This guard is necessary to
16186 // ensure that truncation is strictly smaller.
16187 if (CR.getBitWidth() > BitSize) {
16188 ConstantRange TruncatedCR = CR.truncate(BitSize);
16189 if (!TruncatedCR.isFullSet()) {
16190 Metadata *Bounds[2] = {
16192 ConstantInt::get(*DAG.getContext(), TruncatedCR.getLower())),
16194 ConstantInt::get(*DAG.getContext(), TruncatedCR.getUpper()))};
16195 NewRanges = MDNode::get(*DAG.getContext(), Bounds);
16196 }
16197 } else if (CR.getBitWidth() == BitSize)
16198 NewRanges = OldRanges;
16199 }
16200 Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
16201 LN0->getPointerInfo().getWithOffset(PtrOff),
16202 LN0->getBaseAlign(), LN0->getMemOperand()->getFlags(),
16203 LN0->getAAInfo(), NewRanges);
16204 } else
16205 Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
16206 LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
16207 LN0->getBaseAlign(), LN0->getMemOperand()->getFlags(),
16208 LN0->getAAInfo());
16209
16210 // Replace the old load's chain with the new load's chain.
16211 WorklistRemover DeadNodes(*this);
16212 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
16213
16214 // Shift the result left, if we've swallowed a left shift.
16216 if (ShLeftAmt != 0) {
16217 // If the shift amount is as large as the result size (but, presumably,
16218 // no larger than the source) then the useful bits of the result are
16219 // zero; we can't simply return the shortened shift, because the result
16220 // of that operation is undefined.
16221 if (ShLeftAmt >= VT.getScalarSizeInBits())
16222 Result = DAG.getConstant(0, DL, VT);
16223 else
16224 Result = DAG.getNode(ISD::SHL, DL, VT, Result,
16225 DAG.getShiftAmountConstant(ShLeftAmt, VT, DL));
16226 }
16227
16228 if (ShiftedOffset != 0) {
16229 // We're using a shifted mask, so the load now has an offset. This means
16230 // that data has been loaded into the lower bytes than it would have been
16231 // before, so we need to shl the loaded data into the correct position in the
16232 // register.
16233 SDValue ShiftC = DAG.getConstant(ShiftedOffset, DL, VT);
16234 Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
16235 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
16236 }
16237
16238 // Return the new loaded value.
16239 return Result;
16240}
16241
16242SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
16243 SDValue N0 = N->getOperand(0);
16244 SDValue N1 = N->getOperand(1);
16245 EVT VT = N->getValueType(0);
16246 EVT ExtVT = cast<VTSDNode>(N1)->getVT();
16247 unsigned VTBits = VT.getScalarSizeInBits();
16248 unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
16249 SDLoc DL(N);
16250
16251 // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
16252 if (N0.isUndef())
16253 return DAG.getConstant(0, DL, VT);
16254
16255 // fold (sext_in_reg c1) -> c1
16256 if (SDValue C =
16258 return C;
16259
16260 // If the input is already sign extended, just drop the extension.
16261 if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))
16262 return N0;
16263
16264 // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
16265 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
16266 ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
16267 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N0.getOperand(0), N1);
16268
16269 // fold (sext_in_reg (sext x)) -> (sext x)
16270 // fold (sext_in_reg (aext x)) -> (sext x)
16271 // if x is small enough or if we know that x has more than 1 sign bit and the
16272 // sign_extend_inreg is extending from one of them.
16273 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
16274 SDValue N00 = N0.getOperand(0);
16275 unsigned N00Bits = N00.getScalarValueSizeInBits();
16276 if ((N00Bits <= ExtVTBits ||
16277 DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits) &&
16278 (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
16279 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N00);
16280 }
16281
16282 // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
16283 // if x is small enough or if we know that x has more than 1 sign bit and the
16284 // sign_extend_inreg is extending from one of them.
16286 SDValue N00 = N0.getOperand(0);
16287 unsigned N00Bits = N00.getScalarValueSizeInBits();
16288 bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
16289 if ((N00Bits == ExtVTBits ||
16290 (!IsZext && (N00Bits < ExtVTBits ||
16291 DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits))) &&
16292 (!LegalOperations ||
16294 return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, DL, VT, N00);
16295 }
16296
16297 // fold (sext_in_reg (zext x)) -> (sext x)
16298 // iff we are extending the source sign bit.
16299 if (N0.getOpcode() == ISD::ZERO_EXTEND) {
16300 SDValue N00 = N0.getOperand(0);
16301 if (N00.getScalarValueSizeInBits() == ExtVTBits &&
16302 (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
16303 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N00);
16304 }
16305
16306 // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
16307 if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
16308 return DAG.getZeroExtendInReg(N0, DL, ExtVT);
16309
16310 // fold operands of sext_in_reg based on knowledge that the top bits are not
16311 // demanded.
16313 return SDValue(N, 0);
16314
16315 // fold (sext_in_reg (load x)) -> (smaller sextload x)
16316 // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
16317 if (SDValue NarrowLoad = reduceLoadWidth(N))
16318 return NarrowLoad;
16319
16320 // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
16321 // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
16322 // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
16323 if (N0.getOpcode() == ISD::SRL) {
16324 if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
16325 if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
16326 // We can turn this into an SRA iff the input to the SRL is already sign
16327 // extended enough.
16328 unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
16329 if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
16330 return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0),
16331 N0.getOperand(1));
16332 }
16333 }
16334
16335 // fold (sext_inreg (extload x)) -> (sextload x)
16336 // If sextload is not supported by target, we can only do the combine when
16337 // load has one use. Doing otherwise can block folding the extload with other
16338 // extends that the target does support.
16340 ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
16341 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
16342 N0.hasOneUse()) ||
16343 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
16344 auto *LN0 = cast<LoadSDNode>(N0);
16345 SDValue ExtLoad =
16346 DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
16347 LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
16348 CombineTo(N, ExtLoad);
16349 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
16350 AddToWorklist(ExtLoad.getNode());
16351 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16352 }
16353
16354 // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
16356 N0.hasOneUse() && ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
16357 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
16358 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
16359 auto *LN0 = cast<LoadSDNode>(N0);
16360 SDValue ExtLoad =
16361 DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
16362 LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
16363 CombineTo(N, ExtLoad);
16364 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
16365 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16366 }
16367
16368 // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
16369 // ignore it if the masked load is already sign extended
16370 bool Frozen = N0.getOpcode() == ISD::FREEZE && N0.hasOneUse();
16371 if (auto *Ld = dyn_cast<MaskedLoadSDNode>(Frozen ? N0.getOperand(0) : N0)) {
16372 if (ExtVT == Ld->getMemoryVT() && Ld->hasNUsesOfValue(1, 0) &&
16373 Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
16374 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
16375 SDValue ExtMaskedLoad = DAG.getMaskedLoad(
16376 VT, DL, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
16377 Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
16378 Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
16379 CombineTo(N, Frozen ? N0 : ExtMaskedLoad);
16380 CombineTo(Ld, ExtMaskedLoad, ExtMaskedLoad.getValue(1));
16381 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16382 }
16383 }
16384
16385 // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
16386 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
16387 if (SDValue(GN0, 0).hasOneUse() && ExtVT == GN0->getMemoryVT() &&
16389 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
16390 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
16391
16392 SDValue ExtLoad = DAG.getMaskedGather(
16393 DAG.getVTList(VT, MVT::Other), ExtVT, DL, Ops, GN0->getMemOperand(),
16394 GN0->getIndexType(), ISD::SEXTLOAD);
16395
16396 CombineTo(N, ExtLoad);
16397 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
16398 AddToWorklist(ExtLoad.getNode());
16399 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16400 }
16401 }
16402
16403 // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
16404 if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
16405 if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
16406 N0.getOperand(1), false))
16407 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, BSwap, N1);
16408 }
16409
16410 // Fold (iM_signext_inreg
16411 // (extract_subvector (zext|anyext|sext iN_v to _) _)
16412 // from iN)
16413 // -> (extract_subvector (signext iN_v to iM))
16414 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
16416 SDValue InnerExt = N0.getOperand(0);
16417 EVT InnerExtVT = InnerExt->getValueType(0);
16418 SDValue Extendee = InnerExt->getOperand(0);
16419
16420 if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
16421 (!LegalOperations ||
16422 TLI.isOperationLegal(ISD::SIGN_EXTEND, InnerExtVT))) {
16423 SDValue SignExtExtendee =
16424 DAG.getNode(ISD::SIGN_EXTEND, DL, InnerExtVT, Extendee);
16425 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, SignExtExtendee,
16426 N0.getOperand(1));
16427 }
16428 }
16429
16430 return SDValue();
16431}
16432
16434 SDNode *N, const SDLoc &DL, const TargetLowering &TLI, SelectionDAG &DAG,
16435 bool LegalOperations) {
16436 unsigned InregOpcode = N->getOpcode();
16437 unsigned Opcode = DAG.getOpcode_EXTEND(InregOpcode);
16438
16439 SDValue Src = N->getOperand(0);
16440 EVT VT = N->getValueType(0);
16441 EVT SrcVT = VT.changeVectorElementType(
16442 *DAG.getContext(), Src.getValueType().getVectorElementType());
16443
16444 assert(ISD::isExtVecInRegOpcode(InregOpcode) &&
16445 "Expected EXTEND_VECTOR_INREG dag node in input!");
16446
16447 // Profitability check: our operand must be an one-use CONCAT_VECTORS.
16448 // FIXME: one-use check may be overly restrictive
16449 if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
16450 return SDValue();
16451
16452 // Profitability check: we must be extending exactly one of it's operands.
16453 // FIXME: this is probably overly restrictive.
16454 Src = Src.getOperand(0);
16455 if (Src.getValueType() != SrcVT)
16456 return SDValue();
16457
16458 if (LegalOperations && !TLI.isOperationLegal(Opcode, VT))
16459 return SDValue();
16460
16461 return DAG.getNode(Opcode, DL, VT, Src);
16462}
16463
16464SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
16465 SDValue N0 = N->getOperand(0);
16466 EVT VT = N->getValueType(0);
16467 SDLoc DL(N);
16468
16469 if (N0.isUndef()) {
16470 // aext_vector_inreg(undef) = undef because the top bits are undefined.
16471 // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
16472 return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
16473 ? DAG.getUNDEF(VT)
16474 : DAG.getConstant(0, DL, VT);
16475 }
16476
16477 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
16478 return Res;
16479
16481 return SDValue(N, 0);
16482
16484 LegalOperations))
16485 return R;
16486
16487 return SDValue();
16488}
16489
16490SDValue DAGCombiner::visitTRUNCATE_USAT_U(SDNode *N) {
16491 EVT VT = N->getValueType(0);
16492 SDValue N0 = N->getOperand(0);
16493
16494 SDValue FPVal;
16495 if (sd_match(N0, m_FPToUI(m_Value(FPVal))) &&
16497 ISD::FP_TO_UINT_SAT, FPVal.getValueType(), VT))
16498 return DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), VT, FPVal,
16499 DAG.getValueType(VT.getScalarType()));
16500
16501 return SDValue();
16502}
16503
16504/// Detect patterns of truncation with unsigned saturation:
16505///
16506/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
16507/// Return the source value x to be truncated or SDValue() if the pattern was
16508/// not matched.
16509///
16511 unsigned NumDstBits = VT.getScalarSizeInBits();
16512 unsigned NumSrcBits = In.getScalarValueSizeInBits();
16513 // Saturation with truncation. We truncate from InVT to VT.
16514 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
16515
16516 SDValue Min;
16517 APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
16518 if (sd_match(In, m_UMin(m_Value(Min), m_SpecificInt(UnsignedMax))))
16519 return Min;
16520
16521 return SDValue();
16522}
16523
16524/// Detect patterns of truncation with signed saturation:
16525/// (truncate (smin (smax (x, signed_min_of_dest_type),
16526/// signed_max_of_dest_type)) to dest_type)
16527/// or:
16528/// (truncate (smax (smin (x, signed_max_of_dest_type),
16529/// signed_min_of_dest_type)) to dest_type).
16530///
16531/// Return the source value to be truncated or SDValue() if the pattern was not
16532/// matched.
16534 unsigned NumDstBits = VT.getScalarSizeInBits();
16535 unsigned NumSrcBits = In.getScalarValueSizeInBits();
16536 // Saturation with truncation. We truncate from InVT to VT.
16537 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
16538
16539 SDValue Val;
16540 APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
16541 APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
16542
16543 if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_SpecificInt(SignedMin)),
16544 m_SpecificInt(SignedMax))))
16545 return Val;
16546
16547 if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(SignedMax)),
16548 m_SpecificInt(SignedMin))))
16549 return Val;
16550
16551 return SDValue();
16552}
16553
16554/// Detect patterns of truncation with unsigned saturation:
16556 const SDLoc &DL) {
16557 unsigned NumDstBits = VT.getScalarSizeInBits();
16558 unsigned NumSrcBits = In.getScalarValueSizeInBits();
16559 // Saturation with truncation. We truncate from InVT to VT.
16560 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
16561
16562 SDValue Val;
16563 APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
16564 // Min == 0, Max is unsigned max of destination type.
16565 if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(UnsignedMax)),
16566 m_Zero())))
16567 return Val;
16568
16569 if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_Zero()),
16570 m_SpecificInt(UnsignedMax))))
16571 return Val;
16572
16573 if (sd_match(In, m_UMin(m_SMax(m_Value(Val), m_Zero()),
16574 m_SpecificInt(UnsignedMax))))
16575 return Val;
16576
16577 return SDValue();
16578}
16579
16580static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
16581 SDLoc &DL, const TargetLowering &TLI,
16582 SelectionDAG &DAG) {
16583 auto AllowedTruncateSat = [&](unsigned Opc, EVT SrcVT, EVT VT) -> bool {
16584 return (TLI.isOperationLegalOrCustom(Opc, SrcVT) &&
16585 TLI.isTypeDesirableForOp(Opc, VT));
16586 };
16587
16588 if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
16589 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_S, SrcVT, VT))
16590 if (SDValue SSatVal = detectSSatSPattern(Src, VT))
16591 return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
16592 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
16593 if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
16594 return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
16595 } else if (Src.getOpcode() == ISD::UMIN) {
16596 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
16597 if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
16598 return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
16599 if (AllowedTruncateSat(ISD::TRUNCATE_USAT_U, SrcVT, VT))
16600 if (SDValue USatVal = detectUSatUPattern(Src, VT))
16601 return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal);
16602 }
16603
16604 return SDValue();
16605}
16606
16607SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
16608 SDValue N0 = N->getOperand(0);
16609 EVT VT = N->getValueType(0);
16610 EVT SrcVT = N0.getValueType();
16611 bool isLE = DAG.getDataLayout().isLittleEndian();
16612 SDLoc DL(N);
16613
16614 // trunc(undef) = undef
16615 if (N0.isUndef())
16616 return DAG.getUNDEF(VT);
16617
16618 // fold (truncate (truncate x)) -> (truncate x)
16619 if (N0.getOpcode() == ISD::TRUNCATE)
16620 return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16621
16622 // fold saturated truncate
16623 if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG))
16624 return SaturatedTR;
16625
16626 // fold (truncate c1) -> c1
16627 if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
16628 return C;
16629
16630 // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
16631 if (N0.getOpcode() == ISD::ZERO_EXTEND ||
16632 N0.getOpcode() == ISD::SIGN_EXTEND ||
16633 N0.getOpcode() == ISD::ANY_EXTEND) {
16634 // if the source is smaller than the dest, we still need an extend.
16635 if (N0.getOperand(0).getValueType().bitsLT(VT)) {
16636 SDNodeFlags Flags;
16637 if (N0.getOpcode() == ISD::ZERO_EXTEND)
16638 Flags.setNonNeg(N0->getFlags().hasNonNeg());
16639 return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Flags);
16640 }
16641 // if the source is larger than the dest, than we just need the truncate.
16642 if (N0.getOperand(0).getValueType().bitsGT(VT))
16643 return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16644 // if the source and dest are the same type, we can drop both the extend
16645 // and the truncate.
16646 return N0.getOperand(0);
16647 }
16648
16649 // Try to narrow a truncate-of-sext_in_reg to the destination type:
16650 // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
16651 if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
16652 N0.hasOneUse()) {
16653 SDValue X = N0.getOperand(0);
16654 SDValue ExtVal = N0.getOperand(1);
16655 EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT();
16656 if (ExtVT.bitsLT(VT) && TLI.preferSextInRegOfTruncate(VT, SrcVT, ExtVT)) {
16657 SDValue TrX = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
16658 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, TrX, ExtVal);
16659 }
16660 }
16661
16662 // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
16663 if (N->hasOneUse() && (N->user_begin()->getOpcode() == ISD::ANY_EXTEND))
16664 return SDValue();
16665
16666 // Fold extract-and-trunc into a narrow extract. For example:
16667 // i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
16668 // i32 y = TRUNCATE(i64 x)
16669 // -- becomes --
16670 // v16i8 b = BITCAST (v2i64 val)
16671 // i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
16672 //
16673 // Note: We only run this optimization after type legalization (which often
16674 // creates this pattern) and before operation legalization after which
16675 // we need to be more careful about the vector instructions that we generate.
16676 if (LegalTypes && !LegalOperations && VT.isScalarInteger() && VT != MVT::i1 &&
16677 N0->hasOneUse()) {
16678 EVT TrTy = N->getValueType(0);
16679 SDValue Src = N0;
16680
16681 // Check for cases where we shift down an upper element before truncation.
16682 int EltOffset = 0;
16683 if (Src.getOpcode() == ISD::SRL && Src.getOperand(0)->hasOneUse()) {
16684 if (auto ShAmt = DAG.getValidShiftAmount(Src)) {
16685 if ((*ShAmt % TrTy.getSizeInBits()) == 0) {
16686 Src = Src.getOperand(0);
16687 EltOffset = *ShAmt / TrTy.getSizeInBits();
16688 }
16689 }
16690 }
16691
16692 if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
16693 EVT VecTy = Src.getOperand(0).getValueType();
16694 EVT ExTy = Src.getValueType();
16695
16696 auto EltCnt = VecTy.getVectorElementCount();
16697 unsigned SizeRatio = ExTy.getSizeInBits() / TrTy.getSizeInBits();
16698 auto NewEltCnt = EltCnt * SizeRatio;
16699
16700 EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
16701 assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
16702
16703 SDValue EltNo = Src->getOperand(1);
16704 if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
16705 int Elt = EltNo->getAsZExtVal();
16706 int Index = isLE ? (Elt * SizeRatio + EltOffset)
16707 : (Elt * SizeRatio + (SizeRatio - 1) - EltOffset);
16708 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
16709 DAG.getBitcast(NVT, Src.getOperand(0)),
16710 DAG.getVectorIdxConstant(Index, DL));
16711 }
16712 }
16713 }
16714
16715 // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
16716 if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse() &&
16717 TLI.isTruncateFree(SrcVT, VT)) {
16718 if (!LegalOperations ||
16719 (TLI.isOperationLegal(ISD::SELECT, SrcVT) &&
16720 TLI.isNarrowingProfitable(N0.getNode(), SrcVT, VT))) {
16721 SDLoc SL(N0);
16722 SDValue Cond = N0.getOperand(0);
16723 SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
16724 SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
16725 return DAG.getNode(ISD::SELECT, DL, VT, Cond, TruncOp0, TruncOp1);
16726 }
16727 }
16728
16729 // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
16730 if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
16731 (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
16732 TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
16733 SDValue Amt = N0.getOperand(1);
16734 KnownBits Known = DAG.computeKnownBits(Amt);
16735 unsigned Size = VT.getScalarSizeInBits();
16736 if (Known.countMaxActiveBits() <= Log2_32(Size)) {
16737 EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
16738 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16739 if (AmtVT != Amt.getValueType()) {
16740 Amt = DAG.getZExtOrTrunc(Amt, DL, AmtVT);
16741 AddToWorklist(Amt.getNode());
16742 }
16743 return DAG.getNode(ISD::SHL, DL, VT, Trunc, Amt);
16744 }
16745 }
16746
16747 if (SDValue V = foldSubToUSubSat(VT, N0.getNode(), DL))
16748 return V;
16749
16750 if (SDValue ABD = foldABSToABD(N, DL))
16751 return ABD;
16752
16753 // Attempt to pre-truncate BUILD_VECTOR sources.
16754 if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
16755 N0.hasOneUse() &&
16756 // Avoid creating illegal types if running after type legalizer.
16757 (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
16758 if (TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()))
16759 return DAG.UnrollVectorOp(N);
16760
16761 // trunc(build_vector(ext(x), ext(x)) -> build_vector(x,x)
16762 if (SDValue SplatVal = DAG.getSplatValue(N0)) {
16763 if (ISD::isExtOpcode(SplatVal.getOpcode()) &&
16764 SrcVT.getScalarType() == SplatVal.getValueType())
16765 return DAG.UnrollVectorOp(N);
16766 }
16767 }
16768
16769 // trunc (splat_vector x) -> splat_vector (trunc x)
16770 if (N0.getOpcode() == ISD::SPLAT_VECTOR &&
16771 (!LegalTypes || TLI.isTypeLegal(VT.getScalarType())) &&
16772 (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, VT))) {
16773 EVT SVT = VT.getScalarType();
16774 return DAG.getSplatVector(
16775 VT, DL, DAG.getNode(ISD::TRUNCATE, DL, SVT, N0->getOperand(0)));
16776 }
16777
16778 // Fold a series of buildvector, bitcast, and truncate if possible.
16779 // For example fold
16780 // (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
16781 // (2xi32 (buildvector x, y)).
16782 if (Level == AfterLegalizeVectorOps && VT.isVector() &&
16783 N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
16785 N0.getOperand(0).hasOneUse()) {
16786 SDValue BuildVect = N0.getOperand(0);
16787 EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
16788 EVT TruncVecEltTy = VT.getVectorElementType();
16789
16790 // Check that the element types match.
16791 if (BuildVectEltTy == TruncVecEltTy) {
16792 // Now we only need to compute the offset of the truncated elements.
16793 unsigned BuildVecNumElts = BuildVect.getNumOperands();
16794 unsigned TruncVecNumElts = VT.getVectorNumElements();
16795 unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
16796 unsigned FirstElt = isLE ? 0 : (TruncEltOffset - 1);
16797
16798 assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
16799 "Invalid number of elements");
16800
16802 for (unsigned i = FirstElt, e = BuildVecNumElts; i < e;
16803 i += TruncEltOffset)
16804 Opnds.push_back(BuildVect.getOperand(i));
16805
16806 return DAG.getBuildVector(VT, DL, Opnds);
16807 }
16808 }
16809
16810 // fold (truncate (load x)) -> (smaller load x)
16811 // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
16812 if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
16813 if (SDValue Reduced = reduceLoadWidth(N))
16814 return Reduced;
16815
16816 // Handle the case where the truncated result is at least as wide as the
16817 // loaded type.
16818 if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
16819 auto *LN0 = cast<LoadSDNode>(N0);
16820 if (LN0->isSimple() && LN0->getMemoryVT().bitsLE(VT)) {
16821 SDValue NewLoad = DAG.getExtLoad(
16822 LN0->getExtensionType(), SDLoc(LN0), VT, LN0->getChain(),
16823 LN0->getBasePtr(), LN0->getMemoryVT(), LN0->getMemOperand());
16824 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
16825 return NewLoad;
16826 }
16827 }
16828 }
16829
16830 // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
16831 // where ... are all 'undef'.
16832 if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
16834 SDValue V;
16835 unsigned Idx = 0;
16836 unsigned NumDefs = 0;
16837
16838 for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
16839 SDValue X = N0.getOperand(i);
16840 if (!X.isUndef()) {
16841 V = X;
16842 Idx = i;
16843 NumDefs++;
16844 }
16845 // Stop if more than one members are non-undef.
16846 if (NumDefs > 1)
16847 break;
16848
16851 X.getValueType().getVectorElementCount()));
16852 }
16853
16854 if (NumDefs == 0)
16855 return DAG.getUNDEF(VT);
16856
16857 if (NumDefs == 1) {
16858 assert(V.getNode() && "The single defined operand is empty!");
16860 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
16861 if (i != Idx) {
16862 Opnds.push_back(DAG.getUNDEF(VTs[i]));
16863 continue;
16864 }
16865 SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
16866 AddToWorklist(NV.getNode());
16867 Opnds.push_back(NV);
16868 }
16869 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Opnds);
16870 }
16871 }
16872
16873 // Fold truncate of a bitcast of a vector to an extract of the low vector
16874 // element.
16875 //
16876 // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
16877 if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
16878 SDValue VecSrc = N0.getOperand(0);
16879 EVT VecSrcVT = VecSrc.getValueType();
16880 if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
16881 (!LegalOperations ||
16882 TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
16883 unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
16884 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, VecSrc,
16885 DAG.getVectorIdxConstant(Idx, DL));
16886 }
16887 }
16888
16889 // Simplify the operands using demanded-bits information.
16891 return SDValue(N, 0);
16892
16893 // fold (truncate (extract_subvector(ext x))) ->
16894 // (extract_subvector x)
16895 // TODO: This can be generalized to cover cases where the truncate and extract
16896 // do not fully cancel each other out.
16897 if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
16898 SDValue N00 = N0.getOperand(0);
16899 if (N00.getOpcode() == ISD::SIGN_EXTEND ||
16900 N00.getOpcode() == ISD::ZERO_EXTEND ||
16901 N00.getOpcode() == ISD::ANY_EXTEND) {
16902 if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
16904 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
16905 N00.getOperand(0), N0.getOperand(1));
16906 }
16907 }
16908
16909 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
16910 return NewVSel;
16911
16912 // Narrow a suitable binary operation with a non-opaque constant operand by
16913 // moving it ahead of the truncate. This is limited to pre-legalization
16914 // because targets may prefer a wider type during later combines and invert
16915 // this transform.
16916 switch (N0.getOpcode()) {
16917 case ISD::ADD:
16918 case ISD::SUB:
16919 case ISD::MUL:
16920 case ISD::AND:
16921 case ISD::OR:
16922 case ISD::XOR:
16923 if (!LegalOperations && N0.hasOneUse() &&
16924 (N0.getOperand(0) == N0.getOperand(1) ||
16926 isConstantOrConstantVector(N0.getOperand(1), true))) {
16927 // TODO: We already restricted this to pre-legalization, but for vectors
16928 // we are extra cautious to not create an unsupported operation.
16929 // Target-specific changes are likely needed to avoid regressions here.
16930 if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
16931 SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16932 SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
16933 SDNodeFlags Flags;
16934 // Propagate nuw for sub.
16935 if (N0->getOpcode() == ISD::SUB && N0->getFlags().hasNoUnsignedWrap() &&
16937 N0->getOperand(0),
16939 VT.getScalarSizeInBits())))
16940 Flags.setNoUnsignedWrap(true);
16941 return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR, Flags);
16942 }
16943 }
16944 break;
16945 case ISD::ADDE:
16946 case ISD::UADDO_CARRY:
16947 // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
16948 // (trunc uaddo_carry(X, Y, Carry)) ->
16949 // (uaddo_carry trunc(X), trunc(Y), Carry)
16950 // When the adde's carry is not used.
16951 // We only do for uaddo_carry before legalize operation
16952 if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) ||
16953 TLI.isOperationLegal(N0.getOpcode(), VT)) &&
16954 N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) {
16955 SDValue X = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16956 SDValue Y = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
16957 SDVTList VTs = DAG.getVTList(VT, N0->getValueType(1));
16958 return DAG.getNode(N0.getOpcode(), DL, VTs, X, Y, N0.getOperand(2));
16959 }
16960 break;
16961 case ISD::USUBSAT:
16962 // Truncate the USUBSAT only if LHS is a known zero-extension, its not
16963 // enough to know that the upper bits are zero we must ensure that we don't
16964 // introduce an extra truncate.
16965 if (!LegalOperations && N0.hasOneUse() &&
16968 VT.getScalarSizeInBits() &&
16969 hasOperation(N0.getOpcode(), VT)) {
16970 return getTruncatedUSUBSAT(VT, SrcVT, N0.getOperand(0), N0.getOperand(1),
16971 DAG, DL);
16972 }
16973 break;
16974 case ISD::AVGCEILS:
16975 case ISD::AVGCEILU:
16976 // trunc (avgceilu (sext (x), sext (y))) -> avgceils(x, y)
16977 // trunc (avgceils (zext (x), zext (y))) -> avgceilu(x, y)
16978 if (N0.hasOneUse()) {
16979 SDValue Op0 = N0.getOperand(0);
16980 SDValue Op1 = N0.getOperand(1);
16981 if (N0.getOpcode() == ISD::AVGCEILU) {
16983 Op0.getOpcode() == ISD::SIGN_EXTEND &&
16984 Op1.getOpcode() == ISD::SIGN_EXTEND &&
16985 Op0.getOperand(0).getValueType() == VT &&
16986 Op1.getOperand(0).getValueType() == VT)
16987 return DAG.getNode(ISD::AVGCEILS, DL, VT, Op0.getOperand(0),
16988 Op1.getOperand(0));
16989 } else {
16991 Op0.getOpcode() == ISD::ZERO_EXTEND &&
16992 Op1.getOpcode() == ISD::ZERO_EXTEND &&
16993 Op0.getOperand(0).getValueType() == VT &&
16994 Op1.getOperand(0).getValueType() == VT)
16995 return DAG.getNode(ISD::AVGCEILU, DL, VT, Op0.getOperand(0),
16996 Op1.getOperand(0));
16997 }
16998 }
16999 [[fallthrough]];
17000 case ISD::AVGFLOORS:
17001 case ISD::AVGFLOORU:
17002 case ISD::ABDS:
17003 case ISD::ABDU:
17004 // (trunc (avg a, b)) -> (avg (trunc a), (trunc b))
17005 // (trunc (abdu/abds a, b)) -> (abdu/abds (trunc a), (trunc b))
17006 if (!LegalOperations && N0.hasOneUse() &&
17007 TLI.isOperationLegal(N0.getOpcode(), VT)) {
17008 EVT TruncVT = VT;
17009 unsigned SrcBits = SrcVT.getScalarSizeInBits();
17010 unsigned TruncBits = TruncVT.getScalarSizeInBits();
17011
17012 SDValue A = N0.getOperand(0);
17013 SDValue B = N0.getOperand(1);
17014 bool CanFold = false;
17015
17016 if (N0.getOpcode() == ISD::AVGFLOORU || N0.getOpcode() == ISD::AVGCEILU ||
17017 N0.getOpcode() == ISD::ABDU) {
17018 APInt UpperBits = APInt::getBitsSetFrom(SrcBits, TruncBits);
17019 CanFold = DAG.MaskedValueIsZero(B, UpperBits) &&
17020 DAG.MaskedValueIsZero(A, UpperBits);
17021 } else {
17022 unsigned NeededBits = SrcBits - TruncBits;
17023 CanFold = DAG.ComputeNumSignBits(B) > NeededBits &&
17024 DAG.ComputeNumSignBits(A) > NeededBits;
17025 }
17026
17027 if (CanFold) {
17028 SDValue NewA = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, A);
17029 SDValue NewB = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, B);
17030 return DAG.getNode(N0.getOpcode(), DL, TruncVT, NewA, NewB);
17031 }
17032 }
17033 break;
17034 }
17035
17036 return SDValue();
17037}
17038
17039static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
17040 SDValue Elt = N->getOperand(i);
17041 if (Elt.getOpcode() != ISD::MERGE_VALUES)
17042 return Elt.getNode();
17043 return Elt.getOperand(Elt.getResNo()).getNode();
17044}
17045
17046/// build_pair (load, load) -> load
17047/// if load locations are consecutive.
17048SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
17049 assert(N->getOpcode() == ISD::BUILD_PAIR);
17050
17051 auto *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
17052 auto *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
17053
17054 // A BUILD_PAIR is always having the least significant part in elt 0 and the
17055 // most significant part in elt 1. So when combining into one large load, we
17056 // need to consider the endianness.
17057 if (DAG.getDataLayout().isBigEndian())
17058 std::swap(LD1, LD2);
17059
17060 if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !ISD::isNON_EXTLoad(LD2) ||
17061 !LD1->hasOneUse() || !LD2->hasOneUse() ||
17062 LD1->getAddressSpace() != LD2->getAddressSpace())
17063 return SDValue();
17064
17065 unsigned LD1Fast = 0;
17066 EVT LD1VT = LD1->getValueType(0);
17067 unsigned LD1Bytes = LD1VT.getStoreSize();
17068 if ((!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) &&
17069 DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1) &&
17070 TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
17071 *LD1->getMemOperand(), &LD1Fast) && LD1Fast)
17072 return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
17073 LD1->getPointerInfo(), LD1->getAlign());
17074
17075 return SDValue();
17076}
17077
17078static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
17079 // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
17080 // and Lo parts; on big-endian machines it doesn't.
17081 return DAG.getDataLayout().isBigEndian() ? 1 : 0;
17082}
17083
17084SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
17085 const TargetLowering &TLI) {
17086 // If this is not a bitcast to an FP type or if the target doesn't have
17087 // IEEE754-compliant FP logic, we're done.
17088 EVT VT = N->getValueType(0);
17089 SDValue N0 = N->getOperand(0);
17090 EVT SourceVT = N0.getValueType();
17091
17092 if (!VT.isFloatingPoint())
17093 return SDValue();
17094
17095 // TODO: Handle cases where the integer constant is a different scalar
17096 // bitwidth to the FP.
17097 if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
17098 return SDValue();
17099
17100 unsigned FPOpcode;
17101 APInt SignMask;
17102 switch (N0.getOpcode()) {
17103 case ISD::AND:
17104 FPOpcode = ISD::FABS;
17105 SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
17106 break;
17107 case ISD::XOR:
17108 FPOpcode = ISD::FNEG;
17109 SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
17110 break;
17111 case ISD::OR:
17112 FPOpcode = ISD::FABS;
17113 SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
17114 break;
17115 default:
17116 return SDValue();
17117 }
17118
17119 if (LegalOperations && !TLI.isOperationLegal(FPOpcode, VT))
17120 return SDValue();
17121
17122 // This needs to be the inverse of logic in foldSignChangeInBitcast.
17123 // FIXME: I don't think looking for bitcast intrinsically makes sense, but
17124 // removing this would require more changes.
17125 auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) {
17126 if (sd_match(Op, m_BitCast(m_SpecificVT(VT))))
17127 return true;
17128
17129 return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
17130 };
17131
17132 // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
17133 // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
17134 // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
17135 // fneg (fabs X)
17136 SDValue LogicOp0 = N0.getOperand(0);
17137 ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
17138 if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
17139 IsBitCastOrFree(LogicOp0, VT)) {
17140 SDValue CastOp0 = DAG.getNode(ISD::BITCAST, SDLoc(N), VT, LogicOp0);
17141 SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, CastOp0);
17142 NumFPLogicOpsConv++;
17143 if (N0.getOpcode() == ISD::OR)
17144 return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
17145 return FPOp;
17146 }
17147
17148 return SDValue();
17149}
17150
17151SDValue DAGCombiner::visitBITCAST(SDNode *N) {
17152 SDValue N0 = N->getOperand(0);
17153 EVT VT = N->getValueType(0);
17154
17155 if (N0.isUndef())
17156 return DAG.getUNDEF(VT);
17157
17158 // If the input is a BUILD_VECTOR with all constant elements, fold this now.
17159 // Only do this before legalize types, unless both types are integer and the
17160 // scalar type is legal. Only do this before legalize ops, since the target
17161 // maybe depending on the bitcast.
17162 // First check to see if this is all constant.
17163 // TODO: Support FP bitcasts after legalize types.
17164 if (VT.isVector() &&
17165 (!LegalTypes ||
17166 (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
17167 TLI.isTypeLegal(VT.getVectorElementType()))) &&
17168 N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
17169 cast<BuildVectorSDNode>(N0)->isConstant())
17170 return DAG.FoldConstantBuildVector(cast<BuildVectorSDNode>(N0), SDLoc(N),
17172
17173 // If the input is a constant, let getNode fold it.
17174 if (isIntOrFPConstant(N0)) {
17175 // If we can't allow illegal operations, we need to check that this is just
17176 // a fp -> int or int -> conversion and that the resulting operation will
17177 // be legal.
17178 if (!LegalOperations ||
17179 (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
17181 (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
17182 TLI.isOperationLegal(ISD::Constant, VT))) {
17183 SDValue C = DAG.getBitcast(VT, N0);
17184 if (C.getNode() != N)
17185 return C;
17186 }
17187 }
17188
17189 // (conv (conv x, t1), t2) -> (conv x, t2)
17190 if (N0.getOpcode() == ISD::BITCAST)
17191 return DAG.getBitcast(VT, N0.getOperand(0));
17192
17193 // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c))
17194 // iff the current bitwise logicop type isn't legal
17195 if (ISD::isBitwiseLogicOp(N0.getOpcode()) && VT.isInteger() &&
17196 !TLI.isTypeLegal(N0.getOperand(0).getValueType())) {
17197 auto IsFreeBitcast = [VT](SDValue V) {
17198 return (V.getOpcode() == ISD::BITCAST &&
17199 V.getOperand(0).getValueType() == VT) ||
17201 V->hasOneUse());
17202 };
17203 if (IsFreeBitcast(N0.getOperand(0)) && IsFreeBitcast(N0.getOperand(1)))
17204 return DAG.getNode(N0.getOpcode(), SDLoc(N), VT,
17205 DAG.getBitcast(VT, N0.getOperand(0)),
17206 DAG.getBitcast(VT, N0.getOperand(1)));
17207 }
17208
17209 // fold (conv (load x)) -> (load (conv*)x)
17210 // fold (conv (freeze (load x))) -> (freeze (load (conv*)x))
17211 // If the resultant load doesn't need a higher alignment than the original!
17212 auto CastLoad = [this, &VT](SDValue N0, const SDLoc &DL) {
17213 if (N0.getOpcode() == ISD::AssertNoFPClass)
17214 N0 = N0.getOperand(0);
17215 if (!ISD::isNormalLoad(N0.getNode()) || !N0.hasOneUse())
17216 return SDValue();
17217
17218 // Do not remove the cast if the types differ in endian layout.
17221 return SDValue();
17222
17223 // If the load is volatile, we only want to change the load type if the
17224 // resulting load is legal. Otherwise we might increase the number of
17225 // memory accesses. We don't care if the original type was legal or not
17226 // as we assume software couldn't rely on the number of accesses of an
17227 // illegal type.
17228 auto *LN0 = cast<LoadSDNode>(N0);
17229 if ((LegalOperations || !LN0->isSimple()) &&
17230 !TLI.isOperationLegal(ISD::LOAD, VT))
17231 return SDValue();
17232
17233 if (!TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
17234 *LN0->getMemOperand()))
17235 return SDValue();
17236
17237 // If the range metadata type does not match the new memory
17238 // operation type, remove the range metadata.
17239 if (const MDNode *MD = LN0->getRanges()) {
17240 ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
17241 if (Lower->getBitWidth() != VT.getScalarSizeInBits() || !VT.isInteger()) {
17242 LN0->getMemOperand()->clearRanges();
17243 }
17244 }
17245 SDValue Load = DAG.getLoad(VT, DL, LN0->getChain(), LN0->getBasePtr(),
17246 LN0->getMemOperand());
17247 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
17248 return Load;
17249 };
17250
17251 if (SDValue NewLd = CastLoad(N0, SDLoc(N)))
17252 return NewLd;
17253
17254 if (N0.getOpcode() == ISD::FREEZE && N0.hasOneUse())
17255 if (SDValue NewLd = CastLoad(N0.getOperand(0), SDLoc(N)))
17256 return DAG.getFreeze(NewLd);
17257
17258 if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
17259 return V;
17260
17261 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
17262 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
17263 //
17264 // For ppc_fp128:
17265 // fold (bitcast (fneg x)) ->
17266 // flipbit = signbit
17267 // (xor (bitcast x) (build_pair flipbit, flipbit))
17268 //
17269 // fold (bitcast (fabs x)) ->
17270 // flipbit = (and (extract_element (bitcast x), 0), signbit)
17271 // (xor (bitcast x) (build_pair flipbit, flipbit))
17272 // This often reduces constant pool loads.
17273 if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
17274 (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
17275 N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
17276 !N0.getValueType().isVector()) {
17277 SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
17278 AddToWorklist(NewConv.getNode());
17279
17280 SDLoc DL(N);
17281 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
17282 assert(VT.getSizeInBits() == 128);
17283 SDValue SignBit = DAG.getConstant(
17284 APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
17285 SDValue FlipBit;
17286 if (N0.getOpcode() == ISD::FNEG) {
17287 FlipBit = SignBit;
17288 AddToWorklist(FlipBit.getNode());
17289 } else {
17290 assert(N0.getOpcode() == ISD::FABS);
17291 SDValue Hi =
17292 DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
17294 SDLoc(NewConv)));
17295 AddToWorklist(Hi.getNode());
17296 FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
17297 AddToWorklist(FlipBit.getNode());
17298 }
17299 SDValue FlipBits =
17300 DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
17301 AddToWorklist(FlipBits.getNode());
17302 return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
17303 }
17304 APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
17305 if (N0.getOpcode() == ISD::FNEG)
17306 return DAG.getNode(ISD::XOR, DL, VT,
17307 NewConv, DAG.getConstant(SignBit, DL, VT));
17308 assert(N0.getOpcode() == ISD::FABS);
17309 return DAG.getNode(ISD::AND, DL, VT,
17310 NewConv, DAG.getConstant(~SignBit, DL, VT));
17311 }
17312
17313 // fold (bitconvert (fcopysign cst, x)) ->
17314 // (or (and (bitconvert x), sign), (and cst, (not sign)))
17315 // Note that we don't handle (copysign x, cst) because this can always be
17316 // folded to an fneg or fabs.
17317 //
17318 // For ppc_fp128:
17319 // fold (bitcast (fcopysign cst, x)) ->
17320 // flipbit = (and (extract_element
17321 // (xor (bitcast cst), (bitcast x)), 0),
17322 // signbit)
17323 // (xor (bitcast cst) (build_pair flipbit, flipbit))
17324 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
17326 !VT.isVector()) {
17327 unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
17328 EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
17329 if (isTypeLegal(IntXVT)) {
17330 SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
17331 AddToWorklist(X.getNode());
17332
17333 // If X has a different width than the result/lhs, sext it or truncate it.
17334 unsigned VTWidth = VT.getSizeInBits();
17335 if (OrigXWidth < VTWidth) {
17336 X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
17337 AddToWorklist(X.getNode());
17338 } else if (OrigXWidth > VTWidth) {
17339 // To get the sign bit in the right place, we have to shift it right
17340 // before truncating.
17341 SDLoc DL(X);
17342 X = DAG.getNode(ISD::SRL, DL,
17343 X.getValueType(), X,
17344 DAG.getConstant(OrigXWidth-VTWidth, DL,
17345 X.getValueType()));
17346 AddToWorklist(X.getNode());
17347 X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
17348 AddToWorklist(X.getNode());
17349 }
17350
17351 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
17352 APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
17353 SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
17354 AddToWorklist(Cst.getNode());
17355 SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
17356 AddToWorklist(X.getNode());
17357 SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
17358 AddToWorklist(XorResult.getNode());
17359 SDValue XorResult64 = DAG.getNode(
17360 ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
17362 SDLoc(XorResult)));
17363 AddToWorklist(XorResult64.getNode());
17364 SDValue FlipBit =
17365 DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
17366 DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
17367 AddToWorklist(FlipBit.getNode());
17368 SDValue FlipBits =
17369 DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
17370 AddToWorklist(FlipBits.getNode());
17371 return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
17372 }
17373 APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
17374 X = DAG.getNode(ISD::AND, SDLoc(X), VT,
17375 X, DAG.getConstant(SignBit, SDLoc(X), VT));
17376 AddToWorklist(X.getNode());
17377
17378 SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
17379 Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
17380 Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
17381 AddToWorklist(Cst.getNode());
17382
17383 return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
17384 }
17385 }
17386
17387 // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
17388 if (N0.getOpcode() == ISD::BUILD_PAIR)
17389 if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
17390 return CombineLD;
17391
17392 // int_vt (bitcast (vec_vt (scalar_to_vector elt_vt:x)))
17393 // => int_vt (any_extend elt_vt:x)
17394 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && VT.isScalarInteger()) {
17395 SDValue SrcScalar = N0.getOperand(0);
17396 if (SrcScalar.getValueType().isScalarInteger())
17397 return DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, SrcScalar);
17398 }
17399
17400 // Remove double bitcasts from shuffles - this is often a legacy of
17401 // XformToShuffleWithZero being used to combine bitmaskings (of
17402 // float vectors bitcast to integer vectors) into shuffles.
17403 // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
17404 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
17405 N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
17408 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
17409
17410 // If operands are a bitcast, peek through if it casts the original VT.
17411 // If operands are a constant, just bitcast back to original VT.
17412 auto PeekThroughBitcast = [&](SDValue Op) {
17413 if (Op.getOpcode() == ISD::BITCAST &&
17414 Op.getOperand(0).getValueType() == VT)
17415 return SDValue(Op.getOperand(0));
17416 if (Op.isUndef() || isAnyConstantBuildVector(Op))
17417 return DAG.getBitcast(VT, Op);
17418 return SDValue();
17419 };
17420
17421 // FIXME: If either input vector is bitcast, try to convert the shuffle to
17422 // the result type of this bitcast. This would eliminate at least one
17423 // bitcast. See the transform in InstCombine.
17424 SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
17425 SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
17426 if (!(SV0 && SV1))
17427 return SDValue();
17428
17429 int MaskScale =
17431 SmallVector<int, 8> NewMask;
17432 for (int M : SVN->getMask())
17433 for (int i = 0; i != MaskScale; ++i)
17434 NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
17435
17436 SDValue LegalShuffle =
17437 TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
17438 if (LegalShuffle)
17439 return LegalShuffle;
17440 }
17441
17442 return SDValue();
17443}
17444
17445SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
17446 EVT VT = N->getValueType(0);
17447 return CombineConsecutiveLoads(N, VT);
17448}
17449
17450SDValue DAGCombiner::visitFREEZE(SDNode *N) {
17451 SDValue N0 = N->getOperand(0);
17452
17453 if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false))
17454 return N0;
17455
17456 // If we have frozen and unfrozen users of N0, update so everything uses N.
17457 if (!N0.isUndef() && !N0.hasOneUse()) {
17458 SDValue FrozenN0(N, 0);
17459 // Unfreeze all uses of N to avoid double deleting N from the CSE map.
17460 DAG.ReplaceAllUsesOfValueWith(FrozenN0, N0);
17461 DAG.ReplaceAllUsesOfValueWith(N0, FrozenN0);
17462 // ReplaceAllUsesOfValueWith will have also updated the use in N, thus
17463 // creating a cycle in a DAG. Let's undo that by mutating the freeze.
17464 assert(N->getOperand(0) == FrozenN0 && "Expected cycle in DAG");
17465 DAG.UpdateNodeOperands(N, N0);
17466 return FrozenN0;
17467 }
17468
17469 // We currently avoid folding freeze over SRA/SRL, due to the problems seen
17470 // with (freeze (assert ext)) blocking simplifications of SRA/SRL. See for
17471 // example https://reviews.llvm.org/D136529#4120959.
17472 if (N0.getOpcode() == ISD::SRA || N0.getOpcode() == ISD::SRL)
17473 return SDValue();
17474
17475 // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
17476 // Try to push freeze through instructions that propagate but don't produce
17477 // poison as far as possible. If an operand of freeze follows three
17478 // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
17479 // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
17480 // the freeze through to the operands that are not guaranteed non-poison.
17481 // NOTE: we will strip poison-generating flags, so ignore them here.
17482 if (DAG.canCreateUndefOrPoison(N0, /*PoisonOnly*/ false,
17483 /*ConsiderFlags*/ false) ||
17484 N0->getNumValues() != 1 || !N0->hasOneUse())
17485 return SDValue();
17486
17487 // TOOD: we should always allow multiple operands, however this increases the
17488 // likelihood of infinite loops due to the ReplaceAllUsesOfValueWith call
17489 // below causing later nodes that share frozen operands to fold again and no
17490 // longer being able to confirm other operands are not poison due to recursion
17491 // depth limits on isGuaranteedNotToBeUndefOrPoison.
17492 bool AllowMultipleMaybePoisonOperands =
17493 N0.getOpcode() == ISD::SELECT_CC || N0.getOpcode() == ISD::SETCC ||
17494 N0.getOpcode() == ISD::BUILD_VECTOR ||
17496 N0.getOpcode() == ISD::BUILD_PAIR ||
17499
17500 // Avoid turning a BUILD_VECTOR that can be recognized as "all zeros", "all
17501 // ones" or "constant" into something that depends on FrozenUndef. We can
17502 // instead pick undef values to keep those properties, while at the same time
17503 // folding away the freeze.
17504 // If we implement a more general solution for folding away freeze(undef) in
17505 // the future, then this special handling can be removed.
17506 if (N0.getOpcode() == ISD::BUILD_VECTOR) {
17507 SDLoc DL(N0);
17508 EVT VT = N0.getValueType();
17510 return DAG.getAllOnesConstant(DL, VT);
17513 for (const SDValue &Op : N0->op_values())
17514 NewVecC.push_back(
17515 Op.isUndef() ? DAG.getConstant(0, DL, Op.getValueType()) : Op);
17516 return DAG.getBuildVector(VT, DL, NewVecC);
17517 }
17518 }
17519
17520 SmallSet<SDValue, 8> MaybePoisonOperands;
17521 SmallVector<unsigned, 8> MaybePoisonOperandNumbers;
17522 for (auto [OpNo, Op] : enumerate(N0->ops())) {
17523 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly=*/false))
17524 continue;
17525 bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
17526 bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(Op).second;
17527 if (IsNewMaybePoisonOperand)
17528 MaybePoisonOperandNumbers.push_back(OpNo);
17529 if (!HadMaybePoisonOperands)
17530 continue;
17531 if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
17532 // Multiple maybe-poison ops when not allowed - bail out.
17533 return SDValue();
17534 }
17535 }
17536 // NOTE: the whole op may be not guaranteed to not be undef or poison because
17537 // it could create undef or poison due to it's poison-generating flags.
17538 // So not finding any maybe-poison operands is fine.
17539
17540 for (unsigned OpNo : MaybePoisonOperandNumbers) {
17541 // N0 can mutate during iteration, so make sure to refetch the maybe poison
17542 // operands via the operand numbers. The typical scenario is that we have
17543 // something like this
17544 // t262: i32 = freeze t181
17545 // t150: i32 = ctlz_zero_undef t262
17546 // t184: i32 = ctlz_zero_undef t181
17547 // t268: i32 = select_cc t181, Constant:i32<0>, t184, t186, setne:ch
17548 // When freezing the t181 operand we get t262 back, and then the
17549 // ReplaceAllUsesOfValueWith call will not only replace t181 by t262, but
17550 // also recursively replace t184 by t150.
17551 SDValue MaybePoisonOperand = N->getOperand(0).getOperand(OpNo);
17552 // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
17553 if (MaybePoisonOperand.isUndef())
17554 continue;
17555 // First, freeze each offending operand.
17556 SDValue FrozenMaybePoisonOperand = DAG.getFreeze(MaybePoisonOperand);
17557 // Then, change all other uses of unfrozen operand to use frozen operand.
17558 DAG.ReplaceAllUsesOfValueWith(MaybePoisonOperand, FrozenMaybePoisonOperand);
17559 if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
17560 FrozenMaybePoisonOperand.getOperand(0) == FrozenMaybePoisonOperand) {
17561 // But, that also updated the use in the freeze we just created, thus
17562 // creating a cycle in a DAG. Let's undo that by mutating the freeze.
17563 DAG.UpdateNodeOperands(FrozenMaybePoisonOperand.getNode(),
17564 MaybePoisonOperand);
17565 }
17566
17567 // This node has been merged with another.
17568 if (N->getOpcode() == ISD::DELETED_NODE)
17569 return SDValue(N, 0);
17570 }
17571
17572 assert(N->getOpcode() != ISD::DELETED_NODE && "Node was deleted!");
17573
17574 // The whole node may have been updated, so the value we were holding
17575 // may no longer be valid. Re-fetch the operand we're `freeze`ing.
17576 N0 = N->getOperand(0);
17577
17578 // Finally, recreate the node, it's operands were updated to use
17579 // frozen operands, so we just need to use it's "original" operands.
17581 // TODO: ISD::UNDEF and ISD::POISON should get separate handling, but best
17582 // leave for a future patch.
17583 for (SDValue &Op : Ops) {
17584 if (Op.isUndef())
17585 Op = DAG.getFreeze(Op);
17586 }
17587
17588 SDLoc DL(N0);
17589
17590 // Special case handling for ShuffleVectorSDNode nodes.
17591 if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(N0))
17592 return DAG.getVectorShuffle(N0.getValueType(), DL, Ops[0], Ops[1],
17593 SVN->getMask());
17594
17595 // NOTE: this strips poison generating flags.
17596 // Folding freeze(op(x, ...)) -> op(freeze(x), ...) does not require nnan,
17597 // ninf, nsz, or fast.
17598 // However, contract, reassoc, afn, and arcp should be preserved,
17599 // as these fast-math flags do not introduce poison values.
17600 SDNodeFlags SrcFlags = N0->getFlags();
17601 SDNodeFlags SafeFlags;
17602 SafeFlags.setAllowContract(SrcFlags.hasAllowContract());
17603 SafeFlags.setAllowReassociation(SrcFlags.hasAllowReassociation());
17604 SafeFlags.setApproximateFuncs(SrcFlags.hasApproximateFuncs());
17605 SafeFlags.setAllowReciprocal(SrcFlags.hasAllowReciprocal());
17606 return DAG.getNode(N0.getOpcode(), DL, N0->getVTList(), Ops, SafeFlags);
17607}
17608
17609// Returns true if floating point contraction is allowed on the FMUL-SDValue
17610// `N`
17612 assert(N.getOpcode() == ISD::FMUL);
17613
17614 return Options.AllowFPOpFusion == FPOpFusion::Fast ||
17615 N->getFlags().hasAllowContract();
17616}
17617
17618/// Try to perform FMA combining on a given FADD node.
17619template <class MatchContextClass>
17620SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
17621 SDValue N0 = N->getOperand(0);
17622 SDValue N1 = N->getOperand(1);
17623 EVT VT = N->getValueType(0);
17624 SDLoc SL(N);
17625 MatchContextClass matcher(DAG, TLI, N);
17626 const TargetOptions &Options = DAG.getTarget().Options;
17627
17628 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
17629
17630 // Floating-point multiply-add with intermediate rounding.
17631 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
17632 // FIXME: Add VP_FMAD opcode.
17633 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
17634
17635 // Floating-point multiply-add without intermediate rounding.
17636 bool HasFMA =
17637 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)) &&
17639
17640 // No valid opcode, do not combine.
17641 if (!HasFMAD && !HasFMA)
17642 return SDValue();
17643
17644 bool AllowFusionGlobally =
17645 Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD;
17646 // If the addition is not contractable, do not combine.
17647 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
17648 return SDValue();
17649
17650 // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never
17651 // beneficial. It does not reduce latency. It increases register pressure. It
17652 // replaces an fadd with an fma which is a more complex instruction, so is
17653 // likely to have a larger encoding, use more functional units, etc.
17654 if (N0 == N1)
17655 return SDValue();
17656
17657 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
17658 return SDValue();
17659
17660 // Always prefer FMAD to FMA for precision.
17661 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
17663
17664 auto isFusedOp = [&](SDValue N) {
17665 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
17666 };
17667
17668 // Is the node an FMUL and contractable either due to global flags or
17669 // SDNodeFlags.
17670 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
17671 if (!matcher.match(N, ISD::FMUL))
17672 return false;
17673 return AllowFusionGlobally || N->getFlags().hasAllowContract();
17674 };
17675 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
17676 // prefer to fold the multiply with fewer uses.
17678 if (N0->use_size() > N1->use_size())
17679 std::swap(N0, N1);
17680 }
17681
17682 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
17683 if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
17684 return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
17685 N0.getOperand(1), N1);
17686 }
17687
17688 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
17689 // Note: Commutes FADD operands.
17690 if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
17691 return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
17692 N1.getOperand(1), N0);
17693 }
17694
17695 // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
17696 // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
17697 // This also works with nested fma instructions:
17698 // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
17699 // fma A, B, (fma C, D, fma (E, F, G))
17700 // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
17701 // fma A, B, (fma C, D, fma (E, F, G)).
17702 // This requires reassociation because it changes the order of operations.
17703 bool CanReassociate = N->getFlags().hasAllowReassociation();
17704 if (CanReassociate) {
17705 SDValue FMA, E;
17706 if (isFusedOp(N0) && N0.hasOneUse()) {
17707 FMA = N0;
17708 E = N1;
17709 } else if (isFusedOp(N1) && N1.hasOneUse()) {
17710 FMA = N1;
17711 E = N0;
17712 }
17713
17714 SDValue TmpFMA = FMA;
17715 while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
17716 SDValue FMul = TmpFMA->getOperand(2);
17717 if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) {
17718 SDValue C = FMul.getOperand(0);
17719 SDValue D = FMul.getOperand(1);
17720 SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
17722 // Replacing the inner FMul could cause the outer FMA to be simplified
17723 // away.
17724 return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue(N, 0) : FMA;
17725 }
17726
17727 TmpFMA = TmpFMA->getOperand(2);
17728 }
17729 }
17730
17731 // Look through FP_EXTEND nodes to do more combining.
17732
17733 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
17734 if (matcher.match(N0, ISD::FP_EXTEND)) {
17735 SDValue N00 = N0.getOperand(0);
17736 if (isContractableFMUL(N00) &&
17737 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17738 N00.getValueType())) {
17739 return matcher.getNode(
17740 PreferredFusedOpcode, SL, VT,
17741 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
17742 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1);
17743 }
17744 }
17745
17746 // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
17747 // Note: Commutes FADD operands.
17748 if (matcher.match(N1, ISD::FP_EXTEND)) {
17749 SDValue N10 = N1.getOperand(0);
17750 if (isContractableFMUL(N10) &&
17751 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17752 N10.getValueType())) {
17753 return matcher.getNode(
17754 PreferredFusedOpcode, SL, VT,
17755 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
17756 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
17757 }
17758 }
17759
17760 // More folding opportunities when target permits.
17761 if (Aggressive) {
17762 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
17763 // -> (fma x, y, (fma (fpext u), (fpext v), z))
17764 auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
17765 SDValue Z) {
17766 return matcher.getNode(
17767 PreferredFusedOpcode, SL, VT, X, Y,
17768 matcher.getNode(PreferredFusedOpcode, SL, VT,
17769 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
17770 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
17771 };
17772 if (isFusedOp(N0)) {
17773 SDValue N02 = N0.getOperand(2);
17774 if (matcher.match(N02, ISD::FP_EXTEND)) {
17775 SDValue N020 = N02.getOperand(0);
17776 if (isContractableFMUL(N020) &&
17777 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17778 N020.getValueType())) {
17779 return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
17780 N020.getOperand(0), N020.getOperand(1),
17781 N1);
17782 }
17783 }
17784 }
17785
17786 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
17787 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
17788 // FIXME: This turns two single-precision and one double-precision
17789 // operation into two double-precision operations, which might not be
17790 // interesting for all targets, especially GPUs.
17791 auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
17792 SDValue Z) {
17793 return matcher.getNode(
17794 PreferredFusedOpcode, SL, VT,
17795 matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
17796 matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
17797 matcher.getNode(PreferredFusedOpcode, SL, VT,
17798 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
17799 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
17800 };
17801 if (N0.getOpcode() == ISD::FP_EXTEND) {
17802 SDValue N00 = N0.getOperand(0);
17803 if (isFusedOp(N00)) {
17804 SDValue N002 = N00.getOperand(2);
17805 if (isContractableFMUL(N002) &&
17806 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17807 N00.getValueType())) {
17808 return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
17809 N002.getOperand(0), N002.getOperand(1),
17810 N1);
17811 }
17812 }
17813 }
17814
17815 // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
17816 // -> (fma y, z, (fma (fpext u), (fpext v), x))
17817 if (isFusedOp(N1)) {
17818 SDValue N12 = N1.getOperand(2);
17819 if (N12.getOpcode() == ISD::FP_EXTEND) {
17820 SDValue N120 = N12.getOperand(0);
17821 if (isContractableFMUL(N120) &&
17822 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17823 N120.getValueType())) {
17824 return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
17825 N120.getOperand(0), N120.getOperand(1),
17826 N0);
17827 }
17828 }
17829 }
17830
17831 // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
17832 // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
17833 // FIXME: This turns two single-precision and one double-precision
17834 // operation into two double-precision operations, which might not be
17835 // interesting for all targets, especially GPUs.
17836 if (N1.getOpcode() == ISD::FP_EXTEND) {
17837 SDValue N10 = N1.getOperand(0);
17838 if (isFusedOp(N10)) {
17839 SDValue N102 = N10.getOperand(2);
17840 if (isContractableFMUL(N102) &&
17841 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17842 N10.getValueType())) {
17843 return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
17844 N102.getOperand(0), N102.getOperand(1),
17845 N0);
17846 }
17847 }
17848 }
17849 }
17850
17851 return SDValue();
17852}
17853
17854/// Try to perform FMA combining on a given FSUB node.
17855template <class MatchContextClass>
17856SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
17857 SDValue N0 = N->getOperand(0);
17858 SDValue N1 = N->getOperand(1);
17859 EVT VT = N->getValueType(0);
17860 SDLoc SL(N);
17861 MatchContextClass matcher(DAG, TLI, N);
17862 const TargetOptions &Options = DAG.getTarget().Options;
17863
17864 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
17865
17866 // Floating-point multiply-add with intermediate rounding.
17867 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
17868 // FIXME: Add VP_FMAD opcode.
17869 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
17870
17871 // Floating-point multiply-add without intermediate rounding.
17872 bool HasFMA =
17873 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)) &&
17875
17876 // No valid opcode, do not combine.
17877 if (!HasFMAD && !HasFMA)
17878 return SDValue();
17879
17880 const SDNodeFlags Flags = N->getFlags();
17881 bool AllowFusionGlobally =
17882 (Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD);
17883
17884 // If the subtraction is not contractable, do not combine.
17885 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
17886 return SDValue();
17887
17888 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
17889 return SDValue();
17890
17891 // Always prefer FMAD to FMA for precision.
17892 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
17894 bool NoSignedZero = Flags.hasNoSignedZeros();
17895
17896 // Is the node an FMUL and contractable either due to global flags or
17897 // SDNodeFlags.
17898 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
17899 if (!matcher.match(N, ISD::FMUL))
17900 return false;
17901 return AllowFusionGlobally || N->getFlags().hasAllowContract();
17902 };
17903
17904 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
17905 auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
17906 if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
17907 return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
17908 XY.getOperand(1),
17909 matcher.getNode(ISD::FNEG, SL, VT, Z));
17910 }
17911 return SDValue();
17912 };
17913
17914 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
17915 // Note: Commutes FSUB operands.
17916 auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
17917 if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
17918 return matcher.getNode(
17919 PreferredFusedOpcode, SL, VT,
17920 matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
17921 YZ.getOperand(1), X);
17922 }
17923 return SDValue();
17924 };
17925
17926 // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
17927 // prefer to fold the multiply with fewer uses.
17928 if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
17929 (N0->use_size() > N1->use_size())) {
17930 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
17931 if (SDValue V = tryToFoldXSubYZ(N0, N1))
17932 return V;
17933 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
17934 if (SDValue V = tryToFoldXYSubZ(N0, N1))
17935 return V;
17936 } else {
17937 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
17938 if (SDValue V = tryToFoldXYSubZ(N0, N1))
17939 return V;
17940 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
17941 if (SDValue V = tryToFoldXSubYZ(N0, N1))
17942 return V;
17943 }
17944
17945 // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
17946 if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(0)) &&
17947 (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
17948 SDValue N00 = N0.getOperand(0).getOperand(0);
17949 SDValue N01 = N0.getOperand(0).getOperand(1);
17950 return matcher.getNode(PreferredFusedOpcode, SL, VT,
17951 matcher.getNode(ISD::FNEG, SL, VT, N00), N01,
17952 matcher.getNode(ISD::FNEG, SL, VT, N1));
17953 }
17954
17955 // Look through FP_EXTEND nodes to do more combining.
17956
17957 // fold (fsub (fpext (fmul x, y)), z)
17958 // -> (fma (fpext x), (fpext y), (fneg z))
17959 if (matcher.match(N0, ISD::FP_EXTEND)) {
17960 SDValue N00 = N0.getOperand(0);
17961 if (isContractableFMUL(N00) &&
17962 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17963 N00.getValueType())) {
17964 return matcher.getNode(
17965 PreferredFusedOpcode, SL, VT,
17966 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
17967 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
17968 matcher.getNode(ISD::FNEG, SL, VT, N1));
17969 }
17970 }
17971
17972 // fold (fsub x, (fpext (fmul y, z)))
17973 // -> (fma (fneg (fpext y)), (fpext z), x)
17974 // Note: Commutes FSUB operands.
17975 if (matcher.match(N1, ISD::FP_EXTEND)) {
17976 SDValue N10 = N1.getOperand(0);
17977 if (isContractableFMUL(N10) &&
17978 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17979 N10.getValueType())) {
17980 return matcher.getNode(
17981 PreferredFusedOpcode, SL, VT,
17982 matcher.getNode(
17983 ISD::FNEG, SL, VT,
17984 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
17985 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
17986 }
17987 }
17988
17989 // fold (fsub (fpext (fneg (fmul, x, y))), z)
17990 // -> (fneg (fma (fpext x), (fpext y), z))
17991 // Note: This could be removed with appropriate canonicalization of the
17992 // input expression into (fneg (fadd (fpext (fmul, x, y)), z)). However, the
17993 // command line flag -fp-contract=fast and fast-math flag contract prevent
17994 // from implementing the canonicalization in visitFSUB.
17995 if (matcher.match(N0, ISD::FP_EXTEND)) {
17996 SDValue N00 = N0.getOperand(0);
17997 if (matcher.match(N00, ISD::FNEG)) {
17998 SDValue N000 = N00.getOperand(0);
17999 if (isContractableFMUL(N000) &&
18000 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
18001 N00.getValueType())) {
18002 return matcher.getNode(
18003 ISD::FNEG, SL, VT,
18004 matcher.getNode(
18005 PreferredFusedOpcode, SL, VT,
18006 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
18007 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
18008 N1));
18009 }
18010 }
18011 }
18012
18013 // fold (fsub (fneg (fpext (fmul, x, y))), z)
18014 // -> (fneg (fma (fpext x)), (fpext y), z)
18015 // Note: This could be removed with appropriate canonicalization of the
18016 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
18017 // command line flag -fp-contract=fast and fast-math flag contract prevent
18018 // from implementing the canonicalization in visitFSUB.
18019 if (matcher.match(N0, ISD::FNEG)) {
18020 SDValue N00 = N0.getOperand(0);
18021 if (matcher.match(N00, ISD::FP_EXTEND)) {
18022 SDValue N000 = N00.getOperand(0);
18023 if (isContractableFMUL(N000) &&
18024 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
18025 N000.getValueType())) {
18026 return matcher.getNode(
18027 ISD::FNEG, SL, VT,
18028 matcher.getNode(
18029 PreferredFusedOpcode, SL, VT,
18030 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
18031 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
18032 N1));
18033 }
18034 }
18035 }
18036
18037 auto isContractableAndReassociableFMUL = [&isContractableFMUL](SDValue N) {
18038 return isContractableFMUL(N) && N->getFlags().hasAllowReassociation();
18039 };
18040
18041 auto isFusedOp = [&](SDValue N) {
18042 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
18043 };
18044
18045 // More folding opportunities when target permits.
18046 if (Aggressive && N->getFlags().hasAllowReassociation()) {
18047 bool CanFuse = N->getFlags().hasAllowContract();
18048 // fold (fsub (fma x, y, (fmul u, v)), z)
18049 // -> (fma x, y (fma u, v, (fneg z)))
18050 if (CanFuse && isFusedOp(N0) &&
18051 isContractableAndReassociableFMUL(N0.getOperand(2)) &&
18052 N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
18053 return matcher.getNode(
18054 PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
18055 matcher.getNode(PreferredFusedOpcode, SL, VT,
18056 N0.getOperand(2).getOperand(0),
18057 N0.getOperand(2).getOperand(1),
18058 matcher.getNode(ISD::FNEG, SL, VT, N1)));
18059 }
18060
18061 // fold (fsub x, (fma y, z, (fmul u, v)))
18062 // -> (fma (fneg y), z, (fma (fneg u), v, x))
18063 if (CanFuse && isFusedOp(N1) &&
18064 isContractableAndReassociableFMUL(N1.getOperand(2)) &&
18065 N1->hasOneUse() && NoSignedZero) {
18066 SDValue N20 = N1.getOperand(2).getOperand(0);
18067 SDValue N21 = N1.getOperand(2).getOperand(1);
18068 return matcher.getNode(
18069 PreferredFusedOpcode, SL, VT,
18070 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
18071 N1.getOperand(1),
18072 matcher.getNode(PreferredFusedOpcode, SL, VT,
18073 matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
18074 }
18075
18076 // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
18077 // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
18078 if (isFusedOp(N0) && N0->hasOneUse()) {
18079 SDValue N02 = N0.getOperand(2);
18080 if (matcher.match(N02, ISD::FP_EXTEND)) {
18081 SDValue N020 = N02.getOperand(0);
18082 if (isContractableAndReassociableFMUL(N020) &&
18083 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
18084 N020.getValueType())) {
18085 return matcher.getNode(
18086 PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
18087 matcher.getNode(
18088 PreferredFusedOpcode, SL, VT,
18089 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
18090 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
18091 matcher.getNode(ISD::FNEG, SL, VT, N1)));
18092 }
18093 }
18094 }
18095
18096 // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
18097 // -> (fma (fpext x), (fpext y),
18098 // (fma (fpext u), (fpext v), (fneg z)))
18099 // FIXME: This turns two single-precision and one double-precision
18100 // operation into two double-precision operations, which might not be
18101 // interesting for all targets, especially GPUs.
18102 if (matcher.match(N0, ISD::FP_EXTEND)) {
18103 SDValue N00 = N0.getOperand(0);
18104 if (isFusedOp(N00)) {
18105 SDValue N002 = N00.getOperand(2);
18106 if (isContractableAndReassociableFMUL(N002) &&
18107 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
18108 N00.getValueType())) {
18109 return matcher.getNode(
18110 PreferredFusedOpcode, SL, VT,
18111 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
18112 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
18113 matcher.getNode(
18114 PreferredFusedOpcode, SL, VT,
18115 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
18116 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
18117 matcher.getNode(ISD::FNEG, SL, VT, N1)));
18118 }
18119 }
18120 }
18121
18122 // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
18123 // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
18124 if (isFusedOp(N1) && matcher.match(N1.getOperand(2), ISD::FP_EXTEND) &&
18125 N1->hasOneUse()) {
18126 SDValue N120 = N1.getOperand(2).getOperand(0);
18127 if (isContractableAndReassociableFMUL(N120) &&
18128 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
18129 N120.getValueType())) {
18130 SDValue N1200 = N120.getOperand(0);
18131 SDValue N1201 = N120.getOperand(1);
18132 return matcher.getNode(
18133 PreferredFusedOpcode, SL, VT,
18134 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
18135 N1.getOperand(1),
18136 matcher.getNode(
18137 PreferredFusedOpcode, SL, VT,
18138 matcher.getNode(ISD::FNEG, SL, VT,
18139 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
18140 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
18141 }
18142 }
18143
18144 // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
18145 // -> (fma (fneg (fpext y)), (fpext z),
18146 // (fma (fneg (fpext u)), (fpext v), x))
18147 // FIXME: This turns two single-precision and one double-precision
18148 // operation into two double-precision operations, which might not be
18149 // interesting for all targets, especially GPUs.
18150 if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(0))) {
18151 SDValue CvtSrc = N1.getOperand(0);
18152 SDValue N100 = CvtSrc.getOperand(0);
18153 SDValue N101 = CvtSrc.getOperand(1);
18154 SDValue N102 = CvtSrc.getOperand(2);
18155 if (isContractableAndReassociableFMUL(N102) &&
18156 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
18157 CvtSrc.getValueType())) {
18158 SDValue N1020 = N102.getOperand(0);
18159 SDValue N1021 = N102.getOperand(1);
18160 return matcher.getNode(
18161 PreferredFusedOpcode, SL, VT,
18162 matcher.getNode(ISD::FNEG, SL, VT,
18163 matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)),
18164 matcher.getNode(ISD::FP_EXTEND, SL, VT, N101),
18165 matcher.getNode(
18166 PreferredFusedOpcode, SL, VT,
18167 matcher.getNode(ISD::FNEG, SL, VT,
18168 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
18169 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
18170 }
18171 }
18172 }
18173
18174 return SDValue();
18175}
18176
18177/// Try to perform FMA combining on a given FMUL node based on the distributive
18178/// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
18179/// subtraction instead of addition).
18180SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
18181 SDValue N0 = N->getOperand(0);
18182 SDValue N1 = N->getOperand(1);
18183 EVT VT = N->getValueType(0);
18184 SDLoc SL(N);
18185
18186 assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
18187
18188 const TargetOptions &Options = DAG.getTarget().Options;
18189
18190 // The transforms below are incorrect when x == 0 and y == inf, because the
18191 // intermediate multiplication produces a nan.
18192 SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
18193 if (!FAdd->getFlags().hasNoInfs())
18194 return SDValue();
18195
18196 // Floating-point multiply-add without intermediate rounding.
18197 bool HasFMA =
18199 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)) &&
18201
18202 // Floating-point multiply-add with intermediate rounding. This can result
18203 // in a less precise result due to the changed rounding order.
18204 bool HasFMAD = LegalOperations && TLI.isFMADLegal(DAG, N);
18205
18206 // No valid opcode, do not combine.
18207 if (!HasFMAD && !HasFMA)
18208 return SDValue();
18209
18210 // Always prefer FMAD to FMA for precision.
18211 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
18213
18214 // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
18215 // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
18216 auto FuseFADD = [&](SDValue X, SDValue Y) {
18217 if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
18218 if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
18219 if (C->isExactlyValue(+1.0))
18220 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
18221 Y);
18222 if (C->isExactlyValue(-1.0))
18223 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
18224 DAG.getNode(ISD::FNEG, SL, VT, Y));
18225 }
18226 }
18227 return SDValue();
18228 };
18229
18230 if (SDValue FMA = FuseFADD(N0, N1))
18231 return FMA;
18232 if (SDValue FMA = FuseFADD(N1, N0))
18233 return FMA;
18234
18235 // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
18236 // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
18237 // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
18238 // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
18239 auto FuseFSUB = [&](SDValue X, SDValue Y) {
18240 if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
18241 if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
18242 if (C0->isExactlyValue(+1.0))
18243 return DAG.getNode(PreferredFusedOpcode, SL, VT,
18244 DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
18245 Y);
18246 if (C0->isExactlyValue(-1.0))
18247 return DAG.getNode(PreferredFusedOpcode, SL, VT,
18248 DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
18249 DAG.getNode(ISD::FNEG, SL, VT, Y));
18250 }
18251 if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
18252 if (C1->isExactlyValue(+1.0))
18253 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
18254 DAG.getNode(ISD::FNEG, SL, VT, Y));
18255 if (C1->isExactlyValue(-1.0))
18256 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
18257 Y);
18258 }
18259 }
18260 return SDValue();
18261 };
18262
18263 if (SDValue FMA = FuseFSUB(N0, N1))
18264 return FMA;
18265 if (SDValue FMA = FuseFSUB(N1, N0))
18266 return FMA;
18267
18268 return SDValue();
18269}
18270
18271SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
18272 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18273
18274 // FADD -> FMA combines:
18275 if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
18276 if (Fused.getOpcode() != ISD::DELETED_NODE)
18277 AddToWorklist(Fused.getNode());
18278 return Fused;
18279 }
18280 return SDValue();
18281}
18282
18283SDValue DAGCombiner::visitFADD(SDNode *N) {
18284 SDValue N0 = N->getOperand(0);
18285 SDValue N1 = N->getOperand(1);
18286 bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
18287 bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
18288 EVT VT = N->getValueType(0);
18289 SDLoc DL(N);
18290 SDNodeFlags Flags = N->getFlags();
18291 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18292
18293 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
18294 return R;
18295
18296 // fold (fadd c1, c2) -> c1 + c2
18297 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FADD, DL, VT, {N0, N1}))
18298 return C;
18299
18300 // canonicalize constant to RHS
18301 if (N0CFP && !N1CFP)
18302 return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
18303
18304 // fold vector ops
18305 if (VT.isVector())
18306 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18307 return FoldedVOp;
18308
18309 // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
18310 ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
18311 if (N1C && N1C->isZero())
18312 if (N1C->isNegative() || DAG.canIgnoreSignBitOfZero(SDValue(N, 0)))
18313 return N0;
18314
18315 if (SDValue NewSel = foldBinOpIntoSelect(N))
18316 return NewSel;
18317
18318 // fold (fadd A, (fneg B)) -> (fsub A, B)
18319 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
18320 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
18321 N1, DAG, LegalOperations, ForCodeSize))
18322 return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
18323
18324 // fold (fadd (fneg A), B) -> (fsub B, A)
18325 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
18326 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
18327 N0, DAG, LegalOperations, ForCodeSize))
18328 return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
18329
18330 auto isFMulNegTwo = [](SDValue FMul) {
18331 if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
18332 return false;
18333 auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
18334 return C && C->isExactlyValue(-2.0);
18335 };
18336
18337 // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
18338 if (isFMulNegTwo(N0)) {
18339 SDValue B = N0.getOperand(0);
18340 SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
18341 return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
18342 }
18343 // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
18344 if (isFMulNegTwo(N1)) {
18345 SDValue B = N1.getOperand(0);
18346 SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
18347 return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
18348 }
18349
18350 // No FP constant should be created after legalization as Instruction
18351 // Selection pass has a hard time dealing with FP constants.
18352 bool AllowNewConst = (Level < AfterLegalizeDAG);
18353
18354 // If nnan is enabled, fold lots of things.
18355 if (Flags.hasNoNaNs() && AllowNewConst) {
18356 // If allowed, fold (fadd (fneg x), x) -> 0.0
18357 if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
18358 return DAG.getConstantFP(0.0, DL, VT);
18359
18360 // If allowed, fold (fadd x, (fneg x)) -> 0.0
18361 if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
18362 return DAG.getConstantFP(0.0, DL, VT);
18363 }
18364
18365 // If reassoc and nsz, fold lots of things.
18366 // TODO: break out portions of the transformations below for which Unsafe is
18367 // considered and which do not require both nsz and reassoc
18368 if (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros() &&
18369 AllowNewConst) {
18370 // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
18371 if (N1CFP && N0.getOpcode() == ISD::FADD &&
18373 SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
18374 return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
18375 }
18376
18377 // We can fold chains of FADD's of the same value into multiplications.
18378 // This transform is not safe in general because we are reducing the number
18379 // of rounding steps.
18380 if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
18381 if (N0.getOpcode() == ISD::FMUL) {
18382 bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
18383 bool CFP01 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
18384
18385 // (fadd (fmul x, c), x) -> (fmul x, c+1)
18386 if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
18387 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
18388 DAG.getConstantFP(1.0, DL, VT));
18389 return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
18390 }
18391
18392 // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
18393 if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
18394 N1.getOperand(0) == N1.getOperand(1) &&
18395 N0.getOperand(0) == N1.getOperand(0)) {
18396 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
18397 DAG.getConstantFP(2.0, DL, VT));
18398 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
18399 }
18400 }
18401
18402 if (N1.getOpcode() == ISD::FMUL) {
18403 bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
18404 bool CFP11 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
18405
18406 // (fadd x, (fmul x, c)) -> (fmul x, c+1)
18407 if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
18408 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
18409 DAG.getConstantFP(1.0, DL, VT));
18410 return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
18411 }
18412
18413 // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
18414 if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
18415 N0.getOperand(0) == N0.getOperand(1) &&
18416 N1.getOperand(0) == N0.getOperand(0)) {
18417 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
18418 DAG.getConstantFP(2.0, DL, VT));
18419 return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
18420 }
18421 }
18422
18423 if (N0.getOpcode() == ISD::FADD) {
18424 bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
18425 // (fadd (fadd x, x), x) -> (fmul x, 3.0)
18426 if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
18427 (N0.getOperand(0) == N1)) {
18428 return DAG.getNode(ISD::FMUL, DL, VT, N1,
18429 DAG.getConstantFP(3.0, DL, VT));
18430 }
18431 }
18432
18433 if (N1.getOpcode() == ISD::FADD) {
18434 bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
18435 // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
18436 if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
18437 N1.getOperand(0) == N0) {
18438 return DAG.getNode(ISD::FMUL, DL, VT, N0,
18439 DAG.getConstantFP(3.0, DL, VT));
18440 }
18441 }
18442
18443 // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
18444 if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
18445 N0.getOperand(0) == N0.getOperand(1) &&
18446 N1.getOperand(0) == N1.getOperand(1) &&
18447 N0.getOperand(0) == N1.getOperand(0)) {
18448 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
18449 DAG.getConstantFP(4.0, DL, VT));
18450 }
18451 }
18452 } // reassoc && nsz && AllowNewConst
18453
18454 if (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()) {
18455 // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
18456 if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FADD, ISD::FADD, DL,
18457 VT, N0, N1, Flags))
18458 return SD;
18459 }
18460
18461 // FADD -> FMA combines:
18462 if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
18463 if (Fused.getOpcode() != ISD::DELETED_NODE)
18464 AddToWorklist(Fused.getNode());
18465 return Fused;
18466 }
18467 return SDValue();
18468}
18469
18470SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
18471 SDValue Chain = N->getOperand(0);
18472 SDValue N0 = N->getOperand(1);
18473 SDValue N1 = N->getOperand(2);
18474 EVT VT = N->getValueType(0);
18475 EVT ChainVT = N->getValueType(1);
18476 SDLoc DL(N);
18477 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18478
18479 // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
18480 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
18481 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
18482 N1, DAG, LegalOperations, ForCodeSize)) {
18483 return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
18484 {Chain, N0, NegN1});
18485 }
18486
18487 // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
18488 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
18489 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
18490 N0, DAG, LegalOperations, ForCodeSize)) {
18491 return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
18492 {Chain, N1, NegN0});
18493 }
18494 return SDValue();
18495}
18496
18497SDValue DAGCombiner::visitFSUB(SDNode *N) {
18498 SDValue N0 = N->getOperand(0);
18499 SDValue N1 = N->getOperand(1);
18500 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
18501 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
18502 EVT VT = N->getValueType(0);
18503 SDLoc DL(N);
18504 const SDNodeFlags Flags = N->getFlags();
18505 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18506
18507 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
18508 return R;
18509
18510 // fold (fsub c1, c2) -> c1-c2
18511 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FSUB, DL, VT, {N0, N1}))
18512 return C;
18513
18514 // fold vector ops
18515 if (VT.isVector())
18516 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18517 return FoldedVOp;
18518
18519 if (SDValue NewSel = foldBinOpIntoSelect(N))
18520 return NewSel;
18521
18522 // (fsub A, 0) -> A
18523 if (N1CFP && N1CFP->isZero()) {
18524 if (!N1CFP->isNegative() || DAG.canIgnoreSignBitOfZero(SDValue(N, 0))) {
18525 return N0;
18526 }
18527 }
18528
18529 if (N0 == N1) {
18530 // (fsub x, x) -> 0.0
18531 if (Flags.hasNoNaNs())
18532 return DAG.getConstantFP(0.0f, DL, VT);
18533 }
18534
18535 // (fsub -0.0, N1) -> -N1
18536 if (N0CFP && N0CFP->isZero()) {
18537 if (N0CFP->isNegative() || DAG.canIgnoreSignBitOfZero(SDValue(N, 0))) {
18538 // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
18539 // flushed to zero, unless all users treat denorms as zero (DAZ).
18540 // FIXME: This transform will change the sign of a NaN and the behavior
18541 // of a signaling NaN. It is only valid when a NoNaN flag is present.
18542 DenormalMode DenormMode = DAG.getDenormalMode(VT);
18543 if (DenormMode == DenormalMode::getIEEE()) {
18544 if (SDValue NegN1 =
18545 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
18546 return NegN1;
18547 if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
18548 return DAG.getNode(ISD::FNEG, DL, VT, N1);
18549 }
18550 }
18551 }
18552
18553 if (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros() &&
18554 N1.getOpcode() == ISD::FADD) {
18555 // X - (X + Y) -> -Y
18556 if (N0 == N1->getOperand(0))
18557 return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
18558 // X - (Y + X) -> -Y
18559 if (N0 == N1->getOperand(1))
18560 return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
18561 }
18562
18563 // fold (fsub A, (fneg B)) -> (fadd A, B)
18564 if (SDValue NegN1 =
18565 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
18566 return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
18567
18568 // FSUB -> FMA combines:
18569 if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) {
18570 AddToWorklist(Fused.getNode());
18571 return Fused;
18572 }
18573
18574 return SDValue();
18575}
18576
18577// Transform IEEE Floats:
18578// (fmul C, (uitofp Pow2))
18579// -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
18580// (fdiv C, (uitofp Pow2))
18581// -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
18582//
18583// The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
18584// there is no need for more than an add/sub.
18585//
18586// This is valid under the following circumstances:
18587// 1) We are dealing with IEEE floats
18588// 2) C is normal
18589// 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
18590// TODO: Much of this could also be used for generating `ldexp` on targets the
18591// prefer it.
18592SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
18593 EVT VT = N->getValueType(0);
18595 return SDValue();
18596
18597 SDValue ConstOp, Pow2Op;
18598
18599 std::optional<int> Mantissa;
18600 auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
18601 if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
18602 return false;
18603
18604 ConstOp = peekThroughBitcasts(N->getOperand(ConstOpIdx));
18605 Pow2Op = N->getOperand(1 - ConstOpIdx);
18606 if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
18607 (Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
18608 !DAG.computeKnownBits(Pow2Op).isNonNegative()))
18609 return false;
18610
18611 Pow2Op = Pow2Op.getOperand(0);
18612
18613 // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
18614 // TODO: We could use knownbits to make this bound more precise.
18615 int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
18616
18617 auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
18618 if (CFP == nullptr)
18619 return false;
18620
18621 const APFloat &APF = CFP->getValueAPF();
18622
18623 // Make sure we have normal constant.
18624 if (!APF.isNormal())
18625 return false;
18626
18627 // Make sure the floats exponent is within the bounds that this transform
18628 // produces bitwise equals value.
18629 int CurExp = ilogb(APF);
18630 // FMul by pow2 will only increase exponent.
18631 int MinExp =
18632 N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
18633 // FDiv by pow2 will only decrease exponent.
18634 int MaxExp =
18635 N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
18636 if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
18638 return false;
18639
18640 // Finally make sure we actually know the mantissa for the float type.
18641 int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
18642 if (!Mantissa)
18643 Mantissa = ThisMantissa;
18644
18645 return *Mantissa == ThisMantissa && ThisMantissa > 0;
18646 };
18647
18648 // TODO: We may be able to include undefs.
18649 return ISD::matchUnaryFpPredicate(ConstOp, IsFPConstValid);
18650 };
18651
18652 if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
18653 return SDValue();
18654
18655 if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, ConstOp, Pow2Op))
18656 return SDValue();
18657
18658 // Get log2 after all other checks have taken place. This is because
18659 // BuildLogBase2 may create a new node.
18660 SDLoc DL(N);
18661 // Get Log2 type with same bitwidth as the float type (VT).
18662 EVT NewIntVT = VT.changeElementType(
18663 *DAG.getContext(),
18665
18666 SDValue Log2 = BuildLogBase2(Pow2Op, DL, DAG.isKnownNeverZero(Pow2Op),
18667 /*InexpensiveOnly*/ true, NewIntVT);
18668 if (!Log2)
18669 return SDValue();
18670
18671 // Perform actual transform.
18672 SDValue MantissaShiftCnt =
18673 DAG.getShiftAmountConstant(*Mantissa, NewIntVT, DL);
18674 // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
18675 // `(X << C1) + (C << C1)`, but that isn't always the case because of the
18676 // cast. We could implement that by handle here to handle the casts.
18677 SDValue Shift = DAG.getNode(ISD::SHL, DL, NewIntVT, Log2, MantissaShiftCnt);
18678 SDValue ResAsInt =
18679 DAG.getNode(N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
18680 NewIntVT, DAG.getBitcast(NewIntVT, ConstOp), Shift);
18681 SDValue ResAsFP = DAG.getBitcast(VT, ResAsInt);
18682 return ResAsFP;
18683}
18684
18685SDValue DAGCombiner::visitFMUL(SDNode *N) {
18686 SDValue N0 = N->getOperand(0);
18687 SDValue N1 = N->getOperand(1);
18688 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
18689 EVT VT = N->getValueType(0);
18690 SDLoc DL(N);
18691 const SDNodeFlags Flags = N->getFlags();
18692 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18693
18694 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
18695 return R;
18696
18697 // fold (fmul c1, c2) -> c1*c2
18698 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMUL, DL, VT, {N0, N1}))
18699 return C;
18700
18701 // canonicalize constant to RHS
18704 return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
18705
18706 // fold vector ops
18707 if (VT.isVector())
18708 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18709 return FoldedVOp;
18710
18711 if (SDValue NewSel = foldBinOpIntoSelect(N))
18712 return NewSel;
18713
18714 if (Flags.hasAllowReassociation()) {
18715 // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
18717 N0.getOpcode() == ISD::FMUL) {
18718 SDValue N00 = N0.getOperand(0);
18719 SDValue N01 = N0.getOperand(1);
18720 // Avoid an infinite loop by making sure that N00 is not a constant
18721 // (the inner multiply has not been constant folded yet).
18724 SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
18725 return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
18726 }
18727 }
18728
18729 // Match a special-case: we convert X * 2.0 into fadd.
18730 // fmul (fadd X, X), C -> fmul X, 2.0 * C
18731 if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
18732 N0.getOperand(0) == N0.getOperand(1)) {
18733 const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
18734 SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
18735 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
18736 }
18737
18738 // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
18739 if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FMUL, ISD::FMUL, DL,
18740 VT, N0, N1, Flags))
18741 return SD;
18742 }
18743
18744 // fold (fmul X, 2.0) -> (fadd X, X)
18745 if (N1CFP && N1CFP->isExactlyValue(+2.0))
18746 return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
18747
18748 // fold (fmul X, -1.0) -> (fsub -0.0, X)
18749 if (N1CFP && N1CFP->isExactlyValue(-1.0)) {
18750 if (!LegalOperations || TLI.isOperationLegal(ISD::FSUB, VT)) {
18751 return DAG.getNode(ISD::FSUB, DL, VT,
18752 DAG.getConstantFP(-0.0, DL, VT), N0, Flags);
18753 }
18754 }
18755
18756 // -N0 * -N1 --> N0 * N1
18761 SDValue NegN0 =
18762 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
18763 if (NegN0) {
18764 HandleSDNode NegN0Handle(NegN0);
18765 SDValue NegN1 =
18766 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
18767 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18769 return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
18770 }
18771
18772 // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
18773 // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
18774 if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
18775 (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
18776 TLI.isOperationLegal(ISD::FABS, VT)) {
18777 SDValue Select = N0, X = N1;
18778 if (Select.getOpcode() != ISD::SELECT)
18779 std::swap(Select, X);
18780
18781 SDValue Cond = Select.getOperand(0);
18782 auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
18783 auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
18784
18785 if (TrueOpnd && FalseOpnd &&
18786 Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
18787 isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
18788 cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
18789 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
18790 switch (CC) {
18791 default: break;
18792 case ISD::SETOLT:
18793 case ISD::SETULT:
18794 case ISD::SETOLE:
18795 case ISD::SETULE:
18796 case ISD::SETLT:
18797 case ISD::SETLE:
18798 std::swap(TrueOpnd, FalseOpnd);
18799 [[fallthrough]];
18800 case ISD::SETOGT:
18801 case ISD::SETUGT:
18802 case ISD::SETOGE:
18803 case ISD::SETUGE:
18804 case ISD::SETGT:
18805 case ISD::SETGE:
18806 if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
18807 TLI.isOperationLegal(ISD::FNEG, VT))
18808 return DAG.getNode(ISD::FNEG, DL, VT,
18809 DAG.getNode(ISD::FABS, DL, VT, X));
18810 if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
18811 return DAG.getNode(ISD::FABS, DL, VT, X);
18812
18813 break;
18814 }
18815 }
18816 }
18817
18818 // FMUL -> FMA combines:
18819 if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
18820 AddToWorklist(Fused.getNode());
18821 return Fused;
18822 }
18823
18824 // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
18825 // able to run.
18826 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
18827 return R;
18828
18829 return SDValue();
18830}
18831
18832template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
18833 SDValue N0 = N->getOperand(0);
18834 SDValue N1 = N->getOperand(1);
18835 SDValue N2 = N->getOperand(2);
18836 ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
18837 ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
18838 ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(N2);
18839 EVT VT = N->getValueType(0);
18840 SDLoc DL(N);
18841 // FMA nodes have flags that propagate to the created nodes.
18842 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18843 MatchContextClass matcher(DAG, TLI, N);
18844
18845 // Constant fold FMA.
18846 if (SDValue C =
18847 DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1, N2}))
18848 return C;
18849
18850 // (-N0 * -N1) + N2 --> (N0 * N1) + N2
18855 SDValue NegN0 =
18856 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
18857 if (NegN0) {
18858 HandleSDNode NegN0Handle(NegN0);
18859 SDValue NegN1 =
18860 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
18861 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18863 return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
18864 }
18865
18866 if (N->getFlags().hasNoNaNs() && N->getFlags().hasNoInfs()) {
18867 if (N->getFlags().hasNoSignedZeros() ||
18868 (N2CFP && !N2CFP->isExactlyValue(-0.0))) {
18869 if (N0CFP && N0CFP->isZero())
18870 return N2;
18871 if (N1CFP && N1CFP->isZero())
18872 return N2;
18873 }
18874 }
18875
18876 // FIXME: Support splat of constant.
18877 if (N0CFP && N0CFP->isExactlyValue(1.0))
18878 return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
18879 if (N1CFP && N1CFP->isExactlyValue(1.0))
18880 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
18881
18882 // Canonicalize (fma c, x, y) -> (fma x, c, y)
18885 return matcher.getNode(ISD::FMA, DL, VT, N1, N0, N2);
18886
18887 bool CanReassociate = N->getFlags().hasAllowReassociation();
18888 if (CanReassociate) {
18889 // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
18890 if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(0) &&
18893 return matcher.getNode(
18894 ISD::FMUL, DL, VT, N0,
18895 matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
18896 }
18897
18898 // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
18899 if (matcher.match(N0, ISD::FMUL) &&
18902 return matcher.getNode(
18903 ISD::FMA, DL, VT, N0.getOperand(0),
18904 matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), N2);
18905 }
18906 }
18907
18908 // (fma x, -1, y) -> (fadd (fneg x), y)
18909 // FIXME: Support splat of constant.
18910 if (N1CFP) {
18911 if (N1CFP->isExactlyValue(1.0))
18912 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
18913
18914 if (N1CFP->isExactlyValue(-1.0) &&
18915 (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
18916 SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
18917 AddToWorklist(RHSNeg.getNode());
18918 return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
18919 }
18920
18921 // fma (fneg x), K, y -> fma x -K, y
18922 if (matcher.match(N0, ISD::FNEG) &&
18924 (N1.hasOneUse() &&
18925 !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
18926 return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
18927 matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
18928 }
18929 }
18930
18931 // FIXME: Support splat of constant.
18932 if (CanReassociate) {
18933 // (fma x, c, x) -> (fmul x, (c+1))
18934 if (N1CFP && N0 == N2) {
18935 return matcher.getNode(ISD::FMUL, DL, VT, N0,
18936 matcher.getNode(ISD::FADD, DL, VT, N1,
18937 DAG.getConstantFP(1.0, DL, VT)));
18938 }
18939
18940 // (fma x, c, (fneg x)) -> (fmul x, (c-1))
18941 if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(0) == N0) {
18942 return matcher.getNode(ISD::FMUL, DL, VT, N0,
18943 matcher.getNode(ISD::FADD, DL, VT, N1,
18944 DAG.getConstantFP(-1.0, DL, VT)));
18945 }
18946 }
18947
18948 // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
18949 // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
18950 if (!TLI.isFNegFree(VT))
18952 SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
18953 return matcher.getNode(ISD::FNEG, DL, VT, Neg);
18954 return SDValue();
18955}
18956
18957SDValue DAGCombiner::visitFMAD(SDNode *N) {
18958 SDValue N0 = N->getOperand(0);
18959 SDValue N1 = N->getOperand(1);
18960 SDValue N2 = N->getOperand(2);
18961 EVT VT = N->getValueType(0);
18962 SDLoc DL(N);
18963
18964 // Constant fold FMAD.
18965 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMAD, DL, VT, {N0, N1, N2}))
18966 return C;
18967
18968 return SDValue();
18969}
18970
18971SDValue DAGCombiner::visitFMULADD(SDNode *N) {
18972 SDValue N0 = N->getOperand(0);
18973 SDValue N1 = N->getOperand(1);
18974 SDValue N2 = N->getOperand(2);
18975 EVT VT = N->getValueType(0);
18976 SDLoc DL(N);
18977
18978 // Constant fold FMULADD.
18979 if (SDValue C =
18980 DAG.FoldConstantArithmetic(ISD::FMULADD, DL, VT, {N0, N1, N2}))
18981 return C;
18982
18983 return SDValue();
18984}
18985
18986// Combine multiple FDIVs with the same divisor into multiple FMULs by the
18987// reciprocal.
18988// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
18989// Notice that this is not always beneficial. One reason is different targets
18990// may have different costs for FDIV and FMUL, so sometimes the cost of two
18991// FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
18992// is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
18993SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
18994 // TODO: Limit this transform based on optsize/minsize - it always creates at
18995 // least 1 extra instruction. But the perf win may be substantial enough
18996 // that only minsize should restrict this.
18997 const SDNodeFlags Flags = N->getFlags();
18998 if (LegalDAG || !Flags.hasAllowReciprocal())
18999 return SDValue();
19000
19001 // Skip if current node is a reciprocal/fneg-reciprocal.
19002 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
19003 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
19004 if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
19005 return SDValue();
19006
19007 // Exit early if the target does not want this transform or if there can't
19008 // possibly be enough uses of the divisor to make the transform worthwhile.
19009 unsigned MinUses = TLI.combineRepeatedFPDivisors();
19010
19011 // For splat vectors, scale the number of uses by the splat factor. If we can
19012 // convert the division into a scalar op, that will likely be much faster.
19013 unsigned NumElts = 1;
19014 EVT VT = N->getValueType(0);
19015 if (VT.isVector() && DAG.isSplatValue(N1))
19016 NumElts = VT.getVectorMinNumElements();
19017
19018 if (!MinUses || (N1->use_size() * NumElts) < MinUses)
19019 return SDValue();
19020
19021 // Find all FDIV users of the same divisor.
19022 // Use a set because duplicates may be present in the user list.
19023 SetVector<SDNode *> Users;
19024 for (auto *U : N1->users()) {
19025 if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
19026 // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
19027 if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
19028 U->getOperand(0) == U->getOperand(1).getOperand(0) &&
19029 U->getFlags().hasAllowReassociation() &&
19030 U->getFlags().hasNoSignedZeros())
19031 continue;
19032
19033 // This division is eligible for optimization only if global unsafe math
19034 // is enabled or if this division allows reciprocal formation.
19035 if (U->getFlags().hasAllowReciprocal())
19036 Users.insert(U);
19037 }
19038 }
19039
19040 // Now that we have the actual number of divisor uses, make sure it meets
19041 // the minimum threshold specified by the target.
19042 if ((Users.size() * NumElts) < MinUses)
19043 return SDValue();
19044
19045 SDLoc DL(N);
19046 SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
19047 SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
19048
19049 // Dividend / Divisor -> Dividend * Reciprocal
19050 for (auto *U : Users) {
19051 SDValue Dividend = U->getOperand(0);
19052 if (Dividend != FPOne) {
19053 SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
19054 Reciprocal, Flags);
19055 CombineTo(U, NewNode);
19056 } else if (U != Reciprocal.getNode()) {
19057 // In the absence of fast-math-flags, this user node is always the
19058 // same node as Reciprocal, but with FMF they may be different nodes.
19059 CombineTo(U, Reciprocal);
19060 }
19061 }
19062 return SDValue(N, 0); // N was replaced.
19063}
19064
19065SDValue DAGCombiner::visitFDIV(SDNode *N) {
19066 SDValue N0 = N->getOperand(0);
19067 SDValue N1 = N->getOperand(1);
19068 EVT VT = N->getValueType(0);
19069 SDLoc DL(N);
19070 SDNodeFlags Flags = N->getFlags();
19071 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19072
19073 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
19074 return R;
19075
19076 // fold (fdiv c1, c2) -> c1/c2
19077 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FDIV, DL, VT, {N0, N1}))
19078 return C;
19079
19080 // fold vector ops
19081 if (VT.isVector())
19082 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
19083 return FoldedVOp;
19084
19085 if (SDValue NewSel = foldBinOpIntoSelect(N))
19086 return NewSel;
19087
19089 return V;
19090
19091 // fold (fdiv X, c2) -> (fmul X, 1/c2) if there is no loss in precision, or
19092 // the loss is acceptable with AllowReciprocal.
19093 if (auto *N1CFP = isConstOrConstSplatFP(N1, true)) {
19094 // Compute the reciprocal 1.0 / c2.
19095 const APFloat &N1APF = N1CFP->getValueAPF();
19096 APFloat Recip = APFloat::getOne(N1APF.getSemantics());
19098 // Only do the transform if the reciprocal is a legal fp immediate that
19099 // isn't too nasty (eg NaN, denormal, ...).
19100 if (((st == APFloat::opOK && !Recip.isDenormal()) ||
19101 (st == APFloat::opInexact && Flags.hasAllowReciprocal())) &&
19102 (!LegalOperations ||
19103 // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
19104 // backend)... we should handle this gracefully after Legalize.
19105 // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
19107 TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
19108 return DAG.getNode(ISD::FMUL, DL, VT, N0,
19109 DAG.getConstantFP(Recip, DL, VT));
19110 }
19111
19112 if (Flags.hasAllowReciprocal()) {
19113 // If this FDIV is part of a reciprocal square root, it may be folded
19114 // into a target-specific square root estimate instruction.
19115 bool N1AllowReciprocal = N1->getFlags().hasAllowReciprocal();
19116 if (N1.getOpcode() == ISD::FSQRT) {
19117 if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), N1->getFlags()))
19118 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
19119 } else if (N1.getOpcode() == ISD::FP_EXTEND &&
19120 N1.getOperand(0).getOpcode() == ISD::FSQRT &&
19121 N1AllowReciprocal) {
19122 if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
19123 N1.getOperand(0)->getFlags())) {
19124 RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
19125 AddToWorklist(RV.getNode());
19126 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
19127 }
19128 } else if (N1.getOpcode() == ISD::FP_ROUND &&
19129 N1.getOperand(0).getOpcode() == ISD::FSQRT) {
19130 if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
19131 N1.getOperand(0)->getFlags())) {
19132 RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
19133 AddToWorklist(RV.getNode());
19134 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
19135 }
19136 } else if (N1.getOpcode() == ISD::FMUL) {
19137 // Look through an FMUL. Even though this won't remove the FDIV directly,
19138 // it's still worthwhile to get rid of the FSQRT if possible.
19139 SDValue Sqrt, Y;
19140 if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
19141 Sqrt = N1.getOperand(0);
19142 Y = N1.getOperand(1);
19143 } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
19144 Sqrt = N1.getOperand(1);
19145 Y = N1.getOperand(0);
19146 }
19147 if (Sqrt.getNode()) {
19148 // If the other multiply operand is known positive, pull it into the
19149 // sqrt. That will eliminate the division if we convert to an estimate.
19150 if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
19151 N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
19152 SDValue A;
19153 if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
19154 A = Y.getOperand(0);
19155 else if (Y == Sqrt.getOperand(0))
19156 A = Y;
19157 if (A) {
19158 // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
19159 // X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
19160 SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
19161 SDValue AAZ =
19162 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
19163 if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Sqrt->getFlags()))
19164 return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
19165
19166 // Estimate creation failed. Clean up speculatively created nodes.
19167 recursivelyDeleteUnusedNodes(AAZ.getNode());
19168 }
19169 }
19170
19171 // We found a FSQRT, so try to make this fold:
19172 // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
19173 if (SDValue Rsqrt =
19174 buildRsqrtEstimate(Sqrt.getOperand(0), Sqrt->getFlags())) {
19175 SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
19176 AddToWorklist(Div.getNode());
19177 return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
19178 }
19179 }
19180 }
19181
19182 // Fold into a reciprocal estimate and multiply instead of a real divide.
19183 if (Flags.hasNoInfs())
19184 if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
19185 return RV;
19186 }
19187
19188 // Fold X/Sqrt(X) -> Sqrt(X)
19189 if (DAG.canIgnoreSignBitOfZero(SDValue(N, 0)) &&
19190 Flags.hasAllowReassociation())
19191 if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
19192 return N1;
19193
19194 // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
19199 SDValue NegN0 =
19200 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
19201 if (NegN0) {
19202 HandleSDNode NegN0Handle(NegN0);
19203 SDValue NegN1 =
19204 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
19205 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
19207 return DAG.getNode(ISD::FDIV, DL, VT, NegN0, NegN1);
19208 }
19209
19210 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
19211 return R;
19212
19213 return SDValue();
19214}
19215
19216SDValue DAGCombiner::visitFREM(SDNode *N) {
19217 SDValue N0 = N->getOperand(0);
19218 SDValue N1 = N->getOperand(1);
19219 EVT VT = N->getValueType(0);
19220 SDNodeFlags Flags = N->getFlags();
19221 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19222 SDLoc DL(N);
19223
19224 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
19225 return R;
19226
19227 // fold (frem c1, c2) -> fmod(c1,c2)
19228 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FREM, DL, VT, {N0, N1}))
19229 return C;
19230
19231 if (SDValue NewSel = foldBinOpIntoSelect(N))
19232 return NewSel;
19233
19234 // Lower frem N0, N1 => x - trunc(N0 / N1) * N1, providing N1 is an integer
19235 // power of 2.
19236 if (!TLI.isOperationLegal(ISD::FREM, VT) &&
19240 DAG.isKnownToBeAPowerOfTwoFP(N1)) {
19241 bool NeedsCopySign = !DAG.canIgnoreSignBitOfZero(SDValue(N, 0)) &&
19243 SDValue Div = DAG.getNode(ISD::FDIV, DL, VT, N0, N1);
19244 SDValue Rnd = DAG.getNode(ISD::FTRUNC, DL, VT, Div);
19245 SDValue MLA;
19247 MLA = DAG.getNode(ISD::FMA, DL, VT, DAG.getNode(ISD::FNEG, DL, VT, Rnd),
19248 N1, N0);
19249 } else {
19250 SDValue Mul = DAG.getNode(ISD::FMUL, DL, VT, Rnd, N1);
19251 MLA = DAG.getNode(ISD::FSUB, DL, VT, N0, Mul);
19252 }
19253 return NeedsCopySign ? DAG.getNode(ISD::FCOPYSIGN, DL, VT, MLA, N0) : MLA;
19254 }
19255
19256 return SDValue();
19257}
19258
19259SDValue DAGCombiner::visitFSQRT(SDNode *N) {
19260 SDNodeFlags Flags = N->getFlags();
19261
19262 // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
19263 // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
19264 if (!Flags.hasApproximateFuncs() || !Flags.hasNoInfs())
19265 return SDValue();
19266
19267 SDValue N0 = N->getOperand(0);
19268 if (TLI.isFsqrtCheap(N0, DAG))
19269 return SDValue();
19270
19271 // FSQRT nodes have flags that propagate to the created nodes.
19272 SelectionDAG::FlagInserter FlagInserter(DAG, Flags);
19273 // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
19274 // transform the fdiv, we may produce a sub-optimal estimate sequence
19275 // because the reciprocal calculation may not have to filter out a
19276 // 0.0 input.
19277 return buildSqrtEstimate(N0, Flags);
19278}
19279
19280/// copysign(x, fp_extend(y)) -> copysign(x, y)
19281/// copysign(x, fp_round(y)) -> copysign(x, y)
19282/// Operands to the functions are the type of X and Y respectively.
19283static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) {
19284 // Always fold no-op FP casts.
19285 if (XTy == YTy)
19286 return true;
19287
19288 // Do not optimize out type conversion of f128 type yet.
19289 // For some targets like x86_64, configuration is changed to keep one f128
19290 // value in one SSE register, but instruction selection cannot handle
19291 // FCOPYSIGN on SSE registers yet.
19292 if (YTy == MVT::f128)
19293 return false;
19294
19295 // Avoid mismatched vector operand types, for better instruction selection.
19296 return !YTy.isVector();
19297}
19298
19300 SDValue N1 = N->getOperand(1);
19301 if (N1.getOpcode() != ISD::FP_EXTEND &&
19302 N1.getOpcode() != ISD::FP_ROUND)
19303 return false;
19304 EVT N1VT = N1->getValueType(0);
19305 EVT N1Op0VT = N1->getOperand(0).getValueType();
19306 return CanCombineFCOPYSIGN_EXTEND_ROUND(N1VT, N1Op0VT);
19307}
19308
19309SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
19310 SDValue N0 = N->getOperand(0);
19311 SDValue N1 = N->getOperand(1);
19312 EVT VT = N->getValueType(0);
19313 SDLoc DL(N);
19314
19315 // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
19316 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FCOPYSIGN, DL, VT, {N0, N1}))
19317 return C;
19318
19319 // copysign(x, fp_extend(y)) -> copysign(x, y)
19320 // copysign(x, fp_round(y)) -> copysign(x, y)
19322 return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N0, N1.getOperand(0));
19323
19325 return SDValue(N, 0);
19326
19327 if (VT != N1.getValueType())
19328 return SDValue();
19329
19330 // If this is equivalent to a disjoint or, replace it with one. This can
19331 // happen if the sign operand is a sign mask (i.e., x << sign_bit_position).
19332 if (DAG.SignBitIsZeroFP(N0) &&
19334 // TODO: Just directly match the shift pattern. computeKnownBits is heavy
19335 // for a such a narrowly targeted case.
19336 EVT IntVT = VT.changeTypeToInteger();
19337 // TODO: It appears to be profitable in some situations to unconditionally
19338 // emit a fabs(n0) to perform this combine.
19339 SDValue CastSrc0 = DAG.getNode(ISD::BITCAST, DL, IntVT, N0);
19340 SDValue CastSrc1 = DAG.getNode(ISD::BITCAST, DL, IntVT, N1);
19341
19342 SDValue SignOr = DAG.getNode(ISD::OR, DL, IntVT, CastSrc0, CastSrc1,
19344 return DAG.getNode(ISD::BITCAST, DL, VT, SignOr);
19345 }
19346
19347 return SDValue();
19348}
19349
19350SDValue DAGCombiner::visitFPOW(SDNode *N) {
19351 ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
19352 if (!ExponentC)
19353 return SDValue();
19354 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19355
19356 // Try to convert x ** (1/3) into cube root.
19357 // TODO: Handle the various flavors of long double.
19358 // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
19359 // Some range near 1/3 should be fine.
19360 EVT VT = N->getValueType(0);
19361 EVT ScalarVT = VT.getScalarType();
19362 if ((ScalarVT == MVT::f32 &&
19363 ExponentC->getValueAPF().isExactlyValue(1.0f / 3.0f)) ||
19364 (ScalarVT == MVT::f64 &&
19365 ExponentC->getValueAPF().isExactlyValue(1.0 / 3.0))) {
19366 // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
19367 // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
19368 // pow(-val, 1/3) = nan; cbrt(-val) = -num.
19369 // For regular numbers, rounding may cause the results to differ.
19370 // Therefore, we require { nsz ninf nnan afn } for this transform.
19371 // TODO: We could select out the special cases if we don't have nsz/ninf.
19372 SDNodeFlags Flags = N->getFlags();
19373 if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
19374 !Flags.hasApproximateFuncs())
19375 return SDValue();
19376
19377 // Do not create a cbrt() libcall if the target does not have it, and do not
19378 // turn a pow that has lowering support into a cbrt() libcall.
19379 RTLIB::Libcall LC = RTLIB::getCBRT(VT);
19380 bool HasLibCall =
19381 DAG.getLibcalls().getLibcallImpl(LC) != RTLIB::Unsupported;
19382 if (!HasLibCall ||
19385 return SDValue();
19386
19387 return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
19388 }
19389
19390 // Try to convert x ** (1/4) and x ** (3/4) into square roots.
19391 // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
19392 // TODO: This could be extended (using a target hook) to handle smaller
19393 // power-of-2 fractional exponents.
19394 bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
19395 bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
19396 if (ExponentIs025 || ExponentIs075) {
19397 // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
19398 // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN.
19399 // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
19400 // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) = NaN.
19401 // For regular numbers, rounding may cause the results to differ.
19402 // Therefore, we require { nsz ninf afn } for this transform.
19403 // TODO: We could select out the special cases if we don't have nsz/ninf.
19404 SDNodeFlags Flags = N->getFlags();
19405
19406 // We only need no signed zeros for the 0.25 case.
19407 if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
19408 !Flags.hasApproximateFuncs())
19409 return SDValue();
19410
19411 // Don't double the number of libcalls. We are trying to inline fast code.
19413 return SDValue();
19414
19415 // Assume that libcalls are the smallest code.
19416 // TODO: This restriction should probably be lifted for vectors.
19417 if (ForCodeSize)
19418 return SDValue();
19419
19420 // pow(X, 0.25) --> sqrt(sqrt(X))
19421 SDLoc DL(N);
19422 SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
19423 SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
19424 if (ExponentIs025)
19425 return SqrtSqrt;
19426 // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
19427 return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
19428 }
19429
19430 return SDValue();
19431}
19432
19434 const TargetLowering &TLI) {
19435 // We can fold the fpto[us]i -> [us]itofp pattern into a single ftrunc.
19436 // Additionally, if there are clamps ([us]min or [us]max) around
19437 // the fpto[us]i, we can fold those into fminnum/fmaxnum around the ftrunc.
19438 // If NoSignedZerosFPMath is enabled, this is a direct replacement.
19439 // Otherwise, for strict math, we must handle edge cases:
19440 // 1. For unsigned conversions, use FABS to handle negative cases. Take -0.0
19441 // as example, it first becomes integer 0, and is converted back to +0.0.
19442 // FTRUNC on its own could produce -0.0.
19443
19444 // FIXME: We should be able to use node-level FMF here.
19445 EVT VT = N->getValueType(0);
19446 if (!TLI.isOperationLegal(ISD::FTRUNC, VT))
19447 return SDValue();
19448
19449 bool IsUnsigned = N->getOpcode() == ISD::UINT_TO_FP;
19450 bool IsSigned = N->getOpcode() == ISD::SINT_TO_FP;
19451 assert(IsSigned || IsUnsigned);
19452
19453 bool IsSignedZeroSafe = DAG.getTarget().Options.NoSignedZerosFPMath ||
19455 // For signed conversions: The optimization changes signed zero behavior.
19456 if (IsSigned && !IsSignedZeroSafe)
19457 return SDValue();
19458 // For unsigned conversions, we need FABS to canonicalize -0.0 to +0.0
19459 // (unless outputting a signed zero is OK).
19460 if (IsUnsigned && !IsSignedZeroSafe && !TLI.isFAbsFree(VT))
19461 return SDValue();
19462
19463 // Collect potential clamp operations (outermost to innermost) and peel.
19464 struct ClampInfo {
19465 bool IsMin;
19467 };
19468 constexpr unsigned MaxClamps = 2;
19470 unsigned MinOp = IsUnsigned ? ISD::UMIN : ISD::SMIN;
19471 unsigned MaxOp = IsUnsigned ? ISD::UMAX : ISD::SMAX;
19472 SDValue IntVal = N->getOperand(0);
19473 for (unsigned Level = 0; Level < MaxClamps; ++Level) {
19474 if (!IntVal.hasOneUse() ||
19475 (IntVal.getOpcode() != MinOp && IntVal.getOpcode() != MaxOp))
19476 break;
19477 SDValue RHS = IntVal.getOperand(1);
19478 APInt IntConst;
19479 if (auto *IntConstNode = dyn_cast<ConstantSDNode>(RHS))
19480 IntConst = IntConstNode->getAPIntValue();
19481 else if (!ISD::isConstantSplatVector(RHS.getNode(), IntConst))
19482 return SDValue();
19483 APFloat FPConst(VT.getFltSemantics());
19484 FPConst.convertFromAPInt(IntConst, IsSigned, APFloat::rmNearestTiesToEven);
19485 // Verify roundtrip exactness.
19486 APSInt RoundTrip(IntConst.getBitWidth(), IsUnsigned);
19487 bool IsExact;
19488 if (FPConst.convertToInteger(RoundTrip, APFloat::rmTowardZero, &IsExact) !=
19489 APFloat::opOK ||
19490 !IsExact || static_cast<const APInt &>(RoundTrip) != IntConst)
19491 return SDValue();
19492 bool IsMin = IntVal.getOpcode() == MinOp;
19493 Clamps.push_back({IsMin, DAG.getConstantFP(FPConst, DL, VT)});
19494 IntVal = IntVal.getOperand(0);
19495 }
19496
19497 // Check that the sequence ends with the correct kind of fpto[us]i.
19498 unsigned FPToIntOp = IsUnsigned ? ISD::FP_TO_UINT : ISD::FP_TO_SINT;
19499 if (IntVal.getOpcode() != FPToIntOp ||
19500 IntVal.getOperand(0).getValueType() != VT)
19501 return SDValue();
19502
19503 SDValue Result = IntVal.getOperand(0);
19504 if (IsUnsigned && !IsSignedZeroSafe && TLI.isFAbsFree(VT))
19505 Result = DAG.getNode(ISD::FABS, DL, VT, Result);
19506 Result = DAG.getNode(ISD::FTRUNC, DL, VT, Result);
19507 // Apply clamps, if any, in reverse order (innermost first).
19508 for (const ClampInfo &Clamp : reverse(Clamps)) {
19509 unsigned FPClampOp =
19510 getMinMaxOpcodeForClamp(Clamp.IsMin, Result, Clamp.Constant, DAG, TLI);
19511 if (FPClampOp == ISD::DELETED_NODE)
19512 return SDValue();
19513 Result = DAG.getNode(FPClampOp, DL, VT, Result, Clamp.Constant);
19514 }
19515 return Result;
19516}
19517
19518SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
19519 SDValue N0 = N->getOperand(0);
19520 EVT VT = N->getValueType(0);
19521 EVT OpVT = N0.getValueType();
19522 SDLoc DL(N);
19523
19524 // [us]itofp(undef) = 0, because the result value is bounded.
19525 if (N0.isUndef())
19526 return DAG.getConstantFP(0.0, DL, VT);
19527
19528 // fold (sint_to_fp c1) -> c1fp
19529 // ...but only if the target supports immediate floating-point values
19530 if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
19531 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SINT_TO_FP, DL, VT, {N0}))
19532 return C;
19533
19534 // If the input is a legal type, and SINT_TO_FP is not legal on this target,
19535 // but UINT_TO_FP is legal on this target, try to convert.
19536 if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
19537 hasOperation(ISD::UINT_TO_FP, OpVT)) {
19538 // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
19539 if (DAG.SignBitIsZero(N0))
19540 return DAG.getNode(ISD::UINT_TO_FP, DL, VT, N0);
19541 }
19542
19543 // The next optimizations are desirable only if SELECT_CC can be lowered.
19544 // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
19545 if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
19546 !VT.isVector() &&
19547 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
19548 return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
19549 DAG.getConstantFP(0.0, DL, VT));
19550
19551 // fold (sint_to_fp (zext (setcc x, y, cc))) ->
19552 // (select (setcc x, y, cc), 1.0, 0.0)
19553 if (N0.getOpcode() == ISD::ZERO_EXTEND &&
19554 N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
19555 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
19556 return DAG.getSelect(DL, VT, N0.getOperand(0),
19557 DAG.getConstantFP(1.0, DL, VT),
19558 DAG.getConstantFP(0.0, DL, VT));
19559
19560 if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI))
19561 return FTrunc;
19562
19563 // fold (sint_to_fp (trunc nsw x)) -> (sint_to_fp x)
19564 if (N0.getOpcode() == ISD::TRUNCATE && N0->getFlags().hasNoSignedWrap() &&
19566 N0.getOperand(0).getValueType()))
19567 return DAG.getNode(ISD::SINT_TO_FP, DL, VT, N0.getOperand(0));
19568
19569 return SDValue();
19570}
19571
19572SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
19573 SDValue N0 = N->getOperand(0);
19574 EVT VT = N->getValueType(0);
19575 EVT OpVT = N0.getValueType();
19576 SDLoc DL(N);
19577
19578 // [us]itofp(undef) = 0, because the result value is bounded.
19579 if (N0.isUndef())
19580 return DAG.getConstantFP(0.0, DL, VT);
19581
19582 // fold (uint_to_fp c1) -> c1fp
19583 // ...but only if the target supports immediate floating-point values
19584 if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
19585 if (SDValue C = DAG.FoldConstantArithmetic(ISD::UINT_TO_FP, DL, VT, {N0}))
19586 return C;
19587
19588 // If the input is a legal type, and UINT_TO_FP is not legal on this target,
19589 // but SINT_TO_FP is legal on this target, try to convert.
19590 if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
19591 hasOperation(ISD::SINT_TO_FP, OpVT)) {
19592 // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
19593 if (DAG.SignBitIsZero(N0))
19594 return DAG.getNode(ISD::SINT_TO_FP, DL, VT, N0);
19595 }
19596
19597 // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
19598 if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
19599 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
19600 return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
19601 DAG.getConstantFP(0.0, DL, VT));
19602
19603 if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI))
19604 return FTrunc;
19605
19606 // fold (uint_to_fp (trunc nuw x)) -> (uint_to_fp x)
19607 if (N0.getOpcode() == ISD::TRUNCATE && N0->getFlags().hasNoUnsignedWrap() &&
19609 N0.getOperand(0).getValueType()))
19610 return DAG.getNode(ISD::UINT_TO_FP, DL, VT, N0.getOperand(0));
19611
19612 return SDValue();
19613}
19614
19615// Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
19617 SDValue N0 = N->getOperand(0);
19618 EVT VT = N->getValueType(0);
19619
19620 if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
19621 return SDValue();
19622
19623 SDValue Src = N0.getOperand(0);
19624 EVT SrcVT = Src.getValueType();
19625 bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
19626 bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
19627
19628 // We can safely assume the conversion won't overflow the output range,
19629 // because (for example) (uint8_t)18293.f is undefined behavior.
19630
19631 // Since we can assume the conversion won't overflow, our decision as to
19632 // whether the input will fit in the float should depend on the minimum
19633 // of the input range and output range.
19634
19635 // This means this is also safe for a signed input and unsigned output, since
19636 // a negative input would lead to undefined behavior.
19637 unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
19638 unsigned OutputSize = (int)VT.getScalarSizeInBits();
19639 unsigned ActualSize = std::min(InputSize, OutputSize);
19640 const fltSemantics &Sem = N0.getValueType().getFltSemantics();
19641
19642 // We can only fold away the float conversion if the input range can be
19643 // represented exactly in the float range.
19644 if (APFloat::semanticsPrecision(Sem) >= ActualSize) {
19645 if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
19646 unsigned ExtOp =
19647 IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
19648 return DAG.getNode(ExtOp, DL, VT, Src);
19649 }
19650 if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
19651 return DAG.getNode(ISD::TRUNCATE, DL, VT, Src);
19652 return DAG.getBitcast(VT, Src);
19653 }
19654 return SDValue();
19655}
19656
19657SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
19658 SDValue N0 = N->getOperand(0);
19659 EVT VT = N->getValueType(0);
19660 SDLoc DL(N);
19661
19662 // fold (fp_to_sint undef) -> undef
19663 if (N0.isUndef())
19664 return DAG.getUNDEF(VT);
19665
19666 // fold (fp_to_sint c1fp) -> c1
19667 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_TO_SINT, DL, VT, {N0}))
19668 return C;
19669
19670 return FoldIntToFPToInt(N, DL, DAG);
19671}
19672
19673SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
19674 SDValue N0 = N->getOperand(0);
19675 EVT VT = N->getValueType(0);
19676 SDLoc DL(N);
19677
19678 // fold (fp_to_uint undef) -> undef
19679 if (N0.isUndef())
19680 return DAG.getUNDEF(VT);
19681
19682 // fold (fp_to_uint c1fp) -> c1
19683 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_TO_UINT, DL, VT, {N0}))
19684 return C;
19685
19686 return FoldIntToFPToInt(N, DL, DAG);
19687}
19688
19689SDValue DAGCombiner::visitXROUND(SDNode *N) {
19690 SDValue N0 = N->getOperand(0);
19691 EVT VT = N->getValueType(0);
19692
19693 // fold (lrint|llrint undef) -> undef
19694 // fold (lround|llround undef) -> undef
19695 if (N0.isUndef())
19696 return DAG.getUNDEF(VT);
19697
19698 // fold (lrint|llrint c1fp) -> c1
19699 // fold (lround|llround c1fp) -> c1
19700 if (SDValue C =
19701 DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, {N0}))
19702 return C;
19703
19704 return SDValue();
19705}
19706
19707SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
19708 SDValue N0 = N->getOperand(0);
19709 SDValue N1 = N->getOperand(1);
19710 EVT VT = N->getValueType(0);
19711 SDLoc DL(N);
19712
19713 // fold (fp_round c1fp) -> c1fp
19714 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_ROUND, DL, VT, {N0, N1}))
19715 return C;
19716
19717 // fold (fp_round (fp_extend x)) -> x
19718 if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
19719 return N0.getOperand(0);
19720
19721 // fold (fp_round (fp_round x)) -> (fp_round x)
19722 if (N0.getOpcode() == ISD::FP_ROUND) {
19723 const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
19724 const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
19725
19726 // Avoid folding legal fp_rounds into non-legal ones.
19727 if (!hasOperation(ISD::FP_ROUND, VT))
19728 return SDValue();
19729
19730 // Skip this folding if it results in an fp_round from f80 to f16.
19731 //
19732 // f80 to f16 always generates an expensive (and as yet, unimplemented)
19733 // libcall to __truncxfhf2 instead of selecting native f16 conversion
19734 // instructions from f32 or f64. Moreover, the first (value-preserving)
19735 // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
19736 // x86.
19737 if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
19738 return SDValue();
19739
19740 // If the first fp_round isn't a value preserving truncation, it might
19741 // introduce a tie in the second fp_round, that wouldn't occur in the
19742 // single-step fp_round we want to fold to.
19743 // In other words, double rounding isn't the same as rounding.
19744 // Also, this is a value preserving truncation iff both fp_round's are.
19745 if ((N->getFlags().hasAllowContract() &&
19746 N0->getFlags().hasAllowContract()) ||
19747 N0IsTrunc)
19748 return DAG.getNode(
19749 ISD::FP_ROUND, DL, VT, N0.getOperand(0),
19750 DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
19751 }
19752
19753 // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
19754 // Note: From a legality perspective, this is a two step transform. First,
19755 // we duplicate the fp_round to the arguments of the copysign, then we
19756 // eliminate the fp_round on Y. The second step requires an additional
19757 // predicate to match the implementation above.
19758 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
19760 N0.getValueType())) {
19761 SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
19762 N0.getOperand(0), N1);
19763 AddToWorklist(Tmp.getNode());
19764 return DAG.getNode(ISD::FCOPYSIGN, DL, VT, Tmp, N0.getOperand(1));
19765 }
19766
19767 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
19768 return NewVSel;
19769
19770 return SDValue();
19771}
19772
19773// Eliminate a floating-point widening of a narrowed value if the fast math
19774// flags allow it.
19776 SDValue N0 = N->getOperand(0);
19777 EVT VT = N->getValueType(0);
19778
19779 unsigned NarrowingOp;
19780 switch (N->getOpcode()) {
19781 case ISD::FP16_TO_FP:
19782 NarrowingOp = ISD::FP_TO_FP16;
19783 break;
19784 case ISD::BF16_TO_FP:
19785 NarrowingOp = ISD::FP_TO_BF16;
19786 break;
19787 case ISD::FP_EXTEND:
19788 NarrowingOp = ISD::FP_ROUND;
19789 break;
19790 default:
19791 llvm_unreachable("Expected widening FP cast");
19792 }
19793
19794 if (N0.getOpcode() == NarrowingOp && N0.getOperand(0).getValueType() == VT) {
19795 const SDNodeFlags NarrowFlags = N0->getFlags();
19796 const SDNodeFlags WidenFlags = N->getFlags();
19797 // Narrowing can introduce inf and change the encoding of a nan, so the
19798 // widen must have the nnan and ninf flags to indicate that we don't need to
19799 // care about that. We are also removing a rounding step, and that requires
19800 // both the narrow and widen to allow contraction.
19801 if (WidenFlags.hasNoNaNs() && WidenFlags.hasNoInfs() &&
19802 NarrowFlags.hasAllowContract() && WidenFlags.hasAllowContract()) {
19803 return N0.getOperand(0);
19804 }
19805 }
19806
19807 return SDValue();
19808}
19809
19810SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
19811 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19812 SDValue N0 = N->getOperand(0);
19813 EVT VT = N->getValueType(0);
19814 SDLoc DL(N);
19815
19816 if (VT.isVector())
19817 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
19818 return FoldedVOp;
19819
19820 // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
19821 if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::FP_ROUND)
19822 return SDValue();
19823
19824 // fold (fp_extend c1fp) -> c1fp
19825 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_EXTEND, DL, VT, {N0}))
19826 return C;
19827
19828 // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
19829 if (N0.getOpcode() == ISD::FP16_TO_FP &&
19831 return DAG.getNode(ISD::FP16_TO_FP, DL, VT, N0.getOperand(0));
19832
19833 // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
19834 // value of X.
19835 if (N0.getOpcode() == ISD::FP_ROUND && N0.getConstantOperandVal(1) == 1) {
19836 SDValue In = N0.getOperand(0);
19837 if (In.getValueType() == VT) return In;
19838 if (VT.bitsLT(In.getValueType()))
19839 return DAG.getNode(ISD::FP_ROUND, DL, VT, In, N0.getOperand(1));
19840 return DAG.getNode(ISD::FP_EXTEND, DL, VT, In);
19841 }
19842
19843 // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
19844 if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
19846 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
19847 SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, VT,
19848 LN0->getChain(),
19849 LN0->getBasePtr(), N0.getValueType(),
19850 LN0->getMemOperand());
19851 CombineTo(N, ExtLoad);
19852 CombineTo(
19853 N0.getNode(),
19854 DAG.getNode(ISD::FP_ROUND, SDLoc(N0), N0.getValueType(), ExtLoad,
19855 DAG.getIntPtrConstant(1, SDLoc(N0), /*isTarget=*/true)),
19856 ExtLoad.getValue(1));
19857 return SDValue(N, 0); // Return N so it doesn't get rechecked!
19858 }
19859
19860 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
19861 return NewVSel;
19862
19863 if (SDValue CastEliminated = eliminateFPCastPair(N))
19864 return CastEliminated;
19865
19866 return SDValue();
19867}
19868
19869SDValue DAGCombiner::visitFCEIL(SDNode *N) {
19870 SDValue N0 = N->getOperand(0);
19871 EVT VT = N->getValueType(0);
19872
19873 // fold (fceil c1) -> fceil(c1)
19874 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FCEIL, SDLoc(N), VT, {N0}))
19875 return C;
19876
19877 return SDValue();
19878}
19879
19880SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
19881 SDValue N0 = N->getOperand(0);
19882 EVT VT = N->getValueType(0);
19883
19884 // fold (ftrunc c1) -> ftrunc(c1)
19885 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FTRUNC, SDLoc(N), VT, {N0}))
19886 return C;
19887
19888 // fold ftrunc (known rounded int x) -> x
19889 // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
19890 // likely to be generated to extract integer from a rounded floating value.
19891 switch (N0.getOpcode()) {
19892 default: break;
19893 case ISD::FRINT:
19894 case ISD::FTRUNC:
19895 case ISD::FNEARBYINT:
19896 case ISD::FROUNDEVEN:
19897 case ISD::FFLOOR:
19898 case ISD::FCEIL:
19899 return N0;
19900 }
19901
19902 return SDValue();
19903}
19904
19905SDValue DAGCombiner::visitFFREXP(SDNode *N) {
19906 SDValue N0 = N->getOperand(0);
19907
19908 // fold (ffrexp c1) -> ffrexp(c1)
19910 return DAG.getNode(ISD::FFREXP, SDLoc(N), N->getVTList(), N0);
19911 return SDValue();
19912}
19913
19914SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
19915 SDValue N0 = N->getOperand(0);
19916 EVT VT = N->getValueType(0);
19917
19918 // fold (ffloor c1) -> ffloor(c1)
19919 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FFLOOR, SDLoc(N), VT, {N0}))
19920 return C;
19921
19922 return SDValue();
19923}
19924
19925SDValue DAGCombiner::visitFNEG(SDNode *N) {
19926 SDValue N0 = N->getOperand(0);
19927 EVT VT = N->getValueType(0);
19928 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19929
19930 // Constant fold FNEG.
19931 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FNEG, SDLoc(N), VT, {N0}))
19932 return C;
19933
19934 if (SDValue NegN0 =
19935 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
19936 return NegN0;
19937
19938 // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
19939 // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
19940 // know it was called from a context with a nsz flag if the input fsub does
19941 // not.
19942 if (N0.getOpcode() == ISD::FSUB && N->getFlags().hasNoSignedZeros() &&
19943 N0.hasOneUse()) {
19944 return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
19945 N0.getOperand(0));
19946 }
19947
19949 return SDValue(N, 0);
19950
19951 if (SDValue Cast = foldSignChangeInBitcast(N))
19952 return Cast;
19953
19954 return SDValue();
19955}
19956
19957SDValue DAGCombiner::visitFMinMax(SDNode *N) {
19958 SDValue N0 = N->getOperand(0);
19959 SDValue N1 = N->getOperand(1);
19960 EVT VT = N->getValueType(0);
19961 const SDNodeFlags Flags = N->getFlags();
19962 unsigned Opc = N->getOpcode();
19963 bool PropAllNaNsToQNaNs = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
19964 bool PropOnlySNaNsToQNaNs = Opc == ISD::FMINNUM || Opc == ISD::FMAXNUM;
19965 bool IsMin =
19967 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19968
19969 // Constant fold.
19970 if (SDValue C = DAG.FoldConstantArithmetic(Opc, SDLoc(N), VT, {N0, N1}))
19971 return C;
19972
19973 // Canonicalize to constant on RHS.
19976 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
19977
19978 if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1)) {
19979 const APFloat &AF = N1CFP->getValueAPF();
19980
19981 // minnum(X, qnan) -> X
19982 // maxnum(X, qnan) -> X
19983 // minnum(X, snan) -> qnan
19984 // maxnum(X, snan) -> qnan
19985 // minimum(X, nan) -> qnan
19986 // maximum(X, nan) -> qnan
19987 // minimumnum(X, nan) -> X
19988 // maximumnum(X, nan) -> X
19989 if (AF.isNaN()) {
19990 if (PropAllNaNsToQNaNs || (AF.isSignaling() && PropOnlySNaNsToQNaNs)) {
19991 if (AF.isSignaling())
19992 return DAG.getConstantFP(AF.makeQuiet(), SDLoc(N), VT);
19993 return N->getOperand(1);
19994 }
19995 return N->getOperand(0);
19996 }
19997
19998 // In the following folds, inf can be replaced with the largest finite
19999 // float, if the ninf flag is set.
20000 if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
20001 // minnum(X, -inf) -> -inf (ignoring sNaN -> qNaN propagation)
20002 // maxnum(X, +inf) -> +inf (ignoring sNaN -> qNaN propagation)
20003 // minimum(X, -inf) -> -inf if nnan
20004 // maximum(X, +inf) -> +inf if nnan
20005 // minimumnum(X, -inf) -> -inf
20006 // maximumnum(X, +inf) -> +inf
20007 if (IsMin == AF.isNegative() &&
20008 (!PropAllNaNsToQNaNs || Flags.hasNoNaNs()))
20009 return N->getOperand(1);
20010
20011 // minnum(X, +inf) -> X if nnan
20012 // maxnum(X, -inf) -> X if nnan
20013 // minimum(X, +inf) -> X (ignoring quieting of sNaNs)
20014 // maximum(X, -inf) -> X (ignoring quieting of sNaNs)
20015 // minimumnum(X, +inf) -> X if nnan
20016 // maximumnum(X, -inf) -> X if nnan
20017 if (IsMin != AF.isNegative() && (PropAllNaNsToQNaNs || Flags.hasNoNaNs()))
20018 return N->getOperand(0);
20019 }
20020 }
20021
20022 // There are no VECREDUCE variants of FMINIMUMNUM or FMAXIMUMNUM
20024 return SDValue();
20025
20026 if (SDValue SD = reassociateReduction(
20027 PropAllNaNsToQNaNs
20030 Opc, SDLoc(N), VT, N0, N1, Flags))
20031 return SD;