fork of https://github.com/sourcegraph/zoekt
1package internalerrs
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "sort"
8 "strings"
9 "testing"
10
11 "github.com/google/go-cmp/cmp"
12 "github.com/google/go-cmp/cmp/cmpopts"
13 newsv1 "github.com/sourcegraph/zoekt/grpc/testprotos/news/v1"
14 "google.golang.org/grpc"
15 "google.golang.org/grpc/codes"
16 "google.golang.org/grpc/status"
17 "google.golang.org/protobuf/proto"
18 "google.golang.org/protobuf/types/known/timestamppb"
19)
20
21func TestCallBackClientStream(t *testing.T) {
22 t.Run("SendMsg calls postMessageSend with message and error", func(t *testing.T) {
23 sentinelMessage := struct{}{}
24 sentinelErr := errors.New("send error")
25
26 var called bool
27 stream := callBackClientStream{
28 ClientStream: &mockClientStream{
29 sendErr: sentinelErr,
30 },
31 postMessageSend: func(message any, err error) {
32 called = true
33
34 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
35 t.Errorf("postMessageSend called with unexpected message (-want +got):\n%s", diff)
36 }
37 if !errors.Is(err, sentinelErr) {
38 t.Errorf("got %v, want %v", err, sentinelErr)
39 }
40 },
41 }
42
43 sendErr := stream.SendMsg(sentinelMessage)
44 if !called {
45 t.Error("postMessageSend not called")
46 }
47
48 if !errors.Is(sendErr, sentinelErr) {
49 t.Errorf("got %v, want %v", sendErr, sentinelErr)
50 }
51 })
52
53 t.Run("RecvMsg calls postMessageReceive with message and error", func(t *testing.T) {
54 sentinelMessage := struct{}{}
55 sentinelErr := errors.New("receive error")
56
57 var called bool
58 stream := callBackClientStream{
59 ClientStream: &mockClientStream{
60 recvErr: sentinelErr,
61 },
62 postMessageReceive: func(message any, err error) {
63 called = true
64
65 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
66 t.Errorf("postMessageReceive called with unexpected message (-want +got):\n%s", diff)
67 }
68 if !errors.Is(err, sentinelErr) {
69 t.Errorf("got %v, want %v", err, sentinelErr)
70 }
71 },
72 }
73
74 receiveErr := stream.RecvMsg(sentinelMessage)
75 if !called {
76 t.Error("postMessageReceive not called")
77 }
78
79 if !errors.Is(receiveErr, sentinelErr) {
80 t.Errorf("got %v, want %v", receiveErr, sentinelErr)
81 }
82 })
83}
84
85func TestRequestSavingClientStream_InitialRequest(t *testing.T) {
86 // Setup: create a mock ClientStream that returns a sentinel error on SendMsg
87 sentinelErr := errors.New("send error")
88 mockClientStream := &mockClientStream{
89 sendErr: sentinelErr,
90 }
91
92 // Setup: create a requestSavingClientStream with the mock ClientStream
93 stream := &requestSavingClientStream{
94 ClientStream: mockClientStream,
95 }
96
97 // Setup: create a sample proto.Message for the request
98 request := &newsv1.BinaryAttachment{
99 Name: "sample_request",
100 Data: []byte("sample data"),
101 }
102
103 // Test: call SendMsg with the request
104 err := stream.SendMsg(request)
105
106 // Check: assert SendMsg propagates the error
107 if !errors.Is(err, sentinelErr) {
108 t.Errorf("got %v, want %v", err, sentinelErr)
109 }
110
111 // Check: assert InitialRequest returns the request
112 if diff := cmp.Diff(request, *stream.InitialRequest(), cmpopts.IgnoreUnexported(newsv1.BinaryAttachment{})); diff != "" {
113 t.Fatalf("InitialRequest() (-want +got):\n%s", diff)
114 }
115}
116
117// mockClientStream is a grpc.ClientStream that returns a given error on SendMsg and RecvMsg.
118type mockClientStream struct {
119 grpc.ClientStream
120 sendErr error
121 recvErr error
122}
123
124func (s *mockClientStream) SendMsg(any) error {
125 return s.sendErr
126}
127
128func (s *mockClientStream) RecvMsg(any) error {
129 return s.recvErr
130}
131
132func TestCallBackServerStream(t *testing.T) {
133 t.Run("SendMsg calls postMessageSend with message and error", func(t *testing.T) {
134 sentinelMessage := struct{}{}
135 sentinelErr := errors.New("send error")
136
137 var called bool
138 stream := callBackServerStream{
139 ServerStream: &mockServerStream{
140 sendErr: sentinelErr,
141 },
142 postMessageSend: func(message any, err error) {
143 called = true
144
145 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
146 t.Errorf("postMessageSend called with unexpected message (-want +got):\n%s", diff)
147 }
148 if !errors.Is(err, sentinelErr) {
149 t.Errorf("got %v, want %v", err, sentinelErr)
150 }
151 },
152 }
153
154 sendErr := stream.SendMsg(sentinelMessage)
155 if !called {
156 t.Error("postMessageSend not called")
157 }
158
159 if !errors.Is(sendErr, sentinelErr) {
160 t.Errorf("got %v, want %v", sendErr, sentinelErr)
161 }
162 })
163
164 t.Run("RecvMsg calls postMessageReceive with message and error", func(t *testing.T) {
165 sentinelMessage := struct{}{}
166 sentinelErr := errors.New("receive error")
167
168 var called bool
169 stream := callBackServerStream{
170 ServerStream: &mockServerStream{
171 recvErr: sentinelErr,
172 },
173 postMessageReceive: func(message any, err error) {
174 called = true
175
176 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
177 t.Errorf("postMessageReceive called with unexpected message (-want +got):\n%s", diff)
178 }
179 if !errors.Is(err, sentinelErr) {
180 t.Errorf("got %v, want %v", err, sentinelErr)
181 }
182 },
183 }
184
185 receiveErr := stream.RecvMsg(sentinelMessage)
186 if !called {
187 t.Error("postMessageReceive not called")
188 }
189
190 if !errors.Is(receiveErr, sentinelErr) {
191 t.Errorf("got %v, want %v", receiveErr, sentinelErr)
192 }
193 })
194}
195
196func TestRequestSavingServerStream_InitialRequest(t *testing.T) {
197 // Setup: create a mock ServerStream that returns a sentinel error on SendMsg
198 sentinelErr := errors.New("receive error")
199 mockServerStream := &mockServerStream{
200 recvErr: sentinelErr,
201 }
202
203 // Setup: create a requestSavingServerStream with the mock ServerStream
204 stream := &requestSavingServerStream{
205 ServerStream: mockServerStream,
206 }
207
208 // Setup: create a sample proto.Message for the request
209 request := &newsv1.BinaryAttachment{
210 Name: "sample_request",
211 Data: []byte("sample data"),
212 }
213
214 // Test: call RecvMsg with the request
215 err := stream.RecvMsg(request)
216
217 // Check: assert RecvMsg propagates the error
218 if !errors.Is(err, sentinelErr) {
219 t.Errorf("got %v, want %v", err, sentinelErr)
220 }
221
222 // Check: assert InitialRequest returns the request
223 if diff := cmp.Diff(request, *stream.InitialRequest(), cmpopts.IgnoreUnexported(newsv1.BinaryAttachment{})); diff != "" {
224 t.Fatalf("InitialRequest() (-want +got):\n%s", diff)
225 }
226}
227
228// mockServerStream is a grpc.ServerStream that returns a given error on SendMsg and RecvMsg.
229type mockServerStream struct {
230 grpc.ServerStream
231 sendErr error
232 recvErr error
233}
234
235func (s *mockServerStream) SendMsg(any) error {
236 return s.sendErr
237}
238
239func (s *mockServerStream) RecvMsg(any) error {
240 return s.recvErr
241}
242
243func TestProbablyInternalGRPCError(t *testing.T) {
244 checker := func(s *status.Status) bool {
245 return strings.HasPrefix(s.Message(), "custom error")
246 }
247
248 testCases := []struct {
249 status *status.Status
250 checkers []internalGRPCErrorChecker
251 wantResult bool
252 }{
253 {
254 status: status.New(codes.OK, ""),
255 checkers: []internalGRPCErrorChecker{func(*status.Status) bool { return true }},
256 wantResult: false,
257 },
258 {
259 status: status.New(codes.Internal, "custom error message"),
260 checkers: []internalGRPCErrorChecker{checker},
261 wantResult: true,
262 },
263 {
264 status: status.New(codes.Internal, "some other error"),
265 checkers: []internalGRPCErrorChecker{checker},
266 wantResult: false,
267 },
268 }
269
270 for _, tc := range testCases {
271 gotResult := probablyInternalGRPCError(tc.status, tc.checkers)
272 if gotResult != tc.wantResult {
273 t.Errorf("probablyInternalGRPCError(%v, %v) = %v, want %v", tc.status, tc.checkers, gotResult, tc.wantResult)
274 }
275 }
276}
277
278func TestGRPCResourceExhaustedChecker(t *testing.T) {
279 testCases := []struct {
280 status *status.Status
281 expectPass bool
282 }{
283 {
284 status: status.New(codes.ResourceExhausted, "trying to send message larger than max (1024 vs 2)"),
285 expectPass: true,
286 },
287 {
288 status: status.New(codes.ResourceExhausted, "some other error"),
289 expectPass: false,
290 },
291 {
292 status: status.New(codes.OK, "trying to send message larger than max (1024 vs 5)"),
293 expectPass: false,
294 },
295 }
296
297 for _, tc := range testCases {
298 actual := gRPCResourceExhaustedChecker(tc.status)
299 if actual != tc.expectPass {
300 t.Errorf("gRPCResourceExhaustedChecker(%v) got %t, want %t", tc.status, actual, tc.expectPass)
301 }
302 }
303}
304
305func TestGRPCPrefixChecker(t *testing.T) {
306 tests := []struct {
307 status *status.Status
308 want bool
309 }{
310 {
311 status: status.New(codes.OK, "not a grpc error"),
312 want: false,
313 },
314 {
315 status: status.New(codes.Internal, "grpc: internal server error"),
316 want: true,
317 },
318 {
319 status: status.New(codes.Unavailable, "some other error"),
320 want: false,
321 },
322 }
323 for _, test := range tests {
324 got := gRPCPrefixChecker(test.status)
325 if got != test.want {
326 t.Errorf("gRPCPrefixChecker(%v) = %v, want %v", test.status, got, test.want)
327 }
328 }
329}
330
331func TestGRPCUnexpectedContentTypeChecker(t *testing.T) {
332 tests := []struct {
333 name string
334 status *status.Status
335 want bool
336 }{
337 {
338 name: "gRPC error with OK status",
339 status: status.New(codes.OK, "transport: received unexpected content-type"),
340 want: false,
341 },
342 {
343 name: "gRPC error without unexpected content-type message",
344 status: status.New(codes.Internal, "some random error"),
345 want: false,
346 },
347 {
348 name: "gRPC error with unexpected content-type message",
349 status: status.Newf(codes.Internal, "transport: received unexpected content-type %q", "application/octet-stream"),
350 want: true,
351 },
352 {
353 name: "gRPC error with unexpected content-type message as part of chain",
354 status: status.Newf(codes.Unknown, "transport: malformed grpc-status %q; transport: received unexpected content-type %q", "random-status", "application/octet-stream"),
355 want: true,
356 },
357 }
358
359 for _, tt := range tests {
360 t.Run(tt.name, func(t *testing.T) {
361 if got := gRPCUnexpectedContentTypeChecker(tt.status); got != tt.want {
362 t.Errorf("gRPCUnexpectedContentTypeChecker() = %v, want %v", got, tt.want)
363 }
364 })
365 }
366}
367
368func TestFindNonUTF8StringFields(t *testing.T) {
369 // Create instances of the BinaryAttachment and KeyValueAttachment messages
370 invalidBinaryAttachment := &newsv1.BinaryAttachment{
371 Name: "inval\x80id_binary",
372 Data: []byte("sample data"),
373 }
374
375 invalidKeyValueAttachment := &newsv1.KeyValueAttachment{
376 Name: "inval\x80id_key_value",
377 Data: map[string]string{
378 "key1": "value1",
379 "key2": "inval\x80id_value",
380 },
381 }
382
383 // Create a sample Article message with invalid UTF-8 strings
384 article := &newsv1.Article{
385 Author: "inval\x80id_author",
386 Date: ×tamppb.Timestamp{Seconds: 1234567890},
387 Title: "valid_title",
388 Content: "valid_content",
389 Status: newsv1.Article_STATUS_PUBLISHED,
390 Attachments: []*newsv1.Attachment{
391 {Contents: &newsv1.Attachment_BinaryAttachment{BinaryAttachment: invalidBinaryAttachment}},
392 {Contents: &newsv1.Attachment_KeyValueAttachment{KeyValueAttachment: invalidKeyValueAttachment}},
393 },
394 }
395
396 tests := []struct {
397 name string
398 message proto.Message
399 expectedPaths []string
400 }{
401 {
402 name: "Article with invalid UTF-8 strings",
403 message: article,
404 expectedPaths: []string{
405 "author",
406 "attachments[0].binary_attachment.name",
407 "attachments[1].key_value_attachment.name",
408 `attachments[1].key_value_attachment.data["key2"]`,
409 },
410 },
411 {
412 name: "nil message",
413 message: nil,
414 expectedPaths: []string{},
415 },
416 }
417
418 for _, tt := range tests {
419 t.Run(tt.name, func(t *testing.T) {
420 invalidFields, err := findNonUTF8StringFields(tt.message)
421 if err != nil {
422 t.Fatalf("unexpected error: %v", err)
423 }
424
425 sort.Strings(invalidFields)
426 sort.Strings(tt.expectedPaths)
427
428 if diff := cmp.Diff(tt.expectedPaths, invalidFields, cmpopts.EquateEmpty()); diff != "" {
429 t.Fatalf("unexpected invalid fields (-want +got):\n%s", diff)
430 }
431 })
432 }
433}
434
435func TestMassageIntoStatusErr(t *testing.T) {
436 testCases := []struct {
437 description string
438 input error
439 expected *status.Status
440 expectedOk bool
441 }{
442 {
443 description: "nil error",
444 input: nil,
445 expected: nil,
446 expectedOk: false,
447 },
448 {
449 description: "status error",
450 input: status.Errorf(codes.InvalidArgument, "invalid argument"),
451 expected: status.New(codes.InvalidArgument, "invalid argument"),
452 expectedOk: true,
453 },
454 {
455 description: "context.Canceled error",
456 input: context.Canceled,
457 expected: status.New(codes.Canceled, "context canceled"),
458 expectedOk: true,
459 },
460 {
461 description: "context.DeadlineExceeded error",
462 input: context.DeadlineExceeded,
463 expected: status.New(codes.DeadlineExceeded, "context deadline exceeded"),
464 expectedOk: true,
465 },
466 {
467 description: "non-status error",
468 input: errors.New("non-status error"),
469 expected: nil,
470 expectedOk: false,
471 },
472 }
473
474 for _, tc := range testCases {
475 t.Run(tc.description, func(t *testing.T) {
476 result, ok := massageIntoStatusErr(tc.input)
477 if ok != tc.expectedOk {
478 t.Errorf("Expected ok to be %v, but got %v", tc.expectedOk, ok)
479 }
480
481 expectedStatusString := fmt.Sprintf("%s", tc.expected)
482 actualStatusString := fmt.Sprintf("%s", result)
483
484 if diff := cmp.Diff(expectedStatusString, actualStatusString); diff != "" {
485 t.Fatalf("Unexpected status string (-want +got):\n%s", diff)
486 }
487 })
488 }
489}