fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "path/filepath"
12 "sort"
13 "strings"
14 "testing"
15
16 sglog "github.com/sourcegraph/log"
17 "github.com/sourcegraph/log/logtest"
18 "github.com/stretchr/testify/require"
19 "github.com/xeipuuv/gojsonschema"
20 "google.golang.org/grpc"
21
22 "github.com/google/go-cmp/cmp"
23
24 "github.com/sourcegraph/zoekt"
25 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1"
26 "github.com/sourcegraph/zoekt/internal/tenant"
27)
28
29func TestServer_defaultArgs(t *testing.T) {
30 root, err := url.Parse("http://api.test")
31 if err != nil {
32 t.Fatal(err)
33 }
34
35 s := &Server{
36 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
37 IndexDir: "/testdata/index",
38 CPUCount: 6,
39 IndexConcurrency: 1,
40 }
41 want := &indexArgs{
42 IndexOptions: IndexOptions{
43 Name: "testName",
44 },
45 IndexDir: "/testdata/index",
46 Parallelism: 6,
47 Incremental: true,
48 FileLimit: 1 << 20,
49 }
50 got := s.indexArgs(IndexOptions{Name: "testName"})
51 if !cmp.Equal(got, want) {
52 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
53 }
54}
55
56func TestIndexNoTenant(t *testing.T) {
57 s := &Server{}
58 _, err := s.Index(&indexArgs{})
59 require.ErrorIs(t, err, tenant.ErrMissingTenant)
60}
61
62func TestServer_parallelism(t *testing.T) {
63 root, err := url.Parse("http://api.test")
64 if err != nil {
65 t.Fatal(err)
66 }
67
68 cases := []struct {
69 name string
70 cpuCount int
71 indexConcurrency int
72 options IndexOptions
73 want int
74 }{
75 {
76 name: "CPU count divides evenly",
77 cpuCount: 16,
78 indexConcurrency: 8,
79 want: 2,
80 },
81 {
82 name: "no shard level parallelism",
83 cpuCount: 4,
84 indexConcurrency: 4,
85 want: 1,
86 },
87 {
88 name: "index option overrides server flag",
89 cpuCount: 2,
90 indexConcurrency: 1,
91 options: IndexOptions{
92 ShardConcurrency: 1,
93 },
94 want: 1,
95 },
96 {
97 name: "ignore invalid index option",
98 cpuCount: 8,
99 indexConcurrency: 2,
100 options: IndexOptions{
101 ShardConcurrency: -1,
102 },
103 want: 4,
104 },
105 }
106
107 for _, tt := range cases {
108 t.Run(tt.name, func(t *testing.T) {
109 s := &Server{
110 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
111 IndexDir: "/testdata/index",
112 CPUCount: tt.cpuCount,
113 IndexConcurrency: tt.indexConcurrency,
114 }
115
116 maxProcs := 16
117 got := s.parallelism(tt.options, maxProcs)
118 if tt.want != got {
119 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
120 }
121 })
122 }
123
124 t.Run("index option is limited by available CPU", func(t *testing.T) {
125 s := &Server{
126 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
127 IndexDir: "/testdata/index",
128 IndexConcurrency: 1,
129 }
130
131 got := s.indexArgs(IndexOptions{
132 ShardConcurrency: 2048, // Some number that's way too high
133 })
134
135 if got.Parallelism >= 2048 {
136 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
137 }
138 })
139}
140
141func TestListRepoIDs(t *testing.T) {
142 grpcClient := &mockGRPCClient{}
143
144 clientOptions := []SourcegraphClientOption{
145 WithBatchSize(0),
146 }
147
148 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
149 testHostname := "test-hostname"
150 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...)
151
152 listCalled := false
153 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) {
154 listCalled = true
155
156 gotRepoIDs := in.GetIndexedIds()
157 sort.Slice(gotRepoIDs, func(i, j int) bool {
158 return gotRepoIDs[i] < gotRepoIDs[j]
159 })
160
161 wantRepoIDs := []int32{1, 3}
162 sort.Slice(wantRepoIDs, func(i, j int) bool {
163 return wantRepoIDs[i] < wantRepoIDs[j]
164 })
165
166 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
167 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
168 }
169
170 hostname := in.GetHostname()
171 if diff := cmp.Diff(testHostname, hostname); diff != "" {
172 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
173 }
174
175 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
176 }
177
178 ctx := context.Background()
179 got, err := s.List(ctx, []uint32{1, 3})
180 if err != nil {
181 t.Fatal(err)
182 }
183
184 if !listCalled {
185 t.Fatalf("List was not called")
186 }
187
188 receivedRepoIDs := got.IDs
189 sort.Slice(receivedRepoIDs, func(i, j int) bool {
190 return receivedRepoIDs[i] < receivedRepoIDs[j]
191 })
192
193 expectedRepoIDs := []uint32{1, 2, 3}
194 sort.Slice(expectedRepoIDs, func(i, j int) bool {
195 return expectedRepoIDs[i] < expectedRepoIDs[j]
196 })
197
198 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
199 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
200 }
201}
202
203func TestMain(m *testing.M) {
204 flag.Parse()
205 level := sglog.LevelInfo
206 if !testing.Verbose() {
207 log.SetOutput(io.Discard)
208 level = sglog.LevelNone
209 }
210
211 logtest.InitWithLevel(m, level)
212 os.Exit(m.Run())
213}
214
215func TestCreateEmptyShard(t *testing.T) {
216 dir := t.TempDir()
217
218 args := &indexArgs{
219 IndexOptions: IndexOptions{
220 RepoID: 7,
221 Name: "empty-repo",
222 CloneURL: "code/host",
223 },
224 Incremental: true,
225 IndexDir: dir,
226 Parallelism: 1,
227 FileLimit: 1,
228 }
229
230 if err := createEmptyShard(args); err != nil {
231 t.Fatal(err)
232 }
233
234 bo := args.BuildOptions()
235 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
236
237 if got := bo.IncrementalSkipIndexing(); !got {
238 t.Fatalf("wanted %t, got %t", true, got)
239 }
240}
241
242func TestFormatListUint32(t *testing.T) {
243 cases := []struct {
244 in []uint32
245 want string
246 }{
247 {
248 in: []uint32{42, 8, 3},
249 want: "42, 8, ...",
250 },
251 {
252 in: []uint32{42, 8},
253 want: "42, 8",
254 },
255 {
256 in: []uint32{42},
257 want: "42",
258 },
259 {
260 in: []uint32{},
261 want: "",
262 },
263 }
264
265 for _, tt := range cases {
266 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
267 out := formatListUint32(tt.in, 2)
268 if out != tt.want {
269 t.Fatalf("want %s, got %s", tt.want, out)
270 }
271 })
272 }
273}
274
275func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
276 wd, err := os.Getwd()
277 if err != nil {
278 t.Fatalf("failed to get working directory: %v", err)
279 }
280
281 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
282 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
283
284 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
285
286 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
287 if err != nil {
288 t.Fatalf("failed to validate default service config: %v", err)
289 }
290
291 if !result.Valid() {
292 var errs strings.Builder
293 for _, err := range result.Errors() {
294 errs.WriteString(fmt.Sprintf("- %s\n", err))
295 }
296
297 t.Fatalf("default service config is invalid:\n%s", errs.String())
298 }
299}
300
301func TestGetBoolFromEnvironmentVariables(t *testing.T) {
302 testCases := []struct {
303 name string
304 envVarsToSet map[string]string
305
306 envVarNames []string
307 defaultBool bool
308
309 wantBool bool
310 wantErr bool
311 }{
312 {
313 name: "respect default value: true",
314
315 envVarsToSet: map[string]string{},
316
317 envVarNames: []string{"FOO", "BAR"},
318 defaultBool: true,
319
320 wantBool: true,
321 },
322 {
323 name: "respect default value: false",
324
325 envVarsToSet: map[string]string{},
326
327 envVarNames: []string{"FOO", "BAR"},
328 defaultBool: false,
329
330 wantBool: false,
331 },
332 {
333 name: "read from environment",
334
335 envVarsToSet: map[string]string{"FOO": "1"},
336
337 envVarNames: []string{"FOO"},
338 defaultBool: false,
339
340 wantBool: true,
341 },
342 {
343 name: "read from first env var that is set",
344
345 envVarsToSet: map[string]string{
346 "BAR": "false",
347 "BAZ": "true",
348 },
349
350 envVarNames: []string{"FOO", "BAR", "BAZ"},
351 defaultBool: true,
352
353 wantBool: false,
354 },
355
356 {
357 name: "should error for invalid input",
358
359 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
360
361 envVarNames: []string{"INVALID"},
362 defaultBool: false,
363
364 wantErr: true,
365 },
366 }
367
368 for _, tc := range testCases {
369 t.Run("", func(t *testing.T) {
370 // Prepare the environment by loading all the appropriate environment variables
371 for _, v := range tc.envVarNames {
372 _ = os.Unsetenv(v)
373 }
374
375 for k := range tc.envVarsToSet {
376 _ = os.Unsetenv(k)
377 }
378
379 for k, v := range tc.envVarsToSet {
380 t.Setenv(k, v)
381 }
382
383 // Run the test
384 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
385
386 // Examine the results
387 if tc.wantErr != (err != nil) {
388 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
389 }
390
391 if got != tc.wantBool {
392 t.Errorf("got %v, want %v", got, tc.wantBool)
393 }
394 })
395 }
396}
397
398func TestAddDefaultPort(t *testing.T) {
399 tests := []struct {
400 name string
401 input string
402 want string
403 }{
404 {
405 name: "http no port",
406 input: "http://example.com",
407 want: "http://example.com:80",
408 },
409 {
410 name: "http custom port",
411 input: "http://example.com:90",
412 want: "http://example.com:90",
413 },
414 {
415 name: "https no port",
416 input: "https://example.com",
417 want: "https://example.com:443",
418 },
419 {
420 name: "https custom port",
421 input: "https://example.com:444",
422 want: "https://example.com:444",
423 },
424 {
425 name: "non-http scheme",
426 input: "ftp://example.com",
427 want: "ftp://example.com",
428 },
429 {
430 name: "empty string",
431 input: "",
432 want: "",
433 },
434 {
435 name: "local file path",
436 input: "/etc/hosts",
437 want: "/etc/hosts",
438 },
439 }
440
441 for _, test := range tests {
442 t.Run(test.name, func(t *testing.T) {
443 input, err := url.Parse(test.input)
444 if err != nil {
445 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
446 }
447
448 got := addDefaultPort(input)
449 if diff := cmp.Diff(test.want, got.String()); diff != "" {
450 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
451 }
452 })
453 }
454}