fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

grpc: add support for prometheus metrics that calculates message size (#651)

+1083
+2
cmd/zoekt-sourcegraph-indexserver/main.go
··· 1482 1482 grpc.WithTransportCredentials(insecure.NewCredentials()), 1483 1483 grpc.WithChainStreamInterceptor( 1484 1484 metrics.StreamClientInterceptor(), 1485 + messagesize.StreamClientInterceptor, 1485 1486 internalActorStreamInterceptor(), 1486 1487 internalerrs.LoggingStreamClientInterceptor(logger), 1487 1488 internalerrs.PrometheusStreamClientInterceptor, ··· 1489 1490 ), 1490 1491 grpc.WithChainUnaryInterceptor( 1491 1492 metrics.UnaryClientInterceptor(), 1493 + messagesize.UnaryClientInterceptor, 1492 1494 internalActorUnaryInterceptor(), 1493 1495 internalerrs.LoggingUnaryClientInterceptor(logger), 1494 1496 internalerrs.PrometheusUnaryClientInterceptor,
+2
cmd/zoekt-webserver/main.go
··· 648 648 grpc.ChainStreamInterceptor( 649 649 otelgrpc.StreamServerInterceptor(), 650 650 metrics.StreamServerInterceptor(), 651 + messagesize.StreamServerInterceptor, 651 652 internalerrs.LoggingStreamServerInterceptor(logger), 652 653 ), 653 654 grpc.ChainUnaryInterceptor( 654 655 otelgrpc.UnaryServerInterceptor(), 655 656 metrics.UnaryServerInterceptor(), 657 + messagesize.UnaryServerInterceptor, 656 658 internalerrs.LoggingUnaryServerInterceptor(logger), 657 659 ), 658 660 }
+321
grpc/messagesize/prometheus.go
··· 1 + package messagesize 2 + 3 + import ( 4 + "context" 5 + "sync" 6 + "sync/atomic" 7 + 8 + "github.com/prometheus/client_golang/prometheus" 9 + "github.com/prometheus/client_golang/prometheus/promauto" 10 + "github.com/sourcegraph/zoekt/grpc/grpcutil" 11 + "google.golang.org/grpc" 12 + "google.golang.org/protobuf/proto" 13 + ) 14 + 15 + var ( 16 + metricServerSingleMessageSize = promauto.NewHistogramVec(prometheus.HistogramOpts{ 17 + Name: "grpc_server_sent_individual_message_size_bytes_per_rpc", 18 + Help: "Size of individual messages sent by the server per RPC.", 19 + Buckets: sizeBuckets, 20 + }, []string{ 21 + "grpc_service", // e.g. "gitserver.v1.GitserverService" 22 + "grpc_method", // e.g. "Exec" 23 + }) 24 + 25 + metricServerTotalSentPerRPCBytes = promauto.NewHistogramVec(prometheus.HistogramOpts{ 26 + Name: "grpc_server_sent_bytes_per_rpc", 27 + Help: "Total size of all the messages sent by the server during the course of a single RPC call", 28 + Buckets: sizeBuckets, 29 + }, []string{ 30 + "grpc_service", // e.g. "gitserver.v1.GitserverService" 31 + "grpc_method", // e.g. "Exec" 32 + }) 33 + 34 + metricClientSingleMessageSize = promauto.NewHistogramVec(prometheus.HistogramOpts{ 35 + Name: "grpc_client_sent_individual_message_size_per_rpc_bytes", 36 + Help: "Size of individual messages sent by the client per RPC.", 37 + Buckets: sizeBuckets, 38 + }, []string{ 39 + "grpc_service", // e.g. "gitserver.v1.GitserverService" 40 + "grpc_method", // e.g. "Exec" 41 + }) 42 + 43 + metricClientTotalSentPerRPCBytes = promauto.NewHistogramVec(prometheus.HistogramOpts{ 44 + Name: "grpc_client_sent_bytes_per_rpc", 45 + Help: "Total size of all the messages sent by the client during the course of a single RPC call", 46 + Buckets: sizeBuckets, 47 + }, []string{ 48 + "grpc_service", // e.g. "gitserver.v1.GitserverService" 49 + "grpc_method", // e.g. "Exec" 50 + }) 51 + ) 52 + 53 + const ( 54 + B = 1 55 + KB = 1024 * B 56 + MB = 1024 * KB 57 + GB = 1024 * MB 58 + ) 59 + 60 + var sizeBuckets = []float64{ 61 + 0, 62 + 1 * KB, 63 + 10 * KB, 64 + 50 * KB, 65 + 100 * KB, 66 + 500 * KB, 67 + 1 * MB, 68 + 5 * MB, 69 + 10 * MB, 70 + 50 * MB, 71 + 100 * MB, 72 + 500 * MB, 73 + 1 * GB, 74 + 5 * GB, 75 + 10 * GB, 76 + } 77 + 78 + // UnaryServerInterceptor is a grpc.UnaryServerInterceptor that records Prometheus metrics that observe the size of 79 + // the response message sent back by the server for a single RPC call. 80 + func UnaryServerInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { 81 + observer := newServerMessageSizeObserver(info.FullMethod) 82 + 83 + return unaryServerInterceptor(observer, req, ctx, info, handler) 84 + } 85 + 86 + func unaryServerInterceptor(observer *messageSizeObserver, req any, ctx context.Context, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 87 + defer observer.FinishRPC() 88 + 89 + r, err := handler(ctx, req) 90 + if err != nil { 91 + return r, err 92 + } 93 + 94 + response, ok := r.(proto.Message) 95 + if !ok { 96 + return r, nil 97 + } 98 + 99 + observer.Observe(response) 100 + return response, nil 101 + } 102 + 103 + // StreamServerInterceptor is a grpc.StreamServerInterceptor that records Prometheus metrics that observe both the sizes of the 104 + // individual response messages and the cumulative response size of all the message sent back by the server over the course 105 + // of a single RPC call. 106 + func StreamServerInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 107 + observer := newServerMessageSizeObserver(info.FullMethod) 108 + 109 + return streamServerInterceptor(observer, srv, ss, info, handler) 110 + } 111 + 112 + func streamServerInterceptor(observer *messageSizeObserver, srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 113 + defer observer.FinishRPC() 114 + 115 + wrappedStream := newObservingServerStream(ss, observer) 116 + 117 + return handler(srv, wrappedStream) 118 + } 119 + 120 + // UnaryClientInterceptor is a grpc.UnaryClientInterceptor that records Prometheus metrics that observe the size of 121 + // the request message sent by client for a single RPC call. 122 + func UnaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 123 + o := newClientMessageSizeObserver(method) 124 + return unaryClientInterceptor(o, ctx, method, req, reply, cc, invoker, opts...) 125 + } 126 + 127 + func unaryClientInterceptor(observer *messageSizeObserver, ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 128 + defer observer.FinishRPC() 129 + 130 + err := invoker(ctx, method, req, reply, cc, opts...) 131 + if err != nil { 132 + // Don't record the size of the message if there was an error sending it, since it may not have been sent. 133 + return err 134 + } 135 + 136 + // Observe the size of the request message. 137 + request, ok := req.(proto.Message) 138 + if !ok { 139 + return nil 140 + } 141 + 142 + observer.Observe(request) 143 + return nil 144 + } 145 + 146 + // StreamClientInterceptor is a grpc.StreamClientInterceptor that records Prometheus metrics that observe both the sizes of the 147 + // individual request messages and the cumulative request size of all the message sent by the client over the course 148 + // of a single RPC call. 149 + func StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 150 + observer := newClientMessageSizeObserver(method) 151 + 152 + return streamClientInterceptor(observer, ctx, desc, cc, method, streamer, opts...) 153 + } 154 + 155 + func streamClientInterceptor(observer *messageSizeObserver, ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 156 + s, err := streamer(ctx, desc, cc, method, opts...) 157 + if err != nil { 158 + return nil, err 159 + } 160 + 161 + wrappedStream := newObservingClientStream(s, observer) 162 + return wrappedStream, nil 163 + } 164 + 165 + type observingServerStream struct { 166 + grpc.ServerStream 167 + 168 + observer *messageSizeObserver 169 + } 170 + 171 + func newObservingServerStream(s grpc.ServerStream, observer *messageSizeObserver) grpc.ServerStream { 172 + return &observingServerStream{ 173 + ServerStream: s, 174 + observer: observer, 175 + } 176 + } 177 + 178 + func (s *observingServerStream) SendMsg(m any) error { 179 + err := s.ServerStream.SendMsg(m) 180 + if err != nil { 181 + // Don't record the size of the message if there was an error sending it, since it may not have been sent. 182 + // 183 + // However, the stream aborts on an error, 184 + // so we need to record the total size of the messages sent during the course of the RPC call. 185 + s.observer.FinishRPC() 186 + return err 187 + } 188 + 189 + // Observe the size of the sent message. 190 + message, ok := m.(proto.Message) 191 + if !ok { 192 + return nil 193 + } 194 + 195 + s.observer.Observe(message) 196 + return nil 197 + } 198 + 199 + type observingClientStream struct { 200 + grpc.ClientStream 201 + 202 + observer *messageSizeObserver 203 + } 204 + 205 + func newObservingClientStream(s grpc.ClientStream, observer *messageSizeObserver) grpc.ClientStream { 206 + return &observingClientStream{ 207 + ClientStream: s, 208 + observer: observer, 209 + } 210 + } 211 + 212 + func (s *observingClientStream) SendMsg(m any) error { 213 + err := s.ClientStream.SendMsg(m) 214 + if err != nil { 215 + // Don't record the size of the message if there was an error sending it, since it may not have been sent. 216 + // 217 + // However, the stream aborts on an error, 218 + // so we need to record the total size of the messages sent during the course of the RPC call. 219 + s.observer.FinishRPC() 220 + return err 221 + } 222 + 223 + // Observe the size of the sent message. 224 + message, ok := m.(proto.Message) 225 + if !ok { 226 + return nil 227 + } 228 + 229 + s.observer.Observe(message) 230 + return nil 231 + } 232 + 233 + func (s *observingClientStream) CloseSend() error { 234 + err := s.ClientStream.CloseSend() 235 + 236 + s.observer.FinishRPC() 237 + return err 238 + } 239 + 240 + func (s *observingClientStream) RecvMsg(m any) error { 241 + err := s.ClientStream.RecvMsg(m) 242 + if err != nil { 243 + // Record the total size of the messages sent during the course of the RPC call, even if there was an error. 244 + s.observer.FinishRPC() 245 + } 246 + 247 + return err 248 + } 249 + 250 + func newServerMessageSizeObserver(fullMethod string) *messageSizeObserver { 251 + serviceName, methodName := grpcutil.SplitMethodName(fullMethod) 252 + 253 + onSingle := func(messageSize uint64) { 254 + metricServerSingleMessageSize.WithLabelValues(serviceName, methodName).Observe(float64(messageSize)) 255 + } 256 + 257 + onFinish := func(messageSize uint64) { 258 + metricServerTotalSentPerRPCBytes.WithLabelValues(serviceName, methodName).Observe(float64(messageSize)) 259 + } 260 + 261 + return &messageSizeObserver{ 262 + onSingleFunc: onSingle, 263 + onFinishFunc: onFinish, 264 + } 265 + } 266 + 267 + func newClientMessageSizeObserver(fullMethod string) *messageSizeObserver { 268 + serviceName, methodName := grpcutil.SplitMethodName(fullMethod) 269 + 270 + onSingle := func(messageSize uint64) { 271 + metricClientSingleMessageSize.WithLabelValues(serviceName, methodName).Observe(float64(messageSize)) 272 + } 273 + 274 + onFinish := func(messageSize uint64) { 275 + metricClientTotalSentPerRPCBytes.WithLabelValues(serviceName, methodName).Observe(float64(messageSize)) 276 + } 277 + 278 + return &messageSizeObserver{ 279 + onSingleFunc: onSingle, 280 + onFinishFunc: onFinish, 281 + } 282 + } 283 + 284 + // messageSizeObserver is a utility that records Prometheus metrics that observe the size of each sent message and the 285 + // cumulative size of all sent messages during the course of a single RPC call. 286 + type messageSizeObserver struct { 287 + onSingleFunc func(messageSizeBytes uint64) 288 + 289 + finishOnce sync.Once 290 + onFinishFunc func(totalSizeBytes uint64) 291 + 292 + totalSizeBytes atomic.Uint64 293 + } 294 + 295 + // Observe records the size of a single message. 296 + func (o *messageSizeObserver) Observe(message proto.Message) { 297 + s := uint64(proto.Size(message)) 298 + o.onSingleFunc(s) 299 + 300 + o.totalSizeBytes.Add(s) 301 + } 302 + 303 + // FinishRPC records the total size of all sent messages during the course of a single RPC call. 304 + // This function should only be called once the RPC call has completed. 305 + func (o *messageSizeObserver) FinishRPC() { 306 + o.finishOnce.Do(func() { 307 + o.onFinishFunc(o.totalSizeBytes.Load()) 308 + }) 309 + } 310 + 311 + var ( 312 + _ grpc.ServerStream = &observingServerStream{} 313 + _ grpc.ClientStream = &observingClientStream{} 314 + ) 315 + 316 + var ( 317 + _ grpc.UnaryServerInterceptor = UnaryServerInterceptor 318 + _ grpc.StreamServerInterceptor = StreamServerInterceptor 319 + _ grpc.UnaryClientInterceptor = UnaryClientInterceptor 320 + _ grpc.StreamClientInterceptor = StreamClientInterceptor 321 + )
+758
grpc/messagesize/prometheus_test.go
··· 1 + package messagesize 2 + 3 + import ( 4 + "bytes" 5 + "context" 6 + "errors" 7 + "io" 8 + "strings" 9 + "testing" 10 + 11 + "github.com/google/go-cmp/cmp" 12 + "github.com/google/go-cmp/cmp/cmpopts" 13 + "github.com/stretchr/testify/require" 14 + "google.golang.org/grpc" 15 + "google.golang.org/protobuf/proto" 16 + "google.golang.org/protobuf/testing/protocmp" 17 + "google.golang.org/protobuf/types/known/timestamppb" 18 + 19 + newspb "github.com/sourcegraph/zoekt/grpc/testprotos/news/v1" 20 + ) 21 + 22 + var ( 23 + binaryMessage = &newspb.BinaryAttachment{ 24 + Name: "data", 25 + Data: []byte(strings.Repeat("x", 1*1024*1024)), 26 + } 27 + 28 + keyValueMessage = &newspb.KeyValueAttachment{ 29 + Name: "data", 30 + Data: map[string]string{ 31 + "key1": strings.Repeat("x", 1*1024*1024), 32 + "key2": "value2", 33 + }, 34 + } 35 + 36 + articleMessage = &newspb.Article{ 37 + Author: "author", 38 + Date: &timestamppb.Timestamp{Seconds: 1234567890}, 39 + Title: "title", 40 + Content: "content", 41 + Status: newspb.Article_STATUS_PUBLISHED, 42 + Attachments: []*newspb.Attachment{ 43 + {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}}, 44 + {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}}, 45 + {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}}, 46 + {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}}, 47 + }, 48 + } 49 + ) 50 + 51 + func BenchmarkObserverBinary(b *testing.B) { 52 + o := messageSizeObserver{ 53 + onSingleFunc: func(messageSizeBytes uint64) {}, 54 + onFinishFunc: func(totalSizeBytes uint64) {}, 55 + } 56 + 57 + benchmarkObserver(b, &o, binaryMessage) 58 + } 59 + 60 + func BenchmarkObserverKeyValue(b *testing.B) { 61 + o := messageSizeObserver{ 62 + onSingleFunc: func(messageSizeBytes uint64) {}, 63 + onFinishFunc: func(totalSizeBytes uint64) {}, 64 + } 65 + 66 + benchmarkObserver(b, &o, keyValueMessage) 67 + } 68 + 69 + func BenchmarkObserverArticle(b *testing.B) { 70 + o := messageSizeObserver{ 71 + onSingleFunc: func(messageSizeBytes uint64) {}, 72 + onFinishFunc: func(totalSizeBytes uint64) {}, 73 + } 74 + 75 + benchmarkObserver(b, &o, articleMessage) 76 + } 77 + 78 + func benchmarkObserver(b *testing.B, observer *messageSizeObserver, message proto.Message) { 79 + b.ReportAllocs() 80 + 81 + for n := 0; n < b.N; n++ { 82 + observer.Observe(message) 83 + } 84 + 85 + observer.FinishRPC() 86 + } 87 + 88 + func TestUnaryServerInterceptor(t *testing.T) { 89 + ctx := context.Background() 90 + 91 + request := &newspb.BinaryAttachment{ 92 + Data: bytes.Repeat([]byte("request"), 3), 93 + } 94 + 95 + response := &newspb.BinaryAttachment{ 96 + Data: bytes.Repeat([]byte("response"), 7), 97 + } 98 + 99 + info := &grpc.UnaryServerInfo{ 100 + FullMethod: "news.v1.NewsService/GetArticle", 101 + } 102 + 103 + sentinelError := errors.New("expected error") 104 + 105 + tests := []struct { 106 + name string 107 + handler func(ctx context.Context, req any) (any, error) 108 + expectedError error 109 + expectedResult any 110 + expectedSize uint64 111 + }{ 112 + { 113 + name: "invoker successful - observe response", 114 + handler: func(ctx context.Context, req any) (any, error) { 115 + return response, nil 116 + }, 117 + expectedError: nil, 118 + expectedResult: response, 119 + expectedSize: uint64(proto.Size(response)), 120 + }, 121 + { 122 + name: "invoker error - observe a zero-sized response", 123 + handler: func(ctx context.Context, req any) (any, error) { 124 + return nil, sentinelError 125 + }, 126 + expectedError: sentinelError, 127 + expectedResult: nil, 128 + expectedSize: uint64(0), 129 + }, 130 + } 131 + 132 + for _, test := range tests { 133 + t.Run(test.name, func(t *testing.T) { 134 + onFinishCalledCount := 0 135 + 136 + observer := messageSizeObserver{ 137 + onSingleFunc: func(messageSizeBytes uint64) {}, 138 + onFinishFunc: func(totalSizeBytes uint64) { 139 + onFinishCalledCount++ 140 + 141 + if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { 142 + t.Error("totalSizeBytes mismatch (-want +got):\n", diff) 143 + } 144 + }, 145 + } 146 + 147 + actualResult, err := unaryServerInterceptor(&observer, request, ctx, info, test.handler) 148 + if err != test.expectedError { 149 + t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) 150 + } 151 + 152 + if diff := cmp.Diff(test.expectedResult, actualResult, protocmp.Transform()); diff != "" { 153 + t.Error("response mismatch (-want +got):\n", diff) 154 + } 155 + 156 + if diff := cmp.Diff(1, onFinishCalledCount); diff != "" { 157 + t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 158 + } 159 + }) 160 + } 161 + } 162 + 163 + func TestStreamServerInterceptor(t *testing.T) { 164 + 165 + response1 := &newspb.BinaryAttachment{ 166 + Name: "", 167 + Data: []byte("response"), 168 + } 169 + response2 := &newspb.BinaryAttachment{ 170 + Name: "", 171 + Data: bytes.Repeat([]byte("response"), 3), 172 + } 173 + response3 := &newspb.BinaryAttachment{ 174 + Name: "", 175 + Data: bytes.Repeat([]byte("response"), 7), 176 + } 177 + 178 + info := &grpc.StreamServerInfo{ 179 + FullMethod: "news.v1.NewsService/GetArticle", 180 + } 181 + 182 + sentinelError := errors.New("expected error") 183 + 184 + tests := []struct { 185 + name string 186 + 187 + mockSendMsg func(m any) error 188 + handler func(srv any, stream grpc.ServerStream) error 189 + 190 + expectedError error 191 + expectedResponses []any 192 + expectedSize uint64 193 + }{ 194 + { 195 + name: "invoker successful - observe all 3 responses", 196 + 197 + mockSendMsg: func(m any) error { 198 + return nil // no error 199 + }, 200 + 201 + handler: func(srv any, stream grpc.ServerStream) error { 202 + for _, r := range []proto.Message{response1, response2, response3} { 203 + if err := stream.SendMsg(r); err != nil { 204 + return err 205 + } 206 + } 207 + 208 + return nil 209 + }, 210 + 211 + expectedError: nil, 212 + expectedResponses: []any{response1, response2, response3}, 213 + expectedSize: uint64(proto.Size(response1) + proto.Size(response2) + proto.Size(response3)), 214 + }, 215 + 216 + { 217 + name: "invoker fails on 3rd response - only observe first 2", 218 + 219 + mockSendMsg: func(m any) error { 220 + if m == response3 { 221 + return sentinelError 222 + } 223 + 224 + return nil 225 + }, 226 + handler: func(srv any, stream grpc.ServerStream) error { 227 + for _, r := range []proto.Message{response1, response2, response3} { 228 + if err := stream.SendMsg(r); err != nil { 229 + return err 230 + } 231 + } 232 + 233 + return nil 234 + }, 235 + 236 + expectedError: sentinelError, 237 + expectedResponses: []any{response1, response2, response3}, // response 3 should still be attempted to be sent 238 + expectedSize: uint64(proto.Size(response1) + proto.Size(response2)), // response 3 should not be counted since an error occurred while sending it 239 + }, 240 + 241 + { 242 + name: "invoker fails immediately - should still observe a zero-sized response", 243 + 244 + mockSendMsg: func(m any) error { 245 + return errors.New("should not be called") 246 + }, 247 + 248 + handler: func(srv any, stream grpc.ServerStream) error { 249 + return sentinelError 250 + }, 251 + 252 + expectedError: sentinelError, 253 + expectedResponses: []any{}, // there are no responses 254 + expectedSize: uint64(0), // there are no responses, so the size is 0 255 + }, 256 + } 257 + 258 + for _, test := range tests { 259 + t.Run(test.name, func(t *testing.T) { 260 + onFinishCallCount := 0 261 + 262 + observer := messageSizeObserver{ 263 + onSingleFunc: func(messageSizeBytes uint64) {}, 264 + onFinishFunc: func(totalSizeBytes uint64) { 265 + onFinishCallCount++ 266 + 267 + if totalSizeBytes != test.expectedSize { 268 + t.Errorf("totalSizeBytes mismatch (wanted: %d, got: %d)", test.expectedSize, totalSizeBytes) 269 + } 270 + }, 271 + } 272 + 273 + var actualResponses []any 274 + 275 + ss := &mockServerStream{ 276 + mockSendMsg: func(m any) error { 277 + actualResponses = append(actualResponses, m) 278 + 279 + return test.mockSendMsg(m) 280 + }, 281 + } 282 + 283 + err := streamServerInterceptor(&observer, nil, ss, info, test.handler) 284 + if err != test.expectedError { 285 + t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) 286 + } 287 + 288 + if diff := cmp.Diff(test.expectedResponses, actualResponses, protocmp.Transform(), cmpopts.EquateEmpty()); diff != "" { 289 + t.Error("responses mismatch (-want +got):\n", diff) 290 + } 291 + 292 + if diff := cmp.Diff(1, onFinishCallCount); diff != "" { 293 + t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 294 + } 295 + }) 296 + } 297 + } 298 + 299 + func TestUnaryClientInterceptor(t *testing.T) { 300 + ctx := context.Background() 301 + 302 + request := &newspb.BinaryAttachment{ 303 + Name: "data", 304 + Data: bytes.Repeat([]byte("request"), 3), 305 + } 306 + 307 + method := "news.v1.NewsService/GetArticle" 308 + 309 + sentinelError := errors.New("expected error") 310 + 311 + tests := []struct { 312 + name string 313 + invoker grpc.UnaryInvoker 314 + 315 + expectedError error 316 + expectedRequest any 317 + expectedSize uint64 318 + }{ 319 + { 320 + name: "invoker successful - observe request size", 321 + invoker: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 322 + return nil 323 + }, 324 + 325 + expectedError: nil, 326 + expectedRequest: request, 327 + expectedSize: uint64(proto.Size(request)), 328 + }, 329 + 330 + { 331 + name: "invoker error - observe a zero-sized response", 332 + invoker: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 333 + return sentinelError 334 + }, 335 + 336 + expectedError: sentinelError, 337 + expectedRequest: request, 338 + expectedSize: uint64(0), 339 + }, 340 + } 341 + 342 + for _, test := range tests { 343 + t.Run(test.name, func(t *testing.T) { 344 + onFinishCallCount := 0 345 + 346 + observer := messageSizeObserver{ 347 + onSingleFunc: func(messageSizeBytes uint64) {}, 348 + onFinishFunc: func(totalSizeBytes uint64) { 349 + onFinishCallCount++ 350 + 351 + if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { 352 + t.Error("totalSizeBytes mismatch (-want +got):\n", diff) 353 + } 354 + }, 355 + } 356 + 357 + var actualRequest any 358 + 359 + invokerCalled := false 360 + invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 361 + invokerCalled = true 362 + 363 + actualRequest = req 364 + return test.invoker(ctx, method, req, reply, cc, opts...) 365 + } 366 + 367 + err := unaryClientInterceptor(&observer, ctx, method, request, nil, nil, invoker) 368 + if err != test.expectedError { 369 + t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) 370 + } 371 + 372 + if !invokerCalled { 373 + t.Fatal("invoker not called") 374 + } 375 + 376 + if diff := cmp.Diff(test.expectedRequest, actualRequest, protocmp.Transform()); diff != "" { 377 + t.Error("request mismatch (-want +got):\n", diff) 378 + } 379 + 380 + if diff := cmp.Diff(1, onFinishCallCount); diff != "" { 381 + t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 382 + } 383 + }) 384 + } 385 + } 386 + 387 + func TestStreamingClientInterceptor(t *testing.T) { 388 + ctx := context.Background() 389 + 390 + request1 := &newspb.BinaryAttachment{ 391 + Name: "data", 392 + Data: bytes.Repeat([]byte("request"), 3), 393 + } 394 + 395 + request2 := &newspb.BinaryAttachment{ 396 + Name: "data", 397 + Data: bytes.Repeat([]byte("request"), 7), 398 + } 399 + 400 + request3 := &newspb.BinaryAttachment{ 401 + Name: "data", 402 + Data: bytes.Repeat([]byte("request"), 13), 403 + } 404 + 405 + method := "news.v1.NewsService/GetArticle" 406 + 407 + sentinelError := errors.New("expected error") 408 + 409 + type stepType int 410 + 411 + const ( 412 + stepSend stepType = iota 413 + stepRecv 414 + stepCloseSend 415 + ) 416 + 417 + type step struct { 418 + stepType stepType 419 + 420 + message any 421 + streamErr error 422 + } 423 + 424 + tests := []struct { 425 + name string 426 + 427 + steps []step 428 + expectedSize uint64 429 + }{ 430 + { 431 + name: "invoker successful - observe request size", 432 + steps: []step{ 433 + { 434 + stepType: stepSend, 435 + 436 + message: request1, 437 + streamErr: nil, 438 + }, 439 + { 440 + stepType: stepSend, 441 + 442 + message: request2, 443 + streamErr: nil, 444 + }, 445 + { 446 + stepType: stepSend, 447 + 448 + message: request3, 449 + streamErr: nil, 450 + }, 451 + { 452 + stepType: stepRecv, 453 + 454 + message: nil, 455 + streamErr: io.EOF, // end of stream 456 + }, 457 + }, 458 + 459 + expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), 460 + }, 461 + { 462 + name: "2nd send failed - stream aborts and should only observe first request", 463 + steps: []step{ 464 + { 465 + stepType: stepSend, 466 + message: request1, 467 + streamErr: nil, 468 + }, 469 + { 470 + stepType: stepSend, 471 + message: request2, 472 + streamErr: sentinelError, 473 + }, 474 + }, 475 + 476 + expectedSize: uint64(proto.Size(request1)), 477 + }, 478 + { 479 + name: "recv message fails with non io.EOF error - should still observe all requests", 480 + steps: []step{ 481 + { 482 + stepType: stepSend, 483 + 484 + message: request1, 485 + streamErr: nil, 486 + }, 487 + { 488 + stepType: stepSend, 489 + 490 + message: request2, 491 + streamErr: nil, 492 + }, 493 + { 494 + stepType: stepSend, 495 + 496 + message: request3, 497 + streamErr: nil, 498 + }, 499 + { 500 + stepType: stepRecv, 501 + 502 + message: nil, 503 + streamErr: sentinelError, 504 + }, 505 + }, 506 + 507 + expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), 508 + }, 509 + 510 + { 511 + name: "close send called - should observe all requests", 512 + steps: []step{ 513 + { 514 + stepType: stepSend, 515 + 516 + message: request1, 517 + streamErr: nil, 518 + }, 519 + { 520 + stepType: stepSend, 521 + 522 + message: request2, 523 + streamErr: nil, 524 + }, 525 + { 526 + stepType: stepSend, 527 + 528 + message: request3, 529 + streamErr: nil, 530 + }, 531 + { 532 + stepType: stepCloseSend, 533 + 534 + message: nil, 535 + streamErr: nil, 536 + }, 537 + }, 538 + 539 + expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), 540 + }, 541 + { 542 + name: "close send called immediately - should observe zero-sized response", 543 + steps: []step{ 544 + { 545 + stepType: stepCloseSend, 546 + 547 + message: nil, 548 + streamErr: nil, 549 + }, 550 + }, 551 + 552 + expectedSize: uint64(0), 553 + }, 554 + { 555 + name: "first send fails - stream should abort and observe zero-sized response", 556 + steps: []step{ 557 + { 558 + stepType: stepSend, 559 + 560 + message: request1, 561 + streamErr: sentinelError, 562 + }, 563 + }, 564 + 565 + expectedSize: uint64(0), 566 + }, 567 + } 568 + 569 + for _, test := range tests { 570 + t.Run(test.name, func(t *testing.T) { 571 + onFinishCallCount := 0 572 + 573 + observer := messageSizeObserver{ 574 + onSingleFunc: func(messageSizeBytes uint64) {}, 575 + onFinishFunc: func(totalSizeBytes uint64) { 576 + onFinishCallCount++ 577 + 578 + if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { 579 + t.Error("totalSizeBytes mismatch (-want +got):\n", diff) 580 + } 581 + }, 582 + } 583 + 584 + baseStream := &mockClientStream{} 585 + streamerCalled := false 586 + streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { 587 + streamerCalled = true 588 + 589 + return baseStream, nil 590 + } 591 + 592 + ss, err := streamClientInterceptor(&observer, ctx, nil, nil, method, streamer) 593 + require.NoError(t, err) 594 + 595 + // Run through all the steps, preparing the mockClientStream to return the expected errors 596 + for _, step := range test.steps { 597 + baseStreamCalled := false 598 + var streamErr error 599 + 600 + switch step.stepType { 601 + case stepSend: 602 + baseStream.mockSendMsg = func(m any) error { 603 + baseStreamCalled = true 604 + return step.streamErr 605 + } 606 + 607 + streamErr = ss.SendMsg(step.message) 608 + case stepRecv: 609 + baseStream.mockRecvMsg = func(_ any) error { 610 + baseStreamCalled = true 611 + return step.streamErr 612 + } 613 + 614 + streamErr = ss.RecvMsg(step.message) 615 + 616 + case stepCloseSend: 617 + baseStream.mockCloseSend = func() error { 618 + baseStreamCalled = true 619 + return step.streamErr 620 + } 621 + 622 + streamErr = ss.CloseSend() 623 + default: 624 + t.Fatalf("unknown step type: %v", step.stepType) 625 + } 626 + 627 + // ensure that the baseStream was called and errors are propagated 628 + require.True(t, baseStreamCalled) 629 + require.Equal(t, step.streamErr, streamErr) 630 + } 631 + 632 + if !streamerCalled { 633 + t.Fatal("streamer not called") 634 + } 635 + 636 + if diff := cmp.Diff(1, onFinishCallCount); diff != "" { 637 + t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) 638 + } 639 + }) 640 + } 641 + } 642 + 643 + func TestObserver(t *testing.T) { 644 + testCases := []struct { 645 + name string 646 + messages []proto.Message 647 + }{ 648 + { 649 + name: "single message", 650 + messages: []proto.Message{&newspb.BinaryAttachment{ 651 + Name: "data1", 652 + Data: []byte("sample data"), 653 + }}, 654 + }, 655 + { 656 + name: "multiple messages", 657 + messages: []proto.Message{ 658 + &newspb.BinaryAttachment{ 659 + Name: "data1", 660 + Data: []byte("sample data"), 661 + }, 662 + &newspb.KeyValueAttachment{ 663 + Name: "data2", 664 + Data: map[string]string{ 665 + "key1": "value1", 666 + "key2": "value2", 667 + }, 668 + }, 669 + }}, 670 + } 671 + 672 + for _, tc := range testCases { 673 + t.Run(tc.name, func(t *testing.T) { 674 + var singleMessageSizes []uint64 675 + var totalSize uint64 676 + 677 + // Create a new observer with custom onSingleFunc and onFinishFunc 678 + obs := &messageSizeObserver{ 679 + onSingleFunc: func(messageSizeBytes uint64) { 680 + singleMessageSizes = append(singleMessageSizes, messageSizeBytes) 681 + }, 682 + onFinishFunc: func(totalSizeBytes uint64) { 683 + totalSize = totalSizeBytes 684 + }, 685 + } 686 + 687 + // Call ObserveSingle for each message 688 + for _, msg := range tc.messages { 689 + obs.Observe(msg) 690 + } 691 + 692 + // Check that the singleMessageSizes are correct 693 + for i, msg := range tc.messages { 694 + expectedSize := uint64(proto.Size(msg)) 695 + require.Equal(t, expectedSize, singleMessageSizes[i]) 696 + } 697 + 698 + // Call FinishRPC 699 + obs.FinishRPC() 700 + 701 + // Check that the totalSize is correct 702 + expectedTotalSize := uint64(0) 703 + for _, size := range singleMessageSizes { 704 + expectedTotalSize += size 705 + } 706 + require.EqualValues(t, expectedTotalSize, totalSize) 707 + }) 708 + } 709 + } 710 + 711 + type mockServerStream struct { 712 + mockSendMsg func(m any) error 713 + 714 + grpc.ServerStream 715 + } 716 + 717 + func (s *mockServerStream) SendMsg(m any) error { 718 + if s.mockSendMsg != nil { 719 + return s.mockSendMsg(m) 720 + } 721 + 722 + return errors.New("send msg not implemented") 723 + } 724 + 725 + type mockClientStream struct { 726 + mockRecvMsg func(m any) error 727 + mockSendMsg func(m any) error 728 + mockCloseSend func() error 729 + 730 + grpc.ClientStream 731 + } 732 + 733 + func (s *mockClientStream) SendMsg(m any) error { 734 + if s.mockSendMsg != nil { 735 + return s.mockSendMsg(m) 736 + } 737 + 738 + return errors.New("send msg not implemented") 739 + } 740 + 741 + func (s *mockClientStream) RecvMsg(m any) error { 742 + if s.mockRecvMsg != nil { 743 + return s.mockRecvMsg(m) 744 + } 745 + 746 + return errors.New("recv msg not implemented") 747 + } 748 + 749 + func (s *mockClientStream) CloseSend() error { 750 + if s.mockCloseSend != nil { 751 + return s.mockCloseSend() 752 + } 753 + 754 + return errors.New("close send not implemented") 755 + } 756 + 757 + var _ grpc.ServerStream = &mockServerStream{} 758 + var _ grpc.ClientStream = &mockClientStream{}