Clang Project

clang_source_code/unittests/Tooling/RefactoringActionRulesTest.cpp
1//===- unittest/Tooling/RefactoringTestActionRulesTest.cpp ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "ReplacementTest.h"
10#include "RewriterTestContext.h"
11#include "clang/Tooling/Refactoring.h"
12#include "clang/Tooling/Refactoring/Extract/Extract.h"
13#include "clang/Tooling/Refactoring/RefactoringAction.h"
14#include "clang/Tooling/Refactoring/RefactoringDiagnostic.h"
15#include "clang/Tooling/Refactoring/Rename/SymbolName.h"
16#include "clang/Tooling/Tooling.h"
17#include "llvm/Support/Errc.h"
18#include "gtest/gtest.h"
19
20using namespace clang;
21using namespace tooling;
22
23namespace {
24
25class RefactoringActionRulesTest : public ::testing::Test {
26protected:
27  void SetUp() override {
28    Context.Sources.setMainFileID(
29        Context.createInMemoryFile("input.cpp", DefaultCode));
30  }
31
32  RewriterTestContext Context;
33  std::string DefaultCode = std::string(100'a');
34};
35
36Expected<AtomicChanges>
37createReplacements(const std::unique_ptr<RefactoringActionRule> &Rule,
38                   RefactoringRuleContext &Context) {
39  class Consumer final : public RefactoringResultConsumer {
40    void handleError(llvm::Error Err) override { Result = std::move(Err); }
41
42    void handle(AtomicChanges SourceReplacements) override {
43      Result = std::move(SourceReplacements);
44    }
45    void handle(SymbolOccurrences Occurrences) override {
46      RefactoringResultConsumer::handle(std::move(Occurrences));
47    }
48
49  public:
50    Optional<Expected<AtomicChanges>> Result;
51  };
52
53  Consumer C;
54  Rule->invoke(C, Context);
55  return std::move(*C.Result);
56}
57
58TEST_F(RefactoringActionRulesTest, MyFirstRefactoringRule) {
59  class ReplaceAWithB : public SourceChangeRefactoringRule {
60    std::pair<SourceRange, intSelection;
61
62  public:
63    ReplaceAWithB(std::pair<SourceRange, intSelection)
64        : Selection(Selection) {}
65
66    static Expected<ReplaceAWithB>
67    initiate(RefactoringRuleContext &Cotnext,
68             std::pair<SourceRange, int> Selection) {
69      return ReplaceAWithB(Selection);
70    }
71
72    Expected<AtomicChanges>
73    createSourceReplacements(RefactoringRuleContext &Context) {
74      const SourceManager &SM = Context.getSources();
75      SourceLocation Loc =
76          Selection.first.getBegin().getLocWithOffset(Selection.second);
77      AtomicChange Change(SM, Loc);
78      llvm::Error E = Change.replace(SM, Loc, 1"b");
79      if (E)
80        return std::move(E);
81      return AtomicChanges{Change};
82    }
83  };
84
85  class SelectionRequirement : public SourceRangeSelectionRequirement {
86  public:
87    Expected<std::pair<SourceRange, int>>
88    evaluate(RefactoringRuleContext &Context) const {
89      Expected<SourceRange> R =
90          SourceRangeSelectionRequirement::evaluate(Context);
91      if (!R)
92        return R.takeError();
93      return std::make_pair(*R, 20);
94    }
95  };
96  auto Rule =
97      createRefactoringActionRule<ReplaceAWithB>(SelectionRequirement());
98
99  // When the requirements are satisfied, the rule's function must be invoked.
100  {
101    RefactoringRuleContext RefContext(Context.Sources);
102    SourceLocation Cursor =
103        Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID())
104            .getLocWithOffset(10);
105    RefContext.setSelectionRange({Cursor, Cursor});
106
107    Expected<AtomicChanges> ErrorOrResult =
108        createReplacements(Rule, RefContext);
109    ASSERT_FALSE(!ErrorOrResult);
110    AtomicChanges Result = std::move(*ErrorOrResult);
111    ASSERT_EQ(Result.size(), 1u);
112    std::string YAMLString =
113        const_cast<AtomicChange &>(Result[0]).toYAMLString();
114
115    ASSERT_STREQ("---\n"
116                 "Key:             'input.cpp:30'\n"
117                 "FilePath:        input.cpp\n"
118                 "Error:           ''\n"
119                 "InsertedHeaders: []\n"
120                 "RemovedHeaders:  []\n"
121                 "Replacements:    \n" // Extra whitespace here!
122                 "  - FilePath:        input.cpp\n"
123                 "    Offset:          30\n"
124                 "    Length:          1\n"
125                 "    ReplacementText: b\n"
126                 "...\n",
127                 YAMLString.c_str());
128  }
129
130  // When one of the requirements is not satisfied, invoke should return a
131  // valid error.
132  {
133    RefactoringRuleContext RefContext(Context.Sources);
134    Expected<AtomicChanges> ErrorOrResult =
135        createReplacements(Rule, RefContext);
136
137    ASSERT_TRUE(!ErrorOrResult);
138    unsigned DiagID;
139    llvm::handleAllErrors(ErrorOrResult.takeError(),
140                          [&](DiagnosticError &Error) {
141                            DiagID = Error.getDiagnostic().second.getDiagID();
142                          });
143    EXPECT_EQ(DiagID, diag::err_refactor_no_selection);
144  }
145}
146
147TEST_F(RefactoringActionRulesTest, ReturnError) {
148  class ErrorRule : public SourceChangeRefactoringRule {
149  public:
150    static Expected<ErrorRule> initiate(RefactoringRuleContext &,
151                                        SourceRange R) {
152      return ErrorRule(R);
153    }
154
155    ErrorRule(SourceRange R) {}
156    Expected<AtomicChanges> createSourceReplacements(RefactoringRuleContext &) {
157      return llvm::make_error<llvm::StringError>(
158          "Error", llvm::make_error_code(llvm::errc::invalid_argument));
159    }
160  };
161
162  auto Rule =
163      createRefactoringActionRule<ErrorRule>(SourceRangeSelectionRequirement());
164  RefactoringRuleContext RefContext(Context.Sources);
165  SourceLocation Cursor =
166      Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID());
167  RefContext.setSelectionRange({Cursor, Cursor});
168  Expected<AtomicChanges> Result = createReplacements(Rule, RefContext);
169
170  ASSERT_TRUE(!Result);
171  std::string Message;
172  llvm::handleAllErrors(Result.takeError(), [&](llvm::StringError &Error) {
173    Message = Error.getMessage();
174  });
175  EXPECT_EQ(Message"Error");
176}
177
178Optional<SymbolOccurrences> findOccurrences(RefactoringActionRule &Rule,
179                                            RefactoringRuleContext &Context) {
180  class Consumer final : public RefactoringResultConsumer {
181    void handleError(llvm::Error) override {}
182    void handle(SymbolOccurrences Occurrences) override {
183      Result = std::move(Occurrences);
184    }
185    void handle(AtomicChanges Changes) override {
186      RefactoringResultConsumer::handle(std::move(Changes));
187    }
188
189  public:
190    Optional<SymbolOccurrences> Result;
191  };
192
193  Consumer C;
194  Rule.invoke(C, Context);
195  return std::move(C.Result);
196}
197
198TEST_F(RefactoringActionRulesTest, ReturnSymbolOccurrences) {
199  class FindOccurrences : public FindSymbolOccurrencesRefactoringRule {
200    SourceRange Selection;
201
202  public:
203    FindOccurrences(SourceRange Selection) : Selection(Selection) {}
204
205    static Expected<FindOccurrences> initiate(RefactoringRuleContext &,
206                                              SourceRange Selection) {
207      return FindOccurrences(Selection);
208    }
209
210    Expected<SymbolOccurrences>
211    findSymbolOccurrences(RefactoringRuleContext &) override {
212      SymbolOccurrences Occurrences;
213      Occurrences.push_back(SymbolOccurrence(SymbolName("test"),
214                                             SymbolOccurrence::MatchingSymbol,
215                                             Selection.getBegin()));
216      return std::move(Occurrences);
217    }
218  };
219
220  auto Rule = createRefactoringActionRule<FindOccurrences>(
221      SourceRangeSelectionRequirement());
222
223  RefactoringRuleContext RefContext(Context.Sources);
224  SourceLocation Cursor =
225      Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID());
226  RefContext.setSelectionRange({Cursor, Cursor});
227  Optional<SymbolOccurrences> Result = findOccurrences(*Rule, RefContext);
228
229  ASSERT_FALSE(!Result);
230  SymbolOccurrences Occurrences = std::move(*Result);
231  EXPECT_EQ(Occurrences.size(), 1u);
232  EXPECT_EQ(Occurrences[0].getKind(), SymbolOccurrence::MatchingSymbol);
233  EXPECT_EQ(Occurrences[0].getNameRanges().size(), 1u);
234  EXPECT_EQ(Occurrences[0].getNameRanges()[0],
235            SourceRange(Cursor, Cursor.getLocWithOffset(strlen("test"))));
236}
237
238TEST_F(RefactoringActionRulesTest, EditorCommandBinding) {
239  const RefactoringDescriptor &Descriptor = ExtractFunction::describe();
240  EXPECT_EQ(Descriptor.Name, "extract-function");
241  EXPECT_EQ(
242      Descriptor.Description,
243      "(WIP action; use with caution!) Extracts code into a new function");
244  EXPECT_EQ(Descriptor.Title, "Extract Function");
245}
246
247// end anonymous namespace
248