Skip to content

Commit

Permalink
test: add tests for Jina Ranker
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Mar 6, 2024
1 parent 89f722d commit 5e9e24e
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from typing import Any, Dict, List, Optional

import requests
from haystack import Document, component
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace

JINA_API_URL: str = "https://api.jina.ai/v1/rerank"

Expand All @@ -31,27 +32,19 @@ class JinaRanker:

def __init__(
self,
model_name: str = "jinaai/jina-reranker-v1-base-en",
api_key: Optional[str] = None,
top_k: int = 10,
query_prefix: str = "",
document_prefix: str = "",
model: str = "jina-reranker-v1-base-en",
api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008,
top_k: Optional[int] = None,
score_threshold: Optional[float] = None,
):
"""
Creates an instance of JinaRanker.
:param api_key: The Jina API key. It can be explicitly provided or automatically read from the
environment variable JINA_API_KEY (recommended).
:param model_name: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/`
:param model: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/`
:param top_k:
The maximum number of Documents to return per query.
:param query_prefix:
A string to add to the beginning of the query text before ranking.
Can be used to prepend the text with an instruction, as required by some reranking models, such as bge.
:param document_prefix:
A string to add to the beginning of each Document text before ranking. Can be used to prepend the text with
an instruction, as required by some embedding models, such as bge.
The maximum number of Documents to return per query. If None, all documents are returned
:param score_threshold:
If provided only returns documents with a score above this threshold.
Expand All @@ -60,39 +53,56 @@ def __init__(
If `scale_score` is True and `calibration_factor` is not provided.
"""
# if the user does not provide the API key, check if it is set in the module client
if api_key is None:
try:
api_key = os.environ["JINA_API_KEY"]
except KeyError as e:
msg = (
"JinaRanker expects a Jina API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg) from e
self.model_name = model_name
self.query_prefix = query_prefix
self.document_prefix = document_prefix
resolved_api_key = api_key.resolve_value()
self.api_key = api_key
self.model = model
self.top_k = top_k
self.score_threshold = score_threshold

if self.top_k <= 0:
if self.top_k is not None and self.top_k <= 0:
msg = f"top_k must be > 0, but got {top_k}"
raise ValueError(msg)
# if the user does not provide the API key, check if it is set in the module client
self._session = requests.Session()
self._session.headers.update(
{
"Authorization": f"Bearer {api_key}",
"Authorization": f"Bearer {resolved_api_key}",
"Accept-Encoding": "identity",
"Content-type": "application/json",
}
)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
api_key=self.api_key.to_dict(),
model=self.model,
top_k=self.top_k,
score_threshold=self.score_threshold,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "JinaRanker":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

@component.output_types(documents=List[Document])
def run(
Expand Down Expand Up @@ -122,21 +132,22 @@ def run(
"""
if not documents:
return {"documents": []}
print(f'top_k {top_k}')
if top_k is not None and top_k <= 0:
msg = f"top_k must be > 0, but got {top_k}"
raise ValueError(msg)

top_k = top_k or self.top_k
score_threshold = score_threshold or self.score_threshold

if top_k <= 0:
msg = f"top_k must be > 0, but got {top_k}"
raise ValueError(msg)
data = {"query": query,
"documents": [doc.content or "" for doc in documents],
"model": self.model}
if top_k is not None:
data["top_n"] = top_k
resp = self._session.post( # type: ignore
JINA_API_URL,
json={
"query": query,
"documents": [doc.content or "" for doc in documents],
"model": self.model,
"top_n": top_k,
},
json=data,
).json()
if "results" not in resp:
raise RuntimeError(resp["detail"])
Expand All @@ -146,14 +157,16 @@ def run(
ranked_docs = []
for result in results:
index = result["index"]
relevance_score = results["relevance_score"]
relevance_score = result["relevance_score"]
doc = documents[index]
if top_k is None or len(ranked_docs) < top_k:
doc.score = relevance_score
if score_threshold is not None:
if relevance_score >= score_threshold:
ranked_docs.append(documents[index])
ranked_docs.append(doc)
else:
ranked_docs.append(documents[index])
ranked_docs.append(doc)
else:
break

return {"documents": ranked_docs}
return {"documents": ranked_docs, "meta": {"model": resp["model"], "usage": resp["usage"]}}
125 changes: 125 additions & 0 deletions integrations/jina/tests/test_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import json
from unittest.mock import patch

import pytest
import requests
from haystack import Document
from haystack.utils import Secret
from haystack_integrations.components.rankers.jina import JinaRanker

def mock_session_post_response(*args, **kwargs): # noqa: ARG001
model = kwargs["json"]["model"]
documents = kwargs["json"]["documents"]
mock_response = requests.Response()
mock_response.status_code = 200
results = [{"index": i, "relevance_score": len(documents) - i, "document": {"text": doc}} for i, doc in enumerate(documents)]
mock_response._content = json.dumps(
{"model": model, "usage": {"total_tokens": 4, "prompt_tokens": 4}, "results": results}
).encode()

return mock_response


class TestJinaRanker:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "fake-api-key")
embedder = JinaRanker()

assert embedder.api_key == Secret.from_env_var("JINA_API_KEY")
assert embedder.model == "jina-reranker-v1-base-en"

def test_init_with_parameters(self):
embedder = JinaRanker(
api_key=Secret.from_token("fake-api-key"),
model="model",
top_k=64,
score_threshold=0.5
)

assert embedder.api_key == Secret.from_token("fake-api-key")
assert embedder.model == "model"
assert embedder.top_k == 64
assert embedder.score_threshold == 0.5

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("JINA_API_KEY", raising=False)
with pytest.raises(ValueError):
JinaRanker()

def test_to_dict(self, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "fake-api-key")
component = JinaRanker()
data = component.to_dict()
assert data == {
"type": "haystack_integrations.components.rankers.jina.ranker.JinaRanker",
"init_parameters": {
"api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"},
"model": "jina-reranker-v1-base-en",
"top_k": None,
"score_threshold": None
},
}

def test_to_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "fake-api-key")
component = JinaRanker(
model="model",
top_k=64,
score_threshold=0.5
)
data = component.to_dict()
assert data == {
"type": "haystack_integrations.components.rankers.jina.ranker.JinaRanker",
"init_parameters": {
"api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"},
"model": "model",
"top_k": 64,
"score_threshold": 0.5,
},
}

def test_run(self):
docs = [
Document(content="I love cheese"),
Document(content="A transformer is a deep learning architecture"),
Document(content="A transformer is something"),
Document(content="A transformer is not good"),
]
query = 'What is a transformer?'

model = "jina-ranker"
with patch("requests.sessions.Session.post", side_effect=mock_session_post_response):
ranker = JinaRanker(
api_key=Secret.from_token("fake-api-key"),
model=model,
)

result = ranker.run(query=query, documents=docs)

ranked_documents = result["documents"]
metadata = result["meta"]

assert isinstance(ranked_documents, list)
assert len(ranked_documents) == len(docs)
for i, doc in enumerate(ranked_documents):
assert isinstance(doc, Document)
assert doc.score == len(ranked_documents) - i
assert metadata == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}}

def test_run_wrong_input_format(self):
ranker = JinaRanker(api_key=Secret.from_token("fake-api-key"))

with pytest.raises(ValueError, match="top_k must be > 0, but got 0"):
ranker.run(query='query', documents=[Document(content='document')], top_k=0)

def test_run_on_empty_docs(self):
ranker = JinaRanker(api_key=Secret.from_token("fake-api-key"))

empty_list_input = []
result = ranker.run(query='a', documents=empty_list_input)

assert result["documents"] is not None
assert not result["documents"] # empty list

0 comments on commit 5e9e24e

Please sign in to comment.