LLVM 20.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
63#include "llvm/ADT/MapVector.h"
64#include "llvm/ADT/Statistic.h"
70#include "llvm/IR/IRBuilder.h"
75#include <algorithm>
76
77using namespace llvm;
78using namespace PatternMatch;
79
80#define DEBUG_TYPE "complex-deinterleaving"
81
82STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83
85 "enable-complex-deinterleaving",
86 cl::desc("Enable generation of complex instructions"), cl::init(true),
88
89/// Checks the given mask, and determines whether said mask is interleaving.
90///
91/// To be interleaving, a mask must alternate between `i` and `i + (Length /
92/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93/// 4x vector interleaving mask would be <0, 2, 1, 3>).
94static bool isInterleavingMask(ArrayRef<int> Mask);
95
96/// Checks the given mask, and determines whether said mask is deinterleaving.
97///
98/// To be deinterleaving, a mask must increment in steps of 2, and either start
99/// with 0 or 1.
100/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101/// <1, 3, 5, 7>).
102static bool isDeinterleavingMask(ArrayRef<int> Mask);
103
104/// Returns true if the operation is a negation of V, and it works for both
105/// integers and floats.
106static bool isNeg(Value *V);
107
108/// Returns the operand for negation operation.
109static Value *getNegOperand(Value *V);
110
111namespace {
112
113class ComplexDeinterleavingLegacyPass : public FunctionPass {
114public:
115 static char ID;
116
117 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118 : FunctionPass(ID), TM(TM) {
121 }
122
123 StringRef getPassName() const override {
124 return "Complex Deinterleaving Pass";
125 }
126
127 bool runOnFunction(Function &F) override;
128 void getAnalysisUsage(AnalysisUsage &AU) const override {
130 AU.setPreservesCFG();
131 }
132
133private:
134 const TargetMachine *TM;
135};
136
137class ComplexDeinterleavingGraph;
138struct ComplexDeinterleavingCompositeNode {
139
140 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
141 Value *R, Value *I)
142 : Operation(Op), Real(R), Imag(I) {}
143
144private:
145 friend class ComplexDeinterleavingGraph;
146 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148
149public:
151 Value *Real;
152 Value *Imag;
153
154 // This two members are required exclusively for generating
155 // ComplexDeinterleavingOperation::Symmetric operations.
156 unsigned Opcode;
157 std::optional<FastMathFlags> Flags;
158
160 ComplexDeinterleavingRotation::Rotation_0;
162 Value *ReplacementNode = nullptr;
163
164 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165
166 void dump() { dump(dbgs()); }
167 void dump(raw_ostream &OS) {
168 auto PrintValue = [&](Value *V) {
169 if (V) {
170 OS << "\"";
171 V->print(OS, true);
172 OS << "\"\n";
173 } else
174 OS << "nullptr\n";
175 };
176 auto PrintNodeRef = [&](RawNodePtr Ptr) {
177 if (Ptr)
178 OS << Ptr << "\n";
179 else
180 OS << "nullptr\n";
181 };
182
183 OS << "- CompositeNode: " << this << "\n";
184 OS << " Real: ";
185 PrintValue(Real);
186 OS << " Imag: ";
187 PrintValue(Imag);
188 OS << " ReplacementNode: ";
189 PrintValue(ReplacementNode);
190 OS << " Operation: " << (int)Operation << "\n";
191 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
192 OS << " Operands: \n";
193 for (const auto &Op : Operands) {
194 OS << " - ";
195 PrintNodeRef(Op);
196 }
197 }
198};
199
200class ComplexDeinterleavingGraph {
201public:
202 struct Product {
203 Value *Multiplier;
204 Value *Multiplicand;
205 bool IsPositive;
206 };
207
208 using Addend = std::pair<Value *, bool>;
209 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
211
212 // Helper struct for holding info about potential partial multiplication
213 // candidates
214 struct PartialMulCandidate {
215 Value *Common;
216 NodePtr Node;
217 unsigned RealIdx;
218 unsigned ImagIdx;
219 bool IsNodeInverted;
220 };
221
222 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
223 const TargetLibraryInfo *TLI)
224 : TL(TL), TLI(TLI) {}
225
226private:
227 const TargetLowering *TL = nullptr;
228 const TargetLibraryInfo *TLI = nullptr;
229 SmallVector<NodePtr> CompositeNodes;
230 DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
231
232 SmallPtrSet<Instruction *, 16> FinalInstructions;
233
234 /// Root instructions are instructions from which complex computation starts
235 std::map<Instruction *, NodePtr> RootToNode;
236
237 /// Topologically sorted root instructions
239
240 /// When examining a basic block for complex deinterleaving, if it is a simple
241 /// one-block loop, then the only incoming block is 'Incoming' and the
242 /// 'BackEdge' block is the block itself."
243 BasicBlock *BackEdge = nullptr;
244 BasicBlock *Incoming = nullptr;
245
246 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
247 /// %OutsideUser as it is shown in the IR:
248 ///
249 /// vector.body:
250 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
251 /// [ %ReductionOp, %vector.body ]
252 /// ...
253 /// %ReductionOp = fadd i64 ...
254 /// ...
255 /// br i1 %condition, label %vector.body, %middle.block
256 ///
257 /// middle.block:
258 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
259 ///
260 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
261 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
263
264 /// In the process of detecting a reduction, we consider a pair of
265 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
266 /// traverse the use-tree to detect complex operations. As this is a reduction
267 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
268 /// to the %ReductionOPs that we suspect to be complex.
269 /// RealPHI and ImagPHI are used by the identifyPHINode method.
270 PHINode *RealPHI = nullptr;
271 PHINode *ImagPHI = nullptr;
272
273 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
274 /// detection.
275 bool PHIsFound = false;
276
277 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
278 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
279 /// This mapping is populated during
280 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
281 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
282 /// replacement process.
283 std::map<PHINode *, PHINode *> OldToNewPHI;
284
285 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
286 Value *R, Value *I) {
287 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
288 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
289 (R && I)) &&
290 "Reduction related nodes must have Real and Imaginary parts");
291 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292 I);
293 }
294
295 NodePtr submitCompositeNode(NodePtr Node) {
296 CompositeNodes.push_back(Node);
297 if (Node->Real && Node->Imag)
298 CachedResult[{Node->Real, Node->Imag}] = Node;
299 return Node;
300 }
301
302 /// Identifies a complex partial multiply pattern and its rotation, based on
303 /// the following patterns
304 ///
305 /// 0: r: cr + ar * br
306 /// i: ci + ar * bi
307 /// 90: r: cr - ai * bi
308 /// i: ci + ai * br
309 /// 180: r: cr - ar * br
310 /// i: ci - ar * bi
311 /// 270: r: cr + ai * bi
312 /// i: ci - ai * br
313 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314
315 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316 /// is partially known from identifyPartialMul, filling in the other half of
317 /// the complex pair.
318 NodePtr
319 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
320 std::pair<Value *, Value *> &CommonOperandI);
321
322 /// Identifies a complex add pattern and its rotation, based on the following
323 /// patterns.
324 ///
325 /// 90: r: ar - bi
326 /// i: ai + br
327 /// 270: r: ar + bi
328 /// i: ai - br
329 NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
330 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331
332 NodePtr identifyNode(Value *R, Value *I);
333
334 /// Determine if a sum of complex numbers can be formed from \p RealAddends
335 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
336 /// Return nullptr if it is not possible to construct a complex number.
337 /// \p Flags are needed to generate symmetric Add and Sub operations.
338 NodePtr identifyAdditions(std::list<Addend> &RealAddends,
339 std::list<Addend> &ImagAddends,
340 std::optional<FastMathFlags> Flags,
341 NodePtr Accumulator);
342
343 /// Extract one addend that have both real and imaginary parts positive.
344 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
345 std::list<Addend> &ImagAddends);
346
347 /// Determine if sum of multiplications of complex numbers can be formed from
348 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
349 /// to it. Return nullptr if it is not possible to construct a complex number.
350 NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
351 std::vector<Product> &ImagMuls,
352 NodePtr Accumulator);
353
354 /// Go through pairs of multiplication (one Real and one Imag) and find all
355 /// possible candidates for partial multiplication and put them into \p
356 /// Candidates. Returns true if all Product has pair with common operand
357 bool collectPartialMuls(const std::vector<Product> &RealMuls,
358 const std::vector<Product> &ImagMuls,
359 std::vector<PartialMulCandidate> &Candidates);
360
361 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
362 /// the order of complex computation operations may be significantly altered,
363 /// and the real and imaginary parts may not be executed in parallel. This
364 /// function takes this into consideration and employs a more general approach
365 /// to identify complex computations. Initially, it gathers all the addends
366 /// and multiplicands and then constructs a complex expression from them.
367 NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
368
369 NodePtr identifyRoot(Instruction *I);
370
371 /// Identifies the Deinterleave operation applied to a vector containing
372 /// complex numbers. There are two ways to represent the Deinterleave
373 /// operation:
374 /// * Using two shufflevectors with even indices for /pReal instruction and
375 /// odd indices for /pImag instructions (only for fixed-width vectors)
376 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
377 /// intrinsic (for both fixed and scalable vectors)
378 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
379
380 /// identifying the operation that represents a complex number repeated in a
381 /// Splat vector. There are two possible types of splats: ConstantExpr with
382 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
383 /// initialization mask with all values set to zero.
384 NodePtr identifySplat(Value *Real, Value *Imag);
385
386 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
387
388 /// Identifies SelectInsts in a loop that has reduction with predication masks
389 /// and/or predicated tail folding
390 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
391
392 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
393
394 /// Complete IR modifications after producing new reduction operation:
395 /// * Populate the PHINode generated for
396 /// ComplexDeinterleavingOperation::ReductionPHI
397 /// * Deinterleave the final value outside of the loop and repurpose original
398 /// reduction users
399 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400
401public:
402 void dump() { dump(dbgs()); }
403 void dump(raw_ostream &OS) {
404 for (const auto &Node : CompositeNodes)
405 Node->dump(OS);
406 }
407
408 /// Returns false if the deinterleaving operation should be cancelled for the
409 /// current graph.
410 bool identifyNodes(Instruction *RootI);
411
412 /// In case \pB is one-block loop, this function seeks potential reductions
413 /// and populates ReductionInfo. Returns true if any reductions were
414 /// identified.
415 bool collectPotentialReductions(BasicBlock *B);
416
417 void identifyReductionNodes();
418
419 /// Check that every instruction, from the roots to the leaves, has internal
420 /// uses.
421 bool checkNodes();
422
423 /// Perform the actual replacement of the underlying instruction graph.
424 void replaceNodes();
425};
426
427class ComplexDeinterleaving {
428public:
429 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430 : TL(tl), TLI(tli) {}
431 bool runOnFunction(Function &F);
432
433private:
434 bool evaluateBasicBlock(BasicBlock *B);
435
436 const TargetLowering *TL = nullptr;
437 const TargetLibraryInfo *TLI = nullptr;
438};
439
440} // namespace
441
442char ComplexDeinterleavingLegacyPass::ID = 0;
443
444INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445 "Complex Deinterleaving", false, false)
446INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
448
451 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454 return PreservedAnalyses::all();
455
458 return PA;
459}
460
462 return new ComplexDeinterleavingLegacyPass(TM);
463}
464
465bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469}
470
471bool ComplexDeinterleaving::runOnFunction(Function &F) {
474 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475 return false;
476 }
477
480 dbgs() << "Complex deinterleaving has been disabled, target does "
481 "not support lowering of complex number operations.\n");
482 return false;
483 }
484
485 bool Changed = false;
486 for (auto &B : F)
487 Changed |= evaluateBasicBlock(&B);
488
489 return Changed;
490}
491
493 // If the size is not even, it's not an interleaving mask
494 if ((Mask.size() & 1))
495 return false;
496
497 int HalfNumElements = Mask.size() / 2;
498 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499 int MaskIdx = Idx * 2;
500 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501 return false;
502 }
503
504 return true;
505}
506
508 int Offset = Mask[0];
509 int HalfNumElements = Mask.size() / 2;
510
511 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512 if (Mask[Idx] != (Idx * 2) + Offset)
513 return false;
514 }
515
516 return true;
517}
518
519bool isNeg(Value *V) {
520 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
521}
522
524 assert(isNeg(V));
525 auto *I = cast<Instruction>(V);
526 if (I->getOpcode() == Instruction::FNeg)
527 return I->getOperand(0);
528
529 return I->getOperand(1);
530}
531
532bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
533 ComplexDeinterleavingGraph Graph(TL, TLI);
534 if (Graph.collectPotentialReductions(B))
535 Graph.identifyReductionNodes();
536
537 for (auto &I : *B)
538 Graph.identifyNodes(&I);
539
540 if (Graph.checkNodes()) {
541 Graph.replaceNodes();
542 return true;
543 }
544
545 return false;
546}
547
548ComplexDeinterleavingGraph::NodePtr
549ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550 Instruction *Real, Instruction *Imag,
551 std::pair<Value *, Value *> &PartialMatch) {
552 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553 << "\n");
554
555 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
557 return nullptr;
558 }
559
560 if ((Real->getOpcode() != Instruction::FMul &&
561 Real->getOpcode() != Instruction::Mul) ||
562 (Imag->getOpcode() != Instruction::FMul &&
563 Imag->getOpcode() != Instruction::Mul)) {
565 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
566 return nullptr;
567 }
568
569 Value *R0 = Real->getOperand(0);
570 Value *R1 = Real->getOperand(1);
571 Value *I0 = Imag->getOperand(0);
572 Value *I1 = Imag->getOperand(1);
573
574 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575 // rotations and use the operand.
576 unsigned Negs = 0;
577 Value *Op;
578 if (match(R0, m_Neg(m_Value(Op)))) {
579 Negs |= 1;
580 R0 = Op;
581 } else if (match(R1, m_Neg(m_Value(Op)))) {
582 Negs |= 1;
583 R1 = Op;
584 }
585
586 if (isNeg(I0)) {
587 Negs |= 2;
588 Negs ^= 1;
589 I0 = Op;
590 } else if (match(I1, m_Neg(m_Value(Op)))) {
591 Negs |= 2;
592 Negs ^= 1;
593 I1 = Op;
594 }
595
597
598 Value *CommonOperand;
599 Value *UncommonRealOp;
600 Value *UncommonImagOp;
601
602 if (R0 == I0 || R0 == I1) {
603 CommonOperand = R0;
604 UncommonRealOp = R1;
605 } else if (R1 == I0 || R1 == I1) {
606 CommonOperand = R1;
607 UncommonRealOp = R0;
608 } else {
609 LLVM_DEBUG(dbgs() << " - No equal operand\n");
610 return nullptr;
611 }
612
613 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615 Rotation == ComplexDeinterleavingRotation::Rotation_270)
616 std::swap(UncommonRealOp, UncommonImagOp);
617
618 // Between identifyPartialMul and here we need to have found a complete valid
619 // pair from the CommonOperand of each part.
620 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621 Rotation == ComplexDeinterleavingRotation::Rotation_180)
622 PartialMatch.first = CommonOperand;
623 else
624 PartialMatch.second = CommonOperand;
625
626 if (!PartialMatch.first || !PartialMatch.second) {
627 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
628 return nullptr;
629 }
630
631 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632 if (!CommonNode) {
633 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
634 return nullptr;
635 }
636
637 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638 if (!UncommonNode) {
639 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
640 return nullptr;
641 }
642
643 NodePtr Node = prepareCompositeNode(
644 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645 Node->Rotation = Rotation;
646 Node->addOperand(CommonNode);
647 Node->addOperand(UncommonNode);
648 return submitCompositeNode(Node);
649}
650
651ComplexDeinterleavingGraph::NodePtr
652ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653 Instruction *Imag) {
654 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655 << "\n");
656 // Determine rotation
657 auto IsAdd = [](unsigned Op) {
658 return Op == Instruction::FAdd || Op == Instruction::Add;
659 };
660 auto IsSub = [](unsigned Op) {
661 return Op == Instruction::FSub || Op == Instruction::Sub;
662 };
664 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665 Rotation = ComplexDeinterleavingRotation::Rotation_0;
666 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667 Rotation = ComplexDeinterleavingRotation::Rotation_90;
668 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669 Rotation = ComplexDeinterleavingRotation::Rotation_180;
670 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671 Rotation = ComplexDeinterleavingRotation::Rotation_270;
672 else {
673 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
674 return nullptr;
675 }
676
677 if (isa<FPMathOperator>(Real) &&
678 (!Real->getFastMathFlags().allowContract() ||
679 !Imag->getFastMathFlags().allowContract())) {
680 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
681 return nullptr;
682 }
683
684 Value *CR = Real->getOperand(0);
685 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686 if (!RealMulI)
687 return nullptr;
688 Value *CI = Imag->getOperand(0);
689 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690 if (!ImagMulI)
691 return nullptr;
692
693 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
695 return nullptr;
696 }
697
698 Value *R0 = RealMulI->getOperand(0);
699 Value *R1 = RealMulI->getOperand(1);
700 Value *I0 = ImagMulI->getOperand(0);
701 Value *I1 = ImagMulI->getOperand(1);
702
703 Value *CommonOperand;
704 Value *UncommonRealOp;
705 Value *UncommonImagOp;
706
707 if (R0 == I0 || R0 == I1) {
708 CommonOperand = R0;
709 UncommonRealOp = R1;
710 } else if (R1 == I0 || R1 == I1) {
711 CommonOperand = R1;
712 UncommonRealOp = R0;
713 } else {
714 LLVM_DEBUG(dbgs() << " - No equal operand\n");
715 return nullptr;
716 }
717
718 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720 Rotation == ComplexDeinterleavingRotation::Rotation_270)
721 std::swap(UncommonRealOp, UncommonImagOp);
722
723 std::pair<Value *, Value *> PartialMatch(
724 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725 Rotation == ComplexDeinterleavingRotation::Rotation_180)
726 ? CommonOperand
727 : nullptr,
728 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729 Rotation == ComplexDeinterleavingRotation::Rotation_270)
730 ? CommonOperand
731 : nullptr);
732
733 auto *CRInst = dyn_cast<Instruction>(CR);
734 auto *CIInst = dyn_cast<Instruction>(CI);
735
736 if (!CRInst || !CIInst) {
737 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
738 return nullptr;
739 }
740
741 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742 if (!CNode) {
743 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
744 return nullptr;
745 }
746
747 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748 if (!UncommonRes) {
749 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
750 return nullptr;
751 }
752
753 assert(PartialMatch.first && PartialMatch.second);
754 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755 if (!CommonRes) {
756 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
757 return nullptr;
758 }
759
760 NodePtr Node = prepareCompositeNode(
761 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762 Node->Rotation = Rotation;
763 Node->addOperand(CommonRes);
764 Node->addOperand(UncommonRes);
765 Node->addOperand(CNode);
766 return submitCompositeNode(Node);
767}
768
769ComplexDeinterleavingGraph::NodePtr
770ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772
773 // Determine rotation
775 if ((Real->getOpcode() == Instruction::FSub &&
776 Imag->getOpcode() == Instruction::FAdd) ||
777 (Real->getOpcode() == Instruction::Sub &&
778 Imag->getOpcode() == Instruction::Add))
779 Rotation = ComplexDeinterleavingRotation::Rotation_90;
780 else if ((Real->getOpcode() == Instruction::FAdd &&
781 Imag->getOpcode() == Instruction::FSub) ||
782 (Real->getOpcode() == Instruction::Add &&
783 Imag->getOpcode() == Instruction::Sub))
784 Rotation = ComplexDeinterleavingRotation::Rotation_270;
785 else {
786 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787 return nullptr;
788 }
789
790 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794
795 if (!AR || !AI || !BR || !BI) {
796 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797 return nullptr;
798 }
799
800 NodePtr ResA = identifyNode(AR, AI);
801 if (!ResA) {
802 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803 return nullptr;
804 }
805 NodePtr ResB = identifyNode(BR, BI);
806 if (!ResB) {
807 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808 return nullptr;
809 }
810
811 NodePtr Node =
812 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813 Node->Rotation = Rotation;
814 Node->addOperand(ResA);
815 Node->addOperand(ResB);
816 return submitCompositeNode(Node);
817}
818
820 unsigned OpcA = A->getOpcode();
821 unsigned OpcB = B->getOpcode();
822
823 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
827}
828
830 auto Pattern =
832
833 return match(A, Pattern) && match(B, Pattern);
834}
835
837 switch (I->getOpcode()) {
838 case Instruction::FAdd:
839 case Instruction::FSub:
840 case Instruction::FMul:
841 case Instruction::FNeg:
842 case Instruction::Add:
843 case Instruction::Sub:
844 case Instruction::Mul:
845 return true;
846 default:
847 return false;
848 }
849}
850
851ComplexDeinterleavingGraph::NodePtr
852ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
853 Instruction *Imag) {
854 if (Real->getOpcode() != Imag->getOpcode())
855 return nullptr;
856
859 return nullptr;
860
861 auto *R0 = Real->getOperand(0);
862 auto *I0 = Imag->getOperand(0);
863
864 NodePtr Op0 = identifyNode(R0, I0);
865 NodePtr Op1 = nullptr;
866 if (Op0 == nullptr)
867 return nullptr;
868
869 if (Real->isBinaryOp()) {
870 auto *R1 = Real->getOperand(1);
871 auto *I1 = Imag->getOperand(1);
872 Op1 = identifyNode(R1, I1);
873 if (Op1 == nullptr)
874 return nullptr;
875 }
876
877 if (isa<FPMathOperator>(Real) &&
878 Real->getFastMathFlags() != Imag->getFastMathFlags())
879 return nullptr;
880
881 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
882 Real, Imag);
883 Node->Opcode = Real->getOpcode();
884 if (isa<FPMathOperator>(Real))
885 Node->Flags = Real->getFastMathFlags();
886
887 Node->addOperand(Op0);
888 if (Real->isBinaryOp())
889 Node->addOperand(Op1);
890
891 return submitCompositeNode(Node);
892}
893
894ComplexDeinterleavingGraph::NodePtr
895ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
896 LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
897 assert(R->getType() == I->getType() &&
898 "Real and imaginary parts should not have different types");
899
900 auto It = CachedResult.find({R, I});
901 if (It != CachedResult.end()) {
902 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
903 return It->second;
904 }
905
906 if (NodePtr CN = identifySplat(R, I))
907 return CN;
908
909 auto *Real = dyn_cast<Instruction>(R);
910 auto *Imag = dyn_cast<Instruction>(I);
911 if (!Real || !Imag)
912 return nullptr;
913
914 if (NodePtr CN = identifyDeinterleave(Real, Imag))
915 return CN;
916
917 if (NodePtr CN = identifyPHINode(Real, Imag))
918 return CN;
919
920 if (NodePtr CN = identifySelectNode(Real, Imag))
921 return CN;
922
923 auto *VTy = cast<VectorType>(Real->getType());
924 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
925
926 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
927 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
928 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
929 ComplexDeinterleavingOperation::CAdd, NewVTy);
930
931 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
932 if (NodePtr CN = identifyPartialMul(Real, Imag))
933 return CN;
934 }
935
936 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
937 if (NodePtr CN = identifyAdd(Real, Imag))
938 return CN;
939 }
940
941 if (HasCMulSupport && HasCAddSupport) {
942 if (NodePtr CN = identifyReassocNodes(Real, Imag))
943 return CN;
944 }
945
946 if (NodePtr CN = identifySymmetricOperation(Real, Imag))
947 return CN;
948
949 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
950 CachedResult[{R, I}] = nullptr;
951 return nullptr;
952}
953
954ComplexDeinterleavingGraph::NodePtr
955ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
956 Instruction *Imag) {
957 auto IsOperationSupported = [](unsigned Opcode) -> bool {
958 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
959 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
960 Opcode == Instruction::Sub;
961 };
962
963 if (!IsOperationSupported(Real->getOpcode()) ||
964 !IsOperationSupported(Imag->getOpcode()))
965 return nullptr;
966
967 std::optional<FastMathFlags> Flags;
968 if (isa<FPMathOperator>(Real)) {
969 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
970 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
971 "not identical\n");
972 return nullptr;
973 }
974
975 Flags = Real->getFastMathFlags();
976 if (!Flags->allowReassoc()) {
978 dbgs()
979 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
980 return nullptr;
981 }
982 }
983
984 // Collect multiplications and addend instructions from the given instruction
985 // while traversing it operands. Additionally, verify that all instructions
986 // have the same fast math flags.
987 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
988 std::list<Addend> &Addends) -> bool {
991 while (!Worklist.empty()) {
992 auto [V, IsPositive] = Worklist.back();
993 Worklist.pop_back();
994 if (!Visited.insert(V).second)
995 continue;
996
997 Instruction *I = dyn_cast<Instruction>(V);
998 if (!I) {
999 Addends.emplace_back(V, IsPositive);
1000 continue;
1001 }
1002
1003 // If an instruction has more than one user, it indicates that it either
1004 // has an external user, which will be later checked by the checkNodes
1005 // function, or it is a subexpression utilized by multiple expressions. In
1006 // the latter case, we will attempt to separately identify the complex
1007 // operation from here in order to create a shared
1008 // ComplexDeinterleavingCompositeNode.
1009 if (I != Insn && I->getNumUses() > 1) {
1010 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1011 Addends.emplace_back(I, IsPositive);
1012 continue;
1013 }
1014 switch (I->getOpcode()) {
1015 case Instruction::FAdd:
1016 case Instruction::Add:
1017 Worklist.emplace_back(I->getOperand(1), IsPositive);
1018 Worklist.emplace_back(I->getOperand(0), IsPositive);
1019 break;
1020 case Instruction::FSub:
1021 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1022 Worklist.emplace_back(I->getOperand(0), IsPositive);
1023 break;
1024 case Instruction::Sub:
1025 if (isNeg(I)) {
1026 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1027 } else {
1028 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1029 Worklist.emplace_back(I->getOperand(0), IsPositive);
1030 }
1031 break;
1032 case Instruction::FMul:
1033 case Instruction::Mul: {
1034 Value *A, *B;
1035 if (isNeg(I->getOperand(0))) {
1036 A = getNegOperand(I->getOperand(0));
1037 IsPositive = !IsPositive;
1038 } else {
1039 A = I->getOperand(0);
1040 }
1041
1042 if (isNeg(I->getOperand(1))) {
1043 B = getNegOperand(I->getOperand(1));
1044 IsPositive = !IsPositive;
1045 } else {
1046 B = I->getOperand(1);
1047 }
1048 Muls.push_back(Product{A, B, IsPositive});
1049 break;
1050 }
1051 case Instruction::FNeg:
1052 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1053 break;
1054 default:
1055 Addends.emplace_back(I, IsPositive);
1056 continue;
1057 }
1058
1059 if (Flags && I->getFastMathFlags() != *Flags) {
1060 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1061 "inconsistent with the root instructions' flags: "
1062 << *I << "\n");
1063 return false;
1064 }
1065 }
1066 return true;
1067 };
1068
1069 std::vector<Product> RealMuls, ImagMuls;
1070 std::list<Addend> RealAddends, ImagAddends;
1071 if (!Collect(Real, RealMuls, RealAddends) ||
1072 !Collect(Imag, ImagMuls, ImagAddends))
1073 return nullptr;
1074
1075 if (RealAddends.size() != ImagAddends.size())
1076 return nullptr;
1077
1078 NodePtr FinalNode;
1079 if (!RealMuls.empty() || !ImagMuls.empty()) {
1080 // If there are multiplicands, extract positive addend and use it as an
1081 // accumulator
1082 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1083 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1084 if (!FinalNode)
1085 return nullptr;
1086 }
1087
1088 // Identify and process remaining additions
1089 if (!RealAddends.empty() || !ImagAddends.empty()) {
1090 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1091 if (!FinalNode)
1092 return nullptr;
1093 }
1094 assert(FinalNode && "FinalNode can not be nullptr here");
1095 // Set the Real and Imag fields of the final node and submit it
1096 FinalNode->Real = Real;
1097 FinalNode->Imag = Imag;
1098 submitCompositeNode(FinalNode);
1099 return FinalNode;
1100}
1101
1102bool ComplexDeinterleavingGraph::collectPartialMuls(
1103 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1104 std::vector<PartialMulCandidate> &PartialMulCandidates) {
1105 // Helper function to extract a common operand from two products
1106 auto FindCommonInstruction = [](const Product &Real,
1107 const Product &Imag) -> Value * {
1108 if (Real.Multiplicand == Imag.Multiplicand ||
1109 Real.Multiplicand == Imag.Multiplier)
1110 return Real.Multiplicand;
1111
1112 if (Real.Multiplier == Imag.Multiplicand ||
1113 Real.Multiplier == Imag.Multiplier)
1114 return Real.Multiplier;
1115
1116 return nullptr;
1117 };
1118
1119 // Iterating over real and imaginary multiplications to find common operands
1120 // If a common operand is found, a partial multiplication candidate is created
1121 // and added to the candidates vector The function returns false if no common
1122 // operands are found for any product
1123 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1124 bool FoundCommon = false;
1125 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1126 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1127 if (!Common)
1128 continue;
1129
1130 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1131 : RealMuls[i].Multiplicand;
1132 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1133 : ImagMuls[j].Multiplicand;
1134
1135 auto Node = identifyNode(A, B);
1136 if (Node) {
1137 FoundCommon = true;
1138 PartialMulCandidates.push_back({Common, Node, i, j, false});
1139 }
1140
1141 Node = identifyNode(B, A);
1142 if (Node) {
1143 FoundCommon = true;
1144 PartialMulCandidates.push_back({Common, Node, i, j, true});
1145 }
1146 }
1147 if (!FoundCommon)
1148 return false;
1149 }
1150 return true;
1151}
1152
1153ComplexDeinterleavingGraph::NodePtr
1154ComplexDeinterleavingGraph::identifyMultiplications(
1155 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1156 NodePtr Accumulator = nullptr) {
1157 if (RealMuls.size() != ImagMuls.size())
1158 return nullptr;
1159
1160 std::vector<PartialMulCandidate> Info;
1161 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1162 return nullptr;
1163
1164 // Map to store common instruction to node pointers
1165 std::map<Value *, NodePtr> CommonToNode;
1166 std::vector<bool> Processed(Info.size(), false);
1167 for (unsigned I = 0; I < Info.size(); ++I) {
1168 if (Processed[I])
1169 continue;
1170
1171 PartialMulCandidate &InfoA = Info[I];
1172 for (unsigned J = I + 1; J < Info.size(); ++J) {
1173 if (Processed[J])
1174 continue;
1175
1176 PartialMulCandidate &InfoB = Info[J];
1177 auto *InfoReal = &InfoA;
1178 auto *InfoImag = &InfoB;
1179
1180 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1181 if (!NodeFromCommon) {
1182 std::swap(InfoReal, InfoImag);
1183 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1184 }
1185 if (!NodeFromCommon)
1186 continue;
1187
1188 CommonToNode[InfoReal->Common] = NodeFromCommon;
1189 CommonToNode[InfoImag->Common] = NodeFromCommon;
1190 Processed[I] = true;
1191 Processed[J] = true;
1192 }
1193 }
1194
1195 std::vector<bool> ProcessedReal(RealMuls.size(), false);
1196 std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1197 NodePtr Result = Accumulator;
1198 for (auto &PMI : Info) {
1199 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1200 continue;
1201
1202 auto It = CommonToNode.find(PMI.Common);
1203 // TODO: Process independent complex multiplications. Cases like this:
1204 // A.real() * B where both A and B are complex numbers.
1205 if (It == CommonToNode.end()) {
1206 LLVM_DEBUG({
1207 dbgs() << "Unprocessed independent partial multiplication:\n";
1208 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1209 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1210 << " multiplied by " << *Mul->Multiplicand << "\n";
1211 });
1212 return nullptr;
1213 }
1214
1215 auto &RealMul = RealMuls[PMI.RealIdx];
1216 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1217
1218 auto NodeA = It->second;
1219 auto NodeB = PMI.Node;
1220 auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1221 // The following table illustrates the relationship between multiplications
1222 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1223 // can see:
1224 //
1225 // Rotation | Real | Imag |
1226 // ---------+--------+--------+
1227 // 0 | x * u | x * v |
1228 // 90 | -y * v | y * u |
1229 // 180 | -x * u | -x * v |
1230 // 270 | y * v | -y * u |
1231 //
1232 // Check if the candidate can indeed be represented by partial
1233 // multiplication
1234 // TODO: Add support for multiplication by complex one
1235 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1236 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1237 continue;
1238
1239 // Determine the rotation based on the multiplications
1241 if (IsMultiplicandReal) {
1242 // Detect 0 and 180 degrees rotation
1243 if (RealMul.IsPositive && ImagMul.IsPositive)
1245 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1247 else
1248 continue;
1249
1250 } else {
1251 // Detect 90 and 270 degrees rotation
1252 if (!RealMul.IsPositive && ImagMul.IsPositive)
1254 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1256 else
1257 continue;
1258 }
1259
1260 LLVM_DEBUG({
1261 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1262 dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1263 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1264 dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1265 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1266 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1267 });
1268
1269 NodePtr NodeMul = prepareCompositeNode(
1270 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1271 NodeMul->Rotation = Rotation;
1272 NodeMul->addOperand(NodeA);
1273 NodeMul->addOperand(NodeB);
1274 if (Result)
1275 NodeMul->addOperand(Result);
1276 submitCompositeNode(NodeMul);
1277 Result = NodeMul;
1278 ProcessedReal[PMI.RealIdx] = true;
1279 ProcessedImag[PMI.ImagIdx] = true;
1280 }
1281
1282 // Ensure all products have been processed, if not return nullptr.
1283 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1284 !all_of(ProcessedImag, [](bool V) { return V; })) {
1285
1286 // Dump debug information about which partial multiplications are not
1287 // processed.
1288 LLVM_DEBUG({
1289 dbgs() << "Unprocessed products (Real):\n";
1290 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1291 if (!ProcessedReal[i])
1292 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1293 << *RealMuls[i].Multiplier << " multiplied by "
1294 << *RealMuls[i].Multiplicand << "\n";
1295 }
1296 dbgs() << "Unprocessed products (Imag):\n";
1297 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1298 if (!ProcessedImag[i])
1299 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1300 << *ImagMuls[i].Multiplier << " multiplied by "
1301 << *ImagMuls[i].Multiplicand << "\n";
1302 }
1303 });
1304 return nullptr;
1305 }
1306
1307 return Result;
1308}
1309
1310ComplexDeinterleavingGraph::NodePtr
1311ComplexDeinterleavingGraph::identifyAdditions(
1312 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1313 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1314 if (RealAddends.size() != ImagAddends.size())
1315 return nullptr;
1316
1317 NodePtr Result;
1318 // If we have accumulator use it as first addend
1319 if (Accumulator)
1321 // Otherwise find an element with both positive real and imaginary parts.
1322 else
1323 Result = extractPositiveAddend(RealAddends, ImagAddends);
1324
1325 if (!Result)
1326 return nullptr;
1327
1328 while (!RealAddends.empty()) {
1329 auto ItR = RealAddends.begin();
1330 auto [R, IsPositiveR] = *ItR;
1331
1332 bool FoundImag = false;
1333 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1334 auto [I, IsPositiveI] = *ItI;
1336 if (IsPositiveR && IsPositiveI)
1337 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1338 else if (!IsPositiveR && IsPositiveI)
1339 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1340 else if (!IsPositiveR && !IsPositiveI)
1341 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1342 else
1343 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1344
1345 NodePtr AddNode;
1346 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1347 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1348 AddNode = identifyNode(R, I);
1349 } else {
1350 AddNode = identifyNode(I, R);
1351 }
1352 if (AddNode) {
1353 LLVM_DEBUG({
1354 dbgs() << "Identified addition:\n";
1355 dbgs().indent(4) << "X: " << *R << "\n";
1356 dbgs().indent(4) << "Y: " << *I << "\n";
1357 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1358 });
1359
1360 NodePtr TmpNode;
1362 TmpNode = prepareCompositeNode(
1363 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1364 if (Flags) {
1365 TmpNode->Opcode = Instruction::FAdd;
1366 TmpNode->Flags = *Flags;
1367 } else {
1368 TmpNode->Opcode = Instruction::Add;
1369 }
1370 } else if (Rotation ==
1372 TmpNode = prepareCompositeNode(
1373 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1374 if (Flags) {
1375 TmpNode->Opcode = Instruction::FSub;
1376 TmpNode->Flags = *Flags;
1377 } else {
1378 TmpNode->Opcode = Instruction::Sub;
1379 }
1380 } else {
1381 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1382 nullptr, nullptr);
1383 TmpNode->Rotation = Rotation;
1384 }
1385
1386 TmpNode->addOperand(Result);
1387 TmpNode->addOperand(AddNode);
1388 submitCompositeNode(TmpNode);
1389 Result = TmpNode;
1390 RealAddends.erase(ItR);
1391 ImagAddends.erase(ItI);
1392 FoundImag = true;
1393 break;
1394 }
1395 }
1396 if (!FoundImag)
1397 return nullptr;
1398 }
1399 return Result;
1400}
1401
1402ComplexDeinterleavingGraph::NodePtr
1403ComplexDeinterleavingGraph::extractPositiveAddend(
1404 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1405 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1406 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1407 auto [R, IsPositiveR] = *ItR;
1408 auto [I, IsPositiveI] = *ItI;
1409 if (IsPositiveR && IsPositiveI) {
1410 auto Result = identifyNode(R, I);
1411 if (Result) {
1412 RealAddends.erase(ItR);
1413 ImagAddends.erase(ItI);
1414 return Result;
1415 }
1416 }
1417 }
1418 }
1419 return nullptr;
1420}
1421
1422bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1423 // This potential root instruction might already have been recognized as
1424 // reduction. Because RootToNode maps both Real and Imaginary parts to
1425 // CompositeNode we should choose only one either Real or Imag instruction to
1426 // use as an anchor for generating complex instruction.
1427 auto It = RootToNode.find(RootI);
1428 if (It != RootToNode.end()) {
1429 auto RootNode = It->second;
1430 assert(RootNode->Operation ==
1431 ComplexDeinterleavingOperation::ReductionOperation);
1432 // Find out which part, Real or Imag, comes later, and only if we come to
1433 // the latest part, add it to OrderedRoots.
1434 auto *R = cast<Instruction>(RootNode->Real);
1435 auto *I = cast<Instruction>(RootNode->Imag);
1436 auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1437 if (ReplacementAnchor != RootI)
1438 return false;
1439 OrderedRoots.push_back(RootI);
1440 return true;
1441 }
1442
1443 auto RootNode = identifyRoot(RootI);
1444 if (!RootNode)
1445 return false;
1446
1447 LLVM_DEBUG({
1448 Function *F = RootI->getFunction();
1449 BasicBlock *B = RootI->getParent();
1450 dbgs() << "Complex deinterleaving graph for " << F->getName()
1451 << "::" << B->getName() << ".\n";
1452 dump(dbgs());
1453 dbgs() << "\n";
1454 });
1455 RootToNode[RootI] = RootNode;
1456 OrderedRoots.push_back(RootI);
1457 return true;
1458}
1459
1460bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1461 bool FoundPotentialReduction = false;
1462
1463 auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1464 if (!Br || Br->getNumSuccessors() != 2)
1465 return false;
1466
1467 // Identify simple one-block loop
1468 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1469 return false;
1470
1472 for (auto &PHI : B->phis()) {
1473 if (PHI.getNumIncomingValues() != 2)
1474 continue;
1475
1476 if (!PHI.getType()->isVectorTy())
1477 continue;
1478
1479 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1480 if (!ReductionOp)
1481 continue;
1482
1483 // Check if final instruction is reduced outside of current block
1484 Instruction *FinalReduction = nullptr;
1485 auto NumUsers = 0u;
1486 for (auto *U : ReductionOp->users()) {
1487 ++NumUsers;
1488 if (U == &PHI)
1489 continue;
1490 FinalReduction = dyn_cast<Instruction>(U);
1491 }
1492
1493 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1494 isa<PHINode>(FinalReduction))
1495 continue;
1496
1497 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1498 BackEdge = B;
1499 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1500 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1501 Incoming = PHI.getIncomingBlock(IncomingIdx);
1502 FoundPotentialReduction = true;
1503
1504 // If the initial value of PHINode is an Instruction, consider it a leaf
1505 // value of a complex deinterleaving graph.
1506 if (auto *InitPHI =
1507 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1508 FinalInstructions.insert(InitPHI);
1509 }
1510 return FoundPotentialReduction;
1511}
1512
1513void ComplexDeinterleavingGraph::identifyReductionNodes() {
1514 SmallVector<bool> Processed(ReductionInfo.size(), false);
1515 SmallVector<Instruction *> OperationInstruction;
1516 for (auto &P : ReductionInfo)
1517 OperationInstruction.push_back(P.first);
1518
1519 // Identify a complex computation by evaluating two reduction operations that
1520 // potentially could be involved
1521 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1522 if (Processed[i])
1523 continue;
1524 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1525 if (Processed[j])
1526 continue;
1527
1528 auto *Real = OperationInstruction[i];
1529 auto *Imag = OperationInstruction[j];
1530 if (Real->getType() != Imag->getType())
1531 continue;
1532
1533 RealPHI = ReductionInfo[Real].first;
1534 ImagPHI = ReductionInfo[Imag].first;
1535 PHIsFound = false;
1536 auto Node = identifyNode(Real, Imag);
1537 if (!Node) {
1538 std::swap(Real, Imag);
1539 std::swap(RealPHI, ImagPHI);
1540 Node = identifyNode(Real, Imag);
1541 }
1542
1543 // If a node is identified and reduction PHINode is used in the chain of
1544 // operations, mark its operation instructions as used to prevent
1545 // re-identification and attach the node to the real part
1546 if (Node && PHIsFound) {
1547 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1548 << *Real << " / " << *Imag << "\n");
1549 Processed[i] = true;
1550 Processed[j] = true;
1551 auto RootNode = prepareCompositeNode(
1552 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1553 RootNode->addOperand(Node);
1554 RootToNode[Real] = RootNode;
1555 RootToNode[Imag] = RootNode;
1556 submitCompositeNode(RootNode);
1557 break;
1558 }
1559 }
1560 }
1561
1562 RealPHI = nullptr;
1563 ImagPHI = nullptr;
1564}
1565
1566bool ComplexDeinterleavingGraph::checkNodes() {
1567 // Collect all instructions from roots to leaves
1568 SmallPtrSet<Instruction *, 16> AllInstructions;
1570 for (auto &Pair : RootToNode)
1571 Worklist.push_back(Pair.first);
1572
1573 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1574 // chains
1575 while (!Worklist.empty()) {
1576 auto *I = Worklist.back();
1577 Worklist.pop_back();
1578
1579 if (!AllInstructions.insert(I).second)
1580 continue;
1581
1582 for (Value *Op : I->operands()) {
1583 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1584 if (!FinalInstructions.count(I))
1585 Worklist.emplace_back(OpI);
1586 }
1587 }
1588 }
1589
1590 // Find instructions that have users outside of chain
1591 SmallVector<Instruction *, 2> OuterInstructions;
1592 for (auto *I : AllInstructions) {
1593 // Skip root nodes
1594 if (RootToNode.count(I))
1595 continue;
1596
1597 for (User *U : I->users()) {
1598 if (AllInstructions.count(cast<Instruction>(U)))
1599 continue;
1600
1601 // Found an instruction that is not used by XCMLA/XCADD chain
1602 Worklist.emplace_back(I);
1603 break;
1604 }
1605 }
1606
1607 // If any instructions are found to be used outside, find and remove roots
1608 // that somehow connect to those instructions.
1610 while (!Worklist.empty()) {
1611 auto *I = Worklist.back();
1612 Worklist.pop_back();
1613 if (!Visited.insert(I).second)
1614 continue;
1615
1616 // Found an impacted root node. Removing it from the nodes to be
1617 // deinterleaved
1618 if (RootToNode.count(I)) {
1619 LLVM_DEBUG(dbgs() << "Instruction " << *I
1620 << " could be deinterleaved but its chain of complex "
1621 "operations have an outside user\n");
1622 RootToNode.erase(I);
1623 }
1624
1625 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1626 continue;
1627
1628 for (User *U : I->users())
1629 Worklist.emplace_back(cast<Instruction>(U));
1630
1631 for (Value *Op : I->operands()) {
1632 if (auto *OpI = dyn_cast<Instruction>(Op))
1633 Worklist.emplace_back(OpI);
1634 }
1635 }
1636 return !RootToNode.empty();
1637}
1638
1639ComplexDeinterleavingGraph::NodePtr
1640ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1641 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642 if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
1643 return nullptr;
1644
1645 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1646 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1647 if (!Real || !Imag)
1648 return nullptr;
1649
1650 return identifyNode(Real, Imag);
1651 }
1652
1653 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1654 if (!SVI)
1655 return nullptr;
1656
1657 // Look for a shufflevector that takes separate vectors of the real and
1658 // imaginary components and recombines them into a single vector.
1659 if (!isInterleavingMask(SVI->getShuffleMask()))
1660 return nullptr;
1661
1662 Instruction *Real;
1663 Instruction *Imag;
1664 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1665 return nullptr;
1666
1667 return identifyNode(Real, Imag);
1668}
1669
1670ComplexDeinterleavingGraph::NodePtr
1671ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1672 Instruction *Imag) {
1673 Instruction *I = nullptr;
1674 Value *FinalValue = nullptr;
1675 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1676 match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1677 match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
1678 m_Value(FinalValue)))) {
1679 NodePtr PlaceholderNode = prepareCompositeNode(
1681 PlaceholderNode->ReplacementNode = FinalValue;
1682 FinalInstructions.insert(Real);
1683 FinalInstructions.insert(Imag);
1684 return submitCompositeNode(PlaceholderNode);
1685 }
1686
1687 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1688 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1689 if (!RealShuffle || !ImagShuffle) {
1690 if (RealShuffle || ImagShuffle)
1691 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1692 return nullptr;
1693 }
1694
1695 Value *RealOp1 = RealShuffle->getOperand(1);
1696 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1697 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1698 return nullptr;
1699 }
1700 Value *ImagOp1 = ImagShuffle->getOperand(1);
1701 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1702 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1703 return nullptr;
1704 }
1705
1706 Value *RealOp0 = RealShuffle->getOperand(0);
1707 Value *ImagOp0 = ImagShuffle->getOperand(0);
1708
1709 if (RealOp0 != ImagOp0) {
1710 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1711 return nullptr;
1712 }
1713
1714 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1715 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1716 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1717 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1718 return nullptr;
1719 }
1720
1721 if (RealMask[0] != 0 || ImagMask[0] != 1) {
1722 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1723 return nullptr;
1724 }
1725
1726 // Type checking, the shuffle type should be a vector type of the same
1727 // scalar type, but half the size
1728 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1729 Value *Op = Shuffle->getOperand(0);
1730 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1731 auto *OpTy = cast<FixedVectorType>(Op->getType());
1732
1733 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1734 return false;
1735 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1736 return false;
1737
1738 return true;
1739 };
1740
1741 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1742 if (!CheckType(Shuffle))
1743 return false;
1744
1745 ArrayRef<int> Mask = Shuffle->getShuffleMask();
1746 int Last = *Mask.rbegin();
1747
1748 Value *Op = Shuffle->getOperand(0);
1749 auto *OpTy = cast<FixedVectorType>(Op->getType());
1750 int NumElements = OpTy->getNumElements();
1751
1752 // Ensure that the deinterleaving shuffle only pulls from the first
1753 // shuffle operand.
1754 return Last < NumElements;
1755 };
1756
1757 if (RealShuffle->getType() != ImagShuffle->getType()) {
1758 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1759 return nullptr;
1760 }
1761 if (!CheckDeinterleavingShuffle(RealShuffle)) {
1762 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1763 return nullptr;
1764 }
1765 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1766 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1767 return nullptr;
1768 }
1769
1770 NodePtr PlaceholderNode =
1772 RealShuffle, ImagShuffle);
1773 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1774 FinalInstructions.insert(RealShuffle);
1775 FinalInstructions.insert(ImagShuffle);
1776 return submitCompositeNode(PlaceholderNode);
1777}
1778
1779ComplexDeinterleavingGraph::NodePtr
1780ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1781 auto IsSplat = [](Value *V) -> bool {
1782 // Fixed-width vector with constants
1783 if (isa<ConstantDataVector>(V))
1784 return true;
1785
1786 VectorType *VTy;
1788 // Splats are represented differently depending on whether the repeated
1789 // value is a constant or an Instruction
1790 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1791 if (Const->getOpcode() != Instruction::ShuffleVector)
1792 return false;
1793 VTy = cast<VectorType>(Const->getType());
1794 Mask = Const->getShuffleMask();
1795 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1796 VTy = Shuf->getType();
1797 Mask = Shuf->getShuffleMask();
1798 } else {
1799 return false;
1800 }
1801
1802 // When the data type is <1 x Type>, it's not possible to differentiate
1803 // between the ComplexDeinterleaving::Deinterleave and
1804 // ComplexDeinterleaving::Splat operations.
1805 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1806 return false;
1807
1808 return all_equal(Mask) && Mask[0] == 0;
1809 };
1810
1811 if (!IsSplat(R) || !IsSplat(I))
1812 return nullptr;
1813
1814 auto *Real = dyn_cast<Instruction>(R);
1815 auto *Imag = dyn_cast<Instruction>(I);
1816 if ((!Real && Imag) || (Real && !Imag))
1817 return nullptr;
1818
1819 if (Real && Imag) {
1820 // Non-constant splats should be in the same basic block
1821 if (Real->getParent() != Imag->getParent())
1822 return nullptr;
1823
1824 FinalInstructions.insert(Real);
1825 FinalInstructions.insert(Imag);
1826 }
1827 NodePtr PlaceholderNode =
1828 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1829 return submitCompositeNode(PlaceholderNode);
1830}
1831
1832ComplexDeinterleavingGraph::NodePtr
1833ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1834 Instruction *Imag) {
1835 if (Real != RealPHI || Imag != ImagPHI)
1836 return nullptr;
1837
1838 PHIsFound = true;
1839 NodePtr PlaceholderNode = prepareCompositeNode(
1840 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1841 return submitCompositeNode(PlaceholderNode);
1842}
1843
1844ComplexDeinterleavingGraph::NodePtr
1845ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1846 Instruction *Imag) {
1847 auto *SelectReal = dyn_cast<SelectInst>(Real);
1848 auto *SelectImag = dyn_cast<SelectInst>(Imag);
1849 if (!SelectReal || !SelectImag)
1850 return nullptr;
1851
1852 Instruction *MaskA, *MaskB;
1853 Instruction *AR, *AI, *RA, *BI;
1854 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1855 m_Instruction(RA))) ||
1856 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1857 m_Instruction(BI))))
1858 return nullptr;
1859
1860 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1861 return nullptr;
1862
1863 if (!MaskA->getType()->isVectorTy())
1864 return nullptr;
1865
1866 auto NodeA = identifyNode(AR, AI);
1867 if (!NodeA)
1868 return nullptr;
1869
1870 auto NodeB = identifyNode(RA, BI);
1871 if (!NodeB)
1872 return nullptr;
1873
1874 NodePtr PlaceholderNode = prepareCompositeNode(
1875 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1876 PlaceholderNode->addOperand(NodeA);
1877 PlaceholderNode->addOperand(NodeB);
1878 FinalInstructions.insert(MaskA);
1879 FinalInstructions.insert(MaskB);
1880 return submitCompositeNode(PlaceholderNode);
1881}
1882
1883static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1884 std::optional<FastMathFlags> Flags,
1885 Value *InputA, Value *InputB) {
1886 Value *I;
1887 switch (Opcode) {
1888 case Instruction::FNeg:
1889 I = B.CreateFNeg(InputA);
1890 break;
1891 case Instruction::FAdd:
1892 I = B.CreateFAdd(InputA, InputB);
1893 break;
1894 case Instruction::Add:
1895 I = B.CreateAdd(InputA, InputB);
1896 break;
1897 case Instruction::FSub:
1898 I = B.CreateFSub(InputA, InputB);
1899 break;
1900 case Instruction::Sub:
1901 I = B.CreateSub(InputA, InputB);
1902 break;
1903 case Instruction::FMul:
1904 I = B.CreateFMul(InputA, InputB);
1905 break;
1906 case Instruction::Mul:
1907 I = B.CreateMul(InputA, InputB);
1908 break;
1909 default:
1910 llvm_unreachable("Incorrect symmetric opcode");
1911 }
1912 if (Flags)
1913 cast<Instruction>(I)->setFastMathFlags(*Flags);
1914 return I;
1915}
1916
1917Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1918 RawNodePtr Node) {
1919 if (Node->ReplacementNode)
1920 return Node->ReplacementNode;
1921
1922 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1923 return Node->Operands.size() > Idx
1924 ? replaceNode(Builder, Node->Operands[Idx])
1925 : nullptr;
1926 };
1927
1928 Value *ReplacementNode;
1929 switch (Node->Operation) {
1930 case ComplexDeinterleavingOperation::CAdd:
1931 case ComplexDeinterleavingOperation::CMulPartial:
1932 case ComplexDeinterleavingOperation::Symmetric: {
1933 Value *Input0 = ReplaceOperandIfExist(Node, 0);
1934 Value *Input1 = ReplaceOperandIfExist(Node, 1);
1935 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1936 assert(!Input1 || (Input0->getType() == Input1->getType() &&
1937 "Node inputs need to be of the same type"));
1939 (Input0->getType() == Accumulator->getType() &&
1940 "Accumulator and input need to be of the same type"));
1941 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1942 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1943 Input0, Input1);
1944 else
1945 ReplacementNode = TL->createComplexDeinterleavingIR(
1946 Builder, Node->Operation, Node->Rotation, Input0, Input1,
1947 Accumulator);
1948 break;
1949 }
1950 case ComplexDeinterleavingOperation::Deinterleave:
1951 llvm_unreachable("Deinterleave node should already have ReplacementNode");
1952 break;
1953 case ComplexDeinterleavingOperation::Splat: {
1954 auto *NewTy = VectorType::getDoubleElementsVectorType(
1955 cast<VectorType>(Node->Real->getType()));
1956 auto *R = dyn_cast<Instruction>(Node->Real);
1957 auto *I = dyn_cast<Instruction>(Node->Imag);
1958 if (R && I) {
1959 // Splats that are not constant are interleaved where they are located
1960 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1961 IRBuilder<> IRB(InsertPoint);
1962 ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
1963 NewTy, {Node->Real, Node->Imag});
1964 } else {
1965 ReplacementNode = Builder.CreateIntrinsic(
1966 Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});
1967 }
1968 break;
1969 }
1970 case ComplexDeinterleavingOperation::ReductionPHI: {
1971 // If Operation is ReductionPHI, a new empty PHINode is created.
1972 // It is filled later when the ReductionOperation is processed.
1973 auto *VTy = cast<VectorType>(Node->Real->getType());
1974 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1975 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
1976 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1977 ReplacementNode = NewPHI;
1978 break;
1979 }
1980 case ComplexDeinterleavingOperation::ReductionOperation:
1981 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1982 processReductionOperation(ReplacementNode, Node);
1983 break;
1984 case ComplexDeinterleavingOperation::ReductionSelect: {
1985 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1986 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1987 auto *A = replaceNode(Builder, Node->Operands[0]);
1988 auto *B = replaceNode(Builder, Node->Operands[1]);
1989 auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1990 cast<VectorType>(MaskReal->getType()));
1991 auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,
1992 NewMaskTy, {MaskReal, MaskImag});
1993 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1994 break;
1995 }
1996 }
1997
1998 assert(ReplacementNode && "Target failed to create Intrinsic call.");
1999 NumComplexTransformations += 1;
2000 Node->ReplacementNode = ReplacementNode;
2001 return ReplacementNode;
2002}
2003
2004void ComplexDeinterleavingGraph::processReductionOperation(
2005 Value *OperationReplacement, RawNodePtr Node) {
2006 auto *Real = cast<Instruction>(Node->Real);
2007 auto *Imag = cast<Instruction>(Node->Imag);
2008 auto *OldPHIReal = ReductionInfo[Real].first;
2009 auto *OldPHIImag = ReductionInfo[Imag].first;
2010 auto *NewPHI = OldToNewPHI[OldPHIReal];
2011
2012 auto *VTy = cast<VectorType>(Real->getType());
2013 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2014
2015 // We have to interleave initial origin values coming from IncomingBlock
2016 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2017 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2018
2019 IRBuilder<> Builder(Incoming->getTerminator());
2020 auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2021 {InitReal, InitImag});
2022
2023 NewPHI->addIncoming(NewInit, Incoming);
2024 NewPHI->addIncoming(OperationReplacement, BackEdge);
2025
2026 // Deinterleave complex vector outside of loop so that it can be finally
2027 // reduced
2028 auto *FinalReductionReal = ReductionInfo[Real].second;
2029 auto *FinalReductionImag = ReductionInfo[Imag].second;
2030
2031 Builder.SetInsertPoint(
2032 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2033 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2034 OperationReplacement->getType(),
2035 OperationReplacement);
2036
2037 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2038 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2039
2040 Builder.SetInsertPoint(FinalReductionImag);
2041 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2042 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2043}
2044
2045void ComplexDeinterleavingGraph::replaceNodes() {
2046 SmallVector<Instruction *, 16> DeadInstrRoots;
2047 for (auto *RootInstruction : OrderedRoots) {
2048 // Check if this potential root went through check process and we can
2049 // deinterleave it
2050 if (!RootToNode.count(RootInstruction))
2051 continue;
2052
2053 IRBuilder<> Builder(RootInstruction);
2054 auto RootNode = RootToNode[RootInstruction];
2055 Value *R = replaceNode(Builder, RootNode.get());
2056
2057 if (RootNode->Operation ==
2058 ComplexDeinterleavingOperation::ReductionOperation) {
2059 auto *RootReal = cast<Instruction>(RootNode->Real);
2060 auto *RootImag = cast<Instruction>(RootNode->Imag);
2061 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2062 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2063 DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2064 DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2065 } else {
2066 assert(R && "Unable to find replacement for RootInstruction");
2067 DeadInstrRoots.push_back(RootInstruction);
2068 RootInstruction->replaceAllUsesWith(R);
2069 }
2070 }
2071
2072 for (auto *I : DeadInstrRoots)
2074}
SmallVector< AArch64_IMM::ImmInsnModel, 4 > Insn
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Rewrite undef for PHI
VarLocInsertPt getNextNode(const DbgRecord *DVR)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
Complex Deinterleaving
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
static bool isInstructionPairMul(Instruction *A, Instruction *B)
#define DEBUG_TYPE
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
static bool runOnFunction(Function &F, bool PostInlining)
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
This file implements a map that provides insertion order iteration.
#define P(N)
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
raw_pwrite_stream & OS
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file describes how to lower LLVM code to machine code.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
BinaryOperator * Mul
DEMANGLE_DUMP_METHOD void dump() const
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:256
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
InstListType::const_iterator getFirstNonPHIIt() const
Iterator returning form of getFirstNonPHI.
Definition: BasicBlock.cpp:374
This class represents an Operation in the Expression.
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:155
iterator end()
Definition: DenseMap.h:84
bool allowContract() const
Definition: FMF.h:70
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
Definition: IRBuilder.h:91
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition: IRBuilder.h:2536
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:933
Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
Definition: IRBuilder.cpp:1091
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2686
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:563
bool isBinaryOp() const
Definition: Instruction.h:279
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:70
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:274
bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:36
size_type size() const
Definition: MapVector.h:60
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
void preserve()
Mark an analysis as preserved.
Definition: Analysis.h:131
This instruction constructs a fixed permutation of two input vectors.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:435
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:367
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:502
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:261
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
const ParentTy * getParent() const
Definition: ilist_node.h:32
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:121
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ BR
Control flow instructions. These all have token chains.
Definition: ISDOpcodes.h:1115
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:100
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
Definition: PatternMatch.h:816
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:875
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< cst_pred_ty< is_zero_int >, ValTy, Instruction::Sub > m_Neg(const ValTy &V)
Matches a 'Neg' as 'sub 0, V'.
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition: DWP.cpp:480
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1722
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:540
void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
DWARFExpression::Operation Op
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition: STLExtras.h:2070
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...