fork of https://github.com/sourcegraph/zoekt
1package internalerrs
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "strings"
9
10 "github.com/dustin/go-humanize"
11 "github.com/sourcegraph/log"
12 "google.golang.org/grpc"
13 "google.golang.org/grpc/codes"
14 "google.golang.org/grpc/status"
15 "google.golang.org/protobuf/proto"
16
17 "github.com/sourcegraph/zoekt/grpc/grpcutil"
18)
19
20var (
21 logScope = "gRPC.internal.error.reporter"
22 logDescription = "logs gRPC errors that appear to come from the go-grpc implementation"
23
24 envLoggingEnabled = envMustGetBool("GRPC_INTERNAL_ERROR_LOGGING_ENABLED", true) // "Enables logging of gRPC internal errors"
25 envLogStackTracesEnabled = envMustGetBool("GRPC_INTERNAL_ERROR_LOGGING_LOG_STACK_TRACES", false) // "Enables including stack traces in logs of gRPC internal errors"
26
27 envLogMessagesEnabled = envMustGetBool("GRPC_INTERNAL_ERROR_LOGGING_LOG_PROTOBUF_MESSAGES_ENABLED", false) // "Enables inclusion of raw protobuf messages in the gRPC internal error logs"
28 envLogMessagesHandleMaxMessageSizeBytes = envMustGetBytes("GRPC_INTERNAL_ERROR_LOGGING_LOG_PROTOBUF_MESSAGES_HANDLING_MAX_MESSAGE_SIZE_BYTES", "100MB") // "Maximum size of protobuf messages that can be included in gRPC internal error logs. The purpose of this is to avoid excessive allocations. 0 bytes mean no limit.")
29 envLogMessagesMaxJSONSizeBytes = envMustGetBytes("GRPC_INTERNAL_ERROR_LOGGING_LOG_PROTOBUF_MESSAGES_JSON_TRUNCATION_SIZE_BYTES", "1KB") // "Maximum size of the JSON representation of protobuf messages to log. JSON representations larger than this value will be truncated. 0 bytes disables truncation.")
30)
31
32// LoggingUnaryClientInterceptor returns a grpc.UnaryClientInterceptor that logs
33// errors that appear to come from the go-grpc implementation.
34func LoggingUnaryClientInterceptor(l log.Logger) grpc.UnaryClientInterceptor {
35 if !envLoggingEnabled {
36 // Just return the default invoker if logging is disabled.
37 return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
38 return invoker(ctx, method, req, reply, cc, opts...)
39 }
40 }
41
42 logger := l.Scoped(logScope)
43 logger = logger.Scoped("unaryMethod")
44
45 return func(ctx context.Context, fullMethod string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
46 err := invoker(ctx, fullMethod, req, reply, cc, opts...)
47 if err != nil {
48 serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
49
50 var initialRequest proto.Message
51 if m, ok := req.(proto.Message); ok {
52 initialRequest = m
53 }
54
55 doLog(logger, serviceName, methodName, &initialRequest, req, err)
56 }
57
58 return err
59 }
60}
61
62// LoggingStreamClientInterceptor returns a grpc.StreamClientInterceptor that logs
63// errors that appear to come from the go-grpc implementation.
64func LoggingStreamClientInterceptor(l log.Logger) grpc.StreamClientInterceptor {
65 if !envLoggingEnabled {
66 // Just return the default streamer if logging is disabled.
67 return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
68 return streamer(ctx, desc, cc, method, opts...)
69 }
70 }
71
72 logger := l.Scoped(logScope)
73 logger = logger.Scoped("streamingMethod")
74
75 return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, fullMethod string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
76 serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
77
78 stream, err := streamer(ctx, desc, cc, fullMethod, opts...)
79 if err != nil {
80 // Note: This is a bit hacky, we provide nil initial and payload messages here since the message isn't available
81 // until after the stream is created.
82 //
83 // This is fine since the error is already available, and the non-utf8 string check is robust against nil messages.
84 logger := logger.Scoped("postInit")
85 doLog(logger, serviceName, methodName, nil, nil, err)
86 return nil, err
87 }
88
89 stream = newLoggingClientStream(stream, logger, serviceName, methodName)
90 return stream, nil
91 }
92}
93
94// LoggingUnaryServerInterceptor returns a grpc.UnaryServerInterceptor that logs
95// errors that appear to come from the go-grpc implementation.
96func LoggingUnaryServerInterceptor(l log.Logger) grpc.UnaryServerInterceptor {
97 if !envLoggingEnabled {
98 // Just return the default handler if logging is disabled.
99 return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
100 return handler(ctx, req)
101 }
102 }
103
104 logger := l.Scoped(logScope)
105 logger = logger.Scoped("unaryMethod")
106
107 return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
108 response, err := handler(ctx, req)
109 if err != nil {
110 serviceName, methodName := grpcutil.SplitMethodName(info.FullMethod)
111
112 var initialRequest proto.Message
113 if m, ok := req.(proto.Message); ok {
114 initialRequest = m
115 }
116
117 doLog(logger, serviceName, methodName, &initialRequest, response, err)
118 }
119
120 return response, err
121 }
122}
123
124// LoggingStreamServerInterceptor returns a grpc.StreamServerInterceptor that logs
125// errors that appear to come from the go-grpc implementation.
126func LoggingStreamServerInterceptor(l log.Logger) grpc.StreamServerInterceptor {
127 if !envLoggingEnabled {
128 // Just return the default handler if logging is disabled.
129 return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
130 return handler(srv, ss)
131 }
132 }
133
134 logger := l.Scoped(logScope)
135 logger = logger.Scoped("streamingMethod")
136
137 return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
138 serviceName, methodName := grpcutil.SplitMethodName(info.FullMethod)
139
140 stream := newLoggingServerStream(ss, logger, serviceName, methodName)
141 return handler(srv, stream)
142 }
143}
144
145func newLoggingServerStream(s grpc.ServerStream, logger log.Logger, serviceName, methodName string) grpc.ServerStream {
146 sendLogger := logger.Scoped("postMessageSend")
147 receiveLogger := logger.Scoped("postMessageReceive")
148
149 requestSaver := requestSavingServerStream{ServerStream: s}
150
151 return &callBackServerStream{
152 ServerStream: &requestSaver,
153
154 postMessageSend: func(m any, err error) {
155 if err != nil {
156 doLog(sendLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
157 }
158 },
159
160 postMessageReceive: func(m any, err error) {
161 if err != nil && err != io.EOF { // EOF is expected at the end of a stream, so no need to log an error
162 doLog(receiveLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
163 }
164 },
165 }
166}
167
168func newLoggingClientStream(s grpc.ClientStream, logger log.Logger, serviceName, methodName string) grpc.ClientStream {
169 sendLogger := logger.Scoped("postMessageSend")
170 receiveLogger := logger.Scoped("postMessageReceive")
171
172 requestSaver := requestSavingClientStream{ClientStream: s}
173
174 return &callBackClientStream{
175 ClientStream: &requestSaver,
176
177 postMessageSend: func(m any, err error) {
178 if err != nil {
179 doLog(sendLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
180 }
181 },
182
183 postMessageReceive: func(m any, err error) {
184 if err != nil && err != io.EOF { // EOF is expected at the end of a stream, so no need to log an error
185 doLog(receiveLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
186 }
187 },
188 }
189}
190
191func doLog(logger log.Logger, serviceName, methodName string, initialRequest *proto.Message, payload any, err error) {
192 if err == nil {
193 return
194 }
195
196 s, ok := massageIntoStatusErr(err)
197 if !ok {
198 // If the error isn't a grpc error, we don't know how to handle it.
199 // Just return.
200 return
201 }
202
203 if !probablyInternalGRPCError(s, allCheckers) {
204 return
205 }
206
207 allFields := []log.Field{
208 log.String("grpcService", serviceName),
209 log.String("grpcMethod", methodName),
210 log.String("grpcCode", s.Code().String()),
211 }
212
213 if envLogStackTracesEnabled {
214 allFields = append(allFields, log.String("errWithStack", fmt.Sprintf("%+v", err)))
215 }
216
217 // Log the initial request message
218 if envLogMessagesEnabled {
219 fs := messageJSONFields(initialRequest, "initialRequestJSON", envLogMessagesHandleMaxMessageSizeBytes, envLogMessagesMaxJSONSizeBytes)
220 allFields = append(allFields, fs...)
221 }
222
223 if isNonUTF8StringError(s) {
224 m, ok := payload.(proto.Message)
225 if ok {
226 allFields = append(allFields, nonUTF8StringLogFields(m)...)
227
228 if envLogMessagesEnabled { // Log the latest message as well for non-utf8 errors
229 fs := messageJSONFields(&m, "messageJSON", envLogMessagesHandleMaxMessageSizeBytes, envLogMessagesMaxJSONSizeBytes)
230 allFields = append(allFields, fs...)
231 }
232 }
233 }
234
235 logger.Error(s.Message(), allFields...)
236}
237
238// messageJSONFields converts a protobuf message to a JSON string and returns it as a log field using the provided "key".
239// The resulting JSON string is truncated to maxJSONSizeBytes.
240//
241// If the size of the original protobuf message exceeds maxMessageSizeBytes or any serialization errors are encountered, log fields
242// describing the error are returned instead.
243func messageJSONFields(m *proto.Message, key string, maxMessageSizeBytes, maxJSONSizeBytes uint64) []log.Field {
244 if m == nil || *m == nil {
245 return nil
246 }
247
248 if maxMessageSizeBytes > 0 {
249 size := uint64(proto.Size(*m))
250 if size > maxMessageSizeBytes {
251 err := fmt.Errorf(
252 "failed to marshal protobuf message (key: %q) to string: message too large (size %q, limit %q)",
253 key,
254 humanize.Bytes(size), humanize.Bytes(maxMessageSizeBytes),
255 )
256
257 return []log.Field{log.Error(err)}
258 }
259 }
260
261 // Note: we can't use the protojson library here since it doesn't support messages with non-UTF8 strings.
262 bs, err := json.Marshal(*m)
263 if err != nil {
264 err := fmt.Errorf("failed to marshal protobuf message (key: %q) to string: %w", key, err)
265 return []log.Field{log.Error(err)}
266 }
267
268 s := truncate(string(bs), maxJSONSizeBytes)
269 return []log.Field{log.String(key, s)}
270}
271
272// truncate shortens the string be to at most maxBytes bytes, appending a message indicating that the string was truncated if necessary.
273//
274// If maxBytes is 0, then the string is not truncated.
275func truncate(s string, maxBytes uint64) string {
276 if maxBytes <= 0 {
277 return s
278 }
279
280 bytesToTruncate := len(s) - int(maxBytes)
281 if bytesToTruncate > 0 {
282 s = s[:maxBytes]
283 s = fmt.Sprintf("%s...(truncated %d bytes)", s, bytesToTruncate)
284 }
285
286 return s
287}
288
289func isNonUTF8StringError(s *status.Status) bool {
290 if s.Code() != codes.Internal {
291 return false
292 }
293
294 return strings.Contains(s.Message(), "string field contains invalid UTF-8")
295}
296
297// nonUTF8StringLogFields checks a protobuf message for fields that contain non-utf8 strings, and returns them as log fields.
298func nonUTF8StringLogFields(m proto.Message) []log.Field {
299 fs, err := findNonUTF8StringFields(m)
300 if err != nil {
301 err := fmt.Errorf("failed to find non-UTF8 string fields in protobuf message: %w", err)
302 return []log.Field{log.Error(err)}
303
304 }
305
306 return []log.Field{log.Strings("nonUTF8StringFields", fs)}
307}