fork of https://github.com/sourcegraph/zoekt
1package index
2
3import (
4 "crypto/sha1"
5 "fmt"
6 "io"
7 "log"
8 "os"
9 "path/filepath"
10 "runtime"
11 "sort"
12
13 "github.com/sourcegraph/zoekt"
14 "github.com/sourcegraph/zoekt/internal/tenant"
15)
16
17// Merge files into a compound shard in dstDir. Merge returns tmpName and a
18// dstName. It is the responsibility of the caller to delete the input shards and
19// rename the temporary compound shard from tmpName to dstName.
20func Merge(dstDir string, files ...IndexFile) (tmpName, dstName string, _ error) {
21 var ds []*indexData
22 for _, f := range files {
23 searcher, err := NewSearcher(f)
24 if err != nil {
25 return "", "", err
26 }
27 ds = append(ds, searcher.(*indexData))
28 }
29
30 ib, err := merge(ds...)
31 if err != nil {
32 return "", "", err
33 }
34
35 hasher := sha1.New()
36 for _, d := range ds {
37 for i, md := range d.repoMetaData {
38 if d.repoMetaData[i].Tombstone {
39 continue
40 }
41 hasher.Write([]byte(md.Name))
42 hasher.Write([]byte{0})
43 }
44 }
45
46 dstName = filepath.Join(dstDir, fmt.Sprintf("compound-%x_v%d.%05d.zoekt", hasher.Sum(nil), NextIndexFormatVersion, 0))
47 tmpName = dstName + ".tmp"
48 if err := builderWriteAll(tmpName, ib); err != nil {
49 return "", "", err
50 }
51 return tmpName, dstName, nil
52}
53
54func builderWriteAll(fn string, ib *ShardBuilder) error {
55 dir := filepath.Dir(fn)
56 if err := os.MkdirAll(dir, 0o700); err != nil {
57 return err
58 }
59
60 f, err := os.CreateTemp(dir, filepath.Base(fn)+".*.tmp")
61 if err != nil {
62 return err
63 }
64 if runtime.GOOS != "windows" {
65 // umask?
66 if err := f.Chmod(0o666); err != nil {
67 return err
68 }
69 }
70
71 defer f.Close()
72 if err := ib.Write(f); err != nil {
73 return err
74 }
75 fi, err := f.Stat()
76 if err != nil {
77 return err
78 }
79 if err := f.Close(); err != nil {
80 return err
81 }
82
83 if err := os.Rename(f.Name(), fn); err != nil {
84 return err
85 }
86
87 log.Printf("finished shard %s: %d index bytes (overhead %3.1f)", fn, fi.Size(),
88 float64(fi.Size())/float64(ib.ContentSize()+1))
89
90 return nil
91}
92
93func merge(ds ...*indexData) (*ShardBuilder, error) {
94 if len(ds) == 0 {
95 return nil, fmt.Errorf("need 1 or more indexData to merge")
96 }
97
98 sort.Slice(ds, func(i, j int) bool {
99 return ds[i].repoMetaData[0].GetPriority() > ds[j].repoMetaData[0].GetPriority()
100 })
101
102 sb := newShardBuilder()
103 sb.indexFormatVersion = NextIndexFormatVersion
104
105 for _, d := range ds {
106 lastRepoID := -1
107 for docID := uint32(0); int(docID) < len(d.fileBranchMasks); docID++ {
108 repoID := int(d.repos[docID])
109
110 if d.repoMetaData[repoID].Tombstone {
111 continue
112 }
113
114 if repoID != lastRepoID {
115 if lastRepoID > repoID {
116 return nil, fmt.Errorf("non-contiguous repo ids in %s for document %d: old=%d current=%d", d.String(), docID, lastRepoID, repoID)
117 }
118 lastRepoID = repoID
119
120 // TODO we are losing empty repos on merging since we only get here if
121 // there is an associated document.
122
123 if err := sb.setRepository(&d.repoMetaData[repoID]); err != nil {
124 return nil, err
125 }
126 }
127
128 if err := addDocument(d, sb, repoID, docID); err != nil {
129 return nil, err
130 }
131 }
132 }
133
134 return sb, nil
135}
136
137// Explode takes an input shard and creates 1 simple shard per repository. It is
138// a wrapper around explode that takes care of removing the input shard and
139// renaming the temporary shards.
140func Explode(dstDir string, inputShard string) error {
141 f, err := os.Open(inputShard)
142 if err != nil {
143 return err
144 }
145 defer f.Close()
146
147 indexFile, err := NewIndexFile(f)
148 if err != nil {
149 return err
150 }
151 defer indexFile.Close()
152
153 exploded, err := explode(dstDir, indexFile)
154 defer func() {
155 // best effort removal of tmp files. If os.Remove fails, indexserver will delete
156 // the leftover tmp files during the next cleanup.
157 for tmpFn := range exploded {
158 os.Remove(tmpFn)
159 }
160 }()
161 if err != nil {
162 return fmt.Errorf("zoekt.Explode: %w", err)
163 }
164
165 // remove the input shard first to avoid duplicate indexes. In the worst case,
166 // the process is interrupted just after we delete the compound shard, in which
167 // case we have to reindex the lost repos.
168 paths, err := IndexFilePaths(inputShard)
169 if err != nil {
170 return err
171 }
172 for _, path := range paths {
173 err = os.Remove(path)
174 if err != nil {
175 return err
176 }
177 }
178
179 // best effort rename shards.
180 for tmpFn, dstFn := range exploded {
181 if err := os.Rename(tmpFn, dstFn); err != nil {
182 log.Printf("explode: rename failed: %s", err)
183 }
184 }
185
186 return nil
187}
188
189type shardBuilderFunc func(ib *ShardBuilder)
190
191// explode takes an IndexFile f and creates 1 simple shard per repository
192// contained in f. explode returns a map of tmpName -> dstName. It is the
193// responsibility of the caller to rename the temporary shard(s) and delete the
194// input shard.
195func explode(dstDir string, f IndexFile, ibFuncs ...shardBuilderFunc) (map[string]string, error) {
196 searcher, err := NewSearcher(f)
197 if err != nil {
198 return nil, err
199 }
200 d := searcher.(*indexData)
201
202 shardNames := make(map[string]string, len(d.repoMetaData))
203
204 writeShard := func(ib *ShardBuilder) error {
205 if len(ib.repoList) != 1 {
206 return fmt.Errorf("expected sb to contain exactly 1 repository")
207 }
208 for _, ibFunc := range ibFuncs {
209 ibFunc(ib)
210 }
211
212 prefix := ""
213 if tenant.EnforceTenant() {
214 prefix = tenant.SrcPrefix(ib.repoList[0].TenantID, ib.repoList[0].ID)
215 } else {
216 prefix = ib.repoList[0].Name
217 }
218
219 shardName := ShardName(dstDir, prefix, ib.indexFormatVersion, 0)
220 shardNameTmp := shardName + ".tmp"
221 shardNames[shardNameTmp] = shardName
222 return builderWriteAll(shardNameTmp, ib)
223 }
224
225 var sb *ShardBuilder
226 lastRepoID := -1
227 for docID := uint32(0); int(docID) < len(d.fileBranchMasks); docID++ {
228 repoID := int(d.repos[docID])
229
230 if d.repoMetaData[repoID].Tombstone {
231 continue
232 }
233
234 if repoID != lastRepoID {
235 if lastRepoID > repoID {
236 return shardNames, fmt.Errorf("non-contiguous repo ids in %s for document %d: old=%d current=%d", d.String(), docID, lastRepoID, repoID)
237 }
238 lastRepoID = repoID
239
240 if sb != nil {
241 if err := writeShard(sb); err != nil {
242 return shardNames, err
243 }
244 }
245
246 sb = newShardBuilder()
247 sb.indexFormatVersion = IndexFormatVersion
248 if err := sb.setRepository(&d.repoMetaData[repoID]); err != nil {
249 return shardNames, err
250 }
251 }
252
253 err := addDocument(d, sb, repoID, docID)
254 if err != nil {
255 return shardNames, err
256 }
257 }
258
259 if sb != nil {
260 if err := writeShard(sb); err != nil {
261 return shardNames, err
262 }
263 }
264
265 return shardNames, nil
266}
267
268func addDocument(d *indexData, ib *ShardBuilder, repoID int, docID uint32) error {
269 doc := Document{
270 Name: string(d.fileName(docID)),
271 // Content set below since it can return an error
272 // Branches set below since it requires lookups
273 SubRepositoryPath: d.subRepoPaths[repoID][d.subRepos[docID]],
274 Language: d.languageMap[d.getLanguage(docID)],
275 // SkipReason not set, will be part of content from original indexer.
276 }
277
278 var err error
279 if doc.Content, err = d.readContents(docID); err != nil {
280 return err
281 }
282
283 if doc.Symbols, _, err = d.readDocSections(docID, nil); err != nil {
284 return err
285 }
286
287 doc.SymbolsMetaData = make([]*zoekt.Symbol, len(doc.Symbols))
288 for i := range doc.SymbolsMetaData {
289 doc.SymbolsMetaData[i] = d.symbols.data(d.fileEndSymbol[docID] + uint32(i))
290 }
291
292 // calculate branches
293 {
294 mask := d.fileBranchMasks[docID]
295 id := uint32(1)
296 for mask != 0 {
297 if mask&0x1 != 0 {
298 doc.Branches = append(doc.Branches, d.branchNames[repoID][uint(id)])
299 }
300 id <<= 1
301 mask >>= 1
302 }
303 }
304 return ib.Add(doc)
305}
306
307// copied from builder package to avoid circular imports.
308func hashString(s string) string {
309 h := sha1.New()
310 _, _ = io.WriteString(h, s)
311 return fmt.Sprintf("%x", h.Sum(nil))
312}