LLVM 20.0.0git
NVPTXISelLowering.cpp
Go to the documentation of this file.
1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
16#include "NVPTX.h"
17#include "NVPTXSubtarget.h"
18#include "NVPTXTargetMachine.h"
20#include "NVPTXUtilities.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/StringRef.h"
36#include "llvm/IR/Argument.h"
37#include "llvm/IR/Attributes.h"
38#include "llvm/IR/Constants.h"
39#include "llvm/IR/DataLayout.h"
42#include "llvm/IR/FPEnv.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
45#include "llvm/IR/Instruction.h"
47#include "llvm/IR/IntrinsicsNVPTX.h"
48#include "llvm/IR/Module.h"
49#include "llvm/IR/Type.h"
50#include "llvm/IR/Value.h"
60#include <algorithm>
61#include <cassert>
62#include <cmath>
63#include <cstdint>
64#include <iterator>
65#include <optional>
66#include <string>
67#include <utility>
68#include <vector>
69
70#define DEBUG_TYPE "nvptx-lower"
71
72using namespace llvm;
73
74static std::atomic<unsigned> GlobalUniqueCallSite;
75
77 "nvptx-sched4reg",
78 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
79
81 "nvptx-fma-level", cl::Hidden,
82 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
83 " 1: do it 2: do it aggressively"),
84 cl::init(2));
85
87 "nvptx-prec-divf32", cl::Hidden,
88 cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
89 " IEEE Compliant F32 div.rnd if available."),
90 cl::init(2));
91
93 "nvptx-prec-sqrtf32", cl::Hidden,
94 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
95 cl::init(true));
96
98 "nvptx-force-min-byval-param-align", cl::Hidden,
99 cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
100 " params of device functions."),
101 cl::init(false));
102
104 if (UsePrecDivF32.getNumOccurrences() > 0) {
105 // If nvptx-prec-div32=N is used on the command-line, always honor it
106 return UsePrecDivF32;
107 } else {
108 // Otherwise, use div.approx if fast math is enabled
109 if (getTargetMachine().Options.UnsafeFPMath)
110 return 0;
111 else
112 return 2;
113 }
114}
115
118 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
119 return UsePrecSqrtF32;
120 } else {
121 // Otherwise, use sqrt.approx if fast math is enabled
123 }
124}
125
129}
130
131static bool IsPTXVectorType(MVT VT) {
132 switch (VT.SimpleTy) {
133 default:
134 return false;
135 case MVT::v2i1:
136 case MVT::v4i1:
137 case MVT::v2i8:
138 case MVT::v4i8:
139 case MVT::v8i8: // <2 x i8x4>
140 case MVT::v16i8: // <4 x i8x4>
141 case MVT::v2i16:
142 case MVT::v4i16:
143 case MVT::v8i16: // <4 x i16x2>
144 case MVT::v2i32:
145 case MVT::v4i32:
146 case MVT::v2i64:
147 case MVT::v2f16:
148 case MVT::v4f16:
149 case MVT::v8f16: // <4 x f16x2>
150 case MVT::v2bf16:
151 case MVT::v4bf16:
152 case MVT::v8bf16: // <4 x bf16x2>
153 case MVT::v2f32:
154 case MVT::v4f32:
155 case MVT::v2f64:
156 return true;
157 }
158}
159
160static bool Is16bitsType(MVT VT) {
161 return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
162 VT.SimpleTy == MVT::i16);
163}
164
165// When legalizing vector loads/stores, this function is called, which does two
166// things:
167// 1. Determines Whether the vector is something we want to custom lower,
168// std::nullopt is returned if we do not want to custom lower it.
169// 2. If we do want to handle it, returns two parameters:
170// - unsigned int NumElts - The number of elements in the final vector
171// - EVT EltVT - The type of the elements in the final vector
172static std::optional<std::pair<unsigned int, EVT>>
174 if (!VectorVT.isVector() || !VectorVT.isSimple())
175 return std::nullopt;
176
177 EVT EltVT = VectorVT.getVectorElementType();
178 unsigned NumElts = VectorVT.getVectorNumElements();
179
180 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
181 // legal. We can (and should) split that into 2 stores of <2 x double> here
182 // but I'm leaving that as a TODO for now.
183 switch (VectorVT.getSimpleVT().SimpleTy) {
184 default:
185 return std::nullopt;
186 case MVT::v2i8:
187 case MVT::v2i16:
188 case MVT::v2i32:
189 case MVT::v2i64:
190 case MVT::v2f16:
191 case MVT::v2bf16:
192 case MVT::v2f32:
193 case MVT::v2f64:
194 case MVT::v4i8:
195 case MVT::v4i16:
196 case MVT::v4i32:
197 case MVT::v4f16:
198 case MVT::v4bf16:
199 case MVT::v4f32:
200 // This is a "native" vector type
201 return std::pair(NumElts, EltVT);
202 case MVT::v8i8: // <2 x i8x4>
203 case MVT::v8f16: // <4 x f16x2>
204 case MVT::v8bf16: // <4 x bf16x2>
205 case MVT::v8i16: // <4 x i16x2>
206 case MVT::v16i8: // <4 x i8x4>
207 // This can be upsized into a "native" vector type.
208 // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
209 // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
210 // vectorized loads/stores with the actual element type for i8/i16 as that
211 // would require v8/v16 variants that do not exist.
212 // In order to load/store such vectors efficiently, here in Type
213 // Legalization, we split the vector into word-sized chunks (v2x16/v4i8).
214 // Later, we will lower to PTX as vectors of b32.
215
216 // Number of elements to pack in one word.
217 unsigned NPerWord = 32 / EltVT.getSizeInBits();
218
219 return std::pair(NumElts / NPerWord,
220 MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord));
221 }
222
223 llvm_unreachable("All cases in switch should return.");
224}
225
226/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
227/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
228/// into their primitive components.
229/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
230/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
231/// LowerCall, and LowerReturn.
232static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
233 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
234 SmallVectorImpl<uint64_t> *Offsets = nullptr,
235 uint64_t StartingOffset = 0) {
236 SmallVector<EVT, 16> TempVTs;
237 SmallVector<uint64_t, 16> TempOffsets;
238
239 // Special case for i128 - decompose to (i64, i64)
240 if (Ty->isIntegerTy(128)) {
241 ValueVTs.push_back(EVT(MVT::i64));
242 ValueVTs.push_back(EVT(MVT::i64));
243
244 if (Offsets) {
245 Offsets->push_back(StartingOffset + 0);
246 Offsets->push_back(StartingOffset + 8);
247 }
248
249 return;
250 }
251
252 // Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
253 if (StructType *STy = dyn_cast<StructType>(Ty)) {
254 auto const *SL = DL.getStructLayout(STy);
255 auto ElementNum = 0;
256 for(auto *EI : STy->elements()) {
257 ComputePTXValueVTs(TLI, DL, EI, ValueVTs, Offsets,
258 StartingOffset + SL->getElementOffset(ElementNum));
259 ++ElementNum;
260 }
261 return;
262 }
263
264 // Given an array type, recursively traverse the elements with custom ComputePTXValueVTs.
265 if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
266 Type *EltTy = ATy->getElementType();
267 uint64_t EltSize = DL.getTypeAllocSize(EltTy);
268 for (int I : llvm::seq<int>(ATy->getNumElements()))
269 ComputePTXValueVTs(TLI, DL, EltTy, ValueVTs, Offsets, StartingOffset + I * EltSize);
270 return;
271 }
272
273 ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
274 for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
275 EVT VT = TempVTs[i];
276 uint64_t Off = TempOffsets[i];
277 // Split vectors into individual elements, except for v2f16, which
278 // we will pass as a single scalar.
279 if (VT.isVector()) {
280 unsigned NumElts = VT.getVectorNumElements();
281 EVT EltVT = VT.getVectorElementType();
282 // We require power-of-2 sized vectors becuase
283 // TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
284 // ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
285 // vectors.
286 if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
287 isPowerOf2_32(NumElts)) {
288 // Vectors with an even number of f16 elements will be passed to
289 // us as an array of v2f16/v2bf16 elements. We must match this so we
290 // stay in sync with Ins/Outs.
291 switch (EltVT.getSimpleVT().SimpleTy) {
292 case MVT::f16:
293 EltVT = MVT::v2f16;
294 break;
295 case MVT::bf16:
296 EltVT = MVT::v2bf16;
297 break;
298 case MVT::i16:
299 EltVT = MVT::v2i16;
300 break;
301 default:
302 llvm_unreachable("Unexpected type");
303 }
304 NumElts /= 2;
305 } else if (EltVT.getSimpleVT() == MVT::i8 &&
306 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) ||
307 NumElts == 3)) {
308 // v*i8 are formally lowered as v4i8
309 EltVT = MVT::v4i8;
310 NumElts = (NumElts + 3) / 4;
311 } else if (EltVT.getSimpleVT() == MVT::i8 && NumElts == 2) {
312 // v2i8 is promoted to v2i16
313 NumElts = 1;
314 EltVT = MVT::v2i16;
315 }
316 for (unsigned j = 0; j != NumElts; ++j) {
317 ValueVTs.push_back(EltVT);
318 if (Offsets)
319 Offsets->push_back(Off + j * EltVT.getStoreSize());
320 }
321 } else {
322 ValueVTs.push_back(VT);
323 if (Offsets)
324 Offsets->push_back(Off);
325 }
326 }
327}
328
329/// PromoteScalarIntegerPTX
330/// Used to make sure the arguments/returns are suitable for passing
331/// and promote them to a larger size if they're not.
332///
333/// The promoted type is placed in \p PromoteVT if the function returns true.
334static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) {
335 if (VT.isScalarInteger()) {
336 switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
337 default:
339 "Promotion is not suitable for scalars of size larger than 64-bits");
340 case 1:
341 *PromotedVT = MVT::i1;
342 break;
343 case 2:
344 case 4:
345 case 8:
346 *PromotedVT = MVT::i8;
347 break;
348 case 16:
349 *PromotedVT = MVT::i16;
350 break;
351 case 32:
352 *PromotedVT = MVT::i32;
353 break;
354 case 64:
355 *PromotedVT = MVT::i64;
356 break;
357 }
358 return EVT(*PromotedVT) != VT;
359 }
360 return false;
361}
362
363// Check whether we can merge loads/stores of some of the pieces of a
364// flattened function parameter or return value into a single vector
365// load/store.
366//
367// The flattened parameter is represented as a list of EVTs and
368// offsets, and the whole structure is aligned to ParamAlignment. This
369// function determines whether we can load/store pieces of the
370// parameter starting at index Idx using a single vectorized op of
371// size AccessSize. If so, it returns the number of param pieces
372// covered by the vector op. Otherwise, it returns 1.
374 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
375 const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
376
377 // Can't vectorize if param alignment is not sufficient.
378 if (ParamAlignment < AccessSize)
379 return 1;
380 // Can't vectorize if offset is not aligned.
381 if (Offsets[Idx] & (AccessSize - 1))
382 return 1;
383
384 EVT EltVT = ValueVTs[Idx];
385 unsigned EltSize = EltVT.getStoreSize();
386
387 // Element is too large to vectorize.
388 if (EltSize >= AccessSize)
389 return 1;
390
391 unsigned NumElts = AccessSize / EltSize;
392 // Can't vectorize if AccessBytes if not a multiple of EltSize.
393 if (AccessSize != EltSize * NumElts)
394 return 1;
395
396 // We don't have enough elements to vectorize.
397 if (Idx + NumElts > ValueVTs.size())
398 return 1;
399
400 // PTX ISA can only deal with 2- and 4-element vector ops.
401 if (NumElts != 4 && NumElts != 2)
402 return 1;
403
404 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
405 // Types do not match.
406 if (ValueVTs[j] != EltVT)
407 return 1;
408
409 // Elements are not contiguous.
410 if (Offsets[j] - Offsets[j - 1] != EltSize)
411 return 1;
412 }
413 // OK. We can vectorize ValueVTs[i..i+NumElts)
414 return NumElts;
415}
416
417// Flags for tracking per-element vectorization state of loads/stores
418// of a flattened function parameter or return value.
420 PVF_INNER = 0x0, // Middle elements of a vector.
421 PVF_FIRST = 0x1, // First element of the vector.
422 PVF_LAST = 0x2, // Last element of the vector.
423 // Scalar is effectively a 1-element vector.
426
427// Computes whether and how we can vectorize the loads/stores of a
428// flattened function parameter or return value.
429//
430// The flattened parameter is represented as the list of ValueVTs and
431// Offsets, and is aligned to ParamAlignment bytes. We return a vector
432// of the same size as ValueVTs indicating how each piece should be
433// loaded/stored (i.e. as a scalar, or as part of a vector
434// load/store).
437 const SmallVectorImpl<uint64_t> &Offsets,
438 Align ParamAlignment, bool IsVAArg = false) {
439 // Set vector size to match ValueVTs and mark all elements as
440 // scalars by default.
442 VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
443
444 if (IsVAArg)
445 return VectorInfo;
446
447 // Check what we can vectorize using 128/64/32-bit accesses.
448 for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
449 // Skip elements we've already processed.
450 assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
451 for (unsigned AccessSize : {16, 8, 4, 2}) {
452 unsigned NumElts = CanMergeParamLoadStoresStartingAt(
453 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
454 // Mark vectorized elements.
455 switch (NumElts) {
456 default:
457 llvm_unreachable("Unexpected return value");
458 case 1:
459 // Can't vectorize using this size, try next smaller size.
460 continue;
461 case 2:
462 assert(I + 1 < E && "Not enough elements.");
463 VectorInfo[I] = PVF_FIRST;
464 VectorInfo[I + 1] = PVF_LAST;
465 I += 1;
466 break;
467 case 4:
468 assert(I + 3 < E && "Not enough elements.");
469 VectorInfo[I] = PVF_FIRST;
470 VectorInfo[I + 1] = PVF_INNER;
471 VectorInfo[I + 2] = PVF_INNER;
472 VectorInfo[I + 3] = PVF_LAST;
473 I += 3;
474 break;
475 }
476 // Break out of the inner loop because we've already succeeded
477 // using largest possible AccessSize.
478 break;
479 }
480 }
481 return VectorInfo;
482}
483
485 SDValue Value) {
486 if (Value->getValueType(0) == VT)
487 return Value;
488 return DAG.getNode(ISD::BITCAST, DL, VT, Value);
489}
490
491// NVPTXTargetLowering Constructor.
493 const NVPTXSubtarget &STI)
494 : TargetLowering(TM), nvTM(&TM), STI(STI) {
495 // always lower memset, memcpy, and memmove intrinsics to load/store
496 // instructions, rather
497 // then generating calls to memset, mempcy or memmove.
501
504
505 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
506 // condition branches.
507 setJumpIsExpensive(true);
508
509 // Wide divides are _very_ slow. Try to reduce the width of the divide if
510 // possible.
511 addBypassSlowDiv(64, 32);
512
513 // By default, use the Source scheduling
514 if (sched4reg)
516 else
518
519 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
520 LegalizeAction NoF16Action) {
521 bool IsOpSupported = STI.allowFP16Math();
522 switch (Op) {
523 // Several FP16 instructions are available on sm_80 only.
524 case ISD::FMINNUM:
525 case ISD::FMAXNUM:
528 case ISD::FMAXIMUM:
529 case ISD::FMINIMUM:
530 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
531 break;
532 }
533 setOperationAction(Op, VT, IsOpSupported ? Action : NoF16Action);
534 };
535
536 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
537 LegalizeAction NoBF16Action) {
538 bool IsOpSupported = STI.hasNativeBF16Support(Op);
540 Op, VT, IsOpSupported ? Action : NoBF16Action);
541 };
542
543 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
544 LegalizeAction NoI16x2Action) {
545 bool IsOpSupported = false;
546 // instructions are available on sm_90 only
547 switch (Op) {
548 case ISD::ADD:
549 case ISD::SMAX:
550 case ISD::SMIN:
551 case ISD::UMIN:
552 case ISD::UMAX:
553 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
554 break;
555 }
556 setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
557 };
558
559 addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
560 addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
561 addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
562 addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
563 addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
564 addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
565 addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
566 addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
567 addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
568 addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
569 addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
570 addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
571
572 // Conversion to/from FP16/FP16x2 is always legal.
577
579 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
581
582 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
583 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
584
585 // Conversion to/from BFP16/BFP16x2 is always legal.
590
591 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
592 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
593 if (getOperationAction(ISD::SETCC, MVT::bf16) == Promote)
594 AddPromotedToType(ISD::SETCC, MVT::bf16, MVT::f32);
595
596 // Conversion to/from i16/i16x2 is always legal.
601
606
607 // Custom conversions to/from v2i8.
609
610 // Only logical ops can be done on v4i8 directly, others must be done
611 // elementwise.
628 MVT::v4i8, Expand);
629
630 // Operations not directly supported by NVPTX.
631 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
632 MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
633 MVT::i32, MVT::i64}) {
636 }
637
638 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
639 // For others we will expand to a SHL/SRA pair.
646
653
656
658 {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
659 Expand);
660
661 if (STI.hasHWROT32())
663
665
668
671
672 // We want to legalize constant related memmove and memcopy
673 // intrinsics.
675
676 // Turn FP extload into load/fpextend
677 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
678 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
679 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
680 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
681 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
682 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
683 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
684 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
685 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
686 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
687 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
688 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
689 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
690 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
691 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
692 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
693 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
694 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
695 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
696 // Turn FP truncstore into trunc + store.
697 // FIXME: vector types should also be expanded
698 setTruncStoreAction(MVT::f32, MVT::f16, Expand);
699 setTruncStoreAction(MVT::f64, MVT::f16, Expand);
700 setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
701 setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
702 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
703
704 // PTX does not support load / store predicate registers
707
708 for (MVT VT : MVT::integer_valuetypes()) {
712 setTruncStoreAction(VT, MVT::i1, Expand);
713 }
714
718 MVT::i1, Expand);
719
720 // expand extload of vector of integers.
722 MVT::v2i8, Expand);
723 setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
724
725 // This is legal in NVPTX
730
731 setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom);
733
734 // TRAP can be lowered to PTX trap
735 setOperationAction(ISD::TRAP, MVT::Other, Legal);
736 // DEBUGTRAP can be lowered to PTX brkpt
738
739 // Register custom handling for vector loads/stores
741 if (IsPTXVectorType(VT)) {
745 }
746 }
747
748 // Support varargs.
753
754 // Custom handling for i8 intrinsics
756
757 for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) {
763
766 }
767
768 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
769 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
770 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
771 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
772 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
773 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
774 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
775
776 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
777 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
778 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
779 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
780 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
781 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
782
783 // Other arithmetic and logic ops are unsupported.
787 MVT::v2i16, Expand);
788
793 if (STI.getPTXVersion() >= 43) {
798 }
799
801 setOperationAction(ISD::CTTZ, MVT::v2i16, Expand);
804
805 // PTX does not directly support SELP of i1, so promote to i32 first
807
808 // PTX cannot multiply two i64s in a single instruction.
811
812 // We have some custom DAG combine patterns for these nodes
816
817 // setcc for f16x2 and bf16x2 needs special handling to prevent
818 // legalizer's attempt to scalarize it due to v2i1 not being legal.
819 if (STI.allowFP16Math() || STI.hasBF16Math())
821
822 // Promote fp16 arithmetic if fp16 hardware isn't available or the
823 // user passed --nvptx-no-fp16-math. The flag is useful because,
824 // although sm_53+ GPUs have some sort of FP16 support in
825 // hardware, only sm_53 and sm_60 have full implementation. Others
826 // only have token amount of hardware and are likely to run faster
827 // by using fp32 units instead.
828 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
829 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
830 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
831 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
832 // bf16 must be promoted to f32.
833 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
834 if (getOperationAction(Op, MVT::bf16) == Promote)
835 AddPromotedToType(Op, MVT::bf16, MVT::f32);
836 }
837
838 // On SM80, we select add/mul/sub as fma to avoid promotion to float
839 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
840 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
843 }
844 }
845 }
846
847 // f16/f16x2 neg was introduced in PTX 60, SM_53.
848 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
849 STI.getPTXVersion() >= 60 &&
850 STI.allowFP16Math();
851 for (const auto &VT : {MVT::f16, MVT::v2f16})
853 IsFP16FP16x2NegAvailable ? Legal : Expand);
854
855 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
856 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
857 // (would be) Library functions.
858
859 // These map to conversion instructions for scalar FP types.
860 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
862 setOperationAction(Op, MVT::f16, Legal);
863 setOperationAction(Op, MVT::f32, Legal);
864 setOperationAction(Op, MVT::f64, Legal);
865 setOperationAction(Op, MVT::v2f16, Expand);
866 setOperationAction(Op, MVT::v2bf16, Expand);
867 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
868 if (getOperationAction(Op, MVT::bf16) == Promote)
869 AddPromotedToType(Op, MVT::bf16, MVT::f32);
870 }
871
872 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
874 }
875 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
876 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
879 }
880 }
881
882 // sm_80 only has conversions between f32 and bf16. Custom lower all other
883 // bf16 conversions.
884 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
885 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
888 VT, Custom);
889 }
892 MVT::bf16, Custom);
893 }
894
901 AddPromotedToType(ISD::FROUND, MVT::bf16, MVT::f32);
902
903 // 'Expand' implements FCOPYSIGN without calling an external library.
910
911 // These map to corresponding instructions for f32/f64. f16 must be
912 // promoted to f32. v2f16 is expanded to f16, which is then promoted
913 // to f32.
914 for (const auto &Op :
916 setOperationAction(Op, MVT::f16, Promote);
917 setOperationAction(Op, MVT::f32, Legal);
918 setOperationAction(Op, MVT::f64, Legal);
919 setOperationAction(Op, MVT::v2f16, Expand);
920 setOperationAction(Op, MVT::v2bf16, Expand);
921 setOperationAction(Op, MVT::bf16, Promote);
922 AddPromotedToType(Op, MVT::bf16, MVT::f32);
923 }
924
925 setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
926 if (STI.getPTXVersion() >= 65) {
927 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
928 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
929 } else {
931 setOperationAction(ISD::FABS, MVT::v2f16, Expand);
932 }
933 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
934 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
935 if (getOperationAction(ISD::FABS, MVT::bf16) == Promote)
936 AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32);
937
938 for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
939 setOperationAction(Op, MVT::f32, Legal);
940 setOperationAction(Op, MVT::f64, Legal);
941 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
942 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
943 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
944 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
945 if (getOperationAction(Op, MVT::bf16) == Promote)
946 AddPromotedToType(Op, MVT::bf16, MVT::f32);
947 }
948 bool SupportsF32MinMaxNaN =
949 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
950 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
951 setOperationAction(Op, MVT::f32, SupportsF32MinMaxNaN ? Legal : Expand);
952 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
953 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
954 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
955 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
956 }
957
958 // Custom lowering for inline asm with 128-bit operands
961
962 // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
963 // No FPOW or FREM in PTX.
964
965 // Now deduce the information based on the above mentioned
966 // actions
968
969 setMinCmpXchgSizeInBits(STI.hasAtomCas16() ? 16 : 32);
972}
973
974const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
975
976#define MAKE_CASE(V) \
977 case V: \
978 return #V;
979
980 switch ((NVPTXISD::NodeType)Opcode) {
982 break;
983
1046 }
1047 return nullptr;
1048
1049#undef MAKE_CASE
1050}
1051
1054 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1055 VT.getScalarType() == MVT::i1)
1056 return TypeSplitVector;
1058}
1059
1061 int Enabled, int &ExtraSteps,
1062 bool &UseOneConst,
1063 bool Reciprocal) const {
1066 return SDValue();
1067
1068 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1069 ExtraSteps = 0;
1070
1071 SDLoc DL(Operand);
1072 EVT VT = Operand.getValueType();
1073 bool Ftz = useF32FTZ(DAG.getMachineFunction());
1074
1075 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1076 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
1077 DAG.getConstant(IID, DL, MVT::i32), Operand);
1078 };
1079
1080 // The sqrt and rsqrt refinement processes assume we always start out with an
1081 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1082 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1083 // any refinement, we must return a regular sqrt.
1084 if (Reciprocal || ExtraSteps > 0) {
1085 if (VT == MVT::f32)
1086 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1087 : Intrinsic::nvvm_rsqrt_approx_f);
1088 else if (VT == MVT::f64)
1089 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1090 else
1091 return SDValue();
1092 } else {
1093 if (VT == MVT::f32)
1094 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1095 : Intrinsic::nvvm_sqrt_approx_f);
1096 else {
1097 // There's no sqrt.approx.f64 instruction, so we emit
1098 // reciprocal(rsqrt(x)). This is faster than
1099 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1100 // x * rsqrt(x).)
1101 return DAG.getNode(
1103 DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32),
1104 MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1105 }
1106 }
1107}
1108
1109SDValue
1111 SDLoc dl(Op);
1112 const GlobalAddressSDNode *GAN = cast<GlobalAddressSDNode>(Op);
1113 auto PtrVT = getPointerTy(DAG.getDataLayout(), GAN->getAddressSpace());
1114 Op = DAG.getTargetGlobalAddress(GAN->getGlobal(), dl, PtrVT);
1115 return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
1116}
1117
1118static bool IsTypePassedAsArray(const Type *Ty) {
1119 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
1120 Ty->isHalfTy() || Ty->isBFloatTy();
1121}
1122
1124 const DataLayout &DL, Type *retTy, const ArgListTy &Args,
1125 const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
1126 std::optional<std::pair<unsigned, const APInt &>> VAInfo,
1127 const CallBase &CB, unsigned UniqueCallSite) const {
1128 auto PtrVT = getPointerTy(DL);
1129
1130 bool isABI = (STI.getSmVersion() >= 20);
1131 assert(isABI && "Non-ABI compilation is not supported");
1132 if (!isABI)
1133 return "";
1134
1135 std::string Prototype;
1136 raw_string_ostream O(Prototype);
1137 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1138
1139 if (retTy->getTypeID() == Type::VoidTyID) {
1140 O << "()";
1141 } else {
1142 O << "(";
1143 if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
1144 !IsTypePassedAsArray(retTy)) {
1145 unsigned size = 0;
1146 if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
1147 size = ITy->getBitWidth();
1148 } else {
1149 assert(retTy->isFloatingPointTy() &&
1150 "Floating point type expected here");
1151 size = retTy->getPrimitiveSizeInBits();
1152 }
1153 // PTX ABI requires all scalar return values to be at least 32
1154 // bits in size. fp16 normally uses .b16 as its storage type in
1155 // PTX, so its size must be adjusted here, too.
1157
1158 O << ".param .b" << size << " _";
1159 } else if (isa<PointerType>(retTy)) {
1160 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1161 } else if (IsTypePassedAsArray(retTy)) {
1162 O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
1163 << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
1164 } else {
1165 llvm_unreachable("Unknown return type");
1166 }
1167 O << ") ";
1168 }
1169 O << "_ (";
1170
1171 bool first = true;
1172
1173 unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
1174 for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
1175 Type *Ty = Args[i].Ty;
1176 if (!first) {
1177 O << ", ";
1178 }
1179 first = false;
1180
1181 if (!Outs[OIdx].Flags.isByVal()) {
1182 if (IsTypePassedAsArray(Ty)) {
1183 Align ParamAlign =
1184 getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
1185 O << ".param .align " << ParamAlign.value() << " .b8 ";
1186 O << "_";
1187 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1188 // update the index for Outs
1189 SmallVector<EVT, 16> vtparts;
1190 ComputeValueVTs(*this, DL, Ty, vtparts);
1191 if (unsigned len = vtparts.size())
1192 OIdx += len - 1;
1193 continue;
1194 }
1195 // i8 types in IR will be i16 types in SDAG
1196 assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
1197 (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
1198 "type mismatch between callee prototype and arguments");
1199 // scalar type
1200 unsigned sz = 0;
1201 if (isa<IntegerType>(Ty)) {
1202 sz = cast<IntegerType>(Ty)->getBitWidth();
1204 } else if (isa<PointerType>(Ty)) {
1205 sz = PtrVT.getSizeInBits();
1206 } else {
1207 sz = Ty->getPrimitiveSizeInBits();
1208 }
1209 O << ".param .b" << sz << " ";
1210 O << "_";
1211 continue;
1212 }
1213
1214 // Indirect calls need strict ABI alignment so we disable optimizations by
1215 // not providing a function to optimize.
1216 Type *ETy = Args[i].IndirectType;
1217 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1218 Align ParamByValAlign =
1219 getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
1220
1221 O << ".param .align " << ParamByValAlign.value() << " .b8 ";
1222 O << "_";
1223 O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
1224 }
1225
1226 if (VAInfo)
1227 O << (first ? "" : ",") << " .param .align " << VAInfo->second
1228 << " .b8 _[]\n";
1229 O << ")";
1231 O << " .noreturn";
1232 O << ";";
1233
1234 return Prototype;
1235}
1236
1238 const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
1239 return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
1240}
1241
1242Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
1243 unsigned Idx,
1244 const DataLayout &DL) const {
1245 if (!CB) {
1246 // CallSite is zero, fallback to ABI type alignment
1247 return DL.getABITypeAlign(Ty);
1248 }
1249
1250 const Function *DirectCallee = CB->getCalledFunction();
1251
1252 if (!DirectCallee) {
1253 // We don't have a direct function symbol, but that may be because of
1254 // constant cast instructions in the call.
1255
1256 // With bitcast'd call targets, the instruction will be the call
1257 if (const auto *CI = dyn_cast<CallInst>(CB)) {
1258 // Check if we have call alignment metadata
1259 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1260 return StackAlign.value();
1261 }
1262 DirectCallee = getMaybeBitcastedCallee(CB);
1263 }
1264
1265 // Check for function alignment information if we found that the
1266 // ultimate target is a Function
1267 if (DirectCallee)
1268 return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
1269
1270 // Call is indirect, fall back to the ABI type alignment
1271 return DL.getABITypeAlign(Ty);
1272}
1273
1274static bool adjustElementType(EVT &ElementType) {
1275 switch (ElementType.getSimpleVT().SimpleTy) {
1276 default:
1277 return false;
1278 case MVT::f16:
1279 case MVT::bf16:
1280 ElementType = MVT::i16;
1281 return true;
1282 case MVT::f32:
1283 case MVT::v2f16:
1284 case MVT::v2bf16:
1285 ElementType = MVT::i32;
1286 return true;
1287 case MVT::f64:
1288 ElementType = MVT::i64;
1289 return true;
1290 }
1291}
1292
1293// Use byte-store when the param address of the argument value is unaligned.
1294// This may happen when the return value is a field of a packed structure.
1295//
1296// This is called in LowerCall() when passing the param values.
1298 uint64_t Offset, EVT ElementType,
1299 SDValue StVal, SDValue &InGlue,
1300 unsigned ArgID, const SDLoc &dl) {
1301 // Bit logic only works on integer types
1302 if (adjustElementType(ElementType))
1303 StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
1304
1305 // Store each byte
1306 SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1307 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1308 // Shift the byte to the last byte position
1309 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
1310 DAG.getConstant(i * 8, dl, MVT::i32));
1311 SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
1312 DAG.getConstant(Offset + i, dl, MVT::i32),
1313 ShiftVal, InGlue};
1314 // Trunc store only the last byte by using
1315 // st.param.b8
1316 // The register type can be larger than b8.
1317 Chain = DAG.getMemIntrinsicNode(
1318 NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
1320 InGlue = Chain.getValue(1);
1321 }
1322 return Chain;
1323}
1324
1325// Use byte-load when the param adress of the returned value is unaligned.
1326// This may happen when the returned value is a field of a packed structure.
1327static SDValue
1329 EVT ElementType, SDValue &InGlue,
1330 SmallVectorImpl<SDValue> &TempProxyRegOps,
1331 const SDLoc &dl) {
1332 // Bit logic only works on integer types
1333 EVT MergedType = ElementType;
1334 adjustElementType(MergedType);
1335
1336 // Load each byte and construct the whole value. Initial value to 0
1337 SDValue RetVal = DAG.getConstant(0, dl, MergedType);
1338 // LoadParamMemI8 loads into i16 register only
1339 SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
1340 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1341 SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
1342 DAG.getConstant(Offset + i, dl, MVT::i32),
1343 InGlue};
1344 // This will be selected to LoadParamMemI8
1345 SDValue LdVal =
1346 DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
1347 MVT::i8, MachinePointerInfo(), Align(1));
1348 SDValue TmpLdVal = LdVal.getValue(0);
1349 Chain = LdVal.getValue(1);
1350 InGlue = LdVal.getValue(2);
1351
1352 TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
1353 TmpLdVal.getSimpleValueType(), TmpLdVal);
1354 TempProxyRegOps.push_back(TmpLdVal);
1355
1356 SDValue CMask = DAG.getConstant(255, dl, MergedType);
1357 SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
1358 // Need to extend the i16 register to the whole width.
1359 TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
1360 // Mask off the high bits. Leave only the lower 8bits.
1361 // Do this because we are using loadparam.b8.
1362 TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
1363 // Shift and merge
1364 TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
1365 RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
1366 }
1367 if (ElementType != MergedType)
1368 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
1369
1370 return RetVal;
1371}
1372
1374 const GlobalAddressSDNode *Func) {
1375 if (!Func)
1376 return false;
1377 if (auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal()))
1378 return CB->getFunctionType() != CalleeFunc->getFunctionType();
1379 return false;
1380}
1381
1383 SmallVectorImpl<SDValue> &InVals) const {
1384
1385 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1387 "Support for variadic functions (unsized array parameter) introduced "
1388 "in PTX ISA version 6.0 and requires target sm_30.");
1389
1390 SelectionDAG &DAG = CLI.DAG;
1391 SDLoc dl = CLI.DL;
1393 SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
1395 SDValue Chain = CLI.Chain;
1396 SDValue Callee = CLI.Callee;
1397 bool &isTailCall = CLI.IsTailCall;
1398 ArgListTy &Args = CLI.getArgs();
1399 Type *RetTy = CLI.RetTy;
1400 const CallBase *CB = CLI.CB;
1401 const DataLayout &DL = DAG.getDataLayout();
1402
1403 bool isABI = (STI.getSmVersion() >= 20);
1404 assert(isABI && "Non-ABI compilation is not supported");
1405 if (!isABI)
1406 return Chain;
1407
1408 // Variadic arguments.
1409 //
1410 // Normally, for each argument, we declare a param scalar or a param
1411 // byte array in the .param space, and store the argument value to that
1412 // param scalar or array starting at offset 0.
1413 //
1414 // In the case of the first variadic argument, we declare a vararg byte array
1415 // with size 0. The exact size of this array isn't known at this point, so
1416 // it'll be patched later. All the variadic arguments will be stored to this
1417 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1418 // initially set to 0, so it can be used for non-variadic arguments (which use
1419 // 0 offset) to simplify the code.
1420 //
1421 // After all vararg is processed, 'VAOffset' holds the size of the
1422 // vararg byte array.
1423
1424 SDValue VADeclareParam; // vararg byte array
1425 unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
1426 unsigned VAOffset = 0; // current offset in the param array
1427
1428 unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1);
1429 SDValue TempChain = Chain;
1430 Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
1431 SDValue InGlue = Chain.getValue(1);
1432
1433 unsigned ParamCount = 0;
1434 // Args.size() and Outs.size() need not match.
1435 // Outs.size() will be larger
1436 // * if there is an aggregate argument with multiple fields (each field
1437 // showing up separately in Outs)
1438 // * if there is a vector argument with more than typical vector-length
1439 // elements (generally if more than 4) where each vector element is
1440 // individually present in Outs.
1441 // So a different index should be used for indexing into Outs/OutVals.
1442 // See similar issue in LowerFormalArguments.
1443 unsigned OIdx = 0;
1444 // Declare the .params or .reg need to pass values
1445 // to the function
1446 for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
1447 EVT VT = Outs[OIdx].VT;
1448 Type *Ty = Args[i].Ty;
1449 bool IsVAArg = (i >= CLI.NumFixedArgs);
1450 bool IsByVal = Outs[OIdx].Flags.isByVal();
1451
1454
1455 assert((!IsByVal || Args[i].IndirectType) &&
1456 "byval arg must have indirect type");
1457 Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
1458 ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
1459
1460 Align ArgAlign;
1461 if (IsByVal) {
1462 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1463 // so we don't need to worry whether it's naturally aligned or not.
1464 // See TargetLowering::LowerCallTo().
1465 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1466 ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
1467 InitialAlign, DL);
1468 if (IsVAArg)
1469 VAOffset = alignTo(VAOffset, ArgAlign);
1470 } else {
1471 ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL);
1472 }
1473
1474 unsigned TypeSize =
1475 (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
1476 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1477
1478 bool NeedAlign; // Does argument declaration specify alignment?
1479 bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
1480 if (IsVAArg) {
1481 if (ParamCount == FirstVAArg) {
1482 SDValue DeclareParamOps[] = {
1483 Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
1484 DAG.getConstant(ParamCount, dl, MVT::i32),
1485 DAG.getConstant(1, dl, MVT::i32), InGlue};
1486 VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
1487 DeclareParamVTs, DeclareParamOps);
1488 }
1489 NeedAlign = PassAsArray;
1490 } else if (PassAsArray) {
1491 // declare .param .align <align> .b8 .param<n>[<size>];
1492 SDValue DeclareParamOps[] = {
1493 Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
1494 DAG.getConstant(ParamCount, dl, MVT::i32),
1495 DAG.getConstant(TypeSize, dl, MVT::i32), InGlue};
1496 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
1497 DeclareParamOps);
1498 NeedAlign = true;
1499 } else {
1500 // declare .param .b<size> .param<n>;
1501 if (VT.isInteger() || VT.isFloatingPoint()) {
1502 // PTX ABI requires integral types to be at least 32 bits in
1503 // size. FP16 is loaded/stored using i16, so it's handled
1504 // here as well.
1506 }
1507 SDValue DeclareScalarParamOps[] = {
1508 Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
1509 DAG.getConstant(TypeSize * 8, dl, MVT::i32),
1510 DAG.getConstant(0, dl, MVT::i32), InGlue};
1511 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
1512 DeclareScalarParamOps);
1513 NeedAlign = false;
1514 }
1515 InGlue = Chain.getValue(1);
1516
1517 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1518 // than 32-bits are sign extended or zero extended, depending on
1519 // whether they are signed or unsigned types. This case applies
1520 // only to scalar parameters and not to aggregate values.
1521 bool ExtendIntegerParam =
1522 Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
1523
1524 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
1525 SmallVector<SDValue, 6> StoreOperands;
1526 for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
1527 EVT EltVT = VTs[j];
1528 int CurOffset = Offsets[j];
1529 MaybeAlign PartAlign;
1530 if (NeedAlign)
1531 PartAlign = commonAlignment(ArgAlign, CurOffset);
1532
1533 SDValue StVal = OutVals[OIdx];
1534
1535 MVT PromotedVT;
1536 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
1537 EltVT = EVT(PromotedVT);
1538 }
1539 if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) {
1541 Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1542 StVal = DAG.getNode(Ext, dl, PromotedVT, StVal);
1543 }
1544
1545 if (IsByVal) {
1546 auto PtrVT = getPointerTy(DL);
1547 SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
1548 DAG.getConstant(CurOffset, dl, PtrVT));
1549 StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
1550 PartAlign);
1551 } else if (ExtendIntegerParam) {
1552 assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
1553 // zext/sext to i32
1554 StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
1556 dl, MVT::i32, StVal);
1557 }
1558
1559 if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) {
1560 // Use 16-bit registers for small stores as it's the
1561 // smallest general purpose register size supported by NVPTX.
1562 StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
1563 }
1564
1565 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1566 // scalar store. In such cases, fall back to byte stores.
1567 if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
1568 PartAlign.value() <
1569 DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
1570 assert(StoreOperands.empty() && "Unfinished preceeding store.");
1572 DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
1573 StVal, InGlue, ParamCount, dl);
1574
1575 // LowerUnalignedStoreParam took care of inserting the necessary nodes
1576 // into the SDAG, so just move on to the next element.
1577 if (!IsByVal)
1578 ++OIdx;
1579 continue;
1580 }
1581
1582 // New store.
1583 if (VectorInfo[j] & PVF_FIRST) {
1584 assert(StoreOperands.empty() && "Unfinished preceding store.");
1585 StoreOperands.push_back(Chain);
1586 StoreOperands.push_back(
1587 DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
1588
1589 StoreOperands.push_back(DAG.getConstant(
1590 IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
1591 dl, MVT::i32));
1592 }
1593
1594 // Record the value to store.
1595 StoreOperands.push_back(StVal);
1596
1597 if (VectorInfo[j] & PVF_LAST) {
1598 unsigned NumElts = StoreOperands.size() - 3;
1600 switch (NumElts) {
1601 case 1:
1603 break;
1604 case 2:
1606 break;
1607 case 4:
1609 break;
1610 default:
1611 llvm_unreachable("Invalid vector info.");
1612 }
1613
1614 StoreOperands.push_back(InGlue);
1615
1616 // Adjust type of the store op if we've extended the scalar
1617 // return value.
1618 EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1619
1620 Chain = DAG.getMemIntrinsicNode(
1621 Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
1622 TheStoreType, MachinePointerInfo(), PartAlign,
1624 InGlue = Chain.getValue(1);
1625
1626 // Cleanup.
1627 StoreOperands.clear();
1628
1629 // TODO: We may need to support vector types that can be passed
1630 // as scalars in variadic arguments.
1631 if (!IsByVal && IsVAArg) {
1632 assert(NumElts == 1 &&
1633 "Vectorization is expected to be disabled for variadics.");
1634 VAOffset += DL.getTypeAllocSize(
1635 TheStoreType.getTypeForEVT(*DAG.getContext()));
1636 }
1637 }
1638 if (!IsByVal)
1639 ++OIdx;
1640 }
1641 assert(StoreOperands.empty() && "Unfinished parameter store.");
1642 if (!IsByVal && VTs.size() > 0)
1643 --OIdx;
1644 ++ParamCount;
1645 if (IsByVal && IsVAArg)
1646 VAOffset += TypeSize;
1647 }
1648
1649 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
1650 MaybeAlign retAlignment = std::nullopt;
1651
1652 // Handle Result
1653 if (Ins.size() > 0) {
1654 SmallVector<EVT, 16> resvtparts;
1655 ComputeValueVTs(*this, DL, RetTy, resvtparts);
1656
1657 // Declare
1658 // .param .align N .b8 retval0[<size-in-bytes>], or
1659 // .param .b<size-in-bits> retval0
1660 unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
1661 if (!IsTypePassedAsArray(RetTy)) {
1662 resultsz = promoteScalarArgumentSize(resultsz);
1663 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1664 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
1665 DAG.getConstant(resultsz, dl, MVT::i32),
1666 DAG.getConstant(0, dl, MVT::i32), InGlue };
1667 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
1668 DeclareRetOps);
1669 InGlue = Chain.getValue(1);
1670 } else {
1671 retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
1672 assert(retAlignment && "retAlignment is guaranteed to be set");
1673 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1674 SDValue DeclareRetOps[] = {
1675 Chain, DAG.getConstant(retAlignment->value(), dl, MVT::i32),
1676 DAG.getConstant(resultsz / 8, dl, MVT::i32),
1677 DAG.getConstant(0, dl, MVT::i32), InGlue};
1678 Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
1679 DeclareRetOps);
1680 InGlue = Chain.getValue(1);
1681 }
1682 }
1683
1684 bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1685 // Set the size of the vararg param byte array if the callee is a variadic
1686 // function and the variadic part is not empty.
1687 if (HasVAArgs) {
1688 SDValue DeclareParamOps[] = {
1689 VADeclareParam.getOperand(0), VADeclareParam.getOperand(1),
1690 VADeclareParam.getOperand(2), DAG.getConstant(VAOffset, dl, MVT::i32),
1691 VADeclareParam.getOperand(4)};
1692 DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
1693 VADeclareParam->getVTList(), DeclareParamOps);
1694 }
1695
1696 // If the type of the callsite does not match that of the function, convert
1697 // the callsite to an indirect call.
1698 bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1699
1700 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1701 // between them we must rely on the call site value which is valid for
1702 // indirect calls but is always null for libcalls.
1703 bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1704
1705 if (isa<ExternalSymbolSDNode>(Callee)) {
1706 Function* CalleeFunc = nullptr;
1707
1708 // Try to find the callee in the current module.
1709 Callee = DAG.getSymbolFunctionGlobalAddress(Callee, &CalleeFunc);
1710 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1711
1712 // Set the "libcall callee" attribute to indicate that the function
1713 // must always have a declaration.
1714 CalleeFunc->addFnAttr("nvptx-libcall-callee", "true");
1715 }
1716
1717 if (isIndirectCall) {
1718 // This is indirect function call case : PTX requires a prototype of the
1719 // form
1720 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1721 // to be emitted, and the label has to used as the last arg of call
1722 // instruction.
1723 // The prototype is embedded in a string and put as the operand for a
1724 // CallPrototype SDNode which will print out to the value of the string.
1725 SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1726 std::string Proto = getPrototype(
1727 DL, RetTy, Args, Outs, retAlignment,
1728 HasVAArgs
1729 ? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair(
1730 CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1)))
1731 : std::nullopt,
1732 *CB, UniqueCallSite);
1733 const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
1734 SDValue ProtoOps[] = {
1735 Chain,
1736 DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
1737 InGlue,
1738 };
1739 Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, ProtoOps);
1740 InGlue = Chain.getValue(1);
1741 }
1742 // Op to just print "call"
1743 SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1744 SDValue PrintCallOps[] = {
1745 Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InGlue
1746 };
1747 // We model convergent calls as separate opcodes.
1749 if (CLI.IsConvergent)
1752 Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
1753 InGlue = Chain.getValue(1);
1754
1755 if (ConvertToIndirectCall) {
1756 // Copy the function ptr to a ptx register and use the register to call the
1757 // function.
1758 EVT DestVT = Callee.getValueType();
1760 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
1761 unsigned DestReg =
1762 RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
1763 auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
1764 Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
1765 }
1766
1767 // Ops to print out the function name
1768 SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1769 SDValue CallVoidOps[] = { Chain, Callee, InGlue };
1770 Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps);
1771 InGlue = Chain.getValue(1);
1772
1773 // Ops to print out the param list
1774 SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1775 SDValue CallArgBeginOps[] = { Chain, InGlue };
1776 Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
1777 CallArgBeginOps);
1778 InGlue = Chain.getValue(1);
1779
1780 for (unsigned i = 0, e = std::min(CLI.NumFixedArgs + 1, ParamCount); i != e;
1781 ++i) {
1782 unsigned opcode;
1783 if (i == (e - 1))
1784 opcode = NVPTXISD::LastCallArg;
1785 else
1786 opcode = NVPTXISD::CallArg;
1787 SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1788 SDValue CallArgOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
1789 DAG.getConstant(i, dl, MVT::i32), InGlue };
1790 Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps);
1791 InGlue = Chain.getValue(1);
1792 }
1793 SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1794 SDValue CallArgEndOps[] = { Chain,
1795 DAG.getConstant(isIndirectCall ? 0 : 1, dl, MVT::i32),
1796 InGlue };
1797 Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps);
1798 InGlue = Chain.getValue(1);
1799
1800 if (isIndirectCall) {
1801 SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1802 SDValue PrototypeOps[] = {
1803 Chain, DAG.getConstant(UniqueCallSite, dl, MVT::i32), InGlue};
1804 Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps);
1805 InGlue = Chain.getValue(1);
1806 }
1807
1808 SmallVector<SDValue, 16> ProxyRegOps;
1809 SmallVector<std::optional<MVT>, 16> ProxyRegTruncates;
1810 // An item of the vector is filled if the element does not need a ProxyReg
1811 // operation on it and should be added to InVals as is. ProxyRegOps and
1812 // ProxyRegTruncates contain empty/none items at the same index.
1814 // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
1815 // to use the values of `LoadParam`s and to be replaced later then
1816 // `CALLSEQ_END` is added.
1817 SmallVector<SDValue, 16> TempProxyRegOps;
1818
1819 // Generate loads from param memory/moves from registers for result
1820 if (Ins.size() > 0) {
1823 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
1824 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1825
1826 Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
1827 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
1828
1829 SmallVector<EVT, 6> LoadVTs;
1830 int VecIdx = -1; // Index of the first element of the vector.
1831
1832 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1833 // 32-bits are sign extended or zero extended, depending on whether
1834 // they are signed or unsigned types.
1835 bool ExtendIntegerRetVal =
1836 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
1837
1838 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
1839 bool needTruncate = false;
1840 EVT TheLoadType = VTs[i];
1841 EVT EltType = Ins[i].VT;
1842 Align EltAlign = commonAlignment(RetAlign, Offsets[i]);
1843 MVT PromotedVT;
1844
1845 if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) {
1846 TheLoadType = EVT(PromotedVT);
1847 EltType = EVT(PromotedVT);
1848 needTruncate = true;
1849 }
1850
1851 if (ExtendIntegerRetVal) {
1852 TheLoadType = MVT::i32;
1853 EltType = MVT::i32;
1854 needTruncate = true;
1855 } else if (TheLoadType.getSizeInBits() < 16) {
1856 if (VTs[i].isInteger())
1857 needTruncate = true;
1858 EltType = MVT::i16;
1859 }
1860
1861 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1862 // scalar load. In such cases, fall back to byte loads.
1863 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() &&
1864 EltAlign < DL.getABITypeAlign(
1865 TheLoadType.getTypeForEVT(*DAG.getContext()))) {
1866 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
1868 DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl);
1869 ProxyRegOps.push_back(SDValue());
1870 ProxyRegTruncates.push_back(std::optional<MVT>());
1871 RetElts.resize(i);
1872 RetElts.push_back(Ret);
1873
1874 continue;
1875 }
1876
1877 // Record index of the very first element of the vector.
1878 if (VectorInfo[i] & PVF_FIRST) {
1879 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
1880 VecIdx = i;
1881 }
1882
1883 LoadVTs.push_back(EltType);
1884
1885 if (VectorInfo[i] & PVF_LAST) {
1886 unsigned NumElts = LoadVTs.size();
1887 LoadVTs.push_back(MVT::Other);
1888 LoadVTs.push_back(MVT::Glue);
1890 switch (NumElts) {
1891 case 1:
1893 break;
1894 case 2:
1896 break;
1897 case 4:
1899 break;
1900 default:
1901 llvm_unreachable("Invalid vector info.");
1902 }
1903
1904 SDValue LoadOperands[] = {
1905 Chain, DAG.getConstant(1, dl, MVT::i32),
1906 DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InGlue};
1907 SDValue RetVal = DAG.getMemIntrinsicNode(
1908 Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
1909 MachinePointerInfo(), EltAlign,
1911
1912 for (unsigned j = 0; j < NumElts; ++j) {
1913 ProxyRegOps.push_back(RetVal.getValue(j));
1914
1915 if (needTruncate)
1916 ProxyRegTruncates.push_back(std::optional<MVT>(Ins[VecIdx + j].VT));
1917 else
1918 ProxyRegTruncates.push_back(std::optional<MVT>());
1919 }
1920
1921 Chain = RetVal.getValue(NumElts);
1922 InGlue = RetVal.getValue(NumElts + 1);
1923
1924 // Cleanup
1925 VecIdx = -1;
1926 LoadVTs.clear();
1927 }
1928 }
1929 }
1930
1931 Chain =
1932 DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl);
1933 InGlue = Chain.getValue(1);
1934
1935 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1936 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1937 // dangling.
1938 for (unsigned i = 0; i < ProxyRegOps.size(); ++i) {
1939 if (i < RetElts.size() && RetElts[i]) {
1940 InVals.push_back(RetElts[i]);
1941 continue;
1942 }
1943
1944 SDValue Ret = DAG.getNode(
1946 DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue),
1947 { Chain, ProxyRegOps[i], InGlue }
1948 );
1949
1950 Chain = Ret.getValue(1);
1951 InGlue = Ret.getValue(2);
1952
1953 if (ProxyRegTruncates[i]) {
1954 Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[i], Ret);
1955 }
1956
1957 InVals.push_back(Ret);
1958 }
1959
1960 for (SDValue &T : TempProxyRegOps) {
1961 SDValue Repl = DAG.getNode(
1963 DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue),
1964 {Chain, T.getOperand(0), InGlue});
1965 DAG.ReplaceAllUsesWith(T, Repl);
1966 DAG.RemoveDeadNode(T.getNode());
1967
1968 Chain = Repl.getValue(1);
1969 InGlue = Repl.getValue(2);
1970 }
1971
1972 // set isTailCall to false for now, until we figure out how to express
1973 // tail call optimization in PTX
1974 isTailCall = false;
1975 return Chain;
1976}
1977
1979 SelectionDAG &DAG) const {
1980
1981 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1982 const Function &Fn = DAG.getMachineFunction().getFunction();
1983
1984 DiagnosticInfoUnsupported NoDynamicAlloca(
1985 Fn,
1986 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
1987 "requires target sm_52.",
1988 SDLoc(Op).getDebugLoc());
1989 DAG.getContext()->diagnose(NoDynamicAlloca);
1990 auto Ops = {DAG.getConstant(0, SDLoc(), Op.getValueType()),
1991 Op.getOperand(0)};
1992 return DAG.getMergeValues(Ops, SDLoc());
1993 }
1994
1995 SDValue Chain = Op.getOperand(0);
1996 SDValue Size = Op.getOperand(1);
1997 uint64_t Align = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue();
1998 SDLoc DL(Op.getNode());
1999
2000 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
2001 MVT ValueSizeTy = nvTM->is64Bit() ? MVT::i64 : MVT::i32;
2002
2003 SDValue AllocOps[] = {Chain, DAG.getZExtOrTrunc(Size, DL, ValueSizeTy),
2004 DAG.getTargetConstant(Align, DL, MVT::i32)};
2005 EVT RetTypes[] = {ValueSizeTy, MVT::Other};
2006 return DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, RetTypes, AllocOps);
2007}
2008
2010 SelectionDAG &DAG) const {
2011 SDLoc DL(Op.getNode());
2012 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2013 const Function &Fn = DAG.getMachineFunction().getFunction();
2014
2015 DiagnosticInfoUnsupported NoStackRestore(
2016 Fn,
2017 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
2018 ">= sm_52.",
2019 DL.getDebugLoc());
2020 DAG.getContext()->diagnose(NoStackRestore);
2021 return Op.getOperand(0);
2022 }
2023
2024 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
2025 SDValue Chain = Op.getOperand(0);
2026 SDValue Ptr = Op.getOperand(1);
2029 return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC});
2030}
2031
2033 SelectionDAG &DAG) const {
2034 SDLoc DL(Op.getNode());
2035 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2036 const Function &Fn = DAG.getMachineFunction().getFunction();
2037
2038 DiagnosticInfoUnsupported NoStackSave(
2039 Fn,
2040 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
2041 "sm_52.",
2042 DL.getDebugLoc());
2043 DAG.getContext()->diagnose(NoStackSave);
2044 auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)};
2045 return DAG.getMergeValues(Ops, DL);
2046 }
2047
2048 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
2049 SDValue Chain = Op.getOperand(0);
2050 SDValue SS =
2051 DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain);
2052 SDValue ASC = DAG.getAddrSpaceCast(
2053 DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC);
2054 return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL);
2055}
2056
2057// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
2058// (see LegalizeDAG.cpp). This is slow and uses local memory.
2059// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
2060SDValue
2061NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2062 SDNode *Node = Op.getNode();
2063 SDLoc dl(Node);
2065 unsigned NumOperands = Node->getNumOperands();
2066 for (unsigned i = 0; i < NumOperands; ++i) {
2067 SDValue SubOp = Node->getOperand(i);
2068 EVT VVT = SubOp.getNode()->getValueType(0);
2069 EVT EltVT = VVT.getVectorElementType();
2070 unsigned NumSubElem = VVT.getVectorNumElements();
2071 for (unsigned j = 0; j < NumSubElem; ++j) {
2072 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
2073 DAG.getIntPtrConstant(j, dl)));
2074 }
2075 }
2076 return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
2077}
2078
2079SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2080 // Handle bitcasting from v2i8 without hitting the default promotion
2081 // strategy which goes through stack memory.
2082 EVT FromVT = Op->getOperand(0)->getValueType(0);
2083 if (FromVT != MVT::v2i8) {
2084 return Op;
2085 }
2086
2087 // Pack vector elements into i16 and bitcast to final type
2088 SDLoc DL(Op);
2089 SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2090 Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
2091 SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2092 Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
2093 SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
2094 SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
2095 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
2096 SDValue AsInt = DAG.getNode(
2097 ISD::OR, DL, MVT::i16,
2098 {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
2099 EVT ToVT = Op->getValueType(0);
2100 return MaybeBitcast(DAG, DL, ToVT, AsInt);
2101}
2102
2103// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2104// would get lowered as two constant loads and vector-packing move.
2105// Instead we want just a constant move:
2106// mov.b32 %r2, 0x40003C00
2107SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2108 SelectionDAG &DAG) const {
2109 EVT VT = Op->getValueType(0);
2110 if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
2111 return Op;
2112 SDLoc DL(Op);
2113
2114 if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
2115 return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
2116 isa<ConstantFPSDNode>(Operand);
2117 })) {
2118 if (VT != MVT::v4i8)
2119 return Op;
2120 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2121 // to optimize calculation of constant parts.
2122 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2123 uint64_t SelectionValue) -> SDValue {
2124 SDValue L = Left;
2125 SDValue R = Right;
2126 if (Cast) {
2127 L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
2128 R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
2129 }
2130 return DAG.getNode(
2131 NVPTXISD::PRMT, DL, MVT::v4i8,
2132 {L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),
2133 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
2134 };
2135 auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
2136 auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
2137 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2138 return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);
2139 }
2140
2141 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2142 auto GetOperand = [](SDValue Op, int N) -> APInt {
2143 const SDValue &Operand = Op->getOperand(N);
2144 EVT VT = Op->getValueType(0);
2145 if (Operand->isUndef())
2146 return APInt(32, 0);
2147 APInt Value;
2148 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2149 Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
2150 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2151 Value = Operand->getAsAPIntVal();
2152 else
2153 llvm_unreachable("Unsupported type");
2154 // i8 values are carried around as i16, so we need to zero out upper bits,
2155 // so they do not get in the way of combining individual byte values
2156 if (VT == MVT::v4i8)
2157 Value = Value.trunc(8);
2158 return Value.zext(32);
2159 };
2160 APInt Value;
2161 if (Isv2x16VT(VT)) {
2162 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
2163 } else if (VT == MVT::v4i8) {
2164 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
2165 GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
2166 } else {
2167 llvm_unreachable("Unsupported type");
2168 }
2169 SDValue Const = DAG.getConstant(Value, DL, MVT::i32);
2170 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Const);
2171}
2172
2173SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2174 SelectionDAG &DAG) const {
2175 SDValue Index = Op->getOperand(1);
2176 SDValue Vector = Op->getOperand(0);
2177 SDLoc DL(Op);
2178 EVT VectorVT = Vector.getValueType();
2179
2180 if (VectorVT == MVT::v4i8) {
2181 SDValue BFE =
2182 DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
2183 {Vector,
2184 DAG.getNode(ISD::MUL, DL, MVT::i32,
2185 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2186 DAG.getConstant(8, DL, MVT::i32)),
2187 DAG.getConstant(8, DL, MVT::i32)});
2188 return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
2189 }
2190
2191 // Constant index will be matched by tablegen.
2192 if (isa<ConstantSDNode>(Index.getNode()))
2193 return Op;
2194
2195 // Extract individual elements and select one of them.
2196 assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
2197 EVT EltVT = VectorVT.getVectorElementType();
2198
2199 SDLoc dl(Op.getNode());
2200 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2201 DAG.getIntPtrConstant(0, dl));
2202 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2203 DAG.getIntPtrConstant(1, dl));
2204 return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1,
2206}
2207
2208SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2209 SelectionDAG &DAG) const {
2210 SDValue Vector = Op->getOperand(0);
2211 EVT VectorVT = Vector.getValueType();
2212
2213 if (VectorVT != MVT::v4i8)
2214 return Op;
2215 SDLoc DL(Op);
2216 SDValue Value = Op->getOperand(1);
2217 if (Value->isUndef())
2218 return Vector;
2219
2220 SDValue Index = Op->getOperand(2);
2221
2222 SDValue BFI =
2223 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2224 {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
2225 DAG.getNode(ISD::MUL, DL, MVT::i32,
2226 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2227 DAG.getConstant(8, DL, MVT::i32)),
2228 DAG.getConstant(8, DL, MVT::i32)});
2229 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
2230}
2231
2232SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2233 SelectionDAG &DAG) const {
2234 SDValue V1 = Op.getOperand(0);
2235 EVT VectorVT = V1.getValueType();
2236 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2237 return Op;
2238
2239 // Lower shuffle to PRMT instruction.
2240 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
2241 SDValue V2 = Op.getOperand(1);
2242 uint32_t Selector = 0;
2243 for (auto I : llvm::enumerate(SVN->getMask())) {
2244 if (I.value() != -1) // -1 is a placeholder for undef.
2245 Selector |= (I.value() << (I.index() * 4));
2246 }
2247
2248 SDLoc DL(Op);
2249 return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
2250 DAG.getConstant(Selector, DL, MVT::i32),
2251 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
2252}
2253/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2254/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2255/// amount, or
2256/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2257/// amount.
2258SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2259 SelectionDAG &DAG) const {
2260 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2261 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2262
2263 EVT VT = Op.getValueType();
2264 unsigned VTBits = VT.getSizeInBits();
2265 SDLoc dl(Op);
2266 SDValue ShOpLo = Op.getOperand(0);
2267 SDValue ShOpHi = Op.getOperand(1);
2268 SDValue ShAmt = Op.getOperand(2);
2269 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2270
2271 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2272 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2273 // {dHi, dLo} = {aHi, aLo} >> Amt
2274 // dHi = aHi >> Amt
2275 // dLo = shf.r.clamp aLo, aHi, Amt
2276
2277 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2278 SDValue Lo =
2279 DAG.getNode(NVPTXISD::FSHR_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2280
2281 SDValue Ops[2] = { Lo, Hi };
2282 return DAG.getMergeValues(Ops, dl);
2283 }
2284 else {
2285 // {dHi, dLo} = {aHi, aLo} >> Amt
2286 // - if (Amt>=size) then
2287 // dLo = aHi >> (Amt-size)
2288 // dHi = aHi >> Amt (this is either all 0 or all 1)
2289 // else
2290 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2291 // dHi = aHi >> Amt
2292
2293 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2294 DAG.getConstant(VTBits, dl, MVT::i32),
2295 ShAmt);
2296 SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
2297 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2298 DAG.getConstant(VTBits, dl, MVT::i32));
2299 SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
2300 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2301 SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
2302
2303 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2304 DAG.getConstant(VTBits, dl, MVT::i32),
2305 ISD::SETGE);
2306 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2307 SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2308
2309 SDValue Ops[2] = { Lo, Hi };
2310 return DAG.getMergeValues(Ops, dl);
2311 }
2312}
2313
2314/// LowerShiftLeftParts - Lower SHL_PARTS, which
2315/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2316/// amount, or
2317/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2318/// amount.
2319SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2320 SelectionDAG &DAG) const {
2321 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2322 assert(Op.getOpcode() == ISD::SHL_PARTS);
2323
2324 EVT VT = Op.getValueType();
2325 unsigned VTBits = VT.getSizeInBits();
2326 SDLoc dl(Op);
2327 SDValue ShOpLo = Op.getOperand(0);
2328 SDValue ShOpHi = Op.getOperand(1);
2329 SDValue ShAmt = Op.getOperand(2);
2330
2331 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2332 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2333 // {dHi, dLo} = {aHi, aLo} << Amt
2334 // dHi = shf.l.clamp aLo, aHi, Amt
2335 // dLo = aLo << Amt
2336
2337 SDValue Hi =
2338 DAG.getNode(NVPTXISD::FSHL_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2339 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2340
2341 SDValue Ops[2] = { Lo, Hi };
2342 return DAG.getMergeValues(Ops, dl);
2343 }
2344 else {
2345 // {dHi, dLo} = {aHi, aLo} << Amt
2346 // - if (Amt>=size) then
2347 // dLo = aLo << Amt (all 0)
2348 // dLo = aLo << (Amt-size)
2349 // else
2350 // dLo = aLo << Amt
2351 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2352
2353 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2354 DAG.getConstant(VTBits, dl, MVT::i32),
2355 ShAmt);
2356 SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
2357 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2358 DAG.getConstant(VTBits, dl, MVT::i32));
2359 SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
2360 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2361 SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
2362
2363 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2364 DAG.getConstant(VTBits, dl, MVT::i32),
2365 ISD::SETGE);
2366 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2367 SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2368
2369 SDValue Ops[2] = { Lo, Hi };
2370 return DAG.getMergeValues(Ops, dl);
2371 }
2372}
2373
2374/// If the types match, convert the generic copysign to the NVPTXISD version,
2375/// otherwise bail ensuring that mismatched cases are properly expaned.
2376SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2377 SelectionDAG &DAG) const {
2378 EVT VT = Op.getValueType();
2379 SDLoc DL(Op);
2380
2381 SDValue In1 = Op.getOperand(0);
2382 SDValue In2 = Op.getOperand(1);
2383 EVT SrcVT = In2.getValueType();
2384
2385 if (!SrcVT.bitsEq(VT))
2386 return SDValue();
2387
2388 return DAG.getNode(NVPTXISD::FCOPYSIGN, DL, VT, In1, In2);
2389}
2390
2391SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2392 EVT VT = Op.getValueType();
2393
2394 if (VT == MVT::f32)
2395 return LowerFROUND32(Op, DAG);
2396
2397 if (VT == MVT::f64)
2398 return LowerFROUND64(Op, DAG);
2399
2400 llvm_unreachable("unhandled type");
2401}
2402
2403// This is the the rounding method used in CUDA libdevice in C like code:
2404// float roundf(float A)
2405// {
2406// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2407// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2408// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2409// }
2410SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2411 SelectionDAG &DAG) const {
2412 SDLoc SL(Op);
2413 SDValue A = Op.getOperand(0);
2414 EVT VT = Op.getValueType();
2415
2416 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2417
2418 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2419 SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
2420 const unsigned SignBitMask = 0x80000000;
2421 SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
2422 DAG.getConstant(SignBitMask, SL, MVT::i32));
2423 const unsigned PointFiveInBits = 0x3F000000;
2424 SDValue PointFiveWithSignRaw =
2425 DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
2426 DAG.getConstant(PointFiveInBits, SL, MVT::i32));
2427 SDValue PointFiveWithSign =
2428 DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
2429 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
2430 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2431
2432 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2433 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2434 SDValue IsLarge =
2435 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
2436 ISD::SETOGT);
2437 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2438
2439 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2440 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2441 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2442 SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
2443 return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
2444}
2445
2446// The implementation of round(double) is similar to that of round(float) in
2447// that they both separate the value range into three regions and use a method
2448// specific to the region to round the values. However, round(double) first
2449// calculates the round of the absolute value and then adds the sign back while
2450// round(float) directly rounds the value with sign.
2451SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2452 SelectionDAG &DAG) const {
2453 SDLoc SL(Op);
2454 SDValue A = Op.getOperand(0);
2455 EVT VT = Op.getValueType();
2456
2457 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2458
2459 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2460 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
2461 DAG.getConstantFP(0.5, SL, VT));
2462 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2463
2464 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2465 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2466 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2467 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2468 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
2469 DAG.getConstantFP(0, SL, VT),
2470 RoundedA);
2471
2472 // Add sign to rounded_A
2473 RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
2474 DAG.getNode(ISD::FTRUNC, SL, VT, A);
2475
2476 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2477 SDValue IsLarge =
2478 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
2479 ISD::SETOGT);
2480 return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2481}
2482
2484 EVT VT = N->getValueType(0);
2485 EVT NVT = MVT::f32;
2486 if (VT.isVector()) {
2487 NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
2488 }
2489 SDLoc DL(N);
2490 SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
2491 SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
2492 SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
2493 return DAG.getFPExtendOrRound(Res, DL, VT);
2494}
2495
2496SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2497 SelectionDAG &DAG) const {
2498 if (useF32FTZ(DAG.getMachineFunction())) {
2499 return PromoteBinOpToF32(Op.getNode(), DAG);
2500 }
2501 return Op;
2502}
2503
2504SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2505 SelectionDAG &DAG) const {
2506 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2507
2508 if (Op.getValueType() == MVT::bf16) {
2509 SDLoc Loc(Op);
2510 return DAG.getNode(
2511 ISD::FP_ROUND, Loc, MVT::bf16,
2512 DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
2513 DAG.getIntPtrConstant(0, Loc, /*isTarget=*/true));
2514 }
2515
2516 // Everything else is considered legal.
2517 return Op;
2518}
2519
2520SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2521 SelectionDAG &DAG) const {
2522 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2523
2524 if (Op.getOperand(0).getValueType() == MVT::bf16) {
2525 SDLoc Loc(Op);
2526 return DAG.getNode(
2527 Op.getOpcode(), Loc, Op.getValueType(),
2528 DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
2529 }
2530
2531 // Everything else is considered legal.
2532 return Op;
2533}
2534
2535SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2536 SelectionDAG &DAG) const {
2537 EVT NarrowVT = Op.getValueType();
2538 SDValue Wide = Op.getOperand(0);
2539 EVT WideVT = Wide.getValueType();
2540 if (NarrowVT.getScalarType() == MVT::bf16) {
2541 const TargetLowering *TLI = STI.getTargetLowering();
2542 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2543 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2544 }
2545 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2546 // This combination was the first to support f32 -> bf16.
2547 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2548 if (WideVT.getScalarType() == MVT::f32) {
2549 return Op;
2550 }
2551 if (WideVT.getScalarType() == MVT::f64) {
2552 SDLoc Loc(Op);
2553 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2554 // the hardware f32 -> bf16 instruction.
2556 WideVT.isVector() ? WideVT.changeVectorElementType(MVT::f32)
2557 : MVT::f32,
2558 Wide, Loc, DAG);
2559 return DAG.getFPExtendOrRound(rod, Loc, NarrowVT);
2560 }
2561 }
2562 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2563 }
2564 }
2565
2566 // Everything else is considered legal.
2567 return Op;
2568}
2569
2570SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2571 SelectionDAG &DAG) const {
2572 SDValue Narrow = Op.getOperand(0);
2573 EVT NarrowVT = Narrow.getValueType();
2574 EVT WideVT = Op.getValueType();
2575 if (NarrowVT.getScalarType() == MVT::bf16) {
2576 if (WideVT.getScalarType() == MVT::f32 &&
2577 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2578 SDLoc Loc(Op);
2579 return DAG.getNode(ISD::BF16_TO_FP, Loc, WideVT, Narrow);
2580 }
2581 if (WideVT.getScalarType() == MVT::f64 &&
2582 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2583 EVT F32 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(MVT::f32)
2584 : MVT::f32;
2585 SDLoc Loc(Op);
2586 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2587 Op = DAG.getNode(ISD::FP_EXTEND, Loc, F32, Narrow);
2588 } else {
2589 Op = DAG.getNode(ISD::BF16_TO_FP, Loc, F32, Narrow);
2590 }
2591 return DAG.getNode(ISD::FP_EXTEND, Loc, WideVT, Op);
2592 }
2593 }
2594
2595 // Everything else is considered legal.
2596 return Op;
2597}
2598
2600 SDLoc DL(Op);
2601 if (Op.getValueType() != MVT::v2i16)
2602 return Op;
2603 EVT EltVT = Op.getValueType().getVectorElementType();
2604 SmallVector<SDValue> VecElements;
2605 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2606 SmallVector<SDValue> ScalarArgs;
2607 llvm::transform(Op->ops(), std::back_inserter(ScalarArgs),
2608 [&](const SDUse &O) {
2609 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT,
2610 O.get(), DAG.getIntPtrConstant(I, DL));
2611 });
2612 VecElements.push_back(DAG.getNode(Op.getOpcode(), DL, EltVT, ScalarArgs));
2613 }
2614 SDValue V =
2615 DAG.getNode(ISD::BUILD_VECTOR, DL, Op.getValueType(), VecElements);
2616 return V;
2617}
2618
2619SDValue
2621 switch (Op.getOpcode()) {
2622 case ISD::RETURNADDR:
2623 return SDValue();
2624 case ISD::FRAMEADDR:
2625 return SDValue();
2626 case ISD::GlobalAddress:
2627 return LowerGlobalAddress(Op, DAG);
2629 return Op;
2630 case ISD::BUILD_VECTOR:
2631 return LowerBUILD_VECTOR(Op, DAG);
2632 case ISD::BITCAST:
2633 return LowerBITCAST(Op, DAG);
2635 return Op;
2637 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
2639 return LowerINSERT_VECTOR_ELT(Op, DAG);
2641 return LowerVECTOR_SHUFFLE(Op, DAG);
2643 return LowerCONCAT_VECTORS(Op, DAG);
2644 case ISD::STORE:
2645 return LowerSTORE(Op, DAG);
2646 case ISD::LOAD:
2647 return LowerLOAD(Op, DAG);
2648 case ISD::SHL_PARTS:
2649 return LowerShiftLeftParts(Op, DAG);
2650 case ISD::SRA_PARTS:
2651 case ISD::SRL_PARTS:
2652 return LowerShiftRightParts(Op, DAG);
2653 case ISD::SELECT:
2654 return LowerSelect(Op, DAG);
2655 case ISD::FROUND:
2656 return LowerFROUND(Op, DAG);
2657 case ISD::FCOPYSIGN:
2658 return LowerFCOPYSIGN(Op, DAG);
2659 case ISD::SINT_TO_FP:
2660 case ISD::UINT_TO_FP:
2661 return LowerINT_TO_FP(Op, DAG);
2662 case ISD::FP_TO_SINT:
2663 case ISD::FP_TO_UINT:
2664 return LowerFP_TO_INT(Op, DAG);
2665 case ISD::FP_ROUND:
2666 return LowerFP_ROUND(Op, DAG);
2667 case ISD::FP_EXTEND:
2668 return LowerFP_EXTEND(Op, DAG);
2669 case ISD::BR_JT:
2670 return LowerBR_JT(Op, DAG);
2671 case ISD::VAARG:
2672 return LowerVAARG(Op, DAG);
2673 case ISD::VASTART:
2674 return LowerVASTART(Op, DAG);
2675 case ISD::ABS:
2676 case ISD::SMIN:
2677 case ISD::SMAX:
2678 case ISD::UMIN:
2679 case ISD::UMAX:
2680 case ISD::ADD:
2681 case ISD::SUB:
2682 case ISD::MUL:
2683 case ISD::SHL:
2684 case ISD::SREM:
2685 case ISD::UREM:
2686 return LowerVectorArith(Op, DAG);
2688 return LowerDYNAMIC_STACKALLOC(Op, DAG);
2689 case ISD::STACKRESTORE:
2690 return LowerSTACKRESTORE(Op, DAG);
2691 case ISD::STACKSAVE:
2692 return LowerSTACKSAVE(Op, DAG);
2693 case ISD::CopyToReg:
2694 return LowerCopyToReg_128(Op, DAG);
2695 case ISD::FADD:
2696 case ISD::FSUB:
2697 case ISD::FMUL:
2698 // Used only for bf16 on SM80, where we select fma for non-ftz operation
2699 return PromoteBinOpIfF32FTZ(Op, DAG);
2700
2701 default:
2702 llvm_unreachable("Custom lowering not defined for operation");
2703 }
2704}
2705
2706SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
2707 SDLoc DL(Op);
2708 SDValue Chain = Op.getOperand(0);
2709 const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
2710 SDValue Index = Op.getOperand(2);
2711
2712 unsigned JId = JT->getIndex();
2714 ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
2715
2716 SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
2717
2718 // Generate BrxStart node
2719 SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
2720 Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
2721
2722 // Generate BrxItem nodes
2723 assert(!MBBs.empty());
2724 for (MachineBasicBlock *MBB : MBBs.drop_back())
2725 Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
2726 DAG.getBasicBlock(MBB), Chain.getValue(1));
2727
2728 // Generate BrxEnd nodes
2729 SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
2730 IdV, Chain.getValue(1)};
2731 SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
2732
2733 return BrxEnd;
2734}
2735
2736// This will prevent AsmPrinter from trying to print the jump tables itself.
2739}
2740
2741// This function is almost a copy of SelectionDAG::expandVAArg().
2742// The only diff is that this one produces loads from local address space.
2743SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
2744 const TargetLowering *TLI = STI.getTargetLowering();
2745 SDLoc DL(Op);
2746
2747 SDNode *Node = Op.getNode();
2748 const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
2749 EVT VT = Node->getValueType(0);
2750 auto *Ty = VT.getTypeForEVT(*DAG.getContext());
2751 SDValue Tmp1 = Node->getOperand(0);
2752 SDValue Tmp2 = Node->getOperand(1);
2753 const MaybeAlign MA(Node->getConstantOperandVal(3));
2754
2755 SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
2756 Tmp1, Tmp2, MachinePointerInfo(V));
2757 SDValue VAList = VAListLoad;
2758
2759 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
2760 VAList = DAG.getNode(
2761 ISD::ADD, DL, VAList.getValueType(), VAList,
2762 DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
2763
2764 VAList = DAG.getNode(ISD::AND, DL, VAList.getValueType(), VAList,
2765 DAG.getSignedConstant(-(int64_t)MA->value(), DL,
2766 VAList.getValueType()));
2767 }
2768
2769 // Increment the pointer, VAList, to the next vaarg
2770 Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
2772 DL, VAList.getValueType()));
2773
2774 // Store the incremented VAList to the legalized pointer
2775 Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
2777
2778 const Value *SrcV =
2780
2781 // Load the actual argument out of the pointer VAList
2782 return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
2783}
2784
2785SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
2786 const TargetLowering *TLI = STI.getTargetLowering();
2787 SDLoc DL(Op);
2788 EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
2789
2790 // Store the address of unsized array <function>_vararg[] in the ap object.
2791 SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
2792 SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg);
2793
2794 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
2795 return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
2796 MachinePointerInfo(SV));
2797}
2798
2799SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
2800 SDValue Op0 = Op->getOperand(0);
2801 SDValue Op1 = Op->getOperand(1);
2802 SDValue Op2 = Op->getOperand(2);
2803 SDLoc DL(Op.getNode());
2804
2805 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
2806
2807 Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
2808 Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
2809 SDValue Select = DAG.getNode(ISD::SELECT, DL, MVT::i32, Op0, Op1, Op2);
2810 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
2811
2812 return Trunc;
2813}
2814
2815SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
2816 if (Op.getValueType() == MVT::i1)
2817 return LowerLOADi1(Op, DAG);
2818
2819 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
2820 // unaligned loads and have to handle it here.
2821 EVT VT = Op.getValueType();
2822 if (Isv2x16VT(VT) || VT == MVT::v4i8) {
2823 LoadSDNode *Load = cast<LoadSDNode>(Op);
2824 EVT MemVT = Load->getMemoryVT();
2826 MemVT, *Load->getMemOperand())) {
2827 SDValue Ops[2];
2828 std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG);
2829 return DAG.getMergeValues(Ops, SDLoc(Op));
2830 }
2831 }
2832
2833 return SDValue();
2834}
2835
2836// v = ld i1* addr
2837// =>
2838// v1 = ld i8* addr (-> i16)
2839// v = trunc i16 to i1
2840SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
2841 SDNode *Node = Op.getNode();
2842 LoadSDNode *LD = cast<LoadSDNode>(Node);
2843 SDLoc dl(Node);
2844 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
2845 assert(Node->getValueType(0) == MVT::i1 &&
2846 "Custom lowering for i1 load only");
2847 SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(),
2848 LD->getBasePtr(), LD->getPointerInfo(),
2849 MVT::i8, LD->getAlign(),
2850 LD->getMemOperand()->getFlags());
2851 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
2852 // The legalizer (the caller) is expecting two values from the legalized
2853 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
2854 // in LegalizeDAG.cpp which also uses MergeValues.
2855 SDValue Ops[] = { result, LD->getChain() };
2856 return DAG.getMergeValues(Ops, dl);
2857}
2858
2859SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
2860 StoreSDNode *Store = cast<StoreSDNode>(Op);
2861 EVT VT = Store->getMemoryVT();
2862
2863 if (VT == MVT::i1)
2864 return LowerSTOREi1(Op, DAG);
2865
2866 // v2f16 is legal, so we can't rely on legalizer to handle unaligned
2867 // stores and have to handle it here.
2868 if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
2870 VT, *Store->getMemOperand()))
2871 return expandUnalignedStore(Store, DAG);
2872
2873 // v2f16, v2bf16 and v2i16 don't need special handling.
2874 if (Isv2x16VT(VT) || VT == MVT::v4i8)
2875 return SDValue();
2876
2877 if (VT.isVector())
2878 return LowerSTOREVector(Op, DAG);
2879
2880 return SDValue();
2881}
2882
2883SDValue
2884NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2885 SDNode *N = Op.getNode();
2886 SDValue Val = N->getOperand(1);
2887 SDLoc DL(N);
2888 EVT ValVT = Val.getValueType();
2889
2890 auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
2891 if (!NumEltsAndEltVT)
2892 return SDValue();
2893 auto [NumElts, EltVT] = NumEltsAndEltVT.value();
2894
2895 MemSDNode *MemSD = cast<MemSDNode>(N);
2896 const DataLayout &TD = DAG.getDataLayout();
2897
2898 Align Alignment = MemSD->getAlign();
2899 Align PrefAlign = TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
2900 if (Alignment < PrefAlign) {
2901 // This store is not sufficiently aligned, so bail out and let this vector
2902 // store be scalarized. Note that we may still be able to emit smaller
2903 // vector stores. For example, if we are storing a <4 x float> with an
2904 // alignment of 8, this check will fail but the legalizer will try again
2905 // with 2 x <2 x float>, which will succeed with an alignment of 8.
2906 return SDValue();
2907 }
2908
2909 // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
2910 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
2911 // stored type to i16 and propagate the "real" type as the memory type.
2912 bool NeedExt = false;
2913 if (EltVT.getSizeInBits() < 16)
2914 NeedExt = true;
2915
2916 unsigned Opcode = 0;
2917 switch (NumElts) {
2918 default:
2919 return SDValue();
2920 case 2:
2921 Opcode = NVPTXISD::StoreV2;
2922 break;
2923 case 4:
2924 Opcode = NVPTXISD::StoreV4;
2925 break;
2926 }
2927
2929
2930 // First is the chain
2931 Ops.push_back(N->getOperand(0));
2932
2933 // Then the split values
2934 assert(NumElts <= ValVT.getVectorNumElements() &&
2935 "NumElts should not increase, only decrease or stay the same.");
2936 if (NumElts < ValVT.getVectorNumElements()) {
2937 // If the number of elements has decreased, getVectorLoweringShape has
2938 // upsized the element types
2939 assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 &&
2940 EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type.");
2941 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
2942 // stored as b32s
2943 unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
2944 for (unsigned i = 0; i < NumElts; ++i) {
2945 SmallVector<SDValue, 4> SubVectorElts;
2946 DAG.ExtractVectorElements(Val, SubVectorElts, i * NumEltsPerSubVector,
2947 NumEltsPerSubVector);
2948 SDValue SubVector = DAG.getBuildVector(EltVT, DL, SubVectorElts);
2949 Ops.push_back(SubVector);
2950 }
2951 } else {
2952 for (unsigned i = 0; i < NumElts; ++i) {
2953 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2954 DAG.getIntPtrConstant(i, DL));
2955 if (NeedExt)
2956 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
2957 Ops.push_back(ExtVal);
2958 }
2959 }
2960
2961 // Then any remaining arguments
2962 Ops.append(N->op_begin() + 2, N->op_end());
2963
2964 SDValue NewSt =
2965 DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
2966 MemSD->getMemoryVT(), MemSD->getMemOperand());
2967
2968 // return DCI.CombineTo(N, NewSt, true);
2969 return NewSt;
2970}
2971
2972// st i1 v, addr
2973// =>
2974// v1 = zxt v to i16
2975// st.u8 i16, addr
2976SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
2977 SDNode *Node = Op.getNode();
2978 SDLoc dl(Node);
2979 StoreSDNode *ST = cast<StoreSDNode>(Node);
2980 SDValue Tmp1 = ST->getChain();
2981 SDValue Tmp2 = ST->getBasePtr();
2982 SDValue Tmp3 = ST->getValue();
2983 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
2984 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
2985 SDValue Result =
2986 DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), MVT::i8,
2987 ST->getAlign(), ST->getMemOperand()->getFlags());
2988 return Result;
2989}
2990
2991SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
2992 SelectionDAG &DAG) const {
2993 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
2994 // operand so that it can pass the legalization.
2995
2996 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
2997 "Custom lowering for 128-bit CopyToReg only");
2998
2999 SDNode *Node = Op.getNode();
3000 SDLoc DL(Node);
3001
3002 SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
3003 SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3004 DAG.getIntPtrConstant(0, DL));
3005 SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3006 DAG.getIntPtrConstant(1, DL));
3007
3009 SmallVector<EVT, 3> ResultsType(Node->values());
3010
3011 NewOps[0] = Op->getOperand(0); // Chain
3012 NewOps[1] = Op->getOperand(1); // Dst Reg
3013 NewOps[2] = Lo; // Lower 64-bit
3014 NewOps[3] = Hi; // Higher 64-bit
3015 if (Op.getNumOperands() == 4)
3016 NewOps[4] = Op->getOperand(3); // Glue if exists
3017
3018 return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
3019}
3020
3021unsigned NVPTXTargetLowering::getNumRegisters(
3022 LLVMContext &Context, EVT VT,
3023 std::optional<MVT> RegisterVT = std::nullopt) const {
3024 if (VT == MVT::i128 && RegisterVT == MVT::i128)
3025 return 1;
3026 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
3027}
3028
3029bool NVPTXTargetLowering::splitValueIntoRegisterParts(
3030 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
3031 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
3032 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
3033 Parts[0] = Val;
3034 return true;
3035 }
3036 return false;
3037}
3038
3039// This creates target external symbol for a function parameter.
3040// Name of the symbol is composed from its index and the function name.
3041// Negative index corresponds to special parameter (unsized array) used for
3042// passing variable arguments.
3043SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
3044 EVT v) const {
3045 StringRef SavedStr = nvTM->getStrPool().save(
3047 return DAG.getTargetExternalSymbol(SavedStr.data(), v);
3048}
3049
3051 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
3052 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
3053 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
3055 const DataLayout &DL = DAG.getDataLayout();
3056 auto PtrVT = getPointerTy(DAG.getDataLayout());
3057
3058 const Function *F = &MF.getFunction();
3059 const AttributeList &PAL = F->getAttributes();
3060 const TargetLowering *TLI = STI.getTargetLowering();
3061
3062 SDValue Root = DAG.getRoot();
3063 std::vector<SDValue> OutChains;
3064
3065 bool isABI = (STI.getSmVersion() >= 20);
3066 assert(isABI && "Non-ABI compilation is not supported");
3067 if (!isABI)
3068 return Chain;
3069
3070 std::vector<Type *> argTypes;
3071 std::vector<const Argument *> theArgs;
3072 for (const Argument &I : F->args()) {
3073 theArgs.push_back(&I);
3074 argTypes.push_back(I.getType());
3075 }
3076 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
3077 // Ins.size() will be larger
3078 // * if there is an aggregate argument with multiple fields (each field
3079 // showing up separately in Ins)
3080 // * if there is a vector argument with more than typical vector-length
3081 // elements (generally if more than 4) where each vector element is
3082 // individually present in Ins.
3083 // So a different index should be used for indexing into Ins.
3084 // See similar issue in LowerCall.
3085 unsigned InsIdx = 0;
3086
3087 for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) {
3088 Type *Ty = argTypes[i];
3089
3090 if (theArgs[i]->use_empty()) {
3091 // argument is dead
3092 if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
3093 SmallVector<EVT, 16> vtparts;
3094
3095 ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
3096 if (vtparts.empty())
3097 report_fatal_error("Empty parameter types are not supported");
3098
3099 for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
3100 ++parti) {
3101 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3102 ++InsIdx;
3103 }
3104 if (vtparts.size() > 0)
3105 --InsIdx;
3106 continue;
3107 }
3108 if (Ty->isVectorTy()) {
3109 EVT ObjectVT = getValueType(DL, Ty);
3110 unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
3111 for (unsigned parti = 0; parti < NumRegs; ++parti) {
3112 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3113 ++InsIdx;
3114 }
3115 if (NumRegs > 0)
3116 --InsIdx;
3117 continue;
3118 }
3119 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3120 continue;
3121 }
3122
3123 // In the following cases, assign a node order of "i+1"
3124 // to newly created nodes. The SDNodes for params have to
3125 // appear in the same order as their order of appearance
3126 // in the original function. "i+1" holds that order.
3127 if (!PAL.hasParamAttr(i, Attribute::ByVal)) {
3128 bool aggregateIsPacked = false;
3129 if (StructType *STy = dyn_cast<StructType>(Ty))
3130 aggregateIsPacked = STy->isPacked();
3131
3134 ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
3135 if (VTs.empty())
3136 report_fatal_error("Empty parameter types are not supported");
3137
3140 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
3141
3142 SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3143 int VecIdx = -1; // Index of the first element of the current vector.
3144 for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) {
3145 if (VectorInfo[parti] & PVF_FIRST) {
3146 assert(VecIdx == -1 && "Orphaned vector.");
3147 VecIdx = parti;
3148 }
3149
3150 // That's the last element of this store op.
3151 if (VectorInfo[parti] & PVF_LAST) {
3152 unsigned NumElts = parti - VecIdx + 1;
3153 EVT EltVT = VTs[parti];
3154 // i1 is loaded/stored as i8.
3155 EVT LoadVT = EltVT;
3156 if (EltVT == MVT::i1)
3157 LoadVT = MVT::i8;
3158 else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
3159 // getLoad needs a vector type, but it can't handle
3160 // vectors which contain v2f16 or v2bf16 elements. So we must load
3161 // using i32 here and then bitcast back.
3162 LoadVT = MVT::i32;
3163
3164 EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
3165 SDValue VecAddr =
3166 DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
3167 DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
3169 EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));
3170
3171 const MaybeAlign PartAlign = [&]() -> MaybeAlign {
3172 if (aggregateIsPacked)
3173 return Align(1);
3174 if (NumElts != 1)
3175 return std::nullopt;
3176 Align PartAlign =
3177 DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
3178 return commonAlignment(PartAlign, Offsets[parti]);
3179 }();
3180 SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
3181 MachinePointerInfo(srcValue), PartAlign,
3184 if (P.getNode())
3185 P.getNode()->setIROrder(i + 1);
3186 for (unsigned j = 0; j < NumElts; ++j) {
3187 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3188 DAG.getIntPtrConstant(j, dl));
3189 // We've loaded i1 as an i8 and now must truncate it back to i1
3190 if (EltVT == MVT::i1)
3191 Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
3192 // v2f16 was loaded as an i32. Now we must bitcast it back.
3193 else if (EltVT != LoadVT)
3194 Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
3195
3196 // If a promoted integer type is used, truncate down to the original
3197 MVT PromotedVT;
3198 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
3199 Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
3200 }
3201
3202 // Extend the element if necessary (e.g. an i8 is loaded
3203 // into an i16 register)
3204 if (Ins[InsIdx].VT.isInteger() &&
3205 Ins[InsIdx].VT.getFixedSizeInBits() >
3206 LoadVT.getFixedSizeInBits()) {
3207 unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
3209 Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
3210 }
3211 InVals.push_back(Elt);
3212 }
3213
3214 // Reset vector tracking state.
3215 VecIdx = -1;
3216 }
3217 ++InsIdx;
3218 }
3219 if (VTs.size() > 0)
3220 --InsIdx;
3221 continue;
3222 }
3223
3224 // Param has ByVal attribute
3225 // Return MoveParam(param symbol).
3226 // Ideally, the param symbol can be returned directly,
3227 // but when SDNode builder decides to use it in a CopyToReg(),
3228 // machine instruction fails because TargetExternalSymbol
3229 // (not lowered) is target dependent, and CopyToReg assumes
3230 // the source is lowered.
3231 EVT ObjectVT = getValueType(DL, Ty);
3232 assert(ObjectVT == Ins[InsIdx].VT &&
3233 "Ins type did not match function type");
3234 SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3235 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3236 if (p.getNode())
3237 p.getNode()->setIROrder(i + 1);
3238 InVals.push_back(p);
3239 }
3240
3241 if (!OutChains.empty())
3242 DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
3243
3244 return Chain;
3245}
3246
3247// Use byte-store when the param adress of the return value is unaligned.
3248// This may happen when the return value is a field of a packed structure.
3250 uint64_t Offset, EVT ElementType,
3251 SDValue RetVal, const SDLoc &dl) {
3252 // Bit logic only works on integer types
3253 if (adjustElementType(ElementType))
3254 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
3255
3256 // Store each byte
3257 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
3258 // Shift the byte to the last byte position
3259 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
3260 DAG.getConstant(i * 8, dl, MVT::i32));
3261 SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
3262 ShiftVal};
3263 // Trunc store only the last byte by using
3264 // st.param.b8
3265 // The register type can be larger than b8.
3267 DAG.getVTList(MVT::Other), StoreOperands,
3268 MVT::i8, MachinePointerInfo(), std::nullopt,
3270 }
3271 return Chain;
3272}
3273
3274SDValue
3276 bool isVarArg,
3278 const SmallVectorImpl<SDValue> &OutVals,
3279 const SDLoc &dl, SelectionDAG &DAG) const {
3280 const MachineFunction &MF = DAG.getMachineFunction();
3281 const Function &F = MF.getFunction();
3283
3284 bool isABI = (STI.getSmVersion() >= 20);
3285 assert(isABI && "Non-ABI compilation is not supported");
3286 if (!isABI)
3287 return Chain;
3288
3289 const DataLayout &DL = DAG.getDataLayout();
3290 SmallVector<SDValue, 16> PromotedOutVals;
3293 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
3294 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
3295
3296 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3297 SDValue PromotedOutVal = OutVals[i];
3298 MVT PromotedVT;
3299 if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) {
3300 VTs[i] = EVT(PromotedVT);
3301 }
3302 if (PromoteScalarIntegerPTX(PromotedOutVal.getValueType(), &PromotedVT)) {
3304 Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
3305 PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal);
3306 }
3307 PromotedOutVals.push_back(PromotedOutVal);
3308 }
3309
3310 auto VectorInfo = VectorizePTXValueVTs(
3311 VTs, Offsets,
3313 : Align(1));
3314
3315 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
3316 // 32-bits are sign extended or zero extended, depending on whether
3317 // they are signed or unsigned types.
3318 bool ExtendIntegerRetVal =
3319 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
3320
3321 SmallVector<SDValue, 6> StoreOperands;
3322 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3323 SDValue OutVal = OutVals[i];
3324 SDValue RetVal = PromotedOutVals[i];
3325
3326 if (ExtendIntegerRetVal) {
3327 RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND
3329 dl, MVT::i32, RetVal);
3330 } else if (OutVal.getValueSizeInBits() < 16) {
3331 // Use 16-bit registers for small load-stores as it's the
3332 // smallest general purpose register size supported by NVPTX.
3333 RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
3334 }
3335
3336 // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
3337 // for a scalar store. In such cases, fall back to byte stores.
3338 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) {
3339 EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3340 Align ElementTypeAlign =
3341 DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext()));
3342 Align ElementAlign =
3343 commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]);
3344 if (ElementAlign < ElementTypeAlign) {
3345 assert(StoreOperands.empty() && "Orphaned operand list.");
3346 Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType,
3347 RetVal, dl);
3348
3349 // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
3350 // into the graph, so just move on to the next element.
3351 continue;
3352 }
3353 }
3354
3355 // New load/store. Record chain and offset operands.
3356 if (VectorInfo[i] & PVF_FIRST) {
3357 assert(StoreOperands.empty() && "Orphaned operand list.");
3358 StoreOperands.push_back(Chain);
3359 StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
3360 }
3361
3362 // Record the value to return.
3363 StoreOperands.push_back(RetVal);
3364
3365 // That's the last element of this store op.
3366 if (VectorInfo[i] & PVF_LAST) {
3368 unsigned NumElts = StoreOperands.size() - 2;
3369 switch (NumElts) {
3370 case 1:
3372 break;
3373 case 2:
3375 break;
3376 case 4:
3378 break;
3379 default:
3380 llvm_unreachable("Invalid vector info.");
3381 }
3382
3383 // Adjust type of load/store op if we've extended the scalar
3384 // return value.
3385 EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3386 Chain = DAG.getMemIntrinsicNode(
3387 Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
3389 // Cleanup vector state.
3390 StoreOperands.clear();
3391 }
3392 }
3393
3394 return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
3395}
3396
3398 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
3399 SelectionDAG &DAG) const {
3400 if (Constraint.size() > 1)
3401 return;
3402 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
3403}
3404
3405// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
3406// TgtMemIntrinsic
3407// because we need the information that is only available in the "Value" type
3408// of destination
3409// pointer. In particular, the address space information.
3411 IntrinsicInfo &Info, const CallInst &I,
3412 MachineFunction &MF, unsigned Intrinsic) const {
3413 switch (Intrinsic) {
3414 default:
3415 return false;
3416 case Intrinsic::nvvm_match_all_sync_i32p:
3417 case Intrinsic::nvvm_match_all_sync_i64p:
3419 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
3420 // in order to model data exchange with other threads, but perform no real
3421 // memory accesses.
3422 Info.memVT = MVT::i1;
3423
3424 // Our result depends on both our and other thread's arguments.
3426 return true;
3427 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
3428 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
3429 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
3430 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
3431 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
3432 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
3433 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
3434 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
3435 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
3436 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
3437 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
3438 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
3439 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
3440 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
3441 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
3442 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
3443 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
3444 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
3445 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
3446 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
3447 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
3448 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
3449 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
3450 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
3452 Info.memVT = MVT::v8f16;
3453 Info.ptrVal = I.getArgOperand(0);
3454 Info.offset = 0;
3456 Info.align = Align(16);
3457 return true;
3458 }
3459 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
3460 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
3461 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
3462 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
3463 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
3464 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
3465 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
3466 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
3467 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
3468 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
3469 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
3470 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
3471 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
3472 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
3473 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
3474 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
3475 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
3476 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
3477 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
3478 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
3479 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
3480 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
3481 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
3482 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
3484 Info.memVT = MVT::v2i32;
3485 Info.ptrVal = I.getArgOperand(0);
3486 Info.offset = 0;
3488 Info.align = Align(8);
3489 return true;
3490 }
3491
3492 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
3493 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
3494 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
3495 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
3496 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
3497 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
3498 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
3499 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
3500 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
3501 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
3502 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
3503 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
3504 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
3505 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
3506 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
3507 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
3508
3509 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
3510 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
3511 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
3512 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
3513 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
3514 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
3515 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
3516 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
3517 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
3518 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
3519 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
3520 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
3521 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
3522 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
3523 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
3524 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
3525 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
3526 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
3528 Info.memVT = MVT::v4i32;
3529 Info.ptrVal = I.getArgOperand(0);
3530 Info.offset = 0;
3532 Info.align = Align(16);
3533 return true;
3534 }
3535
3536 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
3537 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
3538 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
3539 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
3540 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
3541 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
3542 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
3543 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
3544
3545 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
3546 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
3547 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
3548 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
3549 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
3550 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
3551 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
3552 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
3553 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
3554 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
3555 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
3556 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
3557 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
3558 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
3559 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
3560 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
3561 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
3562 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
3563 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
3564 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
3565 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
3566 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
3568 Info.memVT = MVT::i32;
3569 Info.ptrVal = I.getArgOperand(0);
3570 Info.offset = 0;
3572 Info.align = Align(4);
3573 return true;
3574 }
3575
3576 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
3577 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
3578 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
3579 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
3580 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
3581 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
3582 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
3583 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
3584 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
3585 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
3586 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
3587 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
3589 Info.memVT = MVT::v4f16;
3590 Info.ptrVal = I.getArgOperand(0);
3591 Info.offset = 0;
3593 Info.align = Align(16);
3594 return true;
3595 }
3596
3597 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
3598 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
3599 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
3600 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
3601 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
3602 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
3603 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
3604 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
3605 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
3606 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
3607 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
3608 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
3609 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
3610 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
3611 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
3612 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
3614 Info.memVT = MVT::v8f32;
3615 Info.ptrVal = I.getArgOperand(0);
3616 Info.offset = 0;
3618 Info.align = Align(16);
3619 return true;
3620 }
3621
3622 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
3623 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
3624 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
3625 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
3626
3627 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
3628 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
3629 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
3630 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
3631
3632 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
3633 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
3634 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
3635 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
3636 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
3637 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
3638 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
3639 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
3640 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
3641 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
3642 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
3643 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
3645 Info.memVT = MVT::v8i32;
3646 Info.ptrVal = I.getArgOperand(0);
3647 Info.offset = 0;
3649 Info.align = Align(16);
3650 return true;
3651 }
3652
3653 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
3654 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
3655 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
3656 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
3657 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
3658 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
3659 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
3660 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
3661 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
3662 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
3664 Info.memVT = MVT::v2i32;
3665 Info.ptrVal = I.getArgOperand(0);
3666 Info.offset = 0;
3668 Info.align = Align(8);
3669 return true;
3670 }
3671
3672 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
3673 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
3674 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
3675 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
3676
3677 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
3678 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
3679 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
3680 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
3682 Info.memVT = MVT::f64;
3683 Info.ptrVal = I.getArgOperand(0);
3684 Info.offset = 0;
3686 Info.align = Align(8);
3687 return true;
3688 }
3689
3690 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
3691 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
3692 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
3693 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
3695 Info.memVT = MVT::v2f64;
3696 Info.ptrVal = I.getArgOperand(0);
3697 Info.offset = 0;
3699 Info.align = Align(16);
3700 return true;
3701 }
3702
3703 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
3704 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
3705 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
3706 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
3707 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
3708 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
3709 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
3710 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
3711 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
3712 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
3713 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
3714 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
3716 Info.memVT = MVT::v4f16;
3717 Info.ptrVal = I.getArgOperand(0);
3718 Info.offset = 0;
3720 Info.align = Align(16);
3721 return true;
3722 }
3723
3724 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
3725 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
3726 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
3727 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
3728 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
3729 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
3730 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
3731 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
3732 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
3733 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
3734 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
3735 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
3736 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
3737 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
3738 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
3739 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
3741 Info.memVT = MVT::v8f32;
3742 Info.ptrVal = I.getArgOperand(0);
3743 Info.offset = 0;
3745 Info.align = Align(16);
3746 return true;
3747 }
3748
3749 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
3750 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
3751 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
3752 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
3753 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
3754 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
3755 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
3756 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
3757 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
3758 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
3759 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
3760 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
3762 Info.memVT = MVT::v8i32;
3763 Info.ptrVal = I.getArgOperand(0);
3764 Info.offset = 0;
3766 Info.align = Align(16);
3767 return true;
3768 }
3769
3770 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
3771 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
3772 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
3773 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
3774 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
3775 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
3776 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
3777 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
3779 Info.memVT = MVT::v2i32;
3780 Info.ptrVal = I.getArgOperand(0);
3781 Info.offset = 0;
3783 Info.align = Align(8);
3784 return true;
3785 }
3786
3787 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
3788 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
3789 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
3790 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
3792 Info.memVT = MVT::v2f64;
3793 Info.ptrVal = I.getArgOperand(0);
3794 Info.offset = 0;
3796 Info.align = Align(16);
3797 return true;
3798 }
3799
3800 case Intrinsic::nvvm_atomic_load_inc_32:
3801 case Intrinsic::nvvm_atomic_load_dec_32:
3802
3803 case Intrinsic::nvvm_atomic_add_gen_f_cta:
3804 case Intrinsic::nvvm_atomic_add_gen_f_sys:
3805 case Intrinsic::nvvm_atomic_add_gen_i_cta:
3806 case Intrinsic::nvvm_atomic_add_gen_i_sys:
3807 case Intrinsic::nvvm_atomic_and_gen_i_cta:
3808 case Intrinsic::nvvm_atomic_and_gen_i_sys:
3809 case Intrinsic::nvvm_atomic_cas_gen_i_cta:
3810 case Intrinsic::nvvm_atomic_cas_gen_i_sys:
3811 case Intrinsic::nvvm_atomic_dec_gen_i_cta:
3812 case Intrinsic::nvvm_atomic_dec_gen_i_sys:
3813 case Intrinsic::nvvm_atomic_inc_gen_i_cta:
3814 case Intrinsic::nvvm_atomic_inc_gen_i_sys:
3815 case Intrinsic::nvvm_atomic_max_gen_i_cta:
3816 case Intrinsic::nvvm_atomic_max_gen_i_sys:
3817 case Intrinsic::nvvm_atomic_min_gen_i_cta:
3818 case Intrinsic::nvvm_atomic_min_gen_i_sys:
3819 case Intrinsic::nvvm_atomic_or_gen_i_cta:
3820 case Intrinsic::nvvm_atomic_or_gen_i_sys:
3821 case Intrinsic::nvvm_atomic_exch_gen_i_cta:
3822 case Intrinsic::nvvm_atomic_exch_gen_i_sys:
3823 case Intrinsic::nvvm_atomic_xor_gen_i_cta:
3824 case Intrinsic::nvvm_atomic_xor_gen_i_sys: {
3825 auto &DL = I.getDataLayout();
3827 Info.memVT = getValueType(DL, I.getType());
3828 Info.ptrVal = I.getArgOperand(0);
3829 Info.offset = 0;
3831 Info.align.reset();
3832 return true;
3833 }
3834
3835 case Intrinsic::nvvm_ldu_global_i:
3836 case Intrinsic::nvvm_ldu_global_f:
3837 case Intrinsic::nvvm_ldu_global_p: {
3838 auto &DL = I.getDataLayout();
3840 if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
3841 Info.memVT = getValueType(DL, I.getType());
3842 else if(Intrinsic == Intrinsic::nvvm_ldu_global_p)
3843 Info.memVT = getPointerTy(DL);
3844 else
3845 Info.memVT = getValueType(DL, I.getType());
3846 Info.ptrVal = I.getArgOperand(0);
3847 Info.offset = 0;
3849 Info.align = cast<ConstantInt>(I.getArgOperand(1))->getMaybeAlignValue();
3850
3851 return true;
3852 }
3853 case Intrinsic::nvvm_tex_1d_v4f32_s32:
3854 case Intrinsic::nvvm_tex_1d_v4f32_f32:
3855 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
3856 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
3857 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
3858 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
3859 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
3860 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
3861 case Intrinsic::nvvm_tex_2d_v4f32_s32:
3862 case Intrinsic::nvvm_tex_2d_v4f32_f32:
3863 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
3864 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
3865 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
3866 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
3867 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
3868 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
3869 case Intrinsic::nvvm_tex_3d_v4f32_s32:
3870 case Intrinsic::nvvm_tex_3d_v4f32_f32:
3871 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
3872 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
3873 case Intrinsic::nvvm_tex_cube_v4f32_f32:
3874 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
3875 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
3876 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
3877 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
3878 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
3879 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
3880 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
3881 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
3882 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
3883 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
3884 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
3885 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
3886 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
3887 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
3888 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
3889 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
3890 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
3891 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
3892 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
3893 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
3894 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
3895 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
3896 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
3897 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
3898 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
3899 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
3900 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
3901 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
3902 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
3903 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
3904 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
3905 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
3906 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
3907 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
3908 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
3909 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
3910 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
3912 Info.memVT = MVT::v4f32;
3913 Info.ptrVal = nullptr;
3914 Info.offset = 0;
3916 Info.align = Align(16);
3917 return true;
3918
3919 case Intrinsic::nvvm_tex_1d_v4s32_s32:
3920 case Intrinsic::nvvm_tex_1d_v4s32_f32:
3921 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
3922 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
3923 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
3924 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
3925 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
3926 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
3927 case Intrinsic::nvvm_tex_2d_v4s32_s32:
3928 case Intrinsic::nvvm_tex_2d_v4s32_f32:
3929 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
3930 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
3931 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
3932 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
3933 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
3934 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
3935 case Intrinsic::nvvm_tex_3d_v4s32_s32:
3936 case Intrinsic::nvvm_tex_3d_v4s32_f32:
3937 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
3938 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
3939 case Intrinsic::nvvm_tex_cube_v4s32_f32:
3940 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
3941 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
3942 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
3943 case Intrinsic::nvvm_tex_cube_v4u32_f32:
3944 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
3945 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
3946 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
3947 case Intrinsic::nvvm_tex_1d_v4u32_s32:
3948 case Intrinsic::nvvm_tex_1d_v4u32_f32:
3949 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
3950 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
3951 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
3952 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
3953 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
3954 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
3955 case Intrinsic::nvvm_tex_2d_v4u32_s32:
3956 case Intrinsic::nvvm_tex_2d_v4u32_f32:
3957 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
3958 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
3959 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
3960 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
3961 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
3962 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
3963 case Intrinsic::nvvm_tex_3d_v4u32_s32:
3964 case Intrinsic::nvvm_tex_3d_v4u32_f32:
3965 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
3966 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
3967 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
3968 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
3969 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
3970 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
3971 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
3972 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
3973 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
3974 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
3975 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
3976 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
3977 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
3978 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
3979 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
3980 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
3981 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
3982 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
3983 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
3984 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
3985 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
3986 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
3987 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
3988 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
3989 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
3990 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
3991 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
3992 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
3993 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
3994 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
3995 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
3996 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
3997 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
3998 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
3999 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4000 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4001 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4002 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4003 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4004 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4005 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4006 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4007 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4008 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4009 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4010 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4011 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4012 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4013 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4014 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4015 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4016 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4017 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4018 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4019 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4020 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4021 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4022 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4023 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4024 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4025 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4026 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4027 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4028 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4029 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4030 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4031 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4032 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4033 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4034 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4036 Info.memVT = MVT::v4i32;
4037 Info.ptrVal = nullptr;
4038 Info.offset = 0;
4040 Info.align = Align(16);
4041 return true;
4042
4043 case Intrinsic::nvvm_suld_1d_i8_clamp:
4044 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4045 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4046 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4047 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4048 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4049 case Intrinsic::nvvm_suld_2d_i8_clamp:
4050 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4051 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4052 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4053 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4054 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4055 case Intrinsic::nvvm_suld_3d_i8_clamp:
4056 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4057 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4058 case Intrinsic::nvvm_suld_1d_i8_trap:
4059 case Intrinsic::nvvm_suld_1d_v2i8_trap:
4060 case Intrinsic::nvvm_suld_1d_v4i8_trap:
4061 case Intrinsic::nvvm_suld_1d_array_i8_trap:
4062 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
4063 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
4064 case Intrinsic::nvvm_suld_2d_i8_trap:
4065 case Intrinsic::nvvm_suld_2d_v2i8_trap:
4066 case Intrinsic::nvvm_suld_2d_v4i8_trap:
4067 case Intrinsic::nvvm_suld_2d_array_i8_trap:
4068 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
4069 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
4070 case Intrinsic::nvvm_suld_3d_i8_trap:
4071 case Intrinsic::nvvm_suld_3d_v2i8_trap:
4072 case Intrinsic::nvvm_suld_3d_v4i8_trap:
4073 case Intrinsic::nvvm_suld_1d_i8_zero:
4074 case Intrinsic::nvvm_suld_1d_v2i8_zero:
4075 case Intrinsic::nvvm_suld_1d_v4i8_zero:
4076 case Intrinsic::nvvm_suld_1d_array_i8_zero:
4077 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
4078 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
4079 case Intrinsic::nvvm_suld_2d_i8_zero:
4080 case Intrinsic::nvvm_suld_2d_v2i8_zero:
4081 case Intrinsic::nvvm_suld_2d_v4i8_zero:
4082 case Intrinsic::nvvm_suld_2d_array_i8_zero:
4083 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
4084 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
4085 case Intrinsic::nvvm_suld_3d_i8_zero:
4086 case Intrinsic::nvvm_suld_3d_v2i8_zero:
4087 case Intrinsic::nvvm_suld_3d_v4i8_zero:
4089 Info.memVT = MVT::i8;
4090 Info.ptrVal = nullptr;
4091 Info.offset = 0;
4093 Info.align = Align(16);
4094 return true;
4095
4096 case Intrinsic::nvvm_suld_1d_i16_clamp:
4097 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
4098 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
4099 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
4100 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
4101 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
4102 case Intrinsic::nvvm_suld_2d_i16_clamp:
4103 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
4104 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
4105 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
4106 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
4107 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
4108 case Intrinsic::nvvm_suld_3d_i16_clamp:
4109 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
4110 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
4111 case Intrinsic::nvvm_suld_1d_i16_trap:
4112 case Intrinsic::nvvm_suld_1d_v2i16_trap:
4113 case Intrinsic::nvvm_suld_1d_v4i16_trap:
4114 case Intrinsic::nvvm_suld_1d_array_i16_trap:
4115 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
4116 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
4117 case Intrinsic::nvvm_suld_2d_i16_trap:
4118 case Intrinsic::nvvm_suld_2d_v2i16_trap:
4119 case Intrinsic::nvvm_suld_2d_v4i16_trap:
4120 case Intrinsic::nvvm_suld_2d_array_i16_trap:
4121 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
4122 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
4123 case Intrinsic::nvvm_suld_3d_i16_trap:
4124 case Intrinsic::nvvm_suld_3d_v2i16_trap:
4125 case Intrinsic::nvvm_suld_3d_v4i16_trap:
4126 case Intrinsic::nvvm_suld_1d_i16_zero:
4127 case Intrinsic::nvvm_suld_1d_v2i16_zero:
4128 case Intrinsic::nvvm_suld_1d_v4i16_zero:
4129 case Intrinsic::nvvm_suld_1d_array_i16_zero:
4130 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
4131 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
4132 case Intrinsic::nvvm_suld_2d_i16_zero:
4133 case Intrinsic::nvvm_suld_2d_v2i16_zero:
4134 case Intrinsic::nvvm_suld_2d_v4i16_zero:
4135 case Intrinsic::nvvm_suld_2d_array_i16_zero:
4136 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
4137 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
4138 case Intrinsic::nvvm_suld_3d_i16_zero:
4139 case Intrinsic::nvvm_suld_3d_v2i16_zero:
4140 case Intrinsic::nvvm_suld_3d_v4i16_zero:
4142 Info.memVT = MVT::i16;
4143 Info.ptrVal = nullptr;
4144 Info.offset = 0;
4146 Info.align = Align(16);
4147 return true;
4148
4149 case Intrinsic::nvvm_suld_1d_i32_clamp:
4150 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
4151 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
4152 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
4153 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
4154 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
4155 case Intrinsic::nvvm_suld_2d_i32_clamp:
4156 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
4157 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
4158 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
4159 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
4160 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
4161 case Intrinsic::nvvm_suld_3d_i32_clamp:
4162 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
4163 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
4164 case Intrinsic::nvvm_suld_1d_i32_trap:
4165 case Intrinsic::nvvm_suld_1d_v2i32_trap:
4166 case Intrinsic::nvvm_suld_1d_v4i32_trap:
4167 case Intrinsic::nvvm_suld_1d_array_i32_trap:
4168 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
4169 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
4170 case Intrinsic::nvvm_suld_2d_i32_trap:
4171 case Intrinsic::nvvm_suld_2d_v2i32_trap:
4172 case Intrinsic::nvvm_suld_2d_v4i32_trap:
4173 case Intrinsic::nvvm_suld_2d_array_i32_trap:
4174 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
4175 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
4176 case Intrinsic::nvvm_suld_3d_i32_trap:
4177 case Intrinsic::nvvm_suld_3d_v2i32_trap:
4178 case Intrinsic::nvvm_suld_3d_v4i32_trap:
4179 case Intrinsic::nvvm_suld_1d_i32_zero:
4180 case Intrinsic::nvvm_suld_1d_v2i32_zero:
4181 case Intrinsic::nvvm_suld_1d_v4i32_zero:
4182 case Intrinsic::nvvm_suld_1d_array_i32_zero:
4183 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
4184 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
4185 case Intrinsic::nvvm_suld_2d_i32_zero:
4186 case Intrinsic::nvvm_suld_2d_v2i32_zero:
4187 case Intrinsic::nvvm_suld_2d_v4i32_zero:
4188 case Intrinsic::nvvm_suld_2d_array_i32_zero:
4189 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
4190 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
4191 case Intrinsic::nvvm_suld_3d_i32_zero:
4192 case Intrinsic::nvvm_suld_3d_v2i32_zero:
4193 case Intrinsic::nvvm_suld_3d_v4i32_zero:
4195 Info.memVT = MVT::i32;
4196 Info.ptrVal = nullptr;
4197 Info.offset = 0;
4199 Info.align = Align(16);
4200 return true;
4201
4202 case Intrinsic::nvvm_suld_1d_i64_clamp:
4203 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
4204 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
4205 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
4206 case Intrinsic::nvvm_suld_2d_i64_clamp:
4207 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
4208 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
4209 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
4210 case Intrinsic::nvvm_suld_3d_i64_clamp:
4211 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
4212 case Intrinsic::nvvm_suld_1d_i64_trap:
4213 case Intrinsic::nvvm_suld_1d_v2i64_trap:
4214 case Intrinsic::nvvm_suld_1d_array_i64_trap:
4215 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
4216 case Intrinsic::nvvm_suld_2d_i64_trap:
4217 case Intrinsic::nvvm_suld_2d_v2i64_trap:
4218 case Intrinsic::nvvm_suld_2d_array_i64_trap:
4219 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
4220 case Intrinsic::nvvm_suld_3d_i64_trap:
4221 case Intrinsic::nvvm_suld_3d_v2i64_trap:
4222 case Intrinsic::nvvm_suld_1d_i64_zero:
4223 case Intrinsic::nvvm_suld_1d_v2i64_zero:
4224 case Intrinsic::nvvm_suld_1d_array_i64_zero:
4225 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
4226 case Intrinsic::nvvm_suld_2d_i64_zero:
4227 case Intrinsic::nvvm_suld_2d_v2i64_zero:
4228 case Intrinsic::nvvm_suld_2d_array_i64_zero:
4229 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
4230 case Intrinsic::nvvm_suld_3d_i64_zero:
4231 case Intrinsic::nvvm_suld_3d_v2i64_zero:
4233 Info.memVT = MVT::i64;
4234 Info.ptrVal = nullptr;
4235 Info.offset = 0;
4237 Info.align = Align(16);
4238 return true;
4239 }
4240 return false;
4241}
4242
4243/// getFunctionParamOptimizedAlign - since function arguments are passed via
4244/// .param space, we may want to increase their alignment in a way that
4245/// ensures that we can effectively vectorize their loads & stores. We can
4246/// increase alignment only if the function has internal or has private
4247/// linkage as for other linkage types callers may already rely on default
4248/// alignment. To allow using 128-bit vectorized loads/stores, this function
4249/// ensures that alignment is 16 or greater.
4251 const Function *F, Type *ArgTy, const DataLayout &DL) const {
4252 // Capping the alignment to 128 bytes as that is the maximum alignment
4253 // supported by PTX.
4254 const Align ABITypeAlign = std::min(Align(128), DL.getABITypeAlign(ArgTy));
4255
4256 // If a function has linkage different from internal or private, we
4257 // must use default ABI alignment as external users rely on it. Same
4258 // for a function that may be called from a function pointer.
4259 if (!F || !F->hasLocalLinkage() ||
4260 F->hasAddressTaken(/*Users=*/nullptr,
4261 /*IgnoreCallbackUses=*/false,
4262 /*IgnoreAssumeLikeCalls=*/true,
4263 /*IgnoreLLVMUsed=*/true))
4264 return ABITypeAlign;
4265
4266 assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");
4267 return std::max(Align(16), ABITypeAlign);
4268}
4269
4270/// Helper for computing alignment of a device function byval parameter.
4272 const Function *F, Type *ArgTy, Align InitialAlign,
4273 const DataLayout &DL) const {
4274 Align ArgAlign = InitialAlign;
4275 // Try to increase alignment to enhance vectorization options.
4276 if (F)
4277 ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
4278
4279 // Old ptx versions have a bug. When PTX code takes address of
4280 // byval parameter with alignment < 4, ptxas generates code to
4281 // spill argument into memory. Alas on sm_50+ ptxas generates
4282 // SASS code that fails with misaligned access. To work around
4283 // the problem, make sure that we align byval parameters by at
4284 // least 4. This bug seems to be fixed at least starting from
4285 // ptxas > 9.0.
4286 // TODO: remove this after verifying the bug is not reproduced
4287 // on non-deprecated ptxas versions.
4289 ArgAlign = std::max(ArgAlign, Align(4));
4290
4291 return ArgAlign;
4292}
4293
4294// Helper for getting a function parameter name. Name is composed from
4295// its index and the function name. Negative index corresponds to special
4296// parameter (unsized array) used for passing variable arguments.
4298 int Idx) const {
4299 std::string ParamName;
4300 raw_string_ostream ParamStr(ParamName);
4301
4302 ParamStr << getTargetMachine().getSymbol(F)->getName();
4303 if (Idx < 0)
4304 ParamStr << "_vararg";
4305 else
4306 ParamStr << "_param_" << Idx;
4307
4308 return ParamName;
4309}
4310
4311/// isLegalAddressingMode - Return true if the addressing mode represented
4312/// by AM is legal for this target, for a load/store of the specified type.
4313/// Used to guide target specific optimizations, like loop strength reduction
4314/// (LoopStrengthReduce.cpp) and memory optimization for address mode
4315/// (CodeGenPrepare.cpp)
4317 const AddrMode &AM, Type *Ty,
4318 unsigned AS, Instruction *I) const {
4319 // AddrMode - This represents an addressing mode of:
4320 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
4321 //
4322 // The legal address modes are
4323 // - [avar]
4324 // - [areg]
4325 // - [areg+immoff]
4326 // - [immAddr]
4327
4328 // immoff must fit in a signed 32-bit int
4329 if (!APInt(64, AM.BaseOffs).isSignedIntN(32))
4330 return false;
4331
4332 if (AM.BaseGV)
4333 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
4334
4335 switch (AM.Scale) {
4336 case 0: // "r", "r+i" or "i" is allowed
4337 break;
4338 case 1:
4339 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
4340 return false;
4341 // Otherwise we have r+i.
4342 break;
4343 default:
4344 // No scale > 1 is allowed
4345 return false;
4346 }
4347 return true;
4348}
4349
4350//===----------------------------------------------------------------------===//
4351// NVPTX Inline Assembly Support
4352//===----------------------------------------------------------------------===//
4353
4354/// getConstraintType - Given a constraint letter, return the type of
4355/// constraint it is for this target.
4358 if (Constraint.size() == 1) {
4359 switch (Constraint[0]) {
4360 default:
4361 break;
4362 case 'b':
4363 case 'r':
4364 case 'h':
4365 case 'c':
4366 case 'l':
4367 case 'f':
4368 case 'd':
4369 case 'q':
4370 case '0':
4371 case 'N':
4372 return C_RegisterClass;
4373 }
4374 }
4375 return TargetLowering::getConstraintType(Constraint);
4376}
4377
4378std::pair<unsigned, const TargetRegisterClass *>
4380 StringRef Constraint,
4381 MVT VT) const {
4382 if (Constraint.size() == 1) {
4383 switch (Constraint[0]) {
4384 case 'b':
4385 return std::make_pair(0U, &NVPTX::Int1RegsRegClass);
4386 case 'c':
4387 return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
4388 case 'h':
4389 return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
4390 case 'r':
4391 return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
4392 case 'l':
4393 case 'N':
4394 return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
4395 case 'q': {
4396 if (STI.getSmVersion() < 70)
4397 report_fatal_error("Inline asm with 128 bit operands is only "
4398 "supported for sm_70 and higher!");
4399 return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
4400 }
4401 case 'f':
4402 return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
4403 case 'd':
4404 return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
4405 }
4406 }
4407 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
4408}
4409
4410//===----------------------------------------------------------------------===//
4411// NVPTX DAG Combining
4412//===----------------------------------------------------------------------===//
4413
4415 CodeGenOptLevel OptLevel) const {
4416 // Always honor command-line argument
4417 if (FMAContractLevelOpt.getNumOccurrences() > 0)
4418 return FMAContractLevelOpt > 0;
4419
4420 // Do not contract if we're not optimizing the code.
4421 if (OptLevel == CodeGenOptLevel::None)
4422 return false;
4423
4424 // Honor TargetOptions flags that explicitly say fusion is okay.
4426 return true;
4427
4428 return allowUnsafeFPMath(MF);
4429}
4430
4432 // Honor TargetOptions flags that explicitly say unsafe math is okay.
4434 return true;
4435
4436 // Allow unsafe math if unsafe-fp-math attribute explicitly says so.
4437 const Function &F = MF.getFunction();
4438 return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
4439}
4440
4441static bool isConstZero(const SDValue &Operand) {
4442 const auto *Const = dyn_cast<ConstantSDNode>(Operand);
4443 return Const && Const->getZExtValue() == 0;
4444}
4445
4446/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
4447/// operands N0 and N1. This is a helper for PerformADDCombine that is
4448/// called with the default operands, and if that fails, with commuted
4449/// operands.
4450static SDValue
4453 EVT VT = N0.getValueType();
4454
4455 // Since integer multiply-add costs the same as integer multiply
4456 // but is more costly than integer add, do the fusion only when
4457 // the mul is only used in the add.
4458 // TODO: this may not be true for later architectures, consider relaxing this
4459 if (!N0.getNode()->hasOneUse())
4460 return SDValue();
4461
4462 // fold (add (select cond, 0, (mul a, b)), c)
4463 // -> (select cond, c, (add (mul a, b), c))
4464 //
4465 if (N0.getOpcode() == ISD::SELECT) {
4466 unsigned ZeroOpNum;
4467 if (isConstZero(N0->getOperand(1)))
4468 ZeroOpNum = 1;
4469 else if (isConstZero(N0->getOperand(2)))
4470 ZeroOpNum = 2;
4471 else
4472 return SDValue();
4473
4474 SDValue M = N0->getOperand((ZeroOpNum == 1) ? 2 : 1);
4475 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
4476 return SDValue();
4477
4478 SDLoc DL(N);
4479 SDValue Mul =
4480 DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1));
4481 SDValue MAD = DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, N1);
4482 return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
4483 ((ZeroOpNum == 1) ? N1 : MAD),
4484 ((ZeroOpNum == 1) ? MAD : N1));
4485 }
4486
4487 return SDValue();
4488}
4489
4490static SDValue
4493 CodeGenOptLevel OptLevel) {
4494 EVT VT = N0.getValueType();
4495 if (N0.getOpcode() == ISD::FMUL) {
4496 const auto *TLI = static_cast<const NVPTXTargetLowering *>(
4497 &DCI.DAG.getTargetLoweringInfo());
4498 if (!TLI->allowFMA(DCI.DAG.getMachineFunction(), OptLevel))
4499 return SDValue();
4500
4501 // For floating point:
4502 // Do the fusion only when the mul has less than 5 uses and all
4503 // are add.
4504 // The heuristic is that if a use is not an add, then that use
4505 // cannot be fused into fma, therefore mul is still needed anyway.
4506 // If there are more than 4 uses, even if they are all add, fusing
4507 // them will increase register pressue.
4508 //
4509 int numUses = 0;
4510 int nonAddCount = 0;
4511 for (const SDNode *User : N0.getNode()->users()) {
4512 numUses++;
4513 if (User->getOpcode() != ISD::FADD)
4514 ++nonAddCount;
4515 if (numUses >= 5)
4516 return SDValue();
4517 }
4518 if (nonAddCount) {
4519 int orderNo = N->getIROrder();
4520 int orderNo2 = N0.getNode()->getIROrder();
4521 // simple heuristics here for considering potential register
4522 // pressure, the logics here is that the differnce are used
4523 // to measure the distance between def and use, the longer distance
4524 // more likely cause register pressure.
4525 if (orderNo - orderNo2 < 500)
4526 return SDValue();
4527
4528 // Now, check if at least one of the FMUL's operands is live beyond the
4529 // node N, which guarantees that the FMA will not increase register
4530 // pressure at node N.
4531 bool opIsLive = false;
4532 const SDNode *left = N0.getOperand(0).getNode();
4533 const SDNode *right = N0.getOperand(1).getNode();
4534
4535 if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
4536 opIsLive = true;
4537
4538 if (!opIsLive)
4539 for (const SDNode *User : left->users()) {
4540 int orderNo3 = User->getIROrder();
4541 if (orderNo3 > orderNo) {
4542 opIsLive = true;
4543 break;
4544 }
4545 }
4546
4547 if (!opIsLive)
4548 for (const SDNode *User : right->users()) {
4549 int orderNo3 = User->getIROrder();
4550 if (orderNo3 > orderNo) {
4551 opIsLive = true;
4552 break;
4553 }
4554 }
4555
4556 if (!opIsLive)
4557 return SDValue();
4558 }
4559
4560 return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0),
4561 N0.getOperand(1), N1);
4562 }
4563
4564 return SDValue();
4565}
4566
4567static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
4568 std::size_t Back) {
4569 if (all_of(N->ops().drop_front(Front).drop_back(Back),
4570 [](const SDUse &U) { return U.get()->isUndef(); }))
4571 // Operand 0 is the previous value in the chain. Cannot return EntryToken
4572 // as the previous value will become unused and eliminated later.
4573 return N->getOperand(0);
4574
4575 return SDValue();
4576}
4577
4579 // Operands from the 3rd to the 2nd last one are the values to be stored.
4580 // {Chain, ArgID, Offset, Val, Glue}
4581 return PerformStoreCombineHelper(N, 3, 1);
4582}
4583
4585 // Operands from the 2nd to the last one are the values to be stored
4586 return PerformStoreCombineHelper(N, 2, 0);
4587}
4588
4589/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
4590///
4593 CodeGenOptLevel OptLevel) {
4594 if (OptLevel == CodeGenOptLevel::None)
4595 return SDValue();
4596
4597 SDValue N0 = N->getOperand(0);
4598 SDValue N1 = N->getOperand(1);
4599
4600 // Skip non-integer, non-scalar case
4601 EVT VT = N0.getValueType();
4602 if (VT.isVector() || VT != MVT::i32)
4603 return SDValue();
4604
4605 // First try with the default operand order.
4606 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
4607 return Result;
4608
4609 // If that didn't work, try again with the operands commuted.
4610 return PerformADDCombineWithOperands(N, N1, N0, DCI);
4611}
4612
4613/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
4614///
4617 CodeGenOptLevel OptLevel) {
4618 SDValue N0 = N->getOperand(0);
4619 SDValue N1 = N->getOperand(1);
4620
4621 EVT VT = N0.getValueType();
4622 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
4623 return SDValue();
4624
4625 // First try with the default operand order.
4626 if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
4627 return Result;
4628
4629 // If that didn't work, try again with the operands commuted.
4630 return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
4631}
4632
4635 // The type legalizer turns a vector load of i8 values into a zextload to i16
4636 // registers, optionally ANY_EXTENDs it (if target type is integer),
4637 // and ANDs off the high 8 bits. Since we turn this load into a
4638 // target-specific DAG node, the DAG combiner fails to eliminate these AND
4639 // nodes. Do that here.
4640 SDValue Val = N->getOperand(0);
4641 SDValue Mask = N->getOperand(1);
4642
4643 if (isa<ConstantSDNode>(Val)) {
4644 std::swap(Val, Mask);
4645 }
4646
4647 SDValue AExt;
4648
4649 // Convert BFE-> truncate i16 -> and 255
4650 // To just BFE-> truncate i16, as the value already has all the bits in the
4651 // right places.
4652 if (Val.getOpcode() == ISD::TRUNCATE) {
4653 SDValue BFE = Val.getOperand(0);
4654 if (BFE.getOpcode() != NVPTXISD::BFE)
4655 return SDValue();
4656
4657 ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(BFE.getOperand(0));
4658 if (!BFEBits)
4659 return SDValue();
4660 uint64_t BFEBitsVal = BFEBits->getZExtValue();
4661
4662 ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
4663 if (!MaskCnst) {
4664 // Not an AND with a constant
4665 return SDValue();
4666 }
4667 uint64_t MaskVal = MaskCnst->getZExtValue();
4668
4669 if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1)
4670 return SDValue();
4671 // If we get here, the AND is unnecessary. Just replace it with the trunc
4672 DCI.CombineTo(N, Val, false);
4673 }
4674 // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
4675 if (Val.getOpcode() == ISD::ANY_EXTEND) {
4676 AExt = Val;
4677 Val = Val->getOperand(0);
4678 }
4679
4680 if (Val->getOpcode() == NVPTXISD::LoadV2 ||
4681 Val->getOpcode() == NVPTXISD::LoadV4) {
4682 ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
4683 if (!MaskCnst) {
4684 // Not an AND with a constant
4685 return SDValue();
4686 }
4687
4688 uint64_t MaskVal = MaskCnst->getZExtValue();
4689 if (MaskVal != 0xff) {
4690 // Not an AND that chops off top 8 bits
4691 return SDValue();
4692 }
4693
4694 MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
4695 if (!Mem) {
4696 // Not a MemSDNode?!?
4697 return SDValue();
4698 }
4699
4700 EVT MemVT = Mem->getMemoryVT();
4701 if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
4702 // We only handle the i8 case
4703 return SDValue();
4704 }
4705
4706 unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
4707 if (ExtType == ISD::SEXTLOAD) {
4708 // If for some reason the load is a sextload, the and is needed to zero
4709 // out the high 8 bits
4710 return SDValue();
4711 }
4712
4713 bool AddTo = false;
4714 if (AExt.getNode() != nullptr) {
4715 // Re-insert the ext as a zext.
4716 Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
4717 AExt.getValueType(), Val);
4718 AddTo = true;
4719 }
4720
4721 // If we get here, the AND is unnecessary. Just replace it with the load
4722 DCI.CombineTo(N, Val, AddTo);
4723 }
4724
4725 return SDValue();
4726}
4727
4730 CodeGenOptLevel OptLevel) {
4731 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
4732
4733 // Don't do anything at less than -O2.
4734 if (OptLevel < CodeGenOptLevel::Default)
4735 return SDValue();
4736
4737 SelectionDAG &DAG = DCI.DAG;
4738 SDLoc DL(N);
4739 EVT VT = N->getValueType(0);
4740 bool IsSigned = N->getOpcode() == ISD::SREM;
4741 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
4742
4743 const SDValue &Num = N->getOperand(0);
4744 const SDValue &Den = N->getOperand(1);
4745
4746 for (const SDNode *U : Num->users()) {
4747 if (U->getOpcode() == DivOpc && U->getOperand(0) == Num &&
4748 U->getOperand(1) == Den) {
4749 // Num % Den -> Num - (Num / Den) * Den
4750 return DAG.getNode(ISD::SUB, DL, VT, Num,
4751 DAG.getNode(ISD::MUL, DL, VT,
4752 DAG.getNode(DivOpc, DL, VT, Num, Den),
4753 Den));
4754 }
4755 }
4756 return SDValue();
4757}
4758
4762 Unknown
4764
4765/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
4766/// that can be demoted to \p OptSize bits without loss of information. The
4767/// signedness of the operand, if determinable, is placed in \p S.
4769 unsigned OptSize,
4770 OperandSignedness &S) {
4771 S = Unknown;
4772
4773 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
4774 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4775 EVT OrigVT = Op.getOperand(0).getValueType();
4776 if (OrigVT.getFixedSizeInBits() <= OptSize) {
4777 S = Signed;
4778 return true;
4779 }
4780 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
4781 EVT OrigVT = Op.getOperand(0).getValueType();
4782 if (OrigVT.getFixedSizeInBits() <= OptSize) {
4783 S = Unsigned;
4784 return true;
4785 }
4786 }
4787
4788 return false;
4789}
4790
4791/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
4792/// be demoted to \p OptSize bits without loss of information. If the operands
4793/// contain a constant, it should appear as the RHS operand. The signedness of
4794/// the operands is placed in \p IsSigned.
4796 unsigned OptSize,
4797 bool &IsSigned) {
4798 OperandSignedness LHSSign;
4799
4800 // The LHS operand must be a demotable op
4801 if (!IsMulWideOperandDemotable(LHS, OptSize, LHSSign))
4802 return false;
4803
4804 // We should have been able to determine the signedness from the LHS
4805 if (LHSSign == Unknown)
4806 return false;
4807
4808 IsSigned = (LHSSign == Signed);
4809
4810 // The RHS can be a demotable op or a constant
4811 if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(RHS)) {
4812 const APInt &Val = CI->getAPIntValue();
4813 if (LHSSign == Unsigned) {
4814 return Val.isIntN(OptSize);
4815 } else {
4816 return Val.isSignedIntN(OptSize);
4817 }
4818 } else {
4819 OperandSignedness RHSSign;
4820 if (!IsMulWideOperandDemotable(RHS, OptSize, RHSSign))
4821 return false;
4822
4823 return LHSSign == RHSSign;
4824 }
4825}
4826
4827/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
4828/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
4829/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
4830/// amount.
4833 EVT MulType = N->getValueType(0);
4834 if (MulType != MVT::i32 && MulType != MVT::i64) {
4835 return SDValue();
4836 }
4837
4838 SDLoc DL(N);
4839 unsigned OptSize = MulType.getSizeInBits() >> 1;
4840 SDValue LHS = N->getOperand(0);
4841 SDValue RHS = N->getOperand(1);
4842
4843 // Canonicalize the multiply so the constant (if any) is on the right
4844 if (N->getOpcode() == ISD::MUL) {
4845 if (isa<ConstantSDNode>(LHS)) {
4846 std::swap(LHS, RHS);
4847 }
4848 }
4849
4850 // If we have a SHL, determine the actual multiply amount
4851 if (N->getOpcode() == ISD::SHL) {
4852 ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(RHS);
4853 if (!ShlRHS) {
4854 return SDValue();
4855 }
4856
4857 APInt ShiftAmt = ShlRHS->getAPIntValue();
4858 unsigned BitWidth = MulType.getSizeInBits();
4859 if (ShiftAmt.sge(0) && ShiftAmt.slt(BitWidth)) {
4860 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
4861 RHS = DCI.DAG.getConstant(MulVal, DL, MulType);
4862 } else {
4863 return SDValue();
4864 }
4865 }
4866
4867 bool Signed;
4868 // Verify that our operands are demotable
4869 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, Signed)) {
4870 return SDValue();
4871 }
4872
4873 EVT DemotedVT;
4874 if (MulType == MVT::i32) {
4875 DemotedVT = MVT::i16;
4876 } else {
4877 DemotedVT = MVT::i32;
4878 }
4879
4880 // Truncate the operands to the correct size. Note that these are just for
4881 // type consistency and will (likely) be eliminated in later phases.
4882 SDValue TruncLHS =
4883 DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, LHS);
4884 SDValue TruncRHS =
4885 DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, RHS);
4886
4887 unsigned Opc;
4888 if (Signed) {
4890 } else {
4892 }
4893
4894 return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
4895}
4896
4897static bool isConstOne(const SDValue &Operand) {
4898 const auto *Const = dyn_cast<ConstantSDNode>(Operand);
4899 return Const && Const->getZExtValue() == 1;
4900}
4901
4903 if (Add->getOpcode() != ISD::ADD)
4904 return SDValue();
4905
4906 if (isConstOne(Add->getOperand(0)))
4907 return Add->getOperand(1);
4908
4909 if (isConstOne(Add->getOperand(1)))
4910 return Add->getOperand(0);
4911
4912 return SDValue();
4913}
4914
4917
4919 SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
4920 return DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, X);
4921 }
4922
4923 return SDValue();
4924}
4925
4927 SDLoc DL,
4929 if (Select->getOpcode() != ISD::SELECT)
4930 return SDValue();
4931
4932 SDValue Cond = Select->getOperand(0);
4933
4934 unsigned ConstOpNo;
4935 if (isConstOne(Select->getOperand(1)))
4936 ConstOpNo = 1;
4937 else if (isConstOne(Select->getOperand(2)))
4938 ConstOpNo = 2;
4939 else
4940 return SDValue();
4941
4942 SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);
4943
4944 // Do not combine if the resulting sequence is not obviously profitable.
4946 return SDValue();
4947
4948 SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
4949
4950 return DCI.DAG.getNode(ISD::SELECT, DL, VT, Cond,
4951 (ConstOpNo == 1) ? X : NewMul,
4952 (ConstOpNo == 1) ? NewMul : X);
4953}
4954
4955static SDValue
4958
4959 EVT VT = N0.getValueType();
4960 if (VT.isVector())
4961 return SDValue();
4962
4963 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
4964 return SDValue();
4965
4966 SDLoc DL(N);
4967
4968 // (mul x, (add y, 1)) -> (add (mul x, y), x)
4969 if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
4970 return Res;
4971 if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
4972 return Res;
4973
4974 // (mul x, (select y, 1)) -> (select (mul x, y), x)
4975 if (SDValue Res = combineMulSelectConstOne(N0, N1, VT, DL, DCI))
4976 return Res;
4977 if (SDValue Res = combineMulSelectConstOne(N1, N0, VT, DL, DCI))
4978 return Res;
4979
4980 return SDValue();
4981}
4982
4983/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
4986 CodeGenOptLevel OptLevel) {
4987 if (OptLevel == CodeGenOptLevel::None)
4988 return SDValue();
4989
4990 if (SDValue Ret = TryMULWIDECombine(N, DCI))
4991 return Ret;
4992
4993 SDValue N0 = N->getOperand(0);
4994 SDValue N1 = N->getOperand(1);
4995 return PerformMULCombineWithOperands(N, N0, N1, DCI);
4996}
4997
4998/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
5001 CodeGenOptLevel OptLevel) {
5002 if (OptLevel > CodeGenOptLevel::None) {
5003 // Try mul.wide combining at OptLevel > 0
5004 if (SDValue Ret = TryMULWIDECombine(N, DCI))
5005 return Ret;
5006 }
5007
5008 return SDValue();
5009}
5010
5013 unsigned int SmVersion) {
5014 EVT CCType = N->getValueType(0);
5015 SDValue A = N->getOperand(0);
5016 SDValue B = N->getOperand(1);
5017
5018 EVT AType = A.getValueType();
5019 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
5020 return SDValue();
5021
5022 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
5023 return SDValue();
5024
5025 SDLoc DL(N);
5026 // setp.f16x2 returns two scalar predicates, which we need to
5027 // convert back to v2i1. The returned result will be scalarized by
5028 // the legalizer, but the comparison will remain a single vector
5029 // instruction.
5030 SDValue CCNode = DCI.DAG.getNode(
5031 A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
5033 DL, DCI.DAG.getVTList(MVT::i1, MVT::i1), {A, B, N->getOperand(2)});
5034 return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0),
5035 CCNode.getValue(1));
5036}
5037
5040 SDValue Vector = N->getOperand(0);
5041 if (Vector->getOpcode() == ISD::FREEZE)
5042 Vector = Vector->getOperand(0);
5043 SDLoc DL(N);
5044 EVT VectorVT = Vector.getValueType();
5045 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
5046 IsPTXVectorType(VectorVT.getSimpleVT()))
5047 return SDValue(); // Native vector loads already combine nicely w/
5048 // extract_vector_elt.
5049 // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
5050 // handle them OK.
5051 if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
5052 VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
5053 return SDValue();
5054
5055 // Don't mess with undef values as sra may be simplified to 0, not undef.
5056 if (Vector->isUndef() || ISD::allOperandsUndef(Vector.getNode()))
5057 return SDValue();
5058
5059 uint64_t VectorBits = VectorVT.getSizeInBits();
5060 // We only handle the types we can extract in-register.
5061 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
5062 return SDValue();
5063
5064 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(N->getOperand(1));
5065 // Index == 0 is handled by generic DAG combiner.
5066 if (!Index || Index->getZExtValue() == 0)
5067 return SDValue();
5068
5069 MVT IVT = MVT::getIntegerVT(VectorBits);
5070 EVT EltVT = VectorVT.getVectorElementType();
5071 EVT EltIVT = EltVT.changeTypeToInteger();
5072 uint64_t EltBits = EltVT.getScalarSizeInBits();
5073
5074 SDValue Result = DCI.DAG.getNode(
5075 ISD::TRUNCATE, DL, EltIVT,
5076 DCI.DAG.getNode(
5077 ISD::SRA, DL, IVT, DCI.DAG.getNode(ISD::BITCAST, DL, IVT, Vector),
5078 DCI.DAG.getConstant(Index->getZExtValue() * EltBits, DL, IVT)));
5079
5080 // If element has non-integer type, bitcast it back to the expected type.
5081 if (EltVT != EltIVT)
5082 Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result);
5083 // Past legalizer, we may need to extent i8 -> i16 to match the register type.
5084 if (EltVT != N->getValueType(0))
5085 Result = DCI.DAG.getNode(ISD::ANY_EXTEND, DL, N->getValueType(0), Result);
5086
5087 return Result;
5088}
5089
5092 SDValue VA = N->getOperand(1);
5093 EVT VectorVT = VA.getValueType();
5094 if (VectorVT != MVT::v4i8)
5095 return SDValue();
5096
5097 // We need to split vselect into individual per-element operations Because we
5098 // use BFE/BFI instruction for byte extraction/insertion, we do end up with
5099 // 32-bit values, so we may as well do comparison as i32 to avoid conversions
5100 // to/from i16 normally used for i8 values.
5102 SDLoc DL(N);
5103 SDValue VCond = N->getOperand(0);
5104 SDValue VB = N->getOperand(2);
5105 for (int I = 0; I < 4; ++I) {
5106 SDValue C = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i1, VCond,
5107 DCI.DAG.getConstant(I, DL, MVT::i32));
5108 SDValue EA = DCI.DAG.getAnyExtOrTrunc(
5109 DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VA,
5110 DCI.DAG.getConstant(I, DL, MVT::i32)),
5111 DL, MVT::i32);
5112 SDValue EB = DCI.DAG.getAnyExtOrTrunc(
5113 DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VB,
5114 DCI.DAG.getConstant(I, DL, MVT::i32)),
5115 DL, MVT::i32);
5117 DCI.DAG.getNode(ISD::SELECT, DL, MVT::i32, C, EA, EB), DL, MVT::i8));
5118 }
5119 return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E);
5120}
5121
5122static SDValue
5124 auto VT = N->getValueType(0);
5125 if (!DCI.isAfterLegalizeDAG() || !Isv2x16VT(VT))
5126 return SDValue();
5127
5128 auto Op0 = N->getOperand(0);
5129 auto Op1 = N->getOperand(1);
5130
5131 // Start out by assuming we want to take the lower 2 bytes of each i32
5132 // operand.
5133 uint64_t Op0Bytes = 0x10;
5134 uint64_t Op1Bytes = 0x54;
5135
5136 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
5137 {&Op1, &Op1Bytes}};
5138
5139 // Check that each operand is an i16, truncated from an i32 operand. We'll
5140 // select individual bytes from those original operands. Optionally, fold in a
5141 // shift right of that original operand.
5142 for (auto &[Op, OpBytes] : OpData) {
5143 // Eat up any bitcast
5144 if (Op->getOpcode() == ISD::BITCAST)
5145 *Op = Op->getOperand(0);
5146
5147 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
5148 Op->getOperand(0).getValueType() == MVT::i32))
5149 return SDValue();
5150
5151 // If the truncate has multiple uses, this optimization can increase
5152 // register pressure
5153 if (!Op->hasOneUse())
5154 return SDValue();
5155
5156 *Op = Op->getOperand(0);
5157
5158 // Optionally, fold in a shift-right of the original operand and let permute
5159 // pick the two higher bytes of the original value directly.
5160 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) {
5161 if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) {
5162 // Shift the PRMT byte selector to pick upper bytes from each respective
5163 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
5164 assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
5165 "PRMT selector values out of range");
5166 *OpBytes += 0x22;
5167 *Op = Op->getOperand(0);
5168 }
5169 }
5170 }
5171
5172 SDLoc DL(N);
5173 auto &DAG = DCI.DAG;
5174
5175 auto PRMT = DAG.getNode(
5176 NVPTXISD::PRMT, DL, MVT::v4i8,
5177 {Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32),
5178 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
5179 return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
5180}
5181
5182SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5183 DAGCombinerInfo &DCI) const {
5185 switch (N->getOpcode()) {
5186 default: break;
5187 case ISD::ADD:
5188 return PerformADDCombine(N, DCI, OptLevel);
5189 case ISD::FADD:
5190 return PerformFADDCombine(N, DCI, OptLevel);
5191 case ISD::MUL:
5192 return PerformMULCombine(N, DCI, OptLevel);
5193 case ISD::SHL:
5194 return PerformSHLCombine(N, DCI, OptLevel);
5195 case ISD::AND:
5196 return PerformANDCombine(N, DCI);
5197 case ISD::UREM:
5198 case ISD::SREM:
5199 return PerformREMCombine(N, DCI, OptLevel);
5200 case ISD::SETCC:
5201 return PerformSETCCCombine(N, DCI, STI.getSmVersion());
5211 return PerformEXTRACTCombine(N, DCI);
5212 case ISD::VSELECT:
5213 return PerformVSELECTCombine(N, DCI);
5214 case ISD::BUILD_VECTOR:
5215 return PerformBUILD_VECTORCombine(N, DCI);
5216 }
5217 return SDValue();
5218}
5219
5222 // Handle bitcasting to v2i8 without hitting the default promotion
5223 // strategy which goes through stack memory.
5224 SDValue Op(Node, 0);
5225 EVT ToVT = Op->getValueType(0);
5226 if (ToVT != MVT::v2i8) {
5227 return;
5228 }
5229
5230 // Bitcast to i16 and unpack elements into a vector
5231 SDLoc DL(Node);
5232 SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
5233 SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
5234 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
5235 SDValue Vec1 =
5236 DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
5237 DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
5238 Results.push_back(
5239 DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
5240}
5241
5242/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
5245 EVT ResVT = N->getValueType(0);
5246 SDLoc DL(N);
5247
5248 assert(ResVT.isVector() && "Vector load must have vector type");
5249
5250 auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
5251 if (!NumEltsAndEltVT)
5252 return;
5253 auto [NumElts, EltVT] = NumEltsAndEltVT.value();
5254
5255 LoadSDNode *LD = cast<LoadSDNode>(N);
5256
5257 Align Alignment = LD->getAlign();
5258 auto &TD = DAG.getDataLayout();
5259 Align PrefAlign =
5260 TD.getPrefTypeAlign(LD->getMemoryVT().getTypeForEVT(*DAG.getContext()));
5261 if (Alignment < PrefAlign) {
5262 // This load is not sufficiently aligned, so bail out and let this vector
5263 // load be scalarized. Note that we may still be able to emit smaller
5264 // vector loads. For example, if we are loading a <4 x float> with an
5265 // alignment of 8, this check will fail but the legalizer will try again
5266 // with 2 x <2 x float>, which will succeed with an alignment of 8.
5267 return;
5268 }
5269
5270 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
5271 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
5272 // loaded type to i16 and propagate the "real" type as the memory type.
5273 bool NeedTrunc = false;
5274 if (EltVT.getSizeInBits() < 16) {
5275 EltVT = MVT::i16;
5276 NeedTrunc = true;
5277 }
5278
5279 unsigned Opcode = 0;
5280 SDVTList LdResVTs;
5281
5282 switch (NumElts) {
5283 default:
5284 return;
5285 case 2:
5286 Opcode = NVPTXISD::LoadV2;
5287 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
5288 break;
5289 case 4: {
5290 Opcode = NVPTXISD::LoadV4;
5291 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
5292 LdResVTs = DAG.getVTList(ListVTs);
5293 break;
5294 }
5295 }
5296
5297 // Copy regular operands
5298 SmallVector<SDValue, 8> OtherOps(N->ops());
5299
5300 // The select routine does not have access to the LoadSDNode instance, so
5301 // pass along the extension information
5302 OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
5303
5304 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
5305 LD->getMemoryVT(),
5306 LD->getMemOperand());
5307
5308 SmallVector<SDValue> ScalarRes;
5309 assert(NumElts <= ResVT.getVectorNumElements() &&
5310 "NumElts should not increase, only decrease or stay the same.");
5311 if (NumElts < ResVT.getVectorNumElements()) {
5312 // If the number of elements has decreased, getVectorLoweringShape has
5313 // upsized the element types
5314 assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 &&
5315 EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type.");
5316 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
5317 // into individual elements.
5318 for (unsigned i = 0; i < NumElts; ++i) {
5319 SDValue SubVector = NewLD.getValue(i);
5320 DAG.ExtractVectorElements(SubVector, ScalarRes);
5321 }
5322 } else {
5323 for (unsigned i = 0; i < NumElts; ++i) {
5324 SDValue Res = NewLD.getValue(i);
5325 if (NeedTrunc)
5326 Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
5327 ScalarRes.push_back(Res);
5328 }
5329 }
5330
5331 SDValue LoadChain = NewLD.getValue(NumElts);
5332
5333 SDValue BuildVec = DAG.getBuildVector(ResVT, DL, ScalarRes);
5334
5335 Results.push_back(BuildVec);
5336 Results.push_back(LoadChain);
5337}
5338
5341 SDValue Chain = N->getOperand(0);
5342 SDValue Intrin = N->getOperand(1);
5343 SDLoc DL(N);
5344
5345 // Get the intrinsic ID
5346 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
5347 switch (IntrinNo) {
5348 default:
5349 return;
5350 case Intrinsic::nvvm_ldu_global_i:
5351 case Intrinsic::nvvm_ldu_global_f:
5352 case Intrinsic::nvvm_ldu_global_p: {
5353 EVT ResVT = N->getValueType(0);
5354
5355 if (ResVT.isVector()) {
5356 // Vector LDG/LDU
5357
5358 unsigned NumElts = ResVT.getVectorNumElements();
5359 EVT EltVT = ResVT.getVectorElementType();
5360
5361 // Since LDU/LDG are target nodes, we cannot rely on DAG type
5362 // legalization.
5363 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
5364 // loaded type to i16 and propagate the "real" type as the memory type.
5365 bool NeedTrunc = false;
5366 if (EltVT.getSizeInBits() < 16) {
5367 EltVT = MVT::i16;
5368 NeedTrunc = true;
5369 }
5370
5371 unsigned Opcode = 0;
5372 SDVTList LdResVTs;
5373
5374 switch (NumElts) {
5375 default:
5376 return;
5377 case 2:
5378 Opcode = NVPTXISD::LDUV2;
5379 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
5380 break;
5381 case 4: {
5382 Opcode = NVPTXISD::LDUV4;
5383 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
5384 LdResVTs = DAG.getVTList(ListVTs);
5385 break;
5386 }
5387 }
5388
5389 SmallVector<SDValue, 8> OtherOps;
5390
5391 // Copy regular operands
5392
5393 OtherOps.push_back(Chain); // Chain
5394 // Skip operand 1 (intrinsic ID)
5395 // Others
5396 OtherOps.append(N->op_begin() + 2, N->op_end());
5397
5398 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
5399
5400 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
5401 MemSD->getMemoryVT(),
5402 MemSD->getMemOperand());
5403
5404 SmallVector<SDValue, 4> ScalarRes;
5405
5406 for (unsigned i = 0; i < NumElts; ++i) {
5407 SDValue Res = NewLD.getValue(i);
5408 if (NeedTrunc)
5409 Res =
5410 DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
5411 ScalarRes.push_back(Res);
5412 }
5413
5414 SDValue LoadChain = NewLD.getValue(NumElts);
5415
5416 SDValue BuildVec =
5417 DAG.getBuildVector(ResVT, DL, ScalarRes);
5418
5419 Results.push_back(BuildVec);
5420 Results.push_back(LoadChain);
5421 } else {
5422 // i8 LDG/LDU
5423 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
5424 "Custom handling of non-i8 ldu/ldg?");
5425
5426 // Just copy all operands as-is
5427 SmallVector<SDValue, 4> Ops(N->ops());
5428
5429 // Force output to i16
5430 SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
5431
5432 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
5433
5434 // We make sure the memory type is i8, which will be used during isel
5435 // to select the proper instruction.
5436 SDValue NewLD =
5437 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, Ops,
5438 MVT::i8, MemSD->getMemOperand());
5439
5440 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
5441 NewLD.getValue(0)));
5442 Results.push_back(NewLD.getValue(1));
5443 }
5444 }
5445 }
5446}
5447
5450 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
5451 // result so that it can pass the legalization
5452 SDLoc DL(N);
5453 SDValue Chain = N->getOperand(0);
5454 SDValue Reg = N->getOperand(1);
5455 SDValue Glue = N->getOperand(2);
5456
5457 assert(Reg.getValueType() == MVT::i128 &&
5458 "Custom lowering for CopyFromReg with 128-bit reg only");
5459 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(1),
5460 N->getValueType(2)};
5461 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
5462
5463 SDValue NewValue = DAG.getNode(ISD::CopyFromReg, DL, ResultsType, NewOps);
5464 SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
5465 {NewValue.getValue(0), NewValue.getValue(1)});
5466
5467 Results.push_back(Pair);
5468 Results.push_back(NewValue.getValue(2));
5469 Results.push_back(NewValue.getValue(3));
5470}
5471
5472void NVPTXTargetLowering::ReplaceNodeResults(
5474 switch (N->getOpcode()) {
5475 default:
5476 report_fatal_error("Unhandled custom legalization");
5477 case ISD::BITCAST:
5478 ReplaceBITCAST(N, DAG, Results);
5479 return;
5480 case ISD::LOAD:
5482 return;
5485 return;
5486 case ISD::CopyFromReg:
5488 return;
5489 }
5490}
5491
5494 Type *Ty = AI->getValOperand()->getType();
5495
5496 if (AI->isFloatingPointOperation()) {
5498 if (Ty->isHalfTy() && STI.getSmVersion() >= 70 &&
5499 STI.getPTXVersion() >= 63)
5501 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
5502 STI.getPTXVersion() >= 78)
5504 if (Ty->isFloatTy())
5506 if (Ty->isDoubleTy() && STI.hasAtomAddF64())
5508 }
5510 }
5511
5512 assert(Ty->isIntegerTy() && "Ty should be integer at this point");
5513 auto ITy = cast<llvm::IntegerType>(Ty);
5514
5515 switch (AI->getOperation()) {
5516 default:
5522 switch (ITy->getBitWidth()) {
5523 case 8:
5524 case 16:
5526 case 32:
5528 case 64:
5529 if (STI.hasAtomBitwise64())
5532 default:
5533 llvm_unreachable("unsupported width encountered");
5534 }
5541 switch (ITy->getBitWidth()) {
5542 case 8:
5543 case 16:
5545 case 32:
5547 case 64:
5548 if (STI.hasAtomMinMax64())
5551 default:
5552 llvm_unreachable("unsupported width encountered");
5553 }
5554 }
5555
5557}
5558
5559// Pin NVPTXTargetObjectFile's vtables to this file.
5561
5563 const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
5564 return getDataSection();
5565}
#define MAKE_CASE(V)
static const LLT F32
AMDGPU Register Bank Select
This file implements a class to represent arbitrary precision integral constant values and operations...
static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformADDCombineWithOperands - Try DAG combinations for an ADD with operands N0 and N1.
static SDValue PerformADDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
static SDValue PerformVSELECTCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
static SDValue PerformMULCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
static SDValue PerformFADDCombine(SDNode *N, SelectionDAG &DAG, const ARMSubtarget *Subtarget)
static SDValue PerformANDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
static SDValue PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformBUILD_VECTORCombine - Target-specific dag combine xforms for ISD::BUILD_VECTOR.
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Function Alias Analysis Results
This file contains the simple types necessary to represent the attributes associated with functions a...
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
This file contains the declarations for the subclasses of Constant, which represent the different fla...
return RetTy
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
uint64_t Size
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
This file contains the declarations of entities that describe floating point environment and related ...
Module.h This file contains the declarations for the Module class.
static LVOptions Options
Definition: LVOptions.cpp:25
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
static DebugLoc getDebugLoc(MachineBasicBlock::instr_iterator FirstMI, MachineBasicBlock::instr_iterator LastMI)
Return the first found DebugLoc that has a DILocation, given a range of instructions.
unsigned const TargetRegisterInfo * TRI
NVPTX address space definition.
static bool shouldConvertToIndirectCall(const CallBase *CB, const GlobalAddressSDNode *Func)
static cl::opt< bool > sched4reg("nvptx-sched4reg", cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false))
static SDValue PerformEXTRACTCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
static bool isConstOne(const SDValue &Operand)
static cl::opt< unsigned > FMAContractLevelOpt("nvptx-fma-level", cl::Hidden, cl::desc("NVPTX Specific: FMA contraction (0: don't do it" " 1: do it 2: do it aggressively"), cl::init(2))
static bool IsPTXVectorType(MVT VT)
static cl::opt< int > UsePrecDivF32("nvptx-prec-divf32", cl::Hidden, cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use" " IEEE Compliant F32 div.rnd if available."), cl::init(2))
static SDValue PerformStoreParamCombine(SDNode *N)
static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static bool Is16bitsType(MVT VT)
static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL, TargetLowering::DAGCombinerInfo &DCI)
static bool IsTypePassedAsArray(const Type *Ty)
static SmallVector< ParamVectorizationFlags, 16 > VectorizePTXValueVTs(const SmallVectorImpl< EVT > &ValueVTs, const SmallVectorImpl< uint64_t > &Offsets, Align ParamAlignment, bool IsVAArg=false)
static unsigned CanMergeParamLoadStoresStartingAt(unsigned Idx, uint32_t AccessSize, const SmallVectorImpl< EVT > &ValueVTs, const SmallVectorImpl< uint64_t > &Offsets, Align ParamAlignment)
static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static SDValue PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static bool isConstZero(const SDValue &Operand)
static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG)
static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, Type *Ty, SmallVectorImpl< EVT > &ValueVTs, SmallVectorImpl< uint64_t > *Offsets=nullptr, uint64_t StartingOffset=0)
ComputePTXValueVTs - For the given Type Ty, returns the set of primitive EVTs that compose it.
static bool IsMulWideOperandDemotable(SDValue Op, unsigned OptSize, OperandSignedness &S)
IsMulWideOperandDemotable - Checks if the provided DAG node is an operand that can be demoted to OptS...
static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain, uint64_t Offset, EVT ElementType, SDValue StVal, SDValue &InGlue, unsigned ArgID, const SDLoc &dl)
static SDValue PerformREMCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static std::optional< std::pair< unsigned int, EVT > > getVectorLoweringShape(EVT VectorVT)
static SDValue PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI)
static SDValue PerformStoreRetvalCombine(SDNode *N)
static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS, unsigned OptSize, bool &IsSigned)
AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can be demoted to OptSize bits...
static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front, std::size_t Back)
static bool adjustElementType(EVT &ElementType)
static SDValue TryMULWIDECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply of M/2 bits that produces...
static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT, SDLoc DL, TargetLowering::DAGCombinerInfo &DCI)
static SDValue matchMADConstOnePattern(SDValue Add)
static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT, SDValue Value)
static cl::opt< bool > UsePrecSqrtF32("nvptx-prec-sqrtf32", cl::Hidden, cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."), cl::init(true))
ParamVectorizationFlags
@ PVF_FIRST
@ PVF_SCALAR
@ PVF_INNER
@ PVF_LAST
static SDValue LowerUnalignedStoreRet(SelectionDAG &DAG, SDValue Chain, uint64_t Offset, EVT ElementType, SDValue RetVal, const SDLoc &dl)
static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG)
static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT)
PromoteScalarIntegerPTX Used to make sure the arguments/returns are suitable for passing and promote ...
OperandSignedness
static SDValue PerformSETCCCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, unsigned int SmVersion)
static SDValue LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset, EVT ElementType, SDValue &InGlue, SmallVectorImpl< SDValue > &TempProxyRegOps, const SDLoc &dl)
static std::atomic< unsigned > GlobalUniqueCallSite
static cl::opt< bool > ForceMinByValParamAlign("nvptx-force-min-byval-param-align", cl::Hidden, cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval" " params of device functions."), cl::init(false))
static SDValue PerformSHLCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
unsigned SmVersion
Definition: NVVMReflect.cpp:79
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
if(PassOpts->AAPipeline)
const SmallVectorImpl< MachineOperand > & Cond
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file contains some templates that are useful if you are working with the STL at all.
This file defines the SmallVector class.
static bool Enabled
Definition: Statistic.cpp:46
This file describes how to lower LLVM code to machine code.
Value * RHS
Value * LHS
Class for arbitrary precision integers.
Definition: APInt.h:78
bool isSignedIntN(unsigned N) const
Check if this APInt has an N-bits signed integer value.
Definition: APInt.h:435
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1130
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:432
bool sge(const APInt &RHS) const
Signed greater or equal comparison.
Definition: APInt.h:1237
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
const T & back() const
back - Get the last element.
Definition: ArrayRef.h:177
ArrayRef< T > drop_back(size_t N=1) const
Drop the last N elements of the array.
Definition: ArrayRef.h:213
bool empty() const
empty - Check if the array is empty.
Definition: ArrayRef.h:163
an instruction that atomically reads a memory location, combines it with another value,...
Definition: Instructions.h:704
@ Add
*p = old + v
Definition: Instructions.h:720
@ FAdd
*p = old + v
Definition: Instructions.h:741
@ Min
*p = old <signed v ? old : v
Definition: Instructions.h:734
@ Or
*p = old | v
Definition: Instructions.h:728
@ Sub
*p = old - v
Definition: Instructions.h:722
@ And
*p = old & v
Definition: Instructions.h:724
@ Xor
*p = old ^ v
Definition: Instructions.h:730
@ Max
*p = old >signed v ? old : v
Definition: Instructions.h:732
@ UMin
*p = old <unsigned v ? old : v
Definition: Instructions.h:738
@ UMax
*p = old >unsigned v ? old : v
Definition: Instructions.h:736
bool isFloatingPointOperation() const
Definition: Instructions.h:882
BinOp getOperation() const
Definition: Instructions.h:805
Value * getValOperand()
Definition: Instructions.h:874
bool hasParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const
Return true if the attribute exists for the given argument.
Definition: Attributes.h:833
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1112
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1341
FunctionType * getFunctionType() const
Definition: InstrTypes.h:1199
This class represents a function call, abstracting a target machine's calling convention.
uint64_t getZExtValue() const
const APInt & getAPIntValue() const
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:373
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
TypeSize getTypeAllocSize(Type *Ty) const
Returns the offset in bytes between successive objects of the specified type, including alignment pad...
Definition: DataLayout.h:457
Align getPrefTypeAlign(Type *Ty) const
Returns the preferred stack/global alignment for the specified type.
Definition: DataLayout.cpp:847
Diagnostic information for unsupported feature in backend.
void addFnAttr(Attribute::AttrKind Kind)
Add function attributes to this function.
Definition: Function.cpp:641
Type * getReturnType() const
Returns the type of the ret val.
Definition: Function.h:221
unsigned getAddressSpace() const
const GlobalValue * getGlobal() const
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
void diagnose(const DiagnosticInfo &DI)
Report a message to the currently installed diagnostic handler.
This class is used to represent ISD::LOAD nodes.
MCSection * getDataSection() const
Instances of this class represent a uniqued identifier for a section in the current translation unit.
Definition: MCSection.h:36
StringRef getName() const
getName - Get the symbol name.
Definition: MCSymbol.h:205
Machine Value Type.
SimpleValueType SimpleTy
unsigned getVectorNumElements() const
bool isScalableVector() const
Return true if this is a vector value type where the runtime length is machine dependent.
static auto integer_valuetypes()
static auto fixedlen_vector_valuetypes()
static MVT getVectorVT(MVT VT, unsigned NumElements)
static MVT getIntegerVT(unsigned BitWidth)
MVT getScalarType() const
If this is a vector, return the element type, otherwise return this.
DenormalMode getDenormalMode(const fltSemantics &FPType) const
Returns the denormal handling type for the default rounding mode of the function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
const MachineJumpTableInfo * getJumpTableInfo() const
getJumpTableInfo - Return the jump table info object for the current function.
const TargetMachine & getTarget() const
getTarget - Return the target machine this machine code is compiled with
@ EK_Inline
EK_Inline - Jump table entries are emitted inline at their point of use.
const std::vector< MachineJumpTableEntry > & getJumpTables() const
@ MODereferenceable
The memory access is dereferenceable (i.e., doesn't trap).
@ MOLoad
The memory access reads data.
@ MOInvariant
The memory access always returns the same value (or traps).
@ MOStore
The memory access writes data.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Register createVirtualRegister(const TargetRegisterClass *RegClass, StringRef Name="")
createVirtualRegister - Create and return a new virtual register in the function with the specified r...
This SDNode is used for target intrinsics that touch memory and need an associated MachineMemOperand.
This is an abstract virtual class for memory operations.
Align getAlign() const
MachineMemOperand * getMemOperand() const
Return a MachineMemOperand object describing the memory reference performed by operation.
EVT getMemoryVT() const
Return the type of the in-memory value.
unsigned getMaxRequiredAlignment() const
bool hasAtomMinMax64() const
bool hasAtomAddF64() const
bool hasHWROT32() const
const NVPTXTargetLowering * getTargetLowering() const override
unsigned getPTXVersion() const
bool hasNativeBF16Support(int Opcode) const
const NVPTXRegisterInfo * getRegisterInfo() const override
unsigned int getSmVersion() const
bool hasAtomBitwise64() const
bool hasBF16Math() const
bool allowFP16Math() const
bool hasAtomCas16() const
ConstraintType getConstraintType(StringRef Constraint) const override
getConstraintType - Given a constraint letter, return the type of constraint it is for this target.
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override
This callback is invoked for operations that are unsupported by the target, which are registered to u...
const NVPTXTargetMachine * nvTM
SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const
NVPTXTargetLowering(const NVPTXTargetMachine &TM, const NVPTXSubtarget &STI)
bool useF32FTZ(const MachineFunction &MF) const
SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const
Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const
SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps, bool &UseOneConst, bool Reciprocal) const override
Hooks for building estimates in place of slower divisions and square roots.
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl< ISD::OutputArg > &Outs, const SmallVectorImpl< SDValue > &OutVals, const SDLoc &dl, SelectionDAG &DAG) const override
This hook must be implemented to lower outgoing return values, described by the Outs array,...
SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl< ISD::InputArg > &Ins, const SDLoc &dl, SelectionDAG &DAG, SmallVectorImpl< SDValue > &InVals) const override
This hook must be implemented to lower the incoming (formal) arguments, described by the Ins array,...
void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint, std::vector< SDValue > &Ops, SelectionDAG &DAG) const override
Lower the specified operand into the Ops vector.
SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const
std::string getParamName(const Function *F, int Idx) const
TargetLoweringBase::LegalizeTypeAction getPreferredVectorAction(MVT VT) const override
Return the preferred vector type legalization action.
std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &, const SmallVectorImpl< ISD::OutputArg > &, MaybeAlign retAlignment, std::optional< std::pair< unsigned, const APInt & > > VAInfo, const CallBase &CB, unsigned UniqueCallSite) const
Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy, const DataLayout &DL) const
getFunctionParamOptimizedAlign - since function arguments are passed via .param space,...
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const
EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Ctx, EVT VT) const override
Return the ValueType of the result of SETCC operations.
std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const override
Given a physical register constraint (e.g.
bool isLegalAddressingMode(const DataLayout &DL, const AddrMode &AM, Type *Ty, unsigned AS, Instruction *I=nullptr) const override
isLegalAddressingMode - Return true if the addressing mode represented by AM is legal for this target...
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const override
Returns how the IR-level AtomicExpand pass should expand the given AtomicRMW, if at all.
Align getFunctionByValParamAlign(const Function *F, Type *ArgTy, Align InitialAlign, const DataLayout &DL) const
Helper for computing alignment of a device function byval parameter.
bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
const char * getTargetNodeName(unsigned Opcode) const override
This method returns the name of a target specific DAG node.
bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const
unsigned getJumpTableEncoding() const override
Return the entry encoding for a jump table in the current function.
bool allowUnsafeFPMath(MachineFunction &MF) const
SDValue LowerCall(CallLoweringInfo &CLI, SmallVectorImpl< SDValue > &InVals) const override
This hook must be implemented to lower calls into the specified DAG.
UniqueStringSaver & getStrPool() const
MCSection * SelectSectionForGlobal(const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const override
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Wrapper class for IR location info (IR ordering and DebugLoc) to be passed into SDNode creation funct...
Represents one node in the SelectionDAG.
const APInt & getAsAPIntVal() const
Helper method returns the APInt value of a ConstantSDNode.
unsigned getOpcode() const
Return the SelectionDAG opcode value for this node.
bool hasOneUse() const
Return true if there is exactly one use of this node.
unsigned getIROrder() const
Return the node ordering.
uint64_t getAsZExtVal() const
Helper method returns the zero-extended integer value of a ConstantSDNode.
unsigned getNumOperands() const
Return the number of values used by this operation.
SDVTList getVTList() const
const SDValue & getOperand(unsigned Num) const
uint64_t getConstantOperandVal(unsigned Num) const
Helper method returns the integer value of a ConstantSDNode operand.
const APInt & getConstantOperandAPInt(unsigned Num) const
Helper method returns the APInt of a ConstantSDNode operand.
EVT getValueType(unsigned ResNo) const
Return the type of a specified result.
bool isUndef() const
Return true if the type of the node type undefined.
iterator_range< user_iterator > users()
Represents a use of a SDNode.
Unlike LLVM values, Selection DAG nodes may return multiple values as the result of a computation.
SDNode * getNode() const
get the SDNode which holds the desired result
SDValue getValue(unsigned R) const
EVT getValueType() const
Return the ValueType of the referenced return value.
TypeSize getValueSizeInBits() const
Returns the size of the value in bits.
const SDValue & getOperand(unsigned i) const
MVT getSimpleValueType() const
Return the simple ValueType of the referenced return value.
unsigned getOpcode() const
SectionKind - This is a simple POD value that classifies the properties of a section.
Definition: SectionKind.h:22
This is used to represent a portion of an LLVM function in a low-level Data Dependence DAG representa...
Definition: SelectionDAG.h:228
SDValue getExtLoad(ISD::LoadExtType ExtType, const SDLoc &dl, EVT VT, SDValue Chain, SDValue Ptr, MachinePointerInfo PtrInfo, EVT MemVT, MaybeAlign Alignment=MaybeAlign(), MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
SDValue getTargetGlobalAddress(const GlobalValue *GV, const SDLoc &DL, EVT VT, int64_t offset=0, unsigned TargetFlags=0)
Definition: SelectionDAG.h:750
const SDValue & getRoot() const
Return the root tag of the SelectionDAG.
Definition: SelectionDAG.h:577
SDValue getAddrSpaceCast(const SDLoc &dl, EVT VT, SDValue Ptr, unsigned SrcAS, unsigned DestAS)
Return an AddrSpaceCastSDNode.
SDValue getCopyToReg(SDValue Chain, const SDLoc &dl, Register Reg, SDValue N)
Definition: SelectionDAG.h:801
SDValue getMergeValues(ArrayRef< SDValue > Ops, const SDLoc &dl)
Create a MERGE_VALUES node from the given operands.
SDVTList getVTList(EVT VT)
Return an SDVTList that represents the list of values specified.
void ExtractVectorElements(SDValue Op, SmallVectorImpl< SDValue > &Args, unsigned Start=0, unsigned Count=0, EVT EltVT=EVT())
Append the extracted elements from Start to Count out of the vector Op in Args.
SDValue getSetCC(const SDLoc &DL, EVT VT, SDValue LHS, SDValue RHS, ISD::CondCode Cond, SDValue Chain=SDValue(), bool IsSignaling=false)
Helper function to make it easier to build SetCC's if you just have an ISD::CondCode instead of an SD...
SDValue getSymbolFunctionGlobalAddress(SDValue Op, Function **TargetFunction=nullptr)
Return a GlobalAddress of the function from the current module with name matching the given ExternalS...
SDValue getConstantFP(double Val, const SDLoc &DL, EVT VT, bool isTarget=false)
Create a ConstantFPSDNode wrapping a constant value.
SDValue getLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, MachinePointerInfo PtrInfo, MaybeAlign Alignment=MaybeAlign(), MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes(), const MDNode *Ranges=nullptr)
Loads are not normal binary operators: their result type is not determined by their operands,...
const TargetLowering & getTargetLoweringInfo() const
Definition: SelectionDAG.h:503
SDNode * MorphNodeTo(SDNode *N, unsigned Opc, SDVTList VTs, ArrayRef< SDValue > Ops)
This mutates the specified node to have the specified return type, opcode, and operands.
SDValue getCALLSEQ_END(SDValue Chain, SDValue Op1, SDValue Op2, SDValue InGlue, const SDLoc &DL)
Return a new CALLSEQ_END node, which always must have a glue result (to ensure it's not CSE'd).
SDValue getBuildVector(EVT VT, const SDLoc &DL, ArrayRef< SDValue > Ops)
Return an ISD::BUILD_VECTOR node.
Definition: SelectionDAG.h:856
SDValue getBitcast(EVT VT, SDValue V)
Return a bitcast using the SDLoc of the value operand, and casting to the provided type.
SDValue getCopyFromReg(SDValue Chain, const SDLoc &dl, Register Reg, EVT VT)
Definition: SelectionDAG.h:827
SDValue getSelect(const SDLoc &DL, EVT VT, SDValue Cond, SDValue LHS, SDValue RHS, SDNodeFlags Flags=SDNodeFlags())
Helper function to make it easier to build Select's if you just have operands and don't want to check...
const DataLayout & getDataLayout() const
Definition: SelectionDAG.h:497
SDValue getConstant(uint64_t Val, const SDLoc &DL, EVT VT, bool isTarget=false, bool isOpaque=false)
Create a ConstantSDNode wrapping a constant value.
SDValue getTruncStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, MachinePointerInfo PtrInfo, EVT SVT, Align Alignment, MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
void ReplaceAllUsesWith(SDValue From, SDValue To)
Modify anything using 'From' to use 'To' instead.
SDValue getStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, MachinePointerInfo PtrInfo, Align Alignment, MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
Helper function to build ISD::STORE nodes.
SDValue getSignedConstant(int64_t Val, const SDLoc &DL, EVT VT, bool isTarget=false, bool isOpaque=false)
SDValue getCALLSEQ_START(SDValue Chain, uint64_t InSize, uint64_t OutSize, const SDLoc &DL)
Return a new CALLSEQ_START node, that starts new call frame, in which InSize bytes are set up inside ...
void RemoveDeadNode(SDNode *N)
Remove the specified node from the system.
SDValue getBasicBlock(MachineBasicBlock *MBB)
SDValue getAnyExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of integer type, to the integer type VT, by either any-extending or truncat...
SDValue getSelectCC(const SDLoc &DL, SDValue LHS, SDValue RHS, SDValue True, SDValue False, ISD::CondCode Cond)
Helper function to make it easier to build SelectCC's if you just have an ISD::CondCode instead of an...
SDValue getIntPtrConstant(uint64_t Val, const SDLoc &DL, bool isTarget=false)
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, ArrayRef< SDUse > Ops)
Gets or creates the specified node.
SDValue getFPExtendOrRound(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of float type, to the float type VT, by either extending or rounding (by tr...
SDValue getTargetConstant(uint64_t Val, const SDLoc &DL, EVT VT, bool isOpaque=false)
Definition: SelectionDAG.h:700
MachineFunction & getMachineFunction() const
Definition: SelectionDAG.h:492
SDValue getZExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of integer type, to the integer type VT, by either zero-extending or trunca...
LLVMContext * getContext() const
Definition: SelectionDAG.h:510
const SDValue & setRoot(SDValue N)
Set the current root tag of the SelectionDAG.
Definition: SelectionDAG.h:586
SDValue getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, SDVTList VTList, ArrayRef< SDValue > Ops, EVT MemVT, MachinePointerInfo PtrInfo, Align Alignment, MachineMemOperand::Flags Flags=MachineMemOperand::MOLoad|MachineMemOperand::MOStore, LocationSize Size=0, const AAMDNodes &AAInfo=AAMDNodes())
Creates a MemIntrinsicNode that may produce a result and takes a list of operands.
SDValue getTargetExternalSymbol(const char *Sym, EVT VT, unsigned TargetFlags=0)
SDValue getEntryNode() const
Return the token chain corresponding to the entry of the function.
Definition: SelectionDAG.h:580
This SDNode is used to implement the code generator support for the llvm IR shufflevector instruction...
ArrayRef< int > getMask() const
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void assign(size_type NumElts, ValueParamT Elt)
Definition: SmallVector.h:704
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
void resize(size_type N)
Definition: SmallVector.h:638
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
This class is used to represent ISD::STORE nodes.
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
constexpr size_t size() const
size - Get the string size.
Definition: StringRef.h:150
constexpr const char * data() const
data - Get a pointer to the start of the string (which may not be null terminated).
Definition: StringRef.h:144
Class to represent struct types.
Definition: DerivedTypes.h:218
void setBooleanVectorContents(BooleanContent Ty)
Specify how the target extends the result of a vector boolean value from a vector of i1 to a wider ty...
void setOperationAction(unsigned Op, MVT VT, LegalizeAction Action)
Indicate that the specified operation does not work with the specified type and indicate what to do a...
void setMaxDivRemBitWidthSupported(unsigned SizeInBits)
Set the size in bits of the maximum div/rem the backend supports.
EVT getValueType(const DataLayout &DL, Type *Ty, bool AllowUnknown=false) const
Return the EVT corresponding to this LLVM type.
LegalizeAction
This enum indicates whether operations are valid for a target, and if not, what action should be used...
unsigned MaxStoresPerMemcpyOptSize
Likewise for functions with the OptSize attribute.
virtual const TargetRegisterClass * getRegClassFor(MVT VT, bool isDivergent=false) const
Return the register class that should be used for the specified value type.
const TargetMachine & getTargetMachine() const
LegalizeTypeAction
This enum indicates whether a types are legal for a target, and if not, what action should be used to...
void addBypassSlowDiv(unsigned int SlowBitWidth, unsigned int FastBitWidth)
Tells the code generator which bitwidths to bypass.
virtual unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const
Return the number of registers that this ValueType will eventually require.
void setMaxAtomicSizeInBitsSupported(unsigned SizeInBits)
Set the maximum atomic operation size supported by the backend.
virtual TargetLoweringBase::LegalizeTypeAction getPreferredVectorAction(MVT VT) const
Return the preferred vector type legalization action.
unsigned MaxStoresPerMemsetOptSize
Likewise for functions with the OptSize attribute.
void setBooleanContents(BooleanContent Ty)
Specify how the target extends the result of integer and floating point boolean values from i1 to a w...
unsigned MaxStoresPerMemmove
Specify maximum number of store instructions per memmove call.
void computeRegisterProperties(const TargetRegisterInfo *TRI)
Once all of the register classes are added, this allows us to compute derived properties we expose.
unsigned MaxStoresPerMemmoveOptSize
Likewise for functions with the OptSize attribute.
void addRegisterClass(MVT VT, const TargetRegisterClass *RC)
Add the specified register class as an available regclass for the specified value type.
virtual MVT getPointerTy(const DataLayout &DL, uint32_t AS=0) const
Return the pointer type for the given address space, defaults to the pointer type from the data layou...
unsigned MaxStoresPerMemset
Specify maximum number of store instructions per memset call.
void setTruncStoreAction(MVT ValVT, MVT MemVT, LegalizeAction Action)
Indicate that the specified truncating store does not work with the specified type and indicate what ...
void setMinCmpXchgSizeInBits(unsigned SizeInBits)
Sets the minimum cmpxchg or ll/sc size supported by the backend.
void AddPromotedToType(unsigned Opc, MVT OrigVT, MVT DestVT)
If Opc/OrigVT is specified as being promoted, the promotion code defaults to trying a larger integer/...
AtomicExpansionKind
Enum that specifies what an atomic load/AtomicRMWInst is expanded to, if at all.
void setCondCodeAction(ArrayRef< ISD::CondCode > CCs, MVT VT, LegalizeAction Action)
Indicate that the specified condition code is or isn't supported on the target and indicate what to d...
void setTargetDAGCombine(ArrayRef< ISD::NodeType > NTs)
Targets should invoke this method for each target independent node that they want to provide a custom...
Align getMinStackArgumentAlignment() const
Return the minimum stack alignment of an argument.
void setLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT, LegalizeAction Action)
Indicate that the specified load with extension does not work with the specified type and indicate wh...
std::vector< ArgListEntry > ArgListTy
bool allowsMemoryAccessForAlignment(LLVMContext &Context, const DataLayout &DL, EVT VT, unsigned AddrSpace=0, Align Alignment=Align(1), MachineMemOperand::Flags Flags=MachineMemOperand::MONone, unsigned *Fast=nullptr) const
This function returns true if the memory access is aligned or if the target allows this specific unal...
unsigned MaxStoresPerMemcpy
Specify maximum number of store instructions per memcpy call.
void setSchedulingPreference(Sched::Preference Pref)
Specify the target scheduling preference.
void setJumpIsExpensive(bool isExpensive=true)
Tells the code generator not to expand logic operations on comparison predicates into separate sequen...
LegalizeAction getOperationAction(unsigned Op, EVT VT) const
Return how this operation should be treated: either it is legal, needs to be promoted to a larger siz...
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
SDValue expandUnalignedStore(StoreSDNode *ST, SelectionDAG &DAG) const
Expands an unaligned store to 2 half-size stores for integer values, and possibly more for vectors.
virtual ConstraintType getConstraintType(StringRef Constraint) const
Given a constraint, return the type of constraint it is for this target.
std::pair< SDValue, SDValue > expandUnalignedLoad(LoadSDNode *LD, SelectionDAG &DAG) const
Expands an unaligned load to 2 half-size loads for an integer, and possibly more for vectors.
virtual std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const
Given a physical register constraint (e.g.
SDValue expandRoundInexactToOdd(EVT ResultVT, SDValue Op, const SDLoc &DL, SelectionDAG &DAG) const
Truncate Op to ResultVT.
SDValue expandFP_ROUND(SDNode *Node, SelectionDAG &DAG) const
Expand round(fp) to fp conversion.
virtual void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint, std::vector< SDValue > &Ops, SelectionDAG &DAG) const
Lower the specified operand into the Ops vector.
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
CodeGenOptLevel getOptLevel() const
Returns the optimization level: None, Less, Default, or Aggressive.
TargetOptions Options
MCSymbol * getSymbol(const GlobalValue *GV) const
unsigned UnsafeFPMath
UnsafeFPMath - This flag is enabled when the -enable-unsafe-fp-math flag is specified on the command ...
FPOpFusion::FPOpFusionMode AllowFPOpFusion
AllowFPOpFusion - This flag is set by the -fp-contract=xxx option.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:270
bool isFloatTy() const
Return true if this is 'float', a 32-bit IEEE fp type.
Definition: Type.h:153
bool isBFloatTy() const
Return true if this is 'bfloat', a 16-bit bfloat type.
Definition: Type.h:145
@ VoidTyID
type with no size
Definition: Type.h:63
bool isAggregateType() const
Return true if the type is an aggregate type.
Definition: Type.h:303
bool isHalfTy() const
Return true if this is 'half', a 16-bit IEEE fp type.
Definition: Type.h:142
bool isDoubleTy() const
Return true if this is 'double', a 64-bit IEEE fp type.
Definition: Type.h:156
bool isFloatingPointTy() const
Return true if this is one of the floating-point types.
Definition: Type.h:184
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:237
TypeID getTypeID() const
Return the type id for the type.
Definition: Type.h:136
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
StringRef save(const char *S)
Definition: StringSaver.h:52
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
int getNumOccurrences() const
Definition: CommandLine.h:399
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:661
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
NodeType
ISD::NodeType enum - This enum defines the target-independent operators for a SelectionDAG.
Definition: ISDOpcodes.h:40
@ SETCC
SetCC operator - This evaluates to a true value iff the condition is true.
Definition: ISDOpcodes.h:780
@ STACKRESTORE
STACKRESTORE has two operands, an input chain and a pointer to restore to it returns an output chain.
Definition: ISDOpcodes.h:1197
@ STACKSAVE
STACKSAVE - STACKSAVE has one operand, an input chain.
Definition: ISDOpcodes.h:1193
@ SMUL_LOHI
SMUL_LOHI/UMUL_LOHI - Multiply two integers of type iN, producing a signed/unsigned value of type i[2...
Definition: ISDOpcodes.h:257
@ BSWAP
Byte Swap and Counting operators.
Definition: ISDOpcodes.h:744
@ VAEND
VAEND, VASTART - VAEND and VASTART have three operands: an input chain, pointer, and a SRCVALUE.
Definition: ISDOpcodes.h:1226
@ ConstantFP
Definition: ISDOpcodes.h:77
@ ADDC
Carry-setting nodes for multiple precision addition and subtraction.
Definition: ISDOpcodes.h:276
@ ADD
Simple integer binary arithmetic operators.
Definition: ISDOpcodes.h:246
@ LOAD
LOAD and STORE have token chains as their first operand, then the same operands as an LLVM load/store...
Definition: ISDOpcodes.h:1102
@ ANY_EXTEND
ANY_EXTEND - Used for integer types. The high bits are undefined.
Definition: ISDOpcodes.h:814
@ FMA
FMA - Perform a * b + c with no intermediate rounding step.
Definition: ISDOpcodes.h:498
@ INTRINSIC_VOID
OUTCHAIN = INTRINSIC_VOID(INCHAIN, INTRINSICID, arg1, arg2, ...) This node represents a target intrin...
Definition: ISDOpcodes.h:205
@ GlobalAddress
Definition: ISDOpcodes.h:78
@ SINT_TO_FP
[SU]INT_TO_FP - These operators convert integers (whose interpreted sign depends on the first letter)...
Definition: ISDOpcodes.h:841
@ CONCAT_VECTORS
CONCAT_VECTORS(VECTOR0, VECTOR1, ...) - Given a number of values of vector type with the same length ...
Definition: ISDOpcodes.h:558
@ FADD
Simple binary floating point operators.
Definition: ISDOpcodes.h:397
@ ABS
ABS - Determine the unsigned absolute value of a signed integer value of the same bitwidth.
Definition: ISDOpcodes.h:717
@ SDIVREM
SDIVREM/UDIVREM - Divide two integers and produce both a quotient and remainder result.
Definition: ISDOpcodes.h:262
@ BITCAST
BITCAST - This operator converts between integer, vector and FP values, as if the value was stored to...
Definition: ISDOpcodes.h:954
@ BUILD_PAIR
BUILD_PAIR - This is the opposite of EXTRACT_ELEMENT in some ways.
Definition: ISDOpcodes.h:236
@ SIGN_EXTEND
Conversion operators.
Definition: ISDOpcodes.h:805
@ READSTEADYCOUNTER
READSTEADYCOUNTER - This corresponds to the readfixedcounter intrinsic.
Definition: ISDOpcodes.h:1259
@ FNEG
Perform various unary floating-point operations inspired by libm.
Definition: ISDOpcodes.h:981
@ BR_CC
BR_CC - Conditional branch.
Definition: ISDOpcodes.h:1148
@ SSUBO
Same for subtraction.
Definition: ISDOpcodes.h:334
@ BRIND
BRIND - Indirect branch.
Definition: ISDOpcodes.h:1123
@ BR_JT
BR_JT - Jumptable branch.
Definition: ISDOpcodes.h:1127
@ SSUBSAT
RESULT = [US]SUBSAT(LHS, RHS) - Perform saturation subtraction on 2 integers with the same bit width ...
Definition: ISDOpcodes.h:356
@ SELECT
Select(COND, TRUEVAL, FALSEVAL).
Definition: ISDOpcodes.h:757
@ UNDEF
UNDEF - An undefined node.
Definition: ISDOpcodes.h:218
@ VACOPY
VACOPY - VACOPY has 5 operands: an input chain, a destination pointer, a source pointer,...
Definition: ISDOpcodes.h:1222
@ CopyFromReg
CopyFromReg - This node indicates that the input value is a virtual or physical register that is defi...
Definition: ISDOpcodes.h:215
@ SADDO
RESULT, BOOL = [SU]ADDO(LHS, RHS) - Overflow-aware nodes for addition.
Definition: ISDOpcodes.h:330
@ MULHU
MULHU/MULHS - Multiply high - Multiply two integers of type iN, producing an unsigned/signed value of...
Definition: ISDOpcodes.h:674
@ SHL
Shift and rotation operations.
Definition: ISDOpcodes.h:735
@ VECTOR_SHUFFLE
VECTOR_SHUFFLE(VEC1, VEC2) - Returns a vector, of the same type as VEC1/VEC2.
Definition: ISDOpcodes.h:615
@ EXTRACT_SUBVECTOR
EXTRACT_SUBVECTOR(VECTOR, IDX) - Returns a subvector from VECTOR.
Definition: ISDOpcodes.h:588
@ FMINNUM_IEEE
FMINNUM_IEEE/FMAXNUM_IEEE - Perform floating-point minimumNumber or maximumNumber on two values,...
Definition: ISDOpcodes.h:1044
@ EXTRACT_VECTOR_ELT
EXTRACT_VECTOR_ELT(VECTOR, IDX) - Returns a single element from VECTOR identified by the (potentially...
Definition: ISDOpcodes.h:550
@ CopyToReg
CopyToReg - This node has three operands: a chain, a register number to set to this value,...
Definition: ISDOpcodes.h:209
@ ZERO_EXTEND
ZERO_EXTEND - Used for integer types, zeroing the new bits.
Definition: ISDOpcodes.h:811
@ DEBUGTRAP
DEBUGTRAP - Trap intended to get the attention of a debugger.
Definition: ISDOpcodes.h:1282
@ SELECT_CC
Select with condition operator - This selects between a true value and a false value (ops #2 and #3) ...
Definition: ISDOpcodes.h:772
@ FMINNUM
FMINNUM/FMAXNUM - Perform floating-point minimum or maximum on two values.
Definition: ISDOpcodes.h:1031
@ SSHLSAT
RESULT = [US]SHLSAT(LHS, RHS) - Perform saturation left shift.
Definition: ISDOpcodes.h:366
@ SMULO
Same for multiplication.
Definition: ISDOpcodes.h:338
@ DYNAMIC_STACKALLOC
DYNAMIC_STACKALLOC - Allocate some number of bytes on the stack aligned to a specified boundary.
Definition: ISDOpcodes.h:1112
@ SIGN_EXTEND_INREG
SIGN_EXTEND_INREG - This operator atomically performs a SHL/SRA pair to sign extend a small value in ...
Definition: ISDOpcodes.h:849
@ SMIN
[US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned integers.
Definition: ISDOpcodes.h:697
@ FP_EXTEND
X = FP_EXTEND(Y) - Extend a smaller FP type into a larger FP type.
Definition: ISDOpcodes.h:939
@ VSELECT
Select with a vector condition (op #0) and two vector operands (ops #1 and #2), returning a vector re...
Definition: ISDOpcodes.h:766
@ UADDO_CARRY
Carry-using nodes for multiple precision addition and subtraction.
Definition: ISDOpcodes.h:310
@ BF16_TO_FP
BF16_TO_FP, FP_TO_BF16 - These operators are used to perform promotions and truncation for bfloat16.
Definition: ISDOpcodes.h:973
@ FRAMEADDR
FRAMEADDR, RETURNADDR - These nodes represent llvm.frameaddress and llvm.returnaddress on the DAG.
Definition: ISDOpcodes.h:100
@ FMINIMUM
FMINIMUM/FMAXIMUM - NaN-propagating minimum/maximum that also treat -0.0 as less than 0....
Definition: ISDOpcodes.h:1050
@ FP_TO_SINT
FP_TO_[US]INT - Convert a floating point value to a signed or unsigned integer.
Definition: ISDOpcodes.h:887
@ READCYCLECOUNTER
READCYCLECOUNTER - This corresponds to the readcyclecounter intrinsic.
Definition: ISDOpcodes.h:1253
@ AND
Bitwise operators - logical and, logical or, logical xor.
Definition: ISDOpcodes.h:709
@ TRAP
TRAP - Trapping instruction.
Definition: ISDOpcodes.h:1279
@ INTRINSIC_WO_CHAIN
RESULT = INTRINSIC_WO_CHAIN(INTRINSICID, arg1, arg2, ...) This node represents a target intrinsic fun...
Definition: ISDOpcodes.h:190
@ ADDE
Carry-using nodes for multiple precision addition and subtraction.
Definition: ISDOpcodes.h:286
@ FREEZE
FREEZE - FREEZE(VAL) returns an arbitrary value if VAL is UNDEF (or is evaluated to UNDEF),...
Definition: ISDOpcodes.h:223
@ INSERT_VECTOR_ELT
INSERT_VECTOR_ELT(VECTOR, VAL, IDX) - Returns VECTOR with the element at IDX replaced with VAL.
Definition: ISDOpcodes.h:539
@ TokenFactor
TokenFactor - This node takes multiple tokens as input and produces a single token result.
Definition: ISDOpcodes.h:52
@ FP_ROUND
X = FP_ROUND(Y, TRUNC) - Rounding 'Y' from a larger floating point type down to the precision of the ...
Definition: ISDOpcodes.h:920
@ TRUNCATE
TRUNCATE - Completely drop the high bits.
Definition: ISDOpcodes.h:817
@ VAARG
VAARG - VAARG has four operands: an input chain, a pointer, a SRCVALUE, and the alignment.
Definition: ISDOpcodes.h:1217
@ SHL_PARTS
SHL_PARTS/SRA_PARTS/SRL_PARTS - These operators are used for expanded integer shift operations.
Definition: ISDOpcodes.h:794
@ FCOPYSIGN
FCOPYSIGN(X, Y) - Return the value of X with the sign of Y.
Definition: ISDOpcodes.h:508
@ SADDSAT
RESULT = [US]ADDSAT(LHS, RHS) - Perform saturation addition on 2 integers with the same bit width (W)...
Definition: ISDOpcodes.h:347
@ SADDO_CARRY
Carry-using overflow-aware nodes for multiple precision addition and subtraction.
Definition: ISDOpcodes.h:320
@ INTRINSIC_W_CHAIN
RESULT,OUTCHAIN = INTRINSIC_W_CHAIN(INCHAIN, INTRINSICID, arg1, ...) This node represents a target in...
Definition: ISDOpcodes.h:198
@ BUILD_VECTOR
BUILD_VECTOR(ELT0, ELT1, ELT2, ELT3,...) - Return a fixed-width vector with the specified,...
Definition: ISDOpcodes.h:530
bool allOperandsUndef(const SDNode *N)
Return true if the node has at least one operand and all operands of the specified node are ISD::UNDE...
@ Bitcast
Perform the operation on a different, but equivalently sized type.
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:480
static bool isIndirectCall(const MachineInstr &MI)
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1697
bool Isv2x16VT(EVT VT)
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition: STLExtras.h:2448
MaybeAlign getAlign(const Function &F, unsigned Index)
uint64_t PowerOf2Ceil(uint64_t A)
Returns the power of two which is greater than or equal to the given value.
Definition: MathExtras.h:396
OutputIt transform(R &&Range, OutputIt d_first, UnaryFunction F)
Wrapper function around std::transform to apply a function to a range and store the result elsewhere.
Definition: STLExtras.h:1952
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
Definition: MathExtras.h:293
unsigned promoteScalarArgumentSize(unsigned size)
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:167
CodeGenOptLevel
Code generation optimization level.
Definition: CodeGen.h:54
@ Mul
Product of integers.
@ Add
Sum of integers.
uint64_t alignTo(uint64_t Size, Align A)
Returns a multiple of A needed to store Size bytes.
Definition: Alignment.h:155
DWARFExpression::Operation Op
void ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, Type *Ty, SmallVectorImpl< EVT > &ValueVTs, SmallVectorImpl< EVT > *MemVTs, SmallVectorImpl< TypeSize > *Offsets=nullptr, TypeSize StartingOffset=TypeSize::getZero())
ComputeValueVTs - Given an LLVM IR type, compute a sequence of EVTs that represent all the individual...
Definition: Analysis.cpp:79
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:217
bool isKernelFunction(const Function &F)
Function * getMaybeBitcastedCallee(const CallBase *CB)
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
static const fltSemantics & IEEEsingle() LLVM_READNONE
Definition: APFloat.cpp:257
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
uint64_t value() const
This is a hole in the type system and should not be abused.
Definition: Alignment.h:85
@ PreserveSign
The sign of a flushed-to-zero number is preserved in the sign of 0.
DenormalModeKind Output
Denormal flushing mode for floating point instruction results in the default floating point environme...
Extended Value Type.
Definition: ValueTypes.h:35
TypeSize getStoreSize() const
Return the number of bytes overwritten by a store of the specified value type.
Definition: ValueTypes.h:390
bool isSimple() const
Test if the given EVT is simple (as opposed to being extended).
Definition: ValueTypes.h:137
static EVT getVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements, bool IsScalable=false)
Returns the EVT that represents a vector NumElements in length, where each element is of type VT.
Definition: ValueTypes.h:74
EVT changeTypeToInteger() const
Return the type converted to an equivalently sized integer or vector with integer element type.
Definition: ValueTypes.h:121
bool isFloatingPoint() const
Return true if this is a FP or a vector FP type.
Definition: ValueTypes.h:147
ElementCount getVectorElementCount() const
Definition: ValueTypes.h:345
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
Definition: ValueTypes.h:368
uint64_t getScalarSizeInBits() const
Definition: ValueTypes.h:380
MVT getSimpleVT() const
Return the SimpleValueType held in the specified simple EVT.
Definition: ValueTypes.h:311
uint64_t getFixedSizeInBits() const
Return the size of the specified fixed width value type in bits.
Definition: ValueTypes.h:376
bool isVector() const
Return true if this is a vector value type.
Definition: ValueTypes.h:168
EVT getScalarType() const
If this is a vector type, return the element type, otherwise return this.
Definition: ValueTypes.h:318
bool bitsEq(EVT VT) const
Return true if this has the same number of bits as VT.
Definition: ValueTypes.h:251
Type * getTypeForEVT(LLVMContext &Context) const
This method returns an LLVM type corresponding to the specified EVT.
Definition: ValueTypes.cpp:210
EVT getVectorElementType() const
Given a vector type, return the type of each element.
Definition: ValueTypes.h:323
bool isScalarInteger() const
Return true if this is an integer, but not a vector.
Definition: ValueTypes.h:157
EVT changeVectorElementType(EVT EltVT) const
Return a VT for a vector type whose attributes match ourselves with the exception of the element type...
Definition: ValueTypes.h:102
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
Definition: ValueTypes.h:331
bool isInteger() const
Return true if this is an integer or a vector integer type.
Definition: ValueTypes.h:152
This class contains a discriminated union of information about pointers in memory operands,...
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
This represents a list of ValueType's that has been intern'd by a SelectionDAG.
This represents an addressing mode of: BaseGV + BaseOffs + BaseReg + Scale*ScaleReg + ScalableOffset*...
This structure contains all information that is necessary for lowering calls.
SmallVector< ISD::InputArg, 32 > Ins
SmallVector< ISD::OutputArg, 32 > Outs
SmallVector< SDValue, 32 > OutVals
SDValue CombineTo(SDNode *N, ArrayRef< SDValue > To, bool AddTo=true)