fork of https://github.com/sourcegraph/zoekt
1// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package zoekt
16
17import (
18 "encoding/binary"
19 "fmt"
20)
21
22// hitIterator finds potential search matches, measured in offsets of
23// the concatenation of all documents.
24type hitIterator interface {
25 // Return the first hit, or maxUInt32 if none.
26 first() uint32
27
28 // Skip until past limit. The argument maxUInt32 should be
29 // treated specially.
30 next(limit uint32)
31
32 // Return how many bytes were read.
33 updateStats(s *Stats)
34}
35
36// distanceHitIterator looks for hits at a fixed distance apart.
37type distanceHitIterator struct {
38 started bool
39 distance uint32
40 i1 hitIterator
41 i2 hitIterator
42}
43
44func (i *distanceHitIterator) String() string {
45 return fmt.Sprintf("dist(%d, %v, %v)", i.distance, i.i1, i.i2)
46}
47
48func (i *distanceHitIterator) findNext() {
49 for {
50 var p1, p2 uint32
51 p1 = i.i1.first()
52 p2 = i.i2.first()
53 if p1 == maxUInt32 || p2 == maxUInt32 {
54 i.i1.next(maxUInt32)
55 break
56 }
57
58 if p1+i.distance < p2 {
59 i.i1.next(p2 - i.distance - 1)
60 } else if p1+i.distance > p2 {
61 i.i2.next(p1 + i.distance - 1)
62 } else {
63 break
64 }
65 }
66}
67
68func (i *distanceHitIterator) first() uint32 {
69 if !i.started {
70 i.findNext()
71 i.started = true
72 }
73 return i.i1.first()
74}
75
76func (i *distanceHitIterator) updateStats(s *Stats) {
77 i.i1.updateStats(s)
78 i.i2.updateStats(s)
79}
80
81func (i *distanceHitIterator) next(limit uint32) {
82 i.i1.next(limit)
83 l2 := limit + i.distance
84
85 if l2 < limit { // overflow.
86 l2 = maxUInt32
87 }
88 i.i2.next(l2)
89 i.findNext()
90}
91
92func (d *indexData) newDistanceTrigramIter(ng1, ng2 ngram, dist uint32, caseSensitive, fileName bool) (hitIterator, error) {
93 if dist == 0 {
94 return nil, fmt.Errorf("d == 0")
95 }
96
97 i1, err := d.trigramHitIterator(ng1, caseSensitive, fileName)
98 if err != nil {
99 return nil, err
100 }
101 i2, err := d.trigramHitIterator(ng2, caseSensitive, fileName)
102 if err != nil {
103 return nil, err
104 }
105 return &distanceHitIterator{
106 i1: i1,
107 i2: i2,
108 distance: dist,
109 }, nil
110}
111
112func (d *indexData) trigramHitIterator(ng ngram, caseSensitive, fileName bool) (hitIterator, error) {
113 variants := []ngram{ng}
114 if !caseSensitive {
115 variants = generateCaseNgrams(ng)
116 }
117
118 iters := make([]hitIterator, 0, len(variants))
119 ngramLookups := 0
120 ngrams := d.ngrams(fileName)
121 for _, v := range variants {
122 sec := ngrams.Get(v)
123 ngramLookups++
124 blob, err := d.readSectionBlob(sec)
125 if err != nil {
126 return nil, err
127 }
128 if len(blob) > 0 {
129 iters = append(iters, newCompressedPostingIterator(blob, v))
130 }
131 }
132
133 if len(iters) == 1 {
134 // if we only return 1 then we need to include our ngramLookups stats
135 iter := (iters[0]).(*compressedPostingIterator)
136 iter.ngramLookups = ngramLookups
137 return iter, nil
138 }
139 return &mergingIterator{
140 ngramLookups: ngramLookups,
141 iters: iters,
142 }, nil
143}
144
145// inMemoryIterator is hitIterator that goes over an in-memory uint32 posting list.
146type inMemoryIterator struct {
147 postings []uint32
148 what ngram
149}
150
151func (i *inMemoryIterator) String() string {
152 return fmt.Sprintf("mem(%s):%v", i.what, i.postings)
153}
154
155func (i *inMemoryIterator) first() uint32 {
156 if len(i.postings) > 0 {
157 return i.postings[0]
158 }
159 return maxUInt32
160}
161
162func (i *inMemoryIterator) updateStats(s *Stats) {
163}
164
165func (i *inMemoryIterator) next(limit uint32) {
166 if limit == maxUInt32 {
167 i.postings = nil
168 }
169
170 for len(i.postings) > 0 && i.postings[0] <= limit {
171 i.postings = i.postings[1:]
172 }
173}
174
175// compressedPostingIterator goes over a delta varint encoded posting
176// list.
177type compressedPostingIterator struct {
178 blob []byte
179 indexBytesLoaded int
180 ngramLookups int
181 _first uint32
182 what ngram
183}
184
185func newCompressedPostingIterator(b []byte, w ngram) *compressedPostingIterator {
186 d, sz := binary.Uvarint(b)
187 return &compressedPostingIterator{
188 _first: uint32(d),
189 blob: b[sz:],
190 indexBytesLoaded: sz,
191 what: w,
192 }
193}
194
195func (i *compressedPostingIterator) String() string {
196 return fmt.Sprintf("compressed(%s, %d, [%d bytes])", i.what, i._first, len(i.blob))
197}
198
199func (i *compressedPostingIterator) first() uint32 {
200 return i._first
201}
202
203func (i *compressedPostingIterator) next(limit uint32) {
204 if limit == maxUInt32 {
205 i.blob = nil
206 i._first = maxUInt32
207 return
208 }
209
210 for i._first <= limit && len(i.blob) > 0 {
211 delta, sz := binary.Uvarint(i.blob)
212 i._first += uint32(delta)
213 i.indexBytesLoaded += sz
214 i.blob = i.blob[sz:]
215 }
216
217 if i._first <= limit && len(i.blob) == 0 {
218 i._first = maxUInt32
219 }
220}
221
222func (i *compressedPostingIterator) updateStats(s *Stats) {
223 s.IndexBytesLoaded += int64(i.indexBytesLoaded)
224 s.NgramLookups += i.ngramLookups
225 i.indexBytesLoaded = 0
226 i.ngramLookups = 0
227}
228
229// mergingIterator forms the merge of a set of hitIterators, to
230// implement an OR operation at the hit level.
231type mergingIterator struct {
232 iters []hitIterator
233 ngramLookups int
234}
235
236func (i *mergingIterator) String() string {
237 return fmt.Sprintf("merge:%v", i.iters)
238}
239
240func (i *mergingIterator) updateStats(s *Stats) {
241 s.NgramLookups += i.ngramLookups
242 i.ngramLookups = 0
243 for _, j := range i.iters {
244 j.updateStats(s)
245 }
246}
247
248func (i *mergingIterator) first() uint32 {
249 r := uint32(maxUInt32)
250 for _, j := range i.iters {
251 f := j.first()
252 if f < r {
253 r = f
254 }
255 }
256
257 return r
258}
259
260func (i *mergingIterator) next(limit uint32) {
261 for _, j := range i.iters {
262 j.next(limit)
263 }
264}