1 | // Copyright 2018 The Go Authors. All rights reserved. |
---|---|
2 | // Use of this source code is governed by a BSD-style |
3 | // license that can be found in the LICENSE file. |
4 | |
5 | package jsonrpc2_test |
6 | |
7 | import ( |
8 | "context" |
9 | "encoding/json" |
10 | "fmt" |
11 | "path" |
12 | "reflect" |
13 | "testing" |
14 | |
15 | "golang.org/x/tools/internal/event/export/eventtest" |
16 | jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" |
17 | "golang.org/x/tools/internal/stack/stacktest" |
18 | ) |
19 | |
20 | var callTests = []invoker{ |
21 | call{"no_args", nil, true}, |
22 | call{"one_string", "fish", "got:fish"}, |
23 | call{"one_number", 10, "got:10"}, |
24 | call{"join", []string{"a", "b", "c"}, "a/b/c"}, |
25 | sequence{"notify", []invoker{ |
26 | notify{"set", 3}, |
27 | notify{"add", 5}, |
28 | call{"get", nil, 8}, |
29 | }}, |
30 | sequence{"preempt", []invoker{ |
31 | async{"a", "wait", "a"}, |
32 | notify{"unblock", "a"}, |
33 | collect{"a", true, false}, |
34 | }}, |
35 | sequence{"basic cancel", []invoker{ |
36 | async{"b", "wait", "b"}, |
37 | cancel{"b"}, |
38 | collect{"b", nil, true}, |
39 | }}, |
40 | sequence{"queue", []invoker{ |
41 | async{"a", "wait", "a"}, |
42 | notify{"set", 1}, |
43 | notify{"add", 2}, |
44 | notify{"add", 3}, |
45 | notify{"add", 4}, |
46 | call{"peek", nil, 0}, // accumulator will not have any adds yet |
47 | notify{"unblock", "a"}, |
48 | collect{"a", true, false}, |
49 | call{"get", nil, 10}, // accumulator now has all the adds |
50 | }}, |
51 | sequence{"fork", []invoker{ |
52 | async{"a", "fork", "a"}, |
53 | notify{"set", 1}, |
54 | notify{"add", 2}, |
55 | notify{"add", 3}, |
56 | notify{"add", 4}, |
57 | call{"get", nil, 10}, // fork will not have blocked the adds |
58 | notify{"unblock", "a"}, |
59 | collect{"a", true, false}, |
60 | }}, |
61 | sequence{"concurrent", []invoker{ |
62 | async{"a", "fork", "a"}, |
63 | notify{"unblock", "a"}, |
64 | async{"b", "fork", "b"}, |
65 | notify{"unblock", "b"}, |
66 | collect{"a", true, false}, |
67 | collect{"b", true, false}, |
68 | }}, |
69 | } |
70 | |
71 | type binder struct { |
72 | framer jsonrpc2.Framer |
73 | runTest func(*handler) |
74 | } |
75 | |
76 | type handler struct { |
77 | conn *jsonrpc2.Connection |
78 | accumulator int |
79 | waiters chan map[string]chan struct{} |
80 | calls map[string]*jsonrpc2.AsyncCall |
81 | } |
82 | |
83 | type invoker interface { |
84 | Name() string |
85 | Invoke(t *testing.T, ctx context.Context, h *handler) |
86 | } |
87 | |
88 | type notify struct { |
89 | method string |
90 | params interface{} |
91 | } |
92 | |
93 | type call struct { |
94 | method string |
95 | params interface{} |
96 | expect interface{} |
97 | } |
98 | |
99 | type async struct { |
100 | name string |
101 | method string |
102 | params interface{} |
103 | } |
104 | |
105 | type collect struct { |
106 | name string |
107 | expect interface{} |
108 | fails bool |
109 | } |
110 | |
111 | type cancel struct { |
112 | name string |
113 | } |
114 | |
115 | type sequence struct { |
116 | name string |
117 | tests []invoker |
118 | } |
119 | |
120 | type echo call |
121 | |
122 | type cancelParams struct{ ID int64 } |
123 | |
124 | func TestConnectionRaw(t *testing.T) { |
125 | testConnection(t, jsonrpc2.RawFramer()) |
126 | } |
127 | |
128 | func TestConnectionHeader(t *testing.T) { |
129 | testConnection(t, jsonrpc2.HeaderFramer()) |
130 | } |
131 | |
132 | func testConnection(t *testing.T, framer jsonrpc2.Framer) { |
133 | stacktest.NoLeak(t) |
134 | ctx := eventtest.NewContext(context.Background(), t) |
135 | listener, err := jsonrpc2.NetPipeListener(ctx) |
136 | if err != nil { |
137 | t.Fatal(err) |
138 | } |
139 | server := jsonrpc2.NewServer(ctx, listener, binder{framer, nil}) |
140 | defer func() { |
141 | listener.Close() |
142 | server.Wait() |
143 | }() |
144 | |
145 | for _, test := range callTests { |
146 | t.Run(test.Name(), func(t *testing.T) { |
147 | client, err := jsonrpc2.Dial(ctx, |
148 | listener.Dialer(), binder{framer, func(h *handler) { |
149 | defer h.conn.Close() |
150 | ctx := eventtest.NewContext(ctx, t) |
151 | test.Invoke(t, ctx, h) |
152 | if call, ok := test.(*call); ok { |
153 | // also run all simple call tests in echo mode |
154 | (*echo)(call).Invoke(t, ctx, h) |
155 | } |
156 | }}) |
157 | if err != nil { |
158 | t.Fatal(err) |
159 | } |
160 | client.Wait() |
161 | }) |
162 | } |
163 | } |
164 | |
165 | func (test notify) Name() string { return test.method } |
166 | func (test notify) Invoke(t *testing.T, ctx context.Context, h *handler) { |
167 | if err := h.conn.Notify(ctx, test.method, test.params); err != nil { |
168 | t.Fatalf("%v:Notify failed: %v", test.method, err) |
169 | } |
170 | } |
171 | |
172 | func (test call) Name() string { return test.method } |
173 | func (test call) Invoke(t *testing.T, ctx context.Context, h *handler) { |
174 | results := newResults(test.expect) |
175 | if err := h.conn.Call(ctx, test.method, test.params).Await(ctx, results); err != nil { |
176 | t.Fatalf("%v:Call failed: %v", test.method, err) |
177 | } |
178 | verifyResults(t, test.method, results, test.expect) |
179 | } |
180 | |
181 | func (test echo) Invoke(t *testing.T, ctx context.Context, h *handler) { |
182 | results := newResults(test.expect) |
183 | if err := h.conn.Call(ctx, "echo", []interface{}{test.method, test.params}).Await(ctx, results); err != nil { |
184 | t.Fatalf("%v:Echo failed: %v", test.method, err) |
185 | } |
186 | verifyResults(t, test.method, results, test.expect) |
187 | } |
188 | |
189 | func (test async) Name() string { return test.name } |
190 | func (test async) Invoke(t *testing.T, ctx context.Context, h *handler) { |
191 | h.calls[test.name] = h.conn.Call(ctx, test.method, test.params) |
192 | } |
193 | |
194 | func (test collect) Name() string { return test.name } |
195 | func (test collect) Invoke(t *testing.T, ctx context.Context, h *handler) { |
196 | o := h.calls[test.name] |
197 | results := newResults(test.expect) |
198 | err := o.Await(ctx, results) |
199 | switch { |
200 | case test.fails && err == nil: |
201 | t.Fatalf("%v:Collect was supposed to fail", test.name) |
202 | case !test.fails && err != nil: |
203 | t.Fatalf("%v:Collect failed: %v", test.name, err) |
204 | } |
205 | verifyResults(t, test.name, results, test.expect) |
206 | } |
207 | |
208 | func (test cancel) Name() string { return test.name } |
209 | func (test cancel) Invoke(t *testing.T, ctx context.Context, h *handler) { |
210 | o := h.calls[test.name] |
211 | if err := h.conn.Notify(ctx, "cancel", &cancelParams{o.ID().Raw().(int64)}); err != nil { |
212 | t.Fatalf("%v:Collect failed: %v", test.name, err) |
213 | } |
214 | } |
215 | |
216 | func (test sequence) Name() string { return test.name } |
217 | func (test sequence) Invoke(t *testing.T, ctx context.Context, h *handler) { |
218 | for _, child := range test.tests { |
219 | child.Invoke(t, ctx, h) |
220 | } |
221 | } |
222 | |
223 | // newResults makes a new empty copy of the expected type to put the results into |
224 | func newResults(expect interface{}) interface{} { |
225 | switch e := expect.(type) { |
226 | case []interface{}: |
227 | var r []interface{} |
228 | for _, v := range e { |
229 | r = append(r, reflect.New(reflect.TypeOf(v)).Interface()) |
230 | } |
231 | return r |
232 | case nil: |
233 | return nil |
234 | default: |
235 | return reflect.New(reflect.TypeOf(expect)).Interface() |
236 | } |
237 | } |
238 | |
239 | // verifyResults compares the results to the expected values |
240 | func verifyResults(t *testing.T, method string, results interface{}, expect interface{}) { |
241 | if expect == nil { |
242 | if results != nil { |
243 | t.Errorf("%v:Got results %+v where none expeted", method, expect) |
244 | } |
245 | return |
246 | } |
247 | val := reflect.Indirect(reflect.ValueOf(results)).Interface() |
248 | if !reflect.DeepEqual(val, expect) { |
249 | t.Errorf("%v:Results are incorrect, got %+v expect %+v", method, val, expect) |
250 | } |
251 | } |
252 | |
253 | func (b binder) Bind(ctx context.Context, conn *jsonrpc2.Connection) jsonrpc2.ConnectionOptions { |
254 | h := &handler{ |
255 | conn: conn, |
256 | waiters: make(chan map[string]chan struct{}, 1), |
257 | calls: make(map[string]*jsonrpc2.AsyncCall), |
258 | } |
259 | h.waiters <- make(map[string]chan struct{}) |
260 | if b.runTest != nil { |
261 | go b.runTest(h) |
262 | } |
263 | return jsonrpc2.ConnectionOptions{ |
264 | Framer: b.framer, |
265 | Preempter: h, |
266 | Handler: h, |
267 | } |
268 | } |
269 | |
270 | func (h *handler) waiter(name string) chan struct{} { |
271 | waiters := <-h.waiters |
272 | defer func() { h.waiters <- waiters }() |
273 | waiter, found := waiters[name] |
274 | if !found { |
275 | waiter = make(chan struct{}) |
276 | waiters[name] = waiter |
277 | } |
278 | return waiter |
279 | } |
280 | |
281 | func (h *handler) Preempt(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) { |
282 | switch req.Method { |
283 | case "unblock": |
284 | var name string |
285 | if err := json.Unmarshal(req.Params, &name); err != nil { |
286 | return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) |
287 | } |
288 | close(h.waiter(name)) |
289 | return nil, nil |
290 | case "peek": |
291 | if len(req.Params) > 0 { |
292 | return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams) |
293 | } |
294 | return h.accumulator, nil |
295 | case "cancel": |
296 | var params cancelParams |
297 | if err := json.Unmarshal(req.Params, ¶ms); err != nil { |
298 | return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) |
299 | } |
300 | h.conn.Cancel(jsonrpc2.Int64ID(params.ID)) |
301 | return nil, nil |
302 | default: |
303 | return nil, jsonrpc2.ErrNotHandled |
304 | } |
305 | } |
306 | |
307 | func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) { |
308 | switch req.Method { |
309 | case "no_args": |
310 | if len(req.Params) > 0 { |
311 | return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams) |
312 | } |
313 | return true, nil |
314 | case "one_string": |
315 | var v string |
316 | if err := json.Unmarshal(req.Params, &v); err != nil { |
317 | return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) |
318 | } |
319 | return "got:" + v, nil |
320 | case "one_number": |
321 | var v int |
322 | if err := json.Unmarshal(req.Params, &v); err != nil { |
323 | return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) |
324 | } |
325 | return fmt.Sprintf("got:%d", v), nil |
326 | case "set": |
327 | var v int |
328 | if err := json.Unmarshal(req.Params, &v); err != nil { |
329 | return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) |
330 | } |
331 | h.accumulator = v |
332 | return nil, nil |
333 | case "add": |
334 | var v int |
335 | if err := json.Unmarshal(req.Params, &v); err != nil { |
336 | return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) |
337 | } |
338 | h.accumulator += v |
339 | return nil, nil |
340 | case "get": |
341 | if len(req.Params) > 0 { |
342 | return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams) |
343 | } |
344 | return h.accumulator, nil |
345 | case "join": |
346 | var v []string |
347 | if err := json.Unmarshal(req.Params, &v); err != nil { |
348 | return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) |
349 | } |
350 | return path.Join(v...), nil |
351 | case "echo": |
352 | var v []interface{} |
353 | if err := json.Unmarshal(req.Params, &v); err != nil { |
354 | return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) |
355 | } |
356 | var result interface{} |
357 | err := h.conn.Call(ctx, v[0].(string), v[1]).Await(ctx, &result) |
358 | return result, err |
359 | case "wait": |
360 | var name string |
361 | if err := json.Unmarshal(req.Params, &name); err != nil { |
362 | return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) |
363 | } |
364 | select { |
365 | case <-h.waiter(name): |
366 | return true, nil |
367 | case <-ctx.Done(): |
368 | return nil, ctx.Err() |
369 | } |
370 | case "fork": |
371 | var name string |
372 | if err := json.Unmarshal(req.Params, &name); err != nil { |
373 | return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) |
374 | } |
375 | waitFor := h.waiter(name) |
376 | go func() { |
377 | select { |
378 | case <-waitFor: |
379 | h.conn.Respond(req.ID, true, nil) |
380 | case <-ctx.Done(): |
381 | h.conn.Respond(req.ID, nil, ctx.Err()) |
382 | } |
383 | }() |
384 | return nil, jsonrpc2.ErrAsyncResponse |
385 | default: |
386 | return nil, jsonrpc2.ErrNotHandled |
387 | } |
388 | } |
389 |
Members