fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "path/filepath"
12 "slices"
13 "strings"
14 "testing"
15 "time"
16
17 "github.com/google/go-cmp/cmp"
18 sglog "github.com/sourcegraph/log"
19 "github.com/sourcegraph/log/logtest"
20 "github.com/stretchr/testify/require"
21 "github.com/xeipuuv/gojsonschema"
22 "google.golang.org/grpc"
23 "google.golang.org/grpc/codes"
24 "google.golang.org/grpc/status"
25
26 "github.com/sourcegraph/zoekt"
27 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1"
28 indexserverv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/zoekt/indexserver/v1"
29 "github.com/sourcegraph/zoekt/index"
30 "github.com/sourcegraph/zoekt/internal/tenant"
31)
32
33func TestServer_defaultArgs(t *testing.T) {
34 root, err := url.Parse("http://api.test")
35 if err != nil {
36 t.Fatal(err)
37 }
38
39 s := &Server{
40 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
41 IndexDir: "/testdata/index",
42 CPUCount: 6,
43 IndexConcurrency: 1,
44 }
45 want := &indexArgs{
46 IndexOptions: IndexOptions{
47 Name: "testName",
48 },
49 IndexDir: "/testdata/index",
50 Parallelism: 6,
51 Incremental: true,
52 FileLimit: 1 << 20,
53 }
54 got := s.indexArgs(IndexOptions{Name: "testName"})
55 if !cmp.Equal(got, want) {
56 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
57 }
58}
59
60func TestIndexNoTenant(t *testing.T) {
61 s := &Server{}
62 _, err := s.index(context.Background(), &indexArgs{})
63 require.ErrorIs(t, err, tenant.ErrMissingTenant)
64}
65
66func TestServer_parallelism(t *testing.T) {
67 root, err := url.Parse("http://api.test")
68 if err != nil {
69 t.Fatal(err)
70 }
71
72 cases := []struct {
73 name string
74 cpuCount int
75 indexConcurrency int
76 options IndexOptions
77 want int
78 }{
79 {
80 name: "CPU count divides evenly",
81 cpuCount: 16,
82 indexConcurrency: 8,
83 want: 2,
84 },
85 {
86 name: "no shard level parallelism",
87 cpuCount: 4,
88 indexConcurrency: 4,
89 want: 1,
90 },
91 {
92 name: "index option overrides server flag",
93 cpuCount: 2,
94 indexConcurrency: 1,
95 options: IndexOptions{
96 ShardConcurrency: 1,
97 },
98 want: 1,
99 },
100 {
101 name: "ignore invalid index option",
102 cpuCount: 8,
103 indexConcurrency: 2,
104 options: IndexOptions{
105 ShardConcurrency: -1,
106 },
107 want: 4,
108 },
109 }
110
111 for _, tt := range cases {
112 t.Run(tt.name, func(t *testing.T) {
113 s := &Server{
114 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
115 IndexDir: "/testdata/index",
116 CPUCount: tt.cpuCount,
117 IndexConcurrency: tt.indexConcurrency,
118 }
119
120 maxProcs := 16
121 got := s.parallelism(tt.options, maxProcs)
122 if tt.want != got {
123 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
124 }
125 })
126 }
127
128 t.Run("index option is limited by available CPU", func(t *testing.T) {
129 s := &Server{
130 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
131 IndexDir: "/testdata/index",
132 IndexConcurrency: 1,
133 }
134
135 got := s.indexArgs(IndexOptions{
136 ShardConcurrency: 2048, // Some number that's way too high
137 })
138
139 if got.Parallelism >= 2048 {
140 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
141 }
142 })
143}
144
145func TestListRepoIDs(t *testing.T) {
146 grpcClient := &mockGRPCClient{}
147
148 clientOptions := []SourcegraphClientOption{
149 WithBatchSize(0),
150 }
151
152 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
153 testHostname := "test-hostname"
154 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...)
155
156 listCalled := false
157 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) {
158 listCalled = true
159
160 gotRepoIDs := in.GetIndexedIds()
161 slices.Sort(gotRepoIDs)
162
163 wantRepoIDs := []int32{1, 3}
164 slices.Sort(wantRepoIDs)
165
166 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
167 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
168 }
169
170 hostname := in.GetHostname()
171 if diff := cmp.Diff(testHostname, hostname); diff != "" {
172 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
173 }
174
175 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
176 }
177
178 ctx := context.Background()
179 got, err := s.List(ctx, []uint32{1, 3})
180 if err != nil {
181 t.Fatal(err)
182 }
183
184 if !listCalled {
185 t.Fatalf("List was not called")
186 }
187
188 receivedRepoIDs := got.IDs
189 slices.Sort(receivedRepoIDs)
190
191 expectedRepoIDs := []uint32{1, 2, 3}
192 slices.Sort(expectedRepoIDs)
193
194 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
195 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
196 }
197}
198
199func TestMain(m *testing.M) {
200 flag.Parse()
201 level := sglog.LevelInfo
202 if !testing.Verbose() {
203 log.SetOutput(io.Discard)
204 debugLog.SetOutput(io.Discard)
205 infoLog.SetOutput(io.Discard)
206 errorLog.SetOutput(io.Discard)
207 level = sglog.LevelNone
208 }
209
210 logtest.InitWithLevel(m, level)
211 os.Exit(m.Run())
212}
213
214func TestCreateEmptyShard(t *testing.T) {
215 dir := t.TempDir()
216
217 args := &indexArgs{
218 IndexOptions: IndexOptions{
219 RepoID: 7,
220 Name: "empty-repo",
221 CloneURL: "code/host",
222 },
223 Incremental: true,
224 IndexDir: dir,
225 Parallelism: 1,
226 FileLimit: 1,
227 }
228
229 if err := createEmptyShard(args); err != nil {
230 t.Fatal(err)
231 }
232
233 bo := args.BuildOptions()
234 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
235
236 if got := bo.IncrementalSkipIndexing(); !got {
237 t.Fatalf("wanted %t, got %t", true, got)
238 }
239}
240
241func TestFormatListUint32(t *testing.T) {
242 cases := []struct {
243 in []uint32
244 want string
245 }{
246 {
247 in: []uint32{42, 8, 3},
248 want: "42, 8, ...",
249 },
250 {
251 in: []uint32{42, 8},
252 want: "42, 8",
253 },
254 {
255 in: []uint32{42},
256 want: "42",
257 },
258 {
259 in: []uint32{},
260 want: "",
261 },
262 }
263
264 for _, tt := range cases {
265 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
266 out := formatListUint32(tt.in, 2)
267 if out != tt.want {
268 t.Fatalf("want %s, got %s", tt.want, out)
269 }
270 })
271 }
272}
273
274func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
275 wd, err := os.Getwd()
276 if err != nil {
277 t.Fatalf("failed to get working directory: %v", err)
278 }
279
280 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
281 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
282
283 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
284
285 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
286 if err != nil {
287 t.Fatalf("failed to validate default service config: %v", err)
288 }
289
290 if !result.Valid() {
291 var errs strings.Builder
292 for _, err := range result.Errors() {
293 errs.WriteString(fmt.Sprintf("- %s\n", err))
294 }
295
296 t.Fatalf("default service config is invalid:\n%s", errs.String())
297 }
298}
299
300func TestGetBoolFromEnvironmentVariables(t *testing.T) {
301 testCases := []struct {
302 name string
303 envVarsToSet map[string]string
304
305 envVarNames []string
306 defaultBool bool
307
308 wantBool bool
309 wantErr bool
310 }{
311 {
312 name: "respect default value: true",
313
314 envVarsToSet: map[string]string{},
315
316 envVarNames: []string{"FOO", "BAR"},
317 defaultBool: true,
318
319 wantBool: true,
320 },
321 {
322 name: "respect default value: false",
323
324 envVarsToSet: map[string]string{},
325
326 envVarNames: []string{"FOO", "BAR"},
327 defaultBool: false,
328
329 wantBool: false,
330 },
331 {
332 name: "read from environment",
333
334 envVarsToSet: map[string]string{"FOO": "1"},
335
336 envVarNames: []string{"FOO"},
337 defaultBool: false,
338
339 wantBool: true,
340 },
341 {
342 name: "read from first env var that is set",
343
344 envVarsToSet: map[string]string{
345 "BAR": "false",
346 "BAZ": "true",
347 },
348
349 envVarNames: []string{"FOO", "BAR", "BAZ"},
350 defaultBool: true,
351
352 wantBool: false,
353 },
354
355 {
356 name: "should error for invalid input",
357
358 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
359
360 envVarNames: []string{"INVALID"},
361 defaultBool: false,
362
363 wantErr: true,
364 },
365 }
366
367 for _, tc := range testCases {
368 t.Run("", func(t *testing.T) {
369 // Prepare the environment by loading all the appropriate environment variables
370 for _, v := range tc.envVarNames {
371 _ = os.Unsetenv(v)
372 }
373
374 for k := range tc.envVarsToSet {
375 _ = os.Unsetenv(k)
376 }
377
378 for k, v := range tc.envVarsToSet {
379 t.Setenv(k, v)
380 }
381
382 // Run the test
383 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
384
385 // Examine the results
386 if tc.wantErr != (err != nil) {
387 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
388 }
389
390 if got != tc.wantBool {
391 t.Errorf("got %v, want %v", got, tc.wantBool)
392 }
393 })
394 }
395}
396
397func TestAddDefaultPort(t *testing.T) {
398 tests := []struct {
399 name string
400 input string
401 want string
402 }{
403 {
404 name: "http no port",
405 input: "http://example.com",
406 want: "http://example.com:80",
407 },
408 {
409 name: "http custom port",
410 input: "http://example.com:90",
411 want: "http://example.com:90",
412 },
413 {
414 name: "https no port",
415 input: "https://example.com",
416 want: "https://example.com:443",
417 },
418 {
419 name: "https custom port",
420 input: "https://example.com:444",
421 want: "https://example.com:444",
422 },
423 {
424 name: "non-http scheme",
425 input: "ftp://example.com",
426 want: "ftp://example.com",
427 },
428 {
429 name: "empty string",
430 input: "",
431 want: "",
432 },
433 {
434 name: "local file path",
435 input: "/etc/hosts",
436 want: "/etc/hosts",
437 },
438 }
439
440 for _, test := range tests {
441 t.Run(test.name, func(t *testing.T) {
442 input, err := url.Parse(test.input)
443 if err != nil {
444 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
445 }
446
447 got := addDefaultPort(input)
448 if diff := cmp.Diff(test.want, got.String()); diff != "" {
449 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
450 }
451 })
452 }
453}
454
455func TestIndexGRPC(t *testing.T) {
456 indexDir := t.TempDir()
457
458 // Minimal server setup
459 s := &Server{
460 logger: logtest.NoOp(t),
461 IndexDir: indexDir,
462 rootURL: &url.URL{Scheme: "http", Host: "example.com"},
463 indexSemaphore: make(chan struct{}, 1),
464 timeout: time.Hour, // no timeout
465 }
466
467 branches := []*configv1.ZoektRepositoryBranch{
468 {
469 Name: "HEAD",
470 Version: "abc123",
471 },
472 }
473
474 req := &indexserverv1.IndexRequest{
475 Options: &configv1.ZoektIndexOptions{
476 RepoId: 42,
477 Name: "repo",
478 TenantId: 1,
479 Branches: branches,
480 },
481 }
482
483 resp, err := s.indexGRPC(context.Background(), req, mockIndexFunc(t))
484 require.NoError(t, err)
485 require.Equal(t, &indexserverv1.IndexResponse{
486 RepoId: 42,
487 Branches: branches,
488 IndexTimeUnix: resp.IndexTimeUnix, // Hack: this changes every time so we don't check it
489 }, resp)
490
491 require.NotZero(t, resp.IndexTimeUnix)
492}
493
494func TestIndexGRPC_Timeout(t *testing.T) {
495 indexDir := t.TempDir()
496
497 s := &Server{
498 logger: logtest.NoOp(t),
499 IndexDir: indexDir,
500 IndexConcurrency: 0, // impossible to acquire index slot
501 timeout: time.Millisecond,
502 }
503
504 req := &indexserverv1.IndexRequest{
505 Options: &configv1.ZoektIndexOptions{
506 RepoId: 42,
507 Name: "repo",
508 },
509 }
510
511 // use context.Background() to make sure we don't return because of context cancellation
512 _, err := s.indexGRPC(context.Background(), req, mockIndexFunc(t))
513 require.Error(t, err)
514 require.Equal(t, codes.DeadlineExceeded, status.Code(err))
515}
516
517func TestDelete(t *testing.T) {
518 indexDir := t.TempDir()
519 trashDir := filepath.Join(indexDir, ".trash")
520 if err := os.MkdirAll(trashDir, 0o755); err != nil {
521 t.Fatal(err)
522 }
523
524 // Create a simple shard
525 createShard(t, indexDir)
526
527 // Verify the shard exists in index dir
528 shards := getShards(indexDir)
529 if len(shards) != 1 {
530 t.Fatalf("expected 1 shard, got %d", len(shards))
531 }
532
533 // Create server and call Delete
534 s := &Server{
535 logger: logtest.NoOp(t),
536 IndexDir: indexDir,
537 rootURL: &url.URL{Scheme: "http", Host: "example.com"},
538 indexSemaphore: make(chan struct{}, 1),
539 timeout: time.Hour, // no timeout
540 }
541
542 req := &indexserverv1.DeleteRequest{
543 RepoIds: []uint32{42}, // matches the repo ID in createShard
544 }
545
546 // Test case: context is canceled
547 cancledCtx, cancel := context.WithCancel(context.Background())
548 cancel()
549 _, err := s.Delete(cancledCtx, req)
550 require.Error(t, err)
551
552 shards = getShards(indexDir)
553 require.Len(t, shards, 1)
554
555 // Test case: context is not canceled
556 _, err = s.Delete(context.Background(), req)
557 require.NoError(t, err)
558
559 // Verify shard was moved to trash
560 trashShards := getShards(trashDir)
561 require.Len(t, trashShards, 1)
562
563 // Verify shard is no longer in index dir
564 shards = getShards(indexDir)
565 require.Len(t, shards, 0)
566}
567
568func mockIndexFunc(t *testing.T) func(ctx context.Context, args *indexArgs) (indexState, error) {
569 return func(ctx context.Context, args *indexArgs) (indexState, error) {
570 createShard(t, args.IndexDir)
571 return indexStateSuccess, nil
572 }
573}
574
575func createShard(t *testing.T, dir string) {
576 opts := index.Options{
577 IndexDir: dir,
578 RepositoryDescription: zoekt.Repository{
579 ID: 42,
580 Name: "repo",
581 Branches: []zoekt.RepositoryBranch{
582 {
583 Name: "HEAD",
584 Version: "abc123",
585 },
586 },
587 },
588 }
589
590 b, err := index.NewBuilder(opts)
591 require.NoError(t, err)
592 require.NoError(t, b.AddFile("test.txt", []byte("hello")))
593 require.NoError(t, b.Finish())
594}
595
596func TestRecoverFromTrash(t *testing.T) {
597 dir := t.TempDir()
598 trashDir := filepath.Join(dir, ".trash")
599 require.NoError(t, os.MkdirAll(trashDir, 0o755))
600
601 // Create a simple shard in trash
602 createTestShard(t, "repo1", 1, filepath.Join(trashDir, "repo1.zoekt"))
603
604 // Create a compound shard with two repos, one of them tombstoned
605 cs := createCompoundShard(t, dir, []uint32{2, 3})
606 require.NoError(t, index.SetTombstone(cs, 2))
607
608 s := &Server{
609 IndexDir: dir,
610 }
611
612 // Test recovering from trash
613 recovered := s.recoverFromTrash(1)
614 require.True(t, recovered, "should have recovered repo1 from trash")
615
616 // Verify shard was moved from trash to index
617 indexShards := getShards(dir)
618 trashShards := getShards(trashDir)
619
620 require.Contains(t, indexShards, uint32(1), "repo1 should be in index")
621 require.NotContains(t, trashShards, uint32(1), "repo1 should not be in trash")
622
623 // Test unsetting tombstone
624 recovered = s.recoverFromTrash(2)
625 require.True(t, recovered, "should have recovered repo2 from tombstone")
626
627 // Verify tombstone was unset
628 repos, _, err := index.ReadMetadataPath(cs)
629 require.NoError(t, err)
630
631 for _, repo := range repos {
632 if repo.ID == 2 {
633 require.False(t, repo.Tombstone, "repo2 should not be tombstoned")
634 }
635 }
636
637 // Test non-existent repo
638 recovered = s.recoverFromTrash(99)
639 require.False(t, recovered, "should not have recovered non-existent repo")
640}