fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "path/filepath"
12 "strings"
13 "testing"
14
15 sglog "github.com/sourcegraph/log"
16 "github.com/sourcegraph/log/logtest"
17 "github.com/stretchr/testify/require"
18 "github.com/xeipuuv/gojsonschema"
19 "google.golang.org/grpc"
20
21 "github.com/google/go-cmp/cmp"
22
23 "slices"
24
25 "github.com/sourcegraph/zoekt"
26 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1"
27 "github.com/sourcegraph/zoekt/internal/tenant"
28)
29
30func TestServer_defaultArgs(t *testing.T) {
31 root, err := url.Parse("http://api.test")
32 if err != nil {
33 t.Fatal(err)
34 }
35
36 s := &Server{
37 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
38 IndexDir: "/testdata/index",
39 CPUCount: 6,
40 IndexConcurrency: 1,
41 }
42 want := &indexArgs{
43 IndexOptions: IndexOptions{
44 Name: "testName",
45 },
46 IndexDir: "/testdata/index",
47 Parallelism: 6,
48 Incremental: true,
49 FileLimit: 1 << 20,
50 }
51 got := s.indexArgs(IndexOptions{Name: "testName"})
52 if !cmp.Equal(got, want) {
53 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
54 }
55}
56
57func TestIndexNoTenant(t *testing.T) {
58 s := &Server{}
59 _, err := s.Index(&indexArgs{})
60 require.ErrorIs(t, err, tenant.ErrMissingTenant)
61}
62
63func TestServer_parallelism(t *testing.T) {
64 root, err := url.Parse("http://api.test")
65 if err != nil {
66 t.Fatal(err)
67 }
68
69 cases := []struct {
70 name string
71 cpuCount int
72 indexConcurrency int
73 options IndexOptions
74 want int
75 }{
76 {
77 name: "CPU count divides evenly",
78 cpuCount: 16,
79 indexConcurrency: 8,
80 want: 2,
81 },
82 {
83 name: "no shard level parallelism",
84 cpuCount: 4,
85 indexConcurrency: 4,
86 want: 1,
87 },
88 {
89 name: "index option overrides server flag",
90 cpuCount: 2,
91 indexConcurrency: 1,
92 options: IndexOptions{
93 ShardConcurrency: 1,
94 },
95 want: 1,
96 },
97 {
98 name: "ignore invalid index option",
99 cpuCount: 8,
100 indexConcurrency: 2,
101 options: IndexOptions{
102 ShardConcurrency: -1,
103 },
104 want: 4,
105 },
106 }
107
108 for _, tt := range cases {
109 t.Run(tt.name, func(t *testing.T) {
110 s := &Server{
111 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
112 IndexDir: "/testdata/index",
113 CPUCount: tt.cpuCount,
114 IndexConcurrency: tt.indexConcurrency,
115 }
116
117 maxProcs := 16
118 got := s.parallelism(tt.options, maxProcs)
119 if tt.want != got {
120 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
121 }
122 })
123 }
124
125 t.Run("index option is limited by available CPU", func(t *testing.T) {
126 s := &Server{
127 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
128 IndexDir: "/testdata/index",
129 IndexConcurrency: 1,
130 }
131
132 got := s.indexArgs(IndexOptions{
133 ShardConcurrency: 2048, // Some number that's way too high
134 })
135
136 if got.Parallelism >= 2048 {
137 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
138 }
139 })
140}
141
142func TestListRepoIDs(t *testing.T) {
143 grpcClient := &mockGRPCClient{}
144
145 clientOptions := []SourcegraphClientOption{
146 WithBatchSize(0),
147 }
148
149 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
150 testHostname := "test-hostname"
151 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...)
152
153 listCalled := false
154 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) {
155 listCalled = true
156
157 gotRepoIDs := in.GetIndexedIds()
158 slices.Sort(gotRepoIDs)
159
160 wantRepoIDs := []int32{1, 3}
161 slices.Sort(wantRepoIDs)
162
163 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
164 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
165 }
166
167 hostname := in.GetHostname()
168 if diff := cmp.Diff(testHostname, hostname); diff != "" {
169 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
170 }
171
172 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
173 }
174
175 ctx := context.Background()
176 got, err := s.List(ctx, []uint32{1, 3})
177 if err != nil {
178 t.Fatal(err)
179 }
180
181 if !listCalled {
182 t.Fatalf("List was not called")
183 }
184
185 receivedRepoIDs := got.IDs
186 slices.Sort(receivedRepoIDs)
187
188 expectedRepoIDs := []uint32{1, 2, 3}
189 slices.Sort(expectedRepoIDs)
190
191 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
192 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
193 }
194}
195
196func TestMain(m *testing.M) {
197 flag.Parse()
198 level := sglog.LevelInfo
199 if !testing.Verbose() {
200 log.SetOutput(io.Discard)
201 debugLog.SetOutput(io.Discard)
202 infoLog.SetOutput(io.Discard)
203 errorLog.SetOutput(io.Discard)
204 level = sglog.LevelNone
205 }
206
207 logtest.InitWithLevel(m, level)
208 os.Exit(m.Run())
209}
210
211func TestCreateEmptyShard(t *testing.T) {
212 dir := t.TempDir()
213
214 args := &indexArgs{
215 IndexOptions: IndexOptions{
216 RepoID: 7,
217 Name: "empty-repo",
218 CloneURL: "code/host",
219 },
220 Incremental: true,
221 IndexDir: dir,
222 Parallelism: 1,
223 FileLimit: 1,
224 }
225
226 if err := createEmptyShard(args); err != nil {
227 t.Fatal(err)
228 }
229
230 bo := args.BuildOptions()
231 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
232
233 if got := bo.IncrementalSkipIndexing(); !got {
234 t.Fatalf("wanted %t, got %t", true, got)
235 }
236}
237
238func TestFormatListUint32(t *testing.T) {
239 cases := []struct {
240 in []uint32
241 want string
242 }{
243 {
244 in: []uint32{42, 8, 3},
245 want: "42, 8, ...",
246 },
247 {
248 in: []uint32{42, 8},
249 want: "42, 8",
250 },
251 {
252 in: []uint32{42},
253 want: "42",
254 },
255 {
256 in: []uint32{},
257 want: "",
258 },
259 }
260
261 for _, tt := range cases {
262 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
263 out := formatListUint32(tt.in, 2)
264 if out != tt.want {
265 t.Fatalf("want %s, got %s", tt.want, out)
266 }
267 })
268 }
269}
270
271func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
272 wd, err := os.Getwd()
273 if err != nil {
274 t.Fatalf("failed to get working directory: %v", err)
275 }
276
277 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
278 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
279
280 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
281
282 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
283 if err != nil {
284 t.Fatalf("failed to validate default service config: %v", err)
285 }
286
287 if !result.Valid() {
288 var errs strings.Builder
289 for _, err := range result.Errors() {
290 errs.WriteString(fmt.Sprintf("- %s\n", err))
291 }
292
293 t.Fatalf("default service config is invalid:\n%s", errs.String())
294 }
295}
296
297func TestGetBoolFromEnvironmentVariables(t *testing.T) {
298 testCases := []struct {
299 name string
300 envVarsToSet map[string]string
301
302 envVarNames []string
303 defaultBool bool
304
305 wantBool bool
306 wantErr bool
307 }{
308 {
309 name: "respect default value: true",
310
311 envVarsToSet: map[string]string{},
312
313 envVarNames: []string{"FOO", "BAR"},
314 defaultBool: true,
315
316 wantBool: true,
317 },
318 {
319 name: "respect default value: false",
320
321 envVarsToSet: map[string]string{},
322
323 envVarNames: []string{"FOO", "BAR"},
324 defaultBool: false,
325
326 wantBool: false,
327 },
328 {
329 name: "read from environment",
330
331 envVarsToSet: map[string]string{"FOO": "1"},
332
333 envVarNames: []string{"FOO"},
334 defaultBool: false,
335
336 wantBool: true,
337 },
338 {
339 name: "read from first env var that is set",
340
341 envVarsToSet: map[string]string{
342 "BAR": "false",
343 "BAZ": "true",
344 },
345
346 envVarNames: []string{"FOO", "BAR", "BAZ"},
347 defaultBool: true,
348
349 wantBool: false,
350 },
351
352 {
353 name: "should error for invalid input",
354
355 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
356
357 envVarNames: []string{"INVALID"},
358 defaultBool: false,
359
360 wantErr: true,
361 },
362 }
363
364 for _, tc := range testCases {
365 t.Run("", func(t *testing.T) {
366 // Prepare the environment by loading all the appropriate environment variables
367 for _, v := range tc.envVarNames {
368 _ = os.Unsetenv(v)
369 }
370
371 for k := range tc.envVarsToSet {
372 _ = os.Unsetenv(k)
373 }
374
375 for k, v := range tc.envVarsToSet {
376 t.Setenv(k, v)
377 }
378
379 // Run the test
380 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
381
382 // Examine the results
383 if tc.wantErr != (err != nil) {
384 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
385 }
386
387 if got != tc.wantBool {
388 t.Errorf("got %v, want %v", got, tc.wantBool)
389 }
390 })
391 }
392}
393
394func TestAddDefaultPort(t *testing.T) {
395 tests := []struct {
396 name string
397 input string
398 want string
399 }{
400 {
401 name: "http no port",
402 input: "http://example.com",
403 want: "http://example.com:80",
404 },
405 {
406 name: "http custom port",
407 input: "http://example.com:90",
408 want: "http://example.com:90",
409 },
410 {
411 name: "https no port",
412 input: "https://example.com",
413 want: "https://example.com:443",
414 },
415 {
416 name: "https custom port",
417 input: "https://example.com:444",
418 want: "https://example.com:444",
419 },
420 {
421 name: "non-http scheme",
422 input: "ftp://example.com",
423 want: "ftp://example.com",
424 },
425 {
426 name: "empty string",
427 input: "",
428 want: "",
429 },
430 {
431 name: "local file path",
432 input: "/etc/hosts",
433 want: "/etc/hosts",
434 },
435 }
436
437 for _, test := range tests {
438 t.Run(test.name, func(t *testing.T) {
439 input, err := url.Parse(test.input)
440 if err != nil {
441 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
442 }
443
444 got := addDefaultPort(input)
445 if diff := cmp.Diff(test.want, got.String()); diff != "" {
446 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
447 }
448 })
449 }
450}