82#define DEBUG_TYPE "complex-deinterleaving"
84STATISTIC(NumComplexTransformations,
"Amount of complex patterns transformed");
87 "enable-complex-deinterleaving",
115 Value *Real =
nullptr;
116 Value *Imag =
nullptr;
119 return Real ==
Other.Real && Imag ==
Other.Imag;
143 static bool isEqual(
const ComplexValue &LHS,
const ComplexValue &RHS) {
144 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
150template <
typename T,
typename IterT>
151std::optional<T> findCommonBetweenCollections(IterT
A, IterT
B) {
153 if (Common !=
A.end())
154 return std::make_optional(*Common);
158class ComplexDeinterleavingLegacyPass :
public FunctionPass {
162 ComplexDeinterleavingLegacyPass(
const TargetMachine *TM =
nullptr)
163 : FunctionPass(ID), TM(TM) {
168 StringRef getPassName()
const override {
169 return "Complex Deinterleaving Pass";
173 void getAnalysisUsage(AnalysisUsage &AU)
const override {
179 const TargetMachine *TM;
182class ComplexDeinterleavingGraph;
183struct ComplexDeinterleavingCompositeNode {
188 Vals.push_back({
R,
I});
193 : Operation(
Op), Vals(
Other) {}
196 friend class ComplexDeinterleavingGraph;
197 using CompositeNode = ComplexDeinterleavingCompositeNode;
198 bool OperandsValid =
true;
207 std::optional<FastMathFlags> Flags;
210 ComplexDeinterleavingRotation::Rotation_0;
212 Value *ReplacementNode =
nullptr;
216 OperandsValid =
false;
217 Operands.push_back(Node);
221 void dump(raw_ostream &OS) {
222 auto PrintValue = [&](
Value *
V) {
230 auto PrintNodeRef = [&](CompositeNode *
Ptr) {
237 OS <<
"- CompositeNode: " <<
this <<
"\n";
238 for (
unsigned I = 0;
I < Vals.size();
I++) {
239 OS <<
" Real(" <<
I <<
") : ";
240 PrintValue(Vals[
I].Real);
241 OS <<
" Imag(" <<
I <<
") : ";
242 PrintValue(Vals[
I].Imag);
244 OS <<
" ReplacementNode: ";
245 PrintValue(ReplacementNode);
246 OS <<
" Operation: " << (int)Operation <<
"\n";
247 OS <<
" Rotation: " << ((int)Rotation * 90) <<
"\n";
248 OS <<
" Operands: \n";
249 for (
const auto &
Op : Operands) {
255 bool areOperandsValid() {
return OperandsValid; }
258class ComplexDeinterleavingGraph {
266 using Addend = std::pair<Value *, bool>;
268 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
272 struct PartialMulCandidate {
280 explicit ComplexDeinterleavingGraph(
const TargetLowering *TL,
281 const TargetLibraryInfo *TLI,
283 : TL(TL), TLI(TLI), Factor(Factor) {}
286 const TargetLowering *TL =
nullptr;
287 const TargetLibraryInfo *TLI =
nullptr;
290 DenseMap<ComplexValues, CompositeNode *> CachedResult;
291 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
293 SmallPtrSet<Instruction *, 16> FinalInstructions;
296 DenseMap<Instruction *, CompositeNode *> RootToNode;
323 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
331 PHINode *RealPHI =
nullptr;
332 PHINode *ImagPHI =
nullptr;
336 bool PHIsFound =
false;
344 DenseMap<PHINode *, PHINode *> OldToNewPHI;
349 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
351 "Reduction related nodes must have Real and Imaginary parts");
352 return new (Allocator.Allocate())
353 ComplexDeinterleavingCompositeNode(
Operation, R,
I);
359 for (
auto &V : Vals) {
361 ((
Operation != ComplexDeinterleavingOperation::ReductionPHI &&
362 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
363 (
V.Real &&
V.Imag)) &&
364 "Reduction related nodes must have Real and Imaginary parts");
367 return new (Allocator.Allocate())
368 ComplexDeinterleavingCompositeNode(
Operation, Vals);
371 CompositeNode *submitCompositeNode(CompositeNode *Node) {
372 CompositeNodes.push_back(Node);
373 if (
Node->Vals[0].Real)
389 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
395 identifyNodeWithImplicitAdd(Instruction *
I, Instruction *J,
396 std::pair<Value *, Value *> &CommonOperandI);
405 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
406 CompositeNode *identifySymmetricOperation(
ComplexValues &Vals);
407 CompositeNode *identifyPartialReduction(
Value *R,
Value *
I);
408 CompositeNode *identifyDotProduct(
Value *Inst);
415 return identifyNode(Vals);
422 CompositeNode *identifyAdditions(AddendList &RealAddends,
423 AddendList &ImagAddends,
424 std::optional<FastMathFlags> Flags,
428 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
429 AddendList &ImagAddends);
434 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
435 SmallVectorImpl<Product> &ImagMuls,
443 SmallVectorImpl<PartialMulCandidate> &Candidates);
451 CompositeNode *identifyReassocNodes(Instruction *
I, Instruction *J);
453 CompositeNode *identifyRoot(Instruction *
I);
471 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
475 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
477 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
484 void processReductionOperation(
Value *OperationReplacement,
485 CompositeNode *Node);
486 void processReductionSingle(
Value *OperationReplacement, CompositeNode *Node);
490 void dump(raw_ostream &OS) {
491 for (
const auto &Node : CompositeNodes)
497 bool identifyNodes(Instruction *RootI);
502 bool collectPotentialReductions(BasicBlock *
B);
504 void identifyReductionNodes();
514class ComplexDeinterleaving {
516 ComplexDeinterleaving(
const TargetLowering *tl,
const TargetLibraryInfo *tli)
517 : TL(tl), TLI(tli) {}
521 bool evaluateBasicBlock(BasicBlock *
B,
unsigned Factor);
523 const TargetLowering *TL =
nullptr;
524 const TargetLibraryInfo *TLI =
nullptr;
529char ComplexDeinterleavingLegacyPass::ID = 0;
532 "Complex Deinterleaving",
false,
false)
538 const TargetLowering *TL = TM->getSubtargetImpl(
F)->getTargetLowering();
549 return new ComplexDeinterleavingLegacyPass(TM);
552bool ComplexDeinterleavingLegacyPass::runOnFunction(
Function &
F) {
554 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
555 return ComplexDeinterleaving(TL, &TLI).runOnFunction(
F);
558bool ComplexDeinterleaving::runOnFunction(Function &
F) {
561 dbgs() <<
"Complex deinterleaving has been explicitly disabled.\n");
567 dbgs() <<
"Complex deinterleaving has been disabled, target does "
568 "not support lowering of complex number operations.\n");
574 Changed |= evaluateBasicBlock(&
B, 2);
579 Changed |= evaluateBasicBlock(&
B, 4);
589 if ((Mask.size() & 1))
592 int HalfNumElements = Mask.size() / 2;
593 for (
int Idx = 0; Idx < HalfNumElements; ++Idx) {
594 int MaskIdx = Idx * 2;
595 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
604 int HalfNumElements = Mask.size() / 2;
606 for (
int Idx = 1; Idx < HalfNumElements; ++Idx) {
607 if (Mask[Idx] != (Idx * 2) +
Offset)
621 if (
I->getOpcode() == Instruction::FNeg)
622 return I->getOperand(0);
624 return I->getOperand(1);
627bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *
B,
unsigned Factor) {
628 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
629 if (Graph.collectPotentialReductions(
B))
630 Graph.identifyReductionNodes();
633 Graph.identifyNodes(&
I);
635 if (Graph.checkNodes()) {
636 Graph.replaceNodes();
643ComplexDeinterleavingGraph::CompositeNode *
644ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
645 Instruction *Real, Instruction *Imag,
646 std::pair<Value *, Value *> &PartialMatch) {
647 LLVM_DEBUG(
dbgs() <<
"identifyNodeWithImplicitAdd " << *Real <<
" / " << *Imag
655 if ((Real->
getOpcode() != Instruction::FMul &&
656 Real->
getOpcode() != Instruction::Mul) ||
657 (Imag->
getOpcode() != Instruction::FMul &&
658 Imag->
getOpcode() != Instruction::Mul)) {
660 dbgs() <<
" - Real or imaginary instruction is not fmul or mul\n");
693 Value *CommonOperand;
694 Value *UncommonRealOp;
695 Value *UncommonImagOp;
697 if (R0 == I0 || R0 == I1) {
700 }
else if (R1 == I0 || R1 == I1) {
708 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
709 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
710 Rotation == ComplexDeinterleavingRotation::Rotation_270)
711 std::swap(UncommonRealOp, UncommonImagOp);
715 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
716 Rotation == ComplexDeinterleavingRotation::Rotation_180)
717 PartialMatch.first = CommonOperand;
719 PartialMatch.second = CommonOperand;
721 if (!PartialMatch.first || !PartialMatch.second) {
726 CompositeNode *CommonNode =
727 identifyNode(PartialMatch.first, PartialMatch.second);
733 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
739 CompositeNode *
Node = prepareCompositeNode(
740 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
741 Node->Rotation = Rotation;
742 Node->addOperand(CommonNode);
743 Node->addOperand(UncommonNode);
744 return submitCompositeNode(Node);
747ComplexDeinterleavingGraph::CompositeNode *
748ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
750 LLVM_DEBUG(
dbgs() <<
"identifyPartialMul " << *Real <<
" / " << *Imag
754 auto IsAdd = [](
unsigned Op) {
755 return Op == Instruction::FAdd ||
Op == Instruction::Add;
757 auto IsSub = [](
unsigned Op) {
758 return Op == Instruction::FSub ||
Op == Instruction::Sub;
762 Rotation = ComplexDeinterleavingRotation::Rotation_0;
764 Rotation = ComplexDeinterleavingRotation::Rotation_90;
766 Rotation = ComplexDeinterleavingRotation::Rotation_180;
768 Rotation = ComplexDeinterleavingRotation::Rotation_270;
777 LLVM_DEBUG(
dbgs() <<
" - Contract is missing from the FastMath flags.\n");
800 Value *CommonOperand;
801 Value *UncommonRealOp;
802 Value *UncommonImagOp;
804 if (R0 == I0 || R0 == I1) {
807 }
else if (R1 == I0 || R1 == I1) {
815 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
816 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
817 Rotation == ComplexDeinterleavingRotation::Rotation_270)
818 std::swap(UncommonRealOp, UncommonImagOp);
820 std::pair<Value *, Value *> PartialMatch(
821 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
822 Rotation == ComplexDeinterleavingRotation::Rotation_180)
825 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
826 Rotation == ComplexDeinterleavingRotation::Rotation_270)
833 if (!CRInst || !CIInst) {
834 LLVM_DEBUG(
dbgs() <<
" - Common operands are not instructions.\n");
838 CompositeNode *CNode =
839 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
845 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
851 assert(PartialMatch.first && PartialMatch.second);
852 CompositeNode *CommonRes =
853 identifyNode(PartialMatch.first, PartialMatch.second);
859 CompositeNode *
Node = prepareCompositeNode(
860 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
861 Node->Rotation = Rotation;
862 Node->addOperand(CommonRes);
863 Node->addOperand(UncommonRes);
864 Node->addOperand(CNode);
865 return submitCompositeNode(Node);
868ComplexDeinterleavingGraph::CompositeNode *
869ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
870 LLVM_DEBUG(
dbgs() <<
"identifyAdd " << *Real <<
" / " << *Imag <<
"\n");
874 if ((Real->
getOpcode() == Instruction::FSub &&
875 Imag->
getOpcode() == Instruction::FAdd) ||
876 (Real->
getOpcode() == Instruction::Sub &&
878 Rotation = ComplexDeinterleavingRotation::Rotation_90;
879 else if ((Real->
getOpcode() == Instruction::FAdd &&
880 Imag->
getOpcode() == Instruction::FSub) ||
881 (Real->
getOpcode() == Instruction::Add &&
883 Rotation = ComplexDeinterleavingRotation::Rotation_270;
885 LLVM_DEBUG(
dbgs() <<
" - Unhandled case, rotation is not assigned.\n");
894 if (!AR || !AI || !BR || !BI) {
899 CompositeNode *ResA = identifyNode(AR, AI);
901 LLVM_DEBUG(
dbgs() <<
" - AR/AI is not identified as a composite node.\n");
904 CompositeNode *ResB = identifyNode(BR, BI);
906 LLVM_DEBUG(
dbgs() <<
" - BR/BI is not identified as a composite node.\n");
910 CompositeNode *
Node =
911 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
912 Node->Rotation = Rotation;
913 Node->addOperand(ResA);
914 Node->addOperand(ResB);
915 return submitCompositeNode(Node);
919 unsigned OpcA =
A->getOpcode();
920 unsigned OpcB =
B->getOpcode();
922 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
923 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
924 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
925 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
936 switch (
I->getOpcode()) {
937 case Instruction::FAdd:
938 case Instruction::FSub:
939 case Instruction::FMul:
940 case Instruction::FNeg:
941 case Instruction::Add:
942 case Instruction::Sub:
943 case Instruction::Mul:
950ComplexDeinterleavingGraph::CompositeNode *
951ComplexDeinterleavingGraph::identifySymmetricOperation(
ComplexValues &Vals) {
953 unsigned FirstOpc = FirstReal->getOpcode();
954 for (
auto &V : Vals) {
971 for (
auto &V : Vals) {
977 CompositeNode *Op0 = identifyNode(OpVals);
978 CompositeNode *Op1 =
nullptr;
982 if (FirstReal->isBinaryOp()) {
984 for (
auto &V : Vals) {
989 Op1 = identifyNode(OpVals);
995 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
996 Node->Opcode = FirstReal->getOpcode();
998 Node->Flags = FirstReal->getFastMathFlags();
1000 Node->addOperand(Op0);
1001 if (FirstReal->isBinaryOp())
1002 Node->addOperand(Op1);
1004 return submitCompositeNode(Node);
1007ComplexDeinterleavingGraph::CompositeNode *
1008ComplexDeinterleavingGraph::identifyDotProduct(
Value *V) {
1010 ComplexDeinterleavingOperation::CDot,
V->getType())) {
1011 LLVM_DEBUG(
dbgs() <<
"Target doesn't support complex deinterleaving "
1012 "operation CDot with the type "
1013 << *
V->getType() <<
"\n");
1021 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst,
nullptr);
1023 CompositeNode *ANode =
nullptr;
1025 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1027 Value *AReal =
nullptr;
1028 Value *AImag =
nullptr;
1029 Value *BReal =
nullptr;
1030 Value *BImag =
nullptr;
1035 return CI->getOperand(0);
1049 if (
match(Inst, PatternRot0)) {
1050 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1051 }
else if (
match(Inst, PatternRot270)) {
1052 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1063 if (!
match(Inst, PatternRot90Rot180))
1066 A0 = UnwrapCast(A0);
1067 A1 = UnwrapCast(A1);
1070 ANode = identifyNode(A0, A1);
1073 ANode = identifyNode(A1, A0);
1077 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1083 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1087 AReal = UnwrapCast(AReal);
1088 AImag = UnwrapCast(AImag);
1089 BReal = UnwrapCast(BReal);
1090 BImag = UnwrapCast(BImag);
1093 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1094 if (AReal->
getType() != ExpectedOperandTy)
1096 if (AImag->
getType() != ExpectedOperandTy)
1098 if (BReal->
getType() != ExpectedOperandTy)
1100 if (BImag->
getType() != ExpectedOperandTy)
1103 if (
Phi->getType() != VTy && RealUser->getType() != VTy)
1106 CompositeNode *
Node = identifyNode(AReal, AImag);
1111 if (ANode && Node != ANode) {
1114 <<
"Identified node is different from previously identified node. "
1115 "Unable to confidently generate a complex operation node\n");
1119 CN->addOperand(Node);
1120 CN->addOperand(identifyNode(BReal, BImag));
1121 CN->addOperand(identifyNode(Phi, RealUser));
1123 return submitCompositeNode(CN);
1126ComplexDeinterleavingGraph::CompositeNode *
1127ComplexDeinterleavingGraph::identifyPartialReduction(
Value *R,
Value *
I) {
1132 if (!
R->hasUseList() || !
I->hasUseList())
1136 findCommonBetweenCollections<Value *>(
R->users(),
I->users());
1141 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1144 if (CompositeNode *CN = identifyDotProduct(IInst))
1150ComplexDeinterleavingGraph::CompositeNode *
1151ComplexDeinterleavingGraph::identifyNode(
ComplexValues &Vals) {
1152 auto It = CachedResult.
find(Vals);
1153 if (It != CachedResult.
end()) {
1158 if (Vals.
size() == 1) {
1159 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1162 if (CompositeNode *CN = identifyPartialReduction(R,
I))
1164 bool IsReduction = RealPHI ==
R && (!ImagPHI || ImagPHI ==
I);
1165 if (!IsReduction &&
R->getType() !=
I->getType())
1169 if (CompositeNode *CN = identifySplat(Vals))
1172 for (
auto &V : Vals) {
1179 if (CompositeNode *CN = identifyDeinterleave(Vals))
1182 if (Vals.size() == 1) {
1183 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1186 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1189 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1193 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1196 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1198 ComplexDeinterleavingOperation::CAdd, NewVTy);
1201 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1206 if (CompositeNode *CN = identifyAdd(Real, Imag))
1210 if (HasCMulSupport && HasCAddSupport) {
1211 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1217 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1221 CachedResult[Vals] =
nullptr;
1225ComplexDeinterleavingGraph::CompositeNode *
1226ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1227 Instruction *Imag) {
1228 auto IsOperationSupported = [](
unsigned Opcode) ->
bool {
1229 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1230 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1231 Opcode == Instruction::Sub;
1234 if (!IsOperationSupported(Real->
getOpcode()) ||
1235 !IsOperationSupported(Imag->
getOpcode()))
1238 std::optional<FastMathFlags>
Flags;
1241 LLVM_DEBUG(
dbgs() <<
"The flags in Real and Imaginary instructions are "
1247 if (!
Flags->allowReassoc()) {
1250 <<
"the 'Reassoc' attribute is missing in the FastMath flags\n");
1259 AddendList &Addends) ->
bool {
1261 SmallPtrSet<Value *, 8> Visited;
1262 while (!Worklist.
empty()) {
1264 if (!Visited.
insert(V).second)
1269 Addends.emplace_back(V, IsPositive);
1279 if (
I != Insn &&
I->hasNUsesOrMore(2)) {
1280 LLVM_DEBUG(
dbgs() <<
"Found potential sub-expression: " << *
I <<
"\n");
1281 Addends.emplace_back(
I, IsPositive);
1284 switch (
I->getOpcode()) {
1285 case Instruction::FAdd:
1286 case Instruction::Add:
1290 case Instruction::FSub:
1294 case Instruction::Sub:
1302 case Instruction::FMul:
1303 case Instruction::Mul: {
1305 if (
isNeg(
I->getOperand(0))) {
1307 IsPositive = !IsPositive;
1309 A =
I->getOperand(0);
1312 if (
isNeg(
I->getOperand(1))) {
1314 IsPositive = !IsPositive;
1316 B =
I->getOperand(1);
1318 Muls.push_back(Product{
A,
B, IsPositive});
1321 case Instruction::FNeg:
1325 Addends.emplace_back(
I, IsPositive);
1329 if (Flags &&
I->getFastMathFlags() != *Flags) {
1331 "inconsistent with the root instructions' flags: "
1340 AddendList RealAddends, ImagAddends;
1341 if (!Collect(Real, RealMuls, RealAddends) ||
1342 !Collect(Imag, ImagMuls, ImagAddends))
1345 if (RealAddends.size() != ImagAddends.size())
1348 CompositeNode *FinalNode =
nullptr;
1349 if (!RealMuls.
empty() || !ImagMuls.
empty()) {
1352 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1353 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1359 if (!RealAddends.empty() || !ImagAddends.empty()) {
1360 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1364 assert(FinalNode &&
"FinalNode can not be nullptr here");
1365 assert(FinalNode->Vals.size() == 1);
1367 FinalNode->Vals[0].Real = Real;
1368 FinalNode->Vals[0].Imag = Imag;
1369 submitCompositeNode(FinalNode);
1373bool ComplexDeinterleavingGraph::collectPartialMuls(
1375 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1377 auto FindCommonInstruction = [](
const Product &Real,
1378 const Product &Imag) ->
Value * {
1379 if (Real.Multiplicand == Imag.Multiplicand ||
1380 Real.Multiplicand == Imag.Multiplier)
1381 return Real.Multiplicand;
1383 if (Real.Multiplier == Imag.Multiplicand ||
1384 Real.Multiplier == Imag.Multiplier)
1385 return Real.Multiplier;
1394 for (
unsigned i = 0; i < RealMuls.
size(); ++i) {
1395 bool FoundCommon =
false;
1396 for (
unsigned j = 0;
j < ImagMuls.
size(); ++
j) {
1397 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1401 auto *
A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1402 : RealMuls[i].Multiplicand;
1403 auto *
B = ImagMuls[
j].Multiplicand == Common ? ImagMuls[
j].Multiplier
1404 : ImagMuls[
j].Multiplicand;
1406 auto Node = identifyNode(
A,
B);
1412 Node = identifyNode(
B,
A);
1424ComplexDeinterleavingGraph::CompositeNode *
1425ComplexDeinterleavingGraph::identifyMultiplications(
1426 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1428 if (RealMuls.
size() != ImagMuls.
size())
1432 if (!collectPartialMuls(RealMuls, ImagMuls,
Info))
1436 DenseMap<Value *, CompositeNode *> CommonToNode;
1437 SmallVector<bool> Processed(
Info.size(),
false);
1438 for (
unsigned I = 0;
I <
Info.size(); ++
I) {
1442 PartialMulCandidate &InfoA =
Info[
I];
1443 for (
unsigned J =
I + 1; J <
Info.size(); ++J) {
1447 PartialMulCandidate &InfoB =
Info[J];
1448 auto *InfoReal = &InfoA;
1449 auto *InfoImag = &InfoB;
1451 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1452 if (!NodeFromCommon) {
1454 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1456 if (!NodeFromCommon)
1459 CommonToNode[InfoReal->Common] = NodeFromCommon;
1460 CommonToNode[InfoImag->Common] = NodeFromCommon;
1461 Processed[
I] =
true;
1462 Processed[J] =
true;
1466 SmallVector<bool> ProcessedReal(RealMuls.
size(),
false);
1467 SmallVector<bool> ProcessedImag(ImagMuls.
size(),
false);
1469 for (
auto &PMI :
Info) {
1470 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1473 auto It = CommonToNode.
find(PMI.Common);
1476 if (It == CommonToNode.
end()) {
1478 dbgs() <<
"Unprocessed independent partial multiplication:\n";
1479 for (
auto *
Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1481 <<
" multiplied by " << *
Mul->Multiplicand <<
"\n";
1486 auto &RealMul = RealMuls[PMI.RealIdx];
1487 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1489 auto NodeA = It->second;
1490 auto NodeB = PMI.Node;
1491 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1506 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1507 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1512 if (IsMultiplicandReal) {
1514 if (RealMul.IsPositive && ImagMul.IsPositive)
1516 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1523 if (!RealMul.IsPositive && ImagMul.IsPositive)
1525 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1532 dbgs() <<
"Identified partial multiplication (X, Y) * (U, V):\n";
1533 dbgs().
indent(4) <<
"X: " << *NodeA->Vals[0].Real <<
"\n";
1534 dbgs().
indent(4) <<
"Y: " << *NodeA->Vals[0].Imag <<
"\n";
1535 dbgs().
indent(4) <<
"U: " << *NodeB->Vals[0].Real <<
"\n";
1536 dbgs().
indent(4) <<
"V: " << *NodeB->Vals[0].Imag <<
"\n";
1537 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1540 CompositeNode *NodeMul = prepareCompositeNode(
1541 ComplexDeinterleavingOperation::CMulPartial,
nullptr,
nullptr);
1542 NodeMul->Rotation = Rotation;
1543 NodeMul->addOperand(NodeA);
1544 NodeMul->addOperand(NodeB);
1546 NodeMul->addOperand(Result);
1547 submitCompositeNode(NodeMul);
1549 ProcessedReal[PMI.RealIdx] =
true;
1550 ProcessedImag[PMI.ImagIdx] =
true;
1554 if (!
all_of(ProcessedReal, [](
bool V) {
return V; }) ||
1555 !
all_of(ProcessedImag, [](
bool V) {
return V; })) {
1560 dbgs() <<
"Unprocessed products (Real):\n";
1561 for (
size_t i = 0; i < ProcessedReal.size(); ++i) {
1562 if (!ProcessedReal[i])
1563 dbgs().
indent(4) << (RealMuls[i].IsPositive ?
"+" :
"-")
1564 << *RealMuls[i].Multiplier <<
" multiplied by "
1565 << *RealMuls[i].Multiplicand <<
"\n";
1567 dbgs() <<
"Unprocessed products (Imag):\n";
1568 for (
size_t i = 0; i < ProcessedImag.size(); ++i) {
1569 if (!ProcessedImag[i])
1570 dbgs().
indent(4) << (ImagMuls[i].IsPositive ?
"+" :
"-")
1571 << *ImagMuls[i].Multiplier <<
" multiplied by "
1572 << *ImagMuls[i].Multiplicand <<
"\n";
1581ComplexDeinterleavingGraph::CompositeNode *
1582ComplexDeinterleavingGraph::identifyAdditions(
1583 AddendList &RealAddends, AddendList &ImagAddends,
1584 std::optional<FastMathFlags> Flags, CompositeNode *
Accumulator =
nullptr) {
1585 if (RealAddends.size() != ImagAddends.size())
1588 CompositeNode *
Result =
nullptr;
1594 Result = extractPositiveAddend(RealAddends, ImagAddends);
1599 while (!RealAddends.empty()) {
1600 auto ItR = RealAddends.begin();
1601 auto [
R, IsPositiveR] = *ItR;
1603 bool FoundImag =
false;
1604 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1605 auto [
I, IsPositiveI] = *ItI;
1607 if (IsPositiveR && IsPositiveI)
1608 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1609 else if (!IsPositiveR && IsPositiveI)
1610 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1611 else if (!IsPositiveR && !IsPositiveI)
1612 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1614 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1616 CompositeNode *AddNode =
nullptr;
1617 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1618 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1619 AddNode = identifyNode(R,
I);
1621 AddNode = identifyNode(
I, R);
1625 dbgs() <<
"Identified addition:\n";
1628 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1631 CompositeNode *TmpNode =
nullptr;
1633 TmpNode = prepareCompositeNode(
1634 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1636 TmpNode->Opcode = Instruction::FAdd;
1637 TmpNode->Flags = *
Flags;
1639 TmpNode->Opcode = Instruction::Add;
1641 }
else if (Rotation ==
1643 TmpNode = prepareCompositeNode(
1644 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1646 TmpNode->Opcode = Instruction::FSub;
1647 TmpNode->Flags = *
Flags;
1649 TmpNode->Opcode = Instruction::Sub;
1652 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1654 TmpNode->Rotation = Rotation;
1657 TmpNode->addOperand(Result);
1658 TmpNode->addOperand(AddNode);
1659 submitCompositeNode(TmpNode);
1661 RealAddends.erase(ItR);
1662 ImagAddends.erase(ItI);
1673ComplexDeinterleavingGraph::CompositeNode *
1674ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1675 AddendList &ImagAddends) {
1676 for (
auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1677 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1678 auto [
R, IsPositiveR] = *ItR;
1679 auto [
I, IsPositiveI] = *ItI;
1680 if (IsPositiveR && IsPositiveI) {
1681 auto Result = identifyNode(R,
I);
1683 RealAddends.erase(ItR);
1684 ImagAddends.erase(ItI);
1693bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1698 auto It = RootToNode.
find(RootI);
1699 if (It != RootToNode.
end()) {
1700 auto RootNode = It->second;
1701 assert(RootNode->Operation ==
1702 ComplexDeinterleavingOperation::ReductionOperation ||
1703 RootNode->Operation ==
1704 ComplexDeinterleavingOperation::ReductionSingle);
1705 assert(RootNode->Vals.size() == 1 &&
1706 "Cannot handle reductions involving multiple complex values");
1715 ReplacementAnchor =
R->comesBefore(
I) ?
I :
R;
1717 ReplacementAnchor =
R;
1719 if (ReplacementAnchor != RootI)
1725 auto RootNode = identifyRoot(RootI);
1732 dbgs() <<
"Complex deinterleaving graph for " <<
F->getName()
1733 <<
"::" <<
B->getName() <<
".\n";
1737 RootToNode[RootI] = RootNode;
1742bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *
B) {
1743 bool FoundPotentialReduction =
false;
1748 if (!Br || Br->getNumSuccessors() != 2)
1752 if (Br->getSuccessor(0) !=
B && Br->getSuccessor(1) !=
B)
1755 for (
auto &
PHI :
B->phis()) {
1756 if (
PHI.getNumIncomingValues() != 2)
1759 if (!
PHI.getType()->isVectorTy())
1769 for (
auto *U : ReductionOp->users()) {
1776 if (NumUsers != 2 || !FinalReduction || FinalReduction->
getParent() ==
B ||
1780 ReductionInfo[ReductionOp] = {&
PHI, FinalReduction};
1782 auto BackEdgeIdx =
PHI.getBasicBlockIndex(
B);
1783 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1784 Incoming =
PHI.getIncomingBlock(IncomingIdx);
1785 FoundPotentialReduction =
true;
1791 FinalInstructions.
insert(InitPHI);
1793 return FoundPotentialReduction;
1796void ComplexDeinterleavingGraph::identifyReductionNodes() {
1797 assert(Factor == 2 &&
"Cannot handle multiple complex values");
1799 SmallVector<bool> Processed(ReductionInfo.
size(),
false);
1801 for (
auto &
P : ReductionInfo)
1806 for (
size_t i = 0; i < OperationInstruction.
size(); ++i) {
1809 for (
size_t j = i + 1;
j < OperationInstruction.
size(); ++
j) {
1812 auto *Real = OperationInstruction[i];
1813 auto *Imag = OperationInstruction[
j];
1814 if (Real->getType() != Imag->
getType())
1817 RealPHI = ReductionInfo[Real].first;
1818 ImagPHI = ReductionInfo[Imag].first;
1820 auto Node = identifyNode(Real, Imag);
1824 Node = identifyNode(Real, Imag);
1830 if (Node && PHIsFound) {
1831 LLVM_DEBUG(
dbgs() <<
"Identified reduction starting from instructions: "
1832 << *Real <<
" / " << *Imag <<
"\n");
1833 Processed[i] =
true;
1834 Processed[
j] =
true;
1835 auto RootNode = prepareCompositeNode(
1836 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1837 RootNode->addOperand(Node);
1838 RootToNode[Real] = RootNode;
1839 RootToNode[Imag] = RootNode;
1840 submitCompositeNode(RootNode);
1845 auto *Real = OperationInstruction[i];
1848 if (Processed[i] || Real->getNumOperands() < 2)
1852 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1855 RealPHI = ReductionInfo[Real].first;
1858 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1859 if (Node && PHIsFound) {
1861 dbgs() <<
"Identified single reduction starting from instruction: "
1862 << *Real <<
"/" << *ReductionInfo[Real].second <<
"\n");
1871 if (ReductionInfo[Real].second->getType()->isVectorTy())
1874 Processed[i] =
true;
1875 auto RootNode = prepareCompositeNode(
1876 ComplexDeinterleavingOperation::ReductionSingle, Real,
nullptr);
1877 RootNode->addOperand(Node);
1878 RootToNode[Real] = RootNode;
1879 submitCompositeNode(RootNode);
1887bool ComplexDeinterleavingGraph::checkNodes() {
1888 bool FoundDeinterleaveNode =
false;
1889 for (CompositeNode *
N : CompositeNodes) {
1890 if (!
N->areOperandsValid())
1893 if (
N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1894 FoundDeinterleaveNode =
true;
1899 if (!FoundDeinterleaveNode) {
1901 dbgs() <<
"Couldn't find a deinterleave node within the graph, cannot "
1902 "guarantee safety during graph transformation.\n");
1907 SmallPtrSet<Instruction *, 16> AllInstructions;
1908 SmallVector<Instruction *, 8> Worklist;
1909 for (
auto &Pair : RootToNode)
1914 while (!Worklist.
empty()) {
1917 if (!AllInstructions.
insert(
I).second)
1922 if (!FinalInstructions.
count(
I))
1929 for (
auto *
I : AllInstructions) {
1931 if (RootToNode.count(
I))
1934 for (User *U :
I->users()) {
1946 SmallPtrSet<Instruction *, 16> Visited;
1947 while (!Worklist.
empty()) {
1949 if (!Visited.
insert(
I).second)
1954 if (RootToNode.count(
I)) {
1956 <<
" could be deinterleaved but its chain of complex "
1957 "operations have an outside user\n");
1958 RootToNode.erase(
I);
1961 if (!AllInstructions.count(
I) || FinalInstructions.
count(
I))
1964 for (User *U :
I->users())
1972 return !RootToNode.
empty();
1975ComplexDeinterleavingGraph::CompositeNode *
1976ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1983 for (
unsigned I = 0;
I < Factor;
I += 2) {
1991 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
2019 return identifyNode(Real, Imag);
2022ComplexDeinterleavingGraph::CompositeNode *
2023ComplexDeinterleavingGraph::identifyDeinterleave(
ComplexValues &Vals) {
2027 auto CheckExtract = [&](
Value *
V,
unsigned ExpectedIdx,
2028 Instruction *ExpectedInsn) -> ExtractValueInst * {
2030 if (!EVI || EVI->getNumIndices() != 1 ||
2031 EVI->getIndices()[0] != ExpectedIdx ||
2033 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2038 for (
unsigned Idx = 0; Idx < Vals.
size(); Idx++) {
2039 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2,
II);
2040 if (RealEVI && Idx == 0)
2042 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1,
II)) {
2049 if (IntrinsicII->getIntrinsicID() !=
2054 CompositeNode *PlaceholderNode = prepareCompositeNode(
2056 PlaceholderNode->ReplacementNode =
II->getOperand(0);
2057 for (
auto &V : Vals) {
2061 return submitCompositeNode(PlaceholderNode);
2064 if (Vals.size() != 1)
2067 Value *Real = Vals[0].Real;
2068 Value *Imag = Vals[0].Imag;
2071 if (!RealShuffle || !ImagShuffle) {
2072 if (RealShuffle || ImagShuffle)
2073 LLVM_DEBUG(
dbgs() <<
" - There's a shuffle where there shouldn't be.\n");
2077 Value *RealOp1 = RealShuffle->getOperand(1);
2082 Value *ImagOp1 = ImagShuffle->getOperand(1);
2088 Value *RealOp0 = RealShuffle->getOperand(0);
2089 Value *ImagOp0 = ImagShuffle->getOperand(0);
2091 if (RealOp0 != ImagOp0) {
2096 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2097 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2103 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2104 LLVM_DEBUG(
dbgs() <<
" - Masks do not have the correct initial value.\n");
2110 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2111 Value *
Op = Shuffle->getOperand(0);
2115 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2117 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2123 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) ->
bool {
2127 ArrayRef<int>
Mask = Shuffle->getShuffleMask();
2130 Value *
Op = Shuffle->getOperand(0);
2132 int NumElements = OpTy->getNumElements();
2136 return Last < NumElements;
2139 if (RealShuffle->getType() != ImagShuffle->getType()) {
2143 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2147 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2152 CompositeNode *PlaceholderNode =
2154 RealShuffle, ImagShuffle);
2155 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2156 FinalInstructions.
insert(RealShuffle);
2157 FinalInstructions.
insert(ImagShuffle);
2158 return submitCompositeNode(PlaceholderNode);
2161ComplexDeinterleavingGraph::CompositeNode *
2162ComplexDeinterleavingGraph::identifySplat(
ComplexValues &Vals) {
2163 auto IsSplat = [](
Value *
V) ->
bool {
2176 if (
Const->getOpcode() != Instruction::ShuffleVector)
2181 VTy = Shuf->getType();
2182 Mask = Shuf->getShuffleMask();
2190 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2200 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2201 for (
auto &V : Vals) {
2202 if (!IsSplat(
V.Real) || !IsSplat(
V.Imag))
2207 if (!Real || !Imag || Real->getParent() != FirstBB ||
2208 Imag->getParent() != FirstBB)
2212 for (
auto &V : Vals) {
2219 for (
auto &V : Vals) {
2223 FinalInstructions.
insert(Real);
2224 FinalInstructions.
insert(Imag);
2227 CompositeNode *PlaceholderNode =
2228 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2229 return submitCompositeNode(PlaceholderNode);
2232ComplexDeinterleavingGraph::CompositeNode *
2233ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2234 Instruction *Imag) {
2235 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2239 CompositeNode *PlaceholderNode = prepareCompositeNode(
2240 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2241 return submitCompositeNode(PlaceholderNode);
2244ComplexDeinterleavingGraph::CompositeNode *
2245ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2246 Instruction *Imag) {
2249 if (!SelectReal || !SelectImag)
2266 auto NodeA = identifyNode(AR, AI);
2270 auto NodeB = identifyNode(
RA, BI);
2274 CompositeNode *PlaceholderNode = prepareCompositeNode(
2275 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2276 PlaceholderNode->addOperand(NodeA);
2277 PlaceholderNode->addOperand(NodeB);
2278 FinalInstructions.
insert(MaskA);
2279 FinalInstructions.
insert(MaskB);
2280 return submitCompositeNode(PlaceholderNode);
2284 std::optional<FastMathFlags> Flags,
2288 case Instruction::FNeg:
2289 I =
B.CreateFNeg(InputA);
2291 case Instruction::FAdd:
2292 I =
B.CreateFAdd(InputA, InputB);
2294 case Instruction::Add:
2295 I =
B.CreateAdd(InputA, InputB);
2297 case Instruction::FSub:
2298 I =
B.CreateFSub(InputA, InputB);
2300 case Instruction::Sub:
2301 I =
B.CreateSub(InputA, InputB);
2303 case Instruction::FMul:
2304 I =
B.CreateFMul(InputA, InputB);
2306 case Instruction::Mul:
2307 I =
B.CreateMul(InputA, InputB);
2317Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2318 CompositeNode *Node) {
2319 if (
Node->ReplacementNode)
2320 return Node->ReplacementNode;
2322 auto ReplaceOperandIfExist = [&](CompositeNode *
Node,
2323 unsigned Idx) ->
Value * {
2324 return Node->Operands.size() > Idx
2325 ? replaceNode(Builder,
Node->Operands[Idx])
2329 Value *ReplacementNode =
nullptr;
2330 switch (
Node->Operation) {
2331 case ComplexDeinterleavingOperation::CDot: {
2332 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2333 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2336 "Node inputs need to be of the same type"));
2341 case ComplexDeinterleavingOperation::CAdd:
2342 case ComplexDeinterleavingOperation::CMulPartial:
2343 case ComplexDeinterleavingOperation::Symmetric: {
2344 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2345 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2348 "Node inputs need to be of the same type"));
2351 "Accumulator and input need to be of the same type"));
2352 if (
Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2357 Builder,
Node->Operation,
Node->Rotation, Input0, Input1,
2361 case ComplexDeinterleavingOperation::Deinterleave:
2364 case ComplexDeinterleavingOperation::Splat: {
2366 for (
auto &V :
Node->Vals) {
2367 Ops.push_back(
V.Real);
2368 Ops.push_back(
V.Imag);
2375 for (
auto V :
Node->Vals) {
2383 ReplacementNode = IRB.CreateVectorInterleave(
Ops);
2389 case ComplexDeinterleavingOperation::ReductionPHI: {
2394 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2396 OldToNewPHI[OldPHI] = NewPHI;
2397 ReplacementNode = NewPHI;
2400 case ComplexDeinterleavingOperation::ReductionSingle:
2401 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2402 processReductionSingle(ReplacementNode, Node);
2404 case ComplexDeinterleavingOperation::ReductionOperation:
2405 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2406 processReductionOperation(ReplacementNode, Node);
2408 case ComplexDeinterleavingOperation::ReductionSelect: {
2411 auto *
A = replaceNode(Builder,
Node->Operands[0]);
2412 auto *
B = replaceNode(Builder,
Node->Operands[1]);
2419 assert(ReplacementNode &&
"Target failed to create Intrinsic call.");
2420 NumComplexTransformations += 1;
2421 Node->ReplacementNode = ReplacementNode;
2422 return ReplacementNode;
2425void ComplexDeinterleavingGraph::processReductionSingle(
2426 Value *OperationReplacement, CompositeNode *Node) {
2428 auto *OldPHI = ReductionInfo[Real].first;
2429 auto *NewPHI = OldToNewPHI[OldPHI];
2431 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2433 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2437 Value *NewInit =
nullptr;
2439 if (
C->isZeroValue())
2447 NewPHI->addIncoming(NewInit, Incoming);
2448 NewPHI->addIncoming(OperationReplacement, BackEdge);
2450 auto *FinalReduction = ReductionInfo[Real].second;
2457void ComplexDeinterleavingGraph::processReductionOperation(
2458 Value *OperationReplacement, CompositeNode *Node) {
2461 auto *OldPHIReal = ReductionInfo[Real].first;
2462 auto *OldPHIImag = ReductionInfo[Imag].first;
2463 auto *NewPHI = OldToNewPHI[OldPHIReal];
2466 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2467 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2472 NewPHI->addIncoming(NewInit, Incoming);
2473 NewPHI->addIncoming(OperationReplacement, BackEdge);
2477 auto *FinalReductionReal = ReductionInfo[Real].second;
2478 auto *FinalReductionImag = ReductionInfo[Imag].second;
2481 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2483 OperationReplacement->
getType(),
2484 OperationReplacement);
2487 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2491 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2494void ComplexDeinterleavingGraph::replaceNodes() {
2495 SmallVector<Instruction *, 16> DeadInstrRoots;
2496 for (
auto *RootInstruction : OrderedRoots) {
2499 if (!RootToNode.count(RootInstruction))
2503 auto RootNode = RootToNode[RootInstruction];
2504 Value *
R = replaceNode(Builder, RootNode);
2506 if (RootNode->Operation ==
2507 ComplexDeinterleavingOperation::ReductionOperation) {
2510 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2511 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2514 }
else if (RootNode->Operation ==
2515 ComplexDeinterleavingOperation::ReductionSingle) {
2517 auto &
Info = ReductionInfo[RootInst];
2518 Info.first->removeIncomingValue(BackEdge);
2521 assert(R &&
"Unable to find replacement for RootInstruction");
2522 DeadInstrRoots.
push_back(RootInstruction);
2523 RootInstruction->replaceAllUsesWith(R);
2527 for (
auto *
I : DeadInstrRoots)
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Analysis containing CSE Info
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
This file implements a map that provides insertion order iteration.
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
This file describes how to lower LLVM code to machine code.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
bool allowContract() const
FunctionPass class - This class is used to implement most global optimizations.
Common base class shared among various IRBuilders.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
An opaque object representing a hash code.
const ParentTy * getParent() const
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
@ BasicBlock
Various leaf nodes.
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
FunctionAddr VTableAddr Value
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
ComplexDeinterleavingOperation
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
LLVM_ABI void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
ComplexDeinterleavingRotation
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
ComplexDeinterleavingPass(TargetMachine *TM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static ComplexValue getEmptyKey()
static unsigned getHashValue(const ComplexValue &Val)
static ComplexValue getTombstoneKey()
An information struct used to provide DenseMap with the various necessary components for a given valu...