fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1// Package chunk provides a utility for sending sets of protobuf messages in 2// groups of smaller chunks. This is useful for gRPC, which has limitations around the maximum 3// size of a message that you can send. 4// 5// This code is adapted from the gitaly project, which is licensed 6// under the MIT license. A copy of that license text can be found at 7// https://mit-license.org/. 8// 9// The code this file was based off can be found here: https://gitlab.com/gitlab-org/gitaly/-/blob/v16.2.0/internal/helper/chunk/chunker_test.go 10package chunk 11 12import ( 13 "bytes" 14 "context" 15 "errors" 16 "fmt" 17 "io" 18 "net" 19 "strconv" 20 "testing" 21 "testing/quick" 22 23 "github.com/dustin/go-humanize" 24 "github.com/google/go-cmp/cmp" 25 "github.com/stretchr/testify/require" 26 "google.golang.org/grpc" 27 "google.golang.org/grpc/credentials/insecure" 28 "google.golang.org/grpc/interop/grpc_testing" 29 "google.golang.org/protobuf/proto" 30) 31 32func TestChunker_DeliverAllMessages(t *testing.T) { 33 runTest := func(inputPayloads [][]byte) error { 34 expectedPayloadSizeBytes := 0 35 for _, payload := range inputPayloads { 36 expectedPayloadSizeBytes += len(payload) 37 } 38 39 var receivedPayloads []*grpc_testing.Payload 40 41 // Tell the chunker to just gather all the payloads for later inspection. 42 sendFunc := func(payloads []*grpc_testing.Payload) error { 43 receivedPayloads = append(receivedPayloads, payloads...) 44 return nil 45 } 46 47 c := New(sendFunc) 48 49 // send all the payloads 50 for _, payload := range inputPayloads { 51 if err := c.Send(&grpc_testing.Payload{Body: payload}); err != nil { 52 return fmt.Errorf("error sending payload: %s", err) 53 } 54 } 55 56 if err := c.Flush(); err != nil { 57 return fmt.Errorf("error flushing chunker: %s", err) 58 } 59 60 // Confirm that we received the same number of payloads as we sent. 61 if diff := cmp.Diff(len(inputPayloads), len(receivedPayloads)); diff != "" { 62 return fmt.Errorf("unexpected number of payloads (-want +got):\n%s", diff) 63 } 64 65 // Confirm that each received payload is the same as the original. 66 for i, payload := range receivedPayloads { 67 expectedPayload := inputPayloads[i] 68 if diff := cmp.Diff(expectedPayload, payload.GetBody()); diff != "" { 69 return fmt.Errorf("for payload #%d (-want +got):\n%s", i, diff) 70 } 71 } 72 73 receivedPayloadSizeBytes := 0 74 for _, payload := range receivedPayloads { 75 receivedPayloadSizeBytes += len(payload.GetBody()) 76 } 77 78 // Confirm that the total size of the payloads we received is the same as the total size of the payloads we sent. 79 if diff := cmp.Diff(expectedPayloadSizeBytes, receivedPayloadSizeBytes); diff != "" { 80 return fmt.Errorf("unexpected payload size (-want +got):\n%s", diff) 81 } 82 83 return nil 84 } 85 86 t.Run("normal", func(t *testing.T) { 87 t.Parallel() 88 89 inputPayloads := [][]byte{ 90 {1, 2, 3}, 91 bytes.Repeat([]byte("a"), int(3.5*maxMessageSize)), 92 {4, 5, 6}, 93 } 94 95 if err := runTest(inputPayloads); err != nil { 96 t.Fatal(err) 97 } 98 }) 99 100 t.Run("some empty", func(t *testing.T) { 101 t.Parallel() 102 103 inputPayloads := [][]byte{ 104 {}, 105 []byte("foo, bar, baz"), 106 bytes.Repeat([]byte("a"), int(3.5*maxMessageSize)), 107 {}, 108 } 109 110 if err := runTest(inputPayloads); err != nil { 111 t.Fatal(err) 112 } 113 }) 114 115 t.Run("fuzz", func(t *testing.T) { 116 t.Parallel() 117 118 var lastErr error 119 120 if err := quick.Check(func(payloads [][]byte) bool { 121 lastErr = runTest(payloads) 122 if lastErr != nil { 123 return false 124 } 125 126 return true 127 }, nil); err != nil { 128 t.Fatal(lastErr) 129 } 130 }) 131} 132 133func TestChunkerE2E(t *testing.T) { 134 for _, test := range []struct { 135 name string 136 137 inputSizeBytes int 138 expectedMessageCount int 139 }{ 140 { 141 name: "normal", 142 143 inputSizeBytes: int(3.5 * maxMessageSize), 144 expectedMessageCount: 4, 145 }, 146 { 147 name: "empty payload", 148 inputSizeBytes: 0, 149 expectedMessageCount: 1, 150 }, 151 } { 152 t.Run(test.name, func(t *testing.T) { 153 s := &server{} 154 srv, serverSocketPath := runServer(t, s) 155 t.Cleanup(func() { 156 srv.Stop() 157 }) 158 159 client, conn := newClient(t, serverSocketPath) 160 t.Cleanup(func() { 161 _ = conn.Close() 162 }) 163 164 ctx := context.Background() 165 166 stream, err := client.StreamingOutputCall(ctx, &grpc_testing.StreamingOutputCallRequest{ 167 Payload: &grpc_testing.Payload{ 168 Body: []byte(strconv.FormatInt(int64(test.inputSizeBytes), 10)), 169 }, 170 }) 171 172 require.NoError(t, err) 173 174 messageCount := 0 175 var receivedPayload []byte 176 for { 177 resp, err := stream.Recv() 178 if errors.Is(err, io.EOF) { 179 break 180 } 181 182 if err != nil { 183 t.Fatal(err) 184 } 185 186 messageCount++ 187 receivedPayload = append(receivedPayload, resp.GetPayload().GetBody()...) 188 189 require.Less(t, proto.Size(resp), maxMessageSize) 190 } 191 192 require.Equal(t, test.expectedMessageCount, messageCount) 193 194 receivedPayloadSizeBytes := len(receivedPayload) 195 196 expectedSizeBytes := test.inputSizeBytes 197 198 if receivedPayloadSizeBytes != expectedSizeBytes { 199 t.Fatalf("input payload size is not %d bytes (~ %q), got size: %d (~ %q)", 200 expectedSizeBytes, humanize.Bytes(uint64(expectedSizeBytes)), 201 receivedPayloadSizeBytes, humanize.Bytes(uint64(receivedPayloadSizeBytes)), 202 ) 203 } 204 }) 205 } 206} 207 208type server struct { 209 grpc_testing.UnimplementedTestServiceServer 210} 211 212func (s *server) StreamingOutputCall(req *grpc_testing.StreamingOutputCallRequest, stream grpc_testing.TestService_StreamingOutputCallServer) error { 213 const kilobyte = 1024 214 215 c := New[*grpc_testing.Payload](func(payloads []*grpc_testing.Payload) error { 216 var body []byte 217 for _, p := range payloads { 218 body = append(body, p.GetBody()...) 219 } 220 221 return stream.Send(&grpc_testing.StreamingOutputCallResponse{Payload: &grpc_testing.Payload{Body: body}}) 222 }) 223 224 bytesToSend, err := strconv.ParseInt(string(req.GetPayload().GetBody()), 10, 64) 225 if err != nil { 226 return err 227 } 228 229 if bytesToSend == 0 { 230 if err := c.Send(&grpc_testing.Payload{}); err != nil { 231 return err 232 } 233 234 return c.Flush() 235 } 236 237 for numBytes := int64(0); numBytes < bytesToSend; numBytes += kilobyte { 238 if err := c.Send(&grpc_testing.Payload{Body: make([]byte, kilobyte)}); err != nil { 239 return err 240 } 241 } 242 243 return c.Flush() 244} 245 246func runServer(t *testing.T, s *server, opt ...grpc.ServerOption) (*grpc.Server, string) { 247 grpcServer := grpc.NewServer(opt...) 248 grpc_testing.RegisterTestServiceServer(grpcServer, s) 249 250 lis, err := net.Listen("tcp", ":0") 251 require.NoError(t, err) 252 253 go func() { 254 err := grpcServer.Serve(lis) 255 require.NoError(t, err) 256 }() 257 258 t.Cleanup(func() { 259 grpcServer.Stop() 260 lis.Close() 261 }) 262 263 return grpcServer, lis.Addr().String() 264} 265 266func newClient(t *testing.T, serverSocketPath string) (grpc_testing.TestServiceClient, *grpc.ClientConn) { 267 connOpts := []grpc.DialOption{ 268 grpc.WithTransportCredentials(insecure.NewCredentials()), 269 } 270 271 conn, err := grpc.Dial(serverSocketPath, connOpts...) 272 if err != nil { 273 t.Fatal(err) 274 } 275 276 return grpc_testing.NewTestServiceClient(conn), conn 277}