···11+package messagesize
22+33+import (
44+ "context"
55+ "sync"
66+ "sync/atomic"
77+88+ "github.com/prometheus/client_golang/prometheus"
99+ "github.com/prometheus/client_golang/prometheus/promauto"
1010+ "github.com/sourcegraph/zoekt/grpc/grpcutil"
1111+ "google.golang.org/grpc"
1212+ "google.golang.org/protobuf/proto"
1313+)
1414+1515+var (
1616+ metricServerSingleMessageSize = promauto.NewHistogramVec(prometheus.HistogramOpts{
1717+ Name: "grpc_server_sent_individual_message_size_bytes_per_rpc",
1818+ Help: "Size of individual messages sent by the server per RPC.",
1919+ Buckets: sizeBuckets,
2020+ }, []string{
2121+ "grpc_service", // e.g. "gitserver.v1.GitserverService"
2222+ "grpc_method", // e.g. "Exec"
2323+ })
2424+2525+ metricServerTotalSentPerRPCBytes = promauto.NewHistogramVec(prometheus.HistogramOpts{
2626+ Name: "grpc_server_sent_bytes_per_rpc",
2727+ Help: "Total size of all the messages sent by the server during the course of a single RPC call",
2828+ Buckets: sizeBuckets,
2929+ }, []string{
3030+ "grpc_service", // e.g. "gitserver.v1.GitserverService"
3131+ "grpc_method", // e.g. "Exec"
3232+ })
3333+3434+ metricClientSingleMessageSize = promauto.NewHistogramVec(prometheus.HistogramOpts{
3535+ Name: "grpc_client_sent_individual_message_size_per_rpc_bytes",
3636+ Help: "Size of individual messages sent by the client per RPC.",
3737+ Buckets: sizeBuckets,
3838+ }, []string{
3939+ "grpc_service", // e.g. "gitserver.v1.GitserverService"
4040+ "grpc_method", // e.g. "Exec"
4141+ })
4242+4343+ metricClientTotalSentPerRPCBytes = promauto.NewHistogramVec(prometheus.HistogramOpts{
4444+ Name: "grpc_client_sent_bytes_per_rpc",
4545+ Help: "Total size of all the messages sent by the client during the course of a single RPC call",
4646+ Buckets: sizeBuckets,
4747+ }, []string{
4848+ "grpc_service", // e.g. "gitserver.v1.GitserverService"
4949+ "grpc_method", // e.g. "Exec"
5050+ })
5151+)
5252+5353+const (
5454+ B = 1
5555+ KB = 1024 * B
5656+ MB = 1024 * KB
5757+ GB = 1024 * MB
5858+)
5959+6060+var sizeBuckets = []float64{
6161+ 0,
6262+ 1 * KB,
6363+ 10 * KB,
6464+ 50 * KB,
6565+ 100 * KB,
6666+ 500 * KB,
6767+ 1 * MB,
6868+ 5 * MB,
6969+ 10 * MB,
7070+ 50 * MB,
7171+ 100 * MB,
7272+ 500 * MB,
7373+ 1 * GB,
7474+ 5 * GB,
7575+ 10 * GB,
7676+}
7777+7878+// UnaryServerInterceptor is a grpc.UnaryServerInterceptor that records Prometheus metrics that observe the size of
7979+// the response message sent back by the server for a single RPC call.
8080+func UnaryServerInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
8181+ observer := newServerMessageSizeObserver(info.FullMethod)
8282+8383+ return unaryServerInterceptor(observer, req, ctx, info, handler)
8484+}
8585+8686+func unaryServerInterceptor(observer *messageSizeObserver, req any, ctx context.Context, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
8787+ defer observer.FinishRPC()
8888+8989+ r, err := handler(ctx, req)
9090+ if err != nil {
9191+ return r, err
9292+ }
9393+9494+ response, ok := r.(proto.Message)
9595+ if !ok {
9696+ return r, nil
9797+ }
9898+9999+ observer.Observe(response)
100100+ return response, nil
101101+}
102102+103103+// StreamServerInterceptor is a grpc.StreamServerInterceptor that records Prometheus metrics that observe both the sizes of the
104104+// individual response messages and the cumulative response size of all the message sent back by the server over the course
105105+// of a single RPC call.
106106+func StreamServerInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
107107+ observer := newServerMessageSizeObserver(info.FullMethod)
108108+109109+ return streamServerInterceptor(observer, srv, ss, info, handler)
110110+}
111111+112112+func streamServerInterceptor(observer *messageSizeObserver, srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
113113+ defer observer.FinishRPC()
114114+115115+ wrappedStream := newObservingServerStream(ss, observer)
116116+117117+ return handler(srv, wrappedStream)
118118+}
119119+120120+// UnaryClientInterceptor is a grpc.UnaryClientInterceptor that records Prometheus metrics that observe the size of
121121+// the request message sent by client for a single RPC call.
122122+func UnaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
123123+ o := newClientMessageSizeObserver(method)
124124+ return unaryClientInterceptor(o, ctx, method, req, reply, cc, invoker, opts...)
125125+}
126126+127127+func unaryClientInterceptor(observer *messageSizeObserver, ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
128128+ defer observer.FinishRPC()
129129+130130+ err := invoker(ctx, method, req, reply, cc, opts...)
131131+ if err != nil {
132132+ // Don't record the size of the message if there was an error sending it, since it may not have been sent.
133133+ return err
134134+ }
135135+136136+ // Observe the size of the request message.
137137+ request, ok := req.(proto.Message)
138138+ if !ok {
139139+ return nil
140140+ }
141141+142142+ observer.Observe(request)
143143+ return nil
144144+}
145145+146146+// StreamClientInterceptor is a grpc.StreamClientInterceptor that records Prometheus metrics that observe both the sizes of the
147147+// individual request messages and the cumulative request size of all the message sent by the client over the course
148148+// of a single RPC call.
149149+func StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
150150+ observer := newClientMessageSizeObserver(method)
151151+152152+ return streamClientInterceptor(observer, ctx, desc, cc, method, streamer, opts...)
153153+}
154154+155155+func streamClientInterceptor(observer *messageSizeObserver, ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
156156+ s, err := streamer(ctx, desc, cc, method, opts...)
157157+ if err != nil {
158158+ return nil, err
159159+ }
160160+161161+ wrappedStream := newObservingClientStream(s, observer)
162162+ return wrappedStream, nil
163163+}
164164+165165+type observingServerStream struct {
166166+ grpc.ServerStream
167167+168168+ observer *messageSizeObserver
169169+}
170170+171171+func newObservingServerStream(s grpc.ServerStream, observer *messageSizeObserver) grpc.ServerStream {
172172+ return &observingServerStream{
173173+ ServerStream: s,
174174+ observer: observer,
175175+ }
176176+}
177177+178178+func (s *observingServerStream) SendMsg(m any) error {
179179+ err := s.ServerStream.SendMsg(m)
180180+ if err != nil {
181181+ // Don't record the size of the message if there was an error sending it, since it may not have been sent.
182182+ //
183183+ // However, the stream aborts on an error,
184184+ // so we need to record the total size of the messages sent during the course of the RPC call.
185185+ s.observer.FinishRPC()
186186+ return err
187187+ }
188188+189189+ // Observe the size of the sent message.
190190+ message, ok := m.(proto.Message)
191191+ if !ok {
192192+ return nil
193193+ }
194194+195195+ s.observer.Observe(message)
196196+ return nil
197197+}
198198+199199+type observingClientStream struct {
200200+ grpc.ClientStream
201201+202202+ observer *messageSizeObserver
203203+}
204204+205205+func newObservingClientStream(s grpc.ClientStream, observer *messageSizeObserver) grpc.ClientStream {
206206+ return &observingClientStream{
207207+ ClientStream: s,
208208+ observer: observer,
209209+ }
210210+}
211211+212212+func (s *observingClientStream) SendMsg(m any) error {
213213+ err := s.ClientStream.SendMsg(m)
214214+ if err != nil {
215215+ // Don't record the size of the message if there was an error sending it, since it may not have been sent.
216216+ //
217217+ // However, the stream aborts on an error,
218218+ // so we need to record the total size of the messages sent during the course of the RPC call.
219219+ s.observer.FinishRPC()
220220+ return err
221221+ }
222222+223223+ // Observe the size of the sent message.
224224+ message, ok := m.(proto.Message)
225225+ if !ok {
226226+ return nil
227227+ }
228228+229229+ s.observer.Observe(message)
230230+ return nil
231231+}
232232+233233+func (s *observingClientStream) CloseSend() error {
234234+ err := s.ClientStream.CloseSend()
235235+236236+ s.observer.FinishRPC()
237237+ return err
238238+}
239239+240240+func (s *observingClientStream) RecvMsg(m any) error {
241241+ err := s.ClientStream.RecvMsg(m)
242242+ if err != nil {
243243+ // Record the total size of the messages sent during the course of the RPC call, even if there was an error.
244244+ s.observer.FinishRPC()
245245+ }
246246+247247+ return err
248248+}
249249+250250+func newServerMessageSizeObserver(fullMethod string) *messageSizeObserver {
251251+ serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
252252+253253+ onSingle := func(messageSize uint64) {
254254+ metricServerSingleMessageSize.WithLabelValues(serviceName, methodName).Observe(float64(messageSize))
255255+ }
256256+257257+ onFinish := func(messageSize uint64) {
258258+ metricServerTotalSentPerRPCBytes.WithLabelValues(serviceName, methodName).Observe(float64(messageSize))
259259+ }
260260+261261+ return &messageSizeObserver{
262262+ onSingleFunc: onSingle,
263263+ onFinishFunc: onFinish,
264264+ }
265265+}
266266+267267+func newClientMessageSizeObserver(fullMethod string) *messageSizeObserver {
268268+ serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
269269+270270+ onSingle := func(messageSize uint64) {
271271+ metricClientSingleMessageSize.WithLabelValues(serviceName, methodName).Observe(float64(messageSize))
272272+ }
273273+274274+ onFinish := func(messageSize uint64) {
275275+ metricClientTotalSentPerRPCBytes.WithLabelValues(serviceName, methodName).Observe(float64(messageSize))
276276+ }
277277+278278+ return &messageSizeObserver{
279279+ onSingleFunc: onSingle,
280280+ onFinishFunc: onFinish,
281281+ }
282282+}
283283+284284+// messageSizeObserver is a utility that records Prometheus metrics that observe the size of each sent message and the
285285+// cumulative size of all sent messages during the course of a single RPC call.
286286+type messageSizeObserver struct {
287287+ onSingleFunc func(messageSizeBytes uint64)
288288+289289+ finishOnce sync.Once
290290+ onFinishFunc func(totalSizeBytes uint64)
291291+292292+ totalSizeBytes atomic.Uint64
293293+}
294294+295295+// Observe records the size of a single message.
296296+func (o *messageSizeObserver) Observe(message proto.Message) {
297297+ s := uint64(proto.Size(message))
298298+ o.onSingleFunc(s)
299299+300300+ o.totalSizeBytes.Add(s)
301301+}
302302+303303+// FinishRPC records the total size of all sent messages during the course of a single RPC call.
304304+// This function should only be called once the RPC call has completed.
305305+func (o *messageSizeObserver) FinishRPC() {
306306+ o.finishOnce.Do(func() {
307307+ o.onFinishFunc(o.totalSizeBytes.Load())
308308+ })
309309+}
310310+311311+var (
312312+ _ grpc.ServerStream = &observingServerStream{}
313313+ _ grpc.ClientStream = &observingClientStream{}
314314+)
315315+316316+var (
317317+ _ grpc.UnaryServerInterceptor = UnaryServerInterceptor
318318+ _ grpc.StreamServerInterceptor = StreamServerInterceptor
319319+ _ grpc.UnaryClientInterceptor = UnaryClientInterceptor
320320+ _ grpc.StreamClientInterceptor = StreamClientInterceptor
321321+)