LLVM 20.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
63#include "llvm/ADT/MapVector.h"
64#include "llvm/ADT/Statistic.h"
69#include "llvm/IR/IRBuilder.h"
74#include <algorithm>
75
76using namespace llvm;
77using namespace PatternMatch;
78
79#define DEBUG_TYPE "complex-deinterleaving"
80
81STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82
84 "enable-complex-deinterleaving",
85 cl::desc("Enable generation of complex instructions"), cl::init(true),
87
88/// Checks the given mask, and determines whether said mask is interleaving.
89///
90/// To be interleaving, a mask must alternate between `i` and `i + (Length /
91/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92/// 4x vector interleaving mask would be <0, 2, 1, 3>).
93static bool isInterleavingMask(ArrayRef<int> Mask);
94
95/// Checks the given mask, and determines whether said mask is deinterleaving.
96///
97/// To be deinterleaving, a mask must increment in steps of 2, and either start
98/// with 0 or 1.
99/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100/// <1, 3, 5, 7>).
101static bool isDeinterleavingMask(ArrayRef<int> Mask);
102
103/// Returns true if the operation is a negation of V, and it works for both
104/// integers and floats.
105static bool isNeg(Value *V);
106
107/// Returns the operand for negation operation.
108static Value *getNegOperand(Value *V);
109
110namespace {
111
112class ComplexDeinterleavingLegacyPass : public FunctionPass {
113public:
114 static char ID;
115
116 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
117 : FunctionPass(ID), TM(TM) {
120 }
121
122 StringRef getPassName() const override {
123 return "Complex Deinterleaving Pass";
124 }
125
126 bool runOnFunction(Function &F) override;
127 void getAnalysisUsage(AnalysisUsage &AU) const override {
129 AU.setPreservesCFG();
130 }
131
132private:
133 const TargetMachine *TM;
134};
135
136class ComplexDeinterleavingGraph;
137struct ComplexDeinterleavingCompositeNode {
138
139 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
140 Value *R, Value *I)
141 : Operation(Op), Real(R), Imag(I) {}
142
143private:
144 friend class ComplexDeinterleavingGraph;
145 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
146 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
147
148public:
150 Value *Real;
151 Value *Imag;
152
153 // This two members are required exclusively for generating
154 // ComplexDeinterleavingOperation::Symmetric operations.
155 unsigned Opcode;
156 std::optional<FastMathFlags> Flags;
157
159 ComplexDeinterleavingRotation::Rotation_0;
161 Value *ReplacementNode = nullptr;
162
163 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
164
165 void dump() { dump(dbgs()); }
166 void dump(raw_ostream &OS) {
167 auto PrintValue = [&](Value *V) {
168 if (V) {
169 OS << "\"";
170 V->print(OS, true);
171 OS << "\"\n";
172 } else
173 OS << "nullptr\n";
174 };
175 auto PrintNodeRef = [&](RawNodePtr Ptr) {
176 if (Ptr)
177 OS << Ptr << "\n";
178 else
179 OS << "nullptr\n";
180 };
181
182 OS << "- CompositeNode: " << this << "\n";
183 OS << " Real: ";
184 PrintValue(Real);
185 OS << " Imag: ";
186 PrintValue(Imag);
187 OS << " ReplacementNode: ";
188 PrintValue(ReplacementNode);
189 OS << " Operation: " << (int)Operation << "\n";
190 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
191 OS << " Operands: \n";
192 for (const auto &Op : Operands) {
193 OS << " - ";
194 PrintNodeRef(Op);
195 }
196 }
197};
198
199class ComplexDeinterleavingGraph {
200public:
201 struct Product {
202 Value *Multiplier;
203 Value *Multiplicand;
204 bool IsPositive;
205 };
206
207 using Addend = std::pair<Value *, bool>;
208 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
209 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
210
211 // Helper struct for holding info about potential partial multiplication
212 // candidates
213 struct PartialMulCandidate {
214 Value *Common;
215 NodePtr Node;
216 unsigned RealIdx;
217 unsigned ImagIdx;
218 bool IsNodeInverted;
219 };
220
221 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
222 const TargetLibraryInfo *TLI)
223 : TL(TL), TLI(TLI) {}
224
225private:
226 const TargetLowering *TL = nullptr;
227 const TargetLibraryInfo *TLI = nullptr;
228 SmallVector<NodePtr> CompositeNodes;
229 DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
230
231 SmallPtrSet<Instruction *, 16> FinalInstructions;
232
233 /// Root instructions are instructions from which complex computation starts
234 std::map<Instruction *, NodePtr> RootToNode;
235
236 /// Topologically sorted root instructions
238
239 /// When examining a basic block for complex deinterleaving, if it is a simple
240 /// one-block loop, then the only incoming block is 'Incoming' and the
241 /// 'BackEdge' block is the block itself."
242 BasicBlock *BackEdge = nullptr;
243 BasicBlock *Incoming = nullptr;
244
245 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
246 /// %OutsideUser as it is shown in the IR:
247 ///
248 /// vector.body:
249 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
250 /// [ %ReductionOp, %vector.body ]
251 /// ...
252 /// %ReductionOp = fadd i64 ...
253 /// ...
254 /// br i1 %condition, label %vector.body, %middle.block
255 ///
256 /// middle.block:
257 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
258 ///
259 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
260 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
262
263 /// In the process of detecting a reduction, we consider a pair of
264 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
265 /// traverse the use-tree to detect complex operations. As this is a reduction
266 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
267 /// to the %ReductionOPs that we suspect to be complex.
268 /// RealPHI and ImagPHI are used by the identifyPHINode method.
269 PHINode *RealPHI = nullptr;
270 PHINode *ImagPHI = nullptr;
271
272 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
273 /// detection.
274 bool PHIsFound = false;
275
276 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
277 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
278 /// This mapping is populated during
279 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
280 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
281 /// replacement process.
282 std::map<PHINode *, PHINode *> OldToNewPHI;
283
284 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
285 Value *R, Value *I) {
286 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
287 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
288 (R && I)) &&
289 "Reduction related nodes must have Real and Imaginary parts");
290 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
291 I);
292 }
293
294 NodePtr submitCompositeNode(NodePtr Node) {
295 CompositeNodes.push_back(Node);
296 if (Node->Real && Node->Imag)
297 CachedResult[{Node->Real, Node->Imag}] = Node;
298 return Node;
299 }
300
301 /// Identifies a complex partial multiply pattern and its rotation, based on
302 /// the following patterns
303 ///
304 /// 0: r: cr + ar * br
305 /// i: ci + ar * bi
306 /// 90: r: cr - ai * bi
307 /// i: ci + ai * br
308 /// 180: r: cr - ar * br
309 /// i: ci - ar * bi
310 /// 270: r: cr + ai * bi
311 /// i: ci - ai * br
312 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
313
314 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
315 /// is partially known from identifyPartialMul, filling in the other half of
316 /// the complex pair.
317 NodePtr
318 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
319 std::pair<Value *, Value *> &CommonOperandI);
320
321 /// Identifies a complex add pattern and its rotation, based on the following
322 /// patterns.
323 ///
324 /// 90: r: ar - bi
325 /// i: ai + br
326 /// 270: r: ar + bi
327 /// i: ai - br
328 NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
329 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
330
331 NodePtr identifyNode(Value *R, Value *I);
332
333 /// Determine if a sum of complex numbers can be formed from \p RealAddends
334 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
335 /// Return nullptr if it is not possible to construct a complex number.
336 /// \p Flags are needed to generate symmetric Add and Sub operations.
337 NodePtr identifyAdditions(std::list<Addend> &RealAddends,
338 std::list<Addend> &ImagAddends,
339 std::optional<FastMathFlags> Flags,
340 NodePtr Accumulator);
341
342 /// Extract one addend that have both real and imaginary parts positive.
343 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
344 std::list<Addend> &ImagAddends);
345
346 /// Determine if sum of multiplications of complex numbers can be formed from
347 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
348 /// to it. Return nullptr if it is not possible to construct a complex number.
349 NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
350 std::vector<Product> &ImagMuls,
351 NodePtr Accumulator);
352
353 /// Go through pairs of multiplication (one Real and one Imag) and find all
354 /// possible candidates for partial multiplication and put them into \p
355 /// Candidates. Returns true if all Product has pair with common operand
356 bool collectPartialMuls(const std::vector<Product> &RealMuls,
357 const std::vector<Product> &ImagMuls,
358 std::vector<PartialMulCandidate> &Candidates);
359
360 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
361 /// the order of complex computation operations may be significantly altered,
362 /// and the real and imaginary parts may not be executed in parallel. This
363 /// function takes this into consideration and employs a more general approach
364 /// to identify complex computations. Initially, it gathers all the addends
365 /// and multiplicands and then constructs a complex expression from them.
366 NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
367
368 NodePtr identifyRoot(Instruction *I);
369
370 /// Identifies the Deinterleave operation applied to a vector containing
371 /// complex numbers. There are two ways to represent the Deinterleave
372 /// operation:
373 /// * Using two shufflevectors with even indices for /pReal instruction and
374 /// odd indices for /pImag instructions (only for fixed-width vectors)
375 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
376 /// intrinsic (for both fixed and scalable vectors)
377 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
378
379 /// identifying the operation that represents a complex number repeated in a
380 /// Splat vector. There are two possible types of splats: ConstantExpr with
381 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
382 /// initialization mask with all values set to zero.
383 NodePtr identifySplat(Value *Real, Value *Imag);
384
385 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
386
387 /// Identifies SelectInsts in a loop that has reduction with predication masks
388 /// and/or predicated tail folding
389 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
390
391 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
392
393 /// Complete IR modifications after producing new reduction operation:
394 /// * Populate the PHINode generated for
395 /// ComplexDeinterleavingOperation::ReductionPHI
396 /// * Deinterleave the final value outside of the loop and repurpose original
397 /// reduction users
398 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
399
400public:
401 void dump() { dump(dbgs()); }
402 void dump(raw_ostream &OS) {
403 for (const auto &Node : CompositeNodes)
404 Node->dump(OS);
405 }
406
407 /// Returns false if the deinterleaving operation should be cancelled for the
408 /// current graph.
409 bool identifyNodes(Instruction *RootI);
410
411 /// In case \pB is one-block loop, this function seeks potential reductions
412 /// and populates ReductionInfo. Returns true if any reductions were
413 /// identified.
414 bool collectPotentialReductions(BasicBlock *B);
415
416 void identifyReductionNodes();
417
418 /// Check that every instruction, from the roots to the leaves, has internal
419 /// uses.
420 bool checkNodes();
421
422 /// Perform the actual replacement of the underlying instruction graph.
423 void replaceNodes();
424};
425
426class ComplexDeinterleaving {
427public:
428 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
429 : TL(tl), TLI(tli) {}
430 bool runOnFunction(Function &F);
431
432private:
433 bool evaluateBasicBlock(BasicBlock *B);
434
435 const TargetLowering *TL = nullptr;
436 const TargetLibraryInfo *TLI = nullptr;
437};
438
439} // namespace
440
441char ComplexDeinterleavingLegacyPass::ID = 0;
442
443INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
444 "Complex Deinterleaving", false, false)
445INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
447
450 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
451 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
452 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
453 return PreservedAnalyses::all();
454
457 return PA;
458}
459
461 return new ComplexDeinterleavingLegacyPass(TM);
462}
463
464bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
465 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
466 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
467 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
468}
469
470bool ComplexDeinterleaving::runOnFunction(Function &F) {
473 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
474 return false;
475 }
476
479 dbgs() << "Complex deinterleaving has been disabled, target does "
480 "not support lowering of complex number operations.\n");
481 return false;
482 }
483
484 bool Changed = false;
485 for (auto &B : F)
486 Changed |= evaluateBasicBlock(&B);
487
488 return Changed;
489}
490
492 // If the size is not even, it's not an interleaving mask
493 if ((Mask.size() & 1))
494 return false;
495
496 int HalfNumElements = Mask.size() / 2;
497 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
498 int MaskIdx = Idx * 2;
499 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
500 return false;
501 }
502
503 return true;
504}
505
507 int Offset = Mask[0];
508 int HalfNumElements = Mask.size() / 2;
509
510 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
511 if (Mask[Idx] != (Idx * 2) + Offset)
512 return false;
513 }
514
515 return true;
516}
517
518bool isNeg(Value *V) {
519 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
520}
521
523 assert(isNeg(V));
524 auto *I = cast<Instruction>(V);
525 if (I->getOpcode() == Instruction::FNeg)
526 return I->getOperand(0);
527
528 return I->getOperand(1);
529}
530
531bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
532 ComplexDeinterleavingGraph Graph(TL, TLI);
533 if (Graph.collectPotentialReductions(B))
534 Graph.identifyReductionNodes();
535
536 for (auto &I : *B)
537 Graph.identifyNodes(&I);
538
539 if (Graph.checkNodes()) {
540 Graph.replaceNodes();
541 return true;
542 }
543
544 return false;
545}
546
547ComplexDeinterleavingGraph::NodePtr
548ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
549 Instruction *Real, Instruction *Imag,
550 std::pair<Value *, Value *> &PartialMatch) {
551 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
552 << "\n");
553
554 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
555 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
556 return nullptr;
557 }
558
559 if ((Real->getOpcode() != Instruction::FMul &&
560 Real->getOpcode() != Instruction::Mul) ||
561 (Imag->getOpcode() != Instruction::FMul &&
562 Imag->getOpcode() != Instruction::Mul)) {
564 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
565 return nullptr;
566 }
567
568 Value *R0 = Real->getOperand(0);
569 Value *R1 = Real->getOperand(1);
570 Value *I0 = Imag->getOperand(0);
571 Value *I1 = Imag->getOperand(1);
572
573 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
574 // rotations and use the operand.
575 unsigned Negs = 0;
576 Value *Op;
577 if (match(R0, m_Neg(m_Value(Op)))) {
578 Negs |= 1;
579 R0 = Op;
580 } else if (match(R1, m_Neg(m_Value(Op)))) {
581 Negs |= 1;
582 R1 = Op;
583 }
584
585 if (isNeg(I0)) {
586 Negs |= 2;
587 Negs ^= 1;
588 I0 = Op;
589 } else if (match(I1, m_Neg(m_Value(Op)))) {
590 Negs |= 2;
591 Negs ^= 1;
592 I1 = Op;
593 }
594
596
597 Value *CommonOperand;
598 Value *UncommonRealOp;
599 Value *UncommonImagOp;
600
601 if (R0 == I0 || R0 == I1) {
602 CommonOperand = R0;
603 UncommonRealOp = R1;
604 } else if (R1 == I0 || R1 == I1) {
605 CommonOperand = R1;
606 UncommonRealOp = R0;
607 } else {
608 LLVM_DEBUG(dbgs() << " - No equal operand\n");
609 return nullptr;
610 }
611
612 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
613 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
614 Rotation == ComplexDeinterleavingRotation::Rotation_270)
615 std::swap(UncommonRealOp, UncommonImagOp);
616
617 // Between identifyPartialMul and here we need to have found a complete valid
618 // pair from the CommonOperand of each part.
619 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
620 Rotation == ComplexDeinterleavingRotation::Rotation_180)
621 PartialMatch.first = CommonOperand;
622 else
623 PartialMatch.second = CommonOperand;
624
625 if (!PartialMatch.first || !PartialMatch.second) {
626 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
627 return nullptr;
628 }
629
630 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
631 if (!CommonNode) {
632 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
633 return nullptr;
634 }
635
636 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
637 if (!UncommonNode) {
638 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
639 return nullptr;
640 }
641
642 NodePtr Node = prepareCompositeNode(
643 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
644 Node->Rotation = Rotation;
645 Node->addOperand(CommonNode);
646 Node->addOperand(UncommonNode);
647 return submitCompositeNode(Node);
648}
649
650ComplexDeinterleavingGraph::NodePtr
651ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
652 Instruction *Imag) {
653 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
654 << "\n");
655 // Determine rotation
656 auto IsAdd = [](unsigned Op) {
657 return Op == Instruction::FAdd || Op == Instruction::Add;
658 };
659 auto IsSub = [](unsigned Op) {
660 return Op == Instruction::FSub || Op == Instruction::Sub;
661 };
663 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
664 Rotation = ComplexDeinterleavingRotation::Rotation_0;
665 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
666 Rotation = ComplexDeinterleavingRotation::Rotation_90;
667 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
668 Rotation = ComplexDeinterleavingRotation::Rotation_180;
669 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
670 Rotation = ComplexDeinterleavingRotation::Rotation_270;
671 else {
672 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
673 return nullptr;
674 }
675
676 if (isa<FPMathOperator>(Real) &&
677 (!Real->getFastMathFlags().allowContract() ||
678 !Imag->getFastMathFlags().allowContract())) {
679 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
680 return nullptr;
681 }
682
683 Value *CR = Real->getOperand(0);
684 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
685 if (!RealMulI)
686 return nullptr;
687 Value *CI = Imag->getOperand(0);
688 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
689 if (!ImagMulI)
690 return nullptr;
691
692 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
693 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
694 return nullptr;
695 }
696
697 Value *R0 = RealMulI->getOperand(0);
698 Value *R1 = RealMulI->getOperand(1);
699 Value *I0 = ImagMulI->getOperand(0);
700 Value *I1 = ImagMulI->getOperand(1);
701
702 Value *CommonOperand;
703 Value *UncommonRealOp;
704 Value *UncommonImagOp;
705
706 if (R0 == I0 || R0 == I1) {
707 CommonOperand = R0;
708 UncommonRealOp = R1;
709 } else if (R1 == I0 || R1 == I1) {
710 CommonOperand = R1;
711 UncommonRealOp = R0;
712 } else {
713 LLVM_DEBUG(dbgs() << " - No equal operand\n");
714 return nullptr;
715 }
716
717 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
718 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
719 Rotation == ComplexDeinterleavingRotation::Rotation_270)
720 std::swap(UncommonRealOp, UncommonImagOp);
721
722 std::pair<Value *, Value *> PartialMatch(
723 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
724 Rotation == ComplexDeinterleavingRotation::Rotation_180)
725 ? CommonOperand
726 : nullptr,
727 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
728 Rotation == ComplexDeinterleavingRotation::Rotation_270)
729 ? CommonOperand
730 : nullptr);
731
732 auto *CRInst = dyn_cast<Instruction>(CR);
733 auto *CIInst = dyn_cast<Instruction>(CI);
734
735 if (!CRInst || !CIInst) {
736 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
737 return nullptr;
738 }
739
740 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
741 if (!CNode) {
742 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
743 return nullptr;
744 }
745
746 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
747 if (!UncommonRes) {
748 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
749 return nullptr;
750 }
751
752 assert(PartialMatch.first && PartialMatch.second);
753 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
754 if (!CommonRes) {
755 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
756 return nullptr;
757 }
758
759 NodePtr Node = prepareCompositeNode(
760 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
761 Node->Rotation = Rotation;
762 Node->addOperand(CommonRes);
763 Node->addOperand(UncommonRes);
764 Node->addOperand(CNode);
765 return submitCompositeNode(Node);
766}
767
768ComplexDeinterleavingGraph::NodePtr
769ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
770 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
771
772 // Determine rotation
774 if ((Real->getOpcode() == Instruction::FSub &&
775 Imag->getOpcode() == Instruction::FAdd) ||
776 (Real->getOpcode() == Instruction::Sub &&
777 Imag->getOpcode() == Instruction::Add))
778 Rotation = ComplexDeinterleavingRotation::Rotation_90;
779 else if ((Real->getOpcode() == Instruction::FAdd &&
780 Imag->getOpcode() == Instruction::FSub) ||
781 (Real->getOpcode() == Instruction::Add &&
782 Imag->getOpcode() == Instruction::Sub))
783 Rotation = ComplexDeinterleavingRotation::Rotation_270;
784 else {
785 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
786 return nullptr;
787 }
788
789 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
790 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
791 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
792 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
793
794 if (!AR || !AI || !BR || !BI) {
795 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
796 return nullptr;
797 }
798
799 NodePtr ResA = identifyNode(AR, AI);
800 if (!ResA) {
801 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
802 return nullptr;
803 }
804 NodePtr ResB = identifyNode(BR, BI);
805 if (!ResB) {
806 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
807 return nullptr;
808 }
809
810 NodePtr Node =
811 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
812 Node->Rotation = Rotation;
813 Node->addOperand(ResA);
814 Node->addOperand(ResB);
815 return submitCompositeNode(Node);
816}
817
819 unsigned OpcA = A->getOpcode();
820 unsigned OpcB = B->getOpcode();
821
822 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
823 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
824 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
825 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
826}
827
829 auto Pattern =
831
832 return match(A, Pattern) && match(B, Pattern);
833}
834
836 switch (I->getOpcode()) {
837 case Instruction::FAdd:
838 case Instruction::FSub:
839 case Instruction::FMul:
840 case Instruction::FNeg:
841 case Instruction::Add:
842 case Instruction::Sub:
843 case Instruction::Mul:
844 return true;
845 default:
846 return false;
847 }
848}
849
850ComplexDeinterleavingGraph::NodePtr
851ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
852 Instruction *Imag) {
853 if (Real->getOpcode() != Imag->getOpcode())
854 return nullptr;
855
858 return nullptr;
859
860 auto *R0 = Real->getOperand(0);
861 auto *I0 = Imag->getOperand(0);
862
863 NodePtr Op0 = identifyNode(R0, I0);
864 NodePtr Op1 = nullptr;
865 if (Op0 == nullptr)
866 return nullptr;
867
868 if (Real->isBinaryOp()) {
869 auto *R1 = Real->getOperand(1);
870 auto *I1 = Imag->getOperand(1);
871 Op1 = identifyNode(R1, I1);
872 if (Op1 == nullptr)
873 return nullptr;
874 }
875
876 if (isa<FPMathOperator>(Real) &&
877 Real->getFastMathFlags() != Imag->getFastMathFlags())
878 return nullptr;
879
880 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
881 Real, Imag);
882 Node->Opcode = Real->getOpcode();
883 if (isa<FPMathOperator>(Real))
884 Node->Flags = Real->getFastMathFlags();
885
886 Node->addOperand(Op0);
887 if (Real->isBinaryOp())
888 Node->addOperand(Op1);
889
890 return submitCompositeNode(Node);
891}
892
893ComplexDeinterleavingGraph::NodePtr
894ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
895 LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
896 assert(R->getType() == I->getType() &&
897 "Real and imaginary parts should not have different types");
898
899 auto It = CachedResult.find({R, I});
900 if (It != CachedResult.end()) {
901 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
902 return It->second;
903 }
904
905 if (NodePtr CN = identifySplat(R, I))
906 return CN;
907
908 auto *Real = dyn_cast<Instruction>(R);
909 auto *Imag = dyn_cast<Instruction>(I);
910 if (!Real || !Imag)
911 return nullptr;
912
913 if (NodePtr CN = identifyDeinterleave(Real, Imag))
914 return CN;
915
916 if (NodePtr CN = identifyPHINode(Real, Imag))
917 return CN;
918
919 if (NodePtr CN = identifySelectNode(Real, Imag))
920 return CN;
921
922 auto *VTy = cast<VectorType>(Real->getType());
923 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
924
925 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
926 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
927 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
928 ComplexDeinterleavingOperation::CAdd, NewVTy);
929
930 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
931 if (NodePtr CN = identifyPartialMul(Real, Imag))
932 return CN;
933 }
934
935 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
936 if (NodePtr CN = identifyAdd(Real, Imag))
937 return CN;
938 }
939
940 if (HasCMulSupport && HasCAddSupport) {
941 if (NodePtr CN = identifyReassocNodes(Real, Imag))
942 return CN;
943 }
944
945 if (NodePtr CN = identifySymmetricOperation(Real, Imag))
946 return CN;
947
948 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
949 CachedResult[{R, I}] = nullptr;
950 return nullptr;
951}
952
953ComplexDeinterleavingGraph::NodePtr
954ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
955 Instruction *Imag) {
956 auto IsOperationSupported = [](unsigned Opcode) -> bool {
957 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
958 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
959 Opcode == Instruction::Sub;
960 };
961
962 if (!IsOperationSupported(Real->getOpcode()) ||
963 !IsOperationSupported(Imag->getOpcode()))
964 return nullptr;
965
966 std::optional<FastMathFlags> Flags;
967 if (isa<FPMathOperator>(Real)) {
968 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
969 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
970 "not identical\n");
971 return nullptr;
972 }
973
974 Flags = Real->getFastMathFlags();
975 if (!Flags->allowReassoc()) {
977 dbgs()
978 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
979 return nullptr;
980 }
981 }
982
983 // Collect multiplications and addend instructions from the given instruction
984 // while traversing it operands. Additionally, verify that all instructions
985 // have the same fast math flags.
986 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
987 std::list<Addend> &Addends) -> bool {
990 while (!Worklist.empty()) {
991 auto [V, IsPositive] = Worklist.back();
992 Worklist.pop_back();
993 if (!Visited.insert(V).second)
994 continue;
995
996 Instruction *I = dyn_cast<Instruction>(V);
997 if (!I) {
998 Addends.emplace_back(V, IsPositive);
999 continue;
1000 }
1001
1002 // If an instruction has more than one user, it indicates that it either
1003 // has an external user, which will be later checked by the checkNodes
1004 // function, or it is a subexpression utilized by multiple expressions. In
1005 // the latter case, we will attempt to separately identify the complex
1006 // operation from here in order to create a shared
1007 // ComplexDeinterleavingCompositeNode.
1008 if (I != Insn && I->getNumUses() > 1) {
1009 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1010 Addends.emplace_back(I, IsPositive);
1011 continue;
1012 }
1013 switch (I->getOpcode()) {
1014 case Instruction::FAdd:
1015 case Instruction::Add:
1016 Worklist.emplace_back(I->getOperand(1), IsPositive);
1017 Worklist.emplace_back(I->getOperand(0), IsPositive);
1018 break;
1019 case Instruction::FSub:
1020 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1021 Worklist.emplace_back(I->getOperand(0), IsPositive);
1022 break;
1023 case Instruction::Sub:
1024 if (isNeg(I)) {
1025 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1026 } else {
1027 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1028 Worklist.emplace_back(I->getOperand(0), IsPositive);
1029 }
1030 break;
1031 case Instruction::FMul:
1032 case Instruction::Mul: {
1033 Value *A, *B;
1034 if (isNeg(I->getOperand(0))) {
1035 A = getNegOperand(I->getOperand(0));
1036 IsPositive = !IsPositive;
1037 } else {
1038 A = I->getOperand(0);
1039 }
1040
1041 if (isNeg(I->getOperand(1))) {
1042 B = getNegOperand(I->getOperand(1));
1043 IsPositive = !IsPositive;
1044 } else {
1045 B = I->getOperand(1);
1046 }
1047 Muls.push_back(Product{A, B, IsPositive});
1048 break;
1049 }
1050 case Instruction::FNeg:
1051 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1052 break;
1053 default:
1054 Addends.emplace_back(I, IsPositive);
1055 continue;
1056 }
1057
1058 if (Flags && I->getFastMathFlags() != *Flags) {
1059 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1060 "inconsistent with the root instructions' flags: "
1061 << *I << "\n");
1062 return false;
1063 }
1064 }
1065 return true;
1066 };
1067
1068 std::vector<Product> RealMuls, ImagMuls;
1069 std::list<Addend> RealAddends, ImagAddends;
1070 if (!Collect(Real, RealMuls, RealAddends) ||
1071 !Collect(Imag, ImagMuls, ImagAddends))
1072 return nullptr;
1073
1074 if (RealAddends.size() != ImagAddends.size())
1075 return nullptr;
1076
1077 NodePtr FinalNode;
1078 if (!RealMuls.empty() || !ImagMuls.empty()) {
1079 // If there are multiplicands, extract positive addend and use it as an
1080 // accumulator
1081 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1082 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1083 if (!FinalNode)
1084 return nullptr;
1085 }
1086
1087 // Identify and process remaining additions
1088 if (!RealAddends.empty() || !ImagAddends.empty()) {
1089 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1090 if (!FinalNode)
1091 return nullptr;
1092 }
1093 assert(FinalNode && "FinalNode can not be nullptr here");
1094 // Set the Real and Imag fields of the final node and submit it
1095 FinalNode->Real = Real;
1096 FinalNode->Imag = Imag;
1097 submitCompositeNode(FinalNode);
1098 return FinalNode;
1099}
1100
1101bool ComplexDeinterleavingGraph::collectPartialMuls(
1102 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1103 std::vector<PartialMulCandidate> &PartialMulCandidates) {
1104 // Helper function to extract a common operand from two products
1105 auto FindCommonInstruction = [](const Product &Real,
1106 const Product &Imag) -> Value * {
1107 if (Real.Multiplicand == Imag.Multiplicand ||
1108 Real.Multiplicand == Imag.Multiplier)
1109 return Real.Multiplicand;
1110
1111 if (Real.Multiplier == Imag.Multiplicand ||
1112 Real.Multiplier == Imag.Multiplier)
1113 return Real.Multiplier;
1114
1115 return nullptr;
1116 };
1117
1118 // Iterating over real and imaginary multiplications to find common operands
1119 // If a common operand is found, a partial multiplication candidate is created
1120 // and added to the candidates vector The function returns false if no common
1121 // operands are found for any product
1122 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1123 bool FoundCommon = false;
1124 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1125 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1126 if (!Common)
1127 continue;
1128
1129 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1130 : RealMuls[i].Multiplicand;
1131 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1132 : ImagMuls[j].Multiplicand;
1133
1134 auto Node = identifyNode(A, B);
1135 if (Node) {
1136 FoundCommon = true;
1137 PartialMulCandidates.push_back({Common, Node, i, j, false});
1138 }
1139
1140 Node = identifyNode(B, A);
1141 if (Node) {
1142 FoundCommon = true;
1143 PartialMulCandidates.push_back({Common, Node, i, j, true});
1144 }
1145 }
1146 if (!FoundCommon)
1147 return false;
1148 }
1149 return true;
1150}
1151
1152ComplexDeinterleavingGraph::NodePtr
1153ComplexDeinterleavingGraph::identifyMultiplications(
1154 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1155 NodePtr Accumulator = nullptr) {
1156 if (RealMuls.size() != ImagMuls.size())
1157 return nullptr;
1158
1159 std::vector<PartialMulCandidate> Info;
1160 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1161 return nullptr;
1162
1163 // Map to store common instruction to node pointers
1164 std::map<Value *, NodePtr> CommonToNode;
1165 std::vector<bool> Processed(Info.size(), false);
1166 for (unsigned I = 0; I < Info.size(); ++I) {
1167 if (Processed[I])
1168 continue;
1169
1170 PartialMulCandidate &InfoA = Info[I];
1171 for (unsigned J = I + 1; J < Info.size(); ++J) {
1172 if (Processed[J])
1173 continue;
1174
1175 PartialMulCandidate &InfoB = Info[J];
1176 auto *InfoReal = &InfoA;
1177 auto *InfoImag = &InfoB;
1178
1179 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1180 if (!NodeFromCommon) {
1181 std::swap(InfoReal, InfoImag);
1182 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1183 }
1184 if (!NodeFromCommon)
1185 continue;
1186
1187 CommonToNode[InfoReal->Common] = NodeFromCommon;
1188 CommonToNode[InfoImag->Common] = NodeFromCommon;
1189 Processed[I] = true;
1190 Processed[J] = true;
1191 }
1192 }
1193
1194 std::vector<bool> ProcessedReal(RealMuls.size(), false);
1195 std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1196 NodePtr Result = Accumulator;
1197 for (auto &PMI : Info) {
1198 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1199 continue;
1200
1201 auto It = CommonToNode.find(PMI.Common);
1202 // TODO: Process independent complex multiplications. Cases like this:
1203 // A.real() * B where both A and B are complex numbers.
1204 if (It == CommonToNode.end()) {
1205 LLVM_DEBUG({
1206 dbgs() << "Unprocessed independent partial multiplication:\n";
1207 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1208 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1209 << " multiplied by " << *Mul->Multiplicand << "\n";
1210 });
1211 return nullptr;
1212 }
1213
1214 auto &RealMul = RealMuls[PMI.RealIdx];
1215 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1216
1217 auto NodeA = It->second;
1218 auto NodeB = PMI.Node;
1219 auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1220 // The following table illustrates the relationship between multiplications
1221 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1222 // can see:
1223 //
1224 // Rotation | Real | Imag |
1225 // ---------+--------+--------+
1226 // 0 | x * u | x * v |
1227 // 90 | -y * v | y * u |
1228 // 180 | -x * u | -x * v |
1229 // 270 | y * v | -y * u |
1230 //
1231 // Check if the candidate can indeed be represented by partial
1232 // multiplication
1233 // TODO: Add support for multiplication by complex one
1234 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1235 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1236 continue;
1237
1238 // Determine the rotation based on the multiplications
1240 if (IsMultiplicandReal) {
1241 // Detect 0 and 180 degrees rotation
1242 if (RealMul.IsPositive && ImagMul.IsPositive)
1244 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1246 else
1247 continue;
1248
1249 } else {
1250 // Detect 90 and 270 degrees rotation
1251 if (!RealMul.IsPositive && ImagMul.IsPositive)
1253 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1255 else
1256 continue;
1257 }
1258
1259 LLVM_DEBUG({
1260 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1261 dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1262 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1263 dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1264 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1265 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1266 });
1267
1268 NodePtr NodeMul = prepareCompositeNode(
1269 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1270 NodeMul->Rotation = Rotation;
1271 NodeMul->addOperand(NodeA);
1272 NodeMul->addOperand(NodeB);
1273 if (Result)
1274 NodeMul->addOperand(Result);
1275 submitCompositeNode(NodeMul);
1276 Result = NodeMul;
1277 ProcessedReal[PMI.RealIdx] = true;
1278 ProcessedImag[PMI.ImagIdx] = true;
1279 }
1280
1281 // Ensure all products have been processed, if not return nullptr.
1282 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1283 !all_of(ProcessedImag, [](bool V) { return V; })) {
1284
1285 // Dump debug information about which partial multiplications are not
1286 // processed.
1287 LLVM_DEBUG({
1288 dbgs() << "Unprocessed products (Real):\n";
1289 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1290 if (!ProcessedReal[i])
1291 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1292 << *RealMuls[i].Multiplier << " multiplied by "
1293 << *RealMuls[i].Multiplicand << "\n";
1294 }
1295 dbgs() << "Unprocessed products (Imag):\n";
1296 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1297 if (!ProcessedImag[i])
1298 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1299 << *ImagMuls[i].Multiplier << " multiplied by "
1300 << *ImagMuls[i].Multiplicand << "\n";
1301 }
1302 });
1303 return nullptr;
1304 }
1305
1306 return Result;
1307}
1308
1309ComplexDeinterleavingGraph::NodePtr
1310ComplexDeinterleavingGraph::identifyAdditions(
1311 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1312 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1313 if (RealAddends.size() != ImagAddends.size())
1314 return nullptr;
1315
1316 NodePtr Result;
1317 // If we have accumulator use it as first addend
1318 if (Accumulator)
1320 // Otherwise find an element with both positive real and imaginary parts.
1321 else
1322 Result = extractPositiveAddend(RealAddends, ImagAddends);
1323
1324 if (!Result)
1325 return nullptr;
1326
1327 while (!RealAddends.empty()) {
1328 auto ItR = RealAddends.begin();
1329 auto [R, IsPositiveR] = *ItR;
1330
1331 bool FoundImag = false;
1332 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1333 auto [I, IsPositiveI] = *ItI;
1335 if (IsPositiveR && IsPositiveI)
1336 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1337 else if (!IsPositiveR && IsPositiveI)
1338 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1339 else if (!IsPositiveR && !IsPositiveI)
1340 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1341 else
1342 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1343
1344 NodePtr AddNode;
1345 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1346 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1347 AddNode = identifyNode(R, I);
1348 } else {
1349 AddNode = identifyNode(I, R);
1350 }
1351 if (AddNode) {
1352 LLVM_DEBUG({
1353 dbgs() << "Identified addition:\n";
1354 dbgs().indent(4) << "X: " << *R << "\n";
1355 dbgs().indent(4) << "Y: " << *I << "\n";
1356 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1357 });
1358
1359 NodePtr TmpNode;
1361 TmpNode = prepareCompositeNode(
1362 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1363 if (Flags) {
1364 TmpNode->Opcode = Instruction::FAdd;
1365 TmpNode->Flags = *Flags;
1366 } else {
1367 TmpNode->Opcode = Instruction::Add;
1368 }
1369 } else if (Rotation ==
1371 TmpNode = prepareCompositeNode(
1372 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1373 if (Flags) {
1374 TmpNode->Opcode = Instruction::FSub;
1375 TmpNode->Flags = *Flags;
1376 } else {
1377 TmpNode->Opcode = Instruction::Sub;
1378 }
1379 } else {
1380 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1381 nullptr, nullptr);
1382 TmpNode->Rotation = Rotation;
1383 }
1384
1385 TmpNode->addOperand(Result);
1386 TmpNode->addOperand(AddNode);
1387 submitCompositeNode(TmpNode);
1388 Result = TmpNode;
1389 RealAddends.erase(ItR);
1390 ImagAddends.erase(ItI);
1391 FoundImag = true;
1392 break;
1393 }
1394 }
1395 if (!FoundImag)
1396 return nullptr;
1397 }
1398 return Result;
1399}
1400
1401ComplexDeinterleavingGraph::NodePtr
1402ComplexDeinterleavingGraph::extractPositiveAddend(
1403 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1404 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1405 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1406 auto [R, IsPositiveR] = *ItR;
1407 auto [I, IsPositiveI] = *ItI;
1408 if (IsPositiveR && IsPositiveI) {
1409 auto Result = identifyNode(R, I);
1410 if (Result) {
1411 RealAddends.erase(ItR);
1412 ImagAddends.erase(ItI);
1413 return Result;
1414 }
1415 }
1416 }
1417 }
1418 return nullptr;
1419}
1420
1421bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1422 // This potential root instruction might already have been recognized as
1423 // reduction. Because RootToNode maps both Real and Imaginary parts to
1424 // CompositeNode we should choose only one either Real or Imag instruction to
1425 // use as an anchor for generating complex instruction.
1426 auto It = RootToNode.find(RootI);
1427 if (It != RootToNode.end()) {
1428 auto RootNode = It->second;
1429 assert(RootNode->Operation ==
1430 ComplexDeinterleavingOperation::ReductionOperation);
1431 // Find out which part, Real or Imag, comes later, and only if we come to
1432 // the latest part, add it to OrderedRoots.
1433 auto *R = cast<Instruction>(RootNode->Real);
1434 auto *I = cast<Instruction>(RootNode->Imag);
1435 auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1436 if (ReplacementAnchor != RootI)
1437 return false;
1438 OrderedRoots.push_back(RootI);
1439 return true;
1440 }
1441
1442 auto RootNode = identifyRoot(RootI);
1443 if (!RootNode)
1444 return false;
1445
1446 LLVM_DEBUG({
1447 Function *F = RootI->getFunction();
1448 BasicBlock *B = RootI->getParent();
1449 dbgs() << "Complex deinterleaving graph for " << F->getName()
1450 << "::" << B->getName() << ".\n";
1451 dump(dbgs());
1452 dbgs() << "\n";
1453 });
1454 RootToNode[RootI] = RootNode;
1455 OrderedRoots.push_back(RootI);
1456 return true;
1457}
1458
1459bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1460 bool FoundPotentialReduction = false;
1461
1462 auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1463 if (!Br || Br->getNumSuccessors() != 2)
1464 return false;
1465
1466 // Identify simple one-block loop
1467 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1468 return false;
1469
1471 for (auto &PHI : B->phis()) {
1472 if (PHI.getNumIncomingValues() != 2)
1473 continue;
1474
1475 if (!PHI.getType()->isVectorTy())
1476 continue;
1477
1478 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1479 if (!ReductionOp)
1480 continue;
1481
1482 // Check if final instruction is reduced outside of current block
1483 Instruction *FinalReduction = nullptr;
1484 auto NumUsers = 0u;
1485 for (auto *U : ReductionOp->users()) {
1486 ++NumUsers;
1487 if (U == &PHI)
1488 continue;
1489 FinalReduction = dyn_cast<Instruction>(U);
1490 }
1491
1492 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1493 isa<PHINode>(FinalReduction))
1494 continue;
1495
1496 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1497 BackEdge = B;
1498 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1499 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1500 Incoming = PHI.getIncomingBlock(IncomingIdx);
1501 FoundPotentialReduction = true;
1502
1503 // If the initial value of PHINode is an Instruction, consider it a leaf
1504 // value of a complex deinterleaving graph.
1505 if (auto *InitPHI =
1506 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1507 FinalInstructions.insert(InitPHI);
1508 }
1509 return FoundPotentialReduction;
1510}
1511
1512void ComplexDeinterleavingGraph::identifyReductionNodes() {
1513 SmallVector<bool> Processed(ReductionInfo.size(), false);
1514 SmallVector<Instruction *> OperationInstruction;
1515 for (auto &P : ReductionInfo)
1516 OperationInstruction.push_back(P.first);
1517
1518 // Identify a complex computation by evaluating two reduction operations that
1519 // potentially could be involved
1520 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1521 if (Processed[i])
1522 continue;
1523 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1524 if (Processed[j])
1525 continue;
1526
1527 auto *Real = OperationInstruction[i];
1528 auto *Imag = OperationInstruction[j];
1529 if (Real->getType() != Imag->getType())
1530 continue;
1531
1532 RealPHI = ReductionInfo[Real].first;
1533 ImagPHI = ReductionInfo[Imag].first;
1534 PHIsFound = false;
1535 auto Node = identifyNode(Real, Imag);
1536 if (!Node) {
1537 std::swap(Real, Imag);
1538 std::swap(RealPHI, ImagPHI);
1539 Node = identifyNode(Real, Imag);
1540 }
1541
1542 // If a node is identified and reduction PHINode is used in the chain of
1543 // operations, mark its operation instructions as used to prevent
1544 // re-identification and attach the node to the real part
1545 if (Node && PHIsFound) {
1546 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1547 << *Real << " / " << *Imag << "\n");
1548 Processed[i] = true;
1549 Processed[j] = true;
1550 auto RootNode = prepareCompositeNode(
1551 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1552 RootNode->addOperand(Node);
1553 RootToNode[Real] = RootNode;
1554 RootToNode[Imag] = RootNode;
1555 submitCompositeNode(RootNode);
1556 break;
1557 }
1558 }
1559 }
1560
1561 RealPHI = nullptr;
1562 ImagPHI = nullptr;
1563}
1564
1565bool ComplexDeinterleavingGraph::checkNodes() {
1566 // Collect all instructions from roots to leaves
1567 SmallPtrSet<Instruction *, 16> AllInstructions;
1569 for (auto &Pair : RootToNode)
1570 Worklist.push_back(Pair.first);
1571
1572 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1573 // chains
1574 while (!Worklist.empty()) {
1575 auto *I = Worklist.back();
1576 Worklist.pop_back();
1577
1578 if (!AllInstructions.insert(I).second)
1579 continue;
1580
1581 for (Value *Op : I->operands()) {
1582 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1583 if (!FinalInstructions.count(I))
1584 Worklist.emplace_back(OpI);
1585 }
1586 }
1587 }
1588
1589 // Find instructions that have users outside of chain
1590 SmallVector<Instruction *, 2> OuterInstructions;
1591 for (auto *I : AllInstructions) {
1592 // Skip root nodes
1593 if (RootToNode.count(I))
1594 continue;
1595
1596 for (User *U : I->users()) {
1597 if (AllInstructions.count(cast<Instruction>(U)))
1598 continue;
1599
1600 // Found an instruction that is not used by XCMLA/XCADD chain
1601 Worklist.emplace_back(I);
1602 break;
1603 }
1604 }
1605
1606 // If any instructions are found to be used outside, find and remove roots
1607 // that somehow connect to those instructions.
1609 while (!Worklist.empty()) {
1610 auto *I = Worklist.back();
1611 Worklist.pop_back();
1612 if (!Visited.insert(I).second)
1613 continue;
1614
1615 // Found an impacted root node. Removing it from the nodes to be
1616 // deinterleaved
1617 if (RootToNode.count(I)) {
1618 LLVM_DEBUG(dbgs() << "Instruction " << *I
1619 << " could be deinterleaved but its chain of complex "
1620 "operations have an outside user\n");
1621 RootToNode.erase(I);
1622 }
1623
1624 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1625 continue;
1626
1627 for (User *U : I->users())
1628 Worklist.emplace_back(cast<Instruction>(U));
1629
1630 for (Value *Op : I->operands()) {
1631 if (auto *OpI = dyn_cast<Instruction>(Op))
1632 Worklist.emplace_back(OpI);
1633 }
1634 }
1635 return !RootToNode.empty();
1636}
1637
1638ComplexDeinterleavingGraph::NodePtr
1639ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1640 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1641 if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
1642 return nullptr;
1643
1644 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1645 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1646 if (!Real || !Imag)
1647 return nullptr;
1648
1649 return identifyNode(Real, Imag);
1650 }
1651
1652 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1653 if (!SVI)
1654 return nullptr;
1655
1656 // Look for a shufflevector that takes separate vectors of the real and
1657 // imaginary components and recombines them into a single vector.
1658 if (!isInterleavingMask(SVI->getShuffleMask()))
1659 return nullptr;
1660
1661 Instruction *Real;
1662 Instruction *Imag;
1663 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1664 return nullptr;
1665
1666 return identifyNode(Real, Imag);
1667}
1668
1669ComplexDeinterleavingGraph::NodePtr
1670ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1671 Instruction *Imag) {
1672 Instruction *I = nullptr;
1673 Value *FinalValue = nullptr;
1674 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1675 match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1676 match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
1677 m_Value(FinalValue)))) {
1678 NodePtr PlaceholderNode = prepareCompositeNode(
1680 PlaceholderNode->ReplacementNode = FinalValue;
1681 FinalInstructions.insert(Real);
1682 FinalInstructions.insert(Imag);
1683 return submitCompositeNode(PlaceholderNode);
1684 }
1685
1686 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1687 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1688 if (!RealShuffle || !ImagShuffle) {
1689 if (RealShuffle || ImagShuffle)
1690 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1691 return nullptr;
1692 }
1693
1694 Value *RealOp1 = RealShuffle->getOperand(1);
1695 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1696 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1697 return nullptr;
1698 }
1699 Value *ImagOp1 = ImagShuffle->getOperand(1);
1700 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1701 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1702 return nullptr;
1703 }
1704
1705 Value *RealOp0 = RealShuffle->getOperand(0);
1706 Value *ImagOp0 = ImagShuffle->getOperand(0);
1707
1708 if (RealOp0 != ImagOp0) {
1709 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1710 return nullptr;
1711 }
1712
1713 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1714 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1715 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1716 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1717 return nullptr;
1718 }
1719
1720 if (RealMask[0] != 0 || ImagMask[0] != 1) {
1721 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1722 return nullptr;
1723 }
1724
1725 // Type checking, the shuffle type should be a vector type of the same
1726 // scalar type, but half the size
1727 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1728 Value *Op = Shuffle->getOperand(0);
1729 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1730 auto *OpTy = cast<FixedVectorType>(Op->getType());
1731
1732 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1733 return false;
1734 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1735 return false;
1736
1737 return true;
1738 };
1739
1740 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1741 if (!CheckType(Shuffle))
1742 return false;
1743
1744 ArrayRef<int> Mask = Shuffle->getShuffleMask();
1745 int Last = *Mask.rbegin();
1746
1747 Value *Op = Shuffle->getOperand(0);
1748 auto *OpTy = cast<FixedVectorType>(Op->getType());
1749 int NumElements = OpTy->getNumElements();
1750
1751 // Ensure that the deinterleaving shuffle only pulls from the first
1752 // shuffle operand.
1753 return Last < NumElements;
1754 };
1755
1756 if (RealShuffle->getType() != ImagShuffle->getType()) {
1757 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1758 return nullptr;
1759 }
1760 if (!CheckDeinterleavingShuffle(RealShuffle)) {
1761 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1762 return nullptr;
1763 }
1764 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1765 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1766 return nullptr;
1767 }
1768
1769 NodePtr PlaceholderNode =
1771 RealShuffle, ImagShuffle);
1772 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1773 FinalInstructions.insert(RealShuffle);
1774 FinalInstructions.insert(ImagShuffle);
1775 return submitCompositeNode(PlaceholderNode);
1776}
1777
1778ComplexDeinterleavingGraph::NodePtr
1779ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1780 auto IsSplat = [](Value *V) -> bool {
1781 // Fixed-width vector with constants
1782 if (isa<ConstantDataVector>(V))
1783 return true;
1784
1785 VectorType *VTy;
1787 // Splats are represented differently depending on whether the repeated
1788 // value is a constant or an Instruction
1789 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1790 if (Const->getOpcode() != Instruction::ShuffleVector)
1791 return false;
1792 VTy = cast<VectorType>(Const->getType());
1793 Mask = Const->getShuffleMask();
1794 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1795 VTy = Shuf->getType();
1796 Mask = Shuf->getShuffleMask();
1797 } else {
1798 return false;
1799 }
1800
1801 // When the data type is <1 x Type>, it's not possible to differentiate
1802 // between the ComplexDeinterleaving::Deinterleave and
1803 // ComplexDeinterleaving::Splat operations.
1804 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1805 return false;
1806
1807 return all_equal(Mask) && Mask[0] == 0;
1808 };
1809
1810 if (!IsSplat(R) || !IsSplat(I))
1811 return nullptr;
1812
1813 auto *Real = dyn_cast<Instruction>(R);
1814 auto *Imag = dyn_cast<Instruction>(I);
1815 if ((!Real && Imag) || (Real && !Imag))
1816 return nullptr;
1817
1818 if (Real && Imag) {
1819 // Non-constant splats should be in the same basic block
1820 if (Real->getParent() != Imag->getParent())
1821 return nullptr;
1822
1823 FinalInstructions.insert(Real);
1824 FinalInstructions.insert(Imag);
1825 }
1826 NodePtr PlaceholderNode =
1827 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1828 return submitCompositeNode(PlaceholderNode);
1829}
1830
1831ComplexDeinterleavingGraph::NodePtr
1832ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1833 Instruction *Imag) {
1834 if (Real != RealPHI || Imag != ImagPHI)
1835 return nullptr;
1836
1837 PHIsFound = true;
1838 NodePtr PlaceholderNode = prepareCompositeNode(
1839 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1840 return submitCompositeNode(PlaceholderNode);
1841}
1842
1843ComplexDeinterleavingGraph::NodePtr
1844ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1845 Instruction *Imag) {
1846 auto *SelectReal = dyn_cast<SelectInst>(Real);
1847 auto *SelectImag = dyn_cast<SelectInst>(Imag);
1848 if (!SelectReal || !SelectImag)
1849 return nullptr;
1850
1851 Instruction *MaskA, *MaskB;
1852 Instruction *AR, *AI, *RA, *BI;
1853 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1854 m_Instruction(RA))) ||
1855 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1856 m_Instruction(BI))))
1857 return nullptr;
1858
1859 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1860 return nullptr;
1861
1862 if (!MaskA->getType()->isVectorTy())
1863 return nullptr;
1864
1865 auto NodeA = identifyNode(AR, AI);
1866 if (!NodeA)
1867 return nullptr;
1868
1869 auto NodeB = identifyNode(RA, BI);
1870 if (!NodeB)
1871 return nullptr;
1872
1873 NodePtr PlaceholderNode = prepareCompositeNode(
1874 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1875 PlaceholderNode->addOperand(NodeA);
1876 PlaceholderNode->addOperand(NodeB);
1877 FinalInstructions.insert(MaskA);
1878 FinalInstructions.insert(MaskB);
1879 return submitCompositeNode(PlaceholderNode);
1880}
1881
1882static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1883 std::optional<FastMathFlags> Flags,
1884 Value *InputA, Value *InputB) {
1885 Value *I;
1886 switch (Opcode) {
1887 case Instruction::FNeg:
1888 I = B.CreateFNeg(InputA);
1889 break;
1890 case Instruction::FAdd:
1891 I = B.CreateFAdd(InputA, InputB);
1892 break;
1893 case Instruction::Add:
1894 I = B.CreateAdd(InputA, InputB);
1895 break;
1896 case Instruction::FSub:
1897 I = B.CreateFSub(InputA, InputB);
1898 break;
1899 case Instruction::Sub:
1900 I = B.CreateSub(InputA, InputB);
1901 break;
1902 case Instruction::FMul:
1903 I = B.CreateFMul(InputA, InputB);
1904 break;
1905 case Instruction::Mul:
1906 I = B.CreateMul(InputA, InputB);
1907 break;
1908 default:
1909 llvm_unreachable("Incorrect symmetric opcode");
1910 }
1911 if (Flags)
1912 cast<Instruction>(I)->setFastMathFlags(*Flags);
1913 return I;
1914}
1915
1916Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1917 RawNodePtr Node) {
1918 if (Node->ReplacementNode)
1919 return Node->ReplacementNode;
1920
1921 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1922 return Node->Operands.size() > Idx
1923 ? replaceNode(Builder, Node->Operands[Idx])
1924 : nullptr;
1925 };
1926
1927 Value *ReplacementNode;
1928 switch (Node->Operation) {
1929 case ComplexDeinterleavingOperation::CAdd:
1930 case ComplexDeinterleavingOperation::CMulPartial:
1931 case ComplexDeinterleavingOperation::Symmetric: {
1932 Value *Input0 = ReplaceOperandIfExist(Node, 0);
1933 Value *Input1 = ReplaceOperandIfExist(Node, 1);
1934 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1935 assert(!Input1 || (Input0->getType() == Input1->getType() &&
1936 "Node inputs need to be of the same type"));
1938 (Input0->getType() == Accumulator->getType() &&
1939 "Accumulator and input need to be of the same type"));
1940 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1941 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1942 Input0, Input1);
1943 else
1944 ReplacementNode = TL->createComplexDeinterleavingIR(
1945 Builder, Node->Operation, Node->Rotation, Input0, Input1,
1946 Accumulator);
1947 break;
1948 }
1949 case ComplexDeinterleavingOperation::Deinterleave:
1950 llvm_unreachable("Deinterleave node should already have ReplacementNode");
1951 break;
1952 case ComplexDeinterleavingOperation::Splat: {
1953 auto *NewTy = VectorType::getDoubleElementsVectorType(
1954 cast<VectorType>(Node->Real->getType()));
1955 auto *R = dyn_cast<Instruction>(Node->Real);
1956 auto *I = dyn_cast<Instruction>(Node->Imag);
1957 if (R && I) {
1958 // Splats that are not constant are interleaved where they are located
1959 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1960 IRBuilder<> IRB(InsertPoint);
1961 ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
1962 NewTy, {Node->Real, Node->Imag});
1963 } else {
1964 ReplacementNode = Builder.CreateIntrinsic(
1965 Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});
1966 }
1967 break;
1968 }
1969 case ComplexDeinterleavingOperation::ReductionPHI: {
1970 // If Operation is ReductionPHI, a new empty PHINode is created.
1971 // It is filled later when the ReductionOperation is processed.
1972 auto *VTy = cast<VectorType>(Node->Real->getType());
1973 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1974 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
1975 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1976 ReplacementNode = NewPHI;
1977 break;
1978 }
1979 case ComplexDeinterleavingOperation::ReductionOperation:
1980 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1981 processReductionOperation(ReplacementNode, Node);
1982 break;
1983 case ComplexDeinterleavingOperation::ReductionSelect: {
1984 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1985 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1986 auto *A = replaceNode(Builder, Node->Operands[0]);
1987 auto *B = replaceNode(Builder, Node->Operands[1]);
1988 auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1989 cast<VectorType>(MaskReal->getType()));
1990 auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,
1991 NewMaskTy, {MaskReal, MaskImag});
1992 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1993 break;
1994 }
1995 }
1996
1997 assert(ReplacementNode && "Target failed to create Intrinsic call.");
1998 NumComplexTransformations += 1;
1999 Node->ReplacementNode = ReplacementNode;
2000 return ReplacementNode;
2001}
2002
2003void ComplexDeinterleavingGraph::processReductionOperation(
2004 Value *OperationReplacement, RawNodePtr Node) {
2005 auto *Real = cast<Instruction>(Node->Real);
2006 auto *Imag = cast<Instruction>(Node->Imag);
2007 auto *OldPHIReal = ReductionInfo[Real].first;
2008 auto *OldPHIImag = ReductionInfo[Imag].first;
2009 auto *NewPHI = OldToNewPHI[OldPHIReal];
2010
2011 auto *VTy = cast<VectorType>(Real->getType());
2012 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2013
2014 // We have to interleave initial origin values coming from IncomingBlock
2015 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2016 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2017
2018 IRBuilder<> Builder(Incoming->getTerminator());
2019 auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2020 {InitReal, InitImag});
2021
2022 NewPHI->addIncoming(NewInit, Incoming);
2023 NewPHI->addIncoming(OperationReplacement, BackEdge);
2024
2025 // Deinterleave complex vector outside of loop so that it can be finally
2026 // reduced
2027 auto *FinalReductionReal = ReductionInfo[Real].second;
2028 auto *FinalReductionImag = ReductionInfo[Imag].second;
2029
2030 Builder.SetInsertPoint(
2031 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2032 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2033 OperationReplacement->getType(),
2034 OperationReplacement);
2035
2036 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2037 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2038
2039 Builder.SetInsertPoint(FinalReductionImag);
2040 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2041 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2042}
2043
2044void ComplexDeinterleavingGraph::replaceNodes() {
2045 SmallVector<Instruction *, 16> DeadInstrRoots;
2046 for (auto *RootInstruction : OrderedRoots) {
2047 // Check if this potential root went through check process and we can
2048 // deinterleave it
2049 if (!RootToNode.count(RootInstruction))
2050 continue;
2051
2052 IRBuilder<> Builder(RootInstruction);
2053 auto RootNode = RootToNode[RootInstruction];
2054 Value *R = replaceNode(Builder, RootNode.get());
2055
2056 if (RootNode->Operation ==
2057 ComplexDeinterleavingOperation::ReductionOperation) {
2058 auto *RootReal = cast<Instruction>(RootNode->Real);
2059 auto *RootImag = cast<Instruction>(RootNode->Imag);
2060 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2061 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2062 DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2063 DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2064 } else {
2065 assert(R && "Unable to find replacement for RootInstruction");
2066 DeadInstrRoots.push_back(RootInstruction);
2067 RootInstruction->replaceAllUsesWith(R);
2068 }
2069 }
2070
2071 for (auto *I : DeadInstrRoots)
2073}
SmallVector< AArch64_IMM::ImmInsnModel, 4 > Insn
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Rewrite undef for PHI
VarLocInsertPt getNextNode(const DbgRecord *DVR)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
Complex Deinterleaving
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
static bool isInstructionPairMul(Instruction *A, Instruction *B)
#define DEBUG_TYPE
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
static bool runOnFunction(Function &F, bool PostInlining)
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
This file implements a map that provides insertion order iteration.
#define P(N)
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
raw_pwrite_stream & OS
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file describes how to lower LLVM code to machine code.
This pass exposes codegen information to IR-level passes.
BinaryOperator * Mul
DEMANGLE_DUMP_METHOD void dump() const
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:256
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
InstListType::const_iterator getFirstNonPHIIt() const
Iterator returning form of getFirstNonPHI.
Definition: BasicBlock.cpp:374
This class represents an Operation in the Expression.
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:156
iterator end()
Definition: DenseMap.h:84
bool allowContract() const
Definition: FMF.h:70
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
Definition: IRBuilder.h:91
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition: IRBuilder.h:2547
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:890
Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
Definition: IRBuilder.cpp:1048
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2697
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:567
bool isBinaryOp() const
Definition: Instruction.h:279
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:72
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:274
bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:36
size_type size() const
Definition: MapVector.h:60
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
void preserve()
Mark an analysis as preserved.
Definition: Analysis.h:131
This instruction constructs a fixed permutation of two input vectors.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:452
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:937
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:270
Value * getOperand(unsigned i) const
Definition: User.h:228
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
const ParentTy * getParent() const
Definition: ilist_node.h:32
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:125
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ BR
Control flow instructions. These all have token chains.
Definition: ISDOpcodes.h:1118
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:100
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
Definition: PatternMatch.h:826
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:885
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< cst_pred_ty< is_zero_int >, ValTy, Instruction::Sub > m_Neg(const ValTy &V)
Matches a 'Neg' as 'sub 0, V'.
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition: DWP.cpp:480
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:546
void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
DWARFExpression::Operation Op
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition: STLExtras.h:2087
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...