Skip to content

Commit

Permalink
Include variable scope clause in deprecated Cypher query (#6)
Browse files Browse the repository at this point in the history
* Include variable scope clause in deprecated Cypher query

* Added original query for versions below 5.23
  • Loading branch information
willtai authored Nov 18, 2024
1 parent f58b60b commit 598917a
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ __pycache__
.mypy_cache_test
.env
.venv*
.idea
49 changes: 34 additions & 15 deletions libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,34 +84,39 @@ class IndexType(str, enum.Enum):


def _get_search_index_query(
search_type: SearchType, index_type: IndexType = DEFAULT_INDEX_TYPE
search_type: SearchType,
index_type: IndexType = DEFAULT_INDEX_TYPE,
neo4j_version_is_5_23_or_above: bool = False,
) -> str:
if index_type == IndexType.NODE:
type_to_query_map = {
SearchType.VECTOR: (
if search_type == SearchType.VECTOR:
return (
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
"YIELD node, score "
),
SearchType.HYBRID: (
"CALL { "
)
elif search_type == SearchType.HYBRID:
call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { "

query_body = (
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
# We use 0 as min
"RETURN n.node AS node, (n.score / max) AS score UNION "
"CALL db.index.fulltext.queryNodes($keyword_index, $query, "
"{limit: $k}) YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
# We use 0 as min
"RETURN n.node AS node, (n.score / max) AS score "
"} "
# dedup
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
),
}
return type_to_query_map[search_type]
)

call_suffix = (
"} WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
)

return call_prefix + query_body + call_suffix
else:
raise ValueError(f"Unsupported SearchType: {search_type}")
else:
return (
"CALL db.index.vector.queryRelationships($index, $k, $embedding) "
Expand Down Expand Up @@ -666,6 +671,10 @@ def verify_version(self) -> None:
else:
version_tuple = tuple(map(int, version.split(".")))

self.neo4j_version_is_5_23_or_above = self._check_if_version_5_23_or_above(
version_tuple
)

target_version = (5, 11, 0)

if version_tuple < target_version:
Expand All @@ -682,6 +691,14 @@ def verify_version(self) -> None:
# Flag for enterprise
self._is_enterprise = True if db_data[0]["edition"] == "enterprise" else False

def _check_if_version_5_23_or_above(self, version_tuple: tuple[int, ...]) -> bool:
"""
Check if the connected Neo4j database version supports the required features.
Sets a flag if the connected Neo4j version is 5.23 or above.
"""
return version_tuple >= (5, 23, 0)

def retrieve_existing_index(self) -> Tuple[Optional[int], Optional[str]]:
"""
Check if the vector index exists in the Neo4j database
Expand Down Expand Up @@ -1064,7 +1081,9 @@ def similarity_search_with_score_by_vector(
index_query = base_index_query + filter_snippets + base_cosine_query

else:
index_query = _get_search_index_query(self.search_type, self._index_type)
index_query = _get_search_index_query(
self.search_type, self._index_type, self.neo4j_version_is_5_23_or_above
)
filter_params = {}

if self._index_type == IndexType.RELATIONSHIP:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
texts = ["foo", "bar", "baz", "It is the end of the world. Take shelter!"]

"""
cd tests/integration_tests/vectorstores/docker-compose
cd tests/integration_tests/docker-compose
docker-compose -f neo4j.yml up
"""

Expand Down
47 changes: 47 additions & 0 deletions libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Test Neo4j functionality."""

from langchain_neo4j.vectorstores.neo4j_vector import (
IndexType,
SearchType,
_get_search_index_query,
dict_to_yaml_str,
remove_lucene_chars,
)
Expand Down Expand Up @@ -65,3 +68,47 @@ def test_converting_to_yaml() -> None:
)

assert yaml_str == expected_output


def test_get_search_index_query_hybrid_node_neo4j_5_23_above() -> None:
expected_query = (
"CALL () { "
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score UNION "
"CALL db.index.fulltext.queryNodes($keyword_index, $query, "
"{limit: $k}) YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
)

actual_query = _get_search_index_query(SearchType.HYBRID, IndexType.NODE, True)

assert actual_query == expected_query


def test_get_search_index_query_hybrid_node_neo4j_5_23_below() -> None:
expected_query = (
"CALL { "
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score UNION "
"CALL db.index.fulltext.queryNodes($keyword_index, $query, "
"{limit: $k}) YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
)

actual_query = _get_search_index_query(SearchType.HYBRID, IndexType.NODE, False)

assert actual_query == expected_query

0 comments on commit 598917a

Please sign in to comment.