fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/url" 10 "os" 11 "path/filepath" 12 "sort" 13 "strings" 14 "testing" 15 16 sglog "github.com/sourcegraph/log" 17 "github.com/sourcegraph/log/logtest" 18 "github.com/stretchr/testify/require" 19 "github.com/xeipuuv/gojsonschema" 20 "google.golang.org/grpc" 21 22 "github.com/google/go-cmp/cmp" 23 24 "github.com/sourcegraph/zoekt" 25 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1" 26 "github.com/sourcegraph/zoekt/internal/tenant" 27) 28 29func TestServer_defaultArgs(t *testing.T) { 30 root, err := url.Parse("http://api.test") 31 if err != nil { 32 t.Fatal(err) 33 } 34 35 s := &Server{ 36 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 37 IndexDir: "/testdata/index", 38 CPUCount: 6, 39 IndexConcurrency: 1, 40 } 41 want := &indexArgs{ 42 IndexOptions: IndexOptions{ 43 Name: "testName", 44 }, 45 IndexDir: "/testdata/index", 46 Parallelism: 6, 47 Incremental: true, 48 FileLimit: 1 << 20, 49 } 50 got := s.indexArgs(IndexOptions{Name: "testName"}) 51 if !cmp.Equal(got, want) { 52 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 53 } 54} 55 56func TestIndexNoTenant(t *testing.T) { 57 s := &Server{} 58 _, err := s.Index(&indexArgs{}) 59 require.ErrorIs(t, err, tenant.ErrMissingTenant) 60} 61 62func TestServer_parallelism(t *testing.T) { 63 root, err := url.Parse("http://api.test") 64 if err != nil { 65 t.Fatal(err) 66 } 67 68 cases := []struct { 69 name string 70 cpuCount int 71 indexConcurrency int 72 options IndexOptions 73 want int 74 }{ 75 { 76 name: "CPU count divides evenly", 77 cpuCount: 16, 78 indexConcurrency: 8, 79 want: 2, 80 }, 81 { 82 name: "no shard level parallelism", 83 cpuCount: 4, 84 indexConcurrency: 4, 85 want: 1, 86 }, 87 { 88 name: "index option overrides server flag", 89 cpuCount: 2, 90 indexConcurrency: 1, 91 options: IndexOptions{ 92 ShardConcurrency: 1, 93 }, 94 want: 1, 95 }, 96 { 97 name: "ignore invalid index option", 98 cpuCount: 8, 99 indexConcurrency: 2, 100 options: IndexOptions{ 101 ShardConcurrency: -1, 102 }, 103 want: 4, 104 }, 105 } 106 107 for _, tt := range cases { 108 t.Run(tt.name, func(t *testing.T) { 109 s := &Server{ 110 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 111 IndexDir: "/testdata/index", 112 CPUCount: tt.cpuCount, 113 IndexConcurrency: tt.indexConcurrency, 114 } 115 116 maxProcs := 16 117 got := s.parallelism(tt.options, maxProcs) 118 if tt.want != got { 119 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 120 } 121 }) 122 } 123 124 t.Run("index option is limited by available CPU", func(t *testing.T) { 125 s := &Server{ 126 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 127 IndexDir: "/testdata/index", 128 IndexConcurrency: 1, 129 } 130 131 got := s.indexArgs(IndexOptions{ 132 ShardConcurrency: 2048, // Some number that's way too high 133 }) 134 135 if got.Parallelism >= 2048 { 136 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 137 } 138 }) 139} 140 141func TestListRepoIDs(t *testing.T) { 142 grpcClient := &mockGRPCClient{} 143 144 clientOptions := []SourcegraphClientOption{ 145 WithBatchSize(0), 146 } 147 148 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 149 testHostname := "test-hostname" 150 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...) 151 152 listCalled := false 153 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) { 154 listCalled = true 155 156 gotRepoIDs := in.GetIndexedIds() 157 sort.Slice(gotRepoIDs, func(i, j int) bool { 158 return gotRepoIDs[i] < gotRepoIDs[j] 159 }) 160 161 wantRepoIDs := []int32{1, 3} 162 sort.Slice(wantRepoIDs, func(i, j int) bool { 163 return wantRepoIDs[i] < wantRepoIDs[j] 164 }) 165 166 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 167 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 168 } 169 170 hostname := in.GetHostname() 171 if diff := cmp.Diff(testHostname, hostname); diff != "" { 172 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 173 } 174 175 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 176 } 177 178 ctx := context.Background() 179 got, err := s.List(ctx, []uint32{1, 3}) 180 if err != nil { 181 t.Fatal(err) 182 } 183 184 if !listCalled { 185 t.Fatalf("List was not called") 186 } 187 188 receivedRepoIDs := got.IDs 189 sort.Slice(receivedRepoIDs, func(i, j int) bool { 190 return receivedRepoIDs[i] < receivedRepoIDs[j] 191 }) 192 193 expectedRepoIDs := []uint32{1, 2, 3} 194 sort.Slice(expectedRepoIDs, func(i, j int) bool { 195 return expectedRepoIDs[i] < expectedRepoIDs[j] 196 }) 197 198 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 199 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 200 } 201} 202 203func TestMain(m *testing.M) { 204 flag.Parse() 205 level := sglog.LevelInfo 206 if !testing.Verbose() { 207 log.SetOutput(io.Discard) 208 debugLog.SetOutput(io.Discard) 209 infoLog.SetOutput(io.Discard) 210 errorLog.SetOutput(io.Discard) 211 level = sglog.LevelNone 212 } 213 214 logtest.InitWithLevel(m, level) 215 os.Exit(m.Run()) 216} 217 218func TestCreateEmptyShard(t *testing.T) { 219 dir := t.TempDir() 220 221 args := &indexArgs{ 222 IndexOptions: IndexOptions{ 223 RepoID: 7, 224 Name: "empty-repo", 225 CloneURL: "code/host", 226 }, 227 Incremental: true, 228 IndexDir: dir, 229 Parallelism: 1, 230 FileLimit: 1, 231 } 232 233 if err := createEmptyShard(args); err != nil { 234 t.Fatal(err) 235 } 236 237 bo := args.BuildOptions() 238 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 239 240 if got := bo.IncrementalSkipIndexing(); !got { 241 t.Fatalf("wanted %t, got %t", true, got) 242 } 243} 244 245func TestFormatListUint32(t *testing.T) { 246 cases := []struct { 247 in []uint32 248 want string 249 }{ 250 { 251 in: []uint32{42, 8, 3}, 252 want: "42, 8, ...", 253 }, 254 { 255 in: []uint32{42, 8}, 256 want: "42, 8", 257 }, 258 { 259 in: []uint32{42}, 260 want: "42", 261 }, 262 { 263 in: []uint32{}, 264 want: "", 265 }, 266 } 267 268 for _, tt := range cases { 269 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 270 out := formatListUint32(tt.in, 2) 271 if out != tt.want { 272 t.Fatalf("want %s, got %s", tt.want, out) 273 } 274 }) 275 } 276} 277 278func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 279 wd, err := os.Getwd() 280 if err != nil { 281 t.Fatalf("failed to get working directory: %v", err) 282 } 283 284 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 285 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 286 287 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 288 289 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 290 if err != nil { 291 t.Fatalf("failed to validate default service config: %v", err) 292 } 293 294 if !result.Valid() { 295 var errs strings.Builder 296 for _, err := range result.Errors() { 297 errs.WriteString(fmt.Sprintf("- %s\n", err)) 298 } 299 300 t.Fatalf("default service config is invalid:\n%s", errs.String()) 301 } 302} 303 304func TestGetBoolFromEnvironmentVariables(t *testing.T) { 305 testCases := []struct { 306 name string 307 envVarsToSet map[string]string 308 309 envVarNames []string 310 defaultBool bool 311 312 wantBool bool 313 wantErr bool 314 }{ 315 { 316 name: "respect default value: true", 317 318 envVarsToSet: map[string]string{}, 319 320 envVarNames: []string{"FOO", "BAR"}, 321 defaultBool: true, 322 323 wantBool: true, 324 }, 325 { 326 name: "respect default value: false", 327 328 envVarsToSet: map[string]string{}, 329 330 envVarNames: []string{"FOO", "BAR"}, 331 defaultBool: false, 332 333 wantBool: false, 334 }, 335 { 336 name: "read from environment", 337 338 envVarsToSet: map[string]string{"FOO": "1"}, 339 340 envVarNames: []string{"FOO"}, 341 defaultBool: false, 342 343 wantBool: true, 344 }, 345 { 346 name: "read from first env var that is set", 347 348 envVarsToSet: map[string]string{ 349 "BAR": "false", 350 "BAZ": "true", 351 }, 352 353 envVarNames: []string{"FOO", "BAR", "BAZ"}, 354 defaultBool: true, 355 356 wantBool: false, 357 }, 358 359 { 360 name: "should error for invalid input", 361 362 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 363 364 envVarNames: []string{"INVALID"}, 365 defaultBool: false, 366 367 wantErr: true, 368 }, 369 } 370 371 for _, tc := range testCases { 372 t.Run("", func(t *testing.T) { 373 // Prepare the environment by loading all the appropriate environment variables 374 for _, v := range tc.envVarNames { 375 _ = os.Unsetenv(v) 376 } 377 378 for k := range tc.envVarsToSet { 379 _ = os.Unsetenv(k) 380 } 381 382 for k, v := range tc.envVarsToSet { 383 t.Setenv(k, v) 384 } 385 386 // Run the test 387 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 388 389 // Examine the results 390 if tc.wantErr != (err != nil) { 391 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 392 } 393 394 if got != tc.wantBool { 395 t.Errorf("got %v, want %v", got, tc.wantBool) 396 } 397 }) 398 } 399} 400 401func TestAddDefaultPort(t *testing.T) { 402 tests := []struct { 403 name string 404 input string 405 want string 406 }{ 407 { 408 name: "http no port", 409 input: "http://example.com", 410 want: "http://example.com:80", 411 }, 412 { 413 name: "http custom port", 414 input: "http://example.com:90", 415 want: "http://example.com:90", 416 }, 417 { 418 name: "https no port", 419 input: "https://example.com", 420 want: "https://example.com:443", 421 }, 422 { 423 name: "https custom port", 424 input: "https://example.com:444", 425 want: "https://example.com:444", 426 }, 427 { 428 name: "non-http scheme", 429 input: "ftp://example.com", 430 want: "ftp://example.com", 431 }, 432 { 433 name: "empty string", 434 input: "", 435 want: "", 436 }, 437 { 438 name: "local file path", 439 input: "/etc/hosts", 440 want: "/etc/hosts", 441 }, 442 } 443 444 for _, test := range tests { 445 t.Run(test.name, func(t *testing.T) { 446 input, err := url.Parse(test.input) 447 if err != nil { 448 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 449 } 450 451 got := addDefaultPort(input) 452 if diff := cmp.Diff(test.want, got.String()); diff != "" { 453 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 454 } 455 }) 456 } 457}