fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/http"
10 "net/http/httptest"
11 "net/url"
12 "os"
13 "path/filepath"
14 "sort"
15 "strings"
16 "testing"
17
18 sglog "github.com/sourcegraph/log"
19 "github.com/sourcegraph/log/logtest"
20 "github.com/xeipuuv/gojsonschema"
21 "google.golang.org/grpc"
22
23 "github.com/google/go-cmp/cmp"
24
25 "github.com/sourcegraph/zoekt"
26 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1"
27)
28
29func TestServer_defaultArgs(t *testing.T) {
30 root, err := url.Parse("http://api.test")
31 if err != nil {
32 t.Fatal(err)
33 }
34
35 s := &Server{
36 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)),
37 IndexDir: "/testdata/index",
38 CPUCount: 6,
39 IndexConcurrency: 1,
40 }
41 want := &indexArgs{
42 IndexOptions: IndexOptions{
43 Name: "testName",
44 },
45 IndexDir: "/testdata/index",
46 Parallelism: 6,
47 Incremental: true,
48 FileLimit: 1 << 20,
49 }
50 got := s.indexArgs(IndexOptions{Name: "testName"})
51 if !cmp.Equal(got, want) {
52 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
53 }
54}
55
56func TestServer_parallelism(t *testing.T) {
57 root, err := url.Parse("http://api.test")
58 if err != nil {
59 t.Fatal(err)
60 }
61
62 cases := []struct {
63 name string
64 cpuCount int
65 indexConcurrency int
66 options IndexOptions
67 want int
68 }{
69 {
70 name: "CPU count divides evenly",
71 cpuCount: 16,
72 indexConcurrency: 8,
73 want: 2,
74 },
75 {
76 name: "no shard level parallelism",
77 cpuCount: 4,
78 indexConcurrency: 4,
79 want: 1,
80 },
81 {
82 name: "index option overrides server flag",
83 cpuCount: 2,
84 indexConcurrency: 1,
85 options: IndexOptions {
86 ShardConcurrency: 1,
87 },
88 want: 1,
89 },
90 {
91 name: "ignore invalid index option",
92 cpuCount: 8,
93 indexConcurrency: 2,
94 options: IndexOptions {
95 ShardConcurrency: -1,
96 },
97 want: 4,
98 },
99 }
100
101 for _, tt := range cases {
102 t.Run(tt.name, func(t *testing.T) {
103 s := &Server{
104 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)),
105 IndexDir: "/testdata/index",
106 CPUCount: tt.cpuCount,
107 IndexConcurrency: tt.indexConcurrency,
108 }
109
110 maxProcs := 16
111 got := s.parallelism(tt.options, maxProcs)
112 if tt.want != got{
113 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
114 }
115 })
116 }
117
118 t.Run("index option is limited by available CPU", func(t *testing.T) {
119 s := &Server{
120 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)),
121 IndexDir: "/testdata/index",
122 IndexConcurrency: 1,
123 }
124
125 got := s.indexArgs(IndexOptions {
126 ShardConcurrency: 2048, // Some number that's way too high
127 })
128
129 if got.Parallelism >= 2048 {
130 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
131 }
132 })
133}
134
135func TestListRepoIDs(t *testing.T) {
136 t.Run("gRPC", func(t *testing.T) {
137
138 grpcClient := &mockGRPCClient{}
139
140 clientOptions := []SourcegraphClientOption{
141 WithShouldUseGRPC(true),
142 WithGRPCClient(grpcClient),
143 WithBatchSize(0),
144 }
145
146 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
147 testHostname := "test-hostname"
148 s := newSourcegraphClient(&testURL, testHostname, clientOptions...)
149
150 listCalled := false
151 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) {
152 listCalled = true
153
154 gotRepoIDs := in.GetIndexedIds()
155 sort.Slice(gotRepoIDs, func(i, j int) bool {
156 return gotRepoIDs[i] < gotRepoIDs[j]
157 })
158
159 wantRepoIDs := []int32{1, 3}
160 sort.Slice(wantRepoIDs, func(i, j int) bool {
161 return wantRepoIDs[i] < wantRepoIDs[j]
162 })
163
164 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
165 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
166 }
167
168 hostname := in.GetHostname()
169 if diff := cmp.Diff(testHostname, hostname); diff != "" {
170 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
171 }
172
173 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
174 }
175
176 ctx := context.Background()
177 got, err := s.List(ctx, []uint32{1, 3})
178 if err != nil {
179 t.Fatal(err)
180 }
181
182 if !listCalled {
183 t.Fatalf("List was not called")
184 }
185
186 receivedRepoIDs := got.IDs
187 sort.Slice(receivedRepoIDs, func(i, j int) bool {
188 return receivedRepoIDs[i] < receivedRepoIDs[j]
189 })
190
191 expectedRepoIDs := []uint32{1, 2, 3}
192 sort.Slice(expectedRepoIDs, func(i, j int) bool {
193 return expectedRepoIDs[i] < expectedRepoIDs[j]
194 })
195
196 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
197 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
198 }
199 })
200
201 t.Run("REST", func(t *testing.T) {
202 var gotBody string
203 var gotURL *url.URL
204 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
205 gotURL = r.URL
206
207 b, err := io.ReadAll(r.Body)
208 if err != nil {
209 t.Fatal(err)
210 }
211 gotBody = string(b)
212
213 _, err = w.Write([]byte(`{"RepoIDs": [1, 2, 3]}`))
214 if err != nil {
215 t.Fatal(err)
216 }
217 }))
218 defer ts.Close()
219
220 u, err := url.Parse(ts.URL)
221 if err != nil {
222 t.Fatal(err)
223 }
224
225 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0))
226
227 gotRepos, err := s.List(context.Background(), []uint32{1, 3})
228 if err != nil {
229 t.Fatal(err)
230 }
231
232 if want := []uint32{1, 2, 3}; !cmp.Equal(gotRepos.IDs, want) {
233 t.Errorf("repos mismatch (-want +got):\n%s", cmp.Diff(want, gotRepos.IDs))
234 }
235 if want := `{"Hostname":"test-indexed-search-1","IndexedIDs":[1,3]}`; gotBody != want {
236 t.Errorf("body mismatch (-want +got):\n%s", cmp.Diff(want, gotBody))
237 }
238 if want := "/.internal/repos/index"; gotURL.Path != want {
239 t.Errorf("request path mismatch (-want +got):\n%s", cmp.Diff(want, gotURL.Path))
240 }
241 })
242}
243
244func TestListRepoIDs_Error_REST(t *testing.T) {
245 // Note: There is no gRPC equivalent to this test because gRPC errors are
246 // always returned as an error to the caller.
247
248 msg := "deadbeaf deadbeaf"
249 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
250
251 // This is how Sourcegraph returns error messages to the caller.
252 http.Error(w, msg, http.StatusInternalServerError)
253 }))
254 defer ts.Close()
255
256 u, err := url.Parse(ts.URL)
257 if err != nil {
258 t.Fatal(err)
259 }
260
261 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0))
262 s.restClient.RetryMax = 0
263
264 _, err = s.List(context.Background(), []uint32{1, 3})
265
266 if !strings.Contains(err.Error(), msg) {
267 t.Fatalf("%s does not contain %s", err.Error(), msg)
268 }
269}
270
271func TestMain(m *testing.M) {
272 flag.Parse()
273 level := sglog.LevelInfo
274 if !testing.Verbose() {
275 log.SetOutput(io.Discard)
276 level = sglog.LevelNone
277 }
278
279 logtest.InitWithLevel(m, level)
280 os.Exit(m.Run())
281}
282
283func TestCreateEmptyShard(t *testing.T) {
284 dir := t.TempDir()
285
286 args := &indexArgs{
287 IndexOptions: IndexOptions{
288 RepoID: 7,
289 Name: "empty-repo",
290 CloneURL: "code/host",
291 },
292 Incremental: true,
293 IndexDir: dir,
294 Parallelism: 1,
295 FileLimit: 1,
296 }
297
298 if err := createEmptyShard(args); err != nil {
299 t.Fatal(err)
300 }
301
302 bo := args.BuildOptions()
303 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
304
305 if got := bo.IncrementalSkipIndexing(); !got {
306 t.Fatalf("wanted %t, got %t", true, got)
307 }
308}
309
310func TestFormatListUint32(t *testing.T) {
311 cases := []struct {
312 in []uint32
313 want string
314 }{
315 {
316 in: []uint32{42, 8, 3},
317 want: "42, 8, ...",
318 },
319 {
320 in: []uint32{42, 8},
321 want: "42, 8",
322 },
323 {
324 in: []uint32{42},
325 want: "42",
326 },
327 {
328 in: []uint32{},
329 want: "",
330 },
331 }
332
333 for _, tt := range cases {
334 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
335 out := formatListUint32(tt.in, 2)
336 if out != tt.want {
337 t.Fatalf("want %s, got %s", tt.want, out)
338 }
339 })
340 }
341}
342
343func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
344 wd, err := os.Getwd()
345 if err != nil {
346 t.Fatalf("failed to get working directory: %v", err)
347 }
348
349 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
350 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
351
352 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
353
354 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
355 if err != nil {
356 t.Fatalf("failed to validate default service config: %v", err)
357 }
358
359 if !result.Valid() {
360 var errs strings.Builder
361 for _, err := range result.Errors() {
362 errs.WriteString(fmt.Sprintf("- %s\n", err))
363 }
364
365 t.Fatalf("default service config is invalid:\n%s", errs.String())
366 }
367}
368
369func TestGetBoolFromEnvironmentVariables(t *testing.T) {
370 testCases := []struct {
371 name string
372 envVarsToSet map[string]string
373
374 envVarNames []string
375 defaultBool bool
376
377 wantBool bool
378 wantErr bool
379 }{
380 {
381 name: "respect default value: true",
382
383 envVarsToSet: map[string]string{},
384
385 envVarNames: []string{"FOO", "BAR"},
386 defaultBool: true,
387
388 wantBool: true,
389 },
390 {
391 name: "respect default value: false",
392
393 envVarsToSet: map[string]string{},
394
395 envVarNames: []string{"FOO", "BAR"},
396 defaultBool: false,
397
398 wantBool: false,
399 },
400 {
401 name: "read from environment",
402
403 envVarsToSet: map[string]string{"FOO": "1"},
404
405 envVarNames: []string{"FOO"},
406 defaultBool: false,
407
408 wantBool: true,
409 },
410 {
411 name: "read from first env var that is set",
412
413 envVarsToSet: map[string]string{
414 "BAR": "false",
415 "BAZ": "true",
416 },
417
418 envVarNames: []string{"FOO", "BAR", "BAZ"},
419 defaultBool: true,
420
421 wantBool: false,
422 },
423
424 {
425 name: "should error for invalid input",
426
427 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
428
429 envVarNames: []string{"INVALID"},
430 defaultBool: false,
431
432 wantErr: true,
433 },
434 }
435
436 for _, tc := range testCases {
437 t.Run("", func(t *testing.T) {
438 // Prepare the environment by loading all the appropriate environment variables
439 for _, v := range tc.envVarNames {
440 _ = os.Unsetenv(v)
441 }
442
443 for k, _ := range tc.envVarsToSet {
444 _ = os.Unsetenv(k)
445 }
446
447 for k, v := range tc.envVarsToSet {
448 t.Setenv(k, v)
449 }
450
451 // Run the test
452 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
453
454 // Examine the results
455 if tc.wantErr != (err != nil) {
456 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
457 }
458
459 if got != tc.wantBool {
460 t.Errorf("got %v, want %v", got, tc.wantBool)
461 }
462 })
463 }
464}
465
466func TestAddDefaultPort(t *testing.T) {
467 tests := []struct {
468 name string
469 input string
470 want string
471 }{
472 {
473 name: "http no port",
474 input: "http://example.com",
475 want: "http://example.com:80",
476 },
477 {
478 name: "http custom port",
479 input: "http://example.com:90",
480 want: "http://example.com:90",
481 },
482 {
483 name: "https no port",
484 input: "https://example.com",
485 want: "https://example.com:443",
486 },
487 {
488 name: "https custom port",
489 input: "https://example.com:444",
490 want: "https://example.com:444",
491 },
492 {
493 name: "non-http scheme",
494 input: "ftp://example.com",
495 want: "ftp://example.com",
496 },
497 {
498 name: "empty string",
499 input: "",
500 want: "",
501 },
502 {
503 name: "local file path",
504 input: "/etc/hosts",
505 want: "/etc/hosts",
506 },
507 }
508
509 for _, test := range tests {
510 t.Run(test.name, func(t *testing.T) {
511 input, err := url.Parse(test.input)
512 if err != nil {
513 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
514 }
515
516 got := addDefaultPort(input)
517 if diff := cmp.Diff(test.want, got.String()); diff != "" {
518 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
519 }
520 })
521 }
522}