fork of https://github.com/sourcegraph/zoekt
1package messagesize
2
3import (
4 "context"
5 "sync"
6 "sync/atomic"
7
8 "github.com/prometheus/client_golang/prometheus"
9 "github.com/prometheus/client_golang/prometheus/promauto"
10 "github.com/sourcegraph/zoekt/grpc/grpcutil"
11 "google.golang.org/grpc"
12 "google.golang.org/protobuf/proto"
13)
14
15var (
16 metricServerSingleMessageSize = promauto.NewHistogramVec(prometheus.HistogramOpts{
17 Name: "grpc_server_sent_individual_message_size_bytes_per_rpc",
18 Help: "Size of individual messages sent by the server per RPC.",
19 Buckets: sizeBuckets,
20 }, []string{
21 "grpc_service", // e.g. "gitserver.v1.GitserverService"
22 "grpc_method", // e.g. "Exec"
23 })
24
25 metricServerTotalSentPerRPCBytes = promauto.NewHistogramVec(prometheus.HistogramOpts{
26 Name: "grpc_server_sent_bytes_per_rpc",
27 Help: "Total size of all the messages sent by the server during the course of a single RPC call",
28 Buckets: sizeBuckets,
29 }, []string{
30 "grpc_service", // e.g. "gitserver.v1.GitserverService"
31 "grpc_method", // e.g. "Exec"
32 })
33
34 metricClientSingleMessageSize = promauto.NewHistogramVec(prometheus.HistogramOpts{
35 Name: "grpc_client_sent_individual_message_size_per_rpc_bytes",
36 Help: "Size of individual messages sent by the client per RPC.",
37 Buckets: sizeBuckets,
38 }, []string{
39 "grpc_service", // e.g. "gitserver.v1.GitserverService"
40 "grpc_method", // e.g. "Exec"
41 })
42
43 metricClientTotalSentPerRPCBytes = promauto.NewHistogramVec(prometheus.HistogramOpts{
44 Name: "grpc_client_sent_bytes_per_rpc",
45 Help: "Total size of all the messages sent by the client during the course of a single RPC call",
46 Buckets: sizeBuckets,
47 }, []string{
48 "grpc_service", // e.g. "gitserver.v1.GitserverService"
49 "grpc_method", // e.g. "Exec"
50 })
51)
52
53const (
54 B = 1
55 KB = 1024 * B
56 MB = 1024 * KB
57 GB = 1024 * MB
58)
59
60var sizeBuckets = []float64{
61 0,
62 1 * KB,
63 10 * KB,
64 50 * KB,
65 100 * KB,
66 500 * KB,
67 1 * MB,
68 5 * MB,
69 10 * MB,
70 50 * MB,
71 100 * MB,
72 500 * MB,
73 1 * GB,
74 5 * GB,
75 10 * GB,
76}
77
78// UnaryServerInterceptor is a grpc.UnaryServerInterceptor that records Prometheus metrics that observe the size of
79// the response message sent back by the server for a single RPC call.
80func UnaryServerInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
81 observer := newServerMessageSizeObserver(info.FullMethod)
82
83 return unaryServerInterceptor(observer, req, ctx, info, handler)
84}
85
86func unaryServerInterceptor(observer *messageSizeObserver, req any, ctx context.Context, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
87 defer observer.FinishRPC()
88
89 r, err := handler(ctx, req)
90 if err != nil {
91 return r, err
92 }
93
94 response, ok := r.(proto.Message)
95 if !ok {
96 return r, nil
97 }
98
99 observer.Observe(response)
100 return response, nil
101}
102
103// StreamServerInterceptor is a grpc.StreamServerInterceptor that records Prometheus metrics that observe both the sizes of the
104// individual response messages and the cumulative response size of all the message sent back by the server over the course
105// of a single RPC call.
106func StreamServerInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
107 observer := newServerMessageSizeObserver(info.FullMethod)
108
109 return streamServerInterceptor(observer, srv, ss, info, handler)
110}
111
112func streamServerInterceptor(observer *messageSizeObserver, srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
113 defer observer.FinishRPC()
114
115 wrappedStream := newObservingServerStream(ss, observer)
116
117 return handler(srv, wrappedStream)
118}
119
120// UnaryClientInterceptor is a grpc.UnaryClientInterceptor that records Prometheus metrics that observe the size of
121// the request message sent by client for a single RPC call.
122func UnaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
123 o := newClientMessageSizeObserver(method)
124 return unaryClientInterceptor(o, ctx, method, req, reply, cc, invoker, opts...)
125}
126
127func unaryClientInterceptor(observer *messageSizeObserver, ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
128 defer observer.FinishRPC()
129
130 err := invoker(ctx, method, req, reply, cc, opts...)
131 if err != nil {
132 // Don't record the size of the message if there was an error sending it, since it may not have been sent.
133 return err
134 }
135
136 // Observe the size of the request message.
137 request, ok := req.(proto.Message)
138 if !ok {
139 return nil
140 }
141
142 observer.Observe(request)
143 return nil
144}
145
146// StreamClientInterceptor is a grpc.StreamClientInterceptor that records Prometheus metrics that observe both the sizes of the
147// individual request messages and the cumulative request size of all the message sent by the client over the course
148// of a single RPC call.
149func StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
150 observer := newClientMessageSizeObserver(method)
151
152 return streamClientInterceptor(observer, ctx, desc, cc, method, streamer, opts...)
153}
154
155func streamClientInterceptor(observer *messageSizeObserver, ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
156 s, err := streamer(ctx, desc, cc, method, opts...)
157 if err != nil {
158 return nil, err
159 }
160
161 wrappedStream := newObservingClientStream(s, observer)
162 return wrappedStream, nil
163}
164
165type observingServerStream struct {
166 grpc.ServerStream
167
168 observer *messageSizeObserver
169}
170
171func newObservingServerStream(s grpc.ServerStream, observer *messageSizeObserver) grpc.ServerStream {
172 return &observingServerStream{
173 ServerStream: s,
174 observer: observer,
175 }
176}
177
178func (s *observingServerStream) SendMsg(m any) error {
179 err := s.ServerStream.SendMsg(m)
180 if err != nil {
181 // Don't record the size of the message if there was an error sending it, since it may not have been sent.
182 //
183 // However, the stream aborts on an error,
184 // so we need to record the total size of the messages sent during the course of the RPC call.
185 s.observer.FinishRPC()
186 return err
187 }
188
189 // Observe the size of the sent message.
190 message, ok := m.(proto.Message)
191 if !ok {
192 return nil
193 }
194
195 s.observer.Observe(message)
196 return nil
197}
198
199type observingClientStream struct {
200 grpc.ClientStream
201
202 observer *messageSizeObserver
203}
204
205func newObservingClientStream(s grpc.ClientStream, observer *messageSizeObserver) grpc.ClientStream {
206 return &observingClientStream{
207 ClientStream: s,
208 observer: observer,
209 }
210}
211
212func (s *observingClientStream) SendMsg(m any) error {
213 err := s.ClientStream.SendMsg(m)
214 if err != nil {
215 // Don't record the size of the message if there was an error sending it, since it may not have been sent.
216 //
217 // However, the stream aborts on an error,
218 // so we need to record the total size of the messages sent during the course of the RPC call.
219 s.observer.FinishRPC()
220 return err
221 }
222
223 // Observe the size of the sent message.
224 message, ok := m.(proto.Message)
225 if !ok {
226 return nil
227 }
228
229 s.observer.Observe(message)
230 return nil
231}
232
233func (s *observingClientStream) CloseSend() error {
234 err := s.ClientStream.CloseSend()
235
236 s.observer.FinishRPC()
237 return err
238}
239
240func (s *observingClientStream) RecvMsg(m any) error {
241 err := s.ClientStream.RecvMsg(m)
242 if err != nil {
243 // Record the total size of the messages sent during the course of the RPC call, even if there was an error.
244 s.observer.FinishRPC()
245 }
246
247 return err
248}
249
250func newServerMessageSizeObserver(fullMethod string) *messageSizeObserver {
251 serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
252
253 onSingle := func(messageSize uint64) {
254 metricServerSingleMessageSize.WithLabelValues(serviceName, methodName).Observe(float64(messageSize))
255 }
256
257 onFinish := func(messageSize uint64) {
258 metricServerTotalSentPerRPCBytes.WithLabelValues(serviceName, methodName).Observe(float64(messageSize))
259 }
260
261 return &messageSizeObserver{
262 onSingleFunc: onSingle,
263 onFinishFunc: onFinish,
264 }
265}
266
267func newClientMessageSizeObserver(fullMethod string) *messageSizeObserver {
268 serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
269
270 onSingle := func(messageSize uint64) {
271 metricClientSingleMessageSize.WithLabelValues(serviceName, methodName).Observe(float64(messageSize))
272 }
273
274 onFinish := func(messageSize uint64) {
275 metricClientTotalSentPerRPCBytes.WithLabelValues(serviceName, methodName).Observe(float64(messageSize))
276 }
277
278 return &messageSizeObserver{
279 onSingleFunc: onSingle,
280 onFinishFunc: onFinish,
281 }
282}
283
284// messageSizeObserver is a utility that records Prometheus metrics that observe the size of each sent message and the
285// cumulative size of all sent messages during the course of a single RPC call.
286type messageSizeObserver struct {
287 onSingleFunc func(messageSizeBytes uint64)
288
289 finishOnce sync.Once
290 onFinishFunc func(totalSizeBytes uint64)
291
292 totalSizeBytes atomic.Uint64
293}
294
295// Observe records the size of a single message.
296func (o *messageSizeObserver) Observe(message proto.Message) {
297 s := uint64(proto.Size(message))
298 o.onSingleFunc(s)
299
300 o.totalSizeBytes.Add(s)
301}
302
303// FinishRPC records the total size of all sent messages during the course of a single RPC call.
304// This function should only be called once the RPC call has completed.
305func (o *messageSizeObserver) FinishRPC() {
306 o.finishOnce.Do(func() {
307 o.onFinishFunc(o.totalSizeBytes.Load())
308 })
309}
310
311var (
312 _ grpc.ServerStream = &observingServerStream{}
313 _ grpc.ClientStream = &observingClientStream{}
314)
315
316var (
317 _ grpc.UnaryServerInterceptor = UnaryServerInterceptor
318 _ grpc.StreamServerInterceptor = StreamServerInterceptor
319 _ grpc.UnaryClientInterceptor = UnaryClientInterceptor
320 _ grpc.StreamClientInterceptor = StreamClientInterceptor
321)