1 | |
2 | |
3 | |
4 | |
5 | |
6 | |
7 | |
8 | |
9 | |
10 | |
11 | #include "clang/Tooling/RefactoringCallbacks.h" |
12 | #include "clang/ASTMatchers/ASTMatchFinder.h" |
13 | #include "clang/Basic/SourceLocation.h" |
14 | #include "clang/Lex/Lexer.h" |
15 | |
16 | using llvm::StringError; |
17 | using llvm::make_error; |
18 | |
19 | namespace clang { |
20 | namespace tooling { |
21 | |
22 | RefactoringCallback::RefactoringCallback() {} |
23 | tooling::Replacements &RefactoringCallback::getReplacements() { |
24 | return Replace; |
25 | } |
26 | |
27 | ASTMatchRefactorer::ASTMatchRefactorer( |
28 | std::map<std::string, Replacements> &FileToReplaces) |
29 | : FileToReplaces(FileToReplaces) {} |
30 | |
31 | void ASTMatchRefactorer::addDynamicMatcher( |
32 | const ast_matchers::internal::DynTypedMatcher &Matcher, |
33 | RefactoringCallback *Callback) { |
34 | MatchFinder.addDynamicMatcher(Matcher, Callback); |
35 | Callbacks.push_back(Callback); |
36 | } |
37 | |
38 | class RefactoringASTConsumer : public ASTConsumer { |
39 | public: |
40 | explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring) |
41 | : Refactoring(Refactoring) {} |
42 | |
43 | void HandleTranslationUnit(ASTContext &Context) override { |
44 | |
45 | |
46 | for (const auto &Callback : Refactoring.Callbacks) { |
47 | Callback->getReplacements().clear(); |
48 | } |
49 | Refactoring.MatchFinder.matchAST(Context); |
50 | for (const auto &Callback : Refactoring.Callbacks) { |
51 | for (const auto &Replacement : Callback->getReplacements()) { |
52 | llvm::Error Err = |
53 | Refactoring.FileToReplaces[Replacement.getFilePath()].add( |
54 | Replacement); |
55 | if (Err) { |
56 | llvm::errs() << "Skipping replacement " << Replacement.toString() |
57 | << " due to this error:\n" |
58 | << toString(std::move(Err)) << "\n"; |
59 | } |
60 | } |
61 | } |
62 | } |
63 | |
64 | private: |
65 | ASTMatchRefactorer &Refactoring; |
66 | }; |
67 | |
68 | std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() { |
69 | return llvm::make_unique<RefactoringASTConsumer>(*this); |
70 | } |
71 | |
72 | static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From, |
73 | StringRef Text) { |
74 | return tooling::Replacement( |
75 | Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text); |
76 | } |
77 | static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From, |
78 | const Stmt &To) { |
79 | return replaceStmtWithText( |
80 | Sources, From, |
81 | Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()), |
82 | Sources, LangOptions())); |
83 | } |
84 | |
85 | ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText) |
86 | : FromId(FromId), ToText(ToText) {} |
87 | |
88 | void ReplaceStmtWithText::run( |
89 | const ast_matchers::MatchFinder::MatchResult &Result) { |
90 | if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) { |
91 | auto Err = Replace.add(tooling::Replacement( |
92 | *Result.SourceManager, |
93 | CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText)); |
94 | |
95 | |
96 | if (Err) { |
97 | llvm::errs() << llvm::toString(std::move(Err)) << "\n"; |
98 | assert(false); |
99 | } |
100 | } |
101 | } |
102 | |
103 | ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId) |
104 | : FromId(FromId), ToId(ToId) {} |
105 | |
106 | void ReplaceStmtWithStmt::run( |
107 | const ast_matchers::MatchFinder::MatchResult &Result) { |
108 | const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId); |
109 | const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId); |
110 | if (FromMatch && ToMatch) { |
111 | auto Err = Replace.add( |
112 | replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch)); |
113 | |
114 | |
115 | if (Err) { |
116 | llvm::errs() << llvm::toString(std::move(Err)) << "\n"; |
117 | assert(false); |
118 | } |
119 | } |
120 | } |
121 | |
122 | ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id, |
123 | bool PickTrueBranch) |
124 | : Id(Id), PickTrueBranch(PickTrueBranch) {} |
125 | |
126 | void ReplaceIfStmtWithItsBody::run( |
127 | const ast_matchers::MatchFinder::MatchResult &Result) { |
128 | if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) { |
129 | const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse(); |
130 | if (Body) { |
131 | auto Err = |
132 | Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body)); |
133 | |
134 | |
135 | if (Err) { |
136 | llvm::errs() << llvm::toString(std::move(Err)) << "\n"; |
137 | assert(false); |
138 | } |
139 | } else if (!PickTrueBranch) { |
140 | |
141 | |
142 | auto Err = |
143 | Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, "")); |
144 | |
145 | |
146 | if (Err) { |
147 | llvm::errs() << llvm::toString(std::move(Err)) << "\n"; |
148 | assert(false); |
149 | } |
150 | } |
151 | } |
152 | } |
153 | |
154 | ReplaceNodeWithTemplate::ReplaceNodeWithTemplate( |
155 | llvm::StringRef FromId, std::vector<TemplateElement> Template) |
156 | : FromId(FromId), Template(std::move(Template)) {} |
157 | |
158 | llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>> |
159 | ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) { |
160 | std::vector<TemplateElement> ParsedTemplate; |
161 | for (size_t Index = 0; Index < ToTemplate.size();) { |
162 | if (ToTemplate[Index] == '$') { |
163 | if (ToTemplate.substr(Index, 2) == "$$") { |
164 | Index += 2; |
165 | ParsedTemplate.push_back( |
166 | TemplateElement{TemplateElement::Literal, "$"}); |
167 | } else if (ToTemplate.substr(Index, 2) == "${") { |
168 | size_t EndOfIdentifier = ToTemplate.find("}", Index); |
169 | if (EndOfIdentifier == std::string::npos) { |
170 | return make_error<StringError>( |
171 | "Unterminated ${...} in replacement template near " + |
172 | ToTemplate.substr(Index), |
173 | llvm::inconvertibleErrorCode()); |
174 | } |
175 | std::string SourceNodeName = |
176 | ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2); |
177 | ParsedTemplate.push_back( |
178 | TemplateElement{TemplateElement::Identifier, SourceNodeName}); |
179 | Index = EndOfIdentifier + 1; |
180 | } else { |
181 | return make_error<StringError>( |
182 | "Invalid $ in replacement template near " + |
183 | ToTemplate.substr(Index), |
184 | llvm::inconvertibleErrorCode()); |
185 | } |
186 | } else { |
187 | size_t NextIndex = ToTemplate.find('$', Index + 1); |
188 | ParsedTemplate.push_back( |
189 | TemplateElement{TemplateElement::Literal, |
190 | ToTemplate.substr(Index, NextIndex - Index)}); |
191 | Index = NextIndex; |
192 | } |
193 | } |
194 | return std::unique_ptr<ReplaceNodeWithTemplate>( |
195 | new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate))); |
196 | } |
197 | |
198 | void ReplaceNodeWithTemplate::run( |
199 | const ast_matchers::MatchFinder::MatchResult &Result) { |
200 | const auto &NodeMap = Result.Nodes.getMap(); |
201 | |
202 | std::string ToText; |
203 | for (const auto &Element : Template) { |
204 | switch (Element.Type) { |
205 | case TemplateElement::Literal: |
206 | ToText += Element.Value; |
207 | break; |
208 | case TemplateElement::Identifier: { |
209 | auto NodeIter = NodeMap.find(Element.Value); |
210 | if (NodeIter == NodeMap.end()) { |
211 | llvm::errs() << "Node " << Element.Value |
212 | << " used in replacement template not bound in Matcher \n"; |
213 | llvm::report_fatal_error("Unbound node in replacement template."); |
214 | } |
215 | CharSourceRange Source = |
216 | CharSourceRange::getTokenRange(NodeIter->second.getSourceRange()); |
217 | ToText += Lexer::getSourceText(Source, *Result.SourceManager, |
218 | Result.Context->getLangOpts()); |
219 | break; |
220 | } |
221 | } |
222 | } |
223 | if (NodeMap.count(FromId) == 0) { |
224 | llvm::errs() << "Node to be replaced " << FromId |
225 | << " not bound in query.\n"; |
226 | llvm::report_fatal_error("FromId node not bound in MatchResult"); |
227 | } |
228 | auto Replacement = |
229 | tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText, |
230 | Result.Context->getLangOpts()); |
231 | llvm::Error Err = Replace.add(Replacement); |
232 | if (Err) { |
233 | llvm::errs() << "Query and replace failed in " << Replacement.getFilePath() |
234 | << "! " << llvm::toString(std::move(Err)) << "\n"; |
235 | llvm::report_fatal_error("Replacement failed"); |
236 | } |
237 | } |
238 | |
239 | } |
240 | } |
241 | |