diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py index 4696db5c8..d1f0c0227 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -import os from typing import Any, Dict, List, Optional, Tuple import requests -from haystack import Document, component, default_to_dict +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace from tqdm import tqdm JINA_API_URL: str = "https://api.jina.ai/v1/embeddings" @@ -35,7 +35,7 @@ class JinaDocumentEmbedder: def __init__( self, - api_key: Optional[str] = None, + api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008 model: str = "jina-embeddings-v2-base-en", prefix: str = "", suffix: str = "", @@ -46,8 +46,7 @@ def __init__( ): """ Create a JinaDocumentEmbedder component. - :param api_key: The Jina API key. It can be explicitly provided or automatically read from the - environment variable JINA_API_KEY (recommended). + :param api_key: The Jina API key. :param model: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/` :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. @@ -57,16 +56,15 @@ def __init__( :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text. :param embedding_separator: Separator used to concatenate the meta fields to the Document text. """ - - api_key = api_key or os.environ.get("JINA_API_KEY") - # we check whether api_key is None or an empty string - if not api_key: + resolved_api_key = api_key.resolve_value() + if resolved_api_key is None: msg = ( "JinaDocumentEmbedder expects an API key. " "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." ) raise ValueError(msg) + self.api_key = api_key self.model_name = model self.prefix = prefix self.suffix = suffix @@ -77,7 +75,7 @@ def __init__( self._session = requests.Session() self._session.headers.update( { - "Authorization": f"Bearer {api_key}", + "Authorization": f"Bearer {resolved_api_key}", "Accept-Encoding": "identity", "Content-type": "application/json", } @@ -96,6 +94,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, + api_key=self.api_key.to_dict(), model=self.model_name, prefix=self.prefix, suffix=self.suffix, @@ -105,6 +104,11 @@ def to_dict(self) -> Dict[str, Any]: embedding_separator=self.embedding_separator, ) + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "JinaDocumentEmbedder": + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: """ Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py index 3f18aa037..83c3b2c43 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import requests -from haystack import component, default_to_dict +from haystack import component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace JINA_API_URL: str = "https://api.jina.ai/v1/embeddings" @@ -33,7 +33,7 @@ class JinaTextEmbedder: def __init__( self, - api_key: Optional[str] = None, + api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008 model: str = "jina-embeddings-v2-base-en", prefix: str = "", suffix: str = "", @@ -48,22 +48,22 @@ def __init__( :param suffix: A string to add to the end of each text. """ - api_key = api_key or os.environ.get("JINA_API_KEY") - # we check whether api_key is None or an empty string - if not api_key: + resolved_api_key = api_key.resolve_value() + if resolved_api_key is None: msg = ( "JinaTextEmbedder expects an API key. " "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." ) raise ValueError(msg) + self.api_key = api_key self.model_name = model self.prefix = prefix self.suffix = suffix self._session = requests.Session() self._session.headers.update( { - "Authorization": f"Bearer {api_key}", + "Authorization": f"Bearer {resolved_api_key}", "Accept-Encoding": "identity", "Content-type": "application/json", } @@ -81,7 +81,14 @@ def to_dict(self) -> Dict[str, Any]: to the constructor. """ - return default_to_dict(self, model=self.model_name, prefix=self.prefix, suffix=self.suffix) + return default_to_dict( + self, api_key=self.api_key.to_dict(), model=self.model_name, prefix=self.prefix, suffix=self.suffix + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "JinaTextEmbedder": + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) @component.output_types(embedding=List[float], meta=Dict[str, Any]) def run(self, text: str): diff --git a/integrations/jina/tests/test_document_embedder.py b/integrations/jina/tests/test_document_embedder.py index 4dd91860e..a9ba23ec0 100644 --- a/integrations/jina/tests/test_document_embedder.py +++ b/integrations/jina/tests/test_document_embedder.py @@ -7,6 +7,7 @@ import pytest import requests from haystack import Document +from haystack.utils import Secret from haystack_integrations.components.embedders.jina import JinaDocumentEmbedder @@ -28,6 +29,7 @@ def test_init_default(self, monkeypatch): monkeypatch.setenv("JINA_API_KEY", "fake-api-key") embedder = JinaDocumentEmbedder() + assert embedder.api_key == Secret.from_env_var("JINA_API_KEY") assert embedder.model_name == "jina-embeddings-v2-base-en" assert embedder.prefix == "" assert embedder.suffix == "" @@ -38,7 +40,7 @@ def test_init_default(self, monkeypatch): def test_init_with_parameters(self): embedder = JinaDocumentEmbedder( - api_key="fake-api-key", + api_key=Secret.from_token("fake-api-key"), model="model", prefix="prefix", suffix="suffix", @@ -47,6 +49,8 @@ def test_init_with_parameters(self): meta_fields_to_embed=["test_field"], embedding_separator=" | ", ) + + assert embedder.api_key == Secret.from_token("fake-api-key") assert embedder.model_name == "model" assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" @@ -60,12 +64,14 @@ def test_init_fail_wo_api_key(self, monkeypatch): with pytest.raises(ValueError): JinaDocumentEmbedder() - def test_to_dict(self): - component = JinaDocumentEmbedder(api_key="fake-api-key") + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "fake-api-key") + component = JinaDocumentEmbedder() data = component.to_dict() assert data == { "type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder", "init_parameters": { + "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, "model": "jina-embeddings-v2-base-en", "prefix": "", "suffix": "", @@ -76,9 +82,9 @@ def test_to_dict(self): }, } - def test_to_dict_with_custom_init_parameters(self): + def test_to_dict_with_custom_init_parameters(self, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "fake-api-key") component = JinaDocumentEmbedder( - api_key="fake-api-key", model="model", prefix="prefix", suffix="suffix", @@ -91,6 +97,7 @@ def test_to_dict_with_custom_init_parameters(self): assert data == { "type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder", "init_parameters": { + "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, "model": "model", "prefix": "prefix", "suffix": "suffix", @@ -107,7 +114,7 @@ def test_prepare_texts_to_embed_w_metadata(self): ] embedder = JinaDocumentEmbedder( - api_key="fake-api-key", meta_fields_to_embed=["meta_field"], embedding_separator=" | " + api_key=Secret.from_token("fake-api-key"), meta_fields_to_embed=["meta_field"], embedding_separator=" | " ) prepared_texts = embedder._prepare_texts_to_embed(documents) @@ -124,7 +131,9 @@ def test_prepare_texts_to_embed_w_metadata(self): def test_prepare_texts_to_embed_w_suffix(self): documents = [Document(content=f"document number {i}") for i in range(5)] - embedder = JinaDocumentEmbedder(api_key="fake-api-key", prefix="my_prefix ", suffix=" my_suffix") + embedder = JinaDocumentEmbedder( + api_key=Secret.from_token("fake-api-key"), prefix="my_prefix ", suffix=" my_suffix" + ) prepared_texts = embedder._prepare_texts_to_embed(documents) @@ -140,7 +149,7 @@ def test_embed_batch(self): texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): - embedder = JinaDocumentEmbedder(api_key="fake-api-key", model="model") + embedder = JinaDocumentEmbedder(api_key=Secret.from_token("fake-api-key"), model="model") embeddings, metadata = embedder._embed_batch(texts_to_embed=texts, batch_size=2) @@ -162,7 +171,7 @@ def test_run(self): model = "jina-embeddings-v2-base-en" with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): embedder = JinaDocumentEmbedder( - api_key="fake-api-key", + api_key=Secret.from_token("fake-api-key"), model=model, prefix="prefix ", suffix=" suffix", @@ -192,7 +201,7 @@ def test_run_custom_batch_size(self): model = "jina-embeddings-v2-base-en" with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): embedder = JinaDocumentEmbedder( - api_key="fake-api-key", + api_key=Secret.from_token("fake-api-key"), model=model, prefix="prefix ", suffix=" suffix", @@ -217,7 +226,7 @@ def test_run_custom_batch_size(self): assert metadata == {"model": model, "usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}} def test_run_wrong_input_format(self): - embedder = JinaDocumentEmbedder(api_key="fake-api-key") + embedder = JinaDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) string_input = "text" list_integers_input = [1, 2, 3] @@ -229,7 +238,7 @@ def test_run_wrong_input_format(self): embedder.run(documents=list_integers_input) def test_run_on_empty_list(self): - embedder = JinaDocumentEmbedder(api_key="fake-api-key") + embedder = JinaDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) empty_list_input = [] result = embedder.run(documents=empty_list_input) diff --git a/integrations/jina/tests/test_text_embedder.py b/integrations/jina/tests/test_text_embedder.py index a4f6fd934..7cb669c68 100644 --- a/integrations/jina/tests/test_text_embedder.py +++ b/integrations/jina/tests/test_text_embedder.py @@ -6,6 +6,7 @@ import pytest import requests +from haystack.utils import Secret from haystack_integrations.components.embedders.jina import JinaTextEmbedder @@ -14,17 +15,19 @@ def test_init_default(self, monkeypatch): monkeypatch.setenv("JINA_API_KEY", "fake-api-key") embedder = JinaTextEmbedder() + assert embedder.api_key == Secret.from_env_var("JINA_API_KEY") assert embedder.model_name == "jina-embeddings-v2-base-en" assert embedder.prefix == "" assert embedder.suffix == "" def test_init_with_parameters(self): embedder = JinaTextEmbedder( - api_key="fake-api-key", + api_key=Secret.from_token("fake-api-key"), model="model", prefix="prefix", suffix="suffix", ) + assert embedder.api_key == Secret.from_token("fake-api-key") assert embedder.model_name == "model" assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" @@ -34,21 +37,23 @@ def test_init_fail_wo_api_key(self, monkeypatch): with pytest.raises(ValueError): JinaTextEmbedder() - def test_to_dict(self): - component = JinaTextEmbedder(api_key="fake-api-key") + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "fake-api-key") + component = JinaTextEmbedder() data = component.to_dict() assert data == { "type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder", "init_parameters": { + "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, "model": "jina-embeddings-v2-base-en", "prefix": "", "suffix": "", }, } - def test_to_dict_with_custom_init_parameters(self): + def test_to_dict_with_custom_init_parameters(self, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "fake-api-key") component = JinaTextEmbedder( - api_key="fake-api-key", model="model", prefix="prefix", suffix="suffix", @@ -57,6 +62,7 @@ def test_to_dict_with_custom_init_parameters(self): assert data == { "type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder", "init_parameters": { + "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, "model": "model", "prefix": "prefix", "suffix": "suffix", @@ -80,7 +86,9 @@ def test_run(self): mock_post.return_value = mock_response - embedder = JinaTextEmbedder(api_key="fake-api-key", model=model, prefix="prefix ", suffix=" suffix") + embedder = JinaTextEmbedder( + api_key=Secret.from_token("fake-api-key"), model=model, prefix="prefix ", suffix=" suffix" + ) result = embedder.run(text="The food was delicious") assert len(result["embedding"]) == 3 @@ -91,7 +99,7 @@ def test_run(self): } def test_run_wrong_input_format(self): - embedder = JinaTextEmbedder(api_key="fake-api-key") + embedder = JinaTextEmbedder(api_key=Secret.from_token("fake-api-key")) list_integers_input = [1, 2, 3] diff --git a/integrations/pgvector/README.md b/integrations/pgvector/README.md index 277c732f4..a2d325c54 100644 --- a/integrations/pgvector/README.md +++ b/integrations/pgvector/README.md @@ -20,12 +20,23 @@ pip install pgvector-haystack ## Testing -TODO +Ensure that you have a PostgreSQL running with the `pgvector` extension. For a quick setup using Docker, run: +``` +docker run -d -p 5432:5432 -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres ankane/pgvector +``` + +then run the tests: ```console hatch run test ``` +To run the coverage report: + +```console +hatch run cov +``` + ## License `pgvector-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/pgvector/examples/example.py b/integrations/pgvector/examples/example.py index 14c2cba60..764c915d1 100644 --- a/integrations/pgvector/examples/example.py +++ b/integrations/pgvector/examples/example.py @@ -11,6 +11,7 @@ # git clone https://github.com/anakin87/neural-search-pills import glob +import os from haystack import Pipeline from haystack.components.converters import MarkdownToDocument @@ -20,9 +21,10 @@ from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore +os.environ["PG_CONN_STR"] = "postgresql://postgres:postgres@localhost:5432/postgres" + # Initialize PgvectorDocumentStore document_store = PgvectorDocumentStore( - connection_string="postgresql://postgres:postgres@localhost:5432/postgres", table_name="haystack_test", embedding_dimension=768, vector_function="cosine_similarity", diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml index 65ded967f..178d9f7e8 100644 --- a/integrations/pgvector/pyproject.toml +++ b/integrations/pgvector/pyproject.toml @@ -138,6 +138,8 @@ ignore = [ "S105", "S106", "S107", # Ignore complexity "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + # ignore function-call-in-default-argument + "B008", ] unfixable = [ # Don't touch unused imports @@ -156,23 +158,22 @@ ban-relative-imports = "parents" # examples can contain "print" commands "examples/**/*" = ["T201"] + [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true parallel = true - -[tool.coverage.paths] -weaviate_haystack = ["src/haystack_integrations", "*/pgvector-haystack/src"] -tests = ["tests", "*/pgvector-haystack/tests"] - [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [[tool.mypy.overrides]] module = [ "haystack.*", diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py index 26807e9bd..4b8df868b 100644 --- a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py @@ -68,9 +68,8 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever": - data["init_parameters"]["document_store"] = default_from_dict( - PgvectorDocumentStore, data["init_parameters"]["document_store"] - ) + doc_store_params = data["init_parameters"]["document_store"] + data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 097e86c7e..798c75276 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -4,10 +4,11 @@ import logging from typing import Any, Dict, List, Literal, Optional -from haystack import default_to_dict +from haystack import default_from_dict, default_to_dict from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy +from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.filters import convert from psycopg import Error, IntegrityError, connect from psycopg.abc import Query @@ -69,7 +70,7 @@ class PgvectorDocumentStore: def __init__( self, *, - connection_string: str, + connection_string: Secret = Secret.from_env_var("PG_CONN_STR"), table_name: str = "haystack_documents", embedding_dimension: int = 768, vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] = "cosine_similarity", @@ -84,8 +85,8 @@ def __init__( It is meant to be connected to a PostgreSQL database with the pgvector extension installed. A specific table to store Haystack documents will be created if it doesn't exist yet. - :param connection_string: The connection string to use to connect to the PostgreSQL database. - e.g. "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME" + :param connection_string: The connection string to use to connect to the PostgreSQL database, defined as an + environment variable, e.g.: PG_CONN_STR="postgresql://USER:PASSWORD@HOST:PORT/DB_NAME" :param table_name: The name of the table to use to store Haystack documents. Defaults to "haystack_documents". :param embedding_dimension: The dimension of the embedding. Defaults to 768. :param vector_function: The similarity function to use when searching for similar embeddings. @@ -130,7 +131,7 @@ def __init__( self.hnsw_index_creation_kwargs = hnsw_index_creation_kwargs or {} self.hnsw_ef_search = hnsw_ef_search - connection = connect(connection_string) + connection = connect(self.connection_string.resolve_value()) connection.autocommit = True self._connection = connection @@ -151,7 +152,7 @@ def __init__( def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, - connection_string=self.connection_string, + connection_string=self.connection_string.to_dict(), table_name=self.table_name, embedding_dimension=self.embedding_dimension, vector_function=self.vector_function, @@ -162,6 +163,11 @@ def to_dict(self) -> Dict[str, Any]: hnsw_ef_search=self.hnsw_ef_search, ) + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PgvectorDocumentStore": + deserialize_secrets_inplace(data["init_parameters"], ["connection_string"]) + return default_from_dict(cls, data) + def _execute_sql( self, sql_query: Query, params: Optional[tuple] = None, error_msg: str = "", cursor: Optional[Cursor] = None ): @@ -221,7 +227,7 @@ def _handle_hnsw(self): ) self._execute_sql(sql_set_hnsw_ef_search, error_msg="Could not set hnsw.ef_search") - index_esists = bool( + index_exists = bool( self._execute_sql( "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", (self.table_name, HNSW_INDEX_NAME), @@ -229,7 +235,7 @@ def _handle_hnsw(self): ).fetchone() ) - if index_esists and not self.hnsw_recreate_index_if_exists: + if index_exists and not self.hnsw_recreate_index_if_exists: logger.warning( "HNSW index already exists and won't be recreated. " "If you want to recreate it, pass 'hnsw_recreate_index_if_exists=True' to the " @@ -373,7 +379,8 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D return written_docs - def _from_haystack_to_pg_documents(self, documents: List[Document]) -> List[Dict[str, Any]]: + @staticmethod + def _from_haystack_to_pg_documents(documents: List[Document]) -> List[Dict[str, Any]]: """ Internal method to convert a list of Haystack Documents to a list of dictionaries that can be used to insert documents into the PgvectorDocumentStore. @@ -395,7 +402,8 @@ def _from_haystack_to_pg_documents(self, documents: List[Document]) -> List[Dict return db_documents - def _from_pg_to_haystack_documents(self, documents: List[Dict[str, Any]]) -> List[Document]: + @staticmethod + def _from_pg_to_haystack_documents(documents: List[Dict[str, Any]]) -> List[Document]: """ Internal method to convert a list of dictionaries from pgvector to a list of Haystack Documents. """ diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py index 743e8de14..068f2ac54 100644 --- a/integrations/pgvector/tests/conftest.py +++ b/integrations/pgvector/tests/conftest.py @@ -1,10 +1,12 @@ +import os + import pytest from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore @pytest.fixture def document_store(request): - connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" + os.environ["PG_CONN_STR"] = "postgresql://postgres:postgres@localhost:5432/postgres" table_name = f"haystack_{request.node.name}" embedding_dimension = 768 vector_function = "cosine_similarity" @@ -12,13 +14,13 @@ def document_store(request): search_strategy = "exact_nearest_neighbor" store = PgvectorDocumentStore( - connection_string=connection_string, table_name=table_name, embedding_dimension=embedding_dimension, vector_function=vector_function, recreate_table=recreate_table, search_strategy=search_strategy, ) + yield store store.delete_table() diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index e8d9107d7..1e158f134 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -41,7 +41,6 @@ def test_write_dataframe(self, document_store: PgvectorDocumentStore): def test_init(self): document_store = PgvectorDocumentStore( - connection_string="postgresql://postgres:postgres@localhost:5432/postgres", table_name="my_table", embedding_dimension=512, vector_function="l2_distance", @@ -52,7 +51,6 @@ def test_init(self): hnsw_ef_search=50, ) - assert document_store.connection_string == "postgresql://postgres:postgres@localhost:5432/postgres" assert document_store.table_name == "my_table" assert document_store.embedding_dimension == 512 assert document_store.vector_function == "l2_distance" @@ -64,7 +62,6 @@ def test_init(self): def test_to_dict(self): document_store = PgvectorDocumentStore( - connection_string="postgresql://postgres:postgres@localhost:5432/postgres", table_name="my_table", embedding_dimension=512, vector_function="l2_distance", @@ -78,7 +75,7 @@ def test_to_dict(self): assert document_store.to_dict() == { "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { - "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "my_table", "embedding_dimension": 512, "vector_function": "l2_distance", diff --git a/integrations/pgvector/tests/test_retriever.py b/integrations/pgvector/tests/test_retriever.py index cca6bbc9f..8eab10de5 100644 --- a/integrations/pgvector/tests/test_retriever.py +++ b/integrations/pgvector/tests/test_retriever.py @@ -4,6 +4,7 @@ from unittest.mock import Mock from haystack.dataclasses import Document +from haystack.utils.auth import EnvVarSecret from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore @@ -37,7 +38,7 @@ def test_to_dict(self, document_store: PgvectorDocumentStore): "document_store": { "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { - "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "haystack_test_to_dict", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -62,7 +63,7 @@ def test_from_dict(self): "document_store": { "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { - "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "haystack_test_to_dict", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -83,7 +84,7 @@ def test_from_dict(self): document_store = retriever.document_store assert isinstance(document_store, PgvectorDocumentStore) - assert document_store.connection_string == "postgresql://postgres:postgres@localhost:5432/postgres" + assert isinstance(document_store.connection_string, EnvVarSecret) assert document_store.table_name == "haystack_test_to_dict" assert document_store.embedding_dimension == 768 assert document_store.vector_function == "cosine_similarity" diff --git a/integrations/pinecone/examples/example.py b/integrations/pinecone/examples/example.py index b2a534452..a10b951b5 100644 --- a/integrations/pinecone/examples/example.py +++ b/integrations/pinecone/examples/example.py @@ -15,8 +15,9 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.preprocessors import DocumentSplitter from haystack.components.writers import DocumentWriter -from pinecone_haystack import PineconeDocumentStore -from pinecone_haystack.dense_retriever import PineconeEmbeddingRetriever + +from haystack_integrations.components.retrievers.pinecone import PineconeEmbeddingRetriever +from haystack_integrations.document_stores.pinecone import PineconeDocumentStore file_paths = glob.glob("neural-search-pills/pills/*.md") diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index 58a3534c4..1db58ea0d 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "qdrant-client"] +dependencies = ["haystack-ai>=2.0.0b6", "qdrant-client"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations" diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index 50dd0220c..4a47bf59e 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -10,6 +10,7 @@ from haystack.dataclasses import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.filters import convert from qdrant_client import grpc from qdrant_client.http import models as rest @@ -55,7 +56,7 @@ def __init__( grpc_port: int = 6334, prefer_grpc: bool = False, # noqa: FBT001, FBT002 https: Optional[bool] = None, - api_key: Optional[str] = None, + api_key: Optional[Secret] = None, prefix: Optional[str] = None, timeout: Optional[float] = None, host: Optional[str] = None, @@ -94,7 +95,7 @@ def __init__( grpc_port=grpc_port, prefer_grpc=prefer_grpc, https=https, - api_key=api_key, + api_key=api_key.resolve_value() if api_key else None, prefix=prefix, timeout=timeout, host=host, @@ -115,6 +116,7 @@ def __init__( self.host = host self.path = path self.metadata = metadata + self.api_key = api_key # Store the Qdrant collection specific attributes self.shard_number = shard_number @@ -232,6 +234,7 @@ def delete_documents(self, ids: List[str]): @classmethod def from_dict(cls, data: Dict[str, Any]) -> "QdrantDocumentStore": + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def to_dict(self) -> Dict[str, Any]: @@ -239,6 +242,7 @@ def to_dict(self) -> Dict[str, Any]: # All the __init__ params must be set as attributes # Set as init_parms without default values init_params = {k: getattr(self, k) for k in params} + init_params["api_key"] = self.api_key.to_dict() if self.api_key else None return default_to_dict( self, **init_params, diff --git a/integrations/qdrant/tests/test_dict_converters.py b/integrations/qdrant/tests/test_dict_converters.py index 1c9eb36e2..18940fbbf 100644 --- a/integrations/qdrant/tests/test_dict_converters.py +++ b/integrations/qdrant/tests/test_dict_converters.py @@ -1,3 +1,4 @@ +from haystack.utils import Secret from haystack_integrations.document_stores.qdrant import QdrantDocumentStore @@ -52,6 +53,7 @@ def test_from_dict(): { "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", "init_parameters": { + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "location": ":memory:", "index": "test", "embedding_dim": 768, @@ -98,5 +100,6 @@ def test_from_dict(): document_store.metadata == {}, document_store.write_batch_size == 1000, document_store.scroll_size == 10000, + document_store.api_key == Secret.from_env_var("ENV_VAR", strict=False), ] ) diff --git a/integrations/weaviate/pydoc/config.yml b/integrations/weaviate/pydoc/config.yml index fa59e6874..84334c2e6 100644 --- a/integrations/weaviate/pydoc/config.yml +++ b/integrations/weaviate/pydoc/config.yml @@ -1,9 +1,12 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../src] - modules: [ - "haystack_integrations.document_stores.weaviate.document_store", - ] + modules: + [ + "haystack_integrations.document_stores.weaviate.document_store", + "haystack_integrations.components.retrievers.weaviate.bm25_retriever", + "haystack_integrations.components.retrievers.weaviate.embedding_retriever", + ] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index fb132516c..00aa500e6 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", @@ -28,6 +26,7 @@ dependencies = [ "haystack-ai", "weaviate-client==3.*", "haystack-pydoc-tools", + "python-dateutil", ] [project.urls] @@ -47,51 +46,25 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/weaviate-v[0-9]*"' [tool.hatch.envs.default] -dependencies = [ - "coverage[toml]>=6.5", - "pytest", - "ipython", -] +dependencies = ["coverage[toml]>=6.5", "pytest", "ipython"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.black] target-version = ["py37"] @@ -134,9 +107,15 @@ ignore = [ # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -164,11 +143,7 @@ weaviate_haystack = ["src/haystack_integrations", "*/weaviate-haystack/src"] tests = ["tests", "*/weaviate-haystack/tests"] [tool.coverage.report] -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ @@ -177,6 +152,6 @@ module = [ "pytest.*", "weaviate.*", "numpy", - "grpc" + "grpc", ] ignore_missing_imports = true diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py new file mode 100644 index 000000000..34bfd0c7d --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py @@ -0,0 +1,4 @@ +from .bm25_retriever import WeaviateBM25Retriever +from .embedding_retriever import WeaviateEmbeddingRetriever + +__all__ = ["WeaviateBM25Retriever", "WeaviateEmbeddingRetriever"] diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py new file mode 100644 index 000000000..6c27378cf --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +@component +class WeaviateBM25Retriever: + """ + Retriever that uses BM25 to find the most promising documents for a given query. + """ + + def __init__( + self, + *, + document_store: WeaviateDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ): + """ + Create a new instance of WeaviateBM25Retriever. + + :param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever. + :param filters: Custom filters applied when running the retriever, defaults to None + :param top_k: Maximum number of documents to return, defaults to 10 + """ + self._document_store = document_store + self._filters = filters or {} + self._top_k = top_k + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "WeaviateBM25Retriever": + data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + filters = filters or self._filters + top_k = top_k or self._top_k + return self._document_store._bm25_retrieval(query=query, filters=filters, top_k=top_k) diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py new file mode 100644 index 000000000..b8a163b56 --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py @@ -0,0 +1,80 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +@component +class WeaviateEmbeddingRetriever: + """ + A retriever that uses Weaviate's vector search to find similar documents based on the embeddings of the query. + """ + + def __init__( + self, + *, + document_store: WeaviateDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + distance: Optional[float] = None, + certainty: Optional[float] = None, + ): + """ + Create a new instance of WeaviateEmbeddingRetriever. + Raises ValueError if both `distance` and `certainty` are provided. + See the official Weaviate documentation to learn more about the `distance` and `certainty` parameters: + https://weaviate.io/developers/weaviate/api/graphql/search-operators#variables + + :param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever. + :param filters: Custom filters applied when running the retriever, defaults to None + :param top_k: Maximum number of documents to return, defaults to 10 + :param distance: The maximum allowed distance between Documents' embeddings, defaults to None + :param certainty: Normalized distance between the result item and the search vector, defaults to None + """ + if distance is not None and certainty is not None: + msg = "Can't use 'distance' and 'certainty' parameters together" + raise ValueError(msg) + + self._document_store = document_store + self._filters = filters or {} + self._top_k = top_k + self._distance = distance + self._certainty = certainty + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + distance=self._distance, + certainty=self._certainty, + document_store=self._document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "WeaviateEmbeddingRetriever": + data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + distance: Optional[float] = None, + certainty: Optional[float] = None, + ): + filters = filters or self._filters + top_k = top_k or self._top_k + distance = distance or self._distance + certainty = certainty or self._certainty + return self._document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + distance=distance, + certainty=certainty, + ) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py new file mode 100644 index 000000000..a192c6947 --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py @@ -0,0 +1,279 @@ +from typing import Any, Dict + +from dateutil import parser +from haystack.errors import FilterError +from pandas import DataFrame + + +def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert filters from Haystack format to Weaviate format. + """ + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise FilterError(msg) + + if "field" in filters: + return {"operator": "And", "operands": [_parse_comparison_condition(filters)]} + return _parse_logical_condition(filters) + + +OPERATOR_INVERSE = { + "==": "!=", + "!=": "==", + ">": "<=", + ">=": "<", + "<": ">=", + "<=": ">", + "in": "not in", + "not in": "in", + "AND": "OR", + "OR": "AND", + "NOT": "AND", +} + + +def _invert_condition(filters: Dict[str, Any]) -> Dict[str, Any]: + """ + Invert condition recursively. + Weaviate doesn't support NOT filters so we need to invert them ourselves. + """ + inverted_condition = filters.copy() + if "operator" not in filters: + # Let's not handle this stuff in here, we'll fail later on anyway. + return inverted_condition + inverted_condition["operator"] = OPERATOR_INVERSE[filters["operator"]] + if "conditions" in filters: + inverted_condition["conditions"] = [] + for condition in filters["conditions"]: + inverted_condition["conditions"].append(_invert_condition(condition)) + + return inverted_condition + + +def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + operator = condition["operator"] + if operator in ["AND", "OR"]: + operands = [] + for c in condition["conditions"]: + if "field" not in c: + operands.append(_parse_logical_condition(c)) + else: + operands.append(_parse_comparison_condition(c)) + return {"operator": operator.lower().capitalize(), "operands": operands} + elif operator == "NOT": + inverted_conditions = _invert_condition(condition) + return _parse_logical_condition(inverted_conditions) + else: + msg = f"Unknown logical operator '{operator}'" + raise FilterError(msg) + + +def _infer_value_type(value: Any) -> str: + if value is None: + return "valueNull" + + if isinstance(value, bool): + return "valueBoolean" + if isinstance(value, int): + return "valueInt" + if isinstance(value, float): + return "valueNumber" + + if isinstance(value, str): + try: + parser.isoparse(value) + return "valueDate" + except ValueError: + return "valueText" + + msg = f"Unknown value type {type(value)}" + raise FilterError(msg) + + +def _handle_date(value: Any) -> str: + if isinstance(value, str): + try: + return parser.isoparse(value).strftime("%Y-%m-%dT%H:%M:%S.%fZ") + except ValueError: + pass + return value + + +def _equal(field: str, value: Any) -> Dict[str, Any]: + if value is None: + return {"path": field, "operator": "IsNull", "valueBoolean": True} + return {"path": field, "operator": "Equal", _infer_value_type(value): _handle_date(value)} + + +def _not_equal(field: str, value: Any) -> Dict[str, Any]: + if value is None: + return {"path": field, "operator": "IsNull", "valueBoolean": False} + return { + "operator": "Or", + "operands": [ + {"path": field, "operator": "NotEqual", _infer_value_type(value): _handle_date(value)}, + {"path": field, "operator": "IsNull", "valueBoolean": True}, + ], + } + + +def _greater_than(field: str, value: Any) -> Dict[str, Any]: + if value is None: + # When the value is None and '>' is used we create a filter that would return a Document + # if it has a field set and not set at the same time. + # This will cause the filter to match no Document. + # This way we keep the behavior consistent with other Document Stores. + return _match_no_document(field) + if isinstance(value, str): + try: + parser.isoparse(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, DataFrame]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + return {"path": field, "operator": "GreaterThan", _infer_value_type(value): _handle_date(value)} + + +def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: + if value is None: + # When the value is None and '>=' is used we create a filter that would return a Document + # if it has a field set and not set at the same time. + # This will cause the filter to match no Document. + # This way we keep the behavior consistent with other Document Stores. + return _match_no_document(field) + if isinstance(value, str): + try: + parser.isoparse(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, DataFrame]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + return {"path": field, "operator": "GreaterThanEqual", _infer_value_type(value): _handle_date(value)} + + +def _less_than(field: str, value: Any) -> Dict[str, Any]: + if value is None: + # When the value is None and '<' is used we create a filter that would return a Document + # if it has a field set and not set at the same time. + # This will cause the filter to match no Document. + # This way we keep the behavior consistent with other Document Stores. + return _match_no_document(field) + if isinstance(value, str): + try: + parser.isoparse(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, DataFrame]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + return {"path": field, "operator": "LessThan", _infer_value_type(value): _handle_date(value)} + + +def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: + if value is None: + # When the value is None and '<=' is used we create a filter that would return a Document + # if it has a field set and not set at the same time. + # This will cause the filter to match no Document. + # This way we keep the behavior consistent with other Document Stores. + return _match_no_document(field) + if isinstance(value, str): + try: + parser.isoparse(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, DataFrame]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + return {"path": field, "operator": "LessThanEqual", _infer_value_type(value): _handle_date(value)} + + +def _in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators" + raise FilterError(msg) + + return {"operator": "And", "operands": [_equal(field, v) for v in value]} + + +def _not_in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators" + raise FilterError(msg) + return {"operator": "And", "operands": [_not_equal(field, v) for v in value]} + + +COMPARISON_OPERATORS = { + "==": _equal, + "!=": _not_equal, + ">": _greater_than, + ">=": _greater_than_equal, + "<": _less_than, + "<=": _less_than_equal, + "in": _in, + "not in": _not_in, +} + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + field: str = condition["field"] + + if field.startswith("meta."): + # Documents are flattened otherwise we wouldn't be able to properly query them. + # We're forced to flatten because Weaviate doesn't support querying of nested properties + # as of now. If we don't flatten the documents we can't filter them. + # As time of writing this they have it in their backlog, see: + # https://github.com/weaviate/weaviate/issues/3694 + field = field.replace("meta.", "") + + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + value: Any = condition["value"] + if isinstance(value, DataFrame): + value = value.to_json() + + return COMPARISON_OPERATORS[operator](field, value) + + +def _match_no_document(field: str) -> Dict[str, Any]: + """ + Returns a filters that will match no Document, this is used to keep the behavior consistent + between different Document Stores. + """ + return { + "operator": "And", + "operands": [ + {"path": field, "operator": "IsNull", "valueBoolean": False}, + {"path": field, "operator": "IsNull", "valueBoolean": True}, + ], + } diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 3d658c316..38f0b38cd 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -16,6 +16,8 @@ from weaviate.embedded import EmbeddedOptions from weaviate.util import generate_uuid5 +from ._filters import convert_filters + Number = Union[int, float] TimeoutType = Union[Tuple[Number, Number], Number] @@ -33,6 +35,12 @@ # These are extremely similar to the Document dataclass, but with a few differences: # - `id` is renamed to `_original_id` as the `id` field is reserved by Weaviate. # - `blob` is split into `blob_data` and `blob_mime_type` as it's more efficient to store them separately. +# Blob meta is missing as it's not usually serialized when saving a Document as we rely on the Document own meta. +# +# Also the Document `meta` fields are omitted as we can't make assumptions on the structure of the meta field. +# We recommend the user to create a proper collection with the correct meta properties for their use case. +# We mostly rely on these defaults for testing purposes using Weaviate automatic schema generation, but that's not +# recommended for production use. DOCUMENT_COLLECTION_PROPERTIES = [ {"name": "_original_id", "dataType": ["text"]}, {"name": "content", "dataType": ["text"]}, @@ -74,8 +82,14 @@ def __init__( - blob_data: blob - blob_mime_type: text - score: number + The Document `meta` fields are omitted in the default collection settings as we can't make assumptions + on the structure of the meta field. + We heavily recommend to create a custom collection with the correct meta properties + for your use case. + Another option is relying on the automatic schema generation, but that's not recommended for + production use. See the official `Weaviate documentation`_ - for more information on collections. + for more information on collections and their properties. :param auth_client_secret: Authentication credentials, defaults to None. Can be one of the following types depending on the authentication mode: - `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens @@ -126,6 +140,7 @@ def __init__( if collection_settings is None: collection_settings = { "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, "properties": DOCUMENT_COLLECTION_PROPERTIES, } else: @@ -199,7 +214,7 @@ def _to_data_object(self, document: Document) -> Dict[str, Any]: """ Convert a Document to a Weviate data object ready to be saved. """ - data = document.to_dict(flatten=False) + data = document.to_dict() # Weaviate forces a UUID as an id. # We don't know if the id of our Document is a UUID or not, so we save it on a different field # and let Weaviate a UUID that we're going to ignore completely. @@ -213,10 +228,6 @@ def _to_data_object(self, document: Document) -> Dict[str, Any]: # The embedding vector is stored separately from the rest of the data del data["embedding"] - # Weaviate doesn't like empty objects, let's delete meta if it's empty - if data["meta"] == {}: - del data["meta"] - return data def _to_document(self, data: Dict[str, Any]) -> Document: @@ -241,7 +252,7 @@ def _to_document(self, data: Dict[str, Any]) -> Document: return Document.from_dict(data) - def _query(self, properties: List[str], batch_size: int, cursor=None): + def _query_paginated(self, properties: List[str], cursor=None): collection_name = self._collection_settings["class"] query = ( self._client.query.get( @@ -249,7 +260,7 @@ def _query(self, properties: List[str], batch_size: int, cursor=None): properties, ) .with_additional(["id vector"]) - .with_limit(batch_size) + .with_limit(100) ) if cursor: @@ -267,14 +278,39 @@ def _query(self, properties: List[str], batch_size: int, cursor=None): return result["data"]["Get"][collection_name] - def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: # noqa: ARG002 + def _query_with_filters(self, properties: List[str], filters: Dict[str, Any]) -> List[Dict[str, Any]]: + collection_name = self._collection_settings["class"] + query = ( + self._client.query.get( + collection_name, + properties, + ) + .with_additional(["id vector"]) + .with_where(convert_filters(filters)) + ) + + result = query.do() + + if "errors" in result: + errors = [e["message"] for e in result.get("errors", {})] + msg = "\n".join(errors) + msg = f"Failed to query documents in Weaviate. Errors:\n{msg}" + raise DocumentStoreError(msg) + + return result["data"]["Get"][collection_name] + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) properties = [prop["name"] for prop in properties] + if filters: + result = self._query_with_filters(properties, filters) + return [self._to_document(doc) for doc in result] + result = [] cursor = None - while batch := self._query(properties, 100, cursor): + while batch := self._query_paginated(properties, cursor): # Take the cursor before we convert the batch to Documents as we manipulate # the batch dictionary and might lose that information. cursor = batch[-1]["_additional"]["id"] @@ -383,3 +419,66 @@ def delete_documents(self, document_ids: List[str]) -> None: "valueTextArray": [generate_uuid5(doc_id) for doc_id in document_ids], }, ) + + def _bm25_retrieval( + self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None + ) -> List[Document]: + collection_name = self._collection_settings["class"] + properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) + properties = [prop["name"] for prop in properties] + + query_builder = ( + self._client.query.get(collection_name, properties=properties) + .with_bm25(query=query, properties=["content"]) + .with_additional(["vector"]) + ) + + if filters: + query_builder = query_builder.with_where(convert_filters(filters)) + + if top_k: + query_builder = query_builder.with_limit(top_k) + + result = query_builder.do() + + return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] + + def _embedding_retrieval( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + distance: Optional[float] = None, + certainty: Optional[float] = None, + ) -> List[Document]: + if distance is not None and certainty is not None: + msg = "Can't use 'distance' and 'certainty' parameters together" + raise ValueError(msg) + + collection_name = self._collection_settings["class"] + properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) + properties = [prop["name"] for prop in properties] + + near_vector: Dict[str, Union[float, List[float]]] = { + "vector": query_embedding, + } + if distance is not None: + near_vector["distance"] = distance + + if certainty is not None: + near_vector["certainty"] = certainty + + query_builder = ( + self._client.query.get(collection_name, properties=properties) + .with_near_vector(near_vector) + .with_additional(["vector"]) + ) + + if filters: + query_builder = query_builder.with_where(convert_filters(filters)) + + if top_k: + query_builder = query_builder.with_limit(top_k) + + result = query_builder.do() + return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] diff --git a/integrations/weaviate/tests/test_bm25_retriever.py b/integrations/weaviate/tests/test_bm25_retriever.py new file mode 100644 index 000000000..83f90735b --- /dev/null +++ b/integrations/weaviate/tests/test_bm25_retriever.py @@ -0,0 +1,102 @@ +from unittest.mock import Mock, patch + +from haystack_integrations.components.retrievers.weaviate import WeaviateBM25Retriever +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +def test_init_default(): + mock_document_store = Mock(spec=WeaviateDocumentStore) + retriever = WeaviateBM25Retriever(document_store=mock_document_store) + assert retriever._document_store == mock_document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_to_dict(_mock_weaviate): + document_store = WeaviateDocumentStore() + retriever = WeaviateBM25Retriever(document_store=document_store) + assert retriever.to_dict() == { + "type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_from_dict(_mock_weaviate): + retriever = WeaviateBM25Retriever.from_dict( + { + "type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + ) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + + +@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore") +def test_run(mock_document_store): + retriever = WeaviateBM25Retriever(document_store=mock_document_store) + query = "some query" + filters = {"field": "content", "operator": "==", "value": "Some text"} + retriever.run(query=query, filters=filters, top_k=5) + mock_document_store._bm25_retrieval.assert_called_once_with(query=query, filters=filters, top_k=5) diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 0682282f3..359af3670 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -1,14 +1,28 @@ import base64 +import random +from typing import List from unittest.mock import MagicMock, patch import pytest +from dateutil import parser from haystack.dataclasses.byte_stream import ByteStream from haystack.dataclasses.document import Document -from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest +from haystack.testing.document_store import ( + TEST_EMBEDDING_1, + TEST_EMBEDDING_2, + CountDocumentsTest, + DeleteDocumentsTest, + FilterDocumentsTest, + WriteDocumentsTest, +) from haystack_integrations.document_stores.weaviate.document_store import ( DOCUMENT_COLLECTION_PROPERTIES, WeaviateDocumentStore, ) +from numpy import array as np_array +from numpy import array_equal as np_array_equal +from numpy import float32 as np_float32 +from pandas import DataFrame from weaviate.auth import AuthApiKey from weaviate.config import Config from weaviate.embedded import ( @@ -20,11 +34,19 @@ ) -class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): +class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest): @pytest.fixture def document_store(self, request) -> WeaviateDocumentStore: # Use a different index for each test so we can run them in parallel - collection_settings = {"class": f"{request.node.name}"} + collection_settings = { + "class": f"{request.node.name}", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + *DOCUMENT_COLLECTION_PROPERTIES, + {"name": "number", "dataType": ["int"]}, + {"name": "date", "dataType": ["date"]}, + ], + } store = WeaviateDocumentStore( url="http://localhost:8080", collection_settings=collection_settings, @@ -32,6 +54,96 @@ def document_store(self, request) -> WeaviateDocumentStore: yield store store._client.schema.delete_class(collection_settings["class"]) + @pytest.fixture + def filterable_docs(self) -> List[Document]: + """ + This fixture has been copied from haystack/testing/document_store.py and modified to + use a different date format. + Weaviate forces RFC 3339 date strings. + The original fixture uses ISO 8601 date strings. + """ + documents = [] + for i in range(3): + documents.append( + Document( + content=f"A Foo Document {i}", + meta={ + "name": f"name_{i}", + "page": "100", + "chapter": "intro", + "number": 2, + "date": "1969-07-21T20:17:40Z", + }, + embedding=[random.random() for _ in range(768)], # noqa: S311 + ) + ) + documents.append( + Document( + content=f"A Bar Document {i}", + meta={ + "name": f"name_{i}", + "page": "123", + "chapter": "abstract", + "number": -2, + "date": "1972-12-11T19:54:58Z", + }, + embedding=[random.random() for _ in range(768)], # noqa: S311 + ) + ) + documents.append( + Document( + content=f"A Foobar Document {i}", + meta={ + "name": f"name_{i}", + "page": "90", + "chapter": "conclusion", + "number": -10, + "date": "1989-11-09T17:53:00Z", + }, + embedding=[random.random() for _ in range(768)], # noqa: S311 + ) + ) + documents.append( + Document( + content=f"Document {i} without embedding", + meta={"name": f"name_{i}", "no_embedding": True, "chapter": "conclusion"}, + ) + ) + documents.append(Document(dataframe=DataFrame([i]), meta={"name": f"table_doc_{i}"})) + documents.append( + Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) + ) + documents.append( + Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) + ) + return documents + + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + assert len(received) == len(expected) + received = sorted(received, key=lambda doc: doc.id) + expected = sorted(expected, key=lambda doc: doc.id) + for received_doc, expected_doc in zip(received, expected): + received_doc_dict = received_doc.to_dict(flatten=False) + expected_doc_dict = expected_doc.to_dict(flatten=False) + + # Weaviate stores embeddings with lower precision floats so we handle that here. + assert np_array_equal( + np_array(received_doc_dict.pop("embedding", None), dtype=np_float32), + np_array(expected_doc_dict.pop("embedding", None), dtype=np_float32), + equal_nan=True, + ) + + received_meta = received_doc_dict.pop("meta", None) + expected_meta = expected_doc_dict.pop("meta", None) + + assert received_doc_dict == expected_doc_dict + + # If a meta field is not set in a saved document, it will be None when retrieved + # from Weaviate so we need to handle that. + meta_keys = set(received_meta.keys()).union(set(expected_meta.keys())) + for key in meta_keys: + assert received_meta.get(key) == expected_meta.get(key) + @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client") def test_init(self, mock_weaviate_client_class): mock_client = MagicMock() @@ -99,6 +211,7 @@ def test_to_dict(self, _mock_weaviate): "url": "http://localhost:8080", "collection_settings": { "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, "properties": [ {"name": "_original_id", "dataType": ["text"]}, {"name": "content", "dataType": ["text"]}, @@ -176,6 +289,7 @@ def test_from_dict(self, _mock_weaviate): assert document_store._url == "http://localhost:8080" assert document_store._collection_settings == { "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, "properties": [ {"name": "_original_id", "dataType": ["text"]}, {"name": "content", "dataType": ["text"]}, @@ -202,10 +316,6 @@ def test_from_dict(self, _mock_weaviate): assert document_store._additional_config.connection_config.session_pool_connections == 20 assert document_store._additional_config.connection_config.session_pool_maxsize == 20 - def test_count_not_empty(self, document_store): - # Skipped for the time being as we don't support writing documents - pass - def test_to_data_object(self, document_store, test_files_path): doc = Document(content="test doc") data = document_store._to_data_object(doc) @@ -231,7 +341,7 @@ def test_to_data_object(self, document_store, test_files_path): "blob_mime_type": "image/jpeg", "dataframe": None, "score": None, - "meta": {"key": "value"}, + "key": "value", } def test_to_document(self, document_store, test_files_path): @@ -283,3 +393,229 @@ def test_filter_documents_with_blob_data(self, document_store, test_files_path): assert len(docs) == 1 assert docs[0].blob == image + + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): + """ + This test has been copied from haystack/testing/document_store.py and modified to + use a different date format. + Same reason as the filterable_docs fixture. + Weaviate forces RFC 3339 date strings and the filterable_docs use ISO 8601 date strings. + """ + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": ">", "value": "1972-12-11T19:54:58"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and parser.isoparse(d.meta["date"]) > parser.isoparse("1972-12-11T19:54:58Z") + ], + ) + + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): + """ + This test has been copied from haystack/testing/document_store.py and modified to + use a different date format. + Same reason as the filterable_docs fixture. + Weaviate forces RFC 3339 date strings and the filterable_docs use ISO 8601 date strings. + """ + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": ">=", "value": "1969-07-21T20:17:40"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and parser.isoparse(d.meta["date"]) >= parser.isoparse("1969-07-21T20:17:40Z") + ], + ) + + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): + """ + This test has been copied from haystack/testing/document_store.py and modified to + use a different date format. + Same reason as the filterable_docs fixture. + Weaviate forces RFC 3339 date strings and the filterable_docs use ISO 8601 date strings. + """ + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": "<", "value": "1969-07-21T20:17:40"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and parser.isoparse(d.meta["date"]) < parser.isoparse("1969-07-21T20:17:40Z") + ], + ) + + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): + """ + This test has been copied from haystack/testing/document_store.py and modified to + use a different date format. + Same reason as the filterable_docs fixture. + Weaviate forces RFC 3339 date strings and the filterable_docs use ISO 8601 date strings. + """ + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": "<=", "value": "1969-07-21T20:17:40"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and parser.isoparse(d.meta["date"]) <= parser.isoparse("1969-07-21T20:17:40Z") + ], + ) + + @pytest.mark.skip(reason="Weaviate for some reason is not returning what we expect") + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): + return super().test_comparison_not_equal_with_dataframe(document_store, filterable_docs) + + def test_bm25_retrieval(self, document_store): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + result = document_store._bm25_retrieval("functional Haskell") + assert len(result) == 5 + assert "functional" in result[0].content + assert "functional" in result[1].content + assert "functional" in result[2].content + assert "functional" in result[3].content + assert "functional" in result[4].content + + def test_bm25_retrieval_with_filters(self, document_store): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + filters = {"field": "content", "operator": "==", "value": "Haskell"} + result = document_store._bm25_retrieval("functional Haskell", filters=filters) + assert len(result) == 1 + assert "Haskell is a functional programming language" == result[0].content + + def test_bm25_retrieval_with_topk(self, document_store): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + result = document_store._bm25_retrieval("functional Haskell", top_k=3) + assert len(result) == 3 + assert "functional" in result[0].content + assert "functional" in result[1].content + assert "functional" in result[2].content + + def test_embedding_retrieval(self, document_store): + document_store.write_documents( + [ + Document( + content="Yet another document", + embedding=[0.00001, 0.00001, 0.00001, 0.00002], + ), + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + ] + ) + result = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0]) + assert len(result) == 3 + assert "The document" == result[0].content + assert "Another document" == result[1].content + assert "Yet another document" == result[2].content + + def test_embedding_retrieval_with_filters(self, document_store): + document_store.write_documents( + [ + Document( + content="Yet another document", + embedding=[0.00001, 0.00001, 0.00001, 0.00002], + ), + Document(content="The document I want", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + ] + ) + filters = {"field": "content", "operator": "==", "value": "The document I want"} + result = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], filters=filters) + assert len(result) == 1 + assert "The document I want" == result[0].content + + def test_embedding_retrieval_with_topk(self, document_store): + docs = [ + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), + ] + document_store.write_documents(docs) + results = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], top_k=2) + assert len(results) == 2 + assert results[0].content == "The document" + assert results[1].content == "Another document" + + def test_embedding_retrieval_with_distance(self, document_store): + docs = [ + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), + ] + document_store.write_documents(docs) + results = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], distance=0.0) + assert len(results) == 1 + assert results[0].content == "The document" + + def test_embedding_retrieval_with_certainty(self, document_store): + docs = [ + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), + ] + document_store.write_documents(docs) + results = document_store._embedding_retrieval(query_embedding=[0.8, 0.8, 0.8, 1.0], certainty=1.0) + assert len(results) == 1 + assert results[0].content == "Another document" + + def test_embedding_retrieval_with_distance_and_certainty(self, document_store): + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=[], distance=0.1, certainty=0.1) diff --git a/integrations/weaviate/tests/test_embedding_retriever.py b/integrations/weaviate/tests/test_embedding_retriever.py new file mode 100644 index 000000000..7f07d8a24 --- /dev/null +++ b/integrations/weaviate/tests/test_embedding_retriever.py @@ -0,0 +1,119 @@ +from unittest.mock import Mock, patch + +import pytest +from haystack_integrations.components.retrievers.weaviate import WeaviateEmbeddingRetriever +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +def test_init_default(): + mock_document_store = Mock(spec=WeaviateDocumentStore) + retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store) + assert retriever._document_store == mock_document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._distance is None + assert retriever._certainty is None + + +def test_init_with_distance_and_certainty(): + mock_document_store = Mock(spec=WeaviateDocumentStore) + with pytest.raises(ValueError): + WeaviateEmbeddingRetriever(document_store=mock_document_store, distance=0.1, certainty=0.8) + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_to_dict(_mock_weaviate): + document_store = WeaviateDocumentStore() + retriever = WeaviateEmbeddingRetriever(document_store=document_store) + assert retriever.to_dict() == { + "type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "distance": None, + "certainty": None, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_from_dict(_mock_weaviate): + retriever = WeaviateEmbeddingRetriever.from_dict( + { + "type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "distance": None, + "certainty": None, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + ) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._distance is None + assert retriever._certainty is None + + +@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore") +def test_run(mock_document_store): + retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store) + query_embedding = [0.1, 0.1, 0.1, 0.1] + filters = {"field": "content", "operator": "==", "value": "Some text"} + retriever.run(query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=0.1) + mock_document_store._embedding_retrieval.assert_called_once_with( + query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=0.1 + ) diff --git a/integrations/weaviate/tests/test_filters.py b/integrations/weaviate/tests/test_filters.py new file mode 100644 index 000000000..cf38d84be --- /dev/null +++ b/integrations/weaviate/tests/test_filters.py @@ -0,0 +1,34 @@ +from haystack_integrations.document_stores.weaviate._filters import _invert_condition + + +def test_invert_conditions(): + filters = { + "operator": "NOT", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.name", "operator": "==", "value": "name_0"}, + { + "operator": "OR", + "conditions": [ + {"field": "meta.name", "operator": "==", "value": "name_1"}, + {"field": "meta.name", "operator": "==", "value": "name_2"}, + ], + }, + ], + } + + inverted = _invert_condition(filters) + assert inverted == { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "!=", "value": 100}, + {"field": "meta.name", "operator": "!=", "value": "name_0"}, + { + "conditions": [ + {"field": "meta.name", "operator": "!=", "value": "name_1"}, + {"field": "meta.name", "operator": "!=", "value": "name_2"}, + ], + "operator": "AND", + }, + ], + }