82#define DEBUG_TYPE "complex-deinterleaving"
84STATISTIC(NumComplexTransformations,
"Amount of complex patterns transformed");
87 "enable-complex-deinterleaving",
115 Value *Real =
nullptr;
116 Value *Imag =
nullptr;
119 return Real ==
Other.Real && Imag ==
Other.Imag;
142 static bool isEqual(
const ComplexValue &LHS,
const ComplexValue &RHS) {
143 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
148template <
typename T,
typename IterT>
149std::optional<T> findCommonBetweenCollections(IterT
A, IterT
B) {
151 if (Common !=
A.end())
152 return std::make_optional(*Common);
156class ComplexDeinterleavingLegacyPass :
public FunctionPass {
160 ComplexDeinterleavingLegacyPass(
const TargetMachine *TM =
nullptr)
161 : FunctionPass(ID), TM(TM) {}
163 StringRef getPassName()
const override {
164 return "Complex Deinterleaving Pass";
168 void getAnalysisUsage(AnalysisUsage &AU)
const override {
174 const TargetMachine *TM;
177class ComplexDeinterleavingGraph;
178struct ComplexDeinterleavingCompositeNode {
183 Vals.push_back({
R,
I});
188 : Operation(
Op), Vals(
Other) {}
191 friend class ComplexDeinterleavingGraph;
192 using CompositeNode = ComplexDeinterleavingCompositeNode;
193 bool OperandsValid =
true;
202 std::optional<FastMathFlags> Flags;
205 ComplexDeinterleavingRotation::Rotation_0;
207 Value *ReplacementNode =
nullptr;
211 OperandsValid =
false;
212 Operands.push_back(Node);
216 void dump(raw_ostream &OS) {
217 auto PrintValue = [&](
Value *
V) {
225 auto PrintNodeRef = [&](CompositeNode *Ptr) {
232 OS <<
"- CompositeNode: " <<
this <<
"\n";
233 for (
unsigned I = 0;
I < Vals.size();
I++) {
234 OS <<
" Real(" <<
I <<
") : ";
235 PrintValue(Vals[
I].Real);
236 OS <<
" Imag(" <<
I <<
") : ";
237 PrintValue(Vals[
I].Imag);
239 OS <<
" ReplacementNode: ";
240 PrintValue(ReplacementNode);
241 OS <<
" Operation: " << (int)Operation <<
"\n";
242 OS <<
" Rotation: " << ((int)Rotation * 90) <<
"\n";
243 OS <<
" Operands: \n";
244 for (
const auto &
Op : Operands) {
250 bool areOperandsValid() {
return OperandsValid; }
253class ComplexDeinterleavingGraph {
261 using Addend = std::pair<Value *, bool>;
263 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
267 struct PartialMulCandidate {
275 explicit ComplexDeinterleavingGraph(
const TargetLowering *TL,
276 const TargetLibraryInfo *TLI,
278 : TL(TL), TLI(TLI), Factor(Factor) {}
281 const TargetLowering *TL =
nullptr;
282 const TargetLibraryInfo *TLI =
nullptr;
285 DenseMap<ComplexValues, CompositeNode *> CachedResult;
286 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
288 SmallPtrSet<Instruction *, 16> FinalInstructions;
291 DenseMap<Instruction *, CompositeNode *> RootToNode;
318 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
326 PHINode *RealPHI =
nullptr;
327 PHINode *ImagPHI =
nullptr;
331 bool PHIsFound =
false;
339 DenseMap<PHINode *, PHINode *> OldToNewPHI;
344 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
346 "Reduction related nodes must have Real and Imaginary parts");
347 return new (Allocator.Allocate())
348 ComplexDeinterleavingCompositeNode(
Operation, R,
I);
354 for (
auto &V : Vals) {
356 ((
Operation != ComplexDeinterleavingOperation::ReductionPHI &&
357 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
358 (
V.Real &&
V.Imag)) &&
359 "Reduction related nodes must have Real and Imaginary parts");
362 return new (Allocator.Allocate())
363 ComplexDeinterleavingCompositeNode(
Operation, Vals);
366 CompositeNode *submitCompositeNode(CompositeNode *Node) {
367 CompositeNodes.push_back(Node);
368 if (
Node->Vals[0].Real)
384 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
390 identifyNodeWithImplicitAdd(Instruction *
I, Instruction *J,
391 std::pair<Value *, Value *> &CommonOperandI);
400 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
401 CompositeNode *identifySymmetricOperation(
ComplexValues &Vals);
402 CompositeNode *identifyPartialReduction(
Value *R,
Value *
I);
403 CompositeNode *identifyDotProduct(
Value *Inst);
410 return identifyNode(Vals);
417 CompositeNode *identifyAdditions(AddendList &RealAddends,
418 AddendList &ImagAddends,
419 std::optional<FastMathFlags> Flags,
423 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
424 AddendList &ImagAddends);
429 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
430 SmallVectorImpl<Product> &ImagMuls,
438 SmallVectorImpl<PartialMulCandidate> &Candidates);
446 CompositeNode *identifyReassocNodes(Instruction *
I, Instruction *J);
448 CompositeNode *identifyRoot(Instruction *
I);
466 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
470 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
472 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
479 void processReductionOperation(
Value *OperationReplacement,
480 CompositeNode *Node);
481 void processReductionSingle(
Value *OperationReplacement, CompositeNode *Node);
485 void dump(raw_ostream &OS) {
486 for (
const auto &Node : CompositeNodes)
492 bool identifyNodes(Instruction *RootI);
497 bool collectPotentialReductions(BasicBlock *
B);
499 void identifyReductionNodes();
509class ComplexDeinterleaving {
511 ComplexDeinterleaving(
const TargetLowering *tl,
const TargetLibraryInfo *tli)
512 : TL(tl), TLI(tli) {}
516 bool evaluateBasicBlock(BasicBlock *
B,
unsigned Factor);
518 const TargetLowering *TL =
nullptr;
519 const TargetLibraryInfo *TLI =
nullptr;
524char ComplexDeinterleavingLegacyPass::ID = 0;
527 "Complex Deinterleaving",
false,
false)
533 const TargetLowering *TL = TM->getSubtargetImpl(
F)->getTargetLowering();
544 return new ComplexDeinterleavingLegacyPass(TM);
547bool ComplexDeinterleavingLegacyPass::runOnFunction(
Function &
F) {
549 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
550 return ComplexDeinterleaving(TL, &TLI).runOnFunction(
F);
553bool ComplexDeinterleaving::runOnFunction(Function &
F) {
556 dbgs() <<
"Complex deinterleaving has been explicitly disabled.\n");
562 dbgs() <<
"Complex deinterleaving has been disabled, target does "
563 "not support lowering of complex number operations.\n");
569 Changed |= evaluateBasicBlock(&
B, 2);
574 Changed |= evaluateBasicBlock(&
B, 4);
584 if ((Mask.size() & 1))
587 int HalfNumElements = Mask.size() / 2;
588 for (
int Idx = 0; Idx < HalfNumElements; ++Idx) {
589 int MaskIdx = Idx * 2;
590 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
599 int HalfNumElements = Mask.size() / 2;
601 for (
int Idx = 1; Idx < HalfNumElements; ++Idx) {
602 if (Mask[Idx] != (Idx * 2) +
Offset)
616 if (
I->getOpcode() == Instruction::FNeg)
617 return I->getOperand(0);
619 return I->getOperand(1);
622bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *
B,
unsigned Factor) {
623 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
624 if (Graph.collectPotentialReductions(
B))
625 Graph.identifyReductionNodes();
628 Graph.identifyNodes(&
I);
630 if (Graph.checkNodes()) {
631 Graph.replaceNodes();
638ComplexDeinterleavingGraph::CompositeNode *
639ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
640 Instruction *Real, Instruction *Imag,
641 std::pair<Value *, Value *> &PartialMatch) {
642 LLVM_DEBUG(
dbgs() <<
"identifyNodeWithImplicitAdd " << *Real <<
" / " << *Imag
650 if ((Real->
getOpcode() != Instruction::FMul &&
651 Real->
getOpcode() != Instruction::Mul) ||
652 (Imag->
getOpcode() != Instruction::FMul &&
653 Imag->
getOpcode() != Instruction::Mul)) {
655 dbgs() <<
" - Real or imaginary instruction is not fmul or mul\n");
688 Value *CommonOperand;
689 Value *UncommonRealOp;
690 Value *UncommonImagOp;
692 if (R0 == I0 || R0 == I1) {
695 }
else if (R1 == I0 || R1 == I1) {
703 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
704 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
705 Rotation == ComplexDeinterleavingRotation::Rotation_270)
706 std::swap(UncommonRealOp, UncommonImagOp);
710 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
711 Rotation == ComplexDeinterleavingRotation::Rotation_180)
712 PartialMatch.first = CommonOperand;
714 PartialMatch.second = CommonOperand;
716 if (!PartialMatch.first || !PartialMatch.second) {
721 CompositeNode *CommonNode =
722 identifyNode(PartialMatch.first, PartialMatch.second);
728 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
734 CompositeNode *
Node = prepareCompositeNode(
735 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
736 Node->Rotation = Rotation;
737 Node->addOperand(CommonNode);
738 Node->addOperand(UncommonNode);
739 return submitCompositeNode(Node);
742ComplexDeinterleavingGraph::CompositeNode *
743ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
745 LLVM_DEBUG(
dbgs() <<
"identifyPartialMul " << *Real <<
" / " << *Imag
749 auto IsAdd = [](
unsigned Op) {
750 return Op == Instruction::FAdd ||
Op == Instruction::Add;
752 auto IsSub = [](
unsigned Op) {
753 return Op == Instruction::FSub ||
Op == Instruction::Sub;
757 Rotation = ComplexDeinterleavingRotation::Rotation_0;
759 Rotation = ComplexDeinterleavingRotation::Rotation_90;
761 Rotation = ComplexDeinterleavingRotation::Rotation_180;
763 Rotation = ComplexDeinterleavingRotation::Rotation_270;
772 LLVM_DEBUG(
dbgs() <<
" - Contract is missing from the FastMath flags.\n");
795 Value *CommonOperand;
796 Value *UncommonRealOp;
797 Value *UncommonImagOp;
799 if (R0 == I0 || R0 == I1) {
802 }
else if (R1 == I0 || R1 == I1) {
810 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
811 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
812 Rotation == ComplexDeinterleavingRotation::Rotation_270)
813 std::swap(UncommonRealOp, UncommonImagOp);
815 std::pair<Value *, Value *> PartialMatch(
816 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
817 Rotation == ComplexDeinterleavingRotation::Rotation_180)
820 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
821 Rotation == ComplexDeinterleavingRotation::Rotation_270)
828 if (!CRInst || !CIInst) {
829 LLVM_DEBUG(
dbgs() <<
" - Common operands are not instructions.\n");
833 CompositeNode *CNode =
834 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
840 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
846 assert(PartialMatch.first && PartialMatch.second);
847 CompositeNode *CommonRes =
848 identifyNode(PartialMatch.first, PartialMatch.second);
854 CompositeNode *
Node = prepareCompositeNode(
855 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
856 Node->Rotation = Rotation;
857 Node->addOperand(CommonRes);
858 Node->addOperand(UncommonRes);
859 Node->addOperand(CNode);
860 return submitCompositeNode(Node);
863ComplexDeinterleavingGraph::CompositeNode *
864ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
865 LLVM_DEBUG(
dbgs() <<
"identifyAdd " << *Real <<
" / " << *Imag <<
"\n");
869 if ((Real->
getOpcode() == Instruction::FSub &&
870 Imag->
getOpcode() == Instruction::FAdd) ||
871 (Real->
getOpcode() == Instruction::Sub &&
873 Rotation = ComplexDeinterleavingRotation::Rotation_90;
874 else if ((Real->
getOpcode() == Instruction::FAdd &&
875 Imag->
getOpcode() == Instruction::FSub) ||
876 (Real->
getOpcode() == Instruction::Add &&
878 Rotation = ComplexDeinterleavingRotation::Rotation_270;
880 LLVM_DEBUG(
dbgs() <<
" - Unhandled case, rotation is not assigned.\n");
889 if (!AR || !AI || !BR || !BI) {
894 CompositeNode *ResA = identifyNode(AR, AI);
896 LLVM_DEBUG(
dbgs() <<
" - AR/AI is not identified as a composite node.\n");
899 CompositeNode *ResB = identifyNode(BR, BI);
901 LLVM_DEBUG(
dbgs() <<
" - BR/BI is not identified as a composite node.\n");
905 CompositeNode *
Node =
906 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
907 Node->Rotation = Rotation;
908 Node->addOperand(ResA);
909 Node->addOperand(ResB);
910 return submitCompositeNode(Node);
914 unsigned OpcA =
A->getOpcode();
915 unsigned OpcB =
B->getOpcode();
917 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
918 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
919 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
920 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
931 switch (
I->getOpcode()) {
932 case Instruction::FAdd:
933 case Instruction::FSub:
934 case Instruction::FMul:
935 case Instruction::FNeg:
936 case Instruction::Add:
937 case Instruction::Sub:
938 case Instruction::Mul:
945ComplexDeinterleavingGraph::CompositeNode *
946ComplexDeinterleavingGraph::identifySymmetricOperation(
ComplexValues &Vals) {
948 unsigned FirstOpc = FirstReal->getOpcode();
949 for (
auto &V : Vals) {
966 for (
auto &V : Vals) {
972 CompositeNode *Op0 = identifyNode(OpVals);
973 CompositeNode *Op1 =
nullptr;
977 if (FirstReal->isBinaryOp()) {
979 for (
auto &V : Vals) {
984 Op1 = identifyNode(OpVals);
990 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
991 Node->Opcode = FirstReal->getOpcode();
993 Node->Flags = FirstReal->getFastMathFlags();
995 Node->addOperand(Op0);
996 if (FirstReal->isBinaryOp())
997 Node->addOperand(Op1);
999 return submitCompositeNode(Node);
1002ComplexDeinterleavingGraph::CompositeNode *
1003ComplexDeinterleavingGraph::identifyDotProduct(
Value *V) {
1005 ComplexDeinterleavingOperation::CDot,
V->getType())) {
1006 LLVM_DEBUG(
dbgs() <<
"Target doesn't support complex deinterleaving "
1007 "operation CDot with the type "
1008 << *
V->getType() <<
"\n");
1016 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst,
nullptr);
1018 CompositeNode *ANode =
nullptr;
1020 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1022 Value *AReal =
nullptr;
1023 Value *AImag =
nullptr;
1024 Value *BReal =
nullptr;
1025 Value *BImag =
nullptr;
1030 return CI->getOperand(0);
1044 if (
match(Inst, PatternRot0)) {
1045 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1046 }
else if (
match(Inst, PatternRot270)) {
1047 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1058 if (!
match(Inst, PatternRot90Rot180))
1061 A0 = UnwrapCast(A0);
1062 A1 = UnwrapCast(A1);
1065 ANode = identifyNode(A0, A1);
1068 ANode = identifyNode(A1, A0);
1072 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1078 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1082 AReal = UnwrapCast(AReal);
1083 AImag = UnwrapCast(AImag);
1084 BReal = UnwrapCast(BReal);
1085 BImag = UnwrapCast(BImag);
1088 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1089 if (AReal->
getType() != ExpectedOperandTy)
1091 if (AImag->
getType() != ExpectedOperandTy)
1093 if (BReal->
getType() != ExpectedOperandTy)
1095 if (BImag->
getType() != ExpectedOperandTy)
1098 if (
Phi->getType() != VTy && RealUser->getType() != VTy)
1101 CompositeNode *
Node = identifyNode(AReal, AImag);
1106 if (ANode && Node != ANode) {
1109 <<
"Identified node is different from previously identified node. "
1110 "Unable to confidently generate a complex operation node\n");
1114 CN->addOperand(Node);
1115 CN->addOperand(identifyNode(BReal, BImag));
1116 CN->addOperand(identifyNode(Phi, RealUser));
1118 return submitCompositeNode(CN);
1121ComplexDeinterleavingGraph::CompositeNode *
1122ComplexDeinterleavingGraph::identifyPartialReduction(
Value *R,
Value *
I) {
1127 if (!
R->hasUseList() || !
I->hasUseList())
1131 findCommonBetweenCollections<Value *>(
R->users(),
I->users());
1136 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1139 if (CompositeNode *CN = identifyDotProduct(IInst))
1145ComplexDeinterleavingGraph::CompositeNode *
1146ComplexDeinterleavingGraph::identifyNode(
ComplexValues &Vals) {
1147 auto It = CachedResult.
find(Vals);
1148 if (It != CachedResult.
end()) {
1153 if (Vals.
size() == 1) {
1154 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1157 if (CompositeNode *CN = identifyPartialReduction(R,
I))
1159 bool IsReduction = RealPHI ==
R && (!ImagPHI || ImagPHI ==
I);
1160 if (!IsReduction &&
R->getType() !=
I->getType())
1164 if (CompositeNode *CN = identifySplat(Vals))
1167 for (
auto &V : Vals) {
1174 if (CompositeNode *CN = identifyDeinterleave(Vals))
1177 if (Vals.size() == 1) {
1178 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1181 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1184 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1188 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1191 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1193 ComplexDeinterleavingOperation::CAdd, NewVTy);
1196 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1201 if (CompositeNode *CN = identifyAdd(Real, Imag))
1205 if (HasCMulSupport && HasCAddSupport) {
1206 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1212 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1216 CachedResult[Vals] =
nullptr;
1220ComplexDeinterleavingGraph::CompositeNode *
1221ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1222 Instruction *Imag) {
1223 auto IsOperationSupported = [](
unsigned Opcode) ->
bool {
1224 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1225 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1226 Opcode == Instruction::Sub;
1229 if (!IsOperationSupported(Real->
getOpcode()) ||
1230 !IsOperationSupported(Imag->
getOpcode()))
1233 std::optional<FastMathFlags>
Flags;
1236 LLVM_DEBUG(
dbgs() <<
"The flags in Real and Imaginary instructions are "
1242 if (!
Flags->allowReassoc()) {
1245 <<
"the 'Reassoc' attribute is missing in the FastMath flags\n");
1254 AddendList &Addends) ->
bool {
1256 SmallPtrSet<Value *, 8> Visited;
1257 while (!Worklist.
empty()) {
1259 if (!Visited.
insert(V).second)
1264 Addends.emplace_back(V, IsPositive);
1274 if (
I != Insn &&
I->hasNUsesOrMore(2)) {
1275 LLVM_DEBUG(
dbgs() <<
"Found potential sub-expression: " << *
I <<
"\n");
1276 Addends.emplace_back(
I, IsPositive);
1279 switch (
I->getOpcode()) {
1280 case Instruction::FAdd:
1281 case Instruction::Add:
1285 case Instruction::FSub:
1289 case Instruction::Sub:
1297 case Instruction::FMul:
1298 case Instruction::Mul: {
1300 if (
isNeg(
I->getOperand(0))) {
1302 IsPositive = !IsPositive;
1304 A =
I->getOperand(0);
1307 if (
isNeg(
I->getOperand(1))) {
1309 IsPositive = !IsPositive;
1311 B =
I->getOperand(1);
1313 Muls.push_back(Product{
A,
B, IsPositive});
1316 case Instruction::FNeg:
1320 Addends.emplace_back(
I, IsPositive);
1324 if (Flags &&
I->getFastMathFlags() != *Flags) {
1326 "inconsistent with the root instructions' flags: "
1335 AddendList RealAddends, ImagAddends;
1336 if (!Collect(Real, RealMuls, RealAddends) ||
1337 !Collect(Imag, ImagMuls, ImagAddends))
1340 if (RealAddends.size() != ImagAddends.size())
1343 CompositeNode *FinalNode =
nullptr;
1344 if (!RealMuls.
empty() || !ImagMuls.
empty()) {
1347 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1348 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1354 if (!RealAddends.empty() || !ImagAddends.empty()) {
1355 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1359 assert(FinalNode &&
"FinalNode can not be nullptr here");
1360 assert(FinalNode->Vals.size() == 1);
1362 FinalNode->Vals[0].Real = Real;
1363 FinalNode->Vals[0].Imag = Imag;
1364 submitCompositeNode(FinalNode);
1368bool ComplexDeinterleavingGraph::collectPartialMuls(
1370 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1372 auto FindCommonInstruction = [](
const Product &Real,
1373 const Product &Imag) ->
Value * {
1374 if (Real.Multiplicand == Imag.Multiplicand ||
1375 Real.Multiplicand == Imag.Multiplier)
1376 return Real.Multiplicand;
1378 if (Real.Multiplier == Imag.Multiplicand ||
1379 Real.Multiplier == Imag.Multiplier)
1380 return Real.Multiplier;
1389 for (
unsigned i = 0; i < RealMuls.
size(); ++i) {
1390 bool FoundCommon =
false;
1391 for (
unsigned j = 0;
j < ImagMuls.
size(); ++
j) {
1392 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1396 auto *
A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1397 : RealMuls[i].Multiplicand;
1398 auto *
B = ImagMuls[
j].Multiplicand == Common ? ImagMuls[
j].Multiplier
1399 : ImagMuls[
j].Multiplicand;
1401 auto Node = identifyNode(
A,
B);
1407 Node = identifyNode(
B,
A);
1419ComplexDeinterleavingGraph::CompositeNode *
1420ComplexDeinterleavingGraph::identifyMultiplications(
1421 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1423 if (RealMuls.
size() != ImagMuls.
size())
1427 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1431 DenseMap<Value *, CompositeNode *> CommonToNode;
1432 SmallVector<bool> Processed(
Info.size(),
false);
1433 for (
unsigned I = 0;
I <
Info.size(); ++
I) {
1437 PartialMulCandidate &InfoA =
Info[
I];
1438 for (
unsigned J =
I + 1; J <
Info.size(); ++J) {
1442 PartialMulCandidate &InfoB =
Info[J];
1443 auto *InfoReal = &InfoA;
1444 auto *InfoImag = &InfoB;
1446 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1447 if (!NodeFromCommon) {
1449 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1451 if (!NodeFromCommon)
1454 CommonToNode[InfoReal->Common] = NodeFromCommon;
1455 CommonToNode[InfoImag->Common] = NodeFromCommon;
1456 Processed[
I] =
true;
1457 Processed[J] =
true;
1461 SmallVector<bool> ProcessedReal(RealMuls.
size(),
false);
1462 SmallVector<bool> ProcessedImag(ImagMuls.
size(),
false);
1464 for (
auto &PMI : Info) {
1465 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1468 auto It = CommonToNode.
find(PMI.Common);
1471 if (It == CommonToNode.
end()) {
1473 dbgs() <<
"Unprocessed independent partial multiplication:\n";
1474 for (
auto *
Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1476 <<
" multiplied by " << *
Mul->Multiplicand <<
"\n";
1481 auto &RealMul = RealMuls[PMI.RealIdx];
1482 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1484 auto NodeA = It->second;
1485 auto NodeB = PMI.Node;
1486 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1501 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1502 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1507 if (IsMultiplicandReal) {
1509 if (RealMul.IsPositive && ImagMul.IsPositive)
1511 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1518 if (!RealMul.IsPositive && ImagMul.IsPositive)
1520 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1527 dbgs() <<
"Identified partial multiplication (X, Y) * (U, V):\n";
1528 dbgs().
indent(4) <<
"X: " << *NodeA->Vals[0].Real <<
"\n";
1529 dbgs().
indent(4) <<
"Y: " << *NodeA->Vals[0].Imag <<
"\n";
1530 dbgs().
indent(4) <<
"U: " << *NodeB->Vals[0].Real <<
"\n";
1531 dbgs().
indent(4) <<
"V: " << *NodeB->Vals[0].Imag <<
"\n";
1532 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1535 CompositeNode *NodeMul = prepareCompositeNode(
1536 ComplexDeinterleavingOperation::CMulPartial,
nullptr,
nullptr);
1537 NodeMul->Rotation = Rotation;
1538 NodeMul->addOperand(NodeA);
1539 NodeMul->addOperand(NodeB);
1541 NodeMul->addOperand(Result);
1542 submitCompositeNode(NodeMul);
1544 ProcessedReal[PMI.RealIdx] =
true;
1545 ProcessedImag[PMI.ImagIdx] =
true;
1549 if (!
all_of(ProcessedReal, [](
bool V) {
return V; }) ||
1550 !
all_of(ProcessedImag, [](
bool V) {
return V; })) {
1555 dbgs() <<
"Unprocessed products (Real):\n";
1556 for (
size_t i = 0; i < ProcessedReal.size(); ++i) {
1557 if (!ProcessedReal[i])
1558 dbgs().
indent(4) << (RealMuls[i].IsPositive ?
"+" :
"-")
1559 << *RealMuls[i].Multiplier <<
" multiplied by "
1560 << *RealMuls[i].Multiplicand <<
"\n";
1562 dbgs() <<
"Unprocessed products (Imag):\n";
1563 for (
size_t i = 0; i < ProcessedImag.size(); ++i) {
1564 if (!ProcessedImag[i])
1565 dbgs().
indent(4) << (ImagMuls[i].IsPositive ?
"+" :
"-")
1566 << *ImagMuls[i].Multiplier <<
" multiplied by "
1567 << *ImagMuls[i].Multiplicand <<
"\n";
1576ComplexDeinterleavingGraph::CompositeNode *
1577ComplexDeinterleavingGraph::identifyAdditions(
1578 AddendList &RealAddends, AddendList &ImagAddends,
1579 std::optional<FastMathFlags> Flags, CompositeNode *
Accumulator =
nullptr) {
1580 if (RealAddends.size() != ImagAddends.size())
1583 CompositeNode *
Result =
nullptr;
1589 Result = extractPositiveAddend(RealAddends, ImagAddends);
1594 while (!RealAddends.empty()) {
1595 auto ItR = RealAddends.begin();
1596 auto [
R, IsPositiveR] = *ItR;
1598 bool FoundImag =
false;
1599 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1600 auto [
I, IsPositiveI] = *ItI;
1602 if (IsPositiveR && IsPositiveI)
1603 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1604 else if (!IsPositiveR && IsPositiveI)
1605 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1606 else if (!IsPositiveR && !IsPositiveI)
1607 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1609 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1611 CompositeNode *AddNode =
nullptr;
1612 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1613 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1614 AddNode = identifyNode(R,
I);
1616 AddNode = identifyNode(
I, R);
1620 dbgs() <<
"Identified addition:\n";
1623 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1626 CompositeNode *TmpNode =
nullptr;
1628 TmpNode = prepareCompositeNode(
1629 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1631 TmpNode->Opcode = Instruction::FAdd;
1632 TmpNode->Flags = *
Flags;
1634 TmpNode->Opcode = Instruction::Add;
1636 }
else if (Rotation ==
1638 TmpNode = prepareCompositeNode(
1639 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1641 TmpNode->Opcode = Instruction::FSub;
1642 TmpNode->Flags = *
Flags;
1644 TmpNode->Opcode = Instruction::Sub;
1647 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1649 TmpNode->Rotation = Rotation;
1652 TmpNode->addOperand(Result);
1653 TmpNode->addOperand(AddNode);
1654 submitCompositeNode(TmpNode);
1656 RealAddends.erase(ItR);
1657 ImagAddends.erase(ItI);
1668ComplexDeinterleavingGraph::CompositeNode *
1669ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1670 AddendList &ImagAddends) {
1671 for (
auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1672 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1673 auto [
R, IsPositiveR] = *ItR;
1674 auto [
I, IsPositiveI] = *ItI;
1675 if (IsPositiveR && IsPositiveI) {
1676 auto Result = identifyNode(R,
I);
1678 RealAddends.erase(ItR);
1679 ImagAddends.erase(ItI);
1688bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1693 auto It = RootToNode.
find(RootI);
1694 if (It != RootToNode.
end()) {
1695 auto RootNode = It->second;
1696 assert(RootNode->Operation ==
1697 ComplexDeinterleavingOperation::ReductionOperation ||
1698 RootNode->Operation ==
1699 ComplexDeinterleavingOperation::ReductionSingle);
1700 assert(RootNode->Vals.size() == 1 &&
1701 "Cannot handle reductions involving multiple complex values");
1710 ReplacementAnchor =
R->comesBefore(
I) ?
I :
R;
1712 ReplacementAnchor =
R;
1714 if (ReplacementAnchor != RootI)
1720 auto RootNode = identifyRoot(RootI);
1727 dbgs() <<
"Complex deinterleaving graph for " <<
F->getName()
1728 <<
"::" <<
B->getName() <<
".\n";
1732 RootToNode[RootI] = RootNode;
1737bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *
B) {
1738 bool FoundPotentialReduction =
false;
1743 if (!Br || Br->getNumSuccessors() != 2)
1747 if (Br->getSuccessor(0) !=
B && Br->getSuccessor(1) !=
B)
1750 for (
auto &
PHI :
B->phis()) {
1751 if (
PHI.getNumIncomingValues() != 2)
1754 if (!
PHI.getType()->isVectorTy())
1764 for (
auto *U : ReductionOp->users()) {
1771 if (NumUsers != 2 || !FinalReduction || FinalReduction->
getParent() ==
B ||
1775 ReductionInfo[ReductionOp] = {&
PHI, FinalReduction};
1777 auto BackEdgeIdx =
PHI.getBasicBlockIndex(
B);
1778 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1779 Incoming =
PHI.getIncomingBlock(IncomingIdx);
1780 FoundPotentialReduction =
true;
1786 FinalInstructions.
insert(InitPHI);
1788 return FoundPotentialReduction;
1791void ComplexDeinterleavingGraph::identifyReductionNodes() {
1792 assert(Factor == 2 &&
"Cannot handle multiple complex values");
1794 SmallVector<bool> Processed(ReductionInfo.
size(),
false);
1796 for (
auto &
P : ReductionInfo)
1801 for (
size_t i = 0; i < OperationInstruction.
size(); ++i) {
1804 for (
size_t j = i + 1;
j < OperationInstruction.
size(); ++
j) {
1807 auto *Real = OperationInstruction[i];
1808 auto *Imag = OperationInstruction[
j];
1809 if (Real->getType() != Imag->
getType())
1812 RealPHI = ReductionInfo[Real].first;
1813 ImagPHI = ReductionInfo[Imag].first;
1815 auto Node = identifyNode(Real, Imag);
1819 Node = identifyNode(Real, Imag);
1825 if (Node && PHIsFound) {
1826 LLVM_DEBUG(
dbgs() <<
"Identified reduction starting from instructions: "
1827 << *Real <<
" / " << *Imag <<
"\n");
1828 Processed[i] =
true;
1829 Processed[
j] =
true;
1830 auto RootNode = prepareCompositeNode(
1831 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1832 RootNode->addOperand(Node);
1833 RootToNode[Real] = RootNode;
1834 RootToNode[Imag] = RootNode;
1835 submitCompositeNode(RootNode);
1840 auto *Real = OperationInstruction[i];
1843 if (Processed[i] || Real->getNumOperands() < 2)
1847 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1850 RealPHI = ReductionInfo[Real].first;
1853 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1854 if (Node && PHIsFound) {
1856 dbgs() <<
"Identified single reduction starting from instruction: "
1857 << *Real <<
"/" << *ReductionInfo[Real].second <<
"\n");
1866 if (ReductionInfo[Real].second->getType()->isVectorTy())
1869 Processed[i] =
true;
1870 auto RootNode = prepareCompositeNode(
1871 ComplexDeinterleavingOperation::ReductionSingle, Real,
nullptr);
1872 RootNode->addOperand(Node);
1873 RootToNode[Real] = RootNode;
1874 submitCompositeNode(RootNode);
1882bool ComplexDeinterleavingGraph::checkNodes() {
1883 bool FoundDeinterleaveNode =
false;
1884 for (CompositeNode *
N : CompositeNodes) {
1885 if (!
N->areOperandsValid())
1888 if (
N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1889 FoundDeinterleaveNode =
true;
1894 if (!FoundDeinterleaveNode) {
1896 dbgs() <<
"Couldn't find a deinterleave node within the graph, cannot "
1897 "guarantee safety during graph transformation.\n");
1902 SmallPtrSet<Instruction *, 16> AllInstructions;
1903 SmallVector<Instruction *, 8> Worklist;
1904 for (
auto &Pair : RootToNode)
1909 while (!Worklist.
empty()) {
1912 if (!AllInstructions.
insert(
I).second)
1917 if (!FinalInstructions.
count(
I))
1924 for (
auto *
I : AllInstructions) {
1926 if (RootToNode.count(
I))
1929 for (User *U :
I->users()) {
1941 SmallPtrSet<Instruction *, 16> Visited;
1942 while (!Worklist.
empty()) {
1944 if (!Visited.
insert(
I).second)
1949 if (RootToNode.count(
I)) {
1951 <<
" could be deinterleaved but its chain of complex "
1952 "operations have an outside user\n");
1953 RootToNode.erase(
I);
1956 if (!AllInstructions.count(
I) || FinalInstructions.
count(
I))
1959 for (User *U :
I->users())
1967 return !RootToNode.
empty();
1970ComplexDeinterleavingGraph::CompositeNode *
1971ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1978 for (
unsigned I = 0;
I < Factor;
I += 2) {
1986 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
2014 return identifyNode(Real, Imag);
2017ComplexDeinterleavingGraph::CompositeNode *
2018ComplexDeinterleavingGraph::identifyDeinterleave(
ComplexValues &Vals) {
2022 auto CheckExtract = [&](
Value *
V,
unsigned ExpectedIdx,
2023 Instruction *ExpectedInsn) -> ExtractValueInst * {
2025 if (!EVI || EVI->getNumIndices() != 1 ||
2026 EVI->getIndices()[0] != ExpectedIdx ||
2028 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2033 for (
unsigned Idx = 0; Idx < Vals.
size(); Idx++) {
2034 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2,
II);
2035 if (RealEVI && Idx == 0)
2037 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1,
II)) {
2044 if (IntrinsicII->getIntrinsicID() !=
2049 CompositeNode *PlaceholderNode = prepareCompositeNode(
2051 PlaceholderNode->ReplacementNode =
II->getOperand(0);
2052 for (
auto &V : Vals) {
2056 return submitCompositeNode(PlaceholderNode);
2059 if (Vals.size() != 1)
2062 Value *Real = Vals[0].Real;
2063 Value *Imag = Vals[0].Imag;
2066 if (!RealShuffle || !ImagShuffle) {
2067 if (RealShuffle || ImagShuffle)
2068 LLVM_DEBUG(
dbgs() <<
" - There's a shuffle where there shouldn't be.\n");
2072 Value *RealOp1 = RealShuffle->getOperand(1);
2077 Value *ImagOp1 = ImagShuffle->getOperand(1);
2083 Value *RealOp0 = RealShuffle->getOperand(0);
2084 Value *ImagOp0 = ImagShuffle->getOperand(0);
2086 if (RealOp0 != ImagOp0) {
2091 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2092 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2098 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2099 LLVM_DEBUG(
dbgs() <<
" - Masks do not have the correct initial value.\n");
2105 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2106 Value *
Op = Shuffle->getOperand(0);
2110 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2112 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2118 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) ->
bool {
2122 ArrayRef<int>
Mask = Shuffle->getShuffleMask();
2125 Value *
Op = Shuffle->getOperand(0);
2127 int NumElements = OpTy->getNumElements();
2131 return Last < NumElements;
2134 if (RealShuffle->getType() != ImagShuffle->getType()) {
2138 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2142 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2147 CompositeNode *PlaceholderNode =
2149 RealShuffle, ImagShuffle);
2150 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2151 FinalInstructions.
insert(RealShuffle);
2152 FinalInstructions.
insert(ImagShuffle);
2153 return submitCompositeNode(PlaceholderNode);
2156ComplexDeinterleavingGraph::CompositeNode *
2157ComplexDeinterleavingGraph::identifySplat(
ComplexValues &Vals) {
2158 auto IsSplat = [](
Value *
V) ->
bool {
2171 if (
Const->getOpcode() != Instruction::ShuffleVector)
2176 VTy = Shuf->getType();
2177 Mask = Shuf->getShuffleMask();
2185 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2195 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2196 for (
auto &V : Vals) {
2197 if (!IsSplat(
V.Real) || !IsSplat(
V.Imag))
2202 if (!Real || !Imag || Real->getParent() != FirstBB ||
2203 Imag->getParent() != FirstBB)
2207 for (
auto &V : Vals) {
2214 for (
auto &V : Vals) {
2218 FinalInstructions.
insert(Real);
2219 FinalInstructions.
insert(Imag);
2222 CompositeNode *PlaceholderNode =
2223 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2224 return submitCompositeNode(PlaceholderNode);
2227ComplexDeinterleavingGraph::CompositeNode *
2228ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2229 Instruction *Imag) {
2230 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2234 CompositeNode *PlaceholderNode = prepareCompositeNode(
2235 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2236 return submitCompositeNode(PlaceholderNode);
2239ComplexDeinterleavingGraph::CompositeNode *
2240ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2241 Instruction *Imag) {
2244 if (!SelectReal || !SelectImag)
2261 auto NodeA = identifyNode(AR, AI);
2265 auto NodeB = identifyNode(
RA, BI);
2269 CompositeNode *PlaceholderNode = prepareCompositeNode(
2270 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2271 PlaceholderNode->addOperand(NodeA);
2272 PlaceholderNode->addOperand(NodeB);
2273 FinalInstructions.
insert(MaskA);
2274 FinalInstructions.
insert(MaskB);
2275 return submitCompositeNode(PlaceholderNode);
2279 std::optional<FastMathFlags> Flags,
2283 case Instruction::FNeg:
2284 I =
B.CreateFNeg(InputA);
2286 case Instruction::FAdd:
2287 I =
B.CreateFAdd(InputA, InputB);
2289 case Instruction::Add:
2290 I =
B.CreateAdd(InputA, InputB);
2292 case Instruction::FSub:
2293 I =
B.CreateFSub(InputA, InputB);
2295 case Instruction::Sub:
2296 I =
B.CreateSub(InputA, InputB);
2298 case Instruction::FMul:
2299 I =
B.CreateFMul(InputA, InputB);
2301 case Instruction::Mul:
2302 I =
B.CreateMul(InputA, InputB);
2312Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2313 CompositeNode *Node) {
2314 if (
Node->ReplacementNode)
2315 return Node->ReplacementNode;
2317 auto ReplaceOperandIfExist = [&](CompositeNode *
Node,
2318 unsigned Idx) ->
Value * {
2319 return Node->Operands.size() > Idx
2320 ? replaceNode(Builder,
Node->Operands[Idx])
2324 Value *ReplacementNode =
nullptr;
2325 switch (
Node->Operation) {
2326 case ComplexDeinterleavingOperation::CDot: {
2327 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2328 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2331 "Node inputs need to be of the same type"));
2336 case ComplexDeinterleavingOperation::CAdd:
2337 case ComplexDeinterleavingOperation::CMulPartial:
2338 case ComplexDeinterleavingOperation::Symmetric: {
2339 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2340 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2343 "Node inputs need to be of the same type"));
2346 "Accumulator and input need to be of the same type"));
2347 if (
Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2352 Builder,
Node->Operation,
Node->Rotation, Input0, Input1,
2356 case ComplexDeinterleavingOperation::Deinterleave:
2359 case ComplexDeinterleavingOperation::Splat: {
2361 for (
auto &V :
Node->Vals) {
2362 Ops.push_back(
V.Real);
2363 Ops.push_back(
V.Imag);
2370 for (
auto V :
Node->Vals) {
2378 ReplacementNode = IRB.CreateVectorInterleave(
Ops);
2384 case ComplexDeinterleavingOperation::ReductionPHI: {
2389 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2391 OldToNewPHI[OldPHI] = NewPHI;
2392 ReplacementNode = NewPHI;
2395 case ComplexDeinterleavingOperation::ReductionSingle:
2396 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2397 processReductionSingle(ReplacementNode, Node);
2399 case ComplexDeinterleavingOperation::ReductionOperation:
2400 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2401 processReductionOperation(ReplacementNode, Node);
2403 case ComplexDeinterleavingOperation::ReductionSelect: {
2406 auto *
A = replaceNode(Builder,
Node->Operands[0]);
2407 auto *
B = replaceNode(Builder,
Node->Operands[1]);
2414 assert(ReplacementNode &&
"Target failed to create Intrinsic call.");
2415 NumComplexTransformations += 1;
2416 Node->ReplacementNode = ReplacementNode;
2417 return ReplacementNode;
2420void ComplexDeinterleavingGraph::processReductionSingle(
2421 Value *OperationReplacement, CompositeNode *Node) {
2423 auto *OldPHI = ReductionInfo[Real].first;
2424 auto *NewPHI = OldToNewPHI[OldPHI];
2426 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2428 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2432 Value *NewInit =
nullptr;
2434 if (
C->isZeroValue())
2442 NewPHI->addIncoming(NewInit, Incoming);
2443 NewPHI->addIncoming(OperationReplacement, BackEdge);
2445 auto *FinalReduction = ReductionInfo[Real].second;
2452void ComplexDeinterleavingGraph::processReductionOperation(
2453 Value *OperationReplacement, CompositeNode *Node) {
2456 auto *OldPHIReal = ReductionInfo[Real].first;
2457 auto *OldPHIImag = ReductionInfo[Imag].first;
2458 auto *NewPHI = OldToNewPHI[OldPHIReal];
2461 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2462 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2467 NewPHI->addIncoming(NewInit, Incoming);
2468 NewPHI->addIncoming(OperationReplacement, BackEdge);
2472 auto *FinalReductionReal = ReductionInfo[Real].second;
2473 auto *FinalReductionImag = ReductionInfo[Imag].second;
2476 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2478 OperationReplacement->
getType(),
2479 OperationReplacement);
2482 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2486 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2489void ComplexDeinterleavingGraph::replaceNodes() {
2490 SmallVector<Instruction *, 16> DeadInstrRoots;
2491 for (
auto *RootInstruction : OrderedRoots) {
2494 if (!RootToNode.count(RootInstruction))
2498 auto RootNode = RootToNode[RootInstruction];
2499 Value *
R = replaceNode(Builder, RootNode);
2501 if (RootNode->Operation ==
2502 ComplexDeinterleavingOperation::ReductionOperation) {
2505 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2506 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2509 }
else if (RootNode->Operation ==
2510 ComplexDeinterleavingOperation::ReductionSingle) {
2512 auto &
Info = ReductionInfo[RootInst];
2513 Info.first->removeIncomingValue(BackEdge);
2516 assert(R &&
"Unable to find replacement for RootInstruction");
2517 DeadInstrRoots.
push_back(RootInstruction);
2518 RootInstruction->replaceAllUsesWith(R);
2522 for (
auto *
I : DeadInstrRoots)
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
This file implements a map that provides insertion order iteration.
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
This file describes how to lower LLVM code to machine code.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
bool allowContract() const
FunctionPass class - This class is used to implement most global optimizations.
Common base class shared among various IRBuilders.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
An opaque object representing a hash code.
const ParentTy * getParent() const
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
@ BR
Control flow instructions. These all have token chains.
@ BasicBlock
Various leaf nodes.
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
FunctionAddr VTableAddr Value
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
ComplexDeinterleavingOperation
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
ComplexDeinterleavingRotation
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
ComplexDeinterleavingPass(const TargetMachine &TM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static ComplexValue getEmptyKey()
static unsigned getHashValue(const ComplexValue &Val)
static ComplexValue getTombstoneKey()
An information struct used to provide DenseMap with the various necessary components for a given valu...