fork of https://github.com/sourcegraph/zoekt
1package internalerrs
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "strings"
9
10 "github.com/dustin/go-humanize"
11 "github.com/sourcegraph/zoekt/grpc/grpcutil"
12
13 "google.golang.org/grpc/codes"
14 "google.golang.org/protobuf/proto"
15
16 "github.com/sourcegraph/log"
17 "google.golang.org/grpc"
18 "google.golang.org/grpc/status"
19)
20
21var (
22 logScope = "gRPC.internal.error.reporter"
23 logDescription = "logs gRPC errors that appear to come from the go-grpc implementation"
24
25 envLoggingEnabled = envMustGetBool("GRPC_INTERNAL_ERROR_LOGGING_ENABLED", true) // "Enables logging of gRPC internal errors"
26 envLogStackTracesEnabled = envMustGetBool("GRPC_INTERNAL_ERROR_LOGGING_LOG_STACK_TRACES", false) // "Enables including stack traces in logs of gRPC internal errors"
27
28 envLogMessagesEnabled = envMustGetBool("GRPC_INTERNAL_ERROR_LOGGING_LOG_PROTOBUF_MESSAGES_ENABLED", false) // "Enables inclusion of raw protobuf messages in the gRPC internal error logs"
29 envLogMessagesHandleMaxMessageSizeBytes = envMustGetBytes("GRPC_INTERNAL_ERROR_LOGGING_LOG_PROTOBUF_MESSAGES_HANDLING_MAX_MESSAGE_SIZE_BYTES", "100MB") // "Maximum size of protobuf messages that can be included in gRPC internal error logs. The purpose of this is to avoid excessive allocations. 0 bytes mean no limit.")
30 envLogMessagesMaxJSONSizeBytes = envMustGetBytes("GRPC_INTERNAL_ERROR_LOGGING_LOG_PROTOBUF_MESSAGES_JSON_TRUNCATION_SIZE_BYTES", "1KB") // "Maximum size of the JSON representation of protobuf messages to log. JSON representations larger than this value will be truncated. 0 bytes disables truncation.")
31)
32
33// LoggingUnaryClientInterceptor returns a grpc.UnaryClientInterceptor that logs
34// errors that appear to come from the go-grpc implementation.
35func LoggingUnaryClientInterceptor(l log.Logger) grpc.UnaryClientInterceptor {
36 if !envLoggingEnabled {
37 // Just return the default invoker if logging is disabled.
38 return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
39 return invoker(ctx, method, req, reply, cc, opts...)
40 }
41 }
42
43 logger := l.Scoped(logScope)
44 logger = logger.Scoped("unaryMethod")
45
46 return func(ctx context.Context, fullMethod string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
47 err := invoker(ctx, fullMethod, req, reply, cc, opts...)
48 if err != nil {
49 serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
50
51 var initialRequest proto.Message
52 if m, ok := req.(proto.Message); ok {
53 initialRequest = m
54 }
55
56 doLog(logger, serviceName, methodName, &initialRequest, req, err)
57 }
58
59 return err
60 }
61}
62
63// LoggingStreamClientInterceptor returns a grpc.StreamClientInterceptor that logs
64// errors that appear to come from the go-grpc implementation.
65func LoggingStreamClientInterceptor(l log.Logger) grpc.StreamClientInterceptor {
66 if !envLoggingEnabled {
67 // Just return the default streamer if logging is disabled.
68 return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
69 return streamer(ctx, desc, cc, method, opts...)
70 }
71 }
72
73 logger := l.Scoped(logScope)
74 logger = logger.Scoped("streamingMethod")
75
76 return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, fullMethod string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
77 serviceName, methodName := grpcutil.SplitMethodName(fullMethod)
78
79 stream, err := streamer(ctx, desc, cc, fullMethod, opts...)
80 if err != nil {
81 // Note: This is a bit hacky, we provide nil initial and payload messages here since the message isn't available
82 // until after the stream is created.
83 //
84 // This is fine since the error is already available, and the non-utf8 string check is robust against nil messages.
85 logger := logger.Scoped("postInit")
86 doLog(logger, serviceName, methodName, nil, nil, err)
87 return nil, err
88 }
89
90 stream = newLoggingClientStream(stream, logger, serviceName, methodName)
91 return stream, nil
92 }
93}
94
95// LoggingUnaryServerInterceptor returns a grpc.UnaryServerInterceptor that logs
96// errors that appear to come from the go-grpc implementation.
97func LoggingUnaryServerInterceptor(l log.Logger) grpc.UnaryServerInterceptor {
98 if !envLoggingEnabled {
99 // Just return the default handler if logging is disabled.
100 return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
101 return handler(ctx, req)
102 }
103 }
104
105 logger := l.Scoped(logScope)
106 logger = logger.Scoped("unaryMethod")
107
108 return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
109 response, err := handler(ctx, req)
110 if err != nil {
111 serviceName, methodName := grpcutil.SplitMethodName(info.FullMethod)
112
113 var initialRequest proto.Message
114 if m, ok := req.(proto.Message); ok {
115 initialRequest = m
116 }
117
118 doLog(logger, serviceName, methodName, &initialRequest, response, err)
119 }
120
121 return response, err
122 }
123}
124
125// LoggingStreamServerInterceptor returns a grpc.StreamServerInterceptor that logs
126// errors that appear to come from the go-grpc implementation.
127func LoggingStreamServerInterceptor(l log.Logger) grpc.StreamServerInterceptor {
128 if !envLoggingEnabled {
129 // Just return the default handler if logging is disabled.
130 return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
131 return handler(srv, ss)
132 }
133 }
134
135 logger := l.Scoped(logScope)
136 logger = logger.Scoped("streamingMethod")
137
138 return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
139 serviceName, methodName := grpcutil.SplitMethodName(info.FullMethod)
140
141 stream := newLoggingServerStream(ss, logger, serviceName, methodName)
142 return handler(srv, stream)
143 }
144}
145
146func newLoggingServerStream(s grpc.ServerStream, logger log.Logger, serviceName, methodName string) grpc.ServerStream {
147 sendLogger := logger.Scoped("postMessageSend")
148 receiveLogger := logger.Scoped("postMessageReceive")
149
150 requestSaver := requestSavingServerStream{ServerStream: s}
151
152 return &callBackServerStream{
153 ServerStream: &requestSaver,
154
155 postMessageSend: func(m any, err error) {
156 if err != nil {
157 doLog(sendLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
158 }
159 },
160
161 postMessageReceive: func(m any, err error) {
162 if err != nil && err != io.EOF { // EOF is expected at the end of a stream, so no need to log an error
163 doLog(receiveLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
164 }
165 },
166 }
167}
168
169func newLoggingClientStream(s grpc.ClientStream, logger log.Logger, serviceName, methodName string) grpc.ClientStream {
170 sendLogger := logger.Scoped("postMessageSend")
171 receiveLogger := logger.Scoped("postMessageReceive")
172
173 requestSaver := requestSavingClientStream{ClientStream: s}
174
175 return &callBackClientStream{
176 ClientStream: &requestSaver,
177
178 postMessageSend: func(m any, err error) {
179 if err != nil {
180 doLog(sendLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
181 }
182 },
183
184 postMessageReceive: func(m any, err error) {
185 if err != nil && err != io.EOF { // EOF is expected at the end of a stream, so no need to log an error
186 doLog(receiveLogger, serviceName, methodName, requestSaver.InitialRequest(), m, err)
187 }
188 },
189 }
190}
191
192func doLog(logger log.Logger, serviceName, methodName string, initialRequest *proto.Message, payload any, err error) {
193 if err == nil {
194 return
195 }
196
197 s, ok := massageIntoStatusErr(err)
198 if !ok {
199 // If the error isn't a grpc error, we don't know how to handle it.
200 // Just return.
201 return
202 }
203
204 if !probablyInternalGRPCError(s, allCheckers) {
205 return
206 }
207
208 allFields := []log.Field{
209 log.String("grpcService", serviceName),
210 log.String("grpcMethod", methodName),
211 log.String("grpcCode", s.Code().String()),
212 }
213
214 if envLogStackTracesEnabled {
215 allFields = append(allFields, log.String("errWithStack", fmt.Sprintf("%+v", err)))
216 }
217
218 // Log the initial request message
219 if envLogMessagesEnabled {
220 fs := messageJSONFields(initialRequest, "initialRequestJSON", envLogMessagesHandleMaxMessageSizeBytes, envLogMessagesMaxJSONSizeBytes)
221 allFields = append(allFields, fs...)
222 }
223
224 if isNonUTF8StringError(s) {
225 m, ok := payload.(proto.Message)
226 if ok {
227 allFields = append(allFields, nonUTF8StringLogFields(m)...)
228
229 if envLogMessagesEnabled { // Log the latest message as well for non-utf8 errors
230 fs := messageJSONFields(&m, "messageJSON", envLogMessagesHandleMaxMessageSizeBytes, envLogMessagesMaxJSONSizeBytes)
231 allFields = append(allFields, fs...)
232 }
233 }
234 }
235
236 logger.Error(s.Message(), allFields...)
237}
238
239// messageJSONFields converts a protobuf message to a JSON string and returns it as a log field using the provided "key".
240// The resulting JSON string is truncated to maxJSONSizeBytes.
241//
242// If the size of the original protobuf message exceeds maxMessageSizeBytes or any serialization errors are encountered, log fields
243// describing the error are returned instead.
244func messageJSONFields(m *proto.Message, key string, maxMessageSizeBytes, maxJSONSizeBytes uint64) []log.Field {
245 if m == nil || *m == nil {
246 return nil
247 }
248
249 if maxMessageSizeBytes > 0 {
250 size := uint64(proto.Size(*m))
251 if size > maxMessageSizeBytes {
252 err := fmt.Errorf(
253 "failed to marshal protobuf message (key: %q) to string: message too large (size %q, limit %q)",
254 key,
255 humanize.Bytes(size), humanize.Bytes(maxMessageSizeBytes),
256 )
257
258 return []log.Field{log.Error(err)}
259 }
260 }
261
262 // Note: we can't use the protojson library here since it doesn't support messages with non-UTF8 strings.
263 bs, err := json.Marshal(*m)
264 if err != nil {
265 err := fmt.Errorf("failed to marshal protobuf message (key: %q) to string: %w", key, err)
266 return []log.Field{log.Error(err)}
267 }
268
269 s := truncate(string(bs), maxJSONSizeBytes)
270 return []log.Field{log.String(key, s)}
271}
272
273// truncate shortens the string be to at most maxBytes bytes, appending a message indicating that the string was truncated if necessary.
274//
275// If maxBytes is 0, then the string is not truncated.
276func truncate(s string, maxBytes uint64) string {
277 if maxBytes <= 0 {
278 return s
279 }
280
281 bytesToTruncate := len(s) - int(maxBytes)
282 if bytesToTruncate > 0 {
283 s = s[:maxBytes]
284 s = fmt.Sprintf("%s...(truncated %d bytes)", s, bytesToTruncate)
285 }
286
287 return s
288}
289
290func isNonUTF8StringError(s *status.Status) bool {
291 if s.Code() != codes.Internal {
292 return false
293 }
294
295 return strings.Contains(s.Message(), "string field contains invalid UTF-8")
296}
297
298// nonUTF8StringLogFields checks a protobuf message for fields that contain non-utf8 strings, and returns them as log fields.
299func nonUTF8StringLogFields(m proto.Message) []log.Field {
300 fs, err := findNonUTF8StringFields(m)
301 if err != nil {
302 err := fmt.Errorf("failed to find non-UTF8 string fields in protobuf message: %w", err)
303 return []log.Field{log.Error(err)}
304
305 }
306
307 return []log.Field{log.Strings("nonUTF8StringFields", fs)}
308}