LLVM 19.0.0git
RISCVFoldMasks.cpp
Go to the documentation of this file.
1//===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// This pass performs various peephole optimisations that fold masks into vector
10// pseudo instructions after instruction selection.
11//
12// Currently it converts
13// PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14// ->
15// PseudoVMV_V_V %false, %true, %vl, %sew
16//
17//===---------------------------------------------------------------------===//
18
19#include "RISCV.h"
20#include "RISCVISelDAGToDAG.h"
21#include "RISCVSubtarget.h"
26
27using namespace llvm;
28
29#define DEBUG_TYPE "riscv-fold-masks"
30
31namespace {
32
33class RISCVFoldMasks : public MachineFunctionPass {
34public:
35 static char ID;
36 const TargetInstrInfo *TII;
39 RISCVFoldMasks() : MachineFunctionPass(ID) {}
40
41 bool runOnMachineFunction(MachineFunction &MF) override;
44 MachineFunctionProperties::Property::IsSSA);
45 }
46
47 StringRef getPassName() const override { return "RISC-V Fold Masks"; }
48
49private:
50 bool convertToUnmasked(MachineInstr &MI) const;
51 bool convertVMergeToVMv(MachineInstr &MI) const;
52
53 bool isAllOnesMask(const MachineInstr *MaskDef) const;
54
55 /// Maps uses of V0 to the corresponding def of V0.
57};
58
59} // namespace
60
61char RISCVFoldMasks::ID = 0;
62
63INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)
64
65bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
66 assert(MaskDef && MaskDef->isCopy() &&
67 MaskDef->getOperand(0).getReg() == RISCV::V0);
68 Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
69 if (!SrcReg.isVirtual())
70 return false;
71 MaskDef = MRI->getVRegDef(SrcReg);
72 if (!MaskDef)
73 return false;
74
75 // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
76 // undefined behaviour if it's the wrong bitwidth, so we could choose to
77 // assume that it's all-ones? Same applies to its VL.
78 switch (MaskDef->getOpcode()) {
79 case RISCV::PseudoVMSET_M_B1:
80 case RISCV::PseudoVMSET_M_B2:
81 case RISCV::PseudoVMSET_M_B4:
82 case RISCV::PseudoVMSET_M_B8:
83 case RISCV::PseudoVMSET_M_B16:
84 case RISCV::PseudoVMSET_M_B32:
85 case RISCV::PseudoVMSET_M_B64:
86 return true;
87 default:
88 return false;
89 }
90}
91
92// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
93// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
94bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const {
95#define CASE_VMERGE_TO_VMV(lmul) \
96 case RISCV::PseudoVMERGE_VVM_##lmul: \
97 NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
98 break;
99 unsigned NewOpc;
100 switch (MI.getOpcode()) {
101 default:
102 return false;
110 }
111
112 Register MergeReg = MI.getOperand(1).getReg();
113 Register FalseReg = MI.getOperand(2).getReg();
114 // Check merge == false (or merge == undef)
115 if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(MergeReg, MRI) !=
116 TRI->lookThruCopyLike(FalseReg, MRI))
117 return false;
118
119 assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
120 if (!isAllOnesMask(V0Defs.lookup(&MI)))
121 return false;
122
123 MI.setDesc(TII->get(NewOpc));
124 MI.removeOperand(1); // Merge operand
125 MI.tieOperands(0, 1); // Tie false to dest
126 MI.removeOperand(3); // Mask operand
127 MI.addOperand(
129
130 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
131 // register class for the destination and merge operands e.g. VRNoV0 -> VR
132 MRI->recomputeRegClass(MI.getOperand(0).getReg());
133 MRI->recomputeRegClass(MI.getOperand(1).getReg());
134 return true;
135}
136
137bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
139 RISCV::getMaskedPseudoInfo(MI.getOpcode());
140 if (!I)
141 return false;
142
143 if (!isAllOnesMask(V0Defs.lookup(&MI)))
144 return false;
145
146 // There are two classes of pseudos in the table - compares and
147 // everything else. See the comment on RISCVMaskedPseudo for details.
148 const unsigned Opc = I->UnmaskedPseudo;
149 const MCInstrDesc &MCID = TII->get(Opc);
150 [[maybe_unused]] const bool HasPolicyOp =
152 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID);
153#ifndef NDEBUG
154 const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode());
157 "Masked and unmasked pseudos are inconsistent");
158 assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
159#endif
160 (void)HasPolicyOp;
161
162 MI.setDesc(MCID);
163
164 // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
165 unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
166 MI.removeOperand(MaskOpIdx);
167
168 // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
169 // so try and relax it to vr.
170 MRI->recomputeRegClass(MI.getOperand(0).getReg());
171 unsigned PassthruOpIdx = MI.getNumExplicitDefs();
172 if (HasPassthru) {
173 if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister)
174 MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg());
175 } else
176 MI.removeOperand(PassthruOpIdx);
177
178 return true;
179}
180
181bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
182 if (skipFunction(MF.getFunction()))
183 return false;
184
185 // Skip if the vector extension is not enabled.
187 if (!ST.hasVInstructions())
188 return false;
189
190 TII = ST.getInstrInfo();
191 MRI = &MF.getRegInfo();
192 TRI = MRI->getTargetRegisterInfo();
193
194 bool Changed = false;
195
196 // Masked pseudos coming out of isel will have their mask operand in the form:
197 //
198 // $v0:vr = COPY %mask:vr
199 // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
200 //
201 // Because $v0 isn't in SSA, keep track of its definition at each use so we
202 // can check mask operands.
203 for (const MachineBasicBlock &MBB : MF) {
204 const MachineInstr *CurrentV0Def = nullptr;
205 for (const MachineInstr &MI : MBB) {
206 if (MI.readsRegister(RISCV::V0, TRI))
207 V0Defs[&MI] = CurrentV0Def;
208
209 if (MI.definesRegister(RISCV::V0, TRI))
210 CurrentV0Def = &MI;
211 }
212 }
213
214 for (MachineBasicBlock &MBB : MF) {
215 for (MachineInstr &MI : MBB) {
216 Changed |= convertToUnmasked(MI);
217 Changed |= convertVMergeToVMv(MI);
218 }
219 }
220
221 return Changed;
222}
223
224FunctionPass *llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); }
unsigned const MachineRegisterInfo * MRI
aarch64 promote const
MachineBasicBlock & MBB
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition: MD5.cpp:58
unsigned const TargetRegisterInfo * TRI
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
#define CASE_VMERGE_TO_VMV(lmul)
#define DEBUG_TYPE
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
Describe properties that are true of each instruction in the target description file.
Definition: MCInstrDesc.h:198
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
virtual MachineFunctionProperties getRequiredProperties() const
Properties which a MachineFunction may have at a given point in time.
MachineFunctionProperties & set(Property P)
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
Representation of each machine instruction.
Definition: MachineInstr.h:69
static MachineOperand CreateImm(int64_t Val)
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
constexpr bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
Definition: Register.h:91
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
TargetInstrInfo - Interface to description of machine instruction set.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ TAIL_UNDISTURBED_MASK_UNDISTURBED
static bool hasVecPolicyOp(uint64_t TSFlags)
static bool isFirstDefTiedToFirstUse(const MCInstrDesc &Desc)
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
unsigned M1(unsigned Val)
Definition: VE.h:376
FunctionPass * createRISCVFoldMasksPass()