fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/url" 10 "os" 11 "path/filepath" 12 "slices" 13 "strings" 14 "testing" 15 "time" 16 17 "github.com/google/go-cmp/cmp" 18 sglog "github.com/sourcegraph/log" 19 "github.com/sourcegraph/log/logtest" 20 "github.com/sourcegraph/zoekt" 21 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1" 22 indexserverv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/zoekt/indexserver/v1" 23 "github.com/sourcegraph/zoekt/index" 24 "github.com/sourcegraph/zoekt/internal/tenant" 25 "github.com/stretchr/testify/require" 26 "github.com/xeipuuv/gojsonschema" 27 "google.golang.org/grpc" 28 "google.golang.org/grpc/codes" 29 "google.golang.org/grpc/status" 30) 31 32func TestServer_defaultArgs(t *testing.T) { 33 root, err := url.Parse("http://api.test") 34 if err != nil { 35 t.Fatal(err) 36 } 37 38 s := &Server{ 39 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 40 IndexDir: "/testdata/index", 41 CPUCount: 6, 42 IndexConcurrency: 1, 43 } 44 want := &indexArgs{ 45 IndexOptions: IndexOptions{ 46 Name: "testName", 47 }, 48 IndexDir: "/testdata/index", 49 Parallelism: 6, 50 Incremental: true, 51 FileLimit: 1 << 20, 52 } 53 got := s.indexArgs(IndexOptions{Name: "testName"}) 54 if !cmp.Equal(got, want) { 55 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 56 } 57} 58 59func TestIndexNoTenant(t *testing.T) { 60 s := &Server{} 61 _, err := s.index(context.Background(), &indexArgs{}) 62 require.ErrorIs(t, err, tenant.ErrMissingTenant) 63} 64 65func TestServer_parallelism(t *testing.T) { 66 root, err := url.Parse("http://api.test") 67 if err != nil { 68 t.Fatal(err) 69 } 70 71 cases := []struct { 72 name string 73 cpuCount int 74 indexConcurrency int 75 options IndexOptions 76 want int 77 }{ 78 { 79 name: "CPU count divides evenly", 80 cpuCount: 16, 81 indexConcurrency: 8, 82 want: 2, 83 }, 84 { 85 name: "no shard level parallelism", 86 cpuCount: 4, 87 indexConcurrency: 4, 88 want: 1, 89 }, 90 { 91 name: "index option overrides server flag", 92 cpuCount: 2, 93 indexConcurrency: 1, 94 options: IndexOptions{ 95 ShardConcurrency: 1, 96 }, 97 want: 1, 98 }, 99 { 100 name: "ignore invalid index option", 101 cpuCount: 8, 102 indexConcurrency: 2, 103 options: IndexOptions{ 104 ShardConcurrency: -1, 105 }, 106 want: 4, 107 }, 108 } 109 110 for _, tt := range cases { 111 t.Run(tt.name, func(t *testing.T) { 112 s := &Server{ 113 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 114 IndexDir: "/testdata/index", 115 CPUCount: tt.cpuCount, 116 IndexConcurrency: tt.indexConcurrency, 117 } 118 119 maxProcs := 16 120 got := s.parallelism(tt.options, maxProcs) 121 if tt.want != got { 122 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 123 } 124 }) 125 } 126 127 t.Run("index option is limited by available CPU", func(t *testing.T) { 128 s := &Server{ 129 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 130 IndexDir: "/testdata/index", 131 IndexConcurrency: 1, 132 } 133 134 got := s.indexArgs(IndexOptions{ 135 ShardConcurrency: 2048, // Some number that's way too high 136 }) 137 138 if got.Parallelism >= 2048 { 139 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 140 } 141 }) 142} 143 144func TestListRepoIDs(t *testing.T) { 145 grpcClient := &mockGRPCClient{} 146 147 clientOptions := []SourcegraphClientOption{ 148 WithBatchSize(0), 149 } 150 151 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 152 testHostname := "test-hostname" 153 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...) 154 155 listCalled := false 156 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) { 157 listCalled = true 158 159 gotRepoIDs := in.GetIndexedIds() 160 slices.Sort(gotRepoIDs) 161 162 wantRepoIDs := []int32{1, 3} 163 slices.Sort(wantRepoIDs) 164 165 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 166 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 167 } 168 169 hostname := in.GetHostname() 170 if diff := cmp.Diff(testHostname, hostname); diff != "" { 171 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 172 } 173 174 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 175 } 176 177 ctx := context.Background() 178 got, err := s.List(ctx, []uint32{1, 3}) 179 if err != nil { 180 t.Fatal(err) 181 } 182 183 if !listCalled { 184 t.Fatalf("List was not called") 185 } 186 187 receivedRepoIDs := got.IDs 188 slices.Sort(receivedRepoIDs) 189 190 expectedRepoIDs := []uint32{1, 2, 3} 191 slices.Sort(expectedRepoIDs) 192 193 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 194 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 195 } 196} 197 198func TestMain(m *testing.M) { 199 flag.Parse() 200 level := sglog.LevelInfo 201 if !testing.Verbose() { 202 log.SetOutput(io.Discard) 203 debugLog.SetOutput(io.Discard) 204 infoLog.SetOutput(io.Discard) 205 errorLog.SetOutput(io.Discard) 206 level = sglog.LevelNone 207 } 208 209 logtest.InitWithLevel(m, level) 210 os.Exit(m.Run()) 211} 212 213func TestCreateEmptyShard(t *testing.T) { 214 dir := t.TempDir() 215 216 args := &indexArgs{ 217 IndexOptions: IndexOptions{ 218 RepoID: 7, 219 Name: "empty-repo", 220 CloneURL: "code/host", 221 }, 222 Incremental: true, 223 IndexDir: dir, 224 Parallelism: 1, 225 FileLimit: 1, 226 } 227 228 if err := createEmptyShard(args); err != nil { 229 t.Fatal(err) 230 } 231 232 bo := args.BuildOptions() 233 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 234 235 if got := bo.IncrementalSkipIndexing(); !got { 236 t.Fatalf("wanted %t, got %t", true, got) 237 } 238} 239 240func TestFormatListUint32(t *testing.T) { 241 cases := []struct { 242 in []uint32 243 want string 244 }{ 245 { 246 in: []uint32{42, 8, 3}, 247 want: "42, 8, ...", 248 }, 249 { 250 in: []uint32{42, 8}, 251 want: "42, 8", 252 }, 253 { 254 in: []uint32{42}, 255 want: "42", 256 }, 257 { 258 in: []uint32{}, 259 want: "", 260 }, 261 } 262 263 for _, tt := range cases { 264 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 265 out := formatListUint32(tt.in, 2) 266 if out != tt.want { 267 t.Fatalf("want %s, got %s", tt.want, out) 268 } 269 }) 270 } 271} 272 273func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 274 wd, err := os.Getwd() 275 if err != nil { 276 t.Fatalf("failed to get working directory: %v", err) 277 } 278 279 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 280 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 281 282 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 283 284 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 285 if err != nil { 286 t.Fatalf("failed to validate default service config: %v", err) 287 } 288 289 if !result.Valid() { 290 var errs strings.Builder 291 for _, err := range result.Errors() { 292 errs.WriteString(fmt.Sprintf("- %s\n", err)) 293 } 294 295 t.Fatalf("default service config is invalid:\n%s", errs.String()) 296 } 297} 298 299func TestGetBoolFromEnvironmentVariables(t *testing.T) { 300 testCases := []struct { 301 name string 302 envVarsToSet map[string]string 303 304 envVarNames []string 305 defaultBool bool 306 307 wantBool bool 308 wantErr bool 309 }{ 310 { 311 name: "respect default value: true", 312 313 envVarsToSet: map[string]string{}, 314 315 envVarNames: []string{"FOO", "BAR"}, 316 defaultBool: true, 317 318 wantBool: true, 319 }, 320 { 321 name: "respect default value: false", 322 323 envVarsToSet: map[string]string{}, 324 325 envVarNames: []string{"FOO", "BAR"}, 326 defaultBool: false, 327 328 wantBool: false, 329 }, 330 { 331 name: "read from environment", 332 333 envVarsToSet: map[string]string{"FOO": "1"}, 334 335 envVarNames: []string{"FOO"}, 336 defaultBool: false, 337 338 wantBool: true, 339 }, 340 { 341 name: "read from first env var that is set", 342 343 envVarsToSet: map[string]string{ 344 "BAR": "false", 345 "BAZ": "true", 346 }, 347 348 envVarNames: []string{"FOO", "BAR", "BAZ"}, 349 defaultBool: true, 350 351 wantBool: false, 352 }, 353 354 { 355 name: "should error for invalid input", 356 357 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 358 359 envVarNames: []string{"INVALID"}, 360 defaultBool: false, 361 362 wantErr: true, 363 }, 364 } 365 366 for _, tc := range testCases { 367 t.Run("", func(t *testing.T) { 368 // Prepare the environment by loading all the appropriate environment variables 369 for _, v := range tc.envVarNames { 370 _ = os.Unsetenv(v) 371 } 372 373 for k := range tc.envVarsToSet { 374 _ = os.Unsetenv(k) 375 } 376 377 for k, v := range tc.envVarsToSet { 378 t.Setenv(k, v) 379 } 380 381 // Run the test 382 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 383 384 // Examine the results 385 if tc.wantErr != (err != nil) { 386 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 387 } 388 389 if got != tc.wantBool { 390 t.Errorf("got %v, want %v", got, tc.wantBool) 391 } 392 }) 393 } 394} 395 396func TestAddDefaultPort(t *testing.T) { 397 tests := []struct { 398 name string 399 input string 400 want string 401 }{ 402 { 403 name: "http no port", 404 input: "http://example.com", 405 want: "http://example.com:80", 406 }, 407 { 408 name: "http custom port", 409 input: "http://example.com:90", 410 want: "http://example.com:90", 411 }, 412 { 413 name: "https no port", 414 input: "https://example.com", 415 want: "https://example.com:443", 416 }, 417 { 418 name: "https custom port", 419 input: "https://example.com:444", 420 want: "https://example.com:444", 421 }, 422 { 423 name: "non-http scheme", 424 input: "ftp://example.com", 425 want: "ftp://example.com", 426 }, 427 { 428 name: "empty string", 429 input: "", 430 want: "", 431 }, 432 { 433 name: "local file path", 434 input: "/etc/hosts", 435 want: "/etc/hosts", 436 }, 437 } 438 439 for _, test := range tests { 440 t.Run(test.name, func(t *testing.T) { 441 input, err := url.Parse(test.input) 442 if err != nil { 443 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 444 } 445 446 got := addDefaultPort(input) 447 if diff := cmp.Diff(test.want, got.String()); diff != "" { 448 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 449 } 450 }) 451 } 452} 453 454func TestIndexGRPC(t *testing.T) { 455 indexDir := t.TempDir() 456 457 // Minimal server setup 458 s := &Server{ 459 logger: logtest.NoOp(t), 460 IndexDir: indexDir, 461 rootURL: &url.URL{Scheme: "http", Host: "example.com"}, 462 indexSemaphore: make(chan struct{}, 1), 463 timeout: time.Hour, // no timeout 464 } 465 466 branches := []*configv1.ZoektRepositoryBranch{ 467 { 468 Name: "HEAD", 469 Version: "abc123", 470 }, 471 } 472 473 req := &indexserverv1.IndexRequest{ 474 Options: &configv1.ZoektIndexOptions{ 475 RepoId: 42, 476 Name: "repo", 477 TenantId: 1, 478 Branches: branches, 479 }, 480 } 481 482 resp, err := s.indexGRPC(context.Background(), req, mockIndexFunc(t)) 483 require.NoError(t, err) 484 require.Equal(t, &indexserverv1.IndexResponse{ 485 RepoId: 42, 486 Branches: branches, 487 IndexTimeUnix: resp.IndexTimeUnix, // Hack: this changes every time so we don't check it 488 }, resp) 489 490 require.NotZero(t, resp.IndexTimeUnix) 491} 492 493func TestIndexGRPC_Timeout(t *testing.T) { 494 indexDir := t.TempDir() 495 496 s := &Server{ 497 logger: logtest.NoOp(t), 498 IndexDir: indexDir, 499 IndexConcurrency: 0, // impossible to acquire index slot 500 timeout: time.Millisecond, 501 } 502 503 req := &indexserverv1.IndexRequest{ 504 Options: &configv1.ZoektIndexOptions{ 505 RepoId: 42, 506 Name: "repo", 507 }, 508 } 509 510 // use context.Background() to make sure we don't return because of context cancellation 511 _, err := s.indexGRPC(context.Background(), req, mockIndexFunc(t)) 512 require.Error(t, err) 513 require.Equal(t, codes.DeadlineExceeded, status.Code(err)) 514} 515 516func TestDelete(t *testing.T) { 517 indexDir := t.TempDir() 518 trashDir := filepath.Join(indexDir, ".trash") 519 if err := os.MkdirAll(trashDir, 0o755); err != nil { 520 t.Fatal(err) 521 } 522 523 // Create a simple shard 524 createShard(t, indexDir) 525 526 // Verify the shard exists in index dir 527 shards := getShards(indexDir) 528 if len(shards) != 1 { 529 t.Fatalf("expected 1 shard, got %d", len(shards)) 530 } 531 532 // Create server and call Delete 533 s := &Server{ 534 logger: logtest.NoOp(t), 535 IndexDir: indexDir, 536 rootURL: &url.URL{Scheme: "http", Host: "example.com"}, 537 indexSemaphore: make(chan struct{}, 1), 538 timeout: time.Hour, // no timeout 539 } 540 541 req := &indexserverv1.DeleteRequest{ 542 RepoIds: []uint32{42}, // matches the repo ID in createShard 543 } 544 545 // Test case: context is canceled 546 cancledCtx, cancel := context.WithCancel(context.Background()) 547 cancel() 548 _, err := s.Delete(cancledCtx, req) 549 require.Error(t, err) 550 551 shards = getShards(indexDir) 552 require.Len(t, shards, 1) 553 554 // Test case: context is not canceled 555 _, err = s.Delete(context.Background(), req) 556 require.NoError(t, err) 557 558 // Verify shard was moved to trash 559 trashShards := getShards(trashDir) 560 require.Len(t, trashShards, 1) 561 562 // Verify shard is no longer in index dir 563 shards = getShards(indexDir) 564 require.Len(t, shards, 0) 565} 566 567func mockIndexFunc(t *testing.T) func(ctx context.Context, args *indexArgs) (indexState, error) { 568 return func(ctx context.Context, args *indexArgs) (indexState, error) { 569 createShard(t, args.IndexDir) 570 return indexStateSuccess, nil 571 } 572} 573 574func createShard(t *testing.T, dir string) { 575 opts := index.Options{ 576 IndexDir: dir, 577 RepositoryDescription: zoekt.Repository{ 578 ID: 42, 579 Name: "repo", 580 Branches: []zoekt.RepositoryBranch{ 581 { 582 Name: "HEAD", 583 Version: "abc123", 584 }, 585 }, 586 }, 587 } 588 589 b, err := index.NewBuilder(opts) 590 require.NoError(t, err) 591 require.NoError(t, b.AddFile("test.txt", []byte("hello"))) 592 require.NoError(t, b.Finish()) 593} 594 595func TestRecoverFromTrash(t *testing.T) { 596 dir := t.TempDir() 597 trashDir := filepath.Join(dir, ".trash") 598 require.NoError(t, os.MkdirAll(trashDir, 0o755)) 599 600 // Create a simple shard in trash 601 createTestShard(t, "repo1", 1, filepath.Join(trashDir, "repo1.zoekt")) 602 603 // Create a compound shard with two repos, one of them tombstoned 604 cs := createCompoundShard(t, dir, []uint32{2, 3}) 605 require.NoError(t, index.SetTombstone(cs, 2)) 606 607 s := &Server{ 608 IndexDir: dir, 609 } 610 611 // Test recovering from trash 612 recovered := s.recoverFromTrash(1) 613 require.True(t, recovered, "should have recovered repo1 from trash") 614 615 // Verify shard was moved from trash to index 616 indexShards := getShards(dir) 617 trashShards := getShards(trashDir) 618 619 require.Contains(t, indexShards, uint32(1), "repo1 should be in index") 620 require.NotContains(t, trashShards, uint32(1), "repo1 should not be in trash") 621 622 // Test unsetting tombstone 623 recovered = s.recoverFromTrash(2) 624 require.True(t, recovered, "should have recovered repo2 from tombstone") 625 626 // Verify tombstone was unset 627 repos, _, err := index.ReadMetadataPath(cs) 628 require.NoError(t, err) 629 630 for _, repo := range repos { 631 if repo.ID == 2 { 632 require.False(t, repo.Tombstone, "repo2 should not be tombstoned") 633 } 634 } 635 636 // Test non-existent repo 637 recovered = s.recoverFromTrash(99) 638 require.False(t, recovered, "should not have recovered non-existent repo") 639}