diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index c9230f463..f9342781e 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -54,9 +54,7 @@ def __init__( return_properties=return_properties, ) except ValidationError as e: - msg = f"Validation failed: {e.errors()}" - logger.error(msg) - raise ValueError(msg) + raise ValueError(f"Validation failed: {e.errors()}") super().__init__(validated_data.driver_model.driver) self.vector_index_name = validated_data.vector_index_name @@ -101,9 +99,7 @@ def search( query_text=query_text, ) except ValidationError as e: - msg = f"Validation failed: {e.errors()}" - logger.error(msg) - raise ValueError(msg) + raise ValueError(f"Validation failed: {e.errors()}") parameters = validated_data.model_dump(exclude_none=True) @@ -142,9 +138,7 @@ def __init__( embedder_model=embedder_model, ) except ValidationError as e: - msg = f"Validation failed: {e.errors()}" - logger.error(msg) - raise ValueError(msg) + raise ValueError(f"Validation failed: {e.errors()}") super().__init__(validated_data.driver_model.driver) self.vector_index_name = validated_data.vector_index_name @@ -192,9 +186,7 @@ def search( query_params=query_params, ) except ValidationError as e: - msg = f"Validation failed: {e.errors()}" - logger.error(msg) - raise ValueError(msg) + raise ValueError(f"Validation failed: {e.errors()}") parameters = validated_data.model_dump(exclude_none=True) diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index af3a60685..c3eb3eae5 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -43,17 +43,23 @@ def __init__( index_name: str, embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, + filters: Optional[dict[str, Any]] = None, ) -> None: super().__init__(driver) self.index_name = index_name self.return_properties = return_properties self.embedder = embedder + self._node_label = None + self._embedding_node_property = None + self._embedding_dimension = None + self._fetch_index_infos() def search( self, query_vector: Optional[list[float]] = None, query_text: Optional[str] = None, top_k: int = 5, + filters: Optional[dict[str, Any]] = None, ) -> list[VectorSearchRecord]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -75,7 +81,7 @@ def search( """ try: validated_data = VectorSearchModel( - index_name=self.index_name, + vector_index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -93,7 +99,15 @@ def search( parameters["query_vector"] = query_vector del parameters["query_text"] - search_query = get_search_query(SearchType.VECTOR, self.return_properties) + search_query, search_params = get_search_query( + SearchType.VECTOR, + self.return_properties, + node_label=self._node_label, + embedding_node_property=self._embedding_node_property, + embedding_dimension=self._embedding_dimension, + filters=filters, + ) + parameters.update(search_params) logger.debug("VectorRetriever Cypher parameters: %s", parameters) logger.debug("VectorRetriever Cypher query: %s", search_query) @@ -129,6 +143,10 @@ def __init__( self.index_name = index_name self.retrieval_query = retrieval_query self.embedder = embedder + self._node_label = None + self._node_embedding_property = None + self._embedding_dimension = None + self._fetch_index_infos() def search( self, @@ -136,6 +154,7 @@ def search( query_text: Optional[str] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, + filters: Optional[dict[str, Any]] = None, ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -158,7 +177,7 @@ def search( """ try: validated_data = VectorCypherSearchModel( - index_name=self.index_name, + vector_index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -181,9 +200,15 @@ def search( parameters[key] = value del parameters["query_params"] - search_query = get_search_query( - SearchType.VECTOR, retrieval_query=self.retrieval_query + search_query, search_params = get_search_query( + SearchType.VECTOR, + retrieval_query=self.retrieval_query, + node_label=self._node_label, + embedding_node_property=self._node_embedding_property, + embedding_dimension=self._embedding_dimension, + filters=filters, ) + parameters.update(search_params) logger.debug("VectorCypherRetriever Cypher parameters: %s", parameters) logger.debug("VectorCypherRetriever Cypher query: %s", search_query) diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 98a892398..8b388fcec 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -36,9 +36,7 @@ def test_vector_retriever_bad_data_validation(driver): def test_vector_cypher_retriever_bad_data_validation(driver): with pytest.raises(ValueError): - VectorCypherRetriever( - driver=driver, index_name="my-index", retrieval_query=42 - ) + VectorCypherRetriever(driver=driver, index_name="my-index", retrieval_query=42) def test_vector_cypher_retriever_initialization(driver):