39#define DEBUG_TYPE "aggressive-instcombine"
41STATISTIC(NumExprsReduced,
"Number of truncations eliminated by reducing bit "
42 "width of expression graph");
44 "Number of instructions whose bit width was reduced");
49 unsigned Opc =
I->getOpcode();
51 case Instruction::Trunc:
52 case Instruction::ZExt:
53 case Instruction::SExt:
57 case Instruction::Add:
58 case Instruction::Sub:
59 case Instruction::Mul:
60 case Instruction::And:
62 case Instruction::Xor:
63 case Instruction::Shl:
64 case Instruction::LShr:
65 case Instruction::AShr:
66 case Instruction::UDiv:
67 case Instruction::URem:
68 case Instruction::InsertElement:
72 case Instruction::ExtractElement:
75 case Instruction::Select:
79 case Instruction::PHI:
80 for (
Value *V : cast<PHINode>(
I)->incoming_values())
88bool TruncInstCombine::buildTruncExpressionGraph() {
96 while (!Worklist.
empty()) {
99 if (isa<Constant>(Curr)) {
104 auto *
I = dyn_cast<Instruction>(Curr);
114 InstInfoMap.insert(std::make_pair(
I,
Info()));
118 if (InstInfoMap.count(
I)) {
126 unsigned Opc =
I->getOpcode();
128 case Instruction::Trunc:
129 case Instruction::ZExt:
130 case Instruction::SExt:
136 case Instruction::Add:
137 case Instruction::Sub:
138 case Instruction::Mul:
139 case Instruction::And:
140 case Instruction::Or:
141 case Instruction::Xor:
142 case Instruction::Shl:
143 case Instruction::LShr:
144 case Instruction::AShr:
145 case Instruction::UDiv:
146 case Instruction::URem:
147 case Instruction::InsertElement:
148 case Instruction::ExtractElement:
149 case Instruction::Select: {
155 case Instruction::PHI: {
175unsigned TruncInstCombine::getMinBitWidth() {
182 unsigned OrigBitWidth =
185 if (isa<Constant>(Src))
186 return TruncBitWidth;
189 InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth;
191 while (!Worklist.
empty()) {
194 if (isa<Constant>(Curr)) {
200 auto *
I = cast<Instruction>(Curr);
202 auto &
Info = InstInfoMap[
I];
213 if (
auto *IOp = dyn_cast<Instruction>(Operand))
215 std::max(
Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth);
221 unsigned ValidBitWidth =
Info.ValidBitWidth;
225 Info.MinBitWidth = std::max(
Info.MinBitWidth,
Info.ValidBitWidth);
228 if (
auto *IOp = dyn_cast<Instruction>(Operand)) {
232 unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;
233 if (IOpBitwidth >= ValidBitWidth)
235 InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
239 unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth;
240 assert(MinBitWidth >= TruncBitWidth);
242 if (MinBitWidth > TruncBitWidth) {
258 bool FromLegal = MinBitWidth == 1 || DL.
isLegalInteger(OrigBitWidth);
259 bool ToLegal = MinBitWidth == 1 || DL.
isLegalInteger(MinBitWidth);
260 if (!DstTy->
isVectorTy() && FromLegal && !ToLegal)
266Type *TruncInstCombine::getBestTruncatedType() {
267 if (!buildTruncExpressionGraph())
274 unsigned DesiredBitWidth = 0;
275 for (
auto Itr : InstInfoMap) {
279 bool IsExtInst = (isa<ZExtInst>(
I) || isa<SExtInst>(
I));
280 for (
auto *U :
I->users())
281 if (
auto *UI = dyn_cast<Instruction>(U))
282 if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) {
288 unsigned ExtInstBitWidth =
289 I->getOperand(0)->getType()->getScalarSizeInBits();
290 if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth)
292 DesiredBitWidth = ExtInstBitWidth;
296 unsigned OrigBitWidth =
307 for (
auto &Itr : InstInfoMap) {
310 KnownBits KnownRHS = computeKnownBits(
I->getOperand(1));
314 if (MinBitWidth == OrigBitWidth)
316 if (
I->getOpcode() == Instruction::LShr) {
317 KnownBits KnownLHS = computeKnownBits(
I->getOperand(0));
321 if (
I->getOpcode() == Instruction::AShr) {
322 unsigned NumSignBits = ComputeNumSignBits(
I->getOperand(0));
323 MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);
325 if (MinBitWidth >= OrigBitWidth)
327 Itr.second.MinBitWidth = MinBitWidth;
329 if (
I->getOpcode() == Instruction::UDiv ||
330 I->getOpcode() == Instruction::URem) {
331 unsigned MinBitWidth = 0;
332 for (
const auto &
Op :
I->operands()) {
336 if (MinBitWidth >= OrigBitWidth)
339 Itr.second.MinBitWidth = MinBitWidth;
345 unsigned MinBitWidth = getMinBitWidth();
349 if (MinBitWidth >= OrigBitWidth ||
350 (DesiredBitWidth && DesiredBitWidth != MinBitWidth))
361 if (
auto *VTy = dyn_cast<VectorType>(V->getType()))
366Value *TruncInstCombine::getReducedOperand(
Value *V,
Type *SclTy) {
368 if (
auto *
C = dyn_cast<Constant>(V)) {
374 auto *
I = cast<Instruction>(V);
375 Info Entry = InstInfoMap.lookup(
I);
377 return Entry.NewValue;
380void TruncInstCombine::ReduceExpressionGraph(
Type *SclTy) {
381 NumInstrsReduced += InstInfoMap.size();
384 for (
auto &Itr : InstInfoMap) {
386 TruncInstCombine::Info &NodeInfo = Itr.second;
388 assert(!NodeInfo.NewValue &&
"Instruction has been evaluated");
391 Value *Res =
nullptr;
392 unsigned Opc =
I->getOpcode();
394 case Instruction::Trunc:
395 case Instruction::ZExt:
396 case Instruction::SExt: {
401 if (
I->getOperand(0)->getType() == Ty) {
402 assert(!isa<TruncInst>(
I) &&
"Cannot reach here with TruncInst");
403 NodeInfo.NewValue =
I->getOperand(0);
408 Res =
Builder.CreateIntCast(
I->getOperand(0), Ty,
409 Opc == Instruction::SExt);
416 auto *Entry =
find(Worklist,
I);
417 if (Entry != Worklist.
end()) {
418 if (
auto *NewCI = dyn_cast<TruncInst>(Res))
421 Worklist.
erase(Entry);
422 }
else if (
auto *NewCI = dyn_cast<TruncInst>(Res))
426 case Instruction::Add:
427 case Instruction::Sub:
428 case Instruction::Mul:
429 case Instruction::And:
430 case Instruction::Or:
431 case Instruction::Xor:
432 case Instruction::Shl:
433 case Instruction::LShr:
434 case Instruction::AShr:
435 case Instruction::UDiv:
436 case Instruction::URem: {
437 Value *
LHS = getReducedOperand(
I->getOperand(0), SclTy);
438 Value *
RHS = getReducedOperand(
I->getOperand(1), SclTy);
441 if (
auto *PEO = dyn_cast<PossiblyExactOperator>(
I))
442 if (
auto *ResI = dyn_cast<Instruction>(Res))
443 ResI->setIsExact(PEO->isExact());
446 case Instruction::ExtractElement: {
447 Value *Vec = getReducedOperand(
I->getOperand(0), SclTy);
452 case Instruction::InsertElement: {
453 Value *Vec = getReducedOperand(
I->getOperand(0), SclTy);
454 Value *NewElt = getReducedOperand(
I->getOperand(1), SclTy);
456 Res =
Builder.CreateInsertElement(Vec, NewElt,
Idx);
459 case Instruction::Select: {
460 Value *Op0 =
I->getOperand(0);
461 Value *
LHS = getReducedOperand(
I->getOperand(1), SclTy);
462 Value *
RHS = getReducedOperand(
I->getOperand(2), SclTy);
463 Res =
Builder.CreateSelect(Op0, LHS, RHS);
466 case Instruction::PHI: {
469 std::make_pair(cast<PHINode>(
I), cast<PHINode>(Res)));
476 NodeInfo.NewValue = Res;
477 if (
auto *ResI = dyn_cast<Instruction>(Res))
481 for (
auto &
Node : OldNewPHINodes) {
485 NewPN->
addIncoming(getReducedOperand(std::get<0>(Incoming), SclTy),
486 std::get<1>(Incoming));
489 Value *Res = getReducedOperand(CurrentTruncInst->
getOperand(0), SclTy);
493 Res =
Builder.CreateIntCast(Res, DstTy,
false);
494 if (
auto *ResI = dyn_cast<Instruction>(Res))
495 ResI->takeName(CurrentTruncInst);
503 for (
auto &
Node : OldNewPHINodes) {
506 InstInfoMap.erase(OldPN);
517 if (
I.first->use_empty())
518 I.first->eraseFromParent();
520 assert((isa<SExtInst>(
I.first) || isa<ZExtInst>(
I.first)) &&
521 "Only {SExt, ZExt}Inst might have unreduced users");
526 bool MadeIRChange =
false;
534 if (
auto *CI = dyn_cast<TruncInst>(&
I))
541 while (!Worklist.
empty()) {
544 if (
Type *NewDstSclTy = getBestTruncatedType()) {
546 dbgs() <<
"ICE: TruncInstCombine reducing type of expression graph "
548 << CurrentTruncInst <<
'\n');
549 ReduceExpressionGraph(NewDstSclTy);
Analysis containing CSE Info
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
mir Rename Register Operands
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static Type * getReducedType(Value *V, Type *Ty)
Given a reduced scalar type Ty and a V value, return a reduced type for V, according to its type,...
static void getRelevantOperands(Instruction *I, SmallVectorImpl< Value * > &Ops)
Given an instruction and a container, it fills all the relevant operands of that instruction,...
Class for arbitrary precision integers.
unsigned getActiveBits() const
Compute the number of active bits in the value.
uint64_t getLimitedValue(uint64_t Limit=UINT64_MAX) const
If this value is smaller than the specified limit, return it, otherwise return the limit value.
APInt uadd_sat(const APInt &RHS) const
static Constant * getIntegerCast(Constant *C, Type *Ty, bool IsSigned)
Create a ZExt, Bitcast or Trunc for integer -> integer casts.
This class represents an Operation in the Expression.
bool isLegalInteger(uint64_t Width) const
Returns true if the specified type is known to be a native integer type supported by the CPU.
Type * getSmallestLegalIntType(LLVMContext &C, unsigned Width=0) const
Returns the smallest integer type with size at least as big as Width bits.
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
iterator erase(const_iterator CI)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool run(Function &F)
Perform TruncInst pattern optimization on given function.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isVectorTy() const
True if this is an instance of VectorType.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVMContext & getContext() const
All values hold a context through their type.
void takeName(Value *V)
Transfer the name from V to this value.
static VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
detail::zippy< detail::zip_shortest, T, U, Args... > zip(T &&t, U &&u, Args &&...args)
zip iterator for two or more iteratable types.
auto find(R &&Range, const T &Val)
Provide wrappers to std::find which take ranges instead of having to pass begin/end explicitly.
void append_range(Container &C, Range &&R)
Wrapper function to append a range to a container.
Constant * ConstantFoldConstant(const Constant *C, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr)
ConstantFoldConstant - Fold the constant using the specified DataLayout.
auto reverse(ContainerTy &&C)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.