LLVM 17.0.0git
RISCVSExtWRemoval.cpp
Go to the documentation of this file.
1//===-------------- RISCVSExtWRemoval.cpp - MI sext.w Removal -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// This pass removes unneeded sext.w instructions at the MI level. Either
10// because the sign extended bits aren't consumed or because the input was
11// already sign extended by an earlier instruction.
12//
13//===---------------------------------------------------------------------===//
14
15#include "RISCV.h"
17#include "RISCVSubtarget.h"
18#include "llvm/ADT/Statistic.h"
21
22using namespace llvm;
23
24#define DEBUG_TYPE "riscv-sextw-removal"
25
26STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
27STATISTIC(NumTransformedToWInstrs,
28 "Number of instructions transformed to W-ops");
29
30static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
31 cl::desc("Disable removal of sext.w"),
32 cl::init(false), cl::Hidden);
33namespace {
34
35class RISCVSExtWRemoval : public MachineFunctionPass {
36public:
37 static char ID;
38
39 RISCVSExtWRemoval() : MachineFunctionPass(ID) {
41 }
42
43 bool runOnMachineFunction(MachineFunction &MF) override;
44
45 void getAnalysisUsage(AnalysisUsage &AU) const override {
46 AU.setPreservesCFG();
48 }
49
50 StringRef getPassName() const override { return "RISCV sext.w Removal"; }
51};
52
53} // end anonymous namespace
54
55char RISCVSExtWRemoval::ID = 0;
56INITIALIZE_PASS(RISCVSExtWRemoval, DEBUG_TYPE, "RISCV sext.w Removal", false,
57 false)
58
60 return new RISCVSExtWRemoval();
61}
62
63// This function returns true if the machine instruction always outputs a value
64// where bits 63:32 match bit 31.
66 const MachineRegisterInfo &MRI) {
67 uint64_t TSFlags = MI.getDesc().TSFlags;
68
69 // Instructions that can be determined from opcode are marked in tablegen.
71 return true;
72
73 // Special cases that require checking operands.
74 switch (MI.getOpcode()) {
75 // shifting right sufficiently makes the value 32-bit sign-extended
76 case RISCV::SRAI:
77 return MI.getOperand(2).getImm() >= 32;
78 case RISCV::SRLI:
79 return MI.getOperand(2).getImm() > 32;
80 // The LI pattern ADDI rd, X0, imm is sign extended.
81 case RISCV::ADDI:
82 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0;
83 // An ANDI with an 11 bit immediate will zero bits 63:11.
84 case RISCV::ANDI:
85 return isUInt<11>(MI.getOperand(2).getImm());
86 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
87 case RISCV::ORI:
88 return !isUInt<11>(MI.getOperand(2).getImm());
89 // Copying from X0 produces zero.
90 case RISCV::COPY:
91 return MI.getOperand(1).getReg() == RISCV::X0;
92 }
93
94 return false;
95}
96
98 const RISCVInstrInfo &TII,
100
103
104 auto AddRegDefToWorkList = [&](Register SrcReg) {
105 if (!SrcReg.isVirtual())
106 return false;
107 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
108 if (!SrcMI)
109 return false;
110 // Add SrcMI to the worklist.
111 Worklist.push_back(SrcMI);
112 return true;
113 };
114
115 if (!AddRegDefToWorkList(SrcReg))
116 return false;
117
118 while (!Worklist.empty()) {
119 MachineInstr *MI = Worklist.pop_back_val();
120
121 // If we already visited this instruction, we don't need to check it again.
122 if (!Visited.insert(MI).second)
123 continue;
124
125 // If this is a sign extending operation we don't need to look any further.
127 continue;
128
129 // Is this an instruction that propagates sign extend?
130 switch (MI->getOpcode()) {
131 default:
132 // Unknown opcode, give up.
133 return false;
134 case RISCV::COPY: {
135 const MachineFunction *MF = MI->getMF();
136 const RISCVMachineFunctionInfo *RVFI =
138
139 // If this is the entry block and the register is livein, see if we know
140 // it is sign extended.
141 if (MI->getParent() == &MF->front()) {
142 Register VReg = MI->getOperand(0).getReg();
143 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg))
144 continue;
145 }
146
147 Register CopySrcReg = MI->getOperand(1).getReg();
148 if (CopySrcReg == RISCV::X10) {
149 // For a method return value, we check the ZExt/SExt flags in attribute.
150 // We assume the following code sequence for method call.
151 // PseudoCALL @bar, ...
152 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
153 // %0:gpr = COPY $x10
154 //
155 // We use the PseudoCall to look up the IR function being called to find
156 // its return attributes.
157 const MachineBasicBlock *MBB = MI->getParent();
158 auto II = MI->getIterator();
159 if (II == MBB->instr_begin() ||
160 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
161 return false;
162
163 const MachineInstr &CallMI = *(--II);
164 if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal())
165 return false;
166
167 auto *CalleeFn =
168 dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal());
169 if (!CalleeFn)
170 return false;
171
172 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType());
173 if (!IntTy)
174 return false;
175
176 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
177 unsigned BitWidth = IntTy->getBitWidth();
178 if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) ||
179 (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt)))
180 continue;
181 }
182
183 if (!AddRegDefToWorkList(CopySrcReg))
184 return false;
185
186 break;
187 }
188
189 // For these, we just need to check if the 1st operand is sign extended.
190 case RISCV::BCLRI:
191 case RISCV::BINVI:
192 case RISCV::BSETI:
193 if (MI->getOperand(2).getImm() >= 31)
194 return false;
195 [[fallthrough]];
196 case RISCV::REM:
197 case RISCV::ANDI:
198 case RISCV::ORI:
199 case RISCV::XORI:
200 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
201 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
202 // Logical operations use a sign extended 12-bit immediate.
203 if (!AddRegDefToWorkList(MI->getOperand(1).getReg()))
204 return false;
205
206 break;
207 case RISCV::PseudoCCADDW:
208 case RISCV::PseudoCCSUBW:
209 // Returns operand 4 or an ADDW/SUBW of operands 5 and 6. We only need to
210 // check if operand 4 is sign extended.
211 if (!AddRegDefToWorkList(MI->getOperand(4).getReg()))
212 return false;
213 break;
214 case RISCV::REMU:
215 case RISCV::AND:
216 case RISCV::OR:
217 case RISCV::XOR:
218 case RISCV::ANDN:
219 case RISCV::ORN:
220 case RISCV::XNOR:
221 case RISCV::MAX:
222 case RISCV::MAXU:
223 case RISCV::MIN:
224 case RISCV::MINU:
225 case RISCV::PseudoCCMOVGPR:
226 case RISCV::PseudoCCAND:
227 case RISCV::PseudoCCOR:
228 case RISCV::PseudoCCXOR:
229 case RISCV::PHI: {
230 // If all incoming values are sign-extended, the output of AND, OR, XOR,
231 // MIN, MAX, or PHI is also sign-extended.
232
233 // The input registers for PHI are operand 1, 3, ...
234 // The input registers for PseudoCCMOVGPR are 4 and 5.
235 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
236 // The input registers for others are operand 1 and 2.
237 unsigned B = 1, E = 3, D = 1;
238 switch (MI->getOpcode()) {
239 case RISCV::PHI:
240 E = MI->getNumOperands();
241 D = 2;
242 break;
243 case RISCV::PseudoCCMOVGPR:
244 B = 4;
245 E = 6;
246 break;
247 case RISCV::PseudoCCAND:
248 case RISCV::PseudoCCOR:
249 case RISCV::PseudoCCXOR:
250 B = 4;
251 E = 7;
252 break;
253 }
254
255 for (unsigned I = B; I != E; I += D) {
256 if (!MI->getOperand(I).isReg())
257 return false;
258
259 if (!AddRegDefToWorkList(MI->getOperand(I).getReg()))
260 return false;
261 }
262
263 break;
264 }
265
266 case RISCV::VT_MASKC:
267 case RISCV::VT_MASKCN:
268 // Instructions return zero or operand 1. Result is sign extended if
269 // operand 1 is sign extended.
270 if (!AddRegDefToWorkList(MI->getOperand(1).getReg()))
271 return false;
272 break;
273
274 // With these opcode, we can "fix" them with the W-version
275 // if we know all users of the result only rely on bits 31:0
276 case RISCV::SLLI:
277 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
278 if (MI->getOperand(2).getImm() >= 32)
279 return false;
280 [[fallthrough]];
281 case RISCV::ADDI:
282 case RISCV::ADD:
283 case RISCV::LD:
284 case RISCV::LWU:
285 case RISCV::MUL:
286 case RISCV::SUB:
287 if (TII.hasAllWUsers(*MI, MRI)) {
288 FixableDef.insert(MI);
289 break;
290 }
291 return false;
292 }
293 }
294
295 // If we get here, then every node we visited produces a sign extended value
296 // or propagated sign extended values. So the result must be sign extended.
297 return true;
298}
299
300static unsigned getWOp(unsigned Opcode) {
301 switch (Opcode) {
302 case RISCV::ADDI:
303 return RISCV::ADDIW;
304 case RISCV::ADD:
305 return RISCV::ADDW;
306 case RISCV::LD:
307 case RISCV::LWU:
308 return RISCV::LW;
309 case RISCV::MUL:
310 return RISCV::MULW;
311 case RISCV::SLLI:
312 return RISCV::SLLIW;
313 case RISCV::SUB:
314 return RISCV::SUBW;
315 default:
316 llvm_unreachable("Unexpected opcode for replacement with W variant");
317 }
318}
319
320bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) {
321 if (skipFunction(MF.getFunction()) || DisableSExtWRemoval)
322 return false;
323
326 const RISCVInstrInfo &TII = *ST.getInstrInfo();
327
328 if (!ST.is64Bit())
329 return false;
330
331 bool MadeChange = false;
332
333 for (MachineBasicBlock &MBB : MF) {
334 for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) {
335 MachineInstr *MI = &*I++;
336
337 // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
338 if (!RISCV::isSEXT_W(*MI))
339 continue;
340
341 Register SrcReg = MI->getOperand(1).getReg();
342
344
345 // If all users only use the lower bits, this sext.w is redundant.
346 // Or if all definitions reaching MI sign-extend their output,
347 // then sext.w is redundant.
348 if (!TII.hasAllWUsers(*MI, MRI) &&
349 !isSignExtendedW(SrcReg, MRI, TII, FixableDefs))
350 continue;
351
352 Register DstReg = MI->getOperand(0).getReg();
353 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
354 continue;
355
356 // Convert Fixable instructions to their W versions.
357 for (MachineInstr *Fixable : FixableDefs) {
358 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
359 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode())));
360 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap);
361 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap);
362 Fixable->clearFlag(MachineInstr::MIFlag::IsExact);
363 LLVM_DEBUG(dbgs() << " with " << *Fixable);
364 ++NumTransformedToWInstrs;
365 }
366
367 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
368 MRI.replaceRegWith(DstReg, SrcReg);
369 MRI.clearKillFlags(SrcReg);
370 MI->eraseFromParent();
371 ++NumRemovedSExtW;
372 MadeChange = true;
373 }
374 }
375
376 return MadeChange;
377}
unsigned const MachineRegisterInfo * MRI
MachineBasicBlock & MBB
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define LLVM_DEBUG(X)
Definition: Debug.h:101
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition: MD5.cpp:58
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
uint64_t TSFlags
static bool isSignExtendingOpW(const MachineInstr &MI, const MachineRegisterInfo &MRI)
static bool isSignExtendedW(Register SrcReg, const MachineRegisterInfo &MRI, const RISCVInstrInfo &TII, SmallPtrSetImpl< MachineInstr * > &FixableDef)
static cl::opt< bool > DisableSExtWRemoval("riscv-disable-sextw-removal", cl::desc("Disable removal of sext.w"), cl::init(false), cl::Hidden)
#define DEBUG_TYPE
static unsigned getWOp(unsigned Opcode)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Represent the analysis usage information of a pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:265
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
instr_iterator instr_begin()
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
Ty * getInfo()
getInfo - Keep track of various per-function pieces of information for backends that would like to do...
const MachineBasicBlock & front() const
Representation of each machine instruction.
Definition: MachineInstr.h:68
bool isCall(QueryType Type=AnyInBundle) const
Definition: MachineInstr.h:872
const MachineOperand & getOperand(unsigned i) const
Definition: MachineInstr.h:526
const GlobalValue * getGlobal() const
bool isGlobal() const
isGlobal - Tests if this is a MO_GlobalAddress operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
bool isLiveIn(Register Reg) const
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
RISCVMachineFunctionInfo - This class is derived from MachineFunctionInfo and contains private RISCV-...
bool isSExt32Register(Register Reg) const
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
Definition: Register.h:91
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:344
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
bool empty() const
Definition: SmallVector.h:94
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
bool isSEXT_W(const MachineInstr &MI)
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void initializeRISCVSExtWRemovalPass(PassRegistry &)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:184
FunctionPass * createRISCVSExtWRemovalPass()