fork of https://github.com/sourcegraph/zoekt
1// Package chunk provides a utility for sending sets of protobuf messages in
2// groups of smaller chunks. This is useful for gRPC, which has limitations around the maximum
3// size of a message that you can send.
4//
5// This code is adapted from the gitaly project, which is licensed
6// under the MIT license. A copy of that license text can be found at
7// https://mit-license.org/.
8//
9// The code this file was based off can be found here: https://gitlab.com/gitlab-org/gitaly/-/blob/v16.2.0/internal/helper/chunk/chunker_test.go
10package chunk
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "net"
19 "strconv"
20 "testing"
21 "testing/quick"
22
23 "github.com/dustin/go-humanize"
24 "github.com/google/go-cmp/cmp"
25 "github.com/stretchr/testify/require"
26 "google.golang.org/grpc"
27 "google.golang.org/grpc/credentials/insecure"
28 "google.golang.org/grpc/interop/grpc_testing"
29 "google.golang.org/protobuf/proto"
30)
31
32func TestChunker_DeliverAllMessages(t *testing.T) {
33 runTest := func(inputPayloads [][]byte) error {
34 expectedPayloadSizeBytes := 0
35 for _, payload := range inputPayloads {
36 expectedPayloadSizeBytes += len(payload)
37 }
38
39 var receivedPayloads []*grpc_testing.Payload
40
41 // Tell the chunker to just gather all the payloads for later inspection.
42 sendFunc := func(payloads []*grpc_testing.Payload) error {
43 receivedPayloads = append(receivedPayloads, payloads...)
44 return nil
45 }
46
47 c := New(sendFunc)
48
49 // send all the payloads
50 for _, payload := range inputPayloads {
51 if err := c.Send(&grpc_testing.Payload{Body: payload}); err != nil {
52 return fmt.Errorf("error sending payload: %s", err)
53 }
54 }
55
56 if err := c.Flush(); err != nil {
57 return fmt.Errorf("error flushing chunker: %s", err)
58 }
59
60 // Confirm that we received the same number of payloads as we sent.
61 if diff := cmp.Diff(len(inputPayloads), len(receivedPayloads)); diff != "" {
62 return fmt.Errorf("unexpected number of payloads (-want +got):\n%s", diff)
63 }
64
65 // Confirm that each received payload is the same as the original.
66 for i, payload := range receivedPayloads {
67 expectedPayload := inputPayloads[i]
68 if diff := cmp.Diff(expectedPayload, payload.GetBody()); diff != "" {
69 return fmt.Errorf("for payload #%d (-want +got):\n%s", i, diff)
70 }
71 }
72
73 receivedPayloadSizeBytes := 0
74 for _, payload := range receivedPayloads {
75 receivedPayloadSizeBytes += len(payload.GetBody())
76 }
77
78 // Confirm that the total size of the payloads we received is the same as the total size of the payloads we sent.
79 if diff := cmp.Diff(expectedPayloadSizeBytes, receivedPayloadSizeBytes); diff != "" {
80 return fmt.Errorf("unexpected payload size (-want +got):\n%s", diff)
81 }
82
83 return nil
84 }
85
86 t.Run("normal", func(t *testing.T) {
87 t.Parallel()
88
89 inputPayloads := [][]byte{
90 {1, 2, 3},
91 bytes.Repeat([]byte("a"), int(3.5*maxMessageSize)),
92 {4, 5, 6},
93 }
94
95 if err := runTest(inputPayloads); err != nil {
96 t.Fatal(err)
97 }
98 })
99
100 t.Run("some empty", func(t *testing.T) {
101 t.Parallel()
102
103 inputPayloads := [][]byte{
104 {},
105 []byte("foo, bar, baz"),
106 bytes.Repeat([]byte("a"), int(3.5*maxMessageSize)),
107 {},
108 }
109
110 if err := runTest(inputPayloads); err != nil {
111 t.Fatal(err)
112 }
113 })
114
115 t.Run("fuzz", func(t *testing.T) {
116 t.Parallel()
117
118 var lastErr error
119
120 if err := quick.Check(func(payloads [][]byte) bool {
121 lastErr = runTest(payloads)
122 if lastErr != nil {
123 return false
124 }
125
126 return true
127 }, nil); err != nil {
128 t.Fatal(lastErr)
129 }
130 })
131}
132
133func TestChunkerE2E(t *testing.T) {
134 for _, test := range []struct {
135 name string
136
137 inputSizeBytes int
138 expectedMessageCount int
139 }{
140 {
141 name: "normal",
142
143 inputSizeBytes: int(3.5 * maxMessageSize),
144 expectedMessageCount: 4,
145 },
146 {
147 name: "empty payload",
148 inputSizeBytes: 0,
149 expectedMessageCount: 1,
150 },
151 } {
152 t.Run(test.name, func(t *testing.T) {
153 s := &server{}
154 srv, serverSocketPath := runServer(t, s)
155 t.Cleanup(func() {
156 srv.Stop()
157 })
158
159 client, conn := newClient(t, serverSocketPath)
160 t.Cleanup(func() {
161 _ = conn.Close()
162 })
163
164 ctx := context.Background()
165
166 stream, err := client.StreamingOutputCall(ctx, &grpc_testing.StreamingOutputCallRequest{
167 Payload: &grpc_testing.Payload{
168 Body: []byte(strconv.FormatInt(int64(test.inputSizeBytes), 10)),
169 },
170 })
171
172 require.NoError(t, err)
173
174 messageCount := 0
175 var receivedPayload []byte
176 for {
177 resp, err := stream.Recv()
178 if errors.Is(err, io.EOF) {
179 break
180 }
181
182 if err != nil {
183 t.Fatal(err)
184 }
185
186 messageCount++
187 receivedPayload = append(receivedPayload, resp.GetPayload().GetBody()...)
188
189 require.Less(t, proto.Size(resp), maxMessageSize)
190 }
191
192 require.Equal(t, test.expectedMessageCount, messageCount)
193
194 receivedPayloadSizeBytes := len(receivedPayload)
195
196 expectedSizeBytes := test.inputSizeBytes
197
198 if receivedPayloadSizeBytes != expectedSizeBytes {
199 t.Fatalf("input payload size is not %d bytes (~ %q), got size: %d (~ %q)",
200 expectedSizeBytes, humanize.Bytes(uint64(expectedSizeBytes)),
201 receivedPayloadSizeBytes, humanize.Bytes(uint64(receivedPayloadSizeBytes)),
202 )
203 }
204 })
205 }
206}
207
208type server struct {
209 grpc_testing.UnimplementedTestServiceServer
210}
211
212func (s *server) StreamingOutputCall(req *grpc_testing.StreamingOutputCallRequest, stream grpc_testing.TestService_StreamingOutputCallServer) error {
213 const kilobyte = 1024
214
215 c := New[*grpc_testing.Payload](func(payloads []*grpc_testing.Payload) error {
216 var body []byte
217 for _, p := range payloads {
218 body = append(body, p.GetBody()...)
219 }
220
221 return stream.Send(&grpc_testing.StreamingOutputCallResponse{Payload: &grpc_testing.Payload{Body: body}})
222 })
223
224 bytesToSend, err := strconv.ParseInt(string(req.GetPayload().GetBody()), 10, 64)
225 if err != nil {
226 return err
227 }
228
229 if bytesToSend == 0 {
230 if err := c.Send(&grpc_testing.Payload{}); err != nil {
231 return err
232 }
233
234 return c.Flush()
235 }
236
237 for numBytes := int64(0); numBytes < bytesToSend; numBytes += kilobyte {
238 if err := c.Send(&grpc_testing.Payload{Body: make([]byte, kilobyte)}); err != nil {
239 return err
240 }
241 }
242
243 return c.Flush()
244}
245
246func runServer(t *testing.T, s *server, opt ...grpc.ServerOption) (*grpc.Server, string) {
247 grpcServer := grpc.NewServer(opt...)
248 grpc_testing.RegisterTestServiceServer(grpcServer, s)
249
250 lis, err := net.Listen("tcp", ":0")
251 require.NoError(t, err)
252
253 go func() {
254 err := grpcServer.Serve(lis)
255 require.NoError(t, err)
256 }()
257
258 t.Cleanup(func() {
259 grpcServer.Stop()
260 lis.Close()
261 })
262
263 return grpcServer, lis.Addr().String()
264}
265
266func newClient(t *testing.T, serverSocketPath string) (grpc_testing.TestServiceClient, *grpc.ClientConn) {
267 connOpts := []grpc.DialOption{
268 grpc.WithTransportCredentials(insecure.NewCredentials()),
269 }
270
271 conn, err := grpc.Dial(serverSocketPath, connOpts...)
272 if err != nil {
273 t.Fatal(err)
274 }
275
276 return grpc_testing.NewTestServiceClient(conn), conn
277}