fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

1package main 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/http" 10 "net/http/httptest" 11 "net/url" 12 "os" 13 "path/filepath" 14 "sort" 15 "strings" 16 "testing" 17 18 sglog "github.com/sourcegraph/log" 19 "github.com/sourcegraph/log/logtest" 20 "github.com/xeipuuv/gojsonschema" 21 "google.golang.org/grpc" 22 23 "github.com/google/go-cmp/cmp" 24 25 "github.com/sourcegraph/zoekt" 26 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1" 27) 28 29func TestServer_defaultArgs(t *testing.T) { 30 root, err := url.Parse("http://api.test") 31 if err != nil { 32 t.Fatal(err) 33 } 34 35 s := &Server{ 36 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)), 37 IndexDir: "/testdata/index", 38 CPUCount: 6, 39 } 40 want := &indexArgs{ 41 IndexOptions: IndexOptions{ 42 Name: "testName", 43 }, 44 IndexDir: "/testdata/index", 45 Parallelism: 6, 46 Incremental: true, 47 FileLimit: 1 << 20, 48 } 49 got := s.indexArgs(IndexOptions{Name: "testName"}) 50 if !cmp.Equal(got, want) { 51 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got)) 52 } 53} 54 55func TestListRepoIDs(t *testing.T) { 56 t.Run("gRPC", func(t *testing.T) { 57 58 grpcClient := &mockGRPCClient{} 59 60 clientOptions := []SourcegraphClientOption{ 61 WithShouldUseGRPC(true), 62 WithGRPCClient(grpcClient), 63 WithBatchSize(0), 64 } 65 66 testURL := url.URL{Scheme: "http", Host: "does.not.matter"} 67 testHostname := "test-hostname" 68 s := newSourcegraphClient(&testURL, testHostname, clientOptions...) 69 70 listCalled := false 71 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) { 72 listCalled = true 73 74 gotRepoIDs := in.GetIndexedIds() 75 sort.Slice(gotRepoIDs, func(i, j int) bool { 76 return gotRepoIDs[i] < gotRepoIDs[j] 77 }) 78 79 wantRepoIDs := []int32{1, 3} 80 sort.Slice(wantRepoIDs, func(i, j int) bool { 81 return wantRepoIDs[i] < wantRepoIDs[j] 82 }) 83 84 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" { 85 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff) 86 } 87 88 hostname := in.GetHostname() 89 if diff := cmp.Diff(testHostname, hostname); diff != "" { 90 t.Errorf("hostname mismatch (-want +got):\n%s", diff) 91 } 92 93 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil 94 } 95 96 ctx := context.Background() 97 got, err := s.List(ctx, []uint32{1, 3}) 98 if err != nil { 99 t.Fatal(err) 100 } 101 102 if !listCalled { 103 t.Fatalf("List was not called") 104 } 105 106 receivedRepoIDs := got.IDs 107 sort.Slice(receivedRepoIDs, func(i, j int) bool { 108 return receivedRepoIDs[i] < receivedRepoIDs[j] 109 }) 110 111 expectedRepoIDs := []uint32{1, 2, 3} 112 sort.Slice(expectedRepoIDs, func(i, j int) bool { 113 return expectedRepoIDs[i] < expectedRepoIDs[j] 114 }) 115 116 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" { 117 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff) 118 } 119 }) 120 121 t.Run("REST", func(t *testing.T) { 122 var gotBody string 123 var gotURL *url.URL 124 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 125 gotURL = r.URL 126 127 b, err := io.ReadAll(r.Body) 128 if err != nil { 129 t.Fatal(err) 130 } 131 gotBody = string(b) 132 133 _, err = w.Write([]byte(`{"RepoIDs": [1, 2, 3]}`)) 134 if err != nil { 135 t.Fatal(err) 136 } 137 })) 138 defer ts.Close() 139 140 u, err := url.Parse(ts.URL) 141 if err != nil { 142 t.Fatal(err) 143 } 144 145 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0)) 146 147 gotRepos, err := s.List(context.Background(), []uint32{1, 3}) 148 if err != nil { 149 t.Fatal(err) 150 } 151 152 if want := []uint32{1, 2, 3}; !cmp.Equal(gotRepos.IDs, want) { 153 t.Errorf("repos mismatch (-want +got):\n%s", cmp.Diff(want, gotRepos.IDs)) 154 } 155 if want := `{"Hostname":"test-indexed-search-1","IndexedIDs":[1,3]}`; gotBody != want { 156 t.Errorf("body mismatch (-want +got):\n%s", cmp.Diff(want, gotBody)) 157 } 158 if want := "/.internal/repos/index"; gotURL.Path != want { 159 t.Errorf("request path mismatch (-want +got):\n%s", cmp.Diff(want, gotURL.Path)) 160 } 161 }) 162} 163 164func TestListRepoIDs_Error_REST(t *testing.T) { 165 // Note: There is no gRPC equivalent to this test because gRPC errors are 166 // always returned as an error to the caller. 167 168 msg := "deadbeaf deadbeaf" 169 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 170 171 // This is how Sourcegraph returns error messages to the caller. 172 http.Error(w, msg, http.StatusInternalServerError) 173 })) 174 defer ts.Close() 175 176 u, err := url.Parse(ts.URL) 177 if err != nil { 178 t.Fatal(err) 179 } 180 181 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0)) 182 s.restClient.RetryMax = 0 183 184 _, err = s.List(context.Background(), []uint32{1, 3}) 185 186 if !strings.Contains(err.Error(), msg) { 187 t.Fatalf("%s does not contain %s", err.Error(), msg) 188 } 189} 190 191func TestMain(m *testing.M) { 192 flag.Parse() 193 level := sglog.LevelInfo 194 if !testing.Verbose() { 195 log.SetOutput(io.Discard) 196 level = sglog.LevelNone 197 } 198 199 logtest.InitWithLevel(m, level) 200 os.Exit(m.Run()) 201} 202 203func TestCreateEmptyShard(t *testing.T) { 204 dir := t.TempDir() 205 206 args := &indexArgs{ 207 IndexOptions: IndexOptions{ 208 RepoID: 7, 209 Name: "empty-repo", 210 CloneURL: "code/host", 211 }, 212 Incremental: true, 213 IndexDir: dir, 214 Parallelism: 1, 215 FileLimit: 1, 216 } 217 218 if err := createEmptyShard(args); err != nil { 219 t.Fatal(err) 220 } 221 222 bo := args.BuildOptions() 223 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}} 224 225 if got := bo.IncrementalSkipIndexing(); !got { 226 t.Fatalf("wanted %t, got %t", true, got) 227 } 228} 229 230func TestFormatListUint32(t *testing.T) { 231 cases := []struct { 232 in []uint32 233 want string 234 }{ 235 { 236 in: []uint32{42, 8, 3}, 237 want: "42, 8, ...", 238 }, 239 { 240 in: []uint32{42, 8}, 241 want: "42, 8", 242 }, 243 { 244 in: []uint32{42}, 245 want: "42", 246 }, 247 { 248 in: []uint32{}, 249 want: "", 250 }, 251 } 252 253 for _, tt := range cases { 254 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { 255 out := formatListUint32(tt.in, 2) 256 if out != tt.want { 257 t.Fatalf("want %s, got %s", tt.want, out) 258 } 259 }) 260 } 261} 262 263func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { 264 wd, err := os.Getwd() 265 if err != nil { 266 t.Fatalf("failed to get working directory: %v", err) 267 } 268 269 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json") 270 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile)) 271 272 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON) 273 274 result, err := gojsonschema.Validate(schemaLoader, documentLoader) 275 if err != nil { 276 t.Fatalf("failed to validate default service config: %v", err) 277 } 278 279 if !result.Valid() { 280 var errs strings.Builder 281 for _, err := range result.Errors() { 282 errs.WriteString(fmt.Sprintf("- %s\n", err)) 283 } 284 285 t.Fatalf("default service config is invalid:\n%s", errs.String()) 286 } 287} 288 289func TestGetBoolFromEnvironmentVariables(t *testing.T) { 290 testCases := []struct { 291 name string 292 envVarsToSet map[string]string 293 294 envVarNames []string 295 defaultBool bool 296 297 wantBool bool 298 wantErr bool 299 }{ 300 { 301 name: "respect default value: true", 302 303 envVarsToSet: map[string]string{}, 304 305 envVarNames: []string{"FOO", "BAR"}, 306 defaultBool: true, 307 308 wantBool: true, 309 }, 310 { 311 name: "respect default value: false", 312 313 envVarsToSet: map[string]string{}, 314 315 envVarNames: []string{"FOO", "BAR"}, 316 defaultBool: false, 317 318 wantBool: false, 319 }, 320 { 321 name: "read from environment", 322 323 envVarsToSet: map[string]string{"FOO": "1"}, 324 325 envVarNames: []string{"FOO"}, 326 defaultBool: false, 327 328 wantBool: true, 329 }, 330 { 331 name: "read from first env var that is set", 332 333 envVarsToSet: map[string]string{ 334 "BAR": "false", 335 "BAZ": "true", 336 }, 337 338 envVarNames: []string{"FOO", "BAR", "BAZ"}, 339 defaultBool: true, 340 341 wantBool: false, 342 }, 343 344 { 345 name: "should error for invalid input", 346 347 envVarsToSet: map[string]string{"INVALID": "not a boolean"}, 348 349 envVarNames: []string{"INVALID"}, 350 defaultBool: false, 351 352 wantErr: true, 353 }, 354 } 355 356 for _, tc := range testCases { 357 t.Run("", func(t *testing.T) { 358 // Prepare the environment by loading all the appropriate environment variables 359 for _, v := range tc.envVarNames { 360 _ = os.Unsetenv(v) 361 } 362 363 for k, _ := range tc.envVarsToSet { 364 _ = os.Unsetenv(k) 365 } 366 367 for k, v := range tc.envVarsToSet { 368 t.Setenv(k, v) 369 } 370 371 // Run the test 372 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) 373 374 // Examine the results 375 if tc.wantErr != (err != nil) { 376 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) 377 } 378 379 if got != tc.wantBool { 380 t.Errorf("got %v, want %v", got, tc.wantBool) 381 } 382 }) 383 } 384} 385 386func TestAddDefaultPort(t *testing.T) { 387 tests := []struct { 388 name string 389 input string 390 want string 391 }{ 392 { 393 name: "http no port", 394 input: "http://example.com", 395 want: "http://example.com:80", 396 }, 397 { 398 name: "http custom port", 399 input: "http://example.com:90", 400 want: "http://example.com:90", 401 }, 402 { 403 name: "https no port", 404 input: "https://example.com", 405 want: "https://example.com:443", 406 }, 407 { 408 name: "https custom port", 409 input: "https://example.com:444", 410 want: "https://example.com:444", 411 }, 412 { 413 name: "non-http scheme", 414 input: "ftp://example.com", 415 want: "ftp://example.com", 416 }, 417 { 418 name: "empty string", 419 input: "", 420 want: "", 421 }, 422 { 423 name: "local file path", 424 input: "/etc/hosts", 425 want: "/etc/hosts", 426 }, 427 } 428 429 for _, test := range tests { 430 t.Run(test.name, func(t *testing.T) { 431 input, err := url.Parse(test.input) 432 if err != nil { 433 t.Fatalf("failed to parse test URL %q: %v", test.input, err) 434 } 435 436 got := addDefaultPort(input) 437 if diff := cmp.Diff(test.want, got.String()); diff != "" { 438 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff) 439 } 440 }) 441 } 442}