Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method for getting the database schema #32

Merged
merged 4 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ htmlcov/
.idea/
.env
docs/build/
.vscode/
140 changes: 140 additions & 0 deletions src/neo4j_genai/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

import neo4j


BASE_ENTITY_LABEL = "__Entity__"
EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"]
EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"]
leila-messallem marked this conversation as resolved.
Show resolved Hide resolved

NODE_PROPERTIES_QUERY = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
AND NOT label IN $EXCLUDED_LABELS
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
RETURN {labels: nodeLabels, properties: properties} AS output
"""

REL_PROPERTIES_QUERY = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
AND NOT label in $EXCLUDED_LABELS
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
RETURN {type: nodeLabels, properties: properties} AS output
"""

REL_QUERY = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE type = "RELATIONSHIP" AND elementType = "node"
UNWIND other AS other_node
WITH * WHERE NOT label IN $EXCLUDED_LABELS
AND NOT other_node IN $EXCLUDED_LABELS
RETURN {start: label, type: property, end: toString(other_node)} AS output
"""


def query_database(
driver: neo4j.Driver, query: str, params: Optional[dict] = None
) -> list[dict[str, Any]]:
"""
Queries the database.

Args:
driver (neo4j.Driver): Neo4j Python driver instance.
query (str): The cypher query.
params (dict, optional): The query parameters. Defaults to None.

Returns:
List[Dict[str, Any]]: the result of the query in json format.
leila-messallem marked this conversation as resolved.
Show resolved Hide resolved
"""
if params is None:
params = {}
data = driver.execute_query(query, params)
return [r.data() for r in data.records]


def get_schema(
leila-messallem marked this conversation as resolved.
Show resolved Hide resolved
driver: neo4j.Driver,
) -> str:
"""
Returns the schema of the graph.

Args:
driver (neo4j.Driver): Neo4j Python driver instance.

Returns:
str: the graph schema information in a serialized format.
"""
node_properties = [
data["output"]
for data in query_database(
driver,
NODE_PROPERTIES_QUERY,
params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
)
]

rel_properties = [
data["output"]
for data in query_database(
driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS}
)
]
relationships = [
data["output"]
for data in query_database(
driver,
REL_QUERY,
params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
)
]

# Format node properties
formatted_node_props = []
for element in node_properties:
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in element["properties"]]
)
formatted_node_props.append(f"{element['labels']} {{{props_str}}}")

# Format relationship properties
formatted_rel_props = []
for element in rel_properties:
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in element["properties"]]
)
formatted_rel_props.append(f"{element['type']} {{{props_str}}}")

# Format relationships
formatted_rels = [
f"(:{element['start']})-[:{element['type']}]->(:{element['end']})"
for element in relationships
]

return "\n".join(
[
"Node properties:",
"\n".join(formatted_node_props),
"Relationship properties:",
"\n".join(formatted_rel_props),
"The relationships:",
"\n".join(formatted_rels),
]
leila-messallem marked this conversation as resolved.
Show resolved Hide resolved
)
84 changes: 84 additions & 0 deletions tests/unit/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
from neo4j_genai.schema import (
get_schema,
NODE_PROPERTIES_QUERY,
REL_PROPERTIES_QUERY,
REL_QUERY,
EXCLUDED_LABELS,
BASE_ENTITY_LABEL,
EXCLUDED_RELS,
)


def _query_return_value(*args, **kwargs):
if NODE_PROPERTIES_QUERY in args[1]:
return [
{
"output": {
"properties": [{"property": "property_a", "type": "STRING"}],
"labels": "LabelA",
}
}
]
if REL_PROPERTIES_QUERY in args[1]:
return [
{
"output": {
"type": "REL_TYPE",
"properties": [{"property": "rel_prop", "type": "STRING"}],
}
}
]
if REL_QUERY in args[1]:
return [
{"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}},
{"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}},
]

raise AssertionError("Unexpected query")


def test_get_schema_happy_path(driver):
get_schema(driver)
assert 3 == driver.execute_query.call_count
driver.execute_query.assert_any_call(
NODE_PROPERTIES_QUERY,
{"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
)
driver.execute_query.assert_any_call(
REL_PROPERTIES_QUERY,
{"EXCLUDED_LABELS": EXCLUDED_RELS},
)
driver.execute_query.assert_any_call(
REL_QUERY,
{"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
)


@patch("neo4j_genai.schema.query_database", side_effect=_query_return_value)
def test_get_schema_ensure_formatted_response(driver):
result = get_schema(driver)
assert (
result
== """Node properties:
LabelA {property_a: STRING}
Relationship properties:
REL_TYPE {rel_prop: STRING}
The relationships:
(:LabelA)-[:REL_TYPE]->(:LabelB)
(:LabelA)-[:REL_TYPE]->(:LabelC)"""
)
Loading