Skip to content

Commit

Permalink
chore: update Jina Embedder usage for V3 release (#1077)
Browse files Browse the repository at this point in the history
* chore: update Jina Embedder usage for V3 release

* fix: resolve lint issue

* fix: resolve test error

* fix: resolve test error

* fix: resolve lint issues

* fix: resolve lint issues

* fix: resolve lint issues

* fix: resolve lint issues

* fix: resolve lint issues

* chore: update JinaEmbedding for v3 release

* fix: resolve test errors

* fix: resolve test errors

* chore: added test case

* fix: resolve lint issues

* fix: update function call

* fix: resolve lint issues

* fix: lint error

* fix: lint error

* chore: remove unnecessary test cases

* chore: use 'task' instead of 'task_type'

* chore: add 'late_chunking' for Jina embedders

* Update integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py

Co-authored-by: Silvano Cerza <[email protected]>

---------

Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
2 people authored and Amnah199 committed Oct 2, 2024
1 parent cf299d5 commit e37741c
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class JinaDocumentEmbedder:
# Make sure that the environment variable JINA_API_KEY is set
document_embedder = JinaDocumentEmbedder()
document_embedder = JinaDocumentEmbedder(task="retrieval.query")
doc = Document(content="I love pizza!")
Expand All @@ -38,13 +38,16 @@ class JinaDocumentEmbedder:
def __init__(
self,
api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008
model: str = "jina-embeddings-v2-base-en",
model: str = "jina-embeddings-v3",
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
task: Optional[str] = None,
dimensions: Optional[int] = None,
late_chunking: Optional[bool] = None,
):
"""
Create a JinaDocumentEmbedder component.
Expand Down Expand Up @@ -78,6 +81,9 @@ def __init__(
"Content-type": "application/json",
}
)
self.task = task
self.dimensions = dimensions
self.late_chunking = late_chunking

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -91,17 +97,25 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
api_key=self.api_key.to_dict(),
model=self.model_name,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
kwargs = {
"api_key": self.api_key.to_dict(),
"model": self.model_name,
"prefix": self.prefix,
"suffix": self.suffix,
"batch_size": self.batch_size,
"progress_bar": self.progress_bar,
"meta_fields_to_embed": self.meta_fields_to_embed,
"embedding_separator": self.embedding_separator,
}
# Optional parameters, the following two are only supported by embeddings-v3 for now
if self.task is not None:
kwargs["task"] = self.task
if self.dimensions is not None:
kwargs["dimensions"] = self.dimensions
if self.late_chunking is not None:
kwargs["late_chunking"] = self.late_chunking

return default_to_dict(self, **kwargs)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "JinaDocumentEmbedder":
Expand Down Expand Up @@ -131,7 +145,9 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
texts_to_embed.append(text_to_embed)
return texts_to_embed

def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
def _embed_batch(
self, texts_to_embed: List[str], batch_size: int, parameters: Optional[Dict] = None
) -> Tuple[List[List[float]], Dict[str, Any]]:
"""
Embed a list of texts in batches.
"""
Expand All @@ -142,7 +158,10 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
response = self._session.post(JINA_API_URL, json={"input": batch, "model": self.model_name}).json()
response = self._session.post(
JINA_API_URL,
json={"input": batch, "model": self.model_name, **(parameters or {})},
).json()
if "data" not in response:
raise RuntimeError(response["detail"])

Expand Down Expand Up @@ -179,8 +198,16 @@ def run(self, documents: List[Document]):
raise TypeError(msg)

texts_to_embed = self._prepare_texts_to_embed(documents=documents)

embeddings, metadata = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
parameters: Dict[str, Any] = {}
if self.task is not None:
parameters["task"] = self.task
if self.dimensions is not None:
parameters["dimensions"] = self.dimensions
if self.late_chunking is not None:
parameters["late_chunking"] = self.late_chunking
embeddings, metadata = self._embed_batch(
texts_to_embed=texts_to_embed, batch_size=self.batch_size, parameters=parameters
)

for doc, emb in zip(documents, embeddings):
doc.embedding = emb
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import requests
from haystack import component, default_from_dict, default_to_dict
Expand All @@ -21,24 +21,27 @@ class JinaTextEmbedder:
# Make sure that the environment variable JINA_API_KEY is set
text_embedder = JinaTextEmbedder()
text_embedder = JinaTextEmbedder(task="retrieval.query")
text_to_embed = "I love pizza!"
print(text_embedder.run(text_to_embed))
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
# 'meta': {'model': 'jina-embeddings-v2-base-en',
# 'meta': {'model': 'jina-embeddings-v3',
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
```
"""

def __init__(
self,
api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008
model: str = "jina-embeddings-v2-base-en",
model: str = "jina-embeddings-v3",
prefix: str = "",
suffix: str = "",
task: Optional[str] = None,
dimensions: Optional[int] = None,
late_chunking: Optional[bool] = None,
):
"""
Create a JinaTextEmbedder component.
Expand All @@ -65,6 +68,9 @@ def __init__(
"Content-type": "application/json",
}
)
self.task = task
self.dimensions = dimensions
self.late_chunking = late_chunking

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -78,9 +84,20 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self, api_key=self.api_key.to_dict(), model=self.model_name, prefix=self.prefix, suffix=self.suffix
)
kwargs = {
"api_key": self.api_key.to_dict(),
"model": self.model_name,
"prefix": self.prefix,
"suffix": self.suffix,
}
# Optional parameters, the following two are only supported by embeddings-v3 for now
if self.task is not None:
kwargs["task"] = self.task
if self.dimensions is not None:
kwargs["dimensions"] = self.dimensions
if self.late_chunking is not None:
kwargs["late_chunking"] = self.late_chunking
return default_to_dict(self, **kwargs)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "JinaTextEmbedder":
Expand Down Expand Up @@ -114,7 +131,19 @@ def run(self, text: str):

text_to_embed = self.prefix + text + self.suffix

resp = self._session.post(JINA_API_URL, json={"input": [text_to_embed], "model": self.model_name}).json()
parameters: Dict[str, Any] = {}
if self.task is not None:
parameters["task"] = self.task
if self.dimensions is not None:
parameters["dimensions"] = self.dimensions
if self.late_chunking is not None:
parameters["late_chunking"] = self.late_chunking

resp = self._session.post(
JINA_API_URL,
json={"input": [text_to_embed], "model": self.model_name, **parameters},
).json()

if "data" not in resp:
raise RuntimeError(resp["detail"])

Expand Down
46 changes: 44 additions & 2 deletions integrations/jina/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_init_default(self, monkeypatch):
embedder = JinaDocumentEmbedder()

assert embedder.api_key == Secret.from_env_var("JINA_API_KEY")
assert embedder.model_name == "jina-embeddings-v2-base-en"
assert embedder.model_name == "jina-embeddings-v3"
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 32
Expand All @@ -49,6 +49,9 @@ def test_init_with_parameters(self):
progress_bar=False,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
task="retrieval.query",
dimensions=1024,
late_chunking=True,
)

assert embedder.api_key == Secret.from_token("fake-api-key")
Expand All @@ -59,6 +62,9 @@ def test_init_with_parameters(self):
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
assert embedder.task == "retrieval.query"
assert embedder.dimensions == 1024
assert embedder.late_chunking is True

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("JINA_API_KEY", raising=False)
Expand All @@ -73,7 +79,7 @@ def test_to_dict(self, monkeypatch):
"type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"},
"model": "jina-embeddings-v2-base-en",
"model": "jina-embeddings-v3",
"prefix": "",
"suffix": "",
"batch_size": 32,
Expand All @@ -93,6 +99,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
progress_bar=False,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
task="retrieval.query",
dimensions=1024,
)
data = component.to_dict()
assert data == {
Expand All @@ -106,6 +114,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
"progress_bar": False,
"meta_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
"task": "retrieval.query",
"dimensions": 1024,
},
}

Expand Down Expand Up @@ -246,3 +256,35 @@ def test_run_on_empty_list(self):

assert result["documents"] is not None
assert not result["documents"] # empty list

def test_run_with_v3(self):
docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]

model = "jina-embeddings-v3"
with patch("requests.sessions.Session.post", side_effect=mock_session_post_response):
embedder = JinaDocumentEmbedder(
api_key=Secret.from_token("fake-api-key"),
model=model,
prefix="prefix ",
suffix=" suffix",
meta_fields_to_embed=["topic"],
embedding_separator=" | ",
batch_size=1,
task="retrieval.query",
)
result = embedder.run(documents=docs)

documents_with_embeddings = result["documents"]
metadata = result["meta"]

assert isinstance(documents_with_embeddings, list)
assert len(documents_with_embeddings) == len(docs)
for doc in documents_with_embeddings:
assert isinstance(doc, Document)
assert isinstance(doc.embedding, list)
assert len(doc.embedding) == 3
assert all(isinstance(x, float) for x in doc.embedding)
assert metadata == {"model": model, "usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}}
43 changes: 41 additions & 2 deletions integrations/jina/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_init_default(self, monkeypatch):
embedder = JinaTextEmbedder()

assert embedder.api_key == Secret.from_env_var("JINA_API_KEY")
assert embedder.model_name == "jina-embeddings-v2-base-en"
assert embedder.model_name == "jina-embeddings-v3"
assert embedder.prefix == ""
assert embedder.suffix == ""

Expand All @@ -27,11 +27,13 @@ def test_init_with_parameters(self):
model="model",
prefix="prefix",
suffix="suffix",
late_chunking=True,
)
assert embedder.api_key == Secret.from_token("fake-api-key")
assert embedder.model_name == "model"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.late_chunking is True

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("JINA_API_KEY", raising=False)
Expand All @@ -46,7 +48,7 @@ def test_to_dict(self, monkeypatch):
"type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"},
"model": "jina-embeddings-v2-base-en",
"model": "jina-embeddings-v3",
"prefix": "",
"suffix": "",
},
Expand All @@ -58,6 +60,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
model="model",
prefix="prefix",
suffix="suffix",
task="retrieval.query",
dimensions=1024,
)
data = component.to_dict()
assert data == {
Expand All @@ -67,6 +71,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
"model": "model",
"prefix": "prefix",
"suffix": "suffix",
"task": "retrieval.query",
"dimensions": 1024,
},
}

Expand Down Expand Up @@ -106,3 +112,36 @@ def test_run_wrong_input_format(self):

with pytest.raises(TypeError, match="JinaTextEmbedder expects a string as an input"):
embedder.run(text=list_integers_input)

def test_with_v3(self):
model = "jina-embeddings-v3"
with patch("requests.sessions.Session.post") as mock_post:
# Configure the mock to return a specific response
mock_response = requests.Response()
mock_response.status_code = 200
mock_response._content = json.dumps(
{
"model": "jina-embeddings-v3",
"object": "list",
"usage": {"total_tokens": 6, "prompt_tokens": 6},
"data": [{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}],
}
).encode()

mock_post.return_value = mock_response

embedder = JinaTextEmbedder(
api_key=Secret.from_token("fake-api-key"),
model=model,
prefix="prefix ",
suffix=" suffix",
task="retrieval.query",
)
result = embedder.run(text="The food was delicious")

assert len(result["embedding"]) == 3
assert all(isinstance(x, float) for x in result["embedding"])
assert result["meta"] == {
"model": "jina-embeddings-v3",
"usage": {"prompt_tokens": 6, "total_tokens": 6},
}

0 comments on commit e37741c

Please sign in to comment.