LLVM 20.0.0git
NVPTXISelLowering.cpp
Go to the documentation of this file.
1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
16#include "NVPTX.h"
17#include "NVPTXSubtarget.h"
18#include "NVPTXTargetMachine.h"
20#include "NVPTXUtilities.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/StringRef.h"
36#include "llvm/IR/Argument.h"
37#include "llvm/IR/Attributes.h"
38#include "llvm/IR/Constants.h"
39#include "llvm/IR/DataLayout.h"
42#include "llvm/IR/FPEnv.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
45#include "llvm/IR/Instruction.h"
47#include "llvm/IR/IntrinsicsNVPTX.h"
48#include "llvm/IR/Module.h"
49#include "llvm/IR/Type.h"
50#include "llvm/IR/Value.h"
60#include <algorithm>
61#include <cassert>
62#include <cmath>
63#include <cstdint>
64#include <iterator>
65#include <optional>
66#include <string>
67#include <utility>
68#include <vector>
69
70#define DEBUG_TYPE "nvptx-lower"
71
72using namespace llvm;
73
74static std::atomic<unsigned> GlobalUniqueCallSite;
75
77 "nvptx-sched4reg",
78 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
79
81 "nvptx-fma-level", cl::Hidden,
82 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
83 " 1: do it 2: do it aggressively"),
84 cl::init(2));
85
87 "nvptx-prec-divf32", cl::Hidden,
88 cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
89 " IEEE Compliant F32 div.rnd if available."),
90 cl::init(2));
91
93 "nvptx-prec-sqrtf32", cl::Hidden,
94 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
95 cl::init(true));
96
97/// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
98/// does NOT use lg2.approx for log2, so this is disabled by default.
100 "nvptx-approx-log2f32",
101 cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
102 cl::init(false));
103
105 "nvptx-force-min-byval-param-align", cl::Hidden,
106 cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
107 " params of device functions."),
108 cl::init(false));
109
111 if (UsePrecDivF32.getNumOccurrences() > 0) {
112 // If nvptx-prec-div32=N is used on the command-line, always honor it
113 return UsePrecDivF32;
114 } else {
115 // Otherwise, use div.approx if fast math is enabled
116 if (getTargetMachine().Options.UnsafeFPMath)
117 return 0;
118 else
119 return 2;
120 }
121}
122
125 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
126 return UsePrecSqrtF32;
127 } else {
128 // Otherwise, use sqrt.approx if fast math is enabled
130 }
131}
132
136}
137
138static bool IsPTXVectorType(MVT VT) {
139 switch (VT.SimpleTy) {
140 default:
141 return false;
142 case MVT::v2i1:
143 case MVT::v4i1:
144 case MVT::v2i8:
145 case MVT::v4i8:
146 case MVT::v8i8: // <2 x i8x4>
147 case MVT::v16i8: // <4 x i8x4>
148 case MVT::v2i16:
149 case MVT::v4i16:
150 case MVT::v8i16: // <4 x i16x2>
151 case MVT::v2i32:
152 case MVT::v4i32:
153 case MVT::v2i64:
154 case MVT::v2f16:
155 case MVT::v4f16:
156 case MVT::v8f16: // <4 x f16x2>
157 case MVT::v2bf16:
158 case MVT::v4bf16:
159 case MVT::v8bf16: // <4 x bf16x2>
160 case MVT::v2f32:
161 case MVT::v4f32:
162 case MVT::v2f64:
163 return true;
164 }
165}
166
167static bool Is16bitsType(MVT VT) {
168 return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
169 VT.SimpleTy == MVT::i16);
170}
171
172// When legalizing vector loads/stores, this function is called, which does two
173// things:
174// 1. Determines Whether the vector is something we want to custom lower,
175// std::nullopt is returned if we do not want to custom lower it.
176// 2. If we do want to handle it, returns two parameters:
177// - unsigned int NumElts - The number of elements in the final vector
178// - EVT EltVT - The type of the elements in the final vector
179static std::optional<std::pair<unsigned int, EVT>>
181 if (!VectorVT.isVector() || !VectorVT.isSimple())
182 return std::nullopt;
183
184 EVT EltVT = VectorVT.getVectorElementType();
185 unsigned NumElts = VectorVT.getVectorNumElements();
186
187 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
188 // legal. We can (and should) split that into 2 stores of <2 x double> here
189 // but I'm leaving that as a TODO for now.
190 switch (VectorVT.getSimpleVT().SimpleTy) {
191 default:
192 return std::nullopt;
193 case MVT::v2i8:
194 case MVT::v2i16:
195 case MVT::v2i32:
196 case MVT::v2i64:
197 case MVT::v2f16:
198 case MVT::v2bf16:
199 case MVT::v2f32:
200 case MVT::v2f64:
201 case MVT::v4i8:
202 case MVT::v4i16:
203 case MVT::v4i32:
204 case MVT::v4f16:
205 case MVT::v4bf16:
206 case MVT::v4f32:
207 // This is a "native" vector type
208 return std::pair(NumElts, EltVT);
209 case MVT::v8i8: // <2 x i8x4>
210 case MVT::v8f16: // <4 x f16x2>
211 case MVT::v8bf16: // <4 x bf16x2>
212 case MVT::v8i16: // <4 x i16x2>
213 case MVT::v16i8: // <4 x i8x4>
214 // This can be upsized into a "native" vector type.
215 // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
216 // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
217 // vectorized loads/stores with the actual element type for i8/i16 as that
218 // would require v8/v16 variants that do not exist.
219 // In order to load/store such vectors efficiently, here in Type
220 // Legalization, we split the vector into word-sized chunks (v2x16/v4i8).
221 // Later, we will lower to PTX as vectors of b32.
222
223 // Number of elements to pack in one word.
224 unsigned NPerWord = 32 / EltVT.getSizeInBits();
225
226 return std::pair(NumElts / NPerWord,
227 MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord));
228 }
229
230 llvm_unreachable("All cases in switch should return.");
231}
232
233/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
234/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
235/// into their primitive components.
236/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
237/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
238/// LowerCall, and LowerReturn.
239static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
240 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
241 SmallVectorImpl<uint64_t> *Offsets = nullptr,
242 uint64_t StartingOffset = 0) {
243 SmallVector<EVT, 16> TempVTs;
244 SmallVector<uint64_t, 16> TempOffsets;
245
246 // Special case for i128 - decompose to (i64, i64)
247 if (Ty->isIntegerTy(128)) {
248 ValueVTs.push_back(EVT(MVT::i64));
249 ValueVTs.push_back(EVT(MVT::i64));
250
251 if (Offsets) {
252 Offsets->push_back(StartingOffset + 0);
253 Offsets->push_back(StartingOffset + 8);
254 }
255
256 return;
257 }
258
259 // Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
260 if (StructType *STy = dyn_cast<StructType>(Ty)) {
261 auto const *SL = DL.getStructLayout(STy);
262 auto ElementNum = 0;
263 for(auto *EI : STy->elements()) {
264 ComputePTXValueVTs(TLI, DL, EI, ValueVTs, Offsets,
265 StartingOffset + SL->getElementOffset(ElementNum));
266 ++ElementNum;
267 }
268 return;
269 }
270
271 // Given an array type, recursively traverse the elements with custom ComputePTXValueVTs.
272 if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
273 Type *EltTy = ATy->getElementType();
274 uint64_t EltSize = DL.getTypeAllocSize(EltTy);
275 for (int I : llvm::seq<int>(ATy->getNumElements()))
276 ComputePTXValueVTs(TLI, DL, EltTy, ValueVTs, Offsets, StartingOffset + I * EltSize);
277 return;
278 }
279
280 ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
281 for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
282 EVT VT = TempVTs[i];
283 uint64_t Off = TempOffsets[i];
284 // Split vectors into individual elements, except for v2f16, which
285 // we will pass as a single scalar.
286 if (VT.isVector()) {
287 unsigned NumElts = VT.getVectorNumElements();
288 EVT EltVT = VT.getVectorElementType();
289 // We require power-of-2 sized vectors becuase
290 // TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
291 // ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
292 // vectors.
293 if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
294 isPowerOf2_32(NumElts)) {
295 // Vectors with an even number of f16 elements will be passed to
296 // us as an array of v2f16/v2bf16 elements. We must match this so we
297 // stay in sync with Ins/Outs.
298 switch (EltVT.getSimpleVT().SimpleTy) {
299 case MVT::f16:
300 EltVT = MVT::v2f16;
301 break;
302 case MVT::bf16:
303 EltVT = MVT::v2bf16;
304 break;
305 case MVT::i16:
306 EltVT = MVT::v2i16;
307 break;
308 default:
309 llvm_unreachable("Unexpected type");
310 }
311 NumElts /= 2;
312 } else if (EltVT.getSimpleVT() == MVT::i8 &&
313 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) ||
314 NumElts == 3)) {
315 // v*i8 are formally lowered as v4i8
316 EltVT = MVT::v4i8;
317 NumElts = (NumElts + 3) / 4;
318 } else if (EltVT.getSimpleVT() == MVT::i8 && NumElts == 2) {
319 // v2i8 is promoted to v2i16
320 NumElts = 1;
321 EltVT = MVT::v2i16;
322 }
323 for (unsigned j = 0; j != NumElts; ++j) {
324 ValueVTs.push_back(EltVT);
325 if (Offsets)
326 Offsets->push_back(Off + j * EltVT.getStoreSize());
327 }
328 } else {
329 ValueVTs.push_back(VT);
330 if (Offsets)
331 Offsets->push_back(Off);
332 }
333 }
334}
335
336/// PromoteScalarIntegerPTX
337/// Used to make sure the arguments/returns are suitable for passing
338/// and promote them to a larger size if they're not.
339///
340/// The promoted type is placed in \p PromoteVT if the function returns true.
341static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) {
342 if (VT.isScalarInteger()) {
343 switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
344 default:
346 "Promotion is not suitable for scalars of size larger than 64-bits");
347 case 1:
348 *PromotedVT = MVT::i1;
349 break;
350 case 2:
351 case 4:
352 case 8:
353 *PromotedVT = MVT::i8;
354 break;
355 case 16:
356 *PromotedVT = MVT::i16;
357 break;
358 case 32:
359 *PromotedVT = MVT::i32;
360 break;
361 case 64:
362 *PromotedVT = MVT::i64;
363 break;
364 }
365 return EVT(*PromotedVT) != VT;
366 }
367 return false;
368}
369
370// Check whether we can merge loads/stores of some of the pieces of a
371// flattened function parameter or return value into a single vector
372// load/store.
373//
374// The flattened parameter is represented as a list of EVTs and
375// offsets, and the whole structure is aligned to ParamAlignment. This
376// function determines whether we can load/store pieces of the
377// parameter starting at index Idx using a single vectorized op of
378// size AccessSize. If so, it returns the number of param pieces
379// covered by the vector op. Otherwise, it returns 1.
381 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
382 const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
383
384 // Can't vectorize if param alignment is not sufficient.
385 if (ParamAlignment < AccessSize)
386 return 1;
387 // Can't vectorize if offset is not aligned.
388 if (Offsets[Idx] & (AccessSize - 1))
389 return 1;
390
391 EVT EltVT = ValueVTs[Idx];
392 unsigned EltSize = EltVT.getStoreSize();
393
394 // Element is too large to vectorize.
395 if (EltSize >= AccessSize)
396 return 1;
397
398 unsigned NumElts = AccessSize / EltSize;
399 // Can't vectorize if AccessBytes if not a multiple of EltSize.
400 if (AccessSize != EltSize * NumElts)
401 return 1;
402
403 // We don't have enough elements to vectorize.
404 if (Idx + NumElts > ValueVTs.size())
405 return 1;
406
407 // PTX ISA can only deal with 2- and 4-element vector ops.
408 if (NumElts != 4 && NumElts != 2)
409 return 1;
410
411 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
412 // Types do not match.
413 if (ValueVTs[j] != EltVT)
414 return 1;
415
416 // Elements are not contiguous.
417 if (Offsets[j] - Offsets[j - 1] != EltSize)
418 return 1;
419 }
420 // OK. We can vectorize ValueVTs[i..i+NumElts)
421 return NumElts;
422}
423
424// Flags for tracking per-element vectorization state of loads/stores
425// of a flattened function parameter or return value.
427 PVF_INNER = 0x0, // Middle elements of a vector.
428 PVF_FIRST = 0x1, // First element of the vector.
429 PVF_LAST = 0x2, // Last element of the vector.
430 // Scalar is effectively a 1-element vector.
433
434// Computes whether and how we can vectorize the loads/stores of a
435// flattened function parameter or return value.
436//
437// The flattened parameter is represented as the list of ValueVTs and
438// Offsets, and is aligned to ParamAlignment bytes. We return a vector
439// of the same size as ValueVTs indicating how each piece should be
440// loaded/stored (i.e. as a scalar, or as part of a vector
441// load/store).
444 const SmallVectorImpl<uint64_t> &Offsets,
445 Align ParamAlignment, bool IsVAArg = false) {
446 // Set vector size to match ValueVTs and mark all elements as
447 // scalars by default.
449 VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
450
451 if (IsVAArg)
452 return VectorInfo;
453
454 // Check what we can vectorize using 128/64/32-bit accesses.
455 for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
456 // Skip elements we've already processed.
457 assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
458 for (unsigned AccessSize : {16, 8, 4, 2}) {
459 unsigned NumElts = CanMergeParamLoadStoresStartingAt(
460 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
461 // Mark vectorized elements.
462 switch (NumElts) {
463 default:
464 llvm_unreachable("Unexpected return value");
465 case 1:
466 // Can't vectorize using this size, try next smaller size.
467 continue;
468 case 2:
469 assert(I + 1 < E && "Not enough elements.");
470 VectorInfo[I] = PVF_FIRST;
471 VectorInfo[I + 1] = PVF_LAST;
472 I += 1;
473 break;
474 case 4:
475 assert(I + 3 < E && "Not enough elements.");
476 VectorInfo[I] = PVF_FIRST;
477 VectorInfo[I + 1] = PVF_INNER;
478 VectorInfo[I + 2] = PVF_INNER;
479 VectorInfo[I + 3] = PVF_LAST;
480 I += 3;
481 break;
482 }
483 // Break out of the inner loop because we've already succeeded
484 // using largest possible AccessSize.
485 break;
486 }
487 }
488 return VectorInfo;
489}
490
492 SDValue Value) {
493 if (Value->getValueType(0) == VT)
494 return Value;
495 return DAG.getNode(ISD::BITCAST, DL, VT, Value);
496}
497
498// NVPTXTargetLowering Constructor.
500 const NVPTXSubtarget &STI)
501 : TargetLowering(TM), nvTM(&TM), STI(STI) {
502 // always lower memset, memcpy, and memmove intrinsics to load/store
503 // instructions, rather
504 // then generating calls to memset, mempcy or memmove.
508
511
512 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
513 // condition branches.
514 setJumpIsExpensive(true);
515
516 // Wide divides are _very_ slow. Try to reduce the width of the divide if
517 // possible.
518 addBypassSlowDiv(64, 32);
519
520 // By default, use the Source scheduling
521 if (sched4reg)
523 else
525
526 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
527 LegalizeAction NoF16Action) {
528 bool IsOpSupported = STI.allowFP16Math();
529 switch (Op) {
530 // Several FP16 instructions are available on sm_80 only.
531 case ISD::FMINNUM:
532 case ISD::FMAXNUM:
535 case ISD::FMAXIMUM:
536 case ISD::FMINIMUM:
537 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
538 break;
539 case ISD::FEXP2:
540 IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
541 break;
542 }
543 setOperationAction(Op, VT, IsOpSupported ? Action : NoF16Action);
544 };
545
546 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
547 LegalizeAction NoBF16Action) {
548 bool IsOpSupported = STI.hasNativeBF16Support(Op);
550 Op, VT, IsOpSupported ? Action : NoBF16Action);
551 };
552
553 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
554 LegalizeAction NoI16x2Action) {
555 bool IsOpSupported = false;
556 // instructions are available on sm_90 only
557 switch (Op) {
558 case ISD::ADD:
559 case ISD::SMAX:
560 case ISD::SMIN:
561 case ISD::UMIN:
562 case ISD::UMAX:
563 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
564 break;
565 }
566 setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
567 };
568
569 addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
570 addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
571 addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
572 addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
573 addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
574 addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
575 addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
576 addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
577 addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
578 addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
579 addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
580 addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
581
582 // Conversion to/from FP16/FP16x2 is always legal.
587
589 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
591
592 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
593 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
594
595 // Conversion to/from BFP16/BFP16x2 is always legal.
600
601 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
602 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
603 if (getOperationAction(ISD::SETCC, MVT::bf16) == Promote)
604 AddPromotedToType(ISD::SETCC, MVT::bf16, MVT::f32);
605
606 // Conversion to/from i16/i16x2 is always legal.
611
616
617 // Custom conversions to/from v2i8.
619
620 // Only logical ops can be done on v4i8 directly, others must be done
621 // elementwise.
638 MVT::v4i8, Expand);
639
640 // Operations not directly supported by NVPTX.
641 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
642 MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
643 MVT::i32, MVT::i64}) {
646 }
647
648 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
649 // For others we will expand to a SHL/SRA pair.
656
663
666
668 {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
669 Expand);
670
671 if (STI.hasHWROT32())
673
675
678
681
682 // We want to legalize constant related memmove and memcopy
683 // intrinsics.
685
686 // Turn FP extload into load/fpextend
687 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
688 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
689 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
690 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
691 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
692 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
693 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
694 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
695 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
696 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
697 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
698 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
699 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
700 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
701 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
702 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
703 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
704 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
705 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
706 // Turn FP truncstore into trunc + store.
707 // FIXME: vector types should also be expanded
708 setTruncStoreAction(MVT::f32, MVT::f16, Expand);
709 setTruncStoreAction(MVT::f64, MVT::f16, Expand);
710 setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
711 setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
712 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
713
714 // PTX does not support load / store predicate registers
717
718 for (MVT VT : MVT::integer_valuetypes()) {
722 setTruncStoreAction(VT, MVT::i1, Expand);
723 }
724
728 MVT::i1, Expand);
729
730 // expand extload of vector of integers.
732 MVT::v2i8, Expand);
733 setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
734
735 // This is legal in NVPTX
740
741 setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom);
743
744 // TRAP can be lowered to PTX trap
745 setOperationAction(ISD::TRAP, MVT::Other, Legal);
746 // DEBUGTRAP can be lowered to PTX brkpt
748
749 // Register custom handling for vector loads/stores
751 if (IsPTXVectorType(VT)) {
755 }
756 }
757
758 // Support varargs.
763
764 // Custom handling for i8 intrinsics
766
767 for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) {
773
776 }
777
778 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
779 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
780 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
781 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
782 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
783 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
784 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
785
786 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
787 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
788 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
789 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
790 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
791 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
792
793 // Other arithmetic and logic ops are unsupported.
797 MVT::v2i16, Expand);
798
803 if (STI.getPTXVersion() >= 43) {
808 }
809
811 setOperationAction(ISD::CTTZ, MVT::v2i16, Expand);
814
815 // PTX does not directly support SELP of i1, so promote to i32 first
817
818 // PTX cannot multiply two i64s in a single instruction.
821
822 // We have some custom DAG combine patterns for these nodes
826
827 // setcc for f16x2 and bf16x2 needs special handling to prevent
828 // legalizer's attempt to scalarize it due to v2i1 not being legal.
829 if (STI.allowFP16Math() || STI.hasBF16Math())
831
832 // Promote fp16 arithmetic if fp16 hardware isn't available or the
833 // user passed --nvptx-no-fp16-math. The flag is useful because,
834 // although sm_53+ GPUs have some sort of FP16 support in
835 // hardware, only sm_53 and sm_60 have full implementation. Others
836 // only have token amount of hardware and are likely to run faster
837 // by using fp32 units instead.
838 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
839 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
840 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
841 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
842 // bf16 must be promoted to f32.
843 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
844 if (getOperationAction(Op, MVT::bf16) == Promote)
845 AddPromotedToType(Op, MVT::bf16, MVT::f32);
846 }
847
848 // On SM80, we select add/mul/sub as fma to avoid promotion to float
849 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
850 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
853 }
854 }
855 }
856
857 // f16/f16x2 neg was introduced in PTX 60, SM_53.
858 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
859 STI.getPTXVersion() >= 60 &&
860 STI.allowFP16Math();
861 for (const auto &VT : {MVT::f16, MVT::v2f16})
863 IsFP16FP16x2NegAvailable ? Legal : Expand);
864
865 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
866 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
867 // (would be) Library functions.
868
869 // These map to conversion instructions for scalar FP types.
870 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
872 setOperationAction(Op, MVT::f16, Legal);
873 setOperationAction(Op, MVT::f32, Legal);
874 setOperationAction(Op, MVT::f64, Legal);
875 setOperationAction(Op, MVT::v2f16, Expand);
876 setOperationAction(Op, MVT::v2bf16, Expand);
877 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
878 if (getOperationAction(Op, MVT::bf16) == Promote)
879 AddPromotedToType(Op, MVT::bf16, MVT::f32);
880 }
881
882 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
884 }
885 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
886 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
889 }
890 }
891
892 // sm_80 only has conversions between f32 and bf16. Custom lower all other
893 // bf16 conversions.
894 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
895 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
898 VT, Custom);
899 }
902 MVT::bf16, Custom);
903 }
904
911 AddPromotedToType(ISD::FROUND, MVT::bf16, MVT::f32);
912
913 // 'Expand' implements FCOPYSIGN without calling an external library.
920
921 // These map to corresponding instructions for f32/f64. f16 must be
922 // promoted to f32. v2f16 is expanded to f16, which is then promoted
923 // to f32.
924 for (const auto &Op :
926 setOperationAction(Op, MVT::f16, Promote);
927 setOperationAction(Op, MVT::f32, Legal);
928 setOperationAction(Op, MVT::f64, Legal);
929 setOperationAction(Op, MVT::v2f16, Expand);
930 setOperationAction(Op, MVT::v2bf16, Expand);
931 setOperationAction(Op, MVT::bf16, Promote);
932 AddPromotedToType(Op, MVT::bf16, MVT::f32);
933 }
934
935 setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
936 if (STI.getPTXVersion() >= 65) {
937 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
938 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
939 } else {
941 setOperationAction(ISD::FABS, MVT::v2f16, Expand);
942 }
943 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
944 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
945 if (getOperationAction(ISD::FABS, MVT::bf16) == Promote)
946 AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32);
947
948 for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
949 setOperationAction(Op, MVT::f32, Legal);
950 setOperationAction(Op, MVT::f64, Legal);
951 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
952 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
953 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
954 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
955 if (getOperationAction(Op, MVT::bf16) == Promote)
956 AddPromotedToType(Op, MVT::bf16, MVT::f32);
957 }
958 bool SupportsF32MinMaxNaN =
959 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
960 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
961 setOperationAction(Op, MVT::f32, SupportsF32MinMaxNaN ? Legal : Expand);
962 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
963 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
964 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
965 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
966 }
967
968 // Custom lowering for inline asm with 128-bit operands
971
972 // FEXP2 support:
973 // - f32
974 // - f16/f16x2 (sm_70+, PTX 7.0+)
975 // - bf16/bf16x2 (sm_90+, PTX 7.8+)
976 // When f16/bf16 types aren't supported, they are promoted/expanded to f32.
978 setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
979 setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
980 setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
981 setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
982
983 // FLOG2 supports f32 only
984 // f16/bf16 types aren't supported, but they are promoted/expanded to f32.
985 if (UseApproxLog2F32) {
987 setOperationPromotedToType(ISD::FLOG2, MVT::f16, MVT::f32);
988 setOperationPromotedToType(ISD::FLOG2, MVT::bf16, MVT::f32);
989 setOperationAction(ISD::FLOG2, {MVT::v2f16, MVT::v2bf16}, Expand);
990 }
991
992 // No FPOW or FREM in PTX.
993
994 // Now deduce the information based on the above mentioned
995 // actions
997
998 setMinCmpXchgSizeInBits(STI.hasAtomCas16() ? 16 : 32);
1001}
1002
1003const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1004
1005#define MAKE_CASE(V) \
1006 case V: \
1007 return #V;
1008
1009 switch ((NVPTXISD::NodeType)Opcode) {
1011 break;
1012
1075 }
1076 return nullptr;
1077
1078#undef MAKE_CASE
1079}
1080
1083 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1084 VT.getScalarType() == MVT::i1)
1085 return TypeSplitVector;
1087}
1088
1090 int Enabled, int &ExtraSteps,
1091 bool &UseOneConst,
1092 bool Reciprocal) const {
1095 return SDValue();
1096
1097 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1098 ExtraSteps = 0;
1099
1100 SDLoc DL(Operand);
1101 EVT VT = Operand.getValueType();
1102 bool Ftz = useF32FTZ(DAG.getMachineFunction());
1103
1104 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1105 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
1106 DAG.getConstant(IID, DL, MVT::i32), Operand);
1107 };
1108
1109 // The sqrt and rsqrt refinement processes assume we always start out with an
1110 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1111 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1112 // any refinement, we must return a regular sqrt.
1113 if (Reciprocal || ExtraSteps > 0) {
1114 if (VT == MVT::f32)
1115 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1116 : Intrinsic::nvvm_rsqrt_approx_f);
1117 else if (VT == MVT::f64)
1118 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1119 else
1120 return SDValue();
1121 } else {
1122 if (VT == MVT::f32)
1123 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1124 : Intrinsic::nvvm_sqrt_approx_f);
1125 else {
1126 // There's no sqrt.approx.f64 instruction, so we emit
1127 // reciprocal(rsqrt(x)). This is faster than
1128 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1129 // x * rsqrt(x).)
1130 return DAG.getNode(
1132 DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32),
1133 MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1134 }
1135 }
1136}
1137
1138SDValue
1140 SDLoc dl(Op);
1141 const GlobalAddressSDNode *GAN = cast<GlobalAddressSDNode>(Op);
1142 auto PtrVT = getPointerTy(DAG.getDataLayout(), GAN->getAddressSpace());
1143 Op = DAG.getTargetGlobalAddress(GAN->getGlobal(), dl, PtrVT);
1144 return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
1145}
1146
1147static bool IsTypePassedAsArray(const Type *Ty) {
1148 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
1149 Ty->isHalfTy() || Ty->isBFloatTy();
1150}
1151
1153 const DataLayout &DL, Type *retTy, const ArgListTy &Args,
1154 const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
1155 std::optional<std::pair<unsigned, const APInt &>> VAInfo,
1156 const CallBase &CB, unsigned UniqueCallSite) const {
1157 auto PtrVT = getPointerTy(DL);
1158
1159 bool isABI = (STI.getSmVersion() >= 20);
1160 assert(isABI && "Non-ABI compilation is not supported");
1161 if (!isABI)
1162 return "";
1163
1164 std::string Prototype;
1165 raw_string_ostream O(Prototype);
1166 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1167
1168 if (retTy->getTypeID() == Type::VoidTyID) {
1169 O << "()";
1170 } else {
1171 O << "(";
1172 if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
1173 !IsTypePassedAsArray(retTy)) {
1174 unsigned size = 0;
1175 if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
1176 size = ITy->getBitWidth();
1177 } else {
1178 assert(retTy->isFloatingPointTy() &&
1179 "Floating point type expected here");
1180 size = retTy->getPrimitiveSizeInBits();
1181 }
1182 // PTX ABI requires all scalar return values to be at least 32
1183 // bits in size. fp16 normally uses .b16 as its storage type in
1184 // PTX, so its size must be adjusted here, too.
1186
1187 O << ".param .b" << size << " _";
1188 } else if (isa<PointerType>(retTy)) {
1189 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1190 } else if (IsTypePassedAsArray(retTy)) {
1191 O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
1192 << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
1193 } else {
1194 llvm_unreachable("Unknown return type");
1195 }
1196 O << ") ";
1197 }
1198 O << "_ (";
1199
1200 bool first = true;
1201
1202 unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
1203 for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
1204 Type *Ty = Args[i].Ty;
1205 if (!first) {
1206 O << ", ";
1207 }
1208 first = false;
1209
1210 if (!Outs[OIdx].Flags.isByVal()) {
1211 if (IsTypePassedAsArray(Ty)) {
1212 Align ParamAlign =
1213 getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
1214 O << ".param .align " << ParamAlign.value() << " .b8 ";
1215 O << "_";
1216 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1217 // update the index for Outs
1218 SmallVector<EVT, 16> vtparts;
1219 ComputeValueVTs(*this, DL, Ty, vtparts);
1220 if (unsigned len = vtparts.size())
1221 OIdx += len - 1;
1222 continue;
1223 }
1224 // i8 types in IR will be i16 types in SDAG
1225 assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
1226 (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
1227 "type mismatch between callee prototype and arguments");
1228 // scalar type
1229 unsigned sz = 0;
1230 if (isa<IntegerType>(Ty)) {
1231 sz = cast<IntegerType>(Ty)->getBitWidth();
1233 } else if (isa<PointerType>(Ty)) {
1234 sz = PtrVT.getSizeInBits();
1235 } else {
1236 sz = Ty->getPrimitiveSizeInBits();
1237 }
1238 O << ".param .b" << sz << " ";
1239 O << "_";
1240 continue;
1241 }
1242
1243 // Indirect calls need strict ABI alignment so we disable optimizations by
1244 // not providing a function to optimize.
1245 Type *ETy = Args[i].IndirectType;
1246 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1247 Align ParamByValAlign =
1248 getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
1249
1250 O << ".param .align " << ParamByValAlign.value() << " .b8 ";
1251 O << "_";
1252 O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
1253 }
1254
1255 if (VAInfo)
1256 O << (first ? "" : ",") << " .param .align " << VAInfo->second
1257 << " .b8 _[]\n";
1258 O << ")";
1260 O << " .noreturn";
1261 O << ";";
1262
1263 return Prototype;
1264}
1265
1267 const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
1268 return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
1269}
1270
1271Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
1272 unsigned Idx,
1273 const DataLayout &DL) const {
1274 if (!CB) {
1275 // CallSite is zero, fallback to ABI type alignment
1276 return DL.getABITypeAlign(Ty);
1277 }
1278
1279 const Function *DirectCallee = CB->getCalledFunction();
1280
1281 if (!DirectCallee) {
1282 // We don't have a direct function symbol, but that may be because of
1283 // constant cast instructions in the call.
1284
1285 // With bitcast'd call targets, the instruction will be the call
1286 if (const auto *CI = dyn_cast<CallInst>(CB)) {
1287 // Check if we have call alignment metadata
1288 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1289 return StackAlign.value();
1290 }
1291 DirectCallee = getMaybeBitcastedCallee(CB);
1292 }
1293
1294 // Check for function alignment information if we found that the
1295 // ultimate target is a Function
1296 if (DirectCallee)
1297 return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
1298
1299 // Call is indirect, fall back to the ABI type alignment
1300 return DL.getABITypeAlign(Ty);
1301}
1302
1303static bool adjustElementType(EVT &ElementType) {
1304 switch (ElementType.getSimpleVT().SimpleTy) {
1305 default:
1306 return false;
1307 case MVT::f16:
1308 case MVT::bf16:
1309 ElementType = MVT::i16;
1310 return true;
1311 case MVT::f32:
1312 case MVT::v2f16:
1313 case MVT::v2bf16:
1314 ElementType = MVT::i32;
1315 return true;
1316 case MVT::f64:
1317 ElementType = MVT::i64;
1318 return true;
1319 }
1320}
1321
1322// Use byte-store when the param address of the argument value is unaligned.
1323// This may happen when the return value is a field of a packed structure.
1324//
1325// This is called in LowerCall() when passing the param values.
1327 uint64_t Offset, EVT ElementType,
1328 SDValue StVal, SDValue &InGlue,
1329 unsigned ArgID, const SDLoc &dl) {
1330 // Bit logic only works on integer types
1331 if (adjustElementType(ElementType))
1332 StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
1333
1334 // Store each byte
1335 SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1336 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1337 // Shift the byte to the last byte position
1338 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
1339 DAG.getConstant(i * 8, dl, MVT::i32));
1340 SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
1341 DAG.getConstant(Offset + i, dl, MVT::i32),
1342 ShiftVal, InGlue};
1343 // Trunc store only the last byte by using
1344 // st.param.b8
1345 // The register type can be larger than b8.
1346 Chain = DAG.getMemIntrinsicNode(
1347 NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
1349 InGlue = Chain.getValue(1);
1350 }
1351 return Chain;
1352}
1353
1354// Use byte-load when the param adress of the returned value is unaligned.
1355// This may happen when the returned value is a field of a packed structure.
1356static SDValue
1358 EVT ElementType, SDValue &InGlue,
1359 SmallVectorImpl<SDValue> &TempProxyRegOps,
1360 const SDLoc &dl) {
1361 // Bit logic only works on integer types
1362 EVT MergedType = ElementType;
1363 adjustElementType(MergedType);
1364
1365 // Load each byte and construct the whole value. Initial value to 0
1366 SDValue RetVal = DAG.getConstant(0, dl, MergedType);
1367 // LoadParamMemI8 loads into i16 register only
1368 SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
1369 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1370 SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
1371 DAG.getConstant(Offset + i, dl, MVT::i32),
1372 InGlue};
1373 // This will be selected to LoadParamMemI8
1374 SDValue LdVal =
1375 DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
1376 MVT::i8, MachinePointerInfo(), Align(1));
1377 SDValue TmpLdVal = LdVal.getValue(0);
1378 Chain = LdVal.getValue(1);
1379 InGlue = LdVal.getValue(2);
1380
1381 TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
1382 TmpLdVal.getSimpleValueType(), TmpLdVal);
1383 TempProxyRegOps.push_back(TmpLdVal);
1384
1385 SDValue CMask = DAG.getConstant(255, dl, MergedType);
1386 SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
1387 // Need to extend the i16 register to the whole width.
1388 TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
1389 // Mask off the high bits. Leave only the lower 8bits.
1390 // Do this because we are using loadparam.b8.
1391 TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
1392 // Shift and merge
1393 TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
1394 RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
1395 }
1396 if (ElementType != MergedType)
1397 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
1398
1399 return RetVal;
1400}
1401
1403 const GlobalAddressSDNode *Func) {
1404 if (!Func)
1405 return false;
1406 if (auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal()))
1407 return CB->getFunctionType() != CalleeFunc->getFunctionType();
1408 return false;
1409}
1410
1412 SmallVectorImpl<SDValue> &InVals) const {
1413
1414 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1416 "Support for variadic functions (unsized array parameter) introduced "
1417 "in PTX ISA version 6.0 and requires target sm_30.");
1418
1419 SelectionDAG &DAG = CLI.DAG;
1420 SDLoc dl = CLI.DL;
1422 SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
1424 SDValue Chain = CLI.Chain;
1425 SDValue Callee = CLI.Callee;
1426 bool &isTailCall = CLI.IsTailCall;
1427 ArgListTy &Args = CLI.getArgs();
1428 Type *RetTy = CLI.RetTy;
1429 const CallBase *CB = CLI.CB;
1430 const DataLayout &DL = DAG.getDataLayout();
1431
1432 bool isABI = (STI.getSmVersion() >= 20);
1433 assert(isABI && "Non-ABI compilation is not supported");
1434 if (!isABI)
1435 return Chain;
1436
1437 // Variadic arguments.
1438 //
1439 // Normally, for each argument, we declare a param scalar or a param
1440 // byte array in the .param space, and store the argument value to that
1441 // param scalar or array starting at offset 0.
1442 //
1443 // In the case of the first variadic argument, we declare a vararg byte array
1444 // with size 0. The exact size of this array isn't known at this point, so
1445 // it'll be patched later. All the variadic arguments will be stored to this
1446 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1447 // initially set to 0, so it can be used for non-variadic arguments (which use
1448 // 0 offset) to simplify the code.
1449 //
1450 // After all vararg is processed, 'VAOffset' holds the size of the
1451 // vararg byte array.
1452
1453 SDValue VADeclareParam; // vararg byte array
1454 unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
1455 unsigned VAOffset = 0; // current offset in the param array
1456
1457 unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1);
1458 SDValue TempChain = Chain;
1459 Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
1460 SDValue InGlue = Chain.getValue(1);
1461
1462 unsigned ParamCount = 0;
1463 // Args.size() and Outs.size() need not match.
1464 // Outs.size() will be larger
1465 // * if there is an aggregate argument with multiple fields (each field
1466 // showing up separately in Outs)
1467 // * if there is a vector argument with more than typical vector-length
1468 // elements (generally if more than 4) where each vector element is
1469 // individually present in Outs.
1470 // So a different index should be used for indexing into Outs/OutVals.
1471 // See similar issue in LowerFormalArguments.
1472 unsigned OIdx = 0;
1473 // Declare the .params or .reg need to pass values
1474 // to the function
1475 for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
1476 EVT VT = Outs[OIdx].VT;
1477 Type *Ty = Args[i].Ty;
1478 bool IsVAArg = (i >= CLI.NumFixedArgs);
1479 bool IsByVal = Outs[OIdx].Flags.isByVal();
1480
1483
1484 assert((!IsByVal || Args[i].IndirectType) &&
1485 "byval arg must have indirect type");
1486 Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
1487 ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
1488
1489 Align ArgAlign;
1490 if (IsByVal) {
1491 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1492 // so we don't need to worry whether it's naturally aligned or not.
1493 // See TargetLowering::LowerCallTo().
1494 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1495 ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
1496 InitialAlign, DL);
1497 if (IsVAArg)
1498 VAOffset = alignTo(VAOffset, ArgAlign);
1499 } else {
1500 ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL);
1501 }
1502
1503 unsigned TypeSize =
1504 (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
1505 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1506
1507 bool NeedAlign; // Does argument declaration specify alignment?
1508 bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
1509 if (IsVAArg) {
1510 if (ParamCount == FirstVAArg) {
1511 SDValue DeclareParamOps[] = {
1512 Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
1513 DAG.getConstant(ParamCount, dl, MVT::i32),
1514 DAG.getConstant(1, dl, MVT::i32), InGlue};
1515 VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
1516 DeclareParamVTs, DeclareParamOps);
1517 }
1518 NeedAlign = PassAsArray;
1519 } else if (PassAsArray) {
1520 // declare .param .align <align> .b8 .param<n>[<size>];
1521 SDValue DeclareParamOps[] = {
1522 Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
1523 DAG.getConstant(ParamCount, dl, MVT::i32),
1524 DAG.getConstant(TypeSize, dl, MVT::i32), InGlue};
1525 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
1526 DeclareParamOps);
1527 NeedAlign = true;
1528 } else {
1529 // declare .param .b<size> .param<n>;
1530 if (VT.isInteger() || VT.isFloatingPoint()) {
1531 // PTX ABI requires integral types to be at least 32 bits in
1532 // size. FP16 is loaded/stored using i16, so it's handled
1533 // here as well.
1535 }
1536 SDValue DeclareScalarParamOps[] = {
1537 Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
1538 DAG.getConstant(TypeSize * 8, dl, MVT::i32),
1539 DAG.getConstant(0, dl, MVT::i32), InGlue};
1540 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
1541 DeclareScalarParamOps);
1542 NeedAlign = false;
1543 }
1544 InGlue = Chain.getValue(1);
1545
1546 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1547 // than 32-bits are sign extended or zero extended, depending on
1548 // whether they are signed or unsigned types. This case applies
1549 // only to scalar parameters and not to aggregate values.
1550 bool ExtendIntegerParam =
1551 Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
1552
1553 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
1554 SmallVector<SDValue, 6> StoreOperands;
1555 for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
1556 EVT EltVT = VTs[j];
1557 int CurOffset = Offsets[j];
1558 MaybeAlign PartAlign;
1559 if (NeedAlign)
1560 PartAlign = commonAlignment(ArgAlign, CurOffset);
1561
1562 SDValue StVal = OutVals[OIdx];
1563
1564 MVT PromotedVT;
1565 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
1566 EltVT = EVT(PromotedVT);
1567 }
1568 if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) {
1570 Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1571 StVal = DAG.getNode(Ext, dl, PromotedVT, StVal);
1572 }
1573
1574 if (IsByVal) {
1575 auto PtrVT = getPointerTy(DL);
1576 SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
1577 DAG.getConstant(CurOffset, dl, PtrVT));
1578 StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
1579 PartAlign);
1580 } else if (ExtendIntegerParam) {
1581 assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
1582 // zext/sext to i32
1583 StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
1585 dl, MVT::i32, StVal);
1586 }
1587
1588 if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) {
1589 // Use 16-bit registers for small stores as it's the
1590 // smallest general purpose register size supported by NVPTX.
1591 StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
1592 }
1593
1594 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1595 // scalar store. In such cases, fall back to byte stores.
1596 if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
1597 PartAlign.value() <
1598 DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
1599 assert(StoreOperands.empty() && "Unfinished preceeding store.");
1601 DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
1602 StVal, InGlue, ParamCount, dl);
1603
1604 // LowerUnalignedStoreParam took care of inserting the necessary nodes
1605 // into the SDAG, so just move on to the next element.
1606 if (!IsByVal)
1607 ++OIdx;
1608 continue;
1609 }
1610
1611 // New store.
1612 if (VectorInfo[j] & PVF_FIRST) {
1613 assert(StoreOperands.empty() && "Unfinished preceding store.");
1614 StoreOperands.push_back(Chain);
1615 StoreOperands.push_back(
1616 DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
1617
1618 StoreOperands.push_back(DAG.getConstant(
1619 IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
1620 dl, MVT::i32));
1621 }
1622
1623 // Record the value to store.
1624 StoreOperands.push_back(StVal);
1625
1626 if (VectorInfo[j] & PVF_LAST) {
1627 unsigned NumElts = StoreOperands.size() - 3;
1629 switch (NumElts) {
1630 case 1:
1632 break;
1633 case 2:
1635 break;
1636 case 4:
1638 break;
1639 default:
1640 llvm_unreachable("Invalid vector info.");
1641 }
1642
1643 StoreOperands.push_back(InGlue);
1644
1645 // Adjust type of the store op if we've extended the scalar
1646 // return value.
1647 EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1648
1649 Chain = DAG.getMemIntrinsicNode(
1650 Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
1651 TheStoreType, MachinePointerInfo(), PartAlign,
1653 InGlue = Chain.getValue(1);
1654
1655 // Cleanup.
1656 StoreOperands.clear();
1657
1658 // TODO: We may need to support vector types that can be passed
1659 // as scalars in variadic arguments.
1660 if (!IsByVal && IsVAArg) {
1661 assert(NumElts == 1 &&
1662 "Vectorization is expected to be disabled for variadics.");
1663 VAOffset += DL.getTypeAllocSize(
1664 TheStoreType.getTypeForEVT(*DAG.getContext()));
1665 }
1666 }
1667 if (!IsByVal)
1668 ++OIdx;
1669 }
1670 assert(StoreOperands.empty() && "Unfinished parameter store.");
1671 if (!IsByVal && VTs.size() > 0)
1672 --OIdx;
1673 ++ParamCount;
1674 if (IsByVal && IsVAArg)
1675 VAOffset += TypeSize;
1676 }
1677
1678 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
1679 MaybeAlign retAlignment = std::nullopt;
1680
1681 // Handle Result
1682 if (Ins.size() > 0) {
1683 SmallVector<EVT, 16> resvtparts;
1684 ComputeValueVTs(*this, DL, RetTy, resvtparts);
1685
1686 // Declare
1687 // .param .align N .b8 retval0[<size-in-bytes>], or
1688 // .param .b<size-in-bits> retval0
1689 unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
1690 if (!IsTypePassedAsArray(RetTy)) {
1691 resultsz = promoteScalarArgumentSize(resultsz);
1692 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1693 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
1694 DAG.getConstant(resultsz, dl, MVT::i32),
1695 DAG.getConstant(0, dl, MVT::i32), InGlue };
1696 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
1697 DeclareRetOps);
1698 InGlue = Chain.getValue(1);
1699 } else {
1700 retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
1701 assert(retAlignment && "retAlignment is guaranteed to be set");
1702 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1703 SDValue DeclareRetOps[] = {
1704 Chain, DAG.getConstant(retAlignment->value(), dl, MVT::i32),
1705 DAG.getConstant(resultsz / 8, dl, MVT::i32),
1706 DAG.getConstant(0, dl, MVT::i32), InGlue};
1707 Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
1708 DeclareRetOps);
1709 InGlue = Chain.getValue(1);
1710 }
1711 }
1712
1713 bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1714 // Set the size of the vararg param byte array if the callee is a variadic
1715 // function and the variadic part is not empty.
1716 if (HasVAArgs) {
1717 SDValue DeclareParamOps[] = {
1718 VADeclareParam.getOperand(0), VADeclareParam.getOperand(1),
1719 VADeclareParam.getOperand(2), DAG.getConstant(VAOffset, dl, MVT::i32),
1720 VADeclareParam.getOperand(4)};
1721 DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
1722 VADeclareParam->getVTList(), DeclareParamOps);
1723 }
1724
1725 // If the type of the callsite does not match that of the function, convert
1726 // the callsite to an indirect call.
1727 bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1728
1729 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1730 // between them we must rely on the call site value which is valid for
1731 // indirect calls but is always null for libcalls.
1732 bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1733
1734 if (isa<ExternalSymbolSDNode>(Callee)) {
1735 Function* CalleeFunc = nullptr;
1736
1737 // Try to find the callee in the current module.
1738 Callee = DAG.getSymbolFunctionGlobalAddress(Callee, &CalleeFunc);
1739 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1740
1741 // Set the "libcall callee" attribute to indicate that the function
1742 // must always have a declaration.
1743 CalleeFunc->addFnAttr("nvptx-libcall-callee", "true");
1744 }
1745
1746 if (isIndirectCall) {
1747 // This is indirect function call case : PTX requires a prototype of the
1748 // form
1749 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1750 // to be emitted, and the label has to used as the last arg of call
1751 // instruction.
1752 // The prototype is embedded in a string and put as the operand for a
1753 // CallPrototype SDNode which will print out to the value of the string.
1754 SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1755 std::string Proto = getPrototype(
1756 DL, RetTy, Args, Outs, retAlignment,
1757 HasVAArgs
1758 ? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair(
1759 CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1)))
1760 : std::nullopt,
1761 *CB, UniqueCallSite);
1762 const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
1763 SDValue ProtoOps[] = {
1764 Chain,
1765 DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
1766 InGlue,
1767 };
1768 Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, ProtoOps);
1769 InGlue = Chain.getValue(1);
1770 }
1771 // Op to just print "call"
1772 SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1773 SDValue PrintCallOps[] = {
1774 Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InGlue
1775 };
1776 // We model convergent calls as separate opcodes.
1778 if (CLI.IsConvergent)
1781 Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
1782 InGlue = Chain.getValue(1);
1783
1784 if (ConvertToIndirectCall) {
1785 // Copy the function ptr to a ptx register and use the register to call the
1786 // function.
1787 EVT DestVT = Callee.getValueType();
1789 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
1790 unsigned DestReg =
1791 RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
1792 auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
1793 Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
1794 }
1795
1796 // Ops to print out the function name
1797 SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1798 SDValue CallVoidOps[] = { Chain, Callee, InGlue };
1799 Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps);
1800 InGlue = Chain.getValue(1);
1801
1802 // Ops to print out the param list
1803 SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1804 SDValue CallArgBeginOps[] = { Chain, InGlue };
1805 Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
1806 CallArgBeginOps);
1807 InGlue = Chain.getValue(1);
1808
1809 for (unsigned i = 0, e = std::min(CLI.NumFixedArgs + 1, ParamCount); i != e;
1810 ++i) {
1811 unsigned opcode;
1812 if (i == (e - 1))
1813 opcode = NVPTXISD::LastCallArg;
1814 else
1815 opcode = NVPTXISD::CallArg;
1816 SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1817 SDValue CallArgOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
1818 DAG.getConstant(i, dl, MVT::i32), InGlue };
1819 Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps);
1820 InGlue = Chain.getValue(1);
1821 }
1822 SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1823 SDValue CallArgEndOps[] = { Chain,
1824 DAG.getConstant(isIndirectCall ? 0 : 1, dl, MVT::i32),
1825 InGlue };
1826 Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps);
1827 InGlue = Chain.getValue(1);
1828
1829 if (isIndirectCall) {
1830 SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1831 SDValue PrototypeOps[] = {
1832 Chain, DAG.getConstant(UniqueCallSite, dl, MVT::i32), InGlue};
1833 Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps);
1834 InGlue = Chain.getValue(1);
1835 }
1836
1837 SmallVector<SDValue, 16> ProxyRegOps;
1838 SmallVector<std::optional<MVT>, 16> ProxyRegTruncates;
1839 // An item of the vector is filled if the element does not need a ProxyReg
1840 // operation on it and should be added to InVals as is. ProxyRegOps and
1841 // ProxyRegTruncates contain empty/none items at the same index.
1843 // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
1844 // to use the values of `LoadParam`s and to be replaced later then
1845 // `CALLSEQ_END` is added.
1846 SmallVector<SDValue, 16> TempProxyRegOps;
1847
1848 // Generate loads from param memory/moves from registers for result
1849 if (Ins.size() > 0) {
1852 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
1853 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1854
1855 Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
1856 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
1857
1858 SmallVector<EVT, 6> LoadVTs;
1859 int VecIdx = -1; // Index of the first element of the vector.
1860
1861 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1862 // 32-bits are sign extended or zero extended, depending on whether
1863 // they are signed or unsigned types.
1864 bool ExtendIntegerRetVal =
1865 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
1866
1867 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
1868 bool needTruncate = false;
1869 EVT TheLoadType = VTs[i];
1870 EVT EltType = Ins[i].VT;
1871 Align EltAlign = commonAlignment(RetAlign, Offsets[i]);
1872 MVT PromotedVT;
1873
1874 if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) {
1875 TheLoadType = EVT(PromotedVT);
1876 EltType = EVT(PromotedVT);
1877 needTruncate = true;
1878 }
1879
1880 if (ExtendIntegerRetVal) {
1881 TheLoadType = MVT::i32;
1882 EltType = MVT::i32;
1883 needTruncate = true;
1884 } else if (TheLoadType.getSizeInBits() < 16) {
1885 if (VTs[i].isInteger())
1886 needTruncate = true;
1887 EltType = MVT::i16;
1888 }
1889
1890 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1891 // scalar load. In such cases, fall back to byte loads.
1892 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() &&
1893 EltAlign < DL.getABITypeAlign(
1894 TheLoadType.getTypeForEVT(*DAG.getContext()))) {
1895 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
1897 DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl);
1898 ProxyRegOps.push_back(SDValue());
1899 ProxyRegTruncates.push_back(std::optional<MVT>());
1900 RetElts.resize(i);
1901 RetElts.push_back(Ret);
1902
1903 continue;
1904 }
1905
1906 // Record index of the very first element of the vector.
1907 if (VectorInfo[i] & PVF_FIRST) {
1908 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
1909 VecIdx = i;
1910 }
1911
1912 LoadVTs.push_back(EltType);
1913
1914 if (VectorInfo[i] & PVF_LAST) {
1915 unsigned NumElts = LoadVTs.size();
1916 LoadVTs.push_back(MVT::Other);
1917 LoadVTs.push_back(MVT::Glue);
1919 switch (NumElts) {
1920 case 1:
1922 break;
1923 case 2:
1925 break;
1926 case 4:
1928 break;
1929 default:
1930 llvm_unreachable("Invalid vector info.");
1931 }
1932
1933 SDValue LoadOperands[] = {
1934 Chain, DAG.getConstant(1, dl, MVT::i32),
1935 DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InGlue};
1936 SDValue RetVal = DAG.getMemIntrinsicNode(
1937 Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
1938 MachinePointerInfo(), EltAlign,
1940
1941 for (unsigned j = 0; j < NumElts; ++j) {
1942 ProxyRegOps.push_back(RetVal.getValue(j));
1943
1944 if (needTruncate)
1945 ProxyRegTruncates.push_back(std::optional<MVT>(Ins[VecIdx + j].VT));
1946 else
1947 ProxyRegTruncates.push_back(std::optional<MVT>());
1948 }
1949
1950 Chain = RetVal.getValue(NumElts);
1951 InGlue = RetVal.getValue(NumElts + 1);
1952
1953 // Cleanup
1954 VecIdx = -1;
1955 LoadVTs.clear();
1956 }
1957 }
1958 }
1959
1960 Chain =
1961 DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl);
1962 InGlue = Chain.getValue(1);
1963
1964 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1965 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1966 // dangling.
1967 for (unsigned i = 0; i < ProxyRegOps.size(); ++i) {
1968 if (i < RetElts.size() && RetElts[i]) {
1969 InVals.push_back(RetElts[i]);
1970 continue;
1971 }
1972
1973 SDValue Ret = DAG.getNode(
1975 DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue),
1976 { Chain, ProxyRegOps[i], InGlue }
1977 );
1978
1979 Chain = Ret.getValue(1);
1980 InGlue = Ret.getValue(2);
1981
1982 if (ProxyRegTruncates[i]) {
1983 Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[i], Ret);
1984 }
1985
1986 InVals.push_back(Ret);
1987 }
1988
1989 for (SDValue &T : TempProxyRegOps) {
1990 SDValue Repl = DAG.getNode(
1992 DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue),
1993 {Chain, T.getOperand(0), InGlue});
1994 DAG.ReplaceAllUsesWith(T, Repl);
1995 DAG.RemoveDeadNode(T.getNode());
1996
1997 Chain = Repl.getValue(1);
1998 InGlue = Repl.getValue(2);
1999 }
2000
2001 // set isTailCall to false for now, until we figure out how to express
2002 // tail call optimization in PTX
2003 isTailCall = false;
2004 return Chain;
2005}
2006
2008 SelectionDAG &DAG) const {
2009
2010 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2011 const Function &Fn = DAG.getMachineFunction().getFunction();
2012
2013 DiagnosticInfoUnsupported NoDynamicAlloca(
2014 Fn,
2015 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
2016 "requires target sm_52.",
2017 SDLoc(Op).getDebugLoc());
2018 DAG.getContext()->diagnose(NoDynamicAlloca);
2019 auto Ops = {DAG.getConstant(0, SDLoc(), Op.getValueType()),
2020 Op.getOperand(0)};
2021 return DAG.getMergeValues(Ops, SDLoc());
2022 }
2023
2024 SDValue Chain = Op.getOperand(0);
2025 SDValue Size = Op.getOperand(1);
2026 uint64_t Align = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue();
2027 SDLoc DL(Op.getNode());
2028
2029 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
2030 MVT ValueSizeTy = nvTM->is64Bit() ? MVT::i64 : MVT::i32;
2031
2032 SDValue AllocOps[] = {Chain, DAG.getZExtOrTrunc(Size, DL, ValueSizeTy),
2033 DAG.getTargetConstant(Align, DL, MVT::i32)};
2034 EVT RetTypes[] = {ValueSizeTy, MVT::Other};
2035 return DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, RetTypes, AllocOps);
2036}
2037
2039 SelectionDAG &DAG) const {
2040 SDLoc DL(Op.getNode());
2041 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2042 const Function &Fn = DAG.getMachineFunction().getFunction();
2043
2044 DiagnosticInfoUnsupported NoStackRestore(
2045 Fn,
2046 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
2047 ">= sm_52.",
2048 DL.getDebugLoc());
2049 DAG.getContext()->diagnose(NoStackRestore);
2050 return Op.getOperand(0);
2051 }
2052
2053 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
2054 SDValue Chain = Op.getOperand(0);
2055 SDValue Ptr = Op.getOperand(1);
2058 return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC});
2059}
2060
2062 SelectionDAG &DAG) const {
2063 SDLoc DL(Op.getNode());
2064 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2065 const Function &Fn = DAG.getMachineFunction().getFunction();
2066
2067 DiagnosticInfoUnsupported NoStackSave(
2068 Fn,
2069 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
2070 "sm_52.",
2071 DL.getDebugLoc());
2072 DAG.getContext()->diagnose(NoStackSave);
2073 auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)};
2074 return DAG.getMergeValues(Ops, DL);
2075 }
2076
2077 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
2078 SDValue Chain = Op.getOperand(0);
2079 SDValue SS =
2080 DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain);
2081 SDValue ASC = DAG.getAddrSpaceCast(
2082 DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC);
2083 return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL);
2084}
2085
2086// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
2087// (see LegalizeDAG.cpp). This is slow and uses local memory.
2088// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
2089SDValue
2090NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2091 SDNode *Node = Op.getNode();
2092 SDLoc dl(Node);
2094 unsigned NumOperands = Node->getNumOperands();
2095 for (unsigned i = 0; i < NumOperands; ++i) {
2096 SDValue SubOp = Node->getOperand(i);
2097 EVT VVT = SubOp.getNode()->getValueType(0);
2098 EVT EltVT = VVT.getVectorElementType();
2099 unsigned NumSubElem = VVT.getVectorNumElements();
2100 for (unsigned j = 0; j < NumSubElem; ++j) {
2101 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
2102 DAG.getIntPtrConstant(j, dl)));
2103 }
2104 }
2105 return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
2106}
2107
2108SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2109 // Handle bitcasting from v2i8 without hitting the default promotion
2110 // strategy which goes through stack memory.
2111 EVT FromVT = Op->getOperand(0)->getValueType(0);
2112 if (FromVT != MVT::v2i8) {
2113 return Op;
2114 }
2115
2116 // Pack vector elements into i16 and bitcast to final type
2117 SDLoc DL(Op);
2118 SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2119 Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
2120 SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2121 Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
2122 SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
2123 SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
2124 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
2125 SDValue AsInt = DAG.getNode(
2126 ISD::OR, DL, MVT::i16,
2127 {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
2128 EVT ToVT = Op->getValueType(0);
2129 return MaybeBitcast(DAG, DL, ToVT, AsInt);
2130}
2131
2132// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2133// would get lowered as two constant loads and vector-packing move.
2134// Instead we want just a constant move:
2135// mov.b32 %r2, 0x40003C00
2136SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2137 SelectionDAG &DAG) const {
2138 EVT VT = Op->getValueType(0);
2139 if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
2140 return Op;
2141 SDLoc DL(Op);
2142
2143 if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
2144 return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
2145 isa<ConstantFPSDNode>(Operand);
2146 })) {
2147 if (VT != MVT::v4i8)
2148 return Op;
2149 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2150 // to optimize calculation of constant parts.
2151 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2152 uint64_t SelectionValue) -> SDValue {
2153 SDValue L = Left;
2154 SDValue R = Right;
2155 if (Cast) {
2156 L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
2157 R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
2158 }
2159 return DAG.getNode(
2160 NVPTXISD::PRMT, DL, MVT::v4i8,
2161 {L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),
2162 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
2163 };
2164 auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
2165 auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
2166 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2167 return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);
2168 }
2169
2170 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2171 auto GetOperand = [](SDValue Op, int N) -> APInt {
2172 const SDValue &Operand = Op->getOperand(N);
2173 EVT VT = Op->getValueType(0);
2174 if (Operand->isUndef())
2175 return APInt(32, 0);
2176 APInt Value;
2177 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2178 Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
2179 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2180 Value = Operand->getAsAPIntVal();
2181 else
2182 llvm_unreachable("Unsupported type");
2183 // i8 values are carried around as i16, so we need to zero out upper bits,
2184 // so they do not get in the way of combining individual byte values
2185 if (VT == MVT::v4i8)
2186 Value = Value.trunc(8);
2187 return Value.zext(32);
2188 };
2189 APInt Value;
2190 if (Isv2x16VT(VT)) {
2191 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
2192 } else if (VT == MVT::v4i8) {
2193 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
2194 GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
2195 } else {
2196 llvm_unreachable("Unsupported type");
2197 }
2198 SDValue Const = DAG.getConstant(Value, DL, MVT::i32);
2199 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Const);
2200}
2201
2202SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2203 SelectionDAG &DAG) const {
2204 SDValue Index = Op->getOperand(1);
2205 SDValue Vector = Op->getOperand(0);
2206 SDLoc DL(Op);
2207 EVT VectorVT = Vector.getValueType();
2208
2209 if (VectorVT == MVT::v4i8) {
2210 SDValue BFE =
2211 DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
2212 {Vector,
2213 DAG.getNode(ISD::MUL, DL, MVT::i32,
2214 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2215 DAG.getConstant(8, DL, MVT::i32)),
2216 DAG.getConstant(8, DL, MVT::i32)});
2217 return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
2218 }
2219
2220 // Constant index will be matched by tablegen.
2221 if (isa<ConstantSDNode>(Index.getNode()))
2222 return Op;
2223
2224 // Extract individual elements and select one of them.
2225 assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
2226 EVT EltVT = VectorVT.getVectorElementType();
2227
2228 SDLoc dl(Op.getNode());
2229 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2230 DAG.getIntPtrConstant(0, dl));
2231 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2232 DAG.getIntPtrConstant(1, dl));
2233 return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1,
2235}
2236
2237SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2238 SelectionDAG &DAG) const {
2239 SDValue Vector = Op->getOperand(0);
2240 EVT VectorVT = Vector.getValueType();
2241
2242 if (VectorVT != MVT::v4i8)
2243 return Op;
2244 SDLoc DL(Op);
2245 SDValue Value = Op->getOperand(1);
2246 if (Value->isUndef())
2247 return Vector;
2248
2249 SDValue Index = Op->getOperand(2);
2250
2251 SDValue BFI =
2252 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2253 {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
2254 DAG.getNode(ISD::MUL, DL, MVT::i32,
2255 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2256 DAG.getConstant(8, DL, MVT::i32)),
2257 DAG.getConstant(8, DL, MVT::i32)});
2258 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
2259}
2260
2261SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2262 SelectionDAG &DAG) const {
2263 SDValue V1 = Op.getOperand(0);
2264 EVT VectorVT = V1.getValueType();
2265 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2266 return Op;
2267
2268 // Lower shuffle to PRMT instruction.
2269 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
2270 SDValue V2 = Op.getOperand(1);
2271 uint32_t Selector = 0;
2272 for (auto I : llvm::enumerate(SVN->getMask())) {
2273 if (I.value() != -1) // -1 is a placeholder for undef.
2274 Selector |= (I.value() << (I.index() * 4));
2275 }
2276
2277 SDLoc DL(Op);
2278 return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
2279 DAG.getConstant(Selector, DL, MVT::i32),
2280 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
2281}
2282/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2283/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2284/// amount, or
2285/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2286/// amount.
2287SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2288 SelectionDAG &DAG) const {
2289 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2290 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2291
2292 EVT VT = Op.getValueType();
2293 unsigned VTBits = VT.getSizeInBits();
2294 SDLoc dl(Op);
2295 SDValue ShOpLo = Op.getOperand(0);
2296 SDValue ShOpHi = Op.getOperand(1);
2297 SDValue ShAmt = Op.getOperand(2);
2298 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2299
2300 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2301 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2302 // {dHi, dLo} = {aHi, aLo} >> Amt
2303 // dHi = aHi >> Amt
2304 // dLo = shf.r.clamp aLo, aHi, Amt
2305
2306 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2307 SDValue Lo =
2308 DAG.getNode(NVPTXISD::FSHR_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2309
2310 SDValue Ops[2] = { Lo, Hi };
2311 return DAG.getMergeValues(Ops, dl);
2312 }
2313 else {
2314 // {dHi, dLo} = {aHi, aLo} >> Amt
2315 // - if (Amt>=size) then
2316 // dLo = aHi >> (Amt-size)
2317 // dHi = aHi >> Amt (this is either all 0 or all 1)
2318 // else
2319 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2320 // dHi = aHi >> Amt
2321
2322 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2323 DAG.getConstant(VTBits, dl, MVT::i32),
2324 ShAmt);
2325 SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
2326 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2327 DAG.getConstant(VTBits, dl, MVT::i32));
2328 SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
2329 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2330 SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
2331
2332 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2333 DAG.getConstant(VTBits, dl, MVT::i32),
2334 ISD::SETGE);
2335 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2336 SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2337
2338 SDValue Ops[2] = { Lo, Hi };
2339 return DAG.getMergeValues(Ops, dl);
2340 }
2341}
2342
2343/// LowerShiftLeftParts - Lower SHL_PARTS, which
2344/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2345/// amount, or
2346/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2347/// amount.
2348SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2349 SelectionDAG &DAG) const {
2350 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2351 assert(Op.getOpcode() == ISD::SHL_PARTS);
2352
2353 EVT VT = Op.getValueType();
2354 unsigned VTBits = VT.getSizeInBits();
2355 SDLoc dl(Op);
2356 SDValue ShOpLo = Op.getOperand(0);
2357 SDValue ShOpHi = Op.getOperand(1);
2358 SDValue ShAmt = Op.getOperand(2);
2359
2360 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2361 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2362 // {dHi, dLo} = {aHi, aLo} << Amt
2363 // dHi = shf.l.clamp aLo, aHi, Amt
2364 // dLo = aLo << Amt
2365
2366 SDValue Hi =
2367 DAG.getNode(NVPTXISD::FSHL_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2368 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2369
2370 SDValue Ops[2] = { Lo, Hi };
2371 return DAG.getMergeValues(Ops, dl);
2372 }
2373 else {
2374 // {dHi, dLo} = {aHi, aLo} << Amt
2375 // - if (Amt>=size) then
2376 // dLo = aLo << Amt (all 0)
2377 // dLo = aLo << (Amt-size)
2378 // else
2379 // dLo = aLo << Amt
2380 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2381
2382 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2383 DAG.getConstant(VTBits, dl, MVT::i32),
2384 ShAmt);
2385 SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
2386 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2387 DAG.getConstant(VTBits, dl, MVT::i32));
2388 SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
2389 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2390 SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
2391
2392 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2393 DAG.getConstant(VTBits, dl, MVT::i32),
2394 ISD::SETGE);
2395 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2396 SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2397
2398 SDValue Ops[2] = { Lo, Hi };
2399 return DAG.getMergeValues(Ops, dl);
2400 }
2401}
2402
2403/// If the types match, convert the generic copysign to the NVPTXISD version,
2404/// otherwise bail ensuring that mismatched cases are properly expaned.
2405SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2406 SelectionDAG &DAG) const {
2407 EVT VT = Op.getValueType();
2408 SDLoc DL(Op);
2409
2410 SDValue In1 = Op.getOperand(0);
2411 SDValue In2 = Op.getOperand(1);
2412 EVT SrcVT = In2.getValueType();
2413
2414 if (!SrcVT.bitsEq(VT))
2415 return SDValue();
2416
2417 return DAG.getNode(NVPTXISD::FCOPYSIGN, DL, VT, In1, In2);
2418}
2419
2420SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2421 EVT VT = Op.getValueType();
2422
2423 if (VT == MVT::f32)
2424 return LowerFROUND32(Op, DAG);
2425
2426 if (VT == MVT::f64)
2427 return LowerFROUND64(Op, DAG);
2428
2429 llvm_unreachable("unhandled type");
2430}
2431
2432// This is the the rounding method used in CUDA libdevice in C like code:
2433// float roundf(float A)
2434// {
2435// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2436// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2437// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2438// }
2439SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2440 SelectionDAG &DAG) const {
2441 SDLoc SL(Op);
2442 SDValue A = Op.getOperand(0);
2443 EVT VT = Op.getValueType();
2444
2445 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2446
2447 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2448 SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
2449 const unsigned SignBitMask = 0x80000000;
2450 SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
2451 DAG.getConstant(SignBitMask, SL, MVT::i32));
2452 const unsigned PointFiveInBits = 0x3F000000;
2453 SDValue PointFiveWithSignRaw =
2454 DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
2455 DAG.getConstant(PointFiveInBits, SL, MVT::i32));
2456 SDValue PointFiveWithSign =
2457 DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
2458 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
2459 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2460
2461 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2462 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2463 SDValue IsLarge =
2464 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
2465 ISD::SETOGT);
2466 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2467
2468 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2469 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2470 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2471 SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
2472 return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
2473}
2474
2475// The implementation of round(double) is similar to that of round(float) in
2476// that they both separate the value range into three regions and use a method
2477// specific to the region to round the values. However, round(double) first
2478// calculates the round of the absolute value and then adds the sign back while
2479// round(float) directly rounds the value with sign.
2480SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2481 SelectionDAG &DAG) const {
2482 SDLoc SL(Op);
2483 SDValue A = Op.getOperand(0);
2484 EVT VT = Op.getValueType();
2485
2486 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2487
2488 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2489 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
2490 DAG.getConstantFP(0.5, SL, VT));
2491 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2492
2493 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2494 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2495 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2496 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2497 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
2498 DAG.getConstantFP(0, SL, VT),
2499 RoundedA);
2500
2501 // Add sign to rounded_A
2502 RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
2503 DAG.getNode(ISD::FTRUNC, SL, VT, A);
2504
2505 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2506 SDValue IsLarge =
2507 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
2508 ISD::SETOGT);
2509 return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2510}
2511
2513 EVT VT = N->getValueType(0);
2514 EVT NVT = MVT::f32;
2515 if (VT.isVector()) {
2516 NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
2517 }
2518 SDLoc DL(N);
2519 SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
2520 SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
2521 SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
2522 return DAG.getFPExtendOrRound(Res, DL, VT);
2523}
2524
2525SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2526 SelectionDAG &DAG) const {
2527 if (useF32FTZ(DAG.getMachineFunction())) {
2528 return PromoteBinOpToF32(Op.getNode(), DAG);
2529 }
2530 return Op;
2531}
2532
2533SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2534 SelectionDAG &DAG) const {
2535 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2536
2537 if (Op.getValueType() == MVT::bf16) {
2538 SDLoc Loc(Op);
2539 return DAG.getNode(
2540 ISD::FP_ROUND, Loc, MVT::bf16,
2541 DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
2542 DAG.getIntPtrConstant(0, Loc, /*isTarget=*/true));
2543 }
2544
2545 // Everything else is considered legal.
2546 return Op;
2547}
2548
2549SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2550 SelectionDAG &DAG) const {
2551 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2552
2553 if (Op.getOperand(0).getValueType() == MVT::bf16) {
2554 SDLoc Loc(Op);
2555 return DAG.getNode(
2556 Op.getOpcode(), Loc, Op.getValueType(),
2557 DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
2558 }
2559
2560 // Everything else is considered legal.
2561 return Op;
2562}
2563
2564SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2565 SelectionDAG &DAG) const {
2566 EVT NarrowVT = Op.getValueType();
2567 SDValue Wide = Op.getOperand(0);
2568 EVT WideVT = Wide.getValueType();
2569 if (NarrowVT.getScalarType() == MVT::bf16) {
2570 const TargetLowering *TLI = STI.getTargetLowering();
2571 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2572 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2573 }
2574 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2575 // This combination was the first to support f32 -> bf16.
2576 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2577 if (WideVT.getScalarType() == MVT::f32) {
2578 return Op;
2579 }
2580 if (WideVT.getScalarType() == MVT::f64) {
2581 SDLoc Loc(Op);
2582 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2583 // the hardware f32 -> bf16 instruction.
2585 WideVT.isVector() ? WideVT.changeVectorElementType(MVT::f32)
2586 : MVT::f32,
2587 Wide, Loc, DAG);
2588 return DAG.getFPExtendOrRound(rod, Loc, NarrowVT);
2589 }
2590 }
2591 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2592 }
2593 }
2594
2595 // Everything else is considered legal.
2596 return Op;
2597}
2598
2599SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2600 SelectionDAG &DAG) const {
2601 SDValue Narrow = Op.getOperand(0);
2602 EVT NarrowVT = Narrow.getValueType();
2603 EVT WideVT = Op.getValueType();
2604 if (NarrowVT.getScalarType() == MVT::bf16) {
2605 if (WideVT.getScalarType() == MVT::f32 &&
2606 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2607 SDLoc Loc(Op);
2608 return DAG.getNode(ISD::BF16_TO_FP, Loc, WideVT, Narrow);
2609 }
2610 if (WideVT.getScalarType() == MVT::f64 &&
2611 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2612 EVT F32 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(MVT::f32)
2613 : MVT::f32;
2614 SDLoc Loc(Op);
2615 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2616 Op = DAG.getNode(ISD::FP_EXTEND, Loc, F32, Narrow);
2617 } else {
2618 Op = DAG.getNode(ISD::BF16_TO_FP, Loc, F32, Narrow);
2619 }
2620 return DAG.getNode(ISD::FP_EXTEND, Loc, WideVT, Op);
2621 }
2622 }
2623
2624 // Everything else is considered legal.
2625 return Op;
2626}
2627
2629 SDLoc DL(Op);
2630 if (Op.getValueType() != MVT::v2i16)
2631 return Op;
2632 EVT EltVT = Op.getValueType().getVectorElementType();
2633 SmallVector<SDValue> VecElements;
2634 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2635 SmallVector<SDValue> ScalarArgs;
2636 llvm::transform(Op->ops(), std::back_inserter(ScalarArgs),
2637 [&](const SDUse &O) {
2638 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT,
2639 O.get(), DAG.getIntPtrConstant(I, DL));
2640 });
2641 VecElements.push_back(DAG.getNode(Op.getOpcode(), DL, EltVT, ScalarArgs));
2642 }
2643 SDValue V =
2644 DAG.getNode(ISD::BUILD_VECTOR, DL, Op.getValueType(), VecElements);
2645 return V;
2646}
2647
2648SDValue
2650 switch (Op.getOpcode()) {
2651 case ISD::RETURNADDR:
2652 return SDValue();
2653 case ISD::FRAMEADDR:
2654 return SDValue();
2655 case ISD::GlobalAddress:
2656 return LowerGlobalAddress(Op, DAG);
2658 return Op;
2659 case ISD::BUILD_VECTOR:
2660 return LowerBUILD_VECTOR(Op, DAG);
2661 case ISD::BITCAST:
2662 return LowerBITCAST(Op, DAG);
2664 return Op;
2666 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
2668 return LowerINSERT_VECTOR_ELT(Op, DAG);
2670 return LowerVECTOR_SHUFFLE(Op, DAG);
2672 return LowerCONCAT_VECTORS(Op, DAG);
2673 case ISD::STORE:
2674 return LowerSTORE(Op, DAG);
2675 case ISD::LOAD:
2676 return LowerLOAD(Op, DAG);
2677 case ISD::SHL_PARTS:
2678 return LowerShiftLeftParts(Op, DAG);
2679 case ISD::SRA_PARTS:
2680 case ISD::SRL_PARTS:
2681 return LowerShiftRightParts(Op, DAG);
2682 case ISD::SELECT:
2683 return LowerSelect(Op, DAG);
2684 case ISD::FROUND:
2685 return LowerFROUND(Op, DAG);
2686 case ISD::FCOPYSIGN:
2687 return LowerFCOPYSIGN(Op, DAG);
2688 case ISD::SINT_TO_FP:
2689 case ISD::UINT_TO_FP:
2690 return LowerINT_TO_FP(Op, DAG);
2691 case ISD::FP_TO_SINT:
2692 case ISD::FP_TO_UINT:
2693 return LowerFP_TO_INT(Op, DAG);
2694 case ISD::FP_ROUND:
2695 return LowerFP_ROUND(Op, DAG);
2696 case ISD::FP_EXTEND:
2697 return LowerFP_EXTEND(Op, DAG);
2698 case ISD::BR_JT:
2699 return LowerBR_JT(Op, DAG);
2700 case ISD::VAARG:
2701 return LowerVAARG(Op, DAG);
2702 case ISD::VASTART:
2703 return LowerVASTART(Op, DAG);
2704 case ISD::ABS:
2705 case ISD::SMIN:
2706 case ISD::SMAX:
2707 case ISD::UMIN:
2708 case ISD::UMAX:
2709 case ISD::ADD:
2710 case ISD::SUB:
2711 case ISD::MUL:
2712 case ISD::SHL:
2713 case ISD::SREM:
2714 case ISD::UREM:
2715 return LowerVectorArith(Op, DAG);
2717 return LowerDYNAMIC_STACKALLOC(Op, DAG);
2718 case ISD::STACKRESTORE:
2719 return LowerSTACKRESTORE(Op, DAG);
2720 case ISD::STACKSAVE:
2721 return LowerSTACKSAVE(Op, DAG);
2722 case ISD::CopyToReg:
2723 return LowerCopyToReg_128(Op, DAG);
2724 case ISD::FADD:
2725 case ISD::FSUB:
2726 case ISD::FMUL:
2727 // Used only for bf16 on SM80, where we select fma for non-ftz operation
2728 return PromoteBinOpIfF32FTZ(Op, DAG);
2729
2730 default:
2731 llvm_unreachable("Custom lowering not defined for operation");
2732 }
2733}
2734
2735SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
2736 SDLoc DL(Op);
2737 SDValue Chain = Op.getOperand(0);
2738 const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
2739 SDValue Index = Op.getOperand(2);
2740
2741 unsigned JId = JT->getIndex();
2743 ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
2744
2745 SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
2746
2747 // Generate BrxStart node
2748 SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
2749 Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
2750
2751 // Generate BrxItem nodes
2752 assert(!MBBs.empty());
2753 for (MachineBasicBlock *MBB : MBBs.drop_back())
2754 Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
2755 DAG.getBasicBlock(MBB), Chain.getValue(1));
2756
2757 // Generate BrxEnd nodes
2758 SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
2759 IdV, Chain.getValue(1)};
2760 SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
2761
2762 return BrxEnd;
2763}
2764
2765// This will prevent AsmPrinter from trying to print the jump tables itself.
2768}
2769
2770// This function is almost a copy of SelectionDAG::expandVAArg().
2771// The only diff is that this one produces loads from local address space.
2772SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
2773 const TargetLowering *TLI = STI.getTargetLowering();
2774 SDLoc DL(Op);
2775
2776 SDNode *Node = Op.getNode();
2777 const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
2778 EVT VT = Node->getValueType(0);
2779 auto *Ty = VT.getTypeForEVT(*DAG.getContext());
2780 SDValue Tmp1 = Node->getOperand(0);
2781 SDValue Tmp2 = Node->getOperand(1);
2782 const MaybeAlign MA(Node->getConstantOperandVal(3));
2783
2784 SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
2785 Tmp1, Tmp2, MachinePointerInfo(V));
2786 SDValue VAList = VAListLoad;
2787
2788 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
2789 VAList = DAG.getNode(
2790 ISD::ADD, DL, VAList.getValueType(), VAList,
2791 DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
2792
2793 VAList = DAG.getNode(ISD::AND, DL, VAList.getValueType(), VAList,
2794 DAG.getSignedConstant(-(int64_t)MA->value(), DL,
2795 VAList.getValueType()));
2796 }
2797
2798 // Increment the pointer, VAList, to the next vaarg
2799 Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
2801 DL, VAList.getValueType()));
2802
2803 // Store the incremented VAList to the legalized pointer
2804 Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
2806
2807 const Value *SrcV =
2809
2810 // Load the actual argument out of the pointer VAList
2811 return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
2812}
2813
2814SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
2815 const TargetLowering *TLI = STI.getTargetLowering();
2816 SDLoc DL(Op);
2817 EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
2818
2819 // Store the address of unsized array <function>_vararg[] in the ap object.
2820 SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
2821 SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg);
2822
2823 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
2824 return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
2825 MachinePointerInfo(SV));
2826}
2827
2828SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
2829 SDValue Op0 = Op->getOperand(0);
2830 SDValue Op1 = Op->getOperand(1);
2831 SDValue Op2 = Op->getOperand(2);
2832 SDLoc DL(Op.getNode());
2833
2834 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
2835
2836 Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
2837 Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
2838 SDValue Select = DAG.getNode(ISD::SELECT, DL, MVT::i32, Op0, Op1, Op2);
2839 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
2840
2841 return Trunc;
2842}
2843
2844SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
2845 if (Op.getValueType() == MVT::i1)
2846 return LowerLOADi1(Op, DAG);
2847
2848 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
2849 // unaligned loads and have to handle it here.
2850 EVT VT = Op.getValueType();
2851 if (Isv2x16VT(VT) || VT == MVT::v4i8) {
2852 LoadSDNode *Load = cast<LoadSDNode>(Op);
2853 EVT MemVT = Load->getMemoryVT();
2855 MemVT, *Load->getMemOperand())) {
2856 SDValue Ops[2];
2857 std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG);
2858 return DAG.getMergeValues(Ops, SDLoc(Op));
2859 }
2860 }
2861
2862 return SDValue();
2863}
2864
2865// v = ld i1* addr
2866// =>
2867// v1 = ld i8* addr (-> i16)
2868// v = trunc i16 to i1
2869SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
2870 SDNode *Node = Op.getNode();
2871 LoadSDNode *LD = cast<LoadSDNode>(Node);
2872 SDLoc dl(Node);
2873 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
2874 assert(Node->getValueType(0) == MVT::i1 &&
2875 "Custom lowering for i1 load only");
2876 SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(),
2877 LD->getBasePtr(), LD->getPointerInfo(),
2878 MVT::i8, LD->getAlign(),
2879 LD->getMemOperand()->getFlags());
2880 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
2881 // The legalizer (the caller) is expecting two values from the legalized
2882 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
2883 // in LegalizeDAG.cpp which also uses MergeValues.
2884 SDValue Ops[] = { result, LD->getChain() };
2885 return DAG.getMergeValues(Ops, dl);
2886}
2887
2888SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
2889 StoreSDNode *Store = cast<StoreSDNode>(Op);
2890 EVT VT = Store->getMemoryVT();
2891
2892 if (VT == MVT::i1)
2893 return LowerSTOREi1(Op, DAG);
2894
2895 // v2f16 is legal, so we can't rely on legalizer to handle unaligned
2896 // stores and have to handle it here.
2897 if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
2899 VT, *Store->getMemOperand()))
2900 return expandUnalignedStore(Store, DAG);
2901
2902 // v2f16, v2bf16 and v2i16 don't need special handling.
2903 if (Isv2x16VT(VT) || VT == MVT::v4i8)
2904 return SDValue();
2905
2906 if (VT.isVector())
2907 return LowerSTOREVector(Op, DAG);
2908
2909 return SDValue();
2910}
2911
2912SDValue
2913NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2914 SDNode *N = Op.getNode();
2915 SDValue Val = N->getOperand(1);
2916 SDLoc DL(N);
2917 EVT ValVT = Val.getValueType();
2918
2919 auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
2920 if (!NumEltsAndEltVT)
2921 return SDValue();
2922 auto [NumElts, EltVT] = NumEltsAndEltVT.value();
2923
2924 MemSDNode *MemSD = cast<MemSDNode>(N);
2925 const DataLayout &TD = DAG.getDataLayout();
2926
2927 Align Alignment = MemSD->getAlign();
2928 Align PrefAlign = TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
2929 if (Alignment < PrefAlign) {
2930 // This store is not sufficiently aligned, so bail out and let this vector
2931 // store be scalarized. Note that we may still be able to emit smaller
2932 // vector stores. For example, if we are storing a <4 x float> with an
2933 // alignment of 8, this check will fail but the legalizer will try again
2934 // with 2 x <2 x float>, which will succeed with an alignment of 8.
2935 return SDValue();
2936 }
2937
2938 // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
2939 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
2940 // stored type to i16 and propagate the "real" type as the memory type.
2941 bool NeedExt = false;
2942 if (EltVT.getSizeInBits() < 16)
2943 NeedExt = true;
2944
2945 unsigned Opcode = 0;
2946 switch (NumElts) {
2947 default:
2948 return SDValue();
2949 case 2:
2950 Opcode = NVPTXISD::StoreV2;
2951 break;
2952 case 4:
2953 Opcode = NVPTXISD::StoreV4;
2954 break;
2955 }
2956
2958
2959 // First is the chain
2960 Ops.push_back(N->getOperand(0));
2961
2962 // Then the split values
2963 assert(NumElts <= ValVT.getVectorNumElements() &&
2964 "NumElts should not increase, only decrease or stay the same.");
2965 if (NumElts < ValVT.getVectorNumElements()) {
2966 // If the number of elements has decreased, getVectorLoweringShape has
2967 // upsized the element types
2968 assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 &&
2969 EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type.");
2970 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
2971 // stored as b32s
2972 unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
2973 for (unsigned i = 0; i < NumElts; ++i) {
2974 SmallVector<SDValue, 4> SubVectorElts;
2975 DAG.ExtractVectorElements(Val, SubVectorElts, i * NumEltsPerSubVector,
2976 NumEltsPerSubVector);
2977 SDValue SubVector = DAG.getBuildVector(EltVT, DL, SubVectorElts);
2978 Ops.push_back(SubVector);
2979 }
2980 } else {
2981 for (unsigned i = 0; i < NumElts; ++i) {
2982 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2983 DAG.getIntPtrConstant(i, DL));
2984 if (NeedExt)
2985 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
2986 Ops.push_back(ExtVal);
2987 }
2988 }
2989
2990 // Then any remaining arguments
2991 Ops.append(N->op_begin() + 2, N->op_end());
2992
2993 SDValue NewSt =
2994 DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
2995 MemSD->getMemoryVT(), MemSD->getMemOperand());
2996
2997 // return DCI.CombineTo(N, NewSt, true);
2998 return NewSt;
2999}
3000
3001// st i1 v, addr
3002// =>
3003// v1 = zxt v to i16
3004// st.u8 i16, addr
3005SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3006 SDNode *Node = Op.getNode();
3007 SDLoc dl(Node);
3008 StoreSDNode *ST = cast<StoreSDNode>(Node);
3009 SDValue Tmp1 = ST->getChain();
3010 SDValue Tmp2 = ST->getBasePtr();
3011 SDValue Tmp3 = ST->getValue();
3012 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3013 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
3014 SDValue Result =
3015 DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), MVT::i8,
3016 ST->getAlign(), ST->getMemOperand()->getFlags());
3017 return Result;
3018}
3019
3020SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3021 SelectionDAG &DAG) const {
3022 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3023 // operand so that it can pass the legalization.
3024
3025 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3026 "Custom lowering for 128-bit CopyToReg only");
3027
3028 SDNode *Node = Op.getNode();
3029 SDLoc DL(Node);
3030
3031 SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
3032 SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3033 DAG.getIntPtrConstant(0, DL));
3034 SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3035 DAG.getIntPtrConstant(1, DL));
3036
3038 SmallVector<EVT, 3> ResultsType(Node->values());
3039
3040 NewOps[0] = Op->getOperand(0); // Chain
3041 NewOps[1] = Op->getOperand(1); // Dst Reg
3042 NewOps[2] = Lo; // Lower 64-bit
3043 NewOps[3] = Hi; // Higher 64-bit
3044 if (Op.getNumOperands() == 4)
3045 NewOps[4] = Op->getOperand(3); // Glue if exists
3046
3047 return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
3048}
3049
3050unsigned NVPTXTargetLowering::getNumRegisters(
3051 LLVMContext &Context, EVT VT,
3052 std::optional<MVT> RegisterVT = std::nullopt) const {
3053 if (VT == MVT::i128 && RegisterVT == MVT::i128)
3054 return 1;
3055 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
3056}
3057
3058bool NVPTXTargetLowering::splitValueIntoRegisterParts(
3059 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
3060 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
3061 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
3062 Parts[0] = Val;
3063 return true;
3064 }
3065 return false;
3066}
3067
3068// This creates target external symbol for a function parameter.
3069// Name of the symbol is composed from its index and the function name.
3070// Negative index corresponds to special parameter (unsized array) used for
3071// passing variable arguments.
3072SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
3073 EVT v) const {
3074 StringRef SavedStr = nvTM->getStrPool().save(
3076 return DAG.getTargetExternalSymbol(SavedStr.data(), v);
3077}
3078
3080 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
3081 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
3082 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
3084 const DataLayout &DL = DAG.getDataLayout();
3085 auto PtrVT = getPointerTy(DAG.getDataLayout());
3086
3087 const Function *F = &MF.getFunction();
3088 const AttributeList &PAL = F->getAttributes();
3089 const TargetLowering *TLI = STI.getTargetLowering();
3090
3091 SDValue Root = DAG.getRoot();
3092 std::vector<SDValue> OutChains;
3093
3094 bool isABI = (STI.getSmVersion() >= 20);
3095 assert(isABI && "Non-ABI compilation is not supported");
3096 if (!isABI)
3097 return Chain;
3098
3099 std::vector<Type *> argTypes;
3100 std::vector<const Argument *> theArgs;
3101 for (const Argument &I : F->args()) {
3102 theArgs.push_back(&I);
3103 argTypes.push_back(I.getType());
3104 }
3105 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
3106 // Ins.size() will be larger
3107 // * if there is an aggregate argument with multiple fields (each field
3108 // showing up separately in Ins)
3109 // * if there is a vector argument with more than typical vector-length
3110 // elements (generally if more than 4) where each vector element is
3111 // individually present in Ins.
3112 // So a different index should be used for indexing into Ins.
3113 // See similar issue in LowerCall.
3114 unsigned InsIdx = 0;
3115
3116 for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) {
3117 Type *Ty = argTypes[i];
3118
3119 if (theArgs[i]->use_empty()) {
3120 // argument is dead
3121 if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
3122 SmallVector<EVT, 16> vtparts;
3123
3124 ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
3125 if (vtparts.empty())
3126 report_fatal_error("Empty parameter types are not supported");
3127
3128 for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
3129 ++parti) {
3130 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3131 ++InsIdx;
3132 }
3133 if (vtparts.size() > 0)
3134 --InsIdx;
3135 continue;
3136 }
3137 if (Ty->isVectorTy()) {
3138 EVT ObjectVT = getValueType(DL, Ty);
3139 unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
3140 for (unsigned parti = 0; parti < NumRegs; ++parti) {
3141 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3142 ++InsIdx;
3143 }
3144 if (NumRegs > 0)
3145 --InsIdx;
3146 continue;
3147 }
3148 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3149 continue;
3150 }
3151
3152 // In the following cases, assign a node order of "i+1"
3153 // to newly created nodes. The SDNodes for params have to
3154 // appear in the same order as their order of appearance
3155 // in the original function. "i+1" holds that order.
3156 if (!PAL.hasParamAttr(i, Attribute::ByVal)) {
3157 bool aggregateIsPacked = false;
3158 if (StructType *STy = dyn_cast<StructType>(Ty))
3159 aggregateIsPacked = STy->isPacked();
3160
3163 ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
3164 if (VTs.empty())
3165 report_fatal_error("Empty parameter types are not supported");
3166
3169 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
3170
3171 SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3172 int VecIdx = -1; // Index of the first element of the current vector.
3173 for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) {
3174 if (VectorInfo[parti] & PVF_FIRST) {
3175 assert(VecIdx == -1 && "Orphaned vector.");
3176 VecIdx = parti;
3177 }
3178
3179 // That's the last element of this store op.
3180 if (VectorInfo[parti] & PVF_LAST) {
3181 unsigned NumElts = parti - VecIdx + 1;
3182 EVT EltVT = VTs[parti];
3183 // i1 is loaded/stored as i8.
3184 EVT LoadVT = EltVT;
3185 if (EltVT == MVT::i1)
3186 LoadVT = MVT::i8;
3187 else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
3188 // getLoad needs a vector type, but it can't handle
3189 // vectors which contain v2f16 or v2bf16 elements. So we must load
3190 // using i32 here and then bitcast back.
3191 LoadVT = MVT::i32;
3192
3193 EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
3194 SDValue VecAddr =
3195 DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
3196 DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
3198 EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));
3199
3200 const MaybeAlign PartAlign = [&]() -> MaybeAlign {
3201 if (aggregateIsPacked)
3202 return Align(1);
3203 if (NumElts != 1)
3204 return std::nullopt;
3205 Align PartAlign =
3206 DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
3207 return commonAlignment(PartAlign, Offsets[parti]);
3208 }();
3209 SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
3210 MachinePointerInfo(srcValue), PartAlign,
3213 if (P.getNode())
3214 P.getNode()->setIROrder(i + 1);
3215 for (unsigned j = 0; j < NumElts; ++j) {
3216 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3217 DAG.getIntPtrConstant(j, dl));
3218 // We've loaded i1 as an i8 and now must truncate it back to i1
3219 if (EltVT == MVT::i1)
3220 Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
3221 // v2f16 was loaded as an i32. Now we must bitcast it back.
3222 else if (EltVT != LoadVT)
3223 Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
3224
3225 // If a promoted integer type is used, truncate down to the original
3226 MVT PromotedVT;
3227 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
3228 Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
3229 }
3230
3231 // Extend the element if necessary (e.g. an i8 is loaded
3232 // into an i16 register)
3233 if (Ins[InsIdx].VT.isInteger() &&
3234 Ins[InsIdx].VT.getFixedSizeInBits() >
3235 LoadVT.getFixedSizeInBits()) {
3236 unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
3238 Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
3239 }
3240 InVals.push_back(Elt);
3241 }
3242
3243 // Reset vector tracking state.
3244 VecIdx = -1;
3245 }
3246 ++InsIdx;
3247 }
3248 if (VTs.size() > 0)
3249 --InsIdx;
3250 continue;
3251 }
3252
3253 // Param has ByVal attribute
3254 // Return MoveParam(param symbol).
3255 // Ideally, the param symbol can be returned directly,
3256 // but when SDNode builder decides to use it in a CopyToReg(),
3257 // machine instruction fails because TargetExternalSymbol
3258 // (not lowered) is target dependent, and CopyToReg assumes
3259 // the source is lowered.
3260 EVT ObjectVT = getValueType(DL, Ty);
3261 assert(ObjectVT == Ins[InsIdx].VT &&
3262 "Ins type did not match function type");
3263 SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3264 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3265 if (p.getNode())
3266 p.getNode()->setIROrder(i + 1);
3267 InVals.push_back(p);
3268 }
3269
3270 if (!OutChains.empty())
3271 DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
3272
3273 return Chain;
3274}
3275
3276// Use byte-store when the param adress of the return value is unaligned.
3277// This may happen when the return value is a field of a packed structure.
3279 uint64_t Offset, EVT ElementType,
3280 SDValue RetVal, const SDLoc &dl) {
3281 // Bit logic only works on integer types
3282 if (adjustElementType(ElementType))
3283 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
3284
3285 // Store each byte
3286 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
3287 // Shift the byte to the last byte position
3288 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
3289 DAG.getConstant(i * 8, dl, MVT::i32));
3290 SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
3291 ShiftVal};
3292 // Trunc store only the last byte by using
3293 // st.param.b8
3294 // The register type can be larger than b8.
3296 DAG.getVTList(MVT::Other), StoreOperands,
3297 MVT::i8, MachinePointerInfo(), std::nullopt,
3299 }
3300 return Chain;
3301}
3302
3303SDValue
3305 bool isVarArg,
3307 const SmallVectorImpl<SDValue> &OutVals,
3308 const SDLoc &dl, SelectionDAG &DAG) const {
3309 const MachineFunction &MF = DAG.getMachineFunction();
3310 const Function &F = MF.getFunction();
3312
3313 bool isABI = (STI.getSmVersion() >= 20);
3314 assert(isABI && "Non-ABI compilation is not supported");
3315 if (!isABI)
3316 return Chain;
3317
3318 const DataLayout &DL = DAG.getDataLayout();
3319 SmallVector<SDValue, 16> PromotedOutVals;
3322 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
3323 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
3324
3325 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3326 SDValue PromotedOutVal = OutVals[i];
3327 MVT PromotedVT;
3328 if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) {
3329 VTs[i] = EVT(PromotedVT);
3330 }
3331 if (PromoteScalarIntegerPTX(PromotedOutVal.getValueType(), &PromotedVT)) {
3333 Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
3334 PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal);
3335 }
3336 PromotedOutVals.push_back(PromotedOutVal);
3337 }
3338
3339 auto VectorInfo = VectorizePTXValueVTs(
3340 VTs, Offsets,
3342 : Align(1));
3343
3344 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
3345 // 32-bits are sign extended or zero extended, depending on whether
3346 // they are signed or unsigned types.
3347 bool ExtendIntegerRetVal =
3348 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
3349
3350 SmallVector<SDValue, 6> StoreOperands;
3351 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3352 SDValue OutVal = OutVals[i];
3353 SDValue RetVal = PromotedOutVals[i];
3354
3355 if (ExtendIntegerRetVal) {
3356 RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND
3358 dl, MVT::i32, RetVal);
3359 } else if (OutVal.getValueSizeInBits() < 16) {
3360 // Use 16-bit registers for small load-stores as it's the
3361 // smallest general purpose register size supported by NVPTX.
3362 RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
3363 }
3364
3365 // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
3366 // for a scalar store. In such cases, fall back to byte stores.
3367 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) {
3368 EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3369 Align ElementTypeAlign =
3370 DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext()));
3371 Align ElementAlign =
3372 commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]);
3373 if (ElementAlign < ElementTypeAlign) {
3374 assert(StoreOperands.empty() && "Orphaned operand list.");
3375 Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType,
3376 RetVal, dl);
3377
3378 // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
3379 // into the graph, so just move on to the next element.
3380 continue;
3381 }
3382 }
3383
3384 // New load/store. Record chain and offset operands.
3385 if (VectorInfo[i] & PVF_FIRST) {
3386 assert(StoreOperands.empty() && "Orphaned operand list.");
3387 StoreOperands.push_back(Chain);
3388 StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
3389 }
3390
3391 // Record the value to return.
3392 StoreOperands.push_back(RetVal);
3393
3394 // That's the last element of this store op.
3395 if (VectorInfo[i] & PVF_LAST) {
3397 unsigned NumElts = StoreOperands.size() - 2;
3398 switch (NumElts) {
3399 case 1:
3401 break;
3402 case 2:
3404 break;
3405 case 4:
3407 break;
3408 default:
3409 llvm_unreachable("Invalid vector info.");
3410 }
3411
3412 // Adjust type of load/store op if we've extended the scalar
3413 // return value.
3414 EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3415 Chain = DAG.getMemIntrinsicNode(
3416 Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
3418 // Cleanup vector state.
3419 StoreOperands.clear();
3420 }
3421 }
3422
3423 return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
3424}
3425
3427 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
3428 SelectionDAG &DAG) const {
3429 if (Constraint.size() > 1)
3430 return;
3431 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
3432}
3433
3434// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
3435// TgtMemIntrinsic
3436// because we need the information that is only available in the "Value" type
3437// of destination
3438// pointer. In particular, the address space information.
3440 IntrinsicInfo &Info, const CallInst &I,
3441 MachineFunction &MF, unsigned Intrinsic) const {
3442 switch (Intrinsic) {
3443 default:
3444 return false;
3445 case Intrinsic::nvvm_match_all_sync_i32p:
3446 case Intrinsic::nvvm_match_all_sync_i64p:
3448 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
3449 // in order to model data exchange with other threads, but perform no real
3450 // memory accesses.
3451 Info.memVT = MVT::i1;
3452
3453 // Our result depends on both our and other thread's arguments.
3455 return true;
3456 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
3457 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
3458 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
3459 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
3460 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
3461 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
3462 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
3463 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
3464 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
3465 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
3466 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
3467 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
3468 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
3469 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
3470 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
3471 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
3472 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
3473 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
3474 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
3475 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
3476 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
3477 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
3478 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
3479 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
3481 Info.memVT = MVT::v8f16;
3482 Info.ptrVal = I.getArgOperand(0);
3483 Info.offset = 0;
3485 Info.align = Align(16);
3486 return true;
3487 }
3488 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
3489 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
3490 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
3491 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
3492 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
3493 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
3494 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
3495 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
3496 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
3497 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
3498 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
3499 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
3500 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
3501 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
3502 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
3503 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
3504 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
3505 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
3506 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
3507 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
3508 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
3509 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
3510 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
3511 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
3513 Info.memVT = MVT::v2i32;
3514 Info.ptrVal = I.getArgOperand(0);
3515 Info.offset = 0;
3517 Info.align = Align(8);
3518 return true;
3519 }
3520
3521 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
3522 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
3523 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
3524 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
3525 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
3526 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
3527 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
3528 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
3529 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
3530 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
3531 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
3532 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
3533 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
3534 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
3535 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
3536 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
3537
3538 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
3539 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
3540 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
3541 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
3542 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
3543 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
3544 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
3545 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
3546 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
3547 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
3548 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
3549 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
3550 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
3551 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
3552 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
3553 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
3554 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
3555 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
3557 Info.memVT = MVT::v4i32;
3558 Info.ptrVal = I.getArgOperand(0);
3559 Info.offset = 0;
3561 Info.align = Align(16);
3562 return true;
3563 }
3564
3565 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
3566 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
3567 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
3568 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
3569 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
3570 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
3571 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
3572 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
3573
3574 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
3575 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
3576 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
3577 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
3578 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
3579 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
3580 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
3581 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
3582 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
3583 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
3584 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
3585 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
3586 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
3587 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
3588 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
3589 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
3590 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
3591 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
3592 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
3593 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
3594 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
3595 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
3597 Info.memVT = MVT::i32;
3598 Info.ptrVal = I.getArgOperand(0);
3599 Info.offset = 0;
3601 Info.align = Align(4);
3602 return true;
3603 }
3604
3605 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
3606 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
3607 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
3608 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
3609 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
3610 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
3611 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
3612 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
3613 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
3614 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
3615 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
3616 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
3618 Info.memVT = MVT::v4f16;
3619 Info.ptrVal = I.getArgOperand(0);
3620 Info.offset = 0;
3622 Info.align = Align(16);
3623 return true;
3624 }
3625
3626 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
3627 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
3628 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
3629 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
3630 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
3631 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
3632 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
3633 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
3634 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
3635 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
3636 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
3637 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
3638 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
3639 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
3640 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
3641 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
3643 Info.memVT = MVT::v8f32;
3644 Info.ptrVal = I.getArgOperand(0);
3645 Info.offset = 0;
3647 Info.align = Align(16);
3648 return true;
3649 }
3650
3651 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
3652 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
3653 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
3654 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
3655
3656 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
3657 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
3658 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
3659 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
3660
3661 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
3662 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
3663 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
3664 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
3665 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
3666 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
3667 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
3668 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
3669 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
3670 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
3671 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
3672 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
3674 Info.memVT = MVT::v8i32;
3675 Info.ptrVal = I.getArgOperand(0);
3676 Info.offset = 0;
3678 Info.align = Align(16);
3679 return true;
3680 }
3681
3682 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
3683 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
3684 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
3685 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
3686 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
3687 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
3688 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
3689 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
3690 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
3691 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
3693 Info.memVT = MVT::v2i32;
3694 Info.ptrVal = I.getArgOperand(0);
3695 Info.offset = 0;
3697 Info.align = Align(8);
3698 return true;
3699 }
3700
3701 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
3702 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
3703 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
3704 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
3705
3706 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
3707 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
3708 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
3709 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
3711 Info.memVT = MVT::f64;
3712 Info.ptrVal = I.getArgOperand(0);
3713 Info.offset = 0;
3715 Info.align = Align(8);
3716 return true;
3717 }
3718
3719 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
3720 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
3721 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
3722 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
3724 Info.memVT = MVT::v2f64;
3725 Info.ptrVal = I.getArgOperand(0);
3726 Info.offset = 0;
3728 Info.align = Align(16);
3729 return true;
3730 }
3731
3732 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
3733 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
3734 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
3735 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
3736 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
3737 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
3738 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
3739 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
3740 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
3741 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
3742 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
3743 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
3745 Info.memVT = MVT::v4f16;
3746 Info.ptrVal = I.getArgOperand(0);
3747 Info.offset = 0;
3749 Info.align = Align(16);
3750 return true;
3751 }
3752
3753 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
3754 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
3755 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
3756 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
3757 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
3758 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
3759 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
3760 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
3761 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
3762 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
3763 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
3764 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
3765 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
3766 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
3767 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
3768 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
3770 Info.memVT = MVT::v8f32;
3771 Info.ptrVal = I.getArgOperand(0);
3772 Info.offset = 0;
3774 Info.align = Align(16);
3775 return true;
3776 }
3777
3778 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
3779 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
3780 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
3781 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
3782 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
3783 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
3784 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
3785 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
3786 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
3787 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
3788 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
3789 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
3791 Info.memVT = MVT::v8i32;
3792 Info.ptrVal = I.getArgOperand(0);
3793 Info.offset = 0;
3795 Info.align = Align(16);
3796 return true;
3797 }
3798
3799 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
3800 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
3801 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
3802 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
3803 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
3804 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
3805 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
3806 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
3808 Info.memVT = MVT::v2i32;
3809 Info.ptrVal = I.getArgOperand(0);
3810 Info.offset = 0;
3812 Info.align = Align(8);
3813 return true;
3814 }
3815
3816 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
3817 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
3818 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
3819 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
3821 Info.memVT = MVT::v2f64;
3822 Info.ptrVal = I.getArgOperand(0);
3823 Info.offset = 0;
3825 Info.align = Align(16);
3826 return true;
3827 }
3828
3829 case Intrinsic::nvvm_atomic_load_inc_32:
3830 case Intrinsic::nvvm_atomic_load_dec_32:
3831
3832 case Intrinsic::nvvm_atomic_add_gen_f_cta:
3833 case Intrinsic::nvvm_atomic_add_gen_f_sys:
3834 case Intrinsic::nvvm_atomic_add_gen_i_cta:
3835 case Intrinsic::nvvm_atomic_add_gen_i_sys:
3836 case Intrinsic::nvvm_atomic_and_gen_i_cta:
3837 case Intrinsic::nvvm_atomic_and_gen_i_sys:
3838 case Intrinsic::nvvm_atomic_cas_gen_i_cta:
3839 case Intrinsic::nvvm_atomic_cas_gen_i_sys:
3840 case Intrinsic::nvvm_atomic_dec_gen_i_cta:
3841 case Intrinsic::nvvm_atomic_dec_gen_i_sys:
3842 case Intrinsic::nvvm_atomic_inc_gen_i_cta:
3843 case Intrinsic::nvvm_atomic_inc_gen_i_sys:
3844 case Intrinsic::nvvm_atomic_max_gen_i_cta:
3845 case Intrinsic::nvvm_atomic_max_gen_i_sys:
3846 case Intrinsic::nvvm_atomic_min_gen_i_cta:
3847 case Intrinsic::nvvm_atomic_min_gen_i_sys:
3848 case Intrinsic::nvvm_atomic_or_gen_i_cta:
3849 case Intrinsic::nvvm_atomic_or_gen_i_sys:
3850 case Intrinsic::nvvm_atomic_exch_gen_i_cta:
3851 case Intrinsic::nvvm_atomic_exch_gen_i_sys:
3852 case Intrinsic::nvvm_atomic_xor_gen_i_cta:
3853 case Intrinsic::nvvm_atomic_xor_gen_i_sys: {
3854 auto &DL = I.getDataLayout();
3856 Info.memVT = getValueType(DL, I.getType());
3857 Info.ptrVal = I.getArgOperand(0);
3858 Info.offset = 0;
3860 Info.align.reset();
3861 return true;
3862 }
3863
3864 case Intrinsic::nvvm_ldu_global_i:
3865 case Intrinsic::nvvm_ldu_global_f:
3866 case Intrinsic::nvvm_ldu_global_p: {
3867 auto &DL = I.getDataLayout();
3869 if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
3870 Info.memVT = getValueType(DL, I.getType());
3871 else if(Intrinsic == Intrinsic::nvvm_ldu_global_p)
3872 Info.memVT = getPointerTy(DL);
3873 else
3874 Info.memVT = getValueType(DL, I.getType());
3875 Info.ptrVal = I.getArgOperand(0);
3876 Info.offset = 0;
3878 Info.align = cast<ConstantInt>(I.getArgOperand(1))->getMaybeAlignValue();
3879
3880 return true;
3881 }
3882 case Intrinsic::nvvm_tex_1d_v4f32_s32:
3883 case Intrinsic::nvvm_tex_1d_v4f32_f32:
3884 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
3885 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
3886 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
3887 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
3888 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
3889 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
3890 case Intrinsic::nvvm_tex_2d_v4f32_s32:
3891 case Intrinsic::nvvm_tex_2d_v4f32_f32:
3892 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
3893 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
3894 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
3895 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
3896 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
3897 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
3898 case Intrinsic::nvvm_tex_3d_v4f32_s32:
3899 case Intrinsic::nvvm_tex_3d_v4f32_f32:
3900 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
3901 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
3902 case Intrinsic::nvvm_tex_cube_v4f32_f32:
3903 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
3904 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
3905 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
3906 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
3907 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
3908 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
3909 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
3910 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
3911 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
3912 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
3913 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
3914 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
3915 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
3916 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
3917 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
3918 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
3919 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
3920 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
3921 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
3922 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
3923 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
3924 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
3925 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
3926 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
3927 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
3928 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
3929 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
3930 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
3931 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
3932 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
3933 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
3934 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
3935 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
3936 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
3937 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
3938 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
3939 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
3941 Info.memVT = MVT::v4f32;
3942 Info.ptrVal = nullptr;
3943 Info.offset = 0;
3945 Info.align = Align(16);
3946 return true;
3947
3948 case Intrinsic::nvvm_tex_1d_v4s32_s32:
3949 case Intrinsic::nvvm_tex_1d_v4s32_f32:
3950 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
3951 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
3952 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
3953 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
3954 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
3955 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
3956 case Intrinsic::nvvm_tex_2d_v4s32_s32:
3957 case Intrinsic::nvvm_tex_2d_v4s32_f32:
3958 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
3959 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
3960 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
3961 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
3962 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
3963 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
3964 case Intrinsic::nvvm_tex_3d_v4s32_s32:
3965 case Intrinsic::nvvm_tex_3d_v4s32_f32:
3966 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
3967 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
3968 case Intrinsic::nvvm_tex_cube_v4s32_f32:
3969 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
3970 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
3971 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
3972 case Intrinsic::nvvm_tex_cube_v4u32_f32:
3973 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
3974 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
3975 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
3976 case Intrinsic::nvvm_tex_1d_v4u32_s32:
3977 case Intrinsic::nvvm_tex_1d_v4u32_f32:
3978 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
3979 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
3980 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
3981 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
3982 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
3983 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
3984 case Intrinsic::nvvm_tex_2d_v4u32_s32:
3985 case Intrinsic::nvvm_tex_2d_v4u32_f32:
3986 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
3987 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
3988 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
3989 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
3990 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
3991 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
3992 case Intrinsic::nvvm_tex_3d_v4u32_s32:
3993 case Intrinsic::nvvm_tex_3d_v4u32_f32:
3994 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
3995 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
3996 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
3997 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
3998 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
3999 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
4000 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
4001 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
4002 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
4003 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
4004 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
4005 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
4006 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
4007 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
4008 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
4009 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
4010 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
4011 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
4012 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
4013 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
4014 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
4015 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
4016 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
4017 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
4018 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
4019 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
4020 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
4021 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
4022 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
4023 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
4024 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
4025 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
4026 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
4027 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
4028 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4029 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4030 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4031 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4032 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4033 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4034 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4035 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4036 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4037 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4038 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4039 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4040 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4041 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4042 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4043 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4044 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4045 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4046 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4047 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4048 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4049 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4050 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4051 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4052 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4053 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4054 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4055 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4056 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4057 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4058 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4059 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4060 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4061 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4062 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4063 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4065 Info.memVT = MVT::v4i32;
4066 Info.ptrVal = nullptr;
4067 Info.offset = 0;
4069 Info.align = Align(16);
4070 return true;
4071
4072 case Intrinsic::nvvm_suld_1d_i8_clamp:
4073 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4074 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4075 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4076 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4077 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4078 case Intrinsic::nvvm_suld_2d_i8_clamp:
4079 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4080 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4081 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4082 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4083 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4084 case Intrinsic::nvvm_suld_3d_i8_clamp:
4085 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4086 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4087 case Intrinsic::nvvm_suld_1d_i8_trap:
4088 case Intrinsic::nvvm_suld_1d_v2i8_trap:
4089 case Intrinsic::nvvm_suld_1d_v4i8_trap:
4090 case Intrinsic::nvvm_suld_1d_array_i8_trap:
4091 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
4092 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
4093 case Intrinsic::nvvm_suld_2d_i8_trap:
4094 case Intrinsic::nvvm_suld_2d_v2i8_trap:
4095 case Intrinsic::nvvm_suld_2d_v4i8_trap:
4096 case Intrinsic::nvvm_suld_2d_array_i8_trap:
4097 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
4098 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
4099 case Intrinsic::nvvm_suld_3d_i8_trap:
4100 case Intrinsic::nvvm_suld_3d_v2i8_trap:
4101 case Intrinsic::nvvm_suld_3d_v4i8_trap:
4102 case Intrinsic::nvvm_suld_1d_i8_zero:
4103 case Intrinsic::nvvm_suld_1d_v2i8_zero:
4104 case Intrinsic::nvvm_suld_1d_v4i8_zero:
4105 case Intrinsic::nvvm_suld_1d_array_i8_zero:
4106 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
4107 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
4108 case Intrinsic::nvvm_suld_2d_i8_zero:
4109 case Intrinsic::nvvm_suld_2d_v2i8_zero:
4110 case Intrinsic::nvvm_suld_2d_v4i8_zero:
4111 case Intrinsic::nvvm_suld_2d_array_i8_zero:
4112 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
4113 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
4114 case Intrinsic::nvvm_suld_3d_i8_zero:
4115 case Intrinsic::nvvm_suld_3d_v2i8_zero:
4116 case Intrinsic::nvvm_suld_3d_v4i8_zero:
4118 Info.memVT = MVT::i8;
4119 Info.ptrVal = nullptr;
4120 Info.offset = 0;
4122 Info.align = Align(16);
4123 return true;
4124
4125 case Intrinsic::nvvm_suld_1d_i16_clamp:
4126 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
4127 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
4128 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
4129 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
4130 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
4131 case Intrinsic::nvvm_suld_2d_i16_clamp:
4132 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
4133 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
4134 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
4135 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
4136 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
4137 case Intrinsic::nvvm_suld_3d_i16_clamp:
4138 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
4139 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
4140 case Intrinsic::nvvm_suld_1d_i16_trap:
4141 case Intrinsic::nvvm_suld_1d_v2i16_trap:
4142 case Intrinsic::nvvm_suld_1d_v4i16_trap:
4143 case Intrinsic::nvvm_suld_1d_array_i16_trap:
4144 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
4145 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
4146 case Intrinsic::nvvm_suld_2d_i16_trap:
4147 case Intrinsic::nvvm_suld_2d_v2i16_trap:
4148 case Intrinsic::nvvm_suld_2d_v4i16_trap:
4149 case Intrinsic::nvvm_suld_2d_array_i16_trap:
4150 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
4151 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
4152 case Intrinsic::nvvm_suld_3d_i16_trap:
4153 case Intrinsic::nvvm_suld_3d_v2i16_trap:
4154 case Intrinsic::nvvm_suld_3d_v4i16_trap:
4155 case Intrinsic::nvvm_suld_1d_i16_zero:
4156 case Intrinsic::nvvm_suld_1d_v2i16_zero:
4157 case Intrinsic::nvvm_suld_1d_v4i16_zero:
4158 case Intrinsic::nvvm_suld_1d_array_i16_zero:
4159 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
4160 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
4161 case Intrinsic::nvvm_suld_2d_i16_zero:
4162 case Intrinsic::nvvm_suld_2d_v2i16_zero:
4163 case Intrinsic::nvvm_suld_2d_v4i16_zero:
4164 case Intrinsic::nvvm_suld_2d_array_i16_zero:
4165 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
4166 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
4167 case Intrinsic::nvvm_suld_3d_i16_zero:
4168 case Intrinsic::nvvm_suld_3d_v2i16_zero:
4169 case Intrinsic::nvvm_suld_3d_v4i16_zero:
4171 Info.memVT = MVT::i16;
4172 Info.ptrVal = nullptr;
4173 Info.offset = 0;
4175 Info.align = Align(16);
4176 return true;
4177
4178 case Intrinsic::nvvm_suld_1d_i32_clamp:
4179 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
4180 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
4181 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
4182 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
4183 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
4184 case Intrinsic::nvvm_suld_2d_i32_clamp:
4185 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
4186 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
4187 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
4188 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
4189 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
4190 case Intrinsic::nvvm_suld_3d_i32_clamp:
4191 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
4192 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
4193 case Intrinsic::nvvm_suld_1d_i32_trap:
4194 case Intrinsic::nvvm_suld_1d_v2i32_trap:
4195 case Intrinsic::nvvm_suld_1d_v4i32_trap:
4196 case Intrinsic::nvvm_suld_1d_array_i32_trap:
4197 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
4198 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
4199 case Intrinsic::nvvm_suld_2d_i32_trap:
4200 case Intrinsic::nvvm_suld_2d_v2i32_trap:
4201 case Intrinsic::nvvm_suld_2d_v4i32_trap:
4202 case Intrinsic::nvvm_suld_2d_array_i32_trap:
4203 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
4204 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
4205 case Intrinsic::nvvm_suld_3d_i32_trap:
4206 case Intrinsic::nvvm_suld_3d_v2i32_trap:
4207 case Intrinsic::nvvm_suld_3d_v4i32_trap:
4208 case Intrinsic::nvvm_suld_1d_i32_zero:
4209 case Intrinsic::nvvm_suld_1d_v2i32_zero:
4210 case Intrinsic::nvvm_suld_1d_v4i32_zero:
4211 case Intrinsic::nvvm_suld_1d_array_i32_zero:
4212 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
4213 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
4214 case Intrinsic::nvvm_suld_2d_i32_zero:
4215 case Intrinsic::nvvm_suld_2d_v2i32_zero:
4216 case Intrinsic::nvvm_suld_2d_v4i32_zero:
4217 case Intrinsic::nvvm_suld_2d_array_i32_zero:
4218 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
4219 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
4220 case Intrinsic::nvvm_suld_3d_i32_zero:
4221 case Intrinsic::nvvm_suld_3d_v2i32_zero:
4222 case Intrinsic::nvvm_suld_3d_v4i32_zero:
4224 Info.memVT = MVT::i32;
4225 Info.ptrVal = nullptr;
4226 Info.offset = 0;
4228 Info.align = Align(16);
4229 return true;
4230
4231 case Intrinsic::nvvm_suld_1d_i64_clamp:
4232 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
4233 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
4234 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
4235 case Intrinsic::nvvm_suld_2d_i64_clamp:
4236 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
4237 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
4238 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
4239 case Intrinsic::nvvm_suld_3d_i64_clamp:
4240 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
4241 case Intrinsic::nvvm_suld_1d_i64_trap:
4242 case Intrinsic::nvvm_suld_1d_v2i64_trap:
4243 case Intrinsic::nvvm_suld_1d_array_i64_trap:
4244 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
4245 case Intrinsic::nvvm_suld_2d_i64_trap:
4246 case Intrinsic::nvvm_suld_2d_v2i64_trap:
4247 case Intrinsic::nvvm_suld_2d_array_i64_trap:
4248 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
4249 case Intrinsic::nvvm_suld_3d_i64_trap:
4250 case Intrinsic::nvvm_suld_3d_v2i64_trap:
4251 case Intrinsic::nvvm_suld_1d_i64_zero:
4252 case Intrinsic::nvvm_suld_1d_v2i64_zero:
4253 case Intrinsic::nvvm_suld_1d_array_i64_zero:
4254 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
4255 case Intrinsic::nvvm_suld_2d_i64_zero:
4256 case Intrinsic::nvvm_suld_2d_v2i64_zero:
4257 case Intrinsic::nvvm_suld_2d_array_i64_zero:
4258 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
4259 case Intrinsic::nvvm_suld_3d_i64_zero:
4260 case Intrinsic::nvvm_suld_3d_v2i64_zero:
4262 Info.memVT = MVT::i64;
4263 Info.ptrVal = nullptr;
4264 Info.offset = 0;
4266 Info.align = Align(16);
4267 return true;
4268 }
4269 return false;
4270}
4271
4272/// getFunctionParamOptimizedAlign - since function arguments are passed via
4273/// .param space, we may want to increase their alignment in a way that
4274/// ensures that we can effectively vectorize their loads & stores. We can
4275/// increase alignment only if the function has internal or has private
4276/// linkage as for other linkage types callers may already rely on default
4277/// alignment. To allow using 128-bit vectorized loads/stores, this function
4278/// ensures that alignment is 16 or greater.
4280 const Function *F, Type *ArgTy, const DataLayout &DL) const {
4281 // Capping the alignment to 128 bytes as that is the maximum alignment
4282 // supported by PTX.
4283 const Align ABITypeAlign = std::min(Align(128), DL.getABITypeAlign(ArgTy));
4284
4285 // If a function has linkage different from internal or private, we
4286 // must use default ABI alignment as external users rely on it. Same
4287 // for a function that may be called from a function pointer.
4288 if (!F || !F->hasLocalLinkage() ||
4289 F->hasAddressTaken(/*Users=*/nullptr,
4290 /*IgnoreCallbackUses=*/false,
4291 /*IgnoreAssumeLikeCalls=*/true,
4292 /*IgnoreLLVMUsed=*/true))
4293 return ABITypeAlign;
4294
4295 assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");
4296 return std::max(Align(16), ABITypeAlign);
4297}
4298
4299/// Helper for computing alignment of a device function byval parameter.
4301 const Function *F, Type *ArgTy, Align InitialAlign,
4302 const DataLayout &DL) const {
4303 Align ArgAlign = InitialAlign;
4304 // Try to increase alignment to enhance vectorization options.
4305 if (F)
4306 ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
4307
4308 // Old ptx versions have a bug. When PTX code takes address of
4309 // byval parameter with alignment < 4, ptxas generates code to
4310 // spill argument into memory. Alas on sm_50+ ptxas generates
4311 // SASS code that fails with misaligned access. To work around
4312 // the problem, make sure that we align byval parameters by at
4313 // least 4. This bug seems to be fixed at least starting from
4314 // ptxas > 9.0.
4315 // TODO: remove this after verifying the bug is not reproduced
4316 // on non-deprecated ptxas versions.
4318 ArgAlign = std::max(ArgAlign, Align(4));
4319
4320 return ArgAlign;
4321}
4322
4323// Helper for getting a function parameter name. Name is composed from
4324// its index and the function name. Negative index corresponds to special
4325// parameter (unsized array) used for passing variable arguments.
4327 int Idx) const {
4328 std::string ParamName;
4329 raw_string_ostream ParamStr(ParamName);
4330
4331 ParamStr << getTargetMachine().getSymbol(F)->getName();
4332 if (Idx < 0)
4333 ParamStr << "_vararg";
4334 else
4335 ParamStr << "_param_" << Idx;
4336
4337 return ParamName;
4338}
4339
4340/// isLegalAddressingMode - Return true if the addressing mode represented
4341/// by AM is legal for this target, for a load/store of the specified type.
4342/// Used to guide target specific optimizations, like loop strength reduction
4343/// (LoopStrengthReduce.cpp) and memory optimization for address mode
4344/// (CodeGenPrepare.cpp)
4346 const AddrMode &AM, Type *Ty,
4347 unsigned AS, Instruction *I) const {
4348 // AddrMode - This represents an addressing mode of:
4349 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
4350 //
4351 // The legal address modes are
4352 // - [avar]
4353 // - [areg]
4354 // - [areg+immoff]
4355 // - [immAddr]
4356
4357 // immoff must fit in a signed 32-bit int
4358 if (!APInt(64, AM.BaseOffs).isSignedIntN(32))
4359 return false;
4360
4361 if (AM.BaseGV)
4362 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
4363
4364 switch (AM.Scale) {
4365 case 0: // "r", "r+i" or "i" is allowed
4366 break;
4367 case 1:
4368 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
4369 return false;
4370 // Otherwise we have r+i.
4371 break;
4372 default:
4373 // No scale > 1 is allowed
4374 return false;
4375 }
4376 return true;
4377}
4378
4379//===----------------------------------------------------------------------===//
4380// NVPTX Inline Assembly Support
4381//===----------------------------------------------------------------------===//
4382
4383/// getConstraintType - Given a constraint letter, return the type of
4384/// constraint it is for this target.
4387 if (Constraint.size() == 1) {
4388 switch (Constraint[0]) {
4389 default:
4390 break;
4391 case 'b':
4392 case 'r':
4393 case 'h':
4394 case 'c':
4395 case 'l':
4396 case 'f':
4397 case 'd':
4398 case 'q':
4399 case '0':
4400 case 'N':
4401 return C_RegisterClass;
4402 }
4403 }
4404 return TargetLowering::getConstraintType(Constraint);
4405}
4406
4407std::pair<unsigned, const TargetRegisterClass *>
4409 StringRef Constraint,
4410 MVT VT) const {
4411 if (Constraint.size() == 1) {
4412 switch (Constraint[0]) {
4413 case 'b':
4414 return std::make_pair(0U, &NVPTX::Int1RegsRegClass);
4415 case 'c':
4416 return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
4417 case 'h':
4418 return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
4419 case 'r':
4420 return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
4421 case 'l':
4422 case 'N':
4423 return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
4424 case 'q': {
4425 if (STI.getSmVersion() < 70)
4426 report_fatal_error("Inline asm with 128 bit operands is only "
4427 "supported for sm_70 and higher!");
4428 return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
4429 }
4430 case 'f':
4431 return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
4432 case 'd':
4433 return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
4434 }
4435 }
4436 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
4437}
4438
4439//===----------------------------------------------------------------------===//
4440// NVPTX DAG Combining
4441//===----------------------------------------------------------------------===//
4442
4444 CodeGenOptLevel OptLevel) const {
4445 // Always honor command-line argument
4446 if (FMAContractLevelOpt.getNumOccurrences() > 0)
4447 return FMAContractLevelOpt > 0;
4448
4449 // Do not contract if we're not optimizing the code.
4450 if (OptLevel == CodeGenOptLevel::None)
4451 return false;
4452
4453 // Honor TargetOptions flags that explicitly say fusion is okay.
4455 return true;
4456
4457 return allowUnsafeFPMath(MF);
4458}
4459
4461 // Honor TargetOptions flags that explicitly say unsafe math is okay.
4463 return true;
4464
4465 // Allow unsafe math if unsafe-fp-math attribute explicitly says so.
4466 const Function &F = MF.getFunction();
4467 return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
4468}
4469
4470static bool isConstZero(const SDValue &Operand) {
4471 const auto *Const = dyn_cast<ConstantSDNode>(Operand);
4472 return Const && Const->getZExtValue() == 0;
4473}
4474
4475/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
4476/// operands N0 and N1. This is a helper for PerformADDCombine that is
4477/// called with the default operands, and if that fails, with commuted
4478/// operands.
4479static SDValue
4482 EVT VT = N0.getValueType();
4483
4484 // Since integer multiply-add costs the same as integer multiply
4485 // but is more costly than integer add, do the fusion only when
4486 // the mul is only used in the add.
4487 // TODO: this may not be true for later architectures, consider relaxing this
4488 if (!N0.getNode()->hasOneUse())
4489 return SDValue();
4490
4491 // fold (add (select cond, 0, (mul a, b)), c)
4492 // -> (select cond, c, (add (mul a, b), c))
4493 //
4494 if (N0.getOpcode() == ISD::SELECT) {
4495 unsigned ZeroOpNum;
4496 if (isConstZero(N0->getOperand(1)))
4497 ZeroOpNum = 1;
4498 else if (isConstZero(N0->getOperand(2)))
4499 ZeroOpNum = 2;
4500 else
4501 return SDValue();
4502
4503 SDValue M = N0->getOperand((ZeroOpNum == 1) ? 2 : 1);
4504 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
4505 return SDValue();
4506
4507 SDLoc DL(N);
4508 SDValue Mul =
4509 DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1));
4510 SDValue MAD = DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, N1);
4511 return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
4512 ((ZeroOpNum == 1) ? N1 : MAD),
4513 ((ZeroOpNum == 1) ? MAD : N1));
4514 }
4515
4516 return SDValue();
4517}
4518
4519static SDValue
4522 CodeGenOptLevel OptLevel) {
4523 EVT VT = N0.getValueType();
4524 if (N0.getOpcode() == ISD::FMUL) {
4525 const auto *TLI = static_cast<const NVPTXTargetLowering *>(
4526 &DCI.DAG.getTargetLoweringInfo());
4527 if (!TLI->allowFMA(DCI.DAG.getMachineFunction(), OptLevel))
4528 return SDValue();
4529
4530 // For floating point:
4531 // Do the fusion only when the mul has less than 5 uses and all
4532 // are add.
4533 // The heuristic is that if a use is not an add, then that use
4534 // cannot be fused into fma, therefore mul is still needed anyway.
4535 // If there are more than 4 uses, even if they are all add, fusing
4536 // them will increase register pressue.
4537 //
4538 int numUses = 0;
4539 int nonAddCount = 0;
4540 for (const SDNode *User : N0.getNode()->users()) {
4541 numUses++;
4542 if (User->getOpcode() != ISD::FADD)
4543 ++nonAddCount;
4544 if (numUses >= 5)
4545 return SDValue();
4546 }
4547 if (nonAddCount) {
4548 int orderNo = N->getIROrder();
4549 int orderNo2 = N0.getNode()->getIROrder();
4550 // simple heuristics here for considering potential register
4551 // pressure, the logics here is that the differnce are used
4552 // to measure the distance between def and use, the longer distance
4553 // more likely cause register pressure.
4554 if (orderNo - orderNo2 < 500)
4555 return SDValue();
4556
4557 // Now, check if at least one of the FMUL's operands is live beyond the
4558 // node N, which guarantees that the FMA will not increase register
4559 // pressure at node N.
4560 bool opIsLive = false;
4561 const SDNode *left = N0.getOperand(0).getNode();
4562 const SDNode *right = N0.getOperand(1).getNode();
4563
4564 if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
4565 opIsLive = true;
4566
4567 if (!opIsLive)
4568 for (const SDNode *User : left->users()) {
4569 int orderNo3 = User->getIROrder();
4570 if (orderNo3 > orderNo) {
4571 opIsLive = true;
4572 break;
4573 }
4574 }
4575
4576 if (!opIsLive)
4577 for (const SDNode *User : right->users()) {
4578 int orderNo3 = User->getIROrder();
4579 if (orderNo3 > orderNo) {
4580 opIsLive = true;
4581 break;
4582 }
4583 }
4584
4585 if (!opIsLive)
4586 return SDValue();
4587 }
4588
4589 return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0),
4590 N0.getOperand(1), N1);
4591 }
4592
4593 return SDValue();
4594}
4595
4596static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
4597 std::size_t Back) {
4598 if (all_of(N->ops().drop_front(Front).drop_back(Back),
4599 [](const SDUse &U) { return U.get()->isUndef(); }))
4600 // Operand 0 is the previous value in the chain. Cannot return EntryToken
4601 // as the previous value will become unused and eliminated later.
4602 return N->getOperand(0);
4603
4604 return SDValue();
4605}
4606
4608 // Operands from the 3rd to the 2nd last one are the values to be stored.
4609 // {Chain, ArgID, Offset, Val, Glue}
4610 return PerformStoreCombineHelper(N, 3, 1);
4611}
4612
4614 // Operands from the 2nd to the last one are the values to be stored
4615 return PerformStoreCombineHelper(N, 2, 0);
4616}
4617
4618/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
4619///
4622 CodeGenOptLevel OptLevel) {
4623 if (OptLevel == CodeGenOptLevel::None)
4624 return SDValue();
4625
4626 SDValue N0 = N->getOperand(0);
4627 SDValue N1 = N->getOperand(1);
4628
4629 // Skip non-integer, non-scalar case
4630 EVT VT = N0.getValueType();
4631 if (VT.isVector() || VT != MVT::i32)
4632 return SDValue();
4633
4634 // First try with the default operand order.
4635 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
4636 return Result;
4637
4638 // If that didn't work, try again with the operands commuted.
4639 return PerformADDCombineWithOperands(N, N1, N0, DCI);
4640}
4641
4642/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
4643///
4646 CodeGenOptLevel OptLevel) {
4647 SDValue N0 = N->getOperand(0);
4648 SDValue N1 = N->getOperand(1);
4649
4650 EVT VT = N0.getValueType();
4651 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
4652 return SDValue();
4653
4654 // First try with the default operand order.
4655 if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
4656 return Result;
4657
4658 // If that didn't work, try again with the operands commuted.
4659 return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
4660}
4661
4664 // The type legalizer turns a vector load of i8 values into a zextload to i16
4665 // registers, optionally ANY_EXTENDs it (if target type is integer),
4666 // and ANDs off the high 8 bits. Since we turn this load into a
4667 // target-specific DAG node, the DAG combiner fails to eliminate these AND
4668 // nodes. Do that here.
4669 SDValue Val = N->getOperand(0);
4670 SDValue Mask = N->getOperand(1);
4671
4672 if (isa<ConstantSDNode>(Val)) {
4673 std::swap(Val, Mask);
4674 }
4675
4676 SDValue AExt;
4677
4678 // Convert BFE-> truncate i16 -> and 255
4679 // To just BFE-> truncate i16, as the value already has all the bits in the
4680 // right places.
4681 if (Val.getOpcode() == ISD::TRUNCATE) {
4682 SDValue BFE = Val.getOperand(0);
4683 if (BFE.getOpcode() != NVPTXISD::BFE)
4684 return SDValue();
4685
4686 ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(BFE.getOperand(0));
4687 if (!BFEBits)
4688 return SDValue();
4689 uint64_t BFEBitsVal = BFEBits->getZExtValue();
4690
4691 ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
4692 if (!MaskCnst) {
4693 // Not an AND with a constant
4694 return SDValue();
4695 }
4696 uint64_t MaskVal = MaskCnst->getZExtValue();
4697
4698 if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1)
4699 return SDValue();
4700 // If we get here, the AND is unnecessary. Just replace it with the trunc
4701 DCI.CombineTo(N, Val, false);
4702 }
4703 // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
4704 if (Val.getOpcode() == ISD::ANY_EXTEND) {
4705 AExt = Val;
4706 Val = Val->getOperand(0);
4707 }
4708
4709 if (Val->getOpcode() == NVPTXISD::LoadV2 ||
4710 Val->getOpcode() == NVPTXISD::LoadV4) {
4711 ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
4712 if (!MaskCnst) {
4713 // Not an AND with a constant
4714 return SDValue();
4715 }
4716
4717 uint64_t MaskVal = MaskCnst->getZExtValue();
4718 if (MaskVal != 0xff) {
4719 // Not an AND that chops off top 8 bits
4720 return SDValue();
4721 }
4722
4723 MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
4724 if (!Mem) {
4725 // Not a MemSDNode?!?
4726 return SDValue();
4727 }
4728
4729 EVT MemVT = Mem->getMemoryVT();
4730 if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
4731 // We only handle the i8 case
4732 return SDValue();
4733 }
4734
4735 unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
4736 if (ExtType == ISD::SEXTLOAD) {
4737 // If for some reason the load is a sextload, the and is needed to zero
4738 // out the high 8 bits
4739 return SDValue();
4740 }
4741
4742 bool AddTo = false;
4743 if (AExt.getNode() != nullptr) {
4744 // Re-insert the ext as a zext.
4745 Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
4746 AExt.getValueType(), Val);
4747 AddTo = true;
4748 }
4749
4750 // If we get here, the AND is unnecessary. Just replace it with the load
4751 DCI.CombineTo(N, Val, AddTo);
4752 }
4753
4754 return SDValue();
4755}
4756
4759 CodeGenOptLevel OptLevel) {
4760 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
4761
4762 // Don't do anything at less than -O2.
4763 if (OptLevel < CodeGenOptLevel::Default)
4764 return SDValue();
4765
4766 SelectionDAG &DAG = DCI.DAG;
4767 SDLoc DL(N);
4768 EVT VT = N->getValueType(0);
4769 bool IsSigned = N->getOpcode() == ISD::SREM;
4770 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
4771
4772 const SDValue &Num = N->getOperand(0);
4773 const SDValue &Den = N->getOperand(1);
4774
4775 for (const SDNode *U : Num->users()) {
4776 if (U->getOpcode() == DivOpc && U->getOperand(0) == Num &&
4777 U->getOperand(1) == Den) {
4778 // Num % Den -> Num - (Num / Den) * Den
4779 return DAG.getNode(ISD::SUB, DL, VT, Num,
4780 DAG.getNode(ISD::MUL, DL, VT,
4781 DAG.getNode(DivOpc, DL, VT, Num, Den),
4782 Den));
4783 }
4784 }
4785 return SDValue();
4786}
4787
4791 Unknown
4793
4794/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
4795/// that can be demoted to \p OptSize bits without loss of information. The
4796/// signedness of the operand, if determinable, is placed in \p S.
4798 unsigned OptSize,
4799 OperandSignedness &S) {
4800 S = Unknown;
4801
4802 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
4803 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4804 EVT OrigVT = Op.getOperand(0).getValueType();
4805 if (OrigVT.getFixedSizeInBits() <= OptSize) {
4806 S = Signed;
4807 return true;
4808 }
4809 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
4810 EVT OrigVT = Op.getOperand(0).getValueType();
4811 if (OrigVT.getFixedSizeInBits() <= OptSize) {
4812 S = Unsigned;
4813 return true;
4814 }
4815 }
4816
4817 return false;
4818}
4819
4820/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
4821/// be demoted to \p OptSize bits without loss of information. If the operands
4822/// contain a constant, it should appear as the RHS operand. The signedness of
4823/// the operands is placed in \p IsSigned.
4825 unsigned OptSize,
4826 bool &IsSigned) {
4827 OperandSignedness LHSSign;
4828
4829 // The LHS operand must be a demotable op
4830 if (!IsMulWideOperandDemotable(LHS, OptSize, LHSSign))
4831 return false;
4832
4833 // We should have been able to determine the signedness from the LHS
4834 if (LHSSign == Unknown)
4835 return false;
4836
4837 IsSigned = (LHSSign == Signed);
4838
4839 // The RHS can be a demotable op or a constant
4840 if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(RHS)) {
4841 const APInt &Val = CI->getAPIntValue();
4842 if (LHSSign == Unsigned) {
4843 return Val.isIntN(OptSize);
4844 } else {
4845 return Val.isSignedIntN(OptSize);
4846 }
4847 } else {
4848 OperandSignedness RHSSign;
4849 if (!IsMulWideOperandDemotable(RHS, OptSize, RHSSign))
4850 return false;
4851
4852 return LHSSign == RHSSign;
4853 }
4854}
4855
4856/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
4857/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
4858/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
4859/// amount.
4862 EVT MulType = N->getValueType(0);
4863 if (MulType != MVT::i32 && MulType != MVT::i64) {
4864 return SDValue();
4865 }
4866
4867 SDLoc DL(N);
4868 unsigned OptSize = MulType.getSizeInBits() >> 1;
4869 SDValue LHS = N->getOperand(0);
4870 SDValue RHS = N->getOperand(1);
4871
4872 // Canonicalize the multiply so the constant (if any) is on the right
4873 if (N->getOpcode() == ISD::MUL) {
4874 if (isa<ConstantSDNode>(LHS)) {
4875 std::swap(LHS, RHS);
4876 }
4877 }
4878
4879 // If we have a SHL, determine the actual multiply amount
4880 if (N->getOpcode() == ISD::SHL) {
4881 ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(RHS);
4882 if (!ShlRHS) {
4883 return SDValue();
4884 }
4885
4886 APInt ShiftAmt = ShlRHS->getAPIntValue();
4887 unsigned BitWidth = MulType.getSizeInBits();
4888 if (ShiftAmt.sge(0) && ShiftAmt.slt(BitWidth)) {
4889 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
4890 RHS = DCI.DAG.getConstant(MulVal, DL, MulType);
4891 } else {
4892 return SDValue();
4893 }
4894 }
4895
4896 bool Signed;
4897 // Verify that our operands are demotable
4898 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, Signed)) {
4899 return SDValue();
4900 }
4901
4902 EVT DemotedVT;
4903 if (MulType == MVT::i32) {
4904 DemotedVT = MVT::i16;
4905 } else {
4906 DemotedVT = MVT::i32;
4907 }
4908
4909 // Truncate the operands to the correct size. Note that these are just for
4910 // type consistency and will (likely) be eliminated in later phases.
4911 SDValue TruncLHS =
4912 DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, LHS);
4913 SDValue TruncRHS =
4914 DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, RHS);
4915
4916 unsigned Opc;
4917 if (Signed) {
4919 } else {
4921 }
4922
4923 return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
4924}
4925
4926static bool isConstOne(const SDValue &Operand) {
4927 const auto *Const = dyn_cast<ConstantSDNode>(Operand);
4928 return Const && Const->getZExtValue() == 1;
4929}
4930
4932 if (Add->getOpcode() != ISD::ADD)
4933 return SDValue();
4934
4935 if (isConstOne(Add->getOperand(0)))
4936 return Add->getOperand(1);
4937
4938 if (isConstOne(Add->getOperand(1)))
4939 return Add->getOperand(0);
4940
4941 return SDValue();
4942}
4943
4946
4948 SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
4949 return DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, X);
4950 }
4951
4952 return SDValue();
4953}
4954
4956 SDLoc DL,
4958 if (Select->getOpcode() != ISD::SELECT)
4959 return SDValue();
4960
4961 SDValue Cond = Select->getOperand(0);
4962
4963 unsigned ConstOpNo;
4964 if (isConstOne(Select->getOperand(1)))
4965 ConstOpNo = 1;
4966 else if (isConstOne(Select->getOperand(2)))
4967 ConstOpNo = 2;
4968 else
4969 return SDValue();
4970
4971 SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);
4972
4973 // Do not combine if the resulting sequence is not obviously profitable.
4975 return SDValue();
4976
4977 SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
4978
4979 return DCI.DAG.getNode(ISD::SELECT, DL, VT, Cond,
4980 (ConstOpNo == 1) ? X : NewMul,
4981 (ConstOpNo == 1) ? NewMul : X);
4982}
4983
4984static SDValue
4987
4988 EVT VT = N0.getValueType();
4989 if (VT.isVector())
4990 return SDValue();
4991
4992 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
4993 return SDValue();
4994
4995 SDLoc DL(N);
4996
4997 // (mul x, (add y, 1)) -> (add (mul x, y), x)
4998 if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
4999 return Res;
5000 if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
5001 return Res;
5002
5003 // (mul x, (select y, 1)) -> (select (mul x, y), x)
5004 if (SDValue Res = combineMulSelectConstOne(N0, N1, VT, DL, DCI))
5005 return Res;
5006 if (SDValue Res = combineMulSelectConstOne(N1, N0, VT, DL, DCI))
5007 return Res;
5008
5009 return SDValue();
5010}
5011
5012/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
5015 CodeGenOptLevel OptLevel) {
5016 if (OptLevel == CodeGenOptLevel::None)
5017 return SDValue();
5018
5019 if (SDValue Ret = TryMULWIDECombine(N, DCI))
5020 return Ret;
5021
5022 SDValue N0 = N->getOperand(0);
5023 SDValue N1 = N->getOperand(1);
5024 return PerformMULCombineWithOperands(N, N0, N1, DCI);
5025}
5026
5027/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
5030 CodeGenOptLevel OptLevel) {
5031 if (OptLevel > CodeGenOptLevel::None) {
5032 // Try mul.wide combining at OptLevel > 0
5033 if (SDValue Ret = TryMULWIDECombine(N, DCI))
5034 return Ret;
5035 }
5036
5037 return SDValue();
5038}
5039
5042 unsigned int SmVersion) {
5043 EVT CCType = N->getValueType(0);
5044 SDValue A = N->getOperand(0);
5045 SDValue B = N->getOperand(1);
5046
5047 EVT AType = A.getValueType();
5048 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
5049 return SDValue();
5050
5051 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
5052 return SDValue();
5053
5054 SDLoc DL(N);
5055 // setp.f16x2 returns two scalar predicates, which we need to
5056 // convert back to v2i1. The returned result will be scalarized by
5057 // the legalizer, but the comparison will remain a single vector
5058 // instruction.
5059 SDValue CCNode = DCI.DAG.getNode(
5060 A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
5062 DL, DCI.DAG.getVTList(MVT::i1, MVT::i1), {A, B, N->getOperand(2)});
5063 return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0),
5064 CCNode.getValue(1));
5065}
5066
5069 SDValue Vector = N->getOperand(0);
5070 if (Vector->getOpcode() == ISD::FREEZE)
5071 Vector = Vector->getOperand(0);
5072 SDLoc DL(N);
5073 EVT VectorVT = Vector.getValueType();
5074 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
5075 IsPTXVectorType(VectorVT.getSimpleVT()))
5076 return SDValue(); // Native vector loads already combine nicely w/
5077 // extract_vector_elt.
5078 // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
5079 // handle them OK.
5080 if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
5081 VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
5082 return SDValue();
5083
5084 // Don't mess with undef values as sra may be simplified to 0, not undef.
5085 if (Vector->isUndef() || ISD::allOperandsUndef(Vector.getNode()))
5086 return SDValue();
5087
5088 uint64_t VectorBits = VectorVT.getSizeInBits();
5089 // We only handle the types we can extract in-register.
5090 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
5091 return SDValue();
5092
5093 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(N->getOperand(1));
5094 // Index == 0 is handled by generic DAG combiner.
5095 if (!Index || Index->getZExtValue() == 0)
5096 return SDValue();
5097
5098 MVT IVT = MVT::getIntegerVT(VectorBits);
5099 EVT EltVT = VectorVT.getVectorElementType();
5100 EVT EltIVT = EltVT.changeTypeToInteger();
5101 uint64_t EltBits = EltVT.getScalarSizeInBits();
5102
5103 SDValue Result = DCI.DAG.getNode(
5104 ISD::TRUNCATE, DL, EltIVT,
5105 DCI.DAG.getNode(
5106 ISD::SRA, DL, IVT, DCI.DAG.getNode(ISD::BITCAST, DL, IVT, Vector),
5107 DCI.DAG.getConstant(Index->getZExtValue() * EltBits, DL, IVT)));
5108
5109 // If element has non-integer type, bitcast it back to the expected type.
5110 if (EltVT != EltIVT)
5111 Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result);
5112 // Past legalizer, we may need to extent i8 -> i16 to match the register type.
5113 if (EltVT != N->getValueType(0))
5114 Result = DCI.DAG.getNode(ISD::ANY_EXTEND, DL, N->getValueType(0), Result);
5115
5116 return Result;
5117}
5118
5121 SDValue VA = N->getOperand(1);
5122 EVT VectorVT = VA.getValueType();
5123 if (VectorVT != MVT::v4i8)
5124 return SDValue();
5125
5126 // We need to split vselect into individual per-element operations Because we
5127 // use BFE/BFI instruction for byte extraction/insertion, we do end up with
5128 // 32-bit values, so we may as well do comparison as i32 to avoid conversions
5129 // to/from i16 normally used for i8 values.
5131 SDLoc DL(N);
5132 SDValue VCond = N->getOperand(0);
5133 SDValue VB = N->getOperand(2);
5134 for (int I = 0; I < 4; ++I) {
5135 SDValue C = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i1, VCond,
5136 DCI.DAG.getConstant(I, DL, MVT::i32));
5137 SDValue EA = DCI.DAG.getAnyExtOrTrunc(
5138 DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VA,
5139 DCI.DAG.getConstant(I, DL, MVT::i32)),
5140 DL, MVT::i32);
5141 SDValue EB = DCI.DAG.getAnyExtOrTrunc(
5142 DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VB,
5143 DCI.DAG.getConstant(I, DL, MVT::i32)),
5144 DL, MVT::i32);
5146 DCI.DAG.getNode(ISD::SELECT, DL, MVT::i32, C, EA, EB), DL, MVT::i8));
5147 }
5148 return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E);
5149}
5150
5151static SDValue
5153 auto VT = N->getValueType(0);
5154 if (!DCI.isAfterLegalizeDAG() || !Isv2x16VT(VT))
5155 return SDValue();
5156
5157 auto Op0 = N->getOperand(0);
5158 auto Op1 = N->getOperand(1);
5159
5160 // Start out by assuming we want to take the lower 2 bytes of each i32
5161 // operand.
5162 uint64_t Op0Bytes = 0x10;
5163 uint64_t Op1Bytes = 0x54;
5164
5165 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
5166 {&Op1, &Op1Bytes}};
5167
5168 // Check that each operand is an i16, truncated from an i32 operand. We'll
5169 // select individual bytes from those original operands. Optionally, fold in a
5170 // shift right of that original operand.
5171 for (auto &[Op, OpBytes] : OpData) {
5172 // Eat up any bitcast
5173 if (Op->getOpcode() == ISD::BITCAST)
5174 *Op = Op->getOperand(0);
5175
5176 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
5177 Op->getOperand(0).getValueType() == MVT::i32))
5178 return SDValue();
5179
5180 // If the truncate has multiple uses, this optimization can increase
5181 // register pressure
5182 if (!Op->hasOneUse())
5183 return SDValue();
5184
5185 *Op = Op->getOperand(0);
5186
5187 // Optionally, fold in a shift-right of the original operand and let permute
5188 // pick the two higher bytes of the original value directly.
5189 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) {
5190 if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) {
5191 // Shift the PRMT byte selector to pick upper bytes from each respective
5192 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
5193 assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
5194 "PRMT selector values out of range");
5195 *OpBytes += 0x22;
5196 *Op = Op->getOperand(0);
5197 }
5198 }
5199 }
5200
5201 SDLoc DL(N);
5202 auto &DAG = DCI.DAG;
5203
5204 auto PRMT = DAG.getNode(
5205 NVPTXISD::PRMT, DL, MVT::v4i8,
5206 {Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32),
5207 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
5208 return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
5209}
5210
5211SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5212 DAGCombinerInfo &DCI) const {
5214 switch (N->getOpcode()) {
5215 default: break;
5216 case ISD::ADD:
5217 return PerformADDCombine(N, DCI, OptLevel);
5218 case ISD::FADD:
5219 return PerformFADDCombine(N, DCI, OptLevel);
5220 case ISD::MUL:
5221 return PerformMULCombine(N, DCI, OptLevel);
5222 case ISD::SHL:
5223 return PerformSHLCombine(N, DCI, OptLevel);
5224 case ISD::AND:
5225 return PerformANDCombine(N, DCI);
5226 case ISD::UREM:
5227 case ISD::SREM:
5228 return PerformREMCombine(N, DCI, OptLevel);
5229 case ISD::SETCC:
5230 return PerformSETCCCombine(N, DCI, STI.getSmVersion());
5240 return PerformEXTRACTCombine(N, DCI);
5241 case ISD::VSELECT:
5242 return PerformVSELECTCombine(N, DCI);
5243 case ISD::BUILD_VECTOR:
5244 return PerformBUILD_VECTORCombine(N, DCI);
5245 }
5246 return SDValue();
5247}
5248
5251 // Handle bitcasting to v2i8 without hitting the default promotion
5252 // strategy which goes through stack memory.
5253 SDValue Op(Node, 0);
5254 EVT ToVT = Op->getValueType(0);
5255 if (ToVT != MVT::v2i8) {
5256 return;
5257 }
5258
5259 // Bitcast to i16 and unpack elements into a vector
5260 SDLoc DL(Node);
5261 SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
5262 SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
5263 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
5264 SDValue Vec1 =
5265 DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
5266 DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
5267 Results.push_back(
5268 DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
5269}
5270
5271/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
5274 EVT ResVT = N->getValueType(0);
5275 SDLoc DL(N);
5276
5277 assert(ResVT.isVector() && "Vector load must have vector type");
5278
5279 auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
5280 if (!NumEltsAndEltVT)
5281 return;
5282 auto [NumElts, EltVT] = NumEltsAndEltVT.value();
5283
5284 LoadSDNode *LD = cast<LoadSDNode>(N);
5285
5286 Align Alignment = LD->getAlign();
5287 auto &TD = DAG.getDataLayout();
5288 Align PrefAlign =
5289 TD.getPrefTypeAlign(LD->getMemoryVT().getTypeForEVT(*DAG.getContext()));
5290 if (Alignment < PrefAlign) {
5291 // This load is not sufficiently aligned, so bail out and let this vector
5292 // load be scalarized. Note that we may still be able to emit smaller
5293 // vector loads. For example, if we are loading a <4 x float> with an
5294 // alignment of 8, this check will fail but the legalizer will try again
5295 // with 2 x <2 x float>, which will succeed with an alignment of 8.
5296 return;
5297 }
5298
5299 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
5300 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
5301 // loaded type to i16 and propagate the "real" type as the memory type.
5302 bool NeedTrunc = false;
5303 if (EltVT.getSizeInBits() < 16) {
5304 EltVT = MVT::i16;
5305 NeedTrunc = true;
5306 }
5307
5308 unsigned Opcode = 0;
5309 SDVTList LdResVTs;
5310
5311 switch (NumElts) {
5312 default:
5313 return;
5314 case 2:
5315 Opcode = NVPTXISD::LoadV2;
5316 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
5317 break;
5318 case 4: {
5319 Opcode = NVPTXISD::LoadV4;
5320 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
5321 LdResVTs = DAG.getVTList(ListVTs);
5322 break;
5323 }
5324 }
5325
5326 // Copy regular operands
5327 SmallVector<SDValue, 8> OtherOps(N->ops());
5328
5329 // The select routine does not have access to the LoadSDNode instance, so
5330 // pass along the extension information
5331 OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
5332
5333 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
5334 LD->getMemoryVT(),
5335 LD->getMemOperand());
5336
5337 SmallVector<SDValue> ScalarRes;
5338 assert(NumElts <= ResVT.getVectorNumElements() &&
5339 "NumElts should not increase, only decrease or stay the same.");
5340 if (NumElts < ResVT.getVectorNumElements()) {
5341 // If the number of elements has decreased, getVectorLoweringShape has
5342 // upsized the element types
5343 assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 &&
5344 EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type.");
5345 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
5346 // into individual elements.
5347 for (unsigned i = 0; i < NumElts; ++i) {
5348 SDValue SubVector = NewLD.getValue(i);
5349 DAG.ExtractVectorElements(SubVector, ScalarRes);
5350 }
5351 } else {
5352 for (unsigned i = 0; i < NumElts; ++i) {
5353 SDValue Res = NewLD.getValue(i);
5354 if (NeedTrunc)
5355 Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
5356 ScalarRes.push_back(Res);
5357 }
5358 }
5359
5360 SDValue LoadChain = NewLD.getValue(NumElts);
5361
5362 SDValue BuildVec = DAG.getBuildVector(ResVT, DL, ScalarRes);
5363
5364 Results.push_back(BuildVec);
5365 Results.push_back(LoadChain);
5366}
5367
5370 SDValue Chain = N->getOperand(0);
5371 SDValue Intrin = N->getOperand(1);
5372 SDLoc DL(N);
5373
5374 // Get the intrinsic ID
5375 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
5376 switch (IntrinNo) {
5377 default:
5378 return;
5379 case Intrinsic::nvvm_ldu_global_i:
5380 case Intrinsic::nvvm_ldu_global_f:
5381 case Intrinsic::nvvm_ldu_global_p: {
5382 EVT ResVT = N->getValueType(0);
5383
5384 if (ResVT.isVector()) {
5385 // Vector LDG/LDU
5386
5387 unsigned NumElts = ResVT.getVectorNumElements();
5388 EVT EltVT = ResVT.getVectorElementType();
5389
5390 // Since LDU/LDG are target nodes, we cannot rely on DAG type
5391 // legalization.
5392 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
5393 // loaded type to i16 and propagate the "real" type as the memory type.
5394 bool NeedTrunc = false;
5395 if (EltVT.getSizeInBits() < 16) {
5396 EltVT = MVT::i16;
5397 NeedTrunc = true;
5398 }
5399
5400 unsigned Opcode = 0;
5401 SDVTList LdResVTs;
5402
5403 switch (NumElts) {
5404 default:
5405 return;
5406 case 2:
5407 Opcode = NVPTXISD::LDUV2;
5408 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
5409 break;
5410 case 4: {
5411 Opcode = NVPTXISD::LDUV4;
5412 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
5413 LdResVTs = DAG.getVTList(ListVTs);
5414 break;
5415 }
5416 }
5417
5418 SmallVector<SDValue, 8> OtherOps;
5419
5420 // Copy regular operands
5421
5422 OtherOps.push_back(Chain); // Chain
5423 // Skip operand 1 (intrinsic ID)
5424 // Others
5425 OtherOps.append(N->op_begin() + 2, N->op_end());
5426
5427 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
5428
5429 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
5430 MemSD->getMemoryVT(),
5431 MemSD->getMemOperand());
5432
5433 SmallVector<SDValue, 4> ScalarRes;
5434
5435 for (unsigned i = 0; i < NumElts; ++i) {
5436 SDValue Res = NewLD.getValue(i);
5437 if (NeedTrunc)
5438 Res =
5439 DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
5440 ScalarRes.push_back(Res);
5441 }
5442
5443 SDValue LoadChain = NewLD.getValue(NumElts);
5444
5445 SDValue BuildVec =
5446 DAG.getBuildVector(ResVT, DL, ScalarRes);
5447
5448 Results.push_back(BuildVec);
5449 Results.push_back(LoadChain);
5450 } else {
5451 // i8 LDG/LDU
5452 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
5453 "Custom handling of non-i8 ldu/ldg?");
5454
5455 // Just copy all operands as-is
5456 SmallVector<SDValue, 4> Ops(N->ops());
5457
5458 // Force output to i16
5459 SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
5460
5461 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
5462
5463 // We make sure the memory type is i8, which will be used during isel
5464 // to select the proper instruction.
5465 SDValue NewLD =
5466 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, Ops,
5467 MVT::i8, MemSD->getMemOperand());
5468
5469 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
5470 NewLD.getValue(0)));
5471 Results.push_back(NewLD.getValue(1));
5472 }
5473 }
5474 }
5475}
5476
5479 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
5480 // result so that it can pass the legalization
5481 SDLoc DL(N);
5482 SDValue Chain = N->getOperand(0);
5483 SDValue Reg = N->getOperand(1);
5484 SDValue Glue = N->getOperand(2);
5485
5486 assert(Reg.getValueType() == MVT::i128 &&
5487 "Custom lowering for CopyFromReg with 128-bit reg only");
5488 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(1),
5489 N->getValueType(2)};
5490 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
5491
5492 SDValue NewValue = DAG.getNode(ISD::CopyFromReg, DL, ResultsType, NewOps);
5493 SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
5494 {NewValue.getValue(0), NewValue.getValue(1)});
5495
5496 Results.push_back(Pair);
5497 Results.push_back(NewValue.getValue(2));
5498 Results.push_back(NewValue.getValue(3));
5499}
5500
5501void NVPTXTargetLowering::ReplaceNodeResults(
5503 switch (N->getOpcode()) {
5504 default:
5505 report_fatal_error("Unhandled custom legalization");
5506 case ISD::BITCAST:
5507 ReplaceBITCAST(N, DAG, Results);
5508 return;
5509 case ISD::LOAD:
5511 return;
5514 return;
5515 case ISD::CopyFromReg:
5517 return;
5518 }
5519}
5520
5523 Type *Ty = AI->getValOperand()->getType();
5524
5525 if (AI->isFloatingPointOperation()) {
5527 if (Ty->isHalfTy() && STI.getSmVersion() >= 70 &&
5528 STI.getPTXVersion() >= 63)
5530 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
5531 STI.getPTXVersion() >= 78)
5533 if (Ty->isFloatTy())
5535 if (Ty->isDoubleTy() && STI.hasAtomAddF64())
5537 }
5539 }
5540
5541 assert(Ty->isIntegerTy() && "Ty should be integer at this point");
5542 auto ITy = cast<llvm::IntegerType>(Ty);
5543
5544 switch (AI->getOperation()) {
5545 default:
5551 switch (ITy->getBitWidth()) {
5552 case 8:
5553 case 16:
5555 case 32:
5557 case 64:
5558 if (STI.hasAtomBitwise64())
5561 default:
5562 llvm_unreachable("unsupported width encountered");
5563 }
5570 switch (ITy->getBitWidth()) {
5571 case 8:
5572 case 16:
5574 case 32:
5576 case 64:
5577 if (STI.hasAtomMinMax64())
5580 default:
5581 llvm_unreachable("unsupported width encountered");
5582 }
5583 }
5584
5586}
5587
5588// Pin NVPTXTargetObjectFile's vtables to this file.
5590
5592 const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
5593 return getDataSection();
5594}
#define MAKE_CASE(V)
static const LLT F32
AMDGPU Register Bank Select
This file implements a class to represent arbitrary precision integral constant values and operations...
static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformADDCombineWithOperands - Try DAG combinations for an ADD with operands N0 and N1.
static SDValue PerformADDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
static SDValue PerformVSELECTCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
static SDValue PerformMULCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
static SDValue PerformFADDCombine(SDNode *N, SelectionDAG &DAG, const ARMSubtarget *Subtarget)
static SDValue PerformANDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
static SDValue PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformBUILD_VECTORCombine - Target-specific dag combine xforms for ISD::BUILD_VECTOR.
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Function Alias Analysis Results
This file contains the simple types necessary to represent the attributes associated with functions a...
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
This file contains the declarations for the subclasses of Constant, which represent the different fla...
return RetTy
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
uint64_t Size
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
This file contains the declarations of entities that describe floating point environment and related ...
Module.h This file contains the declarations for the Module class.
static LVOptions Options
Definition: LVOptions.cpp:25
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
static DebugLoc getDebugLoc(MachineBasicBlock::instr_iterator FirstMI, MachineBasicBlock::instr_iterator LastMI)
Return the first found DebugLoc that has a DILocation, given a range of instructions.
unsigned const TargetRegisterInfo * TRI
NVPTX address space definition.
static bool shouldConvertToIndirectCall(const CallBase *CB, const GlobalAddressSDNode *Func)
static cl::opt< bool > sched4reg("nvptx-sched4reg", cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false))
static SDValue PerformEXTRACTCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
static bool isConstOne(const SDValue &Operand)
static cl::opt< unsigned > FMAContractLevelOpt("nvptx-fma-level", cl::Hidden, cl::desc("NVPTX Specific: FMA contraction (0: don't do it" " 1: do it 2: do it aggressively"), cl::init(2))
static bool IsPTXVectorType(MVT VT)
static cl::opt< int > UsePrecDivF32("nvptx-prec-divf32", cl::Hidden, cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use" " IEEE Compliant F32 div.rnd if available."), cl::init(2))
static SDValue PerformStoreParamCombine(SDNode *N)
static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static bool Is16bitsType(MVT VT)
static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL, TargetLowering::DAGCombinerInfo &DCI)
static bool IsTypePassedAsArray(const Type *Ty)
static SmallVector< ParamVectorizationFlags, 16 > VectorizePTXValueVTs(const SmallVectorImpl< EVT > &ValueVTs, const SmallVectorImpl< uint64_t > &Offsets, Align ParamAlignment, bool IsVAArg=false)
static unsigned CanMergeParamLoadStoresStartingAt(unsigned Idx, uint32_t AccessSize, const SmallVectorImpl< EVT > &ValueVTs, const SmallVectorImpl< uint64_t > &Offsets, Align ParamAlignment)
static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static SDValue PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static bool isConstZero(const SDValue &Operand)
static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG)
static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, Type *Ty, SmallVectorImpl< EVT > &ValueVTs, SmallVectorImpl< uint64_t > *Offsets=nullptr, uint64_t StartingOffset=0)
ComputePTXValueVTs - For the given Type Ty, returns the set of primitive EVTs that compose it.
static bool IsMulWideOperandDemotable(SDValue Op, unsigned OptSize, OperandSignedness &S)
IsMulWideOperandDemotable - Checks if the provided DAG node is an operand that can be demoted to OptS...
static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain, uint64_t Offset, EVT ElementType, SDValue StVal, SDValue &InGlue, unsigned ArgID, const SDLoc &dl)
static SDValue PerformREMCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static std::optional< std::pair< unsigned int, EVT > > getVectorLoweringShape(EVT VectorVT)
static SDValue PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI)
static SDValue PerformStoreRetvalCombine(SDNode *N)
static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS, unsigned OptSize, bool &IsSigned)
AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can be demoted to OptSize bits...
static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front, std::size_t Back)
static bool adjustElementType(EVT &ElementType)
static SDValue TryMULWIDECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply of M/2 bits that produces...
static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT, SDLoc DL, TargetLowering::DAGCombinerInfo &DCI)
static SDValue matchMADConstOnePattern(SDValue Add)
static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT, SDValue Value)
static cl::opt< bool > UsePrecSqrtF32("nvptx-prec-sqrtf32", cl::Hidden, cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."), cl::init(true))
ParamVectorizationFlags
@ PVF_FIRST
@ PVF_SCALAR
@ PVF_INNER
@ PVF_LAST
static SDValue LowerUnalignedStoreRet(SelectionDAG &DAG, SDValue Chain, uint64_t Offset, EVT ElementType, SDValue RetVal, const SDLoc &dl)
static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG)
static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT)
PromoteScalarIntegerPTX Used to make sure the arguments/returns are suitable for passing and promote ...
OperandSignedness
static SDValue PerformSETCCCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, unsigned int SmVersion)
static SDValue LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset, EVT ElementType, SDValue &InGlue, SmallVectorImpl< SDValue > &TempProxyRegOps, const SDLoc &dl)
static std::atomic< unsigned > GlobalUniqueCallSite
static cl::opt< bool > ForceMinByValParamAlign("nvptx-force-min-byval-param-align", cl::Hidden, cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval" " params of device functions."), cl::init(false))
static cl::opt< bool > UseApproxLog2F32("nvptx-approx-log2f32", cl::desc("NVPTX Specific: whether to use lg2.approx for log2"), cl::init(false))
Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it does NOT use lg2....
static SDValue PerformSHLCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
unsigned SmVersion
Definition: NVVMReflect.cpp:79
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
if(PassOpts->AAPipeline)
const SmallVectorImpl< MachineOperand > & Cond
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file contains some templates that are useful if you are working with the STL at all.
This file defines the SmallVector class.
static bool Enabled
Definition: Statistic.cpp:46
This file describes how to lower LLVM code to machine code.
Value * RHS
Value * LHS
Class for arbitrary precision integers.
Definition: APInt.h:78
bool isSignedIntN(unsigned N) const
Check if this APInt has an N-bits signed integer value.
Definition: APInt.h:435
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1130
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:432
bool sge(const APInt &RHS) const
Signed greater or equal comparison.
Definition: APInt.h:1237
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
const T & back() const
back - Get the last element.
Definition: ArrayRef.h:177
ArrayRef< T > drop_back(size_t N=1) const
Drop the last N elements of the array.
Definition: ArrayRef.h:213
bool empty() const
empty - Check if the array is empty.
Definition: ArrayRef.h:163
an instruction that atomically reads a memory location, combines it with another value,...
Definition: Instructions.h:704
@ Add
*p = old + v
Definition: Instructions.h:720
@ FAdd
*p = old + v
Definition: Instructions.h:741
@ Min
*p = old <signed v ? old : v
Definition: Instructions.h:734
@ Or
*p = old | v
Definition: Instructions.h:728
@ Sub
*p = old - v
Definition: Instructions.h:722
@ And
*p = old & v
Definition: Instructions.h:724
@ Xor
*p = old ^ v
Definition: Instructions.h:730
@ Max
*p = old >signed v ? old : v
Definition: Instructions.h:732
@ UMin
*p = old <unsigned v ? old : v
Definition: Instructions.h:738
@ UMax
*p = old >unsigned v ? old : v
Definition: Instructions.h:736
bool isFloatingPointOperation() const
Definition: Instructions.h:882
BinOp getOperation() const
Definition: Instructions.h:805
Value * getValOperand()
Definition: Instructions.h:874
bool hasParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const
Return true if the attribute exists for the given argument.
Definition: Attributes.h:833
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1112
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1341
FunctionType * getFunctionType() const
Definition: InstrTypes.h:1199
This class represents a function call, abstracting a target machine's calling convention.
uint64_t getZExtValue() const
const APInt & getAPIntValue() const
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:373
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
TypeSize getTypeAllocSize(Type *Ty) const
Returns the offset in bytes between successive objects of the specified type, including alignment pad...
Definition: DataLayout.h:457
Align getPrefTypeAlign(Type *Ty) const
Returns the preferred stack/global alignment for the specified type.
Definition: DataLayout.cpp:847
Diagnostic information for unsupported feature in backend.
void addFnAttr(Attribute::AttrKind Kind)
Add function attributes to this function.
Definition: Function.cpp:641
Type * getReturnType() const
Returns the type of the ret val.
Definition: Function.h:221
unsigned getAddressSpace() const
const GlobalValue * getGlobal() const
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
void diagnose(const DiagnosticInfo &DI)
Report a message to the currently installed diagnostic handler.
This class is used to represent ISD::LOAD nodes.
MCSection * getDataSection() const
Instances of this class represent a uniqued identifier for a section in the current translation unit.
Definition: MCSection.h:36
StringRef getName() const
getName - Get the symbol name.
Definition: MCSymbol.h:205
Machine Value Type.
SimpleValueType SimpleTy
unsigned getVectorNumElements() const
bool isScalableVector() const
Return true if this is a vector value type where the runtime length is machine dependent.
static auto integer_valuetypes()
static auto fixedlen_vector_valuetypes()
static MVT getVectorVT(MVT VT, unsigned NumElements)
static MVT getIntegerVT(unsigned BitWidth)
MVT getScalarType() const
If this is a vector, return the element type, otherwise return this.
DenormalMode getDenormalMode(const fltSemantics &FPType) const
Returns the denormal handling type for the default rounding mode of the function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
const MachineJumpTableInfo * getJumpTableInfo() const
getJumpTableInfo - Return the jump table info object for the current function.
const TargetMachine & getTarget() const
getTarget - Return the target machine this machine code is compiled with
@ EK_Inline
EK_Inline - Jump table entries are emitted inline at their point of use.
const std::vector< MachineJumpTableEntry > & getJumpTables() const
@ MODereferenceable
The memory access is dereferenceable (i.e., doesn't trap).
@ MOLoad
The memory access reads data.
@ MOInvariant
The memory access always returns the same value (or traps).
@ MOStore
The memory access writes data.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Register createVirtualRegister(const TargetRegisterClass *RegClass, StringRef Name="")
createVirtualRegister - Create and return a new virtual register in the function with the specified r...
This SDNode is used for target intrinsics that touch memory and need an associated MachineMemOperand.
This is an abstract virtual class for memory operations.
Align getAlign() const
MachineMemOperand * getMemOperand() const
Return a MachineMemOperand object describing the memory reference performed by operation.
EVT getMemoryVT() const
Return the type of the in-memory value.
unsigned getMaxRequiredAlignment() const
bool hasAtomMinMax64() const
bool hasAtomAddF64() const
bool hasHWROT32() const
const NVPTXTargetLowering * getTargetLowering() const override
unsigned getPTXVersion() const
bool hasNativeBF16Support(int Opcode) const
const NVPTXRegisterInfo * getRegisterInfo() const override
unsigned int getSmVersion() const
bool hasAtomBitwise64() const
bool hasBF16Math() const
bool allowFP16Math() const
bool hasAtomCas16() const
ConstraintType getConstraintType(StringRef Constraint) const override
getConstraintType - Given a constraint letter, return the type of constraint it is for this target.
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override
This callback is invoked for operations that are unsupported by the target, which are registered to u...
const NVPTXTargetMachine * nvTM
SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const
NVPTXTargetLowering(const NVPTXTargetMachine &TM, const NVPTXSubtarget &STI)
bool useF32FTZ(const MachineFunction &MF) const
SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const
Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const
SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps, bool &UseOneConst, bool Reciprocal) const override
Hooks for building estimates in place of slower divisions and square roots.
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl< ISD::OutputArg > &Outs, const SmallVectorImpl< SDValue > &OutVals, const SDLoc &dl, SelectionDAG &DAG) const override
This hook must be implemented to lower outgoing return values, described by the Outs array,...
SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl< ISD::InputArg > &Ins, const SDLoc &dl, SelectionDAG &DAG, SmallVectorImpl< SDValue > &InVals) const override
This hook must be implemented to lower the incoming (formal) arguments, described by the Ins array,...
void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint, std::vector< SDValue > &Ops, SelectionDAG &DAG) const override
Lower the specified operand into the Ops vector.
SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const
std::string getParamName(const Function *F, int Idx) const
TargetLoweringBase::LegalizeTypeAction getPreferredVectorAction(MVT VT) const override
Return the preferred vector type legalization action.
std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &, const SmallVectorImpl< ISD::OutputArg > &, MaybeAlign retAlignment, std::optional< std::pair< unsigned, const APInt & > > VAInfo, const CallBase &CB, unsigned UniqueCallSite) const
Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy, const DataLayout &DL) const
getFunctionParamOptimizedAlign - since function arguments are passed via .param space,...
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const
EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Ctx, EVT VT) const override
Return the ValueType of the result of SETCC operations.
std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const override
Given a physical register constraint (e.g.
bool isLegalAddressingMode(const DataLayout &DL, const AddrMode &AM, Type *Ty, unsigned AS, Instruction *I=nullptr) const override
isLegalAddressingMode - Return true if the addressing mode represented by AM is legal for this target...
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const override
Returns how the IR-level AtomicExpand pass should expand the given AtomicRMW, if at all.
Align getFunctionByValParamAlign(const Function *F, Type *ArgTy, Align InitialAlign, const DataLayout &DL) const
Helper for computing alignment of a device function byval parameter.
bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
const char * getTargetNodeName(unsigned Opcode) const override
This method returns the name of a target specific DAG node.
bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const
unsigned getJumpTableEncoding() const override
Return the entry encoding for a jump table in the current function.
bool allowUnsafeFPMath(MachineFunction &MF) const
SDValue LowerCall(CallLoweringInfo &CLI, SmallVectorImpl< SDValue > &InVals) const override
This hook must be implemented to lower calls into the specified DAG.
UniqueStringSaver & getStrPool() const
MCSection * SelectSectionForGlobal(const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const override
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Wrapper class for IR location info (IR ordering and DebugLoc) to be passed into SDNode creation funct...
Represents one node in the SelectionDAG.
const APInt & getAsAPIntVal() const
Helper method returns the APInt value of a ConstantSDNode.
unsigned getOpcode() const
Return the SelectionDAG opcode value for this node.
bool hasOneUse() const
Return true if there is exactly one use of this node.
unsigned getIROrder() const
Return the node ordering.
uint64_t getAsZExtVal() const
Helper method returns the zero-extended integer value of a ConstantSDNode.
unsigned getNumOperands() const
Return the number of values used by this operation.
SDVTList getVTList() const
const SDValue & getOperand(unsigned Num) const
uint64_t getConstantOperandVal(unsigned Num) const
Helper method returns the integer value of a ConstantSDNode operand.
const APInt & getConstantOperandAPInt(unsigned Num) const
Helper method returns the APInt of a ConstantSDNode operand.
EVT getValueType(unsigned ResNo) const
Return the type of a specified result.
bool isUndef() const
Return true if the type of the node type undefined.
iterator_range< user_iterator > users()
Represents a use of a SDNode.
Unlike LLVM values, Selection DAG nodes may return multiple values as the result of a computation.
SDNode * getNode() const
get the SDNode which holds the desired result
SDValue getValue(unsigned R) const
EVT getValueType() const
Return the ValueType of the referenced return value.
TypeSize getValueSizeInBits() const
Returns the size of the value in bits.
const SDValue & getOperand(unsigned i) const
MVT getSimpleValueType() const
Return the simple ValueType of the referenced return value.
unsigned getOpcode() const
SectionKind - This is a simple POD value that classifies the properties of a section.
Definition: SectionKind.h:22
This is used to represent a portion of an LLVM function in a low-level Data Dependence DAG representa...
Definition: SelectionDAG.h:228
SDValue getExtLoad(ISD::LoadExtType ExtType, const SDLoc &dl, EVT VT, SDValue Chain, SDValue Ptr, MachinePointerInfo PtrInfo, EVT MemVT, MaybeAlign Alignment=MaybeAlign(), MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
SDValue getTargetGlobalAddress(const GlobalValue *GV, const SDLoc &DL, EVT VT, int64_t offset=0, unsigned TargetFlags=0)
Definition: SelectionDAG.h:750
const SDValue & getRoot() const
Return the root tag of the SelectionDAG.
Definition: SelectionDAG.h:577
SDValue getAddrSpaceCast(const SDLoc &dl, EVT VT, SDValue Ptr, unsigned SrcAS, unsigned DestAS)
Return an AddrSpaceCastSDNode.
SDValue getCopyToReg(SDValue Chain, const SDLoc &dl, Register Reg, SDValue N)
Definition: SelectionDAG.h:801
SDValue getMergeValues(ArrayRef< SDValue > Ops, const SDLoc &dl)
Create a MERGE_VALUES node from the given operands.
SDVTList getVTList(EVT VT)
Return an SDVTList that represents the list of values specified.
void ExtractVectorElements(SDValue Op, SmallVectorImpl< SDValue > &Args, unsigned Start=0, unsigned Count=0, EVT EltVT=EVT())
Append the extracted elements from Start to Count out of the vector Op in Args.
SDValue getSetCC(const SDLoc &DL, EVT VT, SDValue LHS, SDValue RHS, ISD::CondCode Cond, SDValue Chain=SDValue(), bool IsSignaling=false)
Helper function to make it easier to build SetCC's if you just have an ISD::CondCode instead of an SD...
SDValue getSymbolFunctionGlobalAddress(SDValue Op, Function **TargetFunction=nullptr)
Return a GlobalAddress of the function from the current module with name matching the given ExternalS...
SDValue getConstantFP(double Val, const SDLoc &DL, EVT VT, bool isTarget=false)
Create a ConstantFPSDNode wrapping a constant value.
SDValue getLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, MachinePointerInfo PtrInfo, MaybeAlign Alignment=MaybeAlign(), MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes(), const MDNode *Ranges=nullptr)
Loads are not normal binary operators: their result type is not determined by their operands,...
const TargetLowering & getTargetLoweringInfo() const
Definition: SelectionDAG.h:503
SDNode * MorphNodeTo(SDNode *N, unsigned Opc, SDVTList VTs, ArrayRef< SDValue > Ops)
This mutates the specified node to have the specified return type, opcode, and operands.
SDValue getCALLSEQ_END(SDValue Chain, SDValue Op1, SDValue Op2, SDValue InGlue, const SDLoc &DL)
Return a new CALLSEQ_END node, which always must have a glue result (to ensure it's not CSE'd).
SDValue getBuildVector(EVT VT, const SDLoc &DL, ArrayRef< SDValue > Ops)
Return an ISD::BUILD_VECTOR node.
Definition: SelectionDAG.h:856
SDValue getBitcast(EVT VT, SDValue V)
Return a bitcast using the SDLoc of the value operand, and casting to the provided type.
SDValue getCopyFromReg(SDValue Chain, const SDLoc &dl, Register Reg, EVT VT)
Definition: SelectionDAG.h:827
SDValue getSelect(const SDLoc &DL, EVT VT, SDValue Cond, SDValue LHS, SDValue RHS, SDNodeFlags Flags=SDNodeFlags())
Helper function to make it easier to build Select's if you just have operands and don't want to check...
const DataLayout & getDataLayout() const
Definition: SelectionDAG.h:497
SDValue getConstant(uint64_t Val, const SDLoc &DL, EVT VT, bool isTarget=false, bool isOpaque=false)
Create a ConstantSDNode wrapping a constant value.
SDValue getTruncStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, MachinePointerInfo PtrInfo, EVT SVT, Align Alignment, MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
void ReplaceAllUsesWith(SDValue From, SDValue To)
Modify anything using 'From' to use 'To' instead.
SDValue getStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, MachinePointerInfo PtrInfo, Align Alignment, MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
Helper function to build ISD::STORE nodes.
SDValue getSignedConstant(int64_t Val, const SDLoc &DL, EVT VT, bool isTarget=false, bool isOpaque=false)
SDValue getCALLSEQ_START(SDValue Chain, uint64_t InSize, uint64_t OutSize, const SDLoc &DL)
Return a new CALLSEQ_START node, that starts new call frame, in which InSize bytes are set up inside ...
void RemoveDeadNode(SDNode *N)
Remove the specified node from the system.
SDValue getBasicBlock(MachineBasicBlock *MBB)
SDValue getAnyExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of integer type, to the integer type VT, by either any-extending or truncat...
SDValue getSelectCC(const SDLoc &DL, SDValue LHS, SDValue RHS, SDValue True, SDValue False, ISD::CondCode Cond)
Helper function to make it easier to build SelectCC's if you just have an ISD::CondCode instead of an...
SDValue getIntPtrConstant(uint64_t Val, const SDLoc &DL, bool isTarget=false)
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, ArrayRef< SDUse > Ops)
Gets or creates the specified node.
SDValue getFPExtendOrRound(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of float type, to the float type VT, by either extending or rounding (by tr...
SDValue getTargetConstant(uint64_t Val, const SDLoc &DL, EVT VT, bool isOpaque=false)
Definition: SelectionDAG.h:700
MachineFunction & getMachineFunction() const
Definition: SelectionDAG.h:492
SDValue getZExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of integer type, to the integer type VT, by either zero-extending or trunca...
LLVMContext * getContext() const
Definition: SelectionDAG.h:510
const SDValue & setRoot(SDValue N)
Set the current root tag of the SelectionDAG.
Definition: SelectionDAG.h:586
SDValue getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, SDVTList VTList, ArrayRef< SDValue > Ops, EVT MemVT, MachinePointerInfo PtrInfo, Align Alignment, MachineMemOperand::Flags Flags=MachineMemOperand::MOLoad|MachineMemOperand::MOStore, LocationSize Size=0, const AAMDNodes &AAInfo=AAMDNodes())
Creates a MemIntrinsicNode that may produce a result and takes a list of operands.
SDValue getTargetExternalSymbol(const char *Sym, EVT VT, unsigned TargetFlags=0)
SDValue getEntryNode() const
Return the token chain corresponding to the entry of the function.
Definition: SelectionDAG.h:580
This SDNode is used to implement the code generator support for the llvm IR shufflevector instruction...
ArrayRef< int > getMask() const
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void assign(size_type NumElts, ValueParamT Elt)
Definition: SmallVector.h:704
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
void resize(size_type N)
Definition: SmallVector.h:638
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
This class is used to represent ISD::STORE nodes.
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
constexpr size_t size() const
size - Get the string size.
Definition: StringRef.h:150
constexpr const char * data() const
data - Get a pointer to the start of the string (which may not be null terminated).
Definition: StringRef.h:144
Class to represent struct types.
Definition: DerivedTypes.h:218
void setBooleanVectorContents(BooleanContent Ty)
Specify how the target extends the result of a vector boolean value from a vector of i1 to a wider ty...
void setOperationAction(unsigned Op, MVT VT, LegalizeAction Action)
Indicate that the specified operation does not work with the specified type and indicate what to do a...
void setMaxDivRemBitWidthSupported(unsigned SizeInBits)
Set the size in bits of the maximum div/rem the backend supports.
EVT getValueType(const DataLayout &DL, Type *Ty, bool AllowUnknown=false) const
Return the EVT corresponding to this LLVM type.
LegalizeAction
This enum indicates whether operations are valid for a target, and if not, what action should be used...
unsigned MaxStoresPerMemcpyOptSize
Likewise for functions with the OptSize attribute.
virtual const TargetRegisterClass * getRegClassFor(MVT VT, bool isDivergent=false) const
Return the register class that should be used for the specified value type.
const TargetMachine & getTargetMachine() const
void setOperationPromotedToType(unsigned Opc, MVT OrigVT, MVT DestVT)
Convenience method to set an operation to Promote and specify the type in a single call.
LegalizeTypeAction
This enum indicates whether a types are legal for a target, and if not, what action should be used to...
void addBypassSlowDiv(unsigned int SlowBitWidth, unsigned int FastBitWidth)
Tells the code generator which bitwidths to bypass.
virtual unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const
Return the number of registers that this ValueType will eventually require.
void setMaxAtomicSizeInBitsSupported(unsigned SizeInBits)
Set the maximum atomic operation size supported by the backend.
virtual TargetLoweringBase::LegalizeTypeAction getPreferredVectorAction(MVT VT) const
Return the preferred vector type legalization action.
unsigned MaxStoresPerMemsetOptSize
Likewise for functions with the OptSize attribute.
void setBooleanContents(BooleanContent Ty)
Specify how the target extends the result of integer and floating point boolean values from i1 to a w...
unsigned MaxStoresPerMemmove
Specify maximum number of store instructions per memmove call.
void computeRegisterProperties(const TargetRegisterInfo *TRI)
Once all of the register classes are added, this allows us to compute derived properties we expose.
unsigned MaxStoresPerMemmoveOptSize
Likewise for functions with the OptSize attribute.
void addRegisterClass(MVT VT, const TargetRegisterClass *RC)
Add the specified register class as an available regclass for the specified value type.
virtual MVT getPointerTy(const DataLayout &DL, uint32_t AS=0) const
Return the pointer type for the given address space, defaults to the pointer type from the data layou...
unsigned MaxStoresPerMemset
Specify maximum number of store instructions per memset call.
void setTruncStoreAction(MVT ValVT, MVT MemVT, LegalizeAction Action)
Indicate that the specified truncating store does not work with the specified type and indicate what ...
void setMinCmpXchgSizeInBits(unsigned SizeInBits)
Sets the minimum cmpxchg or ll/sc size supported by the backend.
void AddPromotedToType(unsigned Opc, MVT OrigVT, MVT DestVT)
If Opc/OrigVT is specified as being promoted, the promotion code defaults to trying a larger integer/...
AtomicExpansionKind
Enum that specifies what an atomic load/AtomicRMWInst is expanded to, if at all.
void setCondCodeAction(ArrayRef< ISD::CondCode > CCs, MVT VT, LegalizeAction Action)
Indicate that the specified condition code is or isn't supported on the target and indicate what to d...
void setTargetDAGCombine(ArrayRef< ISD::NodeType > NTs)
Targets should invoke this method for each target independent node that they want to provide a custom...
Align getMinStackArgumentAlignment() const
Return the minimum stack alignment of an argument.
void setLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT, LegalizeAction Action)
Indicate that the specified load with extension does not work with the specified type and indicate wh...
std::vector< ArgListEntry > ArgListTy
bool allowsMemoryAccessForAlignment(LLVMContext &Context, const DataLayout &DL, EVT VT, unsigned AddrSpace=0, Align Alignment=Align(1), MachineMemOperand::Flags Flags=MachineMemOperand::MONone, unsigned *Fast=nullptr) const
This function returns true if the memory access is aligned or if the target allows this specific unal...
unsigned MaxStoresPerMemcpy
Specify maximum number of store instructions per memcpy call.
void setSchedulingPreference(Sched::Preference Pref)
Specify the target scheduling preference.
void setJumpIsExpensive(bool isExpensive=true)
Tells the code generator not to expand logic operations on comparison predicates into separate sequen...
LegalizeAction getOperationAction(unsigned Op, EVT VT) const
Return how this operation should be treated: either it is legal, needs to be promoted to a larger siz...
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
SDValue expandUnalignedStore(StoreSDNode *ST, SelectionDAG &DAG) const
Expands an unaligned store to 2 half-size stores for integer values, and possibly more for vectors.
virtual ConstraintType getConstraintType(StringRef Constraint) const
Given a constraint, return the type of constraint it is for this target.
std::pair< SDValue, SDValue > expandUnalignedLoad(LoadSDNode *LD, SelectionDAG &DAG) const
Expands an unaligned load to 2 half-size loads for an integer, and possibly more for vectors.
virtual std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const
Given a physical register constraint (e.g.
SDValue expandRoundInexactToOdd(EVT ResultVT, SDValue Op, const SDLoc &DL, SelectionDAG &DAG) const
Truncate Op to ResultVT.
SDValue expandFP_ROUND(SDNode *Node, SelectionDAG &DAG) const
Expand round(fp) to fp conversion.
virtual void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint, std::vector< SDValue > &Ops, SelectionDAG &DAG) const
Lower the specified operand into the Ops vector.
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
CodeGenOptLevel getOptLevel() const
Returns the optimization level: None, Less, Default, or Aggressive.
TargetOptions Options
MCSymbol * getSymbol(const GlobalValue *GV) const
unsigned UnsafeFPMath
UnsafeFPMath - This flag is enabled when the -enable-unsafe-fp-math flag is specified on the command ...
FPOpFusion::FPOpFusionMode AllowFPOpFusion
AllowFPOpFusion - This flag is set by the -fp-contract=xxx option.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:270
bool isFloatTy() const
Return true if this is 'float', a 32-bit IEEE fp type.
Definition: Type.h:153
bool isBFloatTy() const
Return true if this is 'bfloat', a 16-bit bfloat type.
Definition: Type.h:145
@ VoidTyID
type with no size
Definition: Type.h:63
bool isAggregateType() const
Return true if the type is an aggregate type.
Definition: Type.h:303
bool isHalfTy() const
Return true if this is 'half', a 16-bit IEEE fp type.
Definition: Type.h:142
bool isDoubleTy() const
Return true if this is 'double', a 64-bit IEEE fp type.
Definition: Type.h:156
bool isFloatingPointTy() const
Return true if this is one of the floating-point types.
Definition: Type.h:184
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:237
TypeID getTypeID() const
Return the type id for the type.
Definition: Type.h:136
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
StringRef save(const char *S)
Definition: StringSaver.h:52
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
int getNumOccurrences() const
Definition: CommandLine.h:399
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:661
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
NodeType
ISD::NodeType enum - This enum defines the target-independent operators for a SelectionDAG.
Definition: ISDOpcodes.h:40
@ SETCC
SetCC operator - This evaluates to a true value iff the condition is true.
Definition: ISDOpcodes.h:780
@ STACKRESTORE
STACKRESTORE has two operands, an input chain and a pointer to restore to it returns an output chain.
Definition: ISDOpcodes.h:1197
@ STACKSAVE
STACKSAVE - STACKSAVE has one operand, an input chain.
Definition: ISDOpcodes.h:1193
@ SMUL_LOHI
SMUL_LOHI/UMUL_LOHI - Multiply two integers of type iN, producing a signed/unsigned value of type i[2...
Definition: ISDOpcodes.h:257
@ BSWAP
Byte Swap and Counting operators.
Definition: ISDOpcodes.h:744
@ VAEND
VAEND, VASTART - VAEND and VASTART have three operands: an input chain, pointer, and a SRCVALUE.
Definition: ISDOpcodes.h:1226
@ ConstantFP
Definition: ISDOpcodes.h:77
@ ADDC
Carry-setting nodes for multiple precision addition and subtraction.
Definition: ISDOpcodes.h:276
@ ADD
Simple integer binary arithmetic operators.
Definition: ISDOpcodes.h:246
@ LOAD
LOAD and STORE have token chains as their first operand, then the same operands as an LLVM load/store...
Definition: ISDOpcodes.h:1102
@ ANY_EXTEND
ANY_EXTEND - Used for integer types. The high bits are undefined.
Definition: ISDOpcodes.h:814
@ FMA
FMA - Perform a * b + c with no intermediate rounding step.
Definition: ISDOpcodes.h:498
@ INTRINSIC_VOID
OUTCHAIN = INTRINSIC_VOID(INCHAIN, INTRINSICID, arg1, arg2, ...) This node represents a target intrin...
Definition: ISDOpcodes.h:205
@ GlobalAddress
Definition: ISDOpcodes.h:78
@ SINT_TO_FP
[SU]INT_TO_FP - These operators convert integers (whose interpreted sign depends on the first letter)...
Definition: ISDOpcodes.h:841
@ CONCAT_VECTORS
CONCAT_VECTORS(VECTOR0, VECTOR1, ...) - Given a number of values of vector type with the same length ...
Definition: ISDOpcodes.h:558
@ FADD
Simple binary floating point operators.
Definition: ISDOpcodes.h:397
@ ABS
ABS - Determine the unsigned absolute value of a signed integer value of the same bitwidth.
Definition: ISDOpcodes.h:717
@ SDIVREM
SDIVREM/UDIVREM - Divide two integers and produce both a quotient and remainder result.
Definition: ISDOpcodes.h:262
@ BITCAST
BITCAST - This operator converts between integer, vector and FP values, as if the value was stored to...
Definition: ISDOpcodes.h:954
@ BUILD_PAIR
BUILD_PAIR - This is the opposite of EXTRACT_ELEMENT in some ways.
Definition: ISDOpcodes.h:236
@ SIGN_EXTEND
Conversion operators.
Definition: ISDOpcodes.h:805
@ READSTEADYCOUNTER
READSTEADYCOUNTER - This corresponds to the readfixedcounter intrinsic.
Definition: ISDOpcodes.h:1259
@ FNEG
Perform various unary floating-point operations inspired by libm.
Definition: ISDOpcodes.h:981
@ BR_CC
BR_CC - Conditional branch.
Definition: ISDOpcodes.h:1148
@ SSUBO
Same for subtraction.
Definition: ISDOpcodes.h:334
@ BRIND
BRIND - Indirect branch.
Definition: ISDOpcodes.h:1123
@ BR_JT
BR_JT - Jumptable branch.
Definition: ISDOpcodes.h:1127
@ SSUBSAT
RESULT = [US]SUBSAT(LHS, RHS) - Perform saturation subtraction on 2 integers with the same bit width ...
Definition: ISDOpcodes.h:356
@ SELECT
Select(COND, TRUEVAL, FALSEVAL).
Definition: ISDOpcodes.h:757
@ UNDEF
UNDEF - An undefined node.
Definition: ISDOpcodes.h:218
@ VACOPY
VACOPY - VACOPY has 5 operands: an input chain, a destination pointer, a source pointer,...
Definition: ISDOpcodes.h:1222
@ CopyFromReg
CopyFromReg - This node indicates that the input value is a virtual or physical register that is defi...
Definition: ISDOpcodes.h:215
@ SADDO
RESULT, BOOL = [SU]ADDO(LHS, RHS) - Overflow-aware nodes for addition.
Definition: ISDOpcodes.h:330
@ MULHU
MULHU/MULHS - Multiply high - Multiply two integers of type iN, producing an unsigned/signed value of...
Definition: ISDOpcodes.h:674
@ SHL
Shift and rotation operations.
Definition: ISDOpcodes.h:735
@ VECTOR_SHUFFLE
VECTOR_SHUFFLE(VEC1, VEC2) - Returns a vector, of the same type as VEC1/VEC2.
Definition: ISDOpcodes.h:615
@ EXTRACT_SUBVECTOR
EXTRACT_SUBVECTOR(VECTOR, IDX) - Returns a subvector from VECTOR.
Definition: ISDOpcodes.h:588
@ FMINNUM_IEEE
FMINNUM_IEEE/FMAXNUM_IEEE - Perform floating-point minimumNumber or maximumNumber on two values,...
Definition: ISDOpcodes.h:1044
@ EXTRACT_VECTOR_ELT
EXTRACT_VECTOR_ELT(VECTOR, IDX) - Returns a single element from VECTOR identified by the (potentially...
Definition: ISDOpcodes.h:550
@ CopyToReg
CopyToReg - This node has three operands: a chain, a register number to set to this value,...
Definition: ISDOpcodes.h:209
@ ZERO_EXTEND
ZERO_EXTEND - Used for integer types, zeroing the new bits.
Definition: ISDOpcodes.h:811
@ DEBUGTRAP
DEBUGTRAP - Trap intended to get the attention of a debugger.
Definition: ISDOpcodes.h:1282
@ SELECT_CC
Select with condition operator - This selects between a true value and a false value (ops #2 and #3) ...
Definition: ISDOpcodes.h:772
@ FMINNUM
FMINNUM/FMAXNUM - Perform floating-point minimum or maximum on two values.
Definition: ISDOpcodes.h:1031
@ SSHLSAT
RESULT = [US]SHLSAT(LHS, RHS) - Perform saturation left shift.
Definition: ISDOpcodes.h:366
@ SMULO
Same for multiplication.
Definition: ISDOpcodes.h:338
@ DYNAMIC_STACKALLOC
DYNAMIC_STACKALLOC - Allocate some number of bytes on the stack aligned to a specified boundary.
Definition: ISDOpcodes.h:1112
@ SIGN_EXTEND_INREG
SIGN_EXTEND_INREG - This operator atomically performs a SHL/SRA pair to sign extend a small value in ...
Definition: ISDOpcodes.h:849
@ SMIN
[US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned integers.
Definition: ISDOpcodes.h:697
@ FP_EXTEND
X = FP_EXTEND(Y) - Extend a smaller FP type into a larger FP type.
Definition: ISDOpcodes.h:939
@ VSELECT
Select with a vector condition (op #0) and two vector operands (ops #1 and #2), returning a vector re...
Definition: ISDOpcodes.h:766
@ UADDO_CARRY
Carry-using nodes for multiple precision addition and subtraction.
Definition: ISDOpcodes.h:310
@ BF16_TO_FP
BF16_TO_FP, FP_TO_BF16 - These operators are used to perform promotions and truncation for bfloat16.
Definition: ISDOpcodes.h:973
@ FRAMEADDR
FRAMEADDR, RETURNADDR - These nodes represent llvm.frameaddress and llvm.returnaddress on the DAG.
Definition: ISDOpcodes.h:100
@ FMINIMUM
FMINIMUM/FMAXIMUM - NaN-propagating minimum/maximum that also treat -0.0 as less than 0....
Definition: ISDOpcodes.h:1050
@ FP_TO_SINT
FP_TO_[US]INT - Convert a floating point value to a signed or unsigned integer.
Definition: ISDOpcodes.h:887
@ READCYCLECOUNTER
READCYCLECOUNTER - This corresponds to the readcyclecounter intrinsic.
Definition: ISDOpcodes.h:1253
@ AND
Bitwise operators - logical and, logical or, logical xor.
Definition: ISDOpcodes.h:709
@ TRAP
TRAP - Trapping instruction.
Definition: ISDOpcodes.h:1279
@ INTRINSIC_WO_CHAIN
RESULT = INTRINSIC_WO_CHAIN(INTRINSICID, arg1, arg2, ...) This node represents a target intrinsic fun...
Definition: ISDOpcodes.h:190
@ ADDE
Carry-using nodes for multiple precision addition and subtraction.
Definition: ISDOpcodes.h:286
@ FREEZE
FREEZE - FREEZE(VAL) returns an arbitrary value if VAL is UNDEF (or is evaluated to UNDEF),...
Definition: ISDOpcodes.h:223
@ INSERT_VECTOR_ELT
INSERT_VECTOR_ELT(VECTOR, VAL, IDX) - Returns VECTOR with the element at IDX replaced with VAL.
Definition: ISDOpcodes.h:539
@ TokenFactor
TokenFactor - This node takes multiple tokens as input and produces a single token result.
Definition: ISDOpcodes.h:52
@ FP_ROUND
X = FP_ROUND(Y, TRUNC) - Rounding 'Y' from a larger floating point type down to the precision of the ...
Definition: ISDOpcodes.h:920
@ TRUNCATE
TRUNCATE - Completely drop the high bits.
Definition: ISDOpcodes.h:817
@ VAARG
VAARG - VAARG has four operands: an input chain, a pointer, a SRCVALUE, and the alignment.
Definition: ISDOpcodes.h:1217
@ SHL_PARTS
SHL_PARTS/SRA_PARTS/SRL_PARTS - These operators are used for expanded integer shift operations.
Definition: ISDOpcodes.h:794
@ FCOPYSIGN
FCOPYSIGN(X, Y) - Return the value of X with the sign of Y.
Definition: ISDOpcodes.h:508
@ SADDSAT
RESULT = [US]ADDSAT(LHS, RHS) - Perform saturation addition on 2 integers with the same bit width (W)...
Definition: ISDOpcodes.h:347
@ SADDO_CARRY
Carry-using overflow-aware nodes for multiple precision addition and subtraction.
Definition: ISDOpcodes.h:320
@ INTRINSIC_W_CHAIN
RESULT,OUTCHAIN = INTRINSIC_W_CHAIN(INCHAIN, INTRINSICID, arg1, ...) This node represents a target in...
Definition: ISDOpcodes.h:198
@ BUILD_VECTOR
BUILD_VECTOR(ELT0, ELT1, ELT2, ELT3,...) - Return a fixed-width vector with the specified,...
Definition: ISDOpcodes.h:530
bool allOperandsUndef(const SDNode *N)
Return true if the node has at least one operand and all operands of the specified node are ISD::UNDE...
@ Bitcast
Perform the operation on a different, but equivalently sized type.
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:480
static bool isIndirectCall(const MachineInstr &MI)
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1697
bool Isv2x16VT(EVT VT)
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition: STLExtras.h:2448
MaybeAlign getAlign(const Function &F, unsigned Index)
uint64_t PowerOf2Ceil(uint64_t A)
Returns the power of two which is greater than or equal to the given value.
Definition: MathExtras.h:396
OutputIt transform(R &&Range, OutputIt d_first, UnaryFunction F)
Wrapper function around std::transform to apply a function to a range and store the result elsewhere.
Definition: STLExtras.h:1952
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
Definition: MathExtras.h:293
unsigned promoteScalarArgumentSize(unsigned size)
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:167
CodeGenOptLevel
Code generation optimization level.
Definition: CodeGen.h:54
@ Mul
Product of integers.
@ Add
Sum of integers.
uint64_t alignTo(uint64_t Size, Align A)
Returns a multiple of A needed to store Size bytes.
Definition: Alignment.h:155
DWARFExpression::Operation Op
void ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, Type *Ty, SmallVectorImpl< EVT > &ValueVTs, SmallVectorImpl< EVT > *MemVTs, SmallVectorImpl< TypeSize > *Offsets=nullptr, TypeSize StartingOffset=TypeSize::getZero())
ComputeValueVTs - Given an LLVM IR type, compute a sequence of EVTs that represent all the individual...
Definition: Analysis.cpp:79
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:217
bool isKernelFunction(const Function &F)
Function * getMaybeBitcastedCallee(const CallBase *CB)
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
static const fltSemantics & IEEEsingle() LLVM_READNONE
Definition: APFloat.cpp:257
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
uint64_t value() const
This is a hole in the type system and should not be abused.
Definition: Alignment.h:85
@ PreserveSign
The sign of a flushed-to-zero number is preserved in the sign of 0.
DenormalModeKind Output
Denormal flushing mode for floating point instruction results in the default floating point environme...
Extended Value Type.
Definition: ValueTypes.h:35
TypeSize getStoreSize() const
Return the number of bytes overwritten by a store of the specified value type.
Definition: ValueTypes.h:390
bool isSimple() const
Test if the given EVT is simple (as opposed to being extended).
Definition: ValueTypes.h:137
static EVT getVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements, bool IsScalable=false)
Returns the EVT that represents a vector NumElements in length, where each element is of type VT.
Definition: ValueTypes.h:74
EVT changeTypeToInteger() const
Return the type converted to an equivalently sized integer or vector with integer element type.
Definition: ValueTypes.h:121
bool isFloatingPoint() const
Return true if this is a FP or a vector FP type.
Definition: ValueTypes.h:147
ElementCount getVectorElementCount() const
Definition: ValueTypes.h:345
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
Definition: ValueTypes.h:368
uint64_t getScalarSizeInBits() const
Definition: ValueTypes.h:380
MVT getSimpleVT() const
Return the SimpleValueType held in the specified simple EVT.
Definition: ValueTypes.h:311
uint64_t getFixedSizeInBits() const
Return the size of the specified fixed width value type in bits.
Definition: ValueTypes.h:376
bool isVector() const
Return true if this is a vector value type.
Definition: ValueTypes.h:168
EVT getScalarType() const
If this is a vector type, return the element type, otherwise return this.
Definition: ValueTypes.h:318
bool bitsEq(EVT VT) const
Return true if this has the same number of bits as VT.
Definition: ValueTypes.h:251
Type * getTypeForEVT(LLVMContext &Context) const
This method returns an LLVM type corresponding to the specified EVT.
Definition: ValueTypes.cpp:210
EVT getVectorElementType() const
Given a vector type, return the type of each element.
Definition: ValueTypes.h:323
bool isScalarInteger() const
Return true if this is an integer, but not a vector.
Definition: ValueTypes.h:157
EVT changeVectorElementType(EVT EltVT) const
Return a VT for a vector type whose attributes match ourselves with the exception of the element type...
Definition: ValueTypes.h:102
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
Definition: ValueTypes.h:331
bool isInteger() const
Return true if this is an integer or a vector integer type.
Definition: ValueTypes.h:152
This class contains a discriminated union of information about pointers in memory operands,...
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
This represents a list of ValueType's that has been intern'd by a SelectionDAG.
This represents an addressing mode of: BaseGV + BaseOffs + BaseReg + Scale*ScaleReg + ScalableOffset*...
This structure contains all information that is necessary for lowering calls.
SmallVector< ISD::InputArg, 32 > Ins
SmallVector< ISD::OutputArg, 32 > Outs
SmallVector< SDValue, 32 > OutVals
SDValue CombineTo(SDNode *N, ArrayRef< SDValue > To, bool AddTo=true)