From f34411630c8899920a2bab142d7364f120e72868 Mon Sep 17 00:00:00 2001 From: hsm207 Date: Mon, 9 Dec 2024 08:32:46 +0000 Subject: [PATCH] fix: enforce client parameter validation in WeaviateVectorStore Signed-off-by: hsm207 --- libs/weaviate/langchain_weaviate/vectorstores.py | 5 ++++- .../tests/unit_tests/test_vectorstores_unit.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/libs/weaviate/langchain_weaviate/vectorstores.py b/libs/weaviate/langchain_weaviate/vectorstores.py index 43d1f8f..29fa854 100644 --- a/libs/weaviate/langchain_weaviate/vectorstores.py +++ b/libs/weaviate/langchain_weaviate/vectorstores.py @@ -428,7 +428,7 @@ def from_texts( metadatas: Optional[List[dict]] = None, *, tenant: Optional[str] = None, - client: weaviate.WeaviateClient = None, + client: Optional[weaviate.WeaviateClient] = None, index_name: Optional[str] = None, text_key: str = "text", relevance_score_fn: Optional[ @@ -474,6 +474,9 @@ def from_texts( attributes = list(metadatas[0].keys()) if metadatas else None + if client is None: + raise ValueError("client must be an instance of WeaviateClient") + weaviate_vector_store = cls( client, index_name, diff --git a/libs/weaviate/tests/unit_tests/test_vectorstores_unit.py b/libs/weaviate/tests/unit_tests/test_vectorstores_unit.py index 281ce16..9f32fa0 100644 --- a/libs/weaviate/tests/unit_tests/test_vectorstores_unit.py +++ b/libs/weaviate/tests/unit_tests/test_vectorstores_unit.py @@ -5,6 +5,7 @@ import pytest from langchain_weaviate.vectorstores import ( + WeaviateVectorStore, _default_score_normalizer, _json_serializable, ) @@ -29,3 +30,12 @@ def test_json_serializable( expected_result: Union[str, int, None], ) -> None: assert _json_serializable(value) == expected_result + + +def test_from_texts_raises_value_error_when_client_is_none(): + with pytest.raises( + ValueError, match="client must be an instance of WeaviateClient" + ): + WeaviateVectorStore.from_texts( + texts=["sample text"], embedding=None, client=None + )