Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Leila Messallem committed May 17, 2024
1 parent abef5df commit 77215af
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
28 changes: 14 additions & 14 deletions src/neo4j_genai/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@

def _query(
driver: neo4j.Driver,
query: str,
query: str,
params: dict = {}
) -> List[Dict[str, Any]]:
"""
Queries the database.
Args:
driver (neo4j.Driver): Neo4j Python driver instance.
query (str): The cypher query.
Expand All @@ -70,13 +70,13 @@ def _query(
Returns:
List[Dict[str, Any]]: the result of the query in json format.
"""
try:
try:
data = driver.execute_query(query, params)
json_data = [r.data() for r in data.records]
return json_data
except CypherSyntaxError as e:
raise ValueError(f"Cypher Statement is not valid: {e}")


def get_schema(
driver: neo4j.Driver,
Expand All @@ -103,7 +103,7 @@ def get_schema(
el["output"]
for el in _query(
driver,
REL_PROPERTIES_QUERY,
REL_PROPERTIES_QUERY,
params={"EXCLUDED_LABELS": EXCLUDED_RELS}
)
]
Expand All @@ -115,7 +115,7 @@ def get_schema(
params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
)
]

# Format node properties
formatted_node_props = []
for el in node_properties:
Expand All @@ -138,14 +138,14 @@ def get_schema(
]

formatted_schema = "\n".join(
[
"Node properties:",
"\n".join(formatted_node_props),
"Relationship properties:",
"\n".join(formatted_rel_props),
"The relationships:",
"\n".join(formatted_rels),
]
[
"Node properties:",
"\n".join(formatted_node_props),
"Relationship properties:",
"\n".join(formatted_rel_props),
"The relationships:",
"\n".join(formatted_rels),
]
)

return formatted_schema
8 changes: 4 additions & 4 deletions tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
# limitations under the License.
from neo4j_genai.schema import get_schema, NODE_PROPERTIES_QUERY, REL_PROPERTIES_QUERY, REL_QUERY, EXCLUDED_LABELS, BASE_ENTITY_LABEL, EXCLUDED_RELS


def test_get_schema_happy_path(driver):
get_schema(driver)
assert 3 == driver.execute_query.call_count
driver.execute_query.assert_any_call(
NODE_PROPERTIES_QUERY,
NODE_PROPERTIES_QUERY,
{"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
)
driver.execute_query.assert_any_call(
REL_PROPERTIES_QUERY,
REL_PROPERTIES_QUERY,
{"EXCLUDED_LABELS": EXCLUDED_RELS},
)
driver.execute_query.assert_any_call(
REL_QUERY,
REL_QUERY,
{"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
)

0 comments on commit 77215af

Please sign in to comment.