Skip to content

Commit

Permalink
Add check to not use deprecated Cypher syntax when Neo4j version is >…
Browse files Browse the repository at this point in the history
…= 5.23.0 (neo4j#183)

* Add check to not use deprecated Cypher syntax when Neo4j version is >= 5.23.0

* Update CHANGELOG

* Add variable scope query in Hybrid Retriever based on neo4j version

* Include E2E test to test for deprecation warning from deprecated Cypher subquery syntax

* Resolve mypy errors

* Add neo4j:latest to pr and scheduled E2E tests
  • Loading branch information
willtai authored Oct 21, 2024
1 parent d0fe5ea commit d0528f4
Show file tree
Hide file tree
Showing 15 changed files with 500 additions and 66 deletions.
10 changes: 4 additions & 6 deletions .github/workflows/pr-e2e-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ jobs:
strategy:
matrix:
python-version: ['3.9', '3.12']
neo4j-version:
- 5
neo4j-edition:
- enterprise
neo4j-tag:
- 'latest'
services:
t2v-transformers:
image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-all-MiniLM-L6-v2-onnx
Expand All @@ -37,7 +35,7 @@ jobs:
- 8080:8080
- 50051:50051
neo4j:
image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }}
image: neo4j:${{ matrix.neo4j-tag }}
env:
NEO4J_AUTH: neo4j/password
NEO4J_ACCEPT_LICENSE_AGREEMENT: 'eval'
Expand Down Expand Up @@ -93,7 +91,7 @@ jobs:
- name: Run tests
shell: bash
run: |
if [[ "${{ matrix.neo4j-edition }}" == "community" ]]; then
if [[ "${{ matrix.neo4j-tag }}" == "latest" || "${{ matrix.neo4j-tag }}" == *-community ]]; then
poetry run pytest -m 'not enterprise_only' ./tests/e2e
else
poetry run pytest ./tests/e2e
Expand Down
13 changes: 6 additions & 7 deletions .github/workflows/scheduled-e2e-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ jobs:
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
neo4j-version:
- 5
neo4j-edition:
- community
- enterprise
neo4j-tag:
- '5-community'
- '5-enterprise'
- 'latest'
services:
t2v-transformers:
image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-all-MiniLM-L6-v2-onnx
Expand All @@ -41,7 +40,7 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
neo4j:
image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }}
image: neo4j:${{ matrix.neo4j-tag }}
env:
NEO4J_AUTH: neo4j/password
NEO4J_ACCEPT_LICENSE_AGREEMENT: 'eval'
Expand Down Expand Up @@ -100,7 +99,7 @@ jobs:
- name: Run tests
shell: bash
run: |
if [[ "${{ matrix.neo4j-edition }}" == "community" ]]; then
if [[ "${{ matrix.neo4j-tag }}" == "latest" || "${{ matrix.neo4j-tag }}" == *-community ]]; then
poetry run pytest -m 'not enterprise_only' ./tests/e2e
else
poetry run pytest ./tests/e2e
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Added
- Made `relations` and `potential_schema` optional in `SchemaBuilder`.
- Added a check to prevent the use of deprecated Cypher syntax for Neo4j versions 5.23.0 and above.

## 1.1.0

Expand Down
69 changes: 54 additions & 15 deletions src/neo4j_graphrag/experimental/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
Neo4jRelationship,
)
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
from neo4j_graphrag.neo4j_queries import UPSERT_NODE_QUERY, UPSERT_RELATIONSHIP_QUERY
from neo4j_graphrag.neo4j_queries import (
UPSERT_NODE_QUERY,
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE,
UPSERT_RELATIONSHIP_QUERY,
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,6 +118,7 @@ def __init__(
self.neo4j_database = neo4j_database
self.batch_size = batch_size
self.max_concurrency = max_concurrency
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()

def _db_setup(self) -> None:
# create index on __Entity__.id
Expand Down Expand Up @@ -147,7 +153,12 @@ def _upsert_nodes(self, nodes: list[Neo4jNode]) -> None:
nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
"""
parameters = {"rows": self._nodes_to_rows(nodes)}
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
if self.is_version_5_23_or_above:
self.driver.execute_query(
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
)
else:
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)

async def _async_upsert_nodes(
self,
Expand All @@ -161,7 +172,32 @@ async def _async_upsert_nodes(
"""
async with sem:
parameters = {"rows": self._nodes_to_rows(nodes)}
await self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
await self.driver.execute_query(
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
)

def _get_version(self) -> tuple[int, ...]:
records, _, _ = self.driver.execute_query(
"CALL dbms.components()", database_=self.neo4j_database
)
version = records[0]["versions"][0]
# Drop everything after the '-' first
version_main, *_ = version.split("-")
# Convert each number between '.' into int
version_tuple = tuple(map(int, version_main.split(".")))
# If no patch version, consider it's 0
if len(version_tuple) < 3:
version_tuple = (*version_tuple, 0)
return version_tuple

def _check_if_version_5_23_or_above(self) -> bool:
"""
Check if the connected Neo4j database version supports the required features.
Sets a flag if the connected Neo4j version is 5.23 or above.
"""
version_tuple = self._get_version()
return version_tuple >= (5, 23, 0)

def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
"""Upserts a single relationship into the Neo4j database.
Expand All @@ -170,7 +206,12 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
"""
parameters = {"rows": [rel.model_dump() for rel in rels]}
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)
if self.is_version_5_23_or_above:
self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
)
else:
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)

async def _async_upsert_relationships(
self, rels: list[Neo4jRelationship], sem: asyncio.Semaphore
Expand All @@ -182,9 +223,15 @@ async def _async_upsert_relationships(
"""
async with sem:
parameters = {"rows": [rel.model_dump() for rel in rels]}
await self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY, parameters_=parameters
)
if self.is_version_5_23_or_above:
await self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
parameters_=parameters,
)
else:
await self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY, parameters_=parameters
)

@validate_call
async def run(self, graph: Neo4jGraph) -> KGWriterModel:
Expand All @@ -193,12 +240,6 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
Args:
graph (Neo4jGraph): The knowledge graph to upsert into the database.
"""
# we disable the notification logger to get rid of the deprecation
# warning about Cypher subqueries. Once the queries are updated
# for Neo4j 5.23, we can remove this line and the 'finally' block
notification_logger = logging.getLogger("neo4j.notifications")
notification_level = notification_logger.level
notification_logger.setLevel(logging.ERROR)
try:
if inspect.iscoroutinefunction(self.driver.execute_query):
await self._async_db_setup()
Expand Down Expand Up @@ -233,5 +274,3 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
except neo4j.exceptions.ClientError as e:
logger.exception(e)
return KGWriterModel(status="FAILURE", metadata={"error": str(e)})
finally:
notification_logger.setLevel(notification_level)
72 changes: 58 additions & 14 deletions src/neo4j_graphrag/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@
"RETURN elementId(n)"
)

UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE = (
"UNWIND $rows AS row "
"CREATE (n:__KGBuilder__ {id: row.id}) "
"SET n += row.properties "
"WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node "
"WITH node as n, row CALL (n, row) { "
"WITH n, row WITH n, row WHERE row.embedding_properties IS NOT NULL "
"UNWIND keys(row.embedding_properties) as emb "
"CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) "
"RETURN count(*) as nbEmb "
"} "
"RETURN elementId(n)"
)

UPSERT_RELATIONSHIP_QUERY = (
"UNWIND $rows as row "
"MATCH (start:__KGBuilder__ {id: row.start_node_id}) "
Expand All @@ -69,6 +83,21 @@
"RETURN elementId(rel)"
)


UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE = (
"UNWIND $rows as row "
"MATCH (start:__KGBuilder__ {id: row.start_node_id}) "
"MATCH (end:__KGBuilder__ {id: row.end_node_id}) "
"WITH start, end, row "
"CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel "
"WITH rel, row CALL (rel, row) { "
"WITH rel, row WITH rel, row WHERE row.embedding_properties IS NOT NULL "
"UNWIND keys(row.embedding_properties) as emb "
"CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) "
"} "
"RETURN elementId(rel)"
)

UPSERT_VECTOR_ON_NODE_QUERY = (
"MATCH (n) "
"WHERE elementId(n) = $id "
Expand All @@ -86,19 +115,33 @@
)


def _get_hybrid_query() -> str:
return (
f"CALL {{ {VECTOR_INDEX_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
f"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
)
def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
if neo4j_version_is_5_23_or_above:
return (
f"CALL () {{ {VECTOR_INDEX_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
f"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
)
else:
return (
f"CALL {{ {VECTOR_INDEX_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
f"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
)


def _get_filtered_vector_query(
Expand Down Expand Up @@ -139,6 +182,7 @@ def get_search_query(
embedding_node_property: Optional[str] = None,
embedding_dimension: Optional[int] = None,
filters: Optional[dict[str, Any]] = None,
neo4j_version_is_5_23_or_above: bool = False,
) -> tuple[str, dict[str, Any]]:
"""Build the search query, including pre-filtering if needed, and return clause.
Expand All @@ -160,7 +204,7 @@ def get_search_query(
if search_type == SearchType.HYBRID:
if filters:
raise Exception("Filters are not supported with Hybrid Search")
query = _get_hybrid_query()
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
params: dict[str, Any] = {}
elif search_type == SearchType.VECTOR:
if filters:
Expand Down
11 changes: 11 additions & 0 deletions src/neo4j_graphrag/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def _get_version(self) -> tuple[tuple[int, ...], bool]:
version_tuple = (*version_tuple, 0)
return version_tuple, "aura" in version

def _check_if_version_5_23_or_above(self, version_tuple: tuple[int, ...]) -> bool:
"""
Check if the connected Neo4j database version supports the required features.
Sets a flag if the connected Neo4j version is 5.23 or above.
"""
return version_tuple >= (5, 23, 0)

def _verify_version(self) -> None:
"""
Check if the connected Neo4j database version supports vector indexing.
Expand All @@ -111,6 +119,9 @@ def _verify_version(self) -> None:
not supported.
"""
version_tuple, is_aura = self._get_version()
self.neo4j_version_is_5_23_or_above = self._check_if_version_5_23_or_above(
version_tuple
)

if is_aura:
target_version = (5, 18, 0)
Expand Down
3 changes: 0 additions & 3 deletions src/neo4j_graphrag/retrievers/external/pinecone/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
from typing import Any, Callable, Optional, Union

import neo4j


from pinecone import Pinecone

from pydantic import (
BaseModel,
ConfigDict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import logging
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Callable, Optional

import neo4j
import weaviate.classes as wvc
Expand Down
10 changes: 8 additions & 2 deletions src/neo4j_graphrag/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ def get_search_results(
query_vector = self.embedder.embed_query(query_text)
parameters["query_vector"] = query_vector

search_query, _ = get_search_query(SearchType.HYBRID, self.return_properties)
search_query, _ = get_search_query(
SearchType.HYBRID,
self.return_properties,
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
)

logger.debug("HybridRetriever Cypher parameters: %s", parameters)
logger.debug("HybridRetriever Cypher query: %s", search_query)
Expand Down Expand Up @@ -336,7 +340,9 @@ def get_search_results(
del parameters["query_params"]

search_query, _ = get_search_query(
SearchType.HYBRID, retrieval_query=self.retrieval_query
SearchType.HYBRID,
retrieval_query=self.retrieval_query,
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
)

logger.debug("HybridCypherRetriever Cypher parameters: %s", parameters)
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ services:
environment:
ENABLE_CUDA: "0"
neo4j:
image: neo4j:5-enterprise
image: neo4j:5.24-enterprise
ports:
- 7687:7687
- 7474:7474
Expand Down
Loading

0 comments on commit d0528f4

Please sign in to comment.