LLVM 23.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
64#include "llvm/ADT/MapVector.h"
65#include "llvm/ADT/Statistic.h"
70#include "llvm/IR/IRBuilder.h"
71#include "llvm/IR/Intrinsics.h"
77#include <algorithm>
78
79using namespace llvm;
80using namespace PatternMatch;
81
82#define DEBUG_TYPE "complex-deinterleaving"
83
84STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
85
87 "enable-complex-deinterleaving",
88 cl::desc("Enable generation of complex instructions"), cl::init(true),
90
91/// Checks the given mask, and determines whether said mask is interleaving.
92///
93/// To be interleaving, a mask must alternate between `i` and `i + (Length /
94/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
95/// 4x vector interleaving mask would be <0, 2, 1, 3>).
96static bool isInterleavingMask(ArrayRef<int> Mask);
97
98/// Checks the given mask, and determines whether said mask is deinterleaving.
99///
100/// To be deinterleaving, a mask must increment in steps of 2, and either start
101/// with 0 or 1.
102/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
103/// <1, 3, 5, 7>).
104static bool isDeinterleavingMask(ArrayRef<int> Mask);
105
106/// Returns true if the operation is a negation of V, and it works for both
107/// integers and floats.
108static bool isNeg(Value *V);
109
110/// Returns the operand for negation operation.
111static Value *getNegOperand(Value *V);
112
113namespace {
114struct ComplexValue {
115 Value *Real = nullptr;
116 Value *Imag = nullptr;
117
118 bool operator==(const ComplexValue &Other) const {
119 return Real == Other.Real && Imag == Other.Imag;
120 }
121};
122hash_code hash_value(const ComplexValue &Arg) {
125}
126} // end namespace
128
129template <> struct llvm::DenseMapInfo<ComplexValue> {
130 static unsigned getHashValue(const ComplexValue &Val) {
133 }
134 static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
135 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
136 }
137};
138
139namespace {
140template <typename T, typename IterT>
141std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
142 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
143 if (Common != A.end())
144 return std::make_optional(*Common);
145 return std::nullopt;
146}
147
148class ComplexDeinterleavingLegacyPass : public FunctionPass {
149public:
150 static char ID;
151
152 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
153 : FunctionPass(ID), TM(TM) {}
154
155 StringRef getPassName() const override {
156 return "Complex Deinterleaving Pass";
157 }
158
159 bool runOnFunction(Function &F) override;
160 void getAnalysisUsage(AnalysisUsage &AU) const override {
161 AU.addRequired<TargetLibraryInfoWrapperPass>();
162 AU.setPreservesCFG();
163 }
164
165private:
166 const TargetMachine *TM;
167};
168
169class ComplexDeinterleavingGraph;
170struct ComplexDeinterleavingCompositeNode {
171
172 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
173 Value *R, Value *I)
174 : Operation(Op) {
175 Vals.push_back({R, I});
176 }
177
178 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
180 : Operation(Op), Vals(Other) {}
181
182private:
183 friend class ComplexDeinterleavingGraph;
184 using CompositeNode = ComplexDeinterleavingCompositeNode;
185 bool OperandsValid = true;
186
187public:
189 ComplexValues Vals;
190
191 // This two members are required exclusively for generating
192 // ComplexDeinterleavingOperation::Symmetric operations.
193 unsigned Opcode;
194 std::optional<FastMathFlags> Flags;
195
197 ComplexDeinterleavingRotation::Rotation_0;
199 Value *ReplacementNode = nullptr;
200
201 void addOperand(CompositeNode *Node) {
202 if (!Node)
203 OperandsValid = false;
204 Operands.push_back(Node);
205 }
206
207 void dump() { dump(dbgs()); }
208 void dump(raw_ostream &OS) {
209 auto PrintValue = [&](Value *V) {
210 if (V) {
211 OS << "\"";
212 V->print(OS, true);
213 OS << "\"\n";
214 } else
215 OS << "nullptr\n";
216 };
217 auto PrintNodeRef = [&](CompositeNode *Ptr) {
218 if (Ptr)
219 OS << Ptr << "\n";
220 else
221 OS << "nullptr\n";
222 };
223
224 OS << "- CompositeNode: " << this << "\n";
225 for (unsigned I = 0; I < Vals.size(); I++) {
226 OS << " Real(" << I << ") : ";
227 PrintValue(Vals[I].Real);
228 OS << " Imag(" << I << ") : ";
229 PrintValue(Vals[I].Imag);
230 }
231 OS << " ReplacementNode: ";
232 PrintValue(ReplacementNode);
233 OS << " Operation: " << (int)Operation << "\n";
234 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
235 OS << " Operands: \n";
236 for (const auto &Op : Operands) {
237 OS << " - ";
238 PrintNodeRef(Op);
239 }
240 }
241
242 bool areOperandsValid() { return OperandsValid; }
243};
244
245class ComplexDeinterleavingGraph {
246public:
247 struct Product {
248 Value *Multiplier;
249 Value *Multiplicand;
250 bool IsPositive;
251 };
252
253 using Addend = std::pair<Value *, bool>;
254 using AddendList = BumpPtrList<Addend>;
255 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
256
257 // Helper struct for holding info about potential partial multiplication
258 // candidates
259 struct PartialMulCandidate {
260 Value *Common;
261 CompositeNode *Node;
262 unsigned RealIdx;
263 unsigned ImagIdx;
264 bool IsNodeInverted;
265 };
266
267 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
268 const TargetLibraryInfo *TLI,
269 unsigned Factor)
270 : TL(TL), TLI(TLI), Factor(Factor) {}
271
272private:
273 const TargetLowering *TL = nullptr;
274 const TargetLibraryInfo *TLI = nullptr;
275 unsigned Factor;
276 SmallVector<CompositeNode *> CompositeNodes;
277 DenseMap<ComplexValues, CompositeNode *> CachedResult;
278 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
279
280 SmallPtrSet<Instruction *, 16> FinalInstructions;
281
282 /// Root instructions are instructions from which complex computation starts
283 DenseMap<Instruction *, CompositeNode *> RootToNode;
284
285 /// Topologically sorted root instructions
287
288 /// When examining a basic block for complex deinterleaving, if it is a simple
289 /// one-block loop, then the only incoming block is 'Incoming' and the
290 /// 'BackEdge' block is the block itself."
291 BasicBlock *BackEdge = nullptr;
292 BasicBlock *Incoming = nullptr;
293
294 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
295 /// %OutsideUser as it is shown in the IR:
296 ///
297 /// vector.body:
298 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
299 /// [ %ReductionOp, %vector.body ]
300 /// ...
301 /// %ReductionOp = fadd i64 ...
302 /// ...
303 /// br i1 %condition, label %vector.body, %middle.block
304 ///
305 /// middle.block:
306 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
307 ///
308 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
309 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
310 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
311
312 /// In the process of detecting a reduction, we consider a pair of
313 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
314 /// traverse the use-tree to detect complex operations. As this is a reduction
315 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
316 /// to the %ReductionOPs that we suspect to be complex.
317 /// RealPHI and ImagPHI are used by the identifyPHINode method.
318 PHINode *RealPHI = nullptr;
319 PHINode *ImagPHI = nullptr;
320
321 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
322 /// detection.
323 bool PHIsFound = false;
324
325 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
326 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
327 /// This mapping is populated during
328 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
329 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
330 /// replacement process.
331 DenseMap<PHINode *, PHINode *> OldToNewPHI;
332
333 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
334 Value *R, Value *I) {
335 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
336 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
337 (R && I)) &&
338 "Reduction related nodes must have Real and Imaginary parts");
339 return new (Allocator.Allocate())
340 ComplexDeinterleavingCompositeNode(Operation, R, I);
341 }
342
343 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
344 ComplexValues &Vals) {
345#ifndef NDEBUG
346 for (auto &V : Vals) {
347 assert(
348 ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
349 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
350 (V.Real && V.Imag)) &&
351 "Reduction related nodes must have Real and Imaginary parts");
352 }
353#endif
354 return new (Allocator.Allocate())
355 ComplexDeinterleavingCompositeNode(Operation, Vals);
356 }
357
358 CompositeNode *submitCompositeNode(CompositeNode *Node) {
359 CompositeNodes.push_back(Node);
360 if (Node->Vals[0].Real)
361 CachedResult[Node->Vals] = Node;
362 return Node;
363 }
364
365 /// Identifies a complex partial multiply pattern and its rotation, based on
366 /// the following patterns
367 ///
368 /// 0: r: cr + ar * br
369 /// i: ci + ar * bi
370 /// 90: r: cr - ai * bi
371 /// i: ci + ai * br
372 /// 180: r: cr - ar * br
373 /// i: ci - ar * bi
374 /// 270: r: cr + ai * bi
375 /// i: ci - ai * br
376 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
377
378 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
379 /// is partially known from identifyPartialMul, filling in the other half of
380 /// the complex pair.
381 CompositeNode *
382 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
383 std::pair<Value *, Value *> &CommonOperandI);
384
385 /// Identifies a complex add pattern and its rotation, based on the following
386 /// patterns.
387 ///
388 /// 90: r: ar - bi
389 /// i: ai + br
390 /// 270: r: ar + bi
391 /// i: ai - br
392 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
393 CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
394 CompositeNode *identifyPartialReduction(Value *R, Value *I);
395 CompositeNode *identifyDotProduct(Value *Inst);
396
397 CompositeNode *identifyNode(ComplexValues &Vals);
398
399 CompositeNode *identifyNode(Value *R, Value *I) {
400 ComplexValues Vals;
401 Vals.push_back({R, I});
402 return identifyNode(Vals);
403 }
404
405 /// Determine if a sum of complex numbers can be formed from \p RealAddends
406 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
407 /// Return nullptr if it is not possible to construct a complex number.
408 /// \p Flags are needed to generate symmetric Add and Sub operations.
409 CompositeNode *identifyAdditions(AddendList &RealAddends,
410 AddendList &ImagAddends,
411 std::optional<FastMathFlags> Flags,
412 CompositeNode *Accumulator);
413
414 /// Extract one addend that have both real and imaginary parts positive.
415 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
416 AddendList &ImagAddends);
417
418 /// Determine if sum of multiplications of complex numbers can be formed from
419 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
420 /// to it. Return nullptr if it is not possible to construct a complex number.
421 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
422 SmallVectorImpl<Product> &ImagMuls,
423 CompositeNode *Accumulator);
424
425 /// Go through pairs of multiplication (one Real and one Imag) and find all
426 /// possible candidates for partial multiplication and put them into \p
427 /// Candidates. Returns true if all Product has pair with common operand
428 bool collectPartialMuls(ArrayRef<Product> RealMuls,
429 ArrayRef<Product> ImagMuls,
430 SmallVectorImpl<PartialMulCandidate> &Candidates);
431
432 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
433 /// the order of complex computation operations may be significantly altered,
434 /// and the real and imaginary parts may not be executed in parallel. This
435 /// function takes this into consideration and employs a more general approach
436 /// to identify complex computations. Initially, it gathers all the addends
437 /// and multiplicands and then constructs a complex expression from them.
438 CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);
439
440 CompositeNode *identifyRoot(Instruction *I);
441
442 /// Identifies the Deinterleave operation applied to a vector containing
443 /// complex numbers. There are two ways to represent the Deinterleave
444 /// operation:
445 /// * Using two shufflevectors with even indices for /pReal instruction and
446 /// odd indices for /pImag instructions (only for fixed-width vectors)
447 /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
448 /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
449 /// 2.
450 CompositeNode *identifyDeinterleave(ComplexValues &Vals);
451
452 /// identifying the operation that represents a complex number repeated in a
453 /// Splat vector. There are two possible types of splats: ConstantExpr with
454 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
455 /// initialization mask with all values set to zero.
456 CompositeNode *identifySplat(ComplexValues &Vals);
457
458 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
459
460 /// Identifies SelectInsts in a loop that has reduction with predication masks
461 /// and/or predicated tail folding
462 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
463
464 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
465
466 /// Complete IR modifications after producing new reduction operation:
467 /// * Populate the PHINode generated for
468 /// ComplexDeinterleavingOperation::ReductionPHI
469 /// * Deinterleave the final value outside of the loop and repurpose original
470 /// reduction users
471 void processReductionOperation(Value *OperationReplacement,
472 CompositeNode *Node);
473 void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);
474
475public:
476 void dump() { dump(dbgs()); }
477 void dump(raw_ostream &OS) {
478 for (const auto &Node : CompositeNodes)
479 Node->dump(OS);
480 }
481
482 /// Returns false if the deinterleaving operation should be cancelled for the
483 /// current graph.
484 bool identifyNodes(Instruction *RootI);
485
486 /// In case \pB is one-block loop, this function seeks potential reductions
487 /// and populates ReductionInfo. Returns true if any reductions were
488 /// identified.
489 bool collectPotentialReductions(BasicBlock *B);
490
491 void identifyReductionNodes();
492
493 /// Check that every instruction, from the roots to the leaves, has internal
494 /// uses.
495 bool checkNodes();
496
497 /// Perform the actual replacement of the underlying instruction graph.
498 void replaceNodes();
499};
500
501class ComplexDeinterleaving {
502public:
503 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
504 : TL(tl), TLI(tli) {}
505 bool runOnFunction(Function &F);
506
507private:
508 bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);
509
510 const TargetLowering *TL = nullptr;
511 const TargetLibraryInfo *TLI = nullptr;
512};
513
514} // namespace
515
516char ComplexDeinterleavingLegacyPass::ID = 0;
517
518INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
519 "Complex Deinterleaving", false, false)
520INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
521 "Complex Deinterleaving", false, false)
522
525 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
526 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
527 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
528 return PreservedAnalyses::all();
529
532 return PA;
533}
534
536 return new ComplexDeinterleavingLegacyPass(TM);
537}
538
539bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
540 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
541 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
542 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
543}
544
545bool ComplexDeinterleaving::runOnFunction(Function &F) {
548 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
549 return false;
550 }
551
554 dbgs() << "Complex deinterleaving has been disabled, target does "
555 "not support lowering of complex number operations.\n");
556 return false;
557 }
558
559 bool Changed = false;
560 for (auto &B : F)
561 Changed |= evaluateBasicBlock(&B, 2);
562
563 // TODO: Permit changes for both interleave factors in the same function.
564 if (!Changed) {
565 for (auto &B : F)
566 Changed |= evaluateBasicBlock(&B, 4);
567 }
568
569 // TODO: We can also support interleave factors of 6 and 8 if needed.
570
571 return Changed;
572}
573
575 // If the size is not even, it's not an interleaving mask
576 if ((Mask.size() & 1))
577 return false;
578
579 int HalfNumElements = Mask.size() / 2;
580 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
581 int MaskIdx = Idx * 2;
582 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
583 return false;
584 }
585
586 return true;
587}
588
590 int Offset = Mask[0];
591 int HalfNumElements = Mask.size() / 2;
592
593 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
594 if (Mask[Idx] != (Idx * 2) + Offset)
595 return false;
596 }
597
598 return true;
599}
600
601bool isNeg(Value *V) {
602 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
603}
604
606 assert(isNeg(V));
607 auto *I = cast<Instruction>(V);
608 if (I->getOpcode() == Instruction::FNeg)
609 return I->getOperand(0);
610
611 return I->getOperand(1);
612}
613
614bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
615 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
616 if (Graph.collectPotentialReductions(B))
617 Graph.identifyReductionNodes();
618
619 for (auto &I : *B)
620 Graph.identifyNodes(&I);
621
622 if (Graph.checkNodes()) {
623 Graph.replaceNodes();
624 return true;
625 }
626
627 return false;
628}
629
630ComplexDeinterleavingGraph::CompositeNode *
631ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
632 Instruction *Real, Instruction *Imag,
633 std::pair<Value *, Value *> &PartialMatch) {
634 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
635 << "\n");
636
637 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
638 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
639 return nullptr;
640 }
641
642 if ((Real->getOpcode() != Instruction::FMul &&
643 Real->getOpcode() != Instruction::Mul) ||
644 (Imag->getOpcode() != Instruction::FMul &&
645 Imag->getOpcode() != Instruction::Mul)) {
647 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
648 return nullptr;
649 }
650
651 Value *R0 = Real->getOperand(0);
652 Value *R1 = Real->getOperand(1);
653 Value *I0 = Imag->getOperand(0);
654 Value *I1 = Imag->getOperand(1);
655
656 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
657 // rotations and use the operand.
658 unsigned Negs = 0;
659 Value *Op;
660 if (match(R0, m_Neg(m_Value(Op)))) {
661 Negs |= 1;
662 R0 = Op;
663 } else if (match(R1, m_Neg(m_Value(Op)))) {
664 Negs |= 1;
665 R1 = Op;
666 }
667
668 if (isNeg(I0)) {
669 Negs |= 2;
670 Negs ^= 1;
671 I0 = Op;
672 } else if (match(I1, m_Neg(m_Value(Op)))) {
673 Negs |= 2;
674 Negs ^= 1;
675 I1 = Op;
676 }
677
679
680 Value *CommonOperand;
681 Value *UncommonRealOp;
682 Value *UncommonImagOp;
683
684 if (R0 == I0 || R0 == I1) {
685 CommonOperand = R0;
686 UncommonRealOp = R1;
687 } else if (R1 == I0 || R1 == I1) {
688 CommonOperand = R1;
689 UncommonRealOp = R0;
690 } else {
691 LLVM_DEBUG(dbgs() << " - No equal operand\n");
692 return nullptr;
693 }
694
695 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
696 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
697 Rotation == ComplexDeinterleavingRotation::Rotation_270)
698 std::swap(UncommonRealOp, UncommonImagOp);
699
700 // Between identifyPartialMul and here we need to have found a complete valid
701 // pair from the CommonOperand of each part.
702 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
703 Rotation == ComplexDeinterleavingRotation::Rotation_180)
704 PartialMatch.first = CommonOperand;
705 else
706 PartialMatch.second = CommonOperand;
707
708 if (!PartialMatch.first || !PartialMatch.second) {
709 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
710 return nullptr;
711 }
712
713 CompositeNode *CommonNode =
714 identifyNode(PartialMatch.first, PartialMatch.second);
715 if (!CommonNode) {
716 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
717 return nullptr;
718 }
719
720 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
721 if (!UncommonNode) {
722 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
723 return nullptr;
724 }
725
726 CompositeNode *Node = prepareCompositeNode(
727 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
728 Node->Rotation = Rotation;
729 Node->addOperand(CommonNode);
730 Node->addOperand(UncommonNode);
731 return submitCompositeNode(Node);
732}
733
734ComplexDeinterleavingGraph::CompositeNode *
735ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
736 Instruction *Imag) {
737 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
738 << "\n");
739
740 // Determine rotation
741 auto IsAdd = [](unsigned Op) {
742 return Op == Instruction::FAdd || Op == Instruction::Add;
743 };
744 auto IsSub = [](unsigned Op) {
745 return Op == Instruction::FSub || Op == Instruction::Sub;
746 };
748 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
749 Rotation = ComplexDeinterleavingRotation::Rotation_0;
750 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
751 Rotation = ComplexDeinterleavingRotation::Rotation_90;
752 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
753 Rotation = ComplexDeinterleavingRotation::Rotation_180;
754 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
755 Rotation = ComplexDeinterleavingRotation::Rotation_270;
756 else {
757 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
758 return nullptr;
759 }
760
761 if (isa<FPMathOperator>(Real) &&
762 (!Real->getFastMathFlags().allowContract() ||
763 !Imag->getFastMathFlags().allowContract())) {
764 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
765 return nullptr;
766 }
767
768 Value *CR = Real->getOperand(0);
769 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
770 if (!RealMulI)
771 return nullptr;
772 Value *CI = Imag->getOperand(0);
773 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
774 if (!ImagMulI)
775 return nullptr;
776
777 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
778 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
779 return nullptr;
780 }
781
782 Value *R0 = RealMulI->getOperand(0);
783 Value *R1 = RealMulI->getOperand(1);
784 Value *I0 = ImagMulI->getOperand(0);
785 Value *I1 = ImagMulI->getOperand(1);
786
787 Value *CommonOperand;
788 Value *UncommonRealOp;
789 Value *UncommonImagOp;
790
791 if (R0 == I0 || R0 == I1) {
792 CommonOperand = R0;
793 UncommonRealOp = R1;
794 } else if (R1 == I0 || R1 == I1) {
795 CommonOperand = R1;
796 UncommonRealOp = R0;
797 } else {
798 LLVM_DEBUG(dbgs() << " - No equal operand\n");
799 return nullptr;
800 }
801
802 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
803 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
804 Rotation == ComplexDeinterleavingRotation::Rotation_270)
805 std::swap(UncommonRealOp, UncommonImagOp);
806
807 std::pair<Value *, Value *> PartialMatch(
808 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
809 Rotation == ComplexDeinterleavingRotation::Rotation_180)
810 ? CommonOperand
811 : nullptr,
812 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
813 Rotation == ComplexDeinterleavingRotation::Rotation_270)
814 ? CommonOperand
815 : nullptr);
816
817 auto *CRInst = dyn_cast<Instruction>(CR);
818 auto *CIInst = dyn_cast<Instruction>(CI);
819
820 if (!CRInst || !CIInst) {
821 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
822 return nullptr;
823 }
824
825 CompositeNode *CNode =
826 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
827 if (!CNode) {
828 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
829 return nullptr;
830 }
831
832 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
833 if (!UncommonRes) {
834 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
835 return nullptr;
836 }
837
838 assert(PartialMatch.first && PartialMatch.second);
839 CompositeNode *CommonRes =
840 identifyNode(PartialMatch.first, PartialMatch.second);
841 if (!CommonRes) {
842 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
843 return nullptr;
844 }
845
846 CompositeNode *Node = prepareCompositeNode(
847 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
848 Node->Rotation = Rotation;
849 Node->addOperand(CommonRes);
850 Node->addOperand(UncommonRes);
851 Node->addOperand(CNode);
852 return submitCompositeNode(Node);
853}
854
855ComplexDeinterleavingGraph::CompositeNode *
856ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
857 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
858
859 // Determine rotation
861 if ((Real->getOpcode() == Instruction::FSub &&
862 Imag->getOpcode() == Instruction::FAdd) ||
863 (Real->getOpcode() == Instruction::Sub &&
864 Imag->getOpcode() == Instruction::Add))
865 Rotation = ComplexDeinterleavingRotation::Rotation_90;
866 else if ((Real->getOpcode() == Instruction::FAdd &&
867 Imag->getOpcode() == Instruction::FSub) ||
868 (Real->getOpcode() == Instruction::Add &&
869 Imag->getOpcode() == Instruction::Sub))
870 Rotation = ComplexDeinterleavingRotation::Rotation_270;
871 else {
872 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
873 return nullptr;
874 }
875
876 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
877 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
878 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
879 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
880
881 if (!AR || !AI || !BR || !BI) {
882 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
883 return nullptr;
884 }
885
886 CompositeNode *ResA = identifyNode(AR, AI);
887 if (!ResA) {
888 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
889 return nullptr;
890 }
891 CompositeNode *ResB = identifyNode(BR, BI);
892 if (!ResB) {
893 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
894 return nullptr;
895 }
896
897 CompositeNode *Node =
898 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
899 Node->Rotation = Rotation;
900 Node->addOperand(ResA);
901 Node->addOperand(ResB);
902 return submitCompositeNode(Node);
903}
904
906 unsigned OpcA = A->getOpcode();
907 unsigned OpcB = B->getOpcode();
908
909 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
910 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
911 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
912 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
913}
914
916 auto Pattern =
918
919 return match(A, Pattern) && match(B, Pattern);
920}
921
923 switch (I->getOpcode()) {
924 case Instruction::FAdd:
925 case Instruction::FSub:
926 case Instruction::FMul:
927 case Instruction::FNeg:
928 case Instruction::Add:
929 case Instruction::Sub:
930 case Instruction::Mul:
931 return true;
932 default:
933 return false;
934 }
935}
936
937ComplexDeinterleavingGraph::CompositeNode *
938ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
939 auto *FirstReal = cast<Instruction>(Vals[0].Real);
940 unsigned FirstOpc = FirstReal->getOpcode();
941 for (auto &V : Vals) {
942 auto *Real = cast<Instruction>(V.Real);
943 auto *Imag = cast<Instruction>(V.Imag);
944 if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
945 return nullptr;
946
949 return nullptr;
950
951 if (isa<FPMathOperator>(FirstReal))
952 if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
953 Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
954 return nullptr;
955 }
956
957 ComplexValues OpVals;
958 for (auto &V : Vals) {
959 auto *R0 = cast<Instruction>(V.Real)->getOperand(0);
960 auto *I0 = cast<Instruction>(V.Imag)->getOperand(0);
961 OpVals.push_back({R0, I0});
962 }
963
964 CompositeNode *Op0 = identifyNode(OpVals);
965 CompositeNode *Op1 = nullptr;
966 if (Op0 == nullptr)
967 return nullptr;
968
969 if (FirstReal->isBinaryOp()) {
970 OpVals.clear();
971 for (auto &V : Vals) {
972 auto *R1 = cast<Instruction>(V.Real)->getOperand(1);
973 auto *I1 = cast<Instruction>(V.Imag)->getOperand(1);
974 OpVals.push_back({R1, I1});
975 }
976 Op1 = identifyNode(OpVals);
977 if (Op1 == nullptr)
978 return nullptr;
979 }
980
981 auto Node =
982 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
983 Node->Opcode = FirstReal->getOpcode();
984 if (isa<FPMathOperator>(FirstReal))
985 Node->Flags = FirstReal->getFastMathFlags();
986
987 Node->addOperand(Op0);
988 if (FirstReal->isBinaryOp())
989 Node->addOperand(Op1);
990
991 return submitCompositeNode(Node);
992}
993
994ComplexDeinterleavingGraph::CompositeNode *
995ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
997 ComplexDeinterleavingOperation::CDot, V->getType())) {
998 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
999 "operation CDot with the type "
1000 << *V->getType() << "\n");
1001 return nullptr;
1002 }
1003
1004 auto *Inst = cast<Instruction>(V);
1005 auto *RealUser = cast<Instruction>(*Inst->user_begin());
1006
1007 CompositeNode *CN =
1008 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
1009
1010 CompositeNode *ANode = nullptr;
1011
1012 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1013
1014 Value *AReal = nullptr;
1015 Value *AImag = nullptr;
1016 Value *BReal = nullptr;
1017 Value *BImag = nullptr;
1018 Value *Phi = nullptr;
1019
1020 auto UnwrapCast = [](Value *V) -> Value * {
1021 if (auto *CI = dyn_cast<CastInst>(V))
1022 return CI->getOperand(0);
1023 return V;
1024 };
1025
1026 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
1028 m_Mul(m_Value(BReal), m_Value(AReal))),
1029 m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
1030
1031 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
1033 m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
1034 m_Mul(m_Value(BImag), m_Value(AReal)));
1035
1036 if (match(Inst, PatternRot0)) {
1037 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1038 } else if (match(Inst, PatternRot270)) {
1039 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1040 } else {
1041 Value *A0, *A1;
1042 // The rotations 90 and 180 share the same operation pattern, so inspect the
1043 // order of the operands, identifying where the real and imaginary
1044 // components of A go, to discern between the aforementioned rotations.
1045 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
1047 m_Mul(m_Value(BReal), m_Value(A0))),
1048 m_Mul(m_Value(BImag), m_Value(A1)));
1049
1050 if (!match(Inst, PatternRot90Rot180))
1051 return nullptr;
1052
1053 A0 = UnwrapCast(A0);
1054 A1 = UnwrapCast(A1);
1055
1056 // Test if A0 is real/A1 is imag
1057 ANode = identifyNode(A0, A1);
1058 if (!ANode) {
1059 // Test if A0 is imag/A1 is real
1060 ANode = identifyNode(A1, A0);
1061 // Unable to identify operand components, thus unable to identify rotation
1062 if (!ANode)
1063 return nullptr;
1064 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1065 AReal = A1;
1066 AImag = A0;
1067 } else {
1068 AReal = A0;
1069 AImag = A1;
1070 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1071 }
1072 }
1073
1074 AReal = UnwrapCast(AReal);
1075 AImag = UnwrapCast(AImag);
1076 BReal = UnwrapCast(BReal);
1077 BImag = UnwrapCast(BImag);
1078
1079 VectorType *VTy = cast<VectorType>(V->getType());
1080 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1081 if (AReal->getType() != ExpectedOperandTy)
1082 return nullptr;
1083 if (AImag->getType() != ExpectedOperandTy)
1084 return nullptr;
1085 if (BReal->getType() != ExpectedOperandTy)
1086 return nullptr;
1087 if (BImag->getType() != ExpectedOperandTy)
1088 return nullptr;
1089
1090 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1091 return nullptr;
1092
1093 CompositeNode *Node = identifyNode(AReal, AImag);
1094
1095 // In the case that a node was identified to figure out the rotation, ensure
1096 // that trying to identify a node with AReal and AImag post-unwrap results in
1097 // the same node
1098 if (ANode && Node != ANode) {
1099 LLVM_DEBUG(
1100 dbgs()
1101 << "Identified node is different from previously identified node. "
1102 "Unable to confidently generate a complex operation node\n");
1103 return nullptr;
1104 }
1105
1106 CN->addOperand(Node);
1107 CN->addOperand(identifyNode(BReal, BImag));
1108 CN->addOperand(identifyNode(Phi, RealUser));
1109
1110 return submitCompositeNode(CN);
1111}
1112
1113ComplexDeinterleavingGraph::CompositeNode *
1114ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1115 // Partial reductions don't support non-vector types, so check these first
1116 if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
1117 return nullptr;
1118
1119 if (!R->hasUseList() || !I->hasUseList())
1120 return nullptr;
1121
1122 auto CommonUser =
1123 findCommonBetweenCollections<Value *>(R->users(), I->users());
1124 if (!CommonUser)
1125 return nullptr;
1126
1127 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1128 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1129 return nullptr;
1130
1131 if (CompositeNode *CN = identifyDotProduct(IInst))
1132 return CN;
1133
1134 return nullptr;
1135}
1136
1137ComplexDeinterleavingGraph::CompositeNode *
1138ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
1139 auto It = CachedResult.find(Vals);
1140 if (It != CachedResult.end()) {
1141 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1142 return It->second;
1143 }
1144
1145 if (Vals.size() == 1) {
1146 assert(Factor == 2 && "Can only handle interleave factors of 2");
1147 Value *R = Vals[0].Real;
1148 Value *I = Vals[0].Imag;
1149 if (CompositeNode *CN = identifyPartialReduction(R, I))
1150 return CN;
1151 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1152 if (!IsReduction && R->getType() != I->getType())
1153 return nullptr;
1154 }
1155
1156 if (CompositeNode *CN = identifySplat(Vals))
1157 return CN;
1158
1159 for (auto &V : Vals) {
1160 auto *Real = dyn_cast<Instruction>(V.Real);
1161 auto *Imag = dyn_cast<Instruction>(V.Imag);
1162 if (!Real || !Imag)
1163 return nullptr;
1164 }
1165
1166 if (CompositeNode *CN = identifyDeinterleave(Vals))
1167 return CN;
1168
1169 if (Vals.size() == 1) {
1170 assert(Factor == 2 && "Can only handle interleave factors of 2");
1171 auto *Real = dyn_cast<Instruction>(Vals[0].Real);
1172 auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
1173 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1174 return CN;
1175
1176 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1177 return CN;
1178
1179 auto *VTy = cast<VectorType>(Real->getType());
1180 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1181
1182 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1183 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1184 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1185 ComplexDeinterleavingOperation::CAdd, NewVTy);
1186
1187 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
1188 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1189 return CN;
1190 }
1191
1192 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
1193 if (CompositeNode *CN = identifyAdd(Real, Imag))
1194 return CN;
1195 }
1196
1197 if (HasCMulSupport && HasCAddSupport) {
1198 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1199 return CN;
1200 }
1201 }
1202 }
1203
1204 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1205 return CN;
1206
1207 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1208 CachedResult[Vals] = nullptr;
1209 return nullptr;
1210}
1211
1212ComplexDeinterleavingGraph::CompositeNode *
1213ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1214 Instruction *Imag) {
1215 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1216 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1217 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1218 Opcode == Instruction::Sub;
1219 };
1220
1221 if (!IsOperationSupported(Real->getOpcode()) ||
1222 !IsOperationSupported(Imag->getOpcode()))
1223 return nullptr;
1224
1225 std::optional<FastMathFlags> Flags;
1226 if (isa<FPMathOperator>(Real)) {
1227 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1228 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1229 "not identical\n");
1230 return nullptr;
1231 }
1232
1233 Flags = Real->getFastMathFlags();
1234 if (!Flags->allowReassoc()) {
1235 LLVM_DEBUG(
1236 dbgs()
1237 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1238 return nullptr;
1239 }
1240 }
1241
1242 // Collect multiplications and addend instructions from the given instruction
1243 // while traversing it operands. Additionally, verify that all instructions
1244 // have the same fast math flags.
1245 auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
1246 AddendList &Addends) -> bool {
1247 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1248 while (!Worklist.empty()) {
1249 auto [V, IsPositive] = Worklist.pop_back_val();
1250
1252 if (!I) {
1253 Addends.emplace_back(V, IsPositive);
1254 continue;
1255 }
1256
1257 // If an instruction has more than one user, it indicates that it either
1258 // has an external user, which will be later checked by the checkNodes
1259 // function, or it is a subexpression utilized by multiple expressions. In
1260 // the latter case, we will attempt to separately identify the complex
1261 // operation from here in order to create a shared
1262 // ComplexDeinterleavingCompositeNode.
1263 if (I != Insn && I->hasNUsesOrMore(2)) {
1264 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1265 Addends.emplace_back(I, IsPositive);
1266 continue;
1267 }
1268 switch (I->getOpcode()) {
1269 case Instruction::FAdd:
1270 case Instruction::Add:
1271 Worklist.emplace_back(I->getOperand(1), IsPositive);
1272 Worklist.emplace_back(I->getOperand(0), IsPositive);
1273 break;
1274 case Instruction::FSub:
1275 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1276 Worklist.emplace_back(I->getOperand(0), IsPositive);
1277 break;
1278 case Instruction::Sub:
1279 if (isNeg(I)) {
1280 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1281 } else {
1282 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1283 Worklist.emplace_back(I->getOperand(0), IsPositive);
1284 }
1285 break;
1286 case Instruction::FMul:
1287 case Instruction::Mul: {
1288 Value *A, *B;
1289 if (isNeg(I->getOperand(0))) {
1290 A = getNegOperand(I->getOperand(0));
1291 IsPositive = !IsPositive;
1292 } else {
1293 A = I->getOperand(0);
1294 }
1295
1296 if (isNeg(I->getOperand(1))) {
1297 B = getNegOperand(I->getOperand(1));
1298 IsPositive = !IsPositive;
1299 } else {
1300 B = I->getOperand(1);
1301 }
1302 Muls.push_back(Product{A, B, IsPositive});
1303 break;
1304 }
1305 case Instruction::FNeg:
1306 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1307 break;
1308 default:
1309 Addends.emplace_back(I, IsPositive);
1310 continue;
1311 }
1312
1313 if (Flags && I->getFastMathFlags() != *Flags) {
1314 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1315 "inconsistent with the root instructions' flags: "
1316 << *I << "\n");
1317 return false;
1318 }
1319 }
1320 return true;
1321 };
1322
1323 SmallVector<Product> RealMuls, ImagMuls;
1324 AddendList RealAddends, ImagAddends;
1325 if (!Collect(Real, RealMuls, RealAddends) ||
1326 !Collect(Imag, ImagMuls, ImagAddends))
1327 return nullptr;
1328
1329 if (RealAddends.size() != ImagAddends.size())
1330 return nullptr;
1331
1332 CompositeNode *FinalNode = nullptr;
1333 if (!RealMuls.empty() || !ImagMuls.empty()) {
1334 // If there are multiplicands, extract positive addend and use it as an
1335 // accumulator
1336 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1337 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1338 if (!FinalNode)
1339 return nullptr;
1340 }
1341
1342 // Identify and process remaining additions
1343 if (!RealAddends.empty() || !ImagAddends.empty()) {
1344 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1345 if (!FinalNode)
1346 return nullptr;
1347 }
1348 assert(FinalNode && "FinalNode can not be nullptr here");
1349 assert(FinalNode->Vals.size() == 1);
1350 // Set the Real and Imag fields of the final node and submit it
1351 FinalNode->Vals[0].Real = Real;
1352 FinalNode->Vals[0].Imag = Imag;
1353 submitCompositeNode(FinalNode);
1354 return FinalNode;
1355}
1356
1357bool ComplexDeinterleavingGraph::collectPartialMuls(
1358 ArrayRef<Product> RealMuls, ArrayRef<Product> ImagMuls,
1359 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1360 // Helper function to extract a common operand from two products
1361 auto FindCommonInstruction = [](const Product &Real,
1362 const Product &Imag) -> Value * {
1363 if (Real.Multiplicand == Imag.Multiplicand ||
1364 Real.Multiplicand == Imag.Multiplier)
1365 return Real.Multiplicand;
1366
1367 if (Real.Multiplier == Imag.Multiplicand ||
1368 Real.Multiplier == Imag.Multiplier)
1369 return Real.Multiplier;
1370
1371 return nullptr;
1372 };
1373
1374 // Iterating over real and imaginary multiplications to find common operands
1375 // If a common operand is found, a partial multiplication candidate is created
1376 // and added to the candidates vector The function returns false if no common
1377 // operands are found for any product
1378 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1379 bool FoundCommon = false;
1380 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1381 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1382 if (!Common)
1383 continue;
1384
1385 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1386 : RealMuls[i].Multiplicand;
1387 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1388 : ImagMuls[j].Multiplicand;
1389
1390 auto Node = identifyNode(A, B);
1391 if (Node) {
1392 FoundCommon = true;
1393 PartialMulCandidates.push_back({Common, Node, i, j, false});
1394 }
1395
1396 Node = identifyNode(B, A);
1397 if (Node) {
1398 FoundCommon = true;
1399 PartialMulCandidates.push_back({Common, Node, i, j, true});
1400 }
1401 }
1402 if (!FoundCommon)
1403 return false;
1404 }
1405 return true;
1406}
1407
1408ComplexDeinterleavingGraph::CompositeNode *
1409ComplexDeinterleavingGraph::identifyMultiplications(
1410 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1411 CompositeNode *Accumulator = nullptr) {
1412 if (RealMuls.size() != ImagMuls.size())
1413 return nullptr;
1414
1416 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1417 return nullptr;
1418
1419 // Map to store common instruction to node pointers
1420 DenseMap<Value *, CompositeNode *> CommonToNode;
1421 SmallVector<bool> Processed(Info.size(), false);
1422 for (unsigned I = 0; I < Info.size(); ++I) {
1423 if (Processed[I])
1424 continue;
1425
1426 PartialMulCandidate &InfoA = Info[I];
1427 for (unsigned J = I + 1; J < Info.size(); ++J) {
1428 if (Processed[J])
1429 continue;
1430
1431 PartialMulCandidate &InfoB = Info[J];
1432 auto *InfoReal = &InfoA;
1433 auto *InfoImag = &InfoB;
1434
1435 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1436 if (!NodeFromCommon) {
1437 std::swap(InfoReal, InfoImag);
1438 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1439 }
1440 if (!NodeFromCommon)
1441 continue;
1442
1443 CommonToNode[InfoReal->Common] = NodeFromCommon;
1444 CommonToNode[InfoImag->Common] = NodeFromCommon;
1445 Processed[I] = true;
1446 Processed[J] = true;
1447 }
1448 }
1449
1450 SmallVector<bool> ProcessedReal(RealMuls.size(), false);
1451 SmallVector<bool> ProcessedImag(ImagMuls.size(), false);
1452 CompositeNode *Result = Accumulator;
1453 for (auto &PMI : Info) {
1454 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1455 continue;
1456
1457 auto It = CommonToNode.find(PMI.Common);
1458 // TODO: Process independent complex multiplications. Cases like this:
1459 // A.real() * B where both A and B are complex numbers.
1460 if (It == CommonToNode.end()) {
1461 LLVM_DEBUG({
1462 dbgs() << "Unprocessed independent partial multiplication:\n";
1463 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1464 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1465 << " multiplied by " << *Mul->Multiplicand << "\n";
1466 });
1467 return nullptr;
1468 }
1469
1470 auto &RealMul = RealMuls[PMI.RealIdx];
1471 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1472
1473 auto NodeA = It->second;
1474 auto NodeB = PMI.Node;
1475 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1476 // The following table illustrates the relationship between multiplications
1477 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1478 // can see:
1479 //
1480 // Rotation | Real | Imag |
1481 // ---------+--------+--------+
1482 // 0 | x * u | x * v |
1483 // 90 | -y * v | y * u |
1484 // 180 | -x * u | -x * v |
1485 // 270 | y * v | -y * u |
1486 //
1487 // Check if the candidate can indeed be represented by partial
1488 // multiplication
1489 // TODO: Add support for multiplication by complex one
1490 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1491 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1492 continue;
1493
1494 // Determine the rotation based on the multiplications
1496 if (IsMultiplicandReal) {
1497 // Detect 0 and 180 degrees rotation
1498 if (RealMul.IsPositive && ImagMul.IsPositive)
1500 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1502 else
1503 continue;
1504
1505 } else {
1506 // Detect 90 and 270 degrees rotation
1507 if (!RealMul.IsPositive && ImagMul.IsPositive)
1509 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1511 else
1512 continue;
1513 }
1514
1515 LLVM_DEBUG({
1516 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1517 dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
1518 dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
1519 dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
1520 dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
1521 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1522 });
1523
1524 CompositeNode *NodeMul = prepareCompositeNode(
1525 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1526 NodeMul->Rotation = Rotation;
1527 NodeMul->addOperand(NodeA);
1528 NodeMul->addOperand(NodeB);
1529 if (Result)
1530 NodeMul->addOperand(Result);
1531 submitCompositeNode(NodeMul);
1532 Result = NodeMul;
1533 ProcessedReal[PMI.RealIdx] = true;
1534 ProcessedImag[PMI.ImagIdx] = true;
1535 }
1536
1537 // Ensure all products have been processed, if not return nullptr.
1538 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1539 !all_of(ProcessedImag, [](bool V) { return V; })) {
1540
1541 // Dump debug information about which partial multiplications are not
1542 // processed.
1543 LLVM_DEBUG({
1544 dbgs() << "Unprocessed products (Real):\n";
1545 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1546 if (!ProcessedReal[i])
1547 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1548 << *RealMuls[i].Multiplier << " multiplied by "
1549 << *RealMuls[i].Multiplicand << "\n";
1550 }
1551 dbgs() << "Unprocessed products (Imag):\n";
1552 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1553 if (!ProcessedImag[i])
1554 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1555 << *ImagMuls[i].Multiplier << " multiplied by "
1556 << *ImagMuls[i].Multiplicand << "\n";
1557 }
1558 });
1559 return nullptr;
1560 }
1561
1562 return Result;
1563}
1564
1565ComplexDeinterleavingGraph::CompositeNode *
1566ComplexDeinterleavingGraph::identifyAdditions(
1567 AddendList &RealAddends, AddendList &ImagAddends,
1568 std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
1569 if (RealAddends.size() != ImagAddends.size())
1570 return nullptr;
1571
1572 CompositeNode *Result = nullptr;
1573 // If we have accumulator use it as first addend
1574 if (Accumulator)
1576 // Otherwise find an element with both positive real and imaginary parts.
1577 else
1578 Result = extractPositiveAddend(RealAddends, ImagAddends);
1579
1580 if (!Result)
1581 return nullptr;
1582
1583 while (!RealAddends.empty()) {
1584 auto ItR = RealAddends.begin();
1585 auto [R, IsPositiveR] = *ItR;
1586
1587 bool FoundImag = false;
1588 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1589 auto [I, IsPositiveI] = *ItI;
1591 if (IsPositiveR && IsPositiveI)
1592 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1593 else if (!IsPositiveR && IsPositiveI)
1594 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1595 else if (!IsPositiveR && !IsPositiveI)
1596 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1597 else
1598 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1599
1600 CompositeNode *AddNode = nullptr;
1601 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1602 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1603 AddNode = identifyNode(R, I);
1604 } else {
1605 AddNode = identifyNode(I, R);
1606 }
1607 if (AddNode) {
1608 LLVM_DEBUG({
1609 dbgs() << "Identified addition:\n";
1610 dbgs().indent(4) << "X: " << *R << "\n";
1611 dbgs().indent(4) << "Y: " << *I << "\n";
1612 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1613 });
1614
1615 CompositeNode *TmpNode = nullptr;
1617 TmpNode = prepareCompositeNode(
1618 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1619 if (Flags) {
1620 TmpNode->Opcode = Instruction::FAdd;
1621 TmpNode->Flags = *Flags;
1622 } else {
1623 TmpNode->Opcode = Instruction::Add;
1624 }
1625 } else if (Rotation ==
1627 TmpNode = prepareCompositeNode(
1628 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1629 if (Flags) {
1630 TmpNode->Opcode = Instruction::FSub;
1631 TmpNode->Flags = *Flags;
1632 } else {
1633 TmpNode->Opcode = Instruction::Sub;
1634 }
1635 } else {
1636 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1637 nullptr, nullptr);
1638 TmpNode->Rotation = Rotation;
1639 }
1640
1641 TmpNode->addOperand(Result);
1642 TmpNode->addOperand(AddNode);
1643 submitCompositeNode(TmpNode);
1644 Result = TmpNode;
1645 RealAddends.erase(ItR);
1646 ImagAddends.erase(ItI);
1647 FoundImag = true;
1648 break;
1649 }
1650 }
1651 if (!FoundImag)
1652 return nullptr;
1653 }
1654 return Result;
1655}
1656
1657ComplexDeinterleavingGraph::CompositeNode *
1658ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1659 AddendList &ImagAddends) {
1660 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1661 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1662 auto [R, IsPositiveR] = *ItR;
1663 auto [I, IsPositiveI] = *ItI;
1664 if (IsPositiveR && IsPositiveI) {
1665 auto Result = identifyNode(R, I);
1666 if (Result) {
1667 RealAddends.erase(ItR);
1668 ImagAddends.erase(ItI);
1669 return Result;
1670 }
1671 }
1672 }
1673 }
1674 return nullptr;
1675}
1676
1677bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1678 // This potential root instruction might already have been recognized as
1679 // reduction. Because RootToNode maps both Real and Imaginary parts to
1680 // CompositeNode we should choose only one either Real or Imag instruction to
1681 // use as an anchor for generating complex instruction.
1682 auto It = RootToNode.find(RootI);
1683 if (It != RootToNode.end()) {
1684 auto RootNode = It->second;
1685 assert(RootNode->Operation ==
1686 ComplexDeinterleavingOperation::ReductionOperation ||
1687 RootNode->Operation ==
1688 ComplexDeinterleavingOperation::ReductionSingle);
1689 assert(RootNode->Vals.size() == 1 &&
1690 "Cannot handle reductions involving multiple complex values");
1691 // Find out which part, Real or Imag, comes later, and only if we come to
1692 // the latest part, add it to OrderedRoots.
1693 auto *R = cast<Instruction>(RootNode->Vals[0].Real);
1694 auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(RootNode->Vals[0].Imag)
1695 : nullptr;
1696
1697 Instruction *ReplacementAnchor;
1698 if (I)
1699 ReplacementAnchor = R->comesBefore(I) ? I : R;
1700 else
1701 ReplacementAnchor = R;
1702
1703 if (ReplacementAnchor != RootI)
1704 return false;
1705 OrderedRoots.push_back(RootI);
1706 return true;
1707 }
1708
1709 auto RootNode = identifyRoot(RootI);
1710 if (!RootNode)
1711 return false;
1712
1713 LLVM_DEBUG({
1714 Function *F = RootI->getFunction();
1715 BasicBlock *B = RootI->getParent();
1716 dbgs() << "Complex deinterleaving graph for " << F->getName()
1717 << "::" << B->getName() << ".\n";
1718 dump(dbgs());
1719 dbgs() << "\n";
1720 });
1721 RootToNode[RootI] = RootNode;
1722 OrderedRoots.push_back(RootI);
1723 return true;
1724}
1725
1726bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1727 bool FoundPotentialReduction = false;
1728 if (Factor != 2)
1729 return false;
1730
1731 auto *Br = dyn_cast<CondBrInst>(B->getTerminator());
1732 if (!Br)
1733 return false;
1734
1735 // Identify simple one-block loop
1736 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1737 return false;
1738
1739 for (auto &PHI : B->phis()) {
1740 if (PHI.getNumIncomingValues() != 2)
1741 continue;
1742
1743 if (!PHI.getType()->isVectorTy())
1744 continue;
1745
1746 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1747 if (!ReductionOp)
1748 continue;
1749
1750 // Check if final instruction is reduced outside of current block
1751 Instruction *FinalReduction = nullptr;
1752 auto NumUsers = 0u;
1753 for (auto *U : ReductionOp->users()) {
1754 ++NumUsers;
1755 if (U == &PHI)
1756 continue;
1757 FinalReduction = dyn_cast<Instruction>(U);
1758 }
1759
1760 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1761 isa<PHINode>(FinalReduction))
1762 continue;
1763
1764 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1765 BackEdge = B;
1766 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1767 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1768 Incoming = PHI.getIncomingBlock(IncomingIdx);
1769 FoundPotentialReduction = true;
1770
1771 // If the initial value of PHINode is an Instruction, consider it a leaf
1772 // value of a complex deinterleaving graph.
1773 if (auto *InitPHI =
1774 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1775 FinalInstructions.insert(InitPHI);
1776 }
1777 return FoundPotentialReduction;
1778}
1779
1780void ComplexDeinterleavingGraph::identifyReductionNodes() {
1781 assert(Factor == 2 && "Cannot handle multiple complex values");
1782
1783 SmallVector<bool> Processed(ReductionInfo.size(), false);
1784 SmallVector<Instruction *> OperationInstruction;
1785 for (auto &P : ReductionInfo)
1786 OperationInstruction.push_back(P.first);
1787
1788 // Identify a complex computation by evaluating two reduction operations that
1789 // potentially could be involved
1790 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1791 if (Processed[i])
1792 continue;
1793 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1794 if (Processed[j])
1795 continue;
1796 auto *Real = OperationInstruction[i];
1797 auto *Imag = OperationInstruction[j];
1798 if (Real->getType() != Imag->getType())
1799 continue;
1800
1801 RealPHI = ReductionInfo[Real].first;
1802 ImagPHI = ReductionInfo[Imag].first;
1803 PHIsFound = false;
1804 auto Node = identifyNode(Real, Imag);
1805 if (!Node) {
1806 std::swap(Real, Imag);
1807 std::swap(RealPHI, ImagPHI);
1808 Node = identifyNode(Real, Imag);
1809 }
1810
1811 // If a node is identified and reduction PHINode is used in the chain of
1812 // operations, mark its operation instructions as used to prevent
1813 // re-identification and attach the node to the real part
1814 if (Node && PHIsFound) {
1815 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1816 << *Real << " / " << *Imag << "\n");
1817 Processed[i] = true;
1818 Processed[j] = true;
1819 auto RootNode = prepareCompositeNode(
1820 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1821 RootNode->addOperand(Node);
1822 RootToNode[Real] = RootNode;
1823 RootToNode[Imag] = RootNode;
1824 submitCompositeNode(RootNode);
1825 break;
1826 }
1827 }
1828
1829 auto *Real = OperationInstruction[i];
1830 // We want to check that we have 2 operands, but the function attributes
1831 // being counted as operands bloats this value.
1832 if (Processed[i] || Real->getNumOperands() < 2)
1833 continue;
1834
1835 // Can only combined integer reductions at the moment.
1836 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1837 continue;
1838
1839 RealPHI = ReductionInfo[Real].first;
1840 ImagPHI = nullptr;
1841 PHIsFound = false;
1842 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1843 if (Node && PHIsFound) {
1844 LLVM_DEBUG(
1845 dbgs() << "Identified single reduction starting from instruction: "
1846 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1847
1848 // Reducing to a single vector is not supported, only permit reducing down
1849 // to scalar values.
1850 // Doing this here will leave the prior node in the graph,
1851 // however with no uses the node will be unreachable by the replacement
1852 // process. That along with the usage outside the graph should prevent the
1853 // replacement process from kicking off at all for this graph.
1854 // TODO Add support for reducing to a single vector value
1855 if (ReductionInfo[Real].second->getType()->isVectorTy())
1856 continue;
1857
1858 Processed[i] = true;
1859 auto RootNode = prepareCompositeNode(
1860 ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1861 RootNode->addOperand(Node);
1862 RootToNode[Real] = RootNode;
1863 submitCompositeNode(RootNode);
1864 }
1865 }
1866
1867 RealPHI = nullptr;
1868 ImagPHI = nullptr;
1869}
1870
1871bool ComplexDeinterleavingGraph::checkNodes() {
1872 bool FoundDeinterleaveNode = false;
1873 for (CompositeNode *N : CompositeNodes) {
1874 if (!N->areOperandsValid())
1875 return false;
1876
1877 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1878 FoundDeinterleaveNode = true;
1879 }
1880
1881 // We need a deinterleave node in order to guarantee that we're working with
1882 // complex numbers.
1883 if (!FoundDeinterleaveNode) {
1884 LLVM_DEBUG(
1885 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1886 "guarantee safety during graph transformation.\n");
1887 return false;
1888 }
1889
1890 // Collect all instructions from roots to leaves
1891 SmallPtrSet<Instruction *, 16> AllInstructions;
1892 SmallVector<Instruction *, 8> Worklist;
1893 for (auto &Pair : RootToNode)
1894 Worklist.push_back(Pair.first);
1895
1896 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1897 // chains
1898 while (!Worklist.empty()) {
1899 auto *I = Worklist.pop_back_val();
1900
1901 if (!AllInstructions.insert(I).second)
1902 continue;
1903
1904 for (Value *Op : I->operands()) {
1905 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1906 if (!FinalInstructions.count(I))
1907 Worklist.emplace_back(OpI);
1908 }
1909 }
1910 }
1911
1912 // Find instructions that have users outside of chain
1913 for (auto *I : AllInstructions) {
1914 // Skip root nodes
1915 if (RootToNode.count(I))
1916 continue;
1917
1918 for (User *U : I->users()) {
1919 if (AllInstructions.count(cast<Instruction>(U)))
1920 continue;
1921
1922 // Found an instruction that is not used by XCMLA/XCADD chain
1923 Worklist.emplace_back(I);
1924 break;
1925 }
1926 }
1927
1928 // If any instructions are found to be used outside, find and remove roots
1929 // that somehow connect to those instructions.
1930 SmallPtrSet<Instruction *, 16> Visited;
1931 while (!Worklist.empty()) {
1932 auto *I = Worklist.pop_back_val();
1933 if (!Visited.insert(I).second)
1934 continue;
1935
1936 // Found an impacted root node. Removing it from the nodes to be
1937 // deinterleaved
1938 if (RootToNode.count(I)) {
1939 LLVM_DEBUG(dbgs() << "Instruction " << *I
1940 << " could be deinterleaved but its chain of complex "
1941 "operations have an outside user\n");
1942 RootToNode.erase(I);
1943 }
1944
1945 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1946 continue;
1947
1948 for (User *U : I->users())
1949 Worklist.emplace_back(cast<Instruction>(U));
1950
1951 for (Value *Op : I->operands()) {
1952 if (auto *OpI = dyn_cast<Instruction>(Op))
1953 Worklist.emplace_back(OpI);
1954 }
1955 }
1956 return !RootToNode.empty();
1957}
1958
1959ComplexDeinterleavingGraph::CompositeNode *
1960ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1961 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1963 Intrinsic->getIntrinsicID())
1964 return nullptr;
1965
1966 ComplexValues Vals;
1967 for (unsigned I = 0; I < Factor; I += 2) {
1968 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(I));
1969 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(I + 1));
1970 if (!Real || !Imag)
1971 return nullptr;
1972 Vals.push_back({Real, Imag});
1973 }
1974
1975 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
1976 if (!Node1)
1977 return nullptr;
1978 return Node1;
1979 }
1980
1981 // TODO: We could also add support for fixed-width interleave factors of 4
1982 // and above, but currently for symmetric operations the interleaves and
1983 // deinterleaves are already removed by VectorCombine. If we extend this to
1984 // permit complex multiplications, reductions, etc. then we should also add
1985 // support for fixed-width here.
1986 if (Factor != 2)
1987 return nullptr;
1988
1989 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1990 if (!SVI)
1991 return nullptr;
1992
1993 // Look for a shufflevector that takes separate vectors of the real and
1994 // imaginary components and recombines them into a single vector.
1995 if (!isInterleavingMask(SVI->getShuffleMask()))
1996 return nullptr;
1997
1998 Instruction *Real;
1999 Instruction *Imag;
2000 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
2001 return nullptr;
2002
2003 return identifyNode(Real, Imag);
2004}
2005
2006ComplexDeinterleavingGraph::CompositeNode *
2007ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
2008 Instruction *II = nullptr;
2009
2010 // Must be at least one complex value.
2011 auto CheckExtract = [&](Value *V, unsigned ExpectedIdx,
2012 Instruction *ExpectedInsn) -> ExtractValueInst * {
2013 auto *EVI = dyn_cast<ExtractValueInst>(V);
2014 if (!EVI || EVI->getNumIndices() != 1 ||
2015 EVI->getIndices()[0] != ExpectedIdx ||
2016 !isa<Instruction>(EVI->getAggregateOperand()) ||
2017 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2018 return nullptr;
2019 return EVI;
2020 };
2021
2022 for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
2023 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
2024 if (RealEVI && Idx == 0)
2026 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
2027 II = nullptr;
2028 break;
2029 }
2030 }
2031
2032 if (auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(II)) {
2033 if (IntrinsicII->getIntrinsicID() !=
2035 return nullptr;
2036
2037 // The remaining should match too.
2038 CompositeNode *PlaceholderNode = prepareCompositeNode(
2040 PlaceholderNode->ReplacementNode = II->getOperand(0);
2041 for (auto &V : Vals) {
2042 FinalInstructions.insert(cast<Instruction>(V.Real));
2043 FinalInstructions.insert(cast<Instruction>(V.Imag));
2044 }
2045 return submitCompositeNode(PlaceholderNode);
2046 }
2047
2048 if (Vals.size() != 1)
2049 return nullptr;
2050
2051 Value *Real = Vals[0].Real;
2052 Value *Imag = Vals[0].Imag;
2053 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
2054 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
2055 if (!RealShuffle || !ImagShuffle) {
2056 if (RealShuffle || ImagShuffle)
2057 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
2058 return nullptr;
2059 }
2060
2061 Value *RealOp1 = RealShuffle->getOperand(1);
2062 if (!isa<UndefValue>(RealOp1) && !match(RealOp1, m_Zero())) {
2063 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
2064 return nullptr;
2065 }
2066 Value *ImagOp1 = ImagShuffle->getOperand(1);
2067 if (!isa<UndefValue>(ImagOp1) && !match(ImagOp1, m_Zero())) {
2068 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
2069 return nullptr;
2070 }
2071
2072 Value *RealOp0 = RealShuffle->getOperand(0);
2073 Value *ImagOp0 = ImagShuffle->getOperand(0);
2074
2075 if (RealOp0 != ImagOp0) {
2076 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
2077 return nullptr;
2078 }
2079
2080 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2081 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2082 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
2083 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
2084 return nullptr;
2085 }
2086
2087 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2088 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
2089 return nullptr;
2090 }
2091
2092 // Type checking, the shuffle type should be a vector type of the same
2093 // scalar type, but half the size
2094 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2095 Value *Op = Shuffle->getOperand(0);
2096 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
2097 auto *OpTy = cast<FixedVectorType>(Op->getType());
2098
2099 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2100 return false;
2101 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2102 return false;
2103
2104 return true;
2105 };
2106
2107 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
2108 if (!CheckType(Shuffle))
2109 return false;
2110
2111 ArrayRef<int> Mask = Shuffle->getShuffleMask();
2112 int Last = *Mask.rbegin();
2113
2114 Value *Op = Shuffle->getOperand(0);
2115 auto *OpTy = cast<FixedVectorType>(Op->getType());
2116 int NumElements = OpTy->getNumElements();
2117
2118 // Ensure that the deinterleaving shuffle only pulls from the first
2119 // shuffle operand.
2120 return Last < NumElements;
2121 };
2122
2123 if (RealShuffle->getType() != ImagShuffle->getType()) {
2124 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
2125 return nullptr;
2126 }
2127 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2128 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
2129 return nullptr;
2130 }
2131 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2132 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
2133 return nullptr;
2134 }
2135
2136 CompositeNode *PlaceholderNode =
2138 RealShuffle, ImagShuffle);
2139 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2140 FinalInstructions.insert(RealShuffle);
2141 FinalInstructions.insert(ImagShuffle);
2142 return submitCompositeNode(PlaceholderNode);
2143}
2144
2145ComplexDeinterleavingGraph::CompositeNode *
2146ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
2147 auto IsSplat = [](Value *V) -> bool {
2148 // Fixed-width vector with constants
2150 return true;
2151
2152 if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
2153 return isa<VectorType>(V->getType());
2154
2155 VectorType *VTy;
2156 ArrayRef<int> Mask;
2157 // Splats are represented differently depending on whether the repeated
2158 // value is a constant or an Instruction
2159 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
2160 if (Const->getOpcode() != Instruction::ShuffleVector)
2161 return false;
2162 VTy = cast<VectorType>(Const->getType());
2163 Mask = Const->getShuffleMask();
2164 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2165 VTy = Shuf->getType();
2166 Mask = Shuf->getShuffleMask();
2167 } else {
2168 return false;
2169 }
2170
2171 // When the data type is <1 x Type>, it's not possible to differentiate
2172 // between the ComplexDeinterleaving::Deinterleave and
2173 // ComplexDeinterleaving::Splat operations.
2174 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2175 return false;
2176
2177 return all_equal(Mask) && Mask[0] == 0;
2178 };
2179
2180 // The splats must meet the following requirements:
2181 // 1. Must either be all instructions or all values.
2182 // 2. Non-constant splats must live in the same block.
2183 if (auto *FirstValAsInstruction = dyn_cast<Instruction>(Vals[0].Real)) {
2184 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2185 for (auto &V : Vals) {
2186 if (!IsSplat(V.Real) || !IsSplat(V.Imag))
2187 return nullptr;
2188
2189 auto *Real = dyn_cast<Instruction>(V.Real);
2190 auto *Imag = dyn_cast<Instruction>(V.Imag);
2191 if (!Real || !Imag || Real->getParent() != FirstBB ||
2192 Imag->getParent() != FirstBB)
2193 return nullptr;
2194 }
2195 } else {
2196 for (auto &V : Vals) {
2197 if (!IsSplat(V.Real) || !IsSplat(V.Imag) || isa<Instruction>(V.Real) ||
2198 isa<Instruction>(V.Imag))
2199 return nullptr;
2200 }
2201 }
2202
2203 for (auto &V : Vals) {
2204 auto *Real = dyn_cast<Instruction>(V.Real);
2205 auto *Imag = dyn_cast<Instruction>(V.Imag);
2206 if (Real && Imag) {
2207 FinalInstructions.insert(Real);
2208 FinalInstructions.insert(Imag);
2209 }
2210 }
2211 CompositeNode *PlaceholderNode =
2212 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2213 return submitCompositeNode(PlaceholderNode);
2214}
2215
2216ComplexDeinterleavingGraph::CompositeNode *
2217ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2218 Instruction *Imag) {
2219 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2220 return nullptr;
2221
2222 PHIsFound = true;
2223 CompositeNode *PlaceholderNode = prepareCompositeNode(
2224 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2225 return submitCompositeNode(PlaceholderNode);
2226}
2227
2228ComplexDeinterleavingGraph::CompositeNode *
2229ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2230 Instruction *Imag) {
2231 auto *SelectReal = dyn_cast<SelectInst>(Real);
2232 auto *SelectImag = dyn_cast<SelectInst>(Imag);
2233 if (!SelectReal || !SelectImag)
2234 return nullptr;
2235
2236 Instruction *MaskA, *MaskB;
2237 Instruction *AR, *AI, *RA, *BI;
2238 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
2239 m_Instruction(RA))) ||
2240 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
2241 m_Instruction(BI))))
2242 return nullptr;
2243
2244 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
2245 return nullptr;
2246
2247 if (!MaskA->getType()->isVectorTy())
2248 return nullptr;
2249
2250 auto NodeA = identifyNode(AR, AI);
2251 if (!NodeA)
2252 return nullptr;
2253
2254 auto NodeB = identifyNode(RA, BI);
2255 if (!NodeB)
2256 return nullptr;
2257
2258 CompositeNode *PlaceholderNode = prepareCompositeNode(
2259 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2260 PlaceholderNode->addOperand(NodeA);
2261 PlaceholderNode->addOperand(NodeB);
2262 FinalInstructions.insert(MaskA);
2263 FinalInstructions.insert(MaskB);
2264 return submitCompositeNode(PlaceholderNode);
2265}
2266
2267static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2268 std::optional<FastMathFlags> Flags,
2269 Value *InputA, Value *InputB) {
2270 Value *I;
2271 switch (Opcode) {
2272 case Instruction::FNeg:
2273 I = B.CreateFNeg(InputA);
2274 break;
2275 case Instruction::FAdd:
2276 I = B.CreateFAdd(InputA, InputB);
2277 break;
2278 case Instruction::Add:
2279 I = B.CreateAdd(InputA, InputB);
2280 break;
2281 case Instruction::FSub:
2282 I = B.CreateFSub(InputA, InputB);
2283 break;
2284 case Instruction::Sub:
2285 I = B.CreateSub(InputA, InputB);
2286 break;
2287 case Instruction::FMul:
2288 I = B.CreateFMul(InputA, InputB);
2289 break;
2290 case Instruction::Mul:
2291 I = B.CreateMul(InputA, InputB);
2292 break;
2293 default:
2294 llvm_unreachable("Incorrect symmetric opcode");
2295 }
2296 if (Flags)
2297 cast<Instruction>(I)->setFastMathFlags(*Flags);
2298 return I;
2299}
2300
2301Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2302 CompositeNode *Node) {
2303 if (Node->ReplacementNode)
2304 return Node->ReplacementNode;
2305
2306 auto ReplaceOperandIfExist = [&](CompositeNode *Node,
2307 unsigned Idx) -> Value * {
2308 return Node->Operands.size() > Idx
2309 ? replaceNode(Builder, Node->Operands[Idx])
2310 : nullptr;
2311 };
2312
2313 Value *ReplacementNode = nullptr;
2314 switch (Node->Operation) {
2315 case ComplexDeinterleavingOperation::CDot: {
2316 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2317 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2318 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2319 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2320 "Node inputs need to be of the same type"));
2321 ReplacementNode = TL->createComplexDeinterleavingIR(
2322 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
2323 break;
2324 }
2325 case ComplexDeinterleavingOperation::CAdd:
2326 case ComplexDeinterleavingOperation::CMulPartial:
2327 case ComplexDeinterleavingOperation::Symmetric: {
2328 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2329 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2330 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2331 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2332 "Node inputs need to be of the same type"));
2334 (Input0->getType() == Accumulator->getType() &&
2335 "Accumulator and input need to be of the same type"));
2336 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2337 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
2338 Input0, Input1);
2339 else
2340 ReplacementNode = TL->createComplexDeinterleavingIR(
2341 Builder, Node->Operation, Node->Rotation, Input0, Input1,
2342 Accumulator);
2343 break;
2344 }
2345 case ComplexDeinterleavingOperation::Deinterleave:
2346 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2347 break;
2348 case ComplexDeinterleavingOperation::Splat: {
2350 for (auto &V : Node->Vals) {
2351 Ops.push_back(V.Real);
2352 Ops.push_back(V.Imag);
2353 }
2354 auto *R = dyn_cast<Instruction>(Node->Vals[0].Real);
2355 auto *I = dyn_cast<Instruction>(Node->Vals[0].Imag);
2356 if (R && I) {
2357 // Splats that are not constant are interleaved where they are located
2358 Instruction *InsertPoint = R;
2359 for (auto V : Node->Vals) {
2360 if (InsertPoint->comesBefore(cast<Instruction>(V.Real)))
2361 InsertPoint = cast<Instruction>(V.Real);
2362 if (InsertPoint->comesBefore(cast<Instruction>(V.Imag)))
2363 InsertPoint = cast<Instruction>(V.Imag);
2364 }
2365 InsertPoint = InsertPoint->getNextNode();
2366 IRBuilder<> IRB(InsertPoint);
2367 ReplacementNode = IRB.CreateVectorInterleave(Ops);
2368 } else {
2369 ReplacementNode = Builder.CreateVectorInterleave(Ops);
2370 }
2371 break;
2372 }
2373 case ComplexDeinterleavingOperation::ReductionPHI: {
2374 // If Operation is ReductionPHI, a new empty PHINode is created.
2375 // It is filled later when the ReductionOperation is processed.
2376 auto *OldPHI = cast<PHINode>(Node->Vals[0].Real);
2377 auto *VTy = cast<VectorType>(Node->Vals[0].Real->getType());
2378 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2379 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
2380 OldToNewPHI[OldPHI] = NewPHI;
2381 ReplacementNode = NewPHI;
2382 break;
2383 }
2384 case ComplexDeinterleavingOperation::ReductionSingle:
2385 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2386 processReductionSingle(ReplacementNode, Node);
2387 break;
2388 case ComplexDeinterleavingOperation::ReductionOperation:
2389 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2390 processReductionOperation(ReplacementNode, Node);
2391 break;
2392 case ComplexDeinterleavingOperation::ReductionSelect: {
2393 auto *MaskReal = cast<Instruction>(Node->Vals[0].Real)->getOperand(0);
2394 auto *MaskImag = cast<Instruction>(Node->Vals[0].Imag)->getOperand(0);
2395 auto *A = replaceNode(Builder, Node->Operands[0]);
2396 auto *B = replaceNode(Builder, Node->Operands[1]);
2397 auto *NewMask = Builder.CreateVectorInterleave({MaskReal, MaskImag});
2398 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
2399 break;
2400 }
2401 }
2402
2403 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2404 NumComplexTransformations += 1;
2405 Node->ReplacementNode = ReplacementNode;
2406 return ReplacementNode;
2407}
2408
2409void ComplexDeinterleavingGraph::processReductionSingle(
2410 Value *OperationReplacement, CompositeNode *Node) {
2411 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2412 auto *OldPHI = ReductionInfo[Real].first;
2413 auto *NewPHI = OldToNewPHI[OldPHI];
2414 auto *VTy = cast<VectorType>(Real->getType());
2415 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2416
2417 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2418
2419 IRBuilder<> Builder(Incoming->getTerminator());
2420
2421 Value *NewInit = nullptr;
2422 if (auto *C = dyn_cast<Constant>(Init)) {
2423 if (C->isNullValue())
2424 NewInit = Constant::getNullValue(NewVTy);
2425 }
2426
2427 if (!NewInit)
2428 NewInit =
2429 Builder.CreateVectorInterleave({Init, Constant::getNullValue(VTy)});
2430
2431 NewPHI->addIncoming(NewInit, Incoming);
2432 NewPHI->addIncoming(OperationReplacement, BackEdge);
2433
2434 auto *FinalReduction = ReductionInfo[Real].second;
2435 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2436
2437 auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2438 FinalReduction->replaceAllUsesWith(AddReduce);
2439}
2440
2441void ComplexDeinterleavingGraph::processReductionOperation(
2442 Value *OperationReplacement, CompositeNode *Node) {
2443 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2444 auto *Imag = cast<Instruction>(Node->Vals[0].Imag);
2445 auto *OldPHIReal = ReductionInfo[Real].first;
2446 auto *OldPHIImag = ReductionInfo[Imag].first;
2447 auto *NewPHI = OldToNewPHI[OldPHIReal];
2448
2449 // We have to interleave initial origin values coming from IncomingBlock
2450 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2451 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2452
2453 IRBuilder<> Builder(Incoming->getTerminator());
2454 auto *NewInit = Builder.CreateVectorInterleave({InitReal, InitImag});
2455
2456 NewPHI->addIncoming(NewInit, Incoming);
2457 NewPHI->addIncoming(OperationReplacement, BackEdge);
2458
2459 // Deinterleave complex vector outside of loop so that it can be finally
2460 // reduced
2461 auto *FinalReductionReal = ReductionInfo[Real].second;
2462 auto *FinalReductionImag = ReductionInfo[Imag].second;
2463
2464 auto *Br = cast<CondBrInst>(BackEdge->getTerminator());
2465 BasicBlock *ExitBB = Br->getSuccessor(Br->getSuccessor(0) == BackEdge);
2466 Builder.SetInsertPoint(&*ExitBB->getFirstInsertionPt());
2467
2468 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2469 OperationReplacement->getType(),
2470 OperationReplacement);
2471
2472 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2473 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2474
2475 Builder.SetInsertPoint(FinalReductionImag);
2476 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2477 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2478}
2479
2480void ComplexDeinterleavingGraph::replaceNodes() {
2481 SmallVector<Instruction *, 16> DeadInstrRoots;
2482 for (auto *RootInstruction : OrderedRoots) {
2483 // Check if this potential root went through check process and we can
2484 // deinterleave it
2485 if (!RootToNode.count(RootInstruction))
2486 continue;
2487
2488 IRBuilder<> Builder(RootInstruction);
2489 auto RootNode = RootToNode[RootInstruction];
2490 Value *R = replaceNode(Builder, RootNode);
2491
2492 if (RootNode->Operation ==
2493 ComplexDeinterleavingOperation::ReductionOperation) {
2494 auto *RootReal = cast<Instruction>(RootNode->Vals[0].Real);
2495 auto *RootImag = cast<Instruction>(RootNode->Vals[0].Imag);
2496 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2497 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2498 DeadInstrRoots.push_back(RootReal);
2499 DeadInstrRoots.push_back(RootImag);
2500 } else if (RootNode->Operation ==
2501 ComplexDeinterleavingOperation::ReductionSingle) {
2502 auto *RootInst = cast<Instruction>(RootNode->Vals[0].Real);
2503 auto &Info = ReductionInfo[RootInst];
2504 Info.first->removeIncomingValue(BackEdge);
2505 DeadInstrRoots.push_back(Info.second);
2506 } else {
2507 assert(R && "Unable to find replacement for RootInstruction");
2508 DeadInstrRoots.push_back(RootInstruction);
2509 RootInstruction->replaceAllUsesWith(R);
2510 }
2511 }
2512
2513 for (auto *I : DeadInstrRoots)
2515}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Rewrite undef for PHI
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
This file implements a map that provides insertion order iteration.
#define T
uint64_t IntrinsicInst * II
#define P(N)
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:119
This file describes how to lower LLVM code to machine code.
This pass exposes codegen information to IR-level passes.
BinaryOperator * Mul
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:275
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
size_t size() const
Get the array size.
Definition ArrayRef.h:141
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:225
iterator end()
Definition DenseMap.h:143
bool allowContract() const
Definition FMF.h:69
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition IRBuilder.h:2664
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI Value * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
LLVM_ABI Value * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={}, function_ref< void(CallInst *)> SetFn=[](CallInst *) {})
Variant to create a possibly constant-folded intrinsic.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:181
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
size_type size() const
Definition MapVector.h:58
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:288
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
An opaque object representing a hash code.
Definition Hashing.h:77
const ParentTy * getParent() const
Definition ilist_node.h:34
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BR
Control flow instructions. These all have token chains.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
match_bind< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BinOp()
Match an arbitrary binary operation and ignore it.
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
NodeAddr< NodeBase * > Node
Definition RDFGraph.h:381
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition DWP.cpp:558
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1739
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:535
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
@ Other
Any other memory.
Definition ModRef.h:68
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1772
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1947
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition STLExtras.h:2166
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:305
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:862
#define N
ComplexDeinterleavingPass(const TargetMachine &TM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static unsigned getHashValue(const ComplexValue &Val)
An information struct used to provide DenseMap with the various necessary components for a given valu...