From 5a5f3ded2e349fff6b873fc23a3e94569de8f812 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 10 Dec 2024 14:57:06 -0800 Subject: [PATCH] Boost symbol matches in BM25 --- build/scoring_test.go | 8 +++---- contentprovider.go | 39 +++++++++++++++++++++++++++++++++ eval.go | 2 +- score.go | 24 -------------------- score_test.go | 51 ------------------------------------------- 5 files changed, 44 insertions(+), 80 deletions(-) delete mode 100644 score_test.go diff --git a/build/scoring_test.go b/build/scoring_test.go index 37bb55e2d..ac703e242 100644 --- a/build/scoring_test.go +++ b/build/scoring_test.go @@ -77,8 +77,8 @@ func TestBM25(t *testing.T) { query: &query.Substring{Pattern: "example"}, content: exampleJava, language: "Java", - // bm25-score: 0.57 <- sum-termFrequencyScore: 10.00, length-ratio: 1.00 - wantScore: 0.57, + // bm25-score: 0.58 <- sum-termFrequencyScore: 14.00, length-ratio: 1.00 + wantScore: 0.58, }, { // Matches only on content fileName: "example.java", @@ -89,8 +89,8 @@ func TestBM25(t *testing.T) { }}, content: exampleJava, language: "Java", - // bm25-score: 1.75 <- sum-termFrequencyScore: 56.00, length-ratio: 1.00 - wantScore: 1.75, + // bm25-score: 1.81 <- sum-termFrequencyScore: 116.00, length-ratio: 1.00 + wantScore: 1.81, }, { // Matches only on filename diff --git a/contentprovider.go b/contentprovider.go index 448637527..bbd334d75 100644 --- a/contentprovider.go +++ b/contentprovider.go @@ -588,6 +588,22 @@ func findMaxOverlappingSection(secs []DocumentSection, off, sz uint32) (uint32, return uint32(j), ol1 > 0 } +func (p *contentProvider) matchesSymbol(cm *candidateMatch) bool { + if cm.fileName { + return false + } + + // Check if this candidate came from a symbol matchTree + if cm.symbol { + return true + } + + // Check if it overlaps with a symbol. + secs := p.docSections() + _, ok := findMaxOverlappingSection(secs, cm.byteOffset, cm.byteMatchSz) + return ok +} + func (p *contentProvider) findSymbol(cm *candidateMatch) (DocumentSection, *Symbol, bool) { if cm.fileName { return DocumentSection{}, nil, false @@ -619,6 +635,29 @@ func (p *contentProvider) findSymbol(cm *candidateMatch) (DocumentSection, *Symb return sec, si, true } +// calculateTermFrequency computes the term frequency for the file match. +// Notes: +// * Filename matches count more than content matches. This mimics a common text search strategy to 'boost' matches on document titles. +// * Symbol matches also count more than content matches, to reward matches on symbol definitions. +func (p *contentProvider) calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int { + // Treat each candidate match as a term and compute the frequencies. For now, ignore case + // sensitivity and treat filenames and symbols the same as content. + termFreqs := map[string]int{} + for _, m := range cands { + term := string(m.substrLowered) + if m.fileName || p.matchesSymbol(m) { + termFreqs[term] += 5 + } else { + termFreqs[term]++ + } + } + + for term := range termFreqs { + df[term] += 1 + } + return termFreqs +} + func (p *contentProvider) candidateMatchScore(ms []*candidateMatch, language string, debug bool) (float64, string, []*Symbol) { type debugScore struct { what string diff --git a/eval.go b/eval.go index 0788fe35d..c54070f25 100644 --- a/eval.go +++ b/eval.go @@ -339,7 +339,7 @@ nextFileMatch: // document frequencies. Since we don't store document frequencies in the index, // we have to defer the calculation of the final BM25 score to after the whole // shard has been processed. - tf = calculateTermFrequency(finalCands, df) + tf = cp.calculateTermFrequency(finalCands, df) } else { // Use the standard, non-experimental scoring method by default d.scoreFile(&fileMatch, nextDoc, mt, known, opts) diff --git a/score.go b/score.go index bba54fd79..09faad837 100644 --- a/score.go +++ b/score.go @@ -110,30 +110,6 @@ func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, kn } } -// calculateTermFrequency computes the term frequency for the file match. -// -// Filename matches count more than content matches. This mimics a common text -// search strategy where you 'boost' matches on document titles. -func calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int { - // Treat each candidate match as a term and compute the frequencies. For now, ignore case - // sensitivity and treat filenames and symbols the same as content. - termFreqs := map[string]int{} - for _, cand := range cands { - term := string(cand.substrLowered) - if cand.fileName { - termFreqs[term] += 5 - } else { - termFreqs[term]++ - } - } - - for term := range termFreqs { - df[term] += 1 - } - - return termFreqs -} - // idf computes the inverse document frequency for a term. nq is the number of // documents that contain the term and documentCount is the total number of // documents in the corpus. diff --git a/score_test.go b/score_test.go deleted file mode 100644 index 2e3b13840..000000000 --- a/score_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package zoekt - -import ( - "maps" - "testing" -) - -func TestCalculateTermFrequency(t *testing.T) { - cases := []struct { - cands []*candidateMatch - wantDF termDocumentFrequency - wantTermFrequencies map[string]int - }{{ - cands: []*candidateMatch{ - {substrLowered: []byte("foo")}, - {substrLowered: []byte("foo")}, - {substrLowered: []byte("bar")}, - { - substrLowered: []byte("bas"), - fileName: true, - }, - }, - wantDF: termDocumentFrequency{ - "foo": 1, - "bar": 1, - "bas": 1, - }, - wantTermFrequencies: map[string]int{ - "foo": 2, - "bar": 1, - "bas": 5, - }, - }, - } - - for _, c := range cases { - t.Run("", func(t *testing.T) { - fm := FileMatch{} - df := make(termDocumentFrequency) - tf := calculateTermFrequency(c.cands, df) - - if !maps.Equal(df, c.wantDF) { - t.Errorf("got %v, want %v", df, c.wantDF) - } - - if !maps.Equal(tf, c.wantTermFrequencies) { - t.Errorf("got %v, want %v", fm, c.wantTermFrequencies) - } - }) - } -}