From bbfb30c248b16258d00db34c236a574a8eacb21e Mon Sep 17 00:00:00 2001 From: Leila Messallem Date: Mon, 20 May 2024 16:39:44 +0200 Subject: [PATCH] Change default argument from empty dict to None --- src/neo4j_genai/schema.py | 16 +++++++++------- tests/unit/test_schema.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/neo4j_genai/schema.py b/src/neo4j_genai/schema.py index 426a8d13..6f5fdc37 100644 --- a/src/neo4j_genai/schema.py +++ b/src/neo4j_genai/schema.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import neo4j @@ -50,8 +50,8 @@ """ -def _query_database( - driver: neo4j.Driver, query: str, params: dict = {} +def query_database( + driver: neo4j.Driver, query: str, params: Optional[dict] = None ) -> list[dict[str, Any]]: """ Queries the database. @@ -59,11 +59,13 @@ def _query_database( Args: driver (neo4j.Driver): Neo4j Python driver instance. query (str): The cypher query. - params (dict, optional): The query parameters. Defaults to {}. + params (dict, optional): The query parameters. Defaults to None. Returns: List[Dict[str, Any]]: the result of the query in json format. """ + if params is None: + params = {} data = driver.execute_query(query, params) return [r.data() for r in data.records] @@ -82,7 +84,7 @@ def get_schema( """ node_properties = [ data["output"] - for data in _query_database( + for data in query_database( driver, NODE_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, @@ -91,13 +93,13 @@ def get_schema( rel_properties = [ data["output"] - for data in _query_database( + for data in query_database( driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS} ) ] relationships = [ data["output"] - for data in _query_database( + for data in query_database( driver, REL_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index 8a38b9e1..24d3ed8a 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -69,7 +69,7 @@ def test_get_schema_happy_path(driver): ) -@patch("neo4j_genai.schema._query_database", side_effect=_query_return_value) +@patch("neo4j_genai.schema.query_database", side_effect=_query_return_value) def test_get_schema_ensure_formatted_response(driver): result = get_schema(driver) assert (