fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/url" 10 "os" 11 "path/filepath" 12 "strings" 13 "testing" 14 15 sglog "github.com/sourcegraph/log" 16 "github.com/sourcegraph/log/logtest" 17 "github.com/stretchr/testify/require" 18 "github.com/xeipuuv/gojsonschema" 19 "google.golang.org/grpc" 20 21 "github.com/google/go-cmp/cmp" 22 23 "slices" 24 25 "github.com/sourcegraph/zoekt" 26 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1" 27 "github.com/sourcegraph/zoekt/internal/tenant" 28) 29 30func TestServer_defaultArgs(t *testing.T) { 31 root, err := url.Parse("http://api.test") 32 if err != nil { 33 t.Fatal(err) 34 } 35 36 s := &Server{ 37 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 38 IndexDir: "/testdata/index", 39 CPUCount: 6, 40 IndexConcurrency: 1, 41 } 42 want := &indexArgs{ 43 IndexOptions: IndexOptions{ 44 Name: "testName", 45 }, 46 IndexDir: "/testdata/index", 47 Parallelism: 6, 48 Incremental: true, 49 FileLimit: 1 << 20, 50 } 51 got := s.indexArgs(IndexOptions{Name: "testName"}) 52 if !cmp.Equal(got, want) { 53 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 54 } 55} 56 57func TestIndexNoTenant(t *testing.T) { 58 s := &Server{} 59 _, err := s.Index(&indexArgs{}) 60 require.ErrorIs(t, err, tenant.ErrMissingTenant) 61} 62 63func TestServer_parallelism(t *testing.T) { 64 root, err := url.Parse("http://api.test") 65 if err != nil { 66 t.Fatal(err) 67 } 68 69 cases := []struct { 70 name string 71 cpuCount int 72 indexConcurrency int 73 options IndexOptions 74 want int 75 }{ 76 { 77 name: "CPU count divides evenly", 78 cpuCount: 16, 79 indexConcurrency: 8, 80 want: 2, 81 }, 82 { 83 name: "no shard level parallelism", 84 cpuCount: 4, 85 indexConcurrency: 4, 86 want: 1, 87 }, 88 { 89 name: "index option overrides server flag", 90 cpuCount: 2, 91 indexConcurrency: 1, 92 options: IndexOptions{ 93 ShardConcurrency: 1, 94 }, 95 want: 1, 96 }, 97 { 98 name: "ignore invalid index option", 99 cpuCount: 8, 100 indexConcurrency: 2, 101 options: IndexOptions{ 102 ShardConcurrency: -1, 103 }, 104 want: 4, 105 }, 106 } 107 108 for _, tt := range cases { 109 t.Run(tt.name, func(t *testing.T) { 110 s := &Server{ 111 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 112 IndexDir: "/testdata/index", 113 CPUCount: tt.cpuCount, 114 IndexConcurrency: tt.indexConcurrency, 115 } 116 117 maxProcs := 16 118 got := s.parallelism(tt.options, maxProcs) 119 if tt.want != got { 120 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 121 } 122 }) 123 } 124 125 t.Run("index option is limited by available CPU", func(t *testing.T) { 126 s := &Server{ 127 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 128 IndexDir: "/testdata/index", 129 IndexConcurrency: 1, 130 } 131 132 got := s.indexArgs(IndexOptions{ 133 ShardConcurrency: 2048, // Some number that's way too high 134 }) 135 136 if got.Parallelism >= 2048 { 137 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 138 } 139 }) 140} 141 142func TestListRepoIDs(t *testing.T) { 143 grpcClient := &mockGRPCClient{} 144 145 clientOptions := []SourcegraphClientOption{ 146 WithBatchSize(0), 147 } 148 149 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 150 testHostname := "test-hostname" 151 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...) 152 153 listCalled := false 154 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) { 155 listCalled = true 156 157 gotRepoIDs := in.GetIndexedIds() 158 slices.Sort(gotRepoIDs) 159 160 wantRepoIDs := []int32{1, 3} 161 slices.Sort(wantRepoIDs) 162 163 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 164 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 165 } 166 167 hostname := in.GetHostname() 168 if diff := cmp.Diff(testHostname, hostname); diff != "" { 169 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 170 } 171 172 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 173 } 174 175 ctx := context.Background() 176 got, err := s.List(ctx, []uint32{1, 3}) 177 if err != nil { 178 t.Fatal(err) 179 } 180 181 if !listCalled { 182 t.Fatalf("List was not called") 183 } 184 185 receivedRepoIDs := got.IDs 186 slices.Sort(receivedRepoIDs) 187 188 expectedRepoIDs := []uint32{1, 2, 3} 189 slices.Sort(expectedRepoIDs) 190 191 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 192 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 193 } 194} 195 196func TestMain(m *testing.M) { 197 flag.Parse() 198 level := sglog.LevelInfo 199 if !testing.Verbose() { 200 log.SetOutput(io.Discard) 201 debugLog.SetOutput(io.Discard) 202 infoLog.SetOutput(io.Discard) 203 errorLog.SetOutput(io.Discard) 204 level = sglog.LevelNone 205 } 206 207 logtest.InitWithLevel(m, level) 208 os.Exit(m.Run()) 209} 210 211func TestCreateEmptyShard(t *testing.T) { 212 dir := t.TempDir() 213 214 args := &indexArgs{ 215 IndexOptions: IndexOptions{ 216 RepoID: 7, 217 Name: "empty-repo", 218 CloneURL: "code/host", 219 }, 220 Incremental: true, 221 IndexDir: dir, 222 Parallelism: 1, 223 FileLimit: 1, 224 } 225 226 if err := createEmptyShard(args); err != nil { 227 t.Fatal(err) 228 } 229 230 bo := args.BuildOptions() 231 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 232 233 if got := bo.IncrementalSkipIndexing(); !got { 234 t.Fatalf("wanted %t, got %t", true, got) 235 } 236} 237 238func TestFormatListUint32(t *testing.T) { 239 cases := []struct { 240 in []uint32 241 want string 242 }{ 243 { 244 in: []uint32{42, 8, 3}, 245 want: "42, 8, ...", 246 }, 247 { 248 in: []uint32{42, 8}, 249 want: "42, 8", 250 }, 251 { 252 in: []uint32{42}, 253 want: "42", 254 }, 255 { 256 in: []uint32{}, 257 want: "", 258 }, 259 } 260 261 for _, tt := range cases { 262 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 263 out := formatListUint32(tt.in, 2) 264 if out != tt.want { 265 t.Fatalf("want %s, got %s", tt.want, out) 266 } 267 }) 268 } 269} 270 271func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 272 wd, err := os.Getwd() 273 if err != nil { 274 t.Fatalf("failed to get working directory: %v", err) 275 } 276 277 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 278 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 279 280 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 281 282 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 283 if err != nil { 284 t.Fatalf("failed to validate default service config: %v", err) 285 } 286 287 if !result.Valid() { 288 var errs strings.Builder 289 for _, err := range result.Errors() { 290 errs.WriteString(fmt.Sprintf("- %s\n", err)) 291 } 292 293 t.Fatalf("default service config is invalid:\n%s", errs.String()) 294 } 295} 296 297func TestGetBoolFromEnvironmentVariables(t *testing.T) { 298 testCases := []struct { 299 name string 300 envVarsToSet map[string]string 301 302 envVarNames []string 303 defaultBool bool 304 305 wantBool bool 306 wantErr bool 307 }{ 308 { 309 name: "respect default value: true", 310 311 envVarsToSet: map[string]string{}, 312 313 envVarNames: []string{"FOO", "BAR"}, 314 defaultBool: true, 315 316 wantBool: true, 317 }, 318 { 319 name: "respect default value: false", 320 321 envVarsToSet: map[string]string{}, 322 323 envVarNames: []string{"FOO", "BAR"}, 324 defaultBool: false, 325 326 wantBool: false, 327 }, 328 { 329 name: "read from environment", 330 331 envVarsToSet: map[string]string{"FOO": "1"}, 332 333 envVarNames: []string{"FOO"}, 334 defaultBool: false, 335 336 wantBool: true, 337 }, 338 { 339 name: "read from first env var that is set", 340 341 envVarsToSet: map[string]string{ 342 "BAR": "false", 343 "BAZ": "true", 344 }, 345 346 envVarNames: []string{"FOO", "BAR", "BAZ"}, 347 defaultBool: true, 348 349 wantBool: false, 350 }, 351 352 { 353 name: "should error for invalid input", 354 355 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 356 357 envVarNames: []string{"INVALID"}, 358 defaultBool: false, 359 360 wantErr: true, 361 }, 362 } 363 364 for _, tc := range testCases { 365 t.Run("", func(t *testing.T) { 366 // Prepare the environment by loading all the appropriate environment variables 367 for _, v := range tc.envVarNames { 368 _ = os.Unsetenv(v) 369 } 370 371 for k := range tc.envVarsToSet { 372 _ = os.Unsetenv(k) 373 } 374 375 for k, v := range tc.envVarsToSet { 376 t.Setenv(k, v) 377 } 378 379 // Run the test 380 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 381 382 // Examine the results 383 if tc.wantErr != (err != nil) { 384 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 385 } 386 387 if got != tc.wantBool { 388 t.Errorf("got %v, want %v", got, tc.wantBool) 389 } 390 }) 391 } 392} 393 394func TestAddDefaultPort(t *testing.T) { 395 tests := []struct { 396 name string 397 input string 398 want string 399 }{ 400 { 401 name: "http no port", 402 input: "http://example.com", 403 want: "http://example.com:80", 404 }, 405 { 406 name: "http custom port", 407 input: "http://example.com:90", 408 want: "http://example.com:90", 409 }, 410 { 411 name: "https no port", 412 input: "https://example.com", 413 want: "https://example.com:443", 414 }, 415 { 416 name: "https custom port", 417 input: "https://example.com:444", 418 want: "https://example.com:444", 419 }, 420 { 421 name: "non-http scheme", 422 input: "ftp://example.com", 423 want: "ftp://example.com", 424 }, 425 { 426 name: "empty string", 427 input: "", 428 want: "", 429 }, 430 { 431 name: "local file path", 432 input: "/etc/hosts", 433 want: "/etc/hosts", 434 }, 435 } 436 437 for _, test := range tests { 438 t.Run(test.name, func(t *testing.T) { 439 input, err := url.Parse(test.input) 440 if err != nil { 441 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 442 } 443 444 got := addDefaultPort(input) 445 if diff := cmp.Diff(test.want, got.String()); diff != "" { 446 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 447 } 448 }) 449 } 450}