fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "path/filepath"
12 "slices"
13 "strings"
14 "testing"
15
16 "github.com/google/go-cmp/cmp"
17 sglog "github.com/sourcegraph/log"
18 "github.com/sourcegraph/log/logtest"
19 "github.com/stretchr/testify/require"
20 "github.com/xeipuuv/gojsonschema"
21 "google.golang.org/grpc"
22
23 "github.com/sourcegraph/zoekt"
24 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1"
25 "github.com/sourcegraph/zoekt/internal/tenant"
26)
27
28func TestServer_defaultArgs(t *testing.T) {
29 root, err := url.Parse("http://api.test")
30 if err != nil {
31 t.Fatal(err)
32 }
33
34 s := &Server{
35 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
36 IndexDir: "/testdata/index",
37 CPUCount: 6,
38 IndexConcurrency: 1,
39 }
40 want := &indexArgs{
41 IndexOptions: IndexOptions{
42 Name: "testName",
43 },
44 IndexDir: "/testdata/index",
45 Parallelism: 6,
46 Incremental: true,
47 FileLimit: 1 << 20,
48 }
49 got := s.indexArgs(IndexOptions{Name: "testName"})
50 if !cmp.Equal(got, want) {
51 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
52 }
53}
54
55func TestIndexNoTenant(t *testing.T) {
56 s := &Server{}
57 _, err := s.index(context.Background(), &indexArgs{})
58 require.ErrorIs(t, err, tenant.ErrMissingTenant)
59}
60
61func TestServer_parallelism(t *testing.T) {
62 root, err := url.Parse("http://api.test")
63 if err != nil {
64 t.Fatal(err)
65 }
66
67 cases := []struct {
68 name string
69 cpuCount int
70 indexConcurrency int
71 options IndexOptions
72 want int
73 }{
74 {
75 name: "CPU count divides evenly",
76 cpuCount: 16,
77 indexConcurrency: 8,
78 want: 2,
79 },
80 {
81 name: "no shard level parallelism",
82 cpuCount: 4,
83 indexConcurrency: 4,
84 want: 1,
85 },
86 {
87 name: "index option overrides server flag",
88 cpuCount: 2,
89 indexConcurrency: 1,
90 options: IndexOptions{
91 ShardConcurrency: 1,
92 },
93 want: 1,
94 },
95 {
96 name: "ignore invalid index option",
97 cpuCount: 8,
98 indexConcurrency: 2,
99 options: IndexOptions{
100 ShardConcurrency: -1,
101 },
102 want: 4,
103 },
104 }
105
106 for _, tt := range cases {
107 t.Run(tt.name, func(t *testing.T) {
108 s := &Server{
109 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
110 IndexDir: "/testdata/index",
111 CPUCount: tt.cpuCount,
112 IndexConcurrency: tt.indexConcurrency,
113 }
114
115 maxProcs := 16
116 got := s.parallelism(tt.options, maxProcs)
117 if tt.want != got {
118 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
119 }
120 })
121 }
122
123 t.Run("index option is limited by available CPU", func(t *testing.T) {
124 s := &Server{
125 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
126 IndexDir: "/testdata/index",
127 IndexConcurrency: 1,
128 }
129
130 got := s.indexArgs(IndexOptions{
131 ShardConcurrency: 2048, // Some number that's way too high
132 })
133
134 if got.Parallelism >= 2048 {
135 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
136 }
137 })
138}
139
140func TestListRepoIDs(t *testing.T) {
141 grpcClient := &mockGRPCClient{}
142
143 clientOptions := []SourcegraphClientOption{
144 WithBatchSize(0),
145 }
146
147 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
148 testHostname := "test-hostname"
149 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...)
150
151 listCalled := false
152 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) {
153 listCalled = true
154
155 gotRepoIDs := in.GetIndexedIds()
156 slices.Sort(gotRepoIDs)
157
158 wantRepoIDs := []int32{1, 3}
159 slices.Sort(wantRepoIDs)
160
161 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
162 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
163 }
164
165 hostname := in.GetHostname()
166 if diff := cmp.Diff(testHostname, hostname); diff != "" {
167 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
168 }
169
170 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
171 }
172
173 ctx := context.Background()
174 got, err := s.List(ctx, []uint32{1, 3})
175 if err != nil {
176 t.Fatal(err)
177 }
178
179 if !listCalled {
180 t.Fatalf("List was not called")
181 }
182
183 receivedRepoIDs := got.IDs
184 slices.Sort(receivedRepoIDs)
185
186 expectedRepoIDs := []uint32{1, 2, 3}
187 slices.Sort(expectedRepoIDs)
188
189 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
190 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
191 }
192}
193
194func TestMain(m *testing.M) {
195 flag.Parse()
196 level := sglog.LevelInfo
197 if !testing.Verbose() {
198 log.SetOutput(io.Discard)
199 debugLog.SetOutput(io.Discard)
200 infoLog.SetOutput(io.Discard)
201 errorLog.SetOutput(io.Discard)
202 level = sglog.LevelNone
203 }
204
205 logtest.InitWithLevel(m, level)
206 os.Exit(m.Run())
207}
208
209func TestCreateEmptyShard(t *testing.T) {
210 dir := t.TempDir()
211
212 args := &indexArgs{
213 IndexOptions: IndexOptions{
214 RepoID: 7,
215 Name: "empty-repo",
216 CloneURL: "code/host",
217 },
218 Incremental: true,
219 IndexDir: dir,
220 Parallelism: 1,
221 FileLimit: 1,
222 }
223
224 if err := createEmptyShard(args); err != nil {
225 t.Fatal(err)
226 }
227
228 bo := args.BuildOptions()
229 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
230
231 if got := bo.IncrementalSkipIndexing(); !got {
232 t.Fatalf("wanted %t, got %t", true, got)
233 }
234}
235
236func TestFormatListUint32(t *testing.T) {
237 cases := []struct {
238 in []uint32
239 want string
240 }{
241 {
242 in: []uint32{42, 8, 3},
243 want: "42, 8, ...",
244 },
245 {
246 in: []uint32{42, 8},
247 want: "42, 8",
248 },
249 {
250 in: []uint32{42},
251 want: "42",
252 },
253 {
254 in: []uint32{},
255 want: "",
256 },
257 }
258
259 for _, tt := range cases {
260 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
261 out := formatListUint32(tt.in, 2)
262 if out != tt.want {
263 t.Fatalf("want %s, got %s", tt.want, out)
264 }
265 })
266 }
267}
268
269func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
270 wd, err := os.Getwd()
271 if err != nil {
272 t.Fatalf("failed to get working directory: %v", err)
273 }
274
275 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
276 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
277
278 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
279
280 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
281 if err != nil {
282 t.Fatalf("failed to validate default service config: %v", err)
283 }
284
285 if !result.Valid() {
286 var errs strings.Builder
287 for _, err := range result.Errors() {
288 errs.WriteString(fmt.Sprintf("- %s\n", err))
289 }
290
291 t.Fatalf("default service config is invalid:\n%s", errs.String())
292 }
293}
294
295func TestGetBoolFromEnvironmentVariables(t *testing.T) {
296 testCases := []struct {
297 name string
298 envVarsToSet map[string]string
299
300 envVarNames []string
301 defaultBool bool
302
303 wantBool bool
304 wantErr bool
305 }{
306 {
307 name: "respect default value: true",
308
309 envVarsToSet: map[string]string{},
310
311 envVarNames: []string{"FOO", "BAR"},
312 defaultBool: true,
313
314 wantBool: true,
315 },
316 {
317 name: "respect default value: false",
318
319 envVarsToSet: map[string]string{},
320
321 envVarNames: []string{"FOO", "BAR"},
322 defaultBool: false,
323
324 wantBool: false,
325 },
326 {
327 name: "read from environment",
328
329 envVarsToSet: map[string]string{"FOO": "1"},
330
331 envVarNames: []string{"FOO"},
332 defaultBool: false,
333
334 wantBool: true,
335 },
336 {
337 name: "read from first env var that is set",
338
339 envVarsToSet: map[string]string{
340 "BAR": "false",
341 "BAZ": "true",
342 },
343
344 envVarNames: []string{"FOO", "BAR", "BAZ"},
345 defaultBool: true,
346
347 wantBool: false,
348 },
349
350 {
351 name: "should error for invalid input",
352
353 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
354
355 envVarNames: []string{"INVALID"},
356 defaultBool: false,
357
358 wantErr: true,
359 },
360 }
361
362 for _, tc := range testCases {
363 t.Run("", func(t *testing.T) {
364 // Prepare the environment by loading all the appropriate environment variables
365 for _, v := range tc.envVarNames {
366 _ = os.Unsetenv(v)
367 }
368
369 for k := range tc.envVarsToSet {
370 _ = os.Unsetenv(k)
371 }
372
373 for k, v := range tc.envVarsToSet {
374 t.Setenv(k, v)
375 }
376
377 // Run the test
378 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
379
380 // Examine the results
381 if tc.wantErr != (err != nil) {
382 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
383 }
384
385 if got != tc.wantBool {
386 t.Errorf("got %v, want %v", got, tc.wantBool)
387 }
388 })
389 }
390}
391
392func TestAddDefaultPort(t *testing.T) {
393 tests := []struct {
394 name string
395 input string
396 want string
397 }{
398 {
399 name: "http no port",
400 input: "http://example.com",
401 want: "http://example.com:80",
402 },
403 {
404 name: "http custom port",
405 input: "http://example.com:90",
406 want: "http://example.com:90",
407 },
408 {
409 name: "https no port",
410 input: "https://example.com",
411 want: "https://example.com:443",
412 },
413 {
414 name: "https custom port",
415 input: "https://example.com:444",
416 want: "https://example.com:444",
417 },
418 {
419 name: "non-http scheme",
420 input: "ftp://example.com",
421 want: "ftp://example.com",
422 },
423 {
424 name: "empty string",
425 input: "",
426 want: "",
427 },
428 {
429 name: "local file path",
430 input: "/etc/hosts",
431 want: "/etc/hosts",
432 },
433 }
434
435 for _, test := range tests {
436 t.Run(test.name, func(t *testing.T) {
437 input, err := url.Parse(test.input)
438 if err != nil {
439 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
440 }
441
442 got := addDefaultPort(input)
443 if diff := cmp.Diff(test.want, got.String()); diff != "" {
444 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
445 }
446 })
447 }
448}