fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

ranking: add IDF to BM25 score calculation (#788)

So far, we didn't include IDF in our BM25 score function. Zoekt uses a
trigram index and hence doesn't compute document frequency during
indexing. We could add this information to the index, but it is not
immediately obvious how to tokenize code in a way that is compatible
with tokens from a natural language query.

Here we calulate the document frequency at query time under the
assumption that we visit all documents containing any of the query terms.

Notes:
Also fixed an off-by-1 bug with how we count documents.

Test plan:
- Updated unit test
- Context evaluation results are slightly worse with a decrease from 64/89 to 63/89

author
Stefan Hengl
committer
GitHub
date (Jun 10, 2024, 12:31 PM +0200) commit 376af3a6 parent b41ceb2c
+156 -49
+9 -3
api.go
··· 946 946 // will be used. This option is temporary and is only exposed for testing/ tuning purposes. 947 947 DocumentRanksWeight float64 948 948 949 - // EXPERIMENTAL. If true, use text-search style scoring instead of the default scoring formula. 950 - // The scoring algorithm treats each match in a file as a term and computes an approximation to 951 - // BM25. When enabled, all other scoring signals are ignored, including document ranks. 949 + // EXPERIMENTAL. If true, use text-search style scoring instead of the default 950 + // scoring formula. The scoring algorithm treats each match in a file as a term 951 + // and computes an approximation to BM25. 952 + // 953 + // The calculation of IDF assumes that Zoekt visits all documents containing any 954 + // of the query terms during evaluation. This is true, for example, if all query 955 + // terms are ORed together. 956 + // 957 + // When enabled, all other scoring signals are ignored, including document ranks. 952 958 UseBM25Scoring bool 953 959 954 960 // Trace turns on opentracing for this request if true and if the Jaeger address was provided as
+8 -8
build/scoring_test.go
··· 77 77 query: &query.Substring{Pattern: "example"}, 78 78 content: exampleJava, 79 79 language: "Java", 80 - // bm25-score:1.69 (sum-tf: 7.00, length-ratio: 2.00) 81 - wantScore: 1.82, 80 + // bm25-score: 0.57 <- sum-termFrequencyScore: 10.00, length-ratio: 1.00 81 + wantScore: 0.57, 82 82 }, { 83 83 // Matches only on content 84 84 fileName: "example.java", ··· 89 89 }}, 90 90 content: exampleJava, 91 91 language: "Java", 92 - // bm25-score:5.75 (sum-tf: 56.00, length-ratio: 2.00) 93 - wantScore: 5.75, 92 + // bm25-score: 1.75 <- sum-termFrequencyScore: 56.00, length-ratio: 1.00 93 + wantScore: 1.75, 94 94 }, 95 95 { 96 96 // Matches only on filename ··· 98 98 query: &query.Substring{Pattern: "java"}, 99 99 content: exampleJava, 100 100 language: "Java", 101 - // bm25-score:1.07 (sum-tf: 2.00, length-ratio: 2.00) 102 - wantScore: 1.55, 101 + // bm25-score: 0.51 <- sum-termFrequencyScore: 5.00, length-ratio: 1.00 102 + wantScore: 0.51, 103 103 }, 104 104 { 105 105 // Matches only on filename, and content is missing 106 106 fileName: "a/b/c/config.go", 107 107 query: &query.Substring{Pattern: "config.go"}, 108 108 language: "Go", 109 - // bm25-score:1.91 (sum-tf: 2.00, length-ratio: 0.00) 110 - wantScore: 2.08, 109 + // bm25-score: 0.60 <- sum-termFrequencyScore: 5.00, length-ratio: 0.00 110 + wantScore: 0.60, 111 111 }, 112 112 } 113 113
+28 -4
eval.go
··· 197 197 docCount := uint32(len(d.fileBranchMasks)) 198 198 lastDoc := int(-1) 199 199 200 + // document frequency per term 201 + df := make(termDocumentFrequency) 202 + 203 + // term frequency per file match 204 + var tfs []termFrequency 205 + 200 206 nextFileMatch: 201 207 for { 202 208 canceled := false ··· 317 323 fileMatch.LineMatches = cp.fillMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts.DebugScore) 318 324 } 319 325 326 + var tf map[string]int 320 327 if opts.UseBM25Scoring { 321 - d.scoreFileUsingBM25(&fileMatch, nextDoc, finalCands, opts) 328 + // For BM25 scoring, the calculation of the score is split in two parts. Here we 329 + // calculate the term frequencies for the current document and update the 330 + // document frequencies. Since we don't store document frequencies in the index, 331 + // we have to defer the calculation of the final BM25 score to after the whole 332 + // shard has been processed. 333 + tf = calculateTermFrequency(finalCands, df) 322 334 } else { 323 335 // Use the standard, non-experimental scoring method by default 324 336 d.scoreFile(&fileMatch, nextDoc, mt, known, opts) ··· 339 351 repoMatchCount += len(fileMatch.LineMatches) 340 352 repoMatchCount += matchedChunkRanges 341 353 342 - if opts.DebugScore { 343 - fileMatch.Debug = fmt.Sprintf("score:%.2f <- %s", fileMatch.Score, fileMatch.Debug) 354 + if opts.UseBM25Scoring { 355 + // Invariant: tfs[i] belongs to res.Files[i] 356 + tfs = append(tfs, termFrequency{ 357 + doc: nextDoc, 358 + tf: tf, 359 + }) 344 360 } 361 + res.Files = append(res.Files, fileMatch) 345 362 346 - res.Files = append(res.Files, fileMatch) 347 363 res.Stats.MatchCount += len(fileMatch.LineMatches) 348 364 res.Stats.MatchCount += matchedChunkRanges 349 365 res.Stats.FileCount++ 366 + } 367 + 368 + // Calculate BM25 score for all file matches in the shard. We assume that we 369 + // have seen all documents containing any of the terms in the query so that df 370 + // correctly reflects the document frequencies. This is true, for example, if 371 + // all terms in the query are ORed together. 372 + if opts.UseBM25Scoring { 373 + d.scoreFilesUsingBM25(res.Files, tfs, df, opts) 350 374 } 351 375 352 376 for _, md := range d.repoMetaData {
+60 -34
score.go
··· 39 39 m.Score += computed 40 40 } 41 41 42 - func (m *FileMatch) addBM25Score(score float64, sumTf float64, L float64, debugScore bool) { 43 - if debugScore { 44 - m.Debug += fmt.Sprintf("bm25-score:%.2f (sum-tf: %.2f, length-ratio: %.2f)", score, sumTf, L) 45 - } 46 - m.Score += score 47 - } 48 - 49 42 // scoreFile computes a score for the file match using various scoring signals, like 50 43 // whether there's an exact match on a symbol, the number of query clauses that matched, etc. 51 44 func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, known map[matchTree]bool, opts *SearchOptions) { ··· 111 104 addScore("repo-rank", scoreRepoRankFactor*float64(md.Rank)/maxUInt16) 112 105 113 106 if opts.DebugScore { 114 - fileMatch.Debug = strings.TrimSuffix(fileMatch.Debug, ", ") 107 + fileMatch.Debug = fmt.Sprintf("score: %.2f <- %s", fileMatch.Score, strings.TrimSuffix(fileMatch.Debug, ", ")) 115 108 } 116 109 } 117 110 118 - // scoreFileUsingBM25 computes a score for the file match using an approximation to BM25, the most common scoring 119 - // algorithm for text search: https://en.wikipedia.org/wiki/Okapi_BM25. It implements all parts of the formula 120 - // except inverse document frequency (idf), since we don't have access to global term frequency statistics. 111 + // calculateTermFrequency computes the term frequency for the file match. 121 112 // 122 - // Filename matches count twice as much as content matches. This mimics a common text search strategy where you 123 - // 'boost' matches on document titles. 124 - // 125 - // This scoring strategy ignores all other signals including document ranks. This keeps things simple for now, 126 - // since BM25 is not normalized and can be tricky to combine with other scoring signals. 127 - func (d *indexData) scoreFileUsingBM25(fileMatch *FileMatch, doc uint32, cands []*candidateMatch, opts *SearchOptions) { 113 + // Filename matches count more than content matches. This mimics a common text 114 + // search strategy where you 'boost' matches on document titles. 115 + func calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int { 128 116 // Treat each candidate match as a term and compute the frequencies. For now, ignore case 129 117 // sensitivity and treat filenames and symbols the same as content. 130 118 termFreqs := map[string]int{} 131 119 for _, cand := range cands { 132 120 term := string(cand.substrLowered) 133 - 134 121 if cand.fileName { 135 122 termFreqs[term] += 5 136 123 } else { ··· 138 125 } 139 126 } 140 127 141 - // Compute the file length ratio. Usually the calculation would be based on terms, but using 142 - // bytes should work fine, as we're just computing a ratio. 143 - fileLength := float64(d.boundaries[doc+1] - d.boundaries[doc]) 144 - numFiles := len(d.boundaries) 145 - averageFileLength := float64(d.boundaries[numFiles-1]) / float64(numFiles) 128 + for term := range termFreqs { 129 + df[term] += 1 130 + } 146 131 132 + return termFreqs 133 + } 134 + 135 + // idf computes the inverse document frequency for a term. nq is the number of 136 + // documents that contain the term and documentCount is the total number of 137 + // documents in the corpus. 138 + func idf(nq, documentCount int) float64 { 139 + return math.Log(1.0 + ((float64(documentCount) - float64(nq) + 0.5) / (float64(nq) + 0.5))) 140 + } 141 + 142 + // termDocumentFrequency is a map "term" -> "number of documents that contain the term" 143 + type termDocumentFrequency map[string]int 144 + 145 + // termFrequency stores the term frequencies for doc. 146 + type termFrequency struct { 147 + doc uint32 148 + tf map[string]int 149 + } 150 + 151 + // scoreFilesUsingBM25 computes the score according to BM25, the most common 152 + // scoring algorithm for text search: https://en.wikipedia.org/wiki/Okapi_BM25. 153 + // 154 + // This scoring strategy ignores all other signals including document ranks. 155 + // This keeps things simple for now, since BM25 is not normalized and can be 156 + // tricky to combine with other scoring signals. 157 + func (d *indexData) scoreFilesUsingBM25(fileMatches []FileMatch, tfs []termFrequency, df termDocumentFrequency, opts *SearchOptions) { 158 + // Use standard parameter defaults (used in Lucene and academic papers) 159 + k, b := 1.2, 0.75 160 + 161 + averageFileLength := float64(d.boundaries[d.numDocs()]) / float64(d.numDocs()) 147 162 // This is very unlikely, but explicitly guard against division by zero. 148 163 if averageFileLength == 0 { 149 164 averageFileLength++ 150 165 } 151 - L := fileLength / averageFileLength 166 + 167 + for i := range tfs { 168 + score := 0.0 169 + 170 + // Compute the file length ratio. Usually the calculation would be based on terms, but using 171 + // bytes should work fine, as we're just computing a ratio. 172 + doc := tfs[i].doc 173 + fileLength := float64(d.boundaries[doc+1] - d.boundaries[doc]) 174 + 175 + L := fileLength / averageFileLength 152 176 153 - // Use standard parameter defaults (used in Lucene and academic papers) 154 - k, b := 1.2, 0.75 155 - sumTf := 0.0 // Just for debugging 156 - score := 0.0 157 - for _, freq := range termFreqs { 158 - tf := float64(freq) 159 - sumTf += tf 160 - score += ((k + 1.0) * tf) / (k*(1.0-b+b*L) + tf) 177 + sumTF := 0 // Just for debugging 178 + for term, f := range tfs[i].tf { 179 + sumTF += f 180 + tfScore := ((k + 1.0) * float64(f)) / (k*(1.0-b+b*L) + float64(f)) 181 + score += idf(df[term], int(d.numDocs())) * tfScore 182 + } 183 + 184 + fileMatches[i].Score = score 185 + 186 + if opts.DebugScore { 187 + fileMatches[i].Debug = fmt.Sprintf("bm25-score: %.2f <- sum-termFrequencies: %d, length-ratio: %.2f", score, sumTF, L) 188 + } 161 189 } 162 - 163 - fileMatch.addBM25Score(score, sumTf, L, opts.DebugScore) 164 190 }
+51
score_test.go
··· 1 + package zoekt 2 + 3 + import ( 4 + "maps" 5 + "testing" 6 + ) 7 + 8 + func TestCalculateTermFrequency(t *testing.T) { 9 + cases := []struct { 10 + cands []*candidateMatch 11 + wantDF termDocumentFrequency 12 + wantTermFrequencies map[string]int 13 + }{{ 14 + cands: []*candidateMatch{ 15 + {substrLowered: []byte("foo")}, 16 + {substrLowered: []byte("foo")}, 17 + {substrLowered: []byte("bar")}, 18 + { 19 + substrLowered: []byte("bas"), 20 + fileName: true, 21 + }, 22 + }, 23 + wantDF: termDocumentFrequency{ 24 + "foo": 1, 25 + "bar": 1, 26 + "bas": 1, 27 + }, 28 + wantTermFrequencies: map[string]int{ 29 + "foo": 2, 30 + "bar": 1, 31 + "bas": 5, 32 + }, 33 + }, 34 + } 35 + 36 + for _, c := range cases { 37 + t.Run("", func(t *testing.T) { 38 + fm := FileMatch{} 39 + df := make(termDocumentFrequency) 40 + tf := calculateTermFrequency(c.cands, df) 41 + 42 + if !maps.Equal(df, c.wantDF) { 43 + t.Errorf("got %v, want %v", df, c.wantDF) 44 + } 45 + 46 + if !maps.Equal(tf, c.wantTermFrequencies) { 47 + t.Errorf("got %v, want %v", fm, c.wantTermFrequencies) 48 + } 49 + }) 50 + } 51 + }