fork of https://github.com/sourcegraph/zoekt
1package internalerrs
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "sort"
8 "strings"
9 "testing"
10
11 "github.com/google/go-cmp/cmp/cmpopts"
12 "google.golang.org/protobuf/proto"
13 "google.golang.org/protobuf/types/known/timestamppb"
14
15 newspb "github.com/sourcegraph/zoekt/grpc/testprotos/news/v1"
16
17 "github.com/google/go-cmp/cmp"
18 "google.golang.org/grpc"
19 "google.golang.org/grpc/codes"
20 "google.golang.org/grpc/status"
21)
22
23func TestCallBackClientStream(t *testing.T) {
24 t.Run("SendMsg calls postMessageSend with message and error", func(t *testing.T) {
25 sentinelMessage := struct{}{}
26 sentinelErr := errors.New("send error")
27
28 var called bool
29 stream := callBackClientStream{
30 ClientStream: &mockClientStream{
31 sendErr: sentinelErr,
32 },
33 postMessageSend: func(message any, err error) {
34 called = true
35
36 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
37 t.Errorf("postMessageSend called with unexpected message (-want +got):\n%s", diff)
38 }
39 if !errors.Is(err, sentinelErr) {
40 t.Errorf("got %v, want %v", err, sentinelErr)
41 }
42 },
43 }
44
45 sendErr := stream.SendMsg(sentinelMessage)
46 if !called {
47 t.Error("postMessageSend not called")
48 }
49
50 if !errors.Is(sendErr, sentinelErr) {
51 t.Errorf("got %v, want %v", sendErr, sentinelErr)
52 }
53 })
54
55 t.Run("RecvMsg calls postMessageReceive with message and error", func(t *testing.T) {
56 sentinelMessage := struct{}{}
57 sentinelErr := errors.New("receive error")
58
59 var called bool
60 stream := callBackClientStream{
61 ClientStream: &mockClientStream{
62 recvErr: sentinelErr,
63 },
64 postMessageReceive: func(message any, err error) {
65 called = true
66
67 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
68 t.Errorf("postMessageReceive called with unexpected message (-want +got):\n%s", diff)
69 }
70 if !errors.Is(err, sentinelErr) {
71 t.Errorf("got %v, want %v", err, sentinelErr)
72 }
73 },
74 }
75
76 receiveErr := stream.RecvMsg(sentinelMessage)
77 if !called {
78 t.Error("postMessageReceive not called")
79 }
80
81 if !errors.Is(receiveErr, sentinelErr) {
82 t.Errorf("got %v, want %v", receiveErr, sentinelErr)
83 }
84 })
85}
86
87func TestRequestSavingClientStream_InitialRequest(t *testing.T) {
88 // Setup: create a mock ClientStream that returns a sentinel error on SendMsg
89 sentinelErr := errors.New("send error")
90 mockClientStream := &mockClientStream{
91 sendErr: sentinelErr,
92 }
93
94 // Setup: create a requestSavingClientStream with the mock ClientStream
95 stream := &requestSavingClientStream{
96 ClientStream: mockClientStream,
97 }
98
99 // Setup: create a sample proto.Message for the request
100 request := &newspb.BinaryAttachment{
101 Name: "sample_request",
102 Data: []byte("sample data"),
103 }
104
105 // Test: call SendMsg with the request
106 err := stream.SendMsg(request)
107
108 // Check: assert SendMsg propagates the error
109 if !errors.Is(err, sentinelErr) {
110 t.Errorf("got %v, want %v", err, sentinelErr)
111 }
112
113 // Check: assert InitialRequest returns the request
114 if diff := cmp.Diff(request, *stream.InitialRequest(), cmpopts.IgnoreUnexported(newspb.BinaryAttachment{})); diff != "" {
115 t.Fatalf("InitialRequest() (-want +got):\n%s", diff)
116 }
117}
118
119// mockClientStream is a grpc.ClientStream that returns a given error on SendMsg and RecvMsg.
120type mockClientStream struct {
121 grpc.ClientStream
122 sendErr error
123 recvErr error
124}
125
126func (s *mockClientStream) SendMsg(any) error {
127 return s.sendErr
128}
129
130func (s *mockClientStream) RecvMsg(any) error {
131 return s.recvErr
132}
133
134func TestCallBackServerStream(t *testing.T) {
135 t.Run("SendMsg calls postMessageSend with message and error", func(t *testing.T) {
136 sentinelMessage := struct{}{}
137 sentinelErr := errors.New("send error")
138
139 var called bool
140 stream := callBackServerStream{
141 ServerStream: &mockServerStream{
142 sendErr: sentinelErr,
143 },
144 postMessageSend: func(message any, err error) {
145 called = true
146
147 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
148 t.Errorf("postMessageSend called with unexpected message (-want +got):\n%s", diff)
149 }
150 if !errors.Is(err, sentinelErr) {
151 t.Errorf("got %v, want %v", err, sentinelErr)
152 }
153 },
154 }
155
156 sendErr := stream.SendMsg(sentinelMessage)
157 if !called {
158 t.Error("postMessageSend not called")
159 }
160
161 if !errors.Is(sendErr, sentinelErr) {
162 t.Errorf("got %v, want %v", sendErr, sentinelErr)
163 }
164 })
165
166 t.Run("RecvMsg calls postMessageReceive with message and error", func(t *testing.T) {
167 sentinelMessage := struct{}{}
168 sentinelErr := errors.New("receive error")
169
170 var called bool
171 stream := callBackServerStream{
172 ServerStream: &mockServerStream{
173 recvErr: sentinelErr,
174 },
175 postMessageReceive: func(message any, err error) {
176 called = true
177
178 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
179 t.Errorf("postMessageReceive called with unexpected message (-want +got):\n%s", diff)
180 }
181 if !errors.Is(err, sentinelErr) {
182 t.Errorf("got %v, want %v", err, sentinelErr)
183 }
184 },
185 }
186
187 receiveErr := stream.RecvMsg(sentinelMessage)
188 if !called {
189 t.Error("postMessageReceive not called")
190 }
191
192 if !errors.Is(receiveErr, sentinelErr) {
193 t.Errorf("got %v, want %v", receiveErr, sentinelErr)
194 }
195 })
196}
197
198func TestRequestSavingServerStream_InitialRequest(t *testing.T) {
199 // Setup: create a mock ServerStream that returns a sentinel error on SendMsg
200 sentinelErr := errors.New("receive error")
201 mockServerStream := &mockServerStream{
202 recvErr: sentinelErr,
203 }
204
205 // Setup: create a requestSavingServerStream with the mock ServerStream
206 stream := &requestSavingServerStream{
207 ServerStream: mockServerStream,
208 }
209
210 // Setup: create a sample proto.Message for the request
211 request := &newspb.BinaryAttachment{
212 Name: "sample_request",
213 Data: []byte("sample data"),
214 }
215
216 // Test: call RecvMsg with the request
217 err := stream.RecvMsg(request)
218
219 // Check: assert RecvMsg propagates the error
220 if !errors.Is(err, sentinelErr) {
221 t.Errorf("got %v, want %v", err, sentinelErr)
222 }
223
224 // Check: assert InitialRequest returns the request
225 if diff := cmp.Diff(request, *stream.InitialRequest(), cmpopts.IgnoreUnexported(newspb.BinaryAttachment{})); diff != "" {
226 t.Fatalf("InitialRequest() (-want +got):\n%s", diff)
227 }
228}
229
230// mockServerStream is a grpc.ServerStream that returns a given error on SendMsg and RecvMsg.
231type mockServerStream struct {
232 grpc.ServerStream
233 sendErr error
234 recvErr error
235}
236
237func (s *mockServerStream) SendMsg(any) error {
238 return s.sendErr
239}
240
241func (s *mockServerStream) RecvMsg(any) error {
242 return s.recvErr
243}
244
245func TestProbablyInternalGRPCError(t *testing.T) {
246 checker := func(s *status.Status) bool {
247 return strings.HasPrefix(s.Message(), "custom error")
248 }
249
250 testCases := []struct {
251 status *status.Status
252 checkers []internalGRPCErrorChecker
253 wantResult bool
254 }{
255 {
256 status: status.New(codes.OK, ""),
257 checkers: []internalGRPCErrorChecker{func(*status.Status) bool { return true }},
258 wantResult: false,
259 },
260 {
261 status: status.New(codes.Internal, "custom error message"),
262 checkers: []internalGRPCErrorChecker{checker},
263 wantResult: true,
264 },
265 {
266 status: status.New(codes.Internal, "some other error"),
267 checkers: []internalGRPCErrorChecker{checker},
268 wantResult: false,
269 },
270 }
271
272 for _, tc := range testCases {
273 gotResult := probablyInternalGRPCError(tc.status, tc.checkers)
274 if gotResult != tc.wantResult {
275 t.Errorf("probablyInternalGRPCError(%v, %v) = %v, want %v", tc.status, tc.checkers, gotResult, tc.wantResult)
276 }
277 }
278}
279
280func TestGRPCResourceExhaustedChecker(t *testing.T) {
281 testCases := []struct {
282 status *status.Status
283 expectPass bool
284 }{
285 {
286 status: status.New(codes.ResourceExhausted, "trying to send message larger than max (1024 vs 2)"),
287 expectPass: true,
288 },
289 {
290 status: status.New(codes.ResourceExhausted, "some other error"),
291 expectPass: false,
292 },
293 {
294 status: status.New(codes.OK, "trying to send message larger than max (1024 vs 5)"),
295 expectPass: false,
296 },
297 }
298
299 for _, tc := range testCases {
300 actual := gRPCResourceExhaustedChecker(tc.status)
301 if actual != tc.expectPass {
302 t.Errorf("gRPCResourceExhaustedChecker(%v) got %t, want %t", tc.status, actual, tc.expectPass)
303 }
304 }
305}
306
307func TestGRPCPrefixChecker(t *testing.T) {
308 tests := []struct {
309 status *status.Status
310 want bool
311 }{
312 {
313 status: status.New(codes.OK, "not a grpc error"),
314 want: false,
315 },
316 {
317 status: status.New(codes.Internal, "grpc: internal server error"),
318 want: true,
319 },
320 {
321 status: status.New(codes.Unavailable, "some other error"),
322 want: false,
323 },
324 }
325 for _, test := range tests {
326 got := gRPCPrefixChecker(test.status)
327 if got != test.want {
328 t.Errorf("gRPCPrefixChecker(%v) = %v, want %v", test.status, got, test.want)
329 }
330 }
331}
332
333func TestGRPCUnexpectedContentTypeChecker(t *testing.T) {
334 tests := []struct {
335 name string
336 status *status.Status
337 want bool
338 }{
339 {
340 name: "gRPC error with OK status",
341 status: status.New(codes.OK, "transport: received unexpected content-type"),
342 want: false,
343 },
344 {
345 name: "gRPC error without unexpected content-type message",
346 status: status.New(codes.Internal, "some random error"),
347 want: false,
348 },
349 {
350 name: "gRPC error with unexpected content-type message",
351 status: status.Newf(codes.Internal, "transport: received unexpected content-type %q", "application/octet-stream"),
352 want: true,
353 },
354 {
355 name: "gRPC error with unexpected content-type message as part of chain",
356 status: status.Newf(codes.Unknown, "transport: malformed grpc-status %q; transport: received unexpected content-type %q", "random-status", "application/octet-stream"),
357 want: true,
358 },
359 }
360
361 for _, tt := range tests {
362 t.Run(tt.name, func(t *testing.T) {
363 if got := gRPCUnexpectedContentTypeChecker(tt.status); got != tt.want {
364 t.Errorf("gRPCUnexpectedContentTypeChecker() = %v, want %v", got, tt.want)
365 }
366 })
367 }
368}
369
370func TestFindNonUTF8StringFields(t *testing.T) {
371 // Create instances of the BinaryAttachment and KeyValueAttachment messages
372 invalidBinaryAttachment := &newspb.BinaryAttachment{
373 Name: "inval\x80id_binary",
374 Data: []byte("sample data"),
375 }
376
377 invalidKeyValueAttachment := &newspb.KeyValueAttachment{
378 Name: "inval\x80id_key_value",
379 Data: map[string]string{
380 "key1": "value1",
381 "key2": "inval\x80id_value",
382 },
383 }
384
385 // Create a sample Article message with invalid UTF-8 strings
386 article := &newspb.Article{
387 Author: "inval\x80id_author",
388 Date: ×tamppb.Timestamp{Seconds: 1234567890},
389 Title: "valid_title",
390 Content: "valid_content",
391 Status: newspb.Article_STATUS_PUBLISHED,
392 Attachments: []*newspb.Attachment{
393 {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: invalidBinaryAttachment}},
394 {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: invalidKeyValueAttachment}},
395 },
396 }
397
398 tests := []struct {
399 name string
400 message proto.Message
401 expectedPaths []string
402 }{
403 {
404 name: "Article with invalid UTF-8 strings",
405 message: article,
406 expectedPaths: []string{
407 "author",
408 "attachments[0].binary_attachment.name",
409 "attachments[1].key_value_attachment.name",
410 `attachments[1].key_value_attachment.data["key2"]`,
411 },
412 },
413 {
414 name: "nil message",
415 message: nil,
416 expectedPaths: []string{},
417 },
418 }
419
420 for _, tt := range tests {
421 t.Run(tt.name, func(t *testing.T) {
422 invalidFields, err := findNonUTF8StringFields(tt.message)
423 if err != nil {
424 t.Fatalf("unexpected error: %v", err)
425 }
426
427 sort.Strings(invalidFields)
428 sort.Strings(tt.expectedPaths)
429
430 if diff := cmp.Diff(tt.expectedPaths, invalidFields, cmpopts.EquateEmpty()); diff != "" {
431 t.Fatalf("unexpected invalid fields (-want +got):\n%s", diff)
432 }
433 })
434 }
435}
436
437func TestMassageIntoStatusErr(t *testing.T) {
438 testCases := []struct {
439 description string
440 input error
441 expected *status.Status
442 expectedOk bool
443 }{
444 {
445 description: "nil error",
446 input: nil,
447 expected: nil,
448 expectedOk: false,
449 },
450 {
451 description: "status error",
452 input: status.Errorf(codes.InvalidArgument, "invalid argument"),
453 expected: status.New(codes.InvalidArgument, "invalid argument"),
454 expectedOk: true,
455 },
456 {
457 description: "context.Canceled error",
458 input: context.Canceled,
459 expected: status.New(codes.Canceled, "context canceled"),
460 expectedOk: true,
461 },
462 {
463 description: "context.DeadlineExceeded error",
464 input: context.DeadlineExceeded,
465 expected: status.New(codes.DeadlineExceeded, "context deadline exceeded"),
466 expectedOk: true,
467 },
468 {
469 description: "non-status error",
470 input: errors.New("non-status error"),
471 expected: nil,
472 expectedOk: false,
473 },
474 }
475
476 for _, tc := range testCases {
477 t.Run(tc.description, func(t *testing.T) {
478 result, ok := massageIntoStatusErr(tc.input)
479 if ok != tc.expectedOk {
480 t.Errorf("Expected ok to be %v, but got %v", tc.expectedOk, ok)
481 }
482
483 expectedStatusString := fmt.Sprintf("%s", tc.expected)
484 actualStatusString := fmt.Sprintf("%s", result)
485
486 if diff := cmp.Diff(expectedStatusString, actualStatusString); diff != "" {
487 t.Fatalf("Unexpected status string (-want +got):\n%s", diff)
488 }
489 })
490 }
491}