fork of https://github.com/sourcegraph/zoekt
1package messagesize
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "io"
8 "strings"
9 "testing"
10
11 "github.com/google/go-cmp/cmp"
12 "github.com/google/go-cmp/cmp/cmpopts"
13 "github.com/stretchr/testify/require"
14 "google.golang.org/grpc"
15 "google.golang.org/protobuf/proto"
16 "google.golang.org/protobuf/testing/protocmp"
17 "google.golang.org/protobuf/types/known/timestamppb"
18
19 newspb "github.com/sourcegraph/zoekt/grpc/testprotos/news/v1"
20)
21
22var (
23 binaryMessage = &newspb.BinaryAttachment{
24 Name: "data",
25 Data: []byte(strings.Repeat("x", 1*1024*1024)),
26 }
27
28 keyValueMessage = &newspb.KeyValueAttachment{
29 Name: "data",
30 Data: map[string]string{
31 "key1": strings.Repeat("x", 1*1024*1024),
32 "key2": "value2",
33 },
34 }
35
36 articleMessage = &newspb.Article{
37 Author: "author",
38 Date: ×tamppb.Timestamp{Seconds: 1234567890},
39 Title: "title",
40 Content: "content",
41 Status: newspb.Article_STATUS_PUBLISHED,
42 Attachments: []*newspb.Attachment{
43 {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}},
44 {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}},
45 {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}},
46 {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}},
47 },
48 }
49)
50
51func BenchmarkObserverBinary(b *testing.B) {
52 o := messageSizeObserver{
53 onSingleFunc: func(messageSizeBytes uint64) {},
54 onFinishFunc: func(totalSizeBytes uint64) {},
55 }
56
57 benchmarkObserver(b, &o, binaryMessage)
58}
59
60func BenchmarkObserverKeyValue(b *testing.B) {
61 o := messageSizeObserver{
62 onSingleFunc: func(messageSizeBytes uint64) {},
63 onFinishFunc: func(totalSizeBytes uint64) {},
64 }
65
66 benchmarkObserver(b, &o, keyValueMessage)
67}
68
69func BenchmarkObserverArticle(b *testing.B) {
70 o := messageSizeObserver{
71 onSingleFunc: func(messageSizeBytes uint64) {},
72 onFinishFunc: func(totalSizeBytes uint64) {},
73 }
74
75 benchmarkObserver(b, &o, articleMessage)
76}
77
78func benchmarkObserver(b *testing.B, observer *messageSizeObserver, message proto.Message) {
79 b.ReportAllocs()
80
81 for n := 0; n < b.N; n++ {
82 observer.Observe(message)
83 }
84
85 observer.FinishRPC()
86}
87
88func TestUnaryServerInterceptor(t *testing.T) {
89 ctx := context.Background()
90
91 request := &newspb.BinaryAttachment{
92 Data: bytes.Repeat([]byte("request"), 3),
93 }
94
95 response := &newspb.BinaryAttachment{
96 Data: bytes.Repeat([]byte("response"), 7),
97 }
98
99 info := &grpc.UnaryServerInfo{
100 FullMethod: "news.v1.NewsService/GetArticle",
101 }
102
103 sentinelError := errors.New("expected error")
104
105 tests := []struct {
106 name string
107 handler func(ctx context.Context, req any) (any, error)
108 expectedError error
109 expectedResult any
110 expectedSize uint64
111 }{
112 {
113 name: "invoker successful - observe response",
114 handler: func(ctx context.Context, req any) (any, error) {
115 return response, nil
116 },
117 expectedError: nil,
118 expectedResult: response,
119 expectedSize: uint64(proto.Size(response)),
120 },
121 {
122 name: "invoker error - observe a zero-sized response",
123 handler: func(ctx context.Context, req any) (any, error) {
124 return nil, sentinelError
125 },
126 expectedError: sentinelError,
127 expectedResult: nil,
128 expectedSize: uint64(0),
129 },
130 }
131
132 for _, test := range tests {
133 t.Run(test.name, func(t *testing.T) {
134 onFinishCalledCount := 0
135
136 observer := messageSizeObserver{
137 onSingleFunc: func(messageSizeBytes uint64) {},
138 onFinishFunc: func(totalSizeBytes uint64) {
139 onFinishCalledCount++
140
141 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" {
142 t.Error("totalSizeBytes mismatch (-want +got):\n", diff)
143 }
144 },
145 }
146
147 actualResult, err := unaryServerInterceptor(&observer, request, ctx, info, test.handler)
148 if err != test.expectedError {
149 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err)
150 }
151
152 if diff := cmp.Diff(test.expectedResult, actualResult, protocmp.Transform()); diff != "" {
153 t.Error("response mismatch (-want +got):\n", diff)
154 }
155
156 if diff := cmp.Diff(1, onFinishCalledCount); diff != "" {
157 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff)
158 }
159 })
160 }
161}
162
163func TestStreamServerInterceptor(t *testing.T) {
164
165 response1 := &newspb.BinaryAttachment{
166 Name: "",
167 Data: []byte("response"),
168 }
169 response2 := &newspb.BinaryAttachment{
170 Name: "",
171 Data: bytes.Repeat([]byte("response"), 3),
172 }
173 response3 := &newspb.BinaryAttachment{
174 Name: "",
175 Data: bytes.Repeat([]byte("response"), 7),
176 }
177
178 info := &grpc.StreamServerInfo{
179 FullMethod: "news.v1.NewsService/GetArticle",
180 }
181
182 sentinelError := errors.New("expected error")
183
184 tests := []struct {
185 name string
186
187 mockSendMsg func(m any) error
188 handler func(srv any, stream grpc.ServerStream) error
189
190 expectedError error
191 expectedResponses []any
192 expectedSize uint64
193 }{
194 {
195 name: "invoker successful - observe all 3 responses",
196
197 mockSendMsg: func(m any) error {
198 return nil // no error
199 },
200
201 handler: func(srv any, stream grpc.ServerStream) error {
202 for _, r := range []proto.Message{response1, response2, response3} {
203 if err := stream.SendMsg(r); err != nil {
204 return err
205 }
206 }
207
208 return nil
209 },
210
211 expectedError: nil,
212 expectedResponses: []any{response1, response2, response3},
213 expectedSize: uint64(proto.Size(response1) + proto.Size(response2) + proto.Size(response3)),
214 },
215
216 {
217 name: "invoker fails on 3rd response - only observe first 2",
218
219 mockSendMsg: func(m any) error {
220 if m == response3 {
221 return sentinelError
222 }
223
224 return nil
225 },
226 handler: func(srv any, stream grpc.ServerStream) error {
227 for _, r := range []proto.Message{response1, response2, response3} {
228 if err := stream.SendMsg(r); err != nil {
229 return err
230 }
231 }
232
233 return nil
234 },
235
236 expectedError: sentinelError,
237 expectedResponses: []any{response1, response2, response3}, // response 3 should still be attempted to be sent
238 expectedSize: uint64(proto.Size(response1) + proto.Size(response2)), // response 3 should not be counted since an error occurred while sending it
239 },
240
241 {
242 name: "invoker fails immediately - should still observe a zero-sized response",
243
244 mockSendMsg: func(m any) error {
245 return errors.New("should not be called")
246 },
247
248 handler: func(srv any, stream grpc.ServerStream) error {
249 return sentinelError
250 },
251
252 expectedError: sentinelError,
253 expectedResponses: []any{}, // there are no responses
254 expectedSize: uint64(0), // there are no responses, so the size is 0
255 },
256 }
257
258 for _, test := range tests {
259 t.Run(test.name, func(t *testing.T) {
260 onFinishCallCount := 0
261
262 observer := messageSizeObserver{
263 onSingleFunc: func(messageSizeBytes uint64) {},
264 onFinishFunc: func(totalSizeBytes uint64) {
265 onFinishCallCount++
266
267 if totalSizeBytes != test.expectedSize {
268 t.Errorf("totalSizeBytes mismatch (wanted: %d, got: %d)", test.expectedSize, totalSizeBytes)
269 }
270 },
271 }
272
273 var actualResponses []any
274
275 ss := &mockServerStream{
276 mockSendMsg: func(m any) error {
277 actualResponses = append(actualResponses, m)
278
279 return test.mockSendMsg(m)
280 },
281 }
282
283 err := streamServerInterceptor(&observer, nil, ss, info, test.handler)
284 if err != test.expectedError {
285 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err)
286 }
287
288 if diff := cmp.Diff(test.expectedResponses, actualResponses, protocmp.Transform(), cmpopts.EquateEmpty()); diff != "" {
289 t.Error("responses mismatch (-want +got):\n", diff)
290 }
291
292 if diff := cmp.Diff(1, onFinishCallCount); diff != "" {
293 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff)
294 }
295 })
296 }
297}
298
299func TestUnaryClientInterceptor(t *testing.T) {
300 ctx := context.Background()
301
302 request := &newspb.BinaryAttachment{
303 Name: "data",
304 Data: bytes.Repeat([]byte("request"), 3),
305 }
306
307 method := "news.v1.NewsService/GetArticle"
308
309 sentinelError := errors.New("expected error")
310
311 tests := []struct {
312 name string
313 invoker grpc.UnaryInvoker
314
315 expectedError error
316 expectedRequest any
317 expectedSize uint64
318 }{
319 {
320 name: "invoker successful - observe request size",
321 invoker: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
322 return nil
323 },
324
325 expectedError: nil,
326 expectedRequest: request,
327 expectedSize: uint64(proto.Size(request)),
328 },
329
330 {
331 name: "invoker error - observe a zero-sized response",
332 invoker: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
333 return sentinelError
334 },
335
336 expectedError: sentinelError,
337 expectedRequest: request,
338 expectedSize: uint64(0),
339 },
340 }
341
342 for _, test := range tests {
343 t.Run(test.name, func(t *testing.T) {
344 onFinishCallCount := 0
345
346 observer := messageSizeObserver{
347 onSingleFunc: func(messageSizeBytes uint64) {},
348 onFinishFunc: func(totalSizeBytes uint64) {
349 onFinishCallCount++
350
351 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" {
352 t.Error("totalSizeBytes mismatch (-want +got):\n", diff)
353 }
354 },
355 }
356
357 var actualRequest any
358
359 invokerCalled := false
360 invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
361 invokerCalled = true
362
363 actualRequest = req
364 return test.invoker(ctx, method, req, reply, cc, opts...)
365 }
366
367 err := unaryClientInterceptor(&observer, ctx, method, request, nil, nil, invoker)
368 if err != test.expectedError {
369 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err)
370 }
371
372 if !invokerCalled {
373 t.Fatal("invoker not called")
374 }
375
376 if diff := cmp.Diff(test.expectedRequest, actualRequest, protocmp.Transform()); diff != "" {
377 t.Error("request mismatch (-want +got):\n", diff)
378 }
379
380 if diff := cmp.Diff(1, onFinishCallCount); diff != "" {
381 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff)
382 }
383 })
384 }
385}
386
387func TestStreamingClientInterceptor(t *testing.T) {
388 ctx := context.Background()
389
390 request1 := &newspb.BinaryAttachment{
391 Name: "data",
392 Data: bytes.Repeat([]byte("request"), 3),
393 }
394
395 request2 := &newspb.BinaryAttachment{
396 Name: "data",
397 Data: bytes.Repeat([]byte("request"), 7),
398 }
399
400 request3 := &newspb.BinaryAttachment{
401 Name: "data",
402 Data: bytes.Repeat([]byte("request"), 13),
403 }
404
405 method := "news.v1.NewsService/GetArticle"
406
407 sentinelError := errors.New("expected error")
408
409 type stepType int
410
411 const (
412 stepSend stepType = iota
413 stepRecv
414 stepCloseSend
415 )
416
417 type step struct {
418 stepType stepType
419
420 message any
421 streamErr error
422 }
423
424 tests := []struct {
425 name string
426
427 steps []step
428 expectedSize uint64
429 }{
430 {
431 name: "invoker successful - observe request size",
432 steps: []step{
433 {
434 stepType: stepSend,
435
436 message: request1,
437 streamErr: nil,
438 },
439 {
440 stepType: stepSend,
441
442 message: request2,
443 streamErr: nil,
444 },
445 {
446 stepType: stepSend,
447
448 message: request3,
449 streamErr: nil,
450 },
451 {
452 stepType: stepRecv,
453
454 message: nil,
455 streamErr: io.EOF, // end of stream
456 },
457 },
458
459 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)),
460 },
461 {
462 name: "2nd send failed - stream aborts and should only observe first request",
463 steps: []step{
464 {
465 stepType: stepSend,
466 message: request1,
467 streamErr: nil,
468 },
469 {
470 stepType: stepSend,
471 message: request2,
472 streamErr: sentinelError,
473 },
474 },
475
476 expectedSize: uint64(proto.Size(request1)),
477 },
478 {
479 name: "recv message fails with non io.EOF error - should still observe all requests",
480 steps: []step{
481 {
482 stepType: stepSend,
483
484 message: request1,
485 streamErr: nil,
486 },
487 {
488 stepType: stepSend,
489
490 message: request2,
491 streamErr: nil,
492 },
493 {
494 stepType: stepSend,
495
496 message: request3,
497 streamErr: nil,
498 },
499 {
500 stepType: stepRecv,
501
502 message: nil,
503 streamErr: sentinelError,
504 },
505 },
506
507 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)),
508 },
509
510 {
511 name: "close send called - should observe all requests",
512 steps: []step{
513 {
514 stepType: stepSend,
515
516 message: request1,
517 streamErr: nil,
518 },
519 {
520 stepType: stepSend,
521
522 message: request2,
523 streamErr: nil,
524 },
525 {
526 stepType: stepSend,
527
528 message: request3,
529 streamErr: nil,
530 },
531 {
532 stepType: stepCloseSend,
533
534 message: nil,
535 streamErr: nil,
536 },
537 },
538
539 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)),
540 },
541 {
542 name: "close send called immediately - should observe zero-sized response",
543 steps: []step{
544 {
545 stepType: stepCloseSend,
546
547 message: nil,
548 streamErr: nil,
549 },
550 },
551
552 expectedSize: uint64(0),
553 },
554 {
555 name: "first send fails - stream should abort and observe zero-sized response",
556 steps: []step{
557 {
558 stepType: stepSend,
559
560 message: request1,
561 streamErr: sentinelError,
562 },
563 },
564
565 expectedSize: uint64(0),
566 },
567 }
568
569 for _, test := range tests {
570 t.Run(test.name, func(t *testing.T) {
571 onFinishCallCount := 0
572
573 observer := messageSizeObserver{
574 onSingleFunc: func(messageSizeBytes uint64) {},
575 onFinishFunc: func(totalSizeBytes uint64) {
576 onFinishCallCount++
577
578 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" {
579 t.Error("totalSizeBytes mismatch (-want +got):\n", diff)
580 }
581 },
582 }
583
584 baseStream := &mockClientStream{}
585 streamerCalled := false
586 streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
587 streamerCalled = true
588
589 return baseStream, nil
590 }
591
592 ss, err := streamClientInterceptor(&observer, ctx, nil, nil, method, streamer)
593 require.NoError(t, err)
594
595 // Run through all the steps, preparing the mockClientStream to return the expected errors
596 for _, step := range test.steps {
597 baseStreamCalled := false
598 var streamErr error
599
600 switch step.stepType {
601 case stepSend:
602 baseStream.mockSendMsg = func(m any) error {
603 baseStreamCalled = true
604 return step.streamErr
605 }
606
607 streamErr = ss.SendMsg(step.message)
608 case stepRecv:
609 baseStream.mockRecvMsg = func(_ any) error {
610 baseStreamCalled = true
611 return step.streamErr
612 }
613
614 streamErr = ss.RecvMsg(step.message)
615
616 case stepCloseSend:
617 baseStream.mockCloseSend = func() error {
618 baseStreamCalled = true
619 return step.streamErr
620 }
621
622 streamErr = ss.CloseSend()
623 default:
624 t.Fatalf("unknown step type: %v", step.stepType)
625 }
626
627 // ensure that the baseStream was called and errors are propagated
628 require.True(t, baseStreamCalled)
629 require.Equal(t, step.streamErr, streamErr)
630 }
631
632 if !streamerCalled {
633 t.Fatal("streamer not called")
634 }
635
636 if diff := cmp.Diff(1, onFinishCallCount); diff != "" {
637 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff)
638 }
639 })
640 }
641}
642
643func TestObserver(t *testing.T) {
644 testCases := []struct {
645 name string
646 messages []proto.Message
647 }{
648 {
649 name: "single message",
650 messages: []proto.Message{&newspb.BinaryAttachment{
651 Name: "data1",
652 Data: []byte("sample data"),
653 }},
654 },
655 {
656 name: "multiple messages",
657 messages: []proto.Message{
658 &newspb.BinaryAttachment{
659 Name: "data1",
660 Data: []byte("sample data"),
661 },
662 &newspb.KeyValueAttachment{
663 Name: "data2",
664 Data: map[string]string{
665 "key1": "value1",
666 "key2": "value2",
667 },
668 },
669 }},
670 }
671
672 for _, tc := range testCases {
673 t.Run(tc.name, func(t *testing.T) {
674 var singleMessageSizes []uint64
675 var totalSize uint64
676
677 // Create a new observer with custom onSingleFunc and onFinishFunc
678 obs := &messageSizeObserver{
679 onSingleFunc: func(messageSizeBytes uint64) {
680 singleMessageSizes = append(singleMessageSizes, messageSizeBytes)
681 },
682 onFinishFunc: func(totalSizeBytes uint64) {
683 totalSize = totalSizeBytes
684 },
685 }
686
687 // Call ObserveSingle for each message
688 for _, msg := range tc.messages {
689 obs.Observe(msg)
690 }
691
692 // Check that the singleMessageSizes are correct
693 for i, msg := range tc.messages {
694 expectedSize := uint64(proto.Size(msg))
695 require.Equal(t, expectedSize, singleMessageSizes[i])
696 }
697
698 // Call FinishRPC
699 obs.FinishRPC()
700
701 // Check that the totalSize is correct
702 expectedTotalSize := uint64(0)
703 for _, size := range singleMessageSizes {
704 expectedTotalSize += size
705 }
706 require.EqualValues(t, expectedTotalSize, totalSize)
707 })
708 }
709}
710
711type mockServerStream struct {
712 mockSendMsg func(m any) error
713
714 grpc.ServerStream
715}
716
717func (s *mockServerStream) SendMsg(m any) error {
718 if s.mockSendMsg != nil {
719 return s.mockSendMsg(m)
720 }
721
722 return errors.New("send msg not implemented")
723}
724
725type mockClientStream struct {
726 mockRecvMsg func(m any) error
727 mockSendMsg func(m any) error
728 mockCloseSend func() error
729
730 grpc.ClientStream
731}
732
733func (s *mockClientStream) SendMsg(m any) error {
734 if s.mockSendMsg != nil {
735 return s.mockSendMsg(m)
736 }
737
738 return errors.New("send msg not implemented")
739}
740
741func (s *mockClientStream) RecvMsg(m any) error {
742 if s.mockRecvMsg != nil {
743 return s.mockRecvMsg(m)
744 }
745
746 return errors.New("recv msg not implemented")
747}
748
749func (s *mockClientStream) CloseSend() error {
750 if s.mockCloseSend != nil {
751 return s.mockCloseSend()
752 }
753
754 return errors.New("close send not implemented")
755}
756
757var _ grpc.ServerStream = &mockServerStream{}
758var _ grpc.ClientStream = &mockClientStream{}