fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "path/filepath"
12 "sort"
13 "strings"
14 "testing"
15
16 sglog "github.com/sourcegraph/log"
17 "github.com/sourcegraph/log/logtest"
18 "github.com/xeipuuv/gojsonschema"
19 "google.golang.org/grpc"
20
21 "github.com/google/go-cmp/cmp"
22
23 "github.com/sourcegraph/zoekt"
24 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1"
25)
26
27func TestServer_defaultArgs(t *testing.T) {
28 root, err := url.Parse("http://api.test")
29 if err != nil {
30 t.Fatal(err)
31 }
32
33 s := &Server{
34 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
35 IndexDir: "/testdata/index",
36 CPUCount: 6,
37 IndexConcurrency: 1,
38 }
39 want := &indexArgs{
40 IndexOptions: IndexOptions{
41 Name: "testName",
42 },
43 IndexDir: "/testdata/index",
44 Parallelism: 6,
45 Incremental: true,
46 FileLimit: 1 << 20,
47 }
48 got := s.indexArgs(IndexOptions{Name: "testName"})
49 if !cmp.Equal(got, want) {
50 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
51 }
52}
53
54func TestServer_parallelism(t *testing.T) {
55 root, err := url.Parse("http://api.test")
56 if err != nil {
57 t.Fatal(err)
58 }
59
60 cases := []struct {
61 name string
62 cpuCount int
63 indexConcurrency int
64 options IndexOptions
65 want int
66 }{
67 {
68 name: "CPU count divides evenly",
69 cpuCount: 16,
70 indexConcurrency: 8,
71 want: 2,
72 },
73 {
74 name: "no shard level parallelism",
75 cpuCount: 4,
76 indexConcurrency: 4,
77 want: 1,
78 },
79 {
80 name: "index option overrides server flag",
81 cpuCount: 2,
82 indexConcurrency: 1,
83 options: IndexOptions{
84 ShardConcurrency: 1,
85 },
86 want: 1,
87 },
88 {
89 name: "ignore invalid index option",
90 cpuCount: 8,
91 indexConcurrency: 2,
92 options: IndexOptions{
93 ShardConcurrency: -1,
94 },
95 want: 4,
96 },
97 }
98
99 for _, tt := range cases {
100 t.Run(tt.name, func(t *testing.T) {
101 s := &Server{
102 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
103 IndexDir: "/testdata/index",
104 CPUCount: tt.cpuCount,
105 IndexConcurrency: tt.indexConcurrency,
106 }
107
108 maxProcs := 16
109 got := s.parallelism(tt.options, maxProcs)
110 if tt.want != got {
111 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
112 }
113 })
114 }
115
116 t.Run("index option is limited by available CPU", func(t *testing.T) {
117 s := &Server{
118 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
119 IndexDir: "/testdata/index",
120 IndexConcurrency: 1,
121 }
122
123 got := s.indexArgs(IndexOptions{
124 ShardConcurrency: 2048, // Some number that's way too high
125 })
126
127 if got.Parallelism >= 2048 {
128 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
129 }
130 })
131}
132
133func TestListRepoIDs(t *testing.T) {
134 grpcClient := &mockGRPCClient{}
135
136 clientOptions := []SourcegraphClientOption{
137 WithBatchSize(0),
138 }
139
140 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
141 testHostname := "test-hostname"
142 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...)
143
144 listCalled := false
145 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) {
146 listCalled = true
147
148 gotRepoIDs := in.GetIndexedIds()
149 sort.Slice(gotRepoIDs, func(i, j int) bool {
150 return gotRepoIDs[i] < gotRepoIDs[j]
151 })
152
153 wantRepoIDs := []int32{1, 3}
154 sort.Slice(wantRepoIDs, func(i, j int) bool {
155 return wantRepoIDs[i] < wantRepoIDs[j]
156 })
157
158 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
159 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
160 }
161
162 hostname := in.GetHostname()
163 if diff := cmp.Diff(testHostname, hostname); diff != "" {
164 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
165 }
166
167 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
168 }
169
170 ctx := context.Background()
171 got, err := s.List(ctx, []uint32{1, 3})
172 if err != nil {
173 t.Fatal(err)
174 }
175
176 if !listCalled {
177 t.Fatalf("List was not called")
178 }
179
180 receivedRepoIDs := got.IDs
181 sort.Slice(receivedRepoIDs, func(i, j int) bool {
182 return receivedRepoIDs[i] < receivedRepoIDs[j]
183 })
184
185 expectedRepoIDs := []uint32{1, 2, 3}
186 sort.Slice(expectedRepoIDs, func(i, j int) bool {
187 return expectedRepoIDs[i] < expectedRepoIDs[j]
188 })
189
190 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
191 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
192 }
193}
194
195func TestMain(m *testing.M) {
196 flag.Parse()
197 level := sglog.LevelInfo
198 if !testing.Verbose() {
199 log.SetOutput(io.Discard)
200 level = sglog.LevelNone
201 }
202
203 logtest.InitWithLevel(m, level)
204 os.Exit(m.Run())
205}
206
207func TestCreateEmptyShard(t *testing.T) {
208 dir := t.TempDir()
209
210 args := &indexArgs{
211 IndexOptions: IndexOptions{
212 RepoID: 7,
213 Name: "empty-repo",
214 CloneURL: "code/host",
215 },
216 Incremental: true,
217 IndexDir: dir,
218 Parallelism: 1,
219 FileLimit: 1,
220 }
221
222 if err := createEmptyShard(args); err != nil {
223 t.Fatal(err)
224 }
225
226 bo := args.BuildOptions()
227 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
228
229 if got := bo.IncrementalSkipIndexing(); !got {
230 t.Fatalf("wanted %t, got %t", true, got)
231 }
232}
233
234func TestFormatListUint32(t *testing.T) {
235 cases := []struct {
236 in []uint32
237 want string
238 }{
239 {
240 in: []uint32{42, 8, 3},
241 want: "42, 8, ...",
242 },
243 {
244 in: []uint32{42, 8},
245 want: "42, 8",
246 },
247 {
248 in: []uint32{42},
249 want: "42",
250 },
251 {
252 in: []uint32{},
253 want: "",
254 },
255 }
256
257 for _, tt := range cases {
258 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
259 out := formatListUint32(tt.in, 2)
260 if out != tt.want {
261 t.Fatalf("want %s, got %s", tt.want, out)
262 }
263 })
264 }
265}
266
267func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
268 wd, err := os.Getwd()
269 if err != nil {
270 t.Fatalf("failed to get working directory: %v", err)
271 }
272
273 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
274 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
275
276 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
277
278 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
279 if err != nil {
280 t.Fatalf("failed to validate default service config: %v", err)
281 }
282
283 if !result.Valid() {
284 var errs strings.Builder
285 for _, err := range result.Errors() {
286 errs.WriteString(fmt.Sprintf("- %s\n", err))
287 }
288
289 t.Fatalf("default service config is invalid:\n%s", errs.String())
290 }
291}
292
293func TestGetBoolFromEnvironmentVariables(t *testing.T) {
294 testCases := []struct {
295 name string
296 envVarsToSet map[string]string
297
298 envVarNames []string
299 defaultBool bool
300
301 wantBool bool
302 wantErr bool
303 }{
304 {
305 name: "respect default value: true",
306
307 envVarsToSet: map[string]string{},
308
309 envVarNames: []string{"FOO", "BAR"},
310 defaultBool: true,
311
312 wantBool: true,
313 },
314 {
315 name: "respect default value: false",
316
317 envVarsToSet: map[string]string{},
318
319 envVarNames: []string{"FOO", "BAR"},
320 defaultBool: false,
321
322 wantBool: false,
323 },
324 {
325 name: "read from environment",
326
327 envVarsToSet: map[string]string{"FOO": "1"},
328
329 envVarNames: []string{"FOO"},
330 defaultBool: false,
331
332 wantBool: true,
333 },
334 {
335 name: "read from first env var that is set",
336
337 envVarsToSet: map[string]string{
338 "BAR": "false",
339 "BAZ": "true",
340 },
341
342 envVarNames: []string{"FOO", "BAR", "BAZ"},
343 defaultBool: true,
344
345 wantBool: false,
346 },
347
348 {
349 name: "should error for invalid input",
350
351 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
352
353 envVarNames: []string{"INVALID"},
354 defaultBool: false,
355
356 wantErr: true,
357 },
358 }
359
360 for _, tc := range testCases {
361 t.Run("", func(t *testing.T) {
362 // Prepare the environment by loading all the appropriate environment variables
363 for _, v := range tc.envVarNames {
364 _ = os.Unsetenv(v)
365 }
366
367 for k := range tc.envVarsToSet {
368 _ = os.Unsetenv(k)
369 }
370
371 for k, v := range tc.envVarsToSet {
372 t.Setenv(k, v)
373 }
374
375 // Run the test
376 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
377
378 // Examine the results
379 if tc.wantErr != (err != nil) {
380 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
381 }
382
383 if got != tc.wantBool {
384 t.Errorf("got %v, want %v", got, tc.wantBool)
385 }
386 })
387 }
388}
389
390func TestAddDefaultPort(t *testing.T) {
391 tests := []struct {
392 name string
393 input string
394 want string
395 }{
396 {
397 name: "http no port",
398 input: "http://example.com",
399 want: "http://example.com:80",
400 },
401 {
402 name: "http custom port",
403 input: "http://example.com:90",
404 want: "http://example.com:90",
405 },
406 {
407 name: "https no port",
408 input: "https://example.com",
409 want: "https://example.com:443",
410 },
411 {
412 name: "https custom port",
413 input: "https://example.com:444",
414 want: "https://example.com:444",
415 },
416 {
417 name: "non-http scheme",
418 input: "ftp://example.com",
419 want: "ftp://example.com",
420 },
421 {
422 name: "empty string",
423 input: "",
424 want: "",
425 },
426 {
427 name: "local file path",
428 input: "/etc/hosts",
429 want: "/etc/hosts",
430 },
431 }
432
433 for _, test := range tests {
434 t.Run(test.name, func(t *testing.T) {
435 input, err := url.Parse(test.input)
436 if err != nil {
437 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
438 }
439
440 got := addDefaultPort(input)
441 if diff := cmp.Diff(test.want, got.String()); diff != "" {
442 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
443 }
444 })
445 }
446}