fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "path/filepath"
12 "slices"
13 "strings"
14 "testing"
15
16 "github.com/google/go-cmp/cmp"
17 sglog "github.com/sourcegraph/log"
18 "github.com/sourcegraph/log/logtest"
19 "github.com/stretchr/testify/require"
20 "github.com/xeipuuv/gojsonschema"
21 "google.golang.org/grpc"
22
23 "github.com/sourcegraph/zoekt"
24 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1"
25 "github.com/sourcegraph/zoekt/internal/tenant"
26)
27
28func TestServer_defaultArgs(t *testing.T) {
29 root, err := url.Parse("http://api.test")
30 if err != nil {
31 t.Fatal(err)
32 }
33
34 s := &Server{
35 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
36 IndexDir: "/testdata/index",
37 CPUCount: 6,
38 IndexConcurrency: 1,
39 }
40 want := &indexArgs{
41 IndexOptions: IndexOptions{
42 Name: "testName",
43 },
44 IndexDir: "/testdata/index",
45 Parallelism: 6,
46 Incremental: true,
47 }
48 got := s.indexArgs(IndexOptions{Name: "testName"})
49 if !cmp.Equal(got, want) {
50 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
51 }
52}
53
54func TestIndexNoTenant(t *testing.T) {
55 s := &Server{}
56 _, err := s.index(context.Background(), &indexArgs{})
57 require.ErrorIs(t, err, tenant.ErrMissingTenant)
58}
59
60func TestServer_parallelism(t *testing.T) {
61 root, err := url.Parse("http://api.test")
62 if err != nil {
63 t.Fatal(err)
64 }
65
66 cases := []struct {
67 name string
68 cpuCount int
69 indexConcurrency int
70 options IndexOptions
71 want int
72 }{
73 {
74 name: "CPU count divides evenly",
75 cpuCount: 16,
76 indexConcurrency: 8,
77 want: 2,
78 },
79 {
80 name: "no shard level parallelism",
81 cpuCount: 4,
82 indexConcurrency: 4,
83 want: 1,
84 },
85 {
86 name: "index option overrides server flag",
87 cpuCount: 2,
88 indexConcurrency: 1,
89 options: IndexOptions{
90 ShardConcurrency: 1,
91 },
92 want: 1,
93 },
94 {
95 name: "ignore invalid index option",
96 cpuCount: 8,
97 indexConcurrency: 2,
98 options: IndexOptions{
99 ShardConcurrency: -1,
100 },
101 want: 4,
102 },
103 }
104
105 for _, tt := range cases {
106 t.Run(tt.name, func(t *testing.T) {
107 s := &Server{
108 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
109 IndexDir: "/testdata/index",
110 CPUCount: tt.cpuCount,
111 IndexConcurrency: tt.indexConcurrency,
112 }
113
114 maxProcs := 16
115 got := s.parallelism(tt.options, maxProcs)
116 if tt.want != got {
117 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
118 }
119 })
120 }
121
122 t.Run("index option is limited by available CPU", func(t *testing.T) {
123 s := &Server{
124 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
125 IndexDir: "/testdata/index",
126 IndexConcurrency: 1,
127 }
128
129 got := s.indexArgs(IndexOptions{
130 ShardConcurrency: 2048, // Some number that's way too high
131 })
132
133 if got.Parallelism >= 2048 {
134 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
135 }
136 })
137}
138
139func TestListRepoIDs(t *testing.T) {
140 grpcClient := &mockGRPCClient{}
141
142 clientOptions := []SourcegraphClientOption{
143 WithBatchSize(0),
144 }
145
146 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
147 testHostname := "test-hostname"
148 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...)
149
150 listCalled := false
151 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) {
152 listCalled = true
153
154 gotRepoIDs := in.GetIndexedIds()
155 slices.Sort(gotRepoIDs)
156
157 wantRepoIDs := []int32{1, 3}
158 slices.Sort(wantRepoIDs)
159
160 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
161 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
162 }
163
164 hostname := in.GetHostname()
165 if diff := cmp.Diff(testHostname, hostname); diff != "" {
166 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
167 }
168
169 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
170 }
171
172 ctx := context.Background()
173 got, err := s.List(ctx, []uint32{1, 3})
174 if err != nil {
175 t.Fatal(err)
176 }
177
178 if !listCalled {
179 t.Fatalf("List was not called")
180 }
181
182 receivedRepoIDs := got.IDs
183 slices.Sort(receivedRepoIDs)
184
185 expectedRepoIDs := []uint32{1, 2, 3}
186 slices.Sort(expectedRepoIDs)
187
188 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
189 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
190 }
191}
192
193func TestMain(m *testing.M) {
194 flag.Parse()
195 level := sglog.LevelInfo
196 if !testing.Verbose() {
197 log.SetOutput(io.Discard)
198 debugLog.SetOutput(io.Discard)
199 infoLog.SetOutput(io.Discard)
200 errorLog.SetOutput(io.Discard)
201 level = sglog.LevelNone
202 }
203
204 logtest.InitWithLevel(m, level)
205 os.Exit(m.Run())
206}
207
208func TestCreateEmptyShard(t *testing.T) {
209 dir := t.TempDir()
210
211 args := &indexArgs{
212 IndexOptions: IndexOptions{
213 RepoID: 7,
214 Name: "empty-repo",
215 CloneURL: "code/host",
216 },
217 Incremental: true,
218 IndexDir: dir,
219 Parallelism: 1,
220 }
221
222 if err := createEmptyShard(args); err != nil {
223 t.Fatal(err)
224 }
225
226 bo := args.BuildOptions()
227 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
228
229 if got := bo.IncrementalSkipIndexing(); !got {
230 t.Fatalf("wanted %t, got %t", true, got)
231 }
232}
233
234func TestFormatListUint32(t *testing.T) {
235 cases := []struct {
236 in []uint32
237 want string
238 }{
239 {
240 in: []uint32{42, 8, 3},
241 want: "42, 8, ...",
242 },
243 {
244 in: []uint32{42, 8},
245 want: "42, 8",
246 },
247 {
248 in: []uint32{42},
249 want: "42",
250 },
251 {
252 in: []uint32{},
253 want: "",
254 },
255 }
256
257 for _, tt := range cases {
258 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
259 out := formatListUint32(tt.in, 2)
260 if out != tt.want {
261 t.Fatalf("want %s, got %s", tt.want, out)
262 }
263 })
264 }
265}
266
267func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
268 wd, err := os.Getwd()
269 if err != nil {
270 t.Fatalf("failed to get working directory: %v", err)
271 }
272
273 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
274 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
275
276 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
277
278 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
279 if err != nil {
280 t.Fatalf("failed to validate default service config: %v", err)
281 }
282
283 if !result.Valid() {
284 var errs strings.Builder
285 for _, err := range result.Errors() {
286 errs.WriteString(fmt.Sprintf("- %s\n", err))
287 }
288
289 t.Fatalf("default service config is invalid:\n%s", errs.String())
290 }
291}
292
293func TestGetBoolFromEnvironmentVariables(t *testing.T) {
294 testCases := []struct {
295 name string
296 envVarsToSet map[string]string
297
298 envVarNames []string
299 defaultBool bool
300
301 wantBool bool
302 wantErr bool
303 }{
304 {
305 name: "respect default value: true",
306
307 envVarsToSet: map[string]string{},
308
309 envVarNames: []string{"FOO", "BAR"},
310 defaultBool: true,
311
312 wantBool: true,
313 },
314 {
315 name: "respect default value: false",
316
317 envVarsToSet: map[string]string{},
318
319 envVarNames: []string{"FOO", "BAR"},
320 defaultBool: false,
321
322 wantBool: false,
323 },
324 {
325 name: "read from environment",
326
327 envVarsToSet: map[string]string{"FOO": "1"},
328
329 envVarNames: []string{"FOO"},
330 defaultBool: false,
331
332 wantBool: true,
333 },
334 {
335 name: "read from first env var that is set",
336
337 envVarsToSet: map[string]string{
338 "BAR": "false",
339 "BAZ": "true",
340 },
341
342 envVarNames: []string{"FOO", "BAR", "BAZ"},
343 defaultBool: true,
344
345 wantBool: false,
346 },
347
348 {
349 name: "should error for invalid input",
350
351 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
352
353 envVarNames: []string{"INVALID"},
354 defaultBool: false,
355
356 wantErr: true,
357 },
358 }
359
360 for _, tc := range testCases {
361 t.Run("", func(t *testing.T) {
362 // Prepare the environment by loading all the appropriate environment variables
363 for _, v := range tc.envVarNames {
364 _ = os.Unsetenv(v)
365 }
366
367 for k := range tc.envVarsToSet {
368 _ = os.Unsetenv(k)
369 }
370
371 for k, v := range tc.envVarsToSet {
372 t.Setenv(k, v)
373 }
374
375 // Run the test
376 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
377
378 // Examine the results
379 if tc.wantErr != (err != nil) {
380 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
381 }
382
383 if got != tc.wantBool {
384 t.Errorf("got %v, want %v", got, tc.wantBool)
385 }
386 })
387 }
388}
389
390func TestAddDefaultPort(t *testing.T) {
391 tests := []struct {
392 name string
393 input string
394 want string
395 }{
396 {
397 name: "http no port",
398 input: "http://example.com",
399 want: "http://example.com:80",
400 },
401 {
402 name: "http custom port",
403 input: "http://example.com:90",
404 want: "http://example.com:90",
405 },
406 {
407 name: "https no port",
408 input: "https://example.com",
409 want: "https://example.com:443",
410 },
411 {
412 name: "https custom port",
413 input: "https://example.com:444",
414 want: "https://example.com:444",
415 },
416 {
417 name: "non-http scheme",
418 input: "ftp://example.com",
419 want: "ftp://example.com",
420 },
421 {
422 name: "empty string",
423 input: "",
424 want: "",
425 },
426 {
427 name: "local file path",
428 input: "/etc/hosts",
429 want: "/etc/hosts",
430 },
431 }
432
433 for _, test := range tests {
434 t.Run(test.name, func(t *testing.T) {
435 input, err := url.Parse(test.input)
436 if err != nil {
437 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
438 }
439
440 got := addDefaultPort(input)
441 if diff := cmp.Diff(test.want, got.String()); diff != "" {
442 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
443 }
444 })
445 }
446}