Skip to content

Commit

Permalink
Boost symbol matches in BM25
Browse files Browse the repository at this point in the history
  • Loading branch information
jtibshirani committed Dec 10, 2024
1 parent 37c4df8 commit 5a5f3de
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 80 deletions.
8 changes: 4 additions & 4 deletions build/scoring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func TestBM25(t *testing.T) {
query: &query.Substring{Pattern: "example"},
content: exampleJava,
language: "Java",
// bm25-score: 0.57 <- sum-termFrequencyScore: 10.00, length-ratio: 1.00
wantScore: 0.57,
// bm25-score: 0.58 <- sum-termFrequencyScore: 14.00, length-ratio: 1.00
wantScore: 0.58,
}, {
// Matches only on content
fileName: "example.java",
Expand All @@ -89,8 +89,8 @@ func TestBM25(t *testing.T) {
}},
content: exampleJava,
language: "Java",
// bm25-score: 1.75 <- sum-termFrequencyScore: 56.00, length-ratio: 1.00
wantScore: 1.75,
// bm25-score: 1.81 <- sum-termFrequencyScore: 116.00, length-ratio: 1.00
wantScore: 1.81,
},
{
// Matches only on filename
Expand Down
39 changes: 39 additions & 0 deletions contentprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,22 @@ func findMaxOverlappingSection(secs []DocumentSection, off, sz uint32) (uint32,
return uint32(j), ol1 > 0
}

func (p *contentProvider) matchesSymbol(cm *candidateMatch) bool {
if cm.fileName {
return false
}

// Check if this candidate came from a symbol matchTree
if cm.symbol {
return true
}

// Check if it overlaps with a symbol.
secs := p.docSections()
_, ok := findMaxOverlappingSection(secs, cm.byteOffset, cm.byteMatchSz)
return ok
}

func (p *contentProvider) findSymbol(cm *candidateMatch) (DocumentSection, *Symbol, bool) {
if cm.fileName {
return DocumentSection{}, nil, false
Expand Down Expand Up @@ -619,6 +635,29 @@ func (p *contentProvider) findSymbol(cm *candidateMatch) (DocumentSection, *Symb
return sec, si, true
}

// calculateTermFrequency computes the term frequency for the file match.
// Notes:
// * Filename matches count more than content matches. This mimics a common text search strategy to 'boost' matches on document titles.
// * Symbol matches also count more than content matches, to reward matches on symbol definitions.
func (p *contentProvider) calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int {
// Treat each candidate match as a term and compute the frequencies. For now, ignore case
// sensitivity and treat filenames and symbols the same as content.
termFreqs := map[string]int{}
for _, m := range cands {
term := string(m.substrLowered)
if m.fileName || p.matchesSymbol(m) {
termFreqs[term] += 5
} else {
termFreqs[term]++
}
}

for term := range termFreqs {
df[term] += 1
}
return termFreqs
}

func (p *contentProvider) candidateMatchScore(ms []*candidateMatch, language string, debug bool) (float64, string, []*Symbol) {
type debugScore struct {
what string
Expand Down
2 changes: 1 addition & 1 deletion eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ nextFileMatch:
// document frequencies. Since we don't store document frequencies in the index,
// we have to defer the calculation of the final BM25 score to after the whole
// shard has been processed.
tf = calculateTermFrequency(finalCands, df)
tf = cp.calculateTermFrequency(finalCands, df)
} else {
// Use the standard, non-experimental scoring method by default
d.scoreFile(&fileMatch, nextDoc, mt, known, opts)
Expand Down
24 changes: 0 additions & 24 deletions score.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,30 +110,6 @@ func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, kn
}
}

// calculateTermFrequency computes the term frequency for the file match.
//
// Filename matches count more than content matches. This mimics a common text
// search strategy where you 'boost' matches on document titles.
func calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int {
// Treat each candidate match as a term and compute the frequencies. For now, ignore case
// sensitivity and treat filenames and symbols the same as content.
termFreqs := map[string]int{}
for _, cand := range cands {
term := string(cand.substrLowered)
if cand.fileName {
termFreqs[term] += 5
} else {
termFreqs[term]++
}
}

for term := range termFreqs {
df[term] += 1
}

return termFreqs
}

// idf computes the inverse document frequency for a term. nq is the number of
// documents that contain the term and documentCount is the total number of
// documents in the corpus.
Expand Down
51 changes: 0 additions & 51 deletions score_test.go

This file was deleted.

0 comments on commit 5a5f3de

Please sign in to comment.