fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package messagesize 2 3import ( 4 "bytes" 5 "context" 6 "errors" 7 "io" 8 "strings" 9 "testing" 10 11 "github.com/google/go-cmp/cmp" 12 "github.com/google/go-cmp/cmp/cmpopts" 13 "github.com/stretchr/testify/require" 14 "google.golang.org/grpc" 15 "google.golang.org/protobuf/proto" 16 "google.golang.org/protobuf/testing/protocmp" 17 "google.golang.org/protobuf/types/known/timestamppb" 18 19 newspb "github.com/sourcegraph/zoekt/grpc/testprotos/news/v1" 20) 21 22var ( 23 binaryMessage = &newspb.BinaryAttachment{ 24 Name: "data", 25 Data: []byte(strings.Repeat("x", 1*1024*1024)), 26 } 27 28 keyValueMessage = &newspb.KeyValueAttachment{ 29 Name: "data", 30 Data: map[string]string{ 31 "key1": strings.Repeat("x", 1*1024*1024), 32 "key2": "value2", 33 }, 34 } 35 36 articleMessage = &newspb.Article{ 37 Author: "author", 38 Date: &timestamppb.Timestamp{Seconds: 1234567890}, 39 Title: "title", 40 Content: "content", 41 Status: newspb.Article_STATUS_PUBLISHED, 42 Attachments: []*newspb.Attachment{ 43 {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}}, 44 {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}}, 45 {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}}, 46 {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}}, 47 }, 48 } 49) 50 51func BenchmarkObserverBinary(b *testing.B) { 52 o := messageSizeObserver{ 53 onSingleFunc: func(messageSizeBytes uint64) {}, 54 onFinishFunc: func(totalSizeBytes uint64) {}, 55 } 56 57 benchmarkObserver(b, &o, binaryMessage) 58} 59 60func BenchmarkObserverKeyValue(b *testing.B) { 61 o := messageSizeObserver{ 62 onSingleFunc: func(messageSizeBytes uint64) {}, 63 onFinishFunc: func(totalSizeBytes uint64) {}, 64 } 65 66 benchmarkObserver(b, &o, keyValueMessage) 67} 68 69func BenchmarkObserverArticle(b *testing.B) { 70 o := messageSizeObserver{ 71 onSingleFunc: func(messageSizeBytes uint64) {}, 72 onFinishFunc: func(totalSizeBytes uint64) {}, 73 } 74 75 benchmarkObserver(b, &o, articleMessage) 76} 77 78func benchmarkObserver(b *testing.B, observer *messageSizeObserver, message proto.Message) { 79 b.ReportAllocs() 80 81 for n := 0; n < b.N; n++ { 82 observer.Observe(message) 83 } 84 85 observer.FinishRPC() 86} 87 88func TestUnaryServerInterceptor(t *testing.T) { 89 ctx := context.Background() 90 91 request := &newspb.BinaryAttachment{ 92 Data: bytes.Repeat([]byte("request"), 3), 93 } 94 95 response := &newspb.BinaryAttachment{ 96 Data: bytes.Repeat([]byte("response"), 7), 97 } 98 99 info := &grpc.UnaryServerInfo{ 100 FullMethod: "news.v1.NewsService/GetArticle", 101 } 102 103 sentinelError := errors.New("expected error") 104 105 tests := []struct { 106 name string 107 handler func(ctx context.Context, req any) (any, error) 108 expectedError error 109 expectedResult any 110 expectedSize uint64 111 }{ 112 { 113 name: "invoker successful - observe response", 114 handler: func(ctx context.Context, req any) (any, error) { 115 return response, nil 116 }, 117 expectedError: nil, 118 expectedResult: response, 119 expectedSize: uint64(proto.Size(response)), 120 }, 121 { 122 name: "invoker error - observe a zero-sized response", 123 handler: func(ctx context.Context, req any) (any, error) { 124 return nil, sentinelError 125 }, 126 expectedError: sentinelError, 127 expectedResult: nil, 128 expectedSize: uint64(0), 129 }, 130 } 131 132 for _, test := range tests { 133 t.Run(test.name, func(t *testing.T) { 134 onFinishCalledCount := 0 135 136 observer := messageSizeObserver{ 137 onSingleFunc: func(messageSizeBytes uint64) {}, 138 onFinishFunc: func(totalSizeBytes uint64) { 139 onFinishCalledCount++ 140 141 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { 142 t.Error("totalSizeBytes mismatch (-want +got):\n", diff) 143 } 144 }, 145 } 146 147 actualResult, err := unaryServerInterceptor(&observer, request, ctx, info, test.handler) 148 if err != test.expectedError { 149 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) 150 } 151 152 if diff := cmp.Diff(test.expectedResult, actualResult, protocmp.Transform()); diff != "" { 153 t.Error("response mismatch (-want +got):\n", diff) 154 } 155 156 if diff := cmp.Diff(1, onFinishCalledCount); diff != "" { 157 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 158 } 159 }) 160 } 161} 162 163func TestStreamServerInterceptor(t *testing.T) { 164 165 response1 := &newspb.BinaryAttachment{ 166 Name: "", 167 Data: []byte("response"), 168 } 169 response2 := &newspb.BinaryAttachment{ 170 Name: "", 171 Data: bytes.Repeat([]byte("response"), 3), 172 } 173 response3 := &newspb.BinaryAttachment{ 174 Name: "", 175 Data: bytes.Repeat([]byte("response"), 7), 176 } 177 178 info := &grpc.StreamServerInfo{ 179 FullMethod: "news.v1.NewsService/GetArticle", 180 } 181 182 sentinelError := errors.New("expected error") 183 184 tests := []struct { 185 name string 186 187 mockSendMsg func(m any) error 188 handler func(srv any, stream grpc.ServerStream) error 189 190 expectedError error 191 expectedResponses []any 192 expectedSize uint64 193 }{ 194 { 195 name: "invoker successful - observe all 3 responses", 196 197 mockSendMsg: func(m any) error { 198 return nil // no error 199 }, 200 201 handler: func(srv any, stream grpc.ServerStream) error { 202 for _, r := range []proto.Message{response1, response2, response3} { 203 if err := stream.SendMsg(r); err != nil { 204 return err 205 } 206 } 207 208 return nil 209 }, 210 211 expectedError: nil, 212 expectedResponses: []any{response1, response2, response3}, 213 expectedSize: uint64(proto.Size(response1) + proto.Size(response2) + proto.Size(response3)), 214 }, 215 216 { 217 name: "invoker fails on 3rd response - only observe first 2", 218 219 mockSendMsg: func(m any) error { 220 if m == response3 { 221 return sentinelError 222 } 223 224 return nil 225 }, 226 handler: func(srv any, stream grpc.ServerStream) error { 227 for _, r := range []proto.Message{response1, response2, response3} { 228 if err := stream.SendMsg(r); err != nil { 229 return err 230 } 231 } 232 233 return nil 234 }, 235 236 expectedError: sentinelError, 237 expectedResponses: []any{response1, response2, response3}, // response 3 should still be attempted to be sent 238 expectedSize: uint64(proto.Size(response1) + proto.Size(response2)), // response 3 should not be counted since an error occurred while sending it 239 }, 240 241 { 242 name: "invoker fails immediately - should still observe a zero-sized response", 243 244 mockSendMsg: func(m any) error { 245 return errors.New("should not be called") 246 }, 247 248 handler: func(srv any, stream grpc.ServerStream) error { 249 return sentinelError 250 }, 251 252 expectedError: sentinelError, 253 expectedResponses: []any{}, // there are no responses 254 expectedSize: uint64(0), // there are no responses, so the size is 0 255 }, 256 } 257 258 for _, test := range tests { 259 t.Run(test.name, func(t *testing.T) { 260 onFinishCallCount := 0 261 262 observer := messageSizeObserver{ 263 onSingleFunc: func(messageSizeBytes uint64) {}, 264 onFinishFunc: func(totalSizeBytes uint64) { 265 onFinishCallCount++ 266 267 if totalSizeBytes != test.expectedSize { 268 t.Errorf("totalSizeBytes mismatch (wanted: %d, got: %d)", test.expectedSize, totalSizeBytes) 269 } 270 }, 271 } 272 273 var actualResponses []any 274 275 ss := &mockServerStream{ 276 mockSendMsg: func(m any) error { 277 actualResponses = append(actualResponses, m) 278 279 return test.mockSendMsg(m) 280 }, 281 } 282 283 err := streamServerInterceptor(&observer, nil, ss, info, test.handler) 284 if err != test.expectedError { 285 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) 286 } 287 288 if diff := cmp.Diff(test.expectedResponses, actualResponses, protocmp.Transform(), cmpopts.EquateEmpty()); diff != "" { 289 t.Error("responses mismatch (-want +got):\n", diff) 290 } 291 292 if diff := cmp.Diff(1, onFinishCallCount); diff != "" { 293 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 294 } 295 }) 296 } 297} 298 299func TestUnaryClientInterceptor(t *testing.T) { 300 ctx := context.Background() 301 302 request := &newspb.BinaryAttachment{ 303 Name: "data", 304 Data: bytes.Repeat([]byte("request"), 3), 305 } 306 307 method := "news.v1.NewsService/GetArticle" 308 309 sentinelError := errors.New("expected error") 310 311 tests := []struct { 312 name string 313 invoker grpc.UnaryInvoker 314 315 expectedError error 316 expectedRequest any 317 expectedSize uint64 318 }{ 319 { 320 name: "invoker successful - observe request size", 321 invoker: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 322 return nil 323 }, 324 325 expectedError: nil, 326 expectedRequest: request, 327 expectedSize: uint64(proto.Size(request)), 328 }, 329 330 { 331 name: "invoker error - observe a zero-sized response", 332 invoker: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 333 return sentinelError 334 }, 335 336 expectedError: sentinelError, 337 expectedRequest: request, 338 expectedSize: uint64(0), 339 }, 340 } 341 342 for _, test := range tests { 343 t.Run(test.name, func(t *testing.T) { 344 onFinishCallCount := 0 345 346 observer := messageSizeObserver{ 347 onSingleFunc: func(messageSizeBytes uint64) {}, 348 onFinishFunc: func(totalSizeBytes uint64) { 349 onFinishCallCount++ 350 351 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { 352 t.Error("totalSizeBytes mismatch (-want +got):\n", diff) 353 } 354 }, 355 } 356 357 var actualRequest any 358 359 invokerCalled := false 360 invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 361 invokerCalled = true 362 363 actualRequest = req 364 return test.invoker(ctx, method, req, reply, cc, opts...) 365 } 366 367 err := unaryClientInterceptor(&observer, ctx, method, request, nil, nil, invoker) 368 if err != test.expectedError { 369 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) 370 } 371 372 if !invokerCalled { 373 t.Fatal("invoker not called") 374 } 375 376 if diff := cmp.Diff(test.expectedRequest, actualRequest, protocmp.Transform()); diff != "" { 377 t.Error("request mismatch (-want +got):\n", diff) 378 } 379 380 if diff := cmp.Diff(1, onFinishCallCount); diff != "" { 381 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 382 } 383 }) 384 } 385} 386 387func TestStreamingClientInterceptor(t *testing.T) { 388 ctx := context.Background() 389 390 request1 := &newspb.BinaryAttachment{ 391 Name: "data", 392 Data: bytes.Repeat([]byte("request"), 3), 393 } 394 395 request2 := &newspb.BinaryAttachment{ 396 Name: "data", 397 Data: bytes.Repeat([]byte("request"), 7), 398 } 399 400 request3 := &newspb.BinaryAttachment{ 401 Name: "data", 402 Data: bytes.Repeat([]byte("request"), 13), 403 } 404 405 method := "news.v1.NewsService/GetArticle" 406 407 sentinelError := errors.New("expected error") 408 409 type stepType int 410 411 const ( 412 stepSend stepType = iota 413 stepRecv 414 stepCloseSend 415 ) 416 417 type step struct { 418 stepType stepType 419 420 message any 421 streamErr error 422 } 423 424 tests := []struct { 425 name string 426 427 steps []step 428 expectedSize uint64 429 }{ 430 { 431 name: "invoker successful - observe request size", 432 steps: []step{ 433 { 434 stepType: stepSend, 435 436 message: request1, 437 streamErr: nil, 438 }, 439 { 440 stepType: stepSend, 441 442 message: request2, 443 streamErr: nil, 444 }, 445 { 446 stepType: stepSend, 447 448 message: request3, 449 streamErr: nil, 450 }, 451 { 452 stepType: stepRecv, 453 454 message: nil, 455 streamErr: io.EOF, // end of stream 456 }, 457 }, 458 459 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), 460 }, 461 { 462 name: "2nd send failed - stream aborts and should only observe first request", 463 steps: []step{ 464 { 465 stepType: stepSend, 466 message: request1, 467 streamErr: nil, 468 }, 469 { 470 stepType: stepSend, 471 message: request2, 472 streamErr: sentinelError, 473 }, 474 }, 475 476 expectedSize: uint64(proto.Size(request1)), 477 }, 478 { 479 name: "recv message fails with non io.EOF error - should still observe all requests", 480 steps: []step{ 481 { 482 stepType: stepSend, 483 484 message: request1, 485 streamErr: nil, 486 }, 487 { 488 stepType: stepSend, 489 490 message: request2, 491 streamErr: nil, 492 }, 493 { 494 stepType: stepSend, 495 496 message: request3, 497 streamErr: nil, 498 }, 499 { 500 stepType: stepRecv, 501 502 message: nil, 503 streamErr: sentinelError, 504 }, 505 }, 506 507 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), 508 }, 509 510 { 511 name: "close send called - should observe all requests", 512 steps: []step{ 513 { 514 stepType: stepSend, 515 516 message: request1, 517 streamErr: nil, 518 }, 519 { 520 stepType: stepSend, 521 522 message: request2, 523 streamErr: nil, 524 }, 525 { 526 stepType: stepSend, 527 528 message: request3, 529 streamErr: nil, 530 }, 531 { 532 stepType: stepCloseSend, 533 534 message: nil, 535 streamErr: nil, 536 }, 537 }, 538 539 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), 540 }, 541 { 542 name: "close send called immediately - should observe zero-sized response", 543 steps: []step{ 544 { 545 stepType: stepCloseSend, 546 547 message: nil, 548 streamErr: nil, 549 }, 550 }, 551 552 expectedSize: uint64(0), 553 }, 554 { 555 name: "first send fails - stream should abort and observe zero-sized response", 556 steps: []step{ 557 { 558 stepType: stepSend, 559 560 message: request1, 561 streamErr: sentinelError, 562 }, 563 }, 564 565 expectedSize: uint64(0), 566 }, 567 } 568 569 for _, test := range tests { 570 t.Run(test.name, func(t *testing.T) { 571 onFinishCallCount := 0 572 573 observer := messageSizeObserver{ 574 onSingleFunc: func(messageSizeBytes uint64) {}, 575 onFinishFunc: func(totalSizeBytes uint64) { 576 onFinishCallCount++ 577 578 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { 579 t.Error("totalSizeBytes mismatch (-want +got):\n", diff) 580 } 581 }, 582 } 583 584 baseStream := &mockClientStream{} 585 streamerCalled := false 586 streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { 587 streamerCalled = true 588 589 return baseStream, nil 590 } 591 592 ss, err := streamClientInterceptor(&observer, ctx, nil, nil, method, streamer) 593 require.NoError(t, err) 594 595 // Run through all the steps, preparing the mockClientStream to return the expected errors 596 for _, step := range test.steps { 597 baseStreamCalled := false 598 var streamErr error 599 600 switch step.stepType { 601 case stepSend: 602 baseStream.mockSendMsg = func(m any) error { 603 baseStreamCalled = true 604 return step.streamErr 605 } 606 607 streamErr = ss.SendMsg(step.message) 608 case stepRecv: 609 baseStream.mockRecvMsg = func(_ any) error { 610 baseStreamCalled = true 611 return step.streamErr 612 } 613 614 streamErr = ss.RecvMsg(step.message) 615 616 case stepCloseSend: 617 baseStream.mockCloseSend = func() error { 618 baseStreamCalled = true 619 return step.streamErr 620 } 621 622 streamErr = ss.CloseSend() 623 default: 624 t.Fatalf("unknown step type: %v", step.stepType) 625 } 626 627 // ensure that the baseStream was called and errors are propagated 628 require.True(t, baseStreamCalled) 629 require.Equal(t, step.streamErr, streamErr) 630 } 631 632 if !streamerCalled { 633 t.Fatal("streamer not called") 634 } 635 636 if diff := cmp.Diff(1, onFinishCallCount); diff != "" { 637 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 638 } 639 }) 640 } 641} 642 643func TestObserver(t *testing.T) { 644 testCases := []struct { 645 name string 646 messages []proto.Message 647 }{ 648 { 649 name: "single message", 650 messages: []proto.Message{&newspb.BinaryAttachment{ 651 Name: "data1", 652 Data: []byte("sample data"), 653 }}, 654 }, 655 { 656 name: "multiple messages", 657 messages: []proto.Message{ 658 &newspb.BinaryAttachment{ 659 Name: "data1", 660 Data: []byte("sample data"), 661 }, 662 &newspb.KeyValueAttachment{ 663 Name: "data2", 664 Data: map[string]string{ 665 "key1": "value1", 666 "key2": "value2", 667 }, 668 }, 669 }}, 670 } 671 672 for _, tc := range testCases { 673 t.Run(tc.name, func(t *testing.T) { 674 var singleMessageSizes []uint64 675 var totalSize uint64 676 677 // Create a new observer with custom onSingleFunc and onFinishFunc 678 obs := &messageSizeObserver{ 679 onSingleFunc: func(messageSizeBytes uint64) { 680 singleMessageSizes = append(singleMessageSizes, messageSizeBytes) 681 }, 682 onFinishFunc: func(totalSizeBytes uint64) { 683 totalSize = totalSizeBytes 684 }, 685 } 686 687 // Call ObserveSingle for each message 688 for _, msg := range tc.messages { 689 obs.Observe(msg) 690 } 691 692 // Check that the singleMessageSizes are correct 693 for i, msg := range tc.messages { 694 expectedSize := uint64(proto.Size(msg)) 695 require.Equal(t, expectedSize, singleMessageSizes[i]) 696 } 697 698 // Call FinishRPC 699 obs.FinishRPC() 700 701 // Check that the totalSize is correct 702 expectedTotalSize := uint64(0) 703 for _, size := range singleMessageSizes { 704 expectedTotalSize += size 705 } 706 require.EqualValues(t, expectedTotalSize, totalSize) 707 }) 708 } 709} 710 711type mockServerStream struct { 712 mockSendMsg func(m any) error 713 714 grpc.ServerStream 715} 716 717func (s *mockServerStream) SendMsg(m any) error { 718 if s.mockSendMsg != nil { 719 return s.mockSendMsg(m) 720 } 721 722 return errors.New("send msg not implemented") 723} 724 725type mockClientStream struct { 726 mockRecvMsg func(m any) error 727 mockSendMsg func(m any) error 728 mockCloseSend func() error 729 730 grpc.ClientStream 731} 732 733func (s *mockClientStream) SendMsg(m any) error { 734 if s.mockSendMsg != nil { 735 return s.mockSendMsg(m) 736 } 737 738 return errors.New("send msg not implemented") 739} 740 741func (s *mockClientStream) RecvMsg(m any) error { 742 if s.mockRecvMsg != nil { 743 return s.mockRecvMsg(m) 744 } 745 746 return errors.New("recv msg not implemented") 747} 748 749func (s *mockClientStream) CloseSend() error { 750 if s.mockCloseSend != nil { 751 return s.mockCloseSend() 752 } 753 754 return errors.New("close send not implemented") 755} 756 757var _ grpc.ServerStream = &mockServerStream{} 758var _ grpc.ClientStream = &mockClientStream{}