From 9efab3ed662fd789d2fc4f50c07ca6cb6d20a078 Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Wed, 24 Apr 2024 23:14:41 +0200 Subject: [PATCH] community[patch]: Add driver config param for neo4j graph (#20772) Co-authored-by: Bagatur --- .../langchain_community/graphs/neo4j_graph.py | 7 ++++++- .../integration_tests/graphs/test_neo4j.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index 72d704a411fc5..fbe025a5d9c7d 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -151,6 +151,7 @@ class Neo4jGraph(GraphStore): embedding-like properties from database responses. Default is False. refresh_schema (bool): A flag whether to refresh schema information at initialization. Default is True. + driver_config (Dict): Configuration passed to Neo4j Driver. *Security note*: Make sure that the database connection uses credentials that are narrowly-scoped to only include necessary permissions. @@ -173,6 +174,8 @@ def __init__( timeout: Optional[float] = None, sanitize: bool = False, refresh_schema: bool = True, + *, + driver_config: Optional[Dict] = None, ) -> None: """Create a new Neo4j graph wrapper instance.""" try: @@ -194,7 +197,9 @@ def __init__( {"database": database}, "database", "NEO4J_DATABASE", "neo4j" ) - self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) + self._driver = neo4j.GraphDatabase.driver( + url, auth=(username, password), **(driver_config or {}) + ) self._database = database self.timeout = timeout self.sanitize = sanitize diff --git a/libs/community/tests/integration_tests/graphs/test_neo4j.py b/libs/community/tests/integration_tests/graphs/test_neo4j.py index 50e0a9f7244ad..c87b8514fec4d 100644 --- a/libs/community/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/community/tests/integration_tests/graphs/test_neo4j.py @@ -273,3 +273,21 @@ def test_neo4j_filtering_labels() -> None: # Assert both are empty assert graph.structured_schema["node_props"] == {} assert graph.structured_schema["relationships"] == [] + + +def test_driver_config() -> None: + """Test that neo4j works with driver config.""" + url = os.environ.get("NEO4J_URI") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph( + url=url, + username=username, + password=password, + driver_config={"max_connection_pool_size": 1}, + ) + graph.query("RETURN 'foo'")