diff --git a/.gitignore b/.gitignore index 16c7907e..489b0412 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ htmlcov/ .idea/ .env docs/build/ +.vscode/ diff --git a/src/neo4j_genai/schema.py b/src/neo4j_genai/schema.py new file mode 100644 index 00000000..7a2982b0 --- /dev/null +++ b/src/neo4j_genai/schema.py @@ -0,0 +1,140 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +import neo4j + + +BASE_ENTITY_LABEL = "__Entity__" +EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"] +EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"] + +NODE_PROPERTIES_QUERY = """ +CALL apoc.meta.data() +YIELD label, other, elementType, type, property +WHERE NOT type = "RELATIONSHIP" AND elementType = "node" + AND NOT label IN $EXCLUDED_LABELS +WITH label AS nodeLabels, collect({property:property, type:type}) AS properties +RETURN {labels: nodeLabels, properties: properties} AS output +""" + +REL_PROPERTIES_QUERY = """ +CALL apoc.meta.data() +YIELD label, other, elementType, type, property +WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship" + AND NOT label in $EXCLUDED_LABELS +WITH label AS nodeLabels, collect({property:property, type:type}) AS properties +RETURN {type: nodeLabels, properties: properties} AS output +""" + +REL_QUERY = """ +CALL apoc.meta.data() +YIELD label, other, elementType, type, property +WHERE type = "RELATIONSHIP" AND elementType = "node" +UNWIND other AS other_node +WITH * WHERE NOT label IN $EXCLUDED_LABELS + AND NOT other_node IN $EXCLUDED_LABELS +RETURN {start: label, type: property, end: toString(other_node)} AS output +""" + + +def query_database( + driver: neo4j.Driver, query: str, params: Optional[dict] = None +) -> list[dict[str, Any]]: + """ + Queries the database. + + Args: + driver (neo4j.Driver): Neo4j Python driver instance. + query (str): The cypher query. + params (dict, optional): The query parameters. Defaults to None. + + Returns: + list[dict[str, Any]]: the result of the query in json format. + """ + if params is None: + params = {} + data = driver.execute_query(query, params) + return [r.data() for r in data.records] + + +def get_schema( + driver: neo4j.Driver, +) -> str: + """ + Returns the schema of the graph. + + Args: + driver (neo4j.Driver): Neo4j Python driver instance. + + Returns: + str: the graph schema information in a serialized format. + """ + node_properties = [ + data["output"] + for data in query_database( + driver, + NODE_PROPERTIES_QUERY, + params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, + ) + ] + + rel_properties = [ + data["output"] + for data in query_database( + driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS} + ) + ] + relationships = [ + data["output"] + for data in query_database( + driver, + REL_QUERY, + params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, + ) + ] + + # Format node properties + formatted_node_props = [] + for element in node_properties: + props_str = ", ".join( + [f"{prop['property']}: {prop['type']}" for prop in element["properties"]] + ) + formatted_node_props.append(f"{element['labels']} {{{props_str}}}") + + # Format relationship properties + formatted_rel_props = [] + for element in rel_properties: + props_str = ", ".join( + [f"{prop['property']}: {prop['type']}" for prop in element["properties"]] + ) + formatted_rel_props.append(f"{element['type']} {{{props_str}}}") + + # Format relationships + formatted_rels = [ + f"(:{element['start']})-[:{element['type']}]->(:{element['end']})" + for element in relationships + ] + + return "\n".join( + [ + "Node properties:", + "\n".join(formatted_node_props), + "Relationship properties:", + "\n".join(formatted_rel_props), + "The relationships:", + "\n".join(formatted_rels), + ] + ) diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py new file mode 100644 index 00000000..24d3ed8a --- /dev/null +++ b/tests/unit/test_schema.py @@ -0,0 +1,84 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch +from neo4j_genai.schema import ( + get_schema, + NODE_PROPERTIES_QUERY, + REL_PROPERTIES_QUERY, + REL_QUERY, + EXCLUDED_LABELS, + BASE_ENTITY_LABEL, + EXCLUDED_RELS, +) + + +def _query_return_value(*args, **kwargs): + if NODE_PROPERTIES_QUERY in args[1]: + return [ + { + "output": { + "properties": [{"property": "property_a", "type": "STRING"}], + "labels": "LabelA", + } + } + ] + if REL_PROPERTIES_QUERY in args[1]: + return [ + { + "output": { + "type": "REL_TYPE", + "properties": [{"property": "rel_prop", "type": "STRING"}], + } + } + ] + if REL_QUERY in args[1]: + return [ + {"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}}, + {"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}}, + ] + + raise AssertionError("Unexpected query") + + +def test_get_schema_happy_path(driver): + get_schema(driver) + assert 3 == driver.execute_query.call_count + driver.execute_query.assert_any_call( + NODE_PROPERTIES_QUERY, + {"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, + ) + driver.execute_query.assert_any_call( + REL_PROPERTIES_QUERY, + {"EXCLUDED_LABELS": EXCLUDED_RELS}, + ) + driver.execute_query.assert_any_call( + REL_QUERY, + {"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, + ) + + +@patch("neo4j_genai.schema.query_database", side_effect=_query_return_value) +def test_get_schema_ensure_formatted_response(driver): + result = get_schema(driver) + assert ( + result + == """Node properties: +LabelA {property_a: STRING} +Relationship properties: +REL_TYPE {rel_prop: STRING} +The relationships: +(:LabelA)-[:REL_TYPE]->(:LabelB) +(:LabelA)-[:REL_TYPE]->(:LabelC)""" + )