fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/url" 10 "os" 11 "path/filepath" 12 "slices" 13 "strings" 14 "testing" 15 16 "github.com/google/go-cmp/cmp" 17 sglog "github.com/sourcegraph/log" 18 "github.com/sourcegraph/log/logtest" 19 "github.com/stretchr/testify/require" 20 "github.com/xeipuuv/gojsonschema" 21 "google.golang.org/grpc" 22 23 "github.com/sourcegraph/zoekt" 24 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1" 25 "github.com/sourcegraph/zoekt/internal/tenant" 26) 27 28func TestServer_defaultArgs(t *testing.T) { 29 root, err := url.Parse("http://api.test") 30 if err != nil { 31 t.Fatal(err) 32 } 33 34 s := &Server{ 35 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 36 IndexDir: "/testdata/index", 37 CPUCount: 6, 38 IndexConcurrency: 1, 39 } 40 want := &indexArgs{ 41 IndexOptions: IndexOptions{ 42 Name: "testName", 43 }, 44 IndexDir: "/testdata/index", 45 Parallelism: 6, 46 Incremental: true, 47 } 48 got := s.indexArgs(IndexOptions{Name: "testName"}) 49 if !cmp.Equal(got, want) { 50 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 51 } 52} 53 54func TestIndexNoTenant(t *testing.T) { 55 s := &Server{} 56 _, err := s.index(context.Background(), &indexArgs{}) 57 require.ErrorIs(t, err, tenant.ErrMissingTenant) 58} 59 60func TestServer_parallelism(t *testing.T) { 61 root, err := url.Parse("http://api.test") 62 if err != nil { 63 t.Fatal(err) 64 } 65 66 cases := []struct { 67 name string 68 cpuCount int 69 indexConcurrency int 70 options IndexOptions 71 want int 72 }{ 73 { 74 name: "CPU count divides evenly", 75 cpuCount: 16, 76 indexConcurrency: 8, 77 want: 2, 78 }, 79 { 80 name: "no shard level parallelism", 81 cpuCount: 4, 82 indexConcurrency: 4, 83 want: 1, 84 }, 85 { 86 name: "index option overrides server flag", 87 cpuCount: 2, 88 indexConcurrency: 1, 89 options: IndexOptions{ 90 ShardConcurrency: 1, 91 }, 92 want: 1, 93 }, 94 { 95 name: "ignore invalid index option", 96 cpuCount: 8, 97 indexConcurrency: 2, 98 options: IndexOptions{ 99 ShardConcurrency: -1, 100 }, 101 want: 4, 102 }, 103 } 104 105 for _, tt := range cases { 106 t.Run(tt.name, func(t *testing.T) { 107 s := &Server{ 108 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 109 IndexDir: "/testdata/index", 110 CPUCount: tt.cpuCount, 111 IndexConcurrency: tt.indexConcurrency, 112 } 113 114 maxProcs := 16 115 got := s.parallelism(tt.options, maxProcs) 116 if tt.want != got { 117 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 118 } 119 }) 120 } 121 122 t.Run("index option is limited by available CPU", func(t *testing.T) { 123 s := &Server{ 124 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 125 IndexDir: "/testdata/index", 126 IndexConcurrency: 1, 127 } 128 129 got := s.indexArgs(IndexOptions{ 130 ShardConcurrency: 2048, // Some number that's way too high 131 }) 132 133 if got.Parallelism >= 2048 { 134 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 135 } 136 }) 137} 138 139func TestListRepoIDs(t *testing.T) { 140 grpcClient := &mockGRPCClient{} 141 142 clientOptions := []SourcegraphClientOption{ 143 WithBatchSize(0), 144 } 145 146 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 147 testHostname := "test-hostname" 148 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...) 149 150 listCalled := false 151 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) { 152 listCalled = true 153 154 gotRepoIDs := in.GetIndexedIds() 155 slices.Sort(gotRepoIDs) 156 157 wantRepoIDs := []int32{1, 3} 158 slices.Sort(wantRepoIDs) 159 160 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 161 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 162 } 163 164 hostname := in.GetHostname() 165 if diff := cmp.Diff(testHostname, hostname); diff != "" { 166 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 167 } 168 169 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 170 } 171 172 ctx := context.Background() 173 got, err := s.List(ctx, []uint32{1, 3}) 174 if err != nil { 175 t.Fatal(err) 176 } 177 178 if !listCalled { 179 t.Fatalf("List was not called") 180 } 181 182 receivedRepoIDs := got.IDs 183 slices.Sort(receivedRepoIDs) 184 185 expectedRepoIDs := []uint32{1, 2, 3} 186 slices.Sort(expectedRepoIDs) 187 188 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 189 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 190 } 191} 192 193func TestMain(m *testing.M) { 194 flag.Parse() 195 level := sglog.LevelInfo 196 if !testing.Verbose() { 197 log.SetOutput(io.Discard) 198 debugLog.SetOutput(io.Discard) 199 infoLog.SetOutput(io.Discard) 200 errorLog.SetOutput(io.Discard) 201 level = sglog.LevelNone 202 } 203 204 logtest.InitWithLevel(m, level) 205 os.Exit(m.Run()) 206} 207 208func TestCreateEmptyShard(t *testing.T) { 209 dir := t.TempDir() 210 211 args := &indexArgs{ 212 IndexOptions: IndexOptions{ 213 RepoID: 7, 214 Name: "empty-repo", 215 CloneURL: "code/host", 216 }, 217 Incremental: true, 218 IndexDir: dir, 219 Parallelism: 1, 220 } 221 222 if err := createEmptyShard(args); err != nil { 223 t.Fatal(err) 224 } 225 226 bo := args.BuildOptions() 227 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 228 229 if got := bo.IncrementalSkipIndexing(); !got { 230 t.Fatalf("wanted %t, got %t", true, got) 231 } 232} 233 234func TestFormatListUint32(t *testing.T) { 235 cases := []struct { 236 in []uint32 237 want string 238 }{ 239 { 240 in: []uint32{42, 8, 3}, 241 want: "42, 8, ...", 242 }, 243 { 244 in: []uint32{42, 8}, 245 want: "42, 8", 246 }, 247 { 248 in: []uint32{42}, 249 want: "42", 250 }, 251 { 252 in: []uint32{}, 253 want: "", 254 }, 255 } 256 257 for _, tt := range cases { 258 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 259 out := formatListUint32(tt.in, 2) 260 if out != tt.want { 261 t.Fatalf("want %s, got %s", tt.want, out) 262 } 263 }) 264 } 265} 266 267func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 268 wd, err := os.Getwd() 269 if err != nil { 270 t.Fatalf("failed to get working directory: %v", err) 271 } 272 273 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 274 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 275 276 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 277 278 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 279 if err != nil { 280 t.Fatalf("failed to validate default service config: %v", err) 281 } 282 283 if !result.Valid() { 284 var errs strings.Builder 285 for _, err := range result.Errors() { 286 errs.WriteString(fmt.Sprintf("- %s\n", err)) 287 } 288 289 t.Fatalf("default service config is invalid:\n%s", errs.String()) 290 } 291} 292 293func TestGetBoolFromEnvironmentVariables(t *testing.T) { 294 testCases := []struct { 295 name string 296 envVarsToSet map[string]string 297 298 envVarNames []string 299 defaultBool bool 300 301 wantBool bool 302 wantErr bool 303 }{ 304 { 305 name: "respect default value: true", 306 307 envVarsToSet: map[string]string{}, 308 309 envVarNames: []string{"FOO", "BAR"}, 310 defaultBool: true, 311 312 wantBool: true, 313 }, 314 { 315 name: "respect default value: false", 316 317 envVarsToSet: map[string]string{}, 318 319 envVarNames: []string{"FOO", "BAR"}, 320 defaultBool: false, 321 322 wantBool: false, 323 }, 324 { 325 name: "read from environment", 326 327 envVarsToSet: map[string]string{"FOO": "1"}, 328 329 envVarNames: []string{"FOO"}, 330 defaultBool: false, 331 332 wantBool: true, 333 }, 334 { 335 name: "read from first env var that is set", 336 337 envVarsToSet: map[string]string{ 338 "BAR": "false", 339 "BAZ": "true", 340 }, 341 342 envVarNames: []string{"FOO", "BAR", "BAZ"}, 343 defaultBool: true, 344 345 wantBool: false, 346 }, 347 348 { 349 name: "should error for invalid input", 350 351 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 352 353 envVarNames: []string{"INVALID"}, 354 defaultBool: false, 355 356 wantErr: true, 357 }, 358 } 359 360 for _, tc := range testCases { 361 t.Run("", func(t *testing.T) { 362 // Prepare the environment by loading all the appropriate environment variables 363 for _, v := range tc.envVarNames { 364 _ = os.Unsetenv(v) 365 } 366 367 for k := range tc.envVarsToSet { 368 _ = os.Unsetenv(k) 369 } 370 371 for k, v := range tc.envVarsToSet { 372 t.Setenv(k, v) 373 } 374 375 // Run the test 376 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 377 378 // Examine the results 379 if tc.wantErr != (err != nil) { 380 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 381 } 382 383 if got != tc.wantBool { 384 t.Errorf("got %v, want %v", got, tc.wantBool) 385 } 386 }) 387 } 388} 389 390func TestAddDefaultPort(t *testing.T) { 391 tests := []struct { 392 name string 393 input string 394 want string 395 }{ 396 { 397 name: "http no port", 398 input: "http://example.com", 399 want: "http://example.com:80", 400 }, 401 { 402 name: "http custom port", 403 input: "http://example.com:90", 404 want: "http://example.com:90", 405 }, 406 { 407 name: "https no port", 408 input: "https://example.com", 409 want: "https://example.com:443", 410 }, 411 { 412 name: "https custom port", 413 input: "https://example.com:444", 414 want: "https://example.com:444", 415 }, 416 { 417 name: "non-http scheme", 418 input: "ftp://example.com", 419 want: "ftp://example.com", 420 }, 421 { 422 name: "empty string", 423 input: "", 424 want: "", 425 }, 426 { 427 name: "local file path", 428 input: "/etc/hosts", 429 want: "/etc/hosts", 430 }, 431 } 432 433 for _, test := range tests { 434 t.Run(test.name, func(t *testing.T) { 435 input, err := url.Parse(test.input) 436 if err != nil { 437 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 438 } 439 440 got := addDefaultPort(input) 441 if diff := cmp.Diff(test.want, got.String()); diff != "" { 442 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 443 } 444 }) 445 } 446}