LLVM 20.0.0git
RISCVOptWInstrs.cpp
Go to the documentation of this file.
1//===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// This pass does some optimizations for *W instructions at the MI level.
10//
11// First it removes unneeded sext.w instructions. Either because the sign
12// extended bits aren't consumed or because the input was already sign extended
13// by an earlier instruction.
14//
15// Then:
16// 1. Unless explicit disabled or the target prefers instructions with W suffix,
17// it removes the -w suffix from opw instructions whenever all users are
18// dependent only on the lower word of the result of the instruction.
19// The cases handled are:
20// * addw because c.add has a larger register encoding than c.addw.
21// * addiw because it helps reduce test differences between RV32 and RV64
22// w/o being a pessimization.
23// * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
24// * slliw because c.slliw doesn't exist and c.slli does
25//
26// 2. Or if explicit enabled or the target prefers instructions with W suffix,
27// it adds the W suffix to the instruction whenever all users are dependent
28// only on the lower word of the result of the instruction.
29// The cases handled are:
30// * add/addi/sub/mul.
31// * slli with imm < 32.
32// * ld/lwu.
33//===---------------------------------------------------------------------===//
34
35#include "RISCV.h"
37#include "RISCVSubtarget.h"
38#include "llvm/ADT/SmallSet.h"
39#include "llvm/ADT/Statistic.h"
42
43using namespace llvm;
44
45#define DEBUG_TYPE "riscv-opt-w-instrs"
46#define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
47
48STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
49STATISTIC(NumTransformedToWInstrs,
50 "Number of instructions transformed to W-ops");
51
52static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
53 cl::desc("Disable removal of sext.w"),
54 cl::init(false), cl::Hidden);
55static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
56 cl::desc("Disable strip W suffix"),
57 cl::init(false), cl::Hidden);
58
59namespace {
60
61class RISCVOptWInstrs : public MachineFunctionPass {
62public:
63 static char ID;
64
65 RISCVOptWInstrs() : MachineFunctionPass(ID) {}
66
67 bool runOnMachineFunction(MachineFunction &MF) override;
68 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
70 bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
72 bool appendWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
74
75 void getAnalysisUsage(AnalysisUsage &AU) const override {
76 AU.setPreservesCFG();
78 }
79
80 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; }
81};
82
83} // end anonymous namespace
84
85char RISCVOptWInstrs::ID = 0;
87 false)
88
90 return new RISCVOptWInstrs();
91}
92
94 unsigned Bits) {
95 const MachineInstr &MI = *UserOp.getParent();
96 unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode());
97
98 if (!MCOpcode)
99 return false;
100
101 const MCInstrDesc &MCID = MI.getDesc();
102 const uint64_t TSFlags = MCID.TSFlags;
103 if (!RISCVII::hasSEWOp(TSFlags))
104 return false;
105 assert(RISCVII::hasVLOp(TSFlags));
106 const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
107
108 if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
109 return false;
110
111 auto NumDemandedBits =
112 RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW);
113 return NumDemandedBits && Bits >= *NumDemandedBits;
114}
115
116// Checks if all users only demand the lower \p OrigBits of the original
117// instruction's result.
118// TODO: handle multiple interdependent transformations
119static bool hasAllNBitUsers(const MachineInstr &OrigMI,
120 const RISCVSubtarget &ST,
121 const MachineRegisterInfo &MRI, unsigned OrigBits) {
122
125
126 Worklist.push_back(std::make_pair(&OrigMI, OrigBits));
127
128 while (!Worklist.empty()) {
129 auto P = Worklist.pop_back_val();
130 const MachineInstr *MI = P.first;
131 unsigned Bits = P.second;
132
133 if (!Visited.insert(P).second)
134 continue;
135
136 // Only handle instructions with one def.
137 if (MI->getNumExplicitDefs() != 1)
138 return false;
139
140 Register DestReg = MI->getOperand(0).getReg();
141 if (!DestReg.isVirtual())
142 return false;
143
144 for (auto &UserOp : MRI.use_nodbg_operands(DestReg)) {
145 const MachineInstr *UserMI = UserOp.getParent();
146 unsigned OpIdx = UserOp.getOperandNo();
147
148 switch (UserMI->getOpcode()) {
149 default:
150 if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
151 break;
152 return false;
153
154 case RISCV::ADDIW:
155 case RISCV::ADDW:
156 case RISCV::DIVUW:
157 case RISCV::DIVW:
158 case RISCV::MULW:
159 case RISCV::REMUW:
160 case RISCV::REMW:
161 case RISCV::SLLIW:
162 case RISCV::SLLW:
163 case RISCV::SRAIW:
164 case RISCV::SRAW:
165 case RISCV::SRLIW:
166 case RISCV::SRLW:
167 case RISCV::SUBW:
168 case RISCV::ROLW:
169 case RISCV::RORW:
170 case RISCV::RORIW:
171 case RISCV::CLZW:
172 case RISCV::CTZW:
173 case RISCV::CPOPW:
174 case RISCV::SLLI_UW:
175 case RISCV::FMV_W_X:
176 case RISCV::FCVT_H_W:
177 case RISCV::FCVT_H_W_INX:
178 case RISCV::FCVT_H_WU:
179 case RISCV::FCVT_H_WU_INX:
180 case RISCV::FCVT_S_W:
181 case RISCV::FCVT_S_W_INX:
182 case RISCV::FCVT_S_WU:
183 case RISCV::FCVT_S_WU_INX:
184 case RISCV::FCVT_D_W:
185 case RISCV::FCVT_D_W_INX:
186 case RISCV::FCVT_D_WU:
187 case RISCV::FCVT_D_WU_INX:
188 if (Bits >= 32)
189 break;
190 return false;
191 case RISCV::SEXT_B:
192 case RISCV::PACKH:
193 if (Bits >= 8)
194 break;
195 return false;
196 case RISCV::SEXT_H:
197 case RISCV::FMV_H_X:
198 case RISCV::ZEXT_H_RV32:
199 case RISCV::ZEXT_H_RV64:
200 case RISCV::PACKW:
201 if (Bits >= 16)
202 break;
203 return false;
204
205 case RISCV::PACK:
206 if (Bits >= (ST.getXLen() / 2))
207 break;
208 return false;
209
210 case RISCV::SRLI: {
211 // If we are shifting right by less than Bits, and users don't demand
212 // any bits that were shifted into [Bits-1:0], then we can consider this
213 // as an N-Bit user.
214 unsigned ShAmt = UserMI->getOperand(2).getImm();
215 if (Bits > ShAmt) {
216 Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt));
217 break;
218 }
219 return false;
220 }
221
222 // these overwrite higher input bits, otherwise the lower word of output
223 // depends only on the lower word of input. So check their uses read W.
224 case RISCV::SLLI: {
225 unsigned ShAmt = UserMI->getOperand(2).getImm();
226 if (Bits >= (ST.getXLen() - ShAmt))
227 break;
228 Worklist.push_back(std::make_pair(UserMI, Bits + ShAmt));
229 break;
230 }
231 case RISCV::ANDI: {
232 uint64_t Imm = UserMI->getOperand(2).getImm();
233 if (Bits >= (unsigned)llvm::bit_width(Imm))
234 break;
235 Worklist.push_back(std::make_pair(UserMI, Bits));
236 break;
237 }
238 case RISCV::ORI: {
239 uint64_t Imm = UserMI->getOperand(2).getImm();
240 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
241 break;
242 Worklist.push_back(std::make_pair(UserMI, Bits));
243 break;
244 }
245
246 case RISCV::SLL:
247 case RISCV::BSET:
248 case RISCV::BCLR:
249 case RISCV::BINV:
250 // Operand 2 is the shift amount which uses log2(xlen) bits.
251 if (OpIdx == 2) {
252 if (Bits >= Log2_32(ST.getXLen()))
253 break;
254 return false;
255 }
256 Worklist.push_back(std::make_pair(UserMI, Bits));
257 break;
258
259 case RISCV::SRA:
260 case RISCV::SRL:
261 case RISCV::ROL:
262 case RISCV::ROR:
263 // Operand 2 is the shift amount which uses 6 bits.
264 if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen()))
265 break;
266 return false;
267
268 case RISCV::ADD_UW:
269 case RISCV::SH1ADD_UW:
270 case RISCV::SH2ADD_UW:
271 case RISCV::SH3ADD_UW:
272 // Operand 1 is implicitly zero extended.
273 if (OpIdx == 1 && Bits >= 32)
274 break;
275 Worklist.push_back(std::make_pair(UserMI, Bits));
276 break;
277
278 case RISCV::BEXTI:
279 if (UserMI->getOperand(2).getImm() >= Bits)
280 return false;
281 break;
282
283 case RISCV::SB:
284 // The first argument is the value to store.
285 if (OpIdx == 0 && Bits >= 8)
286 break;
287 return false;
288 case RISCV::SH:
289 // The first argument is the value to store.
290 if (OpIdx == 0 && Bits >= 16)
291 break;
292 return false;
293 case RISCV::SW:
294 // The first argument is the value to store.
295 if (OpIdx == 0 && Bits >= 32)
296 break;
297 return false;
298
299 // For these, lower word of output in these operations, depends only on
300 // the lower word of input. So, we check all uses only read lower word.
301 case RISCV::COPY:
302 case RISCV::PHI:
303
304 case RISCV::ADD:
305 case RISCV::ADDI:
306 case RISCV::AND:
307 case RISCV::MUL:
308 case RISCV::OR:
309 case RISCV::SUB:
310 case RISCV::XOR:
311 case RISCV::XORI:
312
313 case RISCV::ANDN:
314 case RISCV::BREV8:
315 case RISCV::CLMUL:
316 case RISCV::ORC_B:
317 case RISCV::ORN:
318 case RISCV::SH1ADD:
319 case RISCV::SH2ADD:
320 case RISCV::SH3ADD:
321 case RISCV::XNOR:
322 case RISCV::BSETI:
323 case RISCV::BCLRI:
324 case RISCV::BINVI:
325 Worklist.push_back(std::make_pair(UserMI, Bits));
326 break;
327
328 case RISCV::PseudoCCMOVGPR:
329 // Either operand 4 or operand 5 is returned by this instruction. If
330 // only the lower word of the result is used, then only the lower word
331 // of operand 4 and 5 is used.
332 if (OpIdx != 4 && OpIdx != 5)
333 return false;
334 Worklist.push_back(std::make_pair(UserMI, Bits));
335 break;
336
337 case RISCV::CZERO_EQZ:
338 case RISCV::CZERO_NEZ:
339 case RISCV::VT_MASKC:
340 case RISCV::VT_MASKCN:
341 if (OpIdx != 1)
342 return false;
343 Worklist.push_back(std::make_pair(UserMI, Bits));
344 break;
345 }
346 }
347 }
348
349 return true;
350}
351
352static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
353 const MachineRegisterInfo &MRI) {
354 return hasAllNBitUsers(OrigMI, ST, MRI, 32);
355}
356
357// This function returns true if the machine instruction always outputs a value
358// where bits 63:32 match bit 31.
359static bool isSignExtendingOpW(const MachineInstr &MI, unsigned OpNo) {
360 uint64_t TSFlags = MI.getDesc().TSFlags;
361
362 // Instructions that can be determined from opcode are marked in tablegen.
364 return true;
365
366 // Special cases that require checking operands.
367 switch (MI.getOpcode()) {
368 // shifting right sufficiently makes the value 32-bit sign-extended
369 case RISCV::SRAI:
370 return MI.getOperand(2).getImm() >= 32;
371 case RISCV::SRLI:
372 return MI.getOperand(2).getImm() > 32;
373 // The LI pattern ADDI rd, X0, imm is sign extended.
374 case RISCV::ADDI:
375 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0;
376 // An ANDI with an 11 bit immediate will zero bits 63:11.
377 case RISCV::ANDI:
378 return isUInt<11>(MI.getOperand(2).getImm());
379 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
380 case RISCV::ORI:
381 return !isUInt<11>(MI.getOperand(2).getImm());
382 // A bseti with X0 is sign extended if the immediate is less than 31.
383 case RISCV::BSETI:
384 return MI.getOperand(2).getImm() < 31 &&
385 MI.getOperand(1).getReg() == RISCV::X0;
386 // Copying from X0 produces zero.
387 case RISCV::COPY:
388 return MI.getOperand(1).getReg() == RISCV::X0;
389 // Ignore the scratch register destination.
390 case RISCV::PseudoAtomicLoadNand32:
391 return OpNo == 0;
392 case RISCV::PseudoVMV_X_S: {
393 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
394 int64_t Log2SEW = MI.getOperand(2).getImm();
395 assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW");
396 return Log2SEW <= 5;
397 }
398 }
399
400 return false;
401}
402
403static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
406 SmallSet<Register, 4> Visited;
408
409 auto AddRegToWorkList = [&](Register SrcReg) {
410 if (!SrcReg.isVirtual())
411 return false;
412 Worklist.push_back(SrcReg);
413 return true;
414 };
415
416 if (!AddRegToWorkList(SrcReg))
417 return false;
418
419 while (!Worklist.empty()) {
420 Register Reg = Worklist.pop_back_val();
421
422 // If we already visited this register, we don't need to check it again.
423 if (!Visited.insert(Reg).second)
424 continue;
425
426 MachineInstr *MI = MRI.getVRegDef(Reg);
427 if (!MI)
428 continue;
429
430 int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr);
431 assert(OpNo != -1 && "Couldn't find register");
432
433 // If this is a sign extending operation we don't need to look any further.
434 if (isSignExtendingOpW(*MI, OpNo))
435 continue;
436
437 // Is this an instruction that propagates sign extend?
438 switch (MI->getOpcode()) {
439 default:
440 // Unknown opcode, give up.
441 return false;
442 case RISCV::COPY: {
443 const MachineFunction *MF = MI->getMF();
444 const RISCVMachineFunctionInfo *RVFI =
446
447 // If this is the entry block and the register is livein, see if we know
448 // it is sign extended.
449 if (MI->getParent() == &MF->front()) {
450 Register VReg = MI->getOperand(0).getReg();
451 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg))
452 continue;
453 }
454
455 Register CopySrcReg = MI->getOperand(1).getReg();
456 if (CopySrcReg == RISCV::X10) {
457 // For a method return value, we check the ZExt/SExt flags in attribute.
458 // We assume the following code sequence for method call.
459 // PseudoCALL @bar, ...
460 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
461 // %0:gpr = COPY $x10
462 //
463 // We use the PseudoCall to look up the IR function being called to find
464 // its return attributes.
465 const MachineBasicBlock *MBB = MI->getParent();
466 auto II = MI->getIterator();
467 if (II == MBB->instr_begin() ||
468 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
469 return false;
470
471 const MachineInstr &CallMI = *(--II);
472 if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal())
473 return false;
474
475 auto *CalleeFn =
476 dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal());
477 if (!CalleeFn)
478 return false;
479
480 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType());
481 if (!IntTy)
482 return false;
483
484 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
485 unsigned BitWidth = IntTy->getBitWidth();
486 if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) ||
487 (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt)))
488 continue;
489 }
490
491 if (!AddRegToWorkList(CopySrcReg))
492 return false;
493
494 break;
495 }
496
497 // For these, we just need to check if the 1st operand is sign extended.
498 case RISCV::BCLRI:
499 case RISCV::BINVI:
500 case RISCV::BSETI:
501 if (MI->getOperand(2).getImm() >= 31)
502 return false;
503 [[fallthrough]];
504 case RISCV::REM:
505 case RISCV::ANDI:
506 case RISCV::ORI:
507 case RISCV::XORI:
508 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
509 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
510 // Logical operations use a sign extended 12-bit immediate.
511 if (!AddRegToWorkList(MI->getOperand(1).getReg()))
512 return false;
513
514 break;
515 case RISCV::PseudoCCADDW:
516 case RISCV::PseudoCCADDIW:
517 case RISCV::PseudoCCSUBW:
518 case RISCV::PseudoCCSLLW:
519 case RISCV::PseudoCCSRLW:
520 case RISCV::PseudoCCSRAW:
521 case RISCV::PseudoCCSLLIW:
522 case RISCV::PseudoCCSRLIW:
523 case RISCV::PseudoCCSRAIW:
524 // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only
525 // need to check if operand 4 is sign extended.
526 if (!AddRegToWorkList(MI->getOperand(4).getReg()))
527 return false;
528 break;
529 case RISCV::REMU:
530 case RISCV::AND:
531 case RISCV::OR:
532 case RISCV::XOR:
533 case RISCV::ANDN:
534 case RISCV::ORN:
535 case RISCV::XNOR:
536 case RISCV::MAX:
537 case RISCV::MAXU:
538 case RISCV::MIN:
539 case RISCV::MINU:
540 case RISCV::PseudoCCMOVGPR:
541 case RISCV::PseudoCCAND:
542 case RISCV::PseudoCCOR:
543 case RISCV::PseudoCCXOR:
544 case RISCV::PHI: {
545 // If all incoming values are sign-extended, the output of AND, OR, XOR,
546 // MIN, MAX, or PHI is also sign-extended.
547
548 // The input registers for PHI are operand 1, 3, ...
549 // The input registers for PseudoCCMOVGPR are 4 and 5.
550 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
551 // The input registers for others are operand 1 and 2.
552 unsigned B = 1, E = 3, D = 1;
553 switch (MI->getOpcode()) {
554 case RISCV::PHI:
555 E = MI->getNumOperands();
556 D = 2;
557 break;
558 case RISCV::PseudoCCMOVGPR:
559 B = 4;
560 E = 6;
561 break;
562 case RISCV::PseudoCCAND:
563 case RISCV::PseudoCCOR:
564 case RISCV::PseudoCCXOR:
565 B = 4;
566 E = 7;
567 break;
568 }
569
570 for (unsigned I = B; I != E; I += D) {
571 if (!MI->getOperand(I).isReg())
572 return false;
573
574 if (!AddRegToWorkList(MI->getOperand(I).getReg()))
575 return false;
576 }
577
578 break;
579 }
580
581 case RISCV::CZERO_EQZ:
582 case RISCV::CZERO_NEZ:
583 case RISCV::VT_MASKC:
584 case RISCV::VT_MASKCN:
585 // Instructions return zero or operand 1. Result is sign extended if
586 // operand 1 is sign extended.
587 if (!AddRegToWorkList(MI->getOperand(1).getReg()))
588 return false;
589 break;
590
591 // With these opcode, we can "fix" them with the W-version
592 // if we know all users of the result only rely on bits 31:0
593 case RISCV::SLLI:
594 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
595 if (MI->getOperand(2).getImm() >= 32)
596 return false;
597 [[fallthrough]];
598 case RISCV::ADDI:
599 case RISCV::ADD:
600 case RISCV::LD:
601 case RISCV::LWU:
602 case RISCV::MUL:
603 case RISCV::SUB:
604 if (hasAllWUsers(*MI, ST, MRI)) {
605 FixableDef.insert(MI);
606 break;
607 }
608 return false;
609 }
610 }
611
612 // If we get here, then every node we visited produces a sign extended value
613 // or propagated sign extended values. So the result must be sign extended.
614 return true;
615}
616
617static unsigned getWOp(unsigned Opcode) {
618 switch (Opcode) {
619 case RISCV::ADDI:
620 return RISCV::ADDIW;
621 case RISCV::ADD:
622 return RISCV::ADDW;
623 case RISCV::LD:
624 case RISCV::LWU:
625 return RISCV::LW;
626 case RISCV::MUL:
627 return RISCV::MULW;
628 case RISCV::SLLI:
629 return RISCV::SLLIW;
630 case RISCV::SUB:
631 return RISCV::SUBW;
632 default:
633 llvm_unreachable("Unexpected opcode for replacement with W variant");
634 }
635}
636
637bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
638 const RISCVInstrInfo &TII,
639 const RISCVSubtarget &ST,
642 return false;
643
644 bool MadeChange = false;
645 for (MachineBasicBlock &MBB : MF) {
647 // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
648 if (!RISCV::isSEXT_W(MI))
649 continue;
650
651 Register SrcReg = MI.getOperand(1).getReg();
652
654
655 // If all users only use the lower bits, this sext.w is redundant.
656 // Or if all definitions reaching MI sign-extend their output,
657 // then sext.w is redundant.
658 if (!hasAllWUsers(MI, ST, MRI) &&
659 !isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
660 continue;
661
662 Register DstReg = MI.getOperand(0).getReg();
663 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
664 continue;
665
666 // Convert Fixable instructions to their W versions.
667 for (MachineInstr *Fixable : FixableDefs) {
668 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
669 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode())));
670 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap);
671 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap);
672 Fixable->clearFlag(MachineInstr::MIFlag::IsExact);
673 LLVM_DEBUG(dbgs() << " with " << *Fixable);
674 ++NumTransformedToWInstrs;
675 }
676
677 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
678 MRI.replaceRegWith(DstReg, SrcReg);
679 MRI.clearKillFlags(SrcReg);
680 MI.eraseFromParent();
681 ++NumRemovedSExtW;
682 MadeChange = true;
683 }
684 }
685
686 return MadeChange;
687}
688
689bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
690 const RISCVInstrInfo &TII,
691 const RISCVSubtarget &ST,
693 bool MadeChange = false;
694 for (MachineBasicBlock &MBB : MF) {
695 for (MachineInstr &MI : MBB) {
696 unsigned Opc;
697 switch (MI.getOpcode()) {
698 default:
699 continue;
700 case RISCV::ADDW: Opc = RISCV::ADD; break;
701 case RISCV::ADDIW: Opc = RISCV::ADDI; break;
702 case RISCV::MULW: Opc = RISCV::MUL; break;
703 case RISCV::SLLIW: Opc = RISCV::SLLI; break;
704 }
705
706 if (hasAllWUsers(MI, ST, MRI)) {
707 MI.setDesc(TII.get(Opc));
708 MadeChange = true;
709 }
710 }
711 }
712
713 return MadeChange;
714}
715
716bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF,
717 const RISCVInstrInfo &TII,
718 const RISCVSubtarget &ST,
720 bool MadeChange = false;
721 for (MachineBasicBlock &MBB : MF) {
722 for (MachineInstr &MI : MBB) {
723 unsigned WOpc;
724 // TODO: Add more?
725 switch (MI.getOpcode()) {
726 default:
727 continue;
728 case RISCV::ADD:
729 WOpc = RISCV::ADDW;
730 break;
731 case RISCV::ADDI:
732 WOpc = RISCV::ADDIW;
733 break;
734 case RISCV::SUB:
735 WOpc = RISCV::SUBW;
736 break;
737 case RISCV::MUL:
738 WOpc = RISCV::MULW;
739 break;
740 case RISCV::SLLI:
741 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
742 if (MI.getOperand(2).getImm() >= 32)
743 continue;
744 WOpc = RISCV::SLLIW;
745 break;
746 case RISCV::LD:
747 case RISCV::LWU:
748 WOpc = RISCV::LW;
749 break;
750 }
751
752 if (hasAllWUsers(MI, ST, MRI)) {
753 LLVM_DEBUG(dbgs() << "Replacing " << MI);
754 MI.setDesc(TII.get(WOpc));
755 MI.clearFlag(MachineInstr::MIFlag::NoSWrap);
756 MI.clearFlag(MachineInstr::MIFlag::NoUWrap);
757 MI.clearFlag(MachineInstr::MIFlag::IsExact);
758 LLVM_DEBUG(dbgs() << " with " << MI);
759 ++NumTransformedToWInstrs;
760 MadeChange = true;
761 }
762 }
763 }
764
765 return MadeChange;
766}
767
768bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
769 if (skipFunction(MF.getFunction()))
770 return false;
771
774 const RISCVInstrInfo &TII = *ST.getInstrInfo();
775
776 if (!ST.is64Bit())
777 return false;
778
779 bool MadeChange = false;
780 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
781
782 if (!(DisableStripWSuffix || ST.preferWInst()))
783 MadeChange |= stripWSuffixes(MF, TII, ST, MRI);
784
785 if (ST.preferWInst())
786 MadeChange |= appendWSuffixes(MF, TII, ST, MRI);
787
788 return MadeChange;
789}
unsigned const MachineRegisterInfo * MRI
MachineBasicBlock & MBB
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DEBUG(X)
Definition: Debug.h:101
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition: MD5.cpp:58
uint64_t IntrinsicInst * II
#define P(N)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, const MachineRegisterInfo &MRI, SmallPtrSetImpl< MachineInstr * > &FixableDef)
static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, const MachineRegisterInfo &MRI)
static bool isSignExtendingOpW(const MachineInstr &MI, unsigned OpNo)
static cl::opt< bool > DisableStripWSuffix("riscv-disable-strip-w-suffix", cl::desc("Disable strip W suffix"), cl::init(false), cl::Hidden)
static bool hasAllNBitUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, const MachineRegisterInfo &MRI, unsigned OrigBits)
#define RISCV_OPT_W_INSTRS_NAME
static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, unsigned Bits)
static cl::opt< bool > DisableSExtWRemoval("riscv-disable-sextw-removal", cl::desc("Disable removal of sext.w"), cl::init(false), cl::Hidden)
#define DEBUG_TYPE
static unsigned getWOp(unsigned Opcode)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallSet class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Represent the analysis usage information of a pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:256
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
Describe properties that are true of each instruction in the target description file.
Definition: MCInstrDesc.h:198
instr_iterator instr_begin()
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
Ty * getInfo()
getInfo - Keep track of various per-function pieces of information for backends that would like to do...
const MachineBasicBlock & front() const
Representation of each machine instruction.
Definition: MachineInstr.h:69
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
Definition: MachineInstr.h:569
const MachineBasicBlock * getParent() const
Definition: MachineInstr.h:346
bool isCall(QueryType Type=AnyInBundle) const
Definition: MachineInstr.h:950
const MachineOperand & getOperand(unsigned i) const
Definition: MachineInstr.h:579
MachineOperand class - Representation of each machine instruction operand.
unsigned getOperandNo() const
Returns the index of this operand in the instruction that it belongs to.
const GlobalValue * getGlobal() const
int64_t getImm() const
MachineInstr * getParent()
getParent - Return the instruction that this operand belongs to.
bool isGlobal() const
isGlobal - Tests if this is a MO_GlobalAddress operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
bool isLiveIn(Register Reg) const
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
RISCVMachineFunctionInfo - This class is derived from MachineFunctionInfo and contains private RISCV-...
bool isSExt32Register(Register Reg) const
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
constexpr bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
Definition: Register.h:91
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:346
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:367
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:502
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
bool empty() const
Definition: SmallVector.h:94
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
static unsigned getVLOpNum(const MCInstrDesc &Desc)
static bool hasVLOp(uint64_t TSFlags)
static unsigned getSEWOpNum(const MCInstrDesc &Desc)
static bool hasSEWOp(uint64_t TSFlags)
std::optional< unsigned > getVectorLowDemandedScalarBits(uint16_t Opcode, unsigned Log2SEW)
unsigned getRVVMCOpcode(unsigned RVVPseudoOpcode)
bool isSEXT_W(const MachineInstr &MI)
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
int bit_width(T Value)
Returns the number of bits needed to represent Value if Value is nonzero.
Definition: bit.h:317
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:656
unsigned Log2_32(uint32_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
Definition: MathExtras.h:340
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
FunctionPass * createRISCVOptWInstrsPass()
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:191