fork of https://github.com/sourcegraph/zoekt
1// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package zoekt
16
17import (
18 "encoding/binary"
19 "fmt"
20)
21
22// hitIterator finds potential search matches, measured in offsets of
23// the concatenation of all documents.
24type hitIterator interface {
25 // Return the first hit, or maxUInt32 if none.
26 first() uint32
27
28 // Skip until past limit. The argument maxUInt32 should be
29 // treated specially.
30 next(limit uint32)
31
32 // Return how many bytes were read.
33 updateStats(s *Stats)
34}
35
36// distanceHitIterator looks for hits at a fixed distance apart.
37type distanceHitIterator struct {
38 started bool
39 distance uint32
40 i1 hitIterator
41 i2 hitIterator
42}
43
44func (i *distanceHitIterator) String() string {
45 return fmt.Sprintf("dist(%d, %v, %v)", i.distance, i.i1, i.i2)
46}
47
48func (i *distanceHitIterator) findNext() {
49 for {
50 var p1, p2 uint32
51 p1 = i.i1.first()
52 p2 = i.i2.first()
53 if p1 == maxUInt32 || p2 == maxUInt32 {
54 i.i1.next(maxUInt32)
55 break
56 }
57
58 if p1+i.distance < p2 {
59 i.i1.next(p2 - i.distance - 1)
60 } else if p1+i.distance > p2 {
61 i.i2.next(p1 + i.distance - 1)
62 } else {
63 break
64 }
65 }
66}
67
68func (i *distanceHitIterator) first() uint32 {
69 if !i.started {
70 i.findNext()
71 i.started = true
72 }
73 return i.i1.first()
74}
75
76func (i *distanceHitIterator) updateStats(s *Stats) {
77 i.i1.updateStats(s)
78 i.i2.updateStats(s)
79}
80
81func (i *distanceHitIterator) next(limit uint32) {
82 i.i1.next(limit)
83 l2 := limit + i.distance
84
85 if l2 < limit { // overflow.
86 l2 = maxUInt32
87 }
88 i.i2.next(l2)
89 i.findNext()
90}
91
92func (d *indexData) newDistanceTrigramIter(ng1, ng2 ngram, dist uint32, caseSensitive, fileName bool) (hitIterator, error) {
93 if dist == 0 {
94 return nil, fmt.Errorf("d == 0")
95 }
96
97 i1, err := d.trigramHitIterator(ng1, caseSensitive, fileName)
98 if err != nil {
99 return nil, err
100 }
101 i2, err := d.trigramHitIterator(ng2, caseSensitive, fileName)
102 if err != nil {
103 return nil, err
104 }
105 return &distanceHitIterator{
106 i1: i1,
107 i2: i2,
108 distance: dist,
109 }, nil
110}
111
112func (d *indexData) trigramHitIterator(ng ngram, caseSensitive, fileName bool) (hitIterator, error) {
113 variants := []ngram{ng}
114 if !caseSensitive {
115 variants = generateCaseNgrams(ng)
116 }
117
118 iters := make([]hitIterator, 0, len(variants))
119 for _, v := range variants {
120 if fileName {
121 blob := d.fileNameNgrams[v]
122 if len(blob) > 0 {
123 iters = append(iters, newCompressedPostingIterator(blob, v))
124 }
125 continue
126 }
127
128 sec := d.ngrams.Get(v)
129 blob, err := d.readSectionBlob(sec)
130 if err != nil {
131 return nil, err
132 }
133 if len(blob) > 0 {
134 iters = append(iters, newCompressedPostingIterator(blob, v))
135 }
136 }
137
138 if len(iters) == 1 {
139 return iters[0], nil
140 }
141 return &mergingIterator{
142 iters: iters,
143 }, nil
144}
145
146// inMemoryIterator is hitIterator that goes over an in-memory uint32 posting list.
147type inMemoryIterator struct {
148 postings []uint32
149 what ngram
150}
151
152func (i *inMemoryIterator) String() string {
153 return fmt.Sprintf("mem(%s):%v", i.what, i.postings)
154}
155
156func (i *inMemoryIterator) first() uint32 {
157 if len(i.postings) > 0 {
158 return i.postings[0]
159 }
160 return maxUInt32
161}
162
163func (i *inMemoryIterator) updateStats(s *Stats) {
164}
165
166func (i *inMemoryIterator) next(limit uint32) {
167 if limit == maxUInt32 {
168 i.postings = nil
169 }
170
171 for len(i.postings) > 0 && i.postings[0] <= limit {
172 i.postings = i.postings[1:]
173 }
174}
175
176// compressedPostingIterator goes over a delta varint encoded posting
177// list.
178type compressedPostingIterator struct {
179 blob, orig []byte
180 _first uint32
181 what ngram
182}
183
184func newCompressedPostingIterator(b []byte, w ngram) *compressedPostingIterator {
185 d, sz := binary.Uvarint(b)
186 return &compressedPostingIterator{
187 _first: uint32(d),
188 blob: b[sz:],
189 orig: b,
190 what: w,
191 }
192}
193
194func (i *compressedPostingIterator) String() string {
195 return fmt.Sprintf("compressed(%s, %d, [%d bytes])", i.what, i._first, len(i.blob))
196}
197
198func (i *compressedPostingIterator) first() uint32 {
199 return i._first
200}
201
202func (i *compressedPostingIterator) next(limit uint32) {
203 if limit == maxUInt32 {
204 i.blob = nil
205 i._first = maxUInt32
206 return
207 }
208
209 for i._first <= limit && len(i.blob) > 0 {
210 delta, sz := binary.Uvarint(i.blob)
211 i._first += uint32(delta)
212 i.blob = i.blob[sz:]
213 }
214
215 if i._first <= limit && len(i.blob) == 0 {
216 i._first = maxUInt32
217 }
218}
219
220func (i *compressedPostingIterator) updateStats(s *Stats) {
221 s.IndexBytesLoaded += int64(len(i.orig) - len(i.blob))
222}
223
224// mergingIterator forms the merge of a set of hitIterators, to
225// implement an OR operation at the hit level.
226type mergingIterator struct {
227 iters []hitIterator
228}
229
230func (i *mergingIterator) String() string {
231 return fmt.Sprintf("merge:%v", i.iters)
232}
233
234func (i *mergingIterator) updateStats(s *Stats) {
235 for _, j := range i.iters {
236 j.updateStats(s)
237 }
238}
239
240func (i *mergingIterator) first() uint32 {
241 r := uint32(maxUInt32)
242 for _, j := range i.iters {
243 f := j.first()
244 if f < r {
245 r = f
246 }
247 }
248
249 return r
250}
251
252func (i *mergingIterator) next(limit uint32) {
253 for _, j := range i.iters {
254 j.next(limit)
255 }
256}