fork of https://github.com/sourcegraph/zoekt
1package server
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "net/http/httptest"
9 "net/url"
10 "testing"
11 "testing/quick"
12
13 "github.com/google/go-cmp/cmp"
14 "github.com/google/go-cmp/cmp/cmpopts"
15 "go.uber.org/atomic"
16 "golang.org/x/net/http2"
17 "golang.org/x/net/http2/h2c"
18 "google.golang.org/grpc"
19 "google.golang.org/grpc/credentials/insecure"
20 "google.golang.org/protobuf/proto"
21 "google.golang.org/protobuf/testing/protocmp"
22
23 "github.com/sourcegraph/zoekt"
24 webserverv1 "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1"
25 "github.com/sourcegraph/zoekt/internal/mockSearcher"
26 "github.com/sourcegraph/zoekt/query"
27)
28
29func TestClientServer(t *testing.T) {
30 mock := &mockSearcher.MockSearcher{
31 WantSearch: query.NewAnd(mustParse("hello world|universe"), query.NewSingleBranchesRepos("HEAD", 1, 2)),
32 SearchResult: &zoekt.SearchResult{
33 Files: []zoekt.FileMatch{
34 {FileName: "bin.go"},
35 {FileName: "foo.go"},
36 },
37 },
38
39 WantList: &query.Const{Value: true},
40 RepoList: &zoekt.RepoList{
41 Repos: []*zoekt.RepoListEntry{
42 {
43 Repository: zoekt.Repository{
44 ID: 2,
45 Name: "foo/bar",
46 },
47 },
48 },
49 },
50 }
51
52 gs := grpc.NewServer()
53 defer gs.Stop()
54
55 webserverv1.RegisterWebserverServiceServer(gs, NewServer(adapter{mock}))
56 ts := httptest.NewServer(h2c.NewHandler(gs, &http2.Server{}))
57 defer ts.Close()
58
59 u, err := url.Parse(ts.URL)
60 if err != nil {
61 t.Fatal(err)
62 }
63 cc, err := grpc.Dial(u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()))
64 if err != nil {
65 t.Fatal(err)
66 }
67 defer cc.Close()
68
69 client := webserverv1.NewWebserverServiceClient(cc)
70
71 r, err := client.Search(context.Background(), &webserverv1.SearchRequest{Query: query.QToProto(mock.WantSearch)})
72 if err != nil {
73 t.Fatal(err)
74 }
75 if !proto.Equal(r, mock.SearchResult.ToProto()) {
76 t.Fatalf("got %+v, want %+v", r, mock.SearchResult.ToProto())
77 }
78
79 l, err := client.List(context.Background(), &webserverv1.ListRequest{Query: query.QToProto(mock.WantList)})
80 if err != nil {
81 t.Fatal(err)
82 }
83
84 if !proto.Equal(l, mock.RepoList.ToProto()) {
85 t.Fatalf("got %+v, want %+v", l, mock.RepoList.ToProto())
86 }
87
88 request := webserverv1.StreamSearchRequest{
89 Request: &webserverv1.SearchRequest{Query: query.QToProto(mock.WantSearch)},
90 }
91
92 cs, err := client.StreamSearch(context.Background(), &request)
93 if err != nil {
94 t.Fatal(err)
95 }
96
97 allResponses := readAllStream(t, cs)
98
99 // check to make sure that we get the same set of file matches back
100 var receivedFileMatches []*webserverv1.FileMatch
101 for _, r := range allResponses {
102 receivedFileMatches = append(receivedFileMatches, r.GetFiles()...)
103 }
104
105 if diff := cmp.Diff(receivedFileMatches, mock.SearchResult.ToProto().GetFiles(), protocmp.Transform()); diff != "" {
106 t.Fatalf("unexpected difference in file matches (-want +got):\n%s", diff)
107 }
108}
109
110func TestFuzzGRPCChunkSender(t *testing.T) {
111 validateResult := func(input zoekt.SearchResult) error {
112 clientStream, serverStream := newPairedSearchStream(t)
113 sender := gRPCChunkSender(serverStream)
114
115 sender.Send(&input)
116
117 allResponses := readAllStream(t, clientStream)
118 if len(allResponses) == 0 {
119 return errors.New("received no responses")
120 }
121
122 expectedResult := input.ToProto()
123
124 for i, receivedResponse := range allResponses {
125 // First, check some invariants about the progress field
126
127 if i == len(allResponses)-1 {
128 // The last response should have the same progress as the original search result
129 if diff := cmp.Diff(expectedResult.GetProgress(), receivedResponse.GetProgress(), protocmp.Transform()); diff != "" {
130 return fmt.Errorf("unexpected difference in progress (-want +got):\n%s", diff)
131 }
132 } else {
133 // All other responses should ensure that the progress' priority is less than the max-pending priority, to
134 // ensure that the client consumes the entire set of chunks
135
136 if receivedResponse.GetProgress().GetPriority() > receivedResponse.GetProgress().GetMaxPendingPriority() {
137 return fmt.Errorf(
138 "received response %d (%s) has priority %.6f, which is greater than the max pending priority %.6f",
139 i, receivedResponse,
140 receivedResponse.GetProgress().GetPriority(), receivedResponse.GetProgress().GetMaxPendingPriority(),
141 )
142 }
143 }
144
145 // Safety, ensure that all other fields are echoed back correctly if the schema ever changes
146 opts := []cmp.Option{
147 protocmp.Transform(),
148 protocmp.IgnoreFields(&webserverv1.SearchResponse{},
149 "progress", // progress is tested above
150 "stats", // aggregated stats are tested below
151 "files", // files are tested separately
152 ),
153 }
154
155 if diff := cmp.Diff(expectedResult, receivedResponse, opts...); diff != "" {
156 return fmt.Errorf("unexpected difference in response fields (-want +got):\n%s", diff)
157 }
158 }
159
160 receivedStats := &zoekt.Stats{}
161
162 var receivedFileMatches []*webserverv1.FileMatch
163 for _, r := range allResponses {
164 receivedStats.Add(zoekt.StatsFromProto(r.GetStats()))
165 receivedFileMatches = append(receivedFileMatches, r.GetFiles()...)
166 }
167
168 // Check to make sure that we get one set of stats back
169 if diff := cmp.Diff(expectedResult.GetStats(), receivedStats.ToProto(),
170 protocmp.Transform(),
171 protocmp.IgnoreFields(&webserverv1.Stats{},
172 "duration", // for whatever the duration field isn't updated when zoekt.Stats.Add is called
173 ),
174 ); diff != "" {
175 return fmt.Errorf("unexpected difference in stats (-want +got):\n%s", diff)
176 }
177
178 // Check to make sure that we get the same set of file matches back
179 if diff := cmp.Diff(expectedResult.GetFiles(), receivedFileMatches,
180 protocmp.Transform(), cmpopts.EquateEmpty()); diff != "" {
181 return fmt.Errorf("unexpected difference in file matches (-want +got):\n%s", diff)
182 }
183
184 return nil
185 }
186
187 var lastErr error
188 if err := quick.Check(func(r zoekt.SearchResult) bool {
189 lastErr = validateResult(r)
190
191 return lastErr == nil
192 }, nil); err != nil {
193 t.Fatal(lastErr.Error())
194 }
195}
196
197// newPairedSearchStream returns a pair of client and server search streams that are connected to each other.
198func newPairedSearchStream(t *testing.T) (webserverv1.WebserverService_StreamSearchClient, webserverv1.WebserverService_StreamSearchServer) {
199 client := &mockSearchStreamClient{t: t}
200 server := &mockSearchStreamServer{t: t, pairedClient: client}
201
202 return client, server
203}
204
205type mockSearchStreamClient struct {
206 t *testing.T
207
208 storedResponses []*webserverv1.StreamSearchResponse
209 index int
210
211 startedReading atomic.Bool
212
213 grpc.ClientStream
214}
215
216func (m *mockSearchStreamClient) Recv() (*webserverv1.StreamSearchResponse, error) {
217 m.startedReading.Store(true)
218
219 if m.index >= len(m.storedResponses) {
220 return nil, io.EOF
221 }
222
223 r := m.storedResponses[m.index]
224 m.index++
225 return r, nil
226}
227
228func (m *mockSearchStreamClient) storeResponse(r *webserverv1.StreamSearchResponse) {
229 if m.startedReading.Load() {
230 m.t.Fatalf("cannot store additional responses after starting to read from stream")
231 }
232
233 m.storedResponses = append(m.storedResponses, r)
234}
235
236type mockSearchStreamServer struct {
237 t *testing.T
238
239 pairedClient *mockSearchStreamClient
240
241 grpc.ServerStream
242}
243
244func (m *mockSearchStreamServer) Send(r *webserverv1.StreamSearchResponse) error {
245 m.pairedClient.storeResponse(r)
246 return nil
247}
248
249var (
250 _ webserverv1.WebserverService_StreamSearchServer = &mockSearchStreamServer{}
251 _ webserverv1.WebserverService_StreamSearchClient = &mockSearchStreamClient{}
252)
253
254func readAllStream(t *testing.T, cs webserverv1.WebserverService_StreamSearchClient) []*webserverv1.SearchResponse {
255 var got []*webserverv1.SearchResponse
256 for { // collect all responses from the stream
257 r, err := cs.Recv()
258 if errors.Is(err, io.EOF) {
259 break
260 }
261
262 if err != nil {
263 t.Fatal(err)
264 }
265
266 got = append(got, r.GetResponseChunk())
267 }
268
269 return got
270}
271
272func mustParse(s string) query.Q {
273 q, err := query.Parse(s)
274 if err != nil {
275 panic(err)
276 }
277 return q
278}
279
280type adapter struct {
281 zoekt.Searcher
282}
283
284func (a adapter) StreamSearch(ctx context.Context, q query.Q, opts *zoekt.SearchOptions, sender zoekt.Sender) (err error) {
285 sr, err := a.Searcher.Search(ctx, q, opts)
286 if err != nil {
287 return err
288 }
289 sender.Send(sr)
290 return nil
291}