fork of https://github.com/sourcegraph/zoekt
1// Copyright 2016 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package index
16
17import (
18 "context"
19 "fmt"
20 "log"
21 "regexp/syntax"
22 "sort"
23 "strings"
24 "time"
25
26 enry_data "github.com/go-enry/go-enry/v2/data"
27 "github.com/grafana/regexp"
28
29 "github.com/sourcegraph/zoekt"
30 "github.com/sourcegraph/zoekt/internal/tenant"
31 "github.com/sourcegraph/zoekt/query"
32)
33
34// simplifyMultiRepo takes a query and a predicate. It returns Const(true) if all
35// repository names fulfill the predicate, Const(false) if none of them do, and q
36// otherwise.
37func (d *indexData) simplifyMultiRepo(q query.Q, predicate func(*zoekt.Repository) bool) query.Q {
38 count := 0
39 alive := len(d.repoMetaData)
40 for i := range d.repoMetaData {
41 if d.repoMetaData[i].Tombstone {
42 alive--
43 } else if predicate(&d.repoMetaData[i]) {
44 count++
45 }
46 }
47 if count == alive {
48 return &query.Const{Value: true}
49 }
50 if count > 0 {
51 return q
52 }
53 return &query.Const{Value: false}
54}
55
56func (d *indexData) simplify(in query.Q) query.Q {
57 eval := query.Map(in, func(q query.Q) query.Q {
58 switch r := q.(type) {
59 case *query.Repo:
60 return d.simplifyMultiRepo(q, func(repo *zoekt.Repository) bool {
61 return r.Regexp.MatchString(repo.Name)
62 })
63 case *query.RepoRegexp:
64 return d.simplifyMultiRepo(q, func(repo *zoekt.Repository) bool {
65 return r.Regexp.MatchString(repo.Name)
66 })
67 case *query.BranchesRepos:
68 for i := range d.repoMetaData {
69 for _, br := range r.List {
70 if br.Repos.Contains(d.repoMetaData[i].ID) {
71 return q
72 }
73 }
74 }
75 return &query.Const{Value: false}
76 case *query.RepoSet:
77 return d.simplifyMultiRepo(q, func(repo *zoekt.Repository) bool {
78 return r.Set[repo.Name]
79 })
80 case query.RawConfig:
81 return d.simplifyMultiRepo(q, func(repo *zoekt.Repository) bool { return uint8(r)&encodeRawConfig(repo.RawConfig) == uint8(r) })
82 case *query.RepoIDs:
83 return d.simplifyMultiRepo(q, func(repo *zoekt.Repository) bool {
84 return r.Repos.Contains(repo.ID)
85 })
86 case *query.Language:
87 _, has := d.metaData.LanguageMap[r.Language]
88 if !has && d.metaData.IndexFeatureVersion < 12 {
89 // For index files that haven't been re-indexed by go-enry,
90 // fall back to file-based matching and continue even if this
91 // repo doesn't have the specific language present.
92 extsForLang := enry_data.ExtensionsByLanguage[r.Language]
93 if extsForLang != nil {
94 extFrags := make([]string, 0, len(extsForLang))
95 for _, ext := range extsForLang {
96 extFrags = append(extFrags, regexp.QuoteMeta(ext))
97 }
98 if len(extFrags) > 0 {
99 pattern := fmt.Sprintf("(?i)(%s)$", strings.Join(extFrags, "|"))
100 // inlined copy of query.regexpQuery
101 re, err := syntax.Parse(pattern, syntax.Perl)
102 if err != nil {
103 return &query.Const{Value: false}
104 }
105 if re.Op == syntax.OpLiteral {
106 return &query.Substring{
107 Pattern: string(re.Rune),
108 FileName: true,
109 }
110 }
111 return &query.Regexp{
112 Regexp: re,
113 FileName: true,
114 }
115 }
116 }
117 }
118 if !has {
119 return &query.Const{Value: false}
120 }
121 case *query.Meta:
122 return d.simplifyMultiRepo(q, func(repo *zoekt.Repository) bool {
123 if repo.Metadata == nil {
124 return false
125 }
126 v, ok := repo.Metadata[r.Field]
127 if !ok {
128 return false
129 }
130 return r.Value.MatchString(v)
131 })
132 }
133 return q
134 })
135 return query.Simplify(eval)
136}
137
138func (d *indexData) Search(ctx context.Context, q query.Q, opts *zoekt.SearchOptions) (sr *zoekt.SearchResult, err error) {
139 timer := newTimer()
140
141 copyOpts := *opts
142 opts = ©Opts
143 opts.SetDefaults()
144
145 var res zoekt.SearchResult
146 if len(d.fileNameIndex) == 0 {
147 return &res, nil
148 }
149
150 select {
151 case <-ctx.Done():
152 res.Stats.ShardsSkipped++
153 return &res, nil
154 default:
155 }
156
157 q = d.simplify(q)
158 if c, ok := q.(*query.Const); ok && !c.Value {
159 return &res, nil
160 }
161
162 if opts.EstimateDocCount {
163 res.Stats.ShardFilesConsidered = len(d.fileBranchMasks)
164 return &res, nil
165 }
166
167 q = query.Map(q, query.ExpandFileContent)
168
169 mt, err := d.newMatchTree(q, matchTreeOpt{})
170 if err != nil {
171 return nil, err
172 }
173
174 // Capture the costs of construction before pruning
175 updateMatchTreeStats(mt, &res.Stats)
176
177 mt, err = pruneMatchTree(mt)
178 if err != nil {
179 return nil, err
180 }
181 res.Stats.MatchTreeConstruction = timer.Elapsed()
182 if mt == nil {
183 res.Stats.ShardsSkippedFilter++
184 return &res, nil
185 }
186
187 res.Stats.ShardsScanned++
188
189 cp := &contentProvider{
190 id: d,
191 stats: &res.Stats,
192 }
193
194 // Track the number of documents found in a repository for
195 // ShardRepoMaxMatchCount
196 var (
197 lastRepoID uint16
198 repoMatchCount int
199 )
200
201 docCount := uint32(len(d.fileBranchMasks))
202 lastDoc := int(-1)
203
204nextFileMatch:
205 for {
206 canceled := false
207 select {
208 case <-ctx.Done():
209 canceled = true
210 default:
211 }
212
213 nextDoc := mt.nextDoc()
214 if int(nextDoc) <= lastDoc {
215 nextDoc = uint32(lastDoc + 1)
216 }
217
218 for ; nextDoc < docCount; nextDoc++ {
219 repoID := d.repos[nextDoc]
220 repoMetadata := &d.repoMetaData[repoID]
221
222 // Skip tombstoned repositories
223 if repoMetadata.Tombstone {
224 continue
225 }
226
227 // 🚨 SECURITY: Skip documents that don't belong to the tenant. This check is
228 // necessary to prevent leaking data across tenants.
229 if !tenant.HasAccess(ctx, repoMetadata.TenantID) {
230 continue
231 }
232
233 // Skip documents that are tombstoned
234 if len(repoMetadata.FileTombstones) > 0 {
235 if _, tombstoned := repoMetadata.FileTombstones[string(d.fileName(nextDoc))]; tombstoned {
236 continue
237 }
238 }
239
240 // Skip documents over ShardRepoMaxMatchCount if specified.
241 if opts.ShardRepoMaxMatchCount > 0 {
242 if repoMatchCount >= opts.ShardRepoMaxMatchCount && repoID == lastRepoID {
243 res.Stats.FilesSkipped++
244 continue
245 }
246 }
247
248 break
249 }
250
251 if nextDoc >= docCount {
252 break
253 }
254
255 lastDoc = int(nextDoc)
256
257 // We track lastRepoID for ShardRepoMaxMatchCount
258 if lastRepoID != d.repos[nextDoc] {
259 lastRepoID = d.repos[nextDoc]
260 repoMatchCount = 0
261 }
262
263 if canceled || (res.Stats.MatchCount >= opts.ShardMaxMatchCount && opts.ShardMaxMatchCount > 0) {
264 res.Stats.FilesSkipped += int(docCount - nextDoc)
265 break
266 }
267
268 res.Stats.FilesConsidered++
269 mt.prepare(nextDoc)
270
271 cp.setDocument(nextDoc)
272
273 known := make(map[matchTree]bool)
274 md := d.repoMetaData[d.repos[nextDoc]]
275
276 for cost := costMin; cost <= costMax; cost++ {
277 switch evalMatchTree(cp, cost, known, mt) {
278 case matchesRequiresHigherCost:
279 if cost == costMax {
280 log.Panicf("did not decide. Repo %s, doc %d, known %v",
281 md.Name, nextDoc, known)
282 }
283 case matchesFound:
284 // could short-circuit now, but we want to run higher costs to
285 // potentially find higher ranked matches.
286 case matchesNone:
287 continue nextFileMatch
288 }
289 }
290
291 fileMatch := zoekt.FileMatch{
292 Repository: md.Name,
293 RepositoryID: md.ID,
294 RepositoryPriority: md.GetPriority(),
295 FileName: string(d.fileName(nextDoc)),
296 Checksum: d.getChecksum(nextDoc),
297 Language: d.languageMap[d.getLanguage(nextDoc)],
298 }
299
300 if s := d.subRepos[nextDoc]; s > 0 {
301 if s >= uint32(len(d.subRepoPaths[d.repos[nextDoc]])) {
302 log.Panicf("corrupt index: subrepo %d beyond %v", s, d.subRepoPaths)
303 }
304 path := d.subRepoPaths[d.repos[nextDoc]][s]
305 fileMatch.SubRepositoryPath = path
306 sr := md.SubRepoMap[path]
307 fileMatch.SubRepositoryName = sr.Name
308 if idx := d.branchIndex(nextDoc); idx >= 0 {
309 fileMatch.Version = sr.Branches[idx].Version
310 }
311 } else {
312 idx := d.branchIndex(nextDoc)
313 if idx >= 0 {
314 fileMatch.Version = md.Branches[idx].Version
315 }
316 }
317
318 // Important invariant for performance: finalCands is sorted by offset and
319 // non-overlapping. gatherMatches respects this invariant and all later
320 // transformations respect this.
321 finalCands := d.gatherMatches(nextDoc, mt, known)
322
323 if opts.ChunkMatches {
324 fileMatch.ChunkMatches = cp.fillChunkMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts)
325 } else {
326 fileMatch.LineMatches = cp.fillMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts)
327 }
328
329 if opts.UseBM25Scoring {
330 d.scoreFileBM25(&fileMatch, nextDoc, finalCands, cp, opts)
331 } else {
332 // Use the standard, non-experimental scoring method by default
333 d.scoreFile(&fileMatch, nextDoc, mt, known, opts)
334 }
335
336 fileMatch.Branches = d.gatherBranches(nextDoc, mt, known)
337 sortMatchesByScore(fileMatch.LineMatches)
338 sortChunkMatchesByScore(fileMatch.ChunkMatches)
339 if opts.Whole {
340 fileMatch.Content = cp.data(false)
341 }
342
343 matchedChunkRanges := 0
344 for _, cm := range fileMatch.ChunkMatches {
345 matchedChunkRanges += len(cm.Ranges)
346 }
347
348 repoMatchCount += len(fileMatch.LineMatches)
349 repoMatchCount += matchedChunkRanges
350
351 res.Files = append(res.Files, fileMatch)
352
353 res.Stats.MatchCount += len(fileMatch.LineMatches)
354 res.Stats.MatchCount += matchedChunkRanges
355 res.Stats.FileCount++
356 }
357
358 for _, md := range d.repoMetaData {
359 r := md
360 addRepo(&res, &r)
361 for _, v := range r.SubRepoMap {
362 addRepo(&res, v)
363 }
364 }
365
366 // Update stats based on work done during document search.
367 updateMatchTreeStats(mt, &res.Stats)
368
369 res.Stats.MatchTreeSearch = timer.Elapsed()
370
371 return &res, nil
372}
373
374func addRepo(res *zoekt.SearchResult, repo *zoekt.Repository) {
375 if res.RepoURLs == nil {
376 res.RepoURLs = map[string]string{}
377 }
378 res.RepoURLs[repo.Name] = repo.FileURLTemplate
379
380 if res.LineFragments == nil {
381 res.LineFragments = map[string]string{}
382 }
383 res.LineFragments[repo.Name] = repo.LineFragmentTemplate
384}
385
386// Gather matches from this document. The matches are returned in document
387// order and are non-overlapping. All filename and content matches are
388// returned, with filename matches first.
389//
390// If `merge` is set, overlapping and adjacent matches will be merged
391// into a single index. Otherwise, overlapping matches will be removed,
392// but adjacent matches will remain.
393func (d *indexData) gatherMatches(nextDoc uint32, mt matchTree, known map[matchTree]bool) []*candidateMatch {
394 var cands []*candidateMatch
395 visitMatches(mt, known, 1, func(mt matchTree, scoreWeight float64) {
396 if smt, ok := mt.(*substrMatchTree); ok {
397 cands = append(cands, setScoreWeight(scoreWeight, smt.current)...)
398 }
399 if rmt, ok := mt.(*regexpMatchTree); ok {
400 cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...)
401 }
402 if rmt, ok := mt.(*wordMatchTree); ok {
403 cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...)
404 }
405 if smt, ok := mt.(*symbolRegexpMatchTree); ok {
406 cands = append(cands, setScoreWeight(scoreWeight, smt.found)...)
407 }
408 })
409
410 // If we found no candidate matches at all, assume there must have been a match on filename.
411 if len(cands) == 0 {
412 nm := d.fileName(nextDoc)
413 return []*candidateMatch{{
414 caseSensitive: false,
415 fileName: true,
416 substrBytes: nm,
417 substrLowered: nm,
418 file: nextDoc,
419 runeOffset: 0,
420 byteOffset: 0,
421 byteMatchSz: uint32(len(nm)),
422 }}
423 }
424
425 // Remove overlapping candidates. This guarantees that the matches
426 // are non-overlapping, but also preserves expected match counts.
427 sort.Sort((sortByOffsetSlice)(cands))
428 res := cands[:0]
429 for i, c := range cands {
430 if i == 0 {
431 res = append(res, c)
432 continue
433 }
434
435 last := res[len(res)-1]
436
437 // Never compare filename and content matches
438 if last.fileName != c.fileName {
439 res = append(res, c)
440 continue
441 }
442
443 // Only add the match if its range doesn't overlap
444 lastEnd := last.byteOffset + last.byteMatchSz
445 if lastEnd <= c.byteOffset {
446 res = append(res, c)
447 continue
448 }
449 }
450 return res
451}
452
453type sortByOffsetSlice []*candidateMatch
454
455func (m sortByOffsetSlice) Len() int { return len(m) }
456func (m sortByOffsetSlice) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
457func (m sortByOffsetSlice) Less(i, j int) bool {
458 // Sort all filename matches to the start
459 if m[i].fileName != m[j].fileName {
460 return m[i].fileName
461 }
462
463 if m[i].byteOffset == m[j].byteOffset { // tie break if same offset
464 // Prefer longer candidates if starting at same position
465 return m[i].byteMatchSz > m[j].byteMatchSz
466 }
467 return m[i].byteOffset < m[j].byteOffset
468}
469
470// setScoreWeight is a helper used by gatherMatches to set the weight based on
471// the score weight of the matchTree.
472func setScoreWeight(scoreWeight float64, cm []*candidateMatch) []*candidateMatch {
473 for _, m := range cm {
474 m.scoreWeight = scoreWeight
475 }
476 return cm
477}
478
479func (d *indexData) branchIndex(docID uint32) int {
480 mask := d.fileBranchMasks[docID]
481 idx := 0
482 for mask != 0 {
483 if mask&0x1 != 0 {
484 return idx
485 }
486 idx++
487 mask >>= 1
488 }
489 return -1
490}
491
492// gatherBranches returns a list of branch names taking into account any branch
493// filters in the query. If the query contains a branch filter, it returns all
494// branches containing the docID and matching the branch filter. Otherwise, it
495// returns all branches containing docID.
496func (d *indexData) gatherBranches(docID uint32, mt matchTree, known map[matchTree]bool) []string {
497 var mask uint64
498 visitMatchAtoms(mt, known, func(mt matchTree) {
499 bq, ok := mt.(*branchQueryMatchTree)
500 if !ok {
501 return
502 }
503
504 mask = mask | bq.branchMask()
505 })
506
507 if mask == 0 {
508 mask = d.fileBranchMasks[docID]
509 }
510
511 var branches []string
512 id := uint64(1)
513 branchNames := d.branchNames[d.repos[docID]]
514 for mask != 0 {
515 if mask&0x1 != 0 {
516 branches = append(branches, branchNames[uint(id)])
517 }
518 id <<= 1
519 mask >>= 1
520 }
521
522 return branches
523}
524
525func (d *indexData) List(ctx context.Context, q query.Q, opts *zoekt.ListOptions) (rl *zoekt.RepoList, err error) {
526 var include func(rle *zoekt.RepoListEntry) bool
527
528 q = d.simplify(q)
529 if c, ok := q.(*query.Const); ok {
530 if !c.Value {
531 return &zoekt.RepoList{}, nil
532 }
533 include = func(rle *zoekt.RepoListEntry) bool {
534 return true
535 }
536 } else {
537 sr, err := d.Search(ctx, q, &zoekt.SearchOptions{
538 ShardRepoMaxMatchCount: 1,
539 })
540 if err != nil {
541 return nil, err
542 }
543
544 foundRepos := make(map[string]struct{}, len(sr.Files))
545 for _, file := range sr.Files {
546 foundRepos[file.Repository] = struct{}{}
547 }
548
549 include = func(rle *zoekt.RepoListEntry) bool {
550 _, ok := foundRepos[rle.Repository.Name]
551 return ok
552 }
553 }
554
555 var l zoekt.RepoList
556
557 field, err := opts.GetField()
558 if err != nil {
559 return nil, err
560 }
561 switch field {
562 case zoekt.RepoListFieldRepos:
563 l.Repos = make([]*zoekt.RepoListEntry, 0, len(d.repoListEntry))
564 case zoekt.RepoListFieldReposMap:
565 l.ReposMap = make(zoekt.ReposMap, len(d.repoListEntry))
566 }
567
568 for i := range d.repoListEntry {
569 if d.repoMetaData[i].Tombstone {
570 continue
571 }
572 // 🚨 SECURITY: Skip documents that don't belong to the tenant. This check is
573 // necessary to prevent leaking data across tenants.
574 if !tenant.HasAccess(ctx, d.repoMetaData[i].TenantID) {
575 continue
576 }
577 rle := &d.repoListEntry[i]
578 if !include(rle) {
579 continue
580 }
581
582 l.Stats.Add(&rle.Stats)
583
584 // Backwards compat for when ID is missing
585 if rle.Repository.ID == 0 {
586 l.Repos = append(l.Repos, rle)
587 continue
588 }
589
590 switch field {
591 case zoekt.RepoListFieldRepos:
592 l.Repos = append(l.Repos, rle)
593 case zoekt.RepoListFieldReposMap:
594 l.ReposMap[rle.Repository.ID] = zoekt.MinimalRepoListEntry{
595 HasSymbols: rle.Repository.HasSymbols,
596 Branches: rle.Repository.Branches,
597 IndexTimeUnix: rle.IndexMetadata.IndexTime.Unix(),
598 }
599 }
600
601 }
602
603 // Only one of these fields is populated and in all cases the size of that
604 // field is the number of Repos in this shard.
605 l.Stats.Repos = len(l.Repos) + len(l.ReposMap)
606
607 return &l, nil
608}
609
610// regexpToMatchTreeRecursive converts a regular expression to a matchTree mt. If
611// mt is equivalent to the input r, isEqual = true and the matchTree can be used
612// in place of the regex r. If singleLine = true, then the matchTree and all
613// its children only match terms on the same line. singleLine is used during
614// recursion to decide whether to return an andLineMatchTree (singleLine = true)
615// or a andMatchTree (singleLine = false).
616func (d *indexData) regexpToMatchTreeRecursive(r *syntax.Regexp, minTextSize int, fileName bool, caseSensitive bool) (mt matchTree, isEqual bool, singleLine bool, err error) {
617 // TODO - we could perhaps transform Begin/EndText in '\n'?
618 // TODO - we could perhaps transform CharClass in (OrQuery )
619 // if there are just a few runes, and part of a OpConcat?
620 switch r.Op {
621 case syntax.OpLiteral:
622 s := string(r.Rune)
623 if len(s) >= minTextSize {
624 ignoreCase := syntax.FoldCase == (r.Flags & syntax.FoldCase)
625 mt, err := d.newSubstringMatchTree(&query.Substring{Pattern: s, FileName: fileName, CaseSensitive: !ignoreCase && caseSensitive})
626 return mt, true, !strings.Contains(s, "\n"), err
627 }
628 case syntax.OpCapture:
629 return d.regexpToMatchTreeRecursive(r.Sub[0], minTextSize, fileName, caseSensitive)
630
631 case syntax.OpPlus:
632 return d.regexpToMatchTreeRecursive(r.Sub[0], minTextSize, fileName, caseSensitive)
633
634 case syntax.OpRepeat:
635 if r.Min == 1 {
636 return d.regexpToMatchTreeRecursive(r.Sub[0], minTextSize, fileName, caseSensitive)
637 } else if r.Min > 1 {
638 // (x){2,} can't be expressed precisely by the matchTree
639 mt, _, singleLine, err := d.regexpToMatchTreeRecursive(r.Sub[0], minTextSize, fileName, caseSensitive)
640 return mt, false, singleLine, err
641 }
642 case syntax.OpConcat, syntax.OpAlternate:
643 var qs []matchTree
644 isEq := true
645 singleLine = true
646 for _, sr := range r.Sub {
647 if sq, subIsEq, subSingleLine, err := d.regexpToMatchTreeRecursive(sr, minTextSize, fileName, caseSensitive); sq != nil {
648 if err != nil {
649 return nil, false, false, err
650 }
651 isEq = isEq && subIsEq
652 singleLine = singleLine && subSingleLine
653 qs = append(qs, sq)
654 }
655 }
656 if r.Op == syntax.OpConcat {
657 if len(qs) > 1 {
658 isEq = false
659 }
660 newQs := make([]matchTree, 0, len(qs))
661 for _, q := range qs {
662 if _, ok := q.(*bruteForceMatchTree); ok {
663 continue
664 }
665 newQs = append(newQs, q)
666 }
667 if len(newQs) == 1 {
668 return newQs[0], isEq, singleLine, nil
669 }
670 if len(newQs) == 0 {
671 return &bruteForceMatchTree{}, isEq, singleLine, nil
672 }
673 if singleLine {
674 return &andLineMatchTree{andMatchTree{children: newQs}}, isEq, singleLine, nil
675 }
676 return &andMatchTree{newQs}, isEq, singleLine, nil
677 }
678 for _, q := range qs {
679 if _, ok := q.(*bruteForceMatchTree); ok {
680 return q, isEq, false, nil
681 }
682 }
683 if len(qs) == 0 {
684 return &noMatchTree{Why: "const"}, isEq, false, nil
685 }
686 return &orMatchTree{qs}, isEq, false, nil
687 case syntax.OpStar:
688 if r.Sub[0].Op == syntax.OpAnyCharNotNL {
689 return &bruteForceMatchTree{}, false, true, nil
690 }
691 }
692 return &bruteForceMatchTree{}, false, false, nil
693}
694
695type timer struct {
696 last time.Time
697}
698
699func newTimer() *timer {
700 return &timer{
701 last: time.Now(),
702 }
703}
704
705func (t *timer) Elapsed() time.Duration {
706 now := time.Now()
707 d := now.Sub(t.last)
708 t.last = now
709 return d
710}