82#define DEBUG_TYPE "complex-deinterleaving"
84STATISTIC(NumComplexTransformations,
"Amount of complex patterns transformed");
87 "enable-complex-deinterleaving",
115 Value *Real =
nullptr;
116 Value *Imag =
nullptr;
119 return Real ==
Other.Real && Imag ==
Other.Imag;
134 static bool isEqual(
const ComplexValue &LHS,
const ComplexValue &RHS) {
135 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
140template <
typename T,
typename IterT>
141std::optional<T> findCommonBetweenCollections(IterT
A, IterT
B) {
143 if (Common !=
A.end())
144 return std::make_optional(*Common);
148class ComplexDeinterleavingLegacyPass :
public FunctionPass {
152 ComplexDeinterleavingLegacyPass(
const TargetMachine *TM =
nullptr)
153 : FunctionPass(ID), TM(TM) {}
155 StringRef getPassName()
const override {
156 return "Complex Deinterleaving Pass";
160 void getAnalysisUsage(AnalysisUsage &AU)
const override {
166 const TargetMachine *TM;
169class ComplexDeinterleavingGraph;
170struct ComplexDeinterleavingCompositeNode {
175 Vals.push_back({
R,
I});
180 : Operation(
Op), Vals(
Other) {}
183 friend class ComplexDeinterleavingGraph;
184 using CompositeNode = ComplexDeinterleavingCompositeNode;
185 bool OperandsValid =
true;
194 std::optional<FastMathFlags> Flags;
197 ComplexDeinterleavingRotation::Rotation_0;
199 Value *ReplacementNode =
nullptr;
203 OperandsValid =
false;
204 Operands.push_back(Node);
208 void dump(raw_ostream &OS) {
209 auto PrintValue = [&](
Value *
V) {
217 auto PrintNodeRef = [&](CompositeNode *Ptr) {
224 OS <<
"- CompositeNode: " <<
this <<
"\n";
225 for (
unsigned I = 0;
I < Vals.size();
I++) {
226 OS <<
" Real(" <<
I <<
") : ";
227 PrintValue(Vals[
I].Real);
228 OS <<
" Imag(" <<
I <<
") : ";
229 PrintValue(Vals[
I].Imag);
231 OS <<
" ReplacementNode: ";
232 PrintValue(ReplacementNode);
233 OS <<
" Operation: " << (int)Operation <<
"\n";
234 OS <<
" Rotation: " << ((int)Rotation * 90) <<
"\n";
235 OS <<
" Operands: \n";
236 for (
const auto &
Op : Operands) {
242 bool areOperandsValid() {
return OperandsValid; }
245class ComplexDeinterleavingGraph {
253 using Addend = std::pair<Value *, bool>;
255 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
259 struct PartialMulCandidate {
267 explicit ComplexDeinterleavingGraph(
const TargetLowering *TL,
268 const TargetLibraryInfo *TLI,
270 : TL(TL), TLI(TLI), Factor(Factor) {}
273 const TargetLowering *TL =
nullptr;
274 const TargetLibraryInfo *TLI =
nullptr;
277 DenseMap<ComplexValues, CompositeNode *> CachedResult;
278 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
280 SmallPtrSet<Instruction *, 16> FinalInstructions;
283 DenseMap<Instruction *, CompositeNode *> RootToNode;
310 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
318 PHINode *RealPHI =
nullptr;
319 PHINode *ImagPHI =
nullptr;
323 bool PHIsFound =
false;
331 DenseMap<PHINode *, PHINode *> OldToNewPHI;
336 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
338 "Reduction related nodes must have Real and Imaginary parts");
339 return new (Allocator.Allocate())
340 ComplexDeinterleavingCompositeNode(
Operation, R,
I);
346 for (
auto &V : Vals) {
348 ((
Operation != ComplexDeinterleavingOperation::ReductionPHI &&
349 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
350 (
V.Real &&
V.Imag)) &&
351 "Reduction related nodes must have Real and Imaginary parts");
354 return new (Allocator.Allocate())
355 ComplexDeinterleavingCompositeNode(
Operation, Vals);
358 CompositeNode *submitCompositeNode(CompositeNode *Node) {
359 CompositeNodes.push_back(Node);
360 if (
Node->Vals[0].Real)
376 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
382 identifyNodeWithImplicitAdd(Instruction *
I, Instruction *J,
383 std::pair<Value *, Value *> &CommonOperandI);
392 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
393 CompositeNode *identifySymmetricOperation(
ComplexValues &Vals);
394 CompositeNode *identifyPartialReduction(
Value *R,
Value *
I);
395 CompositeNode *identifyDotProduct(
Value *Inst);
402 return identifyNode(Vals);
409 CompositeNode *identifyAdditions(AddendList &RealAddends,
410 AddendList &ImagAddends,
411 std::optional<FastMathFlags> Flags,
415 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
416 AddendList &ImagAddends);
421 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
422 SmallVectorImpl<Product> &ImagMuls,
430 SmallVectorImpl<PartialMulCandidate> &Candidates);
438 CompositeNode *identifyReassocNodes(Instruction *
I, Instruction *J);
440 CompositeNode *identifyRoot(Instruction *
I);
458 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
462 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
464 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
471 void processReductionOperation(
Value *OperationReplacement,
472 CompositeNode *Node);
473 void processReductionSingle(
Value *OperationReplacement, CompositeNode *Node);
477 void dump(raw_ostream &OS) {
478 for (
const auto &Node : CompositeNodes)
484 bool identifyNodes(Instruction *RootI);
489 bool collectPotentialReductions(BasicBlock *
B);
491 void identifyReductionNodes();
501class ComplexDeinterleaving {
503 ComplexDeinterleaving(
const TargetLowering *tl,
const TargetLibraryInfo *tli)
504 : TL(tl), TLI(tli) {}
508 bool evaluateBasicBlock(BasicBlock *
B,
unsigned Factor);
510 const TargetLowering *TL =
nullptr;
511 const TargetLibraryInfo *TLI =
nullptr;
516char ComplexDeinterleavingLegacyPass::ID = 0;
519 "Complex Deinterleaving",
false,
false)
525 const TargetLowering *TL = TM->getSubtargetImpl(
F)->getTargetLowering();
536 return new ComplexDeinterleavingLegacyPass(TM);
539bool ComplexDeinterleavingLegacyPass::runOnFunction(
Function &
F) {
541 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
542 return ComplexDeinterleaving(TL, &TLI).runOnFunction(
F);
545bool ComplexDeinterleaving::runOnFunction(Function &
F) {
548 dbgs() <<
"Complex deinterleaving has been explicitly disabled.\n");
554 dbgs() <<
"Complex deinterleaving has been disabled, target does "
555 "not support lowering of complex number operations.\n");
561 Changed |= evaluateBasicBlock(&
B, 2);
566 Changed |= evaluateBasicBlock(&
B, 4);
576 if ((Mask.size() & 1))
579 int HalfNumElements = Mask.size() / 2;
580 for (
int Idx = 0; Idx < HalfNumElements; ++Idx) {
581 int MaskIdx = Idx * 2;
582 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
591 int HalfNumElements = Mask.size() / 2;
593 for (
int Idx = 1; Idx < HalfNumElements; ++Idx) {
594 if (Mask[Idx] != (Idx * 2) +
Offset)
608 if (
I->getOpcode() == Instruction::FNeg)
609 return I->getOperand(0);
611 return I->getOperand(1);
614bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *
B,
unsigned Factor) {
615 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
616 if (Graph.collectPotentialReductions(
B))
617 Graph.identifyReductionNodes();
620 Graph.identifyNodes(&
I);
622 if (Graph.checkNodes()) {
623 Graph.replaceNodes();
630ComplexDeinterleavingGraph::CompositeNode *
631ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
632 Instruction *Real, Instruction *Imag,
633 std::pair<Value *, Value *> &PartialMatch) {
634 LLVM_DEBUG(
dbgs() <<
"identifyNodeWithImplicitAdd " << *Real <<
" / " << *Imag
642 if ((Real->
getOpcode() != Instruction::FMul &&
643 Real->
getOpcode() != Instruction::Mul) ||
644 (Imag->
getOpcode() != Instruction::FMul &&
645 Imag->
getOpcode() != Instruction::Mul)) {
647 dbgs() <<
" - Real or imaginary instruction is not fmul or mul\n");
680 Value *CommonOperand;
681 Value *UncommonRealOp;
682 Value *UncommonImagOp;
684 if (R0 == I0 || R0 == I1) {
687 }
else if (R1 == I0 || R1 == I1) {
695 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
696 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
697 Rotation == ComplexDeinterleavingRotation::Rotation_270)
698 std::swap(UncommonRealOp, UncommonImagOp);
702 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
703 Rotation == ComplexDeinterleavingRotation::Rotation_180)
704 PartialMatch.first = CommonOperand;
706 PartialMatch.second = CommonOperand;
708 if (!PartialMatch.first || !PartialMatch.second) {
713 CompositeNode *CommonNode =
714 identifyNode(PartialMatch.first, PartialMatch.second);
720 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
726 CompositeNode *
Node = prepareCompositeNode(
727 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
728 Node->Rotation = Rotation;
729 Node->addOperand(CommonNode);
730 Node->addOperand(UncommonNode);
731 return submitCompositeNode(Node);
734ComplexDeinterleavingGraph::CompositeNode *
735ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
737 LLVM_DEBUG(
dbgs() <<
"identifyPartialMul " << *Real <<
" / " << *Imag
741 auto IsAdd = [](
unsigned Op) {
742 return Op == Instruction::FAdd ||
Op == Instruction::Add;
744 auto IsSub = [](
unsigned Op) {
745 return Op == Instruction::FSub ||
Op == Instruction::Sub;
749 Rotation = ComplexDeinterleavingRotation::Rotation_0;
751 Rotation = ComplexDeinterleavingRotation::Rotation_90;
753 Rotation = ComplexDeinterleavingRotation::Rotation_180;
755 Rotation = ComplexDeinterleavingRotation::Rotation_270;
764 LLVM_DEBUG(
dbgs() <<
" - Contract is missing from the FastMath flags.\n");
787 Value *CommonOperand;
788 Value *UncommonRealOp;
789 Value *UncommonImagOp;
791 if (R0 == I0 || R0 == I1) {
794 }
else if (R1 == I0 || R1 == I1) {
802 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
803 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
804 Rotation == ComplexDeinterleavingRotation::Rotation_270)
805 std::swap(UncommonRealOp, UncommonImagOp);
807 std::pair<Value *, Value *> PartialMatch(
808 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
809 Rotation == ComplexDeinterleavingRotation::Rotation_180)
812 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
813 Rotation == ComplexDeinterleavingRotation::Rotation_270)
820 if (!CRInst || !CIInst) {
821 LLVM_DEBUG(
dbgs() <<
" - Common operands are not instructions.\n");
825 CompositeNode *CNode =
826 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
832 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
838 assert(PartialMatch.first && PartialMatch.second);
839 CompositeNode *CommonRes =
840 identifyNode(PartialMatch.first, PartialMatch.second);
846 CompositeNode *
Node = prepareCompositeNode(
847 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
848 Node->Rotation = Rotation;
849 Node->addOperand(CommonRes);
850 Node->addOperand(UncommonRes);
851 Node->addOperand(CNode);
852 return submitCompositeNode(Node);
855ComplexDeinterleavingGraph::CompositeNode *
856ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
857 LLVM_DEBUG(
dbgs() <<
"identifyAdd " << *Real <<
" / " << *Imag <<
"\n");
861 if ((Real->
getOpcode() == Instruction::FSub &&
862 Imag->
getOpcode() == Instruction::FAdd) ||
863 (Real->
getOpcode() == Instruction::Sub &&
865 Rotation = ComplexDeinterleavingRotation::Rotation_90;
866 else if ((Real->
getOpcode() == Instruction::FAdd &&
867 Imag->
getOpcode() == Instruction::FSub) ||
868 (Real->
getOpcode() == Instruction::Add &&
870 Rotation = ComplexDeinterleavingRotation::Rotation_270;
872 LLVM_DEBUG(
dbgs() <<
" - Unhandled case, rotation is not assigned.\n");
881 if (!AR || !AI || !BR || !BI) {
886 CompositeNode *ResA = identifyNode(AR, AI);
888 LLVM_DEBUG(
dbgs() <<
" - AR/AI is not identified as a composite node.\n");
891 CompositeNode *ResB = identifyNode(BR, BI);
893 LLVM_DEBUG(
dbgs() <<
" - BR/BI is not identified as a composite node.\n");
897 CompositeNode *
Node =
898 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
899 Node->Rotation = Rotation;
900 Node->addOperand(ResA);
901 Node->addOperand(ResB);
902 return submitCompositeNode(Node);
906 unsigned OpcA =
A->getOpcode();
907 unsigned OpcB =
B->getOpcode();
909 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
910 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
911 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
912 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
923 switch (
I->getOpcode()) {
924 case Instruction::FAdd:
925 case Instruction::FSub:
926 case Instruction::FMul:
927 case Instruction::FNeg:
928 case Instruction::Add:
929 case Instruction::Sub:
930 case Instruction::Mul:
937ComplexDeinterleavingGraph::CompositeNode *
938ComplexDeinterleavingGraph::identifySymmetricOperation(
ComplexValues &Vals) {
940 unsigned FirstOpc = FirstReal->getOpcode();
941 for (
auto &V : Vals) {
958 for (
auto &V : Vals) {
964 CompositeNode *Op0 = identifyNode(OpVals);
965 CompositeNode *Op1 =
nullptr;
969 if (FirstReal->isBinaryOp()) {
971 for (
auto &V : Vals) {
976 Op1 = identifyNode(OpVals);
982 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
983 Node->Opcode = FirstReal->getOpcode();
985 Node->Flags = FirstReal->getFastMathFlags();
987 Node->addOperand(Op0);
988 if (FirstReal->isBinaryOp())
989 Node->addOperand(Op1);
991 return submitCompositeNode(Node);
994ComplexDeinterleavingGraph::CompositeNode *
995ComplexDeinterleavingGraph::identifyDotProduct(
Value *V) {
997 ComplexDeinterleavingOperation::CDot,
V->getType())) {
998 LLVM_DEBUG(
dbgs() <<
"Target doesn't support complex deinterleaving "
999 "operation CDot with the type "
1000 << *
V->getType() <<
"\n");
1008 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst,
nullptr);
1010 CompositeNode *ANode =
nullptr;
1012 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1014 Value *AReal =
nullptr;
1015 Value *AImag =
nullptr;
1016 Value *BReal =
nullptr;
1017 Value *BImag =
nullptr;
1022 return CI->getOperand(0);
1036 if (
match(Inst, PatternRot0)) {
1037 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1038 }
else if (
match(Inst, PatternRot270)) {
1039 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1050 if (!
match(Inst, PatternRot90Rot180))
1053 A0 = UnwrapCast(A0);
1054 A1 = UnwrapCast(A1);
1057 ANode = identifyNode(A0, A1);
1060 ANode = identifyNode(A1, A0);
1064 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1070 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1074 AReal = UnwrapCast(AReal);
1075 AImag = UnwrapCast(AImag);
1076 BReal = UnwrapCast(BReal);
1077 BImag = UnwrapCast(BImag);
1080 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1081 if (AReal->
getType() != ExpectedOperandTy)
1083 if (AImag->
getType() != ExpectedOperandTy)
1085 if (BReal->
getType() != ExpectedOperandTy)
1087 if (BImag->
getType() != ExpectedOperandTy)
1090 if (
Phi->getType() != VTy && RealUser->getType() != VTy)
1093 CompositeNode *
Node = identifyNode(AReal, AImag);
1098 if (ANode && Node != ANode) {
1101 <<
"Identified node is different from previously identified node. "
1102 "Unable to confidently generate a complex operation node\n");
1106 CN->addOperand(Node);
1107 CN->addOperand(identifyNode(BReal, BImag));
1108 CN->addOperand(identifyNode(Phi, RealUser));
1110 return submitCompositeNode(CN);
1113ComplexDeinterleavingGraph::CompositeNode *
1114ComplexDeinterleavingGraph::identifyPartialReduction(
Value *R,
Value *
I) {
1119 if (!
R->hasUseList() || !
I->hasUseList())
1123 findCommonBetweenCollections<Value *>(
R->users(),
I->users());
1128 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1131 if (CompositeNode *CN = identifyDotProduct(IInst))
1137ComplexDeinterleavingGraph::CompositeNode *
1138ComplexDeinterleavingGraph::identifyNode(
ComplexValues &Vals) {
1139 auto It = CachedResult.
find(Vals);
1140 if (It != CachedResult.
end()) {
1145 if (Vals.
size() == 1) {
1146 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1149 if (CompositeNode *CN = identifyPartialReduction(R,
I))
1151 bool IsReduction = RealPHI ==
R && (!ImagPHI || ImagPHI ==
I);
1152 if (!IsReduction &&
R->getType() !=
I->getType())
1156 if (CompositeNode *CN = identifySplat(Vals))
1159 for (
auto &V : Vals) {
1166 if (CompositeNode *CN = identifyDeinterleave(Vals))
1169 if (Vals.size() == 1) {
1170 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1173 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1176 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1180 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1183 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1185 ComplexDeinterleavingOperation::CAdd, NewVTy);
1188 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1193 if (CompositeNode *CN = identifyAdd(Real, Imag))
1197 if (HasCMulSupport && HasCAddSupport) {
1198 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1204 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1208 CachedResult[Vals] =
nullptr;
1212ComplexDeinterleavingGraph::CompositeNode *
1213ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1214 Instruction *Imag) {
1215 auto IsOperationSupported = [](
unsigned Opcode) ->
bool {
1216 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1217 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1218 Opcode == Instruction::Sub;
1221 if (!IsOperationSupported(Real->
getOpcode()) ||
1222 !IsOperationSupported(Imag->
getOpcode()))
1225 std::optional<FastMathFlags>
Flags;
1228 LLVM_DEBUG(
dbgs() <<
"The flags in Real and Imaginary instructions are "
1234 if (!
Flags->allowReassoc()) {
1237 <<
"the 'Reassoc' attribute is missing in the FastMath flags\n");
1246 AddendList &Addends) ->
bool {
1248 SmallPtrSet<Value *, 8> Visited;
1249 while (!Worklist.
empty()) {
1251 if (!Visited.
insert(V).second)
1256 Addends.emplace_back(V, IsPositive);
1266 if (
I != Insn &&
I->hasNUsesOrMore(2)) {
1267 LLVM_DEBUG(
dbgs() <<
"Found potential sub-expression: " << *
I <<
"\n");
1268 Addends.emplace_back(
I, IsPositive);
1271 switch (
I->getOpcode()) {
1272 case Instruction::FAdd:
1273 case Instruction::Add:
1277 case Instruction::FSub:
1281 case Instruction::Sub:
1289 case Instruction::FMul:
1290 case Instruction::Mul: {
1292 if (
isNeg(
I->getOperand(0))) {
1294 IsPositive = !IsPositive;
1296 A =
I->getOperand(0);
1299 if (
isNeg(
I->getOperand(1))) {
1301 IsPositive = !IsPositive;
1303 B =
I->getOperand(1);
1305 Muls.push_back(Product{
A,
B, IsPositive});
1308 case Instruction::FNeg:
1312 Addends.emplace_back(
I, IsPositive);
1316 if (Flags &&
I->getFastMathFlags() != *Flags) {
1318 "inconsistent with the root instructions' flags: "
1327 AddendList RealAddends, ImagAddends;
1328 if (!Collect(Real, RealMuls, RealAddends) ||
1329 !Collect(Imag, ImagMuls, ImagAddends))
1332 if (RealAddends.size() != ImagAddends.size())
1335 CompositeNode *FinalNode =
nullptr;
1336 if (!RealMuls.
empty() || !ImagMuls.
empty()) {
1339 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1340 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1346 if (!RealAddends.empty() || !ImagAddends.empty()) {
1347 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1351 assert(FinalNode &&
"FinalNode can not be nullptr here");
1352 assert(FinalNode->Vals.size() == 1);
1354 FinalNode->Vals[0].Real = Real;
1355 FinalNode->Vals[0].Imag = Imag;
1356 submitCompositeNode(FinalNode);
1360bool ComplexDeinterleavingGraph::collectPartialMuls(
1362 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1364 auto FindCommonInstruction = [](
const Product &Real,
1365 const Product &Imag) ->
Value * {
1366 if (Real.Multiplicand == Imag.Multiplicand ||
1367 Real.Multiplicand == Imag.Multiplier)
1368 return Real.Multiplicand;
1370 if (Real.Multiplier == Imag.Multiplicand ||
1371 Real.Multiplier == Imag.Multiplier)
1372 return Real.Multiplier;
1381 for (
unsigned i = 0; i < RealMuls.
size(); ++i) {
1382 bool FoundCommon =
false;
1383 for (
unsigned j = 0;
j < ImagMuls.
size(); ++
j) {
1384 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1388 auto *
A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1389 : RealMuls[i].Multiplicand;
1390 auto *
B = ImagMuls[
j].Multiplicand == Common ? ImagMuls[
j].Multiplier
1391 : ImagMuls[
j].Multiplicand;
1393 auto Node = identifyNode(
A,
B);
1399 Node = identifyNode(
B,
A);
1411ComplexDeinterleavingGraph::CompositeNode *
1412ComplexDeinterleavingGraph::identifyMultiplications(
1413 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1415 if (RealMuls.
size() != ImagMuls.
size())
1419 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1423 DenseMap<Value *, CompositeNode *> CommonToNode;
1424 SmallVector<bool> Processed(
Info.size(),
false);
1425 for (
unsigned I = 0;
I <
Info.size(); ++
I) {
1429 PartialMulCandidate &InfoA =
Info[
I];
1430 for (
unsigned J =
I + 1; J <
Info.size(); ++J) {
1434 PartialMulCandidate &InfoB =
Info[J];
1435 auto *InfoReal = &InfoA;
1436 auto *InfoImag = &InfoB;
1438 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1439 if (!NodeFromCommon) {
1441 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1443 if (!NodeFromCommon)
1446 CommonToNode[InfoReal->Common] = NodeFromCommon;
1447 CommonToNode[InfoImag->Common] = NodeFromCommon;
1448 Processed[
I] =
true;
1449 Processed[J] =
true;
1453 SmallVector<bool> ProcessedReal(RealMuls.
size(),
false);
1454 SmallVector<bool> ProcessedImag(ImagMuls.
size(),
false);
1456 for (
auto &PMI : Info) {
1457 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1460 auto It = CommonToNode.
find(PMI.Common);
1463 if (It == CommonToNode.
end()) {
1465 dbgs() <<
"Unprocessed independent partial multiplication:\n";
1466 for (
auto *
Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1468 <<
" multiplied by " << *
Mul->Multiplicand <<
"\n";
1473 auto &RealMul = RealMuls[PMI.RealIdx];
1474 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1476 auto NodeA = It->second;
1477 auto NodeB = PMI.Node;
1478 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1493 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1494 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1499 if (IsMultiplicandReal) {
1501 if (RealMul.IsPositive && ImagMul.IsPositive)
1503 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1510 if (!RealMul.IsPositive && ImagMul.IsPositive)
1512 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1519 dbgs() <<
"Identified partial multiplication (X, Y) * (U, V):\n";
1520 dbgs().
indent(4) <<
"X: " << *NodeA->Vals[0].Real <<
"\n";
1521 dbgs().
indent(4) <<
"Y: " << *NodeA->Vals[0].Imag <<
"\n";
1522 dbgs().
indent(4) <<
"U: " << *NodeB->Vals[0].Real <<
"\n";
1523 dbgs().
indent(4) <<
"V: " << *NodeB->Vals[0].Imag <<
"\n";
1524 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1527 CompositeNode *NodeMul = prepareCompositeNode(
1528 ComplexDeinterleavingOperation::CMulPartial,
nullptr,
nullptr);
1529 NodeMul->Rotation = Rotation;
1530 NodeMul->addOperand(NodeA);
1531 NodeMul->addOperand(NodeB);
1533 NodeMul->addOperand(Result);
1534 submitCompositeNode(NodeMul);
1536 ProcessedReal[PMI.RealIdx] =
true;
1537 ProcessedImag[PMI.ImagIdx] =
true;
1541 if (!
all_of(ProcessedReal, [](
bool V) {
return V; }) ||
1542 !
all_of(ProcessedImag, [](
bool V) {
return V; })) {
1547 dbgs() <<
"Unprocessed products (Real):\n";
1548 for (
size_t i = 0; i < ProcessedReal.size(); ++i) {
1549 if (!ProcessedReal[i])
1550 dbgs().
indent(4) << (RealMuls[i].IsPositive ?
"+" :
"-")
1551 << *RealMuls[i].Multiplier <<
" multiplied by "
1552 << *RealMuls[i].Multiplicand <<
"\n";
1554 dbgs() <<
"Unprocessed products (Imag):\n";
1555 for (
size_t i = 0; i < ProcessedImag.size(); ++i) {
1556 if (!ProcessedImag[i])
1557 dbgs().
indent(4) << (ImagMuls[i].IsPositive ?
"+" :
"-")
1558 << *ImagMuls[i].Multiplier <<
" multiplied by "
1559 << *ImagMuls[i].Multiplicand <<
"\n";
1568ComplexDeinterleavingGraph::CompositeNode *
1569ComplexDeinterleavingGraph::identifyAdditions(
1570 AddendList &RealAddends, AddendList &ImagAddends,
1571 std::optional<FastMathFlags> Flags, CompositeNode *
Accumulator =
nullptr) {
1572 if (RealAddends.size() != ImagAddends.size())
1575 CompositeNode *
Result =
nullptr;
1581 Result = extractPositiveAddend(RealAddends, ImagAddends);
1586 while (!RealAddends.empty()) {
1587 auto ItR = RealAddends.begin();
1588 auto [
R, IsPositiveR] = *ItR;
1590 bool FoundImag =
false;
1591 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1592 auto [
I, IsPositiveI] = *ItI;
1594 if (IsPositiveR && IsPositiveI)
1595 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1596 else if (!IsPositiveR && IsPositiveI)
1597 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1598 else if (!IsPositiveR && !IsPositiveI)
1599 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1601 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1603 CompositeNode *AddNode =
nullptr;
1604 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1605 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1606 AddNode = identifyNode(R,
I);
1608 AddNode = identifyNode(
I, R);
1612 dbgs() <<
"Identified addition:\n";
1615 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1618 CompositeNode *TmpNode =
nullptr;
1620 TmpNode = prepareCompositeNode(
1621 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1623 TmpNode->Opcode = Instruction::FAdd;
1624 TmpNode->Flags = *
Flags;
1626 TmpNode->Opcode = Instruction::Add;
1628 }
else if (Rotation ==
1630 TmpNode = prepareCompositeNode(
1631 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1633 TmpNode->Opcode = Instruction::FSub;
1634 TmpNode->Flags = *
Flags;
1636 TmpNode->Opcode = Instruction::Sub;
1639 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1641 TmpNode->Rotation = Rotation;
1644 TmpNode->addOperand(Result);
1645 TmpNode->addOperand(AddNode);
1646 submitCompositeNode(TmpNode);
1648 RealAddends.erase(ItR);
1649 ImagAddends.erase(ItI);
1660ComplexDeinterleavingGraph::CompositeNode *
1661ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1662 AddendList &ImagAddends) {
1663 for (
auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1664 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1665 auto [
R, IsPositiveR] = *ItR;
1666 auto [
I, IsPositiveI] = *ItI;
1667 if (IsPositiveR && IsPositiveI) {
1668 auto Result = identifyNode(R,
I);
1670 RealAddends.erase(ItR);
1671 ImagAddends.erase(ItI);
1680bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1685 auto It = RootToNode.
find(RootI);
1686 if (It != RootToNode.
end()) {
1687 auto RootNode = It->second;
1688 assert(RootNode->Operation ==
1689 ComplexDeinterleavingOperation::ReductionOperation ||
1690 RootNode->Operation ==
1691 ComplexDeinterleavingOperation::ReductionSingle);
1692 assert(RootNode->Vals.size() == 1 &&
1693 "Cannot handle reductions involving multiple complex values");
1702 ReplacementAnchor =
R->comesBefore(
I) ?
I :
R;
1704 ReplacementAnchor =
R;
1706 if (ReplacementAnchor != RootI)
1712 auto RootNode = identifyRoot(RootI);
1719 dbgs() <<
"Complex deinterleaving graph for " <<
F->getName()
1720 <<
"::" <<
B->getName() <<
".\n";
1724 RootToNode[RootI] = RootNode;
1729bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *
B) {
1730 bool FoundPotentialReduction =
false;
1739 if (Br->getSuccessor(0) !=
B && Br->getSuccessor(1) !=
B)
1742 for (
auto &
PHI :
B->phis()) {
1743 if (
PHI.getNumIncomingValues() != 2)
1746 if (!
PHI.getType()->isVectorTy())
1756 for (
auto *U : ReductionOp->users()) {
1763 if (NumUsers != 2 || !FinalReduction || FinalReduction->
getParent() ==
B ||
1767 ReductionInfo[ReductionOp] = {&
PHI, FinalReduction};
1769 auto BackEdgeIdx =
PHI.getBasicBlockIndex(
B);
1770 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1771 Incoming =
PHI.getIncomingBlock(IncomingIdx);
1772 FoundPotentialReduction =
true;
1778 FinalInstructions.
insert(InitPHI);
1780 return FoundPotentialReduction;
1783void ComplexDeinterleavingGraph::identifyReductionNodes() {
1784 assert(Factor == 2 &&
"Cannot handle multiple complex values");
1786 SmallVector<bool> Processed(ReductionInfo.
size(),
false);
1788 for (
auto &
P : ReductionInfo)
1793 for (
size_t i = 0; i < OperationInstruction.
size(); ++i) {
1796 for (
size_t j = i + 1;
j < OperationInstruction.
size(); ++
j) {
1799 auto *Real = OperationInstruction[i];
1800 auto *Imag = OperationInstruction[
j];
1801 if (Real->getType() != Imag->
getType())
1804 RealPHI = ReductionInfo[Real].first;
1805 ImagPHI = ReductionInfo[Imag].first;
1807 auto Node = identifyNode(Real, Imag);
1811 Node = identifyNode(Real, Imag);
1817 if (Node && PHIsFound) {
1818 LLVM_DEBUG(
dbgs() <<
"Identified reduction starting from instructions: "
1819 << *Real <<
" / " << *Imag <<
"\n");
1820 Processed[i] =
true;
1821 Processed[
j] =
true;
1822 auto RootNode = prepareCompositeNode(
1823 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1824 RootNode->addOperand(Node);
1825 RootToNode[Real] = RootNode;
1826 RootToNode[Imag] = RootNode;
1827 submitCompositeNode(RootNode);
1832 auto *Real = OperationInstruction[i];
1835 if (Processed[i] || Real->getNumOperands() < 2)
1839 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1842 RealPHI = ReductionInfo[Real].first;
1845 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1846 if (Node && PHIsFound) {
1848 dbgs() <<
"Identified single reduction starting from instruction: "
1849 << *Real <<
"/" << *ReductionInfo[Real].second <<
"\n");
1858 if (ReductionInfo[Real].second->getType()->isVectorTy())
1861 Processed[i] =
true;
1862 auto RootNode = prepareCompositeNode(
1863 ComplexDeinterleavingOperation::ReductionSingle, Real,
nullptr);
1864 RootNode->addOperand(Node);
1865 RootToNode[Real] = RootNode;
1866 submitCompositeNode(RootNode);
1874bool ComplexDeinterleavingGraph::checkNodes() {
1875 bool FoundDeinterleaveNode =
false;
1876 for (CompositeNode *
N : CompositeNodes) {
1877 if (!
N->areOperandsValid())
1880 if (
N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1881 FoundDeinterleaveNode =
true;
1886 if (!FoundDeinterleaveNode) {
1888 dbgs() <<
"Couldn't find a deinterleave node within the graph, cannot "
1889 "guarantee safety during graph transformation.\n");
1894 SmallPtrSet<Instruction *, 16> AllInstructions;
1895 SmallVector<Instruction *, 8> Worklist;
1896 for (
auto &Pair : RootToNode)
1901 while (!Worklist.
empty()) {
1904 if (!AllInstructions.
insert(
I).second)
1909 if (!FinalInstructions.
count(
I))
1916 for (
auto *
I : AllInstructions) {
1918 if (RootToNode.count(
I))
1921 for (User *U :
I->users()) {
1933 SmallPtrSet<Instruction *, 16> Visited;
1934 while (!Worklist.
empty()) {
1936 if (!Visited.
insert(
I).second)
1941 if (RootToNode.count(
I)) {
1943 <<
" could be deinterleaved but its chain of complex "
1944 "operations have an outside user\n");
1945 RootToNode.erase(
I);
1948 if (!AllInstructions.count(
I) || FinalInstructions.
count(
I))
1951 for (User *U :
I->users())
1959 return !RootToNode.
empty();
1962ComplexDeinterleavingGraph::CompositeNode *
1963ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1970 for (
unsigned I = 0;
I < Factor;
I += 2) {
1978 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
2006 return identifyNode(Real, Imag);
2009ComplexDeinterleavingGraph::CompositeNode *
2010ComplexDeinterleavingGraph::identifyDeinterleave(
ComplexValues &Vals) {
2014 auto CheckExtract = [&](
Value *
V,
unsigned ExpectedIdx,
2015 Instruction *ExpectedInsn) -> ExtractValueInst * {
2017 if (!EVI || EVI->getNumIndices() != 1 ||
2018 EVI->getIndices()[0] != ExpectedIdx ||
2020 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2025 for (
unsigned Idx = 0; Idx < Vals.
size(); Idx++) {
2026 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2,
II);
2027 if (RealEVI && Idx == 0)
2029 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1,
II)) {
2036 if (IntrinsicII->getIntrinsicID() !=
2041 CompositeNode *PlaceholderNode = prepareCompositeNode(
2043 PlaceholderNode->ReplacementNode =
II->getOperand(0);
2044 for (
auto &V : Vals) {
2048 return submitCompositeNode(PlaceholderNode);
2051 if (Vals.size() != 1)
2054 Value *Real = Vals[0].Real;
2055 Value *Imag = Vals[0].Imag;
2058 if (!RealShuffle || !ImagShuffle) {
2059 if (RealShuffle || ImagShuffle)
2060 LLVM_DEBUG(
dbgs() <<
" - There's a shuffle where there shouldn't be.\n");
2064 Value *RealOp1 = RealShuffle->getOperand(1);
2069 Value *ImagOp1 = ImagShuffle->getOperand(1);
2075 Value *RealOp0 = RealShuffle->getOperand(0);
2076 Value *ImagOp0 = ImagShuffle->getOperand(0);
2078 if (RealOp0 != ImagOp0) {
2083 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2084 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2090 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2091 LLVM_DEBUG(
dbgs() <<
" - Masks do not have the correct initial value.\n");
2097 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2098 Value *
Op = Shuffle->getOperand(0);
2102 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2104 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2110 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) ->
bool {
2114 ArrayRef<int>
Mask = Shuffle->getShuffleMask();
2117 Value *
Op = Shuffle->getOperand(0);
2119 int NumElements = OpTy->getNumElements();
2123 return Last < NumElements;
2126 if (RealShuffle->getType() != ImagShuffle->getType()) {
2130 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2134 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2139 CompositeNode *PlaceholderNode =
2141 RealShuffle, ImagShuffle);
2142 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2143 FinalInstructions.
insert(RealShuffle);
2144 FinalInstructions.
insert(ImagShuffle);
2145 return submitCompositeNode(PlaceholderNode);
2148ComplexDeinterleavingGraph::CompositeNode *
2149ComplexDeinterleavingGraph::identifySplat(
ComplexValues &Vals) {
2150 auto IsSplat = [](
Value *
V) ->
bool {
2163 if (
Const->getOpcode() != Instruction::ShuffleVector)
2168 VTy = Shuf->getType();
2169 Mask = Shuf->getShuffleMask();
2177 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2187 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2188 for (
auto &V : Vals) {
2189 if (!IsSplat(
V.Real) || !IsSplat(
V.Imag))
2194 if (!Real || !Imag || Real->getParent() != FirstBB ||
2195 Imag->getParent() != FirstBB)
2199 for (
auto &V : Vals) {
2206 for (
auto &V : Vals) {
2210 FinalInstructions.
insert(Real);
2211 FinalInstructions.
insert(Imag);
2214 CompositeNode *PlaceholderNode =
2215 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2216 return submitCompositeNode(PlaceholderNode);
2219ComplexDeinterleavingGraph::CompositeNode *
2220ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2221 Instruction *Imag) {
2222 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2226 CompositeNode *PlaceholderNode = prepareCompositeNode(
2227 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2228 return submitCompositeNode(PlaceholderNode);
2231ComplexDeinterleavingGraph::CompositeNode *
2232ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2233 Instruction *Imag) {
2236 if (!SelectReal || !SelectImag)
2253 auto NodeA = identifyNode(AR, AI);
2257 auto NodeB = identifyNode(
RA, BI);
2261 CompositeNode *PlaceholderNode = prepareCompositeNode(
2262 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2263 PlaceholderNode->addOperand(NodeA);
2264 PlaceholderNode->addOperand(NodeB);
2265 FinalInstructions.
insert(MaskA);
2266 FinalInstructions.
insert(MaskB);
2267 return submitCompositeNode(PlaceholderNode);
2271 std::optional<FastMathFlags> Flags,
2275 case Instruction::FNeg:
2276 I =
B.CreateFNeg(InputA);
2278 case Instruction::FAdd:
2279 I =
B.CreateFAdd(InputA, InputB);
2281 case Instruction::Add:
2282 I =
B.CreateAdd(InputA, InputB);
2284 case Instruction::FSub:
2285 I =
B.CreateFSub(InputA, InputB);
2287 case Instruction::Sub:
2288 I =
B.CreateSub(InputA, InputB);
2290 case Instruction::FMul:
2291 I =
B.CreateFMul(InputA, InputB);
2293 case Instruction::Mul:
2294 I =
B.CreateMul(InputA, InputB);
2304Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2305 CompositeNode *Node) {
2306 if (
Node->ReplacementNode)
2307 return Node->ReplacementNode;
2309 auto ReplaceOperandIfExist = [&](CompositeNode *
Node,
2310 unsigned Idx) ->
Value * {
2311 return Node->Operands.size() > Idx
2312 ? replaceNode(Builder,
Node->Operands[Idx])
2316 Value *ReplacementNode =
nullptr;
2317 switch (
Node->Operation) {
2318 case ComplexDeinterleavingOperation::CDot: {
2319 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2320 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2323 "Node inputs need to be of the same type"));
2328 case ComplexDeinterleavingOperation::CAdd:
2329 case ComplexDeinterleavingOperation::CMulPartial:
2330 case ComplexDeinterleavingOperation::Symmetric: {
2331 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2332 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2335 "Node inputs need to be of the same type"));
2338 "Accumulator and input need to be of the same type"));
2339 if (
Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2344 Builder,
Node->Operation,
Node->Rotation, Input0, Input1,
2348 case ComplexDeinterleavingOperation::Deinterleave:
2351 case ComplexDeinterleavingOperation::Splat: {
2353 for (
auto &V :
Node->Vals) {
2354 Ops.push_back(
V.Real);
2355 Ops.push_back(
V.Imag);
2362 for (
auto V :
Node->Vals) {
2370 ReplacementNode = IRB.CreateVectorInterleave(
Ops);
2376 case ComplexDeinterleavingOperation::ReductionPHI: {
2381 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2383 OldToNewPHI[OldPHI] = NewPHI;
2384 ReplacementNode = NewPHI;
2387 case ComplexDeinterleavingOperation::ReductionSingle:
2388 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2389 processReductionSingle(ReplacementNode, Node);
2391 case ComplexDeinterleavingOperation::ReductionOperation:
2392 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2393 processReductionOperation(ReplacementNode, Node);
2395 case ComplexDeinterleavingOperation::ReductionSelect: {
2398 auto *
A = replaceNode(Builder,
Node->Operands[0]);
2399 auto *
B = replaceNode(Builder,
Node->Operands[1]);
2406 assert(ReplacementNode &&
"Target failed to create Intrinsic call.");
2407 NumComplexTransformations += 1;
2408 Node->ReplacementNode = ReplacementNode;
2409 return ReplacementNode;
2412void ComplexDeinterleavingGraph::processReductionSingle(
2413 Value *OperationReplacement, CompositeNode *Node) {
2415 auto *OldPHI = ReductionInfo[Real].first;
2416 auto *NewPHI = OldToNewPHI[OldPHI];
2418 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2420 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2424 Value *NewInit =
nullptr;
2426 if (
C->isNullValue())
2434 NewPHI->addIncoming(NewInit, Incoming);
2435 NewPHI->addIncoming(OperationReplacement, BackEdge);
2437 auto *FinalReduction = ReductionInfo[Real].second;
2444void ComplexDeinterleavingGraph::processReductionOperation(
2445 Value *OperationReplacement, CompositeNode *Node) {
2448 auto *OldPHIReal = ReductionInfo[Real].first;
2449 auto *OldPHIImag = ReductionInfo[Imag].first;
2450 auto *NewPHI = OldToNewPHI[OldPHIReal];
2453 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2454 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2459 NewPHI->addIncoming(NewInit, Incoming);
2460 NewPHI->addIncoming(OperationReplacement, BackEdge);
2464 auto *FinalReductionReal = ReductionInfo[Real].second;
2465 auto *FinalReductionImag = ReductionInfo[Imag].second;
2468 BasicBlock *ExitBB = Br->getSuccessor(Br->getSuccessor(0) == BackEdge);
2472 OperationReplacement->
getType(),
2473 OperationReplacement);
2476 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2480 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2483void ComplexDeinterleavingGraph::replaceNodes() {
2484 SmallVector<Instruction *, 16> DeadInstrRoots;
2485 for (
auto *RootInstruction : OrderedRoots) {
2488 if (!RootToNode.count(RootInstruction))
2492 auto RootNode = RootToNode[RootInstruction];
2493 Value *
R = replaceNode(Builder, RootNode);
2495 if (RootNode->Operation ==
2496 ComplexDeinterleavingOperation::ReductionOperation) {
2499 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2500 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2503 }
else if (RootNode->Operation ==
2504 ComplexDeinterleavingOperation::ReductionSingle) {
2506 auto &
Info = ReductionInfo[RootInst];
2507 Info.first->removeIncomingValue(BackEdge);
2510 assert(R &&
"Unable to find replacement for RootInstruction");
2511 DeadInstrRoots.
push_back(RootInstruction);
2512 RootInstruction->replaceAllUsesWith(R);
2516 for (
auto *
I : DeadInstrRoots)
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
This file implements a map that provides insertion order iteration.
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
This file describes how to lower LLVM code to machine code.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
Get the array size.
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
bool allowContract() const
FunctionPass class - This class is used to implement most global optimizations.
Common base class shared among various IRBuilders.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={})
Create a call to intrinsic ID with Args, mangled using OverloadTypes.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
An opaque object representing a hash code.
const ParentTy * getParent() const
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
@ BR
Control flow instructions. These all have token chains.
@ BasicBlock
Various leaf nodes.
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
match_bind< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BinOp()
Match an arbitrary binary operation and ignore it.
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
FunctionAddr VTableAddr Value
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
ComplexDeinterleavingOperation
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
ComplexDeinterleavingRotation
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
ComplexDeinterleavingPass(const TargetMachine &TM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static unsigned getHashValue(const ComplexValue &Val)
An information struct used to provide DenseMap with the various necessary components for a given valu...