fork of https://github.com/sourcegraph/zoekt
1package internalerrs
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "os"
8 "strconv"
9 "strings"
10 "sync"
11 "sync/atomic"
12 "unicode/utf8"
13
14 "github.com/dustin/go-humanize"
15 "google.golang.org/protobuf/proto"
16 "google.golang.org/protobuf/reflect/protopath"
17 "google.golang.org/protobuf/reflect/protorange"
18
19 "google.golang.org/grpc"
20 "google.golang.org/grpc/codes"
21 "google.golang.org/grpc/status"
22)
23
24// callBackClientStream is a grpc.ClientStream that calls a function after SendMsg and RecvMsg.
25type callBackClientStream struct {
26 grpc.ClientStream
27
28 postMessageSend func(message any, err error)
29 postMessageReceive func(message any, err error)
30}
31
32func (c *callBackClientStream) SendMsg(m any) error {
33 err := c.ClientStream.SendMsg(m)
34 if c.postMessageSend != nil {
35 c.postMessageSend(m, err)
36 }
37
38 return err
39}
40
41func (c *callBackClientStream) RecvMsg(m any) error {
42 err := c.ClientStream.RecvMsg(m)
43 if c.postMessageReceive != nil {
44 c.postMessageReceive(m, err)
45 }
46
47 return err
48}
49
50var _ grpc.ClientStream = &callBackClientStream{}
51
52// requestSavingClientStream is a grpc.ClientStream that saves the initial request sent to the server.
53type requestSavingClientStream struct {
54 grpc.ClientStream
55
56 initialRequest atomic.Pointer[proto.Message]
57 saveRequestOnce sync.Once
58}
59
60func (c *requestSavingClientStream) SendMsg(m any) error {
61 c.saveRequestOnce.Do(func() {
62 message, ok := m.(proto.Message)
63 if !ok {
64 return
65 }
66
67 c.initialRequest.Store(&message)
68 })
69
70 return c.ClientStream.SendMsg(m)
71}
72
73// InitialRequest returns the initial request sent by the client on the stream.
74func (c *requestSavingClientStream) InitialRequest() *proto.Message {
75 return c.initialRequest.Load()
76}
77
78var _ grpc.ClientStream = &requestSavingClientStream{}
79
80// requestSavingServerStream is a grpc.ServerStream that saves the initial request sent by the client.
81type requestSavingServerStream struct {
82 grpc.ServerStream
83
84 initialRequest atomic.Pointer[proto.Message]
85 saveRequestOnce sync.Once
86}
87
88func (s *requestSavingServerStream) RecvMsg(m any) error {
89 s.saveRequestOnce.Do(func() {
90 message, ok := m.(proto.Message)
91 if !ok {
92 return
93 }
94
95 s.initialRequest.Store(&message)
96 })
97
98 return s.ServerStream.RecvMsg(m)
99}
100
101// InitialRequest returns the initial request sent by the client on the stream.
102func (s *requestSavingServerStream) InitialRequest() *proto.Message {
103 return s.initialRequest.Load()
104}
105
106var _ grpc.ServerStream = &requestSavingServerStream{}
107
108// callBackServerStream is a grpc.ServerStream that calls a function after SendMsg and RecvMsg.
109type callBackServerStream struct {
110 grpc.ServerStream
111
112 postMessageSend func(message any, err error)
113 postMessageReceive func(message any, err error)
114}
115
116func (c *callBackServerStream) SendMsg(m any) error {
117 err := c.ServerStream.SendMsg(m)
118
119 if c.postMessageSend != nil {
120 c.postMessageSend(m, err)
121 }
122
123 return err
124}
125
126func (c *callBackServerStream) RecvMsg(m any) error {
127 err := c.ServerStream.RecvMsg(m)
128
129 if c.postMessageReceive != nil {
130 c.postMessageReceive(m, err)
131 }
132
133 return err
134}
135
136var _ grpc.ServerStream = &callBackServerStream{}
137
138// probablyInternalGRPCError checks if a gRPC status likely represents an error that comes from
139// the go-grpc library.
140//
141// Note: this is a heuristic and may not be 100% accurate.
142// From a cursory glance at the go-grpc source code, it seems most errors are prefixed with "grpc:". This may break in the future, but
143// it's better than nothing.
144// Some other ad-hoc errors that we traced back to the go-grpc library are also checked for.
145func probablyInternalGRPCError(s *status.Status, checkers []internalGRPCErrorChecker) bool {
146 if s.Code() == codes.OK {
147 return false
148 }
149
150 for _, checker := range checkers {
151 if checker(s) {
152 return true
153 }
154 }
155
156 return false
157}
158
159// internalGRPCErrorChecker is a function that checks if a gRPC status likely represents an error that comes from
160// the go-grpc library.
161type internalGRPCErrorChecker func(*status.Status) bool
162
163// allCheckers is a list of functions that check if a gRPC status likely represents an
164// error that comes from the go-grpc library.
165var allCheckers = []internalGRPCErrorChecker{
166 gRPCPrefixChecker,
167 gRPCResourceExhaustedChecker,
168 gRPCUnexpectedContentTypeChecker,
169}
170
171// gRPCPrefixChecker checks if a gRPC status likely represents an error that comes from the go-grpc library, by checking if the error message
172// is prefixed with "grpc: ".
173func gRPCPrefixChecker(s *status.Status) bool {
174 return s.Code() != codes.OK && strings.HasPrefix(s.Message(), "grpc: ")
175}
176
177// gRPCResourceExhaustedChecker checks if a gRPC status likely represents an error that comes from the go-grpc library, by checking if the error message
178// is prefixed with "trying to send message larger than max".
179func gRPCResourceExhaustedChecker(s *status.Status) bool {
180 // Observed from https://github.com/grpc/grpc-go/blob/756119c7de49e91b6f3b9d693b9850e1598938eb/stream.go#L884
181 return s.Code() == codes.ResourceExhausted && strings.HasPrefix(s.Message(), "trying to send message larger than max (")
182}
183
184// gRPCUnexpectedContentTypeChecker checks if a gRPC status likely represents an error that comes from the go-grpc library, by checking if the error message
185// is prefixed with "transport: received unexpected content-type".
186func gRPCUnexpectedContentTypeChecker(s *status.Status) bool {
187 // Observed from https://github.com/grpc/grpc-go/blob/2997e84fd8d18ddb000ac6736129b48b3c9773ec/internal/transport/http2_client.go#L1415-L1417
188 return s.Code() != codes.OK && strings.Contains(s.Message(), "transport: received unexpected content-type")
189}
190
191// findNonUTF8StringFields returns a list of field names that contain invalid UTF-8 strings
192// in the given proto message.
193//
194// Example: ["author", "attachments[1].key_value_attachment.data["key2"]`]
195func findNonUTF8StringFields(m proto.Message) ([]string, error) {
196 if m == nil {
197 return nil, nil
198 }
199
200 var fields []string
201 err := protorange.Range(m.ProtoReflect(), func(p protopath.Values) error {
202 last := p.Index(-1)
203 s, ok := last.Value.Interface().(string)
204 if ok && !utf8.ValidString(s) {
205 fieldName := p.Path[1:].String()
206 fields = append(fields, strings.TrimPrefix(fieldName, "."))
207 }
208
209 return nil
210 })
211 if err != nil {
212 return nil, fmt.Errorf("iterating over proto message: %w", err)
213 }
214
215 return fields, nil
216}
217
218// massageIntoStatusErr converts an error into a status.Status if possible.
219func massageIntoStatusErr(err error) (s *status.Status, ok bool) {
220 if err == nil {
221 return nil, false
222 }
223
224 if s, ok := status.FromError(err); ok {
225 return s, true
226 }
227
228 if errors.Is(err, context.Canceled) {
229 return status.New(codes.Canceled, context.Canceled.Error()), true
230 }
231
232 if errors.Is(err, context.DeadlineExceeded) {
233 return status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()), true
234 }
235
236 return nil, false
237}
238
239func envMustGetBool(key string, defaultValue bool) bool {
240 rawValue, ok := os.LookupEnv(key)
241 if !ok {
242 return defaultValue
243 }
244
245 value, err := strconv.ParseBool(rawValue)
246 if err != nil {
247 panic(fmt.Sprintf("Failed to parse enviroment variable %q as valid boolean. Got %q. Err: %s", key, rawValue, err))
248 }
249
250 return value
251}
252
253func envMustGetBytes(key string, defaultByteSize string) uint64 {
254 defaultByteSizeValue, err := humanize.ParseBytes(defaultByteSize)
255 if err != nil {
256 panic(fmt.Sprintf("Failed to parse default byte size %q as valid byte size. Err: %s", defaultByteSize, err))
257 }
258
259 rawValue, ok := os.LookupEnv(key)
260 if !ok {
261 return defaultByteSizeValue
262 }
263
264 value, err := humanize.ParseBytes(rawValue)
265 if err != nil {
266 panic(fmt.Sprintf("Failed to parse enviroment variable %q as valid byte size. Got %q. Err: %s", key, rawValue, err))
267 }
268
269 return value
270}