fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1// Package chunk provides a utility for sending sets of protobuf messages in 2// groups of smaller chunks. This is useful for gRPC, which has limitations around the maximum 3// size of a message that you can send. 4// 5// This code is adapted from the gitaly project, which is licensed 6// under the MIT license. A copy of that license text can be found at 7// https://mit-license.org/. 8// 9// The code this file was based off can be found here: https://gitlab.com/gitlab-org/gitaly/-/blob/v16.2.0/internal/helper/chunk/chunker_test.go 10package chunk 11 12import ( 13 "bytes" 14 "context" 15 "errors" 16 "fmt" 17 "io" 18 "net" 19 "strconv" 20 "testing" 21 "testing/quick" 22 23 "github.com/dustin/go-humanize" 24 "github.com/google/go-cmp/cmp" 25 "github.com/stretchr/testify/require" 26 "google.golang.org/grpc" 27 "google.golang.org/grpc/credentials/insecure" 28 "google.golang.org/grpc/interop/grpc_testing" 29 "google.golang.org/protobuf/proto" 30) 31 32func TestChunker_DeliverAllMessages(t *testing.T) { 33 runTest := func(inputPayloads [][]byte) error { 34 expectedPayloadSizeBytes := 0 35 for _, payload := range inputPayloads { 36 expectedPayloadSizeBytes += len(payload) 37 } 38 39 var receivedPayloads []*grpc_testing.Payload 40 41 // Tell the chunker to just gather all the payloads for later inspection. 42 sendFunc := func(payloads []*grpc_testing.Payload) error { 43 receivedPayloads = append(receivedPayloads, payloads...) 44 return nil 45 } 46 47 c := New(sendFunc) 48 49 // send all the payloads 50 for _, payload := range inputPayloads { 51 if err := c.Send(&grpc_testing.Payload{Body: payload}); err != nil { 52 return fmt.Errorf("error sending payload: %s", err) 53 } 54 } 55 56 if err := c.Flush(); err != nil { 57 return fmt.Errorf("error flushing chunker: %s", err) 58 } 59 60 // Confirm that we received the same number of payloads as we sent. 61 if diff := cmp.Diff(len(inputPayloads), len(receivedPayloads)); diff != "" { 62 return fmt.Errorf("unexpected number of payloads (-want +got):\n%s", diff) 63 } 64 65 // Confirm that each received payload is the same as the original. 66 for i, payload := range receivedPayloads { 67 expectedPayload := inputPayloads[i] 68 if diff := cmp.Diff(expectedPayload, payload.GetBody()); diff != "" { 69 return fmt.Errorf("for payload #%d (-want +got):\n%s", i, diff) 70 } 71 } 72 73 receivedPayloadSizeBytes := 0 74 for _, payload := range receivedPayloads { 75 receivedPayloadSizeBytes += len(payload.GetBody()) 76 } 77 78 // Confirm that the total size of the payloads we received is the same as the total size of the payloads we sent. 79 if diff := cmp.Diff(expectedPayloadSizeBytes, receivedPayloadSizeBytes); diff != "" { 80 return fmt.Errorf("unexpected payload size (-want +got):\n%s", diff) 81 } 82 83 return nil 84 } 85 86 t.Run("normal", func(t *testing.T) { 87 t.Parallel() 88 89 inputPayloads := [][]byte{ 90 {1, 2, 3}, 91 bytes.Repeat([]byte("a"), int(3.5*maxMessageSize)), 92 {4, 5, 6}, 93 } 94 95 if err := runTest(inputPayloads); err != nil { 96 t.Fatal(err) 97 } 98 }) 99 100 t.Run("some empty", func(t *testing.T) { 101 t.Parallel() 102 103 inputPayloads := [][]byte{ 104 {}, 105 []byte("foo, bar, baz"), 106 bytes.Repeat([]byte("a"), int(3.5*maxMessageSize)), 107 {}, 108 } 109 110 if err := runTest(inputPayloads); err != nil { 111 t.Fatal(err) 112 } 113 }) 114 115 t.Run("fuzz", func(t *testing.T) { 116 t.Parallel() 117 118 var lastErr error 119 120 if err := quick.Check(func(payloads [][]byte) bool { 121 lastErr = runTest(payloads) 122 if lastErr != nil { 123 return false 124 } 125 126 return true 127 }, nil); err != nil { 128 t.Fatal(lastErr) 129 } 130 }) 131} 132 133func TestChunkerE2E(t *testing.T) { 134 for _, test := range []struct { 135 name string 136 137 inputSizeBytes int 138 expectedMessageCount int 139 }{ 140 { 141 name: "normal", 142 143 inputSizeBytes: int(3.5 * maxMessageSize), 144 expectedMessageCount: 4, 145 }, 146 { 147 name: "empty payload", 148 inputSizeBytes: 0, 149 expectedMessageCount: 1, 150 }, 151 } { 152 t.Run(test.name, func(t *testing.T) { 153 s := &server{} 154 srv, serverSocketPath := runServer(t, s) 155 t.Cleanup(func() { 156 srv.Stop() 157 }) 158 159 client, conn := newClient(t, serverSocketPath) 160 t.Cleanup(func() { 161 _ = conn.Close() 162 }) 163 164 ctx := context.Background() 165 166 stream, err := client.StreamingOutputCall(ctx, &grpc_testing.StreamingOutputCallRequest{ 167 Payload: &grpc_testing.Payload{ 168 Body: []byte(strconv.FormatInt(int64(test.inputSizeBytes), 10)), 169 }, 170 }) 171 172 require.NoError(t, err) 173 174 messageCount := 0 175 var receivedPayload []byte 176 for { 177 resp, err := stream.Recv() 178 if errors.Is(err, io.EOF) { 179 break 180 } 181 182 if err != nil { 183 t.Fatal(err) 184 } 185 186 messageCount++ 187 receivedPayload = append(receivedPayload, resp.GetPayload().GetBody()...) 188 189 require.Less(t, proto.Size(resp), maxMessageSize) 190 } 191 192 require.Equal(t, test.expectedMessageCount, messageCount) 193 194 receivedPayloadSizeBytes := len(receivedPayload) 195 196 expectedSizeBytes := test.inputSizeBytes 197 198 if receivedPayloadSizeBytes != expectedSizeBytes { 199 t.Fatalf("input payload size is not %d bytes (~ %q), got size: %d (~ %q)", 200 expectedSizeBytes, humanize.Bytes(uint64(expectedSizeBytes)), 201 receivedPayloadSizeBytes, humanize.Bytes(uint64(receivedPayloadSizeBytes)), 202 ) 203 } 204 205 }) 206 } 207} 208 209type server struct { 210 grpc_testing.UnimplementedTestServiceServer 211} 212 213func (s *server) StreamingOutputCall(req *grpc_testing.StreamingOutputCallRequest, stream grpc_testing.TestService_StreamingOutputCallServer) error { 214 const kilobyte = 1024 215 216 c := New[*grpc_testing.Payload](func(payloads []*grpc_testing.Payload) error { 217 var body []byte 218 for _, p := range payloads { 219 body = append(body, p.GetBody()...) 220 } 221 222 return stream.Send(&grpc_testing.StreamingOutputCallResponse{Payload: &grpc_testing.Payload{Body: body}}) 223 }) 224 225 bytesToSend, err := strconv.ParseInt(string(req.GetPayload().GetBody()), 10, 64) 226 if err != nil { 227 return err 228 } 229 230 if bytesToSend == 0 { 231 if err := c.Send(&grpc_testing.Payload{}); err != nil { 232 return err 233 } 234 235 return c.Flush() 236 } 237 238 for numBytes := int64(0); numBytes < bytesToSend; numBytes += kilobyte { 239 if err := c.Send(&grpc_testing.Payload{Body: make([]byte, kilobyte)}); err != nil { 240 return err 241 } 242 } 243 244 return c.Flush() 245} 246 247func runServer(t *testing.T, s *server, opt ...grpc.ServerOption) (*grpc.Server, string) { 248 grpcServer := grpc.NewServer(opt...) 249 grpc_testing.RegisterTestServiceServer(grpcServer, s) 250 251 lis, err := net.Listen("tcp", ":0") 252 require.NoError(t, err) 253 254 go func() { 255 err := grpcServer.Serve(lis) 256 require.NoError(t, err) 257 }() 258 259 t.Cleanup(func() { 260 grpcServer.Stop() 261 lis.Close() 262 }) 263 264 return grpcServer, lis.Addr().String() 265} 266 267func newClient(t *testing.T, serverSocketPath string) (grpc_testing.TestServiceClient, *grpc.ClientConn) { 268 connOpts := []grpc.DialOption{ 269 grpc.WithTransportCredentials(insecure.NewCredentials()), 270 } 271 272 conn, err := grpc.Dial(serverSocketPath, connOpts...) 273 if err != nil { 274 t.Fatal(err) 275 } 276 277 return grpc_testing.NewTestServiceClient(conn), conn 278}