fork of https://github.com/sourcegraph/zoekt
1package internalerrs
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "sort"
8 "strings"
9 "testing"
10
11 "github.com/google/go-cmp/cmp"
12 "github.com/google/go-cmp/cmp/cmpopts"
13 "google.golang.org/grpc"
14 "google.golang.org/grpc/codes"
15 "google.golang.org/grpc/status"
16 "google.golang.org/protobuf/proto"
17 "google.golang.org/protobuf/types/known/timestamppb"
18
19 newsv1 "github.com/sourcegraph/zoekt/grpc/testprotos/news/v1"
20)
21
22func TestCallBackClientStream(t *testing.T) {
23 t.Run("SendMsg calls postMessageSend with message and error", func(t *testing.T) {
24 sentinelMessage := struct{}{}
25 sentinelErr := errors.New("send error")
26
27 var called bool
28 stream := callBackClientStream{
29 ClientStream: &mockClientStream{
30 sendErr: sentinelErr,
31 },
32 postMessageSend: func(message any, err error) {
33 called = true
34
35 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
36 t.Errorf("postMessageSend called with unexpected message (-want +got):\n%s", diff)
37 }
38 if !errors.Is(err, sentinelErr) {
39 t.Errorf("got %v, want %v", err, sentinelErr)
40 }
41 },
42 }
43
44 sendErr := stream.SendMsg(sentinelMessage)
45 if !called {
46 t.Error("postMessageSend not called")
47 }
48
49 if !errors.Is(sendErr, sentinelErr) {
50 t.Errorf("got %v, want %v", sendErr, sentinelErr)
51 }
52 })
53
54 t.Run("RecvMsg calls postMessageReceive with message and error", func(t *testing.T) {
55 sentinelMessage := struct{}{}
56 sentinelErr := errors.New("receive error")
57
58 var called bool
59 stream := callBackClientStream{
60 ClientStream: &mockClientStream{
61 recvErr: sentinelErr,
62 },
63 postMessageReceive: func(message any, err error) {
64 called = true
65
66 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
67 t.Errorf("postMessageReceive called with unexpected message (-want +got):\n%s", diff)
68 }
69 if !errors.Is(err, sentinelErr) {
70 t.Errorf("got %v, want %v", err, sentinelErr)
71 }
72 },
73 }
74
75 receiveErr := stream.RecvMsg(sentinelMessage)
76 if !called {
77 t.Error("postMessageReceive not called")
78 }
79
80 if !errors.Is(receiveErr, sentinelErr) {
81 t.Errorf("got %v, want %v", receiveErr, sentinelErr)
82 }
83 })
84}
85
86func TestRequestSavingClientStream_InitialRequest(t *testing.T) {
87 // Setup: create a mock ClientStream that returns a sentinel error on SendMsg
88 sentinelErr := errors.New("send error")
89 mockClientStream := &mockClientStream{
90 sendErr: sentinelErr,
91 }
92
93 // Setup: create a requestSavingClientStream with the mock ClientStream
94 stream := &requestSavingClientStream{
95 ClientStream: mockClientStream,
96 }
97
98 // Setup: create a sample proto.Message for the request
99 request := &newsv1.BinaryAttachment{
100 Name: "sample_request",
101 Data: []byte("sample data"),
102 }
103
104 // Test: call SendMsg with the request
105 err := stream.SendMsg(request)
106
107 // Check: assert SendMsg propagates the error
108 if !errors.Is(err, sentinelErr) {
109 t.Errorf("got %v, want %v", err, sentinelErr)
110 }
111
112 // Check: assert InitialRequest returns the request
113 if diff := cmp.Diff(request, *stream.InitialRequest(), cmpopts.IgnoreUnexported(newsv1.BinaryAttachment{})); diff != "" {
114 t.Fatalf("InitialRequest() (-want +got):\n%s", diff)
115 }
116}
117
118// mockClientStream is a grpc.ClientStream that returns a given error on SendMsg and RecvMsg.
119type mockClientStream struct {
120 grpc.ClientStream
121 sendErr error
122 recvErr error
123}
124
125func (s *mockClientStream) SendMsg(any) error {
126 return s.sendErr
127}
128
129func (s *mockClientStream) RecvMsg(any) error {
130 return s.recvErr
131}
132
133func TestCallBackServerStream(t *testing.T) {
134 t.Run("SendMsg calls postMessageSend with message and error", func(t *testing.T) {
135 sentinelMessage := struct{}{}
136 sentinelErr := errors.New("send error")
137
138 var called bool
139 stream := callBackServerStream{
140 ServerStream: &mockServerStream{
141 sendErr: sentinelErr,
142 },
143 postMessageSend: func(message any, err error) {
144 called = true
145
146 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
147 t.Errorf("postMessageSend called with unexpected message (-want +got):\n%s", diff)
148 }
149 if !errors.Is(err, sentinelErr) {
150 t.Errorf("got %v, want %v", err, sentinelErr)
151 }
152 },
153 }
154
155 sendErr := stream.SendMsg(sentinelMessage)
156 if !called {
157 t.Error("postMessageSend not called")
158 }
159
160 if !errors.Is(sendErr, sentinelErr) {
161 t.Errorf("got %v, want %v", sendErr, sentinelErr)
162 }
163 })
164
165 t.Run("RecvMsg calls postMessageReceive with message and error", func(t *testing.T) {
166 sentinelMessage := struct{}{}
167 sentinelErr := errors.New("receive error")
168
169 var called bool
170 stream := callBackServerStream{
171 ServerStream: &mockServerStream{
172 recvErr: sentinelErr,
173 },
174 postMessageReceive: func(message any, err error) {
175 called = true
176
177 if diff := cmp.Diff(message, sentinelMessage); diff != "" {
178 t.Errorf("postMessageReceive called with unexpected message (-want +got):\n%s", diff)
179 }
180 if !errors.Is(err, sentinelErr) {
181 t.Errorf("got %v, want %v", err, sentinelErr)
182 }
183 },
184 }
185
186 receiveErr := stream.RecvMsg(sentinelMessage)
187 if !called {
188 t.Error("postMessageReceive not called")
189 }
190
191 if !errors.Is(receiveErr, sentinelErr) {
192 t.Errorf("got %v, want %v", receiveErr, sentinelErr)
193 }
194 })
195}
196
197func TestRequestSavingServerStream_InitialRequest(t *testing.T) {
198 // Setup: create a mock ServerStream that returns a sentinel error on SendMsg
199 sentinelErr := errors.New("receive error")
200 mockServerStream := &mockServerStream{
201 recvErr: sentinelErr,
202 }
203
204 // Setup: create a requestSavingServerStream with the mock ServerStream
205 stream := &requestSavingServerStream{
206 ServerStream: mockServerStream,
207 }
208
209 // Setup: create a sample proto.Message for the request
210 request := &newsv1.BinaryAttachment{
211 Name: "sample_request",
212 Data: []byte("sample data"),
213 }
214
215 // Test: call RecvMsg with the request
216 err := stream.RecvMsg(request)
217
218 // Check: assert RecvMsg propagates the error
219 if !errors.Is(err, sentinelErr) {
220 t.Errorf("got %v, want %v", err, sentinelErr)
221 }
222
223 // Check: assert InitialRequest returns the request
224 if diff := cmp.Diff(request, *stream.InitialRequest(), cmpopts.IgnoreUnexported(newsv1.BinaryAttachment{})); diff != "" {
225 t.Fatalf("InitialRequest() (-want +got):\n%s", diff)
226 }
227}
228
229// mockServerStream is a grpc.ServerStream that returns a given error on SendMsg and RecvMsg.
230type mockServerStream struct {
231 grpc.ServerStream
232 sendErr error
233 recvErr error
234}
235
236func (s *mockServerStream) SendMsg(any) error {
237 return s.sendErr
238}
239
240func (s *mockServerStream) RecvMsg(any) error {
241 return s.recvErr
242}
243
244func TestProbablyInternalGRPCError(t *testing.T) {
245 checker := func(s *status.Status) bool {
246 return strings.HasPrefix(s.Message(), "custom error")
247 }
248
249 testCases := []struct {
250 status *status.Status
251 checkers []internalGRPCErrorChecker
252 wantResult bool
253 }{
254 {
255 status: status.New(codes.OK, ""),
256 checkers: []internalGRPCErrorChecker{func(*status.Status) bool { return true }},
257 wantResult: false,
258 },
259 {
260 status: status.New(codes.Internal, "custom error message"),
261 checkers: []internalGRPCErrorChecker{checker},
262 wantResult: true,
263 },
264 {
265 status: status.New(codes.Internal, "some other error"),
266 checkers: []internalGRPCErrorChecker{checker},
267 wantResult: false,
268 },
269 }
270
271 for _, tc := range testCases {
272 gotResult := probablyInternalGRPCError(tc.status, tc.checkers)
273 if gotResult != tc.wantResult {
274 t.Errorf("probablyInternalGRPCError(%v, %v) = %v, want %v", tc.status, tc.checkers, gotResult, tc.wantResult)
275 }
276 }
277}
278
279func TestGRPCResourceExhaustedChecker(t *testing.T) {
280 testCases := []struct {
281 status *status.Status
282 expectPass bool
283 }{
284 {
285 status: status.New(codes.ResourceExhausted, "trying to send message larger than max (1024 vs 2)"),
286 expectPass: true,
287 },
288 {
289 status: status.New(codes.ResourceExhausted, "some other error"),
290 expectPass: false,
291 },
292 {
293 status: status.New(codes.OK, "trying to send message larger than max (1024 vs 5)"),
294 expectPass: false,
295 },
296 }
297
298 for _, tc := range testCases {
299 actual := gRPCResourceExhaustedChecker(tc.status)
300 if actual != tc.expectPass {
301 t.Errorf("gRPCResourceExhaustedChecker(%v) got %t, want %t", tc.status, actual, tc.expectPass)
302 }
303 }
304}
305
306func TestGRPCPrefixChecker(t *testing.T) {
307 tests := []struct {
308 status *status.Status
309 want bool
310 }{
311 {
312 status: status.New(codes.OK, "not a grpc error"),
313 want: false,
314 },
315 {
316 status: status.New(codes.Internal, "grpc: internal server error"),
317 want: true,
318 },
319 {
320 status: status.New(codes.Unavailable, "some other error"),
321 want: false,
322 },
323 }
324 for _, test := range tests {
325 got := gRPCPrefixChecker(test.status)
326 if got != test.want {
327 t.Errorf("gRPCPrefixChecker(%v) = %v, want %v", test.status, got, test.want)
328 }
329 }
330}
331
332func TestGRPCUnexpectedContentTypeChecker(t *testing.T) {
333 tests := []struct {
334 name string
335 status *status.Status
336 want bool
337 }{
338 {
339 name: "gRPC error with OK status",
340 status: status.New(codes.OK, "transport: received unexpected content-type"),
341 want: false,
342 },
343 {
344 name: "gRPC error without unexpected content-type message",
345 status: status.New(codes.Internal, "some random error"),
346 want: false,
347 },
348 {
349 name: "gRPC error with unexpected content-type message",
350 status: status.Newf(codes.Internal, "transport: received unexpected content-type %q", "application/octet-stream"),
351 want: true,
352 },
353 {
354 name: "gRPC error with unexpected content-type message as part of chain",
355 status: status.Newf(codes.Unknown, "transport: malformed grpc-status %q; transport: received unexpected content-type %q", "random-status", "application/octet-stream"),
356 want: true,
357 },
358 }
359
360 for _, tt := range tests {
361 t.Run(tt.name, func(t *testing.T) {
362 if got := gRPCUnexpectedContentTypeChecker(tt.status); got != tt.want {
363 t.Errorf("gRPCUnexpectedContentTypeChecker() = %v, want %v", got, tt.want)
364 }
365 })
366 }
367}
368
369func TestFindNonUTF8StringFields(t *testing.T) {
370 // Create instances of the BinaryAttachment and KeyValueAttachment messages
371 invalidBinaryAttachment := &newsv1.BinaryAttachment{
372 Name: "inval\x80id_binary",
373 Data: []byte("sample data"),
374 }
375
376 invalidKeyValueAttachment := &newsv1.KeyValueAttachment{
377 Name: "inval\x80id_key_value",
378 Data: map[string]string{
379 "key1": "value1",
380 "key2": "inval\x80id_value",
381 },
382 }
383
384 // Create a sample Article message with invalid UTF-8 strings
385 article := &newsv1.Article{
386 Author: "inval\x80id_author",
387 Date: ×tamppb.Timestamp{Seconds: 1234567890},
388 Title: "valid_title",
389 Content: "valid_content",
390 Status: newsv1.Article_STATUS_PUBLISHED,
391 Attachments: []*newsv1.Attachment{
392 {Contents: &newsv1.Attachment_BinaryAttachment{BinaryAttachment: invalidBinaryAttachment}},
393 {Contents: &newsv1.Attachment_KeyValueAttachment{KeyValueAttachment: invalidKeyValueAttachment}},
394 },
395 }
396
397 tests := []struct {
398 name string
399 message proto.Message
400 expectedPaths []string
401 }{
402 {
403 name: "Article with invalid UTF-8 strings",
404 message: article,
405 expectedPaths: []string{
406 "author",
407 "attachments[0].binary_attachment.name",
408 "attachments[1].key_value_attachment.name",
409 `attachments[1].key_value_attachment.data["key2"]`,
410 },
411 },
412 {
413 name: "nil message",
414 message: nil,
415 expectedPaths: []string{},
416 },
417 }
418
419 for _, tt := range tests {
420 t.Run(tt.name, func(t *testing.T) {
421 invalidFields, err := findNonUTF8StringFields(tt.message)
422 if err != nil {
423 t.Fatalf("unexpected error: %v", err)
424 }
425
426 sort.Strings(invalidFields)
427 sort.Strings(tt.expectedPaths)
428
429 if diff := cmp.Diff(tt.expectedPaths, invalidFields, cmpopts.EquateEmpty()); diff != "" {
430 t.Fatalf("unexpected invalid fields (-want +got):\n%s", diff)
431 }
432 })
433 }
434}
435
436func TestMassageIntoStatusErr(t *testing.T) {
437 testCases := []struct {
438 description string
439 input error
440 expected *status.Status
441 expectedOk bool
442 }{
443 {
444 description: "nil error",
445 input: nil,
446 expected: nil,
447 expectedOk: false,
448 },
449 {
450 description: "status error",
451 input: status.Errorf(codes.InvalidArgument, "invalid argument"),
452 expected: status.New(codes.InvalidArgument, "invalid argument"),
453 expectedOk: true,
454 },
455 {
456 description: "context.Canceled error",
457 input: context.Canceled,
458 expected: status.New(codes.Canceled, "context canceled"),
459 expectedOk: true,
460 },
461 {
462 description: "context.DeadlineExceeded error",
463 input: context.DeadlineExceeded,
464 expected: status.New(codes.DeadlineExceeded, "context deadline exceeded"),
465 expectedOk: true,
466 },
467 {
468 description: "non-status error",
469 input: errors.New("non-status error"),
470 expected: nil,
471 expectedOk: false,
472 },
473 }
474
475 for _, tc := range testCases {
476 t.Run(tc.description, func(t *testing.T) {
477 result, ok := massageIntoStatusErr(tc.input)
478 if ok != tc.expectedOk {
479 t.Errorf("Expected ok to be %v, but got %v", tc.expectedOk, ok)
480 }
481
482 expectedStatusString := fmt.Sprintf("%s", tc.expected)
483 actualStatusString := fmt.Sprintf("%s", result)
484
485 if diff := cmp.Diff(expectedStatusString, actualStatusString); diff != "" {
486 t.Fatalf("Unexpected status string (-want +got):\n%s", diff)
487 }
488 })
489 }
490}