fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "path/filepath"
12 "slices"
13 "strings"
14 "testing"
15 "time"
16
17 "github.com/google/go-cmp/cmp"
18 sglog "github.com/sourcegraph/log"
19 "github.com/sourcegraph/log/logtest"
20 "github.com/sourcegraph/zoekt"
21 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1"
22 indexserverv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/zoekt/indexserver/v1"
23 "github.com/sourcegraph/zoekt/index"
24 "github.com/sourcegraph/zoekt/internal/tenant"
25 "github.com/stretchr/testify/require"
26 "github.com/xeipuuv/gojsonschema"
27 "google.golang.org/grpc"
28 "google.golang.org/grpc/codes"
29 "google.golang.org/grpc/status"
30)
31
32func TestServer_defaultArgs(t *testing.T) {
33 root, err := url.Parse("http://api.test")
34 if err != nil {
35 t.Fatal(err)
36 }
37
38 s := &Server{
39 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
40 IndexDir: "/testdata/index",
41 CPUCount: 6,
42 IndexConcurrency: 1,
43 }
44 want := &indexArgs{
45 IndexOptions: IndexOptions{
46 Name: "testName",
47 },
48 IndexDir: "/testdata/index",
49 Parallelism: 6,
50 Incremental: true,
51 FileLimit: 1 << 20,
52 }
53 got := s.indexArgs(IndexOptions{Name: "testName"})
54 if !cmp.Equal(got, want) {
55 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
56 }
57}
58
59func TestIndexNoTenant(t *testing.T) {
60 s := &Server{}
61 _, err := s.index(context.Background(), &indexArgs{})
62 require.ErrorIs(t, err, tenant.ErrMissingTenant)
63}
64
65func TestServer_parallelism(t *testing.T) {
66 root, err := url.Parse("http://api.test")
67 if err != nil {
68 t.Fatal(err)
69 }
70
71 cases := []struct {
72 name string
73 cpuCount int
74 indexConcurrency int
75 options IndexOptions
76 want int
77 }{
78 {
79 name: "CPU count divides evenly",
80 cpuCount: 16,
81 indexConcurrency: 8,
82 want: 2,
83 },
84 {
85 name: "no shard level parallelism",
86 cpuCount: 4,
87 indexConcurrency: 4,
88 want: 1,
89 },
90 {
91 name: "index option overrides server flag",
92 cpuCount: 2,
93 indexConcurrency: 1,
94 options: IndexOptions{
95 ShardConcurrency: 1,
96 },
97 want: 1,
98 },
99 {
100 name: "ignore invalid index option",
101 cpuCount: 8,
102 indexConcurrency: 2,
103 options: IndexOptions{
104 ShardConcurrency: -1,
105 },
106 want: 4,
107 },
108 }
109
110 for _, tt := range cases {
111 t.Run(tt.name, func(t *testing.T) {
112 s := &Server{
113 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
114 IndexDir: "/testdata/index",
115 CPUCount: tt.cpuCount,
116 IndexConcurrency: tt.indexConcurrency,
117 }
118
119 maxProcs := 16
120 got := s.parallelism(tt.options, maxProcs)
121 if tt.want != got {
122 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
123 }
124 })
125 }
126
127 t.Run("index option is limited by available CPU", func(t *testing.T) {
128 s := &Server{
129 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
130 IndexDir: "/testdata/index",
131 IndexConcurrency: 1,
132 }
133
134 got := s.indexArgs(IndexOptions{
135 ShardConcurrency: 2048, // Some number that's way too high
136 })
137
138 if got.Parallelism >= 2048 {
139 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
140 }
141 })
142}
143
144func TestListRepoIDs(t *testing.T) {
145 grpcClient := &mockGRPCClient{}
146
147 clientOptions := []SourcegraphClientOption{
148 WithBatchSize(0),
149 }
150
151 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
152 testHostname := "test-hostname"
153 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...)
154
155 listCalled := false
156 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) {
157 listCalled = true
158
159 gotRepoIDs := in.GetIndexedIds()
160 slices.Sort(gotRepoIDs)
161
162 wantRepoIDs := []int32{1, 3}
163 slices.Sort(wantRepoIDs)
164
165 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
166 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
167 }
168
169 hostname := in.GetHostname()
170 if diff := cmp.Diff(testHostname, hostname); diff != "" {
171 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
172 }
173
174 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
175 }
176
177 ctx := context.Background()
178 got, err := s.List(ctx, []uint32{1, 3})
179 if err != nil {
180 t.Fatal(err)
181 }
182
183 if !listCalled {
184 t.Fatalf("List was not called")
185 }
186
187 receivedRepoIDs := got.IDs
188 slices.Sort(receivedRepoIDs)
189
190 expectedRepoIDs := []uint32{1, 2, 3}
191 slices.Sort(expectedRepoIDs)
192
193 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
194 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
195 }
196}
197
198func TestMain(m *testing.M) {
199 flag.Parse()
200 level := sglog.LevelInfo
201 if !testing.Verbose() {
202 log.SetOutput(io.Discard)
203 debugLog.SetOutput(io.Discard)
204 infoLog.SetOutput(io.Discard)
205 errorLog.SetOutput(io.Discard)
206 level = sglog.LevelNone
207 }
208
209 logtest.InitWithLevel(m, level)
210 os.Exit(m.Run())
211}
212
213func TestCreateEmptyShard(t *testing.T) {
214 dir := t.TempDir()
215
216 args := &indexArgs{
217 IndexOptions: IndexOptions{
218 RepoID: 7,
219 Name: "empty-repo",
220 CloneURL: "code/host",
221 },
222 Incremental: true,
223 IndexDir: dir,
224 Parallelism: 1,
225 FileLimit: 1,
226 }
227
228 if err := createEmptyShard(args); err != nil {
229 t.Fatal(err)
230 }
231
232 bo := args.BuildOptions()
233 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
234
235 if got := bo.IncrementalSkipIndexing(); !got {
236 t.Fatalf("wanted %t, got %t", true, got)
237 }
238}
239
240func TestFormatListUint32(t *testing.T) {
241 cases := []struct {
242 in []uint32
243 want string
244 }{
245 {
246 in: []uint32{42, 8, 3},
247 want: "42, 8, ...",
248 },
249 {
250 in: []uint32{42, 8},
251 want: "42, 8",
252 },
253 {
254 in: []uint32{42},
255 want: "42",
256 },
257 {
258 in: []uint32{},
259 want: "",
260 },
261 }
262
263 for _, tt := range cases {
264 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
265 out := formatListUint32(tt.in, 2)
266 if out != tt.want {
267 t.Fatalf("want %s, got %s", tt.want, out)
268 }
269 })
270 }
271}
272
273func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
274 wd, err := os.Getwd()
275 if err != nil {
276 t.Fatalf("failed to get working directory: %v", err)
277 }
278
279 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
280 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
281
282 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
283
284 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
285 if err != nil {
286 t.Fatalf("failed to validate default service config: %v", err)
287 }
288
289 if !result.Valid() {
290 var errs strings.Builder
291 for _, err := range result.Errors() {
292 errs.WriteString(fmt.Sprintf("- %s\n", err))
293 }
294
295 t.Fatalf("default service config is invalid:\n%s", errs.String())
296 }
297}
298
299func TestGetBoolFromEnvironmentVariables(t *testing.T) {
300 testCases := []struct {
301 name string
302 envVarsToSet map[string]string
303
304 envVarNames []string
305 defaultBool bool
306
307 wantBool bool
308 wantErr bool
309 }{
310 {
311 name: "respect default value: true",
312
313 envVarsToSet: map[string]string{},
314
315 envVarNames: []string{"FOO", "BAR"},
316 defaultBool: true,
317
318 wantBool: true,
319 },
320 {
321 name: "respect default value: false",
322
323 envVarsToSet: map[string]string{},
324
325 envVarNames: []string{"FOO", "BAR"},
326 defaultBool: false,
327
328 wantBool: false,
329 },
330 {
331 name: "read from environment",
332
333 envVarsToSet: map[string]string{"FOO": "1"},
334
335 envVarNames: []string{"FOO"},
336 defaultBool: false,
337
338 wantBool: true,
339 },
340 {
341 name: "read from first env var that is set",
342
343 envVarsToSet: map[string]string{
344 "BAR": "false",
345 "BAZ": "true",
346 },
347
348 envVarNames: []string{"FOO", "BAR", "BAZ"},
349 defaultBool: true,
350
351 wantBool: false,
352 },
353
354 {
355 name: "should error for invalid input",
356
357 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
358
359 envVarNames: []string{"INVALID"},
360 defaultBool: false,
361
362 wantErr: true,
363 },
364 }
365
366 for _, tc := range testCases {
367 t.Run("", func(t *testing.T) {
368 // Prepare the environment by loading all the appropriate environment variables
369 for _, v := range tc.envVarNames {
370 _ = os.Unsetenv(v)
371 }
372
373 for k := range tc.envVarsToSet {
374 _ = os.Unsetenv(k)
375 }
376
377 for k, v := range tc.envVarsToSet {
378 t.Setenv(k, v)
379 }
380
381 // Run the test
382 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
383
384 // Examine the results
385 if tc.wantErr != (err != nil) {
386 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
387 }
388
389 if got != tc.wantBool {
390 t.Errorf("got %v, want %v", got, tc.wantBool)
391 }
392 })
393 }
394}
395
396func TestAddDefaultPort(t *testing.T) {
397 tests := []struct {
398 name string
399 input string
400 want string
401 }{
402 {
403 name: "http no port",
404 input: "http://example.com",
405 want: "http://example.com:80",
406 },
407 {
408 name: "http custom port",
409 input: "http://example.com:90",
410 want: "http://example.com:90",
411 },
412 {
413 name: "https no port",
414 input: "https://example.com",
415 want: "https://example.com:443",
416 },
417 {
418 name: "https custom port",
419 input: "https://example.com:444",
420 want: "https://example.com:444",
421 },
422 {
423 name: "non-http scheme",
424 input: "ftp://example.com",
425 want: "ftp://example.com",
426 },
427 {
428 name: "empty string",
429 input: "",
430 want: "",
431 },
432 {
433 name: "local file path",
434 input: "/etc/hosts",
435 want: "/etc/hosts",
436 },
437 }
438
439 for _, test := range tests {
440 t.Run(test.name, func(t *testing.T) {
441 input, err := url.Parse(test.input)
442 if err != nil {
443 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
444 }
445
446 got := addDefaultPort(input)
447 if diff := cmp.Diff(test.want, got.String()); diff != "" {
448 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
449 }
450 })
451 }
452}
453
454func TestIndexGRPC(t *testing.T) {
455 indexDir := t.TempDir()
456
457 // Minimal server setup
458 s := &Server{
459 logger: logtest.NoOp(t),
460 IndexDir: indexDir,
461 rootURL: &url.URL{Scheme: "http", Host: "example.com"},
462 indexSemaphore: make(chan struct{}, 1),
463 timeout: time.Hour, // no timeout
464 }
465
466 branches := []*configv1.ZoektRepositoryBranch{
467 {
468 Name: "HEAD",
469 Version: "abc123",
470 },
471 }
472
473 req := &indexserverv1.IndexRequest{
474 Options: &configv1.ZoektIndexOptions{
475 RepoId: 42,
476 Name: "repo",
477 TenantId: 1,
478 Branches: branches,
479 },
480 }
481
482 resp, err := s.indexGRPC(context.Background(), req, mockIndexFunc(t))
483 require.NoError(t, err)
484 require.Equal(t, &indexserverv1.IndexResponse{
485 RepoId: 42,
486 Branches: branches,
487 IndexTimeUnix: resp.IndexTimeUnix, // Hack: this changes every time so we don't check it
488 }, resp)
489
490 require.NotZero(t, resp.IndexTimeUnix)
491}
492
493func TestIndexGRPC_Timeout(t *testing.T) {
494 indexDir := t.TempDir()
495
496 s := &Server{
497 logger: logtest.NoOp(t),
498 IndexDir: indexDir,
499 IndexConcurrency: 0, // impossible to acquire index slot
500 timeout: time.Millisecond,
501 }
502
503 req := &indexserverv1.IndexRequest{
504 Options: &configv1.ZoektIndexOptions{
505 RepoId: 42,
506 Name: "repo",
507 },
508 }
509
510 // use context.Background() to make sure we don't return because of context cancellation
511 _, err := s.indexGRPC(context.Background(), req, mockIndexFunc(t))
512 require.Error(t, err)
513 require.Equal(t, codes.DeadlineExceeded, status.Code(err))
514}
515
516func TestDelete(t *testing.T) {
517 indexDir := t.TempDir()
518 trashDir := filepath.Join(indexDir, ".trash")
519 if err := os.MkdirAll(trashDir, 0o755); err != nil {
520 t.Fatal(err)
521 }
522
523 // Create a simple shard
524 createShard(t, indexDir)
525
526 // Verify the shard exists in index dir
527 shards := getShards(indexDir)
528 if len(shards) != 1 {
529 t.Fatalf("expected 1 shard, got %d", len(shards))
530 }
531
532 // Create server and call Delete
533 s := &Server{
534 logger: logtest.NoOp(t),
535 IndexDir: indexDir,
536 rootURL: &url.URL{Scheme: "http", Host: "example.com"},
537 indexSemaphore: make(chan struct{}, 1),
538 timeout: time.Hour, // no timeout
539 }
540
541 req := &indexserverv1.DeleteRequest{
542 RepoIds: []uint32{42}, // matches the repo ID in createShard
543 }
544
545 // Test case: context is canceled
546 cancledCtx, cancel := context.WithCancel(context.Background())
547 cancel()
548 _, err := s.Delete(cancledCtx, req)
549 require.Error(t, err)
550
551 shards = getShards(indexDir)
552 require.Len(t, shards, 1)
553
554 // Test case: context is not canceled
555 _, err = s.Delete(context.Background(), req)
556 require.NoError(t, err)
557
558 // Verify shard was moved to trash
559 trashShards := getShards(trashDir)
560 require.Len(t, trashShards, 1)
561
562 // Verify shard is no longer in index dir
563 shards = getShards(indexDir)
564 require.Len(t, shards, 0)
565}
566
567func mockIndexFunc(t *testing.T) func(ctx context.Context, args *indexArgs) (indexState, error) {
568 return func(ctx context.Context, args *indexArgs) (indexState, error) {
569 createShard(t, args.IndexDir)
570 return indexStateSuccess, nil
571 }
572}
573
574func createShard(t *testing.T, dir string) {
575 opts := index.Options{
576 IndexDir: dir,
577 RepositoryDescription: zoekt.Repository{
578 ID: 42,
579 Name: "repo",
580 Branches: []zoekt.RepositoryBranch{
581 {
582 Name: "HEAD",
583 Version: "abc123",
584 },
585 },
586 },
587 }
588
589 b, err := index.NewBuilder(opts)
590 require.NoError(t, err)
591 require.NoError(t, b.AddFile("test.txt", []byte("hello")))
592 require.NoError(t, b.Finish())
593}
594
595func TestRecoverFromTrash(t *testing.T) {
596 dir := t.TempDir()
597 trashDir := filepath.Join(dir, ".trash")
598 require.NoError(t, os.MkdirAll(trashDir, 0o755))
599
600 // Create a simple shard in trash
601 createTestShard(t, "repo1", 1, filepath.Join(trashDir, "repo1.zoekt"))
602
603 // Create a compound shard with two repos, one of them tombstoned
604 cs := createCompoundShard(t, dir, []uint32{2, 3})
605 require.NoError(t, index.SetTombstone(cs, 2))
606
607 s := &Server{
608 IndexDir: dir,
609 }
610
611 // Test recovering from trash
612 recovered := s.recoverFromTrash(1)
613 require.True(t, recovered, "should have recovered repo1 from trash")
614
615 // Verify shard was moved from trash to index
616 indexShards := getShards(dir)
617 trashShards := getShards(trashDir)
618
619 require.Contains(t, indexShards, uint32(1), "repo1 should be in index")
620 require.NotContains(t, trashShards, uint32(1), "repo1 should not be in trash")
621
622 // Test unsetting tombstone
623 recovered = s.recoverFromTrash(2)
624 require.True(t, recovered, "should have recovered repo2 from tombstone")
625
626 // Verify tombstone was unset
627 repos, _, err := index.ReadMetadataPath(cs)
628 require.NoError(t, err)
629
630 for _, repo := range repos {
631 if repo.ID == 2 {
632 require.False(t, repo.Tombstone, "repo2 should not be tombstoned")
633 }
634 }
635
636 // Test non-existent repo
637 recovered = s.recoverFromTrash(99)
638 require.False(t, recovered, "should not have recovered non-existent repo")
639}