LLVM 23.0.0git
AMDGPUIGroupLP.cpp
Go to the documentation of this file.
1//===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file This file defines a set of schedule DAG mutations that can be used to
10// override default scheduler behavior to enforce specific scheduling patterns.
11// They should be used in cases where runtime performance considerations such as
12// inter-wavefront interactions, mean that compile-time heuristics cannot
13// predict the optimal instruction ordering, or in kernels where optimum
14// instruction scheduling is important enough to warrant manual intervention.
15//
16//===----------------------------------------------------------------------===//
17
18#include "AMDGPUIGroupLP.h"
20#include "SIInstrInfo.h"
23#include "llvm/ADT/DenseMap.h"
26
27#include <type_traits>
28
29using namespace llvm;
30using namespace llvm::AMDGPU;
31
32#define DEBUG_TYPE "igrouplp"
33
34namespace {
35
36static cl::opt<bool> EnableExactSolver(
37 "amdgpu-igrouplp-exact-solver", cl::Hidden,
38 cl::desc("Whether to use the exponential time solver to fit "
39 "the instructions to the pipeline as closely as "
40 "possible."),
41 cl::init(false));
42
43static cl::opt<unsigned> CutoffForExact(
44 "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
45 cl::desc("The maximum number of scheduling group conflicts "
46 "which we attempt to solve with the exponential time "
47 "exact solver. Problem sizes greater than this will"
48 "be solved by the less accurate greedy algorithm. Selecting "
49 "solver by size is superseded by manually selecting "
50 "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
51
52static cl::opt<uint64_t> MaxBranchesExplored(
53 "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
54 cl::desc("The amount of branches that we are willing to explore with"
55 "the exact algorithm before giving up."));
56
57static cl::opt<bool> UseCostHeur(
58 "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
59 cl::desc("Whether to use the cost heuristic to make choices as we "
60 "traverse the search space using the exact solver. Defaulted "
61 "to on, and if turned off, we will use the node order -- "
62 "attempting to put the later nodes in the later sched groups. "
63 "Experimentally, results are mixed, so this should be set on a "
64 "case-by-case basis."));
65
66// Components of the mask that determines which instruction types may be may be
67// classified into a SchedGroup.
68enum class SchedGroupMask {
69 NONE = 0u,
70 ALU = 1u << 0,
71 VALU = 1u << 1,
72 SALU = 1u << 2,
73 MFMA = 1u << 3,
74 VMEM = 1u << 4,
75 VMEM_READ = 1u << 5,
76 VMEM_WRITE = 1u << 6,
77 DS = 1u << 7,
78 DS_READ = 1u << 8,
79 DS_WRITE = 1u << 9,
80 TRANS = 1u << 10,
81 LDSDMA = 1u << 11,
82 ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
83 DS_READ | DS_WRITE | TRANS | LDSDMA,
84 LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
85};
86
87class SchedGroup;
88
89// InstructionRule class is used to enact a filter which determines whether or
90// not an SU maps to a given SchedGroup. It contains complementary data
91// structures (e.g Cache) to help those filters.
92class InstructionRule {
93protected:
94 const SIInstrInfo *TII;
95 unsigned SGID;
96 // A cache made available to the Filter to store SUnits for subsequent
97 // invocations of the Filter
98 std::optional<SmallVector<SUnit *, 4>> Cache;
99
100public:
101 virtual bool
102 apply(const SUnit *, const ArrayRef<SUnit *>,
104 return true;
105 };
106
107 InstructionRule(const SIInstrInfo *TII, unsigned SGID,
108 bool NeedsCache = false)
109 : TII(TII), SGID(SGID) {
110 if (NeedsCache) {
111 Cache = SmallVector<SUnit *, 4>();
112 }
113 }
114
115 virtual ~InstructionRule() = default;
116};
117
118using SUnitsToCandidateSGsMap = DenseMap<SUnit *, SmallVector<int, 4>>;
119
120// Classify instructions into groups to enable fine tuned control over the
121// scheduler. These groups may be more specific than current SchedModel
122// instruction classes.
123class SchedGroup {
124private:
125 // Mask that defines which instruction types can be classified into this
126 // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
127 // and SCHED_GROUP_BARRIER.
128 SchedGroupMask SGMask;
129
130 // Maximum number of SUnits that can be added to this group.
131 std::optional<unsigned> MaxSize;
132
133 // SchedGroups will only synchronize with other SchedGroups that have the same
134 // SyncID.
135 int SyncID = 0;
136
137 // SGID is used to map instructions to candidate SchedGroups
138 unsigned SGID;
139
140 // The different rules each instruction in this SchedGroup must conform to
142
143 // Count of the number of created SchedGroups, used to initialize SGID.
144 static unsigned NumSchedGroups;
145
146 // Use SGMask to determine whether we can classify MI as a member of this
147 // SchedGroup object.
148 bool canAddMI(const MachineInstr &MI) const;
149
150public:
151 // Collection of SUnits that are classified as members of this group.
152 SmallVector<SUnit *, 32> Collection;
153
155 const SIInstrInfo *TII;
156
157 // Try to add and edge from SU A to SU B.
158 bool tryAddEdge(SUnit *A, SUnit *B);
159
160 // Returns true if SU can be added to this SchedGroup.
161 bool canAddSU(SUnit &SU) const;
162
163 // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
164 // MakePred is true, SU will be a predecessor of the SUnits in this
165 // SchedGroup, otherwise SU will be a successor.
166 void link(SUnit &SU, bool MakePred = false);
167
168 // Add DAG dependencies and track which edges are added, and the count of
169 // missed edges
170 int link(SUnit &SU, bool MakePred,
171 std::list<std::pair<SUnit *, SUnit *>> &AddedEdges);
172
173 // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
174 // Use the predicate to determine whether SU should be a predecessor (P =
175 // true) or a successor (P = false) of this SchedGroup.
176 void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
177
178 // Add DAG dependencies such that SUnits in this group shall be ordered
179 // before SUnits in OtherGroup.
180 void link(SchedGroup &OtherGroup);
181
182 // Returns true if no more instructions may be added to this group.
183 bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
184
185 // Append a constraint that SUs must meet in order to fit into this
186 // SchedGroup. Since many rules involve the relationship between a SchedGroup
187 // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve
188 // time (rather than SchedGroup init time.)
189 void addRule(std::shared_ptr<InstructionRule> NewRule) {
190 Rules.push_back(NewRule);
191 }
192
193 // Returns true if the SU matches all rules
194 bool allowedByRules(const SUnit *SU,
195 SmallVectorImpl<SchedGroup> &SyncPipe) const {
196 for (auto &Rule : Rules) {
197 if (!Rule->apply(SU, Collection, SyncPipe))
198 return false;
199 }
200 return true;
201 }
202
203 // Add SU to the SchedGroup.
204 void add(SUnit &SU) {
205 LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
206 << format_hex((int)SGMask, 10, true) << " adding "
207 << *SU.getInstr());
208 Collection.push_back(&SU);
209 }
210
211 // Remove last element in the SchedGroup
212 void pop() { Collection.pop_back(); }
213
214 template <class T>
215 void findCandidateSUnits(T Begin, T End,
216 SUnitsToCandidateSGsMap &SyncedInstrs);
217
218 /// Find each SUnit in the DAG that could potentially be added to
219 /// this SchedGroup and add the SGID to the candidate SchedGroups
220 /// for SU in \p SyncedInstrs.
221 void findCandidateSUnits(SUnitsToCandidateSGsMap &SyncedInstrs);
222
223 int getSyncID() { return SyncID; }
224
225 int getSGID() { return SGID; }
226
227 SchedGroupMask getMask() { return SGMask; }
228
229 SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
230 ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
231 : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) {
232 SGID = NumSchedGroups++;
233 }
234
235 SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
236 ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
237 : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) {
238 SGID = NumSchedGroups++;
239 }
240};
241
242using SUToCandSGsPair = std::pair<SUnit *, SmallVector<int, 4>>;
243using SUsToCandSGsVec = SmallVector<SUToCandSGsPair, 4>;
244
245// The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
246// in non-trivial cases. For example, if the requested pipeline is
247// {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
248// in the DAG, then we will have an instruction that can not be trivially
249// assigned to a SchedGroup. The PipelineSolver class implements two algorithms
250// to find a good solution to the pipeline -- a greedy algorithm and an exact
251// algorithm. The exact algorithm has an exponential time complexity and should
252// only be used for small sized problems or medium sized problems where an exact
253// solution is highly desired.
254class PipelineSolver {
255 [[maybe_unused]] ScheduleDAGMI *DAG;
256
257 // Instructions that can be assigned to multiple SchedGroups
259 SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
261 // The current working pipeline
263 // The pipeline that has the best solution found so far
265
266 // Whether or not we actually have any SyncedInstrs to try to solve.
267 bool NeedsSolver = false;
268
269 // Compute an estimate of the size of search tree -- the true size is
270 // the product of each conflictedInst.Matches.size() across all SyncPipelines
271 unsigned computeProblemSize();
272
273 // The cost penalty of not assigning a SU to a SchedGroup
274 int MissPenalty = 0;
275
276 // Costs in terms of the number of edges we are unable to add
277 int BestCost = -1;
278 int CurrCost = 0;
279
280 // Index pointing to the conflicting instruction that is currently being
281 // fitted
282 int CurrConflInstNo = 0;
283 // Index to the pipeline that is currently being fitted
284 int CurrSyncGroupIdx = 0;
285 // The first non trivial pipeline
286 int BeginSyncGroupIdx = 0;
287
288 // How many branches we have explored
289 uint64_t BranchesExplored = 0;
290
291 // The direction in which we process the candidate SchedGroups per SU
292 bool IsBottomUp = true;
293
294 // Update indices to fit next conflicting instruction
295 void advancePosition();
296 // Recede indices to attempt to find better fit for previous conflicting
297 // instruction
298 void retreatPosition();
299
300 // The exponential time algorithm which finds the provably best fit
301 bool solveExact();
302 // The polynomial time algorithm which attempts to find a good fit
303 bool solveGreedy();
304 // Find the best SchedGroup for the current SU using the heuristic given all
305 // current information. One step in the greedy algorithm. Templated against
306 // the SchedGroup iterator (either reverse or forward).
307 template <typename T>
308 void greedyFind(std::list<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
309 // Whether or not the current solution is optimal
310 bool checkOptimal();
311 // Populate the ready list, prioiritizing fewest missed edges first
312 // Templated against the SchedGroup iterator (either reverse or forward).
313 template <typename T>
314 void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
315 T E);
316 // Add edges corresponding to the SchedGroups as assigned by solver
317 void makePipeline();
318 // Link the SchedGroups in the best found pipeline.
319 // Tmplated against the SchedGroup iterator (either reverse or forward).
320 template <typename T> void linkSchedGroups(T I, T E);
321 // Add the edges from the SU to the other SchedGroups in pipeline, and
322 // return the number of edges missed.
323 int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
324 std::list<std::pair<SUnit *, SUnit *>> &AddedEdges);
325
326 /// This class is used to build the edge set implied by an
327 /// assignment of an SUnit to a SchedGroup and to compute the cost
328 /// (edges that cannot be assigned without introducing cycles) of
329 /// the assignment.
330 class EdgeSetBuilder {
331 SUnit *SU;
332 SmallVectorImpl<SchedGroup> &SyncPipeline;
333 bool IsBottomUp;
334 DenseSet<SUnit *> InitialPreds;
335 DenseSet<SUnit *> Succs;
336 bool Initialized = false;
337
338 /// Compute reachability via DFS. If ComputePreds is true, follows
339 /// predecessor edges; otherwise follows successor edges.
340 template <bool ComputePreds>
341 static void computeReachable(DenseSet<SUnit *> &Reachable, SUnit *Start);
342
343 /// Compute all nodes that can reach Start via predecessor edges, including
344 /// Start itself.
345 static void computePreds(DenseSet<SUnit *> &Preds, SUnit *Start);
346
347 /// Compute all nodes reachable from Start via successor edges, including
348 /// Start itself.
349 static void computeSuccs(DenseSet<SUnit *> &Succs, SUnit *Start);
350
351 public:
352 EdgeSetBuilder(SUnit *SU, SmallVectorImpl<SchedGroup> &SyncPipeline,
353 bool IsBottomUp)
354 : SU(SU), SyncPipeline(SyncPipeline), IsBottomUp(IsBottomUp) {}
355
356 /// Determine the edges implied by assigning SU to the SchedGroup
357 /// with ID SGID. Edges are added to NewEdges unless they
358 /// introduce cycles. Return the number of edges that cannot be
359 /// added.
360 int build(int SGID, std::list<std::pair<SUnit *, SUnit *>> &NewEdges);
361
362 private:
363 template <typename T>
364 int buildImpl(int SGID, const iterator_range<T> SchedGroups,
365 std::list<std::pair<SUnit *, SUnit *>> &NewEdges);
366 };
367
368 /// Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
369 /// returns the cost (in terms of missed pipeline edges), and tracks the edges
370 /// added in \p AddedEdges
371 template <typename T>
372 int linkSUnit(SUnit *SU, int SGID,
373 std::list<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
374 /// Remove the edges passed via \p AddedEdges
375 void removeEdges(const std::list<std::pair<SUnit *, SUnit *>> &AddedEdges);
376 // Convert the passed in maps to arrays for bidirectional iterators
377 void convertSyncMapsToArrays();
378
379 void reset();
380
381public:
382 // Invoke the solver to map instructions to instruction groups. Heuristic &&
383 // command-line-option determines to use exact or greedy algorithm.
384 void solve();
385
386 PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
388 ScheduleDAGMI *DAG, bool IsBottomUp = true)
389 : DAG(DAG), SyncedInstrs(SyncedInstrs),
390 SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
391
392 for (auto &PipelineInstrs : SyncedInstrs) {
393 if (!PipelineInstrs.second.empty()) {
394 NeedsSolver = true;
395 break;
396 }
397 }
398
399 if (!NeedsSolver)
400 return;
401
402 convertSyncMapsToArrays();
403
404 CurrPipeline = BestPipeline;
405
406 while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
407 PipelineInstrs[BeginSyncGroupIdx].empty())
408 ++BeginSyncGroupIdx;
409
410 if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
411 return;
412 }
413};
414
415void PipelineSolver::reset() {
416
417 for (auto &SyncPipeline : CurrPipeline) {
418 for (auto &SG : SyncPipeline) {
419 SmallVector<SUnit *, 32> TempCollection = SG.Collection;
420 SG.Collection.clear();
421 auto *SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
422 return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
423 });
424 if (SchedBarr != TempCollection.end())
425 SG.Collection.push_back(*SchedBarr);
426 }
427 }
428
429 CurrSyncGroupIdx = BeginSyncGroupIdx;
430 CurrConflInstNo = 0;
431 CurrCost = 0;
432}
433
434void PipelineSolver::convertSyncMapsToArrays() {
435 for (auto &SyncPipe : SyncedSchedGroups) {
436 BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
437 }
438
439 int PipelineIDx = SyncedInstrs.size() - 1;
440 PipelineInstrs.resize(SyncedInstrs.size());
441 for (auto &SyncInstrMap : SyncedInstrs) {
442 for (auto &SUsToCandSGs : SyncInstrMap.second) {
443 if (PipelineInstrs[PipelineIDx].empty()) {
444 PipelineInstrs[PipelineIDx].push_back(
445 std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
446 continue;
447 }
448 auto *SortPosition = PipelineInstrs[PipelineIDx].begin();
449 // Insert them in sorted order -- this allows for good parsing order in
450 // the greedy algorithm
451 while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
452 SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
453 ++SortPosition;
454 PipelineInstrs[PipelineIDx].insert(
455 SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
456 }
457 --PipelineIDx;
458 }
459}
460
461template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
462 for (; I != E; ++I) {
463 auto &GroupA = *I;
464 for (auto J = std::next(I); J != E; ++J) {
465 auto &GroupB = *J;
466 GroupA.link(GroupB);
467 }
468 }
469}
470
471void PipelineSolver::makePipeline() {
472 // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
473 for (auto &SyncPipeline : BestPipeline) {
474 LLVM_DEBUG(dbgs() << "Printing SchedGroups\n");
475 for (auto &SG : SyncPipeline) {
476 LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()
477 << " has: \n");
478 SUnit *SGBarr = nullptr;
479 for (auto &SU : SG.Collection) {
480 if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
481 SGBarr = SU;
482 LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
483 }
484 // Command line requested IGroupLP doesn't have SGBarr
485 if (!SGBarr)
486 continue;
487 SG.link(*SGBarr, false);
488 }
489 }
490
491 for (auto &SyncPipeline : BestPipeline) {
492 IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())
493 : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());
494 }
495}
496
497template <typename T>
498int PipelineSolver::linkSUnit(
499 SUnit *SU, int SGID, std::list<std::pair<SUnit *, SUnit *>> &AddedEdges,
500 T I, T E) {
501 bool MakePred = false;
502 int AddedCost = 0;
503 for (; I < E; ++I) {
504 if (I->getSGID() == SGID) {
505 MakePred = true;
506 continue;
507 }
508 auto Group = *I;
509 AddedCost += Group.link(*SU, MakePred, AddedEdges);
510 assert(AddedCost >= 0);
511 }
512 return AddedCost;
513}
514
515template <bool ComputePreds>
516void PipelineSolver::EdgeSetBuilder::computeReachable(
517 DenseSet<SUnit *> &Reachable, SUnit *Start) {
518 if (!Reachable.insert(Start).second)
519 return;
520
521 SmallVector<SUnit *, 32> WorkList = {Start};
522
523 while (!WorkList.empty()) {
524 SUnit *Current = WorkList.pop_back_val();
525
526 for (const SDep &Dep : ComputePreds ? Current->Preds : Current->Succs) {
527 if (Reachable.insert(Dep.getSUnit()).second)
528 WorkList.push_back(Dep.getSUnit());
529 }
530 }
531}
532
533void PipelineSolver::EdgeSetBuilder::computePreds(DenseSet<SUnit *> &Preds,
534 SUnit *Start) {
535 computeReachable</*ComputePreds*/ true>(Preds, Start);
536}
537
538void PipelineSolver::EdgeSetBuilder::computeSuccs(DenseSet<SUnit *> &Succs,
539 SUnit *Start) {
540 computeReachable</*ComputePreds*/ false>(Succs, Start);
541}
542
543int PipelineSolver::EdgeSetBuilder::build(
544 int SGID, std::list<std::pair<SUnit *, SUnit *>> &NewEdges) {
545 if (!Initialized) {
546 computePreds(InitialPreds, SU);
547 computeSuccs(Succs, SU);
548 Initialized = true;
549 }
550
551 // See comment in addEdges concerning the iterator direction.
552 return IsBottomUp ? buildImpl(SGID, reverse(SyncPipeline), NewEdges)
553 : buildImpl(SGID,
554 llvm::make_range(SyncPipeline.begin(),
555 SyncPipeline.end()),
556 NewEdges);
557}
558
559template <typename T>
560int PipelineSolver::EdgeSetBuilder::buildImpl(
561 int SGID, iterator_range<T> SchedGroups,
562 std::list<std::pair<SUnit *, SUnit *>> &NewEdges) {
563
564 // Determine the edges that will be added to the DAG if SU is
565 // assigned to the SchedGroup SG with the given SGID. It might be
566 // impossible to add some edges because they would introduce
567 // cycles. The number of such edges is counted and returned, all
568 // other edges are added to NewEdges.
569 //
570 // SU is made a successor of SUnits in SchedGroups before SG, and a
571 // predecessor of SUnits after SG. In each case, the cycle check
572 // requires reachability information for the opposing direction.
573
574 // Nodes U that can reach SU (U ~> SU).
575 // Will be extended as new edges are added and hence cannot be
576 // shared between calls to this function, in contrast to Succs.
577 DenseSet<SUnit *> Preds = InitialPreds;
578
579 int MissedEdges = 0;
580 bool MakePred = false;
581 for (SchedGroup &SG : SchedGroups) {
582 if (SG.getSGID() == SGID) {
583 MakePred = true;
584 continue;
585 }
586
587 for (SUnit *A : SG.Collection) {
588 if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
589 continue;
590
591 if (MakePred) {
592 // Try add SU -> A.
593 if (Preds.contains(A)) { // Would add cycle since A ~> SU.
594 ++MissedEdges;
595 continue;
596 }
597 // Succs does not need to be updated, since it will not be
598 // queried after entering the MakePred case.
599 NewEdges.emplace_back(SU, A);
600 continue;
601 }
602
603 // Try add A -> SU.
604 if (Succs.contains(A)) { // Would add cycle since SU ~> A.
605 ++MissedEdges;
606 continue;
607 }
608 NewEdges.emplace_back(A, SU);
609 computePreds(Preds, A);
610 }
611 }
612
613 return MissedEdges;
614}
615
616int PipelineSolver::addEdges(
617 SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
618 std::list<std::pair<SUnit *, SUnit *>> &AddedEdges) {
619
620 // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
621 // instructions that are the ultimate successors in the resultant mutation.
622 // Therefore, in such a configuration, the SchedGroups occurring before the
623 // candidate SGID are successors of the candidate SchedGroup, thus the current
624 // SU should be linked as a predecessor to SUs in those SchedGroups. The
625 // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
626 // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
627 // IsBottomUp (in reverse).
628 return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),
629 SyncPipeline.rend())
630 : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),
631 SyncPipeline.end());
632}
633
634void PipelineSolver::removeEdges(
635 const std::list<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
636 // Only remove the edges that we have added when testing
637 // the fit.
638 for (auto &PredSuccPair : EdgesToRemove) {
639 SUnit *Pred = PredSuccPair.first;
640 SUnit *Succ = PredSuccPair.second;
641
642 auto *Match = llvm::find_if(Succ->Preds, [&Pred](SDep &P) {
643 return P.getSUnit() == Pred && P.isArtificial();
644 });
645 if (Match != Succ->Preds.end())
646 Succ->removePred(*Match);
647 }
648}
649
650void PipelineSolver::advancePosition() {
651 ++CurrConflInstNo;
652
653 if (static_cast<size_t>(CurrConflInstNo) >=
654 PipelineInstrs[CurrSyncGroupIdx].size()) {
655 CurrConflInstNo = 0;
656 ++CurrSyncGroupIdx;
657 // Advance to next non-trivial pipeline
658 while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
659 PipelineInstrs[CurrSyncGroupIdx].empty())
660 ++CurrSyncGroupIdx;
661 }
662}
663
664void PipelineSolver::retreatPosition() {
665 assert(CurrConflInstNo >= 0);
666 assert(CurrSyncGroupIdx >= 0);
667
668 if (CurrConflInstNo > 0) {
669 --CurrConflInstNo;
670 return;
671 }
672
673 if (CurrConflInstNo == 0) {
674 // If we return to the starting position, we have explored
675 // the entire tree
676 if (CurrSyncGroupIdx == BeginSyncGroupIdx)
677 return;
678
679 --CurrSyncGroupIdx;
680 // Go to previous non-trivial pipeline
681 while (PipelineInstrs[CurrSyncGroupIdx].empty())
682 --CurrSyncGroupIdx;
683
684 CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
685 }
686}
687
688bool PipelineSolver::checkOptimal() {
689 if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
690 if (BestCost == -1 || CurrCost < BestCost) {
691 BestPipeline = CurrPipeline;
692 BestCost = CurrCost;
693 LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
694 }
695 assert(BestCost >= 0);
696 }
697
698 bool DoneExploring = false;
699 if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
700 DoneExploring = true;
701
702 return (DoneExploring || BestCost == 0);
703}
704
705template <typename T>
706void PipelineSolver::populateReadyList(
707 SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
708 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
709 auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
710 assert(CurrSU.second.size() >= 1);
711
712 for (; I != E; ++I) {
713 std::list<std::pair<SUnit *, SUnit *>> AddedEdges;
714 int CandSGID = *I;
715 SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
716 return SG.getSGID() == CandSGID;
717 });
718 assert(Match);
719
720 if (UseCostHeur) {
721 if (Match->isFull()) {
722 ReadyList.push_back(std::pair(*I, MissPenalty));
723 continue;
724 }
725
726 int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
727 ReadyList.push_back(std::pair(*I, TempCost));
728 removeEdges(AddedEdges);
729 } else
730 ReadyList.push_back(std::pair(*I, -1));
731 }
732
733 if (UseCostHeur)
734 std::sort(ReadyList.begin(), ReadyList.end(), llvm::less_second());
735
736 assert(ReadyList.size() == CurrSU.second.size());
737}
738
739bool PipelineSolver::solveExact() {
740 if (checkOptimal())
741 return true;
742
743 if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
744 return false;
745
746 assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
747 assert(static_cast<size_t>(CurrConflInstNo) <
748 PipelineInstrs[CurrSyncGroupIdx].size());
749 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
750 LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
751 << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
752
753 // SchedGroup -> Cost pairs
755 // Prioritize the candidate sched groups in terms of lowest cost first
756 IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),
757 CurrSU.second.rend())
758 : populateReadyList(ReadyList, CurrSU.second.begin(),
759 CurrSU.second.end());
760
761 auto *I = ReadyList.begin();
762 auto *E = ReadyList.end();
763 for (; I != E; ++I) {
764 // If we are trying SGs in least cost order, and the current SG is cost
765 // infeasible, then all subsequent SGs will also be cost infeasible, so we
766 // can prune.
767 if (BestCost != -1 && (CurrCost + I->second > BestCost))
768 return false;
769
770 int CandSGID = I->first;
771 int AddedCost = 0;
772 std::list<std::pair<SUnit *, SUnit *>> AddedEdges;
773 auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
774 SchedGroup *Match;
775 for (auto &SG : SyncPipeline) {
776 if (SG.getSGID() == CandSGID)
777 Match = &SG;
778 }
779
780 if (Match->isFull())
781 continue;
782
783 if (!Match->allowedByRules(CurrSU.first, SyncPipeline))
784 continue;
785
786 LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
787 << (int)Match->getMask() << "and ID " << CandSGID
788 << "\n");
789 Match->add(*CurrSU.first);
790 AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
791 LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
792 CurrCost += AddedCost;
793 advancePosition();
794 ++BranchesExplored;
795 bool FinishedExploring = false;
796 // If the Cost after adding edges is greater than a known solution,
797 // backtrack
798 if (CurrCost < BestCost || BestCost == -1) {
799 if (solveExact()) {
800 FinishedExploring = BestCost != 0;
801 if (!FinishedExploring)
802 return true;
803 }
804 }
805
806 retreatPosition();
807 CurrCost -= AddedCost;
808 removeEdges(AddedEdges);
809 Match->pop();
810 CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
811 if (FinishedExploring)
812 return true;
813 }
814
815 // Try the pipeline where the current instruction is omitted
816 // Potentially if we omit a problematic instruction from the pipeline,
817 // all the other instructions can nicely fit.
818 CurrCost += MissPenalty;
819 advancePosition();
820
821 LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
822
823 bool FinishedExploring = false;
824 if (CurrCost < BestCost || BestCost == -1) {
825 if (solveExact()) {
826 bool FinishedExploring = BestCost != 0;
827 if (!FinishedExploring)
828 return true;
829 }
830 }
831
832 retreatPosition();
833 CurrCost -= MissPenalty;
834 return FinishedExploring;
835}
836
837template <typename T>
838void PipelineSolver::greedyFind(
839 std::list<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
840 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
841
842 struct GroupInfo {
843 SchedGroup *SG;
844 std::list<std::pair<SUnit *, SUnit *>> Edges;
845 int Cost = 0;
846 };
847 std::optional<GroupInfo> Best;
848
849 auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
850 LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
851 << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
852
853 EdgeSetBuilder Builder(CurrSU.first, SyncPipeline, IsBottomUp);
854
855 // Since we have added the potential SchedGroups from bottom up, but
856 // traversed the DAG from top down, parse over the groups from last to
857 // first. If we fail to do this for the greedy algorithm, the solution will
858 // likely not be good in more complex cases.
859 for (; I != E; ++I) {
860 int CandSGID = *I;
861 SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
862 return SG.getSGID() == CandSGID;
863 });
864 assert(Match);
865
866 LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
867 << (int)Match->getMask() << "\n");
868
869 if (Match->isFull()) {
870 LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
871 continue;
872 }
873 if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) {
874 LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
875 continue;
876 }
877
878 std::list<std::pair<SUnit *, SUnit *>> TempEdges;
879 int TempCost = Builder.build(CandSGID, TempEdges);
880 LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
881
882 if (!Best || TempCost < Best->Cost) {
883 Best = {Match, TempEdges, TempCost};
884 if (Best->Cost == 0)
885 break;
886 }
887 }
888
889 if (Best) {
890 SchedGroup *SG = Best->SG;
891 std::list<std::pair<SUnit *, SUnit *>> &Edges = Best->Edges;
892
893 SG->add(*CurrSU.first);
894 if (AddedEdges.empty())
895 AddedEdges = Edges;
896 else
897 AddedEdges.splice(std::prev(AddedEdges.cend()), Edges);
898
899 for (const std::pair<SUnit *, SUnit *> &E : Edges) {
900 if (!SG->tryAddEdge(E.first, E.second))
901 llvm_unreachable("Edges known to be insertable.");
902 }
903
904 LLVM_DEBUG(dbgs() << "Best Group has ID: " << SG->getSGID() << " and Mask"
905 << (int)SG->getMask() << "\n");
906 BestCost += Best->Cost;
907 } else
908 BestCost += MissPenalty;
909}
910
911bool PipelineSolver::solveGreedy() {
912 BestCost = 0;
913 std::list<std::pair<SUnit *, SUnit *>> AddedEdges;
914
915 while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
916 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
917 IsBottomUp
918 ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())
919 : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());
920 advancePosition();
921 }
922 BestPipeline = CurrPipeline;
923 removeEdges(AddedEdges);
924 return false;
925}
926
927unsigned PipelineSolver::computeProblemSize() {
928 unsigned ProblemSize = 0;
929 for (auto &PipeConflicts : PipelineInstrs) {
930 ProblemSize += PipeConflicts.size();
931 }
932
933 return ProblemSize;
934}
935
936void PipelineSolver::solve() {
937 if (!NeedsSolver)
938 return;
939
940 unsigned ProblemSize = computeProblemSize();
941 assert(ProblemSize > 0);
942
943 bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
944 MissPenalty = (ProblemSize / 2) + 1;
945
946 LLVM_DEBUG(DAG->dump());
947 if (EnableExactSolver || BelowCutoff) {
948 LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
949 solveGreedy();
950 reset();
951 LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
952 if (BestCost > 0) {
953 LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
954 solveExact();
955 LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
956 }
957 } else { // Use the Greedy Algorithm by default
958 LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
959 solveGreedy();
960 LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
961 }
962
963 makePipeline();
964 LLVM_DEBUG(dbgs() << "After applying mutation\n");
965 LLVM_DEBUG(DAG->dump());
966}
967
968// Implement a IGLP scheduling strategy.
969class IGLPStrategy {
970protected:
972
973 const SIInstrInfo *TII;
974
975public:
976 /// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
977 virtual bool applyIGLPStrategy(
979 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
981
982 // Returns true if this strategy should be applied to a ScheduleDAG.
983 virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
985
986 bool IsBottomUp = true;
987
988 IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
989 : DAG(DAG), TII(TII) {}
990
991 virtual ~IGLPStrategy() = default;
992};
993
994class MFMASmallGemmOpt final : public IGLPStrategy {
995private:
996public:
997 bool applyIGLPStrategy(
999 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1001
1002 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1003 AMDGPU::SchedulingPhase Phase) override {
1004 return true;
1005 }
1006
1007 MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1008 : IGLPStrategy(DAG, TII) {
1009 IsBottomUp = true;
1010 }
1011};
1012
1013bool MFMASmallGemmOpt::applyIGLPStrategy(
1015 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1017 // Count the number of MFMA instructions.
1018 unsigned MFMACount = 0;
1019 for (const MachineInstr &I : *DAG)
1020 if (TII->isMFMAorWMMA(I))
1021 ++MFMACount;
1022
1023 const unsigned PipelineSyncID = 0;
1024 SchedGroup *SG = nullptr;
1025 for (unsigned I = 0; I < MFMACount * 3; ++I) {
1026 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1027 SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
1028 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1029
1030 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1031 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1032 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1033 }
1034
1035 return true;
1036}
1037
1038class MFMAExpInterleaveOpt final : public IGLPStrategy {
1039private:
1040 // The count of TRANS SUs involved in the interleaved pipeline
1041 static unsigned TransPipeCount;
1042 // The count of MFMA SUs involved in the interleaved pipeline
1043 static unsigned MFMAPipeCount;
1044 // The count of Add SUs involved in the interleaved pipeline
1045 static unsigned AddPipeCount;
1046 // The number of transitive MFMA successors for each TRANS SU
1047 static unsigned MFMAEnablement;
1048 // The number of transitive TRANS predecessors for each MFMA SU
1049 static unsigned ExpRequirement;
1050 // The count of independent "chains" of MFMA instructions in the pipeline
1051 static unsigned MFMAChains;
1052 // Whether or not the pipeline has V_CVT instructions
1053 static bool HasCvt;
1054 // Whether or not there are instructions between the TRANS instruction and
1055 // V_CVT
1056 static bool HasChainBetweenCvt;
1057 // The first occuring DS_READ which feeds an MFMA chain
1058 static std::optional<unsigned> FirstPipeDSR;
1059 // The MFMAPipe SUs with no MFMA predecessors
1060 SmallVector<SUnit *, 4> MFMAChainSeeds;
1061 // Compute the heuristics for the pipeline, returning whether or not the DAG
1062 // is well formatted for the mutation
1063 bool analyzeDAG(const SIInstrInfo *TII);
1064
1065 /// Whether or not the instruction is a transitive predecessor of an MFMA
1066 /// instruction
1067 class IsPipeExp final : public InstructionRule {
1068 public:
1069 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1070 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1071
1072 auto *DAG = SyncPipe[0].DAG;
1073
1074 if (Cache->empty()) {
1075 auto I = DAG->SUnits.rbegin();
1076 auto E = DAG->SUnits.rend();
1077 for (; I != E; I++) {
1078 if (TII->isMFMAorWMMA(*I->getInstr()))
1079 Cache->push_back(&*I);
1080 }
1081 if (Cache->empty())
1082 return false;
1083 }
1084
1085 auto Reaches = any_of(*Cache, [&SU, &DAG](SUnit *TargetSU) {
1086 return DAG->IsReachable(TargetSU, const_cast<SUnit *>(SU));
1087 });
1088
1089 return Reaches;
1090 }
1091 IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1092 : InstructionRule(TII, SGID, NeedsCache) {}
1093 };
1094
1095 /// Whether or not the instruction is a transitive predecessor of the
1096 /// \p Number th MFMA of the MFMAs occuring after a TRANS instruction
1097 class EnablesNthMFMA final : public InstructionRule {
1098 private:
1099 unsigned Number = 1;
1100
1101 public:
1102 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1103 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1104 bool FoundTrans = false;
1105 unsigned Counter = 1;
1106 auto *DAG = SyncPipe[0].DAG;
1107
1108 if (Cache->empty()) {
1109 auto I = DAG->SUnits.begin();
1110 auto E = DAG->SUnits.end();
1111 for (; I != E; I++) {
1112 if (FoundTrans && TII->isMFMAorWMMA(*I->getInstr())) {
1113 if (Counter == Number) {
1114 Cache->push_back(&*I);
1115 break;
1116 }
1117 ++Counter;
1118 }
1119 if (!FoundTrans && TII->isTRANS(I->getInstr()->getOpcode()))
1120 FoundTrans = true;
1121 }
1122 if (Cache->empty())
1123 return false;
1124 }
1125
1126 return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
1127 }
1128
1129 EnablesNthMFMA(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1130 bool NeedsCache = false)
1131 : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1132 };
1133
1134 /// Whether or not the instruction enables the exact MFMA that is the \p
1135 /// Number th MFMA in the chain starting with \p ChainSeed
1136 class EnablesNthMFMAInChain final : public InstructionRule {
1137 private:
1138 unsigned Number = 1;
1139 SUnit *ChainSeed;
1140
1141 public:
1142 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1143 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1144 auto *DAG = SyncPipe[0].DAG;
1145
1146 if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1147 return false;
1148
1149 if (Cache->empty()) {
1150 auto *TempSU = ChainSeed;
1151 auto Depth = Number;
1152 while (Depth > 0) {
1153 --Depth;
1154 bool Found = false;
1155 for (auto &Succ : TempSU->Succs) {
1156 if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1157 TempSU = Succ.getSUnit();
1158 Found = true;
1159 break;
1160 }
1161 }
1162 if (!Found)
1163 return false;
1164 }
1165
1166 Cache->push_back(TempSU);
1167 }
1168 // If we failed to find the instruction to be placed into the cache, we
1169 // would have already exited.
1170 assert(!Cache->empty());
1171
1172 return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
1173 }
1174
1175 EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,
1176 const SIInstrInfo *TII, unsigned SGID,
1177 bool NeedsCache = false)
1178 : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1179 ChainSeed(ChainSeed) {}
1180 };
1181
1182 /// Whether or not the instruction has less than \p Size immediate successors.
1183 /// If \p HasIntermediary is true, this tests also whether all successors of
1184 /// the SUnit have less than \p Size successors.
1185 class LessThanNSuccs final : public InstructionRule {
1186 private:
1187 unsigned Size = 1;
1188 bool HasIntermediary = false;
1189
1190 public:
1191 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1192 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1193 if (!SyncPipe.size())
1194 return false;
1195
1196 unsigned SuccSize = llvm::count_if(SU->Succs, [](const SDep &Succ) {
1197 return Succ.getKind() == SDep::Data;
1198 });
1199 if (SuccSize >= Size)
1200 return false;
1201
1202 if (HasIntermediary) {
1203 for (auto Succ : SU->Succs) {
1204 unsigned SuccSize =
1205 llvm::count_if(Succ.getSUnit()->Succs, [](const SDep &SuccSucc) {
1206 return SuccSucc.getKind() == SDep::Data;
1207 });
1208 if (SuccSize >= Size)
1209 return false;
1210 }
1211 }
1212
1213 return true;
1214 }
1215 LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
1216 bool HasIntermediary = false, bool NeedsCache = false)
1217 : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1218 HasIntermediary(HasIntermediary) {}
1219 };
1220
1221 /// Whether or not the instruction has greater than or equal to \p Size
1222 /// immediate successors. If \p HasIntermediary is true, this tests also
1223 /// whether all successors of the SUnit have greater than or equal to \p Size
1224 /// successors.
1225 class GreaterThanOrEqualToNSuccs final : public InstructionRule {
1226 private:
1227 unsigned Size = 1;
1228 bool HasIntermediary = false;
1229
1230 public:
1231 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1232 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1233 if (!SyncPipe.size())
1234 return false;
1235
1236 unsigned SuccSize = llvm::count_if(SU->Succs, [](const SDep &Succ) {
1237 return Succ.getKind() == SDep::Data;
1238 });
1239 if (SuccSize >= Size)
1240 return true;
1241
1242 if (HasIntermediary) {
1243 for (auto Succ : SU->Succs) {
1244 unsigned SuccSize =
1245 llvm::count_if(Succ.getSUnit()->Succs, [](const SDep &SuccSucc) {
1246 return SuccSucc.getKind() == SDep::Data;
1247 });
1248 if (SuccSize >= Size)
1249 return true;
1250 }
1251 }
1252
1253 return false;
1254 }
1255 GreaterThanOrEqualToNSuccs(unsigned Size, const SIInstrInfo *TII,
1256 unsigned SGID, bool HasIntermediary = false,
1257 bool NeedsCache = false)
1258 : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1259 HasIntermediary(HasIntermediary) {}
1260 };
1261
1262 // Whether or not the instruction is a relevant V_CVT instruction.
1263 class IsCvt final : public InstructionRule {
1264 public:
1265 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1266 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1267 auto Opc = SU->getInstr()->getOpcode();
1268 return Opc == AMDGPU::V_CVT_F16_F32_e32 ||
1269 Opc == AMDGPU::V_CVT_I32_F32_e32;
1270 }
1271 IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1272 : InstructionRule(TII, SGID, NeedsCache) {}
1273 };
1274
1275 // Whether or not the instruction is FMA_F32.
1276 class IsFMA final : public InstructionRule {
1277 public:
1278 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1279 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1280 return SU->getInstr()->getOpcode() == AMDGPU::V_FMA_F32_e64 ||
1281 SU->getInstr()->getOpcode() == AMDGPU::V_PK_FMA_F32;
1282 }
1283 IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1284 : InstructionRule(TII, SGID, NeedsCache) {}
1285 };
1286
1287 // Whether or not the instruction is a V_ADD_F32 instruction.
1288 class IsPipeAdd final : public InstructionRule {
1289 public:
1290 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1291 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1292 return SU->getInstr()->getOpcode() == AMDGPU::V_ADD_F32_e32;
1293 }
1294 IsPipeAdd(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1295 : InstructionRule(TII, SGID, NeedsCache) {}
1296 };
1297
1298 /// Whether or not the instruction is an immediate RAW successor
1299 /// of the SchedGroup \p Distance steps before.
1300 class IsSuccOfPrevNthGroup final : public InstructionRule {
1301 private:
1302 unsigned Distance = 1;
1303
1304 public:
1305 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1306 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1307 SchedGroup *OtherGroup = nullptr;
1308 if (!SyncPipe.size())
1309 return false;
1310
1311 for (auto &PipeSG : SyncPipe) {
1312 if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1313 OtherGroup = &PipeSG;
1314 }
1315
1316 if (!OtherGroup)
1317 return false;
1318 if (!OtherGroup->Collection.size())
1319 return true;
1320
1321 for (auto &OtherEle : OtherGroup->Collection) {
1322 for (auto &Succ : OtherEle->Succs) {
1323 if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)
1324 return true;
1325 }
1326 }
1327
1328 return false;
1329 }
1330 IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1331 unsigned SGID, bool NeedsCache = false)
1332 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1333 };
1334
1335 /// Whether or not the instruction is a transitive successor of any
1336 /// instruction the the SchedGroup \p Distance steps before.
1337 class IsReachableFromPrevNthGroup final : public InstructionRule {
1338 private:
1339 unsigned Distance = 1;
1340
1341 public:
1342 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1343 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1344 SchedGroup *OtherGroup = nullptr;
1345 if (!SyncPipe.size())
1346 return false;
1347
1348 for (auto &PipeSG : SyncPipe) {
1349 if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1350 OtherGroup = &PipeSG;
1351 }
1352
1353 if (!OtherGroup)
1354 return false;
1355 if (!OtherGroup->Collection.size())
1356 return true;
1357
1358 auto *DAG = SyncPipe[0].DAG;
1359
1360 for (auto &OtherEle : OtherGroup->Collection)
1361 if (DAG->IsReachable(const_cast<SUnit *>(SU), OtherEle))
1362 return true;
1363
1364 return false;
1365 }
1366 IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1367 unsigned SGID, bool NeedsCache = false)
1368 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1369 };
1370
1371 /// Whether or not the instruction occurs after the SU with NodeNUm \p Number
1372 class OccursAtOrAfterNode final : public InstructionRule {
1373 private:
1374 unsigned Number = 1;
1375
1376 public:
1377 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1378 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1379
1380 return SU->NodeNum >= Number;
1381 }
1382 OccursAtOrAfterNode(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1383 bool NeedsCache = false)
1384 : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1385 };
1386
1387 /// Whether or not the SU is exactly the \p Number th MFMA in the chain
1388 /// starting with \p ChainSeed
1389 class IsExactMFMA final : public InstructionRule {
1390 private:
1391 unsigned Number = 1;
1392 SUnit *ChainSeed;
1393
1394 public:
1395 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1396 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1397 if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1398 return false;
1399
1400 if (Cache->empty()) {
1401 auto *TempSU = ChainSeed;
1402 auto Depth = Number;
1403 while (Depth > 0) {
1404 --Depth;
1405 bool Found = false;
1406 for (auto &Succ : TempSU->Succs) {
1407 if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1408 TempSU = Succ.getSUnit();
1409 Found = true;
1410 break;
1411 }
1412 }
1413 if (!Found) {
1414 return false;
1415 }
1416 }
1417 Cache->push_back(TempSU);
1418 }
1419 // If we failed to find the instruction to be placed into the cache, we
1420 // would have already exited.
1421 assert(!Cache->empty());
1422
1423 return (*Cache)[0] == SU;
1424 }
1425
1426 IsExactMFMA(unsigned Number, SUnit *ChainSeed, const SIInstrInfo *TII,
1427 unsigned SGID, bool NeedsCache = false)
1428 : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1429 ChainSeed(ChainSeed) {}
1430 };
1431
1432 // Whether the instruction occurs after the first TRANS instruction. This
1433 // implies the instruction can not be a predecessor of the first TRANS
1434 // insruction
1435 class OccursAfterExp final : public InstructionRule {
1436 public:
1437 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1438 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1439
1440 auto *DAG = SyncPipe[0].DAG;
1441 if (Cache->empty()) {
1442 for (auto &SU : DAG->SUnits)
1443 if (TII->isTRANS(SU.getInstr()->getOpcode())) {
1444 Cache->push_back(&SU);
1445 break;
1446 }
1447 if (Cache->empty())
1448 return false;
1449 }
1450
1451 return SU->NodeNum > (*Cache)[0]->NodeNum;
1452 }
1453
1454 OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,
1455 bool NeedsCache = false)
1456 : InstructionRule(TII, SGID, NeedsCache) {}
1457 };
1458
1459public:
1460 bool applyIGLPStrategy(
1462 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1464
1465 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1467
1468 MFMAExpInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1469 : IGLPStrategy(DAG, TII) {
1470 IsBottomUp = false;
1471 }
1472};
1473
1474unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;
1475unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;
1476unsigned MFMAExpInterleaveOpt::AddPipeCount = 0;
1477unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;
1478unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;
1479unsigned MFMAExpInterleaveOpt::MFMAChains = 0;
1480bool MFMAExpInterleaveOpt::HasCvt = false;
1481bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;
1482std::optional<unsigned> MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;
1483
1484bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {
1485 SmallVector<SUnit *, 10> ExpPipeCands;
1486 SmallVector<SUnit *, 10> MFMAPipeCands;
1487 SmallVector<SUnit *, 10> MFMAPipeSUs;
1490
1491 auto isBitPack = [](unsigned Opc) {
1492 return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;
1493 };
1494
1495 auto isCvt = [](unsigned Opc) {
1496 return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;
1497 };
1498
1499 auto isAdd = [](unsigned Opc) { return Opc == AMDGPU::V_ADD_F32_e32; };
1500
1501 AddPipeCount = 0;
1502 for (SUnit &SU : DAG->SUnits) {
1503 auto Opc = SU.getInstr()->getOpcode();
1504 if (TII->isTRANS(Opc)) {
1505 // Avoid counting a potential bonus V_EXP which all the MFMA depend on
1506 if (SU.Succs.size() >= 7)
1507 continue;
1508 for (auto &Succ : SU.Succs) {
1509 if (Succ.getSUnit()->Succs.size() >= 7)
1510 continue;
1511 }
1512 ExpPipeCands.push_back(&SU);
1513 }
1514
1515 if (TII->isMFMAorWMMA(*SU.getInstr()))
1516 MFMAPipeCands.push_back(&SU);
1517
1518 if (isBitPack(Opc))
1519 PackSUs.push_back(&SU);
1520
1521 if (isCvt(Opc))
1522 CvtSUs.push_back(&SU);
1523
1524 if (isAdd(Opc))
1525 ++AddPipeCount;
1526 }
1527
1528 if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))
1529 return false;
1530
1531 TransPipeCount = 0;
1532
1533 std::optional<SUnit *> TempMFMA;
1534 std::optional<SUnit *> TempExp;
1535 // Count the number of EXPs that reach an MFMA
1536 for (auto &PredSU : ExpPipeCands) {
1537 for (auto &SuccSU : MFMAPipeCands) {
1538 if (DAG->IsReachable(SuccSU, PredSU)) {
1539 if (!TempExp) {
1540 TempExp = PredSU;
1541 TempMFMA = SuccSU;
1542 }
1543 MFMAPipeSUs.push_back(SuccSU);
1544 ++TransPipeCount;
1545 break;
1546 }
1547 }
1548 }
1549
1550 if (!(TempExp && TempMFMA))
1551 return false;
1552
1553 HasChainBetweenCvt = none_of((*TempExp)->Succs, [&isCvt](SDep &Succ) {
1554 return isCvt(Succ.getSUnit()->getInstr()->getOpcode());
1555 });
1556
1557 // Count the number of MFMAs that are reached by an EXP
1558 for (auto &SuccSU : MFMAPipeCands) {
1559 if (MFMAPipeSUs.size() &&
1560 any_of(MFMAPipeSUs, [&SuccSU](SUnit *PotentialMatch) {
1561 return PotentialMatch->NodeNum == SuccSU->NodeNum;
1562 }))
1563 continue;
1564
1565 for (auto &PredSU : ExpPipeCands) {
1566 if (DAG->IsReachable(SuccSU, PredSU)) {
1567 MFMAPipeSUs.push_back(SuccSU);
1568 break;
1569 }
1570 }
1571 }
1572
1573 MFMAPipeCount = MFMAPipeSUs.size();
1574
1575 assert(TempExp && TempMFMA);
1576 assert(MFMAPipeCount > 0);
1577
1578 std::optional<SUnit *> TempCvt;
1579 for (auto &SuccSU : CvtSUs) {
1580 if (DAG->IsReachable(SuccSU, *TempExp)) {
1581 TempCvt = SuccSU;
1582 break;
1583 }
1584 }
1585
1586 HasCvt = false;
1587 if (TempCvt.has_value()) {
1588 for (auto &SuccSU : MFMAPipeSUs) {
1589 if (DAG->IsReachable(SuccSU, *TempCvt)) {
1590 HasCvt = true;
1591 break;
1592 }
1593 }
1594 }
1595
1596 MFMAChains = 0;
1597 for (auto &MFMAPipeSU : MFMAPipeSUs) {
1598 if (is_contained(MFMAChainSeeds, MFMAPipeSU))
1599 continue;
1600 if (none_of(MFMAPipeSU->Preds, [&TII](SDep &Succ) {
1601 return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1602 })) {
1603 MFMAChainSeeds.push_back(MFMAPipeSU);
1604 ++MFMAChains;
1605 }
1606 }
1607
1608 if (!MFMAChains)
1609 return false;
1610
1611 for (auto Pred : MFMAChainSeeds[0]->Preds) {
1612 if (TII->isDS(Pred.getSUnit()->getInstr()->getOpcode()) &&
1613 Pred.getSUnit()->getInstr()->mayLoad())
1614 FirstPipeDSR = Pred.getSUnit()->NodeNum;
1615 }
1616
1617 // The number of bit pack operations that depend on a single V_EXP
1618 unsigned PackSuccCount =
1619 llvm::count_if(PackSUs, [this, &TempExp](SUnit *VPack) {
1620 return DAG->IsReachable(VPack, *TempExp);
1621 });
1622
1623 // The number of bit pack operations an MFMA depends on
1624 unsigned PackPredCount =
1625 llvm::count_if((*TempMFMA)->Preds, [&isBitPack](SDep &Pred) {
1626 auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1627 return isBitPack(Opc);
1628 });
1629
1630 auto *PackPred = llvm::find_if((*TempMFMA)->Preds, [&isBitPack](SDep &Pred) {
1631 auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1632 return isBitPack(Opc);
1633 });
1634
1635 if (PackPred == (*TempMFMA)->Preds.end())
1636 return false;
1637
1638 MFMAEnablement = 0;
1639 ExpRequirement = 0;
1640 // How many MFMAs depend on a single bit pack operation
1641 MFMAEnablement =
1642 llvm::count_if(PackPred->getSUnit()->Succs, [&TII](SDep &Succ) {
1643 return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1644 });
1645
1646 // The number of MFMAs that depend on a single V_EXP
1647 MFMAEnablement *= PackSuccCount;
1648
1649 // The number of V_EXPs required to resolve all dependencies for an MFMA
1650 ExpRequirement =
1651 llvm::count_if(ExpPipeCands, [this, &PackPred](SUnit *ExpBase) {
1652 return DAG->IsReachable(PackPred->getSUnit(), ExpBase);
1653 });
1654
1655 ExpRequirement *= PackPredCount;
1656 return true;
1657}
1658
1659bool MFMAExpInterleaveOpt::shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1661 const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1662 const SIInstrInfo *TII = ST.getInstrInfo();
1663
1665 MFMAChainSeeds.clear();
1666 if (Phase != AMDGPU::SchedulingPhase::PostRA && !analyzeDAG(TII))
1667 return false;
1668
1669 return true;
1670}
1671
1672bool MFMAExpInterleaveOpt::applyIGLPStrategy(
1674 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1676
1677 bool IsSmallKernelType =
1678 MFMAEnablement == 2 && ExpRequirement == 4 && TransPipeCount == 32;
1679 bool IsLargeKernelType =
1680 MFMAEnablement == 4 && ExpRequirement == 4 && TransPipeCount == 64;
1681
1682 if (!(IsSmallKernelType || IsLargeKernelType))
1683 return false;
1684
1685 const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1686 const SIInstrInfo *TII = ST.getInstrInfo();
1687
1688 unsigned PipelineSyncID = 0;
1689 SchedGroup *SG = nullptr;
1690
1691 unsigned MFMAChain = 0;
1692 unsigned PositionInChain = 0;
1693 unsigned CurrMFMAForTransPosition = 0;
1694
1695 auto incrementTransPosition = [&MFMAChain, &PositionInChain,
1696 &CurrMFMAForTransPosition]() {
1697 CurrMFMAForTransPosition += MFMAEnablement;
1698 PositionInChain = (CurrMFMAForTransPosition / MFMAChains);
1699 MFMAChain = CurrMFMAForTransPosition % MFMAChains;
1700 };
1701
1702 auto getNextTransPositionInChain = [&CurrMFMAForTransPosition]() {
1703 auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1704 return (TempMFMAForTrans / MFMAChains);
1705 };
1706
1707 auto getNextTransMFMAChain = [&CurrMFMAForTransPosition]() {
1708 auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1709 return TempMFMAForTrans % MFMAChains;
1710 };
1711
1712 unsigned CurrMFMAPosition = 0;
1713 unsigned MFMAChainForMFMA = 0;
1714 unsigned PositionInChainForMFMA = 0;
1715
1716 auto incrementMFMAPosition = [&CurrMFMAPosition, &MFMAChainForMFMA,
1717 &PositionInChainForMFMA]() {
1718 ++CurrMFMAPosition;
1719 MFMAChainForMFMA = CurrMFMAPosition % MFMAChains;
1720 PositionInChainForMFMA = CurrMFMAPosition / MFMAChains;
1721 };
1722
1723 bool IsPostRA = Phase == AMDGPU::SchedulingPhase::PostRA;
1724 assert(IsPostRA || MFMAChainSeeds.size() == MFMAChains);
1725
1726 bool UsesFMA = IsSmallKernelType || !IsPostRA;
1727 bool UsesDSRead = IsLargeKernelType && !IsPostRA && FirstPipeDSR;
1728 bool UsesCvt = HasCvt && (IsSmallKernelType || !IsPostRA);
1729 bool UsesVALU = IsSmallKernelType;
1730
1731 // PHASE 1: "Prefetch"
1732 if (UsesFMA) {
1733 // First Round FMA
1734 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1735 SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1736 if (!IsPostRA && MFMAChains) {
1737 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1738 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1739 true));
1740 } else
1741 SG->addRule(
1742 std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1743 SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1744 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1745
1746 // Second Round FMA
1747 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1748 SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1749 if (!IsPostRA && MFMAChains) {
1750 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1751 getNextTransPositionInChain(),
1752 MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1753 } else
1754 SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1755 SG->getSGID(), true));
1756 SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1757 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1758 }
1759
1760 if (UsesDSRead) {
1761 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1762 SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1763 SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1764 SG->getSGID()));
1765 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1766 }
1767
1768 // First Round EXP
1769 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1770 SchedGroupMask::TRANS, ExpRequirement, PipelineSyncID, DAG, TII);
1771 if (!IsPostRA && MFMAChains)
1772 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1773 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(), true));
1774 else
1775 SG->addRule(std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1776 SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1777 SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1778 HasChainBetweenCvt));
1779 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1780
1781 incrementTransPosition();
1782
1783 // First Round CVT, Third Round FMA, Second Round EXP; interleaved
1784 for (unsigned I = 0; I < ExpRequirement; I++) {
1785 // First Round CVT
1786 if (UsesCvt) {
1787 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1788 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1789 SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1790 if (HasChainBetweenCvt)
1791 SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1792 1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1793 else
1794 SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(
1795 1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1796 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1797 }
1798
1799 // Third Round FMA
1800 if (UsesFMA) {
1801 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1802 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1803 if (!IsPostRA && MFMAChains) {
1804 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1805 getNextTransPositionInChain(),
1806 MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1807 } else
1808 SG->addRule(std::make_shared<EnablesNthMFMA>(2 * MFMAEnablement + 1,
1809 TII, SG->getSGID(), true));
1810 SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1811 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1812 }
1813
1814 // Second Round EXP
1815 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1816 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1817 if (!IsPostRA && MFMAChains)
1818 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1819 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1820 true));
1821 else
1822 SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1823 SG->getSGID(), true));
1824 SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1825 SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1826 HasChainBetweenCvt));
1827 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1828 }
1829
1830 // The "extra" EXP which enables all MFMA
1831 // TODO: UsesExtraExp
1832 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1833 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1834 SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1835 SG->addRule(std::make_shared<GreaterThanOrEqualToNSuccs>(
1836 8, TII, SG->getSGID(), HasChainBetweenCvt));
1837 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1838
1839 // PHASE 2: Main Interleave Loop
1840
1841 // The number of MFMAs per iteration
1842 unsigned MFMARatio =
1843 MFMAEnablement > ExpRequirement ? MFMAEnablement / ExpRequirement : 1;
1844 // The number of Exps per iteration
1845 unsigned ExpRatio =
1846 MFMAEnablement > ExpRequirement ? 1 : ExpRequirement / MFMAEnablement;
1847 // The reamaining Exps
1848 unsigned RemainingExp = TransPipeCount > (2 * ExpRequirement)
1849 ? TransPipeCount - (2 * ExpRequirement)
1850 : 0;
1851 unsigned ExpLoopCount = RemainingExp / ExpRatio;
1852 // In loop MFMAs
1853 unsigned MFMAInLoop = MFMAPipeCount > (MFMAEnablement * 2)
1854 ? MFMAPipeCount - (MFMAEnablement * 2)
1855 : 0;
1856 unsigned MFMALoopCount = MFMAInLoop / MFMARatio;
1857 unsigned VALUOps =
1858 AddPipeCount < MFMAPipeCount ? 1 : AddPipeCount / MFMAPipeCount;
1859 unsigned LoopSize = std::min(ExpLoopCount, MFMALoopCount);
1860
1861 for (unsigned I = 0; I < LoopSize; I++) {
1862 if (!(I * ExpRatio % ExpRequirement))
1863 incrementTransPosition();
1864
1865 // Round N MFMA
1866 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1867 SchedGroupMask::MFMA, MFMARatio, PipelineSyncID, DAG, TII);
1868 if (!IsPostRA && MFMAChains)
1869 SG->addRule(std::make_shared<IsExactMFMA>(
1870 PositionInChainForMFMA, MFMAChainSeeds[MFMAChainForMFMA], TII,
1871 SG->getSGID(), true));
1872 else
1873 SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1874 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1875 incrementMFMAPosition();
1876
1877 if (UsesVALU) {
1878 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1879 SchedGroupMask::VALU, VALUOps, PipelineSyncID, DAG, TII);
1880 SG->addRule(std::make_shared<IsPipeAdd>(TII, SG->getSGID()));
1881 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1882 }
1883
1884 if (UsesDSRead && !(I % 4)) {
1885 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1886 SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1887 SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1888 SG->getSGID()));
1889 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1890 }
1891
1892 // CVT, EXP, FMA Interleaving
1893 for (unsigned J = 0; J < ExpRatio; J++) {
1894 auto MFMAOffset = (1 + UsesVALU) * MFMARatio * (I + 1);
1895 auto MaxMFMAOffset =
1896 (1 + UsesVALU) * ExpRequirement * MFMARatio / ExpRatio;
1897
1898 // Round N + 1 CVT
1899 if (UsesCvt) {
1900 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1901 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1902 SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1903 auto BaseDiff = (2 + UsesFMA) * (ExpRequirement - 1) + 1;
1904 auto DSROffset = I / 4 + 1;
1905 auto MaxDSROffset = MaxMFMAOffset / 4;
1906 // TODO: UsesExtraExp
1907 auto ExpOffset = I * ExpRatio + J >= ExpRequirement ? 0 : 1;
1908 auto CurrentOffset = UsesDSRead * std::min(MaxDSROffset, DSROffset) +
1909 std::min(MaxMFMAOffset, MFMAOffset) + BaseDiff +
1910 ExpOffset;
1911 if (HasChainBetweenCvt)
1912 SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1913 CurrentOffset, TII, SG->getSGID()));
1914 else
1915 SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(CurrentOffset, TII,
1916 SG->getSGID()));
1917 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1918 }
1919
1920 // Round N + 3 FMA
1921 if (UsesFMA) {
1922 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1923 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1924 if (!IsPostRA && MFMAChains)
1925 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1926 getNextTransPositionInChain(),
1927 MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(),
1928 true));
1929 else
1930 SG->addRule(std::make_shared<EnablesNthMFMA>(
1931 (((I * ExpRatio + J) / ExpRequirement) + 3) * MFMAEnablement + 1,
1932 TII, SG->getSGID(), true));
1933 SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1934 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1935 }
1936
1937 // Round N + 2 Exp
1938 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1939 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1940 if (!IsPostRA && MFMAChains)
1941 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1942 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1943 true));
1944 else
1945 SG->addRule(std::make_shared<EnablesNthMFMA>(
1946 (((I * ExpRatio + J) / ExpRequirement) + 2) * MFMAEnablement + 1,
1947 TII, SG->getSGID(), true));
1948 SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1949 SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1950 HasChainBetweenCvt));
1951 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1952 }
1953 }
1954
1955 // PHASE 3: Remaining MFMAs
1956 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1957 SchedGroupMask::MFMA, MFMAEnablement * 2, PipelineSyncID, DAG, TII);
1958 SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1959 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1960 return true;
1961}
1962
1963class MFMAExpSimpleInterleaveOpt final : public IGLPStrategy {
1964public:
1965 bool applyIGLPStrategy(
1967 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1969
1970 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1971 AMDGPU::SchedulingPhase Phase) override {
1972 return true;
1973 }
1974
1975 MFMAExpSimpleInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1976 : IGLPStrategy(DAG, TII) {
1977 IsBottomUp = true;
1978 }
1979};
1980
1981bool MFMAExpSimpleInterleaveOpt::applyIGLPStrategy(
1983 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1985 // Count the number of MFMA instructions.
1986 unsigned MFMACount = 0;
1987 for (const MachineInstr &I : *DAG)
1988 if (TII->isMFMAorWMMA(I))
1989 ++MFMACount;
1990
1991 const unsigned PipelineSyncID = 0;
1992 for (unsigned I = 0; I < MFMACount * 3; ++I) {
1993 SchedGroup *SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1994 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1995 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
1996
1997 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1998 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1999 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2000 }
2001
2002 return true;
2003}
2004
2005class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
2006private:
2007 // Whether the DS_READ is a predecessor of first four MFMA in region
2008 class EnablesInitialMFMA final : public InstructionRule {
2009 public:
2010 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2011 SmallVectorImpl<SchedGroup> &SyncPipe) override {
2012 if (!SyncPipe.size())
2013 return false;
2014 int MFMAsFound = 0;
2015 if (!Cache->size()) {
2016 for (auto &Elt : SyncPipe[0].DAG->SUnits) {
2017 if (TII->isMFMAorWMMA(*Elt.getInstr())) {
2018 ++MFMAsFound;
2019 if (MFMAsFound > 4)
2020 break;
2021 Cache->push_back(&Elt);
2022 }
2023 }
2024 }
2025
2026 auto *DAG = SyncPipe[0].DAG;
2027 for (auto &Elt : *Cache) {
2028 if (DAG->IsReachable(Elt, const_cast<SUnit *>(SU)))
2029 return true;
2030 }
2031 return false;
2032 }
2033
2034 EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,
2035 bool NeedsCache = false)
2036 : InstructionRule(TII, SGID, NeedsCache) {}
2037 };
2038
2039 // Whether the MI is a V_PERM and is a predecessor of a common DS_WRITE
2040 class IsPermForDSW final : public InstructionRule {
2041 public:
2042 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2043 SmallVectorImpl<SchedGroup> &SyncPipe) override {
2044 auto *MI = SU->getInstr();
2045 if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)
2046 return false;
2047
2048 bool FitsInGroup = false;
2049 // Does the VALU have a DS_WRITE successor
2050 if (!Collection.size()) {
2051 for (auto &Succ : SU->Succs) {
2052 SUnit *SuccUnit = Succ.getSUnit();
2053 if (TII->isDS(*SuccUnit->getInstr()) &&
2054 SuccUnit->getInstr()->mayStore()) {
2055 Cache->push_back(SuccUnit);
2056 FitsInGroup = true;
2057 }
2058 }
2059 return FitsInGroup;
2060 }
2061
2062 // Does the VALU have a DS_WRITE successor that is the same as other
2063 // VALU already in the group. The V_PERMs will all share 1 DS_W succ
2064 return llvm::any_of(*Cache, [&SU](SUnit *Elt) {
2065 return llvm::any_of(SU->Succs, [&Elt](const SDep &ThisSucc) {
2066 return ThisSucc.getSUnit() == Elt;
2067 });
2068 });
2069 }
2070
2071 IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
2072 : InstructionRule(TII, SGID, NeedsCache) {}
2073 };
2074
2075 // Whether the SU is a successor of any element in previous SchedGroup
2076 class IsSuccOfPrevGroup final : public InstructionRule {
2077 public:
2078 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2079 SmallVectorImpl<SchedGroup> &SyncPipe) override {
2080 SchedGroup *OtherGroup = nullptr;
2081 for (auto &PipeSG : SyncPipe) {
2082 if ((unsigned)PipeSG.getSGID() == SGID - 1) {
2083 OtherGroup = &PipeSG;
2084 }
2085 }
2086
2087 if (!OtherGroup)
2088 return false;
2089 if (!OtherGroup->Collection.size())
2090 return true;
2091
2092 // Does the previous VALU have this DS_Write as a successor
2093 return any_of(OtherGroup->Collection, [&SU](SUnit *Elt) {
2094 return any_of(Elt->Succs,
2095 [&SU](SDep &Succ) { return Succ.getSUnit() == SU; });
2096 });
2097 }
2098 IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,
2099 bool NeedsCache = false)
2100 : InstructionRule(TII, SGID, NeedsCache) {}
2101 };
2102
2103 // Whether the combined load width of group is 128 bits
2104 class VMEMSize final : public InstructionRule {
2105 public:
2106 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2107 SmallVectorImpl<SchedGroup> &SyncPipe) override {
2108 auto *MI = SU->getInstr();
2109 if (MI->getOpcode() == TargetOpcode::BUNDLE)
2110 return false;
2111 if (!Collection.size())
2112 return true;
2113
2114 int NumBits = 0;
2115
2116 auto TRI = TII->getRegisterInfo();
2117 auto &MRI = MI->getMF()->getRegInfo();
2118 for (auto &Elt : Collection) {
2119 auto Op = Elt->getInstr()->getOperand(0);
2120 auto Size =
2121 TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(MRI, Op));
2122 NumBits += Size;
2123 }
2124
2125 if (NumBits < 128) {
2126 assert(TII->isVMEM(*MI) && MI->mayLoad());
2127 if (NumBits + TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(
2128 MRI, MI->getOperand(0))) <=
2129 128)
2130 return true;
2131 }
2132
2133 return false;
2134 }
2135
2136 VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
2137 : InstructionRule(TII, SGID, NeedsCache) {}
2138 };
2139
2140 /// Whether the SU shares a V_PERM predecessor with any SU in the SchedGroup
2141 /// that is \p Distance steps away
2142 class SharesPredWithPrevNthGroup final : public InstructionRule {
2143 private:
2144 unsigned Distance = 1;
2145
2146 public:
2147 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2148 SmallVectorImpl<SchedGroup> &SyncPipe) override {
2149 SchedGroup *OtherGroup = nullptr;
2150 if (!SyncPipe.size())
2151 return false;
2152
2153 if (!Cache->size()) {
2154
2155 for (auto &PipeSG : SyncPipe) {
2156 if ((unsigned)PipeSG.getSGID() == SGID - Distance) {
2157 OtherGroup = &PipeSG;
2158 }
2159 }
2160
2161 if (!OtherGroup)
2162 return false;
2163 if (!OtherGroup->Collection.size())
2164 return true;
2165
2166 for (auto &OtherEle : OtherGroup->Collection) {
2167 for (auto &Pred : OtherEle->Preds) {
2168 if (Pred.getSUnit()->getInstr()->getOpcode() ==
2169 AMDGPU::V_PERM_B32_e64)
2170 Cache->push_back(Pred.getSUnit());
2171 }
2172 }
2173
2174 // If the other group has no PERM preds, then this group won't share any
2175 if (!Cache->size())
2176 return false;
2177 }
2178
2179 auto *DAG = SyncPipe[0].DAG;
2180 // Does the previous DS_WRITE share a V_PERM predecessor with this
2181 // VMEM_READ
2182 return llvm::any_of(*Cache, [&SU, &DAG](SUnit *Elt) {
2183 return DAG->IsReachable(const_cast<SUnit *>(SU), Elt);
2184 });
2185 }
2186 SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
2187 unsigned SGID, bool NeedsCache = false)
2188 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
2189 };
2190
2191public:
2192 bool applyIGLPStrategy(
2194 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2196
2197 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
2198 AMDGPU::SchedulingPhase Phase) override {
2199 return true;
2200 }
2201
2202 MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
2203 : IGLPStrategy(DAG, TII) {
2204 IsBottomUp = false;
2205 }
2206};
2207
2208static unsigned DSWCount = 0;
2209static unsigned DSWWithPermCount = 0;
2210static unsigned DSWWithSharedVMEMCount = 0;
2211
2212bool MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
2213 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2214 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2216 unsigned MFMACount = 0;
2217 unsigned DSRCount = 0;
2218
2219 bool IsInitial = Phase == AMDGPU::SchedulingPhase::Initial;
2220
2221 assert((!IsInitial || (DSWCount == 0 && DSWWithPermCount == 0 &&
2222 DSWWithSharedVMEMCount == 0)) &&
2223 "DSWCounters should be zero in pre-RA scheduling!");
2224 SmallVector<SUnit *, 6> DSWithPerms;
2225 for (auto &SU : DAG->SUnits) {
2226 auto *I = SU.getInstr();
2227 if (TII->isMFMAorWMMA(*I))
2228 ++MFMACount;
2229 else if (TII->isDS(*I)) {
2230 if (I->mayLoad())
2231 ++DSRCount;
2232 else if (I->mayStore() && IsInitial) {
2233 ++DSWCount;
2234 for (auto Pred : SU.Preds) {
2235 if (Pred.getSUnit()->getInstr()->getOpcode() ==
2236 AMDGPU::V_PERM_B32_e64) {
2237 DSWithPerms.push_back(&SU);
2238 break;
2239 }
2240 }
2241 }
2242 }
2243 }
2244
2245 if (IsInitial) {
2246 DSWWithPermCount = DSWithPerms.size();
2247 auto *I = DSWithPerms.begin();
2248 auto *E = DSWithPerms.end();
2249
2250 // Get the count of DS_WRITES with V_PERM predecessors which
2251 // have loop carried dependencies (WAR) on the same VMEM_READs.
2252 // We consider partial overlap as a miss -- in other words,
2253 // for a given DS_W, we only consider another DS_W as matching
2254 // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
2255 // for every V_PERM pred of this DS_W.
2256 DenseMap<MachineInstr *, SUnit *> VMEMLookup;
2258 for (; I != E; I++) {
2259 SUnit *Cand = nullptr;
2260 bool MissedAny = false;
2261 for (auto &Pred : (*I)->Preds) {
2262 if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)
2263 continue;
2264
2265 if (Cand && llvm::is_contained(Counted, Cand))
2266 break;
2267
2268 for (auto &Succ : Pred.getSUnit()->Succs) {
2269 auto *MI = Succ.getSUnit()->getInstr();
2270 if (!TII->isVMEM(*MI) || !MI->mayLoad())
2271 continue;
2272
2273 if (MissedAny || !VMEMLookup.size()) {
2274 MissedAny = true;
2275 VMEMLookup[MI] = *I;
2276 continue;
2277 }
2278
2279 auto [It, Inserted] = VMEMLookup.try_emplace(MI, *I);
2280 if (Inserted) {
2281 MissedAny = true;
2282 continue;
2283 }
2284
2285 Cand = It->second;
2286 if (llvm::is_contained(Counted, Cand)) {
2287 MissedAny = true;
2288 break;
2289 }
2290 }
2291 }
2292 if (!MissedAny && Cand) {
2293 DSWWithSharedVMEMCount += 2;
2294 Counted.push_back(Cand);
2295 Counted.push_back(*I);
2296 }
2297 }
2298 }
2299
2300 assert(DSWWithSharedVMEMCount <= DSWWithPermCount);
2301 SchedGroup *SG;
2302 unsigned PipelineSyncID = 0;
2303 // For kernels with V_PERM, there are enough VALU to mix in between MFMAs
2304 if (DSWWithPermCount) {
2305 for (unsigned I = 0; I < MFMACount; I++) {
2306 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2307 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2308 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2309
2310 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2311 SchedGroupMask::VALU, 2, PipelineSyncID, DAG, TII);
2312 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2313 }
2314 }
2315
2316 PipelineSyncID = 1;
2317 // Phase 1: Break up DS_READ and MFMA clusters.
2318 // First DS_READ to make ready initial MFMA, then interleave MFMA with DS_READ
2319 // prefetch
2320
2321 // Make ready initial MFMA
2322 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2323 SchedGroupMask::DS_READ, 4, PipelineSyncID, DAG, TII);
2324 SG->addRule(std::make_shared<EnablesInitialMFMA>(TII, SG->getSGID(), true));
2325 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2326
2327 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2328 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2329 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2330
2331 // Interleave MFMA with DS_READ prefetch
2332 for (unsigned I = 4; I < DSRCount; ++I) {
2333 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2334 SchedGroupMask::DS_READ, 1, PipelineSyncID, DAG, TII);
2335 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2336
2337 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2338 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2339 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2340 }
2341
2342 // Phase 2a: Loop carried dependency with V_PERM
2343 // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2344 // depend on. Interleave MFMA to keep XDL unit busy throughout.
2345 for (unsigned I = DSWWithSharedVMEMCount; I < DSWWithPermCount; ++I) {
2346 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2347 SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2348 SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2349 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2350
2351 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2352 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2353 SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2354 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2355
2356 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2357 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2358 SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2359 1, TII, SG->getSGID(), true));
2360 SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2361 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2362
2363 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2364 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2365 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2366
2367 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2368 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2369 SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2370 3, TII, SG->getSGID(), true));
2371 SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2372 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2373
2374 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2375 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2376 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2377 }
2378
2379 // Phase 2b: Loop carried dependency without V_PERM
2380 // Schedule DS_WRITE as closely as possible to the VMEM_READ they depend on.
2381 // Interleave MFMA to keep XDL unit busy throughout.
2382 for (unsigned I = DSWWithPermCount; I < DSWCount; I++) {
2383 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2384 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2385 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2386
2387 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2388 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2389 SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2390 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2391
2392 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2393 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2394 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2395 }
2396
2397 // Phase 2c: Loop carried dependency with V_PERM, VMEM_READs are
2398 // ultimately used by two DS_WRITE
2399 // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2400 // depend on. Interleave MFMA to keep XDL unit busy throughout.
2401
2402 for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {
2403 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2404 SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2405 SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2406 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2407
2408 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2409 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2410 SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2411 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2412
2413 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2414 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2415 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2416
2417 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2418 SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2419 SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2420 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2421
2422 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2423 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2424 SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2425 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2426
2427 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2428 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2429 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2430
2431 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2432 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2433 SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2434 2, TII, SG->getSGID(), true));
2435 SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2436 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2437
2438 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2439 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2440 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2441
2442 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2443 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2444 SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2445 4, TII, SG->getSGID(), true));
2446 SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2447 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2448
2449 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2450 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2451 SG->findCandidateSUnits(SyncedInstrs[SG->getSyncID()]);
2452 }
2453
2454 return true;
2455}
2456
2457static std::unique_ptr<IGLPStrategy>
2458createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
2459 const SIInstrInfo *TII) {
2460 switch (ID) {
2461 case MFMASmallGemmOptID:
2462 return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
2464 return std::make_unique<MFMASmallGemmSingleWaveOpt>(DAG, TII);
2466 return std::make_unique<MFMAExpInterleaveOpt>(DAG, TII);
2468 return std::make_unique<MFMAExpSimpleInterleaveOpt>(DAG, TII);
2469 }
2470
2471 llvm_unreachable("Unknown IGLPStrategyID");
2472}
2473
2474class IGroupLPDAGMutation : public ScheduleDAGMutation {
2475private:
2476 const SIInstrInfo *TII;
2477
2478 ScheduleDAGMI *DAG;
2479
2480 // Organize lists of SchedGroups by their SyncID. SchedGroups /
2481 // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
2482 // between then.
2483 DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
2484
2485 // Used to track instructions that can be mapped to multiple sched groups
2486 DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
2487
2488 // Add DAG edges that enforce SCHED_BARRIER ordering.
2489 void addSchedBarrierEdges(SUnit &SU);
2490
2491 // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
2492 // not be reordered accross the SCHED_BARRIER. This is used for the base
2493 // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
2494 // SCHED_BARRIER will always block all instructions that can be classified
2495 // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
2496 // and may only synchronize with some SchedGroups. Returns the inverse of
2497 // Mask. SCHED_BARRIER's mask describes which instruction types should be
2498 // allowed to be scheduled across it. Invert the mask to get the
2499 // SchedGroupMask of instructions that should be barred.
2500 SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
2501
2502 // Create SchedGroups for a SCHED_GROUP_BARRIER.
2503 void initSchedGroupBarrierPipelineStage(
2504 std::vector<SUnit>::reverse_iterator RIter);
2505
2506 bool initIGLPOpt(SUnit &SU);
2507
2508public:
2509 void apply(ScheduleDAGInstrs *DAGInstrs) override;
2510
2511 // The order in which the PipelineSolver should process the candidate
2512 // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
2513 // created SchedGroup first, and will consider that as the ultimate
2514 // predecessor group when linking. TOP_DOWN instead links and processes the
2515 // first created SchedGroup first.
2516 bool IsBottomUp = true;
2517
2518 // The scheduling phase this application of IGLP corresponds with.
2519 AMDGPU::SchedulingPhase Phase = AMDGPU::SchedulingPhase::Initial;
2520
2521 IGroupLPDAGMutation() = default;
2522 IGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) : Phase(Phase) {}
2523};
2524
2525unsigned SchedGroup::NumSchedGroups = 0;
2526
2527bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
2528 return A != B && DAG->addEdge(B, SDep(A, SDep::Artificial));
2529}
2530
2531bool SchedGroup::canAddMI(const MachineInstr &MI) const {
2532 bool Result = false;
2533 if (MI.isMetaInstruction())
2534 Result = false;
2535
2536 else if (MI.isInlineAsm()) {
2537 const SIRegisterInfo &TRI = TII->getRegisterInfo();
2538 auto &MRI = MI.getParent()->getParent()->getRegInfo();
2539 bool SGPR_used = false, SGPR_big_def = false, VGPR_used = false,
2540 VMFMA_used = false, VReg32_used = false, MayLoad = MI.mayLoad(),
2541 MayStore = MI.mayStore();
2542 for (const MachineOperand &Operand : MI.operands())
2543 if (Operand.isReg()) {
2544 const TargetRegisterClass &RegClass =
2545 *TRI.getRegClassForOperandReg(MRI, Operand);
2546 if (TRI.hasVGPRs(&RegClass)) {
2547 VGPR_used = true;
2548 if (Operand.isUse() && TRI.getRegSizeInBits(RegClass) == 32)
2549 VReg32_used = true;
2550 }
2551 // > 128 bit registers are usually only used by MFMA instructions, so
2552 // we're using that as a heuristic to guess the schedule group mask of
2553 // the inline asm.
2554 if (TRI.hasAGPRs(&RegClass) || TRI.getRegSizeInBits(RegClass) > 128)
2555 VMFMA_used = true;
2556 if (TRI.hasSGPRs(&RegClass))
2557 SGPR_used = true;
2558 if (TRI.getRegSizeInBits(RegClass) > 64 && Operand.isDef())
2559 SGPR_big_def = true;
2560 }
2561
2562 typedef std::underlying_type_t<SchedGroupMask> SGMask_t;
2563 SGMask_t InlineAsmMask = 0;
2564 if (VGPR_used && !VMFMA_used && !MayLoad && !MayStore)
2565 InlineAsmMask |= (SGMask_t)SchedGroupMask::VALU;
2566 if (SGPR_used && !VGPR_used && !MayLoad && !MayStore)
2567 InlineAsmMask |= (SGMask_t)SchedGroupMask::SALU;
2568 if (VMFMA_used)
2569 InlineAsmMask |= (SGMask_t)SchedGroupMask::MFMA;
2570 if (VGPR_used && MayLoad)
2571 InlineAsmMask |= (SGMask_t)(VReg32_used ? SchedGroupMask::DS_READ
2572 : SchedGroupMask::VMEM_READ);
2573 if (VGPR_used && MayStore)
2574 InlineAsmMask |= (SGMask_t)(VReg32_used ? SchedGroupMask::DS_WRITE
2575 : SchedGroupMask::VMEM_WRITE);
2576 if (SGPR_big_def)
2577 InlineAsmMask |= (SGMask_t)SchedGroupMask::DS_READ;
2578 if (InlineAsmMask & (SGMask_t)SchedGroupMask::VALU ||
2579 InlineAsmMask & (SGMask_t)SchedGroupMask::SALU)
2580 InlineAsmMask |= (SGMask_t)SchedGroupMask::ALU;
2581 if (InlineAsmMask & (SGMask_t)SchedGroupMask::DS_READ ||
2582 InlineAsmMask & (SGMask_t)SchedGroupMask::DS_WRITE)
2583 InlineAsmMask |= (SGMask_t)SchedGroupMask::DS;
2584 if (InlineAsmMask & (SGMask_t)SchedGroupMask::VMEM_READ ||
2585 InlineAsmMask & (SGMask_t)SchedGroupMask::VMEM_WRITE)
2586 InlineAsmMask |= (SGMask_t)SchedGroupMask::VMEM;
2587
2588 Result = ((SGMask_t)SGMask & InlineAsmMask) != 0;
2589 }
2590
2591 else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
2592 (TII->isVALU(MI, /*AllowLDSDMA=*/true) || TII->isMFMAorWMMA(MI) ||
2593 TII->isSALU(MI) || TII->isTRANS(MI)))
2594 Result = !MI.mayLoadOrStore();
2595
2596 else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
2597 TII->isVALU(MI, /*AllowLDSDMA=*/true) && !TII->isMFMAorWMMA(MI) &&
2598 !TII->isTRANS(MI) && !TII->isLDSDMA(MI)) {
2599 // Some memory instructions may be marked as VALU (e.g. BUFFER_LOAD_*_LDS).
2600 // For our purposes, these shall not be classified as VALU as this results
2601 // in unexpected behavior.
2602 Result = !MI.mayLoadOrStore();
2603 }
2604
2605 else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
2606 TII->isSALU(MI))
2607 Result = !MI.mayLoadOrStore();
2608
2609 else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
2610 TII->isMFMAorWMMA(MI))
2611 Result = true;
2612
2613 else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
2614 (TII->isVMEM(MI) || TII->isLDSDMA(MI)))
2615 Result = true;
2616
2617 else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
2618 MI.mayLoad() && TII->isVMEM(MI))
2619 Result = true;
2620
2621 else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
2622 MI.mayStore() && TII->isVMEM(MI))
2623 Result = true;
2624
2625 else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
2626 (TII->isDS(MI) || TII->isLDSDMA(MI)))
2627 Result = true;
2628
2629 else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
2630 MI.mayLoad() && TII->isDS(MI))
2631 Result = true;
2632
2633 else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
2634 MI.mayStore() && TII->isDS(MI))
2635 Result = true;
2636
2637 else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&
2638 TII->isTRANS(MI))
2639 Result = true;
2640
2641 else if (((SGMask & SchedGroupMask::LDSDMA) != SchedGroupMask::NONE) &&
2642 TII->isLDSDMA(MI))
2643 Result = true;
2644
2645 LLVM_DEBUG(
2646 dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
2647 << (Result ? " could classify " : " unable to classify ") << MI);
2648
2649 return Result;
2650}
2651
2652int SchedGroup::link(SUnit &SU, bool MakePred,
2653 std::list<std::pair<SUnit *, SUnit *>> &AddedEdges) {
2654 int MissedEdges = 0;
2655 for (auto *A : Collection) {
2656 SUnit *B = &SU;
2657 if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2658 continue;
2659 if (MakePred)
2660 std::swap(A, B);
2661
2662 if (DAG->IsReachable(B, A))
2663 continue;
2664
2665 // tryAddEdge returns false if there is a dependency that makes adding
2666 // the A->B edge impossible, otherwise it returns true;
2667 bool Added = tryAddEdge(A, B);
2668 if (Added)
2669 AddedEdges.emplace_back(A, B);
2670 else
2671 ++MissedEdges;
2672 }
2673
2674 return MissedEdges;
2675}
2676
2677void SchedGroup::link(SUnit &SU, bool MakePred) {
2678 for (auto *A : Collection) {
2679 SUnit *B = &SU;
2680 if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2681 continue;
2682 if (MakePred)
2683 std::swap(A, B);
2684
2685 tryAddEdge(A, B);
2686 }
2687}
2688
2689void SchedGroup::link(SUnit &SU,
2690 function_ref<bool(const SUnit *A, const SUnit *B)> P) {
2691 for (auto *A : Collection) {
2692 SUnit *B = &SU;
2693 if (P(A, B))
2694 std::swap(A, B);
2695
2696 tryAddEdge(A, B);
2697 }
2698}
2699
2700void SchedGroup::link(SchedGroup &OtherGroup) {
2701 for (auto *B : OtherGroup.Collection)
2702 link(*B);
2703}
2704
2705bool SchedGroup::canAddSU(SUnit &SU) const {
2706 MachineInstr &MI = *SU.getInstr();
2707 if (MI.getOpcode() != TargetOpcode::BUNDLE)
2708 return canAddMI(MI);
2709
2710 // Special case for bundled MIs.
2711 const MachineBasicBlock *MBB = MI.getParent();
2712 MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
2713 while (E != MBB->end() && E->isBundledWithPred())
2714 ++E;
2715
2716 // Return true if all of the bundled MIs can be added to this group.
2717 return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
2718}
2719
2720template <class T>
2721void SchedGroup::findCandidateSUnits(T Begin, T End,
2722 SUnitsToCandidateSGsMap &SyncedInstrs) {
2723 for (SUnit &SU : make_range(Begin, End)) {
2724 if (canAddSU(SU))
2725 SyncedInstrs[&SU].push_back(SGID);
2726 }
2727}
2728
2729void SchedGroup::findCandidateSUnits(SUnitsToCandidateSGsMap &SyncedInstrs) {
2730 findCandidateSUnits(DAG->SUnits.rbegin(), DAG->SUnits.rend(), SyncedInstrs);
2731}
2732
2733void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
2734 const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
2735 if (!TSchedModel || DAGInstrs->SUnits.empty())
2736 return;
2737
2738 LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
2739 const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
2740 TII = ST.getInstrInfo();
2741 DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
2742 SyncedSchedGroups.clear();
2743 SyncedInstrs.clear();
2744 bool FoundSB = false;
2745 bool FoundIGLP = false;
2746 bool ShouldApplyIGLP = false;
2747 for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
2748 unsigned Opc = R->getInstr()->getOpcode();
2749 // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
2750 if (Opc == AMDGPU::SCHED_BARRIER) {
2751 addSchedBarrierEdges(*R);
2752 FoundSB = true;
2753 } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
2754 initSchedGroupBarrierPipelineStage(R);
2755 FoundSB = true;
2756 } else if (Opc == AMDGPU::IGLP_OPT) {
2757 if (!FoundSB && !FoundIGLP) {
2758 FoundIGLP = true;
2759 ShouldApplyIGLP = initIGLPOpt(*R);
2760 }
2761 }
2762 }
2763
2764 if (FoundSB || (FoundIGLP && ShouldApplyIGLP)) {
2765 PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
2766 // PipelineSolver performs the mutation by adding the edges it
2767 // determined as the best
2768 PS.solve();
2769 return;
2770 }
2771}
2772
2773void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
2774 MachineInstr &MI = *SchedBarrier.getInstr();
2775 assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
2776 LLVM_DEBUG(dbgs() << "Building SchedGroup for SchedBarrier with Mask: "
2777 << MI.getOperand(0).getImm() << "\n");
2778 auto InvertedMask =
2779 invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
2780 SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
2781
2782 for (SUnit &SU : DAG->SUnits)
2783 if (SG.canAddSU(SU))
2784 SG.add(SU);
2785
2786 // Preserve original instruction ordering relative to the SCHED_BARRIER.
2787 SG.link(
2788 SchedBarrier,
2789 (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
2790 const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
2791}
2792
2793SchedGroupMask
2794IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
2795 // Invert mask and erase bits for types of instructions that are implied to be
2796 // allowed past the SCHED_BARRIER.
2797 SchedGroupMask InvertedMask = ~Mask;
2798
2799 // ALU implies VALU, SALU, MFMA, TRANS.
2800 if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
2801 InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &
2802 ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS;
2803 // VALU, SALU, MFMA, TRANS implies ALU.
2804 else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
2805 (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
2806 (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||
2807 (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)
2808 InvertedMask &= ~SchedGroupMask::ALU;
2809
2810 // VMEM implies VMEM_READ, VMEM_WRITE, LDSDMA.
2811 if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
2812 InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE &
2813 ~SchedGroupMask::LDSDMA;
2814 // VMEM_READ, VMEM_WRITE, LDSDMA implies VMEM.
2815 else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
2816 (InvertedMask & SchedGroupMask::VMEM_WRITE) ==
2817 SchedGroupMask::NONE ||
2818 (InvertedMask & SchedGroupMask::LDSDMA) == SchedGroupMask::NONE)
2819 InvertedMask &= ~SchedGroupMask::VMEM;
2820
2821 // DS implies DS_READ, DS_WRITE, LDSDMA.
2822 if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
2823 InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE &
2824 ~SchedGroupMask::LDSDMA;
2825 // DS_READ, DS_WRITE implies DS.
2826 else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
2827 (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
2828 InvertedMask &= ~SchedGroupMask::DS;
2829
2830 LLVM_DEBUG(dbgs() << "After Inverting, SchedGroup Mask: " << (int)InvertedMask
2831 << "\n");
2832
2833 return InvertedMask;
2834}
2835
2836void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
2837 std::vector<SUnit>::reverse_iterator RIter) {
2838 MachineInstr &SGB = *RIter->getInstr();
2839 assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
2840 int32_t SGMask = SGB.getOperand(0).getImm();
2841 int32_t Size = SGB.getOperand(1).getImm();
2842 int32_t SyncID = SGB.getOperand(2).getImm();
2843
2844 Size++; // Make room for the SCHED_GROUP_BARRIER instruction
2845 auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
2846 Size, SyncID, DAG, TII);
2847 SG.add(*RIter);
2848 SG.findCandidateSUnits(RIter, SG.DAG->SUnits.rend(),
2849 SyncedInstrs[SG.getSyncID()]);
2850}
2851
2852bool IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
2853 IGLPStrategyID StrategyID =
2855 auto S = createIGLPStrategy(StrategyID, DAG, TII);
2856 if (!S->shouldApplyStrategy(DAG, Phase))
2857 return false;
2858
2859 IsBottomUp = S->IsBottomUp;
2860 return S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups, Phase);
2861}
2862
2863} // namespace
2864
2865/// \p Phase specifes whether or not this is a reentry into the
2866/// IGroupLPDAGMutation. Since there may be multiple scheduling passes on the
2867/// same scheduling region (e.g. pre and post-RA scheduling / multiple
2868/// scheduling "phases"), we can reenter this mutation framework more than once
2869/// for a given region.
2870std::unique_ptr<ScheduleDAGMutation>
2872 return std::make_unique<IGroupLPDAGMutation>(Phase);
2873}
aarch64 falkor hwpf fix Falkor HW Prefetch Fix Late Phase
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Provides AMDGPU specific target descriptions.
AMDGPU Rewrite AGPR Copy MFMA
MachineBasicBlock & MBB
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
This file defines the DenseMap class.
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
static std::pair< Value *, APInt > getMask(Value *WideMask, unsigned Factor, ElementCount LeafValueEC)
#define I(x, y, z)
Definition MD5.cpp:57
Register const TargetRegisterInfo * TRI
#define T
#define P(N)
Interface definition for SIInstrInfo.
#define LLVM_DEBUG(...)
Definition Debug.h:119
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
size_t size() const
Get the array size.
Definition ArrayRef.h:141
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:299
unsigned size() const
Definition DenseMap.h:172
Implements a dense probed hash-table based set.
Definition DenseSet.h:281
const HexagonRegisterInfo & getRegisterInfo() const
Instructions::iterator instr_iterator
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
bool mayStore(QueryType Type=AnyInBundle) const
Return true if this instruction could possibly modify memory.
const MachineOperand & getOperand(unsigned i) const
int64_t getImm() const
Scheduling dependency.
Definition ScheduleDAG.h:52
SUnit * getSUnit() const
@ Data
Regular data dependence (aka true-dependence).
Definition ScheduleDAG.h:56
@ Artificial
Arbitrary strong DAG edge (no real dependence).
Definition ScheduleDAG.h:75
Scheduling unit. This is a node in the scheduling DAG.
unsigned NodeNum
Entry # of node in the node vector.
LLVM_ABI void removePred(const SDep &D)
Removes the specified edge as a pred of the current node if it exists.
SmallVector< SDep, 4 > Succs
All sunit successors.
SmallVector< SDep, 4 > Preds
All sunit predecessors.
MachineInstr * getInstr() const
Returns the representative MachineInstr for this SUnit.
A ScheduleDAG for scheduling lists of MachineInstr.
const TargetSchedModel * getSchedModel() const
Gets the machine model for instruction scheduling.
bool addEdge(SUnit *SuccSU, const SDep &PredDep)
Add a DAG edge to the given SU with the given predecessor dependence data.
bool IsReachable(SUnit *SU, SUnit *TargetSU)
IsReachable - Checks if SU is reachable from TargetSU.
void dump() const override
ScheduleDAGMI is an implementation of ScheduleDAGInstrs that simply schedules machine instructions ac...
std::vector< SUnit > SUnits
The scheduling units.
MachineFunction & MF
Machine function.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool contains(const_arg_type_t< ValueT > V) const
Check if the set contains the given element.
Definition DenseSet.h:182
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
IGLPStrategyID
Operand 0 immediate for IGLP_OPT pseudo instructions.
@ MFMASmallGemmSingleWaveOptID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
void apply(Opt *O, const Mod &M, const Mods &... Ms)
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1669
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
std::unique_ptr< ScheduleDAGMutation > createIGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase)
Phase specifes whether or not this is a reentry into the IGroupLPDAGMutation.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1746
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
bool none_of(R &&Range, UnaryPredicate P)
Provide wrappers to std::none_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1753
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
FormattedNumber format_hex(uint64_t N, unsigned Width, bool Upper=false)
format_hex - Output N as a fixed width hexadecimal.
Definition Format.h:156
DWARFExpression::Operation Op
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2019
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1772
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1947
MCRegisterClass TargetRegisterClass
Definition FastISel.h:58
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:862
Function object to check whether the second component of a container supported by std::get (like std:...
Definition STLExtras.h:1448