diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py index 6bcd94220..bbac547c3 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py @@ -24,7 +24,7 @@ class JinaDocumentEmbedder: # Make sure that the environment variable JINA_API_KEY is set - document_embedder = JinaDocumentEmbedder() + document_embedder = JinaDocumentEmbedder(task="retrieval.query") doc = Document(content="I love pizza!") @@ -38,13 +38,16 @@ class JinaDocumentEmbedder: def __init__( self, api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008 - model: str = "jina-embeddings-v2-base-en", + model: str = "jina-embeddings-v3", prefix: str = "", suffix: str = "", batch_size: int = 32, progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + task: Optional[str] = None, + dimensions: Optional[int] = None, + late_chunking: Optional[bool] = None, ): """ Create a JinaDocumentEmbedder component. @@ -78,6 +81,9 @@ def __init__( "Content-type": "application/json", } ) + self.task = task + self.dimensions = dimensions + self.late_chunking = late_chunking def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -91,17 +97,25 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - return default_to_dict( - self, - api_key=self.api_key.to_dict(), - model=self.model_name, - prefix=self.prefix, - suffix=self.suffix, - batch_size=self.batch_size, - progress_bar=self.progress_bar, - meta_fields_to_embed=self.meta_fields_to_embed, - embedding_separator=self.embedding_separator, - ) + kwargs = { + "api_key": self.api_key.to_dict(), + "model": self.model_name, + "prefix": self.prefix, + "suffix": self.suffix, + "batch_size": self.batch_size, + "progress_bar": self.progress_bar, + "meta_fields_to_embed": self.meta_fields_to_embed, + "embedding_separator": self.embedding_separator, + } + # Optional parameters, the following two are only supported by embeddings-v3 for now + if self.task is not None: + kwargs["task"] = self.task + if self.dimensions is not None: + kwargs["dimensions"] = self.dimensions + if self.late_chunking is not None: + kwargs["late_chunking"] = self.late_chunking + + return default_to_dict(self, **kwargs) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "JinaDocumentEmbedder": @@ -131,7 +145,9 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: texts_to_embed.append(text_to_embed) return texts_to_embed - def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: + def _embed_batch( + self, texts_to_embed: List[str], batch_size: int, parameters: Optional[Dict] = None + ) -> Tuple[List[List[float]], Dict[str, Any]]: """ Embed a list of texts in batches. """ @@ -142,7 +158,10 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i : i + batch_size] - response = self._session.post(JINA_API_URL, json={"input": batch, "model": self.model_name}).json() + response = self._session.post( + JINA_API_URL, + json={"input": batch, "model": self.model_name, **(parameters or {})}, + ).json() if "data" not in response: raise RuntimeError(response["detail"]) @@ -179,8 +198,16 @@ def run(self, documents: List[Document]): raise TypeError(msg) texts_to_embed = self._prepare_texts_to_embed(documents=documents) - - embeddings, metadata = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) + parameters: Dict[str, Any] = {} + if self.task is not None: + parameters["task"] = self.task + if self.dimensions is not None: + parameters["dimensions"] = self.dimensions + if self.late_chunking is not None: + parameters["late_chunking"] = self.late_chunking + embeddings, metadata = self._embed_batch( + texts_to_embed=texts_to_embed, batch_size=self.batch_size, parameters=parameters + ) for doc, emb in zip(documents, embeddings): doc.embedding = emb diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py index 6398122a4..c22f9ea2c 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import requests from haystack import component, default_from_dict, default_to_dict @@ -21,14 +21,14 @@ class JinaTextEmbedder: # Make sure that the environment variable JINA_API_KEY is set - text_embedder = JinaTextEmbedder() + text_embedder = JinaTextEmbedder(task="retrieval.query") text_to_embed = "I love pizza!" print(text_embedder.run(text_to_embed)) # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], - # 'meta': {'model': 'jina-embeddings-v2-base-en', + # 'meta': {'model': 'jina-embeddings-v3', # 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}} ``` """ @@ -36,9 +36,12 @@ class JinaTextEmbedder: def __init__( self, api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008 - model: str = "jina-embeddings-v2-base-en", + model: str = "jina-embeddings-v3", prefix: str = "", suffix: str = "", + task: Optional[str] = None, + dimensions: Optional[int] = None, + late_chunking: Optional[bool] = None, ): """ Create a JinaTextEmbedder component. @@ -65,6 +68,9 @@ def __init__( "Content-type": "application/json", } ) + self.task = task + self.dimensions = dimensions + self.late_chunking = late_chunking def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -78,9 +84,20 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - return default_to_dict( - self, api_key=self.api_key.to_dict(), model=self.model_name, prefix=self.prefix, suffix=self.suffix - ) + kwargs = { + "api_key": self.api_key.to_dict(), + "model": self.model_name, + "prefix": self.prefix, + "suffix": self.suffix, + } + # Optional parameters, the following two are only supported by embeddings-v3 for now + if self.task is not None: + kwargs["task"] = self.task + if self.dimensions is not None: + kwargs["dimensions"] = self.dimensions + if self.late_chunking is not None: + kwargs["late_chunking"] = self.late_chunking + return default_to_dict(self, **kwargs) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "JinaTextEmbedder": @@ -114,7 +131,19 @@ def run(self, text: str): text_to_embed = self.prefix + text + self.suffix - resp = self._session.post(JINA_API_URL, json={"input": [text_to_embed], "model": self.model_name}).json() + parameters: Dict[str, Any] = {} + if self.task is not None: + parameters["task"] = self.task + if self.dimensions is not None: + parameters["dimensions"] = self.dimensions + if self.late_chunking is not None: + parameters["late_chunking"] = self.late_chunking + + resp = self._session.post( + JINA_API_URL, + json={"input": [text_to_embed], "model": self.model_name, **parameters}, + ).json() + if "data" not in resp: raise RuntimeError(resp["detail"]) diff --git a/integrations/jina/tests/test_document_embedder.py b/integrations/jina/tests/test_document_embedder.py index 9d63f8302..247b95eff 100644 --- a/integrations/jina/tests/test_document_embedder.py +++ b/integrations/jina/tests/test_document_embedder.py @@ -31,7 +31,7 @@ def test_init_default(self, monkeypatch): embedder = JinaDocumentEmbedder() assert embedder.api_key == Secret.from_env_var("JINA_API_KEY") - assert embedder.model_name == "jina-embeddings-v2-base-en" + assert embedder.model_name == "jina-embeddings-v3" assert embedder.prefix == "" assert embedder.suffix == "" assert embedder.batch_size == 32 @@ -49,6 +49,9 @@ def test_init_with_parameters(self): progress_bar=False, meta_fields_to_embed=["test_field"], embedding_separator=" | ", + task="retrieval.query", + dimensions=1024, + late_chunking=True, ) assert embedder.api_key == Secret.from_token("fake-api-key") @@ -59,6 +62,9 @@ def test_init_with_parameters(self): assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " + assert embedder.task == "retrieval.query" + assert embedder.dimensions == 1024 + assert embedder.late_chunking is True def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("JINA_API_KEY", raising=False) @@ -73,7 +79,7 @@ def test_to_dict(self, monkeypatch): "type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder", "init_parameters": { "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, - "model": "jina-embeddings-v2-base-en", + "model": "jina-embeddings-v3", "prefix": "", "suffix": "", "batch_size": 32, @@ -93,6 +99,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): progress_bar=False, meta_fields_to_embed=["test_field"], embedding_separator=" | ", + task="retrieval.query", + dimensions=1024, ) data = component.to_dict() assert data == { @@ -106,6 +114,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "progress_bar": False, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", + "task": "retrieval.query", + "dimensions": 1024, }, } @@ -246,3 +256,35 @@ def test_run_on_empty_list(self): assert result["documents"] is not None assert not result["documents"] # empty list + + def test_run_with_v3(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + model = "jina-embeddings-v3" + with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): + embedder = JinaDocumentEmbedder( + api_key=Secret.from_token("fake-api-key"), + model=model, + prefix="prefix ", + suffix=" suffix", + meta_fields_to_embed=["topic"], + embedding_separator=" | ", + batch_size=1, + task="retrieval.query", + ) + result = embedder.run(documents=docs) + + documents_with_embeddings = result["documents"] + metadata = result["meta"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 3 + assert all(isinstance(x, float) for x in doc.embedding) + assert metadata == {"model": model, "usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}} diff --git a/integrations/jina/tests/test_text_embedder.py b/integrations/jina/tests/test_text_embedder.py index 5c0f80d02..058712a18 100644 --- a/integrations/jina/tests/test_text_embedder.py +++ b/integrations/jina/tests/test_text_embedder.py @@ -17,7 +17,7 @@ def test_init_default(self, monkeypatch): embedder = JinaTextEmbedder() assert embedder.api_key == Secret.from_env_var("JINA_API_KEY") - assert embedder.model_name == "jina-embeddings-v2-base-en" + assert embedder.model_name == "jina-embeddings-v3" assert embedder.prefix == "" assert embedder.suffix == "" @@ -27,11 +27,13 @@ def test_init_with_parameters(self): model="model", prefix="prefix", suffix="suffix", + late_chunking=True, ) assert embedder.api_key == Secret.from_token("fake-api-key") assert embedder.model_name == "model" assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" + assert embedder.late_chunking is True def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("JINA_API_KEY", raising=False) @@ -46,7 +48,7 @@ def test_to_dict(self, monkeypatch): "type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder", "init_parameters": { "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, - "model": "jina-embeddings-v2-base-en", + "model": "jina-embeddings-v3", "prefix": "", "suffix": "", }, @@ -58,6 +60,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): model="model", prefix="prefix", suffix="suffix", + task="retrieval.query", + dimensions=1024, ) data = component.to_dict() assert data == { @@ -67,6 +71,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "model": "model", "prefix": "prefix", "suffix": "suffix", + "task": "retrieval.query", + "dimensions": 1024, }, } @@ -106,3 +112,36 @@ def test_run_wrong_input_format(self): with pytest.raises(TypeError, match="JinaTextEmbedder expects a string as an input"): embedder.run(text=list_integers_input) + + def test_with_v3(self): + model = "jina-embeddings-v3" + with patch("requests.sessions.Session.post") as mock_post: + # Configure the mock to return a specific response + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = json.dumps( + { + "model": "jina-embeddings-v3", + "object": "list", + "usage": {"total_tokens": 6, "prompt_tokens": 6}, + "data": [{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}], + } + ).encode() + + mock_post.return_value = mock_response + + embedder = JinaTextEmbedder( + api_key=Secret.from_token("fake-api-key"), + model=model, + prefix="prefix ", + suffix=" suffix", + task="retrieval.query", + ) + result = embedder.run(text="The food was delicious") + + assert len(result["embedding"]) == 3 + assert all(isinstance(x, float) for x in result["embedding"]) + assert result["meta"] == { + "model": "jina-embeddings-v3", + "usage": {"prompt_tokens": 6, "total_tokens": 6}, + }