From 39bd5faf587de162d247bfab187b740f99e5a099 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Fri, 19 Apr 2024 17:32:10 +0100 Subject: [PATCH] Adds HybridSearchRetriever and creates abstract base class Retriever --- examples/hybrid_search.py | 62 ++++++++++++++++++++++ pyproject.toml | 2 +- src/neo4j_genai/__init__.py | 4 +- src/neo4j_genai/indexes.py | 2 +- src/neo4j_genai/retrievers.py | 75 ++++++++++++++++++++++++++ src/neo4j_genai/types.py | 5 ++ tests/conftest.py | 8 ++- tests/test_indexes.py | 4 +- tests/test_retrievers.py | 99 +++++++++++++++++++++++++++++++---- 9 files changed, 243 insertions(+), 18 deletions(-) create mode 100644 examples/hybrid_search.py diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py new file mode 100644 index 000000000..7fffd1e6c --- /dev/null +++ b/examples/hybrid_search.py @@ -0,0 +1,62 @@ +from neo4j import GraphDatabase + +from random import random +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index, drop_index, create_fulltext_index +from neo4j_genai.retrievers import HybridSearchRetriever + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +FULLTEXT_INDEX_NAME = "fulltext-index-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random() for _ in range(DIMENSION)] + + +embedder = CustomEmbedder() + +# Creating the index +drop_index(driver, INDEX_NAME) +drop_index(driver, FULLTEXT_INDEX_NAME) +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) +create_fulltext_index( + driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"] +) + +# Initialize the retriever +retriever = HybridSearchRetriever(driver, INDEX_NAME, FULLTEXT_INDEX_NAME, embedder) + +# Upsert the query +vector = [random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (n:Document {id: $id})" + "WITH n " + "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" + "RETURN n" +) +parameters = { + "id": 0, + "vector": vector, +} +driver.execute_query(insert_query, parameters) + +# Perform the similarity search for a text query +query_text = "hello world" +fulltext_query = "fremen" +print(retriever.search(query_text=query_text, fulltext_query=fulltext_query, top_k=5)) diff --git a/pyproject.toml b/pyproject.toml index 09bc8fd0e..94f06d79d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ [tool.poetry] name = "neo4j-genai" -version = "0.1.3" +version = "0.1.4" description = "Python package to allow easy integration to Neo4j's GenAI features" authors = ["Neo4j, Inc "] license = "Apache License, Version 2.0" diff --git a/src/neo4j_genai/__init__.py b/src/neo4j_genai/__init__.py index 5676e9c4e..6cbb50f8a 100644 --- a/src/neo4j_genai/__init__.py +++ b/src/neo4j_genai/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .retrievers import VectorRetriever, VectorCypherRetriever +from .retrievers import VectorRetriever, VectorCypherRetriever, HybridSearchRetriever -__all__ = ["VectorRetriever", "VectorCypherRetriever"] +__all__ = ["VectorRetriever", "VectorCypherRetriever", "HybridSearchRetriever"] diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index b457b3014..b09901e08 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -98,7 +98,7 @@ def create_fulltext_index( raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}") query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" ) diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index 9ce9d6dc9..f6b68dd5b 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -21,6 +21,7 @@ SimilaritySearchModel, VectorSearchRecord, VectorCypherSearchModel, + HybridSearchModel, ) @@ -231,3 +232,77 @@ def search( search_query = query_prefix + self.retrieval_query records, _, _ = self.driver.execute_query(search_query, parameters) return records + + +class HybridSearchRetriever(Retriever): + def __init__( + self, + driver: Driver, + index_name: str, + fulltext_index_name: str, + embedder: Optional[Embedder] = None, + ) -> None: + super().__init__(driver) + self._verify_version() + self.index_name = index_name + self.fulltext_index_name = fulltext_index_name + self.embedder = embedder + + def search( + self, + fulltext_query: str, + query_vector: Optional[list[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, + ) -> list[Record]: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + See the following documentation for more details: + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + Args: + fulltext_query (str): String to query over the fulltext index. + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + Returns: + list[Record]: The results of the search query + """ + try: + validated_data = HybridSearchModel( + index_name=self.index_name, + fulltext_index_name=self.fulltext_index_name, + fulltext_query=fulltext_query, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + parameters["query_vector"] = self.embedder.embed_query(query_text) + del parameters["query_text"] + + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $fulltext_query, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + records, _, _ = self.driver.execute_query(search_query, parameters) + return records diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index ebe8cd166..cde5b1f70 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -78,3 +78,8 @@ def check_query(cls, values): class VectorCypherSearchModel(SimilaritySearchModel): query_params: Optional[dict[str, Any]] = None + + +class HybridSearchModel(VectorCypherSearchModel): + fulltext_index_name: str + fulltext_query: str diff --git a/tests/conftest.py b/tests/conftest.py index 3c210a9f1..01b12efce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ # limitations under the License. import pytest -from neo4j_genai import VectorRetriever, VectorCypherRetriever +from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridSearchRetriever from neo4j import Driver from unittest.mock import MagicMock, patch @@ -37,3 +37,9 @@ def vector_cypher_retriever(_verify_version_mock, driver): RETURN node.id AS node_id, node.text AS text, score """ return VectorCypherRetriever(driver, "my-index", retrieval_query) + + +@pytest.fixture +@patch("neo4j_genai.HybridSearchRetriever._verify_version") +def hybrid_search_retriever(_verify_version_mock, driver): + return HybridSearchRetriever(driver, "my-index", "my-fulltext-index") diff --git a/tests/test_indexes.py b/tests/test_indexes.py index ae2f98c32..c624607d5 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -89,7 +89,7 @@ def test_create_fulltext_index_happy_path(driver): label = "node-label" text_node_properties = ["property-1", "property-2"] create_query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]" ) @@ -116,7 +116,7 @@ def test_create_fulltext_index_ensure_escaping(driver): label = "node-label" text_node_properties = ["property-1", "property-2"] create_query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]" ) diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 62900cb01..de6538520 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -19,7 +19,7 @@ from neo4j.exceptions import CypherSyntaxError from neo4j_genai import VectorRetriever -from neo4j_genai.retrievers import VectorCypherRetriever +from neo4j_genai.retrievers import VectorCypherRetriever, HybridSearchRetriever from neo4j_genai.types import VectorSearchRecord @@ -55,14 +55,12 @@ def test_vector_retriever_no_supported_version(driver): @patch("neo4j_genai.VectorRetriever._verify_version") def test_similarity_search_vector_happy_path(_verify_version_mock, driver): - custom_embeddings = MagicMock() - index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - retriever = VectorRetriever(driver, index_name, custom_embeddings) + retriever = VectorRetriever(driver, index_name) retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], @@ -76,8 +74,6 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): records = retriever.search(query_vector=query_vector, top_k=top_k) - custom_embeddings.embed_query.assert_not_called() - retriever.driver.execute_query.assert_called_once_with( search_query, { @@ -222,14 +218,12 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri @patch("neo4j_genai.VectorRetriever._verify_version") def test_similarity_search_vector_bad_results(_verify_version_mock, driver): - custom_embeddings = MagicMock() - index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - retriever = VectorRetriever(driver, index_name, custom_embeddings) + retriever = VectorRetriever(driver, index_name) retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": "adsa"}], @@ -244,8 +238,6 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): with pytest.raises(ValueError): retriever.search(query_vector=query_vector, top_k=top_k) - custom_embeddings.embed_query.assert_not_called() - retriever.driver.execute_query.assert_called_once_with( search_query, { @@ -369,3 +361,88 @@ def test_retrieval_query_cypher_error(_verify_version_mock, driver): query_text=query_text, top_k=top_k, ) + + +@patch("neo4j_genai.HybridSearchRetriever._verify_version") +def test_hybrid_search_text_happy_path(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + fulltext_query_text = "which Dune quote contains the word 'fear'?" + top_k = 5 + + retriever = HybridSearchRetriever( + driver, index_name, fulltext_index_name, custom_embeddings + ) + + retriever.driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $fulltext_query, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + + records = retriever.search( + fulltext_query=fulltext_query_text, query_text=query_text, top_k=top_k + ) + + retriever.driver.execute_query.assert_called_once_with( + search_query, + { + "index_name": index_name, + "top_k": top_k, + "fulltext_query": fulltext_query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + }, + ) + custom_embeddings.embed_query.assert_called_once_with(query_text) + assert records == [{"node": "dummy-node", "score": 1.0}] + + +def test_hybrid_retriever_search_both_text_and_vector(hybrid_search_retriever): + query_text = "may thy knife chip and shatter" + query_vector = [1.1, 2.2, 3.3] + fulltext_query_text = "which Dune quote contains the word 'fear'?" + top_k = 5 + + with pytest.raises( + ValueError, match="You must provide exactly one of query_vector or query_text." + ): + hybrid_search_retriever.search( + fulltext_query=fulltext_query_text, + query_text=query_text, + query_vector=query_vector, + top_k=top_k, + ) + + +def test_hybrid_search_retriever_search_missing_embedder_for_text( + hybrid_search_retriever, +): + query_text = "may thy knife chip and shatter" + fulltext_query_text = "which Dune quote contains the word 'fear'?" + top_k = 5 + + with pytest.raises(ValueError, match="Embedding method required for text query"): + hybrid_search_retriever.search( + fulltext_query=fulltext_query_text, + query_text=query_text, + top_k=top_k, + )