fork of https://github.com/sourcegraph/zoekt
1// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package zoekt
16
17import (
18 "encoding/binary"
19 "fmt"
20)
21
22// hitIterator finds potential search matches, measured in offsets of
23// the concatenation of all documents.
24type hitIterator interface {
25 // Return the first hit, or maxUInt32 if none.
26 first() uint32
27
28 // Skip until past limit. The argument maxUInt32 should be
29 // treated specially.
30 next(limit uint32)
31
32 // Return how many bytes were read.
33 updateStats(s *Stats)
34}
35
36// distanceHitIterator looks for hits at a fixed distance apart.
37type distanceHitIterator struct {
38 started bool
39 distance uint32
40 i1 hitIterator
41 i2 hitIterator
42}
43
44func (i *distanceHitIterator) String() string {
45 return fmt.Sprintf("dist(%d, %v, %v)", i.distance, i.i1, i.i2)
46}
47
48func (i *distanceHitIterator) findNext() {
49 for {
50 var p1, p2 uint32
51 p1 = i.i1.first()
52 p2 = i.i2.first()
53 if p1 == maxUInt32 || p2 == maxUInt32 {
54 i.i1.next(maxUInt32)
55 break
56 }
57
58 if p1+i.distance < p2 {
59 i.i1.next(p2 - i.distance - 1)
60 } else if p1+i.distance > p2 {
61 i.i2.next(p1 + i.distance - 1)
62 } else {
63 break
64 }
65 }
66}
67
68func (i *distanceHitIterator) first() uint32 {
69 if !i.started {
70 i.findNext()
71 i.started = true
72 }
73 return i.i1.first()
74}
75
76func (i *distanceHitIterator) updateStats(s *Stats) {
77 i.i1.updateStats(s)
78 i.i2.updateStats(s)
79}
80
81func (i *distanceHitIterator) next(limit uint32) {
82 i.i1.next(limit)
83 l2 := limit + i.distance
84
85 if l2 < limit { // overflow.
86 l2 = maxUInt32
87 }
88 i.i2.next(l2)
89 i.findNext()
90}
91
92func (d *indexData) newDistanceTrigramIter(ng1, ng2 ngram, dist uint32, caseSensitive, fileName bool) (hitIterator, error) {
93 if dist == 0 {
94 return nil, fmt.Errorf("d == 0")
95 }
96
97 i1, err := d.trigramHitIterator(ng1, caseSensitive, fileName)
98 if err != nil {
99 return nil, err
100 }
101 i2, err := d.trigramHitIterator(ng2, caseSensitive, fileName)
102 if err != nil {
103 return nil, err
104 }
105 return &distanceHitIterator{
106 i1: i1,
107 i2: i2,
108 distance: dist,
109 }, nil
110}
111
112func (d *indexData) trigramHitIterator(ng ngram, caseSensitive, fileName bool) (hitIterator, error) {
113 if d.ngrams == nil {
114 return nil, fmt.Errorf("trigramHitIterator: ngrams=nil")
115 }
116
117 variants := []ngram{ng}
118 if !caseSensitive {
119 variants = generateCaseNgrams(ng)
120 }
121
122 iters := make([]hitIterator, 0, len(variants))
123 for _, v := range variants {
124 if fileName {
125 blob, err := d.fileNameNgrams.GetBlob(v)
126 if err != nil {
127 return nil, err
128 }
129 if len(blob) > 0 {
130 iters = append(iters, newCompressedPostingIterator(blob, v))
131 }
132 continue
133 }
134
135 sec := d.ngrams.Get(v)
136 blob, err := d.readSectionBlob(sec)
137 if err != nil {
138 return nil, err
139 }
140 if len(blob) > 0 {
141 iters = append(iters, newCompressedPostingIterator(blob, v))
142 }
143 }
144
145 if len(iters) == 1 {
146 return iters[0], nil
147 }
148 return &mergingIterator{
149 iters: iters,
150 }, nil
151}
152
153// inMemoryIterator is hitIterator that goes over an in-memory uint32 posting list.
154type inMemoryIterator struct {
155 postings []uint32
156 what ngram
157}
158
159func (i *inMemoryIterator) String() string {
160 return fmt.Sprintf("mem(%s):%v", i.what, i.postings)
161}
162
163func (i *inMemoryIterator) first() uint32 {
164 if len(i.postings) > 0 {
165 return i.postings[0]
166 }
167 return maxUInt32
168}
169
170func (i *inMemoryIterator) updateStats(s *Stats) {
171}
172
173func (i *inMemoryIterator) next(limit uint32) {
174 if limit == maxUInt32 {
175 i.postings = nil
176 }
177
178 for len(i.postings) > 0 && i.postings[0] <= limit {
179 i.postings = i.postings[1:]
180 }
181}
182
183// compressedPostingIterator goes over a delta varint encoded posting
184// list.
185type compressedPostingIterator struct {
186 blob, orig []byte
187 _first uint32
188 what ngram
189}
190
191func newCompressedPostingIterator(b []byte, w ngram) *compressedPostingIterator {
192 d, sz := binary.Uvarint(b)
193 return &compressedPostingIterator{
194 _first: uint32(d),
195 blob: b[sz:],
196 orig: b,
197 what: w,
198 }
199}
200
201func (i *compressedPostingIterator) String() string {
202 return fmt.Sprintf("compressed(%s, %d, [%d bytes])", i.what, i._first, len(i.blob))
203}
204
205func (i *compressedPostingIterator) first() uint32 {
206 return i._first
207}
208
209func (i *compressedPostingIterator) next(limit uint32) {
210 if limit == maxUInt32 {
211 i.blob = nil
212 i._first = maxUInt32
213 return
214 }
215
216 for i._first <= limit && len(i.blob) > 0 {
217 delta, sz := binary.Uvarint(i.blob)
218 i._first += uint32(delta)
219 i.blob = i.blob[sz:]
220 }
221
222 if i._first <= limit && len(i.blob) == 0 {
223 i._first = maxUInt32
224 }
225}
226
227func (i *compressedPostingIterator) updateStats(s *Stats) {
228 s.IndexBytesLoaded += int64(len(i.orig) - len(i.blob))
229}
230
231// mergingIterator forms the merge of a set of hitIterators, to
232// implement an OR operation at the hit level.
233type mergingIterator struct {
234 iters []hitIterator
235}
236
237func (i *mergingIterator) String() string {
238 return fmt.Sprintf("merge:%v", i.iters)
239}
240
241func (i *mergingIterator) updateStats(s *Stats) {
242 for _, j := range i.iters {
243 j.updateStats(s)
244 }
245}
246
247func (i *mergingIterator) first() uint32 {
248 r := uint32(maxUInt32)
249 for _, j := range i.iters {
250 f := j.first()
251 if f < r {
252 r = f
253 }
254 }
255
256 return r
257}
258
259func (i *mergingIterator) next(limit uint32) {
260 for _, j := range i.iters {
261 j.next(limit)
262 }
263}