fork of https://github.com/sourcegraph/zoekt
1package zoekt
2
3import (
4 "bytes"
5 "encoding/binary"
6 "fmt"
7 "slices"
8 "unsafe"
9)
10
11// Wire-format of map[uint32]MinimalRepoListEntry is pretty straightforward:
12//
13// byte(2) version
14// uvarint(len(minimal))
15// uvarint(sum(len(entry.Branches) for entry in minimal))
16// for repoID, entry in minimal:
17// uvarint(repoID)
18// byte(entry.HasSymbols)
19// uvarint(entry.IndexTimeUnix)
20// uvarint(len(entry.Branches))
21// for b in entry.Branches:
22// str(b.Name)
23// str(b.Version)
24//
25// Version 1 was the same, except it didn't have the IndexTimeUnix field.
26
27// reposMapEncode implements an efficient encoder for ReposMap.
28func reposMapEncode(minimal ReposMap) ([]byte, error) {
29 if minimal == nil {
30 return nil, nil
31 }
32
33 var b bytes.Buffer
34 var enc [binary.MaxVarintLen64]byte
35 varint := func(n int) {
36 m := binary.PutUvarint(enc[:], uint64(n))
37 b.Write(enc[:m])
38 }
39 str := func(s string) {
40 varint(len(s))
41 b.WriteString(s)
42 }
43 strSize := func(s string) int {
44 return binary.PutUvarint(enc[:], uint64(len(s))) + len(s)
45 }
46
47 // We calculate this up front so when decoding we only need to allocate the
48 // underlying array once.
49 allBranchesLen := 0
50 for _, entry := range minimal {
51 allBranchesLen += len(entry.Branches)
52 }
53
54 // Calculate size
55 size := 1 // version
56 size += binary.PutUvarint(enc[:], uint64(len(minimal)))
57 size += binary.PutUvarint(enc[:], uint64(allBranchesLen))
58 for repoID, entry := range minimal {
59 size += binary.PutUvarint(enc[:], uint64(repoID))
60 size += 1 // HasSymbols
61 size += binary.PutUvarint(enc[:], uint64(entry.IndexTimeUnix))
62 size += binary.PutUvarint(enc[:], uint64(len(entry.Branches)))
63 for _, b := range entry.Branches {
64 size += strSize(b.Name)
65 size += strSize(b.Version)
66 }
67 }
68 b.Grow(size)
69
70 // Version
71 b.WriteByte(2)
72
73 // Length
74 varint(len(minimal))
75
76 varint(allBranchesLen)
77
78 for repoID, entry := range minimal {
79 varint(int(repoID))
80
81 hasSymbols := byte(1)
82 if !entry.HasSymbols {
83 hasSymbols = 0
84 }
85 b.WriteByte(hasSymbols)
86
87 varint(int(entry.IndexTimeUnix))
88
89 varint(len(entry.Branches))
90 for _, b := range entry.Branches {
91 str(b.Name)
92 str(b.Version)
93 }
94 }
95
96 return b.Bytes(), nil
97}
98
99// reposMapDecode implements an efficient decoder for map[string]struct{}.
100func reposMapDecode(b []byte) (ReposMap, error) {
101 // nil input
102 if len(b) == 0 {
103 return nil, nil
104 }
105
106 // binaryReader returns strings pointing into b to avoid allocations. We
107 // don't own b, so we create a copy of it.
108 r := binaryReader{
109 typ: "ReposMap",
110 b: slices.Clone(b),
111 }
112
113 // Version
114 var readIndexTime bool
115 v := r.byt()
116 switch v {
117 case 1:
118 readIndexTime = false
119 case 2:
120 readIndexTime = true
121 default:
122 return nil, fmt.Errorf("unsupported stringSet encoding version %d", v)
123 }
124
125 // Length
126 l := r.uvarint()
127 m := make(map[uint32]MinimalRepoListEntry, l)
128
129 // Pre-allocate slice for all branches
130 allBranchesLen := r.uvarint()
131 allBranches := make([]RepositoryBranch, 0, allBranchesLen)
132
133 for range l {
134 repoID := r.uvarint()
135 hasSymbols := r.byt() == 1
136 var indexTimeUnix int64
137 if readIndexTime {
138 indexTimeUnix = int64(r.uvarint())
139 }
140 lb := r.uvarint()
141 for range lb {
142 allBranches = append(allBranches, RepositoryBranch{
143 Name: r.str(),
144 Version: r.str(),
145 })
146 }
147 branches := allBranches[len(allBranches)-lb:]
148 m[uint32(repoID)] = MinimalRepoListEntry{
149 HasSymbols: hasSymbols,
150 Branches: branches,
151 IndexTimeUnix: indexTimeUnix,
152 }
153 }
154
155 return m, r.err
156}
157
158type binaryReader struct {
159 typ string
160 b []byte
161 err error
162}
163
164func (b *binaryReader) uvarint() int {
165 x, n := binary.Uvarint(b.b)
166 if n < 0 {
167 b.b = nil
168 b.err = fmt.Errorf("malformed %s", b.typ)
169 return 0
170 }
171 b.b = b.b[n:]
172 return int(x)
173}
174
175func (b *binaryReader) str() string {
176 l := b.uvarint()
177 if l > len(b.b) {
178 b.b = nil
179 b.err = fmt.Errorf("malformed %s", b.typ)
180 return ""
181 }
182 s := b2s(b.b[:l])
183 b.b = b.b[l:]
184 return s
185}
186
187func (b *binaryReader) byt() byte {
188 if len(b.b) < 1 {
189 b.b = nil
190 b.err = fmt.Errorf("malformed %s", b.typ)
191 return 0
192 }
193 x := b.b[0]
194 b.b = b.b[1:]
195 return x
196}
197
198func b2s(b []byte) string {
199 return *(*string)(unsafe.Pointer(&b))
200}