LLVM 23.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
64#include "llvm/ADT/MapVector.h"
65#include "llvm/ADT/Statistic.h"
70#include "llvm/IR/IRBuilder.h"
71#include "llvm/IR/Intrinsics.h"
77#include <algorithm>
78
79using namespace llvm;
80using namespace PatternMatch;
81
82#define DEBUG_TYPE "complex-deinterleaving"
83
84STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
85
87 "enable-complex-deinterleaving",
88 cl::desc("Enable generation of complex instructions"), cl::init(true),
90
91/// Checks the given mask, and determines whether said mask is interleaving.
92///
93/// To be interleaving, a mask must alternate between `i` and `i + (Length /
94/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
95/// 4x vector interleaving mask would be <0, 2, 1, 3>).
96static bool isInterleavingMask(ArrayRef<int> Mask);
97
98/// Checks the given mask, and determines whether said mask is deinterleaving.
99///
100/// To be deinterleaving, a mask must increment in steps of 2, and either start
101/// with 0 or 1.
102/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
103/// <1, 3, 5, 7>).
104static bool isDeinterleavingMask(ArrayRef<int> Mask);
105
106/// Returns true if the operation is a negation of V, and it works for both
107/// integers and floats.
108static bool isNeg(Value *V);
109
110/// Returns the operand for negation operation.
111static Value *getNegOperand(Value *V);
112
113namespace {
114struct ComplexValue {
115 Value *Real = nullptr;
116 Value *Imag = nullptr;
117
118 bool operator==(const ComplexValue &Other) const {
119 return Real == Other.Real && Imag == Other.Imag;
120 }
121};
122hash_code hash_value(const ComplexValue &Arg) {
125}
126} // end namespace
128
129template <> struct llvm::DenseMapInfo<ComplexValue> {
130 static unsigned getHashValue(const ComplexValue &Val) {
133 }
134 static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
135 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
136 }
137};
138
139namespace {
140template <typename T, typename IterT>
141std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
142 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
143 if (Common != A.end())
144 return std::make_optional(*Common);
145 return std::nullopt;
146}
147
148class ComplexDeinterleavingLegacyPass : public FunctionPass {
149public:
150 static char ID;
151
152 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
153 : FunctionPass(ID), TM(TM) {}
154
155 StringRef getPassName() const override {
156 return "Complex Deinterleaving Pass";
157 }
158
159 bool runOnFunction(Function &F) override;
160 void getAnalysisUsage(AnalysisUsage &AU) const override {
161 AU.addRequired<TargetLibraryInfoWrapperPass>();
162 AU.setPreservesCFG();
163 }
164
165private:
166 const TargetMachine *TM;
167};
168
169class ComplexDeinterleavingGraph;
170struct ComplexDeinterleavingCompositeNode {
171
172 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
173 Value *R, Value *I)
174 : Operation(Op) {
175 Vals.push_back({R, I});
176 }
177
178 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
180 : Operation(Op), Vals(Other) {}
181
182private:
183 friend class ComplexDeinterleavingGraph;
184 using CompositeNode = ComplexDeinterleavingCompositeNode;
185 bool OperandsValid = true;
186
187public:
189 ComplexValues Vals;
190
191 // This two members are required exclusively for generating
192 // ComplexDeinterleavingOperation::Symmetric operations.
193 unsigned Opcode;
194 std::optional<FastMathFlags> Flags;
195
197 ComplexDeinterleavingRotation::Rotation_0;
199 Value *ReplacementNode = nullptr;
200
201 void addOperand(CompositeNode *Node) {
202 if (!Node)
203 OperandsValid = false;
204 Operands.push_back(Node);
205 }
206
207 void dump() { dump(dbgs()); }
208 void dump(raw_ostream &OS) {
209 auto PrintValue = [&](Value *V) {
210 if (V) {
211 OS << "\"";
212 V->print(OS, true);
213 OS << "\"\n";
214 } else
215 OS << "nullptr\n";
216 };
217 auto PrintNodeRef = [&](CompositeNode *Ptr) {
218 if (Ptr)
219 OS << Ptr << "\n";
220 else
221 OS << "nullptr\n";
222 };
223
224 OS << "- CompositeNode: " << this << "\n";
225 for (unsigned I = 0; I < Vals.size(); I++) {
226 OS << " Real(" << I << ") : ";
227 PrintValue(Vals[I].Real);
228 OS << " Imag(" << I << ") : ";
229 PrintValue(Vals[I].Imag);
230 }
231 OS << " ReplacementNode: ";
232 PrintValue(ReplacementNode);
233 OS << " Operation: " << (int)Operation << "\n";
234 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
235 OS << " Operands: \n";
236 for (const auto &Op : Operands) {
237 OS << " - ";
238 PrintNodeRef(Op);
239 }
240 }
241
242 bool areOperandsValid() { return OperandsValid; }
243};
244
245class ComplexDeinterleavingGraph {
246public:
247 struct Product {
248 Value *Multiplier;
249 Value *Multiplicand;
250 bool IsPositive;
251 };
252
253 using Addend = std::pair<Value *, bool>;
254 using AddendList = BumpPtrList<Addend>;
255 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
256
257 // Helper struct for holding info about potential partial multiplication
258 // candidates
259 struct PartialMulCandidate {
260 Value *Common;
261 CompositeNode *Node;
262 unsigned RealIdx;
263 unsigned ImagIdx;
264 bool IsNodeInverted;
265 };
266
267 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
268 const TargetLibraryInfo *TLI,
269 unsigned Factor)
270 : TL(TL), TLI(TLI), Factor(Factor) {}
271
272private:
273 const TargetLowering *TL = nullptr;
274 const TargetLibraryInfo *TLI = nullptr;
275 unsigned Factor;
276 SmallVector<CompositeNode *> CompositeNodes;
277 DenseMap<ComplexValues, CompositeNode *> CachedResult;
278 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
279
280 SmallPtrSet<Instruction *, 16> FinalInstructions;
281
282 /// Root instructions are instructions from which complex computation starts
283 DenseMap<Instruction *, CompositeNode *> RootToNode;
284
285 /// Topologically sorted root instructions
287
288 /// When examining a basic block for complex deinterleaving, if it is a simple
289 /// one-block loop, then the only incoming block is 'Incoming' and the
290 /// 'BackEdge' block is the block itself."
291 BasicBlock *BackEdge = nullptr;
292 BasicBlock *Incoming = nullptr;
293
294 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
295 /// %OutsideUser as it is shown in the IR:
296 ///
297 /// vector.body:
298 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
299 /// [ %ReductionOp, %vector.body ]
300 /// ...
301 /// %ReductionOp = fadd i64 ...
302 /// ...
303 /// br i1 %condition, label %vector.body, %middle.block
304 ///
305 /// middle.block:
306 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
307 ///
308 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
309 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
310 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
311
312 /// In the process of detecting a reduction, we consider a pair of
313 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
314 /// traverse the use-tree to detect complex operations. As this is a reduction
315 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
316 /// to the %ReductionOPs that we suspect to be complex.
317 /// RealPHI and ImagPHI are used by the identifyPHINode method.
318 PHINode *RealPHI = nullptr;
319 PHINode *ImagPHI = nullptr;
320
321 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
322 /// detection.
323 bool PHIsFound = false;
324
325 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
326 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
327 /// This mapping is populated during
328 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
329 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
330 /// replacement process.
331 DenseMap<PHINode *, PHINode *> OldToNewPHI;
332
333 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
334 Value *R, Value *I) {
335 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
336 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
337 (R && I)) &&
338 "Reduction related nodes must have Real and Imaginary parts");
339 return new (Allocator.Allocate())
340 ComplexDeinterleavingCompositeNode(Operation, R, I);
341 }
342
343 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
344 ComplexValues &Vals) {
345#ifndef NDEBUG
346 for (auto &V : Vals) {
347 assert(
348 ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
349 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
350 (V.Real && V.Imag)) &&
351 "Reduction related nodes must have Real and Imaginary parts");
352 }
353#endif
354 return new (Allocator.Allocate())
355 ComplexDeinterleavingCompositeNode(Operation, Vals);
356 }
357
358 CompositeNode *submitCompositeNode(CompositeNode *Node) {
359 CompositeNodes.push_back(Node);
360 if (Node->Vals[0].Real)
361 CachedResult[Node->Vals] = Node;
362 return Node;
363 }
364
365 /// Identifies a complex partial multiply pattern and its rotation, based on
366 /// the following patterns
367 ///
368 /// 0: r: cr + ar * br
369 /// i: ci + ar * bi
370 /// 90: r: cr - ai * bi
371 /// i: ci + ai * br
372 /// 180: r: cr - ar * br
373 /// i: ci - ar * bi
374 /// 270: r: cr + ai * bi
375 /// i: ci - ai * br
376 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
377
378 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
379 /// is partially known from identifyPartialMul, filling in the other half of
380 /// the complex pair.
381 CompositeNode *
382 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
383 std::pair<Value *, Value *> &CommonOperandI);
384
385 /// Identifies a complex add pattern and its rotation, based on the following
386 /// patterns.
387 ///
388 /// 90: r: ar - bi
389 /// i: ai + br
390 /// 270: r: ar + bi
391 /// i: ai - br
392 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
393 CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
394 CompositeNode *identifyPartialReduction(Value *R, Value *I);
395 CompositeNode *identifyDotProduct(Value *Inst);
396
397 CompositeNode *identifyNode(ComplexValues &Vals);
398
399 CompositeNode *identifyNode(Value *R, Value *I) {
400 ComplexValues Vals;
401 Vals.push_back({R, I});
402 return identifyNode(Vals);
403 }
404
405 /// Determine if a sum of complex numbers can be formed from \p RealAddends
406 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
407 /// Return nullptr if it is not possible to construct a complex number.
408 /// \p Flags are needed to generate symmetric Add and Sub operations.
409 CompositeNode *identifyAdditions(AddendList &RealAddends,
410 AddendList &ImagAddends,
411 std::optional<FastMathFlags> Flags,
412 CompositeNode *Accumulator);
413
414 /// Extract one addend that have both real and imaginary parts positive.
415 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
416 AddendList &ImagAddends);
417
418 /// Determine if sum of multiplications of complex numbers can be formed from
419 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
420 /// to it. Return nullptr if it is not possible to construct a complex number.
421 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
422 SmallVectorImpl<Product> &ImagMuls,
423 CompositeNode *Accumulator);
424
425 /// Go through pairs of multiplication (one Real and one Imag) and find all
426 /// possible candidates for partial multiplication and put them into \p
427 /// Candidates. Returns true if all Product has pair with common operand
428 bool collectPartialMuls(ArrayRef<Product> RealMuls,
429 ArrayRef<Product> ImagMuls,
430 SmallVectorImpl<PartialMulCandidate> &Candidates);
431
432 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
433 /// the order of complex computation operations may be significantly altered,
434 /// and the real and imaginary parts may not be executed in parallel. This
435 /// function takes this into consideration and employs a more general approach
436 /// to identify complex computations. Initially, it gathers all the addends
437 /// and multiplicands and then constructs a complex expression from them.
438 CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);
439
440 CompositeNode *identifyRoot(Instruction *I);
441
442 /// Identifies the Deinterleave operation applied to a vector containing
443 /// complex numbers. There are two ways to represent the Deinterleave
444 /// operation:
445 /// * Using two shufflevectors with even indices for /pReal instruction and
446 /// odd indices for /pImag instructions (only for fixed-width vectors)
447 /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
448 /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
449 /// 2.
450 CompositeNode *identifyDeinterleave(ComplexValues &Vals);
451
452 /// identifying the operation that represents a complex number repeated in a
453 /// Splat vector. There are two possible types of splats: ConstantExpr with
454 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
455 /// initialization mask with all values set to zero.
456 CompositeNode *identifySplat(ComplexValues &Vals);
457
458 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
459
460 /// Identifies SelectInsts in a loop that has reduction with predication masks
461 /// and/or predicated tail folding
462 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
463
464 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
465
466 /// Complete IR modifications after producing new reduction operation:
467 /// * Populate the PHINode generated for
468 /// ComplexDeinterleavingOperation::ReductionPHI
469 /// * Deinterleave the final value outside of the loop and repurpose original
470 /// reduction users
471 void processReductionOperation(Value *OperationReplacement,
472 CompositeNode *Node);
473 void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);
474
475public:
476 void dump() { dump(dbgs()); }
477 void dump(raw_ostream &OS) {
478 for (const auto &Node : CompositeNodes)
479 Node->dump(OS);
480 }
481
482 /// Returns false if the deinterleaving operation should be cancelled for the
483 /// current graph.
484 bool identifyNodes(Instruction *RootI);
485
486 /// In case \pB is one-block loop, this function seeks potential reductions
487 /// and populates ReductionInfo. Returns true if any reductions were
488 /// identified.
489 bool collectPotentialReductions(BasicBlock *B);
490
491 void identifyReductionNodes();
492
493 /// Check that every instruction, from the roots to the leaves, has internal
494 /// uses.
495 bool checkNodes();
496
497 /// Perform the actual replacement of the underlying instruction graph.
498 void replaceNodes();
499};
500
501class ComplexDeinterleaving {
502public:
503 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
504 : TL(tl), TLI(tli) {}
505 bool runOnFunction(Function &F);
506
507private:
508 bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);
509
510 const TargetLowering *TL = nullptr;
511 const TargetLibraryInfo *TLI = nullptr;
512};
513
514} // namespace
515
516char ComplexDeinterleavingLegacyPass::ID = 0;
517
518INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
519 "Complex Deinterleaving", false, false)
520INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
521 "Complex Deinterleaving", false, false)
522
525 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
526 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
527 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
528 return PreservedAnalyses::all();
529
532 return PA;
533}
534
536 return new ComplexDeinterleavingLegacyPass(TM);
537}
538
539bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
540 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
541 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
542 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
543}
544
545bool ComplexDeinterleaving::runOnFunction(Function &F) {
548 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
549 return false;
550 }
551
554 dbgs() << "Complex deinterleaving has been disabled, target does "
555 "not support lowering of complex number operations.\n");
556 return false;
557 }
558
559 bool Changed = false;
560 for (auto &B : F)
561 Changed |= evaluateBasicBlock(&B, 2);
562
563 // TODO: Permit changes for both interleave factors in the same function.
564 if (!Changed) {
565 for (auto &B : F)
566 Changed |= evaluateBasicBlock(&B, 4);
567 }
568
569 // TODO: We can also support interleave factors of 6 and 8 if needed.
570
571 return Changed;
572}
573
575 // If the size is not even, it's not an interleaving mask
576 if ((Mask.size() & 1))
577 return false;
578
579 int HalfNumElements = Mask.size() / 2;
580 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
581 int MaskIdx = Idx * 2;
582 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
583 return false;
584 }
585
586 return true;
587}
588
590 int Offset = Mask[0];
591 int HalfNumElements = Mask.size() / 2;
592
593 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
594 if (Mask[Idx] != (Idx * 2) + Offset)
595 return false;
596 }
597
598 return true;
599}
600
601bool isNeg(Value *V) {
602 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
603}
604
606 assert(isNeg(V));
607 auto *I = cast<Instruction>(V);
608 if (I->getOpcode() == Instruction::FNeg)
609 return I->getOperand(0);
610
611 return I->getOperand(1);
612}
613
614bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
615 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
616 if (Graph.collectPotentialReductions(B))
617 Graph.identifyReductionNodes();
618
619 for (auto &I : *B)
620 Graph.identifyNodes(&I);
621
622 if (Graph.checkNodes()) {
623 Graph.replaceNodes();
624 return true;
625 }
626
627 return false;
628}
629
630ComplexDeinterleavingGraph::CompositeNode *
631ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
632 Instruction *Real, Instruction *Imag,
633 std::pair<Value *, Value *> &PartialMatch) {
634 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
635 << "\n");
636
637 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
638 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
639 return nullptr;
640 }
641
642 if ((Real->getOpcode() != Instruction::FMul &&
643 Real->getOpcode() != Instruction::Mul) ||
644 (Imag->getOpcode() != Instruction::FMul &&
645 Imag->getOpcode() != Instruction::Mul)) {
647 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
648 return nullptr;
649 }
650
651 Value *R0 = Real->getOperand(0);
652 Value *R1 = Real->getOperand(1);
653 Value *I0 = Imag->getOperand(0);
654 Value *I1 = Imag->getOperand(1);
655
656 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
657 // rotations and use the operand.
658 unsigned Negs = 0;
659 Value *Op;
660 if (match(R0, m_Neg(m_Value(Op)))) {
661 Negs |= 1;
662 R0 = Op;
663 } else if (match(R1, m_Neg(m_Value(Op)))) {
664 Negs |= 1;
665 R1 = Op;
666 }
667
668 if (isNeg(I0)) {
669 Negs |= 2;
670 Negs ^= 1;
671 I0 = Op;
672 } else if (match(I1, m_Neg(m_Value(Op)))) {
673 Negs |= 2;
674 Negs ^= 1;
675 I1 = Op;
676 }
677
679
680 Value *CommonOperand;
681 Value *UncommonRealOp;
682 Value *UncommonImagOp;
683
684 if (R0 == I0 || R0 == I1) {
685 CommonOperand = R0;
686 UncommonRealOp = R1;
687 } else if (R1 == I0 || R1 == I1) {
688 CommonOperand = R1;
689 UncommonRealOp = R0;
690 } else {
691 LLVM_DEBUG(dbgs() << " - No equal operand\n");
692 return nullptr;
693 }
694
695 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
696 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
697 Rotation == ComplexDeinterleavingRotation::Rotation_270)
698 std::swap(UncommonRealOp, UncommonImagOp);
699
700 // Between identifyPartialMul and here we need to have found a complete valid
701 // pair from the CommonOperand of each part.
702 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
703 Rotation == ComplexDeinterleavingRotation::Rotation_180)
704 PartialMatch.first = CommonOperand;
705 else
706 PartialMatch.second = CommonOperand;
707
708 if (!PartialMatch.first || !PartialMatch.second) {
709 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
710 return nullptr;
711 }
712
713 CompositeNode *CommonNode =
714 identifyNode(PartialMatch.first, PartialMatch.second);
715 if (!CommonNode) {
716 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
717 return nullptr;
718 }
719
720 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
721 if (!UncommonNode) {
722 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
723 return nullptr;
724 }
725
726 CompositeNode *Node = prepareCompositeNode(
727 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
728 Node->Rotation = Rotation;
729 Node->addOperand(CommonNode);
730 Node->addOperand(UncommonNode);
731 return submitCompositeNode(Node);
732}
733
734ComplexDeinterleavingGraph::CompositeNode *
735ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
736 Instruction *Imag) {
737 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
738 << "\n");
739
740 // Determine rotation
741 auto IsAdd = [](unsigned Op) {
742 return Op == Instruction::FAdd || Op == Instruction::Add;
743 };
744 auto IsSub = [](unsigned Op) {
745 return Op == Instruction::FSub || Op == Instruction::Sub;
746 };
748 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
749 Rotation = ComplexDeinterleavingRotation::Rotation_0;
750 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
751 Rotation = ComplexDeinterleavingRotation::Rotation_90;
752 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
753 Rotation = ComplexDeinterleavingRotation::Rotation_180;
754 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
755 Rotation = ComplexDeinterleavingRotation::Rotation_270;
756 else {
757 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
758 return nullptr;
759 }
760
761 if (isa<FPMathOperator>(Real) &&
762 (!Real->getFastMathFlags().allowContract() ||
763 !Imag->getFastMathFlags().allowContract())) {
764 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
765 return nullptr;
766 }
767
768 Value *CR = Real->getOperand(0);
769 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
770 if (!RealMulI)
771 return nullptr;
772 Value *CI = Imag->getOperand(0);
773 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
774 if (!ImagMulI)
775 return nullptr;
776
777 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
778 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
779 return nullptr;
780 }
781
782 Value *R0 = RealMulI->getOperand(0);
783 Value *R1 = RealMulI->getOperand(1);
784 Value *I0 = ImagMulI->getOperand(0);
785 Value *I1 = ImagMulI->getOperand(1);
786
787 Value *CommonOperand;
788 Value *UncommonRealOp;
789 Value *UncommonImagOp;
790
791 if (R0 == I0 || R0 == I1) {
792 CommonOperand = R0;
793 UncommonRealOp = R1;
794 } else if (R1 == I0 || R1 == I1) {
795 CommonOperand = R1;
796 UncommonRealOp = R0;
797 } else {
798 LLVM_DEBUG(dbgs() << " - No equal operand\n");
799 return nullptr;
800 }
801
802 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
803 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
804 Rotation == ComplexDeinterleavingRotation::Rotation_270)
805 std::swap(UncommonRealOp, UncommonImagOp);
806
807 std::pair<Value *, Value *> PartialMatch(
808 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
809 Rotation == ComplexDeinterleavingRotation::Rotation_180)
810 ? CommonOperand
811 : nullptr,
812 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
813 Rotation == ComplexDeinterleavingRotation::Rotation_270)
814 ? CommonOperand
815 : nullptr);
816
817 auto *CRInst = dyn_cast<Instruction>(CR);
818 auto *CIInst = dyn_cast<Instruction>(CI);
819
820 if (!CRInst || !CIInst) {
821 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
822 return nullptr;
823 }
824
825 CompositeNode *CNode =
826 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
827 if (!CNode) {
828 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
829 return nullptr;
830 }
831
832 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
833 if (!UncommonRes) {
834 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
835 return nullptr;
836 }
837
838 assert(PartialMatch.first && PartialMatch.second);
839 CompositeNode *CommonRes =
840 identifyNode(PartialMatch.first, PartialMatch.second);
841 if (!CommonRes) {
842 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
843 return nullptr;
844 }
845
846 CompositeNode *Node = prepareCompositeNode(
847 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
848 Node->Rotation = Rotation;
849 Node->addOperand(CommonRes);
850 Node->addOperand(UncommonRes);
851 Node->addOperand(CNode);
852 return submitCompositeNode(Node);
853}
854
855ComplexDeinterleavingGraph::CompositeNode *
856ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
857 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
858
859 // Determine rotation
861 if ((Real->getOpcode() == Instruction::FSub &&
862 Imag->getOpcode() == Instruction::FAdd) ||
863 (Real->getOpcode() == Instruction::Sub &&
864 Imag->getOpcode() == Instruction::Add))
865 Rotation = ComplexDeinterleavingRotation::Rotation_90;
866 else if ((Real->getOpcode() == Instruction::FAdd &&
867 Imag->getOpcode() == Instruction::FSub) ||
868 (Real->getOpcode() == Instruction::Add &&
869 Imag->getOpcode() == Instruction::Sub))
870 Rotation = ComplexDeinterleavingRotation::Rotation_270;
871 else {
872 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
873 return nullptr;
874 }
875
876 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
877 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
878 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
879 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
880
881 if (!AR || !AI || !BR || !BI) {
882 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
883 return nullptr;
884 }
885
886 CompositeNode *ResA = identifyNode(AR, AI);
887 if (!ResA) {
888 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
889 return nullptr;
890 }
891 CompositeNode *ResB = identifyNode(BR, BI);
892 if (!ResB) {
893 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
894 return nullptr;
895 }
896
897 CompositeNode *Node =
898 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
899 Node->Rotation = Rotation;
900 Node->addOperand(ResA);
901 Node->addOperand(ResB);
902 return submitCompositeNode(Node);
903}
904
906 unsigned OpcA = A->getOpcode();
907 unsigned OpcB = B->getOpcode();
908
909 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
910 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
911 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
912 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
913}
914
916 auto Pattern =
918
919 return match(A, Pattern) && match(B, Pattern);
920}
921
923 switch (I->getOpcode()) {
924 case Instruction::FAdd:
925 case Instruction::FSub:
926 case Instruction::FMul:
927 case Instruction::FNeg:
928 case Instruction::Add:
929 case Instruction::Sub:
930 case Instruction::Mul:
931 return true;
932 default:
933 return false;
934 }
935}
936
937ComplexDeinterleavingGraph::CompositeNode *
938ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
939 auto *FirstReal = cast<Instruction>(Vals[0].Real);
940 unsigned FirstOpc = FirstReal->getOpcode();
941 for (auto &V : Vals) {
942 auto *Real = cast<Instruction>(V.Real);
943 auto *Imag = cast<Instruction>(V.Imag);
944 if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
945 return nullptr;
946
949 return nullptr;
950
951 if (isa<FPMathOperator>(FirstReal))
952 if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
953 Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
954 return nullptr;
955 }
956
957 ComplexValues OpVals;
958 for (auto &V : Vals) {
959 auto *R0 = cast<Instruction>(V.Real)->getOperand(0);
960 auto *I0 = cast<Instruction>(V.Imag)->getOperand(0);
961 OpVals.push_back({R0, I0});
962 }
963
964 CompositeNode *Op0 = identifyNode(OpVals);
965 CompositeNode *Op1 = nullptr;
966 if (Op0 == nullptr)
967 return nullptr;
968
969 if (FirstReal->isBinaryOp()) {
970 OpVals.clear();
971 for (auto &V : Vals) {
972 auto *R1 = cast<Instruction>(V.Real)->getOperand(1);
973 auto *I1 = cast<Instruction>(V.Imag)->getOperand(1);
974 OpVals.push_back({R1, I1});
975 }
976 Op1 = identifyNode(OpVals);
977 if (Op1 == nullptr)
978 return nullptr;
979 }
980
981 auto Node =
982 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
983 Node->Opcode = FirstReal->getOpcode();
984 if (isa<FPMathOperator>(FirstReal))
985 Node->Flags = FirstReal->getFastMathFlags();
986
987 Node->addOperand(Op0);
988 if (FirstReal->isBinaryOp())
989 Node->addOperand(Op1);
990
991 return submitCompositeNode(Node);
992}
993
994ComplexDeinterleavingGraph::CompositeNode *
995ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
997 ComplexDeinterleavingOperation::CDot, V->getType())) {
998 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
999 "operation CDot with the type "
1000 << *V->getType() << "\n");
1001 return nullptr;
1002 }
1003
1004 auto *Inst = cast<Instruction>(V);
1005 auto *RealUser = cast<Instruction>(*Inst->user_begin());
1006
1007 CompositeNode *CN =
1008 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
1009
1010 CompositeNode *ANode = nullptr;
1011
1012 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1013
1014 Value *AReal = nullptr;
1015 Value *AImag = nullptr;
1016 Value *BReal = nullptr;
1017 Value *BImag = nullptr;
1018 Value *Phi = nullptr;
1019
1020 auto UnwrapCast = [](Value *V) -> Value * {
1021 if (auto *CI = dyn_cast<CastInst>(V))
1022 return CI->getOperand(0);
1023 return V;
1024 };
1025
1026 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
1028 m_Mul(m_Value(BReal), m_Value(AReal))),
1029 m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
1030
1031 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
1033 m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
1034 m_Mul(m_Value(BImag), m_Value(AReal)));
1035
1036 if (match(Inst, PatternRot0)) {
1037 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1038 } else if (match(Inst, PatternRot270)) {
1039 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1040 } else {
1041 Value *A0, *A1;
1042 // The rotations 90 and 180 share the same operation pattern, so inspect the
1043 // order of the operands, identifying where the real and imaginary
1044 // components of A go, to discern between the aforementioned rotations.
1045 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
1047 m_Mul(m_Value(BReal), m_Value(A0))),
1048 m_Mul(m_Value(BImag), m_Value(A1)));
1049
1050 if (!match(Inst, PatternRot90Rot180))
1051 return nullptr;
1052
1053 A0 = UnwrapCast(A0);
1054 A1 = UnwrapCast(A1);
1055
1056 // Test if A0 is real/A1 is imag
1057 ANode = identifyNode(A0, A1);
1058 if (!ANode) {
1059 // Test if A0 is imag/A1 is real
1060 ANode = identifyNode(A1, A0);
1061 // Unable to identify operand components, thus unable to identify rotation
1062 if (!ANode)
1063 return nullptr;
1064 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1065 AReal = A1;
1066 AImag = A0;
1067 } else {
1068 AReal = A0;
1069 AImag = A1;
1070 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1071 }
1072 }
1073
1074 AReal = UnwrapCast(AReal);
1075 AImag = UnwrapCast(AImag);
1076 BReal = UnwrapCast(BReal);
1077 BImag = UnwrapCast(BImag);
1078
1079 VectorType *VTy = cast<VectorType>(V->getType());
1080 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1081 if (AReal->getType() != ExpectedOperandTy)
1082 return nullptr;
1083 if (AImag->getType() != ExpectedOperandTy)
1084 return nullptr;
1085 if (BReal->getType() != ExpectedOperandTy)
1086 return nullptr;
1087 if (BImag->getType() != ExpectedOperandTy)
1088 return nullptr;
1089
1090 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1091 return nullptr;
1092
1093 CompositeNode *Node = identifyNode(AReal, AImag);
1094
1095 // In the case that a node was identified to figure out the rotation, ensure
1096 // that trying to identify a node with AReal and AImag post-unwrap results in
1097 // the same node
1098 if (ANode && Node != ANode) {
1099 LLVM_DEBUG(
1100 dbgs()
1101 << "Identified node is different from previously identified node. "
1102 "Unable to confidently generate a complex operation node\n");
1103 return nullptr;
1104 }
1105
1106 CN->addOperand(Node);
1107 CN->addOperand(identifyNode(BReal, BImag));
1108 CN->addOperand(identifyNode(Phi, RealUser));
1109
1110 return submitCompositeNode(CN);
1111}
1112
1113ComplexDeinterleavingGraph::CompositeNode *
1114ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1115 // Partial reductions don't support non-vector types, so check these first
1116 if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
1117 return nullptr;
1118
1119 if (!R->hasUseList() || !I->hasUseList())
1120 return nullptr;
1121
1122 auto CommonUser =
1123 findCommonBetweenCollections<Value *>(R->users(), I->users());
1124 if (!CommonUser)
1125 return nullptr;
1126
1127 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1128 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1129 return nullptr;
1130
1131 if (CompositeNode *CN = identifyDotProduct(IInst))
1132 return CN;
1133
1134 return nullptr;
1135}
1136
1137ComplexDeinterleavingGraph::CompositeNode *
1138ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
1139 auto It = CachedResult.find(Vals);
1140 if (It != CachedResult.end()) {
1141 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1142 return It->second;
1143 }
1144
1145 if (Vals.size() == 1) {
1146 assert(Factor == 2 && "Can only handle interleave factors of 2");
1147 Value *R = Vals[0].Real;
1148 Value *I = Vals[0].Imag;
1149 if (CompositeNode *CN = identifyPartialReduction(R, I))
1150 return CN;
1151 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1152 if (!IsReduction && R->getType() != I->getType())
1153 return nullptr;
1154 }
1155
1156 if (CompositeNode *CN = identifySplat(Vals))
1157 return CN;
1158
1159 for (auto &V : Vals) {
1160 auto *Real = dyn_cast<Instruction>(V.Real);
1161 auto *Imag = dyn_cast<Instruction>(V.Imag);
1162 if (!Real || !Imag)
1163 return nullptr;
1164 }
1165
1166 if (CompositeNode *CN = identifyDeinterleave(Vals))
1167 return CN;
1168
1169 if (Vals.size() == 1) {
1170 assert(Factor == 2 && "Can only handle interleave factors of 2");
1171 auto *Real = dyn_cast<Instruction>(Vals[0].Real);
1172 auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
1173 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1174 return CN;
1175
1176 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1177 return CN;
1178
1179 auto *VTy = cast<VectorType>(Real->getType());
1180 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1181
1182 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1183 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1184 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1185 ComplexDeinterleavingOperation::CAdd, NewVTy);
1186
1187 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
1188 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1189 return CN;
1190 }
1191
1192 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
1193 if (CompositeNode *CN = identifyAdd(Real, Imag))
1194 return CN;
1195 }
1196
1197 if (HasCMulSupport && HasCAddSupport) {
1198 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1199 return CN;
1200 }
1201 }
1202 }
1203
1204 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1205 return CN;
1206
1207 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1208 CachedResult[Vals] = nullptr;
1209 return nullptr;
1210}
1211
1212ComplexDeinterleavingGraph::CompositeNode *
1213ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1214 Instruction *Imag) {
1215 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1216 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1217 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1218 Opcode == Instruction::Sub;
1219 };
1220
1221 if (!IsOperationSupported(Real->getOpcode()) ||
1222 !IsOperationSupported(Imag->getOpcode()))
1223 return nullptr;
1224
1225 std::optional<FastMathFlags> Flags;
1226 if (isa<FPMathOperator>(Real)) {
1227 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1228 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1229 "not identical\n");
1230 return nullptr;
1231 }
1232
1233 Flags = Real->getFastMathFlags();
1234 if (!Flags->allowReassoc()) {
1235 LLVM_DEBUG(
1236 dbgs()
1237 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1238 return nullptr;
1239 }
1240 }
1241
1242 // Collect multiplications and addend instructions from the given instruction
1243 // while traversing it operands. Additionally, verify that all instructions
1244 // have the same fast math flags.
1245 auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
1246 AddendList &Addends) -> bool {
1247 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1248 SmallPtrSet<Value *, 8> Visited;
1249 while (!Worklist.empty()) {
1250 auto [V, IsPositive] = Worklist.pop_back_val();
1251 if (!Visited.insert(V).second)
1252 continue;
1253
1255 if (!I) {
1256 Addends.emplace_back(V, IsPositive);
1257 continue;
1258 }
1259
1260 // If an instruction has more than one user, it indicates that it either
1261 // has an external user, which will be later checked by the checkNodes
1262 // function, or it is a subexpression utilized by multiple expressions. In
1263 // the latter case, we will attempt to separately identify the complex
1264 // operation from here in order to create a shared
1265 // ComplexDeinterleavingCompositeNode.
1266 if (I != Insn && I->hasNUsesOrMore(2)) {
1267 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1268 Addends.emplace_back(I, IsPositive);
1269 continue;
1270 }
1271 switch (I->getOpcode()) {
1272 case Instruction::FAdd:
1273 case Instruction::Add:
1274 Worklist.emplace_back(I->getOperand(1), IsPositive);
1275 Worklist.emplace_back(I->getOperand(0), IsPositive);
1276 break;
1277 case Instruction::FSub:
1278 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1279 Worklist.emplace_back(I->getOperand(0), IsPositive);
1280 break;
1281 case Instruction::Sub:
1282 if (isNeg(I)) {
1283 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1284 } else {
1285 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1286 Worklist.emplace_back(I->getOperand(0), IsPositive);
1287 }
1288 break;
1289 case Instruction::FMul:
1290 case Instruction::Mul: {
1291 Value *A, *B;
1292 if (isNeg(I->getOperand(0))) {
1293 A = getNegOperand(I->getOperand(0));
1294 IsPositive = !IsPositive;
1295 } else {
1296 A = I->getOperand(0);
1297 }
1298
1299 if (isNeg(I->getOperand(1))) {
1300 B = getNegOperand(I->getOperand(1));
1301 IsPositive = !IsPositive;
1302 } else {
1303 B = I->getOperand(1);
1304 }
1305 Muls.push_back(Product{A, B, IsPositive});
1306 break;
1307 }
1308 case Instruction::FNeg:
1309 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1310 break;
1311 default:
1312 Addends.emplace_back(I, IsPositive);
1313 continue;
1314 }
1315
1316 if (Flags && I->getFastMathFlags() != *Flags) {
1317 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1318 "inconsistent with the root instructions' flags: "
1319 << *I << "\n");
1320 return false;
1321 }
1322 }
1323 return true;
1324 };
1325
1326 SmallVector<Product> RealMuls, ImagMuls;
1327 AddendList RealAddends, ImagAddends;
1328 if (!Collect(Real, RealMuls, RealAddends) ||
1329 !Collect(Imag, ImagMuls, ImagAddends))
1330 return nullptr;
1331
1332 if (RealAddends.size() != ImagAddends.size())
1333 return nullptr;
1334
1335 CompositeNode *FinalNode = nullptr;
1336 if (!RealMuls.empty() || !ImagMuls.empty()) {
1337 // If there are multiplicands, extract positive addend and use it as an
1338 // accumulator
1339 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1340 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1341 if (!FinalNode)
1342 return nullptr;
1343 }
1344
1345 // Identify and process remaining additions
1346 if (!RealAddends.empty() || !ImagAddends.empty()) {
1347 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1348 if (!FinalNode)
1349 return nullptr;
1350 }
1351 assert(FinalNode && "FinalNode can not be nullptr here");
1352 assert(FinalNode->Vals.size() == 1);
1353 // Set the Real and Imag fields of the final node and submit it
1354 FinalNode->Vals[0].Real = Real;
1355 FinalNode->Vals[0].Imag = Imag;
1356 submitCompositeNode(FinalNode);
1357 return FinalNode;
1358}
1359
1360bool ComplexDeinterleavingGraph::collectPartialMuls(
1361 ArrayRef<Product> RealMuls, ArrayRef<Product> ImagMuls,
1362 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1363 // Helper function to extract a common operand from two products
1364 auto FindCommonInstruction = [](const Product &Real,
1365 const Product &Imag) -> Value * {
1366 if (Real.Multiplicand == Imag.Multiplicand ||
1367 Real.Multiplicand == Imag.Multiplier)
1368 return Real.Multiplicand;
1369
1370 if (Real.Multiplier == Imag.Multiplicand ||
1371 Real.Multiplier == Imag.Multiplier)
1372 return Real.Multiplier;
1373
1374 return nullptr;
1375 };
1376
1377 // Iterating over real and imaginary multiplications to find common operands
1378 // If a common operand is found, a partial multiplication candidate is created
1379 // and added to the candidates vector The function returns false if no common
1380 // operands are found for any product
1381 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1382 bool FoundCommon = false;
1383 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1384 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1385 if (!Common)
1386 continue;
1387
1388 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1389 : RealMuls[i].Multiplicand;
1390 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1391 : ImagMuls[j].Multiplicand;
1392
1393 auto Node = identifyNode(A, B);
1394 if (Node) {
1395 FoundCommon = true;
1396 PartialMulCandidates.push_back({Common, Node, i, j, false});
1397 }
1398
1399 Node = identifyNode(B, A);
1400 if (Node) {
1401 FoundCommon = true;
1402 PartialMulCandidates.push_back({Common, Node, i, j, true});
1403 }
1404 }
1405 if (!FoundCommon)
1406 return false;
1407 }
1408 return true;
1409}
1410
1411ComplexDeinterleavingGraph::CompositeNode *
1412ComplexDeinterleavingGraph::identifyMultiplications(
1413 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1414 CompositeNode *Accumulator = nullptr) {
1415 if (RealMuls.size() != ImagMuls.size())
1416 return nullptr;
1417
1419 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1420 return nullptr;
1421
1422 // Map to store common instruction to node pointers
1423 DenseMap<Value *, CompositeNode *> CommonToNode;
1424 SmallVector<bool> Processed(Info.size(), false);
1425 for (unsigned I = 0; I < Info.size(); ++I) {
1426 if (Processed[I])
1427 continue;
1428
1429 PartialMulCandidate &InfoA = Info[I];
1430 for (unsigned J = I + 1; J < Info.size(); ++J) {
1431 if (Processed[J])
1432 continue;
1433
1434 PartialMulCandidate &InfoB = Info[J];
1435 auto *InfoReal = &InfoA;
1436 auto *InfoImag = &InfoB;
1437
1438 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1439 if (!NodeFromCommon) {
1440 std::swap(InfoReal, InfoImag);
1441 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1442 }
1443 if (!NodeFromCommon)
1444 continue;
1445
1446 CommonToNode[InfoReal->Common] = NodeFromCommon;
1447 CommonToNode[InfoImag->Common] = NodeFromCommon;
1448 Processed[I] = true;
1449 Processed[J] = true;
1450 }
1451 }
1452
1453 SmallVector<bool> ProcessedReal(RealMuls.size(), false);
1454 SmallVector<bool> ProcessedImag(ImagMuls.size(), false);
1455 CompositeNode *Result = Accumulator;
1456 for (auto &PMI : Info) {
1457 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1458 continue;
1459
1460 auto It = CommonToNode.find(PMI.Common);
1461 // TODO: Process independent complex multiplications. Cases like this:
1462 // A.real() * B where both A and B are complex numbers.
1463 if (It == CommonToNode.end()) {
1464 LLVM_DEBUG({
1465 dbgs() << "Unprocessed independent partial multiplication:\n";
1466 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1467 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1468 << " multiplied by " << *Mul->Multiplicand << "\n";
1469 });
1470 return nullptr;
1471 }
1472
1473 auto &RealMul = RealMuls[PMI.RealIdx];
1474 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1475
1476 auto NodeA = It->second;
1477 auto NodeB = PMI.Node;
1478 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1479 // The following table illustrates the relationship between multiplications
1480 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1481 // can see:
1482 //
1483 // Rotation | Real | Imag |
1484 // ---------+--------+--------+
1485 // 0 | x * u | x * v |
1486 // 90 | -y * v | y * u |
1487 // 180 | -x * u | -x * v |
1488 // 270 | y * v | -y * u |
1489 //
1490 // Check if the candidate can indeed be represented by partial
1491 // multiplication
1492 // TODO: Add support for multiplication by complex one
1493 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1494 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1495 continue;
1496
1497 // Determine the rotation based on the multiplications
1499 if (IsMultiplicandReal) {
1500 // Detect 0 and 180 degrees rotation
1501 if (RealMul.IsPositive && ImagMul.IsPositive)
1503 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1505 else
1506 continue;
1507
1508 } else {
1509 // Detect 90 and 270 degrees rotation
1510 if (!RealMul.IsPositive && ImagMul.IsPositive)
1512 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1514 else
1515 continue;
1516 }
1517
1518 LLVM_DEBUG({
1519 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1520 dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
1521 dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
1522 dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
1523 dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
1524 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1525 });
1526
1527 CompositeNode *NodeMul = prepareCompositeNode(
1528 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1529 NodeMul->Rotation = Rotation;
1530 NodeMul->addOperand(NodeA);
1531 NodeMul->addOperand(NodeB);
1532 if (Result)
1533 NodeMul->addOperand(Result);
1534 submitCompositeNode(NodeMul);
1535 Result = NodeMul;
1536 ProcessedReal[PMI.RealIdx] = true;
1537 ProcessedImag[PMI.ImagIdx] = true;
1538 }
1539
1540 // Ensure all products have been processed, if not return nullptr.
1541 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1542 !all_of(ProcessedImag, [](bool V) { return V; })) {
1543
1544 // Dump debug information about which partial multiplications are not
1545 // processed.
1546 LLVM_DEBUG({
1547 dbgs() << "Unprocessed products (Real):\n";
1548 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1549 if (!ProcessedReal[i])
1550 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1551 << *RealMuls[i].Multiplier << " multiplied by "
1552 << *RealMuls[i].Multiplicand << "\n";
1553 }
1554 dbgs() << "Unprocessed products (Imag):\n";
1555 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1556 if (!ProcessedImag[i])
1557 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1558 << *ImagMuls[i].Multiplier << " multiplied by "
1559 << *ImagMuls[i].Multiplicand << "\n";
1560 }
1561 });
1562 return nullptr;
1563 }
1564
1565 return Result;
1566}
1567
1568ComplexDeinterleavingGraph::CompositeNode *
1569ComplexDeinterleavingGraph::identifyAdditions(
1570 AddendList &RealAddends, AddendList &ImagAddends,
1571 std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
1572 if (RealAddends.size() != ImagAddends.size())
1573 return nullptr;
1574
1575 CompositeNode *Result = nullptr;
1576 // If we have accumulator use it as first addend
1577 if (Accumulator)
1579 // Otherwise find an element with both positive real and imaginary parts.
1580 else
1581 Result = extractPositiveAddend(RealAddends, ImagAddends);
1582
1583 if (!Result)
1584 return nullptr;
1585
1586 while (!RealAddends.empty()) {
1587 auto ItR = RealAddends.begin();
1588 auto [R, IsPositiveR] = *ItR;
1589
1590 bool FoundImag = false;
1591 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1592 auto [I, IsPositiveI] = *ItI;
1594 if (IsPositiveR && IsPositiveI)
1595 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1596 else if (!IsPositiveR && IsPositiveI)
1597 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1598 else if (!IsPositiveR && !IsPositiveI)
1599 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1600 else
1601 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1602
1603 CompositeNode *AddNode = nullptr;
1604 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1605 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1606 AddNode = identifyNode(R, I);
1607 } else {
1608 AddNode = identifyNode(I, R);
1609 }
1610 if (AddNode) {
1611 LLVM_DEBUG({
1612 dbgs() << "Identified addition:\n";
1613 dbgs().indent(4) << "X: " << *R << "\n";
1614 dbgs().indent(4) << "Y: " << *I << "\n";
1615 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1616 });
1617
1618 CompositeNode *TmpNode = nullptr;
1620 TmpNode = prepareCompositeNode(
1621 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1622 if (Flags) {
1623 TmpNode->Opcode = Instruction::FAdd;
1624 TmpNode->Flags = *Flags;
1625 } else {
1626 TmpNode->Opcode = Instruction::Add;
1627 }
1628 } else if (Rotation ==
1630 TmpNode = prepareCompositeNode(
1631 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1632 if (Flags) {
1633 TmpNode->Opcode = Instruction::FSub;
1634 TmpNode->Flags = *Flags;
1635 } else {
1636 TmpNode->Opcode = Instruction::Sub;
1637 }
1638 } else {
1639 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1640 nullptr, nullptr);
1641 TmpNode->Rotation = Rotation;
1642 }
1643
1644 TmpNode->addOperand(Result);
1645 TmpNode->addOperand(AddNode);
1646 submitCompositeNode(TmpNode);
1647 Result = TmpNode;
1648 RealAddends.erase(ItR);
1649 ImagAddends.erase(ItI);
1650 FoundImag = true;
1651 break;
1652 }
1653 }
1654 if (!FoundImag)
1655 return nullptr;
1656 }
1657 return Result;
1658}
1659
1660ComplexDeinterleavingGraph::CompositeNode *
1661ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1662 AddendList &ImagAddends) {
1663 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1664 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1665 auto [R, IsPositiveR] = *ItR;
1666 auto [I, IsPositiveI] = *ItI;
1667 if (IsPositiveR && IsPositiveI) {
1668 auto Result = identifyNode(R, I);
1669 if (Result) {
1670 RealAddends.erase(ItR);
1671 ImagAddends.erase(ItI);
1672 return Result;
1673 }
1674 }
1675 }
1676 }
1677 return nullptr;
1678}
1679
1680bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1681 // This potential root instruction might already have been recognized as
1682 // reduction. Because RootToNode maps both Real and Imaginary parts to
1683 // CompositeNode we should choose only one either Real or Imag instruction to
1684 // use as an anchor for generating complex instruction.
1685 auto It = RootToNode.find(RootI);
1686 if (It != RootToNode.end()) {
1687 auto RootNode = It->second;
1688 assert(RootNode->Operation ==
1689 ComplexDeinterleavingOperation::ReductionOperation ||
1690 RootNode->Operation ==
1691 ComplexDeinterleavingOperation::ReductionSingle);
1692 assert(RootNode->Vals.size() == 1 &&
1693 "Cannot handle reductions involving multiple complex values");
1694 // Find out which part, Real or Imag, comes later, and only if we come to
1695 // the latest part, add it to OrderedRoots.
1696 auto *R = cast<Instruction>(RootNode->Vals[0].Real);
1697 auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(RootNode->Vals[0].Imag)
1698 : nullptr;
1699
1700 Instruction *ReplacementAnchor;
1701 if (I)
1702 ReplacementAnchor = R->comesBefore(I) ? I : R;
1703 else
1704 ReplacementAnchor = R;
1705
1706 if (ReplacementAnchor != RootI)
1707 return false;
1708 OrderedRoots.push_back(RootI);
1709 return true;
1710 }
1711
1712 auto RootNode = identifyRoot(RootI);
1713 if (!RootNode)
1714 return false;
1715
1716 LLVM_DEBUG({
1717 Function *F = RootI->getFunction();
1718 BasicBlock *B = RootI->getParent();
1719 dbgs() << "Complex deinterleaving graph for " << F->getName()
1720 << "::" << B->getName() << ".\n";
1721 dump(dbgs());
1722 dbgs() << "\n";
1723 });
1724 RootToNode[RootI] = RootNode;
1725 OrderedRoots.push_back(RootI);
1726 return true;
1727}
1728
1729bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1730 bool FoundPotentialReduction = false;
1731 if (Factor != 2)
1732 return false;
1733
1734 auto *Br = dyn_cast<CondBrInst>(B->getTerminator());
1735 if (!Br)
1736 return false;
1737
1738 // Identify simple one-block loop
1739 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1740 return false;
1741
1742 for (auto &PHI : B->phis()) {
1743 if (PHI.getNumIncomingValues() != 2)
1744 continue;
1745
1746 if (!PHI.getType()->isVectorTy())
1747 continue;
1748
1749 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1750 if (!ReductionOp)
1751 continue;
1752
1753 // Check if final instruction is reduced outside of current block
1754 Instruction *FinalReduction = nullptr;
1755 auto NumUsers = 0u;
1756 for (auto *U : ReductionOp->users()) {
1757 ++NumUsers;
1758 if (U == &PHI)
1759 continue;
1760 FinalReduction = dyn_cast<Instruction>(U);
1761 }
1762
1763 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1764 isa<PHINode>(FinalReduction))
1765 continue;
1766
1767 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1768 BackEdge = B;
1769 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1770 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1771 Incoming = PHI.getIncomingBlock(IncomingIdx);
1772 FoundPotentialReduction = true;
1773
1774 // If the initial value of PHINode is an Instruction, consider it a leaf
1775 // value of a complex deinterleaving graph.
1776 if (auto *InitPHI =
1777 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1778 FinalInstructions.insert(InitPHI);
1779 }
1780 return FoundPotentialReduction;
1781}
1782
1783void ComplexDeinterleavingGraph::identifyReductionNodes() {
1784 assert(Factor == 2 && "Cannot handle multiple complex values");
1785
1786 SmallVector<bool> Processed(ReductionInfo.size(), false);
1787 SmallVector<Instruction *> OperationInstruction;
1788 for (auto &P : ReductionInfo)
1789 OperationInstruction.push_back(P.first);
1790
1791 // Identify a complex computation by evaluating two reduction operations that
1792 // potentially could be involved
1793 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1794 if (Processed[i])
1795 continue;
1796 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1797 if (Processed[j])
1798 continue;
1799 auto *Real = OperationInstruction[i];
1800 auto *Imag = OperationInstruction[j];
1801 if (Real->getType() != Imag->getType())
1802 continue;
1803
1804 RealPHI = ReductionInfo[Real].first;
1805 ImagPHI = ReductionInfo[Imag].first;
1806 PHIsFound = false;
1807 auto Node = identifyNode(Real, Imag);
1808 if (!Node) {
1809 std::swap(Real, Imag);
1810 std::swap(RealPHI, ImagPHI);
1811 Node = identifyNode(Real, Imag);
1812 }
1813
1814 // If a node is identified and reduction PHINode is used in the chain of
1815 // operations, mark its operation instructions as used to prevent
1816 // re-identification and attach the node to the real part
1817 if (Node && PHIsFound) {
1818 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1819 << *Real << " / " << *Imag << "\n");
1820 Processed[i] = true;
1821 Processed[j] = true;
1822 auto RootNode = prepareCompositeNode(
1823 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1824 RootNode->addOperand(Node);
1825 RootToNode[Real] = RootNode;
1826 RootToNode[Imag] = RootNode;
1827 submitCompositeNode(RootNode);
1828 break;
1829 }
1830 }
1831
1832 auto *Real = OperationInstruction[i];
1833 // We want to check that we have 2 operands, but the function attributes
1834 // being counted as operands bloats this value.
1835 if (Processed[i] || Real->getNumOperands() < 2)
1836 continue;
1837
1838 // Can only combined integer reductions at the moment.
1839 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1840 continue;
1841
1842 RealPHI = ReductionInfo[Real].first;
1843 ImagPHI = nullptr;
1844 PHIsFound = false;
1845 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1846 if (Node && PHIsFound) {
1847 LLVM_DEBUG(
1848 dbgs() << "Identified single reduction starting from instruction: "
1849 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1850
1851 // Reducing to a single vector is not supported, only permit reducing down
1852 // to scalar values.
1853 // Doing this here will leave the prior node in the graph,
1854 // however with no uses the node will be unreachable by the replacement
1855 // process. That along with the usage outside the graph should prevent the
1856 // replacement process from kicking off at all for this graph.
1857 // TODO Add support for reducing to a single vector value
1858 if (ReductionInfo[Real].second->getType()->isVectorTy())
1859 continue;
1860
1861 Processed[i] = true;
1862 auto RootNode = prepareCompositeNode(
1863 ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1864 RootNode->addOperand(Node);
1865 RootToNode[Real] = RootNode;
1866 submitCompositeNode(RootNode);
1867 }
1868 }
1869
1870 RealPHI = nullptr;
1871 ImagPHI = nullptr;
1872}
1873
1874bool ComplexDeinterleavingGraph::checkNodes() {
1875 bool FoundDeinterleaveNode = false;
1876 for (CompositeNode *N : CompositeNodes) {
1877 if (!N->areOperandsValid())
1878 return false;
1879
1880 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1881 FoundDeinterleaveNode = true;
1882 }
1883
1884 // We need a deinterleave node in order to guarantee that we're working with
1885 // complex numbers.
1886 if (!FoundDeinterleaveNode) {
1887 LLVM_DEBUG(
1888 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1889 "guarantee safety during graph transformation.\n");
1890 return false;
1891 }
1892
1893 // Collect all instructions from roots to leaves
1894 SmallPtrSet<Instruction *, 16> AllInstructions;
1895 SmallVector<Instruction *, 8> Worklist;
1896 for (auto &Pair : RootToNode)
1897 Worklist.push_back(Pair.first);
1898
1899 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1900 // chains
1901 while (!Worklist.empty()) {
1902 auto *I = Worklist.pop_back_val();
1903
1904 if (!AllInstructions.insert(I).second)
1905 continue;
1906
1907 for (Value *Op : I->operands()) {
1908 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1909 if (!FinalInstructions.count(I))
1910 Worklist.emplace_back(OpI);
1911 }
1912 }
1913 }
1914
1915 // Find instructions that have users outside of chain
1916 for (auto *I : AllInstructions) {
1917 // Skip root nodes
1918 if (RootToNode.count(I))
1919 continue;
1920
1921 for (User *U : I->users()) {
1922 if (AllInstructions.count(cast<Instruction>(U)))
1923 continue;
1924
1925 // Found an instruction that is not used by XCMLA/XCADD chain
1926 Worklist.emplace_back(I);
1927 break;
1928 }
1929 }
1930
1931 // If any instructions are found to be used outside, find and remove roots
1932 // that somehow connect to those instructions.
1933 SmallPtrSet<Instruction *, 16> Visited;
1934 while (!Worklist.empty()) {
1935 auto *I = Worklist.pop_back_val();
1936 if (!Visited.insert(I).second)
1937 continue;
1938
1939 // Found an impacted root node. Removing it from the nodes to be
1940 // deinterleaved
1941 if (RootToNode.count(I)) {
1942 LLVM_DEBUG(dbgs() << "Instruction " << *I
1943 << " could be deinterleaved but its chain of complex "
1944 "operations have an outside user\n");
1945 RootToNode.erase(I);
1946 }
1947
1948 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1949 continue;
1950
1951 for (User *U : I->users())
1952 Worklist.emplace_back(cast<Instruction>(U));
1953
1954 for (Value *Op : I->operands()) {
1955 if (auto *OpI = dyn_cast<Instruction>(Op))
1956 Worklist.emplace_back(OpI);
1957 }
1958 }
1959 return !RootToNode.empty();
1960}
1961
1962ComplexDeinterleavingGraph::CompositeNode *
1963ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1964 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1966 Intrinsic->getIntrinsicID())
1967 return nullptr;
1968
1969 ComplexValues Vals;
1970 for (unsigned I = 0; I < Factor; I += 2) {
1971 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(I));
1972 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(I + 1));
1973 if (!Real || !Imag)
1974 return nullptr;
1975 Vals.push_back({Real, Imag});
1976 }
1977
1978 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
1979 if (!Node1)
1980 return nullptr;
1981 return Node1;
1982 }
1983
1984 // TODO: We could also add support for fixed-width interleave factors of 4
1985 // and above, but currently for symmetric operations the interleaves and
1986 // deinterleaves are already removed by VectorCombine. If we extend this to
1987 // permit complex multiplications, reductions, etc. then we should also add
1988 // support for fixed-width here.
1989 if (Factor != 2)
1990 return nullptr;
1991
1992 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1993 if (!SVI)
1994 return nullptr;
1995
1996 // Look for a shufflevector that takes separate vectors of the real and
1997 // imaginary components and recombines them into a single vector.
1998 if (!isInterleavingMask(SVI->getShuffleMask()))
1999 return nullptr;
2000
2001 Instruction *Real;
2002 Instruction *Imag;
2003 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
2004 return nullptr;
2005
2006 return identifyNode(Real, Imag);
2007}
2008
2009ComplexDeinterleavingGraph::CompositeNode *
2010ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
2011 Instruction *II = nullptr;
2012
2013 // Must be at least one complex value.
2014 auto CheckExtract = [&](Value *V, unsigned ExpectedIdx,
2015 Instruction *ExpectedInsn) -> ExtractValueInst * {
2016 auto *EVI = dyn_cast<ExtractValueInst>(V);
2017 if (!EVI || EVI->getNumIndices() != 1 ||
2018 EVI->getIndices()[0] != ExpectedIdx ||
2019 !isa<Instruction>(EVI->getAggregateOperand()) ||
2020 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2021 return nullptr;
2022 return EVI;
2023 };
2024
2025 for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
2026 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
2027 if (RealEVI && Idx == 0)
2029 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
2030 II = nullptr;
2031 break;
2032 }
2033 }
2034
2035 if (auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(II)) {
2036 if (IntrinsicII->getIntrinsicID() !=
2038 return nullptr;
2039
2040 // The remaining should match too.
2041 CompositeNode *PlaceholderNode = prepareCompositeNode(
2043 PlaceholderNode->ReplacementNode = II->getOperand(0);
2044 for (auto &V : Vals) {
2045 FinalInstructions.insert(cast<Instruction>(V.Real));
2046 FinalInstructions.insert(cast<Instruction>(V.Imag));
2047 }
2048 return submitCompositeNode(PlaceholderNode);
2049 }
2050
2051 if (Vals.size() != 1)
2052 return nullptr;
2053
2054 Value *Real = Vals[0].Real;
2055 Value *Imag = Vals[0].Imag;
2056 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
2057 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
2058 if (!RealShuffle || !ImagShuffle) {
2059 if (RealShuffle || ImagShuffle)
2060 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
2061 return nullptr;
2062 }
2063
2064 Value *RealOp1 = RealShuffle->getOperand(1);
2065 if (!isa<UndefValue>(RealOp1) && !match(RealOp1, m_Zero())) {
2066 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
2067 return nullptr;
2068 }
2069 Value *ImagOp1 = ImagShuffle->getOperand(1);
2070 if (!isa<UndefValue>(ImagOp1) && !match(ImagOp1, m_Zero())) {
2071 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
2072 return nullptr;
2073 }
2074
2075 Value *RealOp0 = RealShuffle->getOperand(0);
2076 Value *ImagOp0 = ImagShuffle->getOperand(0);
2077
2078 if (RealOp0 != ImagOp0) {
2079 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
2080 return nullptr;
2081 }
2082
2083 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2084 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2085 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
2086 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
2087 return nullptr;
2088 }
2089
2090 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2091 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
2092 return nullptr;
2093 }
2094
2095 // Type checking, the shuffle type should be a vector type of the same
2096 // scalar type, but half the size
2097 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2098 Value *Op = Shuffle->getOperand(0);
2099 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
2100 auto *OpTy = cast<FixedVectorType>(Op->getType());
2101
2102 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2103 return false;
2104 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2105 return false;
2106
2107 return true;
2108 };
2109
2110 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
2111 if (!CheckType(Shuffle))
2112 return false;
2113
2114 ArrayRef<int> Mask = Shuffle->getShuffleMask();
2115 int Last = *Mask.rbegin();
2116
2117 Value *Op = Shuffle->getOperand(0);
2118 auto *OpTy = cast<FixedVectorType>(Op->getType());
2119 int NumElements = OpTy->getNumElements();
2120
2121 // Ensure that the deinterleaving shuffle only pulls from the first
2122 // shuffle operand.
2123 return Last < NumElements;
2124 };
2125
2126 if (RealShuffle->getType() != ImagShuffle->getType()) {
2127 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
2128 return nullptr;
2129 }
2130 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2131 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
2132 return nullptr;
2133 }
2134 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2135 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
2136 return nullptr;
2137 }
2138
2139 CompositeNode *PlaceholderNode =
2141 RealShuffle, ImagShuffle);
2142 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2143 FinalInstructions.insert(RealShuffle);
2144 FinalInstructions.insert(ImagShuffle);
2145 return submitCompositeNode(PlaceholderNode);
2146}
2147
2148ComplexDeinterleavingGraph::CompositeNode *
2149ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
2150 auto IsSplat = [](Value *V) -> bool {
2151 // Fixed-width vector with constants
2153 return true;
2154
2155 if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
2156 return isa<VectorType>(V->getType());
2157
2158 VectorType *VTy;
2159 ArrayRef<int> Mask;
2160 // Splats are represented differently depending on whether the repeated
2161 // value is a constant or an Instruction
2162 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
2163 if (Const->getOpcode() != Instruction::ShuffleVector)
2164 return false;
2165 VTy = cast<VectorType>(Const->getType());
2166 Mask = Const->getShuffleMask();
2167 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2168 VTy = Shuf->getType();
2169 Mask = Shuf->getShuffleMask();
2170 } else {
2171 return false;
2172 }
2173
2174 // When the data type is <1 x Type>, it's not possible to differentiate
2175 // between the ComplexDeinterleaving::Deinterleave and
2176 // ComplexDeinterleaving::Splat operations.
2177 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2178 return false;
2179
2180 return all_equal(Mask) && Mask[0] == 0;
2181 };
2182
2183 // The splats must meet the following requirements:
2184 // 1. Must either be all instructions or all values.
2185 // 2. Non-constant splats must live in the same block.
2186 if (auto *FirstValAsInstruction = dyn_cast<Instruction>(Vals[0].Real)) {
2187 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2188 for (auto &V : Vals) {
2189 if (!IsSplat(V.Real) || !IsSplat(V.Imag))
2190 return nullptr;
2191
2192 auto *Real = dyn_cast<Instruction>(V.Real);
2193 auto *Imag = dyn_cast<Instruction>(V.Imag);
2194 if (!Real || !Imag || Real->getParent() != FirstBB ||
2195 Imag->getParent() != FirstBB)
2196 return nullptr;
2197 }
2198 } else {
2199 for (auto &V : Vals) {
2200 if (!IsSplat(V.Real) || !IsSplat(V.Imag) || isa<Instruction>(V.Real) ||
2201 isa<Instruction>(V.Imag))
2202 return nullptr;
2203 }
2204 }
2205
2206 for (auto &V : Vals) {
2207 auto *Real = dyn_cast<Instruction>(V.Real);
2208 auto *Imag = dyn_cast<Instruction>(V.Imag);
2209 if (Real && Imag) {
2210 FinalInstructions.insert(Real);
2211 FinalInstructions.insert(Imag);
2212 }
2213 }
2214 CompositeNode *PlaceholderNode =
2215 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2216 return submitCompositeNode(PlaceholderNode);
2217}
2218
2219ComplexDeinterleavingGraph::CompositeNode *
2220ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2221 Instruction *Imag) {
2222 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2223 return nullptr;
2224
2225 PHIsFound = true;
2226 CompositeNode *PlaceholderNode = prepareCompositeNode(
2227 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2228 return submitCompositeNode(PlaceholderNode);
2229}
2230
2231ComplexDeinterleavingGraph::CompositeNode *
2232ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2233 Instruction *Imag) {
2234 auto *SelectReal = dyn_cast<SelectInst>(Real);
2235 auto *SelectImag = dyn_cast<SelectInst>(Imag);
2236 if (!SelectReal || !SelectImag)
2237 return nullptr;
2238
2239 Instruction *MaskA, *MaskB;
2240 Instruction *AR, *AI, *RA, *BI;
2241 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
2242 m_Instruction(RA))) ||
2243 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
2244 m_Instruction(BI))))
2245 return nullptr;
2246
2247 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
2248 return nullptr;
2249
2250 if (!MaskA->getType()->isVectorTy())
2251 return nullptr;
2252
2253 auto NodeA = identifyNode(AR, AI);
2254 if (!NodeA)
2255 return nullptr;
2256
2257 auto NodeB = identifyNode(RA, BI);
2258 if (!NodeB)
2259 return nullptr;
2260
2261 CompositeNode *PlaceholderNode = prepareCompositeNode(
2262 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2263 PlaceholderNode->addOperand(NodeA);
2264 PlaceholderNode->addOperand(NodeB);
2265 FinalInstructions.insert(MaskA);
2266 FinalInstructions.insert(MaskB);
2267 return submitCompositeNode(PlaceholderNode);
2268}
2269
2270static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2271 std::optional<FastMathFlags> Flags,
2272 Value *InputA, Value *InputB) {
2273 Value *I;
2274 switch (Opcode) {
2275 case Instruction::FNeg:
2276 I = B.CreateFNeg(InputA);
2277 break;
2278 case Instruction::FAdd:
2279 I = B.CreateFAdd(InputA, InputB);
2280 break;
2281 case Instruction::Add:
2282 I = B.CreateAdd(InputA, InputB);
2283 break;
2284 case Instruction::FSub:
2285 I = B.CreateFSub(InputA, InputB);
2286 break;
2287 case Instruction::Sub:
2288 I = B.CreateSub(InputA, InputB);
2289 break;
2290 case Instruction::FMul:
2291 I = B.CreateFMul(InputA, InputB);
2292 break;
2293 case Instruction::Mul:
2294 I = B.CreateMul(InputA, InputB);
2295 break;
2296 default:
2297 llvm_unreachable("Incorrect symmetric opcode");
2298 }
2299 if (Flags)
2300 cast<Instruction>(I)->setFastMathFlags(*Flags);
2301 return I;
2302}
2303
2304Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2305 CompositeNode *Node) {
2306 if (Node->ReplacementNode)
2307 return Node->ReplacementNode;
2308
2309 auto ReplaceOperandIfExist = [&](CompositeNode *Node,
2310 unsigned Idx) -> Value * {
2311 return Node->Operands.size() > Idx
2312 ? replaceNode(Builder, Node->Operands[Idx])
2313 : nullptr;
2314 };
2315
2316 Value *ReplacementNode = nullptr;
2317 switch (Node->Operation) {
2318 case ComplexDeinterleavingOperation::CDot: {
2319 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2320 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2321 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2322 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2323 "Node inputs need to be of the same type"));
2324 ReplacementNode = TL->createComplexDeinterleavingIR(
2325 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
2326 break;
2327 }
2328 case ComplexDeinterleavingOperation::CAdd:
2329 case ComplexDeinterleavingOperation::CMulPartial:
2330 case ComplexDeinterleavingOperation::Symmetric: {
2331 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2332 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2333 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2334 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2335 "Node inputs need to be of the same type"));
2337 (Input0->getType() == Accumulator->getType() &&
2338 "Accumulator and input need to be of the same type"));
2339 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2340 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
2341 Input0, Input1);
2342 else
2343 ReplacementNode = TL->createComplexDeinterleavingIR(
2344 Builder, Node->Operation, Node->Rotation, Input0, Input1,
2345 Accumulator);
2346 break;
2347 }
2348 case ComplexDeinterleavingOperation::Deinterleave:
2349 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2350 break;
2351 case ComplexDeinterleavingOperation::Splat: {
2353 for (auto &V : Node->Vals) {
2354 Ops.push_back(V.Real);
2355 Ops.push_back(V.Imag);
2356 }
2357 auto *R = dyn_cast<Instruction>(Node->Vals[0].Real);
2358 auto *I = dyn_cast<Instruction>(Node->Vals[0].Imag);
2359 if (R && I) {
2360 // Splats that are not constant are interleaved where they are located
2361 Instruction *InsertPoint = R;
2362 for (auto V : Node->Vals) {
2363 if (InsertPoint->comesBefore(cast<Instruction>(V.Real)))
2364 InsertPoint = cast<Instruction>(V.Real);
2365 if (InsertPoint->comesBefore(cast<Instruction>(V.Imag)))
2366 InsertPoint = cast<Instruction>(V.Imag);
2367 }
2368 InsertPoint = InsertPoint->getNextNode();
2369 IRBuilder<> IRB(InsertPoint);
2370 ReplacementNode = IRB.CreateVectorInterleave(Ops);
2371 } else {
2372 ReplacementNode = Builder.CreateVectorInterleave(Ops);
2373 }
2374 break;
2375 }
2376 case ComplexDeinterleavingOperation::ReductionPHI: {
2377 // If Operation is ReductionPHI, a new empty PHINode is created.
2378 // It is filled later when the ReductionOperation is processed.
2379 auto *OldPHI = cast<PHINode>(Node->Vals[0].Real);
2380 auto *VTy = cast<VectorType>(Node->Vals[0].Real->getType());
2381 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2382 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
2383 OldToNewPHI[OldPHI] = NewPHI;
2384 ReplacementNode = NewPHI;
2385 break;
2386 }
2387 case ComplexDeinterleavingOperation::ReductionSingle:
2388 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2389 processReductionSingle(ReplacementNode, Node);
2390 break;
2391 case ComplexDeinterleavingOperation::ReductionOperation:
2392 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2393 processReductionOperation(ReplacementNode, Node);
2394 break;
2395 case ComplexDeinterleavingOperation::ReductionSelect: {
2396 auto *MaskReal = cast<Instruction>(Node->Vals[0].Real)->getOperand(0);
2397 auto *MaskImag = cast<Instruction>(Node->Vals[0].Imag)->getOperand(0);
2398 auto *A = replaceNode(Builder, Node->Operands[0]);
2399 auto *B = replaceNode(Builder, Node->Operands[1]);
2400 auto *NewMask = Builder.CreateVectorInterleave({MaskReal, MaskImag});
2401 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
2402 break;
2403 }
2404 }
2405
2406 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2407 NumComplexTransformations += 1;
2408 Node->ReplacementNode = ReplacementNode;
2409 return ReplacementNode;
2410}
2411
2412void ComplexDeinterleavingGraph::processReductionSingle(
2413 Value *OperationReplacement, CompositeNode *Node) {
2414 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2415 auto *OldPHI = ReductionInfo[Real].first;
2416 auto *NewPHI = OldToNewPHI[OldPHI];
2417 auto *VTy = cast<VectorType>(Real->getType());
2418 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2419
2420 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2421
2422 IRBuilder<> Builder(Incoming->getTerminator());
2423
2424 Value *NewInit = nullptr;
2425 if (auto *C = dyn_cast<Constant>(Init)) {
2426 if (C->isNullValue())
2427 NewInit = Constant::getNullValue(NewVTy);
2428 }
2429
2430 if (!NewInit)
2431 NewInit =
2432 Builder.CreateVectorInterleave({Init, Constant::getNullValue(VTy)});
2433
2434 NewPHI->addIncoming(NewInit, Incoming);
2435 NewPHI->addIncoming(OperationReplacement, BackEdge);
2436
2437 auto *FinalReduction = ReductionInfo[Real].second;
2438 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2439
2440 auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2441 FinalReduction->replaceAllUsesWith(AddReduce);
2442}
2443
2444void ComplexDeinterleavingGraph::processReductionOperation(
2445 Value *OperationReplacement, CompositeNode *Node) {
2446 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2447 auto *Imag = cast<Instruction>(Node->Vals[0].Imag);
2448 auto *OldPHIReal = ReductionInfo[Real].first;
2449 auto *OldPHIImag = ReductionInfo[Imag].first;
2450 auto *NewPHI = OldToNewPHI[OldPHIReal];
2451
2452 // We have to interleave initial origin values coming from IncomingBlock
2453 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2454 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2455
2456 IRBuilder<> Builder(Incoming->getTerminator());
2457 auto *NewInit = Builder.CreateVectorInterleave({InitReal, InitImag});
2458
2459 NewPHI->addIncoming(NewInit, Incoming);
2460 NewPHI->addIncoming(OperationReplacement, BackEdge);
2461
2462 // Deinterleave complex vector outside of loop so that it can be finally
2463 // reduced
2464 auto *FinalReductionReal = ReductionInfo[Real].second;
2465 auto *FinalReductionImag = ReductionInfo[Imag].second;
2466
2467 auto *Br = cast<CondBrInst>(BackEdge->getTerminator());
2468 BasicBlock *ExitBB = Br->getSuccessor(Br->getSuccessor(0) == BackEdge);
2469 Builder.SetInsertPoint(&*ExitBB->getFirstInsertionPt());
2470
2471 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2472 OperationReplacement->getType(),
2473 OperationReplacement);
2474
2475 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2476 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2477
2478 Builder.SetInsertPoint(FinalReductionImag);
2479 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2480 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2481}
2482
2483void ComplexDeinterleavingGraph::replaceNodes() {
2484 SmallVector<Instruction *, 16> DeadInstrRoots;
2485 for (auto *RootInstruction : OrderedRoots) {
2486 // Check if this potential root went through check process and we can
2487 // deinterleave it
2488 if (!RootToNode.count(RootInstruction))
2489 continue;
2490
2491 IRBuilder<> Builder(RootInstruction);
2492 auto RootNode = RootToNode[RootInstruction];
2493 Value *R = replaceNode(Builder, RootNode);
2494
2495 if (RootNode->Operation ==
2496 ComplexDeinterleavingOperation::ReductionOperation) {
2497 auto *RootReal = cast<Instruction>(RootNode->Vals[0].Real);
2498 auto *RootImag = cast<Instruction>(RootNode->Vals[0].Imag);
2499 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2500 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2501 DeadInstrRoots.push_back(RootReal);
2502 DeadInstrRoots.push_back(RootImag);
2503 } else if (RootNode->Operation ==
2504 ComplexDeinterleavingOperation::ReductionSingle) {
2505 auto *RootInst = cast<Instruction>(RootNode->Vals[0].Real);
2506 auto &Info = ReductionInfo[RootInst];
2507 Info.first->removeIncomingValue(BackEdge);
2508 DeadInstrRoots.push_back(Info.second);
2509 } else {
2510 assert(R && "Unable to find replacement for RootInstruction");
2511 DeadInstrRoots.push_back(RootInstruction);
2512 RootInstruction->replaceAllUsesWith(R);
2513 }
2514 }
2515
2516 for (auto *I : DeadInstrRoots)
2518}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Rewrite undef for PHI
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
This file implements a map that provides insertion order iteration.
#define T
uint64_t IntrinsicInst * II
#define P(N)
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:119
This file describes how to lower LLVM code to machine code.
This pass exposes codegen information to IR-level passes.
BinaryOperator * Mul
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
size_t size() const
Get the array size.
Definition ArrayRef.h:141
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:225
iterator end()
Definition DenseMap.h:143
bool allowContract() const
Definition FMF.h:69
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition IRBuilder.h:2684
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={})
Create a call to intrinsic ID with Args, mangled using OverloadTypes.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:207
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
size_type size() const
Definition MapVector.h:58
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:288
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:552
An opaque object representing a hash code.
Definition Hashing.h:78
const ParentTy * getParent() const
Definition ilist_node.h:34
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BR
Control flow instructions. These all have token chains.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
match_bind< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BinOp()
Match an arbitrary binary operation and ignore it.
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
NodeAddr< NodeBase * > Node
Definition RDFGraph.h:381
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition DWP.cpp:558
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:535
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
@ Other
Any other memory.
Definition ModRef.h:68
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1771
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition STLExtras.h:2165
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:325
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:862
#define N
ComplexDeinterleavingPass(const TargetMachine &TM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static unsigned getHashValue(const ComplexValue &Val)
An information struct used to provide DenseMap with the various necessary components for a given valu...