Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update Jina Embedder usage for V3 release #1077

Merged
merged 30 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5f87bdd
chore: update Jina Embedder usage for V3 release
DresAaron Sep 11, 2024
e0f48cd
fix: resolve lint issue
DresAaron Sep 11, 2024
dff73c2
fix: resolve test error
DresAaron Sep 11, 2024
88f0eba
fix: resolve test error
DresAaron Sep 11, 2024
0ad3ffe
fix: resolve lint issues
DresAaron Sep 11, 2024
3819a18
fix: resolve lint issues
DresAaron Sep 11, 2024
e9fa853
fix: resolve lint issues
DresAaron Sep 11, 2024
4c02045
fix: resolve lint issues
DresAaron Sep 11, 2024
da2238a
fix: resolve lint issues
DresAaron Sep 11, 2024
4fe1ca4
chore: update JinaEmbedding for v3 release
DresAaron Sep 12, 2024
d517aa1
fix: resolve test errors
DresAaron Sep 12, 2024
9ef324e
fix: resolve test errors
DresAaron Sep 12, 2024
f79fa73
chore: added test case
DresAaron Sep 12, 2024
cfb18ee
fix: resolve lint issues
DresAaron Sep 12, 2024
531433f
fix: update function call
DresAaron Sep 13, 2024
c039af7
fix: resolve lint issues
DresAaron Sep 13, 2024
38781cb
fix: lint error
DresAaron Sep 13, 2024
c2d80f0
fix: lint error
DresAaron Sep 13, 2024
3244170
Merge branch 'main' into update-jina-embedders
DresAaron Sep 13, 2024
1a9b936
Merge branch 'main' into update-jina-embedders
DresAaron Sep 13, 2024
86cc892
chore: remove unnecessary test cases
DresAaron Sep 13, 2024
16756e5
Merge branch 'main' into update-jina-embedders
DresAaron Sep 13, 2024
c37cc79
Merge branch 'main' into update-jina-embedders
DresAaron Sep 17, 2024
3743948
Merge branch 'main' into update-jina-embedders
DresAaron Sep 18, 2024
5847efb
chore: use 'task' instead of 'task_type'
DresAaron Sep 18, 2024
9e01dd5
chore: add 'late_chunking' for Jina embedders
DresAaron Sep 18, 2024
6895860
Update integrations/jina/src/haystack_integrations/components/embedde…
DresAaron Sep 18, 2024
048f7a6
Update integrations/jina/src/haystack_integrations/components/embedde…
DresAaron Sep 18, 2024
c0bccaa
Update integrations/jina/src/haystack_integrations/components/embedde…
DresAaron Sep 18, 2024
3d45a2c
Update integrations/jina/src/haystack_integrations/components/embedde…
DresAaron Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class JinaDocumentEmbedder:

# Make sure that the environment variable JINA_API_KEY is set

document_embedder = JinaDocumentEmbedder()
document_embedder = JinaDocumentEmbedder(task="retrieval.query")

doc = Document(content="I love pizza!")

Expand All @@ -38,13 +38,16 @@ class JinaDocumentEmbedder:
def __init__(
self,
api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008
model: str = "jina-embeddings-v2-base-en",
model: str = "jina-embeddings-v3",
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
task: Optional[str] = None,
dimensions: Optional[int] = None,
late_chunking: Optional[bool] = None,
):
"""
Create a JinaDocumentEmbedder component.
Expand Down Expand Up @@ -78,6 +81,9 @@ def __init__(
"Content-type": "application/json",
}
)
self.task = task
self.dimensions = dimensions
self.late_chunking = late_chunking

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -91,17 +97,25 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
api_key=self.api_key.to_dict(),
model=self.model_name,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
kwargs = {
"api_key": self.api_key.to_dict(),
"model": self.model_name,
"prefix": self.prefix,
"suffix": self.suffix,
"batch_size": self.batch_size,
"progress_bar": self.progress_bar,
"meta_fields_to_embed": self.meta_fields_to_embed,
"embedding_separator": self.embedding_separator,
}
# Optional parameters, the following two are only supported by embeddings-v3 for now
if self.task:
kwargs["task"] = self.task
if self.dimensions:
kwargs["dimensions"] = self.dimensions
DresAaron marked this conversation as resolved.
Show resolved Hide resolved
if self.late_chunking is not None:
kwargs["late_chunking"] = self.late_chunking

return default_to_dict(self, **kwargs)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "JinaDocumentEmbedder":
Expand Down Expand Up @@ -131,7 +145,9 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
texts_to_embed.append(text_to_embed)
return texts_to_embed

def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
def _embed_batch(
self, texts_to_embed: List[str], batch_size: int, parameters: Optional[Dict] = None
) -> Tuple[List[List[float]], Dict[str, Any]]:
"""
Embed a list of texts in batches.
"""
Expand All @@ -142,7 +158,10 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
response = self._session.post(JINA_API_URL, json={"input": batch, "model": self.model_name}).json()
response = self._session.post(
JINA_API_URL,
json={"input": batch, "model": self.model_name, **(parameters if parameters is not None else {})},
DresAaron marked this conversation as resolved.
Show resolved Hide resolved
).json()
if "data" not in response:
raise RuntimeError(response["detail"])

Expand Down Expand Up @@ -179,8 +198,16 @@ def run(self, documents: List[Document]):
raise TypeError(msg)

texts_to_embed = self._prepare_texts_to_embed(documents=documents)

embeddings, metadata = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
parameters: Dict[str, Any] = {}
if self.task:
parameters["task"] = self.task
if self.dimensions:
parameters["dimensions"] = self.dimensions
DresAaron marked this conversation as resolved.
Show resolved Hide resolved
if self.late_chunking is not None:
parameters["late_chunking"] = self.late_chunking
embeddings, metadata = self._embed_batch(
texts_to_embed=texts_to_embed, batch_size=self.batch_size, parameters=parameters
)

for doc, emb in zip(documents, embeddings):
doc.embedding = emb
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import requests
from haystack import component, default_from_dict, default_to_dict
Expand All @@ -21,24 +21,27 @@ class JinaTextEmbedder:

# Make sure that the environment variable JINA_API_KEY is set

text_embedder = JinaTextEmbedder()
text_embedder = JinaTextEmbedder(task="retrieval.query")

text_to_embed = "I love pizza!"

print(text_embedder.run(text_to_embed))

# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
# 'meta': {'model': 'jina-embeddings-v2-base-en',
# 'meta': {'model': 'jina-embeddings-v3',
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
```
"""

def __init__(
self,
api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008
model: str = "jina-embeddings-v2-base-en",
model: str = "jina-embeddings-v3",
prefix: str = "",
suffix: str = "",
task: Optional[str] = None,
dimensions: Optional[int] = None,
late_chunking: Optional[bool] = None,
):
"""
Create a JinaTextEmbedder component.
Expand All @@ -65,6 +68,9 @@ def __init__(
"Content-type": "application/json",
}
)
self.task = task
self.dimensions = dimensions
self.late_chunking = late_chunking

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -78,9 +84,20 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self, api_key=self.api_key.to_dict(), model=self.model_name, prefix=self.prefix, suffix=self.suffix
)
kwargs = {
"api_key": self.api_key.to_dict(),
"model": self.model_name,
"prefix": self.prefix,
"suffix": self.suffix,
}
# Optional parameters, the following two are only supported by embeddings-v3 for now
if self.task:
kwargs["task"] = self.task
if self.dimensions:
kwargs["dimensions"] = self.dimensions
DresAaron marked this conversation as resolved.
Show resolved Hide resolved
if self.late_chunking is not None:
kwargs["late_chunking"] = self.late_chunking
return default_to_dict(self, **kwargs)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "JinaTextEmbedder":
Expand Down Expand Up @@ -114,7 +131,19 @@ def run(self, text: str):

text_to_embed = self.prefix + text + self.suffix

resp = self._session.post(JINA_API_URL, json={"input": [text_to_embed], "model": self.model_name}).json()
parameters: Dict[str, Any] = {}
if self.task is not None:
parameters["task"] = self.task
if self.dimensions is not None:
parameters["dimensions"] = self.dimensions
if self.late_chunking is not None:
parameters["late_chunking"] = self.late_chunking

resp = self._session.post(
JINA_API_URL,
json={"input": [text_to_embed], "model": self.model_name, **parameters},
).json()

if "data" not in resp:
raise RuntimeError(resp["detail"])

Expand Down
46 changes: 44 additions & 2 deletions integrations/jina/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_init_default(self, monkeypatch):
embedder = JinaDocumentEmbedder()

assert embedder.api_key == Secret.from_env_var("JINA_API_KEY")
assert embedder.model_name == "jina-embeddings-v2-base-en"
assert embedder.model_name == "jina-embeddings-v3"
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 32
Expand All @@ -49,6 +49,9 @@ def test_init_with_parameters(self):
progress_bar=False,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
task="retrieval.query",
dimensions=1024,
late_chunking=True,
)

assert embedder.api_key == Secret.from_token("fake-api-key")
Expand All @@ -59,6 +62,9 @@ def test_init_with_parameters(self):
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
assert embedder.task == "retrieval.query"
assert embedder.dimensions == 1024
assert embedder.late_chunking is True

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("JINA_API_KEY", raising=False)
Expand All @@ -73,7 +79,7 @@ def test_to_dict(self, monkeypatch):
"type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"},
"model": "jina-embeddings-v2-base-en",
"model": "jina-embeddings-v3",
"prefix": "",
"suffix": "",
"batch_size": 32,
Expand All @@ -93,6 +99,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
progress_bar=False,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
task="retrieval.query",
dimensions=1024,
)
data = component.to_dict()
assert data == {
Expand All @@ -106,6 +114,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
"progress_bar": False,
"meta_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
"task": "retrieval.query",
"dimensions": 1024,
},
}

Expand Down Expand Up @@ -246,3 +256,35 @@ def test_run_on_empty_list(self):

assert result["documents"] is not None
assert not result["documents"] # empty list

def test_run_with_v3(self):
docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]

model = "jina-embeddings-v3"
with patch("requests.sessions.Session.post", side_effect=mock_session_post_response):
embedder = JinaDocumentEmbedder(
api_key=Secret.from_token("fake-api-key"),
model=model,
prefix="prefix ",
suffix=" suffix",
meta_fields_to_embed=["topic"],
embedding_separator=" | ",
batch_size=1,
task="retrieval.query",
)
result = embedder.run(documents=docs)

documents_with_embeddings = result["documents"]
metadata = result["meta"]

assert isinstance(documents_with_embeddings, list)
assert len(documents_with_embeddings) == len(docs)
for doc in documents_with_embeddings:
assert isinstance(doc, Document)
assert isinstance(doc.embedding, list)
assert len(doc.embedding) == 3
assert all(isinstance(x, float) for x in doc.embedding)
assert metadata == {"model": model, "usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}}
43 changes: 41 additions & 2 deletions integrations/jina/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_init_default(self, monkeypatch):
embedder = JinaTextEmbedder()

assert embedder.api_key == Secret.from_env_var("JINA_API_KEY")
assert embedder.model_name == "jina-embeddings-v2-base-en"
assert embedder.model_name == "jina-embeddings-v3"
assert embedder.prefix == ""
assert embedder.suffix == ""

Expand All @@ -27,11 +27,13 @@ def test_init_with_parameters(self):
model="model",
prefix="prefix",
suffix="suffix",
late_chunking=True,
)
assert embedder.api_key == Secret.from_token("fake-api-key")
assert embedder.model_name == "model"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.late_chunking is True

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("JINA_API_KEY", raising=False)
Expand All @@ -46,7 +48,7 @@ def test_to_dict(self, monkeypatch):
"type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"},
"model": "jina-embeddings-v2-base-en",
"model": "jina-embeddings-v3",
"prefix": "",
"suffix": "",
},
Expand All @@ -58,6 +60,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
model="model",
prefix="prefix",
suffix="suffix",
task="retrieval.query",
dimensions=1024,
)
data = component.to_dict()
assert data == {
Expand All @@ -67,6 +71,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
"model": "model",
"prefix": "prefix",
"suffix": "suffix",
"task": "retrieval.query",
"dimensions": 1024,
},
}

Expand Down Expand Up @@ -106,3 +112,36 @@ def test_run_wrong_input_format(self):

with pytest.raises(TypeError, match="JinaTextEmbedder expects a string as an input"):
embedder.run(text=list_integers_input)

def test_with_v3(self):
model = "jina-embeddings-v3"
with patch("requests.sessions.Session.post") as mock_post:
# Configure the mock to return a specific response
mock_response = requests.Response()
mock_response.status_code = 200
mock_response._content = json.dumps(
{
"model": "jina-embeddings-v3",
"object": "list",
"usage": {"total_tokens": 6, "prompt_tokens": 6},
"data": [{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}],
}
).encode()

mock_post.return_value = mock_response

embedder = JinaTextEmbedder(
api_key=Secret.from_token("fake-api-key"),
model=model,
prefix="prefix ",
suffix=" suffix",
task="retrieval.query",
)
result = embedder.run(text="The food was delicious")

assert len(result["embedding"]) == 3
assert all(isinstance(x, float) for x in result["embedding"])
assert result["meta"] == {
"model": "jina-embeddings-v3",
"usage": {"prompt_tokens": 6, "total_tokens": 6},
}