Line data Source code
1 : //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 : //
10 : // This file implements utilities for working with "normalized" expressions.
11 : // See the comments at the top of ScalarEvolutionNormalization.h for details.
12 : //
13 : //===----------------------------------------------------------------------===//
14 :
15 : #include "llvm/Analysis/ScalarEvolutionNormalization.h"
16 : #include "llvm/Analysis/LoopInfo.h"
17 : #include "llvm/Analysis/ScalarEvolutionExpressions.h"
18 : using namespace llvm;
19 :
20 : /// TransformKind - Different types of transformations that
21 : /// TransformForPostIncUse can do.
22 : enum TransformKind {
23 : /// Normalize - Normalize according to the given loops.
24 : Normalize,
25 : /// Denormalize - Perform the inverse transform on the expression with the
26 : /// given loop set.
27 : Denormalize
28 : };
29 :
30 : namespace {
31 : struct NormalizeDenormalizeRewriter
32 : : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
33 : const TransformKind Kind;
34 :
35 : // NB! Pred is a function_ref. Storing it here is okay only because
36 : // we're careful about the lifetime of NormalizeDenormalizeRewriter.
37 : const NormalizePredTy Pred;
38 :
39 : NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
40 : ScalarEvolution &SE)
41 63474 : : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
42 82932 : Pred(Pred) {}
43 : const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
44 : };
45 : } // namespace
46 :
47 : const SCEV *
48 78391 : NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
49 : SmallVector<const SCEV *, 8> Operands;
50 :
51 : transform(AR->operands(), std::back_inserter(Operands),
52 158591 : [&](const SCEV *Op) { return visit(Op); });
53 :
54 78391 : if (!Pred(AR))
55 59240 : return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
56 :
57 : // Normalization and denormalization are fancy names for decrementing and
58 : // incrementing a SCEV expression with respect to a set of loops. Since
59 : // Pred(AR) has returned true, we know we need to normalize or denormalize AR
60 : // with respect to its loop.
61 :
62 19151 : if (Kind == Denormalize) {
63 : // Denormalization / "partial increment" is essentially the same as \c
64 : // SCEVAddRecExpr::getPostIncExpr. Here we use an explicit loop to make the
65 : // symmetry with Normalization clear.
66 11432 : for (int i = 0, e = Operands.size() - 1; i < e; i++)
67 17811 : Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]);
68 : } else {
69 : assert(Kind == Normalize && "Only two possibilities!");
70 :
71 : // Normalization / "partial decrement" is a bit more subtle. Since
72 : // incrementing a SCEV expression (in general) changes the step of the SCEV
73 : // expression as well, we cannot use the step of the current expression.
74 : // Instead, we have to use the step of the very expression we're trying to
75 : // compute!
76 : //
77 : // We solve the issue by recursively building up the result, starting from
78 : // the "least significant" operand in the add recurrence:
79 : //
80 : // Base case:
81 : // Single operand add recurrence. It's its own normalization.
82 : //
83 : // N-operand case:
84 : // {S_{N-1},+,S_{N-2},+,...,+,S_0} = S
85 : //
86 : // Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its
87 : // normalization by induction. We subtract the normalized step
88 : // recurrence from S_{N-1} to get the normalization of S.
89 :
90 27766 : for (int i = Operands.size() - 2; i >= 0; i--)
91 42330 : Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]);
92 : }
93 :
94 19151 : return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
95 : }
96 :
97 39854 : const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
98 : const PostIncLoopSet &Loops,
99 : ScalarEvolution &SE) {
100 : auto Pred = [&](const SCEVAddRecExpr *AR) {
101 39769 : return Loops.count(AR->getLoop());
102 39854 : };
103 39854 : return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
104 : }
105 :
106 19458 : const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
107 : ScalarEvolution &SE) {
108 19458 : return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
109 : }
110 :
111 23620 : const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
112 : const PostIncLoopSet &Loops,
113 : ScalarEvolution &SE) {
114 : auto Pred = [&](const SCEVAddRecExpr *AR) {
115 18215 : return Loops.count(AR->getLoop());
116 23620 : };
117 23620 : return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
118 : }
|