Skip to content

Commit

Permalink
cohere[patch]: Fix cohere rerank (#19624)
Browse files Browse the repository at this point in the history
Fix cohere rerank inspired by
#19486
  • Loading branch information
billytrend-cohere authored and hinthornw committed Apr 26, 2024
1 parent 69530b0 commit 9cdc2dc
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
8 changes: 6 additions & 2 deletions libs/partners/cohere/langchain_cohere/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,14 @@ def rerank(
model = model or self.model
top_n = top_n if (top_n is None or top_n > 0) else self.top_n
results = self.client.rerank(
query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc
query=query,
documents=docs,
model=model,
top_n=top_n,
max_chunks_per_doc=max_chunks_per_doc,
)
result_dicts = []
for res in results:
for res in results.results:
result_dicts.append(
{"index": res.index, "relevance_score": res.relevance_score}
)
Expand Down
16 changes: 16 additions & 0 deletions libs/partners/cohere/tests/integration_tests/test_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Test Cohere reranks."""
from langchain_core.documents import Document

from langchain_cohere import CohereRerank


def test_langchain_cohere_rerank_documents() -> None:
"""Test cohere rerank."""
rerank = CohereRerank()
test_documents = [
Document(page_content="This is a test document."),
Document(page_content="Another test document."),
]
test_query = "Test query"
results = rerank.rerank(test_documents, test_query)
assert len(results) == 2

0 comments on commit 9cdc2dc

Please sign in to comment.