fork of https://github.com/sourcegraph/zoekt
1package messagesize
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "io"
8 "strings"
9 "testing"
10
11 "github.com/google/go-cmp/cmp"
12 "github.com/google/go-cmp/cmp/cmpopts"
13 "github.com/stretchr/testify/require"
14 "google.golang.org/grpc"
15 "google.golang.org/protobuf/proto"
16 "google.golang.org/protobuf/testing/protocmp"
17 "google.golang.org/protobuf/types/known/timestamppb"
18
19 newspb "github.com/sourcegraph/zoekt/grpc/testprotos/news/v1"
20)
21
22var (
23 binaryMessage = &newspb.BinaryAttachment{
24 Name: "data",
25 Data: []byte(strings.Repeat("x", 1*1024*1024)),
26 }
27
28 keyValueMessage = &newspb.KeyValueAttachment{
29 Name: "data",
30 Data: map[string]string{
31 "key1": strings.Repeat("x", 1*1024*1024),
32 "key2": "value2",
33 },
34 }
35
36 articleMessage = &newspb.Article{
37 Author: "author",
38 Date: ×tamppb.Timestamp{Seconds: 1234567890},
39 Title: "title",
40 Content: "content",
41 Status: newspb.Article_STATUS_PUBLISHED,
42 Attachments: []*newspb.Attachment{
43 {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}},
44 {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}},
45 {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}},
46 {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}},
47 },
48 }
49)
50
51func BenchmarkObserverBinary(b *testing.B) {
52 o := messageSizeObserver{
53 onSingleFunc: func(messageSizeBytes uint64) {},
54 onFinishFunc: func(totalSizeBytes uint64) {},
55 }
56
57 benchmarkObserver(b, &o, binaryMessage)
58}
59
60func BenchmarkObserverKeyValue(b *testing.B) {
61 o := messageSizeObserver{
62 onSingleFunc: func(messageSizeBytes uint64) {},
63 onFinishFunc: func(totalSizeBytes uint64) {},
64 }
65
66 benchmarkObserver(b, &o, keyValueMessage)
67}
68
69func BenchmarkObserverArticle(b *testing.B) {
70 o := messageSizeObserver{
71 onSingleFunc: func(messageSizeBytes uint64) {},
72 onFinishFunc: func(totalSizeBytes uint64) {},
73 }
74
75 benchmarkObserver(b, &o, articleMessage)
76}
77
78func benchmarkObserver(b *testing.B, observer *messageSizeObserver, message proto.Message) {
79 b.ReportAllocs()
80
81 for n := 0; n < b.N; n++ {
82 observer.Observe(message)
83 }
84
85 observer.FinishRPC()
86}
87
88func TestUnaryServerInterceptor(t *testing.T) {
89 ctx := context.Background()
90
91 request := &newspb.BinaryAttachment{
92 Data: bytes.Repeat([]byte("request"), 3),
93 }
94
95 response := &newspb.BinaryAttachment{
96 Data: bytes.Repeat([]byte("response"), 7),
97 }
98
99 info := &grpc.UnaryServerInfo{
100 FullMethod: "news.v1.NewsService/GetArticle",
101 }
102
103 sentinelError := errors.New("expected error")
104
105 tests := []struct {
106 name string
107 handler func(ctx context.Context, req any) (any, error)
108 expectedError error
109 expectedResult any
110 expectedSize uint64
111 }{
112 {
113 name: "invoker successful - observe response",
114 handler: func(ctx context.Context, req any) (any, error) {
115 return response, nil
116 },
117 expectedError: nil,
118 expectedResult: response,
119 expectedSize: uint64(proto.Size(response)),
120 },
121 {
122 name: "invoker error - observe a zero-sized response",
123 handler: func(ctx context.Context, req any) (any, error) {
124 return nil, sentinelError
125 },
126 expectedError: sentinelError,
127 expectedResult: nil,
128 expectedSize: uint64(0),
129 },
130 }
131
132 for _, test := range tests {
133 t.Run(test.name, func(t *testing.T) {
134 onFinishCalledCount := 0
135
136 observer := messageSizeObserver{
137 onSingleFunc: func(messageSizeBytes uint64) {},
138 onFinishFunc: func(totalSizeBytes uint64) {
139 onFinishCalledCount++
140
141 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" {
142 t.Error("totalSizeBytes mismatch (-want +got):\n", diff)
143 }
144 },
145 }
146
147 actualResult, err := unaryServerInterceptor(&observer, request, ctx, info, test.handler)
148 if err != test.expectedError {
149 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err)
150 }
151
152 if diff := cmp.Diff(test.expectedResult, actualResult, protocmp.Transform()); diff != "" {
153 t.Error("response mismatch (-want +got):\n", diff)
154 }
155
156 if diff := cmp.Diff(1, onFinishCalledCount); diff != "" {
157 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff)
158 }
159 })
160 }
161}
162
163func TestStreamServerInterceptor(t *testing.T) {
164 response1 := &newspb.BinaryAttachment{
165 Name: "",
166 Data: []byte("response"),
167 }
168 response2 := &newspb.BinaryAttachment{
169 Name: "",
170 Data: bytes.Repeat([]byte("response"), 3),
171 }
172 response3 := &newspb.BinaryAttachment{
173 Name: "",
174 Data: bytes.Repeat([]byte("response"), 7),
175 }
176
177 info := &grpc.StreamServerInfo{
178 FullMethod: "news.v1.NewsService/GetArticle",
179 }
180
181 sentinelError := errors.New("expected error")
182
183 tests := []struct {
184 name string
185
186 mockSendMsg func(m any) error
187 handler func(srv any, stream grpc.ServerStream) error
188
189 expectedError error
190 expectedResponses []any
191 expectedSize uint64
192 }{
193 {
194 name: "invoker successful - observe all 3 responses",
195
196 mockSendMsg: func(m any) error {
197 return nil // no error
198 },
199
200 handler: func(srv any, stream grpc.ServerStream) error {
201 for _, r := range []proto.Message{response1, response2, response3} {
202 if err := stream.SendMsg(r); err != nil {
203 return err
204 }
205 }
206
207 return nil
208 },
209
210 expectedError: nil,
211 expectedResponses: []any{response1, response2, response3},
212 expectedSize: uint64(proto.Size(response1) + proto.Size(response2) + proto.Size(response3)),
213 },
214
215 {
216 name: "invoker fails on 3rd response - only observe first 2",
217
218 mockSendMsg: func(m any) error {
219 if m == response3 {
220 return sentinelError
221 }
222
223 return nil
224 },
225 handler: func(srv any, stream grpc.ServerStream) error {
226 for _, r := range []proto.Message{response1, response2, response3} {
227 if err := stream.SendMsg(r); err != nil {
228 return err
229 }
230 }
231
232 return nil
233 },
234
235 expectedError: sentinelError,
236 expectedResponses: []any{response1, response2, response3}, // response 3 should still be attempted to be sent
237 expectedSize: uint64(proto.Size(response1) + proto.Size(response2)), // response 3 should not be counted since an error occurred while sending it
238 },
239
240 {
241 name: "invoker fails immediately - should still observe a zero-sized response",
242
243 mockSendMsg: func(m any) error {
244 return errors.New("should not be called")
245 },
246
247 handler: func(srv any, stream grpc.ServerStream) error {
248 return sentinelError
249 },
250
251 expectedError: sentinelError,
252 expectedResponses: []any{}, // there are no responses
253 expectedSize: uint64(0), // there are no responses, so the size is 0
254 },
255 }
256
257 for _, test := range tests {
258 t.Run(test.name, func(t *testing.T) {
259 onFinishCallCount := 0
260
261 observer := messageSizeObserver{
262 onSingleFunc: func(messageSizeBytes uint64) {},
263 onFinishFunc: func(totalSizeBytes uint64) {
264 onFinishCallCount++
265
266 if totalSizeBytes != test.expectedSize {
267 t.Errorf("totalSizeBytes mismatch (wanted: %d, got: %d)", test.expectedSize, totalSizeBytes)
268 }
269 },
270 }
271
272 var actualResponses []any
273
274 ss := &mockServerStream{
275 mockSendMsg: func(m any) error {
276 actualResponses = append(actualResponses, m)
277
278 return test.mockSendMsg(m)
279 },
280 }
281
282 err := streamServerInterceptor(&observer, nil, ss, info, test.handler)
283 if err != test.expectedError {
284 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err)
285 }
286
287 if diff := cmp.Diff(test.expectedResponses, actualResponses, protocmp.Transform(), cmpopts.EquateEmpty()); diff != "" {
288 t.Error("responses mismatch (-want +got):\n", diff)
289 }
290
291 if diff := cmp.Diff(1, onFinishCallCount); diff != "" {
292 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff)
293 }
294 })
295 }
296}
297
298func TestUnaryClientInterceptor(t *testing.T) {
299 ctx := context.Background()
300
301 request := &newspb.BinaryAttachment{
302 Name: "data",
303 Data: bytes.Repeat([]byte("request"), 3),
304 }
305
306 method := "news.v1.NewsService/GetArticle"
307
308 sentinelError := errors.New("expected error")
309
310 tests := []struct {
311 name string
312 invoker grpc.UnaryInvoker
313
314 expectedError error
315 expectedRequest any
316 expectedSize uint64
317 }{
318 {
319 name: "invoker successful - observe request size",
320 invoker: func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
321 return nil
322 },
323
324 expectedError: nil,
325 expectedRequest: request,
326 expectedSize: uint64(proto.Size(request)),
327 },
328
329 {
330 name: "invoker error - observe a zero-sized response",
331 invoker: func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
332 return sentinelError
333 },
334
335 expectedError: sentinelError,
336 expectedRequest: request,
337 expectedSize: uint64(0),
338 },
339 }
340
341 for _, test := range tests {
342 t.Run(test.name, func(t *testing.T) {
343 onFinishCallCount := 0
344
345 observer := messageSizeObserver{
346 onSingleFunc: func(messageSizeBytes uint64) {},
347 onFinishFunc: func(totalSizeBytes uint64) {
348 onFinishCallCount++
349
350 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" {
351 t.Error("totalSizeBytes mismatch (-want +got):\n", diff)
352 }
353 },
354 }
355
356 var actualRequest any
357
358 invokerCalled := false
359 invoker := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
360 invokerCalled = true
361
362 actualRequest = req
363 return test.invoker(ctx, method, req, reply, cc, opts...)
364 }
365
366 err := unaryClientInterceptor(&observer, ctx, method, request, nil, nil, invoker)
367 if err != test.expectedError {
368 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err)
369 }
370
371 if !invokerCalled {
372 t.Fatal("invoker not called")
373 }
374
375 if diff := cmp.Diff(test.expectedRequest, actualRequest, protocmp.Transform()); diff != "" {
376 t.Error("request mismatch (-want +got):\n", diff)
377 }
378
379 if diff := cmp.Diff(1, onFinishCallCount); diff != "" {
380 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff)
381 }
382 })
383 }
384}
385
386func TestStreamingClientInterceptor(t *testing.T) {
387 ctx := context.Background()
388
389 request1 := &newspb.BinaryAttachment{
390 Name: "data",
391 Data: bytes.Repeat([]byte("request"), 3),
392 }
393
394 request2 := &newspb.BinaryAttachment{
395 Name: "data",
396 Data: bytes.Repeat([]byte("request"), 7),
397 }
398
399 request3 := &newspb.BinaryAttachment{
400 Name: "data",
401 Data: bytes.Repeat([]byte("request"), 13),
402 }
403
404 method := "news.v1.NewsService/GetArticle"
405
406 sentinelError := errors.New("expected error")
407
408 type stepType int
409
410 const (
411 stepSend stepType = iota
412 stepRecv
413 stepCloseSend
414 )
415
416 type step struct {
417 stepType stepType
418
419 message any
420 streamErr error
421 }
422
423 tests := []struct {
424 name string
425
426 steps []step
427 expectedSize uint64
428 }{
429 {
430 name: "invoker successful - observe request size",
431 steps: []step{
432 {
433 stepType: stepSend,
434
435 message: request1,
436 streamErr: nil,
437 },
438 {
439 stepType: stepSend,
440
441 message: request2,
442 streamErr: nil,
443 },
444 {
445 stepType: stepSend,
446
447 message: request3,
448 streamErr: nil,
449 },
450 {
451 stepType: stepRecv,
452
453 message: nil,
454 streamErr: io.EOF, // end of stream
455 },
456 },
457
458 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)),
459 },
460 {
461 name: "2nd send failed - stream aborts and should only observe first request",
462 steps: []step{
463 {
464 stepType: stepSend,
465 message: request1,
466 streamErr: nil,
467 },
468 {
469 stepType: stepSend,
470 message: request2,
471 streamErr: sentinelError,
472 },
473 },
474
475 expectedSize: uint64(proto.Size(request1)),
476 },
477 {
478 name: "recv message fails with non io.EOF error - should still observe all requests",
479 steps: []step{
480 {
481 stepType: stepSend,
482
483 message: request1,
484 streamErr: nil,
485 },
486 {
487 stepType: stepSend,
488
489 message: request2,
490 streamErr: nil,
491 },
492 {
493 stepType: stepSend,
494
495 message: request3,
496 streamErr: nil,
497 },
498 {
499 stepType: stepRecv,
500
501 message: nil,
502 streamErr: sentinelError,
503 },
504 },
505
506 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)),
507 },
508
509 {
510 name: "close send called - should observe all requests",
511 steps: []step{
512 {
513 stepType: stepSend,
514
515 message: request1,
516 streamErr: nil,
517 },
518 {
519 stepType: stepSend,
520
521 message: request2,
522 streamErr: nil,
523 },
524 {
525 stepType: stepSend,
526
527 message: request3,
528 streamErr: nil,
529 },
530 {
531 stepType: stepCloseSend,
532
533 message: nil,
534 streamErr: nil,
535 },
536 },
537
538 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)),
539 },
540 {
541 name: "close send called immediately - should observe zero-sized response",
542 steps: []step{
543 {
544 stepType: stepCloseSend,
545
546 message: nil,
547 streamErr: nil,
548 },
549 },
550
551 expectedSize: uint64(0),
552 },
553 {
554 name: "first send fails - stream should abort and observe zero-sized response",
555 steps: []step{
556 {
557 stepType: stepSend,
558
559 message: request1,
560 streamErr: sentinelError,
561 },
562 },
563
564 expectedSize: uint64(0),
565 },
566 }
567
568 for _, test := range tests {
569 t.Run(test.name, func(t *testing.T) {
570 onFinishCallCount := 0
571
572 observer := messageSizeObserver{
573 onSingleFunc: func(messageSizeBytes uint64) {},
574 onFinishFunc: func(totalSizeBytes uint64) {
575 onFinishCallCount++
576
577 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" {
578 t.Error("totalSizeBytes mismatch (-want +got):\n", diff)
579 }
580 },
581 }
582
583 baseStream := &mockClientStream{}
584 streamerCalled := false
585 streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
586 streamerCalled = true
587
588 return baseStream, nil
589 }
590
591 ss, err := streamClientInterceptor(&observer, ctx, nil, nil, method, streamer)
592 require.NoError(t, err)
593
594 // Run through all the steps, preparing the mockClientStream to return the expected errors
595 for _, step := range test.steps {
596 baseStreamCalled := false
597 var streamErr error
598
599 switch step.stepType {
600 case stepSend:
601 baseStream.mockSendMsg = func(m any) error {
602 baseStreamCalled = true
603 return step.streamErr
604 }
605
606 streamErr = ss.SendMsg(step.message)
607 case stepRecv:
608 baseStream.mockRecvMsg = func(_ any) error {
609 baseStreamCalled = true
610 return step.streamErr
611 }
612
613 streamErr = ss.RecvMsg(step.message)
614
615 case stepCloseSend:
616 baseStream.mockCloseSend = func() error {
617 baseStreamCalled = true
618 return step.streamErr
619 }
620
621 streamErr = ss.CloseSend()
622 default:
623 t.Fatalf("unknown step type: %v", step.stepType)
624 }
625
626 // ensure that the baseStream was called and errors are propagated
627 require.True(t, baseStreamCalled)
628 require.Equal(t, step.streamErr, streamErr)
629 }
630
631 if !streamerCalled {
632 t.Fatal("streamer not called")
633 }
634
635 if diff := cmp.Diff(1, onFinishCallCount); diff != "" {
636 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff)
637 }
638 })
639 }
640}
641
642func TestObserver(t *testing.T) {
643 testCases := []struct {
644 name string
645 messages []proto.Message
646 }{
647 {
648 name: "single message",
649 messages: []proto.Message{&newspb.BinaryAttachment{
650 Name: "data1",
651 Data: []byte("sample data"),
652 }},
653 },
654 {
655 name: "multiple messages",
656 messages: []proto.Message{
657 &newspb.BinaryAttachment{
658 Name: "data1",
659 Data: []byte("sample data"),
660 },
661 &newspb.KeyValueAttachment{
662 Name: "data2",
663 Data: map[string]string{
664 "key1": "value1",
665 "key2": "value2",
666 },
667 },
668 },
669 },
670 }
671
672 for _, tc := range testCases {
673 t.Run(tc.name, func(t *testing.T) {
674 var singleMessageSizes []uint64
675 var totalSize uint64
676
677 // Create a new observer with custom onSingleFunc and onFinishFunc
678 obs := &messageSizeObserver{
679 onSingleFunc: func(messageSizeBytes uint64) {
680 singleMessageSizes = append(singleMessageSizes, messageSizeBytes)
681 },
682 onFinishFunc: func(totalSizeBytes uint64) {
683 totalSize = totalSizeBytes
684 },
685 }
686
687 // Call ObserveSingle for each message
688 for _, msg := range tc.messages {
689 obs.Observe(msg)
690 }
691
692 // Check that the singleMessageSizes are correct
693 for i, msg := range tc.messages {
694 expectedSize := uint64(proto.Size(msg))
695 require.Equal(t, expectedSize, singleMessageSizes[i])
696 }
697
698 // Call FinishRPC
699 obs.FinishRPC()
700
701 // Check that the totalSize is correct
702 expectedTotalSize := uint64(0)
703 for _, size := range singleMessageSizes {
704 expectedTotalSize += size
705 }
706 require.EqualValues(t, expectedTotalSize, totalSize)
707 })
708 }
709}
710
711type mockServerStream struct {
712 mockSendMsg func(m any) error
713
714 grpc.ServerStream
715}
716
717func (s *mockServerStream) SendMsg(m any) error {
718 if s.mockSendMsg != nil {
719 return s.mockSendMsg(m)
720 }
721
722 return errors.New("send msg not implemented")
723}
724
725type mockClientStream struct {
726 mockRecvMsg func(m any) error
727 mockSendMsg func(m any) error
728 mockCloseSend func() error
729
730 grpc.ClientStream
731}
732
733func (s *mockClientStream) SendMsg(m any) error {
734 if s.mockSendMsg != nil {
735 return s.mockSendMsg(m)
736 }
737
738 return errors.New("send msg not implemented")
739}
740
741func (s *mockClientStream) RecvMsg(m any) error {
742 if s.mockRecvMsg != nil {
743 return s.mockRecvMsg(m)
744 }
745
746 return errors.New("recv msg not implemented")
747}
748
749func (s *mockClientStream) CloseSend() error {
750 if s.mockCloseSend != nil {
751 return s.mockCloseSend()
752 }
753
754 return errors.New("close send not implemented")
755}
756
757var (
758 _ grpc.ServerStream = &mockServerStream{}
759 _ grpc.ClientStream = &mockClientStream{}
760)