fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package messagesize 2 3import ( 4 "bytes" 5 "context" 6 "errors" 7 "io" 8 "strings" 9 "testing" 10 11 "github.com/google/go-cmp/cmp" 12 "github.com/google/go-cmp/cmp/cmpopts" 13 "github.com/stretchr/testify/require" 14 "google.golang.org/grpc" 15 "google.golang.org/protobuf/proto" 16 "google.golang.org/protobuf/testing/protocmp" 17 "google.golang.org/protobuf/types/known/timestamppb" 18 19 newspb "github.com/sourcegraph/zoekt/grpc/testprotos/news/v1" 20) 21 22var ( 23 binaryMessage = &newspb.BinaryAttachment{ 24 Name: "data", 25 Data: []byte(strings.Repeat("x", 1*1024*1024)), 26 } 27 28 keyValueMessage = &newspb.KeyValueAttachment{ 29 Name: "data", 30 Data: map[string]string{ 31 "key1": strings.Repeat("x", 1*1024*1024), 32 "key2": "value2", 33 }, 34 } 35 36 articleMessage = &newspb.Article{ 37 Author: "author", 38 Date: &timestamppb.Timestamp{Seconds: 1234567890}, 39 Title: "title", 40 Content: "content", 41 Status: newspb.Article_STATUS_PUBLISHED, 42 Attachments: []*newspb.Attachment{ 43 {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}}, 44 {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}}, 45 {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}}, 46 {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}}, 47 }, 48 } 49) 50 51func BenchmarkObserverBinary(b *testing.B) { 52 o := messageSizeObserver{ 53 onSingleFunc: func(messageSizeBytes uint64) {}, 54 onFinishFunc: func(totalSizeBytes uint64) {}, 55 } 56 57 benchmarkObserver(b, &o, binaryMessage) 58} 59 60func BenchmarkObserverKeyValue(b *testing.B) { 61 o := messageSizeObserver{ 62 onSingleFunc: func(messageSizeBytes uint64) {}, 63 onFinishFunc: func(totalSizeBytes uint64) {}, 64 } 65 66 benchmarkObserver(b, &o, keyValueMessage) 67} 68 69func BenchmarkObserverArticle(b *testing.B) { 70 o := messageSizeObserver{ 71 onSingleFunc: func(messageSizeBytes uint64) {}, 72 onFinishFunc: func(totalSizeBytes uint64) {}, 73 } 74 75 benchmarkObserver(b, &o, articleMessage) 76} 77 78func benchmarkObserver(b *testing.B, observer *messageSizeObserver, message proto.Message) { 79 b.ReportAllocs() 80 81 for n := 0; n < b.N; n++ { 82 observer.Observe(message) 83 } 84 85 observer.FinishRPC() 86} 87 88func TestUnaryServerInterceptor(t *testing.T) { 89 ctx := context.Background() 90 91 request := &newspb.BinaryAttachment{ 92 Data: bytes.Repeat([]byte("request"), 3), 93 } 94 95 response := &newspb.BinaryAttachment{ 96 Data: bytes.Repeat([]byte("response"), 7), 97 } 98 99 info := &grpc.UnaryServerInfo{ 100 FullMethod: "news.v1.NewsService/GetArticle", 101 } 102 103 sentinelError := errors.New("expected error") 104 105 tests := []struct { 106 name string 107 handler func(ctx context.Context, req any) (any, error) 108 expectedError error 109 expectedResult any 110 expectedSize uint64 111 }{ 112 { 113 name: "invoker successful - observe response", 114 handler: func(ctx context.Context, req any) (any, error) { 115 return response, nil 116 }, 117 expectedError: nil, 118 expectedResult: response, 119 expectedSize: uint64(proto.Size(response)), 120 }, 121 { 122 name: "invoker error - observe a zero-sized response", 123 handler: func(ctx context.Context, req any) (any, error) { 124 return nil, sentinelError 125 }, 126 expectedError: sentinelError, 127 expectedResult: nil, 128 expectedSize: uint64(0), 129 }, 130 } 131 132 for _, test := range tests { 133 t.Run(test.name, func(t *testing.T) { 134 onFinishCalledCount := 0 135 136 observer := messageSizeObserver{ 137 onSingleFunc: func(messageSizeBytes uint64) {}, 138 onFinishFunc: func(totalSizeBytes uint64) { 139 onFinishCalledCount++ 140 141 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { 142 t.Error("totalSizeBytes mismatch (-want +got):\n", diff) 143 } 144 }, 145 } 146 147 actualResult, err := unaryServerInterceptor(&observer, request, ctx, info, test.handler) 148 if err != test.expectedError { 149 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) 150 } 151 152 if diff := cmp.Diff(test.expectedResult, actualResult, protocmp.Transform()); diff != "" { 153 t.Error("response mismatch (-want +got):\n", diff) 154 } 155 156 if diff := cmp.Diff(1, onFinishCalledCount); diff != "" { 157 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 158 } 159 }) 160 } 161} 162 163func TestStreamServerInterceptor(t *testing.T) { 164 response1 := &newspb.BinaryAttachment{ 165 Name: "", 166 Data: []byte("response"), 167 } 168 response2 := &newspb.BinaryAttachment{ 169 Name: "", 170 Data: bytes.Repeat([]byte("response"), 3), 171 } 172 response3 := &newspb.BinaryAttachment{ 173 Name: "", 174 Data: bytes.Repeat([]byte("response"), 7), 175 } 176 177 info := &grpc.StreamServerInfo{ 178 FullMethod: "news.v1.NewsService/GetArticle", 179 } 180 181 sentinelError := errors.New("expected error") 182 183 tests := []struct { 184 name string 185 186 mockSendMsg func(m any) error 187 handler func(srv any, stream grpc.ServerStream) error 188 189 expectedError error 190 expectedResponses []any 191 expectedSize uint64 192 }{ 193 { 194 name: "invoker successful - observe all 3 responses", 195 196 mockSendMsg: func(m any) error { 197 return nil // no error 198 }, 199 200 handler: func(srv any, stream grpc.ServerStream) error { 201 for _, r := range []proto.Message{response1, response2, response3} { 202 if err := stream.SendMsg(r); err != nil { 203 return err 204 } 205 } 206 207 return nil 208 }, 209 210 expectedError: nil, 211 expectedResponses: []any{response1, response2, response3}, 212 expectedSize: uint64(proto.Size(response1) + proto.Size(response2) + proto.Size(response3)), 213 }, 214 215 { 216 name: "invoker fails on 3rd response - only observe first 2", 217 218 mockSendMsg: func(m any) error { 219 if m == response3 { 220 return sentinelError 221 } 222 223 return nil 224 }, 225 handler: func(srv any, stream grpc.ServerStream) error { 226 for _, r := range []proto.Message{response1, response2, response3} { 227 if err := stream.SendMsg(r); err != nil { 228 return err 229 } 230 } 231 232 return nil 233 }, 234 235 expectedError: sentinelError, 236 expectedResponses: []any{response1, response2, response3}, // response 3 should still be attempted to be sent 237 expectedSize: uint64(proto.Size(response1) + proto.Size(response2)), // response 3 should not be counted since an error occurred while sending it 238 }, 239 240 { 241 name: "invoker fails immediately - should still observe a zero-sized response", 242 243 mockSendMsg: func(m any) error { 244 return errors.New("should not be called") 245 }, 246 247 handler: func(srv any, stream grpc.ServerStream) error { 248 return sentinelError 249 }, 250 251 expectedError: sentinelError, 252 expectedResponses: []any{}, // there are no responses 253 expectedSize: uint64(0), // there are no responses, so the size is 0 254 }, 255 } 256 257 for _, test := range tests { 258 t.Run(test.name, func(t *testing.T) { 259 onFinishCallCount := 0 260 261 observer := messageSizeObserver{ 262 onSingleFunc: func(messageSizeBytes uint64) {}, 263 onFinishFunc: func(totalSizeBytes uint64) { 264 onFinishCallCount++ 265 266 if totalSizeBytes != test.expectedSize { 267 t.Errorf("totalSizeBytes mismatch (wanted: %d, got: %d)", test.expectedSize, totalSizeBytes) 268 } 269 }, 270 } 271 272 var actualResponses []any 273 274 ss := &mockServerStream{ 275 mockSendMsg: func(m any) error { 276 actualResponses = append(actualResponses, m) 277 278 return test.mockSendMsg(m) 279 }, 280 } 281 282 err := streamServerInterceptor(&observer, nil, ss, info, test.handler) 283 if err != test.expectedError { 284 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) 285 } 286 287 if diff := cmp.Diff(test.expectedResponses, actualResponses, protocmp.Transform(), cmpopts.EquateEmpty()); diff != "" { 288 t.Error("responses mismatch (-want +got):\n", diff) 289 } 290 291 if diff := cmp.Diff(1, onFinishCallCount); diff != "" { 292 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 293 } 294 }) 295 } 296} 297 298func TestUnaryClientInterceptor(t *testing.T) { 299 ctx := context.Background() 300 301 request := &newspb.BinaryAttachment{ 302 Name: "data", 303 Data: bytes.Repeat([]byte("request"), 3), 304 } 305 306 method := "news.v1.NewsService/GetArticle" 307 308 sentinelError := errors.New("expected error") 309 310 tests := []struct { 311 name string 312 invoker grpc.UnaryInvoker 313 314 expectedError error 315 expectedRequest any 316 expectedSize uint64 317 }{ 318 { 319 name: "invoker successful - observe request size", 320 invoker: func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 321 return nil 322 }, 323 324 expectedError: nil, 325 expectedRequest: request, 326 expectedSize: uint64(proto.Size(request)), 327 }, 328 329 { 330 name: "invoker error - observe a zero-sized response", 331 invoker: func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 332 return sentinelError 333 }, 334 335 expectedError: sentinelError, 336 expectedRequest: request, 337 expectedSize: uint64(0), 338 }, 339 } 340 341 for _, test := range tests { 342 t.Run(test.name, func(t *testing.T) { 343 onFinishCallCount := 0 344 345 observer := messageSizeObserver{ 346 onSingleFunc: func(messageSizeBytes uint64) {}, 347 onFinishFunc: func(totalSizeBytes uint64) { 348 onFinishCallCount++ 349 350 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { 351 t.Error("totalSizeBytes mismatch (-want +got):\n", diff) 352 } 353 }, 354 } 355 356 var actualRequest any 357 358 invokerCalled := false 359 invoker := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 360 invokerCalled = true 361 362 actualRequest = req 363 return test.invoker(ctx, method, req, reply, cc, opts...) 364 } 365 366 err := unaryClientInterceptor(&observer, ctx, method, request, nil, nil, invoker) 367 if err != test.expectedError { 368 t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) 369 } 370 371 if !invokerCalled { 372 t.Fatal("invoker not called") 373 } 374 375 if diff := cmp.Diff(test.expectedRequest, actualRequest, protocmp.Transform()); diff != "" { 376 t.Error("request mismatch (-want +got):\n", diff) 377 } 378 379 if diff := cmp.Diff(1, onFinishCallCount); diff != "" { 380 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 381 } 382 }) 383 } 384} 385 386func TestStreamingClientInterceptor(t *testing.T) { 387 ctx := context.Background() 388 389 request1 := &newspb.BinaryAttachment{ 390 Name: "data", 391 Data: bytes.Repeat([]byte("request"), 3), 392 } 393 394 request2 := &newspb.BinaryAttachment{ 395 Name: "data", 396 Data: bytes.Repeat([]byte("request"), 7), 397 } 398 399 request3 := &newspb.BinaryAttachment{ 400 Name: "data", 401 Data: bytes.Repeat([]byte("request"), 13), 402 } 403 404 method := "news.v1.NewsService/GetArticle" 405 406 sentinelError := errors.New("expected error") 407 408 type stepType int 409 410 const ( 411 stepSend stepType = iota 412 stepRecv 413 stepCloseSend 414 ) 415 416 type step struct { 417 stepType stepType 418 419 message any 420 streamErr error 421 } 422 423 tests := []struct { 424 name string 425 426 steps []step 427 expectedSize uint64 428 }{ 429 { 430 name: "invoker successful - observe request size", 431 steps: []step{ 432 { 433 stepType: stepSend, 434 435 message: request1, 436 streamErr: nil, 437 }, 438 { 439 stepType: stepSend, 440 441 message: request2, 442 streamErr: nil, 443 }, 444 { 445 stepType: stepSend, 446 447 message: request3, 448 streamErr: nil, 449 }, 450 { 451 stepType: stepRecv, 452 453 message: nil, 454 streamErr: io.EOF, // end of stream 455 }, 456 }, 457 458 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), 459 }, 460 { 461 name: "2nd send failed - stream aborts and should only observe first request", 462 steps: []step{ 463 { 464 stepType: stepSend, 465 message: request1, 466 streamErr: nil, 467 }, 468 { 469 stepType: stepSend, 470 message: request2, 471 streamErr: sentinelError, 472 }, 473 }, 474 475 expectedSize: uint64(proto.Size(request1)), 476 }, 477 { 478 name: "recv message fails with non io.EOF error - should still observe all requests", 479 steps: []step{ 480 { 481 stepType: stepSend, 482 483 message: request1, 484 streamErr: nil, 485 }, 486 { 487 stepType: stepSend, 488 489 message: request2, 490 streamErr: nil, 491 }, 492 { 493 stepType: stepSend, 494 495 message: request3, 496 streamErr: nil, 497 }, 498 { 499 stepType: stepRecv, 500 501 message: nil, 502 streamErr: sentinelError, 503 }, 504 }, 505 506 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), 507 }, 508 509 { 510 name: "close send called - should observe all requests", 511 steps: []step{ 512 { 513 stepType: stepSend, 514 515 message: request1, 516 streamErr: nil, 517 }, 518 { 519 stepType: stepSend, 520 521 message: request2, 522 streamErr: nil, 523 }, 524 { 525 stepType: stepSend, 526 527 message: request3, 528 streamErr: nil, 529 }, 530 { 531 stepType: stepCloseSend, 532 533 message: nil, 534 streamErr: nil, 535 }, 536 }, 537 538 expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), 539 }, 540 { 541 name: "close send called immediately - should observe zero-sized response", 542 steps: []step{ 543 { 544 stepType: stepCloseSend, 545 546 message: nil, 547 streamErr: nil, 548 }, 549 }, 550 551 expectedSize: uint64(0), 552 }, 553 { 554 name: "first send fails - stream should abort and observe zero-sized response", 555 steps: []step{ 556 { 557 stepType: stepSend, 558 559 message: request1, 560 streamErr: sentinelError, 561 }, 562 }, 563 564 expectedSize: uint64(0), 565 }, 566 } 567 568 for _, test := range tests { 569 t.Run(test.name, func(t *testing.T) { 570 onFinishCallCount := 0 571 572 observer := messageSizeObserver{ 573 onSingleFunc: func(messageSizeBytes uint64) {}, 574 onFinishFunc: func(totalSizeBytes uint64) { 575 onFinishCallCount++ 576 577 if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { 578 t.Error("totalSizeBytes mismatch (-want +got):\n", diff) 579 } 580 }, 581 } 582 583 baseStream := &mockClientStream{} 584 streamerCalled := false 585 streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { 586 streamerCalled = true 587 588 return baseStream, nil 589 } 590 591 ss, err := streamClientInterceptor(&observer, ctx, nil, nil, method, streamer) 592 require.NoError(t, err) 593 594 // Run through all the steps, preparing the mockClientStream to return the expected errors 595 for _, step := range test.steps { 596 baseStreamCalled := false 597 var streamErr error 598 599 switch step.stepType { 600 case stepSend: 601 baseStream.mockSendMsg = func(m any) error { 602 baseStreamCalled = true 603 return step.streamErr 604 } 605 606 streamErr = ss.SendMsg(step.message) 607 case stepRecv: 608 baseStream.mockRecvMsg = func(_ any) error { 609 baseStreamCalled = true 610 return step.streamErr 611 } 612 613 streamErr = ss.RecvMsg(step.message) 614 615 case stepCloseSend: 616 baseStream.mockCloseSend = func() error { 617 baseStreamCalled = true 618 return step.streamErr 619 } 620 621 streamErr = ss.CloseSend() 622 default: 623 t.Fatalf("unknown step type: %v", step.stepType) 624 } 625 626 // ensure that the baseStream was called and errors are propagated 627 require.True(t, baseStreamCalled) 628 require.Equal(t, step.streamErr, streamErr) 629 } 630 631 if !streamerCalled { 632 t.Fatal("streamer not called") 633 } 634 635 if diff := cmp.Diff(1, onFinishCallCount); diff != "" { 636 t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 637 } 638 }) 639 } 640} 641 642func TestObserver(t *testing.T) { 643 testCases := []struct { 644 name string 645 messages []proto.Message 646 }{ 647 { 648 name: "single message", 649 messages: []proto.Message{&newspb.BinaryAttachment{ 650 Name: "data1", 651 Data: []byte("sample data"), 652 }}, 653 }, 654 { 655 name: "multiple messages", 656 messages: []proto.Message{ 657 &newspb.BinaryAttachment{ 658 Name: "data1", 659 Data: []byte("sample data"), 660 }, 661 &newspb.KeyValueAttachment{ 662 Name: "data2", 663 Data: map[string]string{ 664 "key1": "value1", 665 "key2": "value2", 666 }, 667 }, 668 }, 669 }, 670 } 671 672 for _, tc := range testCases { 673 t.Run(tc.name, func(t *testing.T) { 674 var singleMessageSizes []uint64 675 var totalSize uint64 676 677 // Create a new observer with custom onSingleFunc and onFinishFunc 678 obs := &messageSizeObserver{ 679 onSingleFunc: func(messageSizeBytes uint64) { 680 singleMessageSizes = append(singleMessageSizes, messageSizeBytes) 681 }, 682 onFinishFunc: func(totalSizeBytes uint64) { 683 totalSize = totalSizeBytes 684 }, 685 } 686 687 // Call ObserveSingle for each message 688 for _, msg := range tc.messages { 689 obs.Observe(msg) 690 } 691 692 // Check that the singleMessageSizes are correct 693 for i, msg := range tc.messages { 694 expectedSize := uint64(proto.Size(msg)) 695 require.Equal(t, expectedSize, singleMessageSizes[i]) 696 } 697 698 // Call FinishRPC 699 obs.FinishRPC() 700 701 // Check that the totalSize is correct 702 expectedTotalSize := uint64(0) 703 for _, size := range singleMessageSizes { 704 expectedTotalSize += size 705 } 706 require.EqualValues(t, expectedTotalSize, totalSize) 707 }) 708 } 709} 710 711type mockServerStream struct { 712 mockSendMsg func(m any) error 713 714 grpc.ServerStream 715} 716 717func (s *mockServerStream) SendMsg(m any) error { 718 if s.mockSendMsg != nil { 719 return s.mockSendMsg(m) 720 } 721 722 return errors.New("send msg not implemented") 723} 724 725type mockClientStream struct { 726 mockRecvMsg func(m any) error 727 mockSendMsg func(m any) error 728 mockCloseSend func() error 729 730 grpc.ClientStream 731} 732 733func (s *mockClientStream) SendMsg(m any) error { 734 if s.mockSendMsg != nil { 735 return s.mockSendMsg(m) 736 } 737 738 return errors.New("send msg not implemented") 739} 740 741func (s *mockClientStream) RecvMsg(m any) error { 742 if s.mockRecvMsg != nil { 743 return s.mockRecvMsg(m) 744 } 745 746 return errors.New("recv msg not implemented") 747} 748 749func (s *mockClientStream) CloseSend() error { 750 if s.mockCloseSend != nil { 751 return s.mockCloseSend() 752 } 753 754 return errors.New("close send not implemented") 755} 756 757var ( 758 _ grpc.ServerStream = &mockServerStream{} 759 _ grpc.ClientStream = &mockClientStream{} 760)