LLVM 18.0.0git
X86FixupVectorConstants.cpp
Go to the documentation of this file.
1//===-- X86FixupVectorConstants.cpp - optimize constant generation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file examines all full size vector constant pool loads and attempts to
10// replace them with smaller constant pool entries, including:
11// * Converting AVX512 memory-fold instructions to their broadcast-fold form
12// * TODO: Broadcasting of full width loads.
13// * TODO: Sign/Zero extension of full width loads.
14//
15//===----------------------------------------------------------------------===//
16
17#include "X86.h"
18#include "X86InstrFoldTables.h"
19#include "X86InstrInfo.h"
20#include "X86Subtarget.h"
21#include "llvm/ADT/Statistic.h"
23
24using namespace llvm;
25
26#define DEBUG_TYPE "x86-fixup-vector-constants"
27
28STATISTIC(NumInstChanges, "Number of instructions changes");
29
30namespace {
31class X86FixupVectorConstantsPass : public MachineFunctionPass {
32public:
33 static char ID;
34
35 X86FixupVectorConstantsPass() : MachineFunctionPass(ID) {}
36
37 StringRef getPassName() const override {
38 return "X86 Fixup Vector Constants";
39 }
40
41 bool runOnMachineFunction(MachineFunction &MF) override;
42 bool processInstruction(MachineFunction &MF, MachineBasicBlock &MBB,
44
45 // This pass runs after regalloc and doesn't support VReg operands.
48 MachineFunctionProperties::Property::NoVRegs);
49 }
50
51private:
52 const X86InstrInfo *TII = nullptr;
53 const X86Subtarget *ST = nullptr;
54 const MCSchedModel *SM = nullptr;
55};
56} // end anonymous namespace
57
58char X86FixupVectorConstantsPass::ID = 0;
59
60INITIALIZE_PASS(X86FixupVectorConstantsPass, DEBUG_TYPE, DEBUG_TYPE, false, false)
61
63 return new X86FixupVectorConstantsPass();
64}
65
67 const MachineOperand &Op) {
68 if (!Op.isCPI() || Op.getOffset() != 0)
69 return nullptr;
70
72 MI.getParent()->getParent()->getConstantPool()->getConstants();
73 const MachineConstantPoolEntry &ConstantEntry = Constants[Op.getIndex()];
74
75 // Bail if this is a machine constant pool entry, we won't be able to dig out
76 // anything useful.
77 if (ConstantEntry.isMachineConstantPoolEntry())
78 return nullptr;
79
80 return ConstantEntry.Val.ConstVal;
81}
82
83// Attempt to extract the full width of bits data from the constant.
84static std::optional<APInt> extractConstantBits(const Constant *C) {
85 unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
86
87 if (auto *CInt = dyn_cast<ConstantInt>(C))
88 return CInt->getValue();
89
90 if (auto *CFP = dyn_cast<ConstantFP>(C))
91 return CFP->getValue().bitcastToAPInt();
92
93 if (auto *CV = dyn_cast<ConstantVector>(C)) {
94 if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) {
95 if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
96 assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
97 return APInt::getSplat(NumBits, *Bits);
98 }
99 }
100 }
101
102 if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
103 bool IsInteger = CDS->getElementType()->isIntegerTy();
104 bool IsFloat = CDS->getElementType()->isHalfTy() ||
105 CDS->getElementType()->isBFloatTy() ||
106 CDS->getElementType()->isFloatTy() ||
107 CDS->getElementType()->isDoubleTy();
108 if (IsInteger || IsFloat) {
109 APInt Bits = APInt::getZero(NumBits);
110 unsigned EltBits = CDS->getElementType()->getPrimitiveSizeInBits();
111 for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) {
112 if (IsInteger)
113 Bits.insertBits(CDS->getElementAsAPInt(I), I * EltBits);
114 else
115 Bits.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(),
116 I * EltBits);
117 }
118 return Bits;
119 }
120 }
121
122 return std::nullopt;
123}
124
125// Attempt to compute the splat width of bits data by normalizing the splat to
126// remove undefs.
127static std::optional<APInt> getSplatableConstant(const Constant *C,
128 unsigned SplatBitWidth) {
129 const Type *Ty = C->getType();
130 assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
131 "Illegal splat width");
132
133 if (std::optional<APInt> Bits = extractConstantBits(C))
134 if (Bits->isSplat(SplatBitWidth))
135 return Bits->trunc(SplatBitWidth);
136
137 // Detect general splats with undefs.
138 // TODO: Do we need to handle NumEltsBits > SplatBitWidth splitting?
139 if (auto *CV = dyn_cast<ConstantVector>(C)) {
140 unsigned NumOps = CV->getNumOperands();
141 unsigned NumEltsBits = Ty->getScalarSizeInBits();
142 unsigned NumScaleOps = SplatBitWidth / NumEltsBits;
143 if ((SplatBitWidth % NumEltsBits) == 0) {
144 // Collect the elements and ensure that within the repeated splat sequence
145 // they either match or are undef.
146 SmallVector<Constant *, 16> Sequence(NumScaleOps, nullptr);
147 for (unsigned Idx = 0; Idx != NumOps; ++Idx) {
148 if (Constant *Elt = CV->getAggregateElement(Idx)) {
149 if (isa<UndefValue>(Elt))
150 continue;
151 unsigned SplatIdx = Idx % NumScaleOps;
152 if (!Sequence[SplatIdx] || Sequence[SplatIdx] == Elt) {
153 Sequence[SplatIdx] = Elt;
154 continue;
155 }
156 }
157 return std::nullopt;
158 }
159 // Extract the constant bits forming the splat and insert into the bits
160 // data, leave undef as zero.
161 APInt SplatBits = APInt::getZero(SplatBitWidth);
162 for (unsigned I = 0; I != NumScaleOps; ++I) {
163 if (!Sequence[I])
164 continue;
165 if (std::optional<APInt> Bits = extractConstantBits(Sequence[I])) {
166 SplatBits.insertBits(*Bits, I * Bits->getBitWidth());
167 continue;
168 }
169 return std::nullopt;
170 }
171 return SplatBits;
172 }
173 }
174
175 return std::nullopt;
176}
177
178// Attempt to rebuild a normalized splat vector constant of the requested splat
179// width, built up of potentially smaller scalar values.
180// NOTE: We don't always bother converting to scalars if the vector length is 1.
182 unsigned SplatBitWidth) {
183 std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
184 if (!Splat)
185 return nullptr;
186
187 // Determine scalar size to use for the constant splat vector, clamping as we
188 // might have found a splat smaller than the original constant data.
189 const Type *OriginalType = C->getType();
190 Type *SclTy = OriginalType->getScalarType();
191 unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
192 NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
193
194 if (NumSclBits == 8) {
195 SmallVector<uint8_t> RawBits;
196 for (unsigned I = 0; I != SplatBitWidth; I += 8)
197 RawBits.push_back(Splat->extractBits(8, I).getZExtValue());
198 return ConstantDataVector::get(OriginalType->getContext(), RawBits);
199 }
200
201 if (NumSclBits == 16) {
202 SmallVector<uint16_t> RawBits;
203 for (unsigned I = 0; I != SplatBitWidth; I += 16)
204 RawBits.push_back(Splat->extractBits(16, I).getZExtValue());
205 if (SclTy->is16bitFPTy())
206 return ConstantDataVector::getFP(SclTy, RawBits);
207 return ConstantDataVector::get(OriginalType->getContext(), RawBits);
208 }
209
210 if (NumSclBits == 32) {
211 SmallVector<uint32_t> RawBits;
212 for (unsigned I = 0; I != SplatBitWidth; I += 32)
213 RawBits.push_back(Splat->extractBits(32, I).getZExtValue());
214 if (SclTy->isFloatTy())
215 return ConstantDataVector::getFP(SclTy, RawBits);
216 return ConstantDataVector::get(OriginalType->getContext(), RawBits);
217 }
218
219 // Fallback to i64 / double.
220 SmallVector<uint64_t> RawBits;
221 for (unsigned I = 0; I != SplatBitWidth; I += 64)
222 RawBits.push_back(Splat->extractBits(64, I).getZExtValue());
223 if (SclTy->isDoubleTy())
224 return ConstantDataVector::getFP(SclTy, RawBits);
225 return ConstantDataVector::get(OriginalType->getContext(), RawBits);
226}
227
228bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
230 MachineInstr &MI) {
231 unsigned Opc = MI.getOpcode();
232 MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool();
233 bool HasDQI = ST->hasDQI();
234 bool HasBWI = ST->hasBWI();
235
236 auto ConvertToBroadcast = [&](unsigned OpBcst256, unsigned OpBcst128,
237 unsigned OpBcst64, unsigned OpBcst32,
238 unsigned OpBcst16, unsigned OpBcst8,
239 unsigned OperandNo) {
240 assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
241 "Unexpected number of operands!");
242
243 MachineOperand &CstOp = MI.getOperand(OperandNo + X86::AddrDisp);
244 if (auto *C = getConstantFromPool(MI, CstOp)) {
245 // Attempt to detect a suitable splat from increasing splat widths.
246 std::pair<unsigned, unsigned> Broadcasts[] = {
247 {8, OpBcst8}, {16, OpBcst16}, {32, OpBcst32},
248 {64, OpBcst64}, {128, OpBcst128}, {256, OpBcst256},
249 };
250 for (auto [BitWidth, OpBcst] : Broadcasts) {
251 if (OpBcst) {
252 // Construct a suitable splat constant and adjust the MI to
253 // use the new constant pool entry.
255 unsigned NewCPI =
256 CP->getConstantPoolIndex(NewCst, Align(BitWidth / 8));
257 MI.setDesc(TII->get(OpBcst));
258 CstOp.setIndex(NewCPI);
259 return true;
260 }
261 }
262 }
263 }
264 return false;
265 };
266
267 // Attempt to convert full width vector loads into broadcast loads.
268 switch (Opc) {
269 /* FP Loads */
270 case X86::MOVAPDrm:
271 case X86::MOVAPSrm:
272 case X86::MOVUPDrm:
273 case X86::MOVUPSrm:
274 // TODO: SSE3 MOVDDUP Handling
275 return false;
276 case X86::VMOVAPDrm:
277 case X86::VMOVAPSrm:
278 case X86::VMOVUPDrm:
279 case X86::VMOVUPSrm:
280 return ConvertToBroadcast(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0,
281 1);
282 case X86::VMOVAPDYrm:
283 case X86::VMOVAPSYrm:
284 case X86::VMOVUPDYrm:
285 case X86::VMOVUPSYrm:
286 return ConvertToBroadcast(0, X86::VBROADCASTF128, X86::VBROADCASTSDYrm,
287 X86::VBROADCASTSSYrm, 0, 0, 1);
288 case X86::VMOVAPDZ128rm:
289 case X86::VMOVAPSZ128rm:
290 case X86::VMOVUPDZ128rm:
291 case X86::VMOVUPSZ128rm:
292 return ConvertToBroadcast(0, 0, X86::VMOVDDUPZ128rm,
293 X86::VBROADCASTSSZ128rm, 0, 0, 1);
294 case X86::VMOVAPDZ256rm:
295 case X86::VMOVAPSZ256rm:
296 case X86::VMOVUPDZ256rm:
297 case X86::VMOVUPSZ256rm:
298 return ConvertToBroadcast(
299 0, HasDQI ? X86::VBROADCASTF64X2Z128rm : X86::VBROADCASTF32X4Z256rm,
300 X86::VBROADCASTSDZ256rm, X86::VBROADCASTSSZ256rm, 0, 0, 1);
301 case X86::VMOVAPDZrm:
302 case X86::VMOVAPSZrm:
303 case X86::VMOVUPDZrm:
304 case X86::VMOVUPSZrm:
305 return ConvertToBroadcast(
306 HasDQI ? X86::VBROADCASTF32X8rm : X86::VBROADCASTF64X4rm,
307 HasDQI ? X86::VBROADCASTF64X2rm : X86::VBROADCASTF32X4rm,
308 X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0, 0, 1);
309 /* Integer Loads */
310 case X86::VMOVDQArm:
311 case X86::VMOVDQUrm:
312 if (ST->hasAVX2())
313 return ConvertToBroadcast(0, 0, X86::VPBROADCASTQrm, X86::VPBROADCASTDrm,
314 X86::VPBROADCASTWrm, X86::VPBROADCASTBrm, 1);
315 return ConvertToBroadcast(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0,
316 1);
317 case X86::VMOVDQAYrm:
318 case X86::VMOVDQUYrm:
319 if (ST->hasAVX2())
320 return ConvertToBroadcast(0, X86::VBROADCASTI128, X86::VPBROADCASTQYrm,
321 X86::VPBROADCASTDYrm, X86::VPBROADCASTWYrm,
322 X86::VPBROADCASTBYrm, 1);
323 return ConvertToBroadcast(0, X86::VBROADCASTF128, X86::VBROADCASTSDYrm,
324 X86::VBROADCASTSSYrm, 0, 0, 1);
325 case X86::VMOVDQA32Z128rm:
326 case X86::VMOVDQA64Z128rm:
327 case X86::VMOVDQU32Z128rm:
328 case X86::VMOVDQU64Z128rm:
329 return ConvertToBroadcast(0, 0, X86::VPBROADCASTQZ128rm,
330 X86::VPBROADCASTDZ128rm,
331 HasBWI ? X86::VPBROADCASTWZ128rm : 0,
332 HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1);
333 case X86::VMOVDQA32Z256rm:
334 case X86::VMOVDQA64Z256rm:
335 case X86::VMOVDQU32Z256rm:
336 case X86::VMOVDQU64Z256rm:
337 return ConvertToBroadcast(
338 0, HasDQI ? X86::VBROADCASTI64X2Z128rm : X86::VBROADCASTI32X4Z256rm,
339 X86::VPBROADCASTQZ256rm, X86::VPBROADCASTDZ256rm,
340 HasBWI ? X86::VPBROADCASTWZ256rm : 0,
341 HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1);
342 case X86::VMOVDQA32Zrm:
343 case X86::VMOVDQA64Zrm:
344 case X86::VMOVDQU32Zrm:
345 case X86::VMOVDQU64Zrm:
346 return ConvertToBroadcast(
347 HasDQI ? X86::VBROADCASTI32X8rm : X86::VBROADCASTI64X4rm,
348 HasDQI ? X86::VBROADCASTI64X2rm : X86::VBROADCASTI32X4rm,
349 X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
350 HasBWI ? X86::VPBROADCASTWZrm : 0, HasBWI ? X86::VPBROADCASTBZrm : 0,
351 1);
352 }
353
354 // Attempt to find a AVX512 mapping from a full width memory-fold instruction
355 // to a broadcast-fold instruction variant.
356 if ((MI.getDesc().TSFlags & X86II::EncodingMask) == X86II::EVEX) {
357 unsigned OpBcst32 = 0, OpBcst64 = 0;
358 unsigned OpNoBcst32 = 0, OpNoBcst64 = 0;
359 if (const X86MemoryFoldTableEntry *Mem2Bcst =
361 OpBcst32 = Mem2Bcst->DstOp;
362 OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK;
363 }
364 if (const X86MemoryFoldTableEntry *Mem2Bcst =
366 OpBcst64 = Mem2Bcst->DstOp;
367 OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK;
368 }
369 assert(((OpBcst32 == 0) || (OpBcst64 == 0) || (OpNoBcst32 == OpNoBcst64)) &&
370 "OperandNo mismatch");
371
372 if (OpBcst32 || OpBcst64) {
373 unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
374 return ConvertToBroadcast(0, 0, OpBcst64, OpBcst32, 0, 0, OpNo);
375 }
376 }
377
378 return false;
379}
380
381bool X86FixupVectorConstantsPass::runOnMachineFunction(MachineFunction &MF) {
382 LLVM_DEBUG(dbgs() << "Start X86FixupVectorConstants\n";);
383 bool Changed = false;
385 TII = ST->getInstrInfo();
386 SM = &ST->getSchedModel();
387
388 for (MachineBasicBlock &MBB : MF) {
389 for (MachineInstr &MI : MBB) {
390 if (processInstruction(MF, MBB, MI)) {
391 ++NumInstChanges;
392 Changed = true;
393 }
394 }
395 }
396 LLVM_DEBUG(dbgs() << "End X86FixupVectorConstants\n";);
397 return Changed;
398}
MachineBasicBlock & MBB
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition: MD5.cpp:58
This file declares the MachineConstantPool class which is an abstract constant pool to keep track of ...
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
static std::optional< APInt > getSplatableConstant(const Constant *C, unsigned SplatBitWidth)
static std::optional< APInt > extractConstantBits(const Constant *C)
static const Constant * getConstantFromPool(const MachineInstr &MI, const MachineOperand &Op)
static Constant * rebuildSplatableConstant(const Constant *C, unsigned SplatBitWidth)
#define DEBUG_TYPE
Class for arbitrary precision integers.
Definition: APInt.h:76
static APInt getSplat(unsigned NewLen, const APInt &V)
Return a value containing V broadcasted over NewLen bits.
Definition: APInt.cpp:620
void insertBits(const APInt &SubBits, unsigned bitPosition)
Insert the bits from a smaller APInt starting at bitPosition.
Definition: APInt.cpp:368
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:178
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
static Constant * get(LLVMContext &Context, ArrayRef< uint8_t > Elts)
get() constructors - Return a constant with vector type with an element count and element type matchi...
Definition: Constants.cpp:2906
static Constant * getFP(Type *ElementType, ArrayRef< uint16_t > Elts)
getFP() constructors - Return a constant of vector type with a float element type taken from argument...
Definition: Constants.cpp:2943
This is an important base class in LLVM.
Definition: Constant.h:41
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
Definition: Constants.cpp:418
This class represents an Operation in the Expression.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
This class is a data container for one entry in a MachineConstantPool.
bool isMachineConstantPoolEntry() const
isMachineConstantPoolEntry - Return true if the MachineConstantPoolEntry is indeed a target specific ...
union llvm::MachineConstantPoolEntry::@195 Val
The constant itself.
The MachineConstantPool class keeps track of constants referenced by a function which must be spilled...
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
virtual MachineFunctionProperties getRequiredProperties() const
Properties which a MachineFunction may have at a given point in time.
MachineFunctionProperties & set(Property P)
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
Representation of each machine instruction.
Definition: MachineInstr.h:68
MachineOperand class - Representation of each machine instruction operand.
void setIndex(int Idx)
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isFloatTy() const
Return true if this is 'float', a 32-bit IEEE fp type.
Definition: Type.h:154
bool is16bitFPTy() const
Return true if this is a 16-bit float type.
Definition: Type.h:149
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition: Type.h:129
bool isDoubleTy() const
Return true if this is 'double', a 64-bit IEEE fp type.
Definition: Type.h:157
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:348
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
@ AddrNumOperands
AddrNumOperands - Total number of operands in a memory reference.
Definition: X86BaseInfo.h:41
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
const X86MemoryFoldTableEntry * lookupBroadcastFoldTable(unsigned MemOp, unsigned BroadcastBits)
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:184
FunctionPass * createX86FixupVectorConstants()
Return a pass that reduces the size of vector constant pool loads.
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
Machine model for scheduling, bundling, and heuristics.
Definition: MCSchedule.h:253