71 using namespace PatternMatch;
73 #define DEBUG_TYPE "complex-deinterleaving"
75 STATISTIC(NumComplexTransformations,
"Amount of complex patterns transformed");
78 "enable-complex-deinterleaving",
99 class ComplexDeinterleavingLegacyPass :
public FunctionPass {
110 return "Complex Deinterleaving Pass";
123 class ComplexDeinterleavingGraph;
124 struct ComplexDeinterleavingCompositeNode {
131 friend class ComplexDeinterleavingGraph;
132 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
133 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
146 Value *ReplacementNode =
nullptr;
148 void addInstruction(
Instruction *
I) { InternalInstructions.push_back(
I); }
155 auto PrintValue = [&](
Value *V) {
163 auto PrintNodeRef = [&](RawNodePtr
Ptr) {
170 OS <<
"- CompositeNode: " <<
this <<
"\n";
175 OS <<
" ReplacementNode: ";
176 PrintValue(ReplacementNode);
178 OS <<
" Rotation: " << ((
int)Rotation * 90) <<
"\n";
179 OS <<
" Operands: \n";
184 OS <<
" InternalInstructions:\n";
185 for (
const auto &
I : InternalInstructions) {
193 class ComplexDeinterleavingGraph {
195 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
196 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
197 explicit ComplexDeinterleavingGraph(
const TargetLowering *tl) : TL(tl) {}
208 return std::make_shared<ComplexDeinterleavingCompositeNode>(
Operation, R,
212 NodePtr submitCompositeNode(NodePtr Node) {
213 CompositeNodes.push_back(Node);
214 AllInstructions.
insert(Node->Real);
215 AllInstructions.
insert(Node->Imag);
216 for (
auto *
I : Node->InternalInstructions)
221 NodePtr getContainingComposite(
Value *R,
Value *
I) {
222 for (
const auto &CN : CompositeNodes) {
223 if (CN->Real == R && CN->Imag ==
I)
245 NodePtr identifyNodeWithImplicitAdd(
247 std::pair<Instruction *, Instruction *> &CommonOperandI);
260 Value *replaceNode(RawNodePtr Node);
265 for (
const auto &Node : CompositeNodes)
279 class ComplexDeinterleaving {
282 : TL(tl), TLI(tli) {}
297 "Complex Deinterleaving",
false,
false)
305 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(
F))
314 return new ComplexDeinterleavingLegacyPass(
TM);
318 const auto *TL =
TM->getSubtargetImpl(
F)->getTargetLowering();
319 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
320 return ComplexDeinterleaving(TL, &TLI).runOnFunction(
F);
326 dbgs() <<
"Complex deinterleaving has been explicitly disabled.\n");
332 dbgs() <<
"Complex deinterleaving has been disabled, target does "
333 "not support lowering of complex number operations.\n");
337 bool Changed =
false;
339 Changed |= evaluateBasicBlock(&
B);
346 if ((
Mask.size() & 1))
349 int HalfNumElements =
Mask.size() / 2;
350 for (
int Idx = 0; Idx < HalfNumElements; ++Idx) {
351 int MaskIdx = Idx * 2;
352 if (
Mask[MaskIdx] != Idx ||
Mask[MaskIdx + 1] != (Idx + HalfNumElements))
360 int Offset =
Mask[0];
361 int HalfNumElements =
Mask.size() / 2;
363 for (
int Idx = 1; Idx < HalfNumElements; ++Idx) {
364 if (
Mask[Idx] != (Idx * 2) + Offset)
371 bool ComplexDeinterleaving::evaluateBasicBlock(
BasicBlock *
B) {
372 bool Changed =
false;
377 auto *SVI = dyn_cast<ShuffleVectorInst>(&
I);
386 ComplexDeinterleavingGraph Graph(TL);
387 if (!Graph.identifyNodes(SVI))
390 Graph.replaceNodes();
391 DeadInstrRoots.push_back(SVI);
395 for (
const auto &
I : DeadInstrRoots) {
396 if (!
I ||
I->getParent() ==
nullptr)
404 ComplexDeinterleavingGraph::NodePtr
405 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
407 std::pair<Instruction *, Instruction *> &PartialMatch) {
408 LLVM_DEBUG(
dbgs() <<
"identifyNodeWithImplicitAdd " << *Real <<
" / " << *Imag
416 if (Real->
getOpcode() != Instruction::FMul ||
417 Imag->
getOpcode() != Instruction::FMul) {
418 LLVM_DEBUG(
dbgs() <<
" - Real or imaginary instruction is not fmul\n");
426 if (!R0 || !R1 || !I0 || !I1) {
435 if (R0->
getOpcode() == Instruction::FNeg ||
438 if (R0->
getOpcode() == Instruction::FNeg) {
440 R0 = dyn_cast<Instruction>(R0->
getOperand(0));
443 R1 = dyn_cast<Instruction>(R1->
getOperand(0));
448 if (I0->
getOpcode() == Instruction::FNeg ||
449 I1->getOpcode() == Instruction::FNeg) {
452 if (I0->
getOpcode() == Instruction::FNeg) {
454 I0 = dyn_cast<Instruction>(I0->
getOperand(0));
457 I1 = dyn_cast<Instruction>(
I1->getOperand(0));
469 if (R0 == I0 || R0 == I1) {
472 }
else if (R1 == I0 || R1 == I1) {
480 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
483 std::swap(UncommonRealOp, UncommonImagOp);
489 PartialMatch.first = CommonOperand;
491 PartialMatch.second = CommonOperand;
493 if (!PartialMatch.first || !PartialMatch.second) {
498 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
504 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
510 NodePtr Node = prepareCompositeNode(
512 Node->Rotation = Rotation;
513 Node->addOperand(CommonNode);
514 Node->addOperand(UncommonNode);
515 Node->InternalInstructions.append(FNegs);
516 return submitCompositeNode(Node);
519 ComplexDeinterleavingGraph::NodePtr
520 ComplexDeinterleavingGraph::identifyPartialMul(
Instruction *Real,
522 LLVM_DEBUG(
dbgs() <<
"identifyPartialMul " << *Real <<
" / " << *Imag
526 if (Real->
getOpcode() == Instruction::FAdd &&
529 else if (Real->
getOpcode() == Instruction::FSub &&
532 else if (Real->
getOpcode() == Instruction::FSub &&
535 else if (Real->
getOpcode() == Instruction::FAdd &&
545 LLVM_DEBUG(
dbgs() <<
" - Contract is missing from the FastMath flags.\n");
567 if (!R0 || !R1 || !I0 || !I1) {
576 if (R0 == I0 || R0 == I1) {
579 }
else if (R1 == I0 || R1 == I1) {
587 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
590 std::swap(UncommonRealOp, UncommonImagOp);
592 std::pair<Instruction *, Instruction *> PartialMatch(
601 NodePtr CNode = identifyNodeWithImplicitAdd(
602 cast<Instruction>(CR), cast<Instruction>(CI), PartialMatch);
608 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
614 assert(PartialMatch.first && PartialMatch.second);
615 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
621 NodePtr Node = prepareCompositeNode(
623 Node->addInstruction(RealMulI);
624 Node->addInstruction(ImagMulI);
625 Node->Rotation = Rotation;
626 Node->addOperand(CommonRes);
627 Node->addOperand(UncommonRes);
628 Node->addOperand(CNode);
629 return submitCompositeNode(Node);
632 ComplexDeinterleavingGraph::NodePtr
634 LLVM_DEBUG(
dbgs() <<
"identifyAdd " << *Real <<
" / " << *Imag <<
"\n");
638 if ((Real->
getOpcode() == Instruction::FSub &&
639 Imag->
getOpcode() == Instruction::FAdd) ||
640 (Real->
getOpcode() == Instruction::Sub &&
643 else if ((Real->
getOpcode() == Instruction::FAdd &&
644 Imag->
getOpcode() == Instruction::FSub) ||
649 LLVM_DEBUG(
dbgs() <<
" - Unhandled case, rotation is not assigned.\n");
653 auto *AR = dyn_cast<Instruction>(Real->
getOperand(0));
654 auto *BI = dyn_cast<Instruction>(Real->
getOperand(1));
655 auto *AI = dyn_cast<Instruction>(Imag->
getOperand(0));
658 if (!AR || !AI || !
BR || !BI) {
663 NodePtr ResA = identifyNode(AR, AI);
665 LLVM_DEBUG(
dbgs() <<
" - AR/AI is not identified as a composite node.\n");
668 NodePtr ResB = identifyNode(
BR, BI);
670 LLVM_DEBUG(
dbgs() <<
" - BR/BI is not identified as a composite node.\n");
676 Node->Rotation = Rotation;
677 Node->addOperand(ResA);
678 Node->addOperand(ResB);
679 return submitCompositeNode(Node);
683 unsigned OpcA = A->getOpcode();
684 unsigned OpcB =
B->getOpcode();
686 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
687 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
699 ComplexDeinterleavingGraph::NodePtr
701 LLVM_DEBUG(
dbgs() <<
"identifyNode on " << *Real <<
" / " << *Imag <<
"\n");
702 if (NodePtr CN = getContainingComposite(Real, Imag)) {
707 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
708 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
709 if (RealShuffle && ImagShuffle) {
710 Value *RealOp1 = RealShuffle->getOperand(1);
711 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
715 Value *ImagOp1 = ImagShuffle->getOperand(1);
716 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
721 Value *RealOp0 = RealShuffle->getOperand(0);
722 Value *ImagOp0 = ImagShuffle->getOperand(0);
724 if (RealOp0 != ImagOp0) {
736 if (RealMask[0] != 0 || ImagMask[0] != 1) {
737 LLVM_DEBUG(
dbgs() <<
" - Masks do not have the correct initial value.\n");
745 auto *ShuffleTy = cast<FixedVectorType>(
Shuffle->getType());
746 auto *OpTy = cast<FixedVectorType>(
Op->getType());
748 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
750 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
764 auto *OpTy = cast<FixedVectorType>(
Op->getType());
765 int NumElements = OpTy->getNumElements();
769 return Last < NumElements;
772 if (RealShuffle->getType() != ImagShuffle->getType()) {
776 if (!CheckDeinterleavingShuffle(RealShuffle)) {
780 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
785 NodePtr PlaceholderNode =
787 RealShuffle, ImagShuffle);
788 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
789 return submitCompositeNode(PlaceholderNode);
791 if (RealShuffle || ImagShuffle)
794 auto *VTy = cast<FixedVectorType>(Real->
getType());
801 return identifyPartialMul(Real, Imag);
807 return identifyAdd(Real, Imag);
813 bool ComplexDeinterleavingGraph::identifyNodes(
Instruction *RootI) {
820 AllInstructions.
insert(RootI);
821 RootNode = identifyNode(Real, Imag);
826 dbgs() <<
"Complex deinterleaving graph for " <<
F->getName()
827 <<
"::" <<
B->getName() <<
".\n";
833 for (
const auto &Node : CompositeNodes) {
834 if (!Node->hasAllInternalUses(AllInstructions)) {
839 return RootNode !=
nullptr;
842 Value *ComplexDeinterleavingGraph::replaceNode(
843 ComplexDeinterleavingGraph::RawNodePtr Node) {
844 if (Node->ReplacementNode)
845 return Node->ReplacementNode;
847 Value *Input0 = replaceNode(Node->Operands[0]);
848 Value *Input1 = replaceNode(Node->Operands[1]);
850 Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr;
853 "Node inputs need to be of the same type");
856 Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
858 assert(Node->ReplacementNode &&
"Target failed to create Intrinsic call.");
859 NumComplexTransformations += 1;
860 return Node->ReplacementNode;
863 void ComplexDeinterleavingGraph::replaceNodes() {
864 Value *
R = replaceNode(RootNode.get());
865 assert(R &&
"Unable to find replacement for RootValue");
869 bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
882 for (
auto *
I : InternalInstructions) {
883 for (
auto *
User :
I->users()) {