From 52474ed0537b9ee81eddfa9c8807902c1989aea8 Mon Sep 17 00:00:00 2001 From: Shukri Date: Mon, 25 Mar 2024 17:26:07 +0100 Subject: [PATCH] Throw error when vector and query are None (#142) These are okay: ```python weaviate_vector_store._perform_search(query="hello", vector=None, k=5) weaviate_vector_store._perform_search(query=None, vector=[1, 2, 3], k=5) ``` These are not okay: ```python weaviate_vector_store._perform_search(query=None, k=5) weaviate_vector_store._perform_search(query=None, vector=None, k=5) ``` Signed-off-by: hsm207 --------- Signed-off-by: hsm207 --- langchain_weaviate/vectorstores.py | 15 ++++++++----- tests/integration_tests/test_vectorstores.py | 22 ++++++++++++++++++++ 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/langchain_weaviate/vectorstores.py b/langchain_weaviate/vectorstores.py index 1d94641..5706ece 100644 --- a/langchain_weaviate/vectorstores.py +++ b/langchain_weaviate/vectorstores.py @@ -230,10 +230,10 @@ def _perform_search( if self._embedding is None: raise ValueError("_embedding cannot be None for similarity_search") - if "return_metadata" in kwargs and "score" not in kwargs["return_metadata"]: - kwargs["return_metadata"].append("score") - else: + if "return_metadata" not in kwargs: kwargs["return_metadata"] = ["score"] + elif "score" not in kwargs["return_metadata"]: + kwargs["return_metadata"].append("score") if ( "return_properties" in kwargs @@ -241,9 +241,14 @@ def _perform_search( ): kwargs["return_properties"].append(self._text_key) - # workaround to handle test_max_marginal_relevance_search vector = kwargs.pop("vector", None) - if vector is None and query is not None: + if vector is None and query is None: + # raise an error because weaviate will do a fetch object query + # if both query and vector are None + raise ValueError("Either query or vector must be provided.") + + # workaround to handle test_max_marginal_relevance_search + if vector is None: vector = self._embedding.embed_query(query) return_uuids = kwargs.pop("return_uuids", False) diff --git a/tests/integration_tests/test_vectorstores.py b/tests/integration_tests/test_vectorstores.py index baeb7a5..c1aaba9 100644 --- a/tests/integration_tests/test_vectorstores.py +++ b/tests/integration_tests/test_vectorstores.py @@ -802,3 +802,25 @@ def run_similarity_test(search_method): run_similarity_test("similarity_search") run_similarity_test("similarity_search_with_score") + + +def test_invalid_search_param(weaviate_client, embedding_openai): + index_name = f"TestIndex_{uuid.uuid4().hex}" + text_key = "page" + weaviate_vector_store = WeaviateVectorStore( + weaviate_client, index_name, text_key, embedding_openai + ) + + with pytest.raises(ValueError) as excinfo: + weaviate_vector_store._perform_search(query=None, k=5) + + assert str(excinfo.value) == "Either query or vector must be provided." + + with pytest.raises(ValueError) as excinfo: + weaviate_vector_store._perform_search(query=None, vector=None, k=5) + + assert str(excinfo.value) == "Either query or vector must be provided." + + weaviate_vector_store._perform_search(query="hello", vector=None, k=5) + + weaviate_vector_store._perform_search(query=None, vector=[1, 2, 3], k=5)