LLVM  17.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 //
22 // Replacement:
23 // This step traverses the graph built up by identification, delegating to the
24 // target to validate and generate the correct intrinsics, and plumbs them
25 // together connecting each end of the new intrinsics graph to the existing
26 // use-def chain. This step is assumed to finish successfully, as all
27 // information is expected to be correct by this point.
28 //
29 //
30 // Internal data structure:
31 // ComplexDeinterleavingGraph:
32 // Keeps references to all the valid CompositeNodes formed as part of the
33 // transformation, and every Instruction contained within said nodes. It also
34 // holds onto a reference to the root Instruction, and the root node that should
35 // replace it.
36 //
37 // ComplexDeinterleavingCompositeNode:
38 // A CompositeNode represents a single transformation point; each node should
39 // transform into a single complex instruction (ignoring vector splitting, which
40 // would generate more instructions per node). They are identified in a
41 // depth-first manner, traversing and identifying the operands of each
42 // instruction in the order they appear in the IR.
43 // Each node maintains a reference to its Real and Imaginary instructions,
44 // as well as any additional instructions that make up the identified operation
45 // (Internal instructions should only have uses within their containing node).
46 // A Node also contains the rotation and operation type that it represents.
47 // Operands contains pointers to other CompositeNodes, acting as the edges in
48 // the graph. ReplacementValue is the transformed Value* that has been emitted
49 // to the IR.
50 //
51 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
52 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
53 // should be pre-populated.
54 //
55 //===----------------------------------------------------------------------===//
56 
58 #include "llvm/ADT/Statistic.h"
64 #include "llvm/IR/IRBuilder.h"
65 #include "llvm/InitializePasses.h"
68 #include <algorithm>
69 
70 using namespace llvm;
71 using namespace PatternMatch;
72 
73 #define DEBUG_TYPE "complex-deinterleaving"
74 
75 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
76 
78  "enable-complex-deinterleaving",
79  cl::desc("Enable generation of complex instructions"), cl::init(true),
80  cl::Hidden);
81 
82 /// Checks the given mask, and determines whether said mask is interleaving.
83 ///
84 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
85 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
86 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
88 
89 /// Checks the given mask, and determines whether said mask is deinterleaving.
90 ///
91 /// To be deinterleaving, a mask must increment in steps of 2, and either start
92 /// with 0 or 1.
93 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
94 /// <1, 3, 5, 7>).
96 
97 namespace {
98 
99 class ComplexDeinterleavingLegacyPass : public FunctionPass {
100 public:
101  static char ID;
102 
103  ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
104  : FunctionPass(ID), TM(TM) {
107  }
108 
109  StringRef getPassName() const override {
110  return "Complex Deinterleaving Pass";
111  }
112 
113  bool runOnFunction(Function &F) override;
114  void getAnalysisUsage(AnalysisUsage &AU) const override {
116  AU.setPreservesCFG();
117  }
118 
119 private:
120  const TargetMachine *TM;
121 };
122 
123 class ComplexDeinterleavingGraph;
124 struct ComplexDeinterleavingCompositeNode {
125 
126  ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
128  : Operation(Op), Real(R), Imag(I) {}
129 
130 private:
131  friend class ComplexDeinterleavingGraph;
132  using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
133  using RawNodePtr = ComplexDeinterleavingCompositeNode *;
134 
135 public:
137  Instruction *Real;
138  Instruction *Imag;
139 
140  // Instructions that should only exist within this node, there should be no
141  // users of these instructions outside the node. An example of these would be
142  // the multiply instructions of a partial multiply operation.
143  SmallVector<Instruction *> InternalInstructions;
146  Value *ReplacementNode = nullptr;
147 
148  void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
149  void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
150 
151  bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
152 
153  void dump() { dump(dbgs()); }
154  void dump(raw_ostream &OS) {
155  auto PrintValue = [&](Value *V) {
156  if (V) {
157  OS << "\"";
158  V->print(OS, true);
159  OS << "\"\n";
160  } else
161  OS << "nullptr\n";
162  };
163  auto PrintNodeRef = [&](RawNodePtr Ptr) {
164  if (Ptr)
165  OS << Ptr << "\n";
166  else
167  OS << "nullptr\n";
168  };
169 
170  OS << "- CompositeNode: " << this << "\n";
171  OS << " Real: ";
172  PrintValue(Real);
173  OS << " Imag: ";
174  PrintValue(Imag);
175  OS << " ReplacementNode: ";
176  PrintValue(ReplacementNode);
177  OS << " Operation: " << (int)Operation << "\n";
178  OS << " Rotation: " << ((int)Rotation * 90) << "\n";
179  OS << " Operands: \n";
180  for (const auto &Op : Operands) {
181  OS << " - ";
182  PrintNodeRef(Op);
183  }
184  OS << " InternalInstructions:\n";
185  for (const auto &I : InternalInstructions) {
186  OS << " - \"";
187  I->print(OS, true);
188  OS << "\"\n";
189  }
190  }
191 };
192 
193 class ComplexDeinterleavingGraph {
194 public:
195  using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
196  using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
197  explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
198 
199 private:
200  const TargetLowering *TL;
201  Instruction *RootValue;
202  NodePtr RootNode;
203  SmallVector<NodePtr> CompositeNodes;
204  SmallPtrSet<Instruction *, 16> AllInstructions;
205 
206  NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
207  Instruction *R, Instruction *I) {
208  return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
209  I);
210  }
211 
212  NodePtr submitCompositeNode(NodePtr Node) {
213  CompositeNodes.push_back(Node);
214  AllInstructions.insert(Node->Real);
215  AllInstructions.insert(Node->Imag);
216  for (auto *I : Node->InternalInstructions)
217  AllInstructions.insert(I);
218  return Node;
219  }
220 
221  NodePtr getContainingComposite(Value *R, Value *I) {
222  for (const auto &CN : CompositeNodes) {
223  if (CN->Real == R && CN->Imag == I)
224  return CN;
225  }
226  return nullptr;
227  }
228 
229  /// Identifies a complex partial multiply pattern and its rotation, based on
230  /// the following patterns
231  ///
232  /// 0: r: cr + ar * br
233  /// i: ci + ar * bi
234  /// 90: r: cr - ai * bi
235  /// i: ci + ai * br
236  /// 180: r: cr - ar * br
237  /// i: ci - ar * bi
238  /// 270: r: cr + ai * bi
239  /// i: ci - ai * br
240  NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
241 
242  /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
243  /// is partially known from identifyPartialMul, filling in the other half of
244  /// the complex pair.
245  NodePtr identifyNodeWithImplicitAdd(
247  std::pair<Instruction *, Instruction *> &CommonOperandI);
248 
249  /// Identifies a complex add pattern and its rotation, based on the following
250  /// patterns.
251  ///
252  /// 90: r: ar - bi
253  /// i: ai + br
254  /// 270: r: ar + bi
255  /// i: ai - br
256  NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
257 
258  NodePtr identifyNode(Instruction *I, Instruction *J);
259 
260  Value *replaceNode(RawNodePtr Node);
261 
262 public:
263  void dump() { dump(dbgs()); }
264  void dump(raw_ostream &OS) {
265  for (const auto &Node : CompositeNodes)
266  Node->dump(OS);
267  }
268 
269  /// Returns false if the deinterleaving operation should be cancelled for the
270  /// current graph.
271  bool identifyNodes(Instruction *RootI);
272 
273  /// Perform the actual replacement of the underlying instruction graph.
274  /// Returns false if the deinterleaving operation should be cancelled for the
275  /// current graph.
276  void replaceNodes();
277 };
278 
279 class ComplexDeinterleaving {
280 public:
281  ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
282  : TL(tl), TLI(tli) {}
283  bool runOnFunction(Function &F);
284 
285 private:
286  bool evaluateBasicBlock(BasicBlock *B);
287 
288  const TargetLowering *TL = nullptr;
289  const TargetLibraryInfo *TLI = nullptr;
290 };
291 
292 } // namespace
293 
295 
296 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
297  "Complex Deinterleaving", false, false)
298 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
300 
303  const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
304  auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
305  if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
306  return PreservedAnalyses::all();
307 
310  return PA;
311 }
312 
314  return new ComplexDeinterleavingLegacyPass(TM);
315 }
316 
318  const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
319  auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
320  return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
321 }
322 
325  LLVM_DEBUG(
326  dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
327  return false;
328  }
329 
331  LLVM_DEBUG(
332  dbgs() << "Complex deinterleaving has been disabled, target does "
333  "not support lowering of complex number operations.\n");
334  return false;
335  }
336 
337  bool Changed = false;
338  for (auto &B : F)
339  Changed |= evaluateBasicBlock(&B);
340 
341  return Changed;
342 }
343 
345  // If the size is not even, it's not an interleaving mask
346  if ((Mask.size() & 1))
347  return false;
348 
349  int HalfNumElements = Mask.size() / 2;
350  for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
351  int MaskIdx = Idx * 2;
352  if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
353  return false;
354  }
355 
356  return true;
357 }
358 
360  int Offset = Mask[0];
361  int HalfNumElements = Mask.size() / 2;
362 
363  for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
364  if (Mask[Idx] != (Idx * 2) + Offset)
365  return false;
366  }
367 
368  return true;
369 }
370 
371 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
372  bool Changed = false;
373 
374  SmallVector<Instruction *> DeadInstrRoots;
375 
376  for (auto &I : *B) {
377  auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
378  if (!SVI)
379  continue;
380 
381  // Look for a shufflevector that takes separate vectors of the real and
382  // imaginary components and recombines them into a single vector.
383  if (!isInterleavingMask(SVI->getShuffleMask()))
384  continue;
385 
386  ComplexDeinterleavingGraph Graph(TL);
387  if (!Graph.identifyNodes(SVI))
388  continue;
389 
390  Graph.replaceNodes();
391  DeadInstrRoots.push_back(SVI);
392  Changed = true;
393  }
394 
395  for (const auto &I : DeadInstrRoots) {
396  if (!I || I->getParent() == nullptr)
397  continue;
399  }
400 
401  return Changed;
402 }
403 
404 ComplexDeinterleavingGraph::NodePtr
405 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
406  Instruction *Real, Instruction *Imag,
407  std::pair<Instruction *, Instruction *> &PartialMatch) {
408  LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
409  << "\n");
410 
411  if (!Real->hasOneUse() || !Imag->hasOneUse()) {
412  LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
413  return nullptr;
414  }
415 
416  if (Real->getOpcode() != Instruction::FMul ||
417  Imag->getOpcode() != Instruction::FMul) {
418  LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n");
419  return nullptr;
420  }
421 
422  Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
423  Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
424  Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
425  Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
426  if (!R0 || !R1 || !I0 || !I1) {
427  LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
428  return nullptr;
429  }
430 
431  // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
432  // rotations and use the operand.
433  unsigned Negs = 0;
435  if (R0->getOpcode() == Instruction::FNeg ||
436  R1->getOpcode() == Instruction::FNeg) {
437  Negs |= 1;
438  if (R0->getOpcode() == Instruction::FNeg) {
439  FNegs.push_back(R0);
440  R0 = dyn_cast<Instruction>(R0->getOperand(0));
441  } else {
442  FNegs.push_back(R1);
443  R1 = dyn_cast<Instruction>(R1->getOperand(0));
444  }
445  if (!R0 || !R1)
446  return nullptr;
447  }
448  if (I0->getOpcode() == Instruction::FNeg ||
449  I1->getOpcode() == Instruction::FNeg) {
450  Negs |= 2;
451  Negs ^= 1;
452  if (I0->getOpcode() == Instruction::FNeg) {
453  FNegs.push_back(I0);
454  I0 = dyn_cast<Instruction>(I0->getOperand(0));
455  } else {
456  FNegs.push_back(I1);
457  I1 = dyn_cast<Instruction>(I1->getOperand(0));
458  }
459  if (!I0 || !I1)
460  return nullptr;
461  }
462 
464 
465  Instruction *CommonOperand;
466  Instruction *UncommonRealOp;
467  Instruction *UncommonImagOp;
468 
469  if (R0 == I0 || R0 == I1) {
470  CommonOperand = R0;
471  UncommonRealOp = R1;
472  } else if (R1 == I0 || R1 == I1) {
473  CommonOperand = R1;
474  UncommonRealOp = R0;
475  } else {
476  LLVM_DEBUG(dbgs() << " - No equal operand\n");
477  return nullptr;
478  }
479 
480  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
483  std::swap(UncommonRealOp, UncommonImagOp);
484 
485  // Between identifyPartialMul and here we need to have found a complete valid
486  // pair from the CommonOperand of each part.
489  PartialMatch.first = CommonOperand;
490  else
491  PartialMatch.second = CommonOperand;
492 
493  if (!PartialMatch.first || !PartialMatch.second) {
494  LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
495  return nullptr;
496  }
497 
498  NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
499  if (!CommonNode) {
500  LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
501  return nullptr;
502  }
503 
504  NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
505  if (!UncommonNode) {
506  LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
507  return nullptr;
508  }
509 
510  NodePtr Node = prepareCompositeNode(
512  Node->Rotation = Rotation;
513  Node->addOperand(CommonNode);
514  Node->addOperand(UncommonNode);
515  Node->InternalInstructions.append(FNegs);
516  return submitCompositeNode(Node);
517 }
518 
519 ComplexDeinterleavingGraph::NodePtr
520 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
521  Instruction *Imag) {
522  LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
523  << "\n");
524  // Determine rotation
526  if (Real->getOpcode() == Instruction::FAdd &&
527  Imag->getOpcode() == Instruction::FAdd)
529  else if (Real->getOpcode() == Instruction::FSub &&
530  Imag->getOpcode() == Instruction::FAdd)
532  else if (Real->getOpcode() == Instruction::FSub &&
533  Imag->getOpcode() == Instruction::FSub)
535  else if (Real->getOpcode() == Instruction::FAdd &&
536  Imag->getOpcode() == Instruction::FSub)
538  else {
539  LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
540  return nullptr;
541  }
542 
543  if (!Real->getFastMathFlags().allowContract() ||
544  !Imag->getFastMathFlags().allowContract()) {
545  LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
546  return nullptr;
547  }
548 
549  Value *CR = Real->getOperand(0);
550  Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
551  if (!RealMulI)
552  return nullptr;
553  Value *CI = Imag->getOperand(0);
554  Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
555  if (!ImagMulI)
556  return nullptr;
557 
558  if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
559  LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
560  return nullptr;
561  }
562 
563  Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
564  Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
565  Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
566  Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
567  if (!R0 || !R1 || !I0 || !I1) {
568  LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
569  return nullptr;
570  }
571 
572  Instruction *CommonOperand;
573  Instruction *UncommonRealOp;
574  Instruction *UncommonImagOp;
575 
576  if (R0 == I0 || R0 == I1) {
577  CommonOperand = R0;
578  UncommonRealOp = R1;
579  } else if (R1 == I0 || R1 == I1) {
580  CommonOperand = R1;
581  UncommonRealOp = R0;
582  } else {
583  LLVM_DEBUG(dbgs() << " - No equal operand\n");
584  return nullptr;
585  }
586 
587  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
590  std::swap(UncommonRealOp, UncommonImagOp);
591 
592  std::pair<Instruction *, Instruction *> PartialMatch(
595  ? CommonOperand
596  : nullptr,
599  ? CommonOperand
600  : nullptr);
601  NodePtr CNode = identifyNodeWithImplicitAdd(
602  cast<Instruction>(CR), cast<Instruction>(CI), PartialMatch);
603  if (!CNode) {
604  LLVM_DEBUG(dbgs() << " - No cnode identified\n");
605  return nullptr;
606  }
607 
608  NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
609  if (!UncommonRes) {
610  LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
611  return nullptr;
612  }
613 
614  assert(PartialMatch.first && PartialMatch.second);
615  NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
616  if (!CommonRes) {
617  LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
618  return nullptr;
619  }
620 
621  NodePtr Node = prepareCompositeNode(
623  Node->addInstruction(RealMulI);
624  Node->addInstruction(ImagMulI);
625  Node->Rotation = Rotation;
626  Node->addOperand(CommonRes);
627  Node->addOperand(UncommonRes);
628  Node->addOperand(CNode);
629  return submitCompositeNode(Node);
630 }
631 
632 ComplexDeinterleavingGraph::NodePtr
633 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
634  LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
635 
636  // Determine rotation
638  if ((Real->getOpcode() == Instruction::FSub &&
639  Imag->getOpcode() == Instruction::FAdd) ||
640  (Real->getOpcode() == Instruction::Sub &&
641  Imag->getOpcode() == Instruction::Add))
643  else if ((Real->getOpcode() == Instruction::FAdd &&
644  Imag->getOpcode() == Instruction::FSub) ||
645  (Real->getOpcode() == Instruction::Add &&
646  Imag->getOpcode() == Instruction::Sub))
648  else {
649  LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
650  return nullptr;
651  }
652 
653  auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
654  auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
655  auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
656  auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
657 
658  if (!AR || !AI || !BR || !BI) {
659  LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
660  return nullptr;
661  }
662 
663  NodePtr ResA = identifyNode(AR, AI);
664  if (!ResA) {
665  LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
666  return nullptr;
667  }
668  NodePtr ResB = identifyNode(BR, BI);
669  if (!ResB) {
670  LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
671  return nullptr;
672  }
673 
674  NodePtr Node =
675  prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
676  Node->Rotation = Rotation;
677  Node->addOperand(ResA);
678  Node->addOperand(ResB);
679  return submitCompositeNode(Node);
680 }
681 
683  unsigned OpcA = A->getOpcode();
684  unsigned OpcB = B->getOpcode();
685 
686  return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
687  (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
688  (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
689  (OpcA == Instruction::Add && OpcB == Instruction::Sub);
690 }
691 
693  auto Pattern =
695 
696  return match(A, Pattern) && match(B, Pattern);
697 }
698 
699 ComplexDeinterleavingGraph::NodePtr
700 ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
701  LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
702  if (NodePtr CN = getContainingComposite(Real, Imag)) {
703  LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
704  return CN;
705  }
706 
707  auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
708  auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
709  if (RealShuffle && ImagShuffle) {
710  Value *RealOp1 = RealShuffle->getOperand(1);
711  if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
712  LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
713  return nullptr;
714  }
715  Value *ImagOp1 = ImagShuffle->getOperand(1);
716  if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
717  LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
718  return nullptr;
719  }
720 
721  Value *RealOp0 = RealShuffle->getOperand(0);
722  Value *ImagOp0 = ImagShuffle->getOperand(0);
723 
724  if (RealOp0 != ImagOp0) {
725  LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
726  return nullptr;
727  }
728 
729  ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
730  ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
731  if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
732  LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
733  return nullptr;
734  }
735 
736  if (RealMask[0] != 0 || ImagMask[0] != 1) {
737  LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
738  return nullptr;
739  }
740 
741  // Type checking, the shuffle type should be a vector type of the same
742  // scalar type, but half the size
743  auto CheckType = [&](ShuffleVectorInst *Shuffle) {
744  Value *Op = Shuffle->getOperand(0);
745  auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
746  auto *OpTy = cast<FixedVectorType>(Op->getType());
747 
748  if (OpTy->getScalarType() != ShuffleTy->getScalarType())
749  return false;
750  if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
751  return false;
752 
753  return true;
754  };
755 
756  auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
757  if (!CheckType(Shuffle))
758  return false;
759 
760  ArrayRef<int> Mask = Shuffle->getShuffleMask();
761  int Last = *Mask.rbegin();
762 
763  Value *Op = Shuffle->getOperand(0);
764  auto *OpTy = cast<FixedVectorType>(Op->getType());
765  int NumElements = OpTy->getNumElements();
766 
767  // Ensure that the deinterleaving shuffle only pulls from the first
768  // shuffle operand.
769  return Last < NumElements;
770  };
771 
772  if (RealShuffle->getType() != ImagShuffle->getType()) {
773  LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
774  return nullptr;
775  }
776  if (!CheckDeinterleavingShuffle(RealShuffle)) {
777  LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
778  return nullptr;
779  }
780  if (!CheckDeinterleavingShuffle(ImagShuffle)) {
781  LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
782  return nullptr;
783  }
784 
785  NodePtr PlaceholderNode =
787  RealShuffle, ImagShuffle);
788  PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
789  return submitCompositeNode(PlaceholderNode);
790  }
791  if (RealShuffle || ImagShuffle)
792  return nullptr;
793 
794  auto *VTy = cast<FixedVectorType>(Real->getType());
795  auto *NewVTy =
796  FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2);
797 
800  isInstructionPairMul(Real, Imag)) {
801  return identifyPartialMul(Real, Imag);
802  }
803 
806  isInstructionPairAdd(Real, Imag)) {
807  return identifyAdd(Real, Imag);
808  }
809 
810  return nullptr;
811 }
812 
813 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
814  Instruction *Real;
815  Instruction *Imag;
816  if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
817  return false;
818 
819  RootValue = RootI;
820  AllInstructions.insert(RootI);
821  RootNode = identifyNode(Real, Imag);
822 
823  LLVM_DEBUG({
824  Function *F = RootI->getFunction();
825  BasicBlock *B = RootI->getParent();
826  dbgs() << "Complex deinterleaving graph for " << F->getName()
827  << "::" << B->getName() << ".\n";
828  dump(dbgs());
829  dbgs() << "\n";
830  });
831 
832  // Check all instructions have internal uses
833  for (const auto &Node : CompositeNodes) {
834  if (!Node->hasAllInternalUses(AllInstructions)) {
835  LLVM_DEBUG(dbgs() << " - Invalid internal uses\n");
836  return false;
837  }
838  }
839  return RootNode != nullptr;
840 }
841 
842 Value *ComplexDeinterleavingGraph::replaceNode(
843  ComplexDeinterleavingGraph::RawNodePtr Node) {
844  if (Node->ReplacementNode)
845  return Node->ReplacementNode;
846 
847  Value *Input0 = replaceNode(Node->Operands[0]);
848  Value *Input1 = replaceNode(Node->Operands[1]);
849  Value *Accumulator =
850  Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr;
851 
852  assert(Input0->getType() == Input1->getType() &&
853  "Node inputs need to be of the same type");
854 
855  Node->ReplacementNode = TL->createComplexDeinterleavingIR(
856  Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
857 
858  assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
859  NumComplexTransformations += 1;
860  return Node->ReplacementNode;
861 }
862 
863 void ComplexDeinterleavingGraph::replaceNodes() {
864  Value *R = replaceNode(RootNode.get());
865  assert(R && "Unable to find replacement for RootValue");
866  RootValue->replaceAllUsesWith(R);
867 }
868 
869 bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
870  SmallPtrSet<Instruction *, 16> &AllInstructions) {
872  return true;
873 
874  for (auto *User : Real->users()) {
875  if (!AllInstructions.contains(cast<Instruction>(User)))
876  return false;
877  }
878  for (auto *User : Imag->users()) {
879  if (!AllInstructions.contains(cast<Instruction>(User)))
880  return false;
881  }
882  for (auto *I : InternalInstructions) {
883  for (auto *User : I->users()) {
884  if (!AllInstructions.contains(cast<Instruction>(User)))
885  return false;
886  }
887  }
888  return true;
889 }
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
llvm::RecursivelyDeleteTriviallyDeadInstructions
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:537
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
llvm::AArch64PACKey::ID
ID
Definition: AArch64BaseInfo.h:824
llvm::dxil::ParameterKind::I1
@ I1
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
llvm::Function
Definition: Function.h:59
llvm::ComplexDeinterleavingRotation::Rotation_0
@ Rotation_0
llvm::PseudoProbeReservedId::Last
@ Last
llvm::ComplexDeinterleavingRotation::Rotation_90
@ Rotation_90
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1199
Statistic.h
CheckType
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(const unsigned char *MatcherTable, unsigned &MatcherIndex, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
Definition: SelectionDAGISel.cpp:2591
llvm::initializeComplexDeinterleavingLegacyPassPass
void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
Local.h
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:138
FMAInstKind::Accumulator
@ Accumulator
llvm::PatternMatch::m_BinOp
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:84
llvm::SmallPtrSet
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
isInstructionPairMul
static bool isInstructionPairMul(Instruction *A, Instruction *B)
Definition: ComplexDeinterleavingPass.cpp:692
llvm::dump
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
Definition: SparseBitVector.h:877
isInterleavingMask
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
Definition: ComplexDeinterleavingPass.cpp:344
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::createComplexDeinterleavingPass
FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
Definition: ComplexDeinterleavingPass.cpp:313
llvm::RISCVFenceField::R
@ R
Definition: RISCVBaseInfo.h:275
llvm::ComplexDeinterleavingOperation
ComplexDeinterleavingOperation
Definition: ComplexDeinterleavingPass.h:36
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
llvm::BitmaskEnumDetail::Mask
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:80
TargetLowering.h
llvm::Instruction::getOpcode
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:168
llvm::ComplexDeinterleavingPass
Definition: ComplexDeinterleavingPass.h:25
TargetMachine.h
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:24
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
Operation
PowerPC Reduce CR logical Operation
Definition: PPCReduceCRLogicals.cpp:735
llvm::User
Definition: User.h:44
isDeinterleavingMask
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
Definition: ComplexDeinterleavingPass.cpp:359
int
Clang compiles this i1 i64 store i64 i64 store i64 i64 store i64 i64 store i64 align Which gets codegen d xmm0 movaps rbp movaps rbp movaps rbp movaps rbp rbp rbp rbp rbp It would be better to have movq s of instead of the movaps s LLVM produces ret int
Definition: README.txt:536
llvm::TargetLowering
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Definition: TargetLowering.h:3510
llvm::TargetLoweringBase::isComplexDeinterleavingOperationSupported
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
Definition: TargetLowering.h:3143
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
TargetLibraryInfo.h
llvm::PatternMatch::m_Instruction
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
Definition: PatternMatch.h:716
false
Definition: StackSlotColoring.cpp:141
B
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
isInstructionPairAdd
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
Definition: ComplexDeinterleavingPass.cpp:682
llvm::Instruction
Definition: Instruction.h:41
llvm::STATISTIC
STATISTIC(NumFunctions, "Total number of functions")
llvm::raw_ostream
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, "Complex Deinterleaving", false, false) INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass
llvm::FixedVectorType::get
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:686
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
llvm::dxil::PointerTypeAnalysis::run
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
Definition: PointerTypeAnalysis.cpp:189
llvm::ComplexDeinterleavingRotation::Rotation_270
@ Rotation_270
Operands
mir Rename Register Operands
Definition: MIRNamerPass.cpp:74
llvm::ComplexDeinterleavingRotation
ComplexDeinterleavingRotation
Definition: ComplexDeinterleavingPass.h:44
llvm::cl::opt< bool >
ComplexDeinterleavingEnabled
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
ComplexDeinterleavingPass.h
llvm::TargetLibraryInfoWrapperPass
Definition: TargetLibraryInfo.h:565
llvm::PreservedAnalyses::preserve
void preserve()
Mark an analysis as preserved.
Definition: PassManager.h:173
DEBUG_TYPE
#define DEBUG_TYPE
Definition: ComplexDeinterleavingPass.cpp:73
I
#define I(x, y, z)
Definition: MD5.cpp:58
llvm::cl::init
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
llvm::TargetLoweringBase::createComplexDeinterleavingIR
virtual Value * createComplexDeinterleavingIR(Instruction *I, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
Definition: TargetLowering.h:3151
TargetPassConfig.h
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::TargetMachine
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:78
addOperand
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Definition: AMDGPUDisassembler.cpp:59
std::swap
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:853
llvm::Instruction::getFastMathFlags
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
Definition: Instruction.cpp:339
Ptr
@ Ptr
Definition: TargetLibraryInfo.cpp:62
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
llvm::ArrayRef< int >
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:265
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::Instruction::getFunction
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:74
llvm::PatternMatch::m_Shuffle
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
Definition: PatternMatch.h:1551
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:532
TargetSubtargetInfo.h
llvm::ComplexDeinterleavingOperation::Shuffle
@ Shuffle
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:85
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:354
llvm::PreservedAnalyses::all
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
llvm::ISD::BR
@ BR
Control flow instructions. These all have token chains.
Definition: ISDOpcodes.h:981
llvm::TargetLibraryInfo
Provides information about what library functions are available for the current target.
Definition: TargetLibraryInfo.h:234
llvm::TargetLoweringBase::isComplexDeinterleavingSupported
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
Definition: TargetLowering.h:3139
llvm::ComplexDeinterleavingRotation::Rotation_180
@ Rotation_180
llvm::MCID::Add
@ Add
Definition: MCInstrDesc.h:186
Deinterleaving
Complex Deinterleaving
Definition: ComplexDeinterleavingPass.cpp:299
llvm::ShuffleVectorInst
This instruction constructs a fixed permutation of two input vectors.
Definition: Instructions.h:2017
llvm::Pattern
Definition: FileCheckImpl.h:614
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:90
llvm::ComplexDeinterleavingOperation::CAdd
@ CAdd
TargetTransformInfo.h
llvm::AnalysisManager
A container for analyses that lazily runs them and caches their results.
Definition: InstructionSimplify.h:42
TM
const char LLVMTargetMachineRef TM
Definition: PassBuilderBindings.cpp:47
llvm::InnerAnalysisManagerProxy
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:931
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
llvm::FastMathFlags::allowContract
bool allowContract() const
Definition: FMF.h:71
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::cl::desc
Definition: CommandLine.h:411
llvm::SmallPtrSetImpl::contains
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:389
llvm::PatternMatch::m_FMul
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
Definition: PatternMatch.h:1051
InitializePasses.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
llvm::TargetLibraryAnalysis
Analysis pass providing the TargetLibraryInfo.
Definition: TargetLibraryInfo.h:540
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::ComplexDeinterleavingOperation::CMulPartial
@ CMulPartial
llvm::SmallPtrSetImpl::insert
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365