Skip to content

Commit

Permalink
chore!: Rename model_name to model in the Jina integration (#230)
Browse files Browse the repository at this point in the history
* rename model_name to model in doc embedder

* rename model_name to model in text embedder

* fix tests

* leftover
  • Loading branch information
ZanSara authored Jan 18, 2024
1 parent 065a00f commit 00a55b2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 20 deletions.
8 changes: 4 additions & 4 deletions integrations/jina/src/jina_haystack/document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class JinaDocumentEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "jina-embeddings-v2-base-en",
model: str = "jina-embeddings-v2-base-en",
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
Expand All @@ -48,7 +48,7 @@ def __init__(
Create a JinaDocumentEmbedder component.
:param api_key: The Jina API key. It can be explicitly provided or automatically read from the
environment variable JINA_API_KEY (recommended).
:param model_name: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/`
:param model: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/`
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of Documents to encode at once.
Expand All @@ -67,7 +67,7 @@ def __init__(
)
raise ValueError(msg)

self.model_name = model_name
self.model_name = model
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
Expand Down Expand Up @@ -96,7 +96,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model_name=self.model_name,
model=self.model_name,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
Expand Down
8 changes: 4 additions & 4 deletions integrations/jina/src/jina_haystack/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class JinaTextEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "jina-embeddings-v2-base-en",
model: str = "jina-embeddings-v2-base-en",
prefix: str = "",
suffix: str = "",
):
Expand All @@ -43,7 +43,7 @@ def __init__(
:param api_key: The Jina API key. It can be explicitly provided or automatically read from the
environment variable JINA_API_KEY (recommended).
:param model_name: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/`
:param model: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/`
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""
Expand All @@ -57,7 +57,7 @@ def __init__(
)
raise ValueError(msg)

self.model_name = model_name
self.model_name = model
self.prefix = prefix
self.suffix = suffix
self._session = requests.Session()
Expand All @@ -81,7 +81,7 @@ def to_dict(self) -> Dict[str, Any]:
to the constructor.
"""

return default_to_dict(self, model_name=self.model_name, prefix=self.prefix, suffix=self.suffix)
return default_to_dict(self, model=self.model_name, prefix=self.prefix, suffix=self.suffix)

@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str):
Expand Down
14 changes: 7 additions & 7 deletions integrations/jina/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_init_default(self, monkeypatch):
def test_init_with_parameters(self):
embedder = JinaDocumentEmbedder(
api_key="fake-api-key",
model_name="model",
model="model",
prefix="prefix",
suffix="suffix",
batch_size=64,
Expand All @@ -67,7 +67,7 @@ def test_to_dict(self):
assert data == {
"type": "jina_haystack.document_embedder.JinaDocumentEmbedder",
"init_parameters": {
"model_name": "jina-embeddings-v2-base-en",
"model": "jina-embeddings-v2-base-en",
"prefix": "",
"suffix": "",
"batch_size": 32,
Expand All @@ -80,7 +80,7 @@ def test_to_dict(self):
def test_to_dict_with_custom_init_parameters(self):
component = JinaDocumentEmbedder(
api_key="fake-api-key",
model_name="model",
model="model",
prefix="prefix",
suffix="suffix",
batch_size=64,
Expand All @@ -92,7 +92,7 @@ def test_to_dict_with_custom_init_parameters(self):
assert data == {
"type": "jina_haystack.document_embedder.JinaDocumentEmbedder",
"init_parameters": {
"model_name": "model",
"model": "model",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_embed_batch(self):
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]

with patch("requests.sessions.Session.post", side_effect=mock_session_post_response):
embedder = JinaDocumentEmbedder(api_key="fake-api-key", model_name="model")
embedder = JinaDocumentEmbedder(api_key="fake-api-key", model="model")

embeddings, metadata = embedder._embed_batch(texts_to_embed=texts, batch_size=2)

Expand All @@ -164,7 +164,7 @@ def test_run(self):
with patch("requests.sessions.Session.post", side_effect=mock_session_post_response):
embedder = JinaDocumentEmbedder(
api_key="fake-api-key",
model_name=model,
model=model,
prefix="prefix ",
suffix=" suffix",
meta_fields_to_embed=["topic"],
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_run_custom_batch_size(self):
with patch("requests.sessions.Session.post", side_effect=mock_session_post_response):
embedder = JinaDocumentEmbedder(
api_key="fake-api-key",
model_name=model,
model=model,
prefix="prefix ",
suffix=" suffix",
meta_fields_to_embed=["topic"],
Expand Down
10 changes: 5 additions & 5 deletions integrations/jina/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_init_default(self, monkeypatch):
def test_init_with_parameters(self):
embedder = JinaTextEmbedder(
api_key="fake-api-key",
model_name="model",
model="model",
prefix="prefix",
suffix="suffix",
)
Expand All @@ -41,7 +41,7 @@ def test_to_dict(self):
assert data == {
"type": "jina_haystack.text_embedder.JinaTextEmbedder",
"init_parameters": {
"model_name": "jina-embeddings-v2-base-en",
"model": "jina-embeddings-v2-base-en",
"prefix": "",
"suffix": "",
},
Expand All @@ -50,15 +50,15 @@ def test_to_dict(self):
def test_to_dict_with_custom_init_parameters(self):
component = JinaTextEmbedder(
api_key="fake-api-key",
model_name="model",
model="model",
prefix="prefix",
suffix="suffix",
)
data = component.to_dict()
assert data == {
"type": "jina_haystack.text_embedder.JinaTextEmbedder",
"init_parameters": {
"model_name": "model",
"model": "model",
"prefix": "prefix",
"suffix": "suffix",
},
Expand All @@ -81,7 +81,7 @@ def test_run(self):

mock_post.return_value = mock_response

embedder = JinaTextEmbedder(api_key="fake-api-key", model_name=model, prefix="prefix ", suffix=" suffix")
embedder = JinaTextEmbedder(api_key="fake-api-key", model=model, prefix="prefix ", suffix=" suffix")
result = embedder.run(text="The food was delicious")

assert len(result["embedding"]) == 3
Expand Down

0 comments on commit 00a55b2

Please sign in to comment.