fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/url" 10 "os" 11 "path/filepath" 12 "sort" 13 "strings" 14 "testing" 15 16 sglog "github.com/sourcegraph/log" 17 "github.com/sourcegraph/log/logtest" 18 "github.com/stretchr/testify/require" 19 "github.com/xeipuuv/gojsonschema" 20 "google.golang.org/grpc" 21 22 "github.com/google/go-cmp/cmp" 23 24 "github.com/sourcegraph/zoekt" 25 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1" 26 "github.com/sourcegraph/zoekt/internal/tenant" 27) 28 29func TestServer_defaultArgs(t *testing.T) { 30 root, err := url.Parse("http://api.test") 31 if err != nil { 32 t.Fatal(err) 33 } 34 35 s := &Server{ 36 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 37 IndexDir: "/testdata/index", 38 CPUCount: 6, 39 IndexConcurrency: 1, 40 } 41 want := &indexArgs{ 42 IndexOptions: IndexOptions{ 43 Name: "testName", 44 }, 45 IndexDir: "/testdata/index", 46 Parallelism: 6, 47 Incremental: true, 48 FileLimit: 1 << 20, 49 } 50 got := s.indexArgs(IndexOptions{Name: "testName"}) 51 if !cmp.Equal(got, want) { 52 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 53 } 54} 55 56func TestIndexNoTenant(t *testing.T) { 57 s := &Server{} 58 _, err := s.Index(&indexArgs{}) 59 require.ErrorIs(t, err, tenant.ErrMissingTenant) 60} 61 62func TestServer_parallelism(t *testing.T) { 63 root, err := url.Parse("http://api.test") 64 if err != nil { 65 t.Fatal(err) 66 } 67 68 cases := []struct { 69 name string 70 cpuCount int 71 indexConcurrency int 72 options IndexOptions 73 want int 74 }{ 75 { 76 name: "CPU count divides evenly", 77 cpuCount: 16, 78 indexConcurrency: 8, 79 want: 2, 80 }, 81 { 82 name: "no shard level parallelism", 83 cpuCount: 4, 84 indexConcurrency: 4, 85 want: 1, 86 }, 87 { 88 name: "index option overrides server flag", 89 cpuCount: 2, 90 indexConcurrency: 1, 91 options: IndexOptions{ 92 ShardConcurrency: 1, 93 }, 94 want: 1, 95 }, 96 { 97 name: "ignore invalid index option", 98 cpuCount: 8, 99 indexConcurrency: 2, 100 options: IndexOptions{ 101 ShardConcurrency: -1, 102 }, 103 want: 4, 104 }, 105 } 106 107 for _, tt := range cases { 108 t.Run(tt.name, func(t *testing.T) { 109 s := &Server{ 110 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 111 IndexDir: "/testdata/index", 112 CPUCount: tt.cpuCount, 113 IndexConcurrency: tt.indexConcurrency, 114 } 115 116 maxProcs := 16 117 got := s.parallelism(tt.options, maxProcs) 118 if tt.want != got { 119 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 120 } 121 }) 122 } 123 124 t.Run("index option is limited by available CPU", func(t *testing.T) { 125 s := &Server{ 126 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 127 IndexDir: "/testdata/index", 128 IndexConcurrency: 1, 129 } 130 131 got := s.indexArgs(IndexOptions{ 132 ShardConcurrency: 2048, // Some number that's way too high 133 }) 134 135 if got.Parallelism >= 2048 { 136 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 137 } 138 }) 139} 140 141func TestListRepoIDs(t *testing.T) { 142 grpcClient := &mockGRPCClient{} 143 144 clientOptions := []SourcegraphClientOption{ 145 WithBatchSize(0), 146 } 147 148 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 149 testHostname := "test-hostname" 150 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...) 151 152 listCalled := false 153 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) { 154 listCalled = true 155 156 gotRepoIDs := in.GetIndexedIds() 157 sort.Slice(gotRepoIDs, func(i, j int) bool { 158 return gotRepoIDs[i] < gotRepoIDs[j] 159 }) 160 161 wantRepoIDs := []int32{1, 3} 162 sort.Slice(wantRepoIDs, func(i, j int) bool { 163 return wantRepoIDs[i] < wantRepoIDs[j] 164 }) 165 166 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 167 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 168 } 169 170 hostname := in.GetHostname() 171 if diff := cmp.Diff(testHostname, hostname); diff != "" { 172 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 173 } 174 175 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 176 } 177 178 ctx := context.Background() 179 got, err := s.List(ctx, []uint32{1, 3}) 180 if err != nil { 181 t.Fatal(err) 182 } 183 184 if !listCalled { 185 t.Fatalf("List was not called") 186 } 187 188 receivedRepoIDs := got.IDs 189 sort.Slice(receivedRepoIDs, func(i, j int) bool { 190 return receivedRepoIDs[i] < receivedRepoIDs[j] 191 }) 192 193 expectedRepoIDs := []uint32{1, 2, 3} 194 sort.Slice(expectedRepoIDs, func(i, j int) bool { 195 return expectedRepoIDs[i] < expectedRepoIDs[j] 196 }) 197 198 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 199 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 200 } 201} 202 203func TestMain(m *testing.M) { 204 flag.Parse() 205 level := sglog.LevelInfo 206 if !testing.Verbose() { 207 log.SetOutput(io.Discard) 208 level = sglog.LevelNone 209 } 210 211 logtest.InitWithLevel(m, level) 212 os.Exit(m.Run()) 213} 214 215func TestCreateEmptyShard(t *testing.T) { 216 dir := t.TempDir() 217 218 args := &indexArgs{ 219 IndexOptions: IndexOptions{ 220 RepoID: 7, 221 Name: "empty-repo", 222 CloneURL: "code/host", 223 }, 224 Incremental: true, 225 IndexDir: dir, 226 Parallelism: 1, 227 FileLimit: 1, 228 } 229 230 if err := createEmptyShard(args); err != nil { 231 t.Fatal(err) 232 } 233 234 bo := args.BuildOptions() 235 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 236 237 if got := bo.IncrementalSkipIndexing(); !got { 238 t.Fatalf("wanted %t, got %t", true, got) 239 } 240} 241 242func TestFormatListUint32(t *testing.T) { 243 cases := []struct { 244 in []uint32 245 want string 246 }{ 247 { 248 in: []uint32{42, 8, 3}, 249 want: "42, 8, ...", 250 }, 251 { 252 in: []uint32{42, 8}, 253 want: "42, 8", 254 }, 255 { 256 in: []uint32{42}, 257 want: "42", 258 }, 259 { 260 in: []uint32{}, 261 want: "", 262 }, 263 } 264 265 for _, tt := range cases { 266 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 267 out := formatListUint32(tt.in, 2) 268 if out != tt.want { 269 t.Fatalf("want %s, got %s", tt.want, out) 270 } 271 }) 272 } 273} 274 275func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 276 wd, err := os.Getwd() 277 if err != nil { 278 t.Fatalf("failed to get working directory: %v", err) 279 } 280 281 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 282 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 283 284 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 285 286 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 287 if err != nil { 288 t.Fatalf("failed to validate default service config: %v", err) 289 } 290 291 if !result.Valid() { 292 var errs strings.Builder 293 for _, err := range result.Errors() { 294 errs.WriteString(fmt.Sprintf("- %s\n", err)) 295 } 296 297 t.Fatalf("default service config is invalid:\n%s", errs.String()) 298 } 299} 300 301func TestGetBoolFromEnvironmentVariables(t *testing.T) { 302 testCases := []struct { 303 name string 304 envVarsToSet map[string]string 305 306 envVarNames []string 307 defaultBool bool 308 309 wantBool bool 310 wantErr bool 311 }{ 312 { 313 name: "respect default value: true", 314 315 envVarsToSet: map[string]string{}, 316 317 envVarNames: []string{"FOO", "BAR"}, 318 defaultBool: true, 319 320 wantBool: true, 321 }, 322 { 323 name: "respect default value: false", 324 325 envVarsToSet: map[string]string{}, 326 327 envVarNames: []string{"FOO", "BAR"}, 328 defaultBool: false, 329 330 wantBool: false, 331 }, 332 { 333 name: "read from environment", 334 335 envVarsToSet: map[string]string{"FOO": "1"}, 336 337 envVarNames: []string{"FOO"}, 338 defaultBool: false, 339 340 wantBool: true, 341 }, 342 { 343 name: "read from first env var that is set", 344 345 envVarsToSet: map[string]string{ 346 "BAR": "false", 347 "BAZ": "true", 348 }, 349 350 envVarNames: []string{"FOO", "BAR", "BAZ"}, 351 defaultBool: true, 352 353 wantBool: false, 354 }, 355 356 { 357 name: "should error for invalid input", 358 359 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 360 361 envVarNames: []string{"INVALID"}, 362 defaultBool: false, 363 364 wantErr: true, 365 }, 366 } 367 368 for _, tc := range testCases { 369 t.Run("", func(t *testing.T) { 370 // Prepare the environment by loading all the appropriate environment variables 371 for _, v := range tc.envVarNames { 372 _ = os.Unsetenv(v) 373 } 374 375 for k := range tc.envVarsToSet { 376 _ = os.Unsetenv(k) 377 } 378 379 for k, v := range tc.envVarsToSet { 380 t.Setenv(k, v) 381 } 382 383 // Run the test 384 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 385 386 // Examine the results 387 if tc.wantErr != (err != nil) { 388 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 389 } 390 391 if got != tc.wantBool { 392 t.Errorf("got %v, want %v", got, tc.wantBool) 393 } 394 }) 395 } 396} 397 398func TestAddDefaultPort(t *testing.T) { 399 tests := []struct { 400 name string 401 input string 402 want string 403 }{ 404 { 405 name: "http no port", 406 input: "http://example.com", 407 want: "http://example.com:80", 408 }, 409 { 410 name: "http custom port", 411 input: "http://example.com:90", 412 want: "http://example.com:90", 413 }, 414 { 415 name: "https no port", 416 input: "https://example.com", 417 want: "https://example.com:443", 418 }, 419 { 420 name: "https custom port", 421 input: "https://example.com:444", 422 want: "https://example.com:444", 423 }, 424 { 425 name: "non-http scheme", 426 input: "ftp://example.com", 427 want: "ftp://example.com", 428 }, 429 { 430 name: "empty string", 431 input: "", 432 want: "", 433 }, 434 { 435 name: "local file path", 436 input: "/etc/hosts", 437 want: "/etc/hosts", 438 }, 439 } 440 441 for _, test := range tests { 442 t.Run(test.name, func(t *testing.T) { 443 input, err := url.Parse(test.input) 444 if err != nil { 445 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 446 } 447 448 got := addDefaultPort(input) 449 if diff := cmp.Diff(test.want, got.String()); diff != "" { 450 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 451 } 452 }) 453 } 454}