fork of https://github.com/sourcegraph/zoekt
1// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package index
16
17import (
18 "encoding/binary"
19 "fmt"
20
21 "github.com/sourcegraph/zoekt"
22)
23
24// hitIterator finds potential search matches, measured in offsets of
25// the concatenation of all documents.
26type hitIterator interface {
27 // Return the first hit, or maxUInt32 if none.
28 first() uint32
29
30 // Skip until past limit. The argument maxUInt32 should be
31 // treated specially.
32 next(limit uint32)
33
34 // Return how many bytes were read.
35 updateStats(s *zoekt.Stats)
36}
37
38// distanceHitIterator looks for hits at a fixed distance apart.
39type distanceHitIterator struct {
40 i1 hitIterator
41 i2 hitIterator
42 distance uint32
43 started bool
44}
45
46func (i *distanceHitIterator) String() string {
47 return fmt.Sprintf("dist(%d, %v, %v)", i.distance, i.i1, i.i2)
48}
49
50func (i *distanceHitIterator) findNext() {
51 for {
52 var p1, p2 uint32
53 p1 = i.i1.first()
54 p2 = i.i2.first()
55 if p1 == maxUInt32 || p2 == maxUInt32 {
56 i.i1.next(maxUInt32)
57 break
58 }
59
60 if p1+i.distance < p2 {
61 i.i1.next(p2 - i.distance - 1)
62 } else if p1+i.distance > p2 {
63 i.i2.next(p1 + i.distance - 1)
64 } else {
65 break
66 }
67 }
68}
69
70func (i *distanceHitIterator) first() uint32 {
71 if !i.started {
72 i.findNext()
73 i.started = true
74 }
75 return i.i1.first()
76}
77
78func (i *distanceHitIterator) updateStats(s *zoekt.Stats) {
79 i.i1.updateStats(s)
80 i.i2.updateStats(s)
81}
82
83func (i *distanceHitIterator) next(limit uint32) {
84 i.i1.next(limit)
85 l2 := limit + i.distance
86
87 if l2 < limit { // overflow.
88 l2 = maxUInt32
89 }
90 i.i2.next(l2)
91 i.findNext()
92}
93
94func (d *indexData) newDistanceTrigramIter(ng1, ng2 ngram, dist uint32, caseSensitive, fileName bool) (hitIterator, error) {
95 if dist == 0 {
96 return nil, fmt.Errorf("d == 0")
97 }
98
99 i1, err := d.trigramHitIterator(ng1, caseSensitive, fileName)
100 if err != nil {
101 return nil, err
102 }
103 i2, err := d.trigramHitIterator(ng2, caseSensitive, fileName)
104 if err != nil {
105 return nil, err
106 }
107 return &distanceHitIterator{
108 i1: i1,
109 i2: i2,
110 distance: dist,
111 }, nil
112}
113
114func (d *indexData) trigramHitIterator(ng ngram, caseSensitive, fileName bool) (hitIterator, error) {
115 variants := []ngram{ng}
116 if !caseSensitive {
117 variants = generateCaseNgrams(ng)
118 }
119
120 iters := make([]hitIterator, 0, len(variants))
121 ngramLookups := 0
122 ngrams := d.ngrams(fileName)
123 for _, v := range variants {
124 sec := ngrams.Get(v)
125 ngramLookups++
126 blob, err := d.readSectionBlob(sec)
127 if err != nil {
128 return nil, err
129 }
130 if len(blob) > 0 {
131 iters = append(iters, newCompressedPostingIterator(blob, v))
132 }
133 }
134
135 if len(iters) == 1 {
136 // if we only return 1 then we need to include our ngramLookups stats
137 iter := (iters[0]).(*compressedPostingIterator)
138 iter.ngramLookups = ngramLookups
139 return iter, nil
140 }
141 return &mergingIterator{
142 ngramLookups: ngramLookups,
143 iters: iters,
144 }, nil
145}
146
147// inMemoryIterator is hitIterator that goes over an in-memory uint32 posting list.
148type inMemoryIterator struct {
149 postings []uint32
150 what ngram
151}
152
153func (i *inMemoryIterator) String() string {
154 return fmt.Sprintf("mem(%s):%v", i.what, i.postings)
155}
156
157func (i *inMemoryIterator) first() uint32 {
158 if len(i.postings) > 0 {
159 return i.postings[0]
160 }
161 return maxUInt32
162}
163
164func (i *inMemoryIterator) updateStats(s *zoekt.Stats) {
165}
166
167func (i *inMemoryIterator) next(limit uint32) {
168 if limit == maxUInt32 {
169 i.postings = nil
170 }
171
172 for len(i.postings) > 0 && i.postings[0] <= limit {
173 i.postings = i.postings[1:]
174 }
175}
176
177// compressedPostingIterator goes over a delta varint encoded posting
178// list.
179type compressedPostingIterator struct {
180 blob []byte
181 indexBytesLoaded int
182 ngramLookups int
183 _first uint32
184 what ngram
185}
186
187func newCompressedPostingIterator(b []byte, w ngram) *compressedPostingIterator {
188 d, sz := binary.Uvarint(b)
189 return &compressedPostingIterator{
190 _first: uint32(d),
191 blob: b[sz:],
192 indexBytesLoaded: sz,
193 what: w,
194 }
195}
196
197func (i *compressedPostingIterator) String() string {
198 return fmt.Sprintf("compressed(%s, %d, [%d bytes])", i.what, i._first, len(i.blob))
199}
200
201func (i *compressedPostingIterator) first() uint32 {
202 return i._first
203}
204
205func (i *compressedPostingIterator) next(limit uint32) {
206 if limit == maxUInt32 {
207 i.blob = nil
208 i._first = maxUInt32
209 return
210 }
211
212 for i._first <= limit && len(i.blob) > 0 {
213 delta, sz := binary.Uvarint(i.blob)
214 i._first += uint32(delta)
215 i.indexBytesLoaded += sz
216 i.blob = i.blob[sz:]
217 }
218
219 if i._first <= limit && len(i.blob) == 0 {
220 i._first = maxUInt32
221 }
222}
223
224func (i *compressedPostingIterator) updateStats(s *zoekt.Stats) {
225 s.IndexBytesLoaded += int64(i.indexBytesLoaded)
226 s.NgramLookups += i.ngramLookups
227 i.indexBytesLoaded = 0
228 i.ngramLookups = 0
229}
230
231// mergingIterator forms the merge of a set of hitIterators, to
232// implement an OR operation at the hit level.
233type mergingIterator struct {
234 iters []hitIterator
235 ngramLookups int
236}
237
238func (i *mergingIterator) String() string {
239 return fmt.Sprintf("merge:%v", i.iters)
240}
241
242func (i *mergingIterator) updateStats(s *zoekt.Stats) {
243 s.NgramLookups += i.ngramLookups
244 i.ngramLookups = 0
245 for _, j := range i.iters {
246 j.updateStats(s)
247 }
248}
249
250func (i *mergingIterator) first() uint32 {
251 r := uint32(maxUInt32)
252 for _, j := range i.iters {
253 f := j.first()
254 if f < r {
255 r = f
256 }
257 }
258
259 return r
260}
261
262func (i *mergingIterator) next(limit uint32) {
263 for _, j := range i.iters {
264 j.next(limit)
265 }
266}