fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/http" 10 "net/http/httptest" 11 "net/url" 12 "os" 13 "path/filepath" 14 "sort" 15 "strings" 16 "testing" 17 18 sglog "github.com/sourcegraph/log" 19 "github.com/sourcegraph/log/logtest" 20 "github.com/xeipuuv/gojsonschema" 21 "google.golang.org/grpc" 22 23 "github.com/google/go-cmp/cmp" 24 25 "github.com/sourcegraph/zoekt" 26 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1" 27) 28 29func TestServer_defaultArgs(t *testing.T) { 30 root, err := url.Parse("http://api.test") 31 if err != nil { 32 t.Fatal(err) 33 } 34 35 s := &Server{ 36 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)), 37 IndexDir: "/testdata/index", 38 CPUCount: 6, 39 IndexConcurrency: 1, 40 } 41 want := &indexArgs{ 42 IndexOptions: IndexOptions{ 43 Name: "testName", 44 }, 45 IndexDir: "/testdata/index", 46 Parallelism: 6, 47 Incremental: true, 48 FileLimit: 1 << 20, 49 } 50 got := s.indexArgs(IndexOptions{Name: "testName"}) 51 if !cmp.Equal(got, want) { 52 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 53 } 54} 55 56func TestServer_parallelism(t *testing.T) { 57 root, err := url.Parse("http://api.test") 58 if err != nil { 59 t.Fatal(err) 60 } 61 62 cases := []struct { 63 name string 64 cpuCount int 65 indexConcurrency int 66 options IndexOptions 67 want int 68 }{ 69 { 70 name: "CPU count divides evenly", 71 cpuCount: 16, 72 indexConcurrency: 8, 73 want: 2, 74 }, 75 { 76 name: "no shard level parallelism", 77 cpuCount: 4, 78 indexConcurrency: 4, 79 want: 1, 80 }, 81 { 82 name: "index option overrides server flag", 83 cpuCount: 2, 84 indexConcurrency: 1, 85 options: IndexOptions{ 86 ShardConcurrency: 1, 87 }, 88 want: 1, 89 }, 90 { 91 name: "ignore invalid index option", 92 cpuCount: 8, 93 indexConcurrency: 2, 94 options: IndexOptions{ 95 ShardConcurrency: -1, 96 }, 97 want: 4, 98 }, 99 } 100 101 for _, tt := range cases { 102 t.Run(tt.name, func(t *testing.T) { 103 s := &Server{ 104 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)), 105 IndexDir: "/testdata/index", 106 CPUCount: tt.cpuCount, 107 IndexConcurrency: tt.indexConcurrency, 108 } 109 110 maxProcs := 16 111 got := s.parallelism(tt.options, maxProcs) 112 if tt.want != got { 113 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 114 } 115 }) 116 } 117 118 t.Run("index option is limited by available CPU", func(t *testing.T) { 119 s := &Server{ 120 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)), 121 IndexDir: "/testdata/index", 122 IndexConcurrency: 1, 123 } 124 125 got := s.indexArgs(IndexOptions{ 126 ShardConcurrency: 2048, // Some number that's way too high 127 }) 128 129 if got.Parallelism >= 2048 { 130 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 131 } 132 }) 133} 134 135func TestListRepoIDs(t *testing.T) { 136 t.Run("gRPC", func(t *testing.T) { 137 grpcClient := &mockGRPCClient{} 138 139 clientOptions := []SourcegraphClientOption{ 140 WithShouldUseGRPC(true), 141 WithGRPCClient(grpcClient), 142 WithBatchSize(0), 143 } 144 145 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 146 testHostname := "test-hostname" 147 s := newSourcegraphClient(&testURL, testHostname, clientOptions...) 148 149 listCalled := false 150 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) { 151 listCalled = true 152 153 gotRepoIDs := in.GetIndexedIds() 154 sort.Slice(gotRepoIDs, func(i, j int) bool { 155 return gotRepoIDs[i] < gotRepoIDs[j] 156 }) 157 158 wantRepoIDs := []int32{1, 3} 159 sort.Slice(wantRepoIDs, func(i, j int) bool { 160 return wantRepoIDs[i] < wantRepoIDs[j] 161 }) 162 163 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 164 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 165 } 166 167 hostname := in.GetHostname() 168 if diff := cmp.Diff(testHostname, hostname); diff != "" { 169 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 170 } 171 172 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 173 } 174 175 ctx := context.Background() 176 got, err := s.List(ctx, []uint32{1, 3}) 177 if err != nil { 178 t.Fatal(err) 179 } 180 181 if !listCalled { 182 t.Fatalf("List was not called") 183 } 184 185 receivedRepoIDs := got.IDs 186 sort.Slice(receivedRepoIDs, func(i, j int) bool { 187 return receivedRepoIDs[i] < receivedRepoIDs[j] 188 }) 189 190 expectedRepoIDs := []uint32{1, 2, 3} 191 sort.Slice(expectedRepoIDs, func(i, j int) bool { 192 return expectedRepoIDs[i] < expectedRepoIDs[j] 193 }) 194 195 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 196 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 197 } 198 }) 199 200 t.Run("REST", func(t *testing.T) { 201 var gotBody string 202 var gotURL *url.URL 203 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 204 gotURL = r.URL 205 206 b, err := io.ReadAll(r.Body) 207 if err != nil { 208 t.Fatal(err) 209 } 210 gotBody = string(b) 211 212 _, err = w.Write([]byte(`{"RepoIDs": [1, 2, 3]}`)) 213 if err != nil { 214 t.Fatal(err) 215 } 216 })) 217 defer ts.Close() 218 219 u, err := url.Parse(ts.URL) 220 if err != nil { 221 t.Fatal(err) 222 } 223 224 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0)) 225 226 gotRepos, err := s.List(context.Background(), []uint32{1, 3}) 227 if err != nil { 228 t.Fatal(err) 229 } 230 231 if want := []uint32{1, 2, 3}; !cmp.Equal(gotRepos.IDs, want) { 232 t.Errorf("repos mismatch (-want +got):\n%s", cmp.Diff(want, gotRepos.IDs)) 233 } 234 if want := `{"Hostname":"test-indexed-search-1","IndexedIDs":[1,3]}`; gotBody != want { 235 t.Errorf("body mismatch (-want +got):\n%s", cmp.Diff(want, gotBody)) 236 } 237 if want := "/.internal/repos/index"; gotURL.Path != want { 238 t.Errorf("request path mismatch (-want +got):\n%s", cmp.Diff(want, gotURL.Path)) 239 } 240 }) 241} 242 243func TestListRepoIDs_Error_REST(t *testing.T) { 244 // Note: There is no gRPC equivalent to this test because gRPC errors are 245 // always returned as an error to the caller. 246 247 msg := "deadbeaf deadbeaf" 248 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 249 // This is how Sourcegraph returns error messages to the caller. 250 http.Error(w, msg, http.StatusInternalServerError) 251 })) 252 defer ts.Close() 253 254 u, err := url.Parse(ts.URL) 255 if err != nil { 256 t.Fatal(err) 257 } 258 259 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0)) 260 s.restClient.RetryMax = 0 261 262 _, err = s.List(context.Background(), []uint32{1, 3}) 263 264 if !strings.Contains(err.Error(), msg) { 265 t.Fatalf("%s does not contain %s", err.Error(), msg) 266 } 267} 268 269func TestMain(m *testing.M) { 270 flag.Parse() 271 level := sglog.LevelInfo 272 if !testing.Verbose() { 273 log.SetOutput(io.Discard) 274 level = sglog.LevelNone 275 } 276 277 logtest.InitWithLevel(m, level) 278 os.Exit(m.Run()) 279} 280 281func TestCreateEmptyShard(t *testing.T) { 282 dir := t.TempDir() 283 284 args := &indexArgs{ 285 IndexOptions: IndexOptions{ 286 RepoID: 7, 287 Name: "empty-repo", 288 CloneURL: "code/host", 289 }, 290 Incremental: true, 291 IndexDir: dir, 292 Parallelism: 1, 293 FileLimit: 1, 294 } 295 296 if err := createEmptyShard(args); err != nil { 297 t.Fatal(err) 298 } 299 300 bo := args.BuildOptions() 301 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 302 303 if got := bo.IncrementalSkipIndexing(); !got { 304 t.Fatalf("wanted %t, got %t", true, got) 305 } 306} 307 308func TestFormatListUint32(t *testing.T) { 309 cases := []struct { 310 in []uint32 311 want string 312 }{ 313 { 314 in: []uint32{42, 8, 3}, 315 want: "42, 8, ...", 316 }, 317 { 318 in: []uint32{42, 8}, 319 want: "42, 8", 320 }, 321 { 322 in: []uint32{42}, 323 want: "42", 324 }, 325 { 326 in: []uint32{}, 327 want: "", 328 }, 329 } 330 331 for _, tt := range cases { 332 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 333 out := formatListUint32(tt.in, 2) 334 if out != tt.want { 335 t.Fatalf("want %s, got %s", tt.want, out) 336 } 337 }) 338 } 339} 340 341func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 342 wd, err := os.Getwd() 343 if err != nil { 344 t.Fatalf("failed to get working directory: %v", err) 345 } 346 347 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 348 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 349 350 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 351 352 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 353 if err != nil { 354 t.Fatalf("failed to validate default service config: %v", err) 355 } 356 357 if !result.Valid() { 358 var errs strings.Builder 359 for _, err := range result.Errors() { 360 errs.WriteString(fmt.Sprintf("- %s\n", err)) 361 } 362 363 t.Fatalf("default service config is invalid:\n%s", errs.String()) 364 } 365} 366 367func TestGetBoolFromEnvironmentVariables(t *testing.T) { 368 testCases := []struct { 369 name string 370 envVarsToSet map[string]string 371 372 envVarNames []string 373 defaultBool bool 374 375 wantBool bool 376 wantErr bool 377 }{ 378 { 379 name: "respect default value: true", 380 381 envVarsToSet: map[string]string{}, 382 383 envVarNames: []string{"FOO", "BAR"}, 384 defaultBool: true, 385 386 wantBool: true, 387 }, 388 { 389 name: "respect default value: false", 390 391 envVarsToSet: map[string]string{}, 392 393 envVarNames: []string{"FOO", "BAR"}, 394 defaultBool: false, 395 396 wantBool: false, 397 }, 398 { 399 name: "read from environment", 400 401 envVarsToSet: map[string]string{"FOO": "1"}, 402 403 envVarNames: []string{"FOO"}, 404 defaultBool: false, 405 406 wantBool: true, 407 }, 408 { 409 name: "read from first env var that is set", 410 411 envVarsToSet: map[string]string{ 412 "BAR": "false", 413 "BAZ": "true", 414 }, 415 416 envVarNames: []string{"FOO", "BAR", "BAZ"}, 417 defaultBool: true, 418 419 wantBool: false, 420 }, 421 422 { 423 name: "should error for invalid input", 424 425 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 426 427 envVarNames: []string{"INVALID"}, 428 defaultBool: false, 429 430 wantErr: true, 431 }, 432 } 433 434 for _, tc := range testCases { 435 t.Run("", func(t *testing.T) { 436 // Prepare the environment by loading all the appropriate environment variables 437 for _, v := range tc.envVarNames { 438 _ = os.Unsetenv(v) 439 } 440 441 for k := range tc.envVarsToSet { 442 _ = os.Unsetenv(k) 443 } 444 445 for k, v := range tc.envVarsToSet { 446 t.Setenv(k, v) 447 } 448 449 // Run the test 450 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 451 452 // Examine the results 453 if tc.wantErr != (err != nil) { 454 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 455 } 456 457 if got != tc.wantBool { 458 t.Errorf("got %v, want %v", got, tc.wantBool) 459 } 460 }) 461 } 462} 463 464func TestAddDefaultPort(t *testing.T) { 465 tests := []struct { 466 name string 467 input string 468 want string 469 }{ 470 { 471 name: "http no port", 472 input: "http://example.com", 473 want: "http://example.com:80", 474 }, 475 { 476 name: "http custom port", 477 input: "http://example.com:90", 478 want: "http://example.com:90", 479 }, 480 { 481 name: "https no port", 482 input: "https://example.com", 483 want: "https://example.com:443", 484 }, 485 { 486 name: "https custom port", 487 input: "https://example.com:444", 488 want: "https://example.com:444", 489 }, 490 { 491 name: "non-http scheme", 492 input: "ftp://example.com", 493 want: "ftp://example.com", 494 }, 495 { 496 name: "empty string", 497 input: "", 498 want: "", 499 }, 500 { 501 name: "local file path", 502 input: "/etc/hosts", 503 want: "/etc/hosts", 504 }, 505 } 506 507 for _, test := range tests { 508 t.Run(test.name, func(t *testing.T) { 509 input, err := url.Parse(test.input) 510 if err != nil { 511 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 512 } 513 514 got := addDefaultPort(input) 515 if diff := cmp.Diff(test.want, got.String()); diff != "" { 516 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 517 } 518 }) 519 } 520}