···11+package internalerrs
22+33+import (
44+ "context"
55+ "errors"
66+ "fmt"
77+ "os"
88+ "strconv"
99+ "strings"
1010+ "sync"
1111+ "sync/atomic"
1212+ "unicode/utf8"
1313+1414+ "github.com/dustin/go-humanize"
1515+ "google.golang.org/protobuf/proto"
1616+ "google.golang.org/protobuf/reflect/protopath"
1717+ "google.golang.org/protobuf/reflect/protorange"
1818+1919+ "google.golang.org/grpc"
2020+ "google.golang.org/grpc/codes"
2121+ "google.golang.org/grpc/status"
2222+)
2323+2424+// callBackClientStream is a grpc.ClientStream that calls a function after SendMsg and RecvMsg.
2525+type callBackClientStream struct {
2626+ grpc.ClientStream
2727+2828+ postMessageSend func(message any, err error)
2929+ postMessageReceive func(message any, err error)
3030+}
3131+3232+func (c *callBackClientStream) SendMsg(m any) error {
3333+ err := c.ClientStream.SendMsg(m)
3434+ if c.postMessageSend != nil {
3535+ c.postMessageSend(m, err)
3636+ }
3737+3838+ return err
3939+}
4040+4141+func (c *callBackClientStream) RecvMsg(m any) error {
4242+ err := c.ClientStream.RecvMsg(m)
4343+ if c.postMessageReceive != nil {
4444+ c.postMessageReceive(m, err)
4545+ }
4646+4747+ return err
4848+}
4949+5050+var _ grpc.ClientStream = &callBackClientStream{}
5151+5252+// requestSavingClientStream is a grpc.ClientStream that saves the initial request sent to the server.
5353+type requestSavingClientStream struct {
5454+ grpc.ClientStream
5555+5656+ initialRequest atomic.Pointer[proto.Message]
5757+ saveRequestOnce sync.Once
5858+}
5959+6060+func (c *requestSavingClientStream) SendMsg(m any) error {
6161+ c.saveRequestOnce.Do(func() {
6262+ message, ok := m.(proto.Message)
6363+ if !ok {
6464+ return
6565+ }
6666+6767+ c.initialRequest.Store(&message)
6868+ })
6969+7070+ return c.ClientStream.SendMsg(m)
7171+}
7272+7373+// InitialRequest returns the initial request sent by the client on the stream.
7474+func (c *requestSavingClientStream) InitialRequest() *proto.Message {
7575+ return c.initialRequest.Load()
7676+}
7777+7878+var _ grpc.ClientStream = &requestSavingClientStream{}
7979+8080+// requestSavingServerStream is a grpc.ServerStream that saves the initial request sent by the client.
8181+type requestSavingServerStream struct {
8282+ grpc.ServerStream
8383+8484+ initialRequest atomic.Pointer[proto.Message]
8585+ saveRequestOnce sync.Once
8686+}
8787+8888+func (s *requestSavingServerStream) RecvMsg(m any) error {
8989+ s.saveRequestOnce.Do(func() {
9090+ message, ok := m.(proto.Message)
9191+ if !ok {
9292+ return
9393+ }
9494+9595+ s.initialRequest.Store(&message)
9696+ })
9797+9898+ return s.ServerStream.RecvMsg(m)
9999+}
100100+101101+// InitialRequest returns the initial request sent by the client on the stream.
102102+func (s *requestSavingServerStream) InitialRequest() *proto.Message {
103103+ return s.initialRequest.Load()
104104+}
105105+106106+var _ grpc.ServerStream = &requestSavingServerStream{}
107107+108108+// callBackServerStream is a grpc.ServerStream that calls a function after SendMsg and RecvMsg.
109109+type callBackServerStream struct {
110110+ grpc.ServerStream
111111+112112+ postMessageSend func(message any, err error)
113113+ postMessageReceive func(message any, err error)
114114+}
115115+116116+func (c *callBackServerStream) SendMsg(m any) error {
117117+ err := c.ServerStream.SendMsg(m)
118118+119119+ if c.postMessageSend != nil {
120120+ c.postMessageSend(m, err)
121121+ }
122122+123123+ return err
124124+}
125125+126126+func (c *callBackServerStream) RecvMsg(m any) error {
127127+ err := c.ServerStream.RecvMsg(m)
128128+129129+ if c.postMessageReceive != nil {
130130+ c.postMessageReceive(m, err)
131131+ }
132132+133133+ return err
134134+}
135135+136136+var _ grpc.ServerStream = &callBackServerStream{}
137137+138138+// probablyInternalGRPCError checks if a gRPC status likely represents an error that comes from
139139+// the go-grpc library.
140140+//
141141+// Note: this is a heuristic and may not be 100% accurate.
142142+// From a cursory glance at the go-grpc source code, it seems most errors are prefixed with "grpc:". This may break in the future, but
143143+// it's better than nothing.
144144+// Some other ad-hoc errors that we traced back to the go-grpc library are also checked for.
145145+func probablyInternalGRPCError(s *status.Status, checkers []internalGRPCErrorChecker) bool {
146146+ if s.Code() == codes.OK {
147147+ return false
148148+ }
149149+150150+ for _, checker := range checkers {
151151+ if checker(s) {
152152+ return true
153153+ }
154154+ }
155155+156156+ return false
157157+}
158158+159159+// internalGRPCErrorChecker is a function that checks if a gRPC status likely represents an error that comes from
160160+// the go-grpc library.
161161+type internalGRPCErrorChecker func(*status.Status) bool
162162+163163+// allCheckers is a list of functions that check if a gRPC status likely represents an
164164+// error that comes from the go-grpc library.
165165+var allCheckers = []internalGRPCErrorChecker{
166166+ gRPCPrefixChecker,
167167+ gRPCResourceExhaustedChecker,
168168+ gRPCUnexpectedContentTypeChecker,
169169+}
170170+171171+// gRPCPrefixChecker checks if a gRPC status likely represents an error that comes from the go-grpc library, by checking if the error message
172172+// is prefixed with "grpc: ".
173173+func gRPCPrefixChecker(s *status.Status) bool {
174174+ return s.Code() != codes.OK && strings.HasPrefix(s.Message(), "grpc: ")
175175+}
176176+177177+// gRPCResourceExhaustedChecker checks if a gRPC status likely represents an error that comes from the go-grpc library, by checking if the error message
178178+// is prefixed with "trying to send message larger than max".
179179+func gRPCResourceExhaustedChecker(s *status.Status) bool {
180180+ // Observed from https://github.com/grpc/grpc-go/blob/756119c7de49e91b6f3b9d693b9850e1598938eb/stream.go#L884
181181+ return s.Code() == codes.ResourceExhausted && strings.HasPrefix(s.Message(), "trying to send message larger than max (")
182182+}
183183+184184+// gRPCUnexpectedContentTypeChecker checks if a gRPC status likely represents an error that comes from the go-grpc library, by checking if the error message
185185+// is prefixed with "transport: received unexpected content-type".
186186+func gRPCUnexpectedContentTypeChecker(s *status.Status) bool {
187187+ // Observed from https://github.com/grpc/grpc-go/blob/2997e84fd8d18ddb000ac6736129b48b3c9773ec/internal/transport/http2_client.go#L1415-L1417
188188+ return s.Code() != codes.OK && strings.Contains(s.Message(), "transport: received unexpected content-type")
189189+}
190190+191191+// findNonUTF8StringFields returns a list of field names that contain invalid UTF-8 strings
192192+// in the given proto message.
193193+//
194194+// Example: ["author", "attachments[1].key_value_attachment.data["key2"]`]
195195+func findNonUTF8StringFields(m proto.Message) ([]string, error) {
196196+ if m == nil {
197197+ return nil, nil
198198+ }
199199+200200+ var fields []string
201201+ err := protorange.Range(m.ProtoReflect(), func(p protopath.Values) error {
202202+ last := p.Index(-1)
203203+ s, ok := last.Value.Interface().(string)
204204+ if ok && !utf8.ValidString(s) {
205205+ fieldName := p.Path[1:].String()
206206+ fields = append(fields, strings.TrimPrefix(fieldName, "."))
207207+ }
208208+209209+ return nil
210210+ })
211211+212212+ if err != nil {
213213+ return nil, fmt.Errorf("iterating over proto message: %w", err)
214214+ }
215215+216216+ return fields, nil
217217+}
218218+219219+// massageIntoStatusErr converts an error into a status.Status if possible.
220220+func massageIntoStatusErr(err error) (s *status.Status, ok bool) {
221221+ if err == nil {
222222+ return nil, false
223223+ }
224224+225225+ if s, ok := status.FromError(err); ok {
226226+ return s, true
227227+ }
228228+229229+ if errors.Is(err, context.Canceled) {
230230+ return status.New(codes.Canceled, context.Canceled.Error()), true
231231+232232+ }
233233+234234+ if errors.Is(err, context.DeadlineExceeded) {
235235+ return status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()), true
236236+ }
237237+238238+ return nil, false
239239+}
240240+241241+func envMustGetBool(key string, defaultValue bool) bool {
242242+ rawValue, ok := os.LookupEnv(key)
243243+ if !ok {
244244+ return defaultValue
245245+ }
246246+247247+ value, err := strconv.ParseBool(rawValue)
248248+ if err != nil {
249249+ panic(fmt.Sprintf("Failed to parse enviroment variable %q as valid boolean. Got %q. Err: %s", key, rawValue, err))
250250+ }
251251+252252+ return value
253253+}
254254+255255+func envMustGetBytes(key string, defaultByteSize string) uint64 {
256256+ defaultByteSizeValue, err := humanize.ParseBytes(defaultByteSize)
257257+ if err != nil {
258258+ panic(fmt.Sprintf("Failed to parse default byte size %q as valid byte size. Err: %s", defaultByteSize, err))
259259+ }
260260+261261+ rawValue, ok := os.LookupEnv(key)
262262+ if !ok {
263263+ return defaultByteSizeValue
264264+ }
265265+266266+ value, err := humanize.ParseBytes(rawValue)
267267+ if err != nil {
268268+ panic(fmt.Sprintf("Failed to parse enviroment variable %q as valid byte size. Got %q. Err: %s", key, rawValue, err))
269269+ }
270270+271271+ return value
272272+}
···11+package internalerrs
22+33+import (
44+ "context"
55+ "encoding/json"
66+ "fmt"
77+ "io"
88+ "strings"
99+1010+ "github.com/dustin/go-humanize"
1111+ "github.com/sourcegraph/zoekt/grpc/grpcutil"
1212+1313+ "google.golang.org/grpc/codes"
1414+ "google.golang.org/protobuf/proto"
1515+1616+ "github.com/sourcegraph/log"
1717+ "google.golang.org/grpc"
1818+ "google.golang.org/grpc/status"
1919+)
2020+2121+var (
2222+ logScope = "gRPC.internal.error.reporter"
2323+ logDescription = "logs gRPC errors that appear to come from the go-grpc implementation"
2424+2525+ envLoggingEnabled = envMustGetBool("GRPC_INTERNAL_ERROR_LOGGING_ENABLED", true) // "Enables logging of gRPC internal errors"
2626+ envLogStackTracesEnabled = envMustGetBool("GRPC_INTERNAL_ERROR_LOGGING_LOG_STACK_TRACES", false) // "Enables including stack traces in logs of gRPC internal errors"
2727+2828+ envLogMessagesEnabled = envMustGetBool("GRPC_INTERNAL_ERROR_LOGGING_LOG_PROTOBUF_MESSAGES_ENABLED", false) // "Enables inclusion of raw protobuf messages in the gRPC internal error logs"
2929+ envLogMessagesHandleMaxMessageSizeBytes = envMustGetBytes("GRPC_INTERNAL_ERROR_LOGGING_LOG_PROTOBUF_MESSAGES_HANDLING_MAX_MESSAGE_SIZE_BYTES", "100MB") // "Maximum size of protobuf messages that can be included in gRPC internal error logs. The purpose of this is to avoid excessive allocations. 0 bytes mean no limit.")
3030+ envLogMessagesMaxJSONSizeBytes = envMustGetBytes("GRPC_INTERNAL_ERROR_LOGGING_LOG_PROTOBUF_MESSAGES_JSON_TRUNCATION_SIZE_BYTES", "1KB") // "Maximum size of the JSON representation of protobuf messages to log. JSON representations larger than this value will be truncated. 0 bytes disables truncation.")
3131+)
3232+3333+// LoggingUnaryClientInterceptor returns a grpc.UnaryClientInterceptor that logs
3434+// errors that appear to come from the go-grpc implementation.
3535+func LoggingUnaryClientInterceptor(l log.Logger) grpc.UnaryClientInterceptor {
3636+ if !envLoggingEnabled {
3737+ // Just return the default invoker if logging is disabled.
3838+ return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
3939+ return invoker(ctx, method, req, reply, cc, opts...)
4040+ }
4141+ }
4242+4343+ logger := l.Scoped(logScope, logDescription)
4444+ logger = logger.Scoped("unaryMethod", "errors that originated from a unary method")
4545+4646+ return func(ctx context.Context, fullMethod string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
4747+ err := invoker(ctx, fullMethod, req, reply, cc, opts...)
4848+ if err != nil {
4949+ serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
5050+5151+ var initialRequest proto.Message
5252+ if m, ok := req.(proto.Message); ok {
5353+ initialRequest = m
5454+ }
5555+5656+ doLog(logger, serviceName, methodName, &initialRequest, req, err)
5757+ }
5858+5959+ return err
6060+ }
6161+}
6262+6363+// LoggingStreamClientInterceptor returns a grpc.StreamClientInterceptor that logs
6464+// errors that appear to come from the go-grpc implementation.
6565+func LoggingStreamClientInterceptor(l log.Logger) grpc.StreamClientInterceptor {
6666+ if !envLoggingEnabled {
6767+ // Just return the default streamer if logging is disabled.
6868+ return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
6969+ return streamer(ctx, desc, cc, method, opts...)
7070+ }
7171+ }
7272+7373+ logger := l.Scoped(logScope, logDescription)
7474+ logger = logger.Scoped("streamingMethod", "errors that originated from a streaming method")
7575+7676+ return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, fullMethod string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
7777+ serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
7878+7979+ stream, err := streamer(ctx, desc, cc, fullMethod, opts...)
8080+ if err != nil {
8181+ // Note: This is a bit hacky, we provide nil initial and payload messages here since the message isn't available
8282+ // until after the stream is created.
8383+ //
8484+ // This is fine since the error is already available, and the non-utf8 string check is robust against nil messages.
8585+ logger := logger.Scoped("postInit", "errors that occurred after stream initialization, but before the first message was sent")
8686+ doLog(logger, serviceName, methodName, nil, nil, err)
8787+ return nil, err
8888+ }
8989+9090+ stream = newLoggingClientStream(stream, logger, serviceName, methodName)
9191+ return stream, nil
9292+ }
9393+}
9494+9595+// LoggingUnaryServerInterceptor returns a grpc.UnaryServerInterceptor that logs
9696+// errors that appear to come from the go-grpc implementation.
9797+func LoggingUnaryServerInterceptor(l log.Logger) grpc.UnaryServerInterceptor {
9898+ if !envLoggingEnabled {
9999+ // Just return the default handler if logging is disabled.
100100+ return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
101101+ return handler(ctx, req)
102102+ }
103103+ }
104104+105105+ logger := l.Scoped(logScope, logDescription)
106106+ logger = logger.Scoped("unaryMethod", "errors that originated from a unary method")
107107+108108+ return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
109109+ response, err := handler(ctx, req)
110110+ if err != nil {
111111+ serviceName, methodName := grpcutil.SplitMethodName(info.FullMethod)
112112+113113+ var initialRequest proto.Message
114114+ if m, ok := req.(proto.Message); ok {
115115+ initialRequest = m
116116+ }
117117+118118+ doLog(logger, serviceName, methodName, &initialRequest, response, err)
119119+ }
120120+121121+ return response, err
122122+ }
123123+}
124124+125125+// LoggingStreamServerInterceptor returns a grpc.StreamServerInterceptor that logs
126126+// errors that appear to come from the go-grpc implementation.
127127+func LoggingStreamServerInterceptor(l log.Logger) grpc.StreamServerInterceptor {
128128+ if !envLoggingEnabled {
129129+ // Just return the default handler if logging is disabled.
130130+ return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
131131+ return handler(srv, ss)
132132+ }
133133+ }
134134+135135+ logger := l.Scoped(logScope, logDescription)
136136+ logger = logger.Scoped("streamingMethod", "errors that originated from a streaming method")
137137+138138+ return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
139139+ serviceName, methodName := grpcutil.SplitMethodName(info.FullMethod)
140140+141141+ stream := newLoggingServerStream(ss, logger, serviceName, methodName)
142142+ return handler(srv, stream)
143143+ }
144144+}
145145+146146+func newLoggingServerStream(s grpc.ServerStream, logger log.Logger, serviceName, methodName string) grpc.ServerStream {
147147+ sendLogger := logger.Scoped("postMessageSend", "errors that occurred after sending a message")
148148+ receiveLogger := logger.Scoped("postMessageReceive", "errors that occurred after receiving a message")
149149+150150+ requestSaver := requestSavingServerStream{ServerStream: s}
151151+152152+ return &callBackServerStream{
153153+ ServerStream: &requestSaver,
154154+155155+ postMessageSend: func(m any, err error) {
156156+ if err != nil {
157157+ doLog(sendLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
158158+ }
159159+ },
160160+161161+ postMessageReceive: func(m any, err error) {
162162+ if err != nil && err != io.EOF { // EOF is expected at the end of a stream, so no need to log an error
163163+ doLog(receiveLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
164164+ }
165165+ },
166166+ }
167167+}
168168+169169+func newLoggingClientStream(s grpc.ClientStream, logger log.Logger, serviceName, methodName string) grpc.ClientStream {
170170+ sendLogger := logger.Scoped("postMessageSend", "errors that occurred after sending a message")
171171+ receiveLogger := logger.Scoped("postMessageReceive", "errors that occurred after receiving a message")
172172+173173+ requestSaver := requestSavingClientStream{ClientStream: s}
174174+175175+ return &callBackClientStream{
176176+ ClientStream: &requestSaver,
177177+178178+ postMessageSend: func(m any, err error) {
179179+ if err != nil {
180180+ doLog(sendLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
181181+ }
182182+ },
183183+184184+ postMessageReceive: func(m any, err error) {
185185+ if err != nil && err != io.EOF { // EOF is expected at the end of a stream, so no need to log an error
186186+ doLog(receiveLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
187187+ }
188188+ },
189189+ }
190190+}
191191+192192+func doLog(logger log.Logger, serviceName, methodName string, initialRequest *proto.Message, payload any, err error) {
193193+ if err == nil {
194194+ return
195195+ }
196196+197197+ s, ok := massageIntoStatusErr(err)
198198+ if !ok {
199199+ // If the error isn't a grpc error, we don't know how to handle it.
200200+ // Just return.
201201+ return
202202+ }
203203+204204+ if !probablyInternalGRPCError(s, allCheckers) {
205205+ return
206206+ }
207207+208208+ allFields := []log.Field{
209209+ log.String("grpcService", serviceName),
210210+ log.String("grpcMethod", methodName),
211211+ log.String("grpcCode", s.Code().String()),
212212+ }
213213+214214+ if envLogStackTracesEnabled {
215215+ allFields = append(allFields, log.String("errWithStack", fmt.Sprintf("%+v", err)))
216216+ }
217217+218218+ // Log the initial request message
219219+ if envLogMessagesEnabled {
220220+ fs := messageJSONFields(initialRequest, "initialRequestJSON", envLogMessagesHandleMaxMessageSizeBytes, envLogMessagesMaxJSONSizeBytes)
221221+ allFields = append(allFields, fs...)
222222+ }
223223+224224+ if isNonUTF8StringError(s) {
225225+ m, ok := payload.(proto.Message)
226226+ if ok {
227227+ allFields = append(allFields, nonUTF8StringLogFields(m)...)
228228+229229+ if envLogMessagesEnabled { // Log the latest message as well for non-utf8 errors
230230+ fs := messageJSONFields(&m, "messageJSON", envLogMessagesHandleMaxMessageSizeBytes, envLogMessagesMaxJSONSizeBytes)
231231+ allFields = append(allFields, fs...)
232232+ }
233233+ }
234234+ }
235235+236236+ logger.Error(s.Message(), allFields...)
237237+}
238238+239239+// messageJSONFields converts a protobuf message to a JSON string and returns it as a log field using the provided "key".
240240+// The resulting JSON string is truncated to maxJSONSizeBytes.
241241+//
242242+// If the size of the original protobuf message exceeds maxMessageSizeBytes or any serialization errors are encountered, log fields
243243+// describing the error are returned instead.
244244+func messageJSONFields(m *proto.Message, key string, maxMessageSizeBytes, maxJSONSizeBytes uint64) []log.Field {
245245+ if m == nil || *m == nil {
246246+ return nil
247247+ }
248248+249249+ if maxMessageSizeBytes > 0 {
250250+ size := uint64(proto.Size(*m))
251251+ if size > maxMessageSizeBytes {
252252+ err := fmt.Errorf(
253253+ "failed to marshal protobuf message (key: %q) to string: message too large (size %q, limit %q)",
254254+ key,
255255+ humanize.Bytes(size), humanize.Bytes(maxMessageSizeBytes),
256256+ )
257257+258258+ return []log.Field{log.Error(err)}
259259+ }
260260+ }
261261+262262+ // Note: we can't use the protojson library here since it doesn't support messages with non-UTF8 strings.
263263+ bs, err := json.Marshal(*m)
264264+ if err != nil {
265265+ err := fmt.Errorf("failed to marshal protobuf message (key: %q) to string: %w", key, err)
266266+ return []log.Field{log.Error(err)}
267267+ }
268268+269269+ s := truncate(string(bs), maxJSONSizeBytes)
270270+ return []log.Field{log.String(key, s)}
271271+}
272272+273273+// truncate shortens the string be to at most maxBytes bytes, appending a message indicating that the string was truncated if necessary.
274274+//
275275+// If maxBytes is 0, then the string is not truncated.
276276+func truncate(s string, maxBytes uint64) string {
277277+ if maxBytes <= 0 {
278278+ return s
279279+ }
280280+281281+ bytesToTruncate := len(s) - int(maxBytes)
282282+ if bytesToTruncate > 0 {
283283+ s = s[:maxBytes]
284284+ s = fmt.Sprintf("%s...(truncated %d bytes)", s, bytesToTruncate)
285285+ }
286286+287287+ return s
288288+}
289289+290290+func isNonUTF8StringError(s *status.Status) bool {
291291+ if s.Code() != codes.Internal {
292292+ return false
293293+ }
294294+295295+ return strings.Contains(s.Message(), "string field contains invalid UTF-8")
296296+}
297297+298298+// nonUTF8StringLogFields checks a protobuf message for fields that contain non-utf8 strings, and returns them as log fields.
299299+func nonUTF8StringLogFields(m proto.Message) []log.Field {
300300+ fs, err := findNonUTF8StringFields(m)
301301+ if err != nil {
302302+ err := fmt.Errorf("failed to find non-UTF8 string fields in protobuf message: %w", err)
303303+ return []log.Field{log.Error(err)}
304304+305305+ }
306306+307307+ return []log.Field{log.Strings("nonUTF8StringFields", fs)}
308308+}
+115
grpc/internalerrs/prometheus.go
···11+package internalerrs
22+33+import (
44+ "context"
55+ "io"
66+ "sync"
77+88+ "github.com/prometheus/client_golang/prometheus"
99+ "github.com/prometheus/client_golang/prometheus/promauto"
1010+ "github.com/sourcegraph/zoekt/grpc/grpcutil"
1111+ "google.golang.org/grpc"
1212+ "google.golang.org/grpc/codes"
1313+)
1414+1515+var metricGRPCMethodStatus = promauto.NewCounterVec(prometheus.CounterOpts{
1616+ Name: "grpc_method_status",
1717+ Help: "Counts the number of gRPC methods that return a given status code, and whether a possible error is an go-grpc internal error.",
1818+},
1919+ []string{
2020+ "grpc_service", // e.g. "gitserver.v1.GitserverService"
2121+ "grpc_method", // e.g. "Exec"
2222+ "grpc_code", // e.g. "NotFound"
2323+ "is_internal_error", // e.g. "true"
2424+ },
2525+)
2626+2727+// PrometheusUnaryClientInterceptor returns a grpc.UnaryClientInterceptor that observes the result of
2828+// the RPC and records it as a Prometheus metric ("src_grpc_method_status").
2929+func PrometheusUnaryClientInterceptor(ctx context.Context, fullMethod string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
3030+ serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
3131+3232+ err := invoker(ctx, fullMethod, req, reply, cc, opts...)
3333+ doObservation(serviceName, methodName, err)
3434+ return err
3535+}
3636+3737+// PrometheusStreamClientInterceptor returns a grpc.StreamClientInterceptor that observes the result of
3838+// the RPC and records it as a Prometheus metric ("src_grpc_method_status").
3939+//
4040+// If any errors are encountered during the stream, the first error is recorded. Otherwise, the
4141+// final status of the stream is recorded.
4242+func PrometheusStreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, fullMethod string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
4343+ serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
4444+4545+ s, err := streamer(ctx, desc, cc, fullMethod, opts...)
4646+ if err != nil {
4747+ doObservation(serviceName, methodName, err) // method failed to be invoked at all, record it
4848+ return nil, err
4949+ }
5050+5151+ return newPrometheusServerStream(s, serviceName, methodName), err
5252+}
5353+5454+// newPrometheusServerStream wraps a grpc.ClientStream to observe the first error
5555+// encountered during the stream, if any.
5656+func newPrometheusServerStream(s grpc.ClientStream, serviceName, methodName string) grpc.ClientStream {
5757+ // Design note: We only want a single observation for each RPC call: it either succeeds or fails
5858+ // with a single error. This ensures we do not double-count RPCs in Prometheus metrics.
5959+ //
6060+ // For unary calls this is straightforward, but for streaming RPCs we need to make a compromise. We only
6161+ // observe the first error (either sending or receiving) that occurs during the stream, instead of every
6262+ // error that occurs during the stream's lifespan. While this approach swallows some errors, it keeps the
6363+ // Prometheus metric count clean and non-duplicated. The logging interceptor handles surfacing all errors
6464+ // that are encountered during a stream.
6565+ var observeOnce sync.Once
6666+6767+ return &callBackClientStream{
6868+ ClientStream: s,
6969+ postMessageSend: func(_ any, err error) {
7070+ if err != nil {
7171+ observeOnce.Do(func() {
7272+ doObservation(serviceName, methodName, err)
7373+ })
7474+ }
7575+ },
7676+ postMessageReceive: func(_ any, err error) {
7777+ if err != nil {
7878+ if err == io.EOF {
7979+ // EOF signals end of stream, not an error. We handle this by setting err to nil, because
8080+ // we want to treat the stream as successfully completed.
8181+ err = nil
8282+ }
8383+8484+ observeOnce.Do(func() {
8585+ doObservation(serviceName, methodName, err)
8686+ })
8787+ }
8888+ },
8989+ }
9090+9191+}
9292+9393+func doObservation(serviceName, methodName string, rpcErr error) {
9494+ if rpcErr == nil {
9595+ // No error occurred, so we record a successful call.
9696+ metricGRPCMethodStatus.WithLabelValues(serviceName, methodName, codes.OK.String(), "false").Inc()
9797+ return
9898+ }
9999+100100+ s, ok := massageIntoStatusErr(rpcErr)
101101+ if !ok {
102102+ // An error occurred, but it was not an error that has a status.Status implementation. We record this as an unknown error.
103103+ metricGRPCMethodStatus.WithLabelValues(serviceName, methodName, codes.Unknown.String(), "false").Inc()
104104+ return
105105+ }
106106+107107+ if !probablyInternalGRPCError(s, allCheckers) {
108108+ // An error occurred, but it was not an internal gRPC error. We record this as a non-internal error.
109109+ metricGRPCMethodStatus.WithLabelValues(serviceName, methodName, s.Code().String(), "false").Inc()
110110+ return
111111+ }
112112+113113+ // An error occurred, and it looks like an internal gRPC error. We record this as an internal error.
114114+ metricGRPCMethodStatus.WithLabelValues(serviceName, methodName, s.Code().String(), "true").Inc()
115115+}
+1-1
grpc/server.go
grpc/server/server.go
···11-package grpc
11+package server
2233import (
44 "context"
+1-1
grpc/server_test.go
grpc/server/server_test.go
···11-package grpc
11+package server
2233import (
44 "context"
+7
grpc/testprotos/news/v1/buf.gen.yaml
···11+# Configuration file for https://buf.build/, which we use for Protobuf code generation.
22+version: v1
33+plugins:
44+ - plugin: buf.build/protocolbuffers/go:v1.29.1
55+ out: .
66+ opt:
77+ - paths=source_relative