fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "path/filepath"
12 "sort"
13 "strings"
14 "testing"
15
16 sglog "github.com/sourcegraph/log"
17 "github.com/sourcegraph/log/logtest"
18 "github.com/stretchr/testify/require"
19 "github.com/xeipuuv/gojsonschema"
20 "google.golang.org/grpc"
21
22 "github.com/google/go-cmp/cmp"
23
24 "github.com/sourcegraph/zoekt"
25 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1"
26 "github.com/sourcegraph/zoekt/internal/tenant"
27)
28
29func TestServer_defaultArgs(t *testing.T) {
30 root, err := url.Parse("http://api.test")
31 if err != nil {
32 t.Fatal(err)
33 }
34
35 s := &Server{
36 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
37 IndexDir: "/testdata/index",
38 CPUCount: 6,
39 IndexConcurrency: 1,
40 }
41 want := &indexArgs{
42 IndexOptions: IndexOptions{
43 Name: "testName",
44 },
45 IndexDir: "/testdata/index",
46 Parallelism: 6,
47 Incremental: true,
48 FileLimit: 1 << 20,
49 }
50 got := s.indexArgs(IndexOptions{Name: "testName"})
51 if !cmp.Equal(got, want) {
52 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
53 }
54}
55
56func TestIndexNoTenant(t *testing.T) {
57 s := &Server{}
58 _, err := s.Index(&indexArgs{})
59 require.ErrorIs(t, err, tenant.ErrMissingTenant)
60}
61
62func TestServer_parallelism(t *testing.T) {
63 root, err := url.Parse("http://api.test")
64 if err != nil {
65 t.Fatal(err)
66 }
67
68 cases := []struct {
69 name string
70 cpuCount int
71 indexConcurrency int
72 options IndexOptions
73 want int
74 }{
75 {
76 name: "CPU count divides evenly",
77 cpuCount: 16,
78 indexConcurrency: 8,
79 want: 2,
80 },
81 {
82 name: "no shard level parallelism",
83 cpuCount: 4,
84 indexConcurrency: 4,
85 want: 1,
86 },
87 {
88 name: "index option overrides server flag",
89 cpuCount: 2,
90 indexConcurrency: 1,
91 options: IndexOptions{
92 ShardConcurrency: 1,
93 },
94 want: 1,
95 },
96 {
97 name: "ignore invalid index option",
98 cpuCount: 8,
99 indexConcurrency: 2,
100 options: IndexOptions{
101 ShardConcurrency: -1,
102 },
103 want: 4,
104 },
105 }
106
107 for _, tt := range cases {
108 t.Run(tt.name, func(t *testing.T) {
109 s := &Server{
110 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
111 IndexDir: "/testdata/index",
112 CPUCount: tt.cpuCount,
113 IndexConcurrency: tt.indexConcurrency,
114 }
115
116 maxProcs := 16
117 got := s.parallelism(tt.options, maxProcs)
118 if tt.want != got {
119 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
120 }
121 })
122 }
123
124 t.Run("index option is limited by available CPU", func(t *testing.T) {
125 s := &Server{
126 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
127 IndexDir: "/testdata/index",
128 IndexConcurrency: 1,
129 }
130
131 got := s.indexArgs(IndexOptions{
132 ShardConcurrency: 2048, // Some number that's way too high
133 })
134
135 if got.Parallelism >= 2048 {
136 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
137 }
138 })
139}
140
141func TestListRepoIDs(t *testing.T) {
142 grpcClient := &mockGRPCClient{}
143
144 clientOptions := []SourcegraphClientOption{
145 WithBatchSize(0),
146 }
147
148 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
149 testHostname := "test-hostname"
150 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...)
151
152 listCalled := false
153 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) {
154 listCalled = true
155
156 gotRepoIDs := in.GetIndexedIds()
157 sort.Slice(gotRepoIDs, func(i, j int) bool {
158 return gotRepoIDs[i] < gotRepoIDs[j]
159 })
160
161 wantRepoIDs := []int32{1, 3}
162 sort.Slice(wantRepoIDs, func(i, j int) bool {
163 return wantRepoIDs[i] < wantRepoIDs[j]
164 })
165
166 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
167 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
168 }
169
170 hostname := in.GetHostname()
171 if diff := cmp.Diff(testHostname, hostname); diff != "" {
172 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
173 }
174
175 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
176 }
177
178 ctx := context.Background()
179 got, err := s.List(ctx, []uint32{1, 3})
180 if err != nil {
181 t.Fatal(err)
182 }
183
184 if !listCalled {
185 t.Fatalf("List was not called")
186 }
187
188 receivedRepoIDs := got.IDs
189 sort.Slice(receivedRepoIDs, func(i, j int) bool {
190 return receivedRepoIDs[i] < receivedRepoIDs[j]
191 })
192
193 expectedRepoIDs := []uint32{1, 2, 3}
194 sort.Slice(expectedRepoIDs, func(i, j int) bool {
195 return expectedRepoIDs[i] < expectedRepoIDs[j]
196 })
197
198 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
199 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
200 }
201}
202
203func TestMain(m *testing.M) {
204 flag.Parse()
205 level := sglog.LevelInfo
206 if !testing.Verbose() {
207 log.SetOutput(io.Discard)
208 debugLog.SetOutput(io.Discard)
209 infoLog.SetOutput(io.Discard)
210 errorLog.SetOutput(io.Discard)
211 level = sglog.LevelNone
212 }
213
214 logtest.InitWithLevel(m, level)
215 os.Exit(m.Run())
216}
217
218func TestCreateEmptyShard(t *testing.T) {
219 dir := t.TempDir()
220
221 args := &indexArgs{
222 IndexOptions: IndexOptions{
223 RepoID: 7,
224 Name: "empty-repo",
225 CloneURL: "code/host",
226 },
227 Incremental: true,
228 IndexDir: dir,
229 Parallelism: 1,
230 FileLimit: 1,
231 }
232
233 if err := createEmptyShard(args); err != nil {
234 t.Fatal(err)
235 }
236
237 bo := args.BuildOptions()
238 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
239
240 if got := bo.IncrementalSkipIndexing(); !got {
241 t.Fatalf("wanted %t, got %t", true, got)
242 }
243}
244
245func TestFormatListUint32(t *testing.T) {
246 cases := []struct {
247 in []uint32
248 want string
249 }{
250 {
251 in: []uint32{42, 8, 3},
252 want: "42, 8, ...",
253 },
254 {
255 in: []uint32{42, 8},
256 want: "42, 8",
257 },
258 {
259 in: []uint32{42},
260 want: "42",
261 },
262 {
263 in: []uint32{},
264 want: "",
265 },
266 }
267
268 for _, tt := range cases {
269 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
270 out := formatListUint32(tt.in, 2)
271 if out != tt.want {
272 t.Fatalf("want %s, got %s", tt.want, out)
273 }
274 })
275 }
276}
277
278func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
279 wd, err := os.Getwd()
280 if err != nil {
281 t.Fatalf("failed to get working directory: %v", err)
282 }
283
284 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
285 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
286
287 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
288
289 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
290 if err != nil {
291 t.Fatalf("failed to validate default service config: %v", err)
292 }
293
294 if !result.Valid() {
295 var errs strings.Builder
296 for _, err := range result.Errors() {
297 errs.WriteString(fmt.Sprintf("- %s\n", err))
298 }
299
300 t.Fatalf("default service config is invalid:\n%s", errs.String())
301 }
302}
303
304func TestGetBoolFromEnvironmentVariables(t *testing.T) {
305 testCases := []struct {
306 name string
307 envVarsToSet map[string]string
308
309 envVarNames []string
310 defaultBool bool
311
312 wantBool bool
313 wantErr bool
314 }{
315 {
316 name: "respect default value: true",
317
318 envVarsToSet: map[string]string{},
319
320 envVarNames: []string{"FOO", "BAR"},
321 defaultBool: true,
322
323 wantBool: true,
324 },
325 {
326 name: "respect default value: false",
327
328 envVarsToSet: map[string]string{},
329
330 envVarNames: []string{"FOO", "BAR"},
331 defaultBool: false,
332
333 wantBool: false,
334 },
335 {
336 name: "read from environment",
337
338 envVarsToSet: map[string]string{"FOO": "1"},
339
340 envVarNames: []string{"FOO"},
341 defaultBool: false,
342
343 wantBool: true,
344 },
345 {
346 name: "read from first env var that is set",
347
348 envVarsToSet: map[string]string{
349 "BAR": "false",
350 "BAZ": "true",
351 },
352
353 envVarNames: []string{"FOO", "BAR", "BAZ"},
354 defaultBool: true,
355
356 wantBool: false,
357 },
358
359 {
360 name: "should error for invalid input",
361
362 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
363
364 envVarNames: []string{"INVALID"},
365 defaultBool: false,
366
367 wantErr: true,
368 },
369 }
370
371 for _, tc := range testCases {
372 t.Run("", func(t *testing.T) {
373 // Prepare the environment by loading all the appropriate environment variables
374 for _, v := range tc.envVarNames {
375 _ = os.Unsetenv(v)
376 }
377
378 for k := range tc.envVarsToSet {
379 _ = os.Unsetenv(k)
380 }
381
382 for k, v := range tc.envVarsToSet {
383 t.Setenv(k, v)
384 }
385
386 // Run the test
387 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
388
389 // Examine the results
390 if tc.wantErr != (err != nil) {
391 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
392 }
393
394 if got != tc.wantBool {
395 t.Errorf("got %v, want %v", got, tc.wantBool)
396 }
397 })
398 }
399}
400
401func TestAddDefaultPort(t *testing.T) {
402 tests := []struct {
403 name string
404 input string
405 want string
406 }{
407 {
408 name: "http no port",
409 input: "http://example.com",
410 want: "http://example.com:80",
411 },
412 {
413 name: "http custom port",
414 input: "http://example.com:90",
415 want: "http://example.com:90",
416 },
417 {
418 name: "https no port",
419 input: "https://example.com",
420 want: "https://example.com:443",
421 },
422 {
423 name: "https custom port",
424 input: "https://example.com:444",
425 want: "https://example.com:444",
426 },
427 {
428 name: "non-http scheme",
429 input: "ftp://example.com",
430 want: "ftp://example.com",
431 },
432 {
433 name: "empty string",
434 input: "",
435 want: "",
436 },
437 {
438 name: "local file path",
439 input: "/etc/hosts",
440 want: "/etc/hosts",
441 },
442 }
443
444 for _, test := range tests {
445 t.Run(test.name, func(t *testing.T) {
446 input, err := url.Parse(test.input)
447 if err != nil {
448 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
449 }
450
451 got := addDefaultPort(input)
452 if diff := cmp.Diff(test.want, got.String()); diff != "" {
453 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
454 }
455 })
456 }
457}