fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package internalerrs 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "sort" 8 "strings" 9 "testing" 10 11 "github.com/google/go-cmp/cmp/cmpopts" 12 "google.golang.org/protobuf/proto" 13 "google.golang.org/protobuf/types/known/timestamppb" 14 15 newspb "github.com/sourcegraph/zoekt/grpc/testprotos/news/v1" 16 17 "github.com/google/go-cmp/cmp" 18 "google.golang.org/grpc" 19 "google.golang.org/grpc/codes" 20 "google.golang.org/grpc/status" 21) 22 23func TestCallBackClientStream(t *testing.T) { 24 t.Run("SendMsg calls postMessageSend with message and error", func(t *testing.T) { 25 sentinelMessage := struct{}{} 26 sentinelErr := errors.New("send error") 27 28 var called bool 29 stream := callBackClientStream{ 30 ClientStream: &mockClientStream{ 31 sendErr: sentinelErr, 32 }, 33 postMessageSend: func(message any, err error) { 34 called = true 35 36 if diff := cmp.Diff(message, sentinelMessage); diff != "" { 37 t.Errorf("postMessageSend called with unexpected message (-want +got):\n%s", diff) 38 } 39 if !errors.Is(err, sentinelErr) { 40 t.Errorf("got %v, want %v", err, sentinelErr) 41 } 42 }, 43 } 44 45 sendErr := stream.SendMsg(sentinelMessage) 46 if !called { 47 t.Error("postMessageSend not called") 48 } 49 50 if !errors.Is(sendErr, sentinelErr) { 51 t.Errorf("got %v, want %v", sendErr, sentinelErr) 52 } 53 }) 54 55 t.Run("RecvMsg calls postMessageReceive with message and error", func(t *testing.T) { 56 sentinelMessage := struct{}{} 57 sentinelErr := errors.New("receive error") 58 59 var called bool 60 stream := callBackClientStream{ 61 ClientStream: &mockClientStream{ 62 recvErr: sentinelErr, 63 }, 64 postMessageReceive: func(message any, err error) { 65 called = true 66 67 if diff := cmp.Diff(message, sentinelMessage); diff != "" { 68 t.Errorf("postMessageReceive called with unexpected message (-want +got):\n%s", diff) 69 } 70 if !errors.Is(err, sentinelErr) { 71 t.Errorf("got %v, want %v", err, sentinelErr) 72 } 73 }, 74 } 75 76 receiveErr := stream.RecvMsg(sentinelMessage) 77 if !called { 78 t.Error("postMessageReceive not called") 79 } 80 81 if !errors.Is(receiveErr, sentinelErr) { 82 t.Errorf("got %v, want %v", receiveErr, sentinelErr) 83 } 84 }) 85} 86 87func TestRequestSavingClientStream_InitialRequest(t *testing.T) { 88 // Setup: create a mock ClientStream that returns a sentinel error on SendMsg 89 sentinelErr := errors.New("send error") 90 mockClientStream := &mockClientStream{ 91 sendErr: sentinelErr, 92 } 93 94 // Setup: create a requestSavingClientStream with the mock ClientStream 95 stream := &requestSavingClientStream{ 96 ClientStream: mockClientStream, 97 } 98 99 // Setup: create a sample proto.Message for the request 100 request := &newspb.BinaryAttachment{ 101 Name: "sample_request", 102 Data: []byte("sample data"), 103 } 104 105 // Test: call SendMsg with the request 106 err := stream.SendMsg(request) 107 108 // Check: assert SendMsg propagates the error 109 if !errors.Is(err, sentinelErr) { 110 t.Errorf("got %v, want %v", err, sentinelErr) 111 } 112 113 // Check: assert InitialRequest returns the request 114 if diff := cmp.Diff(request, *stream.InitialRequest(), cmpopts.IgnoreUnexported(newspb.BinaryAttachment{})); diff != "" { 115 t.Fatalf("InitialRequest() (-want +got):\n%s", diff) 116 } 117} 118 119// mockClientStream is a grpc.ClientStream that returns a given error on SendMsg and RecvMsg. 120type mockClientStream struct { 121 grpc.ClientStream 122 sendErr error 123 recvErr error 124} 125 126func (s *mockClientStream) SendMsg(any) error { 127 return s.sendErr 128} 129 130func (s *mockClientStream) RecvMsg(any) error { 131 return s.recvErr 132} 133 134func TestCallBackServerStream(t *testing.T) { 135 t.Run("SendMsg calls postMessageSend with message and error", func(t *testing.T) { 136 sentinelMessage := struct{}{} 137 sentinelErr := errors.New("send error") 138 139 var called bool 140 stream := callBackServerStream{ 141 ServerStream: &mockServerStream{ 142 sendErr: sentinelErr, 143 }, 144 postMessageSend: func(message any, err error) { 145 called = true 146 147 if diff := cmp.Diff(message, sentinelMessage); diff != "" { 148 t.Errorf("postMessageSend called with unexpected message (-want +got):\n%s", diff) 149 } 150 if !errors.Is(err, sentinelErr) { 151 t.Errorf("got %v, want %v", err, sentinelErr) 152 } 153 }, 154 } 155 156 sendErr := stream.SendMsg(sentinelMessage) 157 if !called { 158 t.Error("postMessageSend not called") 159 } 160 161 if !errors.Is(sendErr, sentinelErr) { 162 t.Errorf("got %v, want %v", sendErr, sentinelErr) 163 } 164 }) 165 166 t.Run("RecvMsg calls postMessageReceive with message and error", func(t *testing.T) { 167 sentinelMessage := struct{}{} 168 sentinelErr := errors.New("receive error") 169 170 var called bool 171 stream := callBackServerStream{ 172 ServerStream: &mockServerStream{ 173 recvErr: sentinelErr, 174 }, 175 postMessageReceive: func(message any, err error) { 176 called = true 177 178 if diff := cmp.Diff(message, sentinelMessage); diff != "" { 179 t.Errorf("postMessageReceive called with unexpected message (-want +got):\n%s", diff) 180 } 181 if !errors.Is(err, sentinelErr) { 182 t.Errorf("got %v, want %v", err, sentinelErr) 183 } 184 }, 185 } 186 187 receiveErr := stream.RecvMsg(sentinelMessage) 188 if !called { 189 t.Error("postMessageReceive not called") 190 } 191 192 if !errors.Is(receiveErr, sentinelErr) { 193 t.Errorf("got %v, want %v", receiveErr, sentinelErr) 194 } 195 }) 196} 197 198func TestRequestSavingServerStream_InitialRequest(t *testing.T) { 199 // Setup: create a mock ServerStream that returns a sentinel error on SendMsg 200 sentinelErr := errors.New("receive error") 201 mockServerStream := &mockServerStream{ 202 recvErr: sentinelErr, 203 } 204 205 // Setup: create a requestSavingServerStream with the mock ServerStream 206 stream := &requestSavingServerStream{ 207 ServerStream: mockServerStream, 208 } 209 210 // Setup: create a sample proto.Message for the request 211 request := &newspb.BinaryAttachment{ 212 Name: "sample_request", 213 Data: []byte("sample data"), 214 } 215 216 // Test: call RecvMsg with the request 217 err := stream.RecvMsg(request) 218 219 // Check: assert RecvMsg propagates the error 220 if !errors.Is(err, sentinelErr) { 221 t.Errorf("got %v, want %v", err, sentinelErr) 222 } 223 224 // Check: assert InitialRequest returns the request 225 if diff := cmp.Diff(request, *stream.InitialRequest(), cmpopts.IgnoreUnexported(newspb.BinaryAttachment{})); diff != "" { 226 t.Fatalf("InitialRequest() (-want +got):\n%s", diff) 227 } 228} 229 230// mockServerStream is a grpc.ServerStream that returns a given error on SendMsg and RecvMsg. 231type mockServerStream struct { 232 grpc.ServerStream 233 sendErr error 234 recvErr error 235} 236 237func (s *mockServerStream) SendMsg(any) error { 238 return s.sendErr 239} 240 241func (s *mockServerStream) RecvMsg(any) error { 242 return s.recvErr 243} 244 245func TestProbablyInternalGRPCError(t *testing.T) { 246 checker := func(s *status.Status) bool { 247 return strings.HasPrefix(s.Message(), "custom error") 248 } 249 250 testCases := []struct { 251 status *status.Status 252 checkers []internalGRPCErrorChecker 253 wantResult bool 254 }{ 255 { 256 status: status.New(codes.OK, ""), 257 checkers: []internalGRPCErrorChecker{func(*status.Status) bool { return true }}, 258 wantResult: false, 259 }, 260 { 261 status: status.New(codes.Internal, "custom error message"), 262 checkers: []internalGRPCErrorChecker{checker}, 263 wantResult: true, 264 }, 265 { 266 status: status.New(codes.Internal, "some other error"), 267 checkers: []internalGRPCErrorChecker{checker}, 268 wantResult: false, 269 }, 270 } 271 272 for _, tc := range testCases { 273 gotResult := probablyInternalGRPCError(tc.status, tc.checkers) 274 if gotResult != tc.wantResult { 275 t.Errorf("probablyInternalGRPCError(%v, %v) = %v, want %v", tc.status, tc.checkers, gotResult, tc.wantResult) 276 } 277 } 278} 279 280func TestGRPCResourceExhaustedChecker(t *testing.T) { 281 testCases := []struct { 282 status *status.Status 283 expectPass bool 284 }{ 285 { 286 status: status.New(codes.ResourceExhausted, "trying to send message larger than max (1024 vs 2)"), 287 expectPass: true, 288 }, 289 { 290 status: status.New(codes.ResourceExhausted, "some other error"), 291 expectPass: false, 292 }, 293 { 294 status: status.New(codes.OK, "trying to send message larger than max (1024 vs 5)"), 295 expectPass: false, 296 }, 297 } 298 299 for _, tc := range testCases { 300 actual := gRPCResourceExhaustedChecker(tc.status) 301 if actual != tc.expectPass { 302 t.Errorf("gRPCResourceExhaustedChecker(%v) got %t, want %t", tc.status, actual, tc.expectPass) 303 } 304 } 305} 306 307func TestGRPCPrefixChecker(t *testing.T) { 308 tests := []struct { 309 status *status.Status 310 want bool 311 }{ 312 { 313 status: status.New(codes.OK, "not a grpc error"), 314 want: false, 315 }, 316 { 317 status: status.New(codes.Internal, "grpc: internal server error"), 318 want: true, 319 }, 320 { 321 status: status.New(codes.Unavailable, "some other error"), 322 want: false, 323 }, 324 } 325 for _, test := range tests { 326 got := gRPCPrefixChecker(test.status) 327 if got != test.want { 328 t.Errorf("gRPCPrefixChecker(%v) = %v, want %v", test.status, got, test.want) 329 } 330 } 331} 332 333func TestGRPCUnexpectedContentTypeChecker(t *testing.T) { 334 tests := []struct { 335 name string 336 status *status.Status 337 want bool 338 }{ 339 { 340 name: "gRPC error with OK status", 341 status: status.New(codes.OK, "transport: received unexpected content-type"), 342 want: false, 343 }, 344 { 345 name: "gRPC error without unexpected content-type message", 346 status: status.New(codes.Internal, "some random error"), 347 want: false, 348 }, 349 { 350 name: "gRPC error with unexpected content-type message", 351 status: status.Newf(codes.Internal, "transport: received unexpected content-type %q", "application/octet-stream"), 352 want: true, 353 }, 354 { 355 name: "gRPC error with unexpected content-type message as part of chain", 356 status: status.Newf(codes.Unknown, "transport: malformed grpc-status %q; transport: received unexpected content-type %q", "random-status", "application/octet-stream"), 357 want: true, 358 }, 359 } 360 361 for _, tt := range tests { 362 t.Run(tt.name, func(t *testing.T) { 363 if got := gRPCUnexpectedContentTypeChecker(tt.status); got != tt.want { 364 t.Errorf("gRPCUnexpectedContentTypeChecker() = %v, want %v", got, tt.want) 365 } 366 }) 367 } 368} 369 370func TestFindNonUTF8StringFields(t *testing.T) { 371 // Create instances of the BinaryAttachment and KeyValueAttachment messages 372 invalidBinaryAttachment := &newspb.BinaryAttachment{ 373 Name: "inval\x80id_binary", 374 Data: []byte("sample data"), 375 } 376 377 invalidKeyValueAttachment := &newspb.KeyValueAttachment{ 378 Name: "inval\x80id_key_value", 379 Data: map[string]string{ 380 "key1": "value1", 381 "key2": "inval\x80id_value", 382 }, 383 } 384 385 // Create a sample Article message with invalid UTF-8 strings 386 article := &newspb.Article{ 387 Author: "inval\x80id_author", 388 Date: &timestamppb.Timestamp{Seconds: 1234567890}, 389 Title: "valid_title", 390 Content: "valid_content", 391 Status: newspb.Article_STATUS_PUBLISHED, 392 Attachments: []*newspb.Attachment{ 393 {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: invalidBinaryAttachment}}, 394 {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: invalidKeyValueAttachment}}, 395 }, 396 } 397 398 tests := []struct { 399 name string 400 message proto.Message 401 expectedPaths []string 402 }{ 403 { 404 name: "Article with invalid UTF-8 strings", 405 message: article, 406 expectedPaths: []string{ 407 "author", 408 "attachments[0].binary_attachment.name", 409 "attachments[1].key_value_attachment.name", 410 `attachments[1].key_value_attachment.data["key2"]`, 411 }, 412 }, 413 { 414 name: "nil message", 415 message: nil, 416 expectedPaths: []string{}, 417 }, 418 } 419 420 for _, tt := range tests { 421 t.Run(tt.name, func(t *testing.T) { 422 invalidFields, err := findNonUTF8StringFields(tt.message) 423 if err != nil { 424 t.Fatalf("unexpected error: %v", err) 425 } 426 427 sort.Strings(invalidFields) 428 sort.Strings(tt.expectedPaths) 429 430 if diff := cmp.Diff(tt.expectedPaths, invalidFields, cmpopts.EquateEmpty()); diff != "" { 431 t.Fatalf("unexpected invalid fields (-want +got):\n%s", diff) 432 } 433 }) 434 } 435} 436 437func TestMassageIntoStatusErr(t *testing.T) { 438 testCases := []struct { 439 description string 440 input error 441 expected *status.Status 442 expectedOk bool 443 }{ 444 { 445 description: "nil error", 446 input: nil, 447 expected: nil, 448 expectedOk: false, 449 }, 450 { 451 description: "status error", 452 input: status.Errorf(codes.InvalidArgument, "invalid argument"), 453 expected: status.New(codes.InvalidArgument, "invalid argument"), 454 expectedOk: true, 455 }, 456 { 457 description: "context.Canceled error", 458 input: context.Canceled, 459 expected: status.New(codes.Canceled, "context canceled"), 460 expectedOk: true, 461 }, 462 { 463 description: "context.DeadlineExceeded error", 464 input: context.DeadlineExceeded, 465 expected: status.New(codes.DeadlineExceeded, "context deadline exceeded"), 466 expectedOk: true, 467 }, 468 { 469 description: "non-status error", 470 input: errors.New("non-status error"), 471 expected: nil, 472 expectedOk: false, 473 }, 474 } 475 476 for _, tc := range testCases { 477 t.Run(tc.description, func(t *testing.T) { 478 result, ok := massageIntoStatusErr(tc.input) 479 if ok != tc.expectedOk { 480 t.Errorf("Expected ok to be %v, but got %v", tc.expectedOk, ok) 481 } 482 483 expectedStatusString := fmt.Sprintf("%s", tc.expected) 484 actualStatusString := fmt.Sprintf("%s", result) 485 486 if diff := cmp.Diff(expectedStatusString, actualStatusString); diff != "" { 487 t.Fatalf("Unexpected status string (-want +got):\n%s", diff) 488 } 489 }) 490 } 491}