fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/http" 10 "net/http/httptest" 11 "net/url" 12 "os" 13 "path/filepath" 14 "sort" 15 "strings" 16 "testing" 17 18 sglog "github.com/sourcegraph/log" 19 "github.com/sourcegraph/log/logtest" 20 "github.com/xeipuuv/gojsonschema" 21 "google.golang.org/grpc" 22 23 "github.com/google/go-cmp/cmp" 24 25 "github.com/sourcegraph/zoekt" 26 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1" 27) 28 29func TestServer_defaultArgs(t *testing.T) { 30 root, err := url.Parse("http://api.test") 31 if err != nil { 32 t.Fatal(err) 33 } 34 35 s := &Server{ 36 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)), 37 IndexDir: "/testdata/index", 38 CPUCount: 6, 39 IndexConcurrency: 1, 40 } 41 want := &indexArgs{ 42 IndexOptions: IndexOptions{ 43 Name: "testName", 44 }, 45 IndexDir: "/testdata/index", 46 Parallelism: 6, 47 Incremental: true, 48 FileLimit: 1 << 20, 49 } 50 got := s.indexArgs(IndexOptions{Name: "testName"}) 51 if !cmp.Equal(got, want) { 52 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 53 } 54} 55 56func TestServer_parallelism(t *testing.T) { 57 root, err := url.Parse("http://api.test") 58 if err != nil { 59 t.Fatal(err) 60 } 61 62 cases := []struct { 63 name string 64 cpuCount int 65 indexConcurrency int 66 options IndexOptions 67 want int 68 }{ 69 { 70 name: "CPU count divides evenly", 71 cpuCount: 16, 72 indexConcurrency: 8, 73 want: 2, 74 }, 75 { 76 name: "no shard level parallelism", 77 cpuCount: 4, 78 indexConcurrency: 4, 79 want: 1, 80 }, 81 { 82 name: "index option overrides server flag", 83 cpuCount: 2, 84 indexConcurrency: 1, 85 options: IndexOptions { 86 ShardConcurrency: 1, 87 }, 88 want: 1, 89 }, 90 { 91 name: "ignore invalid index option", 92 cpuCount: 8, 93 indexConcurrency: 2, 94 options: IndexOptions { 95 ShardConcurrency: -1, 96 }, 97 want: 4, 98 }, 99 } 100 101 for _, tt := range cases { 102 t.Run(tt.name, func(t *testing.T) { 103 s := &Server{ 104 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)), 105 IndexDir: "/testdata/index", 106 CPUCount: tt.cpuCount, 107 IndexConcurrency: tt.indexConcurrency, 108 } 109 110 maxProcs := 16 111 got := s.parallelism(tt.options, maxProcs) 112 if tt.want != got{ 113 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 114 } 115 }) 116 } 117 118 t.Run("index option is limited by available CPU", func(t *testing.T) { 119 s := &Server{ 120 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)), 121 IndexDir: "/testdata/index", 122 IndexConcurrency: 1, 123 } 124 125 got := s.indexArgs(IndexOptions { 126 ShardConcurrency: 2048, // Some number that's way too high 127 }) 128 129 if got.Parallelism >= 2048 { 130 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 131 } 132 }) 133} 134 135func TestListRepoIDs(t *testing.T) { 136 t.Run("gRPC", func(t *testing.T) { 137 138 grpcClient := &mockGRPCClient{} 139 140 clientOptions := []SourcegraphClientOption{ 141 WithShouldUseGRPC(true), 142 WithGRPCClient(grpcClient), 143 WithBatchSize(0), 144 } 145 146 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 147 testHostname := "test-hostname" 148 s := newSourcegraphClient(&testURL, testHostname, clientOptions...) 149 150 listCalled := false 151 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) { 152 listCalled = true 153 154 gotRepoIDs := in.GetIndexedIds() 155 sort.Slice(gotRepoIDs, func(i, j int) bool { 156 return gotRepoIDs[i] < gotRepoIDs[j] 157 }) 158 159 wantRepoIDs := []int32{1, 3} 160 sort.Slice(wantRepoIDs, func(i, j int) bool { 161 return wantRepoIDs[i] < wantRepoIDs[j] 162 }) 163 164 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 165 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 166 } 167 168 hostname := in.GetHostname() 169 if diff := cmp.Diff(testHostname, hostname); diff != "" { 170 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 171 } 172 173 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 174 } 175 176 ctx := context.Background() 177 got, err := s.List(ctx, []uint32{1, 3}) 178 if err != nil { 179 t.Fatal(err) 180 } 181 182 if !listCalled { 183 t.Fatalf("List was not called") 184 } 185 186 receivedRepoIDs := got.IDs 187 sort.Slice(receivedRepoIDs, func(i, j int) bool { 188 return receivedRepoIDs[i] < receivedRepoIDs[j] 189 }) 190 191 expectedRepoIDs := []uint32{1, 2, 3} 192 sort.Slice(expectedRepoIDs, func(i, j int) bool { 193 return expectedRepoIDs[i] < expectedRepoIDs[j] 194 }) 195 196 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 197 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 198 } 199 }) 200 201 t.Run("REST", func(t *testing.T) { 202 var gotBody string 203 var gotURL *url.URL 204 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 205 gotURL = r.URL 206 207 b, err := io.ReadAll(r.Body) 208 if err != nil { 209 t.Fatal(err) 210 } 211 gotBody = string(b) 212 213 _, err = w.Write([]byte(`{"RepoIDs": [1, 2, 3]}`)) 214 if err != nil { 215 t.Fatal(err) 216 } 217 })) 218 defer ts.Close() 219 220 u, err := url.Parse(ts.URL) 221 if err != nil { 222 t.Fatal(err) 223 } 224 225 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0)) 226 227 gotRepos, err := s.List(context.Background(), []uint32{1, 3}) 228 if err != nil { 229 t.Fatal(err) 230 } 231 232 if want := []uint32{1, 2, 3}; !cmp.Equal(gotRepos.IDs, want) { 233 t.Errorf("repos mismatch (-want +got):\n%s", cmp.Diff(want, gotRepos.IDs)) 234 } 235 if want := `{"Hostname":"test-indexed-search-1","IndexedIDs":[1,3]}`; gotBody != want { 236 t.Errorf("body mismatch (-want +got):\n%s", cmp.Diff(want, gotBody)) 237 } 238 if want := "/.internal/repos/index"; gotURL.Path != want { 239 t.Errorf("request path mismatch (-want +got):\n%s", cmp.Diff(want, gotURL.Path)) 240 } 241 }) 242} 243 244func TestListRepoIDs_Error_REST(t *testing.T) { 245 // Note: There is no gRPC equivalent to this test because gRPC errors are 246 // always returned as an error to the caller. 247 248 msg := "deadbeaf deadbeaf" 249 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 250 251 // This is how Sourcegraph returns error messages to the caller. 252 http.Error(w, msg, http.StatusInternalServerError) 253 })) 254 defer ts.Close() 255 256 u, err := url.Parse(ts.URL) 257 if err != nil { 258 t.Fatal(err) 259 } 260 261 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0)) 262 s.restClient.RetryMax = 0 263 264 _, err = s.List(context.Background(), []uint32{1, 3}) 265 266 if !strings.Contains(err.Error(), msg) { 267 t.Fatalf("%s does not contain %s", err.Error(), msg) 268 } 269} 270 271func TestMain(m *testing.M) { 272 flag.Parse() 273 level := sglog.LevelInfo 274 if !testing.Verbose() { 275 log.SetOutput(io.Discard) 276 level = sglog.LevelNone 277 } 278 279 logtest.InitWithLevel(m, level) 280 os.Exit(m.Run()) 281} 282 283func TestCreateEmptyShard(t *testing.T) { 284 dir := t.TempDir() 285 286 args := &indexArgs{ 287 IndexOptions: IndexOptions{ 288 RepoID: 7, 289 Name: "empty-repo", 290 CloneURL: "code/host", 291 }, 292 Incremental: true, 293 IndexDir: dir, 294 Parallelism: 1, 295 FileLimit: 1, 296 } 297 298 if err := createEmptyShard(args); err != nil { 299 t.Fatal(err) 300 } 301 302 bo := args.BuildOptions() 303 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 304 305 if got := bo.IncrementalSkipIndexing(); !got { 306 t.Fatalf("wanted %t, got %t", true, got) 307 } 308} 309 310func TestFormatListUint32(t *testing.T) { 311 cases := []struct { 312 in []uint32 313 want string 314 }{ 315 { 316 in: []uint32{42, 8, 3}, 317 want: "42, 8, ...", 318 }, 319 { 320 in: []uint32{42, 8}, 321 want: "42, 8", 322 }, 323 { 324 in: []uint32{42}, 325 want: "42", 326 }, 327 { 328 in: []uint32{}, 329 want: "", 330 }, 331 } 332 333 for _, tt := range cases { 334 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 335 out := formatListUint32(tt.in, 2) 336 if out != tt.want { 337 t.Fatalf("want %s, got %s", tt.want, out) 338 } 339 }) 340 } 341} 342 343func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 344 wd, err := os.Getwd() 345 if err != nil { 346 t.Fatalf("failed to get working directory: %v", err) 347 } 348 349 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 350 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 351 352 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 353 354 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 355 if err != nil { 356 t.Fatalf("failed to validate default service config: %v", err) 357 } 358 359 if !result.Valid() { 360 var errs strings.Builder 361 for _, err := range result.Errors() { 362 errs.WriteString(fmt.Sprintf("- %s\n", err)) 363 } 364 365 t.Fatalf("default service config is invalid:\n%s", errs.String()) 366 } 367} 368 369func TestGetBoolFromEnvironmentVariables(t *testing.T) { 370 testCases := []struct { 371 name string 372 envVarsToSet map[string]string 373 374 envVarNames []string 375 defaultBool bool 376 377 wantBool bool 378 wantErr bool 379 }{ 380 { 381 name: "respect default value: true", 382 383 envVarsToSet: map[string]string{}, 384 385 envVarNames: []string{"FOO", "BAR"}, 386 defaultBool: true, 387 388 wantBool: true, 389 }, 390 { 391 name: "respect default value: false", 392 393 envVarsToSet: map[string]string{}, 394 395 envVarNames: []string{"FOO", "BAR"}, 396 defaultBool: false, 397 398 wantBool: false, 399 }, 400 { 401 name: "read from environment", 402 403 envVarsToSet: map[string]string{"FOO": "1"}, 404 405 envVarNames: []string{"FOO"}, 406 defaultBool: false, 407 408 wantBool: true, 409 }, 410 { 411 name: "read from first env var that is set", 412 413 envVarsToSet: map[string]string{ 414 "BAR": "false", 415 "BAZ": "true", 416 }, 417 418 envVarNames: []string{"FOO", "BAR", "BAZ"}, 419 defaultBool: true, 420 421 wantBool: false, 422 }, 423 424 { 425 name: "should error for invalid input", 426 427 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 428 429 envVarNames: []string{"INVALID"}, 430 defaultBool: false, 431 432 wantErr: true, 433 }, 434 } 435 436 for _, tc := range testCases { 437 t.Run("", func(t *testing.T) { 438 // Prepare the environment by loading all the appropriate environment variables 439 for _, v := range tc.envVarNames { 440 _ = os.Unsetenv(v) 441 } 442 443 for k, _ := range tc.envVarsToSet { 444 _ = os.Unsetenv(k) 445 } 446 447 for k, v := range tc.envVarsToSet { 448 t.Setenv(k, v) 449 } 450 451 // Run the test 452 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 453 454 // Examine the results 455 if tc.wantErr != (err != nil) { 456 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 457 } 458 459 if got != tc.wantBool { 460 t.Errorf("got %v, want %v", got, tc.wantBool) 461 } 462 }) 463 } 464} 465 466func TestAddDefaultPort(t *testing.T) { 467 tests := []struct { 468 name string 469 input string 470 want string 471 }{ 472 { 473 name: "http no port", 474 input: "http://example.com", 475 want: "http://example.com:80", 476 }, 477 { 478 name: "http custom port", 479 input: "http://example.com:90", 480 want: "http://example.com:90", 481 }, 482 { 483 name: "https no port", 484 input: "https://example.com", 485 want: "https://example.com:443", 486 }, 487 { 488 name: "https custom port", 489 input: "https://example.com:444", 490 want: "https://example.com:444", 491 }, 492 { 493 name: "non-http scheme", 494 input: "ftp://example.com", 495 want: "ftp://example.com", 496 }, 497 { 498 name: "empty string", 499 input: "", 500 want: "", 501 }, 502 { 503 name: "local file path", 504 input: "/etc/hosts", 505 want: "/etc/hosts", 506 }, 507 } 508 509 for _, test := range tests { 510 t.Run(test.name, func(t *testing.T) { 511 input, err := url.Parse(test.input) 512 if err != nil { 513 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 514 } 515 516 got := addDefaultPort(input) 517 if diff := cmp.Diff(test.want, got.String()); diff != "" { 518 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 519 } 520 }) 521 } 522}