fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 10 kB View raw
1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/url" 10 "os" 11 "path/filepath" 12 "slices" 13 "strings" 14 "testing" 15 "unicode/utf8" 16 17 "github.com/google/go-cmp/cmp" 18 sglog "github.com/sourcegraph/log" 19 "github.com/sourcegraph/log/logtest" 20 "github.com/stretchr/testify/require" 21 "github.com/xeipuuv/gojsonschema" 22 "google.golang.org/grpc" 23 24 "github.com/sourcegraph/zoekt" 25 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1" 26 "github.com/sourcegraph/zoekt/internal/tenant" 27) 28 29func TestServer_defaultArgs(t *testing.T) { 30 root, err := url.Parse("http://api.test") 31 if err != nil { 32 t.Fatal(err) 33 } 34 35 s := &Server{ 36 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 37 IndexDir: "/testdata/index", 38 CPUCount: 6, 39 IndexConcurrency: 1, 40 } 41 want := &indexArgs{ 42 IndexOptions: IndexOptions{ 43 Name: "testName", 44 }, 45 IndexDir: "/testdata/index", 46 Parallelism: 6, 47 Incremental: true, 48 } 49 got := s.indexArgs(IndexOptions{Name: "testName"}) 50 if !cmp.Equal(got, want) { 51 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 52 } 53} 54 55func TestIndexNoTenant(t *testing.T) { 56 s := &Server{} 57 _, err := s.index(context.Background(), &indexArgs{}) 58 require.ErrorIs(t, err, tenant.ErrMissingTenant) 59} 60 61func TestTruncateFailureMessageForSourcegraph(t *testing.T) { 62 t.Run("preserves short message", func(t *testing.T) { 63 require.Equal(t, "boom", truncateFailureMessageForSourcegraph("boom")) 64 }) 65 66 t.Run("truncates oversized utf8 message", func(t *testing.T) { 67 input := strings.Repeat("é", maxFailureMessageBytes) + "tail" 68 69 got := truncateFailureMessageForSourcegraph(input) 70 parts := strings.Split(got, failureMessageTruncationMarker) 71 72 require.Len(t, parts, 2) 73 require.LessOrEqual(t, len(got), maxFailureMessageBytes) 74 require.True(t, utf8.ValidString(got)) 75 require.NotEmpty(t, parts[0]) 76 require.True(t, strings.HasPrefix(input, parts[0])) 77 require.Contains(t, parts[1], "tail") 78 require.Less(t, len(parts[0])+len(parts[1]), len(input)) 79 }) 80} 81 82func TestServer_parallelism(t *testing.T) { 83 root, err := url.Parse("http://api.test") 84 if err != nil { 85 t.Fatal(err) 86 } 87 88 cases := []struct { 89 name string 90 cpuCount int 91 indexConcurrency int 92 options IndexOptions 93 want int 94 }{ 95 { 96 name: "CPU count divides evenly", 97 cpuCount: 16, 98 indexConcurrency: 8, 99 want: 2, 100 }, 101 { 102 name: "no shard level parallelism", 103 cpuCount: 4, 104 indexConcurrency: 4, 105 want: 1, 106 }, 107 { 108 name: "index option overrides server flag", 109 cpuCount: 2, 110 indexConcurrency: 1, 111 options: IndexOptions{ 112 ShardConcurrency: 1, 113 }, 114 want: 1, 115 }, 116 { 117 name: "ignore invalid index option", 118 cpuCount: 8, 119 indexConcurrency: 2, 120 options: IndexOptions{ 121 ShardConcurrency: -1, 122 }, 123 want: 4, 124 }, 125 } 126 127 for _, tt := range cases { 128 t.Run(tt.name, func(t *testing.T) { 129 s := &Server{ 130 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 131 IndexDir: "/testdata/index", 132 CPUCount: tt.cpuCount, 133 IndexConcurrency: tt.indexConcurrency, 134 } 135 136 maxProcs := 16 137 got := s.parallelism(tt.options, maxProcs) 138 if tt.want != got { 139 t.Errorf("mismatch, want: %d, got: %d", tt.want, got) 140 } 141 }) 142 } 143 144 t.Run("index option is limited by available CPU", func(t *testing.T) { 145 s := &Server{ 146 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)), 147 IndexDir: "/testdata/index", 148 IndexConcurrency: 1, 149 } 150 151 got := s.indexArgs(IndexOptions{ 152 ShardConcurrency: 2048, // Some number that's way too high 153 }) 154 155 if got.Parallelism >= 2048 { 156 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism) 157 } 158 }) 159} 160 161func TestListRepoIDs(t *testing.T) { 162 grpcClient := &mockGRPCClient{} 163 164 clientOptions := []SourcegraphClientOption{ 165 WithBatchSize(0), 166 } 167 168 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 169 testHostname := "test-hostname" 170 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...) 171 172 listCalled := false 173 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) { 174 listCalled = true 175 176 gotRepoIDs := in.GetIndexedIds() 177 slices.Sort(gotRepoIDs) 178 179 wantRepoIDs := []int32{1, 3} 180 slices.Sort(wantRepoIDs) 181 182 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 183 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 184 } 185 186 hostname := in.GetHostname() 187 if diff := cmp.Diff(testHostname, hostname); diff != "" { 188 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 189 } 190 191 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 192 } 193 194 ctx := context.Background() 195 got, err := s.List(ctx, []uint32{1, 3}) 196 if err != nil { 197 t.Fatal(err) 198 } 199 200 if !listCalled { 201 t.Fatalf("List was not called") 202 } 203 204 receivedRepoIDs := got.IDs 205 slices.Sort(receivedRepoIDs) 206 207 expectedRepoIDs := []uint32{1, 2, 3} 208 slices.Sort(expectedRepoIDs) 209 210 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 211 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 212 } 213} 214 215func TestMain(m *testing.M) { 216 flag.Parse() 217 level := sglog.LevelInfo 218 if !testing.Verbose() { 219 log.SetOutput(io.Discard) 220 debugLog.SetOutput(io.Discard) 221 infoLog.SetOutput(io.Discard) 222 errorLog.SetOutput(io.Discard) 223 level = sglog.LevelNone 224 } 225 226 logtest.InitWithLevel(m, level) 227 os.Exit(m.Run()) 228} 229 230func TestCreateEmptyShard(t *testing.T) { 231 dir := t.TempDir() 232 233 args := &indexArgs{ 234 IndexOptions: IndexOptions{ 235 RepoID: 7, 236 Name: "empty-repo", 237 CloneURL: "code/host", 238 }, 239 Incremental: true, 240 IndexDir: dir, 241 Parallelism: 1, 242 } 243 244 if err := createEmptyShard(args); err != nil { 245 t.Fatal(err) 246 } 247 248 bo := args.BuildOptions() 249 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 250 251 if got := bo.IncrementalSkipIndexing(); !got { 252 t.Fatalf("wanted %t, got %t", true, got) 253 } 254} 255 256func TestFormatListUint32(t *testing.T) { 257 cases := []struct { 258 in []uint32 259 want string 260 }{ 261 { 262 in: []uint32{42, 8, 3}, 263 want: "42, 8, ...", 264 }, 265 { 266 in: []uint32{42, 8}, 267 want: "42, 8", 268 }, 269 { 270 in: []uint32{42}, 271 want: "42", 272 }, 273 { 274 in: []uint32{}, 275 want: "", 276 }, 277 } 278 279 for _, tt := range cases { 280 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 281 out := formatListUint32(tt.in, 2) 282 if out != tt.want { 283 t.Fatalf("want %s, got %s", tt.want, out) 284 } 285 }) 286 } 287} 288 289func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 290 wd, err := os.Getwd() 291 if err != nil { 292 t.Fatalf("failed to get working directory: %v", err) 293 } 294 295 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 296 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 297 298 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 299 300 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 301 if err != nil { 302 t.Fatalf("failed to validate default service config: %v", err) 303 } 304 305 if !result.Valid() { 306 var errs strings.Builder 307 for _, err := range result.Errors() { 308 errs.WriteString(fmt.Sprintf("- %s\n", err)) 309 } 310 311 t.Fatalf("default service config is invalid:\n%s", errs.String()) 312 } 313} 314 315func TestGetBoolFromEnvironmentVariables(t *testing.T) { 316 testCases := []struct { 317 name string 318 envVarsToSet map[string]string 319 320 envVarNames []string 321 defaultBool bool 322 323 wantBool bool 324 wantErr bool 325 }{ 326 { 327 name: "respect default value: true", 328 329 envVarsToSet: map[string]string{}, 330 331 envVarNames: []string{"FOO", "BAR"}, 332 defaultBool: true, 333 334 wantBool: true, 335 }, 336 { 337 name: "respect default value: false", 338 339 envVarsToSet: map[string]string{}, 340 341 envVarNames: []string{"FOO", "BAR"}, 342 defaultBool: false, 343 344 wantBool: false, 345 }, 346 { 347 name: "read from environment", 348 349 envVarsToSet: map[string]string{"FOO": "1"}, 350 351 envVarNames: []string{"FOO"}, 352 defaultBool: false, 353 354 wantBool: true, 355 }, 356 { 357 name: "read from first env var that is set", 358 359 envVarsToSet: map[string]string{ 360 "BAR": "false", 361 "BAZ": "true", 362 }, 363 364 envVarNames: []string{"FOO", "BAR", "BAZ"}, 365 defaultBool: true, 366 367 wantBool: false, 368 }, 369 370 { 371 name: "should error for invalid input", 372 373 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 374 375 envVarNames: []string{"INVALID"}, 376 defaultBool: false, 377 378 wantErr: true, 379 }, 380 } 381 382 for _, tc := range testCases { 383 t.Run("", func(t *testing.T) { 384 // Prepare the environment by loading all the appropriate environment variables 385 for _, v := range tc.envVarNames { 386 _ = os.Unsetenv(v) 387 } 388 389 for k := range tc.envVarsToSet { 390 _ = os.Unsetenv(k) 391 } 392 393 for k, v := range tc.envVarsToSet { 394 t.Setenv(k, v) 395 } 396 397 // Run the test 398 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 399 400 // Examine the results 401 if tc.wantErr != (err != nil) { 402 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 403 } 404 405 if got != tc.wantBool { 406 t.Errorf("got %v, want %v", got, tc.wantBool) 407 } 408 }) 409 } 410} 411 412func TestAddDefaultPort(t *testing.T) { 413 tests := []struct { 414 name string 415 input string 416 want string 417 }{ 418 { 419 name: "http no port", 420 input: "http://example.com", 421 want: "http://example.com:80", 422 }, 423 { 424 name: "http custom port", 425 input: "http://example.com:90", 426 want: "http://example.com:90", 427 }, 428 { 429 name: "https no port", 430 input: "https://example.com", 431 want: "https://example.com:443", 432 }, 433 { 434 name: "https custom port", 435 input: "https://example.com:444", 436 want: "https://example.com:444", 437 }, 438 { 439 name: "non-http scheme", 440 input: "ftp://example.com", 441 want: "ftp://example.com", 442 }, 443 { 444 name: "empty string", 445 input: "", 446 want: "", 447 }, 448 { 449 name: "local file path", 450 input: "/etc/hosts", 451 want: "/etc/hosts", 452 }, 453 } 454 455 for _, test := range tests { 456 t.Run(test.name, func(t *testing.T) { 457 input, err := url.Parse(test.input) 458 if err != nil { 459 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 460 } 461 462 got := addDefaultPort(input) 463 if diff := cmp.Diff(test.want, got.String()); diff != "" { 464 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 465 } 466 }) 467 } 468}