77using namespace PatternMatch;
79#define DEBUG_TYPE "complex-deinterleaving"
81STATISTIC(NumComplexTransformations,
"Amount of complex patterns transformed");
84 "enable-complex-deinterleaving",
111template <
typename T,
typename IterT>
112std::optional<T> findCommonBetweenCollections(IterT
A, IterT
B) {
114 if (Common !=
A.end())
115 return std::make_optional(*Common);
119class ComplexDeinterleavingLegacyPass :
public FunctionPass {
123 ComplexDeinterleavingLegacyPass(
const TargetMachine *TM =
nullptr)
130 return "Complex Deinterleaving Pass";
143class ComplexDeinterleavingGraph;
144struct ComplexDeinterleavingCompositeNode {
151 friend class ComplexDeinterleavingGraph;
152 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
153 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
154 bool OperandsValid =
true;
164 std::optional<FastMathFlags>
Flags;
167 ComplexDeinterleavingRotation::Rotation_0;
169 Value *ReplacementNode =
nullptr;
173 OperandsValid =
false;
179 auto PrintValue = [&](
Value *
V) {
187 auto PrintNodeRef = [&](RawNodePtr
Ptr) {
194 OS <<
"- CompositeNode: " <<
this <<
"\n";
199 OS <<
" ReplacementNode: ";
200 PrintValue(ReplacementNode);
202 OS <<
" Rotation: " << ((int)Rotation * 90) <<
"\n";
203 OS <<
" Operands: \n";
210 bool areOperandsValid() {
return OperandsValid; }
213class ComplexDeinterleavingGraph {
221 using Addend = std::pair<Value *, bool>;
222 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
223 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
227 struct PartialMulCandidate {
237 : TL(TL), TLI(TLI) {}
248 std::map<Instruction *, NodePtr> RootToNode;
288 bool PHIsFound =
false;
296 std::map<PHINode *, PHINode *> OldToNewPHI;
301 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
303 "Reduction related nodes must have Real and Imaginary parts");
304 return std::make_shared<ComplexDeinterleavingCompositeNode>(
Operation, R,
308 NodePtr submitCompositeNode(NodePtr
Node) {
333 std::pair<Value *, Value *> &CommonOperandI);
344 NodePtr identifyPartialReduction(
Value *R,
Value *
I);
345 NodePtr identifyDotProduct(
Value *Inst);
353 NodePtr identifyAdditions(std::list<Addend> &RealAddends,
354 std::list<Addend> &ImagAddends,
355 std::optional<FastMathFlags> Flags,
359 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
360 std::list<Addend> &ImagAddends);
365 NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
366 std::vector<Product> &ImagMuls,
372 bool collectPartialMuls(
const std::vector<Product> &RealMuls,
373 const std::vector<Product> &ImagMuls,
374 std::vector<PartialMulCandidate> &Candidates);
399 NodePtr identifySplat(
Value *Real,
Value *Imag);
414 void processReductionOperation(
Value *OperationReplacement, RawNodePtr
Node);
415 void processReductionSingle(
Value *OperationReplacement, RawNodePtr
Node);
420 for (
const auto &
Node : CompositeNodes)
433 void identifyReductionNodes();
443class ComplexDeinterleaving {
446 : TL(tl), TLI(tli) {}
458char ComplexDeinterleavingLegacyPass::ID = 0;
461 "Complex Deinterleaving",
false,
false)
467 const TargetLowering *TL = TM->getSubtargetImpl(
F)->getTargetLowering();
469 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(
F))
478 return new ComplexDeinterleavingLegacyPass(TM);
481bool ComplexDeinterleavingLegacyPass::runOnFunction(
Function &
F) {
482 const auto *TL = TM->getSubtargetImpl(
F)->getTargetLowering();
483 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
484 return ComplexDeinterleaving(TL, &TLI).runOnFunction(
F);
487bool ComplexDeinterleaving::runOnFunction(
Function &
F) {
490 dbgs() <<
"Complex deinterleaving has been explicitly disabled.\n");
496 dbgs() <<
"Complex deinterleaving has been disabled, target does "
497 "not support lowering of complex number operations.\n");
501 bool Changed =
false;
503 Changed |= evaluateBasicBlock(&
B);
510 if ((Mask.size() & 1))
513 int HalfNumElements = Mask.size() / 2;
514 for (
int Idx = 0;
Idx < HalfNumElements; ++
Idx) {
515 int MaskIdx =
Idx * 2;
516 if (Mask[MaskIdx] !=
Idx || Mask[MaskIdx + 1] != (
Idx + HalfNumElements))
525 int HalfNumElements = Mask.size() / 2;
527 for (
int Idx = 1;
Idx < HalfNumElements; ++
Idx) {
541 auto *
I = cast<Instruction>(V);
542 if (
I->getOpcode() == Instruction::FNeg)
543 return I->getOperand(0);
545 return I->getOperand(1);
548bool ComplexDeinterleaving::evaluateBasicBlock(
BasicBlock *
B) {
549 ComplexDeinterleavingGraph Graph(TL, TLI);
550 if (Graph.collectPotentialReductions(
B))
551 Graph.identifyReductionNodes();
554 Graph.identifyNodes(&
I);
556 if (Graph.checkNodes()) {
557 Graph.replaceNodes();
564ComplexDeinterleavingGraph::NodePtr
565ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
567 std::pair<Value *, Value *> &PartialMatch) {
568 LLVM_DEBUG(
dbgs() <<
"identifyNodeWithImplicitAdd " << *Real <<
" / " << *Imag
576 if ((Real->
getOpcode() != Instruction::FMul &&
577 Real->
getOpcode() != Instruction::Mul) ||
578 (Imag->
getOpcode() != Instruction::FMul &&
579 Imag->
getOpcode() != Instruction::Mul)) {
581 dbgs() <<
" - Real or imaginary instruction is not fmul or mul\n");
614 Value *CommonOperand;
615 Value *UncommonRealOp;
616 Value *UncommonImagOp;
618 if (R0 == I0 || R0 == I1) {
621 }
else if (R1 == I0 || R1 == I1) {
629 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
630 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
631 Rotation == ComplexDeinterleavingRotation::Rotation_270)
632 std::swap(UncommonRealOp, UncommonImagOp);
636 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
637 Rotation == ComplexDeinterleavingRotation::Rotation_180)
638 PartialMatch.first = CommonOperand;
640 PartialMatch.second = CommonOperand;
642 if (!PartialMatch.first || !PartialMatch.second) {
647 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
653 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
659 NodePtr
Node = prepareCompositeNode(
660 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
661 Node->Rotation = Rotation;
662 Node->addOperand(CommonNode);
663 Node->addOperand(UncommonNode);
664 return submitCompositeNode(
Node);
667ComplexDeinterleavingGraph::NodePtr
668ComplexDeinterleavingGraph::identifyPartialMul(
Instruction *Real,
670 LLVM_DEBUG(
dbgs() <<
"identifyPartialMul " << *Real <<
" / " << *Imag
673 auto IsAdd = [](
unsigned Op) {
674 return Op == Instruction::FAdd ||
Op == Instruction::Add;
676 auto IsSub = [](
unsigned Op) {
677 return Op == Instruction::FSub ||
Op == Instruction::Sub;
681 Rotation = ComplexDeinterleavingRotation::Rotation_0;
683 Rotation = ComplexDeinterleavingRotation::Rotation_90;
685 Rotation = ComplexDeinterleavingRotation::Rotation_180;
687 Rotation = ComplexDeinterleavingRotation::Rotation_270;
693 if (isa<FPMathOperator>(Real) &&
696 LLVM_DEBUG(
dbgs() <<
" - Contract is missing from the FastMath flags.\n");
719 Value *CommonOperand;
720 Value *UncommonRealOp;
721 Value *UncommonImagOp;
723 if (R0 == I0 || R0 == I1) {
726 }
else if (R1 == I0 || R1 == I1) {
734 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
735 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
736 Rotation == ComplexDeinterleavingRotation::Rotation_270)
737 std::swap(UncommonRealOp, UncommonImagOp);
739 std::pair<Value *, Value *> PartialMatch(
740 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
741 Rotation == ComplexDeinterleavingRotation::Rotation_180)
744 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
745 Rotation == ComplexDeinterleavingRotation::Rotation_270)
749 auto *CRInst = dyn_cast<Instruction>(CR);
750 auto *CIInst = dyn_cast<Instruction>(CI);
752 if (!CRInst || !CIInst) {
753 LLVM_DEBUG(
dbgs() <<
" - Common operands are not instructions.\n");
757 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
763 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
769 assert(PartialMatch.first && PartialMatch.second);
770 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
776 NodePtr
Node = prepareCompositeNode(
777 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
778 Node->Rotation = Rotation;
779 Node->addOperand(CommonRes);
780 Node->addOperand(UncommonRes);
781 Node->addOperand(CNode);
782 return submitCompositeNode(
Node);
785ComplexDeinterleavingGraph::NodePtr
787 LLVM_DEBUG(
dbgs() <<
"identifyAdd " << *Real <<
" / " << *Imag <<
"\n");
791 if ((Real->
getOpcode() == Instruction::FSub &&
792 Imag->
getOpcode() == Instruction::FAdd) ||
793 (Real->
getOpcode() == Instruction::Sub &&
795 Rotation = ComplexDeinterleavingRotation::Rotation_90;
796 else if ((Real->
getOpcode() == Instruction::FAdd &&
797 Imag->
getOpcode() == Instruction::FSub) ||
798 (Real->
getOpcode() == Instruction::Add &&
800 Rotation = ComplexDeinterleavingRotation::Rotation_270;
802 LLVM_DEBUG(
dbgs() <<
" - Unhandled case, rotation is not assigned.\n");
806 auto *AR = dyn_cast<Instruction>(Real->
getOperand(0));
807 auto *BI = dyn_cast<Instruction>(Real->
getOperand(1));
808 auto *AI = dyn_cast<Instruction>(Imag->
getOperand(0));
811 if (!AR || !AI || !BR || !BI) {
816 NodePtr ResA = identifyNode(AR, AI);
818 LLVM_DEBUG(
dbgs() <<
" - AR/AI is not identified as a composite node.\n");
821 NodePtr ResB = identifyNode(BR, BI);
823 LLVM_DEBUG(
dbgs() <<
" - BR/BI is not identified as a composite node.\n");
828 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
829 Node->Rotation = Rotation;
830 Node->addOperand(ResA);
831 Node->addOperand(ResB);
832 return submitCompositeNode(
Node);
836 unsigned OpcA =
A->getOpcode();
837 unsigned OpcB =
B->getOpcode();
839 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
840 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
841 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
842 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
853 switch (
I->getOpcode()) {
854 case Instruction::FAdd:
855 case Instruction::FSub:
856 case Instruction::FMul:
857 case Instruction::FNeg:
858 case Instruction::Add:
859 case Instruction::Sub:
860 case Instruction::Mul:
867ComplexDeinterleavingGraph::NodePtr
868ComplexDeinterleavingGraph::identifySymmetricOperation(
Instruction *Real,
880 NodePtr Op0 = identifyNode(R0, I0);
881 NodePtr Op1 =
nullptr;
888 Op1 = identifyNode(R1, I1);
893 if (isa<FPMathOperator>(Real) &&
897 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
900 if (isa<FPMathOperator>(Real))
903 Node->addOperand(Op0);
905 Node->addOperand(Op1);
907 return submitCompositeNode(
Node);
910ComplexDeinterleavingGraph::NodePtr
911ComplexDeinterleavingGraph::identifyDotProduct(
Value *V) {
914 ComplexDeinterleavingOperation::CDot,
V->getType())) {
915 LLVM_DEBUG(
dbgs() <<
"Target doesn't support complex deinterleaving "
916 "operation CDot with the type "
917 << *
V->getType() <<
"\n");
921 auto *Inst = cast<Instruction>(V);
922 auto *RealUser = cast<Instruction>(*Inst->user_begin());
925 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst,
nullptr);
930 Intrinsic::experimental_vector_partial_reduce_add;
932 Value *AReal =
nullptr;
933 Value *AImag =
nullptr;
934 Value *BReal =
nullptr;
935 Value *BImag =
nullptr;
939 if (
auto *CI = dyn_cast<CastInst>(V))
940 return CI->getOperand(0);
944 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
945 m_Intrinsic<PartialReduceInt>(
m_Value(Phi),
949 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
950 m_Intrinsic<PartialReduceInt>(
954 if (
match(Inst, PatternRot0)) {
955 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
956 }
else if (
match(Inst, PatternRot270)) {
957 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
963 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
964 m_Intrinsic<PartialReduceInt>(
m_Value(Phi),
968 if (!
match(Inst, PatternRot90Rot180))
975 ANode = identifyNode(A0, A1);
978 ANode = identifyNode(A1, A0);
982 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
988 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
992 AReal = UnwrapCast(AReal);
993 AImag = UnwrapCast(AImag);
994 BReal = UnwrapCast(BReal);
995 BImag = UnwrapCast(BImag);
998 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
999 if (AReal->
getType() != ExpectedOperandTy)
1001 if (AImag->
getType() != ExpectedOperandTy)
1003 if (BReal->
getType() != ExpectedOperandTy)
1005 if (BImag->
getType() != ExpectedOperandTy)
1008 if (
Phi->getType() != VTy && RealUser->getType() != VTy)
1011 NodePtr
Node = identifyNode(AReal, AImag);
1016 if (ANode &&
Node != ANode) {
1019 <<
"Identified node is different from previously identified node. "
1020 "Unable to confidently generate a complex operation node\n");
1024 CN->addOperand(
Node);
1025 CN->addOperand(identifyNode(BReal, BImag));
1026 CN->addOperand(identifyNode(Phi, RealUser));
1028 return submitCompositeNode(CN);
1031ComplexDeinterleavingGraph::NodePtr
1032ComplexDeinterleavingGraph::identifyPartialReduction(
Value *R,
Value *
I) {
1034 if (!isa<VectorType>(
R->getType()) || !isa<VectorType>(
I->getType()))
1038 findCommonBetweenCollections<Value *>(
R->users(),
I->users());
1042 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1043 if (!IInst || IInst->getIntrinsicID() !=
1044 Intrinsic::experimental_vector_partial_reduce_add)
1047 if (NodePtr CN = identifyDotProduct(IInst))
1053ComplexDeinterleavingGraph::NodePtr
1054ComplexDeinterleavingGraph::identifyNode(
Value *R,
Value *
I) {
1055 auto It = CachedResult.
find({
R,
I});
1056 if (It != CachedResult.
end()) {
1061 if (NodePtr CN = identifyPartialReduction(R,
I))
1064 bool IsReduction = RealPHI ==
R && (!ImagPHI || ImagPHI ==
I);
1065 if (!IsReduction &&
R->getType() !=
I->getType())
1068 if (NodePtr CN = identifySplat(R,
I))
1071 auto *Real = dyn_cast<Instruction>(R);
1072 auto *Imag = dyn_cast<Instruction>(
I);
1076 if (NodePtr CN = identifyDeinterleave(Real, Imag))
1079 if (NodePtr CN = identifyPHINode(Real, Imag))
1082 if (NodePtr CN = identifySelectNode(Real, Imag))
1085 auto *VTy = cast<VectorType>(Real->
getType());
1086 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1089 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1091 ComplexDeinterleavingOperation::CAdd, NewVTy);
1094 if (NodePtr CN = identifyPartialMul(Real, Imag))
1099 if (NodePtr CN = identifyAdd(Real, Imag))
1103 if (HasCMulSupport && HasCAddSupport) {
1104 if (NodePtr CN = identifyReassocNodes(Real, Imag))
1108 if (NodePtr CN = identifySymmetricOperation(Real, Imag))
1112 CachedResult[{
R,
I}] =
nullptr;
1116ComplexDeinterleavingGraph::NodePtr
1117ComplexDeinterleavingGraph::identifyReassocNodes(
Instruction *Real,
1119 auto IsOperationSupported = [](
unsigned Opcode) ->
bool {
1120 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1121 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1122 Opcode == Instruction::Sub;
1125 if (!IsOperationSupported(Real->
getOpcode()) ||
1126 !IsOperationSupported(Imag->
getOpcode()))
1129 std::optional<FastMathFlags>
Flags;
1130 if (isa<FPMathOperator>(Real)) {
1132 LLVM_DEBUG(
dbgs() <<
"The flags in Real and Imaginary instructions are "
1138 if (!
Flags->allowReassoc()) {
1141 <<
"the 'Reassoc' attribute is missing in the FastMath flags\n");
1150 std::list<Addend> &Addends) ->
bool {
1153 while (!Worklist.
empty()) {
1154 auto [
V, IsPositive] = Worklist.
back();
1156 if (!Visited.
insert(V).second)
1161 Addends.emplace_back(V, IsPositive);
1171 if (
I !=
Insn &&
I->getNumUses() > 1) {
1172 LLVM_DEBUG(
dbgs() <<
"Found potential sub-expression: " << *
I <<
"\n");
1173 Addends.emplace_back(
I, IsPositive);
1176 switch (
I->getOpcode()) {
1177 case Instruction::FAdd:
1178 case Instruction::Add:
1182 case Instruction::FSub:
1186 case Instruction::Sub:
1194 case Instruction::FMul:
1195 case Instruction::Mul: {
1197 if (
isNeg(
I->getOperand(0))) {
1199 IsPositive = !IsPositive;
1201 A =
I->getOperand(0);
1204 if (
isNeg(
I->getOperand(1))) {
1206 IsPositive = !IsPositive;
1208 B =
I->getOperand(1);
1210 Muls.push_back(Product{
A,
B, IsPositive});
1213 case Instruction::FNeg:
1217 Addends.emplace_back(
I, IsPositive);
1221 if (Flags &&
I->getFastMathFlags() != *Flags) {
1223 "inconsistent with the root instructions' flags: "
1231 std::vector<Product> RealMuls, ImagMuls;
1232 std::list<Addend> RealAddends, ImagAddends;
1233 if (!Collect(Real, RealMuls, RealAddends) ||
1234 !Collect(Imag, ImagMuls, ImagAddends))
1237 if (RealAddends.size() != ImagAddends.size())
1241 if (!RealMuls.empty() || !ImagMuls.empty()) {
1244 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1245 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1251 if (!RealAddends.empty() || !ImagAddends.empty()) {
1252 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1256 assert(FinalNode &&
"FinalNode can not be nullptr here");
1258 FinalNode->Real = Real;
1259 FinalNode->Imag = Imag;
1260 submitCompositeNode(FinalNode);
1264bool ComplexDeinterleavingGraph::collectPartialMuls(
1265 const std::vector<Product> &RealMuls,
const std::vector<Product> &ImagMuls,
1266 std::vector<PartialMulCandidate> &PartialMulCandidates) {
1268 auto FindCommonInstruction = [](
const Product &Real,
1269 const Product &Imag) ->
Value * {
1270 if (Real.Multiplicand == Imag.Multiplicand ||
1271 Real.Multiplicand == Imag.Multiplier)
1272 return Real.Multiplicand;
1274 if (Real.Multiplier == Imag.Multiplicand ||
1275 Real.Multiplier == Imag.Multiplier)
1276 return Real.Multiplier;
1285 for (
unsigned i = 0; i < RealMuls.size(); ++i) {
1286 bool FoundCommon =
false;
1287 for (
unsigned j = 0;
j < ImagMuls.size(); ++
j) {
1288 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1292 auto *
A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1293 : RealMuls[i].Multiplicand;
1294 auto *
B = ImagMuls[
j].Multiplicand == Common ? ImagMuls[
j].Multiplier
1295 : ImagMuls[
j].Multiplicand;
1297 auto Node = identifyNode(
A,
B);
1300 PartialMulCandidates.push_back({Common,
Node, i,
j,
false});
1303 Node = identifyNode(
B,
A);
1306 PartialMulCandidates.push_back({Common,
Node, i,
j,
true});
1315ComplexDeinterleavingGraph::NodePtr
1316ComplexDeinterleavingGraph::identifyMultiplications(
1317 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1319 if (RealMuls.size() != ImagMuls.size())
1322 std::vector<PartialMulCandidate>
Info;
1323 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1327 std::map<Value *, NodePtr> CommonToNode;
1328 std::vector<bool> Processed(
Info.size(),
false);
1329 for (
unsigned I = 0;
I <
Info.size(); ++
I) {
1333 PartialMulCandidate &InfoA =
Info[
I];
1334 for (
unsigned J =
I + 1; J <
Info.size(); ++J) {
1338 PartialMulCandidate &InfoB =
Info[J];
1339 auto *InfoReal = &InfoA;
1340 auto *InfoImag = &InfoB;
1342 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1343 if (!NodeFromCommon) {
1345 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1347 if (!NodeFromCommon)
1350 CommonToNode[InfoReal->Common] = NodeFromCommon;
1351 CommonToNode[InfoImag->Common] = NodeFromCommon;
1352 Processed[
I] =
true;
1353 Processed[J] =
true;
1357 std::vector<bool> ProcessedReal(RealMuls.size(),
false);
1358 std::vector<bool> ProcessedImag(ImagMuls.size(),
false);
1360 for (
auto &PMI : Info) {
1361 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1364 auto It = CommonToNode.find(PMI.Common);
1367 if (It == CommonToNode.end()) {
1369 dbgs() <<
"Unprocessed independent partial multiplication:\n";
1370 for (
auto *
Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1372 <<
" multiplied by " << *
Mul->Multiplicand <<
"\n";
1377 auto &RealMul = RealMuls[PMI.RealIdx];
1378 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1380 auto NodeA = It->second;
1381 auto NodeB = PMI.Node;
1382 auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1397 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1398 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1403 if (IsMultiplicandReal) {
1405 if (RealMul.IsPositive && ImagMul.IsPositive)
1407 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1414 if (!RealMul.IsPositive && ImagMul.IsPositive)
1416 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1423 dbgs() <<
"Identified partial multiplication (X, Y) * (U, V):\n";
1424 dbgs().
indent(4) <<
"X: " << *NodeA->Real <<
"\n";
1425 dbgs().
indent(4) <<
"Y: " << *NodeA->Imag <<
"\n";
1426 dbgs().
indent(4) <<
"U: " << *NodeB->Real <<
"\n";
1427 dbgs().
indent(4) <<
"V: " << *NodeB->Imag <<
"\n";
1428 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1431 NodePtr NodeMul = prepareCompositeNode(
1432 ComplexDeinterleavingOperation::CMulPartial,
nullptr,
nullptr);
1433 NodeMul->Rotation = Rotation;
1434 NodeMul->addOperand(NodeA);
1435 NodeMul->addOperand(NodeB);
1437 NodeMul->addOperand(Result);
1438 submitCompositeNode(NodeMul);
1440 ProcessedReal[PMI.RealIdx] =
true;
1441 ProcessedImag[PMI.ImagIdx] =
true;
1445 if (!
all_of(ProcessedReal, [](
bool V) {
return V; }) ||
1446 !
all_of(ProcessedImag, [](
bool V) {
return V; })) {
1451 dbgs() <<
"Unprocessed products (Real):\n";
1452 for (
size_t i = 0; i < ProcessedReal.size(); ++i) {
1453 if (!ProcessedReal[i])
1454 dbgs().
indent(4) << (RealMuls[i].IsPositive ?
"+" :
"-")
1455 << *RealMuls[i].Multiplier <<
" multiplied by "
1456 << *RealMuls[i].Multiplicand <<
"\n";
1458 dbgs() <<
"Unprocessed products (Imag):\n";
1459 for (
size_t i = 0; i < ProcessedImag.size(); ++i) {
1460 if (!ProcessedImag[i])
1461 dbgs().
indent(4) << (ImagMuls[i].IsPositive ?
"+" :
"-")
1462 << *ImagMuls[i].Multiplier <<
" multiplied by "
1463 << *ImagMuls[i].Multiplicand <<
"\n";
1472ComplexDeinterleavingGraph::NodePtr
1473ComplexDeinterleavingGraph::identifyAdditions(
1474 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1475 std::optional<FastMathFlags> Flags, NodePtr
Accumulator =
nullptr) {
1476 if (RealAddends.size() != ImagAddends.size())
1485 Result = extractPositiveAddend(RealAddends, ImagAddends);
1490 while (!RealAddends.empty()) {
1491 auto ItR = RealAddends.begin();
1492 auto [
R, IsPositiveR] = *ItR;
1494 bool FoundImag =
false;
1495 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1496 auto [
I, IsPositiveI] = *ItI;
1498 if (IsPositiveR && IsPositiveI)
1499 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1500 else if (!IsPositiveR && IsPositiveI)
1501 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1502 else if (!IsPositiveR && !IsPositiveI)
1503 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1505 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1508 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1509 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1510 AddNode = identifyNode(R,
I);
1512 AddNode = identifyNode(
I, R);
1516 dbgs() <<
"Identified addition:\n";
1519 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1524 TmpNode = prepareCompositeNode(
1525 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1527 TmpNode->Opcode = Instruction::FAdd;
1528 TmpNode->Flags = *
Flags;
1530 TmpNode->Opcode = Instruction::Add;
1532 }
else if (Rotation ==
1534 TmpNode = prepareCompositeNode(
1535 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1537 TmpNode->Opcode = Instruction::FSub;
1538 TmpNode->Flags = *
Flags;
1540 TmpNode->Opcode = Instruction::Sub;
1543 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1545 TmpNode->Rotation = Rotation;
1548 TmpNode->addOperand(Result);
1549 TmpNode->addOperand(AddNode);
1550 submitCompositeNode(TmpNode);
1552 RealAddends.erase(ItR);
1553 ImagAddends.erase(ItI);
1564ComplexDeinterleavingGraph::NodePtr
1565ComplexDeinterleavingGraph::extractPositiveAddend(
1566 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1567 for (
auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1568 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1569 auto [
R, IsPositiveR] = *ItR;
1570 auto [
I, IsPositiveI] = *ItI;
1571 if (IsPositiveR && IsPositiveI) {
1572 auto Result = identifyNode(R,
I);
1574 RealAddends.erase(ItR);
1575 ImagAddends.erase(ItI);
1584bool ComplexDeinterleavingGraph::identifyNodes(
Instruction *RootI) {
1589 auto It = RootToNode.find(RootI);
1590 if (It != RootToNode.end()) {
1591 auto RootNode = It->second;
1592 assert(RootNode->Operation ==
1593 ComplexDeinterleavingOperation::ReductionOperation ||
1594 RootNode->Operation ==
1595 ComplexDeinterleavingOperation::ReductionSingle);
1598 auto *
R = cast<Instruction>(RootNode->Real);
1599 auto *
I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
1603 ReplacementAnchor =
R->comesBefore(
I) ?
I :
R;
1605 ReplacementAnchor =
R;
1607 if (ReplacementAnchor != RootI)
1613 auto RootNode = identifyRoot(RootI);
1620 dbgs() <<
"Complex deinterleaving graph for " <<
F->getName()
1621 <<
"::" <<
B->getName() <<
".\n";
1625 RootToNode[RootI] = RootNode;
1630bool ComplexDeinterleavingGraph::collectPotentialReductions(
BasicBlock *
B) {
1631 bool FoundPotentialReduction =
false;
1633 auto *Br = dyn_cast<BranchInst>(
B->getTerminator());
1634 if (!Br || Br->getNumSuccessors() != 2)
1638 if (Br->getSuccessor(0) !=
B && Br->getSuccessor(1) !=
B)
1642 for (
auto &
PHI :
B->phis()) {
1643 if (
PHI.getNumIncomingValues() != 2)
1646 if (!
PHI.getType()->isVectorTy())
1649 auto *ReductionOp = dyn_cast<Instruction>(
PHI.getIncomingValueForBlock(
B));
1656 for (
auto *U : ReductionOp->users()) {
1660 FinalReduction = dyn_cast<Instruction>(U);
1663 if (NumUsers != 2 || !FinalReduction || FinalReduction->
getParent() ==
B ||
1664 isa<PHINode>(FinalReduction))
1667 ReductionInfo[ReductionOp] = {&
PHI, FinalReduction};
1669 auto BackEdgeIdx =
PHI.getBasicBlockIndex(
B);
1670 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1672 FoundPotentialReduction =
true;
1677 dyn_cast<Instruction>(
PHI.getIncomingValueForBlock(
Incoming)))
1678 FinalInstructions.
insert(InitPHI);
1680 return FoundPotentialReduction;
1683void ComplexDeinterleavingGraph::identifyReductionNodes() {
1686 for (
auto &
P : ReductionInfo)
1691 for (
size_t i = 0; i < OperationInstruction.
size(); ++i) {
1694 for (
size_t j = i + 1;
j < OperationInstruction.
size(); ++
j) {
1697 auto *Real = OperationInstruction[i];
1698 auto *Imag = OperationInstruction[
j];
1699 if (Real->getType() != Imag->
getType())
1702 RealPHI = ReductionInfo[Real].first;
1703 ImagPHI = ReductionInfo[Imag].first;
1705 auto Node = identifyNode(Real, Imag);
1709 Node = identifyNode(Real, Imag);
1715 if (
Node && PHIsFound) {
1716 LLVM_DEBUG(
dbgs() <<
"Identified reduction starting from instructions: "
1717 << *Real <<
" / " << *Imag <<
"\n");
1718 Processed[i] =
true;
1719 Processed[
j] =
true;
1720 auto RootNode = prepareCompositeNode(
1721 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1722 RootNode->addOperand(
Node);
1723 RootToNode[Real] = RootNode;
1724 RootToNode[Imag] = RootNode;
1725 submitCompositeNode(RootNode);
1730 auto *Real = OperationInstruction[i];
1733 if (Real->getNumOperands() < 2)
1736 RealPHI = ReductionInfo[Real].first;
1739 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1740 if (
Node && PHIsFound) {
1742 dbgs() <<
"Identified single reduction starting from instruction: "
1743 << *Real <<
"/" << *ReductionInfo[Real].second <<
"\n");
1744 Processed[i] =
true;
1745 auto RootNode = prepareCompositeNode(
1746 ComplexDeinterleavingOperation::ReductionSingle, Real,
nullptr);
1747 RootNode->addOperand(
Node);
1748 RootToNode[Real] = RootNode;
1749 submitCompositeNode(RootNode);
1757bool ComplexDeinterleavingGraph::checkNodes() {
1759 bool FoundDeinterleaveNode =
false;
1760 for (NodePtr
N : CompositeNodes) {
1761 if (!
N->areOperandsValid())
1763 if (
N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1764 FoundDeinterleaveNode =
true;
1769 if (!FoundDeinterleaveNode) {
1771 dbgs() <<
"Couldn't find a deinterleave node within the graph, cannot "
1772 "guarantee safety during graph transformation.\n");
1779 for (
auto &Pair : RootToNode)
1784 while (!Worklist.
empty()) {
1785 auto *
I = Worklist.
back();
1788 if (!AllInstructions.
insert(
I).second)
1792 if (
auto *OpI = dyn_cast<Instruction>(
Op)) {
1793 if (!FinalInstructions.
count(
I))
1801 for (
auto *
I : AllInstructions) {
1803 if (RootToNode.count(
I))
1806 for (
User *U :
I->users()) {
1807 if (AllInstructions.count(cast<Instruction>(U)))
1819 while (!Worklist.
empty()) {
1820 auto *
I = Worklist.
back();
1822 if (!Visited.
insert(
I).second)
1827 if (RootToNode.count(
I)) {
1829 <<
" could be deinterleaved but its chain of complex "
1830 "operations have an outside user\n");
1831 RootToNode.erase(
I);
1834 if (!AllInstructions.count(
I) || FinalInstructions.
count(
I))
1837 for (
User *U :
I->users())
1841 if (
auto *OpI = dyn_cast<Instruction>(
Op))
1845 return !RootToNode.empty();
1848ComplexDeinterleavingGraph::NodePtr
1849ComplexDeinterleavingGraph::identifyRoot(
Instruction *RootI) {
1850 if (
auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1851 if (
Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
1854 auto *Real = dyn_cast<Instruction>(
Intrinsic->getOperand(0));
1855 auto *Imag = dyn_cast<Instruction>(
Intrinsic->getOperand(1));
1859 return identifyNode(Real, Imag);
1862 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1876 return identifyNode(Real, Imag);
1879ComplexDeinterleavingGraph::NodePtr
1880ComplexDeinterleavingGraph::identifyDeinterleave(
Instruction *Real,
1883 Value *FinalValue =
nullptr;
1886 match(
I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
1888 NodePtr PlaceholderNode = prepareCompositeNode(
1890 PlaceholderNode->ReplacementNode = FinalValue;
1891 FinalInstructions.
insert(Real);
1892 FinalInstructions.
insert(Imag);
1893 return submitCompositeNode(PlaceholderNode);
1896 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1897 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1898 if (!RealShuffle || !ImagShuffle) {
1899 if (RealShuffle || ImagShuffle)
1900 LLVM_DEBUG(
dbgs() <<
" - There's a shuffle where there shouldn't be.\n");
1904 Value *RealOp1 = RealShuffle->getOperand(1);
1905 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1909 Value *ImagOp1 = ImagShuffle->getOperand(1);
1910 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1915 Value *RealOp0 = RealShuffle->getOperand(0);
1916 Value *ImagOp0 = ImagShuffle->getOperand(0);
1918 if (RealOp0 != ImagOp0) {
1930 if (RealMask[0] != 0 || ImagMask[0] != 1) {
1931 LLVM_DEBUG(
dbgs() <<
" - Masks do not have the correct initial value.\n");
1938 Value *
Op = Shuffle->getOperand(0);
1939 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1940 auto *OpTy = cast<FixedVectorType>(
Op->getType());
1942 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1944 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1957 Value *
Op = Shuffle->getOperand(0);
1958 auto *OpTy = cast<FixedVectorType>(
Op->getType());
1959 int NumElements = OpTy->getNumElements();
1963 return Last < NumElements;
1966 if (RealShuffle->getType() != ImagShuffle->getType()) {
1970 if (!CheckDeinterleavingShuffle(RealShuffle)) {
1974 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1979 NodePtr PlaceholderNode =
1981 RealShuffle, ImagShuffle);
1982 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1983 FinalInstructions.
insert(RealShuffle);
1984 FinalInstructions.
insert(ImagShuffle);
1985 return submitCompositeNode(PlaceholderNode);
1988ComplexDeinterleavingGraph::NodePtr
1989ComplexDeinterleavingGraph::identifySplat(
Value *R,
Value *
I) {
1990 auto IsSplat = [](
Value *
V) ->
bool {
1992 if (isa<ConstantDataVector>(V))
1999 if (
auto *Const = dyn_cast<ConstantExpr>(V)) {
2000 if (
Const->getOpcode() != Instruction::ShuffleVector)
2002 VTy = cast<VectorType>(
Const->getType());
2004 }
else if (
auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2005 VTy = Shuf->getType();
2006 Mask = Shuf->getShuffleMask();
2014 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2020 if (!IsSplat(R) || !IsSplat(
I))
2023 auto *Real = dyn_cast<Instruction>(R);
2024 auto *Imag = dyn_cast<Instruction>(
I);
2025 if ((!Real && Imag) || (Real && !Imag))
2033 FinalInstructions.
insert(Real);
2034 FinalInstructions.
insert(Imag);
2036 NodePtr PlaceholderNode =
2037 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R,
I);
2038 return submitCompositeNode(PlaceholderNode);
2041ComplexDeinterleavingGraph::NodePtr
2042ComplexDeinterleavingGraph::identifyPHINode(
Instruction *Real,
2044 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2048 NodePtr PlaceholderNode = prepareCompositeNode(
2049 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2050 return submitCompositeNode(PlaceholderNode);
2053ComplexDeinterleavingGraph::NodePtr
2054ComplexDeinterleavingGraph::identifySelectNode(
Instruction *Real,
2056 auto *SelectReal = dyn_cast<SelectInst>(Real);
2057 auto *SelectImag = dyn_cast<SelectInst>(Imag);
2058 if (!SelectReal || !SelectImag)
2075 auto NodeA = identifyNode(AR, AI);
2079 auto NodeB = identifyNode(
RA, BI);
2083 NodePtr PlaceholderNode = prepareCompositeNode(
2084 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2085 PlaceholderNode->addOperand(NodeA);
2086 PlaceholderNode->addOperand(NodeB);
2087 FinalInstructions.
insert(MaskA);
2088 FinalInstructions.
insert(MaskB);
2089 return submitCompositeNode(PlaceholderNode);
2093 std::optional<FastMathFlags> Flags,
2097 case Instruction::FNeg:
2098 I =
B.CreateFNeg(InputA);
2100 case Instruction::FAdd:
2101 I =
B.CreateFAdd(InputA, InputB);
2103 case Instruction::Add:
2104 I =
B.CreateAdd(InputA, InputB);
2106 case Instruction::FSub:
2107 I =
B.CreateFSub(InputA, InputB);
2109 case Instruction::Sub:
2110 I =
B.CreateSub(InputA, InputB);
2112 case Instruction::FMul:
2113 I =
B.CreateFMul(InputA, InputB);
2115 case Instruction::Mul:
2116 I =
B.CreateMul(InputA, InputB);
2122 cast<Instruction>(
I)->setFastMathFlags(*Flags);
2128 if (
Node->ReplacementNode)
2129 return Node->ReplacementNode;
2131 auto ReplaceOperandIfExist = [&](RawNodePtr &
Node,
unsigned Idx) ->
Value * {
2132 return Node->Operands.size() >
Idx
2133 ? replaceNode(Builder,
Node->Operands[
Idx])
2137 Value *ReplacementNode;
2138 switch (
Node->Operation) {
2139 case ComplexDeinterleavingOperation::CDot: {
2140 Value *Input0 = ReplaceOperandIfExist(
Node, 0);
2141 Value *Input1 = ReplaceOperandIfExist(
Node, 1);
2144 "Node inputs need to be of the same type"));
2149 case ComplexDeinterleavingOperation::CAdd:
2150 case ComplexDeinterleavingOperation::CMulPartial:
2151 case ComplexDeinterleavingOperation::Symmetric: {
2152 Value *Input0 = ReplaceOperandIfExist(
Node, 0);
2153 Value *Input1 = ReplaceOperandIfExist(
Node, 1);
2156 "Node inputs need to be of the same type"));
2159 "Accumulator and input need to be of the same type"));
2160 if (
Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2165 Builder,
Node->Operation,
Node->Rotation, Input0, Input1,
2169 case ComplexDeinterleavingOperation::Deinterleave:
2172 case ComplexDeinterleavingOperation::Splat: {
2173 auto *NewTy = VectorType::getDoubleElementsVectorType(
2174 cast<VectorType>(
Node->Real->getType()));
2175 auto *
R = dyn_cast<Instruction>(
Node->Real);
2176 auto *
I = dyn_cast<Instruction>(
Node->Imag);
2181 ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
2185 Intrinsic::vector_interleave2, NewTy, {
Node->Real,
Node->Imag});
2189 case ComplexDeinterleavingOperation::ReductionPHI: {
2192 auto *OldPHI = cast<PHINode>(
Node->Real);
2193 auto *VTy = cast<VectorType>(
Node->Real->getType());
2194 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2196 OldToNewPHI[OldPHI] = NewPHI;
2197 ReplacementNode = NewPHI;
2200 case ComplexDeinterleavingOperation::ReductionSingle:
2201 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2202 processReductionSingle(ReplacementNode,
Node);
2204 case ComplexDeinterleavingOperation::ReductionOperation:
2205 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2206 processReductionOperation(ReplacementNode,
Node);
2208 case ComplexDeinterleavingOperation::ReductionSelect: {
2209 auto *MaskReal = cast<Instruction>(
Node->Real)->getOperand(0);
2210 auto *MaskImag = cast<Instruction>(
Node->Imag)->getOperand(0);
2211 auto *
A = replaceNode(Builder,
Node->Operands[0]);
2212 auto *
B = replaceNode(Builder,
Node->Operands[1]);
2213 auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
2214 cast<VectorType>(MaskReal->getType()));
2215 auto *NewMask = Builder.
CreateIntrinsic(Intrinsic::vector_interleave2,
2216 NewMaskTy, {MaskReal, MaskImag});
2222 assert(ReplacementNode &&
"Target failed to create Intrinsic call.");
2223 NumComplexTransformations += 1;
2224 Node->ReplacementNode = ReplacementNode;
2225 return ReplacementNode;
2228void ComplexDeinterleavingGraph::processReductionSingle(
2229 Value *OperationReplacement, RawNodePtr
Node) {
2230 auto *Real = cast<Instruction>(
Node->Real);
2231 auto *OldPHI = ReductionInfo[Real].first;
2232 auto *NewPHI = OldToNewPHI[OldPHI];
2233 auto *VTy = cast<VectorType>(Real->
getType());
2234 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2240 Value *NewInit =
nullptr;
2241 if (
auto *
C = dyn_cast<Constant>(
Init)) {
2242 if (
C->isZeroValue())
2247 NewInit = Builder.
CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2250 NewPHI->addIncoming(NewInit,
Incoming);
2251 NewPHI->addIncoming(OperationReplacement, BackEdge);
2253 auto *FinalReduction = ReductionInfo[Real].second;
2260void ComplexDeinterleavingGraph::processReductionOperation(
2261 Value *OperationReplacement, RawNodePtr
Node) {
2262 auto *Real = cast<Instruction>(
Node->Real);
2263 auto *Imag = cast<Instruction>(
Node->Imag);
2264 auto *OldPHIReal = ReductionInfo[Real].first;
2265 auto *OldPHIImag = ReductionInfo[Imag].first;
2266 auto *NewPHI = OldToNewPHI[OldPHIReal];
2268 auto *VTy = cast<VectorType>(Real->
getType());
2269 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2272 Value *InitReal = OldPHIReal->getIncomingValueForBlock(
Incoming);
2273 Value *InitImag = OldPHIImag->getIncomingValueForBlock(
Incoming);
2276 auto *NewInit = Builder.
CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2277 {InitReal, InitImag});
2279 NewPHI->addIncoming(NewInit,
Incoming);
2280 NewPHI->addIncoming(OperationReplacement, BackEdge);
2284 auto *FinalReductionReal = ReductionInfo[Real].second;
2285 auto *FinalReductionImag = ReductionInfo[Imag].second;
2288 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2290 OperationReplacement->
getType(),
2291 OperationReplacement);
2294 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2298 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2301void ComplexDeinterleavingGraph::replaceNodes() {
2303 for (
auto *RootInstruction : OrderedRoots) {
2306 if (!RootToNode.count(RootInstruction))
2310 auto RootNode = RootToNode[RootInstruction];
2311 Value *
R = replaceNode(Builder, RootNode.get());
2313 if (RootNode->Operation ==
2314 ComplexDeinterleavingOperation::ReductionOperation) {
2315 auto *RootReal = cast<Instruction>(RootNode->Real);
2316 auto *RootImag = cast<Instruction>(RootNode->Imag);
2317 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2318 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2321 }
else if (RootNode->Operation ==
2322 ComplexDeinterleavingOperation::ReductionSingle) {
2323 auto *RootInst = cast<Instruction>(RootNode->Real);
2324 ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
2325 DeadInstrRoots.
push_back(ReductionInfo[RootInst].second);
2327 assert(R &&
"Unable to find replacement for RootInstruction");
2328 DeadInstrRoots.
push_back(RootInstruction);
2329 RootInstruction->replaceAllUsesWith(R);
2333 for (
auto *
I : DeadInstrRoots)
SmallVector< AArch64_IMM::ImmInsnModel, 4 > Insn
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
VarLocInsertPt getNextNode(const DbgRecord *DVR)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
static bool isInstructionPairMul(Instruction *A, Instruction *B)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runOnFunction(Function &F, bool PostInlining)
mir Rename Register Operands
This file implements a map that provides insertion order iteration.
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
This file describes how to lower LLVM code to machine code.
DEMANGLE_DUMP_METHOD void dump() const
A container for analyses that lazily runs them and caches their results.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
InstListType::const_iterator getFirstNonPHIIt() const
Iterator returning form of getFirstNonPHI.
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
This class represents an Operation in the Expression.
iterator find(const_arg_type_t< KeyT > Val)
bool allowContract() const
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
const Function * getFunction() const
Return the function this instruction belongs to.
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
This class implements a map that also provides access to all stored values in a deterministic order.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
This instruction constructs a fixed permutation of two input vectors.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isVectorTy() const
True if this is an instance of VectorType.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
const ParentTy * getParent() const
This class implements an extremely fast bulk output stream that can only output to a stream.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ BR
Control flow instructions. These all have token chains.
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
BinaryOp_match< cst_pred_ty< is_zero_int >, ValTy, Instruction::Sub > m_Neg(const ValTy &V)
Matches a 'Neg' as 'sub 0, V'.
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
ComplexDeinterleavingOperation
FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
ComplexDeinterleavingRotation
DWARFExpression::Operation Op
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...