diff --git a/integrations/weaviate/docker-compose.yml b/integrations/weaviate/docker-compose.yml index 24d1daaa3..c61b0ed57 100644 --- a/integrations/weaviate/docker-compose.yml +++ b/integrations/weaviate/docker-compose.yml @@ -19,4 +19,4 @@ services: PERSISTENCE_DATA_PATH: '/var/lib/weaviate' DEFAULT_VECTORIZER_MODULE: 'none' ENABLE_MODULES: '' - CLUSTER_HOSTNAME: 'node1' \ No newline at end of file + CLUSTER_HOSTNAME: 'node1' diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 4c15d707e..492b1826d 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -6,7 +6,7 @@ from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses.document import Document -from haystack.document_stores.protocol import DuplicatePolicy +from haystack.document_stores.types.policy import DuplicatePolicy import weaviate from weaviate.auth import AuthCredentials @@ -25,6 +25,20 @@ "weaviate.auth.AuthApiKey": weaviate.auth.AuthApiKey, } +# This is the default collection properties for Weaviate. +# It's a list of properties that will be created on the collection. +# These are extremely similar to the Document dataclass, but with a few differences: +# - `id` is renamed to `_original_id` as the `id` field is reserved by Weaviate. +# - `blob` is split into `blob_data` and `blob_mime_type` as it's more efficient to store them separately. +DOCUMENT_COLLECTION_PROPERTIES = [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, +] + class WeaviateDocumentStore: """ @@ -35,7 +49,7 @@ def __init__( self, *, url: Optional[str] = None, - collection_name: str = "default", + collection_settings: Optional[Dict[str, Any]] = None, auth_client_secret: Optional[AuthCredentials] = None, timeout_config: TimeoutType = (10, 60), proxies: Optional[Union[Dict, str]] = None, @@ -49,6 +63,16 @@ def __init__( Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance. :param url: The URL to the weaviate instance, defaults to None. + :param collection_settings: The collection settings to use, defaults to None. + If None it will use a collection named `default` with the following properties: + - _original_id: text + - content: text + - dataframe: text + - blob_data: blob + - blob_mime_type: text + - score: number + See the official `Weaviate documentation`_ + for more information on collections. :param auth_client_secret: Authentication credentials, defaults to None. Can be one of the following types depending on the authentication mode: - `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens @@ -80,8 +104,6 @@ def __init__( :param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None. For a full list of options see `weaviate.embedded.EmbeddedOptions`. :param additional_config: Additional and advanced configuration options for weaviate, defaults to None. - :param collection_name: The name of the collection to use, defaults to "default". - If the collection does not exist it will be created. """ self._client = weaviate.Client( url=url, @@ -98,11 +120,22 @@ def __init__( # Test connection, it will raise an exception if it fails. self._client.schema.get() - if not self._client.schema.exists(collection_name): - self._client.schema.create_class({"class": collection_name}) + if collection_settings is None: + collection_settings = { + "class": "Default", + "properties": DOCUMENT_COLLECTION_PROPERTIES, + } + else: + # Set the class if not set + collection_settings["class"] = collection_settings.get("class", "default").capitalize() + # Set the properties if they're not set + collection_settings["properties"] = collection_settings.get("properties", DOCUMENT_COLLECTION_PROPERTIES) + + if not self._client.schema.exists(collection_settings["class"]): + self._client.schema.create_class(collection_settings) self._url = url - self._collection_name = collection_name + self._collection_settings = collection_settings self._auth_client_secret = auth_client_secret self._timeout_config = timeout_config self._proxies = proxies @@ -124,7 +157,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, url=self._url, - collection_name=self._collection_name, + collection_settings=self._collection_settings, auth_client_secret=auth_client_secret, timeout_config=self._timeout_config, proxies=self._proxies, @@ -161,7 +194,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc return [] def write_documents( - self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE # noqa: ARG002 + self, + documents: List[Document], + policy: DuplicatePolicy = DuplicatePolicy.NONE, # noqa: ARG002 ) -> int: return 0 diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 0666151ee..d5b1a2380 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -1,6 +1,10 @@ from unittest.mock import MagicMock, patch -from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore +import pytest +from haystack_integrations.document_stores.weaviate.document_store import ( + DOCUMENT_COLLECTION_PROPERTIES, + WeaviateDocumentStore, +) from weaviate.auth import AuthApiKey from weaviate.config import Config from weaviate.embedded import ( @@ -13,6 +17,17 @@ class TestWeaviateDocumentStore: + @pytest.fixture + def document_store(self, request) -> WeaviateDocumentStore: + # Use a different index for each test so we can run them in parallel + collection_settings = {"class": f"{request.node.name}"} + store = WeaviateDocumentStore( + url="http://localhost:8080", + collection_settings=collection_settings, + ) + yield store + store._client.schema.delete_class(collection_settings["class"]) + @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client") def test_init(self, mock_weaviate_client_class): mock_client = MagicMock() @@ -21,7 +36,7 @@ def test_init(self, mock_weaviate_client_class): WeaviateDocumentStore( url="http://localhost:8080", - collection_name="my_collection", + collection_settings={"class": "My_collection"}, auth_client_secret=AuthApiKey("my_api_key"), proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, @@ -54,14 +69,15 @@ def test_init(self, mock_weaviate_client_class): # Verify collection is created mock_client.schema.get.assert_called_once() - mock_client.schema.exists.assert_called_once_with("my_collection") - mock_client.schema.create_class.assert_called_once_with({"class": "my_collection"}) + mock_client.schema.exists.assert_called_once_with("My_collection") + mock_client.schema.create_class.assert_called_once_with( + {"class": "My_collection", "properties": DOCUMENT_COLLECTION_PROPERTIES} + ) @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") def test_to_dict(self, _mock_weaviate): document_store = WeaviateDocumentStore( url="http://localhost:8080", - collection_name="my_collection", auth_client_secret=AuthApiKey("my_api_key"), proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, @@ -77,7 +93,17 @@ def test_to_dict(self, _mock_weaviate): "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", "init_parameters": { "url": "http://localhost:8080", - "collection_name": "my_collection", + "collection_settings": { + "class": "Default", + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, "auth_client_secret": { "type": "weaviate.auth.AuthApiKey", "init_parameters": {"api_key": "my_api_key"}, @@ -113,7 +139,7 @@ def test_from_dict(self, _mock_weaviate): "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", "init_parameters": { "url": "http://localhost:8080", - "collection_name": "my_collection", + "collection_settings": None, "auth_client_secret": { "type": "weaviate.auth.AuthApiKey", "init_parameters": {"api_key": "my_api_key"}, @@ -144,7 +170,17 @@ def test_from_dict(self, _mock_weaviate): ) assert document_store._url == "http://localhost:8080" - assert document_store._collection_name == "my_collection" + assert document_store._collection_settings == { + "class": "Default", + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + } assert document_store._auth_client_secret == AuthApiKey("my_api_key") assert document_store._timeout_config == (10, 60) assert document_store._proxies == {"http": "http://proxy:1234"}