From 77215afd63ea0721aacf07af6041466ef8217a2b Mon Sep 17 00:00:00 2001 From: Leila Messallem Date: Fri, 17 May 2024 08:34:17 +0200 Subject: [PATCH] formatting --- src/neo4j_genai/schema.py | 28 ++++++++++++++-------------- tests/unit/test_schema.py | 8 ++++---- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/neo4j_genai/schema.py b/src/neo4j_genai/schema.py index 26f5b7a2..c18164c5 100644 --- a/src/neo4j_genai/schema.py +++ b/src/neo4j_genai/schema.py @@ -53,12 +53,12 @@ def _query( driver: neo4j.Driver, - query: str, + query: str, params: dict = {} ) -> List[Dict[str, Any]]: """ Queries the database. - + Args: driver (neo4j.Driver): Neo4j Python driver instance. query (str): The cypher query. @@ -70,13 +70,13 @@ def _query( Returns: List[Dict[str, Any]]: the result of the query in json format. """ - try: + try: data = driver.execute_query(query, params) json_data = [r.data() for r in data.records] return json_data except CypherSyntaxError as e: raise ValueError(f"Cypher Statement is not valid: {e}") - + def get_schema( driver: neo4j.Driver, @@ -103,7 +103,7 @@ def get_schema( el["output"] for el in _query( driver, - REL_PROPERTIES_QUERY, + REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS} ) ] @@ -115,7 +115,7 @@ def get_schema( params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, ) ] - + # Format node properties formatted_node_props = [] for el in node_properties: @@ -138,14 +138,14 @@ def get_schema( ] formatted_schema = "\n".join( - [ - "Node properties:", - "\n".join(formatted_node_props), - "Relationship properties:", - "\n".join(formatted_rel_props), - "The relationships:", - "\n".join(formatted_rels), - ] + [ + "Node properties:", + "\n".join(formatted_node_props), + "Relationship properties:", + "\n".join(formatted_rel_props), + "The relationships:", + "\n".join(formatted_rels), + ] ) return formatted_schema diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index 05382aa4..36680b1d 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -14,19 +14,19 @@ # limitations under the License. from neo4j_genai.schema import get_schema, NODE_PROPERTIES_QUERY, REL_PROPERTIES_QUERY, REL_QUERY, EXCLUDED_LABELS, BASE_ENTITY_LABEL, EXCLUDED_RELS - + def test_get_schema_happy_path(driver): get_schema(driver) assert 3 == driver.execute_query.call_count driver.execute_query.assert_any_call( - NODE_PROPERTIES_QUERY, + NODE_PROPERTIES_QUERY, {"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, ) driver.execute_query.assert_any_call( - REL_PROPERTIES_QUERY, + REL_PROPERTIES_QUERY, {"EXCLUDED_LABELS": EXCLUDED_RELS}, ) driver.execute_query.assert_any_call( - REL_QUERY, + REL_QUERY, {"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, )