1 | // Copyright 2018 The Go Authors. All rights reserved. |
---|---|
2 | // Use of this source code is governed by a BSD-style |
3 | // license that can be found in the LICENSE file. |
4 | |
5 | package inspector_test |
6 | |
7 | import ( |
8 | "go/ast" |
9 | "go/build" |
10 | "go/parser" |
11 | "go/token" |
12 | "log" |
13 | "path/filepath" |
14 | "reflect" |
15 | "strconv" |
16 | "strings" |
17 | "testing" |
18 | |
19 | "golang.org/x/tools/go/ast/inspector" |
20 | "golang.org/x/tools/internal/typeparams" |
21 | ) |
22 | |
23 | var netFiles []*ast.File |
24 | |
25 | func init() { |
26 | files, err := parseNetFiles() |
27 | if err != nil { |
28 | log.Fatal(err) |
29 | } |
30 | netFiles = files |
31 | } |
32 | |
33 | func parseNetFiles() ([]*ast.File, error) { |
34 | pkg, err := build.Default.Import("net", "", 0) |
35 | if err != nil { |
36 | return nil, err |
37 | } |
38 | fset := token.NewFileSet() |
39 | var files []*ast.File |
40 | for _, filename := range pkg.GoFiles { |
41 | filename = filepath.Join(pkg.Dir, filename) |
42 | f, err := parser.ParseFile(fset, filename, nil, 0) |
43 | if err != nil { |
44 | return nil, err |
45 | } |
46 | files = append(files, f) |
47 | } |
48 | return files, nil |
49 | } |
50 | |
51 | // TestAllNodes compares Inspector against ast.Inspect. |
52 | func TestInspectAllNodes(t *testing.T) { |
53 | inspect := inspector.New(netFiles) |
54 | |
55 | var nodesA []ast.Node |
56 | inspect.Nodes(nil, func(n ast.Node, push bool) bool { |
57 | if push { |
58 | nodesA = append(nodesA, n) |
59 | } |
60 | return true |
61 | }) |
62 | var nodesB []ast.Node |
63 | for _, f := range netFiles { |
64 | ast.Inspect(f, func(n ast.Node) bool { |
65 | if n != nil { |
66 | nodesB = append(nodesB, n) |
67 | } |
68 | return true |
69 | }) |
70 | } |
71 | compare(t, nodesA, nodesB) |
72 | } |
73 | |
74 | func TestInspectGenericNodes(t *testing.T) { |
75 | if !typeparams.Enabled { |
76 | t.Skip("type parameters are not supported at this Go version") |
77 | } |
78 | |
79 | // src is using the 16 identifiers i0, i1, ... i15 so |
80 | // we can easily verify that we've found all of them. |
81 | const src = `package a |
82 | |
83 | type I interface { ~i0|i1 } |
84 | |
85 | type T[i2, i3 interface{ ~i4 }] struct {} |
86 | |
87 | func f[i5, i6 any]() { |
88 | _ = f[i7, i8] |
89 | var x T[i9, i10] |
90 | } |
91 | |
92 | func (*T[i11, i12]) m() |
93 | |
94 | var _ i13[i14, i15] |
95 | ` |
96 | fset := token.NewFileSet() |
97 | f, _ := parser.ParseFile(fset, "a.go", src, 0) |
98 | inspect := inspector.New([]*ast.File{f}) |
99 | found := make([]bool, 16) |
100 | |
101 | indexListExprs := make(map[*typeparams.IndexListExpr]bool) |
102 | |
103 | // Verify that we reach all i* identifiers, and collect IndexListExpr nodes. |
104 | inspect.Preorder(nil, func(n ast.Node) { |
105 | switch n := n.(type) { |
106 | case *ast.Ident: |
107 | if n.Name[0] == 'i' { |
108 | index, err := strconv.Atoi(n.Name[1:]) |
109 | if err != nil { |
110 | t.Fatal(err) |
111 | } |
112 | found[index] = true |
113 | } |
114 | case *typeparams.IndexListExpr: |
115 | indexListExprs[n] = false |
116 | } |
117 | }) |
118 | for i, v := range found { |
119 | if !v { |
120 | t.Errorf("missed identifier i%d", i) |
121 | } |
122 | } |
123 | |
124 | // Verify that we can filter to IndexListExprs that we found in the first |
125 | // step. |
126 | if len(indexListExprs) == 0 { |
127 | t.Fatal("no index list exprs found") |
128 | } |
129 | inspect.Preorder([]ast.Node{&typeparams.IndexListExpr{}}, func(n ast.Node) { |
130 | ix := n.(*typeparams.IndexListExpr) |
131 | indexListExprs[ix] = true |
132 | }) |
133 | for ix, v := range indexListExprs { |
134 | if !v { |
135 | t.Errorf("inspected node %v not filtered", ix) |
136 | } |
137 | } |
138 | } |
139 | |
140 | // TestPruning compares Inspector against ast.Inspect, |
141 | // pruning descent within ast.CallExpr nodes. |
142 | func TestInspectPruning(t *testing.T) { |
143 | inspect := inspector.New(netFiles) |
144 | |
145 | var nodesA []ast.Node |
146 | inspect.Nodes(nil, func(n ast.Node, push bool) bool { |
147 | if push { |
148 | nodesA = append(nodesA, n) |
149 | _, isCall := n.(*ast.CallExpr) |
150 | return !isCall // don't descend into function calls |
151 | } |
152 | return false |
153 | }) |
154 | var nodesB []ast.Node |
155 | for _, f := range netFiles { |
156 | ast.Inspect(f, func(n ast.Node) bool { |
157 | if n != nil { |
158 | nodesB = append(nodesB, n) |
159 | _, isCall := n.(*ast.CallExpr) |
160 | return !isCall // don't descend into function calls |
161 | } |
162 | return false |
163 | }) |
164 | } |
165 | compare(t, nodesA, nodesB) |
166 | } |
167 | |
168 | func compare(t *testing.T, nodesA, nodesB []ast.Node) { |
169 | if len(nodesA) != len(nodesB) { |
170 | t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB)) |
171 | } else { |
172 | for i := range nodesA { |
173 | if a, b := nodesA[i], nodesB[i]; a != b { |
174 | t.Errorf("node %d is inconsistent: %T, %T", i, a, b) |
175 | } |
176 | } |
177 | } |
178 | } |
179 | |
180 | func TestTypeFiltering(t *testing.T) { |
181 | const src = `package a |
182 | func f() { |
183 | print("hi") |
184 | panic("oops") |
185 | } |
186 | ` |
187 | fset := token.NewFileSet() |
188 | f, _ := parser.ParseFile(fset, "a.go", src, 0) |
189 | inspect := inspector.New([]*ast.File{f}) |
190 | |
191 | var got []string |
192 | fn := func(n ast.Node, push bool) bool { |
193 | if push { |
194 | got = append(got, typeOf(n)) |
195 | } |
196 | return true |
197 | } |
198 | |
199 | // no type filtering |
200 | inspect.Nodes(nil, fn) |
201 | if want := strings.Fields("File Ident FuncDecl Ident FuncType FieldList BlockStmt ExprStmt CallExpr Ident BasicLit ExprStmt CallExpr Ident BasicLit"); !reflect.DeepEqual(got, want) { |
202 | t.Errorf("inspect: got %s, want %s", got, want) |
203 | } |
204 | |
205 | // type filtering |
206 | nodeTypes := []ast.Node{ |
207 | (*ast.BasicLit)(nil), |
208 | (*ast.CallExpr)(nil), |
209 | } |
210 | got = nil |
211 | inspect.Nodes(nodeTypes, fn) |
212 | if want := strings.Fields("CallExpr BasicLit CallExpr BasicLit"); !reflect.DeepEqual(got, want) { |
213 | t.Errorf("inspect: got %s, want %s", got, want) |
214 | } |
215 | |
216 | // inspect with stack |
217 | got = nil |
218 | inspect.WithStack(nodeTypes, func(n ast.Node, push bool, stack []ast.Node) bool { |
219 | if push { |
220 | var line []string |
221 | for _, n := range stack { |
222 | line = append(line, typeOf(n)) |
223 | } |
224 | got = append(got, strings.Join(line, " ")) |
225 | } |
226 | return true |
227 | }) |
228 | want := []string{ |
229 | "File FuncDecl BlockStmt ExprStmt CallExpr", |
230 | "File FuncDecl BlockStmt ExprStmt CallExpr BasicLit", |
231 | "File FuncDecl BlockStmt ExprStmt CallExpr", |
232 | "File FuncDecl BlockStmt ExprStmt CallExpr BasicLit", |
233 | } |
234 | if !reflect.DeepEqual(got, want) { |
235 | t.Errorf("inspect: got %s, want %s", got, want) |
236 | } |
237 | } |
238 | |
239 | func typeOf(n ast.Node) string { |
240 | return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.") |
241 | } |
242 | |
243 | // The numbers show a marginal improvement (ASTInspect/Inspect) of 3.5x, |
244 | // but a break-even point (NewInspector/(ASTInspect-Inspect)) of about 5 |
245 | // traversals. |
246 | // |
247 | // BenchmarkNewInspector 4.5 ms |
248 | // BenchmarkNewInspect 0.33ms |
249 | // BenchmarkASTInspect 1.2 ms |
250 | |
251 | func BenchmarkNewInspector(b *testing.B) { |
252 | // Measure one-time construction overhead. |
253 | for i := 0; i < b.N; i++ { |
254 | inspector.New(netFiles) |
255 | } |
256 | } |
257 | |
258 | func BenchmarkInspect(b *testing.B) { |
259 | b.StopTimer() |
260 | inspect := inspector.New(netFiles) |
261 | b.StartTimer() |
262 | |
263 | // Measure marginal cost of traversal. |
264 | var ndecls, nlits int |
265 | for i := 0; i < b.N; i++ { |
266 | inspect.Preorder(nil, func(n ast.Node) { |
267 | switch n.(type) { |
268 | case *ast.FuncDecl: |
269 | ndecls++ |
270 | case *ast.FuncLit: |
271 | nlits++ |
272 | } |
273 | }) |
274 | } |
275 | } |
276 | |
277 | func BenchmarkASTInspect(b *testing.B) { |
278 | var ndecls, nlits int |
279 | for i := 0; i < b.N; i++ { |
280 | for _, f := range netFiles { |
281 | ast.Inspect(f, func(n ast.Node) bool { |
282 | switch n.(type) { |
283 | case *ast.FuncDecl: |
284 | ndecls++ |
285 | case *ast.FuncLit: |
286 | nlits++ |
287 | } |
288 | return true |
289 | }) |
290 | } |
291 | } |
292 | } |
293 |
Members