Clang Project

clang_source_code/lib/StaticAnalyzer/Core/LoopUnrolling.cpp
1//===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// This file contains functions which are used to decide if a loop worth to be
10/// unrolled. Moreover, these functions manages the stack of loop which is
11/// tracked by the ProgramState.
12///
13//===----------------------------------------------------------------------===//
14
15#include "clang/ASTMatchers/ASTMatchers.h"
16#include "clang/ASTMatchers/ASTMatchFinder.h"
17#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19#include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
20
21using namespace clang;
22using namespace ento;
23using namespace clang::ast_matchers;
24
25static const int MAXIMUM_STEP_UNROLLED = 128;
26
27struct LoopState {
28private:
29  enum Kind { NormalUnrolled } K;
30  const Stmt *LoopStmt;
31  const LocationContext *LCtx;
32  unsigned maxStep;
33  LoopState(Kind InKconst Stmt *Sconst LocationContext *Lunsigned N)
34      : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
35
36public:
37  static LoopState getNormal(const Stmt *Sconst LocationContext *L,
38                             unsigned N) {
39    return LoopState(NormalSLN);
40  }
41  static LoopState getUnrolled(const Stmt *Sconst LocationContext *L,
42                               unsigned N) {
43    return LoopState(UnrolledSLN);
44  }
45  bool isUnrolled() const { return K == Unrolled; }
46  unsigned getMaxStep() const { return maxStep; }
47  const Stmt *getLoopStmt() const { return LoopStmt; }
48  const LocationContext *getLocationContext() const { return LCtx; }
49  bool operator==(const LoopState &Xconst {
50    return K == X.K && LoopStmt == X.LoopStmt;
51  }
52  void Profile(llvm::FoldingSetNodeID &IDconst {
53    ID.AddInteger(K);
54    ID.AddPointer(LoopStmt);
55    ID.AddPointer(LCtx);
56    ID.AddInteger(maxStep);
57  }
58};
59
60// The tracked stack of loops. The stack indicates that which loops the
61// simulated element contained by. The loops are marked depending if we decided
62// to unroll them.
63// TODO: The loop stack should not need to be in the program state since it is
64// lexical in nature. Instead, the stack of loops should be tracked in the
65// LocationContext.
66REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
67
68namespace clang {
69namespace ento {
70
71static bool isLoopStmt(const Stmt *S) {
72  return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
73}
74
75ProgramStateRef processLoopEnd(const Stmt *LoopStmtProgramStateRef State) {
76  auto LS = State->get<LoopStack>();
77  if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
78    State = State->set<LoopStack>(LS.getTail());
79  return State;
80}
81
82static internal::Matcher<StmtsimpleCondition(StringRef BindName) {
83  return binaryOperator(anyOf(hasOperatorName("<"), hasOperatorName(">"),
84                              hasOperatorName("<="), hasOperatorName(">="),
85                              hasOperatorName("!=")),
86                        hasEitherOperand(ignoringParenImpCasts(declRefExpr(
87                            to(varDecl(hasType(isInteger())).bind(BindName))))),
88                        hasEitherOperand(ignoringParenImpCasts(
89                            integerLiteral().bind("boundNum"))))
90      .bind("conditionOperator");
91}
92
93static internal::Matcher<Stmt>
94changeIntBoundNode(internal::Matcher<DeclVarNodeMatcher) {
95  return anyOf(
96      unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
97                    hasUnaryOperand(ignoringParenImpCasts(
98                        declRefExpr(to(varDecl(VarNodeMatcher)))))),
99      binaryOperator(isAssignmentOperator(),
100                     hasLHS(ignoringParenImpCasts(
101                         declRefExpr(to(varDecl(VarNodeMatcher)))))));
102}
103
104static internal::Matcher<Stmt>
105callByRef(internal::Matcher<DeclVarNodeMatcher) {
106  return callExpr(forEachArgumentWithParam(
107      declRefExpr(to(varDecl(VarNodeMatcher))),
108      parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
109}
110
111static internal::Matcher<Stmt>
112assignedToRef(internal::Matcher<DeclVarNodeMatcher) {
113  return declStmt(hasDescendant(varDecl(
114      allOf(hasType(referenceType()),
115            hasInitializer(anyOf(
116                initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
117                declRefExpr(to(varDecl(VarNodeMatcher)))))))));
118}
119
120static internal::Matcher<Stmt>
121getAddrTo(internal::Matcher<DeclVarNodeMatcher) {
122  return unaryOperator(
123      hasOperatorName("&"),
124      hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
125}
126
127static internal::Matcher<StmthasSuspiciousStmt(StringRef NodeName) {
128  return hasDescendant(stmt(
129      anyOf(gotoStmt(), switchStmt(), returnStmt(),
130            // Escaping and not known mutation of the loop counter is handled
131            // by exclusion of assigning and address-of operators and
132            // pass-by-ref function calls on the loop counter from the body.
133            changeIntBoundNode(equalsBoundNode(NodeName)),
134            callByRef(equalsBoundNode(NodeName)),
135            getAddrTo(equalsBoundNode(NodeName)),
136            assignedToRef(equalsBoundNode(NodeName)))));
137}
138
139static internal::Matcher<StmtforLoopMatcher() {
140  return forStmt(
141             hasCondition(simpleCondition("initVarName")),
142             // Initialization should match the form: 'int i = 6' or 'i = 42'.
143             hasLoopInit(
144                 anyOf(declStmt(hasSingleDecl(
145                           varDecl(allOf(hasInitializer(ignoringParenImpCasts(
146                                             integerLiteral().bind("initNum"))),
147                                         equalsBoundNode("initVarName"))))),
148                       binaryOperator(hasLHS(declRefExpr(to(varDecl(
149                                          equalsBoundNode("initVarName"))))),
150                                      hasRHS(ignoringParenImpCasts(
151                                          integerLiteral().bind("initNum")))))),
152             // Incrementation should be a simple increment or decrement
153             // operator call.
154             hasIncrement(unaryOperator(
155                 anyOf(hasOperatorName("++"), hasOperatorName("--")),
156                 hasUnaryOperand(declRefExpr(
157                     to(varDecl(allOf(equalsBoundNode("initVarName"),
158                                      hasType(isInteger())))))))),
159             unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop");
160}
161
162static bool isPossiblyEscaped(const VarDecl *VDExplodedNode *N) {
163  // Global variables assumed as escaped variables.
164  if (VD->hasGlobalStorage())
165    return true;
166
167  while (!N->pred_empty()) {
168    const Stmt *S = PathDiagnosticLocation::getStmt(N);
169    if (!S) {
170      N = N->getFirstPred();
171      continue;
172    }
173
174    if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
175      for (const Decl *D : DS->decls()) {
176        // Once we reach the declaration of the VD we can return.
177        if (D->getCanonicalDecl() == VD)
178          return false;
179      }
180    }
181    // Check the usage of the pass-by-ref function calls and adress-of operator
182    // on VD and reference initialized by VD.
183    ASTContext &ASTCtx =
184        N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
185    auto Match =
186        match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
187                         assignedToRef(equalsNode(VD)))),
188              *SASTCtx);
189    if (!Match.empty())
190      return true;
191
192    N = N->getFirstPred();
193  }
194  llvm_unreachable("Reached root without finding the declaration of VD");
195}
196
197bool shouldCompletelyUnroll(const Stmt *LoopStmtASTContext &ASTCtx,
198                            ExplodedNode *Predunsigned &maxStep) {
199
200  if (!isLoopStmt(LoopStmt))
201    return false;
202
203  // TODO: Match the cases where the bound is not a concrete literal but an
204  // integer with known value
205  auto Matches = match(forLoopMatcher(), *LoopStmtASTCtx);
206  if (Matches.empty())
207    return false;
208
209  auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName");
210  llvm::APInt BoundNum =
211      Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
212  llvm::APInt InitNum =
213      Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
214  auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
215  if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
216    InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth());
217    BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth());
218  }
219
220  if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
221    maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
222  else
223    maxStep = (BoundNum - InitNum).abs().getZExtValue();
224
225  // Check if the counter of the loop is not escaped before.
226  return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred);
227}
228
229bool madeNewBranch(ExplodedNode *Nconst Stmt *LoopStmt) {
230  const Stmt *S = nullptr;
231  while (!N->pred_empty()) {
232    if (N->succ_size() > 1)
233      return true;
234
235    ProgramPoint P = N->getLocation();
236    if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
237      S = BE->getBlock()->getTerminator();
238
239    if (S == LoopStmt)
240      return false;
241
242    N = N->getFirstPred();
243  }
244
245  llvm_unreachable("Reached root without encountering the previous step");
246}
247
248// updateLoopStack is called on every basic block, therefore it needs to be fast
249ProgramStateRef updateLoopStack(const Stmt *LoopStmtASTContext &ASTCtx,
250                                ExplodedNode *Predunsigned maxVisitOnPath) {
251  auto State = Pred->getState();
252  auto LCtx = Pred->getLocationContext();
253
254  if (!isLoopStmt(LoopStmt))
255    return State;
256
257  auto LS = State->get<LoopStack>();
258  if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
259      LCtx == LS.getHead().getLocationContext()) {
260    if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
261      State = State->set<LoopStack>(LS.getTail());
262      State = State->add<LoopStack>(
263          LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
264    }
265    return State;
266  }
267  unsigned maxStep;
268  if (!shouldCompletelyUnroll(LoopStmtASTCtxPredmaxStep)) {
269    State = State->add<LoopStack>(
270        LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
271    return State;
272  }
273
274  unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
275
276  unsigned innerMaxStep = maxStep * outerStep;
277  if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
278    State = State->add<LoopStack>(
279        LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
280  else
281    State = State->add<LoopStack>(
282        LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
283  return State;
284}
285
286bool isUnrolledState(ProgramStateRef State) {
287  auto LS = State->get<LoopStack>();
288  if (LS.isEmpty() || !LS.getHead().isUnrolled())
289    return false;
290  return true;
291}
292}
293}
294
LoopState::Kind