fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "path/filepath"
12 "slices"
13 "strings"
14 "testing"
15 "unicode/utf8"
16
17 "github.com/google/go-cmp/cmp"
18 sglog "github.com/sourcegraph/log"
19 "github.com/sourcegraph/log/logtest"
20 "github.com/stretchr/testify/require"
21 "github.com/xeipuuv/gojsonschema"
22 "google.golang.org/grpc"
23
24 "github.com/sourcegraph/zoekt"
25 configv1 "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/grpc/protos/sourcegraph/zoekt/configuration/v1"
26 "github.com/sourcegraph/zoekt/internal/tenant"
27)
28
29func TestServer_defaultArgs(t *testing.T) {
30 root, err := url.Parse("http://api.test")
31 if err != nil {
32 t.Fatal(err)
33 }
34
35 s := &Server{
36 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
37 IndexDir: "/testdata/index",
38 CPUCount: 6,
39 IndexConcurrency: 1,
40 }
41 want := &indexArgs{
42 IndexOptions: IndexOptions{
43 Name: "testName",
44 },
45 IndexDir: "/testdata/index",
46 Parallelism: 6,
47 Incremental: true,
48 }
49 got := s.indexArgs(IndexOptions{Name: "testName"})
50 if !cmp.Equal(got, want) {
51 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
52 }
53}
54
55func TestIndexNoTenant(t *testing.T) {
56 s := &Server{}
57 _, err := s.index(context.Background(), &indexArgs{})
58 require.ErrorIs(t, err, tenant.ErrMissingTenant)
59}
60
61func TestTruncateFailureMessageForSourcegraph(t *testing.T) {
62 t.Run("preserves short message", func(t *testing.T) {
63 require.Equal(t, "boom", truncateFailureMessageForSourcegraph("boom"))
64 })
65
66 t.Run("truncates oversized utf8 message", func(t *testing.T) {
67 input := strings.Repeat("é", maxFailureMessageBytes) + "tail"
68
69 got := truncateFailureMessageForSourcegraph(input)
70 parts := strings.Split(got, failureMessageTruncationMarker)
71
72 require.Len(t, parts, 2)
73 require.LessOrEqual(t, len(got), maxFailureMessageBytes)
74 require.True(t, utf8.ValidString(got))
75 require.NotEmpty(t, parts[0])
76 require.True(t, strings.HasPrefix(input, parts[0]))
77 require.Contains(t, parts[1], "tail")
78 require.Less(t, len(parts[0])+len(parts[1]), len(input))
79 })
80}
81
82func TestServer_parallelism(t *testing.T) {
83 root, err := url.Parse("http://api.test")
84 if err != nil {
85 t.Fatal(err)
86 }
87
88 cases := []struct {
89 name string
90 cpuCount int
91 indexConcurrency int
92 options IndexOptions
93 want int
94 }{
95 {
96 name: "CPU count divides evenly",
97 cpuCount: 16,
98 indexConcurrency: 8,
99 want: 2,
100 },
101 {
102 name: "no shard level parallelism",
103 cpuCount: 4,
104 indexConcurrency: 4,
105 want: 1,
106 },
107 {
108 name: "index option overrides server flag",
109 cpuCount: 2,
110 indexConcurrency: 1,
111 options: IndexOptions{
112 ShardConcurrency: 1,
113 },
114 want: 1,
115 },
116 {
117 name: "ignore invalid index option",
118 cpuCount: 8,
119 indexConcurrency: 2,
120 options: IndexOptions{
121 ShardConcurrency: -1,
122 },
123 want: 4,
124 },
125 }
126
127 for _, tt := range cases {
128 t.Run(tt.name, func(t *testing.T) {
129 s := &Server{
130 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
131 IndexDir: "/testdata/index",
132 CPUCount: tt.cpuCount,
133 IndexConcurrency: tt.indexConcurrency,
134 }
135
136 maxProcs := 16
137 got := s.parallelism(tt.options, maxProcs)
138 if tt.want != got {
139 t.Errorf("mismatch, want: %d, got: %d", tt.want, got)
140 }
141 })
142 }
143
144 t.Run("index option is limited by available CPU", func(t *testing.T) {
145 s := &Server{
146 Sourcegraph: newSourcegraphClient(root, "", nil, WithBatchSize(0)),
147 IndexDir: "/testdata/index",
148 IndexConcurrency: 1,
149 }
150
151 got := s.indexArgs(IndexOptions{
152 ShardConcurrency: 2048, // Some number that's way too high
153 })
154
155 if got.Parallelism >= 2048 {
156 t.Errorf("parallelism should be limited by available CPUs, instead got %d", got.Parallelism)
157 }
158 })
159}
160
161func TestListRepoIDs(t *testing.T) {
162 grpcClient := &mockGRPCClient{}
163
164 clientOptions := []SourcegraphClientOption{
165 WithBatchSize(0),
166 }
167
168 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
169 testHostname := "test-hostname"
170 s := newSourcegraphClient(&testURL, testHostname, grpcClient, clientOptions...)
171
172 listCalled := false
173 grpcClient.mockList = func(ctx context.Context, in *configv1.ListRequest, opts ...grpc.CallOption) (*configv1.ListResponse, error) {
174 listCalled = true
175
176 gotRepoIDs := in.GetIndexedIds()
177 slices.Sort(gotRepoIDs)
178
179 wantRepoIDs := []int32{1, 3}
180 slices.Sort(wantRepoIDs)
181
182 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
183 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
184 }
185
186 hostname := in.GetHostname()
187 if diff := cmp.Diff(testHostname, hostname); diff != "" {
188 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
189 }
190
191 return &configv1.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
192 }
193
194 ctx := context.Background()
195 got, err := s.List(ctx, []uint32{1, 3})
196 if err != nil {
197 t.Fatal(err)
198 }
199
200 if !listCalled {
201 t.Fatalf("List was not called")
202 }
203
204 receivedRepoIDs := got.IDs
205 slices.Sort(receivedRepoIDs)
206
207 expectedRepoIDs := []uint32{1, 2, 3}
208 slices.Sort(expectedRepoIDs)
209
210 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
211 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
212 }
213}
214
215func TestMain(m *testing.M) {
216 flag.Parse()
217 level := sglog.LevelInfo
218 if !testing.Verbose() {
219 log.SetOutput(io.Discard)
220 debugLog.SetOutput(io.Discard)
221 infoLog.SetOutput(io.Discard)
222 errorLog.SetOutput(io.Discard)
223 level = sglog.LevelNone
224 }
225
226 logtest.InitWithLevel(m, level)
227 os.Exit(m.Run())
228}
229
230func TestCreateEmptyShard(t *testing.T) {
231 dir := t.TempDir()
232
233 args := &indexArgs{
234 IndexOptions: IndexOptions{
235 RepoID: 7,
236 Name: "empty-repo",
237 CloneURL: "code/host",
238 },
239 Incremental: true,
240 IndexDir: dir,
241 Parallelism: 1,
242 }
243
244 if err := createEmptyShard(args); err != nil {
245 t.Fatal(err)
246 }
247
248 bo := args.BuildOptions()
249 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
250
251 if got := bo.IncrementalSkipIndexing(); !got {
252 t.Fatalf("wanted %t, got %t", true, got)
253 }
254}
255
256func TestFormatListUint32(t *testing.T) {
257 cases := []struct {
258 in []uint32
259 want string
260 }{
261 {
262 in: []uint32{42, 8, 3},
263 want: "42, 8, ...",
264 },
265 {
266 in: []uint32{42, 8},
267 want: "42, 8",
268 },
269 {
270 in: []uint32{42},
271 want: "42",
272 },
273 {
274 in: []uint32{},
275 want: "",
276 },
277 }
278
279 for _, tt := range cases {
280 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
281 out := formatListUint32(tt.in, 2)
282 if out != tt.want {
283 t.Fatalf("want %s, got %s", tt.want, out)
284 }
285 })
286 }
287}
288
289func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
290 wd, err := os.Getwd()
291 if err != nil {
292 t.Fatalf("failed to get working directory: %v", err)
293 }
294
295 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
296 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
297
298 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
299
300 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
301 if err != nil {
302 t.Fatalf("failed to validate default service config: %v", err)
303 }
304
305 if !result.Valid() {
306 var errs strings.Builder
307 for _, err := range result.Errors() {
308 errs.WriteString(fmt.Sprintf("- %s\n", err))
309 }
310
311 t.Fatalf("default service config is invalid:\n%s", errs.String())
312 }
313}
314
315func TestGetBoolFromEnvironmentVariables(t *testing.T) {
316 testCases := []struct {
317 name string
318 envVarsToSet map[string]string
319
320 envVarNames []string
321 defaultBool bool
322
323 wantBool bool
324 wantErr bool
325 }{
326 {
327 name: "respect default value: true",
328
329 envVarsToSet: map[string]string{},
330
331 envVarNames: []string{"FOO", "BAR"},
332 defaultBool: true,
333
334 wantBool: true,
335 },
336 {
337 name: "respect default value: false",
338
339 envVarsToSet: map[string]string{},
340
341 envVarNames: []string{"FOO", "BAR"},
342 defaultBool: false,
343
344 wantBool: false,
345 },
346 {
347 name: "read from environment",
348
349 envVarsToSet: map[string]string{"FOO": "1"},
350
351 envVarNames: []string{"FOO"},
352 defaultBool: false,
353
354 wantBool: true,
355 },
356 {
357 name: "read from first env var that is set",
358
359 envVarsToSet: map[string]string{
360 "BAR": "false",
361 "BAZ": "true",
362 },
363
364 envVarNames: []string{"FOO", "BAR", "BAZ"},
365 defaultBool: true,
366
367 wantBool: false,
368 },
369
370 {
371 name: "should error for invalid input",
372
373 envVarsToSet: map[string]string{"INVALID": "not a boolean"},
374
375 envVarNames: []string{"INVALID"},
376 defaultBool: false,
377
378 wantErr: true,
379 },
380 }
381
382 for _, tc := range testCases {
383 t.Run("", func(t *testing.T) {
384 // Prepare the environment by loading all the appropriate environment variables
385 for _, v := range tc.envVarNames {
386 _ = os.Unsetenv(v)
387 }
388
389 for k := range tc.envVarsToSet {
390 _ = os.Unsetenv(k)
391 }
392
393 for k, v := range tc.envVarsToSet {
394 t.Setenv(k, v)
395 }
396
397 // Run the test
398 got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool)
399
400 // Examine the results
401 if tc.wantErr != (err != nil) {
402 t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err)
403 }
404
405 if got != tc.wantBool {
406 t.Errorf("got %v, want %v", got, tc.wantBool)
407 }
408 })
409 }
410}
411
412func TestAddDefaultPort(t *testing.T) {
413 tests := []struct {
414 name string
415 input string
416 want string
417 }{
418 {
419 name: "http no port",
420 input: "http://example.com",
421 want: "http://example.com:80",
422 },
423 {
424 name: "http custom port",
425 input: "http://example.com:90",
426 want: "http://example.com:90",
427 },
428 {
429 name: "https no port",
430 input: "https://example.com",
431 want: "https://example.com:443",
432 },
433 {
434 name: "https custom port",
435 input: "https://example.com:444",
436 want: "https://example.com:444",
437 },
438 {
439 name: "non-http scheme",
440 input: "ftp://example.com",
441 want: "ftp://example.com",
442 },
443 {
444 name: "empty string",
445 input: "",
446 want: "",
447 },
448 {
449 name: "local file path",
450 input: "/etc/hosts",
451 want: "/etc/hosts",
452 },
453 }
454
455 for _, test := range tests {
456 t.Run(test.name, func(t *testing.T) {
457 input, err := url.Parse(test.input)
458 if err != nil {
459 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
460 }
461
462 got := addDefaultPort(input)
463 if diff := cmp.Diff(test.want, got.String()); diff != "" {
464 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
465 }
466 })
467 }
468}