Skip to content

Commit

Permalink
Add new types for validating inputs to retrievers init (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai authored May 14, 2024
1 parent e30fa26 commit e58d68e
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 20 deletions.
64 changes: 53 additions & 11 deletions src/neo4j_genai/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@

from neo4j_genai.embedder import Embedder
from neo4j_genai.retrievers.base import Retriever
from neo4j_genai.types import HybridSearchModel, SearchType, HybridCypherSearchModel
from neo4j_genai.types import (
HybridSearchModel,
SearchType,
HybridCypherSearchModel,
Neo4jDriverModel,
EmbedderModel,
HybridRetrieverModel,
HybridCypherRetrieverModel,
)
from neo4j_genai.neo4j_queries import get_search_query
import logging

Expand All @@ -35,11 +43,28 @@ def __init__(
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
super().__init__(driver)
self.vector_index_name = vector_index_name
self.fulltext_index_name = fulltext_index_name
self.embedder = embedder
self.return_properties = return_properties
try:
driver_model = Neo4jDriverModel(driver=driver)
embedder_model = EmbedderModel(embedder=embedder) if embedder else None
validated_data = HybridRetrieverModel(
driver_model=driver_model,
vector_index_name=vector_index_name,
fulltext_index_name=fulltext_index_name,
embedder_model=embedder_model,
return_properties=return_properties,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

super().__init__(validated_data.driver_model.driver)
self.vector_index_name = validated_data.vector_index_name
self.fulltext_index_name = validated_data.fulltext_index_name
self.return_properties = validated_data.return_properties
self.embedder = (
validated_data.embedder_model.embedder
if validated_data.embedder_model
else None
)

def search(
self,
Expand Down Expand Up @@ -102,11 +127,28 @@ def __init__(
retrieval_query: str,
embedder: Optional[Embedder] = None,
) -> None:
super().__init__(driver)
self.vector_index_name = vector_index_name
self.fulltext_index_name = fulltext_index_name
self.retrieval_query = retrieval_query
self.embedder = embedder
try:
driver_model = Neo4jDriverModel(driver=driver)
embedder_model = EmbedderModel(embedder=embedder) if embedder else None
validated_data = HybridCypherRetrieverModel(
driver_model=driver_model,
vector_index_name=vector_index_name,
fulltext_index_name=fulltext_index_name,
retrieval_query=retrieval_query,
embedder_model=embedder_model,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

super().__init__(validated_data.driver_model.driver)
self.vector_index_name = validated_data.vector_index_name
self.fulltext_index_name = validated_data.fulltext_index_name
self.retrieval_query = validated_data.retrieval_query
self.embedder = (
validated_data.embedder_model.embedder
if validated_data.embedder_model
else None
)

def search(
self,
Expand Down
49 changes: 43 additions & 6 deletions src/neo4j_genai/retrievers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
VectorSearchModel,
VectorCypherSearchModel,
SearchType,
Neo4jDriverModel,
EmbedderModel,
VectorRetrieverModel,
VectorCypherRetrieverModel,
)
from neo4j_genai.neo4j_queries import get_search_query
import logging
Expand All @@ -44,10 +48,26 @@ def __init__(
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
try:
driver_model = Neo4jDriverModel(driver=driver)
embedder_model = EmbedderModel(embedder=embedder) if embedder else None
validated_data = VectorRetrieverModel(
driver_model=driver_model,
index_name=index_name,
embedder_model=embedder_model,
return_properties=return_properties,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

super().__init__(driver)
self.index_name = index_name
self.return_properties = return_properties
self.embedder = embedder
self.index_name = validated_data.index_name
self.return_properties = validated_data.return_properties
self.embedder = (
validated_data.embedder_model.embedder
if validated_data.embedder_model
else None
)
self._node_label = None
self._embedding_node_property = None
self._embedding_dimension = None
Expand Down Expand Up @@ -138,10 +158,26 @@ def __init__(
retrieval_query: str,
embedder: Optional[Embedder] = None,
) -> None:
try:
driver_model = Neo4jDriverModel(driver=driver)
embedder_model = EmbedderModel(embedder=embedder) if embedder else None
validated_data = VectorCypherRetrieverModel(
driver_model=driver_model,
index_name=index_name,
retrieval_query=retrieval_query,
embedder_model=embedder_model,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

super().__init__(driver)
self.index_name = index_name
self.retrieval_query = retrieval_query
self.embedder = embedder
self.index_name = validated_data.index_name
self.retrieval_query = validated_data.retrieval_query
self.embedder = (
validated_data.embedder_model.embedder
if validated_data.embedder_model
else None
)
self._node_label = None
self._node_embedding_property = None
self._embedding_dimension = None
Expand All @@ -166,6 +202,7 @@ def search(
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
query_params (Optional[dict[str, Any]], optional): Parameters for the Cypher query. Defaults to None.
filters (Optional[dict[str, Any]], optional): Filters for metadata pre-filtering.. Defaults to None.
Raises:
ValueError: If validation of the input arguments fail.
Expand Down
64 changes: 63 additions & 1 deletion src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
# limitations under the License.
from enum import Enum
from typing import Any, Literal, Optional
from pydantic import BaseModel, PositiveInt, model_validator, field_validator
from pydantic import (
BaseModel,
PositiveInt,
model_validator,
field_validator,
ConfigDict,
)
import neo4j


Expand Down Expand Up @@ -93,3 +99,59 @@ class SearchType(str, Enum):

VECTOR = "vector"
HYBRID = "hybrid"


class EmbedderModel(BaseModel):
embedder: Optional[Any]
model_config = ConfigDict(arbitrary_types_allowed=True)

@field_validator("embedder")
def check_embedder(cls, value):
if not hasattr(value, "embed_query") or not callable(
getattr(value, "embed_query", None)
):
raise ValueError(
"Provided embedder object must have an 'embed_query' callable method."
)
return value


class Neo4jDriverModel(BaseModel):
driver: neo4j.Driver
model_config = ConfigDict(arbitrary_types_allowed=True)

@field_validator("driver")
def check_driver(cls, value):
if not isinstance(value, neo4j.Driver):
raise ValueError("Provided driver needs to be of type neo4j.Driver")
return value


class VectorRetrieverModel(BaseModel):
driver_model: Neo4jDriverModel
index_name: str
embedder_model: Optional[EmbedderModel] = None
return_properties: Optional[list[str]] = None


class VectorCypherRetrieverModel(BaseModel):
driver_model: Neo4jDriverModel
index_name: str
retrieval_query: str
embedder_model: Optional[EmbedderModel] = None


class HybridRetrieverModel(BaseModel):
driver_model: Neo4jDriverModel
vector_index_name: str
fulltext_index_name: str
embedder_model: Optional[EmbedderModel] = None
return_properties: Optional[list[str]] = None


class HybridCypherRetrieverModel(BaseModel):
driver_model: Neo4jDriverModel
vector_index_name: str
fulltext_index_name: str
retrieval_query: str
embedder_model: Optional[EmbedderModel] = None
15 changes: 15 additions & 0 deletions tests/unit/retrievers/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ def test_vector_cypher_retriever_initialization(driver):
mock_verify.assert_called_once()


def test_hybrid_retriever_bad_data_validation(driver):
with pytest.raises(ValueError):
HybridRetriever(driver=driver, vector_index_name=42, fulltext_index_name=42)


def test_hybrid_cypher_retriever_bad_data_validation(driver):
with pytest.raises(ValueError):
HybridCypherRetriever(
driver=driver,
vector_index_name="my-index",
fulltext_index_name="fulltext-index",
retrieval_query=42,
)


@patch("neo4j_genai.HybridRetriever._verify_version")
def test_hybrid_search_text_happy_path(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
Expand Down
15 changes: 13 additions & 2 deletions tests/unit/retrievers/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from neo4j.exceptions import CypherSyntaxError

from neo4j_genai import VectorRetriever, VectorCypherRetriever
from neo4j_genai.embedder import Embedder
from neo4j_genai.neo4j_queries import get_search_query
from neo4j_genai.types import SearchType, VectorSearchRecord

Expand All @@ -28,6 +29,16 @@ def test_vector_retriever_initialization(driver):
mock_verify.assert_called_once()


def test_vector_retriever_bad_data_validation(driver):
with pytest.raises(ValueError):
VectorRetriever(driver=driver, index_name=42)


def test_vector_cypher_retriever_bad_data_validation(driver):
with pytest.raises(ValueError):
VectorCypherRetriever(driver=driver, index_name="my-index", retrieval_query=42)


def test_vector_cypher_retriever_initialization(driver):
with patch("neo4j_genai.retrievers.base.Retriever._verify_version") as mock_verify:
VectorCypherRetriever(driver=driver, index_name="my-index", retrieval_query="")
Expand Down Expand Up @@ -70,7 +81,7 @@ def test_similarity_search_text_happy_path(
_verify_version_mock, _fetch_index_infos, driver
):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
custom_embeddings = MagicMock(spec=Embedder)
custom_embeddings.embed_query.return_value = embed_query_vector
index_name = "my-index"
query_text = "may thy knife chip and shatter"
Expand Down Expand Up @@ -104,7 +115,7 @@ def test_similarity_search_text_return_properties(
_verify_version_mock, _fetch_index_infos, driver
):
embed_query_vector = [1.0 for _ in range(3)]
custom_embeddings = MagicMock()
custom_embeddings = MagicMock(spec=Embedder)
custom_embeddings.embed_query.return_value = embed_query_vector
index_name = "my-index"
query_text = "may thy knife chip and shatter"
Expand Down

0 comments on commit e58d68e

Please sign in to comment.