fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/url" 10 "os" 11 "path/filepath" 12 "slices" 13 "strings" 14 "testing" 15 "time" 16 17 "github.com/google/go-cmp/cmp" 18 sglog "github.com/sourcegraph/log" 19 "github.com/sourcegraph/log/logtest" 20 "github.com/stretchr/testify/require" 21 "github.com/xeipuuv/gojsonschema" 22 "google.golang.org/grpc" 23 "google.golang.org/grpc/codes" 24 "google.golang.org/grpc/status" 25 26 "github.com/sourcegraph/zoekt" 27 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1" 28 indexserverv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/zoekt/indexserver/v1" 29 "github.com/sourcegraph/zoekt/index" 30 "github.com/sourcegraph/zoekt/internal/tenant" 31) 32 33func TestServer_defaultArgs(t *testing.T) { 34 root, err := url.Parse("http://api.test") 35 if err != nil { 36 t.Fatal(err) 37 } 38 39 s := &Server{ 40 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 41 IndexDir: "/testdata/index", 42 CPUCount: 6, 43 IndexConcurrency: 1, 44 } 45 want := &indexArgs{ 46 IndexOptions: IndexOptions{ 47 Name: "testName", 48 }, 49 IndexDir: "/testdata/index", 50 Parallelism: 6, 51 Incremental: true, 52 FileLimit: 1 << 20, 53 } 54 got := s.indexArgs(IndexOptions{Name: "testName"}) 55 if !cmp.Equal(got, want) { 56 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 57 } 58} 59 60func TestIndexNoTenant(t *testing.T) { 61 s := &Server{} 62 _, err := s.index(context.Background(), &indexArgs{}) 63 require.ErrorIs(t, err, tenant.ErrMissingTenant) 64} 65 66func TestServer_parallelism(t *testing.T) { 67 root, err := url.Parse("http://api.test") 68 if err != nil { 69 t.Fatal(err) 70 } 71 72 cases := []struct { 73 name string 74 cpuCount int 75 indexConcurrency int 76 options IndexOptions 77 want int 78 }{ 79 { 80 name: "CPU count divides evenly", 81 cpuCount: 16, 82 indexConcurrency: 8, 83 want: 2, 84 }, 85 { 86 name: "no shard level parallelism", 87 cpuCount: 4, 88 indexConcurrency: 4, 89 want: 1, 90 }, 91 { 92 name: "index option overrides server flag", 93 cpuCount: 2, 94 indexConcurrency: 1, 95 options: IndexOptions{ 96 ShardConcurrency: 1, 97 }, 98 want: 1, 99 }, 100 { 101 name: "ignore invalid index option", 102 cpuCount: 8, 103 indexConcurrency: 2, 104 options: IndexOptions{ 105 ShardConcurrency: -1, 106 }, 107 want: 4, 108 }, 109 } 110 111 for _, tt := range cases { 112 t.Run(tt.name, func(t *testing.T) { 113 s := &Server{ 114 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 115 IndexDir: "/testdata/index", 116 CPUCount: tt.cpuCount, 117 IndexConcurrency: tt.indexConcurrency, 118 } 119 120 maxProcs := 16 121 got := s.parallelism(tt.options, maxProcs) 122 if tt.want != got { 123 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 124 } 125 }) 126 } 127 128 t.Run("index option is limited by available CPU", func(t *testing.T) { 129 s := &Server{ 130 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 131 IndexDir: "/testdata/index", 132 IndexConcurrency: 1, 133 } 134 135 got := s.indexArgs(IndexOptions{ 136 ShardConcurrency: 2048, // Some number that's way too high 137 }) 138 139 if got.Parallelism >= 2048 { 140 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 141 } 142 }) 143} 144 145func TestListRepoIDs(t *testing.T) { 146 grpcClient := &mockGRPCClient{} 147 148 clientOptions := []SourcegraphClientOption{ 149 WithBatchSize(0), 150 } 151 152 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 153 testHostname := "test-hostname" 154 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...) 155 156 listCalled := false 157 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) { 158 listCalled = true 159 160 gotRepoIDs := in.GetIndexedIds() 161 slices.Sort(gotRepoIDs) 162 163 wantRepoIDs := []int32{1, 3} 164 slices.Sort(wantRepoIDs) 165 166 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 167 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 168 } 169 170 hostname := in.GetHostname() 171 if diff := cmp.Diff(testHostname, hostname); diff != "" { 172 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 173 } 174 175 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 176 } 177 178 ctx := context.Background() 179 got, err := s.List(ctx, []uint32{1, 3}) 180 if err != nil { 181 t.Fatal(err) 182 } 183 184 if !listCalled { 185 t.Fatalf("List was not called") 186 } 187 188 receivedRepoIDs := got.IDs 189 slices.Sort(receivedRepoIDs) 190 191 expectedRepoIDs := []uint32{1, 2, 3} 192 slices.Sort(expectedRepoIDs) 193 194 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 195 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 196 } 197} 198 199func TestMain(m *testing.M) { 200 flag.Parse() 201 level := sglog.LevelInfo 202 if !testing.Verbose() { 203 log.SetOutput(io.Discard) 204 debugLog.SetOutput(io.Discard) 205 infoLog.SetOutput(io.Discard) 206 errorLog.SetOutput(io.Discard) 207 level = sglog.LevelNone 208 } 209 210 logtest.InitWithLevel(m, level) 211 os.Exit(m.Run()) 212} 213 214func TestCreateEmptyShard(t *testing.T) { 215 dir := t.TempDir() 216 217 args := &indexArgs{ 218 IndexOptions: IndexOptions{ 219 RepoID: 7, 220 Name: "empty-repo", 221 CloneURL: "code/host", 222 }, 223 Incremental: true, 224 IndexDir: dir, 225 Parallelism: 1, 226 FileLimit: 1, 227 } 228 229 if err := createEmptyShard(args); err != nil { 230 t.Fatal(err) 231 } 232 233 bo := args.BuildOptions() 234 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 235 236 if got := bo.IncrementalSkipIndexing(); !got { 237 t.Fatalf("wanted %t, got %t", true, got) 238 } 239} 240 241func TestFormatListUint32(t *testing.T) { 242 cases := []struct { 243 in []uint32 244 want string 245 }{ 246 { 247 in: []uint32{42, 8, 3}, 248 want: "42, 8, ...", 249 }, 250 { 251 in: []uint32{42, 8}, 252 want: "42, 8", 253 }, 254 { 255 in: []uint32{42}, 256 want: "42", 257 }, 258 { 259 in: []uint32{}, 260 want: "", 261 }, 262 } 263 264 for _, tt := range cases { 265 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 266 out := formatListUint32(tt.in, 2) 267 if out != tt.want { 268 t.Fatalf("want %s, got %s", tt.want, out) 269 } 270 }) 271 } 272} 273 274func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 275 wd, err := os.Getwd() 276 if err != nil { 277 t.Fatalf("failed to get working directory: %v", err) 278 } 279 280 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 281 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 282 283 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 284 285 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 286 if err != nil { 287 t.Fatalf("failed to validate default service config: %v", err) 288 } 289 290 if !result.Valid() { 291 var errs strings.Builder 292 for _, err := range result.Errors() { 293 errs.WriteString(fmt.Sprintf("- %s\n", err)) 294 } 295 296 t.Fatalf("default service config is invalid:\n%s", errs.String()) 297 } 298} 299 300func TestGetBoolFromEnvironmentVariables(t *testing.T) { 301 testCases := []struct { 302 name string 303 envVarsToSet map[string]string 304 305 envVarNames []string 306 defaultBool bool 307 308 wantBool bool 309 wantErr bool 310 }{ 311 { 312 name: "respect default value: true", 313 314 envVarsToSet: map[string]string{}, 315 316 envVarNames: []string{"FOO", "BAR"}, 317 defaultBool: true, 318 319 wantBool: true, 320 }, 321 { 322 name: "respect default value: false", 323 324 envVarsToSet: map[string]string{}, 325 326 envVarNames: []string{"FOO", "BAR"}, 327 defaultBool: false, 328 329 wantBool: false, 330 }, 331 { 332 name: "read from environment", 333 334 envVarsToSet: map[string]string{"FOO": "1"}, 335 336 envVarNames: []string{"FOO"}, 337 defaultBool: false, 338 339 wantBool: true, 340 }, 341 { 342 name: "read from first env var that is set", 343 344 envVarsToSet: map[string]string{ 345 "BAR": "false", 346 "BAZ": "true", 347 }, 348 349 envVarNames: []string{"FOO", "BAR", "BAZ"}, 350 defaultBool: true, 351 352 wantBool: false, 353 }, 354 355 { 356 name: "should error for invalid input", 357 358 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 359 360 envVarNames: []string{"INVALID"}, 361 defaultBool: false, 362 363 wantErr: true, 364 }, 365 } 366 367 for _, tc := range testCases { 368 t.Run("", func(t *testing.T) { 369 // Prepare the environment by loading all the appropriate environment variables 370 for _, v := range tc.envVarNames { 371 _ = os.Unsetenv(v) 372 } 373 374 for k := range tc.envVarsToSet { 375 _ = os.Unsetenv(k) 376 } 377 378 for k, v := range tc.envVarsToSet { 379 t.Setenv(k, v) 380 } 381 382 // Run the test 383 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 384 385 // Examine the results 386 if tc.wantErr != (err != nil) { 387 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 388 } 389 390 if got != tc.wantBool { 391 t.Errorf("got %v, want %v", got, tc.wantBool) 392 } 393 }) 394 } 395} 396 397func TestAddDefaultPort(t *testing.T) { 398 tests := []struct { 399 name string 400 input string 401 want string 402 }{ 403 { 404 name: "http no port", 405 input: "http://example.com", 406 want: "http://example.com:80", 407 }, 408 { 409 name: "http custom port", 410 input: "http://example.com:90", 411 want: "http://example.com:90", 412 }, 413 { 414 name: "https no port", 415 input: "https://example.com", 416 want: "https://example.com:443", 417 }, 418 { 419 name: "https custom port", 420 input: "https://example.com:444", 421 want: "https://example.com:444", 422 }, 423 { 424 name: "non-http scheme", 425 input: "ftp://example.com", 426 want: "ftp://example.com", 427 }, 428 { 429 name: "empty string", 430 input: "", 431 want: "", 432 }, 433 { 434 name: "local file path", 435 input: "/etc/hosts", 436 want: "/etc/hosts", 437 }, 438 } 439 440 for _, test := range tests { 441 t.Run(test.name, func(t *testing.T) { 442 input, err := url.Parse(test.input) 443 if err != nil { 444 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 445 } 446 447 got := addDefaultPort(input) 448 if diff := cmp.Diff(test.want, got.String()); diff != "" { 449 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 450 } 451 }) 452 } 453} 454 455func TestIndexGRPC(t *testing.T) { 456 indexDir := t.TempDir() 457 458 // Minimal server setup 459 s := &Server{ 460 logger: logtest.NoOp(t), 461 IndexDir: indexDir, 462 rootURL: &url.URL{Scheme: "http", Host: "example.com"}, 463 indexSemaphore: make(chan struct{}, 1), 464 timeout: time.Hour, // no timeout 465 } 466 467 branches := []*configv1.ZoektRepositoryBranch{ 468 { 469 Name: "HEAD", 470 Version: "abc123", 471 }, 472 } 473 474 req := &indexserverv1.IndexRequest{ 475 Options: &configv1.ZoektIndexOptions{ 476 RepoId: 42, 477 Name: "repo", 478 TenantId: 1, 479 Branches: branches, 480 }, 481 } 482 483 resp, err := s.indexGRPC(context.Background(), req, mockIndexFunc(t)) 484 require.NoError(t, err) 485 require.Equal(t, &indexserverv1.IndexResponse{ 486 RepoId: 42, 487 Branches: branches, 488 IndexTimeUnix: resp.IndexTimeUnix, // Hack: this changes every time so we don't check it 489 }, resp) 490 491 require.NotZero(t, resp.IndexTimeUnix) 492} 493 494func TestIndexGRPC_Timeout(t *testing.T) { 495 indexDir := t.TempDir() 496 497 s := &Server{ 498 logger: logtest.NoOp(t), 499 IndexDir: indexDir, 500 IndexConcurrency: 0, // impossible to acquire index slot 501 timeout: time.Millisecond, 502 } 503 504 req := &indexserverv1.IndexRequest{ 505 Options: &configv1.ZoektIndexOptions{ 506 RepoId: 42, 507 Name: "repo", 508 }, 509 } 510 511 // use context.Background() to make sure we don't return because of context cancellation 512 _, err := s.indexGRPC(context.Background(), req, mockIndexFunc(t)) 513 require.Error(t, err) 514 require.Equal(t, codes.DeadlineExceeded, status.Code(err)) 515} 516 517func TestDelete(t *testing.T) { 518 indexDir := t.TempDir() 519 trashDir := filepath.Join(indexDir, ".trash") 520 if err := os.MkdirAll(trashDir, 0o755); err != nil { 521 t.Fatal(err) 522 } 523 524 // Create a simple shard 525 createShard(t, indexDir) 526 527 // Verify the shard exists in index dir 528 shards := getShards(indexDir) 529 if len(shards) != 1 { 530 t.Fatalf("expected 1 shard, got %d", len(shards)) 531 } 532 533 // Create server and call Delete 534 s := &Server{ 535 logger: logtest.NoOp(t), 536 IndexDir: indexDir, 537 rootURL: &url.URL{Scheme: "http", Host: "example.com"}, 538 indexSemaphore: make(chan struct{}, 1), 539 timeout: time.Hour, // no timeout 540 } 541 542 req := &indexserverv1.DeleteRequest{ 543 RepoIds: []uint32{42}, // matches the repo ID in createShard 544 } 545 546 // Test case: context is canceled 547 cancledCtx, cancel := context.WithCancel(context.Background()) 548 cancel() 549 _, err := s.Delete(cancledCtx, req) 550 require.Error(t, err) 551 552 shards = getShards(indexDir) 553 require.Len(t, shards, 1) 554 555 // Test case: context is not canceled 556 _, err = s.Delete(context.Background(), req) 557 require.NoError(t, err) 558 559 // Verify shard was moved to trash 560 trashShards := getShards(trashDir) 561 require.Len(t, trashShards, 1) 562 563 // Verify shard is no longer in index dir 564 shards = getShards(indexDir) 565 require.Len(t, shards, 0) 566} 567 568func mockIndexFunc(t *testing.T) func(ctx context.Context, args *indexArgs) (indexState, error) { 569 return func(ctx context.Context, args *indexArgs) (indexState, error) { 570 createShard(t, args.IndexDir) 571 return indexStateSuccess, nil 572 } 573} 574 575func createShard(t *testing.T, dir string) { 576 opts := index.Options{ 577 IndexDir: dir, 578 RepositoryDescription: zoekt.Repository{ 579 ID: 42, 580 Name: "repo", 581 Branches: []zoekt.RepositoryBranch{ 582 { 583 Name: "HEAD", 584 Version: "abc123", 585 }, 586 }, 587 }, 588 } 589 590 b, err := index.NewBuilder(opts) 591 require.NoError(t, err) 592 require.NoError(t, b.AddFile("test.txt", []byte("hello"))) 593 require.NoError(t, b.Finish()) 594} 595 596func TestRecoverFromTrash(t *testing.T) { 597 dir := t.TempDir() 598 trashDir := filepath.Join(dir, ".trash") 599 require.NoError(t, os.MkdirAll(trashDir, 0o755)) 600 601 // Create a simple shard in trash 602 createTestShard(t, "repo1", 1, filepath.Join(trashDir, "repo1.zoekt")) 603 604 // Create a compound shard with two repos, one of them tombstoned 605 cs := createCompoundShard(t, dir, []uint32{2, 3}) 606 require.NoError(t, index.SetTombstone(cs, 2)) 607 608 s := &Server{ 609 IndexDir: dir, 610 } 611 612 // Test recovering from trash 613 recovered := s.recoverFromTrash(1) 614 require.True(t, recovered, "should have recovered repo1 from trash") 615 616 // Verify shard was moved from trash to index 617 indexShards := getShards(dir) 618 trashShards := getShards(trashDir) 619 620 require.Contains(t, indexShards, uint32(1), "repo1 should be in index") 621 require.NotContains(t, trashShards, uint32(1), "repo1 should not be in trash") 622 623 // Test unsetting tombstone 624 recovered = s.recoverFromTrash(2) 625 require.True(t, recovered, "should have recovered repo2 from tombstone") 626 627 // Verify tombstone was unset 628 repos, _, err := index.ReadMetadataPath(cs) 629 require.NoError(t, err) 630 631 for _, repo := range repos { 632 if repo.ID == 2 { 633 require.False(t, repo.Tombstone, "repo2 should not be tombstoned") 634 } 635 } 636 637 // Test non-existent repo 638 recovered = s.recoverFromTrash(99) 639 require.False(t, recovered, "should not have recovered non-existent repo") 640}