1 | // Copyright 2017 The Go Authors. All rights reserved. |
---|---|
2 | // Use of this source code is governed by a BSD-style |
3 | // license that can be found in the LICENSE file. |
4 | |
5 | package astutil |
6 | |
7 | import ( |
8 | "fmt" |
9 | "go/ast" |
10 | "reflect" |
11 | "sort" |
12 | |
13 | "golang.org/x/tools/internal/typeparams" |
14 | ) |
15 | |
16 | // An ApplyFunc is invoked by Apply for each node n, even if n is nil, |
17 | // before and/or after the node's children, using a Cursor describing |
18 | // the current node and providing operations on it. |
19 | // |
20 | // The return value of ApplyFunc controls the syntax tree traversal. |
21 | // See Apply for details. |
22 | type ApplyFunc func(*Cursor) bool |
23 | |
24 | // Apply traverses a syntax tree recursively, starting with root, |
25 | // and calling pre and post for each node as described below. |
26 | // Apply returns the syntax tree, possibly modified. |
27 | // |
28 | // If pre is not nil, it is called for each node before the node's |
29 | // children are traversed (pre-order). If pre returns false, no |
30 | // children are traversed, and post is not called for that node. |
31 | // |
32 | // If post is not nil, and a prior call of pre didn't return false, |
33 | // post is called for each node after its children are traversed |
34 | // (post-order). If post returns false, traversal is terminated and |
35 | // Apply returns immediately. |
36 | // |
37 | // Only fields that refer to AST nodes are considered children; |
38 | // i.e., token.Pos, Scopes, Objects, and fields of basic types |
39 | // (strings, etc.) are ignored. |
40 | // |
41 | // Children are traversed in the order in which they appear in the |
42 | // respective node's struct definition. A package's files are |
43 | // traversed in the filenames' alphabetical order. |
44 | func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) { |
45 | parent := &struct{ ast.Node }{root} |
46 | defer func() { |
47 | if r := recover(); r != nil && r != abort { |
48 | panic(r) |
49 | } |
50 | result = parent.Node |
51 | }() |
52 | a := &application{pre: pre, post: post} |
53 | a.apply(parent, "Node", nil, root) |
54 | return |
55 | } |
56 | |
57 | var abort = new(int) // singleton, to signal termination of Apply |
58 | |
59 | // A Cursor describes a node encountered during Apply. |
60 | // Information about the node and its parent is available |
61 | // from the Node, Parent, Name, and Index methods. |
62 | // |
63 | // If p is a variable of type and value of the current parent node |
64 | // c.Parent(), and f is the field identifier with name c.Name(), |
65 | // the following invariants hold: |
66 | // |
67 | // p.f == c.Node() if c.Index() < 0 |
68 | // p.f[c.Index()] == c.Node() if c.Index() >= 0 |
69 | // |
70 | // The methods Replace, Delete, InsertBefore, and InsertAfter |
71 | // can be used to change the AST without disrupting Apply. |
72 | type Cursor struct { |
73 | parent ast.Node |
74 | name string |
75 | iter *iterator // valid if non-nil |
76 | node ast.Node |
77 | } |
78 | |
79 | // Node returns the current Node. |
80 | func (c *Cursor) Node() ast.Node { return c.node } |
81 | |
82 | // Parent returns the parent of the current Node. |
83 | func (c *Cursor) Parent() ast.Node { return c.parent } |
84 | |
85 | // Name returns the name of the parent Node field that contains the current Node. |
86 | // If the parent is a *ast.Package and the current Node is a *ast.File, Name returns |
87 | // the filename for the current Node. |
88 | func (c *Cursor) Name() string { return c.name } |
89 | |
90 | // Index reports the index >= 0 of the current Node in the slice of Nodes that |
91 | // contains it, or a value < 0 if the current Node is not part of a slice. |
92 | // The index of the current node changes if InsertBefore is called while |
93 | // processing the current node. |
94 | func (c *Cursor) Index() int { |
95 | if c.iter != nil { |
96 | return c.iter.index |
97 | } |
98 | return -1 |
99 | } |
100 | |
101 | // field returns the current node's parent field value. |
102 | func (c *Cursor) field() reflect.Value { |
103 | return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name) |
104 | } |
105 | |
106 | // Replace replaces the current Node with n. |
107 | // The replacement node is not walked by Apply. |
108 | func (c *Cursor) Replace(n ast.Node) { |
109 | if _, ok := c.node.(*ast.File); ok { |
110 | file, ok := n.(*ast.File) |
111 | if !ok { |
112 | panic("attempt to replace *ast.File with non-*ast.File") |
113 | } |
114 | c.parent.(*ast.Package).Files[c.name] = file |
115 | return |
116 | } |
117 | |
118 | v := c.field() |
119 | if i := c.Index(); i >= 0 { |
120 | v = v.Index(i) |
121 | } |
122 | v.Set(reflect.ValueOf(n)) |
123 | } |
124 | |
125 | // Delete deletes the current Node from its containing slice. |
126 | // If the current Node is not part of a slice, Delete panics. |
127 | // As a special case, if the current node is a package file, |
128 | // Delete removes it from the package's Files map. |
129 | func (c *Cursor) Delete() { |
130 | if _, ok := c.node.(*ast.File); ok { |
131 | delete(c.parent.(*ast.Package).Files, c.name) |
132 | return |
133 | } |
134 | |
135 | i := c.Index() |
136 | if i < 0 { |
137 | panic("Delete node not contained in slice") |
138 | } |
139 | v := c.field() |
140 | l := v.Len() |
141 | reflect.Copy(v.Slice(i, l), v.Slice(i+1, l)) |
142 | v.Index(l - 1).Set(reflect.Zero(v.Type().Elem())) |
143 | v.SetLen(l - 1) |
144 | c.iter.step-- |
145 | } |
146 | |
147 | // InsertAfter inserts n after the current Node in its containing slice. |
148 | // If the current Node is not part of a slice, InsertAfter panics. |
149 | // Apply does not walk n. |
150 | func (c *Cursor) InsertAfter(n ast.Node) { |
151 | i := c.Index() |
152 | if i < 0 { |
153 | panic("InsertAfter node not contained in slice") |
154 | } |
155 | v := c.field() |
156 | v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) |
157 | l := v.Len() |
158 | reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l)) |
159 | v.Index(i + 1).Set(reflect.ValueOf(n)) |
160 | c.iter.step++ |
161 | } |
162 | |
163 | // InsertBefore inserts n before the current Node in its containing slice. |
164 | // If the current Node is not part of a slice, InsertBefore panics. |
165 | // Apply will not walk n. |
166 | func (c *Cursor) InsertBefore(n ast.Node) { |
167 | i := c.Index() |
168 | if i < 0 { |
169 | panic("InsertBefore node not contained in slice") |
170 | } |
171 | v := c.field() |
172 | v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) |
173 | l := v.Len() |
174 | reflect.Copy(v.Slice(i+1, l), v.Slice(i, l)) |
175 | v.Index(i).Set(reflect.ValueOf(n)) |
176 | c.iter.index++ |
177 | } |
178 | |
179 | // application carries all the shared data so we can pass it around cheaply. |
180 | type application struct { |
181 | pre, post ApplyFunc |
182 | cursor Cursor |
183 | iter iterator |
184 | } |
185 | |
186 | func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) { |
187 | // convert typed nil into untyped nil |
188 | if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() { |
189 | n = nil |
190 | } |
191 | |
192 | // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead |
193 | saved := a.cursor |
194 | a.cursor.parent = parent |
195 | a.cursor.name = name |
196 | a.cursor.iter = iter |
197 | a.cursor.node = n |
198 | |
199 | if a.pre != nil && !a.pre(&a.cursor) { |
200 | a.cursor = saved |
201 | return |
202 | } |
203 | |
204 | // walk children |
205 | // (the order of the cases matches the order of the corresponding node types in go/ast) |
206 | switch n := n.(type) { |
207 | case nil: |
208 | // nothing to do |
209 | |
210 | // Comments and fields |
211 | case *ast.Comment: |
212 | // nothing to do |
213 | |
214 | case *ast.CommentGroup: |
215 | if n != nil { |
216 | a.applyList(n, "List") |
217 | } |
218 | |
219 | case *ast.Field: |
220 | a.apply(n, "Doc", nil, n.Doc) |
221 | a.applyList(n, "Names") |
222 | a.apply(n, "Type", nil, n.Type) |
223 | a.apply(n, "Tag", nil, n.Tag) |
224 | a.apply(n, "Comment", nil, n.Comment) |
225 | |
226 | case *ast.FieldList: |
227 | a.applyList(n, "List") |
228 | |
229 | // Expressions |
230 | case *ast.BadExpr, *ast.Ident, *ast.BasicLit: |
231 | // nothing to do |
232 | |
233 | case *ast.Ellipsis: |
234 | a.apply(n, "Elt", nil, n.Elt) |
235 | |
236 | case *ast.FuncLit: |
237 | a.apply(n, "Type", nil, n.Type) |
238 | a.apply(n, "Body", nil, n.Body) |
239 | |
240 | case *ast.CompositeLit: |
241 | a.apply(n, "Type", nil, n.Type) |
242 | a.applyList(n, "Elts") |
243 | |
244 | case *ast.ParenExpr: |
245 | a.apply(n, "X", nil, n.X) |
246 | |
247 | case *ast.SelectorExpr: |
248 | a.apply(n, "X", nil, n.X) |
249 | a.apply(n, "Sel", nil, n.Sel) |
250 | |
251 | case *ast.IndexExpr: |
252 | a.apply(n, "X", nil, n.X) |
253 | a.apply(n, "Index", nil, n.Index) |
254 | |
255 | case *typeparams.IndexListExpr: |
256 | a.apply(n, "X", nil, n.X) |
257 | a.applyList(n, "Indices") |
258 | |
259 | case *ast.SliceExpr: |
260 | a.apply(n, "X", nil, n.X) |
261 | a.apply(n, "Low", nil, n.Low) |
262 | a.apply(n, "High", nil, n.High) |
263 | a.apply(n, "Max", nil, n.Max) |
264 | |
265 | case *ast.TypeAssertExpr: |
266 | a.apply(n, "X", nil, n.X) |
267 | a.apply(n, "Type", nil, n.Type) |
268 | |
269 | case *ast.CallExpr: |
270 | a.apply(n, "Fun", nil, n.Fun) |
271 | a.applyList(n, "Args") |
272 | |
273 | case *ast.StarExpr: |
274 | a.apply(n, "X", nil, n.X) |
275 | |
276 | case *ast.UnaryExpr: |
277 | a.apply(n, "X", nil, n.X) |
278 | |
279 | case *ast.BinaryExpr: |
280 | a.apply(n, "X", nil, n.X) |
281 | a.apply(n, "Y", nil, n.Y) |
282 | |
283 | case *ast.KeyValueExpr: |
284 | a.apply(n, "Key", nil, n.Key) |
285 | a.apply(n, "Value", nil, n.Value) |
286 | |
287 | // Types |
288 | case *ast.ArrayType: |
289 | a.apply(n, "Len", nil, n.Len) |
290 | a.apply(n, "Elt", nil, n.Elt) |
291 | |
292 | case *ast.StructType: |
293 | a.apply(n, "Fields", nil, n.Fields) |
294 | |
295 | case *ast.FuncType: |
296 | if tparams := typeparams.ForFuncType(n); tparams != nil { |
297 | a.apply(n, "TypeParams", nil, tparams) |
298 | } |
299 | a.apply(n, "Params", nil, n.Params) |
300 | a.apply(n, "Results", nil, n.Results) |
301 | |
302 | case *ast.InterfaceType: |
303 | a.apply(n, "Methods", nil, n.Methods) |
304 | |
305 | case *ast.MapType: |
306 | a.apply(n, "Key", nil, n.Key) |
307 | a.apply(n, "Value", nil, n.Value) |
308 | |
309 | case *ast.ChanType: |
310 | a.apply(n, "Value", nil, n.Value) |
311 | |
312 | // Statements |
313 | case *ast.BadStmt: |
314 | // nothing to do |
315 | |
316 | case *ast.DeclStmt: |
317 | a.apply(n, "Decl", nil, n.Decl) |
318 | |
319 | case *ast.EmptyStmt: |
320 | // nothing to do |
321 | |
322 | case *ast.LabeledStmt: |
323 | a.apply(n, "Label", nil, n.Label) |
324 | a.apply(n, "Stmt", nil, n.Stmt) |
325 | |
326 | case *ast.ExprStmt: |
327 | a.apply(n, "X", nil, n.X) |
328 | |
329 | case *ast.SendStmt: |
330 | a.apply(n, "Chan", nil, n.Chan) |
331 | a.apply(n, "Value", nil, n.Value) |
332 | |
333 | case *ast.IncDecStmt: |
334 | a.apply(n, "X", nil, n.X) |
335 | |
336 | case *ast.AssignStmt: |
337 | a.applyList(n, "Lhs") |
338 | a.applyList(n, "Rhs") |
339 | |
340 | case *ast.GoStmt: |
341 | a.apply(n, "Call", nil, n.Call) |
342 | |
343 | case *ast.DeferStmt: |
344 | a.apply(n, "Call", nil, n.Call) |
345 | |
346 | case *ast.ReturnStmt: |
347 | a.applyList(n, "Results") |
348 | |
349 | case *ast.BranchStmt: |
350 | a.apply(n, "Label", nil, n.Label) |
351 | |
352 | case *ast.BlockStmt: |
353 | a.applyList(n, "List") |
354 | |
355 | case *ast.IfStmt: |
356 | a.apply(n, "Init", nil, n.Init) |
357 | a.apply(n, "Cond", nil, n.Cond) |
358 | a.apply(n, "Body", nil, n.Body) |
359 | a.apply(n, "Else", nil, n.Else) |
360 | |
361 | case *ast.CaseClause: |
362 | a.applyList(n, "List") |
363 | a.applyList(n, "Body") |
364 | |
365 | case *ast.SwitchStmt: |
366 | a.apply(n, "Init", nil, n.Init) |
367 | a.apply(n, "Tag", nil, n.Tag) |
368 | a.apply(n, "Body", nil, n.Body) |
369 | |
370 | case *ast.TypeSwitchStmt: |
371 | a.apply(n, "Init", nil, n.Init) |
372 | a.apply(n, "Assign", nil, n.Assign) |
373 | a.apply(n, "Body", nil, n.Body) |
374 | |
375 | case *ast.CommClause: |
376 | a.apply(n, "Comm", nil, n.Comm) |
377 | a.applyList(n, "Body") |
378 | |
379 | case *ast.SelectStmt: |
380 | a.apply(n, "Body", nil, n.Body) |
381 | |
382 | case *ast.ForStmt: |
383 | a.apply(n, "Init", nil, n.Init) |
384 | a.apply(n, "Cond", nil, n.Cond) |
385 | a.apply(n, "Post", nil, n.Post) |
386 | a.apply(n, "Body", nil, n.Body) |
387 | |
388 | case *ast.RangeStmt: |
389 | a.apply(n, "Key", nil, n.Key) |
390 | a.apply(n, "Value", nil, n.Value) |
391 | a.apply(n, "X", nil, n.X) |
392 | a.apply(n, "Body", nil, n.Body) |
393 | |
394 | // Declarations |
395 | case *ast.ImportSpec: |
396 | a.apply(n, "Doc", nil, n.Doc) |
397 | a.apply(n, "Name", nil, n.Name) |
398 | a.apply(n, "Path", nil, n.Path) |
399 | a.apply(n, "Comment", nil, n.Comment) |
400 | |
401 | case *ast.ValueSpec: |
402 | a.apply(n, "Doc", nil, n.Doc) |
403 | a.applyList(n, "Names") |
404 | a.apply(n, "Type", nil, n.Type) |
405 | a.applyList(n, "Values") |
406 | a.apply(n, "Comment", nil, n.Comment) |
407 | |
408 | case *ast.TypeSpec: |
409 | a.apply(n, "Doc", nil, n.Doc) |
410 | a.apply(n, "Name", nil, n.Name) |
411 | if tparams := typeparams.ForTypeSpec(n); tparams != nil { |
412 | a.apply(n, "TypeParams", nil, tparams) |
413 | } |
414 | a.apply(n, "Type", nil, n.Type) |
415 | a.apply(n, "Comment", nil, n.Comment) |
416 | |
417 | case *ast.BadDecl: |
418 | // nothing to do |
419 | |
420 | case *ast.GenDecl: |
421 | a.apply(n, "Doc", nil, n.Doc) |
422 | a.applyList(n, "Specs") |
423 | |
424 | case *ast.FuncDecl: |
425 | a.apply(n, "Doc", nil, n.Doc) |
426 | a.apply(n, "Recv", nil, n.Recv) |
427 | a.apply(n, "Name", nil, n.Name) |
428 | a.apply(n, "Type", nil, n.Type) |
429 | a.apply(n, "Body", nil, n.Body) |
430 | |
431 | // Files and packages |
432 | case *ast.File: |
433 | a.apply(n, "Doc", nil, n.Doc) |
434 | a.apply(n, "Name", nil, n.Name) |
435 | a.applyList(n, "Decls") |
436 | // Don't walk n.Comments; they have either been walked already if |
437 | // they are Doc comments, or they can be easily walked explicitly. |
438 | |
439 | case *ast.Package: |
440 | // collect and sort names for reproducible behavior |
441 | var names []string |
442 | for name := range n.Files { |
443 | names = append(names, name) |
444 | } |
445 | sort.Strings(names) |
446 | for _, name := range names { |
447 | a.apply(n, name, nil, n.Files[name]) |
448 | } |
449 | |
450 | default: |
451 | panic(fmt.Sprintf("Apply: unexpected node type %T", n)) |
452 | } |
453 | |
454 | if a.post != nil && !a.post(&a.cursor) { |
455 | panic(abort) |
456 | } |
457 | |
458 | a.cursor = saved |
459 | } |
460 | |
461 | // An iterator controls iteration over a slice of nodes. |
462 | type iterator struct { |
463 | index, step int |
464 | } |
465 | |
466 | func (a *application) applyList(parent ast.Node, name string) { |
467 | // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead |
468 | saved := a.iter |
469 | a.iter.index = 0 |
470 | for { |
471 | // must reload parent.name each time, since cursor modifications might change it |
472 | v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name) |
473 | if a.iter.index >= v.Len() { |
474 | break |
475 | } |
476 | |
477 | // element x may be nil in a bad AST - be cautious |
478 | var x ast.Node |
479 | if e := v.Index(a.iter.index); e.IsValid() { |
480 | x = e.Interface().(ast.Node) |
481 | } |
482 | |
483 | a.iter.step = 1 |
484 | a.apply(parent, name, &a.iter, x) |
485 | a.iter.index += a.iter.step |
486 | } |
487 | a.iter = saved |
488 | } |
489 |
Members