From 3489bf5427c2656c72ae10d88bc519287af4d8eb Mon Sep 17 00:00:00 2001 From: jason Date: Fri, 29 Nov 2024 12:24:06 -0500 Subject: [PATCH] feat: add driver connection lifecycle mgmt to adhere to neo4j Driver expectations (#14) * feat:add context/connection lifecycle management including: state checks for neo4jgraph methods, prevent op on closed connections, driver resource management + tests * slight refactor, add unit tests * format and lint * changelog * add defaults for integration tests --- CHANGELOG.md | 5 + .../langchain_neo4j/graphs/neo4j_graph.py | 107 +++++++++++- .../integration_tests/graphs/test_neo4j.py | 159 ++++++++++++++++++ .../unit_tests/graphs/test_neo4j_graph.py | 115 ++++++++++++- 4 files changed, 384 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62e89d0..2d7f3f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## Next +### Added + +- Enhanced Neo4j driver connection management with more robust error handling +- Simplified connection state checking in Neo4jGraph + ## 0.1.1 ### Changed diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index 54bb643..506533c 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -1,5 +1,5 @@ from hashlib import md5 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from langchain_core.utils import get_from_dict_or_env @@ -400,6 +400,18 @@ def __init__( ) raise e + def _check_driver_state(self) -> None: + """ + Check if the driver is available and ready for operations. + + Raises: + RuntimeError: If the driver has been closed or is not initialized. + """ + if not hasattr(self, "_driver"): + raise RuntimeError( + "Cannot perform operations - Neo4j connection has been closed" + ) + @property def get_schema(self) -> str: """Returns the schema of the Graph""" @@ -423,7 +435,11 @@ def query( Returns: List[Dict[str, Any]]: The list of dictionaries containing the query results. + + Raises: + RuntimeError: If the connection has been closed. """ + self._check_driver_state() from neo4j import Query from neo4j.exceptions import Neo4jError @@ -467,7 +483,11 @@ def query( def refresh_schema(self) -> None: """ Refreshes the Neo4j graph schema information. + + Raises: + RuntimeError: If the connection has been closed. """ + self._check_driver_state() from neo4j.exceptions import ClientError, CypherTypeError node_properties = [ @@ -588,7 +608,11 @@ def add_graph_documents( - baseEntityLabel (bool, optional): If True, each newly created node gets a secondary __Entity__ label, which is indexed and improves import speed and performance. Defaults to False. + + Raises: + RuntimeError: If the connection has been closed. """ + self._check_driver_state() if baseEntityLabel: # Check if constraint already exists constraint_exists = any( [ @@ -810,3 +834,84 @@ def _enhanced_schema_cypher( # Combine all parts of the Cypher query cypher_query = "\n".join([match_clause, with_clause, return_clause]) return cypher_query + + def close(self) -> None: + """ + Explicitly close the Neo4j driver connection. + + Delegates connection management to the Neo4j driver. + """ + if hasattr(self, "_driver"): + self._driver.close() + # Remove the driver attribute to indicate closure + delattr(self, "_driver") + + def __enter__(self) -> "Neo4jGraph": + """ + Enter the runtime context for the Neo4j graph connection. + + Enables use of the graph connection with the 'with' statement. + This method allows for automatic resource management and ensures + that the connection is properly handled. + + Returns: + Neo4jGraph: The current graph connection instance + + Example: + with Neo4jGraph(...) as graph: + graph.query(...) # Connection automatically managed + """ + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + """ + Exit the runtime context for the Neo4j graph connection. + + This method is automatically called when exiting a 'with' statement. + It ensures that the database connection is closed, regardless of + whether an exception occurred during the context's execution. + + Args: + exc_type: The type of exception that caused the context to exit + (None if no exception occurred) + exc_val: The exception instance that caused the context to exit + (None if no exception occurred) + exc_tb: The traceback for the exception (None if no exception occurred) + + Note: + Any exception is re-raised after the connection is closed. + """ + self.close() + + def __del__(self) -> None: + """ + Destructor for the Neo4j graph connection. + + This method is called during garbage collection to ensure that + database resources are released if not explicitly closed. + + Caution: + - Do not rely on this method for deterministic resource cleanup + - Always prefer explicit .close() or context manager + + Best practices: + 1. Use context manager: + with Neo4jGraph(...) as graph: + ... + 2. Explicitly close: + graph = Neo4jGraph(...) + try: + ... + finally: + graph.close() + """ + try: + self.close() + except Exception: + # Suppress any exceptions during garbage collection + pass diff --git a/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py b/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py index 1a12276..17fcace 100644 --- a/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py @@ -398,3 +398,162 @@ def test_backticks() -> None: assert nodes == expected_nodes assert rels == expected_rels + + +def test_neo4j_context_manager() -> None: + """Test that Neo4jGraph works correctly with context manager.""" + url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + username = os.environ.get("NEO4J_USERNAME", "neo4j") + password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") + assert url is not None + assert username is not None + assert password is not None + + with Neo4jGraph(url=url, username=username, password=password) as graph: + # Test that the connection is working + graph.query("RETURN 1 as n") + + # Test that the connection is closed after exiting context + try: + graph.query("RETURN 1 as n") + assert False, "Expected RuntimeError when using closed connection" + except RuntimeError: + pass + + +def test_neo4j_explicit_close() -> None: + """Test that Neo4jGraph can be explicitly closed.""" + url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + username = os.environ.get("NEO4J_USERNAME", "neo4j") + password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password) + # Test that the connection is working + graph.query("RETURN 1 as n") + + # Close the connection + graph.close() + + # Test that the connection is closed + try: + graph.query("RETURN 1 as n") + assert False, "Expected RuntimeError when using closed connection" + except RuntimeError: + pass + + +def test_neo4j_error_after_close() -> None: + """Test that Neo4jGraph operations raise proper errors after closing.""" + url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + username = os.environ.get("NEO4J_USERNAME", "neo4j") + password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password) + graph.query("RETURN 1") # Should work + graph.close() + + # Test various operations after close + try: + graph.refresh_schema() + assert ( + False + ), "Expected RuntimeError when refreshing schema on closed connection" + except RuntimeError as e: + assert "connection has been closed" in str(e) + + try: + graph.query("RETURN 1") + assert False, "Expected RuntimeError when querying closed connection" + except RuntimeError as e: + assert "connection has been closed" in str(e) + + try: + graph.add_graph_documents([test_data[0]]) + assert False, "Expected RuntimeError when adding documents to closed connection" + except RuntimeError as e: + assert "connection has been closed" in str(e) + + +def test_neo4j_concurrent_connections() -> None: + """Test that multiple Neo4jGraph instances can be used independently.""" + url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + username = os.environ.get("NEO4J_USERNAME", "neo4j") + password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") + assert url is not None + assert username is not None + assert password is not None + + graph1 = Neo4jGraph(url=url, username=username, password=password) + graph2 = Neo4jGraph(url=url, username=username, password=password) + + # Both connections should work independently + assert graph1.query("RETURN 1 as n") == [{"n": 1}] + assert graph2.query("RETURN 2 as n") == [{"n": 2}] + + # Closing one shouldn't affect the other + graph1.close() + try: + graph1.query("RETURN 1") + assert False, "Expected RuntimeError when using closed connection" + except RuntimeError: + pass + assert graph2.query("RETURN 2 as n") == [{"n": 2}] + + graph2.close() + + +def test_neo4j_nested_context_managers() -> None: + """Test that nested context managers work correctly.""" + url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + username = os.environ.get("NEO4J_USERNAME", "neo4j") + password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") + assert url is not None + assert username is not None + assert password is not None + + with Neo4jGraph(url=url, username=username, password=password) as graph1: + with Neo4jGraph(url=url, username=username, password=password) as graph2: + # Both connections should work + assert graph1.query("RETURN 1 as n") == [{"n": 1}] + assert graph2.query("RETURN 2 as n") == [{"n": 2}] + + # Inner connection should be closed, outer still works + try: + graph2.query("RETURN 2") + assert False, "Expected RuntimeError when using closed connection" + except RuntimeError: + pass + assert graph1.query("RETURN 1 as n") == [{"n": 1}] + + # Both connections should be closed + try: + graph1.query("RETURN 1") + assert False, "Expected RuntimeError when using closed connection" + except RuntimeError: + pass + try: + graph2.query("RETURN 2") + assert False, "Expected RuntimeError when using closed connection" + except RuntimeError: + pass + + +def test_neo4j_multiple_close() -> None: + """Test that Neo4jGraph can be closed multiple times without error.""" + url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + username = os.environ.get("NEO4J_USERNAME", "neo4j") + password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password) + # Test that multiple closes don't raise errors + graph.close() + graph.close() # This should not raise an error diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 12daf57..0a74ffa 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -1,4 +1,8 @@ -from langchain_neo4j.graphs.neo4j_graph import value_sanitize +from unittest.mock import MagicMock, patch + +import pytest + +from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph, value_sanitize def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def] @@ -39,3 +43,112 @@ def test_value_sanitize_with_dict_in_nested_list(): # type: ignore[no-untyped-d } expected_output = {"key1": "value1", "deeply_nested_lists": [[[[{}]]]]} assert value_sanitize(input_dict) == expected_output + + +def test_driver_state_management(): # type: ignore[no-untyped-def] + """Comprehensive test for driver state management.""" + with patch("neo4j.GraphDatabase.driver") as mock_driver: + # Setup mock driver + mock_driver_instance = MagicMock() + mock_driver.return_value = mock_driver_instance + mock_driver_instance.execute_query = MagicMock(return_value=([], None, None)) + + # Create graph instance + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + + # Store original driver + original_driver = graph._driver + original_driver.close = MagicMock() + + # Test initial state + assert hasattr(graph, "_driver") + + # First close + graph.close() + original_driver.close.assert_called_once() + assert not hasattr(graph, "_driver") + + # Verify methods raise error when driver is closed + with pytest.raises( + RuntimeError, + match="Cannot perform operations - Neo4j connection has been closed", + ): + graph.query("RETURN 1") + + with pytest.raises( + RuntimeError, + match="Cannot perform operations - Neo4j connection has been closed", + ): + graph.refresh_schema() + + +def test_close_method_removes_driver(): # type: ignore[no-untyped-def] + """Test that close method removes the _driver attribute.""" + with patch("neo4j.GraphDatabase.driver") as mock_driver: + # Configure mock to return a mock driver + mock_driver_instance = MagicMock() + mock_driver.return_value = mock_driver_instance + + # Configure mock execute_query to return empty result + mock_driver_instance.execute_query = MagicMock(return_value=([], None, None)) + + # Add a _closed attribute to simulate driver state + mock_driver_instance._closed = False + + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + + # Store a reference to the original driver + original_driver = graph._driver + + # Ensure driver's close method can be mocked + original_driver.close = MagicMock() + + # Call close method + graph.close() + + # Verify driver.close was called + original_driver.close.assert_called_once() + + # Verify _driver attribute is removed + assert not hasattr(graph, "_driver") + + # Verify second close does not raise an error + graph.close() # Should not raise any exception + + +def test_multiple_close_calls_safe(): # type: ignore[no-untyped-def] + """Test that multiple close calls do not raise errors.""" + with patch("neo4j.GraphDatabase.driver") as mock_driver: + # Configure mock to return a mock driver + mock_driver_instance = MagicMock() + mock_driver.return_value = mock_driver_instance + + # Configure mock execute_query to return empty result + mock_driver_instance.execute_query = MagicMock(return_value=([], None, None)) + + # Add a _closed attribute to simulate driver state + mock_driver_instance._closed = False + + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + + # Store a reference to the original driver + original_driver = graph._driver + + # Mock the driver's close method + original_driver.close = MagicMock() + + # First close + graph.close() + original_driver.close.assert_called_once() + + # Verify _driver attribute is removed + assert not hasattr(graph, "_driver") + + # Second close should not raise an error + graph.close() # Should not raise any exception