fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/url" 10 "os" 11 "path/filepath" 12 "slices" 13 "strings" 14 "testing" 15 16 "github.com/google/go-cmp/cmp" 17 sglog "github.com/sourcegraph/log" 18 "github.com/sourcegraph/log/logtest" 19 "github.com/stretchr/testify/require" 20 "github.com/xeipuuv/gojsonschema" 21 "google.golang.org/grpc" 22 23 "github.com/sourcegraph/zoekt" 24 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1" 25 "github.com/sourcegraph/zoekt/internal/tenant" 26) 27 28func TestServer_defaultArgs(t *testing.T) { 29 root, err := url.Parse("http://api.test") 30 if err != nil { 31 t.Fatal(err) 32 } 33 34 s := &Server{ 35 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 36 IndexDir: "/testdata/index", 37 CPUCount: 6, 38 IndexConcurrency: 1, 39 } 40 want := &indexArgs{ 41 IndexOptions: IndexOptions{ 42 Name: "testName", 43 }, 44 IndexDir: "/testdata/index", 45 Parallelism: 6, 46 Incremental: true, 47 FileLimit: 1 << 20, 48 } 49 got := s.indexArgs(IndexOptions{Name: "testName"}) 50 if !cmp.Equal(got, want) { 51 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 52 } 53} 54 55func TestIndexNoTenant(t *testing.T) { 56 s := &Server{} 57 _, err := s.index(context.Background(), &indexArgs{}) 58 require.ErrorIs(t, err, tenant.ErrMissingTenant) 59} 60 61func TestServer_parallelism(t *testing.T) { 62 root, err := url.Parse("http://api.test") 63 if err != nil { 64 t.Fatal(err) 65 } 66 67 cases := []struct { 68 name string 69 cpuCount int 70 indexConcurrency int 71 options IndexOptions 72 want int 73 }{ 74 { 75 name: "CPU count divides evenly", 76 cpuCount: 16, 77 indexConcurrency: 8, 78 want: 2, 79 }, 80 { 81 name: "no shard level parallelism", 82 cpuCount: 4, 83 indexConcurrency: 4, 84 want: 1, 85 }, 86 { 87 name: "index option overrides server flag", 88 cpuCount: 2, 89 indexConcurrency: 1, 90 options: IndexOptions{ 91 ShardConcurrency: 1, 92 }, 93 want: 1, 94 }, 95 { 96 name: "ignore invalid index option", 97 cpuCount: 8, 98 indexConcurrency: 2, 99 options: IndexOptions{ 100 ShardConcurrency: -1, 101 }, 102 want: 4, 103 }, 104 } 105 106 for _, tt := range cases { 107 t.Run(tt.name, func(t *testing.T) { 108 s := &Server{ 109 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 110 IndexDir: "/testdata/index", 111 CPUCount: tt.cpuCount, 112 IndexConcurrency: tt.indexConcurrency, 113 } 114 115 maxProcs := 16 116 got := s.parallelism(tt.options, maxProcs) 117 if tt.want != got { 118 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 119 } 120 }) 121 } 122 123 t.Run("index option is limited by available CPU", func(t *testing.T) { 124 s := &Server{ 125 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 126 IndexDir: "/testdata/index", 127 IndexConcurrency: 1, 128 } 129 130 got := s.indexArgs(IndexOptions{ 131 ShardConcurrency: 2048, // Some number that's way too high 132 }) 133 134 if got.Parallelism >= 2048 { 135 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 136 } 137 }) 138} 139 140func TestListRepoIDs(t *testing.T) { 141 grpcClient := &mockGRPCClient{} 142 143 clientOptions := []SourcegraphClientOption{ 144 WithBatchSize(0), 145 } 146 147 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 148 testHostname := "test-hostname" 149 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...) 150 151 listCalled := false 152 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) { 153 listCalled = true 154 155 gotRepoIDs := in.GetIndexedIds() 156 slices.Sort(gotRepoIDs) 157 158 wantRepoIDs := []int32{1, 3} 159 slices.Sort(wantRepoIDs) 160 161 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 162 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 163 } 164 165 hostname := in.GetHostname() 166 if diff := cmp.Diff(testHostname, hostname); diff != "" { 167 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 168 } 169 170 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 171 } 172 173 ctx := context.Background() 174 got, err := s.List(ctx, []uint32{1, 3}) 175 if err != nil { 176 t.Fatal(err) 177 } 178 179 if !listCalled { 180 t.Fatalf("List was not called") 181 } 182 183 receivedRepoIDs := got.IDs 184 slices.Sort(receivedRepoIDs) 185 186 expectedRepoIDs := []uint32{1, 2, 3} 187 slices.Sort(expectedRepoIDs) 188 189 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 190 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 191 } 192} 193 194func TestMain(m *testing.M) { 195 flag.Parse() 196 level := sglog.LevelInfo 197 if !testing.Verbose() { 198 log.SetOutput(io.Discard) 199 debugLog.SetOutput(io.Discard) 200 infoLog.SetOutput(io.Discard) 201 errorLog.SetOutput(io.Discard) 202 level = sglog.LevelNone 203 } 204 205 logtest.InitWithLevel(m, level) 206 os.Exit(m.Run()) 207} 208 209func TestCreateEmptyShard(t *testing.T) { 210 dir := t.TempDir() 211 212 args := &indexArgs{ 213 IndexOptions: IndexOptions{ 214 RepoID: 7, 215 Name: "empty-repo", 216 CloneURL: "code/host", 217 }, 218 Incremental: true, 219 IndexDir: dir, 220 Parallelism: 1, 221 FileLimit: 1, 222 } 223 224 if err := createEmptyShard(args); err != nil { 225 t.Fatal(err) 226 } 227 228 bo := args.BuildOptions() 229 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 230 231 if got := bo.IncrementalSkipIndexing(); !got { 232 t.Fatalf("wanted %t, got %t", true, got) 233 } 234} 235 236func TestFormatListUint32(t *testing.T) { 237 cases := []struct { 238 in []uint32 239 want string 240 }{ 241 { 242 in: []uint32{42, 8, 3}, 243 want: "42, 8, ...", 244 }, 245 { 246 in: []uint32{42, 8}, 247 want: "42, 8", 248 }, 249 { 250 in: []uint32{42}, 251 want: "42", 252 }, 253 { 254 in: []uint32{}, 255 want: "", 256 }, 257 } 258 259 for _, tt := range cases { 260 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 261 out := formatListUint32(tt.in, 2) 262 if out != tt.want { 263 t.Fatalf("want %s, got %s", tt.want, out) 264 } 265 }) 266 } 267} 268 269func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 270 wd, err := os.Getwd() 271 if err != nil { 272 t.Fatalf("failed to get working directory: %v", err) 273 } 274 275 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 276 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 277 278 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 279 280 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 281 if err != nil { 282 t.Fatalf("failed to validate default service config: %v", err) 283 } 284 285 if !result.Valid() { 286 var errs strings.Builder 287 for _, err := range result.Errors() { 288 errs.WriteString(fmt.Sprintf("- %s\n", err)) 289 } 290 291 t.Fatalf("default service config is invalid:\n%s", errs.String()) 292 } 293} 294 295func TestGetBoolFromEnvironmentVariables(t *testing.T) { 296 testCases := []struct { 297 name string 298 envVarsToSet map[string]string 299 300 envVarNames []string 301 defaultBool bool 302 303 wantBool bool 304 wantErr bool 305 }{ 306 { 307 name: "respect default value: true", 308 309 envVarsToSet: map[string]string{}, 310 311 envVarNames: []string{"FOO", "BAR"}, 312 defaultBool: true, 313 314 wantBool: true, 315 }, 316 { 317 name: "respect default value: false", 318 319 envVarsToSet: map[string]string{}, 320 321 envVarNames: []string{"FOO", "BAR"}, 322 defaultBool: false, 323 324 wantBool: false, 325 }, 326 { 327 name: "read from environment", 328 329 envVarsToSet: map[string]string{"FOO": "1"}, 330 331 envVarNames: []string{"FOO"}, 332 defaultBool: false, 333 334 wantBool: true, 335 }, 336 { 337 name: "read from first env var that is set", 338 339 envVarsToSet: map[string]string{ 340 "BAR": "false", 341 "BAZ": "true", 342 }, 343 344 envVarNames: []string{"FOO", "BAR", "BAZ"}, 345 defaultBool: true, 346 347 wantBool: false, 348 }, 349 350 { 351 name: "should error for invalid input", 352 353 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 354 355 envVarNames: []string{"INVALID"}, 356 defaultBool: false, 357 358 wantErr: true, 359 }, 360 } 361 362 for _, tc := range testCases { 363 t.Run("", func(t *testing.T) { 364 // Prepare the environment by loading all the appropriate environment variables 365 for _, v := range tc.envVarNames { 366 _ = os.Unsetenv(v) 367 } 368 369 for k := range tc.envVarsToSet { 370 _ = os.Unsetenv(k) 371 } 372 373 for k, v := range tc.envVarsToSet { 374 t.Setenv(k, v) 375 } 376 377 // Run the test 378 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 379 380 // Examine the results 381 if tc.wantErr != (err != nil) { 382 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 383 } 384 385 if got != tc.wantBool { 386 t.Errorf("got %v, want %v", got, tc.wantBool) 387 } 388 }) 389 } 390} 391 392func TestAddDefaultPort(t *testing.T) { 393 tests := []struct { 394 name string 395 input string 396 want string 397 }{ 398 { 399 name: "http no port", 400 input: "http://example.com", 401 want: "http://example.com:80", 402 }, 403 { 404 name: "http custom port", 405 input: "http://example.com:90", 406 want: "http://example.com:90", 407 }, 408 { 409 name: "https no port", 410 input: "https://example.com", 411 want: "https://example.com:443", 412 }, 413 { 414 name: "https custom port", 415 input: "https://example.com:444", 416 want: "https://example.com:444", 417 }, 418 { 419 name: "non-http scheme", 420 input: "ftp://example.com", 421 want: "ftp://example.com", 422 }, 423 { 424 name: "empty string", 425 input: "", 426 want: "", 427 }, 428 { 429 name: "local file path", 430 input: "/etc/hosts", 431 want: "/etc/hosts", 432 }, 433 } 434 435 for _, test := range tests { 436 t.Run(test.name, func(t *testing.T) { 437 input, err := url.Parse(test.input) 438 if err != nil { 439 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 440 } 441 442 got := addDefaultPort(input) 443 if diff := cmp.Diff(test.want, got.String()); diff != "" { 444 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 445 } 446 }) 447 } 448}