diff --git a/src/voyage_embedders/__about__.py b/src/voyage_embedders/__about__.py new file mode 100644 index 0000000..16829bb --- /dev/null +++ b/src/voyage_embedders/__about__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2023-present Ashwin Mathur <> +# +# SPDX-License-Identifier: Apache-2.0 +__version__ = "1.0.0" diff --git a/src/voyage_embedders/__init__.py b/src/voyage_embedders/__init__.py new file mode 100644 index 0000000..05e74ad --- /dev/null +++ b/src/voyage_embedders/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2023-present John Doe +# +# SPDX-License-Identifier: Apache-2.0 + +from voyage_embedders.voyage_document_embedder import VoyageDocumentEmbedder +from voyage_embedders.voyage_text_embedder import VoyageTextEmbedder + +__all__ = ["VoyageDocumentEmbedder", "VoyageTextEmbedder"] diff --git a/src/voyage_embedders/voyage_document_embedder.py b/src/voyage_embedders/voyage_document_embedder.py new file mode 100644 index 0000000..e2f943b --- /dev/null +++ b/src/voyage_embedders/voyage_document_embedder.py @@ -0,0 +1,154 @@ +import os +from typing import Any, Dict, List, Optional + +import voyageai +from haystack.preview import Document, component, default_to_dict +from tqdm import tqdm +from voyageai import get_embeddings + +MAX_BATCH_SIZE = 8 + + +@component +class VoyageDocumentEmbedder: + """ + A component for computing Document embeddings using Voyage Embedding models. + The embedding of each Document is stored in the `embedding` field of the Document. + + Usage example: + ```python + from haystack.preview import Document + from haystack.preview.components.embedders import VoyageDocumentEmbedder + + doc = Document(text="I love pizza!") + + document_embedder = VoyageDocumentEmbedder() + + result = document_embedder.run([doc]) + print(result['documents'][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "voyage-01", + prefix: str = "", + suffix: str = "", + batch_size: int = 8, + metadata_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + progress_bar: bool = True, # noqa + ): + """ + Create a VoyageDocumentEmbedder component. + :param api_key: The VoyageAI API key. It can be explicitly provided or automatically read from the + environment variable VOYAGE_API_KEY (recommended). + :param model_name: The name of the model to use. Defaults to "voyage-01". + For more details on the available models, + see [Voyage Embeddings documentation](https://docs.voyageai.com/embeddings/). + :param prefix: A string to add to the beginning of each text. + :param suffix: A string to add to the end of each text. + :param batch_size: Number of Documents to encode at once. + :param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text. + :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + :param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments + to keep the logs clean. + """ + # if the user does not provide the API key, check if it is set in the module client + api_key = api_key or voyageai.api_key + if api_key is None: + try: + api_key = os.environ["VOYAGE_API_KEY"] + except KeyError as e: + msg = "VoyageDocumentEmbedder expects an VoyageAI API key. Set the VOYAGE_API_KEY environment variable (recommended) or pass it explicitly." # noqa + raise ValueError(msg) from e + + voyageai.api_key = api_key + + self.model_name = model_name + self.prefix = prefix + self.suffix = suffix + + if batch_size <= MAX_BATCH_SIZE: + self.batch_size = batch_size + else: + err_msg = f"""VoyageDocumentEmbedder has a maximum batch size of {MAX_BATCH_SIZE}. Set the Set the batch_size to {MAX_BATCH_SIZE} or less.""" # noqa + raise ValueError(err_msg) + + self.progress_bar = progress_bar + self.metadata_fields_to_embed = metadata_fields_to_embed or [] + self.embedding_separator = embedding_separator + + def to_dict(self) -> Dict[str, Any]: + """ + This method overrides the default serializer in order to avoid leaking the `api_key` value passed + to the constructor. + """ + return default_to_dict( + self, + model_name=self.model_name, + prefix=self.prefix, + suffix=self.suffix, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + metadata_fields_to_embed=self.metadata_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) + for key in self.metadata_fields_to_embed + if key in doc.meta and doc.meta[key] is not None + ] + + text_to_embed = ( + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix + ) + + texts_to_embed.append(text_to_embed) + return texts_to_embed + + def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[float]]: + """ + Embed a list of texts in batches. + """ + + all_embeddings = [] + for i in tqdm( + range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" + ): + batch = texts_to_embed[i : i + batch_size] + embeddings = get_embeddings(list_of_text=batch, batch_size=batch_size, model=self.model_name) + all_embeddings.extend(embeddings) + + return all_embeddings + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + """ + Embed a list of Documents. + The embedding of each Document is stored in the `embedding` field of the Document. + + :param documents: A list of Documents to embed. + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + msg = "VoyageDocumentEmbedder expects a list of Documents as input.In case you want to embed a string, please use the VoyageTextEmbedder." # noqa + raise TypeError(msg) + + texts_to_embed = self._prepare_texts_to_embed(documents=documents) + + embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents} diff --git a/src/voyage_embedders/voyage_text_embedder.py b/src/voyage_embedders/voyage_text_embedder.py new file mode 100644 index 0000000..2192871 --- /dev/null +++ b/src/voyage_embedders/voyage_text_embedder.py @@ -0,0 +1,80 @@ +import os +from typing import Any, Dict, List, Optional + +import voyageai +from haystack.preview import component, default_to_dict +from voyageai import get_embedding + + +@component +class VoyageTextEmbedder: + """ + A component for embedding strings using Voyage models. + + Usage example: + ```python + from haystack.preview.components.embedders import VoyageTextEmbedder + + text_to_embed = "I love pizza!" + + text_embedder = VoyageTextEmbedder() + + print(text_embedder.run(text_to_embed)) + + # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], + ``` + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "voyage-01", + prefix: str = "", + suffix: str = "", + ): + """ + Create an VoyageTextEmbedder component. + + :param api_key: The VoyageAI API key. It can be explicitly provided or automatically read from the + environment variable VOYAGE_API_KEY (recommended). + :param model_name: The name of the Voyage model to use. Defaults to "voyage-01". + For more details on the available models, + see [Voyage Embeddings documentation](https://docs.voyageai.com/embeddings/). + :param prefix: A string to add to the beginning of each text. + :param suffix: A string to add to the end of each text. + """ + # if the user does not provide the API key, check if it is set in the module client + api_key = api_key or voyageai.api_key + if api_key is None: + try: + api_key = os.environ["VOYAGE_API_KEY"] + except KeyError as e: + msg = "VoyageTextEmbedder expects an VoyageAI API key. Set the VOYAGE_API_KEY environment variable (recommended) or pass it explicitly." # noqa + raise ValueError(msg) from e + + voyageai.api_key = api_key + + self.model_name = model_name + self.prefix = prefix + self.suffix = suffix + + def to_dict(self) -> Dict[str, Any]: + """ + This method overrides the default serializer in order to avoid leaking the `api_key` value passed + to the constructor. + """ + + return default_to_dict(self, model_name=self.model_name, prefix=self.prefix, suffix=self.suffix) + + @component.output_types(embedding=List[float]) + def run(self, text: str): + """Embed a string.""" + if not isinstance(text, str): + msg = "VoyageTextEmbedder expects a string as an input.In case you want to embed a list of Documents, please use the VoyageDocumentEmbedder." # noqa + raise TypeError(msg) + + text_to_embed = self.prefix + text + self.suffix + + embedding = get_embedding(text=text_to_embed, model=self.model_name) + + return {"embedding": embedding} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..1cb56b7 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present Ashwin Mathur <> +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_voyage_document_embedder.py b/tests/test_voyage_document_embedder.py new file mode 100644 index 0000000..770731c --- /dev/null +++ b/tests/test_voyage_document_embedder.py @@ -0,0 +1,257 @@ +from typing import List +from unittest.mock import patch + +import numpy as np +import pytest +import voyageai +from haystack.preview import Document + +from voyage_embedders.voyage_document_embedder import VoyageDocumentEmbedder + + +def mock_voyageai_response(list_of_text: List[str], model: str = "voyage-01", **kwargs) -> List[List[float]]: # noqa + response = [np.random.rand(1024).tolist() for i in range(len(list_of_text))] + return response + + +class TestVoyageDocumentEmbedder: + @pytest.mark.unit + def test_init_default(self, monkeypatch): + voyageai.api_key = None + monkeypatch.setenv("VOYAGE_API_KEY", "fake-api-key") + embedder = VoyageDocumentEmbedder() + + assert voyageai.api_key == "fake-api-key" + + assert embedder.model_name == "voyage-01" + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.batch_size == 8 + assert embedder.progress_bar is True + assert embedder.metadata_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + + @pytest.mark.unit + def test_init_with_parameters(self): + embedder = VoyageDocumentEmbedder( + api_key="fake-api-key", + model_name="model", + prefix="prefix", + suffix="suffix", + batch_size=4, + progress_bar=False, + metadata_fields_to_embed=["test_field"], + embedding_separator=" | ", + ) + assert voyageai.api_key == "fake-api-key" + + assert embedder.model_name == "model" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert embedder.batch_size == 4 + assert embedder.progress_bar is False + assert embedder.metadata_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " + + @pytest.mark.unit + def test_init_fail_wo_api_key(self, monkeypatch): + voyageai.api_key = None + monkeypatch.delenv("VOYAGE_API_KEY", raising=False) + with pytest.raises(ValueError, match="VoyageDocumentEmbedder expects an VoyageAI API key"): + VoyageDocumentEmbedder() + + @pytest.mark.unit + def test_to_dict(self): + component = VoyageDocumentEmbedder(api_key="fake-api-key") + data = component.to_dict() + assert data == { + "type": "VoyageDocumentEmbedder", + "init_parameters": { + "model_name": "voyage-01", + "prefix": "", + "suffix": "", + "batch_size": 8, + "progress_bar": True, + "metadata_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + component = VoyageDocumentEmbedder( + api_key="fake-api-key", + model_name="model", + prefix="prefix", + suffix="suffix", + batch_size=4, + progress_bar=False, + metadata_fields_to_embed=["test_field"], + embedding_separator=" | ", + ) + data = component.to_dict() + assert data == { + "type": "VoyageDocumentEmbedder", + "init_parameters": { + "model_name": "model", + "prefix": "prefix", + "suffix": "suffix", + "batch_size": 4, + "progress_bar": False, + "metadata_fields_to_embed": ["test_field"], + "embedding_separator": " | ", + }, + } + + @pytest.mark.unit + def test_prepare_texts_to_embed_w_metadata(self): + documents = [ + Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5) + ] + + embedder = VoyageDocumentEmbedder( + api_key="fake-api-key", metadata_fields_to_embed=["meta_field"], embedding_separator=" | " + ) + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + # note that newline is replaced by space + assert prepared_texts == [ + "meta_value 0 | document number 0: content", + "meta_value 1 | document number 1: content", + "meta_value 2 | document number 2: content", + "meta_value 3 | document number 3: content", + "meta_value 4 | document number 4: content", + ] + + @pytest.mark.unit + def test_prepare_texts_to_embed_w_suffix(self): + documents = [Document(content=f"document number {i}") for i in range(5)] + + embedder = VoyageDocumentEmbedder(api_key="fake-api-key", prefix="my_prefix ", suffix=" my_suffix") + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + assert prepared_texts == [ + "my_prefix document number 0 my_suffix", + "my_prefix document number 1 my_suffix", + "my_prefix document number 2 my_suffix", + "my_prefix document number 3 my_suffix", + "my_prefix document number 4 my_suffix", + ] + + @pytest.mark.unit + def test_embed_batch(self): + texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] + + with patch("voyage_embedders.voyage_document_embedder.get_embeddings") as voyageai_embedding_patch: + voyageai_embedding_patch.side_effect = mock_voyageai_response + embedder = VoyageDocumentEmbedder(api_key="fake-api-key", model_name="model") + + embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2) + + assert voyageai_embedding_patch.call_count == 3 + + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) + for embedding in embeddings: + assert isinstance(embedding, list) + assert len(embedding) == 1024 + assert all(isinstance(x, float) for x in embedding) + + @pytest.mark.unit + def test_run(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + model = "voyage-01-lite" + with patch("voyage_embedders.voyage_document_embedder.get_embeddings") as voyageai_embedding_patch: + voyageai_embedding_patch.side_effect = mock_voyageai_response + embedder = VoyageDocumentEmbedder( + api_key="fake-api-key", + model_name=model, + prefix="prefix ", + suffix=" suffix", + metadata_fields_to_embed=["topic"], + embedding_separator=" | ", + ) + + result = embedder.run(documents=docs) + + voyageai_embedding_patch.assert_called_once_with( + model=model, + list_of_text=[ + "prefix Cuisine | I love cheese suffix", + "prefix ML | A transformer is a deep learning architecture suffix", + ], + batch_size=8, + ) + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 1024 + assert all(isinstance(x, float) for x in doc.embedding) + + @pytest.mark.unit + def test_run_custom_batch_size(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + model = "voyage-01-lite" + with patch("voyage_embedders.voyage_document_embedder.get_embeddings") as voyageai_embedding_patch: + voyageai_embedding_patch.side_effect = mock_voyageai_response + embedder = VoyageDocumentEmbedder( + api_key="fake-api-key", + model_name=model, + prefix="prefix ", + suffix=" suffix", + metadata_fields_to_embed=["topic"], + embedding_separator=" | ", + batch_size=1, + ) + + result = embedder.run(documents=docs) + + assert voyageai_embedding_patch.call_count == 2 + + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 1024 + assert all(isinstance(x, float) for x in doc.embedding) + + @pytest.mark.unit + def test_run_wrong_input_format(self): + embedder = VoyageDocumentEmbedder(api_key="fake-api-key") + + # wrong formats + string_input = "text" + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="VoyageDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=string_input) + + with pytest.raises(TypeError, match="VoyageDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=list_integers_input) + + @pytest.mark.unit + def test_run_on_empty_list(self): + embedder = VoyageDocumentEmbedder(api_key="fake-api-key") + + empty_list_input = [] + result = embedder.run(documents=empty_list_input) + + assert result["documents"] is not None + assert not result["documents"] # empty list diff --git a/tests/test_voyage_text_embedder.py b/tests/test_voyage_text_embedder.py new file mode 100644 index 0000000..e972556 --- /dev/null +++ b/tests/test_voyage_text_embedder.py @@ -0,0 +1,101 @@ +from typing import List +from unittest.mock import patch + +import numpy as np +import pytest +import voyageai + +from voyage_embedders.voyage_text_embedder import VoyageTextEmbedder + + +def mock_voyageai_response(text: str, model: str = "voyage-01", **kwargs) -> List[float]: # noqa + response = np.random.rand(1024).tolist() + return response + + +class TestVoyageTextEmbedder: + @pytest.mark.unit + def test_init_default(self, monkeypatch): + voyageai.api_key = None + monkeypatch.setenv("VOYAGE_API_KEY", "fake-api-key") + embedder = VoyageTextEmbedder() + + assert voyageai.api_key == "fake-api-key" + assert embedder.model_name == "voyage-01" + assert embedder.prefix == "" + assert embedder.suffix == "" + + @pytest.mark.unit + def test_init_with_parameters(self): + embedder = VoyageTextEmbedder( + api_key="fake-api-key", + model_name="model", + prefix="prefix", + suffix="suffix", + ) + assert voyageai.api_key == "fake-api-key" + assert embedder.model_name == "model" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + + @pytest.mark.unit + def test_init_fail_wo_api_key(self, monkeypatch): + voyageai.api_key = None + monkeypatch.delenv("VOYAGE_API_KEY", raising=False) + with pytest.raises(ValueError, match="VoyageTextEmbedder expects an VoyageAI API key"): + VoyageTextEmbedder() + + @pytest.mark.unit + def test_to_dict(self): + component = VoyageTextEmbedder(api_key="fake-api-key") + data = component.to_dict() + assert data == { + "type": "VoyageTextEmbedder", + "init_parameters": { + "model_name": "voyage-01", + "prefix": "", + "suffix": "", + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + component = VoyageTextEmbedder( + api_key="fake-api-key", + model_name="model", + prefix="prefix", + suffix="suffix", + ) + data = component.to_dict() + assert data == { + "type": "VoyageTextEmbedder", + "init_parameters": { + "model_name": "model", + "prefix": "prefix", + "suffix": "suffix", + }, + } + + @pytest.mark.unit + def test_run(self): + model = "voyage-01-lite" + + with patch("voyage_embedders.voyage_text_embedder.get_embedding") as voyageai_embedding_patch: + voyageai_embedding_patch.side_effect = mock_voyageai_response + + embedder = VoyageTextEmbedder(api_key="fake-api-key", model_name=model, prefix="prefix ", suffix=" suffix") + result = embedder.run(text="The food was delicious") + + voyageai_embedding_patch.assert_called_once_with(model=model, text="prefix The food was delicious suffix") + + assert len(result["embedding"]) == 1024 + assert all(isinstance(x, float) for x in result["embedding"]) + + @pytest.mark.unit + def test_run_wrong_input_format(self): + embedder = VoyageTextEmbedder(api_key="fake-api-key") + + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="VoyageTextEmbedder expects a string as an input"): + embedder.run(text=list_integers_input)