LLVM 22.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
64#include "llvm/ADT/MapVector.h"
65#include "llvm/ADT/Statistic.h"
70#include "llvm/IR/IRBuilder.h"
71#include "llvm/IR/Intrinsics.h"
77#include <algorithm>
78
79using namespace llvm;
80using namespace PatternMatch;
81
82#define DEBUG_TYPE "complex-deinterleaving"
83
84STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
85
87 "enable-complex-deinterleaving",
88 cl::desc("Enable generation of complex instructions"), cl::init(true),
90
91/// Checks the given mask, and determines whether said mask is interleaving.
92///
93/// To be interleaving, a mask must alternate between `i` and `i + (Length /
94/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
95/// 4x vector interleaving mask would be <0, 2, 1, 3>).
96static bool isInterleavingMask(ArrayRef<int> Mask);
97
98/// Checks the given mask, and determines whether said mask is deinterleaving.
99///
100/// To be deinterleaving, a mask must increment in steps of 2, and either start
101/// with 0 or 1.
102/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
103/// <1, 3, 5, 7>).
104static bool isDeinterleavingMask(ArrayRef<int> Mask);
105
106/// Returns true if the operation is a negation of V, and it works for both
107/// integers and floats.
108static bool isNeg(Value *V);
109
110/// Returns the operand for negation operation.
111static Value *getNegOperand(Value *V);
112
113namespace {
114struct ComplexValue {
115 Value *Real = nullptr;
116 Value *Imag = nullptr;
117
118 bool operator==(const ComplexValue &Other) const {
119 return Real == Other.Real && Imag == Other.Imag;
120 }
121};
122hash_code hash_value(const ComplexValue &Arg) {
125}
126} // end namespace
128
129namespace llvm {
130template <> struct DenseMapInfo<ComplexValue> {
131 static inline ComplexValue getEmptyKey() {
134 }
135 static inline ComplexValue getTombstoneKey() {
138 }
139 static unsigned getHashValue(const ComplexValue &Val) {
142 }
143 static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
144 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
145 }
146};
147} // end namespace llvm
148
149namespace {
150template <typename T, typename IterT>
151std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
152 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
153 if (Common != A.end())
154 return std::make_optional(*Common);
155 return std::nullopt;
156}
157
158class ComplexDeinterleavingLegacyPass : public FunctionPass {
159public:
160 static char ID;
161
162 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
163 : FunctionPass(ID), TM(TM) {
166 }
167
168 StringRef getPassName() const override {
169 return "Complex Deinterleaving Pass";
170 }
171
172 bool runOnFunction(Function &F) override;
173 void getAnalysisUsage(AnalysisUsage &AU) const override {
174 AU.addRequired<TargetLibraryInfoWrapperPass>();
175 AU.setPreservesCFG();
176 }
177
178private:
179 const TargetMachine *TM;
180};
181
182class ComplexDeinterleavingGraph;
183struct ComplexDeinterleavingCompositeNode {
184
185 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
186 Value *R, Value *I)
187 : Operation(Op) {
188 Vals.push_back({R, I});
189 }
190
191 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
193 : Operation(Op), Vals(Other) {}
194
195private:
196 friend class ComplexDeinterleavingGraph;
197 using CompositeNode = ComplexDeinterleavingCompositeNode;
198 bool OperandsValid = true;
199
200public:
202 ComplexValues Vals;
203
204 // This two members are required exclusively for generating
205 // ComplexDeinterleavingOperation::Symmetric operations.
206 unsigned Opcode;
207 std::optional<FastMathFlags> Flags;
208
210 ComplexDeinterleavingRotation::Rotation_0;
212 Value *ReplacementNode = nullptr;
213
214 void addOperand(CompositeNode *Node) {
215 if (!Node)
216 OperandsValid = false;
217 Operands.push_back(Node);
218 }
219
220 void dump() { dump(dbgs()); }
221 void dump(raw_ostream &OS) {
222 auto PrintValue = [&](Value *V) {
223 if (V) {
224 OS << "\"";
225 V->print(OS, true);
226 OS << "\"\n";
227 } else
228 OS << "nullptr\n";
229 };
230 auto PrintNodeRef = [&](CompositeNode *Ptr) {
231 if (Ptr)
232 OS << Ptr << "\n";
233 else
234 OS << "nullptr\n";
235 };
236
237 OS << "- CompositeNode: " << this << "\n";
238 for (unsigned I = 0; I < Vals.size(); I++) {
239 OS << " Real(" << I << ") : ";
240 PrintValue(Vals[I].Real);
241 OS << " Imag(" << I << ") : ";
242 PrintValue(Vals[I].Imag);
243 }
244 OS << " ReplacementNode: ";
245 PrintValue(ReplacementNode);
246 OS << " Operation: " << (int)Operation << "\n";
247 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
248 OS << " Operands: \n";
249 for (const auto &Op : Operands) {
250 OS << " - ";
251 PrintNodeRef(Op);
252 }
253 }
254
255 bool areOperandsValid() { return OperandsValid; }
256};
257
258class ComplexDeinterleavingGraph {
259public:
260 struct Product {
261 Value *Multiplier;
262 Value *Multiplicand;
263 bool IsPositive;
264 };
265
266 using Addend = std::pair<Value *, bool>;
267 using AddendList = BumpPtrList<Addend>;
268 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
269
270 // Helper struct for holding info about potential partial multiplication
271 // candidates
272 struct PartialMulCandidate {
273 Value *Common;
274 CompositeNode *Node;
275 unsigned RealIdx;
276 unsigned ImagIdx;
277 bool IsNodeInverted;
278 };
279
280 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
281 const TargetLibraryInfo *TLI,
282 unsigned Factor)
283 : TL(TL), TLI(TLI), Factor(Factor) {}
284
285private:
286 const TargetLowering *TL = nullptr;
287 const TargetLibraryInfo *TLI = nullptr;
288 unsigned Factor;
289 SmallVector<CompositeNode *> CompositeNodes;
290 DenseMap<ComplexValues, CompositeNode *> CachedResult;
291 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
292
293 SmallPtrSet<Instruction *, 16> FinalInstructions;
294
295 /// Root instructions are instructions from which complex computation starts
296 DenseMap<Instruction *, CompositeNode *> RootToNode;
297
298 /// Topologically sorted root instructions
300
301 /// When examining a basic block for complex deinterleaving, if it is a simple
302 /// one-block loop, then the only incoming block is 'Incoming' and the
303 /// 'BackEdge' block is the block itself."
304 BasicBlock *BackEdge = nullptr;
305 BasicBlock *Incoming = nullptr;
306
307 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
308 /// %OutsideUser as it is shown in the IR:
309 ///
310 /// vector.body:
311 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
312 /// [ %ReductionOp, %vector.body ]
313 /// ...
314 /// %ReductionOp = fadd i64 ...
315 /// ...
316 /// br i1 %condition, label %vector.body, %middle.block
317 ///
318 /// middle.block:
319 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
320 ///
321 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
322 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
323 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
324
325 /// In the process of detecting a reduction, we consider a pair of
326 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
327 /// traverse the use-tree to detect complex operations. As this is a reduction
328 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
329 /// to the %ReductionOPs that we suspect to be complex.
330 /// RealPHI and ImagPHI are used by the identifyPHINode method.
331 PHINode *RealPHI = nullptr;
332 PHINode *ImagPHI = nullptr;
333
334 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
335 /// detection.
336 bool PHIsFound = false;
337
338 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
339 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
340 /// This mapping is populated during
341 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
342 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
343 /// replacement process.
344 DenseMap<PHINode *, PHINode *> OldToNewPHI;
345
346 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
347 Value *R, Value *I) {
348 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
349 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
350 (R && I)) &&
351 "Reduction related nodes must have Real and Imaginary parts");
352 return new (Allocator.Allocate())
353 ComplexDeinterleavingCompositeNode(Operation, R, I);
354 }
355
356 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
357 ComplexValues &Vals) {
358#ifndef NDEBUG
359 for (auto &V : Vals) {
360 assert(
361 ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
362 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
363 (V.Real && V.Imag)) &&
364 "Reduction related nodes must have Real and Imaginary parts");
365 }
366#endif
367 return new (Allocator.Allocate())
368 ComplexDeinterleavingCompositeNode(Operation, Vals);
369 }
370
371 CompositeNode *submitCompositeNode(CompositeNode *Node) {
372 CompositeNodes.push_back(Node);
373 if (Node->Vals[0].Real)
374 CachedResult[Node->Vals] = Node;
375 return Node;
376 }
377
378 /// Identifies a complex partial multiply pattern and its rotation, based on
379 /// the following patterns
380 ///
381 /// 0: r: cr + ar * br
382 /// i: ci + ar * bi
383 /// 90: r: cr - ai * bi
384 /// i: ci + ai * br
385 /// 180: r: cr - ar * br
386 /// i: ci - ar * bi
387 /// 270: r: cr + ai * bi
388 /// i: ci - ai * br
389 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
390
391 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
392 /// is partially known from identifyPartialMul, filling in the other half of
393 /// the complex pair.
394 CompositeNode *
395 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
396 std::pair<Value *, Value *> &CommonOperandI);
397
398 /// Identifies a complex add pattern and its rotation, based on the following
399 /// patterns.
400 ///
401 /// 90: r: ar - bi
402 /// i: ai + br
403 /// 270: r: ar + bi
404 /// i: ai - br
405 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
406 CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
407 CompositeNode *identifyPartialReduction(Value *R, Value *I);
408 CompositeNode *identifyDotProduct(Value *Inst);
409
410 CompositeNode *identifyNode(ComplexValues &Vals);
411
412 CompositeNode *identifyNode(Value *R, Value *I) {
413 ComplexValues Vals;
414 Vals.push_back({R, I});
415 return identifyNode(Vals);
416 }
417
418 /// Determine if a sum of complex numbers can be formed from \p RealAddends
419 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
420 /// Return nullptr if it is not possible to construct a complex number.
421 /// \p Flags are needed to generate symmetric Add and Sub operations.
422 CompositeNode *identifyAdditions(AddendList &RealAddends,
423 AddendList &ImagAddends,
424 std::optional<FastMathFlags> Flags,
425 CompositeNode *Accumulator);
426
427 /// Extract one addend that have both real and imaginary parts positive.
428 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
429 AddendList &ImagAddends);
430
431 /// Determine if sum of multiplications of complex numbers can be formed from
432 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
433 /// to it. Return nullptr if it is not possible to construct a complex number.
434 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
435 SmallVectorImpl<Product> &ImagMuls,
436 CompositeNode *Accumulator);
437
438 /// Go through pairs of multiplication (one Real and one Imag) and find all
439 /// possible candidates for partial multiplication and put them into \p
440 /// Candidates. Returns true if all Product has pair with common operand
441 bool collectPartialMuls(ArrayRef<Product> RealMuls,
442 ArrayRef<Product> ImagMuls,
443 SmallVectorImpl<PartialMulCandidate> &Candidates);
444
445 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
446 /// the order of complex computation operations may be significantly altered,
447 /// and the real and imaginary parts may not be executed in parallel. This
448 /// function takes this into consideration and employs a more general approach
449 /// to identify complex computations. Initially, it gathers all the addends
450 /// and multiplicands and then constructs a complex expression from them.
451 CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);
452
453 CompositeNode *identifyRoot(Instruction *I);
454
455 /// Identifies the Deinterleave operation applied to a vector containing
456 /// complex numbers. There are two ways to represent the Deinterleave
457 /// operation:
458 /// * Using two shufflevectors with even indices for /pReal instruction and
459 /// odd indices for /pImag instructions (only for fixed-width vectors)
460 /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
461 /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
462 /// 2.
463 CompositeNode *identifyDeinterleave(ComplexValues &Vals);
464
465 /// identifying the operation that represents a complex number repeated in a
466 /// Splat vector. There are two possible types of splats: ConstantExpr with
467 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
468 /// initialization mask with all values set to zero.
469 CompositeNode *identifySplat(ComplexValues &Vals);
470
471 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
472
473 /// Identifies SelectInsts in a loop that has reduction with predication masks
474 /// and/or predicated tail folding
475 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
476
477 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
478
479 /// Complete IR modifications after producing new reduction operation:
480 /// * Populate the PHINode generated for
481 /// ComplexDeinterleavingOperation::ReductionPHI
482 /// * Deinterleave the final value outside of the loop and repurpose original
483 /// reduction users
484 void processReductionOperation(Value *OperationReplacement,
485 CompositeNode *Node);
486 void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);
487
488public:
489 void dump() { dump(dbgs()); }
490 void dump(raw_ostream &OS) {
491 for (const auto &Node : CompositeNodes)
492 Node->dump(OS);
493 }
494
495 /// Returns false if the deinterleaving operation should be cancelled for the
496 /// current graph.
497 bool identifyNodes(Instruction *RootI);
498
499 /// In case \pB is one-block loop, this function seeks potential reductions
500 /// and populates ReductionInfo. Returns true if any reductions were
501 /// identified.
502 bool collectPotentialReductions(BasicBlock *B);
503
504 void identifyReductionNodes();
505
506 /// Check that every instruction, from the roots to the leaves, has internal
507 /// uses.
508 bool checkNodes();
509
510 /// Perform the actual replacement of the underlying instruction graph.
511 void replaceNodes();
512};
513
514class ComplexDeinterleaving {
515public:
516 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
517 : TL(tl), TLI(tli) {}
518 bool runOnFunction(Function &F);
519
520private:
521 bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);
522
523 const TargetLowering *TL = nullptr;
524 const TargetLibraryInfo *TLI = nullptr;
525};
526
527} // namespace
528
529char ComplexDeinterleavingLegacyPass::ID = 0;
530
531INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
532 "Complex Deinterleaving", false, false)
533INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
534 "Complex Deinterleaving", false, false)
535
538 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
539 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
540 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
541 return PreservedAnalyses::all();
542
545 return PA;
546}
547
549 return new ComplexDeinterleavingLegacyPass(TM);
550}
551
552bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
553 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
554 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
555 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
556}
557
558bool ComplexDeinterleaving::runOnFunction(Function &F) {
561 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
562 return false;
563 }
564
567 dbgs() << "Complex deinterleaving has been disabled, target does "
568 "not support lowering of complex number operations.\n");
569 return false;
570 }
571
572 bool Changed = false;
573 for (auto &B : F)
574 Changed |= evaluateBasicBlock(&B, 2);
575
576 // TODO: Permit changes for both interleave factors in the same function.
577 if (!Changed) {
578 for (auto &B : F)
579 Changed |= evaluateBasicBlock(&B, 4);
580 }
581
582 // TODO: We can also support interleave factors of 6 and 8 if needed.
583
584 return Changed;
585}
586
588 // If the size is not even, it's not an interleaving mask
589 if ((Mask.size() & 1))
590 return false;
591
592 int HalfNumElements = Mask.size() / 2;
593 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
594 int MaskIdx = Idx * 2;
595 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
596 return false;
597 }
598
599 return true;
600}
601
603 int Offset = Mask[0];
604 int HalfNumElements = Mask.size() / 2;
605
606 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
607 if (Mask[Idx] != (Idx * 2) + Offset)
608 return false;
609 }
610
611 return true;
612}
613
614bool isNeg(Value *V) {
615 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
616}
617
619 assert(isNeg(V));
620 auto *I = cast<Instruction>(V);
621 if (I->getOpcode() == Instruction::FNeg)
622 return I->getOperand(0);
623
624 return I->getOperand(1);
625}
626
627bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
628 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
629 if (Graph.collectPotentialReductions(B))
630 Graph.identifyReductionNodes();
631
632 for (auto &I : *B)
633 Graph.identifyNodes(&I);
634
635 if (Graph.checkNodes()) {
636 Graph.replaceNodes();
637 return true;
638 }
639
640 return false;
641}
642
643ComplexDeinterleavingGraph::CompositeNode *
644ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
645 Instruction *Real, Instruction *Imag,
646 std::pair<Value *, Value *> &PartialMatch) {
647 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
648 << "\n");
649
650 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
651 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
652 return nullptr;
653 }
654
655 if ((Real->getOpcode() != Instruction::FMul &&
656 Real->getOpcode() != Instruction::Mul) ||
657 (Imag->getOpcode() != Instruction::FMul &&
658 Imag->getOpcode() != Instruction::Mul)) {
660 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
661 return nullptr;
662 }
663
664 Value *R0 = Real->getOperand(0);
665 Value *R1 = Real->getOperand(1);
666 Value *I0 = Imag->getOperand(0);
667 Value *I1 = Imag->getOperand(1);
668
669 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
670 // rotations and use the operand.
671 unsigned Negs = 0;
672 Value *Op;
673 if (match(R0, m_Neg(m_Value(Op)))) {
674 Negs |= 1;
675 R0 = Op;
676 } else if (match(R1, m_Neg(m_Value(Op)))) {
677 Negs |= 1;
678 R1 = Op;
679 }
680
681 if (isNeg(I0)) {
682 Negs |= 2;
683 Negs ^= 1;
684 I0 = Op;
685 } else if (match(I1, m_Neg(m_Value(Op)))) {
686 Negs |= 2;
687 Negs ^= 1;
688 I1 = Op;
689 }
690
692
693 Value *CommonOperand;
694 Value *UncommonRealOp;
695 Value *UncommonImagOp;
696
697 if (R0 == I0 || R0 == I1) {
698 CommonOperand = R0;
699 UncommonRealOp = R1;
700 } else if (R1 == I0 || R1 == I1) {
701 CommonOperand = R1;
702 UncommonRealOp = R0;
703 } else {
704 LLVM_DEBUG(dbgs() << " - No equal operand\n");
705 return nullptr;
706 }
707
708 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
709 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
710 Rotation == ComplexDeinterleavingRotation::Rotation_270)
711 std::swap(UncommonRealOp, UncommonImagOp);
712
713 // Between identifyPartialMul and here we need to have found a complete valid
714 // pair from the CommonOperand of each part.
715 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
716 Rotation == ComplexDeinterleavingRotation::Rotation_180)
717 PartialMatch.first = CommonOperand;
718 else
719 PartialMatch.second = CommonOperand;
720
721 if (!PartialMatch.first || !PartialMatch.second) {
722 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
723 return nullptr;
724 }
725
726 CompositeNode *CommonNode =
727 identifyNode(PartialMatch.first, PartialMatch.second);
728 if (!CommonNode) {
729 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
730 return nullptr;
731 }
732
733 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
734 if (!UncommonNode) {
735 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
736 return nullptr;
737 }
738
739 CompositeNode *Node = prepareCompositeNode(
740 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
741 Node->Rotation = Rotation;
742 Node->addOperand(CommonNode);
743 Node->addOperand(UncommonNode);
744 return submitCompositeNode(Node);
745}
746
747ComplexDeinterleavingGraph::CompositeNode *
748ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
749 Instruction *Imag) {
750 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
751 << "\n");
752
753 // Determine rotation
754 auto IsAdd = [](unsigned Op) {
755 return Op == Instruction::FAdd || Op == Instruction::Add;
756 };
757 auto IsSub = [](unsigned Op) {
758 return Op == Instruction::FSub || Op == Instruction::Sub;
759 };
761 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
762 Rotation = ComplexDeinterleavingRotation::Rotation_0;
763 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
764 Rotation = ComplexDeinterleavingRotation::Rotation_90;
765 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
766 Rotation = ComplexDeinterleavingRotation::Rotation_180;
767 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
768 Rotation = ComplexDeinterleavingRotation::Rotation_270;
769 else {
770 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
771 return nullptr;
772 }
773
774 if (isa<FPMathOperator>(Real) &&
775 (!Real->getFastMathFlags().allowContract() ||
776 !Imag->getFastMathFlags().allowContract())) {
777 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
778 return nullptr;
779 }
780
781 Value *CR = Real->getOperand(0);
782 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
783 if (!RealMulI)
784 return nullptr;
785 Value *CI = Imag->getOperand(0);
786 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
787 if (!ImagMulI)
788 return nullptr;
789
790 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
791 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
792 return nullptr;
793 }
794
795 Value *R0 = RealMulI->getOperand(0);
796 Value *R1 = RealMulI->getOperand(1);
797 Value *I0 = ImagMulI->getOperand(0);
798 Value *I1 = ImagMulI->getOperand(1);
799
800 Value *CommonOperand;
801 Value *UncommonRealOp;
802 Value *UncommonImagOp;
803
804 if (R0 == I0 || R0 == I1) {
805 CommonOperand = R0;
806 UncommonRealOp = R1;
807 } else if (R1 == I0 || R1 == I1) {
808 CommonOperand = R1;
809 UncommonRealOp = R0;
810 } else {
811 LLVM_DEBUG(dbgs() << " - No equal operand\n");
812 return nullptr;
813 }
814
815 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
816 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
817 Rotation == ComplexDeinterleavingRotation::Rotation_270)
818 std::swap(UncommonRealOp, UncommonImagOp);
819
820 std::pair<Value *, Value *> PartialMatch(
821 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
822 Rotation == ComplexDeinterleavingRotation::Rotation_180)
823 ? CommonOperand
824 : nullptr,
825 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
826 Rotation == ComplexDeinterleavingRotation::Rotation_270)
827 ? CommonOperand
828 : nullptr);
829
830 auto *CRInst = dyn_cast<Instruction>(CR);
831 auto *CIInst = dyn_cast<Instruction>(CI);
832
833 if (!CRInst || !CIInst) {
834 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
835 return nullptr;
836 }
837
838 CompositeNode *CNode =
839 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
840 if (!CNode) {
841 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
842 return nullptr;
843 }
844
845 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
846 if (!UncommonRes) {
847 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
848 return nullptr;
849 }
850
851 assert(PartialMatch.first && PartialMatch.second);
852 CompositeNode *CommonRes =
853 identifyNode(PartialMatch.first, PartialMatch.second);
854 if (!CommonRes) {
855 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
856 return nullptr;
857 }
858
859 CompositeNode *Node = prepareCompositeNode(
860 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
861 Node->Rotation = Rotation;
862 Node->addOperand(CommonRes);
863 Node->addOperand(UncommonRes);
864 Node->addOperand(CNode);
865 return submitCompositeNode(Node);
866}
867
868ComplexDeinterleavingGraph::CompositeNode *
869ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
870 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
871
872 // Determine rotation
874 if ((Real->getOpcode() == Instruction::FSub &&
875 Imag->getOpcode() == Instruction::FAdd) ||
876 (Real->getOpcode() == Instruction::Sub &&
877 Imag->getOpcode() == Instruction::Add))
878 Rotation = ComplexDeinterleavingRotation::Rotation_90;
879 else if ((Real->getOpcode() == Instruction::FAdd &&
880 Imag->getOpcode() == Instruction::FSub) ||
881 (Real->getOpcode() == Instruction::Add &&
882 Imag->getOpcode() == Instruction::Sub))
883 Rotation = ComplexDeinterleavingRotation::Rotation_270;
884 else {
885 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
886 return nullptr;
887 }
888
889 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
890 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
891 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
892 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
893
894 if (!AR || !AI || !BR || !BI) {
895 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
896 return nullptr;
897 }
898
899 CompositeNode *ResA = identifyNode(AR, AI);
900 if (!ResA) {
901 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
902 return nullptr;
903 }
904 CompositeNode *ResB = identifyNode(BR, BI);
905 if (!ResB) {
906 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
907 return nullptr;
908 }
909
910 CompositeNode *Node =
911 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
912 Node->Rotation = Rotation;
913 Node->addOperand(ResA);
914 Node->addOperand(ResB);
915 return submitCompositeNode(Node);
916}
917
919 unsigned OpcA = A->getOpcode();
920 unsigned OpcB = B->getOpcode();
921
922 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
923 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
924 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
925 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
926}
927
929 auto Pattern =
931
932 return match(A, Pattern) && match(B, Pattern);
933}
934
936 switch (I->getOpcode()) {
937 case Instruction::FAdd:
938 case Instruction::FSub:
939 case Instruction::FMul:
940 case Instruction::FNeg:
941 case Instruction::Add:
942 case Instruction::Sub:
943 case Instruction::Mul:
944 return true;
945 default:
946 return false;
947 }
948}
949
950ComplexDeinterleavingGraph::CompositeNode *
951ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
952 auto *FirstReal = cast<Instruction>(Vals[0].Real);
953 unsigned FirstOpc = FirstReal->getOpcode();
954 for (auto &V : Vals) {
955 auto *Real = cast<Instruction>(V.Real);
956 auto *Imag = cast<Instruction>(V.Imag);
957 if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
958 return nullptr;
959
962 return nullptr;
963
964 if (isa<FPMathOperator>(FirstReal))
965 if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
966 Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
967 return nullptr;
968 }
969
970 ComplexValues OpVals;
971 for (auto &V : Vals) {
972 auto *R0 = cast<Instruction>(V.Real)->getOperand(0);
973 auto *I0 = cast<Instruction>(V.Imag)->getOperand(0);
974 OpVals.push_back({R0, I0});
975 }
976
977 CompositeNode *Op0 = identifyNode(OpVals);
978 CompositeNode *Op1 = nullptr;
979 if (Op0 == nullptr)
980 return nullptr;
981
982 if (FirstReal->isBinaryOp()) {
983 OpVals.clear();
984 for (auto &V : Vals) {
985 auto *R1 = cast<Instruction>(V.Real)->getOperand(1);
986 auto *I1 = cast<Instruction>(V.Imag)->getOperand(1);
987 OpVals.push_back({R1, I1});
988 }
989 Op1 = identifyNode(OpVals);
990 if (Op1 == nullptr)
991 return nullptr;
992 }
993
994 auto Node =
995 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
996 Node->Opcode = FirstReal->getOpcode();
997 if (isa<FPMathOperator>(FirstReal))
998 Node->Flags = FirstReal->getFastMathFlags();
999
1000 Node->addOperand(Op0);
1001 if (FirstReal->isBinaryOp())
1002 Node->addOperand(Op1);
1003
1004 return submitCompositeNode(Node);
1005}
1006
1007ComplexDeinterleavingGraph::CompositeNode *
1008ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
1010 ComplexDeinterleavingOperation::CDot, V->getType())) {
1011 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
1012 "operation CDot with the type "
1013 << *V->getType() << "\n");
1014 return nullptr;
1015 }
1016
1017 auto *Inst = cast<Instruction>(V);
1018 auto *RealUser = cast<Instruction>(*Inst->user_begin());
1019
1020 CompositeNode *CN =
1021 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
1022
1023 CompositeNode *ANode = nullptr;
1024
1025 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1026
1027 Value *AReal = nullptr;
1028 Value *AImag = nullptr;
1029 Value *BReal = nullptr;
1030 Value *BImag = nullptr;
1031 Value *Phi = nullptr;
1032
1033 auto UnwrapCast = [](Value *V) -> Value * {
1034 if (auto *CI = dyn_cast<CastInst>(V))
1035 return CI->getOperand(0);
1036 return V;
1037 };
1038
1039 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
1041 m_Mul(m_Value(BReal), m_Value(AReal))),
1042 m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
1043
1044 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
1046 m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
1047 m_Mul(m_Value(BImag), m_Value(AReal)));
1048
1049 if (match(Inst, PatternRot0)) {
1050 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1051 } else if (match(Inst, PatternRot270)) {
1052 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1053 } else {
1054 Value *A0, *A1;
1055 // The rotations 90 and 180 share the same operation pattern, so inspect the
1056 // order of the operands, identifying where the real and imaginary
1057 // components of A go, to discern between the aforementioned rotations.
1058 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
1060 m_Mul(m_Value(BReal), m_Value(A0))),
1061 m_Mul(m_Value(BImag), m_Value(A1)));
1062
1063 if (!match(Inst, PatternRot90Rot180))
1064 return nullptr;
1065
1066 A0 = UnwrapCast(A0);
1067 A1 = UnwrapCast(A1);
1068
1069 // Test if A0 is real/A1 is imag
1070 ANode = identifyNode(A0, A1);
1071 if (!ANode) {
1072 // Test if A0 is imag/A1 is real
1073 ANode = identifyNode(A1, A0);
1074 // Unable to identify operand components, thus unable to identify rotation
1075 if (!ANode)
1076 return nullptr;
1077 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1078 AReal = A1;
1079 AImag = A0;
1080 } else {
1081 AReal = A0;
1082 AImag = A1;
1083 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1084 }
1085 }
1086
1087 AReal = UnwrapCast(AReal);
1088 AImag = UnwrapCast(AImag);
1089 BReal = UnwrapCast(BReal);
1090 BImag = UnwrapCast(BImag);
1091
1092 VectorType *VTy = cast<VectorType>(V->getType());
1093 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1094 if (AReal->getType() != ExpectedOperandTy)
1095 return nullptr;
1096 if (AImag->getType() != ExpectedOperandTy)
1097 return nullptr;
1098 if (BReal->getType() != ExpectedOperandTy)
1099 return nullptr;
1100 if (BImag->getType() != ExpectedOperandTy)
1101 return nullptr;
1102
1103 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1104 return nullptr;
1105
1106 CompositeNode *Node = identifyNode(AReal, AImag);
1107
1108 // In the case that a node was identified to figure out the rotation, ensure
1109 // that trying to identify a node with AReal and AImag post-unwrap results in
1110 // the same node
1111 if (ANode && Node != ANode) {
1112 LLVM_DEBUG(
1113 dbgs()
1114 << "Identified node is different from previously identified node. "
1115 "Unable to confidently generate a complex operation node\n");
1116 return nullptr;
1117 }
1118
1119 CN->addOperand(Node);
1120 CN->addOperand(identifyNode(BReal, BImag));
1121 CN->addOperand(identifyNode(Phi, RealUser));
1122
1123 return submitCompositeNode(CN);
1124}
1125
1126ComplexDeinterleavingGraph::CompositeNode *
1127ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1128 // Partial reductions don't support non-vector types, so check these first
1129 if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
1130 return nullptr;
1131
1132 if (!R->hasUseList() || !I->hasUseList())
1133 return nullptr;
1134
1135 auto CommonUser =
1136 findCommonBetweenCollections<Value *>(R->users(), I->users());
1137 if (!CommonUser)
1138 return nullptr;
1139
1140 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1141 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1142 return nullptr;
1143
1144 if (CompositeNode *CN = identifyDotProduct(IInst))
1145 return CN;
1146
1147 return nullptr;
1148}
1149
1150ComplexDeinterleavingGraph::CompositeNode *
1151ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
1152 auto It = CachedResult.find(Vals);
1153 if (It != CachedResult.end()) {
1154 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1155 return It->second;
1156 }
1157
1158 if (Vals.size() == 1) {
1159 assert(Factor == 2 && "Can only handle interleave factors of 2");
1160 Value *R = Vals[0].Real;
1161 Value *I = Vals[0].Imag;
1162 if (CompositeNode *CN = identifyPartialReduction(R, I))
1163 return CN;
1164 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1165 if (!IsReduction && R->getType() != I->getType())
1166 return nullptr;
1167 }
1168
1169 if (CompositeNode *CN = identifySplat(Vals))
1170 return CN;
1171
1172 for (auto &V : Vals) {
1173 auto *Real = dyn_cast<Instruction>(V.Real);
1174 auto *Imag = dyn_cast<Instruction>(V.Imag);
1175 if (!Real || !Imag)
1176 return nullptr;
1177 }
1178
1179 if (CompositeNode *CN = identifyDeinterleave(Vals))
1180 return CN;
1181
1182 if (Vals.size() == 1) {
1183 assert(Factor == 2 && "Can only handle interleave factors of 2");
1184 auto *Real = dyn_cast<Instruction>(Vals[0].Real);
1185 auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
1186 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1187 return CN;
1188
1189 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1190 return CN;
1191
1192 auto *VTy = cast<VectorType>(Real->getType());
1193 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1194
1195 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1196 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1197 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1198 ComplexDeinterleavingOperation::CAdd, NewVTy);
1199
1200 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
1201 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1202 return CN;
1203 }
1204
1205 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
1206 if (CompositeNode *CN = identifyAdd(Real, Imag))
1207 return CN;
1208 }
1209
1210 if (HasCMulSupport && HasCAddSupport) {
1211 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1212 return CN;
1213 }
1214 }
1215 }
1216
1217 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1218 return CN;
1219
1220 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1221 CachedResult[Vals] = nullptr;
1222 return nullptr;
1223}
1224
1225ComplexDeinterleavingGraph::CompositeNode *
1226ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1227 Instruction *Imag) {
1228 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1229 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1230 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1231 Opcode == Instruction::Sub;
1232 };
1233
1234 if (!IsOperationSupported(Real->getOpcode()) ||
1235 !IsOperationSupported(Imag->getOpcode()))
1236 return nullptr;
1237
1238 std::optional<FastMathFlags> Flags;
1239 if (isa<FPMathOperator>(Real)) {
1240 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1241 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1242 "not identical\n");
1243 return nullptr;
1244 }
1245
1246 Flags = Real->getFastMathFlags();
1247 if (!Flags->allowReassoc()) {
1248 LLVM_DEBUG(
1249 dbgs()
1250 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1251 return nullptr;
1252 }
1253 }
1254
1255 // Collect multiplications and addend instructions from the given instruction
1256 // while traversing it operands. Additionally, verify that all instructions
1257 // have the same fast math flags.
1258 auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
1259 AddendList &Addends) -> bool {
1260 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1261 SmallPtrSet<Value *, 8> Visited;
1262 while (!Worklist.empty()) {
1263 auto [V, IsPositive] = Worklist.pop_back_val();
1264 if (!Visited.insert(V).second)
1265 continue;
1266
1268 if (!I) {
1269 Addends.emplace_back(V, IsPositive);
1270 continue;
1271 }
1272
1273 // If an instruction has more than one user, it indicates that it either
1274 // has an external user, which will be later checked by the checkNodes
1275 // function, or it is a subexpression utilized by multiple expressions. In
1276 // the latter case, we will attempt to separately identify the complex
1277 // operation from here in order to create a shared
1278 // ComplexDeinterleavingCompositeNode.
1279 if (I != Insn && I->hasNUsesOrMore(2)) {
1280 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1281 Addends.emplace_back(I, IsPositive);
1282 continue;
1283 }
1284 switch (I->getOpcode()) {
1285 case Instruction::FAdd:
1286 case Instruction::Add:
1287 Worklist.emplace_back(I->getOperand(1), IsPositive);
1288 Worklist.emplace_back(I->getOperand(0), IsPositive);
1289 break;
1290 case Instruction::FSub:
1291 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1292 Worklist.emplace_back(I->getOperand(0), IsPositive);
1293 break;
1294 case Instruction::Sub:
1295 if (isNeg(I)) {
1296 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1297 } else {
1298 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1299 Worklist.emplace_back(I->getOperand(0), IsPositive);
1300 }
1301 break;
1302 case Instruction::FMul:
1303 case Instruction::Mul: {
1304 Value *A, *B;
1305 if (isNeg(I->getOperand(0))) {
1306 A = getNegOperand(I->getOperand(0));
1307 IsPositive = !IsPositive;
1308 } else {
1309 A = I->getOperand(0);
1310 }
1311
1312 if (isNeg(I->getOperand(1))) {
1313 B = getNegOperand(I->getOperand(1));
1314 IsPositive = !IsPositive;
1315 } else {
1316 B = I->getOperand(1);
1317 }
1318 Muls.push_back(Product{A, B, IsPositive});
1319 break;
1320 }
1321 case Instruction::FNeg:
1322 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1323 break;
1324 default:
1325 Addends.emplace_back(I, IsPositive);
1326 continue;
1327 }
1328
1329 if (Flags && I->getFastMathFlags() != *Flags) {
1330 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1331 "inconsistent with the root instructions' flags: "
1332 << *I << "\n");
1333 return false;
1334 }
1335 }
1336 return true;
1337 };
1338
1339 SmallVector<Product> RealMuls, ImagMuls;
1340 AddendList RealAddends, ImagAddends;
1341 if (!Collect(Real, RealMuls, RealAddends) ||
1342 !Collect(Imag, ImagMuls, ImagAddends))
1343 return nullptr;
1344
1345 if (RealAddends.size() != ImagAddends.size())
1346 return nullptr;
1347
1348 CompositeNode *FinalNode = nullptr;
1349 if (!RealMuls.empty() || !ImagMuls.empty()) {
1350 // If there are multiplicands, extract positive addend and use it as an
1351 // accumulator
1352 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1353 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1354 if (!FinalNode)
1355 return nullptr;
1356 }
1357
1358 // Identify and process remaining additions
1359 if (!RealAddends.empty() || !ImagAddends.empty()) {
1360 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1361 if (!FinalNode)
1362 return nullptr;
1363 }
1364 assert(FinalNode && "FinalNode can not be nullptr here");
1365 assert(FinalNode->Vals.size() == 1);
1366 // Set the Real and Imag fields of the final node and submit it
1367 FinalNode->Vals[0].Real = Real;
1368 FinalNode->Vals[0].Imag = Imag;
1369 submitCompositeNode(FinalNode);
1370 return FinalNode;
1371}
1372
1373bool ComplexDeinterleavingGraph::collectPartialMuls(
1374 ArrayRef<Product> RealMuls, ArrayRef<Product> ImagMuls,
1375 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1376 // Helper function to extract a common operand from two products
1377 auto FindCommonInstruction = [](const Product &Real,
1378 const Product &Imag) -> Value * {
1379 if (Real.Multiplicand == Imag.Multiplicand ||
1380 Real.Multiplicand == Imag.Multiplier)
1381 return Real.Multiplicand;
1382
1383 if (Real.Multiplier == Imag.Multiplicand ||
1384 Real.Multiplier == Imag.Multiplier)
1385 return Real.Multiplier;
1386
1387 return nullptr;
1388 };
1389
1390 // Iterating over real and imaginary multiplications to find common operands
1391 // If a common operand is found, a partial multiplication candidate is created
1392 // and added to the candidates vector The function returns false if no common
1393 // operands are found for any product
1394 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1395 bool FoundCommon = false;
1396 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1397 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1398 if (!Common)
1399 continue;
1400
1401 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1402 : RealMuls[i].Multiplicand;
1403 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1404 : ImagMuls[j].Multiplicand;
1405
1406 auto Node = identifyNode(A, B);
1407 if (Node) {
1408 FoundCommon = true;
1409 PartialMulCandidates.push_back({Common, Node, i, j, false});
1410 }
1411
1412 Node = identifyNode(B, A);
1413 if (Node) {
1414 FoundCommon = true;
1415 PartialMulCandidates.push_back({Common, Node, i, j, true});
1416 }
1417 }
1418 if (!FoundCommon)
1419 return false;
1420 }
1421 return true;
1422}
1423
1424ComplexDeinterleavingGraph::CompositeNode *
1425ComplexDeinterleavingGraph::identifyMultiplications(
1426 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1427 CompositeNode *Accumulator = nullptr) {
1428 if (RealMuls.size() != ImagMuls.size())
1429 return nullptr;
1430
1432 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1433 return nullptr;
1434
1435 // Map to store common instruction to node pointers
1436 DenseMap<Value *, CompositeNode *> CommonToNode;
1437 SmallVector<bool> Processed(Info.size(), false);
1438 for (unsigned I = 0; I < Info.size(); ++I) {
1439 if (Processed[I])
1440 continue;
1441
1442 PartialMulCandidate &InfoA = Info[I];
1443 for (unsigned J = I + 1; J < Info.size(); ++J) {
1444 if (Processed[J])
1445 continue;
1446
1447 PartialMulCandidate &InfoB = Info[J];
1448 auto *InfoReal = &InfoA;
1449 auto *InfoImag = &InfoB;
1450
1451 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1452 if (!NodeFromCommon) {
1453 std::swap(InfoReal, InfoImag);
1454 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1455 }
1456 if (!NodeFromCommon)
1457 continue;
1458
1459 CommonToNode[InfoReal->Common] = NodeFromCommon;
1460 CommonToNode[InfoImag->Common] = NodeFromCommon;
1461 Processed[I] = true;
1462 Processed[J] = true;
1463 }
1464 }
1465
1466 SmallVector<bool> ProcessedReal(RealMuls.size(), false);
1467 SmallVector<bool> ProcessedImag(ImagMuls.size(), false);
1468 CompositeNode *Result = Accumulator;
1469 for (auto &PMI : Info) {
1470 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1471 continue;
1472
1473 auto It = CommonToNode.find(PMI.Common);
1474 // TODO: Process independent complex multiplications. Cases like this:
1475 // A.real() * B where both A and B are complex numbers.
1476 if (It == CommonToNode.end()) {
1477 LLVM_DEBUG({
1478 dbgs() << "Unprocessed independent partial multiplication:\n";
1479 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1480 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1481 << " multiplied by " << *Mul->Multiplicand << "\n";
1482 });
1483 return nullptr;
1484 }
1485
1486 auto &RealMul = RealMuls[PMI.RealIdx];
1487 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1488
1489 auto NodeA = It->second;
1490 auto NodeB = PMI.Node;
1491 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1492 // The following table illustrates the relationship between multiplications
1493 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1494 // can see:
1495 //
1496 // Rotation | Real | Imag |
1497 // ---------+--------+--------+
1498 // 0 | x * u | x * v |
1499 // 90 | -y * v | y * u |
1500 // 180 | -x * u | -x * v |
1501 // 270 | y * v | -y * u |
1502 //
1503 // Check if the candidate can indeed be represented by partial
1504 // multiplication
1505 // TODO: Add support for multiplication by complex one
1506 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1507 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1508 continue;
1509
1510 // Determine the rotation based on the multiplications
1512 if (IsMultiplicandReal) {
1513 // Detect 0 and 180 degrees rotation
1514 if (RealMul.IsPositive && ImagMul.IsPositive)
1516 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1518 else
1519 continue;
1520
1521 } else {
1522 // Detect 90 and 270 degrees rotation
1523 if (!RealMul.IsPositive && ImagMul.IsPositive)
1525 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1527 else
1528 continue;
1529 }
1530
1531 LLVM_DEBUG({
1532 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1533 dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
1534 dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
1535 dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
1536 dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
1537 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1538 });
1539
1540 CompositeNode *NodeMul = prepareCompositeNode(
1541 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1542 NodeMul->Rotation = Rotation;
1543 NodeMul->addOperand(NodeA);
1544 NodeMul->addOperand(NodeB);
1545 if (Result)
1546 NodeMul->addOperand(Result);
1547 submitCompositeNode(NodeMul);
1548 Result = NodeMul;
1549 ProcessedReal[PMI.RealIdx] = true;
1550 ProcessedImag[PMI.ImagIdx] = true;
1551 }
1552
1553 // Ensure all products have been processed, if not return nullptr.
1554 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1555 !all_of(ProcessedImag, [](bool V) { return V; })) {
1556
1557 // Dump debug information about which partial multiplications are not
1558 // processed.
1559 LLVM_DEBUG({
1560 dbgs() << "Unprocessed products (Real):\n";
1561 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1562 if (!ProcessedReal[i])
1563 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1564 << *RealMuls[i].Multiplier << " multiplied by "
1565 << *RealMuls[i].Multiplicand << "\n";
1566 }
1567 dbgs() << "Unprocessed products (Imag):\n";
1568 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1569 if (!ProcessedImag[i])
1570 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1571 << *ImagMuls[i].Multiplier << " multiplied by "
1572 << *ImagMuls[i].Multiplicand << "\n";
1573 }
1574 });
1575 return nullptr;
1576 }
1577
1578 return Result;
1579}
1580
1581ComplexDeinterleavingGraph::CompositeNode *
1582ComplexDeinterleavingGraph::identifyAdditions(
1583 AddendList &RealAddends, AddendList &ImagAddends,
1584 std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
1585 if (RealAddends.size() != ImagAddends.size())
1586 return nullptr;
1587
1588 CompositeNode *Result = nullptr;
1589 // If we have accumulator use it as first addend
1590 if (Accumulator)
1592 // Otherwise find an element with both positive real and imaginary parts.
1593 else
1594 Result = extractPositiveAddend(RealAddends, ImagAddends);
1595
1596 if (!Result)
1597 return nullptr;
1598
1599 while (!RealAddends.empty()) {
1600 auto ItR = RealAddends.begin();
1601 auto [R, IsPositiveR] = *ItR;
1602
1603 bool FoundImag = false;
1604 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1605 auto [I, IsPositiveI] = *ItI;
1607 if (IsPositiveR && IsPositiveI)
1608 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1609 else if (!IsPositiveR && IsPositiveI)
1610 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1611 else if (!IsPositiveR && !IsPositiveI)
1612 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1613 else
1614 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1615
1616 CompositeNode *AddNode = nullptr;
1617 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1618 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1619 AddNode = identifyNode(R, I);
1620 } else {
1621 AddNode = identifyNode(I, R);
1622 }
1623 if (AddNode) {
1624 LLVM_DEBUG({
1625 dbgs() << "Identified addition:\n";
1626 dbgs().indent(4) << "X: " << *R << "\n";
1627 dbgs().indent(4) << "Y: " << *I << "\n";
1628 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1629 });
1630
1631 CompositeNode *TmpNode = nullptr;
1633 TmpNode = prepareCompositeNode(
1634 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1635 if (Flags) {
1636 TmpNode->Opcode = Instruction::FAdd;
1637 TmpNode->Flags = *Flags;
1638 } else {
1639 TmpNode->Opcode = Instruction::Add;
1640 }
1641 } else if (Rotation ==
1643 TmpNode = prepareCompositeNode(
1644 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1645 if (Flags) {
1646 TmpNode->Opcode = Instruction::FSub;
1647 TmpNode->Flags = *Flags;
1648 } else {
1649 TmpNode->Opcode = Instruction::Sub;
1650 }
1651 } else {
1652 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1653 nullptr, nullptr);
1654 TmpNode->Rotation = Rotation;
1655 }
1656
1657 TmpNode->addOperand(Result);
1658 TmpNode->addOperand(AddNode);
1659 submitCompositeNode(TmpNode);
1660 Result = TmpNode;
1661 RealAddends.erase(ItR);
1662 ImagAddends.erase(ItI);
1663 FoundImag = true;
1664 break;
1665 }
1666 }
1667 if (!FoundImag)
1668 return nullptr;
1669 }
1670 return Result;
1671}
1672
1673ComplexDeinterleavingGraph::CompositeNode *
1674ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1675 AddendList &ImagAddends) {
1676 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1677 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1678 auto [R, IsPositiveR] = *ItR;
1679 auto [I, IsPositiveI] = *ItI;
1680 if (IsPositiveR && IsPositiveI) {
1681 auto Result = identifyNode(R, I);
1682 if (Result) {
1683 RealAddends.erase(ItR);
1684 ImagAddends.erase(ItI);
1685 return Result;
1686 }
1687 }
1688 }
1689 }
1690 return nullptr;
1691}
1692
1693bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1694 // This potential root instruction might already have been recognized as
1695 // reduction. Because RootToNode maps both Real and Imaginary parts to
1696 // CompositeNode we should choose only one either Real or Imag instruction to
1697 // use as an anchor for generating complex instruction.
1698 auto It = RootToNode.find(RootI);
1699 if (It != RootToNode.end()) {
1700 auto RootNode = It->second;
1701 assert(RootNode->Operation ==
1702 ComplexDeinterleavingOperation::ReductionOperation ||
1703 RootNode->Operation ==
1704 ComplexDeinterleavingOperation::ReductionSingle);
1705 assert(RootNode->Vals.size() == 1 &&
1706 "Cannot handle reductions involving multiple complex values");
1707 // Find out which part, Real or Imag, comes later, and only if we come to
1708 // the latest part, add it to OrderedRoots.
1709 auto *R = cast<Instruction>(RootNode->Vals[0].Real);
1710 auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(RootNode->Vals[0].Imag)
1711 : nullptr;
1712
1713 Instruction *ReplacementAnchor;
1714 if (I)
1715 ReplacementAnchor = R->comesBefore(I) ? I : R;
1716 else
1717 ReplacementAnchor = R;
1718
1719 if (ReplacementAnchor != RootI)
1720 return false;
1721 OrderedRoots.push_back(RootI);
1722 return true;
1723 }
1724
1725 auto RootNode = identifyRoot(RootI);
1726 if (!RootNode)
1727 return false;
1728
1729 LLVM_DEBUG({
1730 Function *F = RootI->getFunction();
1731 BasicBlock *B = RootI->getParent();
1732 dbgs() << "Complex deinterleaving graph for " << F->getName()
1733 << "::" << B->getName() << ".\n";
1734 dump(dbgs());
1735 dbgs() << "\n";
1736 });
1737 RootToNode[RootI] = RootNode;
1738 OrderedRoots.push_back(RootI);
1739 return true;
1740}
1741
1742bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1743 bool FoundPotentialReduction = false;
1744 if (Factor != 2)
1745 return false;
1746
1747 auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1748 if (!Br || Br->getNumSuccessors() != 2)
1749 return false;
1750
1751 // Identify simple one-block loop
1752 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1753 return false;
1754
1755 for (auto &PHI : B->phis()) {
1756 if (PHI.getNumIncomingValues() != 2)
1757 continue;
1758
1759 if (!PHI.getType()->isVectorTy())
1760 continue;
1761
1762 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1763 if (!ReductionOp)
1764 continue;
1765
1766 // Check if final instruction is reduced outside of current block
1767 Instruction *FinalReduction = nullptr;
1768 auto NumUsers = 0u;
1769 for (auto *U : ReductionOp->users()) {
1770 ++NumUsers;
1771 if (U == &PHI)
1772 continue;
1773 FinalReduction = dyn_cast<Instruction>(U);
1774 }
1775
1776 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1777 isa<PHINode>(FinalReduction))
1778 continue;
1779
1780 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1781 BackEdge = B;
1782 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1783 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1784 Incoming = PHI.getIncomingBlock(IncomingIdx);
1785 FoundPotentialReduction = true;
1786
1787 // If the initial value of PHINode is an Instruction, consider it a leaf
1788 // value of a complex deinterleaving graph.
1789 if (auto *InitPHI =
1790 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1791 FinalInstructions.insert(InitPHI);
1792 }
1793 return FoundPotentialReduction;
1794}
1795
1796void ComplexDeinterleavingGraph::identifyReductionNodes() {
1797 assert(Factor == 2 && "Cannot handle multiple complex values");
1798
1799 SmallVector<bool> Processed(ReductionInfo.size(), false);
1800 SmallVector<Instruction *> OperationInstruction;
1801 for (auto &P : ReductionInfo)
1802 OperationInstruction.push_back(P.first);
1803
1804 // Identify a complex computation by evaluating two reduction operations that
1805 // potentially could be involved
1806 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1807 if (Processed[i])
1808 continue;
1809 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1810 if (Processed[j])
1811 continue;
1812 auto *Real = OperationInstruction[i];
1813 auto *Imag = OperationInstruction[j];
1814 if (Real->getType() != Imag->getType())
1815 continue;
1816
1817 RealPHI = ReductionInfo[Real].first;
1818 ImagPHI = ReductionInfo[Imag].first;
1819 PHIsFound = false;
1820 auto Node = identifyNode(Real, Imag);
1821 if (!Node) {
1822 std::swap(Real, Imag);
1823 std::swap(RealPHI, ImagPHI);
1824 Node = identifyNode(Real, Imag);
1825 }
1826
1827 // If a node is identified and reduction PHINode is used in the chain of
1828 // operations, mark its operation instructions as used to prevent
1829 // re-identification and attach the node to the real part
1830 if (Node && PHIsFound) {
1831 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1832 << *Real << " / " << *Imag << "\n");
1833 Processed[i] = true;
1834 Processed[j] = true;
1835 auto RootNode = prepareCompositeNode(
1836 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1837 RootNode->addOperand(Node);
1838 RootToNode[Real] = RootNode;
1839 RootToNode[Imag] = RootNode;
1840 submitCompositeNode(RootNode);
1841 break;
1842 }
1843 }
1844
1845 auto *Real = OperationInstruction[i];
1846 // We want to check that we have 2 operands, but the function attributes
1847 // being counted as operands bloats this value.
1848 if (Processed[i] || Real->getNumOperands() < 2)
1849 continue;
1850
1851 // Can only combined integer reductions at the moment.
1852 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1853 continue;
1854
1855 RealPHI = ReductionInfo[Real].first;
1856 ImagPHI = nullptr;
1857 PHIsFound = false;
1858 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1859 if (Node && PHIsFound) {
1860 LLVM_DEBUG(
1861 dbgs() << "Identified single reduction starting from instruction: "
1862 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1863
1864 // Reducing to a single vector is not supported, only permit reducing down
1865 // to scalar values.
1866 // Doing this here will leave the prior node in the graph,
1867 // however with no uses the node will be unreachable by the replacement
1868 // process. That along with the usage outside the graph should prevent the
1869 // replacement process from kicking off at all for this graph.
1870 // TODO Add support for reducing to a single vector value
1871 if (ReductionInfo[Real].second->getType()->isVectorTy())
1872 continue;
1873
1874 Processed[i] = true;
1875 auto RootNode = prepareCompositeNode(
1876 ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1877 RootNode->addOperand(Node);
1878 RootToNode[Real] = RootNode;
1879 submitCompositeNode(RootNode);
1880 }
1881 }
1882
1883 RealPHI = nullptr;
1884 ImagPHI = nullptr;
1885}
1886
1887bool ComplexDeinterleavingGraph::checkNodes() {
1888 bool FoundDeinterleaveNode = false;
1889 for (CompositeNode *N : CompositeNodes) {
1890 if (!N->areOperandsValid())
1891 return false;
1892
1893 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1894 FoundDeinterleaveNode = true;
1895 }
1896
1897 // We need a deinterleave node in order to guarantee that we're working with
1898 // complex numbers.
1899 if (!FoundDeinterleaveNode) {
1900 LLVM_DEBUG(
1901 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1902 "guarantee safety during graph transformation.\n");
1903 return false;
1904 }
1905
1906 // Collect all instructions from roots to leaves
1907 SmallPtrSet<Instruction *, 16> AllInstructions;
1908 SmallVector<Instruction *, 8> Worklist;
1909 for (auto &Pair : RootToNode)
1910 Worklist.push_back(Pair.first);
1911
1912 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1913 // chains
1914 while (!Worklist.empty()) {
1915 auto *I = Worklist.pop_back_val();
1916
1917 if (!AllInstructions.insert(I).second)
1918 continue;
1919
1920 for (Value *Op : I->operands()) {
1921 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1922 if (!FinalInstructions.count(I))
1923 Worklist.emplace_back(OpI);
1924 }
1925 }
1926 }
1927
1928 // Find instructions that have users outside of chain
1929 for (auto *I : AllInstructions) {
1930 // Skip root nodes
1931 if (RootToNode.count(I))
1932 continue;
1933
1934 for (User *U : I->users()) {
1935 if (AllInstructions.count(cast<Instruction>(U)))
1936 continue;
1937
1938 // Found an instruction that is not used by XCMLA/XCADD chain
1939 Worklist.emplace_back(I);
1940 break;
1941 }
1942 }
1943
1944 // If any instructions are found to be used outside, find and remove roots
1945 // that somehow connect to those instructions.
1946 SmallPtrSet<Instruction *, 16> Visited;
1947 while (!Worklist.empty()) {
1948 auto *I = Worklist.pop_back_val();
1949 if (!Visited.insert(I).second)
1950 continue;
1951
1952 // Found an impacted root node. Removing it from the nodes to be
1953 // deinterleaved
1954 if (RootToNode.count(I)) {
1955 LLVM_DEBUG(dbgs() << "Instruction " << *I
1956 << " could be deinterleaved but its chain of complex "
1957 "operations have an outside user\n");
1958 RootToNode.erase(I);
1959 }
1960
1961 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1962 continue;
1963
1964 for (User *U : I->users())
1965 Worklist.emplace_back(cast<Instruction>(U));
1966
1967 for (Value *Op : I->operands()) {
1968 if (auto *OpI = dyn_cast<Instruction>(Op))
1969 Worklist.emplace_back(OpI);
1970 }
1971 }
1972 return !RootToNode.empty();
1973}
1974
1975ComplexDeinterleavingGraph::CompositeNode *
1976ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1977 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1979 Intrinsic->getIntrinsicID())
1980 return nullptr;
1981
1982 ComplexValues Vals;
1983 for (unsigned I = 0; I < Factor; I += 2) {
1984 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(I));
1985 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(I + 1));
1986 if (!Real || !Imag)
1987 return nullptr;
1988 Vals.push_back({Real, Imag});
1989 }
1990
1991 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
1992 if (!Node1)
1993 return nullptr;
1994 return Node1;
1995 }
1996
1997 // TODO: We could also add support for fixed-width interleave factors of 4
1998 // and above, but currently for symmetric operations the interleaves and
1999 // deinterleaves are already removed by VectorCombine. If we extend this to
2000 // permit complex multiplications, reductions, etc. then we should also add
2001 // support for fixed-width here.
2002 if (Factor != 2)
2003 return nullptr;
2004
2005 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
2006 if (!SVI)
2007 return nullptr;
2008
2009 // Look for a shufflevector that takes separate vectors of the real and
2010 // imaginary components and recombines them into a single vector.
2011 if (!isInterleavingMask(SVI->getShuffleMask()))
2012 return nullptr;
2013
2014 Instruction *Real;
2015 Instruction *Imag;
2016 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
2017 return nullptr;
2018
2019 return identifyNode(Real, Imag);
2020}
2021
2022ComplexDeinterleavingGraph::CompositeNode *
2023ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
2024 Instruction *II = nullptr;
2025
2026 // Must be at least one complex value.
2027 auto CheckExtract = [&](Value *V, unsigned ExpectedIdx,
2028 Instruction *ExpectedInsn) -> ExtractValueInst * {
2029 auto *EVI = dyn_cast<ExtractValueInst>(V);
2030 if (!EVI || EVI->getNumIndices() != 1 ||
2031 EVI->getIndices()[0] != ExpectedIdx ||
2032 !isa<Instruction>(EVI->getAggregateOperand()) ||
2033 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2034 return nullptr;
2035 return EVI;
2036 };
2037
2038 for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
2039 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
2040 if (RealEVI && Idx == 0)
2042 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
2043 II = nullptr;
2044 break;
2045 }
2046 }
2047
2048 if (auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(II)) {
2049 if (IntrinsicII->getIntrinsicID() !=
2051 return nullptr;
2052
2053 // The remaining should match too.
2054 CompositeNode *PlaceholderNode = prepareCompositeNode(
2056 PlaceholderNode->ReplacementNode = II->getOperand(0);
2057 for (auto &V : Vals) {
2058 FinalInstructions.insert(cast<Instruction>(V.Real));
2059 FinalInstructions.insert(cast<Instruction>(V.Imag));
2060 }
2061 return submitCompositeNode(PlaceholderNode);
2062 }
2063
2064 if (Vals.size() != 1)
2065 return nullptr;
2066
2067 Value *Real = Vals[0].Real;
2068 Value *Imag = Vals[0].Imag;
2069 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
2070 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
2071 if (!RealShuffle || !ImagShuffle) {
2072 if (RealShuffle || ImagShuffle)
2073 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
2074 return nullptr;
2075 }
2076
2077 Value *RealOp1 = RealShuffle->getOperand(1);
2078 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
2079 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
2080 return nullptr;
2081 }
2082 Value *ImagOp1 = ImagShuffle->getOperand(1);
2083 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
2084 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
2085 return nullptr;
2086 }
2087
2088 Value *RealOp0 = RealShuffle->getOperand(0);
2089 Value *ImagOp0 = ImagShuffle->getOperand(0);
2090
2091 if (RealOp0 != ImagOp0) {
2092 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
2093 return nullptr;
2094 }
2095
2096 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2097 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2098 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
2099 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
2100 return nullptr;
2101 }
2102
2103 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2104 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
2105 return nullptr;
2106 }
2107
2108 // Type checking, the shuffle type should be a vector type of the same
2109 // scalar type, but half the size
2110 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2111 Value *Op = Shuffle->getOperand(0);
2112 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
2113 auto *OpTy = cast<FixedVectorType>(Op->getType());
2114
2115 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2116 return false;
2117 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2118 return false;
2119
2120 return true;
2121 };
2122
2123 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
2124 if (!CheckType(Shuffle))
2125 return false;
2126
2127 ArrayRef<int> Mask = Shuffle->getShuffleMask();
2128 int Last = *Mask.rbegin();
2129
2130 Value *Op = Shuffle->getOperand(0);
2131 auto *OpTy = cast<FixedVectorType>(Op->getType());
2132 int NumElements = OpTy->getNumElements();
2133
2134 // Ensure that the deinterleaving shuffle only pulls from the first
2135 // shuffle operand.
2136 return Last < NumElements;
2137 };
2138
2139 if (RealShuffle->getType() != ImagShuffle->getType()) {
2140 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
2141 return nullptr;
2142 }
2143 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2144 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
2145 return nullptr;
2146 }
2147 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2148 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
2149 return nullptr;
2150 }
2151
2152 CompositeNode *PlaceholderNode =
2154 RealShuffle, ImagShuffle);
2155 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2156 FinalInstructions.insert(RealShuffle);
2157 FinalInstructions.insert(ImagShuffle);
2158 return submitCompositeNode(PlaceholderNode);
2159}
2160
2161ComplexDeinterleavingGraph::CompositeNode *
2162ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
2163 auto IsSplat = [](Value *V) -> bool {
2164 // Fixed-width vector with constants
2166 return true;
2167
2168 if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
2169 return isa<VectorType>(V->getType());
2170
2171 VectorType *VTy;
2172 ArrayRef<int> Mask;
2173 // Splats are represented differently depending on whether the repeated
2174 // value is a constant or an Instruction
2175 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
2176 if (Const->getOpcode() != Instruction::ShuffleVector)
2177 return false;
2178 VTy = cast<VectorType>(Const->getType());
2179 Mask = Const->getShuffleMask();
2180 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2181 VTy = Shuf->getType();
2182 Mask = Shuf->getShuffleMask();
2183 } else {
2184 return false;
2185 }
2186
2187 // When the data type is <1 x Type>, it's not possible to differentiate
2188 // between the ComplexDeinterleaving::Deinterleave and
2189 // ComplexDeinterleaving::Splat operations.
2190 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2191 return false;
2192
2193 return all_equal(Mask) && Mask[0] == 0;
2194 };
2195
2196 // The splats must meet the following requirements:
2197 // 1. Must either be all instructions or all values.
2198 // 2. Non-constant splats must live in the same block.
2199 if (auto *FirstValAsInstruction = dyn_cast<Instruction>(Vals[0].Real)) {
2200 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2201 for (auto &V : Vals) {
2202 if (!IsSplat(V.Real) || !IsSplat(V.Imag))
2203 return nullptr;
2204
2205 auto *Real = dyn_cast<Instruction>(V.Real);
2206 auto *Imag = dyn_cast<Instruction>(V.Imag);
2207 if (!Real || !Imag || Real->getParent() != FirstBB ||
2208 Imag->getParent() != FirstBB)
2209 return nullptr;
2210 }
2211 } else {
2212 for (auto &V : Vals) {
2213 if (!IsSplat(V.Real) || !IsSplat(V.Imag) || isa<Instruction>(V.Real) ||
2214 isa<Instruction>(V.Imag))
2215 return nullptr;
2216 }
2217 }
2218
2219 for (auto &V : Vals) {
2220 auto *Real = dyn_cast<Instruction>(V.Real);
2221 auto *Imag = dyn_cast<Instruction>(V.Imag);
2222 if (Real && Imag) {
2223 FinalInstructions.insert(Real);
2224 FinalInstructions.insert(Imag);
2225 }
2226 }
2227 CompositeNode *PlaceholderNode =
2228 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2229 return submitCompositeNode(PlaceholderNode);
2230}
2231
2232ComplexDeinterleavingGraph::CompositeNode *
2233ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2234 Instruction *Imag) {
2235 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2236 return nullptr;
2237
2238 PHIsFound = true;
2239 CompositeNode *PlaceholderNode = prepareCompositeNode(
2240 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2241 return submitCompositeNode(PlaceholderNode);
2242}
2243
2244ComplexDeinterleavingGraph::CompositeNode *
2245ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2246 Instruction *Imag) {
2247 auto *SelectReal = dyn_cast<SelectInst>(Real);
2248 auto *SelectImag = dyn_cast<SelectInst>(Imag);
2249 if (!SelectReal || !SelectImag)
2250 return nullptr;
2251
2252 Instruction *MaskA, *MaskB;
2253 Instruction *AR, *AI, *RA, *BI;
2254 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
2255 m_Instruction(RA))) ||
2256 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
2257 m_Instruction(BI))))
2258 return nullptr;
2259
2260 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
2261 return nullptr;
2262
2263 if (!MaskA->getType()->isVectorTy())
2264 return nullptr;
2265
2266 auto NodeA = identifyNode(AR, AI);
2267 if (!NodeA)
2268 return nullptr;
2269
2270 auto NodeB = identifyNode(RA, BI);
2271 if (!NodeB)
2272 return nullptr;
2273
2274 CompositeNode *PlaceholderNode = prepareCompositeNode(
2275 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2276 PlaceholderNode->addOperand(NodeA);
2277 PlaceholderNode->addOperand(NodeB);
2278 FinalInstructions.insert(MaskA);
2279 FinalInstructions.insert(MaskB);
2280 return submitCompositeNode(PlaceholderNode);
2281}
2282
2283static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2284 std::optional<FastMathFlags> Flags,
2285 Value *InputA, Value *InputB) {
2286 Value *I;
2287 switch (Opcode) {
2288 case Instruction::FNeg:
2289 I = B.CreateFNeg(InputA);
2290 break;
2291 case Instruction::FAdd:
2292 I = B.CreateFAdd(InputA, InputB);
2293 break;
2294 case Instruction::Add:
2295 I = B.CreateAdd(InputA, InputB);
2296 break;
2297 case Instruction::FSub:
2298 I = B.CreateFSub(InputA, InputB);
2299 break;
2300 case Instruction::Sub:
2301 I = B.CreateSub(InputA, InputB);
2302 break;
2303 case Instruction::FMul:
2304 I = B.CreateFMul(InputA, InputB);
2305 break;
2306 case Instruction::Mul:
2307 I = B.CreateMul(InputA, InputB);
2308 break;
2309 default:
2310 llvm_unreachable("Incorrect symmetric opcode");
2311 }
2312 if (Flags)
2313 cast<Instruction>(I)->setFastMathFlags(*Flags);
2314 return I;
2315}
2316
2317Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2318 CompositeNode *Node) {
2319 if (Node->ReplacementNode)
2320 return Node->ReplacementNode;
2321
2322 auto ReplaceOperandIfExist = [&](CompositeNode *Node,
2323 unsigned Idx) -> Value * {
2324 return Node->Operands.size() > Idx
2325 ? replaceNode(Builder, Node->Operands[Idx])
2326 : nullptr;
2327 };
2328
2329 Value *ReplacementNode = nullptr;
2330 switch (Node->Operation) {
2331 case ComplexDeinterleavingOperation::CDot: {
2332 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2333 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2334 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2335 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2336 "Node inputs need to be of the same type"));
2337 ReplacementNode = TL->createComplexDeinterleavingIR(
2338 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
2339 break;
2340 }
2341 case ComplexDeinterleavingOperation::CAdd:
2342 case ComplexDeinterleavingOperation::CMulPartial:
2343 case ComplexDeinterleavingOperation::Symmetric: {
2344 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2345 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2346 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2347 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2348 "Node inputs need to be of the same type"));
2350 (Input0->getType() == Accumulator->getType() &&
2351 "Accumulator and input need to be of the same type"));
2352 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2353 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
2354 Input0, Input1);
2355 else
2356 ReplacementNode = TL->createComplexDeinterleavingIR(
2357 Builder, Node->Operation, Node->Rotation, Input0, Input1,
2358 Accumulator);
2359 break;
2360 }
2361 case ComplexDeinterleavingOperation::Deinterleave:
2362 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2363 break;
2364 case ComplexDeinterleavingOperation::Splat: {
2366 for (auto &V : Node->Vals) {
2367 Ops.push_back(V.Real);
2368 Ops.push_back(V.Imag);
2369 }
2370 auto *R = dyn_cast<Instruction>(Node->Vals[0].Real);
2371 auto *I = dyn_cast<Instruction>(Node->Vals[0].Imag);
2372 if (R && I) {
2373 // Splats that are not constant are interleaved where they are located
2374 Instruction *InsertPoint = R;
2375 for (auto V : Node->Vals) {
2376 if (InsertPoint->comesBefore(cast<Instruction>(V.Real)))
2377 InsertPoint = cast<Instruction>(V.Real);
2378 if (InsertPoint->comesBefore(cast<Instruction>(V.Imag)))
2379 InsertPoint = cast<Instruction>(V.Imag);
2380 }
2381 InsertPoint = InsertPoint->getNextNode();
2382 IRBuilder<> IRB(InsertPoint);
2383 ReplacementNode = IRB.CreateVectorInterleave(Ops);
2384 } else {
2385 ReplacementNode = Builder.CreateVectorInterleave(Ops);
2386 }
2387 break;
2388 }
2389 case ComplexDeinterleavingOperation::ReductionPHI: {
2390 // If Operation is ReductionPHI, a new empty PHINode is created.
2391 // It is filled later when the ReductionOperation is processed.
2392 auto *OldPHI = cast<PHINode>(Node->Vals[0].Real);
2393 auto *VTy = cast<VectorType>(Node->Vals[0].Real->getType());
2394 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2395 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
2396 OldToNewPHI[OldPHI] = NewPHI;
2397 ReplacementNode = NewPHI;
2398 break;
2399 }
2400 case ComplexDeinterleavingOperation::ReductionSingle:
2401 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2402 processReductionSingle(ReplacementNode, Node);
2403 break;
2404 case ComplexDeinterleavingOperation::ReductionOperation:
2405 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2406 processReductionOperation(ReplacementNode, Node);
2407 break;
2408 case ComplexDeinterleavingOperation::ReductionSelect: {
2409 auto *MaskReal = cast<Instruction>(Node->Vals[0].Real)->getOperand(0);
2410 auto *MaskImag = cast<Instruction>(Node->Vals[0].Imag)->getOperand(0);
2411 auto *A = replaceNode(Builder, Node->Operands[0]);
2412 auto *B = replaceNode(Builder, Node->Operands[1]);
2413 auto *NewMask = Builder.CreateVectorInterleave({MaskReal, MaskImag});
2414 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
2415 break;
2416 }
2417 }
2418
2419 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2420 NumComplexTransformations += 1;
2421 Node->ReplacementNode = ReplacementNode;
2422 return ReplacementNode;
2423}
2424
2425void ComplexDeinterleavingGraph::processReductionSingle(
2426 Value *OperationReplacement, CompositeNode *Node) {
2427 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2428 auto *OldPHI = ReductionInfo[Real].first;
2429 auto *NewPHI = OldToNewPHI[OldPHI];
2430 auto *VTy = cast<VectorType>(Real->getType());
2431 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2432
2433 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2434
2435 IRBuilder<> Builder(Incoming->getTerminator());
2436
2437 Value *NewInit = nullptr;
2438 if (auto *C = dyn_cast<Constant>(Init)) {
2439 if (C->isZeroValue())
2440 NewInit = Constant::getNullValue(NewVTy);
2441 }
2442
2443 if (!NewInit)
2444 NewInit =
2445 Builder.CreateVectorInterleave({Init, Constant::getNullValue(VTy)});
2446
2447 NewPHI->addIncoming(NewInit, Incoming);
2448 NewPHI->addIncoming(OperationReplacement, BackEdge);
2449
2450 auto *FinalReduction = ReductionInfo[Real].second;
2451 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2452
2453 auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2454 FinalReduction->replaceAllUsesWith(AddReduce);
2455}
2456
2457void ComplexDeinterleavingGraph::processReductionOperation(
2458 Value *OperationReplacement, CompositeNode *Node) {
2459 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2460 auto *Imag = cast<Instruction>(Node->Vals[0].Imag);
2461 auto *OldPHIReal = ReductionInfo[Real].first;
2462 auto *OldPHIImag = ReductionInfo[Imag].first;
2463 auto *NewPHI = OldToNewPHI[OldPHIReal];
2464
2465 // We have to interleave initial origin values coming from IncomingBlock
2466 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2467 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2468
2469 IRBuilder<> Builder(Incoming->getTerminator());
2470 auto *NewInit = Builder.CreateVectorInterleave({InitReal, InitImag});
2471
2472 NewPHI->addIncoming(NewInit, Incoming);
2473 NewPHI->addIncoming(OperationReplacement, BackEdge);
2474
2475 // Deinterleave complex vector outside of loop so that it can be finally
2476 // reduced
2477 auto *FinalReductionReal = ReductionInfo[Real].second;
2478 auto *FinalReductionImag = ReductionInfo[Imag].second;
2479
2480 Builder.SetInsertPoint(
2481 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2482 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2483 OperationReplacement->getType(),
2484 OperationReplacement);
2485
2486 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2487 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2488
2489 Builder.SetInsertPoint(FinalReductionImag);
2490 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2491 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2492}
2493
2494void ComplexDeinterleavingGraph::replaceNodes() {
2495 SmallVector<Instruction *, 16> DeadInstrRoots;
2496 for (auto *RootInstruction : OrderedRoots) {
2497 // Check if this potential root went through check process and we can
2498 // deinterleave it
2499 if (!RootToNode.count(RootInstruction))
2500 continue;
2501
2502 IRBuilder<> Builder(RootInstruction);
2503 auto RootNode = RootToNode[RootInstruction];
2504 Value *R = replaceNode(Builder, RootNode);
2505
2506 if (RootNode->Operation ==
2507 ComplexDeinterleavingOperation::ReductionOperation) {
2508 auto *RootReal = cast<Instruction>(RootNode->Vals[0].Real);
2509 auto *RootImag = cast<Instruction>(RootNode->Vals[0].Imag);
2510 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2511 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2512 DeadInstrRoots.push_back(RootReal);
2513 DeadInstrRoots.push_back(RootImag);
2514 } else if (RootNode->Operation ==
2515 ComplexDeinterleavingOperation::ReductionSingle) {
2516 auto *RootInst = cast<Instruction>(RootNode->Vals[0].Real);
2517 auto &Info = ReductionInfo[RootInst];
2518 Info.first->removeIncomingValue(BackEdge);
2519 DeadInstrRoots.push_back(Info.second);
2520 } else {
2521 assert(R && "Unable to find replacement for RootInstruction");
2522 DeadInstrRoots.push_back(RootInstruction);
2523 RootInstruction->replaceAllUsesWith(R);
2524 }
2525 }
2526
2527 for (auto *I : DeadInstrRoots)
2529}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Rewrite undef for PHI
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Analysis containing CSE Info
Definition CSEInfo.cpp:27
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
This file implements a map that provides insertion order iteration.
#define T
uint64_t IntrinsicInst * II
#define P(N)
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
This file describes how to lower LLVM code to machine code.
This pass exposes codegen information to IR-level passes.
BinaryOperator * Mul
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition ArrayRef.h:147
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:165
iterator end()
Definition DenseMap.h:81
bool allowContract() const
Definition FMF.h:69
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition IRBuilder.h:2618
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:207
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
size_type size() const
Definition MapVector.h:56
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:273
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
An opaque object representing a hash code.
Definition Hashing.h:76
const ParentTy * getParent() const
Definition ilist_node.h:34
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:355
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
NodeAddr< NodeBase * > Node
Definition RDFGraph.h:381
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1705
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:533
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:759
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ABI void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
@ Other
Any other memory.
Definition ModRef.h:68
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:565
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1877
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition STLExtras.h:2088
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:584
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:853
#define N
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static unsigned getHashValue(const ComplexValue &Val)
An information struct used to provide DenseMap with the various necessary components for a given valu...