From 598917a11ac342f0457ea3a0a6bb3a6485e60675 Mon Sep 17 00:00:00 2001 From: willtai Date: Mon, 18 Nov 2024 13:42:53 +0000 Subject: [PATCH] Include variable scope clause in deprecated Cypher query (#6) * Include variable scope clause in deprecated Cypher query * Added original query for versions below 5.23 --- .gitignore | 1 + .../vectorstores/neo4j_vector.py | 49 +++++++++++++------ .../vectorstores/test_neo4jvector.py | 2 +- .../unit_tests/vectorstores/test_neo4j.py | 47 ++++++++++++++++++ 4 files changed, 83 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 45d553b..716a06f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__ .mypy_cache_test .env .venv* +.idea \ No newline at end of file diff --git a/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py b/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py index 863de6d..0dbf383 100644 --- a/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py +++ b/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py @@ -84,34 +84,39 @@ class IndexType(str, enum.Enum): def _get_search_index_query( - search_type: SearchType, index_type: IndexType = DEFAULT_INDEX_TYPE + search_type: SearchType, + index_type: IndexType = DEFAULT_INDEX_TYPE, + neo4j_version_is_5_23_or_above: bool = False, ) -> str: if index_type == IndexType.NODE: - type_to_query_map = { - SearchType.VECTOR: ( + if search_type == SearchType.VECTOR: + return ( "CALL db.index.vector.queryNodes($index, $k, $embedding) " "YIELD node, score " - ), - SearchType.HYBRID: ( - "CALL { " + ) + elif search_type == SearchType.HYBRID: + call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { " + + query_body = ( "CALL db.index.vector.queryNodes($index, $k, $embedding) " "YIELD node, score " "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " "UNWIND nodes AS n " - # We use 0 as min "RETURN n.node AS node, (n.score / max) AS score UNION " "CALL db.index.fulltext.queryNodes($keyword_index, $query, " "{limit: $k}) YIELD node, score " "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " "UNWIND nodes AS n " - # We use 0 as min "RETURN n.node AS node, (n.score / max) AS score " - "} " - # dedup - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " - ), - } - return type_to_query_map[search_type] + ) + + call_suffix = ( + "} WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " + ) + + return call_prefix + query_body + call_suffix + else: + raise ValueError(f"Unsupported SearchType: {search_type}") else: return ( "CALL db.index.vector.queryRelationships($index, $k, $embedding) " @@ -666,6 +671,10 @@ def verify_version(self) -> None: else: version_tuple = tuple(map(int, version.split("."))) + self.neo4j_version_is_5_23_or_above = self._check_if_version_5_23_or_above( + version_tuple + ) + target_version = (5, 11, 0) if version_tuple < target_version: @@ -682,6 +691,14 @@ def verify_version(self) -> None: # Flag for enterprise self._is_enterprise = True if db_data[0]["edition"] == "enterprise" else False + def _check_if_version_5_23_or_above(self, version_tuple: tuple[int, ...]) -> bool: + """ + Check if the connected Neo4j database version supports the required features. + + Sets a flag if the connected Neo4j version is 5.23 or above. + """ + return version_tuple >= (5, 23, 0) + def retrieve_existing_index(self) -> Tuple[Optional[int], Optional[str]]: """ Check if the vector index exists in the Neo4j database @@ -1064,7 +1081,9 @@ def similarity_search_with_score_by_vector( index_query = base_index_query + filter_snippets + base_cosine_query else: - index_query = _get_search_index_query(self.search_type, self._index_type) + index_query = _get_search_index_query( + self.search_type, self._index_type, self.neo4j_version_is_5_23_or_above + ) filter_params = {} if self._index_type == IndexType.RELATIONSHIP: diff --git a/libs/neo4j/tests/integration_tests/vectorstores/test_neo4jvector.py b/libs/neo4j/tests/integration_tests/vectorstores/test_neo4jvector.py index f9c92c2..ccf62c6 100644 --- a/libs/neo4j/tests/integration_tests/vectorstores/test_neo4jvector.py +++ b/libs/neo4j/tests/integration_tests/vectorstores/test_neo4jvector.py @@ -35,7 +35,7 @@ texts = ["foo", "bar", "baz", "It is the end of the world. Take shelter!"] """ -cd tests/integration_tests/vectorstores/docker-compose +cd tests/integration_tests/docker-compose docker-compose -f neo4j.yml up """ diff --git a/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py b/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py index 72b371f..c63435a 100644 --- a/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py +++ b/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py @@ -1,6 +1,9 @@ """Test Neo4j functionality.""" from langchain_neo4j.vectorstores.neo4j_vector import ( + IndexType, + SearchType, + _get_search_index_query, dict_to_yaml_str, remove_lucene_chars, ) @@ -65,3 +68,47 @@ def test_converting_to_yaml() -> None: ) assert yaml_str == expected_output + + +def test_get_search_index_query_hybrid_node_neo4j_5_23_above() -> None: + expected_query = ( + "CALL () { " + "CALL db.index.vector.queryNodes($index, $k, $embedding) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score UNION " + "CALL db.index.fulltext.queryNodes($keyword_index, $query, " + "{limit: $k}) YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " + ) + + actual_query = _get_search_index_query(SearchType.HYBRID, IndexType.NODE, True) + + assert actual_query == expected_query + + +def test_get_search_index_query_hybrid_node_neo4j_5_23_below() -> None: + expected_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($index, $k, $embedding) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score UNION " + "CALL db.index.fulltext.queryNodes($keyword_index, $query, " + "{limit: $k}) YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " + ) + + actual_query = _get_search_index_query(SearchType.HYBRID, IndexType.NODE, False) + + assert actual_query == expected_query