fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/http"
10 "net/http/httptest"
11 "net/url"
12 "os"
13 "path/filepath"
14 "sort"
15 "strings"
16 "testing"
17
18 sglog "github.com/sourcegraph/log"
19 "github.com/sourcegraph/log/logtest"
20 "github.com/xeipuuv/gojsonschema"
21 "google.golang.org/grpc"
22
23 "github.com/google/go-cmp/cmp"
24
25 "github.com/sourcegraph/zoekt"
26 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1"
27)
28
29func TestServer_defaultArgs(t *testing.T) {
30 root, err := url.Parse("http://api.test")
31 if err != nil {
32 t.Fatal(err)
33 }
34
35 s := &Server{
36 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)),
37 IndexDir: "/testdata/index",
38 CPUCount: 6,
39 IndexConcurrency: 1,
40 }
41 want := &indexArgs{
42 IndexOptions: IndexOptions{
43 Name: "testName",
44 },
45 IndexDir: "/testdata/index",
46 Parallelism: 6,
47 Incremental: true,
48 FileLimit: 1 << 20,
49 }
50 got := s.indexArgs(IndexOptions{Name: "testName"})
51 if !cmp.Equal(got, want) {
52 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
53 }
54}
55
56func TestServer_parallelism(t *testing.T) {
57 root, err := url.Parse("http://api.test")
58 if err != nil {
59 t.Fatal(err)
60 }
61
62 cases := []struct {
63 name string
64 cpuCount int
65 indexConcurrency int
66 options IndexOptions
67 want int
68 }{
69 {
70 name: "CPU count divides evenly",
71 cpuCount: 16,
72 indexConcurrency: 8,
73 want: 2,
74 },
75 {
76 name: "no shard level parallelism",
77 cpuCount: 4,
78 indexConcurrency: 4,
79 want: 1,
80 },
81 {
82 name: "index option overrides server flag",
83 cpuCount: 2,
84 indexConcurrency: 1,
85 options: IndexOptions{
86 ShardConcurrency: 1,
87 },
88 want: 1,
89 },
90 {
91 name: "ignore invalid index option",
92 cpuCount: 8,
93 indexConcurrency: 2,
94 options: IndexOptions{
95 ShardConcurrency: -1,
96 },
97 want: 4,
98 },
99 }
100
101 for _, tt := range cases {
102 t.Run(tt.name, func(t *testing.T) {
103 s := &Server{
104 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)),
105 IndexDir: "/testdata/index",
106 CPUCount: tt.cpuCount,
107 IndexConcurrency: tt.indexConcurrency,
108 }
109
110 maxProcs := 16
111 got := s.parallelism(tt.options, maxProcs)
112 if tt.want != got {
113 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
114 }
115 })
116 }
117
118 t.Run("index option is limited by available CPU", func(t *testing.T) {
119 s := &Server{
120 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)),
121 IndexDir: "/testdata/index",
122 IndexConcurrency: 1,
123 }
124
125 got := s.indexArgs(IndexOptions{
126 ShardConcurrency: 2048, // Some number that's way too high
127 })
128
129 if got.Parallelism >= 2048 {
130 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
131 }
132 })
133}
134
135func TestListRepoIDs(t *testing.T) {
136 t.Run("gRPC", func(t *testing.T) {
137 grpcClient := &mockGRPCClient{}
138
139 clientOptions := []SourcegraphClientOption{
140 WithShouldUseGRPC(true),
141 WithGRPCClient(grpcClient),
142 WithBatchSize(0),
143 }
144
145 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
146 testHostname := "test-hostname"
147 s := newSourcegraphClient(&testURL, testHostname, clientOptions...)
148
149 listCalled := false
150 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) {
151 listCalled = true
152
153 gotRepoIDs := in.GetIndexedIds()
154 sort.Slice(gotRepoIDs, func(i, j int) bool {
155 return gotRepoIDs[i] < gotRepoIDs[j]
156 })
157
158 wantRepoIDs := []int32{1, 3}
159 sort.Slice(wantRepoIDs, func(i, j int) bool {
160 return wantRepoIDs[i] < wantRepoIDs[j]
161 })
162
163 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
164 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
165 }
166
167 hostname := in.GetHostname()
168 if diff := cmp.Diff(testHostname, hostname); diff != "" {
169 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
170 }
171
172 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
173 }
174
175 ctx := context.Background()
176 got, err := s.List(ctx, []uint32{1, 3})
177 if err != nil {
178 t.Fatal(err)
179 }
180
181 if !listCalled {
182 t.Fatalf("List was not called")
183 }
184
185 receivedRepoIDs := got.IDs
186 sort.Slice(receivedRepoIDs, func(i, j int) bool {
187 return receivedRepoIDs[i] < receivedRepoIDs[j]
188 })
189
190 expectedRepoIDs := []uint32{1, 2, 3}
191 sort.Slice(expectedRepoIDs, func(i, j int) bool {
192 return expectedRepoIDs[i] < expectedRepoIDs[j]
193 })
194
195 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
196 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
197 }
198 })
199
200 t.Run("REST", func(t *testing.T) {
201 var gotBody string
202 var gotURL *url.URL
203 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
204 gotURL = r.URL
205
206 b, err := io.ReadAll(r.Body)
207 if err != nil {
208 t.Fatal(err)
209 }
210 gotBody = string(b)
211
212 _, err = w.Write([]byte(`{"RepoIDs": [1, 2, 3]}`))
213 if err != nil {
214 t.Fatal(err)
215 }
216 }))
217 defer ts.Close()
218
219 u, err := url.Parse(ts.URL)
220 if err != nil {
221 t.Fatal(err)
222 }
223
224 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0))
225
226 gotRepos, err := s.List(context.Background(), []uint32{1, 3})
227 if err != nil {
228 t.Fatal(err)
229 }
230
231 if want := []uint32{1, 2, 3}; !cmp.Equal(gotRepos.IDs, want) {
232 t.Errorf("repos mismatch (-want +got):\n%s", cmp.Diff(want, gotRepos.IDs))
233 }
234 if want := `{"Hostname":"test-indexed-search-1","IndexedIDs":[1,3]}`; gotBody != want {
235 t.Errorf("body mismatch (-want +got):\n%s", cmp.Diff(want, gotBody))
236 }
237 if want := "/.internal/repos/index"; gotURL.Path != want {
238 t.Errorf("request path mismatch (-want +got):\n%s", cmp.Diff(want, gotURL.Path))
239 }
240 })
241}
242
243func TestListRepoIDs_Error_REST(t *testing.T) {
244 // Note: There is no gRPC equivalent to this test because gRPC errors are
245 // always returned as an error to the caller.
246
247 msg := "deadbeaf deadbeaf"
248 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
249 // This is how Sourcegraph returns error messages to the caller.
250 http.Error(w, msg, http.StatusInternalServerError)
251 }))
252 defer ts.Close()
253
254 u, err := url.Parse(ts.URL)
255 if err != nil {
256 t.Fatal(err)
257 }
258
259 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0))
260 s.restClient.RetryMax = 0
261
262 _, err = s.List(context.Background(), []uint32{1, 3})
263
264 if !strings.Contains(err.Error(), msg) {
265 t.Fatalf("%s does not contain %s", err.Error(), msg)
266 }
267}
268
269func TestMain(m *testing.M) {
270 flag.Parse()
271 level := sglog.LevelInfo
272 if !testing.Verbose() {
273 log.SetOutput(io.Discard)
274 level = sglog.LevelNone
275 }
276
277 logtest.InitWithLevel(m, level)
278 os.Exit(m.Run())
279}
280
281func TestCreateEmptyShard(t *testing.T) {
282 dir := t.TempDir()
283
284 args := &indexArgs{
285 IndexOptions: IndexOptions{
286 RepoID: 7,
287 Name: "empty-repo",
288 CloneURL: "code/host",
289 },
290 Incremental: true,
291 IndexDir: dir,
292 Parallelism: 1,
293 FileLimit: 1,
294 }
295
296 if err := createEmptyShard(args); err != nil {
297 t.Fatal(err)
298 }
299
300 bo := args.BuildOptions()
301 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
302
303 if got := bo.IncrementalSkipIndexing(); !got {
304 t.Fatalf("wanted %t, got %t", true, got)
305 }
306}
307
308func TestFormatListUint32(t *testing.T) {
309 cases := []struct {
310 in []uint32
311 want string
312 }{
313 {
314 in: []uint32{42, 8, 3},
315 want: "42, 8, ...",
316 },
317 {
318 in: []uint32{42, 8},
319 want: "42, 8",
320 },
321 {
322 in: []uint32{42},
323 want: "42",
324 },
325 {
326 in: []uint32{},
327 want: "",
328 },
329 }
330
331 for _, tt := range cases {
332 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
333 out := formatListUint32(tt.in, 2)
334 if out != tt.want {
335 t.Fatalf("want %s, got %s", tt.want, out)
336 }
337 })
338 }
339}
340
341func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
342 wd, err := os.Getwd()
343 if err != nil {
344 t.Fatalf("failed to get working directory: %v", err)
345 }
346
347 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
348 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
349
350 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
351
352 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
353 if err != nil {
354 t.Fatalf("failed to validate default service config: %v", err)
355 }
356
357 if !result.Valid() {
358 var errs strings.Builder
359 for _, err := range result.Errors() {
360 errs.WriteString(fmt.Sprintf("- %s\n", err))
361 }
362
363 t.Fatalf("default service config is invalid:\n%s", errs.String())
364 }
365}
366
367func TestGetBoolFromEnvironmentVariables(t *testing.T) {
368 testCases := []struct {
369 name string
370 envVarsToSet map[string]string
371
372 envVarNames []string
373 defaultBool bool
374
375 wantBool bool
376 wantErr bool
377 }{
378 {
379 name: "respect default value: true",
380
381 envVarsToSet: map[string]string{},
382
383 envVarNames: []string{"FOO", "BAR"},
384 defaultBool: true,
385
386 wantBool: true,
387 },
388 {
389 name: "respect default value: false",
390
391 envVarsToSet: map[string]string{},
392
393 envVarNames: []string{"FOO", "BAR"},
394 defaultBool: false,
395
396 wantBool: false,
397 },
398 {
399 name: "read from environment",
400
401 envVarsToSet: map[string]string{"FOO": "1"},
402
403 envVarNames: []string{"FOO"},
404 defaultBool: false,
405
406 wantBool: true,
407 },
408 {
409 name: "read from first env var that is set",
410
411 envVarsToSet: map[string]string{
412 "BAR": "false",
413 "BAZ": "true",
414 },
415
416 envVarNames: []string{"FOO", "BAR", "BAZ"},
417 defaultBool: true,
418
419 wantBool: false,
420 },
421
422 {
423 name: "should error for invalid input",
424
425 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
426
427 envVarNames: []string{"INVALID"},
428 defaultBool: false,
429
430 wantErr: true,
431 },
432 }
433
434 for _, tc := range testCases {
435 t.Run("", func(t *testing.T) {
436 // Prepare the environment by loading all the appropriate environment variables
437 for _, v := range tc.envVarNames {
438 _ = os.Unsetenv(v)
439 }
440
441 for k := range tc.envVarsToSet {
442 _ = os.Unsetenv(k)
443 }
444
445 for k, v := range tc.envVarsToSet {
446 t.Setenv(k, v)
447 }
448
449 // Run the test
450 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
451
452 // Examine the results
453 if tc.wantErr != (err != nil) {
454 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
455 }
456
457 if got != tc.wantBool {
458 t.Errorf("got %v, want %v", got, tc.wantBool)
459 }
460 })
461 }
462}
463
464func TestAddDefaultPort(t *testing.T) {
465 tests := []struct {
466 name string
467 input string
468 want string
469 }{
470 {
471 name: "http no port",
472 input: "http://example.com",
473 want: "http://example.com:80",
474 },
475 {
476 name: "http custom port",
477 input: "http://example.com:90",
478 want: "http://example.com:90",
479 },
480 {
481 name: "https no port",
482 input: "https://example.com",
483 want: "https://example.com:443",
484 },
485 {
486 name: "https custom port",
487 input: "https://example.com:444",
488 want: "https://example.com:444",
489 },
490 {
491 name: "non-http scheme",
492 input: "ftp://example.com",
493 want: "ftp://example.com",
494 },
495 {
496 name: "empty string",
497 input: "",
498 want: "",
499 },
500 {
501 name: "local file path",
502 input: "/etc/hosts",
503 want: "/etc/hosts",
504 },
505 }
506
507 for _, test := range tests {
508 t.Run(test.name, func(t *testing.T) {
509 input, err := url.Parse(test.input)
510 if err != nil {
511 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
512 }
513
514 got := addDefaultPort(input)
515 if diff := cmp.Diff(test.want, got.String()); diff != "" {
516 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
517 }
518 })
519 }
520}