fork of https://github.com/sourcegraph/zoekt
1package tenant
2
3import (
4 "context"
5 "fmt"
6 "runtime/pprof"
7 "strconv"
8
9 "google.golang.org/grpc"
10 "google.golang.org/grpc/codes"
11 "google.golang.org/grpc/metadata"
12 "google.golang.org/grpc/status"
13
14 grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
15
16 "github.com/sourcegraph/zoekt/grpc/propagator"
17 "github.com/sourcegraph/zoekt/internal/tenant/internal/tenanttype"
18)
19
20const (
21 // headerKeyTenantID is the header key for the tenant ID.
22 headerKeyTenantID = "X-Sourcegraph-Tenant-ID"
23
24 // headerValueNoTenant indicates the request has no tenant.
25 headerValueNoTenant = "none"
26)
27
28// Propagator implements the propagator.Propagator interface
29// for propagating tenants across RPC calls. This is modeled directly on
30// the HTTP middleware in this package, and should work exactly the same.
31type Propagator struct{}
32
33var _ propagator.Propagator = &Propagator{}
34
35func (Propagator) FromContext(ctx context.Context) metadata.MD {
36 md := make(metadata.MD)
37 tenant, ok := tenanttype.GetTenant(ctx)
38 if !ok {
39 md.Append(headerKeyTenantID, headerValueNoTenant)
40 } else {
41 md.Append(headerKeyTenantID, strconv.Itoa(tenant.ID()))
42 }
43 return md
44}
45
46func (Propagator) InjectContext(ctx context.Context, md metadata.MD) (context.Context, error) {
47 var raw string
48 if vals := md.Get(headerKeyTenantID); len(vals) > 0 {
49 raw = vals[0]
50 }
51 switch raw {
52 case "", headerValueNoTenant:
53 // Nothing to do, empty tenant.
54 return ctx, nil
55 default:
56 tenant, err := tenanttype.Unmarshal(raw)
57 if err != nil {
58 // The tenant value is invalid.
59 return ctx, status.New(codes.InvalidArgument, fmt.Errorf("bad tenant value in metadata: %w", err).Error()).Err()
60 }
61 return tenanttype.WithTenant(ctx, tenant), nil
62 }
63}
64
65// UnaryServerInterceptor is a grpc.UnaryServerInterceptor that injects the tenant ID
66// from the context into pprof labels.
67func UnaryServerInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (response any, err error) {
68 if tnt, ok := tenanttype.GetTenant(ctx); ok {
69 defer pprof.SetGoroutineLabels(ctx)
70 ctx = pprof.WithLabels(ctx, pprof.Labels("tenant", tenanttype.Marshal(tnt)))
71 pprof.SetGoroutineLabels(ctx)
72 }
73
74 return handler(ctx, req)
75}
76
77// StreamServerInterceptor is a grpc.StreamServerInterceptor that injects the tenant ID
78// from the context into pprof labels.
79func StreamServerInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
80 if tnt, ok := tenanttype.GetTenant(ss.Context()); ok {
81 ctx := ss.Context()
82 defer pprof.SetGoroutineLabels(ctx)
83 ctx = pprof.WithLabels(ctx, pprof.Labels("tenant", tenanttype.Marshal(tnt)))
84
85 pprof.SetGoroutineLabels(ctx)
86
87 ss = &grpc_middleware.WrappedServerStream{
88 ServerStream: ss,
89 WrappedContext: ctx,
90 }
91 }
92
93 return handler(srv, ss)
94}