fork of https://github.com/sourcegraph/zoekt
1// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package zoekt
16
17import (
18 "encoding/binary"
19 "fmt"
20)
21
22// hitIterator finds potential search matches, measured in offsets of
23// the concatenation of all documents.
24type hitIterator interface {
25 // Return the first hit, or maxUInt32 if none.
26 first() uint32
27
28 // Skip until past limit. The argument maxUInt32 should be
29 // treated specially.
30 next(limit uint32)
31
32 // Return how many bytes were read.
33 updateStats(s *Stats)
34}
35
36// distanceHitIterator looks for hits at a fixed distance apart.
37type distanceHitIterator struct {
38 started bool
39 distance uint32
40 i1 hitIterator
41 i2 hitIterator
42}
43
44func (i *distanceHitIterator) String() string {
45 return fmt.Sprintf("dist(%d, %v, %v)", i.distance, i.i1, i.i2)
46}
47
48func (i *distanceHitIterator) findNext() {
49 for {
50 var p1, p2 uint32
51 p1 = i.i1.first()
52 p2 = i.i2.first()
53 if p1 == maxUInt32 || p2 == maxUInt32 {
54 i.i1.next(maxUInt32)
55 break
56 }
57
58 if p1+i.distance < p2 {
59 i.i1.next(p2 - i.distance - 1)
60 } else if p1+i.distance > p2 {
61 i.i2.next(p1 + i.distance - 1)
62 } else {
63 break
64 }
65 }
66}
67
68func (i *distanceHitIterator) first() uint32 {
69 if !i.started {
70 i.findNext()
71 i.started = true
72 }
73 return i.i1.first()
74}
75
76func (i *distanceHitIterator) updateStats(s *Stats) {
77 i.i1.updateStats(s)
78 i.i2.updateStats(s)
79}
80
81func (i *distanceHitIterator) next(limit uint32) {
82 i.i1.next(limit)
83 l2 := limit + i.distance
84
85 if l2 < limit { // overflow.
86 l2 = maxUInt32
87 }
88 i.i2.next(l2)
89 i.findNext()
90}
91
92func (d *indexData) newDistanceTrigramIter(ng1, ng2 ngram, dist uint32, caseSensitive, fileName bool) (hitIterator, error) {
93 if dist == 0 {
94 return nil, fmt.Errorf("d == 0")
95 }
96
97 i1, err := d.trigramHitIterator(ng1, caseSensitive, fileName)
98 if err != nil {
99 return nil, err
100 }
101 i2, err := d.trigramHitIterator(ng2, caseSensitive, fileName)
102 if err != nil {
103 return nil, err
104 }
105 return &distanceHitIterator{
106 i1: i1,
107 i2: i2,
108 distance: dist,
109 }, nil
110}
111
112func (d *indexData) trigramHitIterator(ng ngram, caseSensitive, fileName bool) (hitIterator, error) {
113 if d.ngrams == nil {
114 return nil, fmt.Errorf("trigramHitIterator: ngrams=nil")
115 }
116
117 variants := []ngram{ng}
118 if !caseSensitive {
119 variants = generateCaseNgrams(ng)
120 }
121
122 iters := make([]hitIterator, 0, len(variants))
123 ngramLookups := 0
124 for _, v := range variants {
125 if fileName {
126 blob, err := d.fileNameNgrams.GetBlob(v)
127 ngramLookups++
128 if err != nil {
129 return nil, err
130 }
131 if len(blob) > 0 {
132 iters = append(iters, newCompressedPostingIterator(blob, v))
133 }
134 continue
135 }
136
137 sec := d.ngrams.Get(v)
138 ngramLookups++
139 blob, err := d.readSectionBlob(sec)
140 if err != nil {
141 return nil, err
142 }
143 if len(blob) > 0 {
144 iters = append(iters, newCompressedPostingIterator(blob, v))
145 }
146 }
147
148 if len(iters) == 1 {
149 // if we only return 1 then we need to include our ngramLookups stats
150 iter := (iters[0]).(*compressedPostingIterator)
151 iter.ngramLookups = ngramLookups
152 return iter, nil
153 }
154 return &mergingIterator{
155 ngramLookups: ngramLookups,
156 iters: iters,
157 }, nil
158}
159
160// inMemoryIterator is hitIterator that goes over an in-memory uint32 posting list.
161type inMemoryIterator struct {
162 postings []uint32
163 what ngram
164}
165
166func (i *inMemoryIterator) String() string {
167 return fmt.Sprintf("mem(%s):%v", i.what, i.postings)
168}
169
170func (i *inMemoryIterator) first() uint32 {
171 if len(i.postings) > 0 {
172 return i.postings[0]
173 }
174 return maxUInt32
175}
176
177func (i *inMemoryIterator) updateStats(s *Stats) {
178}
179
180func (i *inMemoryIterator) next(limit uint32) {
181 if limit == maxUInt32 {
182 i.postings = nil
183 }
184
185 for len(i.postings) > 0 && i.postings[0] <= limit {
186 i.postings = i.postings[1:]
187 }
188}
189
190// compressedPostingIterator goes over a delta varint encoded posting
191// list.
192type compressedPostingIterator struct {
193 blob []byte
194 indexBytesLoaded int
195 ngramLookups int
196 _first uint32
197 what ngram
198}
199
200func newCompressedPostingIterator(b []byte, w ngram) *compressedPostingIterator {
201 d, sz := binary.Uvarint(b)
202 return &compressedPostingIterator{
203 _first: uint32(d),
204 blob: b[sz:],
205 indexBytesLoaded: sz,
206 what: w,
207 }
208}
209
210func (i *compressedPostingIterator) String() string {
211 return fmt.Sprintf("compressed(%s, %d, [%d bytes])", i.what, i._first, len(i.blob))
212}
213
214func (i *compressedPostingIterator) first() uint32 {
215 return i._first
216}
217
218func (i *compressedPostingIterator) next(limit uint32) {
219 if limit == maxUInt32 {
220 i.blob = nil
221 i._first = maxUInt32
222 return
223 }
224
225 for i._first <= limit && len(i.blob) > 0 {
226 delta, sz := binary.Uvarint(i.blob)
227 i._first += uint32(delta)
228 i.indexBytesLoaded += sz
229 i.blob = i.blob[sz:]
230 }
231
232 if i._first <= limit && len(i.blob) == 0 {
233 i._first = maxUInt32
234 }
235}
236
237func (i *compressedPostingIterator) updateStats(s *Stats) {
238 s.IndexBytesLoaded += int64(i.indexBytesLoaded)
239 s.NgramLookups += i.ngramLookups
240 i.indexBytesLoaded = 0
241 i.ngramLookups = 0
242}
243
244// mergingIterator forms the merge of a set of hitIterators, to
245// implement an OR operation at the hit level.
246type mergingIterator struct {
247 iters []hitIterator
248 ngramLookups int
249}
250
251func (i *mergingIterator) String() string {
252 return fmt.Sprintf("merge:%v", i.iters)
253}
254
255func (i *mergingIterator) updateStats(s *Stats) {
256 s.NgramLookups += i.ngramLookups
257 i.ngramLookups = 0
258 for _, j := range i.iters {
259 j.updateStats(s)
260 }
261}
262
263func (i *mergingIterator) first() uint32 {
264 r := uint32(maxUInt32)
265 for _, j := range i.iters {
266 f := j.first()
267 if f < r {
268 r = f
269 }
270 }
271
272 return r
273}
274
275func (i *mergingIterator) next(limit uint32) {
276 for _, j := range i.iters {
277 j.next(limit)
278 }
279}