fork of https://github.com/sourcegraph/zoekt
1// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package zoekt
16
17import (
18 "encoding/binary"
19 "fmt"
20)
21
22// hitIterator finds potential search matches, measured in offsets of
23// the concatenation of all documents.
24type hitIterator interface {
25 // Return the first hit, or maxUInt32 if none.
26 first() uint32
27
28 // Skip until past limit. The argument maxUInt32 should be
29 // treated specially.
30 next(limit uint32)
31
32 // Return how many bytes were read.
33 updateStats(s *Stats)
34}
35
36// distanceHitIterator looks for hits at a fixed distance apart.
37type distanceHitIterator struct {
38 started bool
39 distance uint32
40 i1 hitIterator
41 i2 hitIterator
42}
43
44func (i *distanceHitIterator) String() string {
45 return fmt.Sprintf("dist(%d, %v, %v)", i.distance, i.i1, i.i2)
46}
47
48func (i *distanceHitIterator) findNext() {
49 for {
50 var p1, p2 uint32
51 p1 = i.i1.first()
52 p2 = i.i2.first()
53 if p1 == maxUInt32 || p2 == maxUInt32 {
54 i.i1.next(maxUInt32)
55 break
56 }
57
58 if p1+i.distance < p2 {
59 i.i1.next(p2 - i.distance - 1)
60 } else if p1+i.distance > p2 {
61 i.i2.next(p1 + i.distance - 1)
62 } else {
63 break
64 }
65 }
66}
67
68func (i *distanceHitIterator) first() uint32 {
69 if !i.started {
70 i.findNext()
71 i.started = true
72 }
73 return i.i1.first()
74}
75
76func (i *distanceHitIterator) updateStats(s *Stats) {
77 i.i1.updateStats(s)
78 i.i2.updateStats(s)
79}
80
81func (i *distanceHitIterator) next(limit uint32) {
82 i.i1.next(limit)
83 l2 := limit + i.distance
84
85 if l2 < limit { // overflow.
86 l2 = maxUInt32
87 }
88 i.i2.next(l2)
89 i.findNext()
90}
91
92func (d *indexData) newDistanceTrigramIter(ng1, ng2 ngram, dist uint32, caseSensitive, fileName bool) (hitIterator, error) {
93 if dist == 0 {
94 return nil, fmt.Errorf("d == 0")
95 }
96
97 i1, err := d.trigramHitIterator(ng1, caseSensitive, fileName)
98 if err != nil {
99 return nil, err
100 }
101 i2, err := d.trigramHitIterator(ng2, caseSensitive, fileName)
102 if err != nil {
103 return nil, err
104 }
105 return &distanceHitIterator{
106 i1: i1,
107 i2: i2,
108 distance: dist,
109 }, nil
110}
111
112func (d *indexData) trigramHitIterator(ng ngram, caseSensitive, fileName bool) (hitIterator, error) {
113 variants := []ngram{ng}
114 if !caseSensitive {
115 variants = generateCaseNgrams(ng)
116 }
117
118 iters := make([]hitIterator, 0, len(variants))
119 ngramLookups := 0
120 for _, v := range variants {
121 if fileName {
122 blob, err := d.fileNameNgrams.GetBlob(v)
123 ngramLookups++
124 if err != nil {
125 return nil, err
126 }
127 if len(blob) > 0 {
128 iters = append(iters, newCompressedPostingIterator(blob, v))
129 }
130 continue
131 }
132
133 sec := d.ngrams.Get(v)
134 ngramLookups++
135 blob, err := d.readSectionBlob(sec)
136 if err != nil {
137 return nil, err
138 }
139 if len(blob) > 0 {
140 iters = append(iters, newCompressedPostingIterator(blob, v))
141 }
142 }
143
144 if len(iters) == 1 {
145 // if we only return 1 then we need to include our ngramLookups stats
146 iter := (iters[0]).(*compressedPostingIterator)
147 iter.ngramLookups = ngramLookups
148 return iter, nil
149 }
150 return &mergingIterator{
151 ngramLookups: ngramLookups,
152 iters: iters,
153 }, nil
154}
155
156// inMemoryIterator is hitIterator that goes over an in-memory uint32 posting list.
157type inMemoryIterator struct {
158 postings []uint32
159 what ngram
160}
161
162func (i *inMemoryIterator) String() string {
163 return fmt.Sprintf("mem(%s):%v", i.what, i.postings)
164}
165
166func (i *inMemoryIterator) first() uint32 {
167 if len(i.postings) > 0 {
168 return i.postings[0]
169 }
170 return maxUInt32
171}
172
173func (i *inMemoryIterator) updateStats(s *Stats) {
174}
175
176func (i *inMemoryIterator) next(limit uint32) {
177 if limit == maxUInt32 {
178 i.postings = nil
179 }
180
181 for len(i.postings) > 0 && i.postings[0] <= limit {
182 i.postings = i.postings[1:]
183 }
184}
185
186// compressedPostingIterator goes over a delta varint encoded posting
187// list.
188type compressedPostingIterator struct {
189 blob []byte
190 indexBytesLoaded int
191 ngramLookups int
192 _first uint32
193 what ngram
194}
195
196func newCompressedPostingIterator(b []byte, w ngram) *compressedPostingIterator {
197 d, sz := binary.Uvarint(b)
198 return &compressedPostingIterator{
199 _first: uint32(d),
200 blob: b[sz:],
201 indexBytesLoaded: sz,
202 what: w,
203 }
204}
205
206func (i *compressedPostingIterator) String() string {
207 return fmt.Sprintf("compressed(%s, %d, [%d bytes])", i.what, i._first, len(i.blob))
208}
209
210func (i *compressedPostingIterator) first() uint32 {
211 return i._first
212}
213
214func (i *compressedPostingIterator) next(limit uint32) {
215 if limit == maxUInt32 {
216 i.blob = nil
217 i._first = maxUInt32
218 return
219 }
220
221 for i._first <= limit && len(i.blob) > 0 {
222 delta, sz := binary.Uvarint(i.blob)
223 i._first += uint32(delta)
224 i.indexBytesLoaded += sz
225 i.blob = i.blob[sz:]
226 }
227
228 if i._first <= limit && len(i.blob) == 0 {
229 i._first = maxUInt32
230 }
231}
232
233func (i *compressedPostingIterator) updateStats(s *Stats) {
234 s.IndexBytesLoaded += int64(i.indexBytesLoaded)
235 s.NgramLookups += i.ngramLookups
236 i.indexBytesLoaded = 0
237 i.ngramLookups = 0
238}
239
240// mergingIterator forms the merge of a set of hitIterators, to
241// implement an OR operation at the hit level.
242type mergingIterator struct {
243 iters []hitIterator
244 ngramLookups int
245}
246
247func (i *mergingIterator) String() string {
248 return fmt.Sprintf("merge:%v", i.iters)
249}
250
251func (i *mergingIterator) updateStats(s *Stats) {
252 s.NgramLookups += i.ngramLookups
253 i.ngramLookups = 0
254 for _, j := range i.iters {
255 j.updateStats(s)
256 }
257}
258
259func (i *mergingIterator) first() uint32 {
260 r := uint32(maxUInt32)
261 for _, j := range i.iters {
262 f := j.first()
263 if f < r {
264 r = f
265 }
266 }
267
268 return r
269}
270
271func (i *mergingIterator) next(limit uint32) {
272 for _, j := range i.iters {
273 j.next(limit)
274 }
275}