LLVM 23.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
64#include "llvm/ADT/MapVector.h"
65#include "llvm/ADT/Statistic.h"
70#include "llvm/IR/IRBuilder.h"
71#include "llvm/IR/Intrinsics.h"
77#include <algorithm>
78
79using namespace llvm;
80using namespace PatternMatch;
81
82#define DEBUG_TYPE "complex-deinterleaving"
83
84STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
85
87 "enable-complex-deinterleaving",
88 cl::desc("Enable generation of complex instructions"), cl::init(true),
90
91/// Checks the given mask, and determines whether said mask is interleaving.
92///
93/// To be interleaving, a mask must alternate between `i` and `i + (Length /
94/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
95/// 4x vector interleaving mask would be <0, 2, 1, 3>).
96static bool isInterleavingMask(ArrayRef<int> Mask);
97
98/// Checks the given mask, and determines whether said mask is deinterleaving.
99///
100/// To be deinterleaving, a mask must increment in steps of 2, and either start
101/// with 0 or 1.
102/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
103/// <1, 3, 5, 7>).
104static bool isDeinterleavingMask(ArrayRef<int> Mask);
105
106/// Returns true if the operation is a negation of V, and it works for both
107/// integers and floats.
108static bool isNeg(Value *V);
109
110/// Returns the operand for negation operation.
111static Value *getNegOperand(Value *V);
112
113namespace {
114struct ComplexValue {
115 Value *Real = nullptr;
116 Value *Imag = nullptr;
117
118 bool operator==(const ComplexValue &Other) const {
119 return Real == Other.Real && Imag == Other.Imag;
120 }
121};
122hash_code hash_value(const ComplexValue &Arg) {
125}
126} // end namespace
128
129template <> struct llvm::DenseMapInfo<ComplexValue> {
130 static inline ComplexValue getEmptyKey() {
133 }
134 static inline ComplexValue getTombstoneKey() {
137 }
138 static unsigned getHashValue(const ComplexValue &Val) {
141 }
142 static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
143 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
144 }
145};
146
147namespace {
148template <typename T, typename IterT>
149std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
150 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
151 if (Common != A.end())
152 return std::make_optional(*Common);
153 return std::nullopt;
154}
155
156class ComplexDeinterleavingLegacyPass : public FunctionPass {
157public:
158 static char ID;
159
160 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
161 : FunctionPass(ID), TM(TM) {}
162
163 StringRef getPassName() const override {
164 return "Complex Deinterleaving Pass";
165 }
166
167 bool runOnFunction(Function &F) override;
168 void getAnalysisUsage(AnalysisUsage &AU) const override {
169 AU.addRequired<TargetLibraryInfoWrapperPass>();
170 AU.setPreservesCFG();
171 }
172
173private:
174 const TargetMachine *TM;
175};
176
177class ComplexDeinterleavingGraph;
178struct ComplexDeinterleavingCompositeNode {
179
180 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
181 Value *R, Value *I)
182 : Operation(Op) {
183 Vals.push_back({R, I});
184 }
185
186 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
188 : Operation(Op), Vals(Other) {}
189
190private:
191 friend class ComplexDeinterleavingGraph;
192 using CompositeNode = ComplexDeinterleavingCompositeNode;
193 bool OperandsValid = true;
194
195public:
197 ComplexValues Vals;
198
199 // This two members are required exclusively for generating
200 // ComplexDeinterleavingOperation::Symmetric operations.
201 unsigned Opcode;
202 std::optional<FastMathFlags> Flags;
203
205 ComplexDeinterleavingRotation::Rotation_0;
207 Value *ReplacementNode = nullptr;
208
209 void addOperand(CompositeNode *Node) {
210 if (!Node)
211 OperandsValid = false;
212 Operands.push_back(Node);
213 }
214
215 void dump() { dump(dbgs()); }
216 void dump(raw_ostream &OS) {
217 auto PrintValue = [&](Value *V) {
218 if (V) {
219 OS << "\"";
220 V->print(OS, true);
221 OS << "\"\n";
222 } else
223 OS << "nullptr\n";
224 };
225 auto PrintNodeRef = [&](CompositeNode *Ptr) {
226 if (Ptr)
227 OS << Ptr << "\n";
228 else
229 OS << "nullptr\n";
230 };
231
232 OS << "- CompositeNode: " << this << "\n";
233 for (unsigned I = 0; I < Vals.size(); I++) {
234 OS << " Real(" << I << ") : ";
235 PrintValue(Vals[I].Real);
236 OS << " Imag(" << I << ") : ";
237 PrintValue(Vals[I].Imag);
238 }
239 OS << " ReplacementNode: ";
240 PrintValue(ReplacementNode);
241 OS << " Operation: " << (int)Operation << "\n";
242 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
243 OS << " Operands: \n";
244 for (const auto &Op : Operands) {
245 OS << " - ";
246 PrintNodeRef(Op);
247 }
248 }
249
250 bool areOperandsValid() { return OperandsValid; }
251};
252
253class ComplexDeinterleavingGraph {
254public:
255 struct Product {
256 Value *Multiplier;
257 Value *Multiplicand;
258 bool IsPositive;
259 };
260
261 using Addend = std::pair<Value *, bool>;
262 using AddendList = BumpPtrList<Addend>;
263 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
264
265 // Helper struct for holding info about potential partial multiplication
266 // candidates
267 struct PartialMulCandidate {
268 Value *Common;
269 CompositeNode *Node;
270 unsigned RealIdx;
271 unsigned ImagIdx;
272 bool IsNodeInverted;
273 };
274
275 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
276 const TargetLibraryInfo *TLI,
277 unsigned Factor)
278 : TL(TL), TLI(TLI), Factor(Factor) {}
279
280private:
281 const TargetLowering *TL = nullptr;
282 const TargetLibraryInfo *TLI = nullptr;
283 unsigned Factor;
284 SmallVector<CompositeNode *> CompositeNodes;
285 DenseMap<ComplexValues, CompositeNode *> CachedResult;
286 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
287
288 SmallPtrSet<Instruction *, 16> FinalInstructions;
289
290 /// Root instructions are instructions from which complex computation starts
291 DenseMap<Instruction *, CompositeNode *> RootToNode;
292
293 /// Topologically sorted root instructions
295
296 /// When examining a basic block for complex deinterleaving, if it is a simple
297 /// one-block loop, then the only incoming block is 'Incoming' and the
298 /// 'BackEdge' block is the block itself."
299 BasicBlock *BackEdge = nullptr;
300 BasicBlock *Incoming = nullptr;
301
302 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
303 /// %OutsideUser as it is shown in the IR:
304 ///
305 /// vector.body:
306 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
307 /// [ %ReductionOp, %vector.body ]
308 /// ...
309 /// %ReductionOp = fadd i64 ...
310 /// ...
311 /// br i1 %condition, label %vector.body, %middle.block
312 ///
313 /// middle.block:
314 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
315 ///
316 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
317 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
318 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
319
320 /// In the process of detecting a reduction, we consider a pair of
321 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
322 /// traverse the use-tree to detect complex operations. As this is a reduction
323 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
324 /// to the %ReductionOPs that we suspect to be complex.
325 /// RealPHI and ImagPHI are used by the identifyPHINode method.
326 PHINode *RealPHI = nullptr;
327 PHINode *ImagPHI = nullptr;
328
329 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
330 /// detection.
331 bool PHIsFound = false;
332
333 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
334 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
335 /// This mapping is populated during
336 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
337 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
338 /// replacement process.
339 DenseMap<PHINode *, PHINode *> OldToNewPHI;
340
341 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
342 Value *R, Value *I) {
343 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
344 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
345 (R && I)) &&
346 "Reduction related nodes must have Real and Imaginary parts");
347 return new (Allocator.Allocate())
348 ComplexDeinterleavingCompositeNode(Operation, R, I);
349 }
350
351 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
352 ComplexValues &Vals) {
353#ifndef NDEBUG
354 for (auto &V : Vals) {
355 assert(
356 ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
357 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
358 (V.Real && V.Imag)) &&
359 "Reduction related nodes must have Real and Imaginary parts");
360 }
361#endif
362 return new (Allocator.Allocate())
363 ComplexDeinterleavingCompositeNode(Operation, Vals);
364 }
365
366 CompositeNode *submitCompositeNode(CompositeNode *Node) {
367 CompositeNodes.push_back(Node);
368 if (Node->Vals[0].Real)
369 CachedResult[Node->Vals] = Node;
370 return Node;
371 }
372
373 /// Identifies a complex partial multiply pattern and its rotation, based on
374 /// the following patterns
375 ///
376 /// 0: r: cr + ar * br
377 /// i: ci + ar * bi
378 /// 90: r: cr - ai * bi
379 /// i: ci + ai * br
380 /// 180: r: cr - ar * br
381 /// i: ci - ar * bi
382 /// 270: r: cr + ai * bi
383 /// i: ci - ai * br
384 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
385
386 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
387 /// is partially known from identifyPartialMul, filling in the other half of
388 /// the complex pair.
389 CompositeNode *
390 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
391 std::pair<Value *, Value *> &CommonOperandI);
392
393 /// Identifies a complex add pattern and its rotation, based on the following
394 /// patterns.
395 ///
396 /// 90: r: ar - bi
397 /// i: ai + br
398 /// 270: r: ar + bi
399 /// i: ai - br
400 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
401 CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
402 CompositeNode *identifyPartialReduction(Value *R, Value *I);
403 CompositeNode *identifyDotProduct(Value *Inst);
404
405 CompositeNode *identifyNode(ComplexValues &Vals);
406
407 CompositeNode *identifyNode(Value *R, Value *I) {
408 ComplexValues Vals;
409 Vals.push_back({R, I});
410 return identifyNode(Vals);
411 }
412
413 /// Determine if a sum of complex numbers can be formed from \p RealAddends
414 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
415 /// Return nullptr if it is not possible to construct a complex number.
416 /// \p Flags are needed to generate symmetric Add and Sub operations.
417 CompositeNode *identifyAdditions(AddendList &RealAddends,
418 AddendList &ImagAddends,
419 std::optional<FastMathFlags> Flags,
420 CompositeNode *Accumulator);
421
422 /// Extract one addend that have both real and imaginary parts positive.
423 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
424 AddendList &ImagAddends);
425
426 /// Determine if sum of multiplications of complex numbers can be formed from
427 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
428 /// to it. Return nullptr if it is not possible to construct a complex number.
429 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
430 SmallVectorImpl<Product> &ImagMuls,
431 CompositeNode *Accumulator);
432
433 /// Go through pairs of multiplication (one Real and one Imag) and find all
434 /// possible candidates for partial multiplication and put them into \p
435 /// Candidates. Returns true if all Product has pair with common operand
436 bool collectPartialMuls(ArrayRef<Product> RealMuls,
437 ArrayRef<Product> ImagMuls,
438 SmallVectorImpl<PartialMulCandidate> &Candidates);
439
440 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
441 /// the order of complex computation operations may be significantly altered,
442 /// and the real and imaginary parts may not be executed in parallel. This
443 /// function takes this into consideration and employs a more general approach
444 /// to identify complex computations. Initially, it gathers all the addends
445 /// and multiplicands and then constructs a complex expression from them.
446 CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);
447
448 CompositeNode *identifyRoot(Instruction *I);
449
450 /// Identifies the Deinterleave operation applied to a vector containing
451 /// complex numbers. There are two ways to represent the Deinterleave
452 /// operation:
453 /// * Using two shufflevectors with even indices for /pReal instruction and
454 /// odd indices for /pImag instructions (only for fixed-width vectors)
455 /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
456 /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
457 /// 2.
458 CompositeNode *identifyDeinterleave(ComplexValues &Vals);
459
460 /// identifying the operation that represents a complex number repeated in a
461 /// Splat vector. There are two possible types of splats: ConstantExpr with
462 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
463 /// initialization mask with all values set to zero.
464 CompositeNode *identifySplat(ComplexValues &Vals);
465
466 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
467
468 /// Identifies SelectInsts in a loop that has reduction with predication masks
469 /// and/or predicated tail folding
470 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
471
472 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
473
474 /// Complete IR modifications after producing new reduction operation:
475 /// * Populate the PHINode generated for
476 /// ComplexDeinterleavingOperation::ReductionPHI
477 /// * Deinterleave the final value outside of the loop and repurpose original
478 /// reduction users
479 void processReductionOperation(Value *OperationReplacement,
480 CompositeNode *Node);
481 void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);
482
483public:
484 void dump() { dump(dbgs()); }
485 void dump(raw_ostream &OS) {
486 for (const auto &Node : CompositeNodes)
487 Node->dump(OS);
488 }
489
490 /// Returns false if the deinterleaving operation should be cancelled for the
491 /// current graph.
492 bool identifyNodes(Instruction *RootI);
493
494 /// In case \pB is one-block loop, this function seeks potential reductions
495 /// and populates ReductionInfo. Returns true if any reductions were
496 /// identified.
497 bool collectPotentialReductions(BasicBlock *B);
498
499 void identifyReductionNodes();
500
501 /// Check that every instruction, from the roots to the leaves, has internal
502 /// uses.
503 bool checkNodes();
504
505 /// Perform the actual replacement of the underlying instruction graph.
506 void replaceNodes();
507};
508
509class ComplexDeinterleaving {
510public:
511 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
512 : TL(tl), TLI(tli) {}
513 bool runOnFunction(Function &F);
514
515private:
516 bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);
517
518 const TargetLowering *TL = nullptr;
519 const TargetLibraryInfo *TLI = nullptr;
520};
521
522} // namespace
523
524char ComplexDeinterleavingLegacyPass::ID = 0;
525
526INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
527 "Complex Deinterleaving", false, false)
528INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
529 "Complex Deinterleaving", false, false)
530
533 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
534 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
535 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
536 return PreservedAnalyses::all();
537
540 return PA;
541}
542
544 return new ComplexDeinterleavingLegacyPass(TM);
545}
546
547bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
548 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
549 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
550 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
551}
552
553bool ComplexDeinterleaving::runOnFunction(Function &F) {
556 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
557 return false;
558 }
559
562 dbgs() << "Complex deinterleaving has been disabled, target does "
563 "not support lowering of complex number operations.\n");
564 return false;
565 }
566
567 bool Changed = false;
568 for (auto &B : F)
569 Changed |= evaluateBasicBlock(&B, 2);
570
571 // TODO: Permit changes for both interleave factors in the same function.
572 if (!Changed) {
573 for (auto &B : F)
574 Changed |= evaluateBasicBlock(&B, 4);
575 }
576
577 // TODO: We can also support interleave factors of 6 and 8 if needed.
578
579 return Changed;
580}
581
583 // If the size is not even, it's not an interleaving mask
584 if ((Mask.size() & 1))
585 return false;
586
587 int HalfNumElements = Mask.size() / 2;
588 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
589 int MaskIdx = Idx * 2;
590 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
591 return false;
592 }
593
594 return true;
595}
596
598 int Offset = Mask[0];
599 int HalfNumElements = Mask.size() / 2;
600
601 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
602 if (Mask[Idx] != (Idx * 2) + Offset)
603 return false;
604 }
605
606 return true;
607}
608
609bool isNeg(Value *V) {
610 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
611}
612
614 assert(isNeg(V));
615 auto *I = cast<Instruction>(V);
616 if (I->getOpcode() == Instruction::FNeg)
617 return I->getOperand(0);
618
619 return I->getOperand(1);
620}
621
622bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
623 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
624 if (Graph.collectPotentialReductions(B))
625 Graph.identifyReductionNodes();
626
627 for (auto &I : *B)
628 Graph.identifyNodes(&I);
629
630 if (Graph.checkNodes()) {
631 Graph.replaceNodes();
632 return true;
633 }
634
635 return false;
636}
637
638ComplexDeinterleavingGraph::CompositeNode *
639ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
640 Instruction *Real, Instruction *Imag,
641 std::pair<Value *, Value *> &PartialMatch) {
642 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
643 << "\n");
644
645 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
646 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
647 return nullptr;
648 }
649
650 if ((Real->getOpcode() != Instruction::FMul &&
651 Real->getOpcode() != Instruction::Mul) ||
652 (Imag->getOpcode() != Instruction::FMul &&
653 Imag->getOpcode() != Instruction::Mul)) {
655 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
656 return nullptr;
657 }
658
659 Value *R0 = Real->getOperand(0);
660 Value *R1 = Real->getOperand(1);
661 Value *I0 = Imag->getOperand(0);
662 Value *I1 = Imag->getOperand(1);
663
664 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
665 // rotations and use the operand.
666 unsigned Negs = 0;
667 Value *Op;
668 if (match(R0, m_Neg(m_Value(Op)))) {
669 Negs |= 1;
670 R0 = Op;
671 } else if (match(R1, m_Neg(m_Value(Op)))) {
672 Negs |= 1;
673 R1 = Op;
674 }
675
676 if (isNeg(I0)) {
677 Negs |= 2;
678 Negs ^= 1;
679 I0 = Op;
680 } else if (match(I1, m_Neg(m_Value(Op)))) {
681 Negs |= 2;
682 Negs ^= 1;
683 I1 = Op;
684 }
685
687
688 Value *CommonOperand;
689 Value *UncommonRealOp;
690 Value *UncommonImagOp;
691
692 if (R0 == I0 || R0 == I1) {
693 CommonOperand = R0;
694 UncommonRealOp = R1;
695 } else if (R1 == I0 || R1 == I1) {
696 CommonOperand = R1;
697 UncommonRealOp = R0;
698 } else {
699 LLVM_DEBUG(dbgs() << " - No equal operand\n");
700 return nullptr;
701 }
702
703 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
704 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
705 Rotation == ComplexDeinterleavingRotation::Rotation_270)
706 std::swap(UncommonRealOp, UncommonImagOp);
707
708 // Between identifyPartialMul and here we need to have found a complete valid
709 // pair from the CommonOperand of each part.
710 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
711 Rotation == ComplexDeinterleavingRotation::Rotation_180)
712 PartialMatch.first = CommonOperand;
713 else
714 PartialMatch.second = CommonOperand;
715
716 if (!PartialMatch.first || !PartialMatch.second) {
717 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
718 return nullptr;
719 }
720
721 CompositeNode *CommonNode =
722 identifyNode(PartialMatch.first, PartialMatch.second);
723 if (!CommonNode) {
724 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
725 return nullptr;
726 }
727
728 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
729 if (!UncommonNode) {
730 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
731 return nullptr;
732 }
733
734 CompositeNode *Node = prepareCompositeNode(
735 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
736 Node->Rotation = Rotation;
737 Node->addOperand(CommonNode);
738 Node->addOperand(UncommonNode);
739 return submitCompositeNode(Node);
740}
741
742ComplexDeinterleavingGraph::CompositeNode *
743ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
744 Instruction *Imag) {
745 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
746 << "\n");
747
748 // Determine rotation
749 auto IsAdd = [](unsigned Op) {
750 return Op == Instruction::FAdd || Op == Instruction::Add;
751 };
752 auto IsSub = [](unsigned Op) {
753 return Op == Instruction::FSub || Op == Instruction::Sub;
754 };
756 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
757 Rotation = ComplexDeinterleavingRotation::Rotation_0;
758 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
759 Rotation = ComplexDeinterleavingRotation::Rotation_90;
760 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
761 Rotation = ComplexDeinterleavingRotation::Rotation_180;
762 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
763 Rotation = ComplexDeinterleavingRotation::Rotation_270;
764 else {
765 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
766 return nullptr;
767 }
768
769 if (isa<FPMathOperator>(Real) &&
770 (!Real->getFastMathFlags().allowContract() ||
771 !Imag->getFastMathFlags().allowContract())) {
772 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
773 return nullptr;
774 }
775
776 Value *CR = Real->getOperand(0);
777 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
778 if (!RealMulI)
779 return nullptr;
780 Value *CI = Imag->getOperand(0);
781 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
782 if (!ImagMulI)
783 return nullptr;
784
785 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
786 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
787 return nullptr;
788 }
789
790 Value *R0 = RealMulI->getOperand(0);
791 Value *R1 = RealMulI->getOperand(1);
792 Value *I0 = ImagMulI->getOperand(0);
793 Value *I1 = ImagMulI->getOperand(1);
794
795 Value *CommonOperand;
796 Value *UncommonRealOp;
797 Value *UncommonImagOp;
798
799 if (R0 == I0 || R0 == I1) {
800 CommonOperand = R0;
801 UncommonRealOp = R1;
802 } else if (R1 == I0 || R1 == I1) {
803 CommonOperand = R1;
804 UncommonRealOp = R0;
805 } else {
806 LLVM_DEBUG(dbgs() << " - No equal operand\n");
807 return nullptr;
808 }
809
810 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
811 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
812 Rotation == ComplexDeinterleavingRotation::Rotation_270)
813 std::swap(UncommonRealOp, UncommonImagOp);
814
815 std::pair<Value *, Value *> PartialMatch(
816 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
817 Rotation == ComplexDeinterleavingRotation::Rotation_180)
818 ? CommonOperand
819 : nullptr,
820 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
821 Rotation == ComplexDeinterleavingRotation::Rotation_270)
822 ? CommonOperand
823 : nullptr);
824
825 auto *CRInst = dyn_cast<Instruction>(CR);
826 auto *CIInst = dyn_cast<Instruction>(CI);
827
828 if (!CRInst || !CIInst) {
829 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
830 return nullptr;
831 }
832
833 CompositeNode *CNode =
834 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
835 if (!CNode) {
836 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
837 return nullptr;
838 }
839
840 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
841 if (!UncommonRes) {
842 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
843 return nullptr;
844 }
845
846 assert(PartialMatch.first && PartialMatch.second);
847 CompositeNode *CommonRes =
848 identifyNode(PartialMatch.first, PartialMatch.second);
849 if (!CommonRes) {
850 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
851 return nullptr;
852 }
853
854 CompositeNode *Node = prepareCompositeNode(
855 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
856 Node->Rotation = Rotation;
857 Node->addOperand(CommonRes);
858 Node->addOperand(UncommonRes);
859 Node->addOperand(CNode);
860 return submitCompositeNode(Node);
861}
862
863ComplexDeinterleavingGraph::CompositeNode *
864ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
865 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
866
867 // Determine rotation
869 if ((Real->getOpcode() == Instruction::FSub &&
870 Imag->getOpcode() == Instruction::FAdd) ||
871 (Real->getOpcode() == Instruction::Sub &&
872 Imag->getOpcode() == Instruction::Add))
873 Rotation = ComplexDeinterleavingRotation::Rotation_90;
874 else if ((Real->getOpcode() == Instruction::FAdd &&
875 Imag->getOpcode() == Instruction::FSub) ||
876 (Real->getOpcode() == Instruction::Add &&
877 Imag->getOpcode() == Instruction::Sub))
878 Rotation = ComplexDeinterleavingRotation::Rotation_270;
879 else {
880 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
881 return nullptr;
882 }
883
884 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
885 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
886 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
887 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
888
889 if (!AR || !AI || !BR || !BI) {
890 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
891 return nullptr;
892 }
893
894 CompositeNode *ResA = identifyNode(AR, AI);
895 if (!ResA) {
896 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
897 return nullptr;
898 }
899 CompositeNode *ResB = identifyNode(BR, BI);
900 if (!ResB) {
901 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
902 return nullptr;
903 }
904
905 CompositeNode *Node =
906 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
907 Node->Rotation = Rotation;
908 Node->addOperand(ResA);
909 Node->addOperand(ResB);
910 return submitCompositeNode(Node);
911}
912
914 unsigned OpcA = A->getOpcode();
915 unsigned OpcB = B->getOpcode();
916
917 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
918 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
919 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
920 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
921}
922
924 auto Pattern =
926
927 return match(A, Pattern) && match(B, Pattern);
928}
929
931 switch (I->getOpcode()) {
932 case Instruction::FAdd:
933 case Instruction::FSub:
934 case Instruction::FMul:
935 case Instruction::FNeg:
936 case Instruction::Add:
937 case Instruction::Sub:
938 case Instruction::Mul:
939 return true;
940 default:
941 return false;
942 }
943}
944
945ComplexDeinterleavingGraph::CompositeNode *
946ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
947 auto *FirstReal = cast<Instruction>(Vals[0].Real);
948 unsigned FirstOpc = FirstReal->getOpcode();
949 for (auto &V : Vals) {
950 auto *Real = cast<Instruction>(V.Real);
951 auto *Imag = cast<Instruction>(V.Imag);
952 if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
953 return nullptr;
954
957 return nullptr;
958
959 if (isa<FPMathOperator>(FirstReal))
960 if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
961 Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
962 return nullptr;
963 }
964
965 ComplexValues OpVals;
966 for (auto &V : Vals) {
967 auto *R0 = cast<Instruction>(V.Real)->getOperand(0);
968 auto *I0 = cast<Instruction>(V.Imag)->getOperand(0);
969 OpVals.push_back({R0, I0});
970 }
971
972 CompositeNode *Op0 = identifyNode(OpVals);
973 CompositeNode *Op1 = nullptr;
974 if (Op0 == nullptr)
975 return nullptr;
976
977 if (FirstReal->isBinaryOp()) {
978 OpVals.clear();
979 for (auto &V : Vals) {
980 auto *R1 = cast<Instruction>(V.Real)->getOperand(1);
981 auto *I1 = cast<Instruction>(V.Imag)->getOperand(1);
982 OpVals.push_back({R1, I1});
983 }
984 Op1 = identifyNode(OpVals);
985 if (Op1 == nullptr)
986 return nullptr;
987 }
988
989 auto Node =
990 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
991 Node->Opcode = FirstReal->getOpcode();
992 if (isa<FPMathOperator>(FirstReal))
993 Node->Flags = FirstReal->getFastMathFlags();
994
995 Node->addOperand(Op0);
996 if (FirstReal->isBinaryOp())
997 Node->addOperand(Op1);
998
999 return submitCompositeNode(Node);
1000}
1001
1002ComplexDeinterleavingGraph::CompositeNode *
1003ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
1005 ComplexDeinterleavingOperation::CDot, V->getType())) {
1006 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
1007 "operation CDot with the type "
1008 << *V->getType() << "\n");
1009 return nullptr;
1010 }
1011
1012 auto *Inst = cast<Instruction>(V);
1013 auto *RealUser = cast<Instruction>(*Inst->user_begin());
1014
1015 CompositeNode *CN =
1016 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
1017
1018 CompositeNode *ANode = nullptr;
1019
1020 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1021
1022 Value *AReal = nullptr;
1023 Value *AImag = nullptr;
1024 Value *BReal = nullptr;
1025 Value *BImag = nullptr;
1026 Value *Phi = nullptr;
1027
1028 auto UnwrapCast = [](Value *V) -> Value * {
1029 if (auto *CI = dyn_cast<CastInst>(V))
1030 return CI->getOperand(0);
1031 return V;
1032 };
1033
1034 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
1036 m_Mul(m_Value(BReal), m_Value(AReal))),
1037 m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
1038
1039 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
1041 m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
1042 m_Mul(m_Value(BImag), m_Value(AReal)));
1043
1044 if (match(Inst, PatternRot0)) {
1045 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1046 } else if (match(Inst, PatternRot270)) {
1047 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1048 } else {
1049 Value *A0, *A1;
1050 // The rotations 90 and 180 share the same operation pattern, so inspect the
1051 // order of the operands, identifying where the real and imaginary
1052 // components of A go, to discern between the aforementioned rotations.
1053 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
1055 m_Mul(m_Value(BReal), m_Value(A0))),
1056 m_Mul(m_Value(BImag), m_Value(A1)));
1057
1058 if (!match(Inst, PatternRot90Rot180))
1059 return nullptr;
1060
1061 A0 = UnwrapCast(A0);
1062 A1 = UnwrapCast(A1);
1063
1064 // Test if A0 is real/A1 is imag
1065 ANode = identifyNode(A0, A1);
1066 if (!ANode) {
1067 // Test if A0 is imag/A1 is real
1068 ANode = identifyNode(A1, A0);
1069 // Unable to identify operand components, thus unable to identify rotation
1070 if (!ANode)
1071 return nullptr;
1072 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1073 AReal = A1;
1074 AImag = A0;
1075 } else {
1076 AReal = A0;
1077 AImag = A1;
1078 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1079 }
1080 }
1081
1082 AReal = UnwrapCast(AReal);
1083 AImag = UnwrapCast(AImag);
1084 BReal = UnwrapCast(BReal);
1085 BImag = UnwrapCast(BImag);
1086
1087 VectorType *VTy = cast<VectorType>(V->getType());
1088 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1089 if (AReal->getType() != ExpectedOperandTy)
1090 return nullptr;
1091 if (AImag->getType() != ExpectedOperandTy)
1092 return nullptr;
1093 if (BReal->getType() != ExpectedOperandTy)
1094 return nullptr;
1095 if (BImag->getType() != ExpectedOperandTy)
1096 return nullptr;
1097
1098 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1099 return nullptr;
1100
1101 CompositeNode *Node = identifyNode(AReal, AImag);
1102
1103 // In the case that a node was identified to figure out the rotation, ensure
1104 // that trying to identify a node with AReal and AImag post-unwrap results in
1105 // the same node
1106 if (ANode && Node != ANode) {
1107 LLVM_DEBUG(
1108 dbgs()
1109 << "Identified node is different from previously identified node. "
1110 "Unable to confidently generate a complex operation node\n");
1111 return nullptr;
1112 }
1113
1114 CN->addOperand(Node);
1115 CN->addOperand(identifyNode(BReal, BImag));
1116 CN->addOperand(identifyNode(Phi, RealUser));
1117
1118 return submitCompositeNode(CN);
1119}
1120
1121ComplexDeinterleavingGraph::CompositeNode *
1122ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1123 // Partial reductions don't support non-vector types, so check these first
1124 if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
1125 return nullptr;
1126
1127 if (!R->hasUseList() || !I->hasUseList())
1128 return nullptr;
1129
1130 auto CommonUser =
1131 findCommonBetweenCollections<Value *>(R->users(), I->users());
1132 if (!CommonUser)
1133 return nullptr;
1134
1135 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1136 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1137 return nullptr;
1138
1139 if (CompositeNode *CN = identifyDotProduct(IInst))
1140 return CN;
1141
1142 return nullptr;
1143}
1144
1145ComplexDeinterleavingGraph::CompositeNode *
1146ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
1147 auto It = CachedResult.find(Vals);
1148 if (It != CachedResult.end()) {
1149 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1150 return It->second;
1151 }
1152
1153 if (Vals.size() == 1) {
1154 assert(Factor == 2 && "Can only handle interleave factors of 2");
1155 Value *R = Vals[0].Real;
1156 Value *I = Vals[0].Imag;
1157 if (CompositeNode *CN = identifyPartialReduction(R, I))
1158 return CN;
1159 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1160 if (!IsReduction && R->getType() != I->getType())
1161 return nullptr;
1162 }
1163
1164 if (CompositeNode *CN = identifySplat(Vals))
1165 return CN;
1166
1167 for (auto &V : Vals) {
1168 auto *Real = dyn_cast<Instruction>(V.Real);
1169 auto *Imag = dyn_cast<Instruction>(V.Imag);
1170 if (!Real || !Imag)
1171 return nullptr;
1172 }
1173
1174 if (CompositeNode *CN = identifyDeinterleave(Vals))
1175 return CN;
1176
1177 if (Vals.size() == 1) {
1178 assert(Factor == 2 && "Can only handle interleave factors of 2");
1179 auto *Real = dyn_cast<Instruction>(Vals[0].Real);
1180 auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
1181 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1182 return CN;
1183
1184 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1185 return CN;
1186
1187 auto *VTy = cast<VectorType>(Real->getType());
1188 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1189
1190 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1191 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1192 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1193 ComplexDeinterleavingOperation::CAdd, NewVTy);
1194
1195 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
1196 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1197 return CN;
1198 }
1199
1200 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
1201 if (CompositeNode *CN = identifyAdd(Real, Imag))
1202 return CN;
1203 }
1204
1205 if (HasCMulSupport && HasCAddSupport) {
1206 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1207 return CN;
1208 }
1209 }
1210 }
1211
1212 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1213 return CN;
1214
1215 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1216 CachedResult[Vals] = nullptr;
1217 return nullptr;
1218}
1219
1220ComplexDeinterleavingGraph::CompositeNode *
1221ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1222 Instruction *Imag) {
1223 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1224 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1225 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1226 Opcode == Instruction::Sub;
1227 };
1228
1229 if (!IsOperationSupported(Real->getOpcode()) ||
1230 !IsOperationSupported(Imag->getOpcode()))
1231 return nullptr;
1232
1233 std::optional<FastMathFlags> Flags;
1234 if (isa<FPMathOperator>(Real)) {
1235 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1236 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1237 "not identical\n");
1238 return nullptr;
1239 }
1240
1241 Flags = Real->getFastMathFlags();
1242 if (!Flags->allowReassoc()) {
1243 LLVM_DEBUG(
1244 dbgs()
1245 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1246 return nullptr;
1247 }
1248 }
1249
1250 // Collect multiplications and addend instructions from the given instruction
1251 // while traversing it operands. Additionally, verify that all instructions
1252 // have the same fast math flags.
1253 auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
1254 AddendList &Addends) -> bool {
1255 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1256 SmallPtrSet<Value *, 8> Visited;
1257 while (!Worklist.empty()) {
1258 auto [V, IsPositive] = Worklist.pop_back_val();
1259 if (!Visited.insert(V).second)
1260 continue;
1261
1263 if (!I) {
1264 Addends.emplace_back(V, IsPositive);
1265 continue;
1266 }
1267
1268 // If an instruction has more than one user, it indicates that it either
1269 // has an external user, which will be later checked by the checkNodes
1270 // function, or it is a subexpression utilized by multiple expressions. In
1271 // the latter case, we will attempt to separately identify the complex
1272 // operation from here in order to create a shared
1273 // ComplexDeinterleavingCompositeNode.
1274 if (I != Insn && I->hasNUsesOrMore(2)) {
1275 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1276 Addends.emplace_back(I, IsPositive);
1277 continue;
1278 }
1279 switch (I->getOpcode()) {
1280 case Instruction::FAdd:
1281 case Instruction::Add:
1282 Worklist.emplace_back(I->getOperand(1), IsPositive);
1283 Worklist.emplace_back(I->getOperand(0), IsPositive);
1284 break;
1285 case Instruction::FSub:
1286 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1287 Worklist.emplace_back(I->getOperand(0), IsPositive);
1288 break;
1289 case Instruction::Sub:
1290 if (isNeg(I)) {
1291 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1292 } else {
1293 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1294 Worklist.emplace_back(I->getOperand(0), IsPositive);
1295 }
1296 break;
1297 case Instruction::FMul:
1298 case Instruction::Mul: {
1299 Value *A, *B;
1300 if (isNeg(I->getOperand(0))) {
1301 A = getNegOperand(I->getOperand(0));
1302 IsPositive = !IsPositive;
1303 } else {
1304 A = I->getOperand(0);
1305 }
1306
1307 if (isNeg(I->getOperand(1))) {
1308 B = getNegOperand(I->getOperand(1));
1309 IsPositive = !IsPositive;
1310 } else {
1311 B = I->getOperand(1);
1312 }
1313 Muls.push_back(Product{A, B, IsPositive});
1314 break;
1315 }
1316 case Instruction::FNeg:
1317 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1318 break;
1319 default:
1320 Addends.emplace_back(I, IsPositive);
1321 continue;
1322 }
1323
1324 if (Flags && I->getFastMathFlags() != *Flags) {
1325 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1326 "inconsistent with the root instructions' flags: "
1327 << *I << "\n");
1328 return false;
1329 }
1330 }
1331 return true;
1332 };
1333
1334 SmallVector<Product> RealMuls, ImagMuls;
1335 AddendList RealAddends, ImagAddends;
1336 if (!Collect(Real, RealMuls, RealAddends) ||
1337 !Collect(Imag, ImagMuls, ImagAddends))
1338 return nullptr;
1339
1340 if (RealAddends.size() != ImagAddends.size())
1341 return nullptr;
1342
1343 CompositeNode *FinalNode = nullptr;
1344 if (!RealMuls.empty() || !ImagMuls.empty()) {
1345 // If there are multiplicands, extract positive addend and use it as an
1346 // accumulator
1347 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1348 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1349 if (!FinalNode)
1350 return nullptr;
1351 }
1352
1353 // Identify and process remaining additions
1354 if (!RealAddends.empty() || !ImagAddends.empty()) {
1355 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1356 if (!FinalNode)
1357 return nullptr;
1358 }
1359 assert(FinalNode && "FinalNode can not be nullptr here");
1360 assert(FinalNode->Vals.size() == 1);
1361 // Set the Real and Imag fields of the final node and submit it
1362 FinalNode->Vals[0].Real = Real;
1363 FinalNode->Vals[0].Imag = Imag;
1364 submitCompositeNode(FinalNode);
1365 return FinalNode;
1366}
1367
1368bool ComplexDeinterleavingGraph::collectPartialMuls(
1369 ArrayRef<Product> RealMuls, ArrayRef<Product> ImagMuls,
1370 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1371 // Helper function to extract a common operand from two products
1372 auto FindCommonInstruction = [](const Product &Real,
1373 const Product &Imag) -> Value * {
1374 if (Real.Multiplicand == Imag.Multiplicand ||
1375 Real.Multiplicand == Imag.Multiplier)
1376 return Real.Multiplicand;
1377
1378 if (Real.Multiplier == Imag.Multiplicand ||
1379 Real.Multiplier == Imag.Multiplier)
1380 return Real.Multiplier;
1381
1382 return nullptr;
1383 };
1384
1385 // Iterating over real and imaginary multiplications to find common operands
1386 // If a common operand is found, a partial multiplication candidate is created
1387 // and added to the candidates vector The function returns false if no common
1388 // operands are found for any product
1389 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1390 bool FoundCommon = false;
1391 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1392 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1393 if (!Common)
1394 continue;
1395
1396 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1397 : RealMuls[i].Multiplicand;
1398 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1399 : ImagMuls[j].Multiplicand;
1400
1401 auto Node = identifyNode(A, B);
1402 if (Node) {
1403 FoundCommon = true;
1404 PartialMulCandidates.push_back({Common, Node, i, j, false});
1405 }
1406
1407 Node = identifyNode(B, A);
1408 if (Node) {
1409 FoundCommon = true;
1410 PartialMulCandidates.push_back({Common, Node, i, j, true});
1411 }
1412 }
1413 if (!FoundCommon)
1414 return false;
1415 }
1416 return true;
1417}
1418
1419ComplexDeinterleavingGraph::CompositeNode *
1420ComplexDeinterleavingGraph::identifyMultiplications(
1421 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1422 CompositeNode *Accumulator = nullptr) {
1423 if (RealMuls.size() != ImagMuls.size())
1424 return nullptr;
1425
1427 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1428 return nullptr;
1429
1430 // Map to store common instruction to node pointers
1431 DenseMap<Value *, CompositeNode *> CommonToNode;
1432 SmallVector<bool> Processed(Info.size(), false);
1433 for (unsigned I = 0; I < Info.size(); ++I) {
1434 if (Processed[I])
1435 continue;
1436
1437 PartialMulCandidate &InfoA = Info[I];
1438 for (unsigned J = I + 1; J < Info.size(); ++J) {
1439 if (Processed[J])
1440 continue;
1441
1442 PartialMulCandidate &InfoB = Info[J];
1443 auto *InfoReal = &InfoA;
1444 auto *InfoImag = &InfoB;
1445
1446 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1447 if (!NodeFromCommon) {
1448 std::swap(InfoReal, InfoImag);
1449 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1450 }
1451 if (!NodeFromCommon)
1452 continue;
1453
1454 CommonToNode[InfoReal->Common] = NodeFromCommon;
1455 CommonToNode[InfoImag->Common] = NodeFromCommon;
1456 Processed[I] = true;
1457 Processed[J] = true;
1458 }
1459 }
1460
1461 SmallVector<bool> ProcessedReal(RealMuls.size(), false);
1462 SmallVector<bool> ProcessedImag(ImagMuls.size(), false);
1463 CompositeNode *Result = Accumulator;
1464 for (auto &PMI : Info) {
1465 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1466 continue;
1467
1468 auto It = CommonToNode.find(PMI.Common);
1469 // TODO: Process independent complex multiplications. Cases like this:
1470 // A.real() * B where both A and B are complex numbers.
1471 if (It == CommonToNode.end()) {
1472 LLVM_DEBUG({
1473 dbgs() << "Unprocessed independent partial multiplication:\n";
1474 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1475 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1476 << " multiplied by " << *Mul->Multiplicand << "\n";
1477 });
1478 return nullptr;
1479 }
1480
1481 auto &RealMul = RealMuls[PMI.RealIdx];
1482 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1483
1484 auto NodeA = It->second;
1485 auto NodeB = PMI.Node;
1486 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1487 // The following table illustrates the relationship between multiplications
1488 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1489 // can see:
1490 //
1491 // Rotation | Real | Imag |
1492 // ---------+--------+--------+
1493 // 0 | x * u | x * v |
1494 // 90 | -y * v | y * u |
1495 // 180 | -x * u | -x * v |
1496 // 270 | y * v | -y * u |
1497 //
1498 // Check if the candidate can indeed be represented by partial
1499 // multiplication
1500 // TODO: Add support for multiplication by complex one
1501 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1502 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1503 continue;
1504
1505 // Determine the rotation based on the multiplications
1507 if (IsMultiplicandReal) {
1508 // Detect 0 and 180 degrees rotation
1509 if (RealMul.IsPositive && ImagMul.IsPositive)
1511 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1513 else
1514 continue;
1515
1516 } else {
1517 // Detect 90 and 270 degrees rotation
1518 if (!RealMul.IsPositive && ImagMul.IsPositive)
1520 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1522 else
1523 continue;
1524 }
1525
1526 LLVM_DEBUG({
1527 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1528 dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
1529 dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
1530 dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
1531 dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
1532 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1533 });
1534
1535 CompositeNode *NodeMul = prepareCompositeNode(
1536 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1537 NodeMul->Rotation = Rotation;
1538 NodeMul->addOperand(NodeA);
1539 NodeMul->addOperand(NodeB);
1540 if (Result)
1541 NodeMul->addOperand(Result);
1542 submitCompositeNode(NodeMul);
1543 Result = NodeMul;
1544 ProcessedReal[PMI.RealIdx] = true;
1545 ProcessedImag[PMI.ImagIdx] = true;
1546 }
1547
1548 // Ensure all products have been processed, if not return nullptr.
1549 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1550 !all_of(ProcessedImag, [](bool V) { return V; })) {
1551
1552 // Dump debug information about which partial multiplications are not
1553 // processed.
1554 LLVM_DEBUG({
1555 dbgs() << "Unprocessed products (Real):\n";
1556 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1557 if (!ProcessedReal[i])
1558 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1559 << *RealMuls[i].Multiplier << " multiplied by "
1560 << *RealMuls[i].Multiplicand << "\n";
1561 }
1562 dbgs() << "Unprocessed products (Imag):\n";
1563 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1564 if (!ProcessedImag[i])
1565 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1566 << *ImagMuls[i].Multiplier << " multiplied by "
1567 << *ImagMuls[i].Multiplicand << "\n";
1568 }
1569 });
1570 return nullptr;
1571 }
1572
1573 return Result;
1574}
1575
1576ComplexDeinterleavingGraph::CompositeNode *
1577ComplexDeinterleavingGraph::identifyAdditions(
1578 AddendList &RealAddends, AddendList &ImagAddends,
1579 std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
1580 if (RealAddends.size() != ImagAddends.size())
1581 return nullptr;
1582
1583 CompositeNode *Result = nullptr;
1584 // If we have accumulator use it as first addend
1585 if (Accumulator)
1587 // Otherwise find an element with both positive real and imaginary parts.
1588 else
1589 Result = extractPositiveAddend(RealAddends, ImagAddends);
1590
1591 if (!Result)
1592 return nullptr;
1593
1594 while (!RealAddends.empty()) {
1595 auto ItR = RealAddends.begin();
1596 auto [R, IsPositiveR] = *ItR;
1597
1598 bool FoundImag = false;
1599 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1600 auto [I, IsPositiveI] = *ItI;
1602 if (IsPositiveR && IsPositiveI)
1603 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1604 else if (!IsPositiveR && IsPositiveI)
1605 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1606 else if (!IsPositiveR && !IsPositiveI)
1607 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1608 else
1609 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1610
1611 CompositeNode *AddNode = nullptr;
1612 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1613 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1614 AddNode = identifyNode(R, I);
1615 } else {
1616 AddNode = identifyNode(I, R);
1617 }
1618 if (AddNode) {
1619 LLVM_DEBUG({
1620 dbgs() << "Identified addition:\n";
1621 dbgs().indent(4) << "X: " << *R << "\n";
1622 dbgs().indent(4) << "Y: " << *I << "\n";
1623 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1624 });
1625
1626 CompositeNode *TmpNode = nullptr;
1628 TmpNode = prepareCompositeNode(
1629 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1630 if (Flags) {
1631 TmpNode->Opcode = Instruction::FAdd;
1632 TmpNode->Flags = *Flags;
1633 } else {
1634 TmpNode->Opcode = Instruction::Add;
1635 }
1636 } else if (Rotation ==
1638 TmpNode = prepareCompositeNode(
1639 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1640 if (Flags) {
1641 TmpNode->Opcode = Instruction::FSub;
1642 TmpNode->Flags = *Flags;
1643 } else {
1644 TmpNode->Opcode = Instruction::Sub;
1645 }
1646 } else {
1647 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1648 nullptr, nullptr);
1649 TmpNode->Rotation = Rotation;
1650 }
1651
1652 TmpNode->addOperand(Result);
1653 TmpNode->addOperand(AddNode);
1654 submitCompositeNode(TmpNode);
1655 Result = TmpNode;
1656 RealAddends.erase(ItR);
1657 ImagAddends.erase(ItI);
1658 FoundImag = true;
1659 break;
1660 }
1661 }
1662 if (!FoundImag)
1663 return nullptr;
1664 }
1665 return Result;
1666}
1667
1668ComplexDeinterleavingGraph::CompositeNode *
1669ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1670 AddendList &ImagAddends) {
1671 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1672 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1673 auto [R, IsPositiveR] = *ItR;
1674 auto [I, IsPositiveI] = *ItI;
1675 if (IsPositiveR && IsPositiveI) {
1676 auto Result = identifyNode(R, I);
1677 if (Result) {
1678 RealAddends.erase(ItR);
1679 ImagAddends.erase(ItI);
1680 return Result;
1681 }
1682 }
1683 }
1684 }
1685 return nullptr;
1686}
1687
1688bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1689 // This potential root instruction might already have been recognized as
1690 // reduction. Because RootToNode maps both Real and Imaginary parts to
1691 // CompositeNode we should choose only one either Real or Imag instruction to
1692 // use as an anchor for generating complex instruction.
1693 auto It = RootToNode.find(RootI);
1694 if (It != RootToNode.end()) {
1695 auto RootNode = It->second;
1696 assert(RootNode->Operation ==
1697 ComplexDeinterleavingOperation::ReductionOperation ||
1698 RootNode->Operation ==
1699 ComplexDeinterleavingOperation::ReductionSingle);
1700 assert(RootNode->Vals.size() == 1 &&
1701 "Cannot handle reductions involving multiple complex values");
1702 // Find out which part, Real or Imag, comes later, and only if we come to
1703 // the latest part, add it to OrderedRoots.
1704 auto *R = cast<Instruction>(RootNode->Vals[0].Real);
1705 auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(RootNode->Vals[0].Imag)
1706 : nullptr;
1707
1708 Instruction *ReplacementAnchor;
1709 if (I)
1710 ReplacementAnchor = R->comesBefore(I) ? I : R;
1711 else
1712 ReplacementAnchor = R;
1713
1714 if (ReplacementAnchor != RootI)
1715 return false;
1716 OrderedRoots.push_back(RootI);
1717 return true;
1718 }
1719
1720 auto RootNode = identifyRoot(RootI);
1721 if (!RootNode)
1722 return false;
1723
1724 LLVM_DEBUG({
1725 Function *F = RootI->getFunction();
1726 BasicBlock *B = RootI->getParent();
1727 dbgs() << "Complex deinterleaving graph for " << F->getName()
1728 << "::" << B->getName() << ".\n";
1729 dump(dbgs());
1730 dbgs() << "\n";
1731 });
1732 RootToNode[RootI] = RootNode;
1733 OrderedRoots.push_back(RootI);
1734 return true;
1735}
1736
1737bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1738 bool FoundPotentialReduction = false;
1739 if (Factor != 2)
1740 return false;
1741
1742 auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1743 if (!Br || Br->getNumSuccessors() != 2)
1744 return false;
1745
1746 // Identify simple one-block loop
1747 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1748 return false;
1749
1750 for (auto &PHI : B->phis()) {
1751 if (PHI.getNumIncomingValues() != 2)
1752 continue;
1753
1754 if (!PHI.getType()->isVectorTy())
1755 continue;
1756
1757 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1758 if (!ReductionOp)
1759 continue;
1760
1761 // Check if final instruction is reduced outside of current block
1762 Instruction *FinalReduction = nullptr;
1763 auto NumUsers = 0u;
1764 for (auto *U : ReductionOp->users()) {
1765 ++NumUsers;
1766 if (U == &PHI)
1767 continue;
1768 FinalReduction = dyn_cast<Instruction>(U);
1769 }
1770
1771 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1772 isa<PHINode>(FinalReduction))
1773 continue;
1774
1775 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1776 BackEdge = B;
1777 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1778 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1779 Incoming = PHI.getIncomingBlock(IncomingIdx);
1780 FoundPotentialReduction = true;
1781
1782 // If the initial value of PHINode is an Instruction, consider it a leaf
1783 // value of a complex deinterleaving graph.
1784 if (auto *InitPHI =
1785 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1786 FinalInstructions.insert(InitPHI);
1787 }
1788 return FoundPotentialReduction;
1789}
1790
1791void ComplexDeinterleavingGraph::identifyReductionNodes() {
1792 assert(Factor == 2 && "Cannot handle multiple complex values");
1793
1794 SmallVector<bool> Processed(ReductionInfo.size(), false);
1795 SmallVector<Instruction *> OperationInstruction;
1796 for (auto &P : ReductionInfo)
1797 OperationInstruction.push_back(P.first);
1798
1799 // Identify a complex computation by evaluating two reduction operations that
1800 // potentially could be involved
1801 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1802 if (Processed[i])
1803 continue;
1804 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1805 if (Processed[j])
1806 continue;
1807 auto *Real = OperationInstruction[i];
1808 auto *Imag = OperationInstruction[j];
1809 if (Real->getType() != Imag->getType())
1810 continue;
1811
1812 RealPHI = ReductionInfo[Real].first;
1813 ImagPHI = ReductionInfo[Imag].first;
1814 PHIsFound = false;
1815 auto Node = identifyNode(Real, Imag);
1816 if (!Node) {
1817 std::swap(Real, Imag);
1818 std::swap(RealPHI, ImagPHI);
1819 Node = identifyNode(Real, Imag);
1820 }
1821
1822 // If a node is identified and reduction PHINode is used in the chain of
1823 // operations, mark its operation instructions as used to prevent
1824 // re-identification and attach the node to the real part
1825 if (Node && PHIsFound) {
1826 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1827 << *Real << " / " << *Imag << "\n");
1828 Processed[i] = true;
1829 Processed[j] = true;
1830 auto RootNode = prepareCompositeNode(
1831 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1832 RootNode->addOperand(Node);
1833 RootToNode[Real] = RootNode;
1834 RootToNode[Imag] = RootNode;
1835 submitCompositeNode(RootNode);
1836 break;
1837 }
1838 }
1839
1840 auto *Real = OperationInstruction[i];
1841 // We want to check that we have 2 operands, but the function attributes
1842 // being counted as operands bloats this value.
1843 if (Processed[i] || Real->getNumOperands() < 2)
1844 continue;
1845
1846 // Can only combined integer reductions at the moment.
1847 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1848 continue;
1849
1850 RealPHI = ReductionInfo[Real].first;
1851 ImagPHI = nullptr;
1852 PHIsFound = false;
1853 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1854 if (Node && PHIsFound) {
1855 LLVM_DEBUG(
1856 dbgs() << "Identified single reduction starting from instruction: "
1857 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1858
1859 // Reducing to a single vector is not supported, only permit reducing down
1860 // to scalar values.
1861 // Doing this here will leave the prior node in the graph,
1862 // however with no uses the node will be unreachable by the replacement
1863 // process. That along with the usage outside the graph should prevent the
1864 // replacement process from kicking off at all for this graph.
1865 // TODO Add support for reducing to a single vector value
1866 if (ReductionInfo[Real].second->getType()->isVectorTy())
1867 continue;
1868
1869 Processed[i] = true;
1870 auto RootNode = prepareCompositeNode(
1871 ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1872 RootNode->addOperand(Node);
1873 RootToNode[Real] = RootNode;
1874 submitCompositeNode(RootNode);
1875 }
1876 }
1877
1878 RealPHI = nullptr;
1879 ImagPHI = nullptr;
1880}
1881
1882bool ComplexDeinterleavingGraph::checkNodes() {
1883 bool FoundDeinterleaveNode = false;
1884 for (CompositeNode *N : CompositeNodes) {
1885 if (!N->areOperandsValid())
1886 return false;
1887
1888 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1889 FoundDeinterleaveNode = true;
1890 }
1891
1892 // We need a deinterleave node in order to guarantee that we're working with
1893 // complex numbers.
1894 if (!FoundDeinterleaveNode) {
1895 LLVM_DEBUG(
1896 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1897 "guarantee safety during graph transformation.\n");
1898 return false;
1899 }
1900
1901 // Collect all instructions from roots to leaves
1902 SmallPtrSet<Instruction *, 16> AllInstructions;
1903 SmallVector<Instruction *, 8> Worklist;
1904 for (auto &Pair : RootToNode)
1905 Worklist.push_back(Pair.first);
1906
1907 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1908 // chains
1909 while (!Worklist.empty()) {
1910 auto *I = Worklist.pop_back_val();
1911
1912 if (!AllInstructions.insert(I).second)
1913 continue;
1914
1915 for (Value *Op : I->operands()) {
1916 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1917 if (!FinalInstructions.count(I))
1918 Worklist.emplace_back(OpI);
1919 }
1920 }
1921 }
1922
1923 // Find instructions that have users outside of chain
1924 for (auto *I : AllInstructions) {
1925 // Skip root nodes
1926 if (RootToNode.count(I))
1927 continue;
1928
1929 for (User *U : I->users()) {
1930 if (AllInstructions.count(cast<Instruction>(U)))
1931 continue;
1932
1933 // Found an instruction that is not used by XCMLA/XCADD chain
1934 Worklist.emplace_back(I);
1935 break;
1936 }
1937 }
1938
1939 // If any instructions are found to be used outside, find and remove roots
1940 // that somehow connect to those instructions.
1941 SmallPtrSet<Instruction *, 16> Visited;
1942 while (!Worklist.empty()) {
1943 auto *I = Worklist.pop_back_val();
1944 if (!Visited.insert(I).second)
1945 continue;
1946
1947 // Found an impacted root node. Removing it from the nodes to be
1948 // deinterleaved
1949 if (RootToNode.count(I)) {
1950 LLVM_DEBUG(dbgs() << "Instruction " << *I
1951 << " could be deinterleaved but its chain of complex "
1952 "operations have an outside user\n");
1953 RootToNode.erase(I);
1954 }
1955
1956 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1957 continue;
1958
1959 for (User *U : I->users())
1960 Worklist.emplace_back(cast<Instruction>(U));
1961
1962 for (Value *Op : I->operands()) {
1963 if (auto *OpI = dyn_cast<Instruction>(Op))
1964 Worklist.emplace_back(OpI);
1965 }
1966 }
1967 return !RootToNode.empty();
1968}
1969
1970ComplexDeinterleavingGraph::CompositeNode *
1971ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1972 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1974 Intrinsic->getIntrinsicID())
1975 return nullptr;
1976
1977 ComplexValues Vals;
1978 for (unsigned I = 0; I < Factor; I += 2) {
1979 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(I));
1980 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(I + 1));
1981 if (!Real || !Imag)
1982 return nullptr;
1983 Vals.push_back({Real, Imag});
1984 }
1985
1986 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
1987 if (!Node1)
1988 return nullptr;
1989 return Node1;
1990 }
1991
1992 // TODO: We could also add support for fixed-width interleave factors of 4
1993 // and above, but currently for symmetric operations the interleaves and
1994 // deinterleaves are already removed by VectorCombine. If we extend this to
1995 // permit complex multiplications, reductions, etc. then we should also add
1996 // support for fixed-width here.
1997 if (Factor != 2)
1998 return nullptr;
1999
2000 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
2001 if (!SVI)
2002 return nullptr;
2003
2004 // Look for a shufflevector that takes separate vectors of the real and
2005 // imaginary components and recombines them into a single vector.
2006 if (!isInterleavingMask(SVI->getShuffleMask()))
2007 return nullptr;
2008
2009 Instruction *Real;
2010 Instruction *Imag;
2011 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
2012 return nullptr;
2013
2014 return identifyNode(Real, Imag);
2015}
2016
2017ComplexDeinterleavingGraph::CompositeNode *
2018ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
2019 Instruction *II = nullptr;
2020
2021 // Must be at least one complex value.
2022 auto CheckExtract = [&](Value *V, unsigned ExpectedIdx,
2023 Instruction *ExpectedInsn) -> ExtractValueInst * {
2024 auto *EVI = dyn_cast<ExtractValueInst>(V);
2025 if (!EVI || EVI->getNumIndices() != 1 ||
2026 EVI->getIndices()[0] != ExpectedIdx ||
2027 !isa<Instruction>(EVI->getAggregateOperand()) ||
2028 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2029 return nullptr;
2030 return EVI;
2031 };
2032
2033 for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
2034 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
2035 if (RealEVI && Idx == 0)
2037 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
2038 II = nullptr;
2039 break;
2040 }
2041 }
2042
2043 if (auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(II)) {
2044 if (IntrinsicII->getIntrinsicID() !=
2046 return nullptr;
2047
2048 // The remaining should match too.
2049 CompositeNode *PlaceholderNode = prepareCompositeNode(
2051 PlaceholderNode->ReplacementNode = II->getOperand(0);
2052 for (auto &V : Vals) {
2053 FinalInstructions.insert(cast<Instruction>(V.Real));
2054 FinalInstructions.insert(cast<Instruction>(V.Imag));
2055 }
2056 return submitCompositeNode(PlaceholderNode);
2057 }
2058
2059 if (Vals.size() != 1)
2060 return nullptr;
2061
2062 Value *Real = Vals[0].Real;
2063 Value *Imag = Vals[0].Imag;
2064 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
2065 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
2066 if (!RealShuffle || !ImagShuffle) {
2067 if (RealShuffle || ImagShuffle)
2068 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
2069 return nullptr;
2070 }
2071
2072 Value *RealOp1 = RealShuffle->getOperand(1);
2073 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
2074 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
2075 return nullptr;
2076 }
2077 Value *ImagOp1 = ImagShuffle->getOperand(1);
2078 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
2079 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
2080 return nullptr;
2081 }
2082
2083 Value *RealOp0 = RealShuffle->getOperand(0);
2084 Value *ImagOp0 = ImagShuffle->getOperand(0);
2085
2086 if (RealOp0 != ImagOp0) {
2087 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
2088 return nullptr;
2089 }
2090
2091 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2092 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2093 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
2094 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
2095 return nullptr;
2096 }
2097
2098 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2099 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
2100 return nullptr;
2101 }
2102
2103 // Type checking, the shuffle type should be a vector type of the same
2104 // scalar type, but half the size
2105 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2106 Value *Op = Shuffle->getOperand(0);
2107 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
2108 auto *OpTy = cast<FixedVectorType>(Op->getType());
2109
2110 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2111 return false;
2112 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2113 return false;
2114
2115 return true;
2116 };
2117
2118 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
2119 if (!CheckType(Shuffle))
2120 return false;
2121
2122 ArrayRef<int> Mask = Shuffle->getShuffleMask();
2123 int Last = *Mask.rbegin();
2124
2125 Value *Op = Shuffle->getOperand(0);
2126 auto *OpTy = cast<FixedVectorType>(Op->getType());
2127 int NumElements = OpTy->getNumElements();
2128
2129 // Ensure that the deinterleaving shuffle only pulls from the first
2130 // shuffle operand.
2131 return Last < NumElements;
2132 };
2133
2134 if (RealShuffle->getType() != ImagShuffle->getType()) {
2135 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
2136 return nullptr;
2137 }
2138 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2139 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
2140 return nullptr;
2141 }
2142 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2143 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
2144 return nullptr;
2145 }
2146
2147 CompositeNode *PlaceholderNode =
2149 RealShuffle, ImagShuffle);
2150 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2151 FinalInstructions.insert(RealShuffle);
2152 FinalInstructions.insert(ImagShuffle);
2153 return submitCompositeNode(PlaceholderNode);
2154}
2155
2156ComplexDeinterleavingGraph::CompositeNode *
2157ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
2158 auto IsSplat = [](Value *V) -> bool {
2159 // Fixed-width vector with constants
2161 return true;
2162
2163 if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
2164 return isa<VectorType>(V->getType());
2165
2166 VectorType *VTy;
2167 ArrayRef<int> Mask;
2168 // Splats are represented differently depending on whether the repeated
2169 // value is a constant or an Instruction
2170 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
2171 if (Const->getOpcode() != Instruction::ShuffleVector)
2172 return false;
2173 VTy = cast<VectorType>(Const->getType());
2174 Mask = Const->getShuffleMask();
2175 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2176 VTy = Shuf->getType();
2177 Mask = Shuf->getShuffleMask();
2178 } else {
2179 return false;
2180 }
2181
2182 // When the data type is <1 x Type>, it's not possible to differentiate
2183 // between the ComplexDeinterleaving::Deinterleave and
2184 // ComplexDeinterleaving::Splat operations.
2185 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2186 return false;
2187
2188 return all_equal(Mask) && Mask[0] == 0;
2189 };
2190
2191 // The splats must meet the following requirements:
2192 // 1. Must either be all instructions or all values.
2193 // 2. Non-constant splats must live in the same block.
2194 if (auto *FirstValAsInstruction = dyn_cast<Instruction>(Vals[0].Real)) {
2195 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2196 for (auto &V : Vals) {
2197 if (!IsSplat(V.Real) || !IsSplat(V.Imag))
2198 return nullptr;
2199
2200 auto *Real = dyn_cast<Instruction>(V.Real);
2201 auto *Imag = dyn_cast<Instruction>(V.Imag);
2202 if (!Real || !Imag || Real->getParent() != FirstBB ||
2203 Imag->getParent() != FirstBB)
2204 return nullptr;
2205 }
2206 } else {
2207 for (auto &V : Vals) {
2208 if (!IsSplat(V.Real) || !IsSplat(V.Imag) || isa<Instruction>(V.Real) ||
2209 isa<Instruction>(V.Imag))
2210 return nullptr;
2211 }
2212 }
2213
2214 for (auto &V : Vals) {
2215 auto *Real = dyn_cast<Instruction>(V.Real);
2216 auto *Imag = dyn_cast<Instruction>(V.Imag);
2217 if (Real && Imag) {
2218 FinalInstructions.insert(Real);
2219 FinalInstructions.insert(Imag);
2220 }
2221 }
2222 CompositeNode *PlaceholderNode =
2223 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2224 return submitCompositeNode(PlaceholderNode);
2225}
2226
2227ComplexDeinterleavingGraph::CompositeNode *
2228ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2229 Instruction *Imag) {
2230 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2231 return nullptr;
2232
2233 PHIsFound = true;
2234 CompositeNode *PlaceholderNode = prepareCompositeNode(
2235 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2236 return submitCompositeNode(PlaceholderNode);
2237}
2238
2239ComplexDeinterleavingGraph::CompositeNode *
2240ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2241 Instruction *Imag) {
2242 auto *SelectReal = dyn_cast<SelectInst>(Real);
2243 auto *SelectImag = dyn_cast<SelectInst>(Imag);
2244 if (!SelectReal || !SelectImag)
2245 return nullptr;
2246
2247 Instruction *MaskA, *MaskB;
2248 Instruction *AR, *AI, *RA, *BI;
2249 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
2250 m_Instruction(RA))) ||
2251 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
2252 m_Instruction(BI))))
2253 return nullptr;
2254
2255 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
2256 return nullptr;
2257
2258 if (!MaskA->getType()->isVectorTy())
2259 return nullptr;
2260
2261 auto NodeA = identifyNode(AR, AI);
2262 if (!NodeA)
2263 return nullptr;
2264
2265 auto NodeB = identifyNode(RA, BI);
2266 if (!NodeB)
2267 return nullptr;
2268
2269 CompositeNode *PlaceholderNode = prepareCompositeNode(
2270 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2271 PlaceholderNode->addOperand(NodeA);
2272 PlaceholderNode->addOperand(NodeB);
2273 FinalInstructions.insert(MaskA);
2274 FinalInstructions.insert(MaskB);
2275 return submitCompositeNode(PlaceholderNode);
2276}
2277
2278static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2279 std::optional<FastMathFlags> Flags,
2280 Value *InputA, Value *InputB) {
2281 Value *I;
2282 switch (Opcode) {
2283 case Instruction::FNeg:
2284 I = B.CreateFNeg(InputA);
2285 break;
2286 case Instruction::FAdd:
2287 I = B.CreateFAdd(InputA, InputB);
2288 break;
2289 case Instruction::Add:
2290 I = B.CreateAdd(InputA, InputB);
2291 break;
2292 case Instruction::FSub:
2293 I = B.CreateFSub(InputA, InputB);
2294 break;
2295 case Instruction::Sub:
2296 I = B.CreateSub(InputA, InputB);
2297 break;
2298 case Instruction::FMul:
2299 I = B.CreateFMul(InputA, InputB);
2300 break;
2301 case Instruction::Mul:
2302 I = B.CreateMul(InputA, InputB);
2303 break;
2304 default:
2305 llvm_unreachable("Incorrect symmetric opcode");
2306 }
2307 if (Flags)
2308 cast<Instruction>(I)->setFastMathFlags(*Flags);
2309 return I;
2310}
2311
2312Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2313 CompositeNode *Node) {
2314 if (Node->ReplacementNode)
2315 return Node->ReplacementNode;
2316
2317 auto ReplaceOperandIfExist = [&](CompositeNode *Node,
2318 unsigned Idx) -> Value * {
2319 return Node->Operands.size() > Idx
2320 ? replaceNode(Builder, Node->Operands[Idx])
2321 : nullptr;
2322 };
2323
2324 Value *ReplacementNode = nullptr;
2325 switch (Node->Operation) {
2326 case ComplexDeinterleavingOperation::CDot: {
2327 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2328 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2329 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2330 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2331 "Node inputs need to be of the same type"));
2332 ReplacementNode = TL->createComplexDeinterleavingIR(
2333 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
2334 break;
2335 }
2336 case ComplexDeinterleavingOperation::CAdd:
2337 case ComplexDeinterleavingOperation::CMulPartial:
2338 case ComplexDeinterleavingOperation::Symmetric: {
2339 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2340 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2341 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2342 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2343 "Node inputs need to be of the same type"));
2345 (Input0->getType() == Accumulator->getType() &&
2346 "Accumulator and input need to be of the same type"));
2347 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2348 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
2349 Input0, Input1);
2350 else
2351 ReplacementNode = TL->createComplexDeinterleavingIR(
2352 Builder, Node->Operation, Node->Rotation, Input0, Input1,
2353 Accumulator);
2354 break;
2355 }
2356 case ComplexDeinterleavingOperation::Deinterleave:
2357 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2358 break;
2359 case ComplexDeinterleavingOperation::Splat: {
2361 for (auto &V : Node->Vals) {
2362 Ops.push_back(V.Real);
2363 Ops.push_back(V.Imag);
2364 }
2365 auto *R = dyn_cast<Instruction>(Node->Vals[0].Real);
2366 auto *I = dyn_cast<Instruction>(Node->Vals[0].Imag);
2367 if (R && I) {
2368 // Splats that are not constant are interleaved where they are located
2369 Instruction *InsertPoint = R;
2370 for (auto V : Node->Vals) {
2371 if (InsertPoint->comesBefore(cast<Instruction>(V.Real)))
2372 InsertPoint = cast<Instruction>(V.Real);
2373 if (InsertPoint->comesBefore(cast<Instruction>(V.Imag)))
2374 InsertPoint = cast<Instruction>(V.Imag);
2375 }
2376 InsertPoint = InsertPoint->getNextNode();
2377 IRBuilder<> IRB(InsertPoint);
2378 ReplacementNode = IRB.CreateVectorInterleave(Ops);
2379 } else {
2380 ReplacementNode = Builder.CreateVectorInterleave(Ops);
2381 }
2382 break;
2383 }
2384 case ComplexDeinterleavingOperation::ReductionPHI: {
2385 // If Operation is ReductionPHI, a new empty PHINode is created.
2386 // It is filled later when the ReductionOperation is processed.
2387 auto *OldPHI = cast<PHINode>(Node->Vals[0].Real);
2388 auto *VTy = cast<VectorType>(Node->Vals[0].Real->getType());
2389 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2390 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
2391 OldToNewPHI[OldPHI] = NewPHI;
2392 ReplacementNode = NewPHI;
2393 break;
2394 }
2395 case ComplexDeinterleavingOperation::ReductionSingle:
2396 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2397 processReductionSingle(ReplacementNode, Node);
2398 break;
2399 case ComplexDeinterleavingOperation::ReductionOperation:
2400 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2401 processReductionOperation(ReplacementNode, Node);
2402 break;
2403 case ComplexDeinterleavingOperation::ReductionSelect: {
2404 auto *MaskReal = cast<Instruction>(Node->Vals[0].Real)->getOperand(0);
2405 auto *MaskImag = cast<Instruction>(Node->Vals[0].Imag)->getOperand(0);
2406 auto *A = replaceNode(Builder, Node->Operands[0]);
2407 auto *B = replaceNode(Builder, Node->Operands[1]);
2408 auto *NewMask = Builder.CreateVectorInterleave({MaskReal, MaskImag});
2409 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
2410 break;
2411 }
2412 }
2413
2414 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2415 NumComplexTransformations += 1;
2416 Node->ReplacementNode = ReplacementNode;
2417 return ReplacementNode;
2418}
2419
2420void ComplexDeinterleavingGraph::processReductionSingle(
2421 Value *OperationReplacement, CompositeNode *Node) {
2422 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2423 auto *OldPHI = ReductionInfo[Real].first;
2424 auto *NewPHI = OldToNewPHI[OldPHI];
2425 auto *VTy = cast<VectorType>(Real->getType());
2426 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2427
2428 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2429
2430 IRBuilder<> Builder(Incoming->getTerminator());
2431
2432 Value *NewInit = nullptr;
2433 if (auto *C = dyn_cast<Constant>(Init)) {
2434 if (C->isZeroValue())
2435 NewInit = Constant::getNullValue(NewVTy);
2436 }
2437
2438 if (!NewInit)
2439 NewInit =
2440 Builder.CreateVectorInterleave({Init, Constant::getNullValue(VTy)});
2441
2442 NewPHI->addIncoming(NewInit, Incoming);
2443 NewPHI->addIncoming(OperationReplacement, BackEdge);
2444
2445 auto *FinalReduction = ReductionInfo[Real].second;
2446 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2447
2448 auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2449 FinalReduction->replaceAllUsesWith(AddReduce);
2450}
2451
2452void ComplexDeinterleavingGraph::processReductionOperation(
2453 Value *OperationReplacement, CompositeNode *Node) {
2454 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2455 auto *Imag = cast<Instruction>(Node->Vals[0].Imag);
2456 auto *OldPHIReal = ReductionInfo[Real].first;
2457 auto *OldPHIImag = ReductionInfo[Imag].first;
2458 auto *NewPHI = OldToNewPHI[OldPHIReal];
2459
2460 // We have to interleave initial origin values coming from IncomingBlock
2461 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2462 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2463
2464 IRBuilder<> Builder(Incoming->getTerminator());
2465 auto *NewInit = Builder.CreateVectorInterleave({InitReal, InitImag});
2466
2467 NewPHI->addIncoming(NewInit, Incoming);
2468 NewPHI->addIncoming(OperationReplacement, BackEdge);
2469
2470 // Deinterleave complex vector outside of loop so that it can be finally
2471 // reduced
2472 auto *FinalReductionReal = ReductionInfo[Real].second;
2473 auto *FinalReductionImag = ReductionInfo[Imag].second;
2474
2475 Builder.SetInsertPoint(
2476 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2477 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2478 OperationReplacement->getType(),
2479 OperationReplacement);
2480
2481 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2482 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2483
2484 Builder.SetInsertPoint(FinalReductionImag);
2485 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2486 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2487}
2488
2489void ComplexDeinterleavingGraph::replaceNodes() {
2490 SmallVector<Instruction *, 16> DeadInstrRoots;
2491 for (auto *RootInstruction : OrderedRoots) {
2492 // Check if this potential root went through check process and we can
2493 // deinterleave it
2494 if (!RootToNode.count(RootInstruction))
2495 continue;
2496
2497 IRBuilder<> Builder(RootInstruction);
2498 auto RootNode = RootToNode[RootInstruction];
2499 Value *R = replaceNode(Builder, RootNode);
2500
2501 if (RootNode->Operation ==
2502 ComplexDeinterleavingOperation::ReductionOperation) {
2503 auto *RootReal = cast<Instruction>(RootNode->Vals[0].Real);
2504 auto *RootImag = cast<Instruction>(RootNode->Vals[0].Imag);
2505 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2506 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2507 DeadInstrRoots.push_back(RootReal);
2508 DeadInstrRoots.push_back(RootImag);
2509 } else if (RootNode->Operation ==
2510 ComplexDeinterleavingOperation::ReductionSingle) {
2511 auto *RootInst = cast<Instruction>(RootNode->Vals[0].Real);
2512 auto &Info = ReductionInfo[RootInst];
2513 Info.first->removeIncomingValue(BackEdge);
2514 DeadInstrRoots.push_back(Info.second);
2515 } else {
2516 assert(R && "Unable to find replacement for RootInstruction");
2517 DeadInstrRoots.push_back(RootInstruction);
2518 RootInstruction->replaceAllUsesWith(R);
2519 }
2520 }
2521
2522 for (auto *I : DeadInstrRoots)
2524}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Rewrite undef for PHI
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
This file implements a map that provides insertion order iteration.
#define T
uint64_t IntrinsicInst * II
#define P(N)
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
This file describes how to lower LLVM code to machine code.
This pass exposes codegen information to IR-level passes.
BinaryOperator * Mul
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
iterator end()
Definition DenseMap.h:81
bool allowContract() const
Definition FMF.h:69
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition IRBuilder.h:2598
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:207
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
size_type size() const
Definition MapVector.h:56
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:273
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
An opaque object representing a hash code.
Definition Hashing.h:76
const ParentTy * getParent() const
Definition ilist_node.h:34
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BR
Control flow instructions. These all have token chains.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
NodeAddr< NodeBase * > Node
Definition RDFGraph.h:381
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1737
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:533
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
@ Other
Any other memory.
Definition ModRef.h:68
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1770
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1945
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition STLExtras.h:2156
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:592
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
ComplexDeinterleavingPass(const TargetMachine &TM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static unsigned getHashValue(const ComplexValue &Val)
An information struct used to provide DenseMap with the various necessary components for a given valu...