Skip to content

Commit

Permalink
feat: add driver connection lifecycle mgmt to adhere to neo4j Driver …
Browse files Browse the repository at this point in the history
…expectations (langchain-ai#14)

* feat:add context/connection lifecycle management including: state checks for neo4jgraph methods, prevent op on closed connections, driver resource management + tests

* slight refactor, add unit tests

* format and lint

* changelog

* add defaults for integration tests
  • Loading branch information
jtanningbed authored Nov 29, 2024
1 parent 0889744 commit 3489bf5
Show file tree
Hide file tree
Showing 4 changed files with 384 additions and 2 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## Next

### Added

- Enhanced Neo4j driver connection management with more robust error handling
- Simplified connection state checking in Neo4jGraph

## 0.1.1

### Changed
Expand Down
107 changes: 106 additions & 1 deletion libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hashlib import md5
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

from langchain_core.utils import get_from_dict_or_env

Expand Down Expand Up @@ -400,6 +400,18 @@ def __init__(
)
raise e

def _check_driver_state(self) -> None:
"""
Check if the driver is available and ready for operations.
Raises:
RuntimeError: If the driver has been closed or is not initialized.
"""
if not hasattr(self, "_driver"):
raise RuntimeError(
"Cannot perform operations - Neo4j connection has been closed"
)

@property
def get_schema(self) -> str:
"""Returns the schema of the Graph"""
Expand All @@ -423,7 +435,11 @@ def query(
Returns:
List[Dict[str, Any]]: The list of dictionaries containing the query results.
Raises:
RuntimeError: If the connection has been closed.
"""
self._check_driver_state()
from neo4j import Query
from neo4j.exceptions import Neo4jError

Expand Down Expand Up @@ -467,7 +483,11 @@ def query(
def refresh_schema(self) -> None:
"""
Refreshes the Neo4j graph schema information.
Raises:
RuntimeError: If the connection has been closed.
"""
self._check_driver_state()
from neo4j.exceptions import ClientError, CypherTypeError

node_properties = [
Expand Down Expand Up @@ -588,7 +608,11 @@ def add_graph_documents(
- baseEntityLabel (bool, optional): If True, each newly created node
gets a secondary __Entity__ label, which is indexed and improves import
speed and performance. Defaults to False.
Raises:
RuntimeError: If the connection has been closed.
"""
self._check_driver_state()
if baseEntityLabel: # Check if constraint already exists
constraint_exists = any(
[
Expand Down Expand Up @@ -810,3 +834,84 @@ def _enhanced_schema_cypher(
# Combine all parts of the Cypher query
cypher_query = "\n".join([match_clause, with_clause, return_clause])
return cypher_query

def close(self) -> None:
"""
Explicitly close the Neo4j driver connection.
Delegates connection management to the Neo4j driver.
"""
if hasattr(self, "_driver"):
self._driver.close()
# Remove the driver attribute to indicate closure
delattr(self, "_driver")

def __enter__(self) -> "Neo4jGraph":
"""
Enter the runtime context for the Neo4j graph connection.
Enables use of the graph connection with the 'with' statement.
This method allows for automatic resource management and ensures
that the connection is properly handled.
Returns:
Neo4jGraph: The current graph connection instance
Example:
with Neo4jGraph(...) as graph:
graph.query(...) # Connection automatically managed
"""
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> None:
"""
Exit the runtime context for the Neo4j graph connection.
This method is automatically called when exiting a 'with' statement.
It ensures that the database connection is closed, regardless of
whether an exception occurred during the context's execution.
Args:
exc_type: The type of exception that caused the context to exit
(None if no exception occurred)
exc_val: The exception instance that caused the context to exit
(None if no exception occurred)
exc_tb: The traceback for the exception (None if no exception occurred)
Note:
Any exception is re-raised after the connection is closed.
"""
self.close()

def __del__(self) -> None:
"""
Destructor for the Neo4j graph connection.
This method is called during garbage collection to ensure that
database resources are released if not explicitly closed.
Caution:
- Do not rely on this method for deterministic resource cleanup
- Always prefer explicit .close() or context manager
Best practices:
1. Use context manager:
with Neo4jGraph(...) as graph:
...
2. Explicitly close:
graph = Neo4jGraph(...)
try:
...
finally:
graph.close()
"""
try:
self.close()
except Exception:
# Suppress any exceptions during garbage collection
pass
159 changes: 159 additions & 0 deletions libs/neo4j/tests/integration_tests/graphs/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,162 @@ def test_backticks() -> None:

assert nodes == expected_nodes
assert rels == expected_rels


def test_neo4j_context_manager() -> None:
"""Test that Neo4jGraph works correctly with context manager."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j")
password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
assert url is not None
assert username is not None
assert password is not None

with Neo4jGraph(url=url, username=username, password=password) as graph:
# Test that the connection is working
graph.query("RETURN 1 as n")

# Test that the connection is closed after exiting context
try:
graph.query("RETURN 1 as n")
assert False, "Expected RuntimeError when using closed connection"
except RuntimeError:
pass


def test_neo4j_explicit_close() -> None:
"""Test that Neo4jGraph can be explicitly closed."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j")
password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
assert url is not None
assert username is not None
assert password is not None

graph = Neo4jGraph(url=url, username=username, password=password)
# Test that the connection is working
graph.query("RETURN 1 as n")

# Close the connection
graph.close()

# Test that the connection is closed
try:
graph.query("RETURN 1 as n")
assert False, "Expected RuntimeError when using closed connection"
except RuntimeError:
pass


def test_neo4j_error_after_close() -> None:
"""Test that Neo4jGraph operations raise proper errors after closing."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j")
password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
assert url is not None
assert username is not None
assert password is not None

graph = Neo4jGraph(url=url, username=username, password=password)
graph.query("RETURN 1") # Should work
graph.close()

# Test various operations after close
try:
graph.refresh_schema()
assert (
False
), "Expected RuntimeError when refreshing schema on closed connection"
except RuntimeError as e:
assert "connection has been closed" in str(e)

try:
graph.query("RETURN 1")
assert False, "Expected RuntimeError when querying closed connection"
except RuntimeError as e:
assert "connection has been closed" in str(e)

try:
graph.add_graph_documents([test_data[0]])
assert False, "Expected RuntimeError when adding documents to closed connection"
except RuntimeError as e:
assert "connection has been closed" in str(e)


def test_neo4j_concurrent_connections() -> None:
"""Test that multiple Neo4jGraph instances can be used independently."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j")
password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
assert url is not None
assert username is not None
assert password is not None

graph1 = Neo4jGraph(url=url, username=username, password=password)
graph2 = Neo4jGraph(url=url, username=username, password=password)

# Both connections should work independently
assert graph1.query("RETURN 1 as n") == [{"n": 1}]
assert graph2.query("RETURN 2 as n") == [{"n": 2}]

# Closing one shouldn't affect the other
graph1.close()
try:
graph1.query("RETURN 1")
assert False, "Expected RuntimeError when using closed connection"
except RuntimeError:
pass
assert graph2.query("RETURN 2 as n") == [{"n": 2}]

graph2.close()


def test_neo4j_nested_context_managers() -> None:
"""Test that nested context managers work correctly."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j")
password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
assert url is not None
assert username is not None
assert password is not None

with Neo4jGraph(url=url, username=username, password=password) as graph1:
with Neo4jGraph(url=url, username=username, password=password) as graph2:
# Both connections should work
assert graph1.query("RETURN 1 as n") == [{"n": 1}]
assert graph2.query("RETURN 2 as n") == [{"n": 2}]

# Inner connection should be closed, outer still works
try:
graph2.query("RETURN 2")
assert False, "Expected RuntimeError when using closed connection"
except RuntimeError:
pass
assert graph1.query("RETURN 1 as n") == [{"n": 1}]

# Both connections should be closed
try:
graph1.query("RETURN 1")
assert False, "Expected RuntimeError when using closed connection"
except RuntimeError:
pass
try:
graph2.query("RETURN 2")
assert False, "Expected RuntimeError when using closed connection"
except RuntimeError:
pass


def test_neo4j_multiple_close() -> None:
"""Test that Neo4jGraph can be closed multiple times without error."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j")
password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
assert url is not None
assert username is not None
assert password is not None

graph = Neo4jGraph(url=url, username=username, password=password)
# Test that multiple closes don't raise errors
graph.close()
graph.close() # This should not raise an error
Loading

0 comments on commit 3489bf5

Please sign in to comment.