From e58d68e229324e10a7f622d60c8d7517f5e88217 Mon Sep 17 00:00:00 2001 From: willtai Date: Tue, 14 May 2024 15:16:11 +0100 Subject: [PATCH] Add new types for validating inputs to retrievers init (#31) --- src/neo4j_genai/retrievers/hybrid.py | 64 +++++++++++++++++++++++----- src/neo4j_genai/retrievers/vector.py | 49 ++++++++++++++++++--- src/neo4j_genai/types.py | 64 +++++++++++++++++++++++++++- tests/unit/retrievers/test_hybrid.py | 15 +++++++ tests/unit/retrievers/test_vector.py | 15 ++++++- 5 files changed, 187 insertions(+), 20 deletions(-) diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index fea96a2d..f9342781 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -19,7 +19,15 @@ from neo4j_genai.embedder import Embedder from neo4j_genai.retrievers.base import Retriever -from neo4j_genai.types import HybridSearchModel, SearchType, HybridCypherSearchModel +from neo4j_genai.types import ( + HybridSearchModel, + SearchType, + HybridCypherSearchModel, + Neo4jDriverModel, + EmbedderModel, + HybridRetrieverModel, + HybridCypherRetrieverModel, +) from neo4j_genai.neo4j_queries import get_search_query import logging @@ -35,11 +43,28 @@ def __init__( embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, ) -> None: - super().__init__(driver) - self.vector_index_name = vector_index_name - self.fulltext_index_name = fulltext_index_name - self.embedder = embedder - self.return_properties = return_properties + try: + driver_model = Neo4jDriverModel(driver=driver) + embedder_model = EmbedderModel(embedder=embedder) if embedder else None + validated_data = HybridRetrieverModel( + driver_model=driver_model, + vector_index_name=vector_index_name, + fulltext_index_name=fulltext_index_name, + embedder_model=embedder_model, + return_properties=return_properties, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + super().__init__(validated_data.driver_model.driver) + self.vector_index_name = validated_data.vector_index_name + self.fulltext_index_name = validated_data.fulltext_index_name + self.return_properties = validated_data.return_properties + self.embedder = ( + validated_data.embedder_model.embedder + if validated_data.embedder_model + else None + ) def search( self, @@ -102,11 +127,28 @@ def __init__( retrieval_query: str, embedder: Optional[Embedder] = None, ) -> None: - super().__init__(driver) - self.vector_index_name = vector_index_name - self.fulltext_index_name = fulltext_index_name - self.retrieval_query = retrieval_query - self.embedder = embedder + try: + driver_model = Neo4jDriverModel(driver=driver) + embedder_model = EmbedderModel(embedder=embedder) if embedder else None + validated_data = HybridCypherRetrieverModel( + driver_model=driver_model, + vector_index_name=vector_index_name, + fulltext_index_name=fulltext_index_name, + retrieval_query=retrieval_query, + embedder_model=embedder_model, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + super().__init__(validated_data.driver_model.driver) + self.vector_index_name = validated_data.vector_index_name + self.fulltext_index_name = validated_data.fulltext_index_name + self.retrieval_query = validated_data.retrieval_query + self.embedder = ( + validated_data.embedder_model.embedder + if validated_data.embedder_model + else None + ) def search( self, diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index 32314801..3307af26 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -24,6 +24,10 @@ VectorSearchModel, VectorCypherSearchModel, SearchType, + Neo4jDriverModel, + EmbedderModel, + VectorRetrieverModel, + VectorCypherRetrieverModel, ) from neo4j_genai.neo4j_queries import get_search_query import logging @@ -44,10 +48,26 @@ def __init__( embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, ) -> None: + try: + driver_model = Neo4jDriverModel(driver=driver) + embedder_model = EmbedderModel(embedder=embedder) if embedder else None + validated_data = VectorRetrieverModel( + driver_model=driver_model, + index_name=index_name, + embedder_model=embedder_model, + return_properties=return_properties, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + super().__init__(driver) - self.index_name = index_name - self.return_properties = return_properties - self.embedder = embedder + self.index_name = validated_data.index_name + self.return_properties = validated_data.return_properties + self.embedder = ( + validated_data.embedder_model.embedder + if validated_data.embedder_model + else None + ) self._node_label = None self._embedding_node_property = None self._embedding_dimension = None @@ -138,10 +158,26 @@ def __init__( retrieval_query: str, embedder: Optional[Embedder] = None, ) -> None: + try: + driver_model = Neo4jDriverModel(driver=driver) + embedder_model = EmbedderModel(embedder=embedder) if embedder else None + validated_data = VectorCypherRetrieverModel( + driver_model=driver_model, + index_name=index_name, + retrieval_query=retrieval_query, + embedder_model=embedder_model, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + super().__init__(driver) - self.index_name = index_name - self.retrieval_query = retrieval_query - self.embedder = embedder + self.index_name = validated_data.index_name + self.retrieval_query = validated_data.retrieval_query + self.embedder = ( + validated_data.embedder_model.embedder + if validated_data.embedder_model + else None + ) self._node_label = None self._node_embedding_property = None self._embedding_dimension = None @@ -166,6 +202,7 @@ def search( query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. top_k (int, optional): The number of neighbors to return. Defaults to 5. query_params (Optional[dict[str, Any]], optional): Parameters for the Cypher query. Defaults to None. + filters (Optional[dict[str, Any]], optional): Filters for metadata pre-filtering.. Defaults to None. Raises: ValueError: If validation of the input arguments fail. diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 357bb44e..f13e6748 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -14,7 +14,13 @@ # limitations under the License. from enum import Enum from typing import Any, Literal, Optional -from pydantic import BaseModel, PositiveInt, model_validator, field_validator +from pydantic import ( + BaseModel, + PositiveInt, + model_validator, + field_validator, + ConfigDict, +) import neo4j @@ -93,3 +99,59 @@ class SearchType(str, Enum): VECTOR = "vector" HYBRID = "hybrid" + + +class EmbedderModel(BaseModel): + embedder: Optional[Any] + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("embedder") + def check_embedder(cls, value): + if not hasattr(value, "embed_query") or not callable( + getattr(value, "embed_query", None) + ): + raise ValueError( + "Provided embedder object must have an 'embed_query' callable method." + ) + return value + + +class Neo4jDriverModel(BaseModel): + driver: neo4j.Driver + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("driver") + def check_driver(cls, value): + if not isinstance(value, neo4j.Driver): + raise ValueError("Provided driver needs to be of type neo4j.Driver") + return value + + +class VectorRetrieverModel(BaseModel): + driver_model: Neo4jDriverModel + index_name: str + embedder_model: Optional[EmbedderModel] = None + return_properties: Optional[list[str]] = None + + +class VectorCypherRetrieverModel(BaseModel): + driver_model: Neo4jDriverModel + index_name: str + retrieval_query: str + embedder_model: Optional[EmbedderModel] = None + + +class HybridRetrieverModel(BaseModel): + driver_model: Neo4jDriverModel + vector_index_name: str + fulltext_index_name: str + embedder_model: Optional[EmbedderModel] = None + return_properties: Optional[list[str]] = None + + +class HybridCypherRetrieverModel(BaseModel): + driver_model: Neo4jDriverModel + vector_index_name: str + fulltext_index_name: str + retrieval_query: str + embedder_model: Optional[EmbedderModel] = None diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index 79486835..dff2ebcf 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -43,6 +43,21 @@ def test_vector_cypher_retriever_initialization(driver): mock_verify.assert_called_once() +def test_hybrid_retriever_bad_data_validation(driver): + with pytest.raises(ValueError): + HybridRetriever(driver=driver, vector_index_name=42, fulltext_index_name=42) + + +def test_hybrid_cypher_retriever_bad_data_validation(driver): + with pytest.raises(ValueError): + HybridCypherRetriever( + driver=driver, + vector_index_name="my-index", + fulltext_index_name="fulltext-index", + retrieval_query=42, + ) + + @patch("neo4j_genai.HybridRetriever._verify_version") def test_hybrid_search_text_happy_path(_verify_version_mock, driver): embed_query_vector = [1.0 for _ in range(1536)] diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index c3fd1ade..8b388fce 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -18,6 +18,7 @@ from neo4j.exceptions import CypherSyntaxError from neo4j_genai import VectorRetriever, VectorCypherRetriever +from neo4j_genai.embedder import Embedder from neo4j_genai.neo4j_queries import get_search_query from neo4j_genai.types import SearchType, VectorSearchRecord @@ -28,6 +29,16 @@ def test_vector_retriever_initialization(driver): mock_verify.assert_called_once() +def test_vector_retriever_bad_data_validation(driver): + with pytest.raises(ValueError): + VectorRetriever(driver=driver, index_name=42) + + +def test_vector_cypher_retriever_bad_data_validation(driver): + with pytest.raises(ValueError): + VectorCypherRetriever(driver=driver, index_name="my-index", retrieval_query=42) + + def test_vector_cypher_retriever_initialization(driver): with patch("neo4j_genai.retrievers.base.Retriever._verify_version") as mock_verify: VectorCypherRetriever(driver=driver, index_name="my-index", retrieval_query="") @@ -70,7 +81,7 @@ def test_similarity_search_text_happy_path( _verify_version_mock, _fetch_index_infos, driver ): embed_query_vector = [1.0 for _ in range(1536)] - custom_embeddings = MagicMock() + custom_embeddings = MagicMock(spec=Embedder) custom_embeddings.embed_query.return_value = embed_query_vector index_name = "my-index" query_text = "may thy knife chip and shatter" @@ -104,7 +115,7 @@ def test_similarity_search_text_return_properties( _verify_version_mock, _fetch_index_infos, driver ): embed_query_vector = [1.0 for _ in range(3)] - custom_embeddings = MagicMock() + custom_embeddings = MagicMock(spec=Embedder) custom_embeddings.embed_query.return_value = embed_query_vector index_name = "my-index" query_text = "may thy knife chip and shatter"