82#define DEBUG_TYPE "complex-deinterleaving"
84STATISTIC(NumComplexTransformations,
"Amount of complex patterns transformed");
87 "enable-complex-deinterleaving",
115 Value *Real =
nullptr;
116 Value *Imag =
nullptr;
119 return Real ==
Other.Real && Imag ==
Other.Imag;
134 static bool isEqual(
const ComplexValue &LHS,
const ComplexValue &RHS) {
135 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
140template <
typename T,
typename IterT>
141std::optional<T> findCommonBetweenCollections(IterT
A, IterT
B) {
143 if (Common !=
A.end())
144 return std::make_optional(*Common);
148class ComplexDeinterleavingLegacyPass :
public FunctionPass {
152 ComplexDeinterleavingLegacyPass(
const TargetMachine *TM =
nullptr)
153 : FunctionPass(ID), TM(TM) {}
155 StringRef getPassName()
const override {
156 return "Complex Deinterleaving Pass";
160 void getAnalysisUsage(AnalysisUsage &AU)
const override {
166 const TargetMachine *TM;
169class ComplexDeinterleavingGraph;
170struct ComplexDeinterleavingCompositeNode {
175 Vals.push_back({
R,
I});
180 : Operation(
Op), Vals(
Other) {}
183 friend class ComplexDeinterleavingGraph;
184 using CompositeNode = ComplexDeinterleavingCompositeNode;
185 bool OperandsValid =
true;
194 std::optional<FastMathFlags> Flags;
197 ComplexDeinterleavingRotation::Rotation_0;
199 Value *ReplacementNode =
nullptr;
203 OperandsValid =
false;
204 Operands.push_back(Node);
208 void dump(raw_ostream &OS) {
209 auto PrintValue = [&](
Value *
V) {
217 auto PrintNodeRef = [&](CompositeNode *Ptr) {
224 OS <<
"- CompositeNode: " <<
this <<
"\n";
225 for (
unsigned I = 0;
I < Vals.size();
I++) {
226 OS <<
" Real(" <<
I <<
") : ";
227 PrintValue(Vals[
I].Real);
228 OS <<
" Imag(" <<
I <<
") : ";
229 PrintValue(Vals[
I].Imag);
231 OS <<
" ReplacementNode: ";
232 PrintValue(ReplacementNode);
233 OS <<
" Operation: " << (int)Operation <<
"\n";
234 OS <<
" Rotation: " << ((int)Rotation * 90) <<
"\n";
235 OS <<
" Operands: \n";
236 for (
const auto &
Op : Operands) {
242 bool areOperandsValid() {
return OperandsValid; }
245class ComplexDeinterleavingGraph {
253 using Addend = std::pair<Value *, bool>;
255 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
259 struct PartialMulCandidate {
267 explicit ComplexDeinterleavingGraph(
const TargetLowering *TL,
268 const TargetLibraryInfo *TLI,
270 : TL(TL), TLI(TLI), Factor(Factor) {}
273 const TargetLowering *TL =
nullptr;
274 const TargetLibraryInfo *TLI =
nullptr;
277 DenseMap<ComplexValues, CompositeNode *> CachedResult;
278 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
280 SmallPtrSet<Instruction *, 16> FinalInstructions;
283 DenseMap<Instruction *, CompositeNode *> RootToNode;
310 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
318 PHINode *RealPHI =
nullptr;
319 PHINode *ImagPHI =
nullptr;
323 bool PHIsFound =
false;
331 DenseMap<PHINode *, PHINode *> OldToNewPHI;
336 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
338 "Reduction related nodes must have Real and Imaginary parts");
339 return new (Allocator.Allocate())
340 ComplexDeinterleavingCompositeNode(
Operation, R,
I);
346 for (
auto &V : Vals) {
348 ((
Operation != ComplexDeinterleavingOperation::ReductionPHI &&
349 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
350 (
V.Real &&
V.Imag)) &&
351 "Reduction related nodes must have Real and Imaginary parts");
354 return new (Allocator.Allocate())
355 ComplexDeinterleavingCompositeNode(
Operation, Vals);
358 CompositeNode *submitCompositeNode(CompositeNode *Node) {
359 CompositeNodes.push_back(Node);
360 if (
Node->Vals[0].Real)
376 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
382 identifyNodeWithImplicitAdd(Instruction *
I, Instruction *J,
383 std::pair<Value *, Value *> &CommonOperandI);
392 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
393 CompositeNode *identifySymmetricOperation(
ComplexValues &Vals);
394 CompositeNode *identifyPartialReduction(
Value *R,
Value *
I);
395 CompositeNode *identifyDotProduct(
Value *Inst);
402 return identifyNode(Vals);
409 CompositeNode *identifyAdditions(AddendList &RealAddends,
410 AddendList &ImagAddends,
411 std::optional<FastMathFlags> Flags,
415 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
416 AddendList &ImagAddends);
421 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
422 SmallVectorImpl<Product> &ImagMuls,
430 SmallVectorImpl<PartialMulCandidate> &Candidates);
438 CompositeNode *identifyReassocNodes(Instruction *
I, Instruction *J);
440 CompositeNode *identifyRoot(Instruction *
I);
458 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
462 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
464 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
471 void processReductionOperation(
Value *OperationReplacement,
472 CompositeNode *Node);
473 void processReductionSingle(
Value *OperationReplacement, CompositeNode *Node);
477 void dump(raw_ostream &OS) {
478 for (
const auto &Node : CompositeNodes)
484 bool identifyNodes(Instruction *RootI);
489 bool collectPotentialReductions(BasicBlock *
B);
491 void identifyReductionNodes();
501class ComplexDeinterleaving {
503 ComplexDeinterleaving(
const TargetLowering *tl,
const TargetLibraryInfo *tli)
504 : TL(tl), TLI(tli) {}
508 bool evaluateBasicBlock(BasicBlock *
B,
unsigned Factor);
510 const TargetLowering *TL =
nullptr;
511 const TargetLibraryInfo *TLI =
nullptr;
516char ComplexDeinterleavingLegacyPass::ID = 0;
519 "Complex Deinterleaving",
false,
false)
525 const TargetLowering *TL = TM->getSubtargetImpl(
F)->getTargetLowering();
536 return new ComplexDeinterleavingLegacyPass(TM);
539bool ComplexDeinterleavingLegacyPass::runOnFunction(
Function &
F) {
541 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
542 return ComplexDeinterleaving(TL, &TLI).runOnFunction(
F);
545bool ComplexDeinterleaving::runOnFunction(Function &
F) {
548 dbgs() <<
"Complex deinterleaving has been explicitly disabled.\n");
554 dbgs() <<
"Complex deinterleaving has been disabled, target does "
555 "not support lowering of complex number operations.\n");
561 Changed |= evaluateBasicBlock(&
B, 2);
566 Changed |= evaluateBasicBlock(&
B, 4);
576 if ((Mask.size() & 1))
579 int HalfNumElements = Mask.size() / 2;
580 for (
int Idx = 0; Idx < HalfNumElements; ++Idx) {
581 int MaskIdx = Idx * 2;
582 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
591 int HalfNumElements = Mask.size() / 2;
593 for (
int Idx = 1; Idx < HalfNumElements; ++Idx) {
594 if (Mask[Idx] != (Idx * 2) +
Offset)
608 if (
I->getOpcode() == Instruction::FNeg)
609 return I->getOperand(0);
611 return I->getOperand(1);
614bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *
B,
unsigned Factor) {
615 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
616 if (Graph.collectPotentialReductions(
B))
617 Graph.identifyReductionNodes();
620 Graph.identifyNodes(&
I);
622 if (Graph.checkNodes()) {
623 Graph.replaceNodes();
630ComplexDeinterleavingGraph::CompositeNode *
631ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
632 Instruction *Real, Instruction *Imag,
633 std::pair<Value *, Value *> &PartialMatch) {
634 LLVM_DEBUG(
dbgs() <<
"identifyNodeWithImplicitAdd " << *Real <<
" / " << *Imag
642 if ((Real->
getOpcode() != Instruction::FMul &&
643 Real->
getOpcode() != Instruction::Mul) ||
644 (Imag->
getOpcode() != Instruction::FMul &&
645 Imag->
getOpcode() != Instruction::Mul)) {
647 dbgs() <<
" - Real or imaginary instruction is not fmul or mul\n");
680 Value *CommonOperand;
681 Value *UncommonRealOp;
682 Value *UncommonImagOp;
684 if (R0 == I0 || R0 == I1) {
687 }
else if (R1 == I0 || R1 == I1) {
695 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
696 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
697 Rotation == ComplexDeinterleavingRotation::Rotation_270)
698 std::swap(UncommonRealOp, UncommonImagOp);
702 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
703 Rotation == ComplexDeinterleavingRotation::Rotation_180)
704 PartialMatch.first = CommonOperand;
706 PartialMatch.second = CommonOperand;
708 if (!PartialMatch.first || !PartialMatch.second) {
713 CompositeNode *CommonNode =
714 identifyNode(PartialMatch.first, PartialMatch.second);
720 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
726 CompositeNode *
Node = prepareCompositeNode(
727 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
728 Node->Rotation = Rotation;
729 Node->addOperand(CommonNode);
730 Node->addOperand(UncommonNode);
731 return submitCompositeNode(Node);
734ComplexDeinterleavingGraph::CompositeNode *
735ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
737 LLVM_DEBUG(
dbgs() <<
"identifyPartialMul " << *Real <<
" / " << *Imag
741 auto IsAdd = [](
unsigned Op) {
742 return Op == Instruction::FAdd ||
Op == Instruction::Add;
744 auto IsSub = [](
unsigned Op) {
745 return Op == Instruction::FSub ||
Op == Instruction::Sub;
749 Rotation = ComplexDeinterleavingRotation::Rotation_0;
751 Rotation = ComplexDeinterleavingRotation::Rotation_90;
753 Rotation = ComplexDeinterleavingRotation::Rotation_180;
755 Rotation = ComplexDeinterleavingRotation::Rotation_270;
764 LLVM_DEBUG(
dbgs() <<
" - Contract is missing from the FastMath flags.\n");
787 Value *CommonOperand;
788 Value *UncommonRealOp;
789 Value *UncommonImagOp;
791 if (R0 == I0 || R0 == I1) {
794 }
else if (R1 == I0 || R1 == I1) {
802 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
803 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
804 Rotation == ComplexDeinterleavingRotation::Rotation_270)
805 std::swap(UncommonRealOp, UncommonImagOp);
807 std::pair<Value *, Value *> PartialMatch(
808 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
809 Rotation == ComplexDeinterleavingRotation::Rotation_180)
812 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
813 Rotation == ComplexDeinterleavingRotation::Rotation_270)
820 if (!CRInst || !CIInst) {
821 LLVM_DEBUG(
dbgs() <<
" - Common operands are not instructions.\n");
825 CompositeNode *CNode =
826 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
832 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
838 assert(PartialMatch.first && PartialMatch.second);
839 CompositeNode *CommonRes =
840 identifyNode(PartialMatch.first, PartialMatch.second);
846 CompositeNode *
Node = prepareCompositeNode(
847 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
848 Node->Rotation = Rotation;
849 Node->addOperand(CommonRes);
850 Node->addOperand(UncommonRes);
851 Node->addOperand(CNode);
852 return submitCompositeNode(Node);
855ComplexDeinterleavingGraph::CompositeNode *
856ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
857 LLVM_DEBUG(
dbgs() <<
"identifyAdd " << *Real <<
" / " << *Imag <<
"\n");
861 if ((Real->
getOpcode() == Instruction::FSub &&
862 Imag->
getOpcode() == Instruction::FAdd) ||
863 (Real->
getOpcode() == Instruction::Sub &&
865 Rotation = ComplexDeinterleavingRotation::Rotation_90;
866 else if ((Real->
getOpcode() == Instruction::FAdd &&
867 Imag->
getOpcode() == Instruction::FSub) ||
868 (Real->
getOpcode() == Instruction::Add &&
870 Rotation = ComplexDeinterleavingRotation::Rotation_270;
872 LLVM_DEBUG(
dbgs() <<
" - Unhandled case, rotation is not assigned.\n");
881 if (!AR || !AI || !BR || !BI) {
886 CompositeNode *ResA = identifyNode(AR, AI);
888 LLVM_DEBUG(
dbgs() <<
" - AR/AI is not identified as a composite node.\n");
891 CompositeNode *ResB = identifyNode(BR, BI);
893 LLVM_DEBUG(
dbgs() <<
" - BR/BI is not identified as a composite node.\n");
897 CompositeNode *
Node =
898 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
899 Node->Rotation = Rotation;
900 Node->addOperand(ResA);
901 Node->addOperand(ResB);
902 return submitCompositeNode(Node);
906 unsigned OpcA =
A->getOpcode();
907 unsigned OpcB =
B->getOpcode();
909 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
910 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
911 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
912 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
923 switch (
I->getOpcode()) {
924 case Instruction::FAdd:
925 case Instruction::FSub:
926 case Instruction::FMul:
927 case Instruction::FNeg:
928 case Instruction::Add:
929 case Instruction::Sub:
930 case Instruction::Mul:
937ComplexDeinterleavingGraph::CompositeNode *
938ComplexDeinterleavingGraph::identifySymmetricOperation(
ComplexValues &Vals) {
940 unsigned FirstOpc = FirstReal->getOpcode();
941 for (
auto &V : Vals) {
958 for (
auto &V : Vals) {
964 CompositeNode *Op0 = identifyNode(OpVals);
965 CompositeNode *Op1 =
nullptr;
969 if (FirstReal->isBinaryOp()) {
971 for (
auto &V : Vals) {
976 Op1 = identifyNode(OpVals);
982 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
983 Node->Opcode = FirstReal->getOpcode();
985 Node->Flags = FirstReal->getFastMathFlags();
987 Node->addOperand(Op0);
988 if (FirstReal->isBinaryOp())
989 Node->addOperand(Op1);
991 return submitCompositeNode(Node);
994ComplexDeinterleavingGraph::CompositeNode *
995ComplexDeinterleavingGraph::identifyDotProduct(
Value *V) {
997 ComplexDeinterleavingOperation::CDot,
V->getType())) {
998 LLVM_DEBUG(
dbgs() <<
"Target doesn't support complex deinterleaving "
999 "operation CDot with the type "
1000 << *
V->getType() <<
"\n");
1008 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst,
nullptr);
1010 CompositeNode *ANode =
nullptr;
1012 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1014 Value *AReal =
nullptr;
1015 Value *AImag =
nullptr;
1016 Value *BReal =
nullptr;
1017 Value *BImag =
nullptr;
1022 return CI->getOperand(0);
1036 if (
match(Inst, PatternRot0)) {
1037 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1038 }
else if (
match(Inst, PatternRot270)) {
1039 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1050 if (!
match(Inst, PatternRot90Rot180))
1053 A0 = UnwrapCast(A0);
1054 A1 = UnwrapCast(A1);
1057 ANode = identifyNode(A0, A1);
1060 ANode = identifyNode(A1, A0);
1064 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1070 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1074 AReal = UnwrapCast(AReal);
1075 AImag = UnwrapCast(AImag);
1076 BReal = UnwrapCast(BReal);
1077 BImag = UnwrapCast(BImag);
1080 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1081 if (AReal->
getType() != ExpectedOperandTy)
1083 if (AImag->
getType() != ExpectedOperandTy)
1085 if (BReal->
getType() != ExpectedOperandTy)
1087 if (BImag->
getType() != ExpectedOperandTy)
1090 if (
Phi->getType() != VTy && RealUser->getType() != VTy)
1093 CompositeNode *
Node = identifyNode(AReal, AImag);
1098 if (ANode && Node != ANode) {
1101 <<
"Identified node is different from previously identified node. "
1102 "Unable to confidently generate a complex operation node\n");
1106 CN->addOperand(Node);
1107 CN->addOperand(identifyNode(BReal, BImag));
1108 CN->addOperand(identifyNode(Phi, RealUser));
1110 return submitCompositeNode(CN);
1113ComplexDeinterleavingGraph::CompositeNode *
1114ComplexDeinterleavingGraph::identifyPartialReduction(
Value *R,
Value *
I) {
1119 if (!
R->hasUseList() || !
I->hasUseList())
1123 findCommonBetweenCollections<Value *>(
R->users(),
I->users());
1128 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1131 if (CompositeNode *CN = identifyDotProduct(IInst))
1137ComplexDeinterleavingGraph::CompositeNode *
1138ComplexDeinterleavingGraph::identifyNode(
ComplexValues &Vals) {
1139 auto It = CachedResult.
find(Vals);
1140 if (It != CachedResult.
end()) {
1145 if (Vals.
size() == 1) {
1146 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1149 if (CompositeNode *CN = identifyPartialReduction(R,
I))
1151 bool IsReduction = RealPHI ==
R && (!ImagPHI || ImagPHI ==
I);
1152 if (!IsReduction &&
R->getType() !=
I->getType())
1156 if (CompositeNode *CN = identifySplat(Vals))
1159 for (
auto &V : Vals) {
1166 if (CompositeNode *CN = identifyDeinterleave(Vals))
1169 if (Vals.size() == 1) {
1170 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1173 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1176 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1180 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1183 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1185 ComplexDeinterleavingOperation::CAdd, NewVTy);
1188 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1193 if (CompositeNode *CN = identifyAdd(Real, Imag))
1197 if (HasCMulSupport && HasCAddSupport) {
1198 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1204 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1208 CachedResult[Vals] =
nullptr;
1212ComplexDeinterleavingGraph::CompositeNode *
1213ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1214 Instruction *Imag) {
1215 auto IsOperationSupported = [](
unsigned Opcode) ->
bool {
1216 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1217 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1218 Opcode == Instruction::Sub;
1221 if (!IsOperationSupported(Real->
getOpcode()) ||
1222 !IsOperationSupported(Imag->
getOpcode()))
1225 std::optional<FastMathFlags>
Flags;
1228 LLVM_DEBUG(
dbgs() <<
"The flags in Real and Imaginary instructions are "
1234 if (!
Flags->allowReassoc()) {
1237 <<
"the 'Reassoc' attribute is missing in the FastMath flags\n");
1246 AddendList &Addends) ->
bool {
1248 while (!Worklist.
empty()) {
1253 Addends.emplace_back(V, IsPositive);
1263 if (
I != Insn &&
I->hasNUsesOrMore(2)) {
1264 LLVM_DEBUG(
dbgs() <<
"Found potential sub-expression: " << *
I <<
"\n");
1265 Addends.emplace_back(
I, IsPositive);
1268 switch (
I->getOpcode()) {
1269 case Instruction::FAdd:
1270 case Instruction::Add:
1274 case Instruction::FSub:
1278 case Instruction::Sub:
1286 case Instruction::FMul:
1287 case Instruction::Mul: {
1289 if (
isNeg(
I->getOperand(0))) {
1291 IsPositive = !IsPositive;
1293 A =
I->getOperand(0);
1296 if (
isNeg(
I->getOperand(1))) {
1298 IsPositive = !IsPositive;
1300 B =
I->getOperand(1);
1302 Muls.push_back(Product{
A,
B, IsPositive});
1305 case Instruction::FNeg:
1309 Addends.emplace_back(
I, IsPositive);
1313 if (Flags &&
I->getFastMathFlags() != *Flags) {
1315 "inconsistent with the root instructions' flags: "
1324 AddendList RealAddends, ImagAddends;
1325 if (!Collect(Real, RealMuls, RealAddends) ||
1326 !Collect(Imag, ImagMuls, ImagAddends))
1329 if (RealAddends.size() != ImagAddends.size())
1332 CompositeNode *FinalNode =
nullptr;
1333 if (!RealMuls.
empty() || !ImagMuls.
empty()) {
1336 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1337 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1343 if (!RealAddends.empty() || !ImagAddends.empty()) {
1344 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1348 assert(FinalNode &&
"FinalNode can not be nullptr here");
1349 assert(FinalNode->Vals.size() == 1);
1351 FinalNode->Vals[0].Real = Real;
1352 FinalNode->Vals[0].Imag = Imag;
1353 submitCompositeNode(FinalNode);
1357bool ComplexDeinterleavingGraph::collectPartialMuls(
1359 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1361 auto FindCommonInstruction = [](
const Product &Real,
1362 const Product &Imag) ->
Value * {
1363 if (Real.Multiplicand == Imag.Multiplicand ||
1364 Real.Multiplicand == Imag.Multiplier)
1365 return Real.Multiplicand;
1367 if (Real.Multiplier == Imag.Multiplicand ||
1368 Real.Multiplier == Imag.Multiplier)
1369 return Real.Multiplier;
1378 for (
unsigned i = 0; i < RealMuls.
size(); ++i) {
1379 bool FoundCommon =
false;
1380 for (
unsigned j = 0;
j < ImagMuls.
size(); ++
j) {
1381 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1385 auto *
A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1386 : RealMuls[i].Multiplicand;
1387 auto *
B = ImagMuls[
j].Multiplicand == Common ? ImagMuls[
j].Multiplier
1388 : ImagMuls[
j].Multiplicand;
1390 auto Node = identifyNode(
A,
B);
1396 Node = identifyNode(
B,
A);
1408ComplexDeinterleavingGraph::CompositeNode *
1409ComplexDeinterleavingGraph::identifyMultiplications(
1410 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1412 if (RealMuls.
size() != ImagMuls.
size())
1416 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1420 DenseMap<Value *, CompositeNode *> CommonToNode;
1421 SmallVector<bool> Processed(
Info.size(),
false);
1422 for (
unsigned I = 0;
I <
Info.size(); ++
I) {
1426 PartialMulCandidate &InfoA =
Info[
I];
1427 for (
unsigned J =
I + 1; J <
Info.size(); ++J) {
1431 PartialMulCandidate &InfoB =
Info[J];
1432 auto *InfoReal = &InfoA;
1433 auto *InfoImag = &InfoB;
1435 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1436 if (!NodeFromCommon) {
1438 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1440 if (!NodeFromCommon)
1443 CommonToNode[InfoReal->Common] = NodeFromCommon;
1444 CommonToNode[InfoImag->Common] = NodeFromCommon;
1445 Processed[
I] =
true;
1446 Processed[J] =
true;
1450 SmallVector<bool> ProcessedReal(RealMuls.
size(),
false);
1451 SmallVector<bool> ProcessedImag(ImagMuls.
size(),
false);
1453 for (
auto &PMI : Info) {
1454 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1457 auto It = CommonToNode.
find(PMI.Common);
1460 if (It == CommonToNode.
end()) {
1462 dbgs() <<
"Unprocessed independent partial multiplication:\n";
1463 for (
auto *
Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1465 <<
" multiplied by " << *
Mul->Multiplicand <<
"\n";
1470 auto &RealMul = RealMuls[PMI.RealIdx];
1471 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1473 auto NodeA = It->second;
1474 auto NodeB = PMI.Node;
1475 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1490 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1491 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1496 if (IsMultiplicandReal) {
1498 if (RealMul.IsPositive && ImagMul.IsPositive)
1500 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1507 if (!RealMul.IsPositive && ImagMul.IsPositive)
1509 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1516 dbgs() <<
"Identified partial multiplication (X, Y) * (U, V):\n";
1517 dbgs().
indent(4) <<
"X: " << *NodeA->Vals[0].Real <<
"\n";
1518 dbgs().
indent(4) <<
"Y: " << *NodeA->Vals[0].Imag <<
"\n";
1519 dbgs().
indent(4) <<
"U: " << *NodeB->Vals[0].Real <<
"\n";
1520 dbgs().
indent(4) <<
"V: " << *NodeB->Vals[0].Imag <<
"\n";
1521 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1524 CompositeNode *NodeMul = prepareCompositeNode(
1525 ComplexDeinterleavingOperation::CMulPartial,
nullptr,
nullptr);
1526 NodeMul->Rotation = Rotation;
1527 NodeMul->addOperand(NodeA);
1528 NodeMul->addOperand(NodeB);
1530 NodeMul->addOperand(Result);
1531 submitCompositeNode(NodeMul);
1533 ProcessedReal[PMI.RealIdx] =
true;
1534 ProcessedImag[PMI.ImagIdx] =
true;
1538 if (!
all_of(ProcessedReal, [](
bool V) {
return V; }) ||
1539 !
all_of(ProcessedImag, [](
bool V) {
return V; })) {
1544 dbgs() <<
"Unprocessed products (Real):\n";
1545 for (
size_t i = 0; i < ProcessedReal.size(); ++i) {
1546 if (!ProcessedReal[i])
1547 dbgs().
indent(4) << (RealMuls[i].IsPositive ?
"+" :
"-")
1548 << *RealMuls[i].Multiplier <<
" multiplied by "
1549 << *RealMuls[i].Multiplicand <<
"\n";
1551 dbgs() <<
"Unprocessed products (Imag):\n";
1552 for (
size_t i = 0; i < ProcessedImag.size(); ++i) {
1553 if (!ProcessedImag[i])
1554 dbgs().
indent(4) << (ImagMuls[i].IsPositive ?
"+" :
"-")
1555 << *ImagMuls[i].Multiplier <<
" multiplied by "
1556 << *ImagMuls[i].Multiplicand <<
"\n";
1565ComplexDeinterleavingGraph::CompositeNode *
1566ComplexDeinterleavingGraph::identifyAdditions(
1567 AddendList &RealAddends, AddendList &ImagAddends,
1568 std::optional<FastMathFlags> Flags, CompositeNode *
Accumulator =
nullptr) {
1569 if (RealAddends.size() != ImagAddends.size())
1572 CompositeNode *
Result =
nullptr;
1578 Result = extractPositiveAddend(RealAddends, ImagAddends);
1583 while (!RealAddends.empty()) {
1584 auto ItR = RealAddends.begin();
1585 auto [
R, IsPositiveR] = *ItR;
1587 bool FoundImag =
false;
1588 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1589 auto [
I, IsPositiveI] = *ItI;
1591 if (IsPositiveR && IsPositiveI)
1592 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1593 else if (!IsPositiveR && IsPositiveI)
1594 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1595 else if (!IsPositiveR && !IsPositiveI)
1596 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1598 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1600 CompositeNode *AddNode =
nullptr;
1601 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1602 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1603 AddNode = identifyNode(R,
I);
1605 AddNode = identifyNode(
I, R);
1609 dbgs() <<
"Identified addition:\n";
1612 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1615 CompositeNode *TmpNode =
nullptr;
1617 TmpNode = prepareCompositeNode(
1618 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1620 TmpNode->Opcode = Instruction::FAdd;
1621 TmpNode->Flags = *
Flags;
1623 TmpNode->Opcode = Instruction::Add;
1625 }
else if (Rotation ==
1627 TmpNode = prepareCompositeNode(
1628 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1630 TmpNode->Opcode = Instruction::FSub;
1631 TmpNode->Flags = *
Flags;
1633 TmpNode->Opcode = Instruction::Sub;
1636 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1638 TmpNode->Rotation = Rotation;
1641 TmpNode->addOperand(Result);
1642 TmpNode->addOperand(AddNode);
1643 submitCompositeNode(TmpNode);
1645 RealAddends.erase(ItR);
1646 ImagAddends.erase(ItI);
1657ComplexDeinterleavingGraph::CompositeNode *
1658ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1659 AddendList &ImagAddends) {
1660 for (
auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1661 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1662 auto [
R, IsPositiveR] = *ItR;
1663 auto [
I, IsPositiveI] = *ItI;
1664 if (IsPositiveR && IsPositiveI) {
1665 auto Result = identifyNode(R,
I);
1667 RealAddends.erase(ItR);
1668 ImagAddends.erase(ItI);
1677bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1682 auto It = RootToNode.
find(RootI);
1683 if (It != RootToNode.
end()) {
1684 auto RootNode = It->second;
1685 assert(RootNode->Operation ==
1686 ComplexDeinterleavingOperation::ReductionOperation ||
1687 RootNode->Operation ==
1688 ComplexDeinterleavingOperation::ReductionSingle);
1689 assert(RootNode->Vals.size() == 1 &&
1690 "Cannot handle reductions involving multiple complex values");
1699 ReplacementAnchor =
R->comesBefore(
I) ?
I :
R;
1701 ReplacementAnchor =
R;
1703 if (ReplacementAnchor != RootI)
1709 auto RootNode = identifyRoot(RootI);
1716 dbgs() <<
"Complex deinterleaving graph for " <<
F->getName()
1717 <<
"::" <<
B->getName() <<
".\n";
1721 RootToNode[RootI] = RootNode;
1726bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *
B) {
1727 bool FoundPotentialReduction =
false;
1736 if (Br->getSuccessor(0) !=
B && Br->getSuccessor(1) !=
B)
1739 for (
auto &
PHI :
B->phis()) {
1740 if (
PHI.getNumIncomingValues() != 2)
1743 if (!
PHI.getType()->isVectorTy())
1753 for (
auto *U : ReductionOp->users()) {
1760 if (NumUsers != 2 || !FinalReduction || FinalReduction->
getParent() ==
B ||
1764 ReductionInfo[ReductionOp] = {&
PHI, FinalReduction};
1766 auto BackEdgeIdx =
PHI.getBasicBlockIndex(
B);
1767 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1768 Incoming =
PHI.getIncomingBlock(IncomingIdx);
1769 FoundPotentialReduction =
true;
1775 FinalInstructions.
insert(InitPHI);
1777 return FoundPotentialReduction;
1780void ComplexDeinterleavingGraph::identifyReductionNodes() {
1781 assert(Factor == 2 &&
"Cannot handle multiple complex values");
1783 SmallVector<bool> Processed(ReductionInfo.
size(),
false);
1785 for (
auto &
P : ReductionInfo)
1790 for (
size_t i = 0; i < OperationInstruction.
size(); ++i) {
1793 for (
size_t j = i + 1;
j < OperationInstruction.
size(); ++
j) {
1796 auto *Real = OperationInstruction[i];
1797 auto *Imag = OperationInstruction[
j];
1798 if (Real->getType() != Imag->
getType())
1801 RealPHI = ReductionInfo[Real].first;
1802 ImagPHI = ReductionInfo[Imag].first;
1804 auto Node = identifyNode(Real, Imag);
1808 Node = identifyNode(Real, Imag);
1814 if (Node && PHIsFound) {
1815 LLVM_DEBUG(
dbgs() <<
"Identified reduction starting from instructions: "
1816 << *Real <<
" / " << *Imag <<
"\n");
1817 Processed[i] =
true;
1818 Processed[
j] =
true;
1819 auto RootNode = prepareCompositeNode(
1820 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1821 RootNode->addOperand(Node);
1822 RootToNode[Real] = RootNode;
1823 RootToNode[Imag] = RootNode;
1824 submitCompositeNode(RootNode);
1829 auto *Real = OperationInstruction[i];
1832 if (Processed[i] || Real->getNumOperands() < 2)
1836 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1839 RealPHI = ReductionInfo[Real].first;
1842 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1843 if (Node && PHIsFound) {
1845 dbgs() <<
"Identified single reduction starting from instruction: "
1846 << *Real <<
"/" << *ReductionInfo[Real].second <<
"\n");
1855 if (ReductionInfo[Real].second->getType()->isVectorTy())
1858 Processed[i] =
true;
1859 auto RootNode = prepareCompositeNode(
1860 ComplexDeinterleavingOperation::ReductionSingle, Real,
nullptr);
1861 RootNode->addOperand(Node);
1862 RootToNode[Real] = RootNode;
1863 submitCompositeNode(RootNode);
1871bool ComplexDeinterleavingGraph::checkNodes() {
1872 bool FoundDeinterleaveNode =
false;
1873 for (CompositeNode *
N : CompositeNodes) {
1874 if (!
N->areOperandsValid())
1877 if (
N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1878 FoundDeinterleaveNode =
true;
1883 if (!FoundDeinterleaveNode) {
1885 dbgs() <<
"Couldn't find a deinterleave node within the graph, cannot "
1886 "guarantee safety during graph transformation.\n");
1891 SmallPtrSet<Instruction *, 16> AllInstructions;
1892 SmallVector<Instruction *, 8> Worklist;
1893 for (
auto &Pair : RootToNode)
1898 while (!Worklist.
empty()) {
1901 if (!AllInstructions.
insert(
I).second)
1906 if (!FinalInstructions.
count(
I))
1913 for (
auto *
I : AllInstructions) {
1915 if (RootToNode.count(
I))
1918 for (User *U :
I->users()) {
1930 SmallPtrSet<Instruction *, 16> Visited;
1931 while (!Worklist.
empty()) {
1933 if (!Visited.
insert(
I).second)
1938 if (RootToNode.count(
I)) {
1940 <<
" could be deinterleaved but its chain of complex "
1941 "operations have an outside user\n");
1942 RootToNode.erase(
I);
1945 if (!AllInstructions.count(
I) || FinalInstructions.
count(
I))
1948 for (User *U :
I->users())
1956 return !RootToNode.empty();
1959ComplexDeinterleavingGraph::CompositeNode *
1960ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1967 for (
unsigned I = 0;
I < Factor;
I += 2) {
1975 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
2003 return identifyNode(Real, Imag);
2006ComplexDeinterleavingGraph::CompositeNode *
2007ComplexDeinterleavingGraph::identifyDeinterleave(
ComplexValues &Vals) {
2011 auto CheckExtract = [&](
Value *
V,
unsigned ExpectedIdx,
2012 Instruction *ExpectedInsn) -> ExtractValueInst * {
2014 if (!EVI || EVI->getNumIndices() != 1 ||
2015 EVI->getIndices()[0] != ExpectedIdx ||
2017 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2022 for (
unsigned Idx = 0; Idx < Vals.
size(); Idx++) {
2023 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2,
II);
2024 if (RealEVI && Idx == 0)
2026 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1,
II)) {
2033 if (IntrinsicII->getIntrinsicID() !=
2038 CompositeNode *PlaceholderNode = prepareCompositeNode(
2040 PlaceholderNode->ReplacementNode =
II->getOperand(0);
2041 for (
auto &V : Vals) {
2045 return submitCompositeNode(PlaceholderNode);
2048 if (Vals.size() != 1)
2051 Value *Real = Vals[0].Real;
2052 Value *Imag = Vals[0].Imag;
2055 if (!RealShuffle || !ImagShuffle) {
2056 if (RealShuffle || ImagShuffle)
2057 LLVM_DEBUG(
dbgs() <<
" - There's a shuffle where there shouldn't be.\n");
2061 Value *RealOp1 = RealShuffle->getOperand(1);
2066 Value *ImagOp1 = ImagShuffle->getOperand(1);
2072 Value *RealOp0 = RealShuffle->getOperand(0);
2073 Value *ImagOp0 = ImagShuffle->getOperand(0);
2075 if (RealOp0 != ImagOp0) {
2080 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2081 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2087 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2088 LLVM_DEBUG(
dbgs() <<
" - Masks do not have the correct initial value.\n");
2094 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2095 Value *
Op = Shuffle->getOperand(0);
2099 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2101 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2107 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) ->
bool {
2111 ArrayRef<int>
Mask = Shuffle->getShuffleMask();
2114 Value *
Op = Shuffle->getOperand(0);
2116 int NumElements = OpTy->getNumElements();
2120 return Last < NumElements;
2123 if (RealShuffle->getType() != ImagShuffle->getType()) {
2127 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2131 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2136 CompositeNode *PlaceholderNode =
2138 RealShuffle, ImagShuffle);
2139 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2140 FinalInstructions.
insert(RealShuffle);
2141 FinalInstructions.
insert(ImagShuffle);
2142 return submitCompositeNode(PlaceholderNode);
2145ComplexDeinterleavingGraph::CompositeNode *
2146ComplexDeinterleavingGraph::identifySplat(
ComplexValues &Vals) {
2147 auto IsSplat = [](
Value *
V) ->
bool {
2160 if (
Const->getOpcode() != Instruction::ShuffleVector)
2165 VTy = Shuf->getType();
2166 Mask = Shuf->getShuffleMask();
2174 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2184 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2185 for (
auto &V : Vals) {
2186 if (!IsSplat(
V.Real) || !IsSplat(
V.Imag))
2191 if (!Real || !Imag || Real->getParent() != FirstBB ||
2192 Imag->getParent() != FirstBB)
2196 for (
auto &V : Vals) {
2203 for (
auto &V : Vals) {
2207 FinalInstructions.
insert(Real);
2208 FinalInstructions.
insert(Imag);
2211 CompositeNode *PlaceholderNode =
2212 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2213 return submitCompositeNode(PlaceholderNode);
2216ComplexDeinterleavingGraph::CompositeNode *
2217ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2218 Instruction *Imag) {
2219 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2223 CompositeNode *PlaceholderNode = prepareCompositeNode(
2224 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2225 return submitCompositeNode(PlaceholderNode);
2228ComplexDeinterleavingGraph::CompositeNode *
2229ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2230 Instruction *Imag) {
2233 if (!SelectReal || !SelectImag)
2250 auto NodeA = identifyNode(AR, AI);
2254 auto NodeB = identifyNode(
RA, BI);
2258 CompositeNode *PlaceholderNode = prepareCompositeNode(
2259 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2260 PlaceholderNode->addOperand(NodeA);
2261 PlaceholderNode->addOperand(NodeB);
2262 FinalInstructions.
insert(MaskA);
2263 FinalInstructions.
insert(MaskB);
2264 return submitCompositeNode(PlaceholderNode);
2268 std::optional<FastMathFlags> Flags,
2272 case Instruction::FNeg:
2273 I =
B.CreateFNeg(InputA);
2275 case Instruction::FAdd:
2276 I =
B.CreateFAdd(InputA, InputB);
2278 case Instruction::Add:
2279 I =
B.CreateAdd(InputA, InputB);
2281 case Instruction::FSub:
2282 I =
B.CreateFSub(InputA, InputB);
2284 case Instruction::Sub:
2285 I =
B.CreateSub(InputA, InputB);
2287 case Instruction::FMul:
2288 I =
B.CreateFMul(InputA, InputB);
2290 case Instruction::Mul:
2291 I =
B.CreateMul(InputA, InputB);
2301Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2302 CompositeNode *Node) {
2303 if (
Node->ReplacementNode)
2304 return Node->ReplacementNode;
2306 auto ReplaceOperandIfExist = [&](CompositeNode *
Node,
2307 unsigned Idx) ->
Value * {
2308 return Node->Operands.size() > Idx
2309 ? replaceNode(Builder,
Node->Operands[Idx])
2313 Value *ReplacementNode =
nullptr;
2314 switch (
Node->Operation) {
2315 case ComplexDeinterleavingOperation::CDot: {
2316 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2317 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2320 "Node inputs need to be of the same type"));
2325 case ComplexDeinterleavingOperation::CAdd:
2326 case ComplexDeinterleavingOperation::CMulPartial:
2327 case ComplexDeinterleavingOperation::Symmetric: {
2328 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2329 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2332 "Node inputs need to be of the same type"));
2335 "Accumulator and input need to be of the same type"));
2336 if (
Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2341 Builder,
Node->Operation,
Node->Rotation, Input0, Input1,
2345 case ComplexDeinterleavingOperation::Deinterleave:
2348 case ComplexDeinterleavingOperation::Splat: {
2350 for (
auto &V :
Node->Vals) {
2351 Ops.push_back(
V.Real);
2352 Ops.push_back(
V.Imag);
2359 for (
auto V :
Node->Vals) {
2367 ReplacementNode = IRB.CreateVectorInterleave(
Ops);
2373 case ComplexDeinterleavingOperation::ReductionPHI: {
2378 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2380 OldToNewPHI[OldPHI] = NewPHI;
2381 ReplacementNode = NewPHI;
2384 case ComplexDeinterleavingOperation::ReductionSingle:
2385 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2386 processReductionSingle(ReplacementNode, Node);
2388 case ComplexDeinterleavingOperation::ReductionOperation:
2389 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2390 processReductionOperation(ReplacementNode, Node);
2392 case ComplexDeinterleavingOperation::ReductionSelect: {
2395 auto *
A = replaceNode(Builder,
Node->Operands[0]);
2396 auto *
B = replaceNode(Builder,
Node->Operands[1]);
2403 assert(ReplacementNode &&
"Target failed to create Intrinsic call.");
2404 NumComplexTransformations += 1;
2405 Node->ReplacementNode = ReplacementNode;
2406 return ReplacementNode;
2409void ComplexDeinterleavingGraph::processReductionSingle(
2410 Value *OperationReplacement, CompositeNode *Node) {
2412 auto *OldPHI = ReductionInfo[Real].first;
2413 auto *NewPHI = OldToNewPHI[OldPHI];
2415 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2417 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2421 Value *NewInit =
nullptr;
2423 if (
C->isNullValue())
2431 NewPHI->addIncoming(NewInit, Incoming);
2432 NewPHI->addIncoming(OperationReplacement, BackEdge);
2434 auto *FinalReduction = ReductionInfo[Real].second;
2441void ComplexDeinterleavingGraph::processReductionOperation(
2442 Value *OperationReplacement, CompositeNode *Node) {
2445 auto *OldPHIReal = ReductionInfo[Real].first;
2446 auto *OldPHIImag = ReductionInfo[Imag].first;
2447 auto *NewPHI = OldToNewPHI[OldPHIReal];
2450 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2451 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2456 NewPHI->addIncoming(NewInit, Incoming);
2457 NewPHI->addIncoming(OperationReplacement, BackEdge);
2461 auto *FinalReductionReal = ReductionInfo[Real].second;
2462 auto *FinalReductionImag = ReductionInfo[Imag].second;
2465 BasicBlock *ExitBB = Br->getSuccessor(Br->getSuccessor(0) == BackEdge);
2469 OperationReplacement->
getType(),
2470 OperationReplacement);
2473 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2477 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2480void ComplexDeinterleavingGraph::replaceNodes() {
2481 SmallVector<Instruction *, 16> DeadInstrRoots;
2482 for (
auto *RootInstruction : OrderedRoots) {
2485 if (!RootToNode.count(RootInstruction))
2489 auto RootNode = RootToNode[RootInstruction];
2490 Value *
R = replaceNode(Builder, RootNode);
2492 if (RootNode->Operation ==
2493 ComplexDeinterleavingOperation::ReductionOperation) {
2496 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2497 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2500 }
else if (RootNode->Operation ==
2501 ComplexDeinterleavingOperation::ReductionSingle) {
2503 auto &
Info = ReductionInfo[RootInst];
2504 Info.first->removeIncomingValue(BackEdge);
2507 assert(R &&
"Unable to find replacement for RootInstruction");
2508 DeadInstrRoots.
push_back(RootInstruction);
2509 RootInstruction->replaceAllUsesWith(R);
2513 for (
auto *
I : DeadInstrRoots)
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
This file implements a map that provides insertion order iteration.
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
This file describes how to lower LLVM code to machine code.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
Get the array size.
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
bool allowContract() const
FunctionPass class - This class is used to implement most global optimizations.
Common base class shared among various IRBuilders.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI Value * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
LLVM_ABI Value * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={}, function_ref< void(CallInst *)> SetFn=[](CallInst *) {})
Variant to create a possibly constant-folded intrinsic.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
An opaque object representing a hash code.
const ParentTy * getParent() const
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
@ BR
Control flow instructions. These all have token chains.
@ BasicBlock
Various leaf nodes.
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
match_bind< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BinOp()
Match an arbitrary binary operation and ignore it.
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
FunctionAddr VTableAddr Value
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
ComplexDeinterleavingOperation
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
ComplexDeinterleavingRotation
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
ComplexDeinterleavingPass(const TargetMachine &TM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static unsigned getHashValue(const ComplexValue &Val)
An information struct used to provide DenseMap with the various necessary components for a given valu...