fork of https://github.com/sourcegraph/zoekt
1// Package chunk provides a utility for sending sets of protobuf messages in
2// groups of smaller chunks. This is useful for gRPC, which has limitations around the maximum
3// size of a message that you can send.
4//
5// This code is adapted from the gitaly project, which is licensed
6// under the MIT license. A copy of that license text can be found at
7// https://mit-license.org/.
8//
9// The code this file was based off can be found here: https://gitlab.com/gitlab-org/gitaly/-/blob/v16.2.0/internal/helper/chunk/chunker_test.go
10package chunk
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "net"
19 "strconv"
20 "testing"
21 "testing/quick"
22
23 "github.com/dustin/go-humanize"
24 "github.com/google/go-cmp/cmp"
25 "github.com/stretchr/testify/require"
26 "google.golang.org/grpc"
27 "google.golang.org/grpc/credentials/insecure"
28 "google.golang.org/grpc/interop/grpc_testing"
29 "google.golang.org/protobuf/proto"
30)
31
32func TestChunker_DeliverAllMessages(t *testing.T) {
33 runTest := func(inputPayloads [][]byte) error {
34 expectedPayloadSizeBytes := 0
35 for _, payload := range inputPayloads {
36 expectedPayloadSizeBytes += len(payload)
37 }
38
39 var receivedPayloads []*grpc_testing.Payload
40
41 // Tell the chunker to just gather all the payloads for later inspection.
42 sendFunc := func(payloads []*grpc_testing.Payload) error {
43 receivedPayloads = append(receivedPayloads, payloads...)
44 return nil
45 }
46
47 c := New(sendFunc)
48
49 // send all the payloads
50 for _, payload := range inputPayloads {
51 if err := c.Send(&grpc_testing.Payload{Body: payload}); err != nil {
52 return fmt.Errorf("error sending payload: %s", err)
53 }
54 }
55
56 if err := c.Flush(); err != nil {
57 return fmt.Errorf("error flushing chunker: %s", err)
58 }
59
60 // Confirm that we received the same number of payloads as we sent.
61 if diff := cmp.Diff(len(inputPayloads), len(receivedPayloads)); diff != "" {
62 return fmt.Errorf("unexpected number of payloads (-want +got):\n%s", diff)
63 }
64
65 // Confirm that each received payload is the same as the original.
66 for i, payload := range receivedPayloads {
67 expectedPayload := inputPayloads[i]
68 if diff := cmp.Diff(expectedPayload, payload.GetBody()); diff != "" {
69 return fmt.Errorf("for payload #%d (-want +got):\n%s", i, diff)
70 }
71 }
72
73 receivedPayloadSizeBytes := 0
74 for _, payload := range receivedPayloads {
75 receivedPayloadSizeBytes += len(payload.GetBody())
76 }
77
78 // Confirm that the total size of the payloads we received is the same as the total size of the payloads we sent.
79 if diff := cmp.Diff(expectedPayloadSizeBytes, receivedPayloadSizeBytes); diff != "" {
80 return fmt.Errorf("unexpected payload size (-want +got):\n%s", diff)
81 }
82
83 return nil
84 }
85
86 t.Run("normal", func(t *testing.T) {
87 t.Parallel()
88
89 inputPayloads := [][]byte{
90 {1, 2, 3},
91 bytes.Repeat([]byte("a"), int(3.5*maxMessageSize)),
92 {4, 5, 6},
93 }
94
95 if err := runTest(inputPayloads); err != nil {
96 t.Fatal(err)
97 }
98 })
99
100 t.Run("some empty", func(t *testing.T) {
101 t.Parallel()
102
103 inputPayloads := [][]byte{
104 {},
105 []byte("foo, bar, baz"),
106 bytes.Repeat([]byte("a"), int(3.5*maxMessageSize)),
107 {},
108 }
109
110 if err := runTest(inputPayloads); err != nil {
111 t.Fatal(err)
112 }
113 })
114
115 t.Run("fuzz", func(t *testing.T) {
116 t.Parallel()
117
118 var lastErr error
119
120 if err := quick.Check(func(payloads [][]byte) bool {
121 lastErr = runTest(payloads)
122 if lastErr != nil {
123 return false
124 }
125
126 return true
127 }, nil); err != nil {
128 t.Fatal(lastErr)
129 }
130 })
131}
132
133func TestChunkerE2E(t *testing.T) {
134 for _, test := range []struct {
135 name string
136
137 inputSizeBytes int
138 expectedMessageCount int
139 }{
140 {
141 name: "normal",
142
143 inputSizeBytes: int(3.5 * maxMessageSize),
144 expectedMessageCount: 4,
145 },
146 {
147 name: "empty payload",
148 inputSizeBytes: 0,
149 expectedMessageCount: 1,
150 },
151 } {
152 t.Run(test.name, func(t *testing.T) {
153 s := &server{}
154 srv, serverSocketPath := runServer(t, s)
155 t.Cleanup(func() {
156 srv.Stop()
157 })
158
159 client, conn := newClient(t, serverSocketPath)
160 t.Cleanup(func() {
161 _ = conn.Close()
162 })
163
164 ctx := context.Background()
165
166 stream, err := client.StreamingOutputCall(ctx, &grpc_testing.StreamingOutputCallRequest{
167 Payload: &grpc_testing.Payload{
168 Body: []byte(strconv.FormatInt(int64(test.inputSizeBytes), 10)),
169 },
170 })
171
172 require.NoError(t, err)
173
174 messageCount := 0
175 var receivedPayload []byte
176 for {
177 resp, err := stream.Recv()
178 if errors.Is(err, io.EOF) {
179 break
180 }
181
182 if err != nil {
183 t.Fatal(err)
184 }
185
186 messageCount++
187 receivedPayload = append(receivedPayload, resp.GetPayload().GetBody()...)
188
189 require.Less(t, proto.Size(resp), maxMessageSize)
190 }
191
192 require.Equal(t, test.expectedMessageCount, messageCount)
193
194 receivedPayloadSizeBytes := len(receivedPayload)
195
196 expectedSizeBytes := test.inputSizeBytes
197
198 if receivedPayloadSizeBytes != expectedSizeBytes {
199 t.Fatalf("input payload size is not %d bytes (~ %q), got size: %d (~ %q)",
200 expectedSizeBytes, humanize.Bytes(uint64(expectedSizeBytes)),
201 receivedPayloadSizeBytes, humanize.Bytes(uint64(receivedPayloadSizeBytes)),
202 )
203 }
204
205 })
206 }
207}
208
209type server struct {
210 grpc_testing.UnimplementedTestServiceServer
211}
212
213func (s *server) StreamingOutputCall(req *grpc_testing.StreamingOutputCallRequest, stream grpc_testing.TestService_StreamingOutputCallServer) error {
214 const kilobyte = 1024
215
216 c := New[*grpc_testing.Payload](func(payloads []*grpc_testing.Payload) error {
217 var body []byte
218 for _, p := range payloads {
219 body = append(body, p.GetBody()...)
220 }
221
222 return stream.Send(&grpc_testing.StreamingOutputCallResponse{Payload: &grpc_testing.Payload{Body: body}})
223 })
224
225 bytesToSend, err := strconv.ParseInt(string(req.GetPayload().GetBody()), 10, 64)
226 if err != nil {
227 return err
228 }
229
230 if bytesToSend == 0 {
231 if err := c.Send(&grpc_testing.Payload{}); err != nil {
232 return err
233 }
234
235 return c.Flush()
236 }
237
238 for numBytes := int64(0); numBytes < bytesToSend; numBytes += kilobyte {
239 if err := c.Send(&grpc_testing.Payload{Body: make([]byte, kilobyte)}); err != nil {
240 return err
241 }
242 }
243
244 return c.Flush()
245}
246
247func runServer(t *testing.T, s *server, opt ...grpc.ServerOption) (*grpc.Server, string) {
248 grpcServer := grpc.NewServer(opt...)
249 grpc_testing.RegisterTestServiceServer(grpcServer, s)
250
251 lis, err := net.Listen("tcp", ":0")
252 require.NoError(t, err)
253
254 go func() {
255 err := grpcServer.Serve(lis)
256 require.NoError(t, err)
257 }()
258
259 t.Cleanup(func() {
260 grpcServer.Stop()
261 lis.Close()
262 })
263
264 return grpcServer, lis.Addr().String()
265}
266
267func newClient(t *testing.T, serverSocketPath string) (grpc_testing.TestServiceClient, *grpc.ClientConn) {
268 connOpts := []grpc.DialOption{
269 grpc.WithTransportCredentials(insecure.NewCredentials()),
270 }
271
272 conn, err := grpc.Dial(serverSocketPath, connOpts...)
273 if err != nil {
274 t.Fatal(err)
275 }
276
277 return grpc_testing.NewTestServiceClient(conn), conn
278}