fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/http"
10 "net/http/httptest"
11 "net/url"
12 "os"
13 "path/filepath"
14 "sort"
15 "strings"
16 "testing"
17
18 sglog "github.com/sourcegraph/log"
19 "github.com/sourcegraph/log/logtest"
20 "github.com/xeipuuv/gojsonschema"
21 "google.golang.org/grpc"
22
23 "github.com/google/go-cmp/cmp"
24
25 "github.com/sourcegraph/zoekt"
26 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1"
27)
28
29func TestServer_defaultArgs(t *testing.T) {
30 root, err := url.Parse("http://api.test")
31 if err != nil {
32 t.Fatal(err)
33 }
34
35 s := &Server{
36 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)),
37 IndexDir: "/testdata/index",
38 CPUCount: 6,
39 }
40 want := &indexArgs{
41 IndexOptions: IndexOptions{
42 Name: "testName",
43 },
44 IndexDir: "/testdata/index",
45 Parallelism: 6,
46 Incremental: true,
47 FileLimit: 1 << 20,
48 }
49 got := s.indexArgs(IndexOptions{Name: "testName"})
50 if !cmp.Equal(got, want) {
51 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
52 }
53}
54
55func TestListRepoIDs(t *testing.T) {
56 t.Run("gRPC", func(t *testing.T) {
57
58 grpcClient := &mockGRPCClient{}
59
60 clientOptions := []SourcegraphClientOption{
61 WithShouldUseGRPC(true),
62 WithGRPCClient(grpcClient),
63 WithBatchSize(0),
64 }
65
66 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
67 testHostname := "test-hostname"
68 s := newSourcegraphClient(&testURL, testHostname, clientOptions...)
69
70 listCalled := false
71 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) {
72 listCalled = true
73
74 gotRepoIDs := in.GetIndexedIds()
75 sort.Slice(gotRepoIDs, func(i, j int) bool {
76 return gotRepoIDs[i] < gotRepoIDs[j]
77 })
78
79 wantRepoIDs := []int32{1, 3}
80 sort.Slice(wantRepoIDs, func(i, j int) bool {
81 return wantRepoIDs[i] < wantRepoIDs[j]
82 })
83
84 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
85 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
86 }
87
88 hostname := in.GetHostname()
89 if diff := cmp.Diff(testHostname, hostname); diff != "" {
90 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
91 }
92
93 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
94 }
95
96 ctx := context.Background()
97 got, err := s.List(ctx, []uint32{1, 3})
98 if err != nil {
99 t.Fatal(err)
100 }
101
102 if !listCalled {
103 t.Fatalf("List was not called")
104 }
105
106 receivedRepoIDs := got.IDs
107 sort.Slice(receivedRepoIDs, func(i, j int) bool {
108 return receivedRepoIDs[i] < receivedRepoIDs[j]
109 })
110
111 expectedRepoIDs := []uint32{1, 2, 3}
112 sort.Slice(expectedRepoIDs, func(i, j int) bool {
113 return expectedRepoIDs[i] < expectedRepoIDs[j]
114 })
115
116 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
117 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
118 }
119 })
120
121 t.Run("REST", func(t *testing.T) {
122 var gotBody string
123 var gotURL *url.URL
124 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
125 gotURL = r.URL
126
127 b, err := io.ReadAll(r.Body)
128 if err != nil {
129 t.Fatal(err)
130 }
131 gotBody = string(b)
132
133 _, err = w.Write([]byte(`{"RepoIDs": [1, 2, 3]}`))
134 if err != nil {
135 t.Fatal(err)
136 }
137 }))
138 defer ts.Close()
139
140 u, err := url.Parse(ts.URL)
141 if err != nil {
142 t.Fatal(err)
143 }
144
145 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0))
146
147 gotRepos, err := s.List(context.Background(), []uint32{1, 3})
148 if err != nil {
149 t.Fatal(err)
150 }
151
152 if want := []uint32{1, 2, 3}; !cmp.Equal(gotRepos.IDs, want) {
153 t.Errorf("repos mismatch (-want +got):\n%s", cmp.Diff(want, gotRepos.IDs))
154 }
155 if want := `{"Hostname":"test-indexed-search-1","IndexedIDs":[1,3]}`; gotBody != want {
156 t.Errorf("body mismatch (-want +got):\n%s", cmp.Diff(want, gotBody))
157 }
158 if want := "/.internal/repos/index"; gotURL.Path != want {
159 t.Errorf("request path mismatch (-want +got):\n%s", cmp.Diff(want, gotURL.Path))
160 }
161 })
162}
163
164func TestListRepoIDs_Error_REST(t *testing.T) {
165 // Note: There is no gRPC equivalent to this test because gRPC errors are
166 // always returned as an error to the caller.
167
168 msg := "deadbeaf deadbeaf"
169 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
170
171 // This is how Sourcegraph returns error messages to the caller.
172 http.Error(w, msg, http.StatusInternalServerError)
173 }))
174 defer ts.Close()
175
176 u, err := url.Parse(ts.URL)
177 if err != nil {
178 t.Fatal(err)
179 }
180
181 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0))
182 s.restClient.RetryMax = 0
183
184 _, err = s.List(context.Background(), []uint32{1, 3})
185
186 if !strings.Contains(err.Error(), msg) {
187 t.Fatalf("%s does not contain %s", err.Error(), msg)
188 }
189}
190
191func TestMain(m *testing.M) {
192 flag.Parse()
193 level := sglog.LevelInfo
194 if !testing.Verbose() {
195 log.SetOutput(io.Discard)
196 level = sglog.LevelNone
197 }
198
199 logtest.InitWithLevel(m, level)
200 os.Exit(m.Run())
201}
202
203func TestCreateEmptyShard(t *testing.T) {
204 dir := t.TempDir()
205
206 args := &indexArgs{
207 IndexOptions: IndexOptions{
208 RepoID: 7,
209 Name: "empty-repo",
210 CloneURL: "code/host",
211 },
212 Incremental: true,
213 IndexDir: dir,
214 Parallelism: 1,
215 FileLimit: 1,
216 }
217
218 if err := createEmptyShard(args); err != nil {
219 t.Fatal(err)
220 }
221
222 bo := args.BuildOptions()
223 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
224
225 if got := bo.IncrementalSkipIndexing(); !got {
226 t.Fatalf("wanted %t, got %t", true, got)
227 }
228}
229
230func TestFormatListUint32(t *testing.T) {
231 cases := []struct {
232 in []uint32
233 want string
234 }{
235 {
236 in: []uint32{42, 8, 3},
237 want: "42, 8, ...",
238 },
239 {
240 in: []uint32{42, 8},
241 want: "42, 8",
242 },
243 {
244 in: []uint32{42},
245 want: "42",
246 },
247 {
248 in: []uint32{},
249 want: "",
250 },
251 }
252
253 for _, tt := range cases {
254 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
255 out := formatListUint32(tt.in, 2)
256 if out != tt.want {
257 t.Fatalf("want %s, got %s", tt.want, out)
258 }
259 })
260 }
261}
262
263func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
264 wd, err := os.Getwd()
265 if err != nil {
266 t.Fatalf("failed to get working directory: %v", err)
267 }
268
269 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
270 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
271
272 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
273
274 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
275 if err != nil {
276 t.Fatalf("failed to validate default service config: %v", err)
277 }
278
279 if !result.Valid() {
280 var errs strings.Builder
281 for _, err := range result.Errors() {
282 errs.WriteString(fmt.Sprintf("- %s\n", err))
283 }
284
285 t.Fatalf("default service config is invalid:\n%s", errs.String())
286 }
287}
288
289func TestGetBoolFromEnvironmentVariables(t *testing.T) {
290 testCases := []struct {
291 name string
292 envVarsToSet map[string]string
293
294 envVarNames []string
295 defaultBool bool
296
297 wantBool bool
298 wantErr bool
299 }{
300 {
301 name: "respect default value: true",
302
303 envVarsToSet: map[string]string{},
304
305 envVarNames: []string{"FOO", "BAR"},
306 defaultBool: true,
307
308 wantBool: true,
309 },
310 {
311 name: "respect default value: false",
312
313 envVarsToSet: map[string]string{},
314
315 envVarNames: []string{"FOO", "BAR"},
316 defaultBool: false,
317
318 wantBool: false,
319 },
320 {
321 name: "read from environment",
322
323 envVarsToSet: map[string]string{"FOO": "1"},
324
325 envVarNames: []string{"FOO"},
326 defaultBool: false,
327
328 wantBool: true,
329 },
330 {
331 name: "read from first env var that is set",
332
333 envVarsToSet: map[string]string{
334 "BAR": "false",
335 "BAZ": "true",
336 },
337
338 envVarNames: []string{"FOO", "BAR", "BAZ"},
339 defaultBool: true,
340
341 wantBool: false,
342 },
343
344 {
345 name: "should error for invalid input",
346
347 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
348
349 envVarNames: []string{"INVALID"},
350 defaultBool: false,
351
352 wantErr: true,
353 },
354 }
355
356 for _, tc := range testCases {
357 t.Run("", func(t *testing.T) {
358 // Prepare the environment by loading all the appropriate environment variables
359 for _, v := range tc.envVarNames {
360 _ = os.Unsetenv(v)
361 }
362
363 for k, _ := range tc.envVarsToSet {
364 _ = os.Unsetenv(k)
365 }
366
367 for k, v := range tc.envVarsToSet {
368 t.Setenv(k, v)
369 }
370
371 // Run the test
372 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
373
374 // Examine the results
375 if tc.wantErr != (err != nil) {
376 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
377 }
378
379 if got != tc.wantBool {
380 t.Errorf("got %v, want %v", got, tc.wantBool)
381 }
382 })
383 }
384}
385
386func TestAddDefaultPort(t *testing.T) {
387 tests := []struct {
388 name string
389 input string
390 want string
391 }{
392 {
393 name: "http no port",
394 input: "http://example.com",
395 want: "http://example.com:80",
396 },
397 {
398 name: "http custom port",
399 input: "http://example.com:90",
400 want: "http://example.com:90",
401 },
402 {
403 name: "https no port",
404 input: "https://example.com",
405 want: "https://example.com:443",
406 },
407 {
408 name: "https custom port",
409 input: "https://example.com:444",
410 want: "https://example.com:444",
411 },
412 {
413 name: "non-http scheme",
414 input: "ftp://example.com",
415 want: "ftp://example.com",
416 },
417 {
418 name: "empty string",
419 input: "",
420 want: "",
421 },
422 {
423 name: "local file path",
424 input: "/etc/hosts",
425 want: "/etc/hosts",
426 },
427 }
428
429 for _, test := range tests {
430 t.Run(test.name, func(t *testing.T) {
431 input, err := url.Parse(test.input)
432 if err != nil {
433 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
434 }
435
436 got := addDefaultPort(input)
437 if diff := cmp.Diff(test.want, got.String()); diff != "" {
438 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
439 }
440 })
441 }
442}