-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathrerank.go
128 lines (109 loc) · 3.86 KB
/
rerank.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// Package main provides the main entry point and core functionality for the application.
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"sort"
"manifold/internal/sefii"
)
// RerankRequest defines the payload to send to the reranker.
type RerankRequest struct {
Model string `json:"model"`
Query string `json:"query"`
TopN int `json:"top_n"`
Documents []string `json:"documents"`
}
// RerankResult represents one document's rerank score.
type RerankResult struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
// RerankResponse represents the complete response from the reranker.
type RerankResponse struct {
Model string `json:"model"`
Object string `json:"object"`
Usage interface{} `json:"usage"`
Results []RerankResult `json:"results"`
}
// reRankChunks calls the llama.cpp reranker and reorders the chunks based on relevance.
// It takes a context, configuration, query string, and a slice of chunks as input.
// Returns the reordered chunks or an error if the reranking fails.
func reRankChunks(ctx context.Context, config *Config, query string, chunks []sefii.Chunk) ([]sefii.Chunk, error) {
documents := extractDocuments(chunks)
rankReq, err := createRerankRequest(query, documents, len(chunks))
if err != nil {
return nil, fmt.Errorf("failed to create rerank request: %w", err)
}
resp, err := sendRerankRequest(ctx, config.Reranker.Host, rankReq)
if err != nil {
return nil, fmt.Errorf("rerank request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("rerank failed with status %d: %s", resp.StatusCode, string(body))
}
rankResp, err := parseRerankResponse(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to decode rerank response: %w", err)
}
scoreMap := mapScores(rankResp.Results)
sortChunksByScore(chunks, scoreMap)
log.Printf("Reranking complete. Top score: %v", scoreMap[0])
return chunks, nil
}
// extractDocuments extracts the content of chunks into a slice of strings.
func extractDocuments(chunks []sefii.Chunk) []string {
documents := make([]string, len(chunks))
for i, ch := range chunks {
documents[i] = ch.Content
}
return documents
}
// createRerankRequest constructs the payload for the reranker.
func createRerankRequest(query string, documents []string, topN int) ([]byte, error) {
rankReq := RerankRequest{
Model: "slide-bge-reranker-v2-m3.Q8_0.gguf",
Query: query,
TopN: topN,
Documents: documents,
}
return json.Marshal(rankReq)
}
// sendRerankRequest sends the rerank request to the specified URL.
func sendRerankRequest(ctx context.Context, url string, payload []byte) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(payload))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
return client.Do(req)
}
// parseRerankResponse decodes the rerank response body into a RerankResponse struct.
func parseRerankResponse(body io.Reader) (*RerankResponse, error) {
var rankResp RerankResponse
if err := json.NewDecoder(body).Decode(&rankResp); err != nil {
return nil, err
}
return &rankResp, nil
}
// mapScores maps the relevance scores from the rerank results to their respective indices.
func mapScores(results []RerankResult) map[int]float64 {
scoreMap := make(map[int]float64)
for _, result := range results {
scoreMap[result.Index] = result.RelevanceScore
}
return scoreMap
}
// sortChunksByScore sorts the chunks in descending order of their relevance scores.
func sortChunksByScore(chunks []sefii.Chunk, scoreMap map[int]float64) {
sort.Slice(chunks, func(i, j int) bool {
return scoreMap[i] > scoreMap[j]
})
}