fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/url" 10 "os" 11 "path/filepath" 12 "sort" 13 "strings" 14 "testing" 15 16 sglog "github.com/sourcegraph/log" 17 "github.com/sourcegraph/log/logtest" 18 "github.com/xeipuuv/gojsonschema" 19 "google.golang.org/grpc" 20 21 "github.com/google/go-cmp/cmp" 22 23 "github.com/sourcegraph/zoekt" 24 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1" 25) 26 27func TestServer_defaultArgs(t *testing.T) { 28 root, err := url.Parse("http://api.test") 29 if err != nil { 30 t.Fatal(err) 31 } 32 33 s := &Server{ 34 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 35 IndexDir: "/testdata/index", 36 CPUCount: 6, 37 IndexConcurrency: 1, 38 } 39 want := &indexArgs{ 40 IndexOptions: IndexOptions{ 41 Name: "testName", 42 }, 43 IndexDir: "/testdata/index", 44 Parallelism: 6, 45 Incremental: true, 46 FileLimit: 1 << 20, 47 } 48 got := s.indexArgs(IndexOptions{Name: "testName"}) 49 if !cmp.Equal(got, want) { 50 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 51 } 52} 53 54func TestServer_parallelism(t *testing.T) { 55 root, err := url.Parse("http://api.test") 56 if err != nil { 57 t.Fatal(err) 58 } 59 60 cases := []struct { 61 name string 62 cpuCount int 63 indexConcurrency int 64 options IndexOptions 65 want int 66 }{ 67 { 68 name: "CPU count divides evenly", 69 cpuCount: 16, 70 indexConcurrency: 8, 71 want: 2, 72 }, 73 { 74 name: "no shard level parallelism", 75 cpuCount: 4, 76 indexConcurrency: 4, 77 want: 1, 78 }, 79 { 80 name: "index option overrides server flag", 81 cpuCount: 2, 82 indexConcurrency: 1, 83 options: IndexOptions{ 84 ShardConcurrency: 1, 85 }, 86 want: 1, 87 }, 88 { 89 name: "ignore invalid index option", 90 cpuCount: 8, 91 indexConcurrency: 2, 92 options: IndexOptions{ 93 ShardConcurrency: -1, 94 }, 95 want: 4, 96 }, 97 } 98 99 for _, tt := range cases { 100 t.Run(tt.name, func(t *testing.T) { 101 s := &Server{ 102 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 103 IndexDir: "/testdata/index", 104 CPUCount: tt.cpuCount, 105 IndexConcurrency: tt.indexConcurrency, 106 } 107 108 maxProcs := 16 109 got := s.parallelism(tt.options, maxProcs) 110 if tt.want != got { 111 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 112 } 113 }) 114 } 115 116 t.Run("index option is limited by available CPU", func(t *testing.T) { 117 s := &Server{ 118 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 119 IndexDir: "/testdata/index", 120 IndexConcurrency: 1, 121 } 122 123 got := s.indexArgs(IndexOptions{ 124 ShardConcurrency: 2048, // Some number that's way too high 125 }) 126 127 if got.Parallelism >= 2048 { 128 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 129 } 130 }) 131} 132 133func TestListRepoIDs(t *testing.T) { 134 grpcClient := &mockGRPCClient{} 135 136 clientOptions := []SourcegraphClientOption{ 137 WithBatchSize(0), 138 } 139 140 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 141 testHostname := "test-hostname" 142 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...) 143 144 listCalled := false 145 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) { 146 listCalled = true 147 148 gotRepoIDs := in.GetIndexedIds() 149 sort.Slice(gotRepoIDs, func(i, j int) bool { 150 return gotRepoIDs[i] < gotRepoIDs[j] 151 }) 152 153 wantRepoIDs := []int32{1, 3} 154 sort.Slice(wantRepoIDs, func(i, j int) bool { 155 return wantRepoIDs[i] < wantRepoIDs[j] 156 }) 157 158 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 159 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 160 } 161 162 hostname := in.GetHostname() 163 if diff := cmp.Diff(testHostname, hostname); diff != "" { 164 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 165 } 166 167 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 168 } 169 170 ctx := context.Background() 171 got, err := s.List(ctx, []uint32{1, 3}) 172 if err != nil { 173 t.Fatal(err) 174 } 175 176 if !listCalled { 177 t.Fatalf("List was not called") 178 } 179 180 receivedRepoIDs := got.IDs 181 sort.Slice(receivedRepoIDs, func(i, j int) bool { 182 return receivedRepoIDs[i] < receivedRepoIDs[j] 183 }) 184 185 expectedRepoIDs := []uint32{1, 2, 3} 186 sort.Slice(expectedRepoIDs, func(i, j int) bool { 187 return expectedRepoIDs[i] < expectedRepoIDs[j] 188 }) 189 190 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 191 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 192 } 193} 194 195func TestMain(m *testing.M) { 196 flag.Parse() 197 level := sglog.LevelInfo 198 if !testing.Verbose() { 199 log.SetOutput(io.Discard) 200 level = sglog.LevelNone 201 } 202 203 logtest.InitWithLevel(m, level) 204 os.Exit(m.Run()) 205} 206 207func TestCreateEmptyShard(t *testing.T) { 208 dir := t.TempDir() 209 210 args := &indexArgs{ 211 IndexOptions: IndexOptions{ 212 RepoID: 7, 213 Name: "empty-repo", 214 CloneURL: "code/host", 215 }, 216 Incremental: true, 217 IndexDir: dir, 218 Parallelism: 1, 219 FileLimit: 1, 220 } 221 222 if err := createEmptyShard(args); err != nil { 223 t.Fatal(err) 224 } 225 226 bo := args.BuildOptions() 227 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 228 229 if got := bo.IncrementalSkipIndexing(); !got { 230 t.Fatalf("wanted %t, got %t", true, got) 231 } 232} 233 234func TestFormatListUint32(t *testing.T) { 235 cases := []struct { 236 in []uint32 237 want string 238 }{ 239 { 240 in: []uint32{42, 8, 3}, 241 want: "42, 8, ...", 242 }, 243 { 244 in: []uint32{42, 8}, 245 want: "42, 8", 246 }, 247 { 248 in: []uint32{42}, 249 want: "42", 250 }, 251 { 252 in: []uint32{}, 253 want: "", 254 }, 255 } 256 257 for _, tt := range cases { 258 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 259 out := formatListUint32(tt.in, 2) 260 if out != tt.want { 261 t.Fatalf("want %s, got %s", tt.want, out) 262 } 263 }) 264 } 265} 266 267func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 268 wd, err := os.Getwd() 269 if err != nil { 270 t.Fatalf("failed to get working directory: %v", err) 271 } 272 273 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 274 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 275 276 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 277 278 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 279 if err != nil { 280 t.Fatalf("failed to validate default service config: %v", err) 281 } 282 283 if !result.Valid() { 284 var errs strings.Builder 285 for _, err := range result.Errors() { 286 errs.WriteString(fmt.Sprintf("- %s\n", err)) 287 } 288 289 t.Fatalf("default service config is invalid:\n%s", errs.String()) 290 } 291} 292 293func TestGetBoolFromEnvironmentVariables(t *testing.T) { 294 testCases := []struct { 295 name string 296 envVarsToSet map[string]string 297 298 envVarNames []string 299 defaultBool bool 300 301 wantBool bool 302 wantErr bool 303 }{ 304 { 305 name: "respect default value: true", 306 307 envVarsToSet: map[string]string{}, 308 309 envVarNames: []string{"FOO", "BAR"}, 310 defaultBool: true, 311 312 wantBool: true, 313 }, 314 { 315 name: "respect default value: false", 316 317 envVarsToSet: map[string]string{}, 318 319 envVarNames: []string{"FOO", "BAR"}, 320 defaultBool: false, 321 322 wantBool: false, 323 }, 324 { 325 name: "read from environment", 326 327 envVarsToSet: map[string]string{"FOO": "1"}, 328 329 envVarNames: []string{"FOO"}, 330 defaultBool: false, 331 332 wantBool: true, 333 }, 334 { 335 name: "read from first env var that is set", 336 337 envVarsToSet: map[string]string{ 338 "BAR": "false", 339 "BAZ": "true", 340 }, 341 342 envVarNames: []string{"FOO", "BAR", "BAZ"}, 343 defaultBool: true, 344 345 wantBool: false, 346 }, 347 348 { 349 name: "should error for invalid input", 350 351 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 352 353 envVarNames: []string{"INVALID"}, 354 defaultBool: false, 355 356 wantErr: true, 357 }, 358 } 359 360 for _, tc := range testCases { 361 t.Run("", func(t *testing.T) { 362 // Prepare the environment by loading all the appropriate environment variables 363 for _, v := range tc.envVarNames { 364 _ = os.Unsetenv(v) 365 } 366 367 for k := range tc.envVarsToSet { 368 _ = os.Unsetenv(k) 369 } 370 371 for k, v := range tc.envVarsToSet { 372 t.Setenv(k, v) 373 } 374 375 // Run the test 376 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 377 378 // Examine the results 379 if tc.wantErr != (err != nil) { 380 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 381 } 382 383 if got != tc.wantBool { 384 t.Errorf("got %v, want %v", got, tc.wantBool) 385 } 386 }) 387 } 388} 389 390func TestAddDefaultPort(t *testing.T) { 391 tests := []struct { 392 name string 393 input string 394 want string 395 }{ 396 { 397 name: "http no port", 398 input: "http://example.com", 399 want: "http://example.com:80", 400 }, 401 { 402 name: "http custom port", 403 input: "http://example.com:90", 404 want: "http://example.com:90", 405 }, 406 { 407 name: "https no port", 408 input: "https://example.com", 409 want: "https://example.com:443", 410 }, 411 { 412 name: "https custom port", 413 input: "https://example.com:444", 414 want: "https://example.com:444", 415 }, 416 { 417 name: "non-http scheme", 418 input: "ftp://example.com", 419 want: "ftp://example.com", 420 }, 421 { 422 name: "empty string", 423 input: "", 424 want: "", 425 }, 426 { 427 name: "local file path", 428 input: "/etc/hosts", 429 want: "/etc/hosts", 430 }, 431 } 432 433 for _, test := range tests { 434 t.Run(test.name, func(t *testing.T) { 435 input, err := url.Parse(test.input) 436 if err != nil { 437 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 438 } 439 440 got := addDefaultPort(input) 441 if diff := cmp.Diff(test.want, got.String()); diff != "" { 442 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 443 } 444 }) 445 } 446}