fork of https://github.com/sourcegraph/zoekt
1// Package rpc provides a zoekt.Searcher over RPC.
2package rpc
3
4import (
5 "context"
6 "encoding/gob"
7 "fmt"
8 "net/http"
9 "reflect"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/keegancsmith/rpc"
15 "github.com/sourcegraph/zoekt"
16 "github.com/sourcegraph/zoekt/query"
17 "github.com/sourcegraph/zoekt/rpc/internal/srv"
18)
19
20// DefaultRPCPath is the rpc path used by zoekt-webserver
21const DefaultRPCPath = "/rpc"
22
23// Server returns an http.Handler for searcher which is the server side of the
24// RPC calls.
25func Server(searcher zoekt.Searcher) http.Handler {
26 RegisterGob()
27 server := rpc.NewServer()
28 if err := server.Register(&srv.Searcher{Searcher: searcher}); err != nil {
29 // this should never fail, so we panic.
30 panic("unexpected error registering rpc server: " + err.Error())
31 }
32 return server
33}
34
35// Client connects to a Searcher HTTP RPC server at address (host:port) using
36// DefaultRPCPath path.
37func Client(address string) zoekt.Searcher {
38 return ClientAtPath(address, DefaultRPCPath)
39}
40
41// ClientAtPath connects to a Searcher HTTP RPC server at address and path
42// (http://host:port/path).
43func ClientAtPath(address, path string) zoekt.Searcher {
44 RegisterGob()
45 return &client{addr: address, path: path}
46}
47
48type client struct {
49 addr, path string
50
51 mu sync.Mutex // protects client and gen
52 cl *rpc.Client
53 gen int // incremented each time we dial
54}
55
56func (c *client) Search(ctx context.Context, q query.Q, opts *zoekt.SearchOptions) (*zoekt.SearchResult, error) {
57 var reply srv.SearchReply
58 err := c.call(ctx, "Searcher.Search", &srv.SearchArgs{Q: q, Opts: opts}, &reply)
59 return reply.Result, err
60}
61
62func (c *client) List(ctx context.Context, q query.Q, opts *zoekt.ListOptions) (*zoekt.RepoList, error) {
63 var reply srv.ListReply
64 err := c.call(ctx, "Searcher.List", &srv.ListArgs{Q: q, Opts: opts}, &reply)
65 return reply.List, err
66}
67
68func (c *client) call(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error {
69 // We try twice. If we fail to dial or fail to call the function we try
70 // again after 100ms. Unrolled to make logic clear
71 cl, gen, err := c.getRPCClient(ctx, 0)
72 if err == nil {
73 err = cl.Call(ctx, serviceMethod, args, reply)
74 if err != rpc.ErrShutdown {
75 return err
76 }
77 }
78
79 select {
80 case <-ctx.Done():
81 return ctx.Err()
82 case <-time.After(100 * time.Millisecond):
83 }
84
85 cl, _, err = c.getRPCClient(ctx, gen)
86 if err != nil {
87 return err
88 }
89 return cl.Call(ctx, serviceMethod, args, reply)
90}
91
92// getRPCClient gets the rpc client. If gen matches the current generation, we
93// redial and increment the generation. This is used to prevent concurrent
94// redialing on network failure.
95func (c *client) getRPCClient(ctx context.Context, gen int) (*rpc.Client, int, error) {
96 // coarse lock so we only dial once
97 c.mu.Lock()
98 defer c.mu.Unlock()
99 if gen != c.gen {
100 return c.cl, c.gen, nil
101 }
102 var timeout time.Duration
103 if deadline, ok := ctx.Deadline(); ok {
104 timeout = time.Until(deadline)
105 }
106 cl, err := rpc.DialHTTPPathTimeout("tcp", c.addr, c.path, timeout)
107 if err != nil {
108 return nil, c.gen, err
109 }
110 c.cl = cl
111 c.gen++
112 return c.cl, c.gen, nil
113}
114
115func (c *client) Close() {
116 c.mu.Lock()
117 defer c.mu.Unlock()
118 if c.cl != nil {
119 c.cl.Close()
120 }
121}
122
123func (c *client) String() string {
124 return fmt.Sprintf("rpcSearcher(%s/%s)", c.addr, c.path)
125}
126
127var once sync.Once
128
129// RegisterGob registers various query types with gob. It can be called more than
130// once, because calls to gob.Register are protected by a sync.Once.
131func RegisterGob() {
132 once.Do(func() {
133 gobRegister(&query.And{})
134 gobRegister(&query.BranchRepos{})
135 gobRegister(&query.BranchesRepos{})
136 gobRegister(&query.Branch{})
137 gobRegister(&query.Const{})
138 gobRegister(&query.FileNameSet{})
139 gobRegister(&query.GobCache{})
140 gobRegister(&query.Language{})
141 gobRegister(&query.Not{})
142 gobRegister(&query.Or{})
143 gobRegister(&query.Regexp{})
144 gobRegister(&query.RepoRegexp{})
145 gobRegister(&query.RepoSet{})
146 gobRegister(&query.RepoIDs{})
147 gobRegister(&query.Repo{})
148 gobRegister(&query.Substring{})
149 gobRegister(&query.Symbol{})
150 gobRegister(&query.Type{})
151 gobRegister(query.RawConfig(41))
152 })
153}
154
155// gobRegister exists to keep backwards compatibility around renames of the go
156// module. This is to avoid breaking the wire protocol due to refactors. In
157// particular in August 2022 we renamed the go module from
158// github.com/google/zoekt to github.com/sourcegraph/zoekt which breaks the
159// wire protocol. So this function will replace those names so we keep using
160// google/zoekt.
161func gobRegister(value any) {
162 name := gobRegister_name(value)
163
164 name = strings.Replace(name, "github.com/sourcegraph/", "github.com/google/", 1)
165
166 gob.RegisterName(name, value)
167}
168
169// gobRegister_name is copy-pasta from the stdlib gob.Register, returning the
170// name it picks for gob.RegisterName.
171func gobRegister_name(value any) string {
172 // Default to printed representation for unnamed types
173 rt := reflect.TypeOf(value)
174 name := rt.String()
175
176 // But for named types (or pointers to them), qualify with import path (but see inner comment).
177 // Dereference one pointer looking for a named type.
178 star := ""
179 if rt.Name() == "" {
180 if pt := rt; pt.Kind() == reflect.Pointer {
181 star = "*"
182 // NOTE: The following line should be rt = pt.Elem() to implement
183 // what the comment above claims, but fixing it would break compatibility
184 // with existing gobs.
185 //
186 // Given package p imported as "full/p" with these definitions:
187 // package p
188 // type T1 struct { ... }
189 // this table shows the intended and actual strings used by gob to
190 // name the types:
191 //
192 // Type Correct string Actual string
193 //
194 // T1 full/p.T1 full/p.T1
195 // *T1 *full/p.T1 *p.T1
196 //
197 // The missing full path cannot be fixed without breaking existing gob decoders.
198 rt = pt
199 }
200 }
201 if rt.Name() != "" {
202 if rt.PkgPath() == "" {
203 name = star + rt.Name()
204 } else {
205 name = star + rt.PkgPath() + "." + rt.Name()
206 }
207 }
208
209 return name
210}