fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/http"
10 "net/http/httptest"
11 "net/url"
12 "os"
13 "path/filepath"
14 "sort"
15 "strings"
16 "testing"
17
18 sglog "github.com/sourcegraph/log"
19 "github.com/sourcegraph/log/logtest"
20
21 "github.com/xeipuuv/gojsonschema"
22 "google.golang.org/grpc"
23
24 "github.com/google/go-cmp/cmp"
25
26 "github.com/sourcegraph/zoekt"
27 proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1"
28)
29
30func TestServer_defaultArgs(t *testing.T) {
31 root, err := url.Parse("http://api.test")
32 if err != nil {
33 t.Fatal(err)
34 }
35
36 s := &Server{
37 Sourcegraph: newSourcegraphClient(root, "", WithBatchSize(0)),
38 IndexDir: "/testdata/index",
39 CPUCount: 6,
40 }
41 want := &indexArgs{
42 IndexOptions: IndexOptions{
43 Name: "testName",
44 },
45 IndexDir: "/testdata/index",
46 Parallelism: 6,
47 Incremental: true,
48 FileLimit: 1 << 20,
49 }
50 got := s.indexArgs(IndexOptions{Name: "testName"})
51 if !cmp.Equal(got, want) {
52 t.Errorf("mismatch (-want +got):\n%s", cmp.Diff(want, got))
53 }
54}
55
56func TestListRepoIDs(t *testing.T) {
57 t.Run("gRPC", func(t *testing.T) {
58
59 grpcClient := &mockGRPCClient{}
60
61 clientOptions := []SourcegraphClientOption{
62 WithShouldUseGRPC(true),
63 WithGRPCClient(grpcClient),
64 WithBatchSize(0),
65 }
66
67 testURL := url.URL{Scheme: "http", Host: "does.not.matter"}
68 testHostname := "test-hostname"
69 s := newSourcegraphClient(&testURL, testHostname, clientOptions...)
70
71 listCalled := false
72 grpcClient.mockList = func(ctx context.Context, in *proto.ListRequest, opts ...grpc.CallOption) (*proto.ListResponse, error) {
73 listCalled = true
74
75 gotRepoIDs := in.GetIndexedIds()
76 sort.Slice(gotRepoIDs, func(i, j int) bool {
77 return gotRepoIDs[i] < gotRepoIDs[j]
78 })
79
80 wantRepoIDs := []int32{1, 3}
81 sort.Slice(wantRepoIDs, func(i, j int) bool {
82 return wantRepoIDs[i] < wantRepoIDs[j]
83 })
84
85 if diff := cmp.Diff(wantRepoIDs, gotRepoIDs); diff != "" {
86 t.Errorf("indexed repoIDs mismatch (-want +got):\n%s", diff)
87 }
88
89 hostname := in.GetHostname()
90 if diff := cmp.Diff(testHostname, hostname); diff != "" {
91 t.Errorf("hostname mismatch (-want +got):\n%s", diff)
92 }
93
94 return &proto.ListResponse{RepoIds: []int32{1, 2, 3}}, nil
95 }
96
97 ctx := context.Background()
98 got, err := s.List(ctx, []uint32{1, 3})
99 if err != nil {
100 t.Fatal(err)
101 }
102
103 if !listCalled {
104 t.Fatalf("List was not called")
105 }
106
107 receivedRepoIDs := got.IDs
108 sort.Slice(receivedRepoIDs, func(i, j int) bool {
109 return receivedRepoIDs[i] < receivedRepoIDs[j]
110 })
111
112 expectedRepoIDs := []uint32{1, 2, 3}
113 sort.Slice(expectedRepoIDs, func(i, j int) bool {
114 return expectedRepoIDs[i] < expectedRepoIDs[j]
115 })
116
117 if diff := cmp.Diff(expectedRepoIDs, receivedRepoIDs); diff != "" {
118 t.Errorf("mismatch in list of all repoIDs (-want +got):\n%s", diff)
119 }
120 })
121
122 t.Run("REST", func(t *testing.T) {
123 var gotBody string
124 var gotURL *url.URL
125 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
126 gotURL = r.URL
127
128 b, err := io.ReadAll(r.Body)
129 if err != nil {
130 t.Fatal(err)
131 }
132 gotBody = string(b)
133
134 _, err = w.Write([]byte(`{"RepoIDs": [1, 2, 3]}`))
135 if err != nil {
136 t.Fatal(err)
137 }
138 }))
139 defer ts.Close()
140
141 u, err := url.Parse(ts.URL)
142 if err != nil {
143 t.Fatal(err)
144 }
145
146 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0))
147
148 gotRepos, err := s.List(context.Background(), []uint32{1, 3})
149 if err != nil {
150 t.Fatal(err)
151 }
152
153 if want := []uint32{1, 2, 3}; !cmp.Equal(gotRepos.IDs, want) {
154 t.Errorf("repos mismatch (-want +got):\n%s", cmp.Diff(want, gotRepos.IDs))
155 }
156 if want := `{"Hostname":"test-indexed-search-1","IndexedIDs":[1,3]}`; gotBody != want {
157 t.Errorf("body mismatch (-want +got):\n%s", cmp.Diff(want, gotBody))
158 }
159 if want := "/.internal/repos/index"; gotURL.Path != want {
160 t.Errorf("request path mismatch (-want +got):\n%s", cmp.Diff(want, gotURL.Path))
161 }
162 })
163}
164
165func TestListRepoIDs_Error_REST(t *testing.T) {
166 // Note: There is no gRPC equivalent to this test because gRPC errors are
167 // always returned as an error to the caller.
168
169 msg := "deadbeaf deadbeaf"
170 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
171
172 // This is how Sourcegraph returns error messages to the caller.
173 http.Error(w, msg, http.StatusInternalServerError)
174 }))
175 defer ts.Close()
176
177 u, err := url.Parse(ts.URL)
178 if err != nil {
179 t.Fatal(err)
180 }
181
182 s := newSourcegraphClient(u, "test-indexed-search-1", WithBatchSize(0))
183 s.restClient.RetryMax = 0
184
185 _, err = s.List(context.Background(), []uint32{1, 3})
186
187 if !strings.Contains(err.Error(), msg) {
188 t.Fatalf("%s does not contain %s", err.Error(), msg)
189 }
190}
191
192func TestMain(m *testing.M) {
193 flag.Parse()
194 level := sglog.LevelInfo
195 if !testing.Verbose() {
196 log.SetOutput(io.Discard)
197 level = sglog.LevelNone
198 }
199
200 logtest.InitWithLevel(m, level)
201 os.Exit(m.Run())
202}
203
204func TestCreateEmptyShard(t *testing.T) {
205 dir := t.TempDir()
206
207 args := &indexArgs{
208 IndexOptions: IndexOptions{
209 RepoID: 7,
210 Name: "empty-repo",
211 CloneURL: "code/host",
212 },
213 Incremental: true,
214 IndexDir: dir,
215 Parallelism: 1,
216 FileLimit: 1,
217 }
218
219 if err := createEmptyShard(args); err != nil {
220 t.Fatal(err)
221 }
222
223 bo := args.BuildOptions()
224 bo.RepositoryDescription.Branches = []zoekt.RepositoryBranch{{Name: "HEAD", Version: "404aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}
225
226 if got := bo.IncrementalSkipIndexing(); !got {
227 t.Fatalf("wanted %t, got %t", true, got)
228 }
229}
230
231func TestFormatListUint32(t *testing.T) {
232 cases := []struct {
233 in []uint32
234 want string
235 }{
236 {
237 in: []uint32{42, 8, 3},
238 want: "42, 8, ...",
239 },
240 {
241 in: []uint32{42, 8},
242 want: "42, 8",
243 },
244 {
245 in: []uint32{42},
246 want: "42",
247 },
248 {
249 in: []uint32{},
250 want: "",
251 },
252 }
253
254 for _, tt := range cases {
255 t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
256 out := formatListUint32(tt.in, 2)
257 if out != tt.want {
258 t.Fatalf("want %s, got %s", tt.want, out)
259 }
260 })
261 }
262}
263
264func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) {
265 wd, err := os.Getwd()
266 if err != nil {
267 t.Fatalf("failed to get working directory: %v", err)
268 }
269
270 schemaFile := filepath.Join(wd, "json_schemas", "ServiceConfig.json")
271 schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", schemaFile))
272
273 documentLoader := gojsonschema.NewStringLoader(defaultGRPCServiceConfigurationJSON)
274
275 result, err := gojsonschema.Validate(schemaLoader, documentLoader)
276 if err != nil {
277 t.Fatalf("failed to validate default service config: %v", err)
278 }
279
280 if !result.Valid() {
281 var errs strings.Builder
282 for _, err := range result.Errors() {
283 errs.WriteString(fmt.Sprintf("- %s\n", err))
284 }
285
286 t.Fatalf("default service config is invalid:\n%s", errs.String())
287 }
288}
289
290func TestAddDefaultPort(t *testing.T) {
291 tests := []struct {
292 name string
293 input string
294 want string
295 }{
296 {
297 name: "http no port",
298 input: "http://example.com",
299 want: "http://example.com:80",
300 },
301 {
302 name: "http custom port",
303 input: "http://example.com:90",
304 want: "http://example.com:90",
305 },
306 {
307 name: "https no port",
308 input: "https://example.com",
309 want: "https://example.com:443",
310 },
311 {
312 name: "https custom port",
313 input: "https://example.com:444",
314 want: "https://example.com:444",
315 },
316 {
317 name: "non-http scheme",
318 input: "ftp://example.com",
319 want: "ftp://example.com",
320 },
321 {
322 name: "empty string",
323 input: "",
324 want: "",
325 },
326 {
327 name: "local file path",
328 input: "/etc/hosts",
329 want: "/etc/hosts",
330 },
331 }
332
333 for _, test := range tests {
334 t.Run(test.name, func(t *testing.T) {
335 input, err := url.Parse(test.input)
336 if err != nil {
337 t.Fatalf("failed to parse test URL %q: %v", test.input, err)
338 }
339
340 got := addDefaultPort(input)
341 if diff := cmp.Diff(test.want, got.String()); diff != "" {
342 t.Errorf("addDefaultPort(%q) mismatch (-want +got):\n%s", test.input, diff)
343 }
344 })
345 }
346}