From 6b31b1fcc6913314dcbc5faa21fd3dc499544112 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 11 Dec 2024 10:32:23 +0000 Subject: [PATCH 1/9] Adds tests so that Neo4jChatMessageHistory class now has 100% coverage (#20) * Adds tests so that Neo4jChatMessageHistory class now has 100% coverage * Fixed 3.10 issue --- .../chat_message_histories/test_neo4j.py | 38 ++++++++++ .../chat_message_histories/__init__.py | 0 .../test_neo4j_chat_message_history.py | 69 +++++++++++++++++++ 3 files changed, 107 insertions(+) create mode 100644 libs/neo4j/tests/unit_tests/chat_message_histories/__init__.py create mode 100644 libs/neo4j/tests/unit_tests/chat_message_histories/test_neo4j_chat_message_history.py diff --git a/libs/neo4j/tests/integration_tests/chat_message_histories/test_neo4j.py b/libs/neo4j/tests/integration_tests/chat_message_histories/test_neo4j.py index ebc7f53..8343da2 100644 --- a/libs/neo4j/tests/integration_tests/chat_message_histories/test_neo4j.py +++ b/libs/neo4j/tests/integration_tests/chat_message_histories/test_neo4j.py @@ -1,5 +1,7 @@ import os +import urllib.parse +import pytest from langchain_core.messages import AIMessage, HumanMessage from langchain_neo4j.chat_message_histories.neo4j import Neo4jChatMessageHistory @@ -82,3 +84,39 @@ def test_add_messages_graph_object() -> None: del os.environ["NEO4J_URI"] del os.environ["NEO4J_USERNAME"] del os.environ["NEO4J_PASSWORD"] + + +def test_invalid_url() -> None: + """Test initializing with invalid credentials raises ValueError.""" + # Parse the original URL + parsed_url = urllib.parse.urlparse(url) + # Increment the port number by 1 and wrap around if necessary + original_port = parsed_url.port or 7687 + new_port = (original_port + 1) % 65535 or 1 + # Reconstruct the netloc (hostname:port) + new_netloc = f"{parsed_url.hostname}:{new_port}" + # Rebuild the URL with the new netloc + new_url = parsed_url._replace(netloc=new_netloc).geturl() + + with pytest.raises(ValueError) as exc_info: + Neo4jChatMessageHistory( + "test_session", + url=new_url, + username=username, + password=password, + ) + assert "Please ensure that the url is correct" in str(exc_info.value) + + +def test_invalid_credentials() -> None: + """Test initializing with invalid credentials raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + Neo4jChatMessageHistory( + "test_session", + url=url, + username="invalid_username", + password="invalid_password", + ) + assert "Please ensure that the username and password are correct" in str( + exc_info.value + ) diff --git a/libs/neo4j/tests/unit_tests/chat_message_histories/__init__.py b/libs/neo4j/tests/unit_tests/chat_message_histories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/neo4j/tests/unit_tests/chat_message_histories/test_neo4j_chat_message_history.py b/libs/neo4j/tests/unit_tests/chat_message_histories/test_neo4j_chat_message_history.py new file mode 100644 index 0000000..45e84b0 --- /dev/null +++ b/libs/neo4j/tests/unit_tests/chat_message_histories/test_neo4j_chat_message_history.py @@ -0,0 +1,69 @@ +import gc +from types import ModuleType +from typing import Mapping, Sequence, Union +from unittest.mock import MagicMock, patch + +import pytest + +from langchain_neo4j.chat_message_histories.neo4j import Neo4jChatMessageHistory + + +def test_init_without_session_id() -> None: + """Test initializing without session_id raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + Neo4jChatMessageHistory(None) # type: ignore[arg-type] + assert "Please ensure that the session_id parameter is provided" in str( + exc_info.value + ) + + +def test_messages_setter() -> None: + """Test that assigning to messages raises NotImplementedError.""" + with patch("neo4j.GraphDatabase.driver", autospec=True): + message_store = Neo4jChatMessageHistory( + session_id="test_session", + url="bolt://url", + username="username", + password="password", + ) + + with pytest.raises(NotImplementedError) as exc_info: + message_store.messages = [] + assert "Direct assignment to 'messages' is not allowed." in str(exc_info.value) + + +def test_import_error() -> None: + """Test that ImportError is raised when neo4j package is not installed.""" + original_import = __import__ + + def mock_import( + name: str, + globals: Union[Mapping[str, object], None] = None, + locals: Union[Mapping[str, object], None] = None, + fromlist: Sequence[str] = (), + level: int = 0, + ) -> ModuleType: + if name == "neo4j": + raise ImportError() + return original_import(name, globals, locals, fromlist, level) + + with patch("builtins.__import__", side_effect=mock_import): + with pytest.raises(ImportError) as exc_info: + Neo4jChatMessageHistory("test_session") + assert "Could not import neo4j python package." in str(exc_info.value) + + +def test_driver_closed_on_delete() -> None: + """Test that the driver is closed when the object is deleted.""" + with patch("neo4j.GraphDatabase.driver", autospec=True): + message_store = Neo4jChatMessageHistory( + session_id="test_session", + url="bolt://url", + username="username", + password="password", + ) + mock_driver = message_store._driver + assert isinstance(mock_driver.close, MagicMock) + message_store.__del__() + gc.collect() + mock_driver.close.assert_called_once() From cfa583a9373d660eae0c66d9b495db444e2f618c Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Fri, 13 Dec 2024 12:56:44 +0000 Subject: [PATCH 2/9] Adds tests so that Neo4jGraph class now has 100% coverage (#21) * Added tests for _format_schema helper function * Added mock Neo4j driver to Neo4jGraph tests * Added _enhanced_schema_cypher tests * Parametrised _format_schema tests * Parametrised value_sanitize tests * test_format_schema refactoring * More refactoring + minor tests added * Neo4jGraph 100% coverage * Fixed linting issues --- .../langchain_neo4j/graphs/neo4j_graph.py | 10 +- .../integration_tests/graphs/test_neo4j.py | 89 +- .../unit_tests/graphs/test_neo4j_graph.py | 976 +++++++++++++++--- 3 files changed, 948 insertions(+), 127 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index dd97de0..237c5d5 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -191,7 +191,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: "DATE_TIME", "LOCAL_DATE_TIME", ]: - if prop.get("min") is not None: + if prop.get("min") and prop.get("max"): example = f'Min: {prop["min"]}, Max: {prop["max"]}' else: example = ( @@ -215,7 +215,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: formatted_rel_props.append(f"- **{rel_type}**") for prop in properties: example = "" - if prop["type"] == "STRING": + if prop["type"] == "STRING" and prop.get("values"): if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT: example = ( f'Example: "{clean_string_values(prop["values"][0])}"' @@ -238,8 +238,8 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: "DATE_TIME", "LOCAL_DATE_TIME", ]: - if prop.get("min"): # If we have min/max - example = f'Min: {prop["min"]}, Max: {prop["max"]}' + if prop.get("min") and prop.get("max"): # If we have min/max + example = f'Min: {prop["min"]}, Max: {prop["max"]}' else: # return a single value example = ( f'Example: "{prop["values"][0]}"' if prop["values"] else "" @@ -252,7 +252,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}' ) formatted_rel_props.append( - f" - `{prop['property']}: {prop['type']}` {example}" + f" - `{prop['property']}`: {prop['type']} {example}" ) else: # Format node properties diff --git a/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py b/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py index 6c27707..bd47454 100644 --- a/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py @@ -1,5 +1,7 @@ import os +import urllib +import pytest from langchain_core.documents import Document from langchain_neo4j import Neo4jGraph @@ -19,6 +21,7 @@ source=Node(id="foo", type="foo"), target=Node(id="bar", type="bar"), type="REL", + properties={"key": "val"}, ) ], source=Document(page_content="source document"), @@ -130,7 +133,7 @@ def test_neo4j_timeout() -> None: def test_neo4j_sanitize_values() -> None: - """Test that neo4j uses the timeout correctly.""" + """Test that lists with more than 128 elements are removed from the results.""" url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") username = os.environ.get("NEO4J_USERNAME", "neo4j") password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") @@ -347,7 +350,16 @@ def test_enhanced_schema() -> None: } ], }, - "rel_props": {}, + "rel_props": { + "REL": [ + { + "distinct_count": 1, + "property": "key", + "type": "STRING", + "values": ["val"], + } + ] + }, "relationships": [{"start": "foo", "type": "REL", "end": "bar"}], } # remove metadata portion of schema @@ -365,16 +377,37 @@ def test_enhanced_schema_exception() -> None: assert password is not None graph = Neo4jGraph( - url=url, username=username, password=password, enhanced_schema=True + url=url, + username=username, + password=password, + enhanced_schema=True, + refresh_schema=False, ) graph.query("MATCH (n) DETACH DELETE n") - graph.query("CREATE (:Node {foo:'bar'})," "(:Node {foo: 1}), (:Node {foo: [1,2]})") + graph.query( + "CREATE (:Node {foo: 'bar'}), (:Node {foo: 1}), (:Node {foo: [1,2]}), " + "(: EmptyNode)" + ) + graph.query( + "MATCH (a:Node {foo: 'bar'}), (b:Node {foo: 1}), " + "(c:Node {foo: [1,2]}), (d: EmptyNode) " + "CREATE (a)-[:REL {foo: 'bar'}]->(b), (b)-[:REL {foo: 1}]->(c), " + "(c)-[:REL {foo: [1,2]}]->(a), (d)-[:EMPTY_REL {}]->(d)" + ) graph.refresh_schema() expected_output = { "node_props": {"Node": [{"property": "foo", "type": "STRING"}]}, - "rel_props": {}, - "relationships": [], + "rel_props": {"REL": [{"property": "foo", "type": "STRING"}]}, + "relationships": [ + { + "end": "Node", + "start": "Node", + "type": "REL", + }, + {"end": "EmptyNode", "start": "EmptyNode", "type": "EMPTY_REL"}, + ], } + # remove metadata portion of schema del graph.structured_schema["metadata"] assert graph.structured_schema == expected_output @@ -558,3 +591,47 @@ def test_neo4j_multiple_close() -> None: # Test that multiple closes don't raise errors graph.close() graph.close() # This should not raise an error + + +def test_invalid_url() -> None: + """Test initializing with invalid credentials raises ValueError.""" + url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + username = os.environ.get("NEO4J_USERNAME", "neo4j") + password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") + assert url is not None + assert username is not None + assert password is not None + + # Parse the original URL + parsed_url = urllib.parse.urlparse(url) + # Increment the port number by 1 and wrap around if necessary + original_port = parsed_url.port or 7687 + new_port = (original_port + 1) % 65535 or 1 + # Reconstruct the netloc (hostname:port) + new_netloc = f"{parsed_url.hostname}:{new_port}" + # Rebuild the URL with the new netloc + new_url = parsed_url._replace(netloc=new_netloc).geturl() + + with pytest.raises(ValueError) as exc_info: + Neo4jGraph( + url=new_url, + username=username, + password=password, + ) + assert "Please ensure that the url is correct" in str(exc_info.value) + + +def test_invalid_credentials() -> None: + """Test initializing with invalid credentials raises ValueError.""" + url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + assert url is not None + + with pytest.raises(ValueError) as exc_info: + Neo4jGraph( + url=url, + username="invalid_username", + password="invalid_password", + ) + assert "Please ensure that the username and password are correct" in str( + exc_info.value + ) diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 3a617b5..265af31 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -1,150 +1,894 @@ +from types import ModuleType +from typing import Any, Dict, Generator, Mapping, Sequence, Union from unittest.mock import MagicMock, patch import pytest +from neo4j.exceptions import ClientError, Neo4jError -from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph, value_sanitize +from langchain_neo4j.graphs.neo4j_graph import ( + LIST_LIMIT, + Neo4jGraph, + _format_schema, + value_sanitize, +) -def test_value_sanitize_with_small_list() -> None: - small_list = list(range(15)) # list size > LIST_LIMIT - input_dict = {"key1": "value1", "small_list": small_list} - expected_output = {"key1": "value1", "small_list": small_list} - assert value_sanitize(input_dict) == expected_output - - -def test_value_sanitize_with_oversized_list() -> None: - oversized_list = list(range(150)) # list size > LIST_LIMIT - input_dict = {"key1": "value1", "oversized_list": oversized_list} - expected_output = { - "key1": "value1" - # oversized_list should not be included - } - assert value_sanitize(input_dict) == expected_output - - -def test_value_sanitize_with_nested_oversized_list() -> None: - oversized_list = list(range(150)) # list size > LIST_LIMIT - input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}} - expected_output = {"key1": "value1", "oversized_list": {}} - assert value_sanitize(input_dict) == expected_output - - -def test_value_sanitize_with_dict_in_list() -> None: - oversized_list = list(range(150)) # list size > LIST_LIMIT - input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]} - expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]} - assert value_sanitize(input_dict) == expected_output - - -def test_value_sanitize_with_dict_in_nested_list() -> None: - input_dict = { - "key1": "value1", - "deeply_nested_lists": [[[[{"final_nested_key": list(range(200))}]]]], - } - expected_output = {"key1": "value1", "deeply_nested_lists": [[[[{}]]]]} - assert value_sanitize(input_dict) == expected_output - - -def test_driver_state_management() -> None: - """Comprehensive test for driver state management.""" +@pytest.fixture +def mock_neo4j_driver() -> Generator[MagicMock, None, None]: with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver: - # Setup mock driver mock_driver_instance = MagicMock() mock_driver.return_value = mock_driver_instance + mock_driver_instance.verify_connectivity.return_value = None mock_driver_instance.execute_query = MagicMock(return_value=([], None, None)) - - # Create graph instance + mock_driver_instance._closed = False + yield mock_driver_instance + + +@pytest.mark.parametrize( + "description, input_value, expected_output", + [ + ( + "Small list", + {"key1": "value1", "small_list": list(range(15))}, + {"key1": "value1", "small_list": list(range(15))}, + ), + ( + "Oversized list", + {"key1": "value1", "oversized_list": list(range(LIST_LIMIT + 1))}, + {"key1": "value1"}, + ), + ( + "Nested oversized list", + {"key1": "value1", "oversized_list": {"key": list(range(150))}}, + {"key1": "value1", "oversized_list": {}}, + ), + ( + "Dict in list", + { + "key1": "value1", + "oversized_list": [1, 2, {"key": list(range(LIST_LIMIT + 1))}], + }, + {"key1": "value1", "oversized_list": [1, 2, {}]}, + ), + ( + "Dict in nested list", + { + "key1": "value1", + "deeply_nested_lists": [ + [[[{"final_nested_key": list(range(LIST_LIMIT + 1))}]]] + ], + }, + {"key1": "value1", "deeply_nested_lists": [[[[{}]]]]}, + ), + ( + "Bare oversized list", + list(range(LIST_LIMIT + 1)), + None, + ), + ( + "None value", + None, + None, + ), + ], +) +def test_value_sanitize( + description: str, input_value: Dict[str, Any], expected_output: Any +) -> None: + """Test the value_sanitize function.""" + assert ( + value_sanitize(input_value) == expected_output + ), f"Failed test case: {description}" + + +def test_driver_state_management(mock_neo4j_driver: MagicMock) -> None: + """Comprehensive test for driver state management.""" + # Create graph instance + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + + # Store original driver + original_driver = graph._driver + assert isinstance(original_driver.close, MagicMock) + + # Test initial state + assert hasattr(graph, "_driver") + + # First close + graph.close() + original_driver.close.assert_called_once() + assert not hasattr(graph, "_driver") + + # Verify methods raise error when driver is closed + with pytest.raises( + RuntimeError, + match="Cannot perform operations - Neo4j connection has been closed", + ): + graph.query("RETURN 1") + + with pytest.raises( + RuntimeError, + match="Cannot perform operations - Neo4j connection has been closed", + ): + graph.refresh_schema() + + +def test_neo4j_graph_del_method(mock_neo4j_driver: MagicMock) -> None: + """Test the __del__ method.""" + with patch.object(Neo4jGraph, "close") as mock_close: graph = Neo4jGraph( url="bolt://localhost:7687", username="neo4j", password="password" ) + # Ensure exceptions are suppressed when the graph's destructor is called + mock_close.side_effect = Exception() + mock_close.assert_not_called() + graph.__del__() + mock_close.assert_called_once() - # Store original driver - original_driver = graph._driver - assert isinstance(original_driver.close, MagicMock) - # Test initial state - assert hasattr(graph, "_driver") +def test_close_method_removes_driver(mock_neo4j_driver: MagicMock) -> None: + """Test that close method removes the _driver attribute.""" + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) - # First close - graph.close() - original_driver.close.assert_called_once() - assert not hasattr(graph, "_driver") + # Store a reference to the original driver + original_driver = graph._driver + assert isinstance(original_driver.close, MagicMock) - # Verify methods raise error when driver is closed - with pytest.raises( - RuntimeError, - match="Cannot perform operations - Neo4j connection has been closed", - ): - graph.query("RETURN 1") + # Call close method + graph.close() - with pytest.raises( - RuntimeError, - match="Cannot perform operations - Neo4j connection has been closed", - ): - graph.refresh_schema() + # Verify driver.close was called + original_driver.close.assert_called_once() + # Verify _driver attribute is removed + assert not hasattr(graph, "_driver") -def test_close_method_removes_driver() -> None: - """Test that close method removes the _driver attribute.""" - with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver: - # Configure mock to return a mock driver - mock_driver_instance = MagicMock() - mock_driver.return_value = mock_driver_instance + # Verify second close does not raise an error + graph.close() # Should not raise any exception - # Configure mock execute_query to return empty result - mock_driver_instance.execute_query = MagicMock(return_value=([], None, None)) - # Add a _closed attribute to simulate driver state - mock_driver_instance._closed = False +def test_multiple_close_calls_safe(mock_neo4j_driver: MagicMock) -> None: + """Test that multiple close calls do not raise errors.""" + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) + # Store a reference to the original driver + original_driver = graph._driver + assert isinstance(original_driver.close, MagicMock) - # Store a reference to the original driver - original_driver = graph._driver - assert isinstance(original_driver.close, MagicMock) + # First close + graph.close() + original_driver.close.assert_called_once() - # Call close method - graph.close() + # Verify _driver attribute is removed + assert not hasattr(graph, "_driver") - # Verify driver.close was called - original_driver.close.assert_called_once() + # Second close should not raise an error + graph.close() # Should not raise any exception - # Verify _driver attribute is removed - assert not hasattr(graph, "_driver") - # Verify second close does not raise an error - graph.close() # Should not raise any exception +def test_import_error() -> None: + """Test that ImportError is raised when neo4j package is not installed.""" + original_import = __import__ + def mock_import( + name: str, + globals: Union[Mapping[str, object], None] = None, + locals: Union[Mapping[str, object], None] = None, + fromlist: Sequence[str] = (), + level: int = 0, + ) -> ModuleType: + if name == "neo4j": + raise ImportError() + return original_import(name, globals, locals, fromlist, level) -def test_multiple_close_calls_safe() -> None: - """Test that multiple close calls do not raise errors.""" + with patch("builtins.__import__", side_effect=mock_import): + with pytest.raises(ImportError) as exc_info: + Neo4jGraph() + assert "Could not import neo4j python package." in str(exc_info.value) + + +def test_neo4j_graph_init_with_empty_credentials() -> None: + """Test the __init__ method when no credentials have been provided.""" with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver: - # Configure mock to return a mock driver mock_driver_instance = MagicMock() mock_driver.return_value = mock_driver_instance - - # Configure mock execute_query to return empty result - mock_driver_instance.execute_query = MagicMock(return_value=([], None, None)) - - # Add a _closed attribute to simulate driver state - mock_driver_instance._closed = False - - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" + mock_driver_instance.verify_connectivity.return_value = None + Neo4jGraph( + url="bolt://localhost:7687", username="", password="", refresh_schema=False ) - - # Store a reference to the original driver - original_driver = graph._driver - assert isinstance(original_driver.close, MagicMock) - - # First close - graph.close() - original_driver.close.assert_called_once() - - # Verify _driver attribute is removed - assert not hasattr(graph, "_driver") - - # Second close should not raise an error - graph.close() # Should not raise any exception + mock_driver.assert_called_with("bolt://localhost:7687", auth=None) + + +def test_init_apoc_procedure_not_found( + mock_neo4j_driver: MagicMock, +) -> None: + """Test an error is raised when APOC is not installed.""" + with patch("langchain_neo4j.Neo4jGraph.refresh_schema") as mock_refresh_schema: + err = ClientError() + err.code = "Neo.ClientError.Procedure.ProcedureNotFound" + mock_refresh_schema.side_effect = err + with pytest.raises(ValueError) as exc_info: + Neo4jGraph(url="bolt://localhost:7687", username="", password="") + assert "Could not use APOC procedures." in str(exc_info.value) + + +def test_init_refresh_schema_other_err( + mock_neo4j_driver: MagicMock, +) -> None: + """Test any other ClientErrors raised when calling refresh_schema in __init__ are + re-raised.""" + with patch("langchain_neo4j.Neo4jGraph.refresh_schema") as mock_refresh_schema: + err = ClientError() + err.code = "other_error" + mock_refresh_schema.side_effect = err + with pytest.raises(ClientError) as exc_info: + Neo4jGraph(url="bolt://localhost:7687", username="", password="") + assert exc_info.value == err + + +def test_query_fallback_execution(mock_neo4j_driver: MagicMock) -> None: + """Test the fallback to allow for implicit transactions in query.""" + err = Neo4jError() + err.code = "Neo.DatabaseError.Statement.ExecutionFailed" + err.message = "in an implicit transaction" + mock_neo4j_driver.execute_query.side_effect = err + graph = Neo4jGraph( + url="bolt://localhost:7687", + username="neo4j", + password="password", + database="test_db", + sanitize=True, + ) + mock_session = MagicMock() + mock_result = MagicMock() + mock_result.data.return_value = { + "key1": "value1", + "oversized_list": list(range(LIST_LIMIT + 1)), + } + mock_session.run.return_value = [mock_result] + mock_neo4j_driver.session.return_value.__enter__.return_value = mock_session + mock_neo4j_driver.session.return_value.__exit__.return_value = None + query = "MATCH (n) RETURN n;" + params = {"param1": "value1"} + json_data = graph.query(query, params) + mock_neo4j_driver.session.assert_called_with(database="test_db") + called_args, _ = mock_session.run.call_args + called_query = called_args[0] + assert called_query.text == query + assert called_query.timeout == graph.timeout + assert called_args[1] == params + assert json_data == [{"key1": "value1"}] + + +def test_refresh_schema_handles_client_error(mock_neo4j_driver: MagicMock) -> None: + """Test refresh schema handles a client error which might arise due to a user + not having access to schema information""" + + graph = Neo4jGraph( + url="bolt://localhost:7687", + username="neo4j", + password="password", + database="test_db", + ) + node_properties = [ + { + "output": { + "properties": [{"property": "property_a", "type": "STRING"}], + "labels": "LabelA", + } + } + ] + relationships_properties = [ + { + "output": { + "type": "REL_TYPE", + "properties": [{"property": "rel_prop", "type": "STRING"}], + } + } + ] + relationships = [ + {"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}}, + {"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}}, + ] + + # Mock the query method to raise ClientError for constraint and index queries + graph.query = MagicMock( # type: ignore[method-assign] + side_effect=[ + node_properties, + relationships_properties, + relationships, + ClientError("Mock ClientError"), + ] + ) + graph.refresh_schema() + + # Assertions + # Ensure constraints and indexes are empty due to the ClientError + assert graph.structured_schema["metadata"]["constraint"] == [] + assert graph.structured_schema["metadata"]["index"] == [] + + # Ensure the query method was called as expected + assert graph.query.call_count == 4 + graph.query.assert_any_call("SHOW CONSTRAINTS") + + +def test_get_schema(mock_neo4j_driver: MagicMock) -> None: + """Tests the get_schema property.""" + graph = Neo4jGraph( + url="bolt://localhost:7687", + username="neo4j", + password="password", + refresh_schema=False, + ) + graph.schema = "test" + assert graph.get_schema == "test" + + +@pytest.mark.parametrize( + "description, schema, is_enhanced, expected_output", + [ + ( + "Enhanced, string property with high distinct count", + { + "node_props": { + "Person": [ + { + "property": "name", + "type": "STRING", + "values": ["Alice", "Bob", "Charlie"], + "distinct_count": 11, + } + ] + }, + "rel_props": {}, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "- **Person**\n" + ' - `name`: STRING Example: "Alice"\n' + "Relationship properties:\n" + "\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, string property with low distinct count", + { + "node_props": { + "Animal": [ + { + "property": "species", + "type": "STRING", + "values": ["Cat", "Dog"], + "distinct_count": 2, + } + ] + }, + "rel_props": {}, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "- **Animal**\n" + " - `species`: STRING Available options: ['Cat', 'Dog']\n" + "Relationship properties:\n" + "\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, numeric property with min and max", + { + "node_props": { + "Person": [ + {"property": "age", "type": "INTEGER", "min": 20, "max": 70} + ] + }, + "rel_props": {}, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "- **Person**\n" + " - `age`: INTEGER Min: 20, Max: 70\n" + "Relationship properties:\n" + "\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, numeric property with values", + { + "node_props": { + "Event": [ + { + "property": "date", + "type": "DATE", + "values": ["2021-01-01", "2021-01-02"], + } + ] + }, + "rel_props": {}, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "- **Event**\n" + ' - `date`: DATE Example: "2021-01-01"\n' + "Relationship properties:\n" + "\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, list property that should be skipped", + { + "node_props": { + "Document": [ + { + "property": "embedding", + "type": "LIST", + "min_size": 150, + "max_size": 200, + } + ] + }, + "rel_props": {}, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "- **Document**\n" + "Relationship properties:\n" + "\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, list property that should be included", + { + "node_props": { + "Document": [ + { + "property": "keywords", + "type": "LIST", + "min_size": 2, + "max_size": 5, + } + ] + }, + "rel_props": {}, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "- **Document**\n" + " - `keywords`: LIST Min Size: 2, Max Size: 5\n" + "Relationship properties:\n" + "\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, relationship string property with high distinct count", + { + "node_props": {}, + "rel_props": { + "KNOWS": [ + { + "property": "since", + "type": "STRING", + "values": ["2000", "2001", "2002"], + "distinct_count": 15, + } + ] + }, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "\n" + "Relationship properties:\n" + "- **KNOWS**\n" + ' - `since`: STRING Example: "2000"\n' + "The relationships:\n" + ), + ), + ( + "Enhanced, relationship string property with low distinct count", + { + "node_props": {}, + "rel_props": { + "LIKES": [ + { + "property": "intensity", + "type": "STRING", + "values": ["High", "Medium", "Low"], + "distinct_count": 3, + } + ] + }, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "\n" + "Relationship properties:\n" + "- **LIKES**\n" + " - `intensity`: STRING Available options: ['High', 'Medium', 'Low']\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, relationship numeric property with min and max", + { + "node_props": {}, + "rel_props": { + "WORKS_WITH": [ + { + "property": "since", + "type": "INTEGER", + "min": 1995, + "max": 2020, + } + ] + }, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "\n" + "Relationship properties:\n" + "- **WORKS_WITH**\n" + " - `since`: INTEGER Min: 1995, Max: 2020\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, relationship list property that should be skipped", + { + "node_props": {}, + "rel_props": { + "KNOWS": [ + { + "property": "embedding", + "type": "LIST", + "min_size": 150, + "max_size": 200, + } + ] + }, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "\n" + "Relationship properties:\n" + "- **KNOWS**\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, relationship list property that should be included", + { + "node_props": {}, + "rel_props": { + "KNOWS": [ + { + "property": "messages", + "type": "LIST", + "min_size": 2, + "max_size": 5, + } + ] + }, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "\n" + "Relationship properties:\n" + "- **KNOWS**\n" + " - `messages`: LIST Min Size: 2, Max Size: 5\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, relationship numeric property without min and max", + { + "node_props": {}, + "rel_props": { + "OWES": [ + { + "property": "amount", + "type": "FLOAT", + "values": [3.14, 2.71], + } + ] + }, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "\n" + "Relationship properties:\n" + "- **OWES**\n" + ' - `amount`: FLOAT Example: "3.14"\n' + "The relationships:\n" + ), + ), + ( + "Enhanced, property with empty values list", + { + "node_props": { + "Person": [ + { + "property": "name", + "type": "STRING", + "values": [], + "distinct_count": 15, + } + ] + }, + "rel_props": {}, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "- **Person**\n" + " - `name`: STRING \n" + "Relationship properties:\n" + "\n" + "The relationships:\n" + ), + ), + ( + "Enhanced, property with missing values", + { + "node_props": { + "Person": [ + { + "property": "name", + "type": "STRING", + "distinct_count": 15, + } + ] + }, + "rel_props": {}, + "relationships": [], + }, + True, + ( + "Node properties:\n" + "- **Person**\n" + " - `name`: STRING \n" + "Relationship properties:\n" + "\n" + "The relationships:\n" + ), + ), + ], +) +def test_format_schema( + description: str, schema: Dict, is_enhanced: bool, expected_output: str +) -> None: + result = _format_schema(schema, is_enhanced) + assert result == expected_output, f"Failed test case: {description}" + + +# _enhanced_schema_cypher tests + + +def test_enhanced_schema_cypher_integer_exhaustive_true( + mock_neo4j_driver: MagicMock, +) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + + graph.structured_schema = {"metadata": {"index": []}} + properties = [{"property": "age", "type": "INTEGER"}] + query = graph._enhanced_schema_cypher("Person", properties, exhaustive=True) + assert "min(n.`age`) AS `age_min`" in query + assert "max(n.`age`) AS `age_max`" in query + assert "count(distinct n.`age`) AS `age_distinct`" in query + assert ( + "min: toString(`age_min`), max: toString(`age_max`), " + "distinct_count: `age_distinct`" in query + ) + + +def test_enhanced_schema_cypher_list_exhaustive_true( + mock_neo4j_driver: MagicMock, +) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + graph.structured_schema = {"metadata": {"index": []}} + properties = [{"property": "tags", "type": "LIST"}] + query = graph._enhanced_schema_cypher("Article", properties, exhaustive=True) + assert "min(size(n.`tags`)) AS `tags_size_min`" in query + assert "max(size(n.`tags`)) AS `tags_size_max`" in query + assert "min_size: `tags_size_min`, max_size: `tags_size_max`" in query + + +def test_enhanced_schema_cypher_boolean_exhaustive_true( + mock_neo4j_driver: MagicMock, +) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + properties = [{"property": "active", "type": "BOOLEAN"}] + query = graph._enhanced_schema_cypher("User", properties, exhaustive=True) + # BOOLEAN types should be skipped, so their properties should not be in the query + assert "n.`active`" not in query + + +def test_enhanced_schema_cypher_integer_exhaustive_false_no_index( + mock_neo4j_driver: MagicMock, +) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + graph.structured_schema = {"metadata": {"index": []}} + properties = [{"property": "age", "type": "INTEGER"}] + query = graph._enhanced_schema_cypher("Person", properties, exhaustive=False) + assert "collect(distinct toString(n.`age`)) AS `age_values`" in query + assert "values: `age_values`" in query + + +def test_enhanced_schema_cypher_integer_exhaustive_false_with_index( + mock_neo4j_driver: MagicMock, +) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + graph.structured_schema = { + "metadata": { + "index": [ + { + "label": "Person", + "properties": ["age"], + "type": "RANGE", + } + ] + } + } + properties = [{"property": "age", "type": "INTEGER"}] + query = graph._enhanced_schema_cypher("Person", properties, exhaustive=False) + assert "min(n.`age`) AS `age_min`" in query + assert "max(n.`age`) AS `age_max`" in query + assert "count(distinct n.`age`) AS `age_distinct`" in query + assert ( + "min: toString(`age_min`), max: toString(`age_max`), " + "distinct_count: `age_distinct`" in query + ) + + +def test_enhanced_schema_cypher_list_exhaustive_false( + mock_neo4j_driver: MagicMock, +) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + properties = [{"property": "tags", "type": "LIST"}] + query = graph._enhanced_schema_cypher("Article", properties, exhaustive=False) + assert "min(size(n.`tags`)) AS `tags_size_min`" in query + assert "max(size(n.`tags`)) AS `tags_size_max`" in query + assert "min_size: `tags_size_min`, max_size: `tags_size_max`" in query + + +def test_enhanced_schema_cypher_boolean_exhaustive_false( + mock_neo4j_driver: MagicMock, +) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + properties = [{"property": "active", "type": "BOOLEAN"}] + query = graph._enhanced_schema_cypher("User", properties, exhaustive=False) + # BOOLEAN types should be skipped, so their properties should not be in the query + assert "n.`active`" not in query + + +def test_enhanced_schema_cypher_string_exhaustive_false_with_index( + mock_neo4j_driver: MagicMock, +) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + graph.structured_schema = { + "metadata": { + "index": [ + { + "label": "Person", + "properties": ["status"], + "type": "RANGE", + "size": 5, + "distinctValues": 5, + } + ] + } + } + graph.query = MagicMock(return_value=[{"value": ["Single", "Married", "Divorced"]}]) # type: ignore[method-assign] + properties = [{"property": "status", "type": "STRING"}] + query = graph._enhanced_schema_cypher("Person", properties, exhaustive=False) + assert "values: ['Single', 'Married', 'Divorced'], distinct_count: 3" in query + + +def test_enhanced_schema_cypher_string_exhaustive_false_no_index( + mock_neo4j_driver: MagicMock, +) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + graph.structured_schema = {"metadata": {"index": []}} + properties = [{"property": "status", "type": "STRING"}] + query = graph._enhanced_schema_cypher("Person", properties, exhaustive=False) + assert ( + "collect(distinct substring(toString(n.`status`), 0, 50)) AS `status_values`" + in query + ) + assert "values: `status_values`" in query + + +def test_enhanced_schema_cypher_point_type(mock_neo4j_driver: MagicMock) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + properties = [{"property": "location", "type": "POINT"}] + query = graph._enhanced_schema_cypher("Place", properties, exhaustive=True) + # POINT types should be skipped + assert "n.`location`" not in query + + +def test_enhanced_schema_cypher_duration_type(mock_neo4j_driver: MagicMock) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + properties = [{"property": "duration", "type": "DURATION"}] + query = graph._enhanced_schema_cypher("Event", properties, exhaustive=False) + # DURATION types should be skipped + assert "n.`duration`" not in query + + +def test_enhanced_schema_cypher_relationship(mock_neo4j_driver: MagicMock) -> None: + graph = Neo4jGraph( + url="bolt://localhost:7687", username="neo4j", password="password" + ) + properties = [{"property": "since", "type": "INTEGER"}] + + query = graph._enhanced_schema_cypher( + label_or_type="FRIENDS_WITH", + properties=properties, + exhaustive=True, + is_relationship=True, + ) + + assert query.startswith("MATCH ()-[n:`FRIENDS_WITH`]->()") + assert "min(n.`since`) AS `since_min`" in query + assert "max(n.`since`) AS `since_max`" in query + assert "count(distinct n.`since`) AS `since_distinct`" in query + expected_return_clause = ( + "`since`: {min: toString(`since_min`), max: toString(`since_max`), " + "distinct_count: `since_distinct`}" + ) + assert expected_return_clause in query From a53e8eea150818de316ff12d8446fe65389c5ebd Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 18 Dec 2024 11:04:29 +0000 Subject: [PATCH 3/9] Adds tests so that GraphCypherQAChain class now has 100% coverage (#23) * Added tests to improve GraphCypherQAChain test coverage * Fixed linting issue * Added validate_cypher test * Added test_function_response * 100% coverage for GraphCypherQAChain * Refactoring --- .../langchain_neo4j/chains/graph_qa/cypher.py | 2 +- .../chains}/__init__.py | 0 .../chains/test_graph_database.py | 55 +++- libs/neo4j/tests/llms/__init__.py | 0 .../tests/{unit_tests => }/llms/fake_llm.py | 0 .../tests/unit_tests/chains/test_graph_qa.py | 234 ++++++++++++++++-- 6 files changed, 271 insertions(+), 20 deletions(-) rename libs/neo4j/tests/{unit_tests/llms => integration_tests/chains}/__init__.py (100%) create mode 100644 libs/neo4j/tests/llms/__init__.py rename libs/neo4j/tests/{unit_tests => }/llms/fake_llm.py (100%) diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py index e84a4df..514870f 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py @@ -346,7 +346,7 @@ def from_llm( if validate_cypher: corrector_schema = [ Schema(el["start"], el["type"], el["end"]) - for el in kwargs["graph"].structured_schema.get("relationships") + for el in kwargs["graph"].get_structured_schema.get("relationships", []) ] cypher_query_corrector = CypherQueryCorrector(corrector_schema) diff --git a/libs/neo4j/tests/unit_tests/llms/__init__.py b/libs/neo4j/tests/integration_tests/chains/__init__.py similarity index 100% rename from libs/neo4j/tests/unit_tests/llms/__init__.py rename to libs/neo4j/tests/integration_tests/chains/__init__.py diff --git a/libs/neo4j/tests/integration_tests/chains/test_graph_database.py b/libs/neo4j/tests/integration_tests/chains/test_graph_database.py index 56ecc1b..eda5683 100644 --- a/libs/neo4j/tests/integration_tests/chains/test_graph_database.py +++ b/libs/neo4j/tests/integration_tests/chains/test_graph_database.py @@ -4,10 +4,10 @@ from unittest.mock import MagicMock from langchain_core.language_models import BaseLanguageModel -from langchain_core.language_models.fake import FakeListLLM from langchain_neo4j.chains.graph_qa.cypher import GraphCypherQAChain from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph +from tests.llms.fake_llm import FakeLLM def test_connect_neo4j() -> None: @@ -71,10 +71,13 @@ def test_cypher_generating_run() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = FakeListLLM(responses=[query, "Bruce Willis"]) + llm = FakeLLM( + queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True + ) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, + validate_cypher=True, allow_dangerous_requests=True, ) output = chain.run("Who starred in Pulp Fiction?") @@ -111,7 +114,7 @@ def test_cypher_top_k() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = FakeListLLM(responses=[query]) + llm = FakeLLM(queries={"query": query}, sequential_responses=True) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, @@ -149,7 +152,9 @@ def test_cypher_intermediate_steps() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = FakeListLLM(responses=[query, "Bruce Willis"]) + llm = FakeLLM( + queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True + ) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, @@ -194,7 +199,7 @@ def test_cypher_return_direct() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = FakeListLLM(responses=[query]) + llm = FakeLLM(queries={"query": query}, sequential_responses=True) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, @@ -206,6 +211,46 @@ def test_cypher_return_direct() -> None: assert output == expected_output +def test_function_response() -> None: + """Test returning a function response.""" + url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + username = os.environ.get("NEO4J_USERNAME", "neo4j") + password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") + + graph = Neo4jGraph( + url=url, + username=username, + password=password, + ) + # Delete all nodes in the graph + graph.query("MATCH (n) DETACH DELETE n") + # Create two nodes and a relationship + graph.query( + "CREATE (a:Actor {name:'Bruce Willis'})" + "-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})" + ) + # Refresh schema information + graph.refresh_schema() + + query = ( + "MATCH (a:Actor)-[:ACTED_IN]->(m:Movie) " + "WHERE m.title = 'Pulp Fiction' " + "RETURN a.name" + ) + llm = FakeLLM( + queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True + ) + chain = GraphCypherQAChain.from_llm( + llm=llm, + graph=graph, + allow_dangerous_requests=True, + use_function_response=True, + ) + output = chain.run("Who starred in Pulp Fiction?") + expected_output = "Bruce Willis" + assert output == expected_output + + def test_exclude_types() -> None: """Test exclude types from schema.""" url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") diff --git a/libs/neo4j/tests/llms/__init__.py b/libs/neo4j/tests/llms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/neo4j/tests/unit_tests/llms/fake_llm.py b/libs/neo4j/tests/llms/fake_llm.py similarity index 100% rename from libs/neo4j/tests/unit_tests/llms/fake_llm.py rename to libs/neo4j/tests/llms/fake_llm.py diff --git a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py index a857fb4..34d1f38 100644 --- a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py @@ -1,9 +1,16 @@ import pathlib from csv import DictReader from typing import Any, Dict, List +from unittest.mock import MagicMock, patch +import pytest from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory -from langchain_core.messages import SystemMessage +from langchain_core.language_models.llms import LLM +from langchain_core.messages import ( + AIMessage, + SystemMessage, + ToolMessage, +) from langchain_core.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, @@ -15,6 +22,7 @@ GraphCypherQAChain, construct_schema, extract_cypher, + get_function_response, ) from langchain_neo4j.chains.graph_qa.cypher_utils import ( CypherQueryCorrector, @@ -26,7 +34,7 @@ ) from langchain_neo4j.graphs.graph_document import GraphDocument from langchain_neo4j.graphs.graph_store import GraphStore -from tests.unit_tests.llms.fake_llm import FakeLLM +from tests.llms.fake_llm import FakeLLM class FakeGraphStore(GraphStore): @@ -141,21 +149,34 @@ def test_graph_cypher_qa_chain_prompt_selection_5() -> None: readonlymemory = ReadOnlySharedMemory(memory=memory) qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[]) cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[]) - try: + with pytest.raises(ValueError) as exc_info: GraphCypherQAChain.from_llm( llm=FakeLLM(), graph=FakeGraphStore(), verbose=True, return_intermediate_steps=False, - qa_prompt=qa_prompt, cypher_prompt=cypher_prompt, cypher_llm_kwargs={"memory": readonlymemory}, + allow_dangerous_requests=True, + ) + assert ( + "Specifying cypher_prompt and cypher_llm_kwargs together is" + " not allowed. Please pass prompt via cypher_llm_kwargs." + ) == str(exc_info.value) + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + return_intermediate_steps=False, + qa_prompt=qa_prompt, qa_llm_kwargs={"memory": readonlymemory}, allow_dangerous_requests=True, ) - assert False - except ValueError: - assert True + assert ( + "Specifying qa_prompt and qa_llm_kwargs together is" + " not allowed. Please pass prompt via qa_llm_kwargs." + ) == str(exc_info.value) def test_graph_cypher_qa_chain_prompt_selection_6() -> None: @@ -182,6 +203,53 @@ def test_graph_cypher_qa_chain_prompt_selection_6() -> None: assert chain.cypher_generation_chain.first == CYPHER_GENERATION_PROMPT +def test_graph_cypher_qa_chain_prompt_selection_7() -> None: + # Pass prompts which do not inherit from BasePromptTemplate + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + cypher_llm_kwargs={"prompt": None}, + allow_dangerous_requests=True, + ) + assert "The cypher_llm_kwargs `prompt` must inherit from BasePromptTemplate" == str( + exc_info.value + ) + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + qa_llm_kwargs={"prompt": None}, + allow_dangerous_requests=True, + ) + assert "The qa_llm_kwargs `prompt` must inherit from BasePromptTemplate" == str( + exc_info.value + ) + + +def test_validate_cypher() -> None: + with patch( + "langchain_neo4j.chains.graph_qa.cypher.CypherQueryCorrector", + autospec=True, + ) as cypher_query_corrector_mock: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + validate_cypher=True, + allow_dangerous_requests=True, + ) + cypher_query_corrector_mock.assert_called_once_with([]) + + +def test_chain_type() -> None: + chain = GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + assert chain._chain_type == "graph_cypher_chain" + + def test_graph_cypher_qa_chain() -> None: template = """You are a nice chatbot having a conversation with a human. @@ -236,6 +304,19 @@ def test_graph_cypher_qa_chain() -> None: assert True +def test_cypher_generation_failure() -> None: + """Test the chain doesn't fail if the Cypher query fails to be generated.""" + llm = FakeLLM(queries={"query": ""}, sequential_responses=True) + chain = GraphCypherQAChain.from_llm( + llm=llm, + graph=FakeGraphStore(), + allow_dangerous_requests=True, + return_direct=True, + ) + response = chain.run("Test question") + assert response == [] + + def test_no_backticks() -> None: """Test if there are no backticks, so the original text should be returned.""" query = "MATCH (n) RETURN n" @@ -257,7 +338,7 @@ def test_exclude_types() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -268,7 +349,8 @@ def test_exclude_types() -> None: expected_schema = ( "Node properties are the following:\n" "Movie {title: STRING},Actor {name: STRING}\n" - "Relationship properties are the following:\n\n" + "Relationship properties are the following:\n" + "ACTED_IN {role: STRING}\n" "The relationships are the following:\n" "(:Actor)-[:ACTED_IN]->(:Movie)" ) @@ -282,7 +364,7 @@ def test_include_types() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -293,7 +375,8 @@ def test_include_types() -> None: expected_schema = ( "Node properties are the following:\n" "Movie {title: STRING},Actor {name: STRING}\n" - "Relationship properties are the following:\n\n" + "Relationship properties are the following:\n" + "ACTED_IN {role: STRING}\n" "The relationships are the following:\n" "(:Actor)-[:ACTED_IN]->(:Movie)" ) @@ -307,7 +390,7 @@ def test_include_types2() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -331,7 +414,7 @@ def test_include_types3() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -342,13 +425,136 @@ def test_include_types3() -> None: expected_schema = ( "Node properties are the following:\n" "Movie {title: STRING},Actor {name: STRING}\n" - "Relationship properties are the following:\n\n" + "Relationship properties are the following:\n" + "ACTED_IN {role: STRING}\n" "The relationships are the following:\n" "(:Actor)-[:ACTED_IN]->(:Movie)" ) assert output == expected_schema +def test_include_exclude_types_err() -> None: + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + include_types=["Movie", "Actor"], + exclude_types=["Person", "DIRECTED"], + allow_dangerous_requests=True, + ) + assert ( + "Either `exclude_types` or `include_types` can be provided, but not both" + == str(exc_info.value) + ) + + +def test_get_function_response() -> None: + question = "Who directed Dune?" + context = [{"director": "Denis Villeneuve"}] + messages = get_function_response(question, context) + assert len(messages) == 2 + # Validate AIMessage + ai_message = messages[0] + assert isinstance(ai_message, AIMessage) + assert ai_message.content == "" + assert "tool_calls" in ai_message.additional_kwargs + tool_call = ai_message.additional_kwargs["tool_calls"][0] + assert tool_call["function"]["arguments"] == f'{{"question":"{question}"}}' + # Validate ToolMessage + tool_message = messages[1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == str(context) + + +def test_allow_dangerous_requests_err() -> None: + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + ) + assert ( + "In order to use this chain, you must acknowledge that it can make " + "dangerous requests by setting `allow_dangerous_requests` to `True`." + ) in str(exc_info.value) + + +def test_llm_arg_combinations() -> None: + # No llm + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + graph=FakeGraphStore(), allow_dangerous_requests=True + ) + assert "At least one LLM must be provided" == str(exc_info.value) + # llm only + GraphCypherQAChain.from_llm( + llm=FakeLLM(), graph=FakeGraphStore(), allow_dangerous_requests=True + ) + # qa_llm only + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + qa_llm=FakeLLM(), graph=FakeGraphStore(), allow_dangerous_requests=True + ) + assert ( + "If `llm` is not provided, both `qa_llm` and `cypher_llm` must be provided." + == str(exc_info.value) + ) + # cypher_llm only + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + cypher_llm=FakeLLM(), graph=FakeGraphStore(), allow_dangerous_requests=True + ) + assert ( + "If `llm` is not provided, both `qa_llm` and `cypher_llm` must be provided." + == str(exc_info.value) + ) + # llm + qa_llm + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + qa_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + # llm + cypher_llm + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + cypher_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + # qa_llm + cypher_llm + GraphCypherQAChain.from_llm( + qa_llm=FakeLLM(), + cypher_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + # llm + qa_llm + cypher_llm + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + qa_llm=FakeLLM(), + cypher_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + assert ( + "You can specify up to two of 'cypher_llm', 'qa_llm'" + ", and 'llm', but not all three simultaneously." + ) == str(exc_info.value) + + +def test_use_function_response_err() -> None: + llm = MagicMock(spec=LLM) + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=llm, + graph=FakeGraphStore(), + allow_dangerous_requests=True, + use_function_response=True, + ) + assert "Provided LLM does not support native tools/functions" == str(exc_info.value) + + HERE = pathlib.Path(__file__).parent UNIT_TESTS_ROOT = HERE.parent From 1248b6426b751d22c07a8edaef97b3342fe245e4 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 18 Dec 2024 11:27:46 +0000 Subject: [PATCH 4/9] Release 0.2.0 (#25) --- CHANGELOG.md | 2 ++ libs/neo4j/pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ca6fdb1..0976c74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## Next +## 0.2.0 + ### Added - Enhanced Neo4j driver connection management with more robust error handling. diff --git a/libs/neo4j/pyproject.toml b/libs/neo4j/pyproject.toml index d1393a2..112835c 100644 --- a/libs/neo4j/pyproject.toml +++ b/libs/neo4j/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-neo4j" -version = "0.1.1" +version = "0.2.0" description = "An integration package connecting Neo4j and LangChain" authors = [] readme = "README.md" From ef342d00a60f22d7f58796bde0977b8e8f096533 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 18 Dec 2024 15:37:20 +0000 Subject: [PATCH 5/9] Moved langchain-core depedency from 0.3.0 to 0.3.1 to fix a bug (#26) --- libs/neo4j/.gitignore | 1 + libs/neo4j/poetry.lock | 4 ++-- libs/neo4j/pyproject.toml | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/libs/neo4j/.gitignore b/libs/neo4j/.gitignore index bee8a64..073a081 100644 --- a/libs/neo4j/.gitignore +++ b/libs/neo4j/.gitignore @@ -1 +1,2 @@ __pycache__ +.python-version diff --git a/libs/neo4j/poetry.lock b/libs/neo4j/poetry.lock index 9af7616..afeff93 100644 --- a/libs/neo4j/poetry.lock +++ b/libs/neo4j/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1912,4 +1912,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "3a6a2a86b0b2af7e6d6d947711f12afe71ee582c4b753359be920d27c047e958" +content-hash = "f53f504ff1199eb1021feeae70864948c106446ffbabaecfc6edee56f5ea2f13" diff --git a/libs/neo4j/pyproject.toml b/libs/neo4j/pyproject.toml index 112835c..827fa0f 100644 --- a/libs/neo4j/pyproject.toml +++ b/libs/neo4j/pyproject.toml @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.9,<4.0" -langchain-core = "^0.3.0" +langchain-core = "^0.3.1" neo4j = "^5.25.0" langchain = "^0.3.7" From a6c8e139aa4beb505cb79c446a72d0c53a28e7ef Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 18 Dec 2024 16:32:35 +0000 Subject: [PATCH 6/9] Moved LangChain core to 0.3.8 (#27) --- libs/neo4j/poetry.lock | 2 +- libs/neo4j/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/neo4j/poetry.lock b/libs/neo4j/poetry.lock index afeff93..6dfbe97 100644 --- a/libs/neo4j/poetry.lock +++ b/libs/neo4j/poetry.lock @@ -1912,4 +1912,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "f53f504ff1199eb1021feeae70864948c106446ffbabaecfc6edee56f5ea2f13" +content-hash = "5b6582f252d7d74f2fe188e34865d68cbe16bb3477f3e09a8855ad68f824597a" diff --git a/libs/neo4j/pyproject.toml b/libs/neo4j/pyproject.toml index 827fa0f..66db058 100644 --- a/libs/neo4j/pyproject.toml +++ b/libs/neo4j/pyproject.toml @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.9,<4.0" -langchain-core = "^0.3.1" +langchain-core = "^0.3.8" neo4j = "^5.25.0" langchain = "^0.3.7" From 6c9ac2453ae12bef1d27fc2fa78161d9be660667 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Tue, 7 Jan 2025 10:13:53 +0000 Subject: [PATCH 7/9] Disables Neo4jGraph driver warnings (#29) * Disables Neo4jGraph driver warnings * Fixed linting error * More linting fixes --- .../langchain_neo4j/graphs/neo4j_graph.py | 22 +++++++-- .../unit_tests/graphs/test_neo4j_graph.py | 49 ++++++++++++++++++- 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index 237c5d5..b1ced4b 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -310,6 +310,7 @@ class Neo4jGraph(GraphStore): enhanced_schema (bool): A flag whether to scan the database for example values and use them in the graph schema. Default is False. driver_config (Dict): Configuration passed to Neo4j Driver. + Defaults to {"notifications_min_severity", "OFF"} if not set. *Security note*: Make sure that the database connection uses credentials that are narrowly-scoped to only include necessary permissions. @@ -365,9 +366,10 @@ def __init__( {"database": database}, "database", "NEO4J_DATABASE", "neo4j" ) - self._driver = neo4j.GraphDatabase.driver( - url, auth=auth, **(driver_config or {}) - ) + if driver_config is None: + driver_config = {} + driver_config.setdefault("notifications_min_severity", "OFF") + self._driver = neo4j.GraphDatabase.driver(url, auth=auth, **driver_config) self._database = database self.timeout = timeout self.sanitize = sanitize @@ -377,6 +379,20 @@ def __init__( # Verify connection try: self._driver.verify_connectivity() + except neo4j.exceptions.ConfigurationError as e: + # If notification filtering is not supported + if "Notification filtering is not supported" in str(e): + # Retry without notifications_min_severity + driver_config.pop("notifications_min_severity", None) + self._driver = neo4j.GraphDatabase.driver( + url, auth=auth, **driver_config + ) + self._driver.verify_connectivity() + else: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the driver config is correct" + ) except neo4j.exceptions.ServiceUnavailable: raise ValueError( "Could not connect to Neo4j database. " diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 265af31..60b79b1 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest -from neo4j.exceptions import ClientError, Neo4jError +from neo4j.exceptions import ClientError, ConfigurationError, Neo4jError from langchain_neo4j.graphs.neo4j_graph import ( LIST_LIMIT, @@ -201,7 +201,52 @@ def test_neo4j_graph_init_with_empty_credentials() -> None: Neo4jGraph( url="bolt://localhost:7687", username="", password="", refresh_schema=False ) - mock_driver.assert_called_with("bolt://localhost:7687", auth=None) + mock_driver.assert_called_with( + "bolt://localhost:7687", auth=None, notifications_min_severity="OFF" + ) + + +def test_neo4j_graph_init_notification_filtering_err() -> None: + """Test the __init__ method when notification filtering is disabled.""" + with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver: + mock_driver_instance = MagicMock() + mock_driver.return_value = mock_driver_instance + err = ConfigurationError("Notification filtering is not supported") + mock_driver_instance.verify_connectivity.side_effect = [err, None] + Neo4jGraph( + url="bolt://localhost:7687", + username="username", + password="password", + refresh_schema=False, + ) + mock_driver.assert_any_call( + "bolt://localhost:7687", + auth=("username", "password"), + notifications_min_severity="OFF", + ) + # The first call verify_connectivity should fail causing the driver to be + # recreated without the notifications_min_severity parameter + mock_driver.assert_any_call( + "bolt://localhost:7687", + auth=("username", "password"), + ) + + +def test_neo4j_graph_init_driver_config_err() -> None: + """Test the __init__ method with an incorrect driver config.""" + with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver: + mock_driver_instance = MagicMock() + mock_driver.return_value = mock_driver_instance + err = ConfigurationError() + mock_driver_instance.verify_connectivity.side_effect = err + with pytest.raises(ValueError) as exc_info: + Neo4jGraph( + url="bolt://localhost:7687", + username="username", + password="password", + refresh_schema=False, + ) + assert "Please ensure that the driver config is correct" in str(exc_info.value) def test_init_apoc_procedure_not_found( From daeb84abe7c8af81a10de47af6c2536431cae093 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 8 Jan 2025 11:05:57 +0000 Subject: [PATCH 8/9] Makes the source parameter of GraphDocument optional (#32) * Makes the source parameter of GraphDocument optional * Updated CHANGELOG --- CHANGELOG.md | 8 +++++ .../langchain_neo4j/graphs/graph_document.py | 7 ++-- .../langchain_neo4j/graphs/neo4j_graph.py | 32 ++++++++++++------- .../unit_tests/graphs/test_neo4j_graph.py | 27 ++++++++++++++++ 4 files changed, 59 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0976c74..c5ee525 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ ## Next +### Changed + +- Made the `source` parameter of `GraphDocument` optional and updated related methods to support this. + +### Fixed + +- Disabled warnings from the Neo4j driver for the Neo4jGraph class. + ## 0.2.0 ### Added diff --git a/libs/neo4j/langchain_neo4j/graphs/graph_document.py b/libs/neo4j/langchain_neo4j/graphs/graph_document.py index ff82ca4..ff32837 100644 --- a/libs/neo4j/langchain_neo4j/graphs/graph_document.py +++ b/libs/neo4j/langchain_neo4j/graphs/graph_document.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Union +from typing import List, Optional, Union from langchain_core.documents import Document from langchain_core.load.serializable import Serializable @@ -43,9 +43,10 @@ class GraphDocument(Serializable): Attributes: nodes (List[Node]): A list of nodes in the graph. relationships (List[Relationship]): A list of relationships in the graph. - source (Document): The document from which the graph information is derived. + source (Optional[Document]): The document from which the graph information is + derived. """ nodes: List[Node] relationships: List[Relationship] - source: Document + source: Optional[Document] = None diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index b1ced4b..21ac2dd 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -616,7 +616,7 @@ def add_graph_documents( - graph_documents (List[GraphDocument]): A list of GraphDocument objects that contain the nodes and relationships to be added to the graph. Each GraphDocument should encapsulate the structure of part of the graph, - including nodes, relationships, and the source document information. + including nodes, relationships, and optionally the source document information. - include_source (bool, optional): If True, stores the source document and links it to nodes in the graph using the MENTIONS relationship. This is useful for tracing back the origin of data. Merges source @@ -650,25 +650,33 @@ def add_graph_documents( ) self.refresh_schema() # Refresh constraint information + # Check each graph_document has a source when include_source is true + if include_source: + for doc in graph_documents: + if doc.source is None: + raise TypeError( + "include_source is set to True, " + "but at least one document has no `source`." + ) + node_import_query = _get_node_import_query(baseEntityLabel, include_source) rel_import_query = _get_rel_import_query(baseEntityLabel) for document in graph_documents: - if not document.source.metadata.get("id"): - document.source.metadata["id"] = md5( - document.source.page_content.encode("utf-8") - ).hexdigest() + node_import_query_params: dict[str, Any] = { + "data": [el.__dict__ for el in document.nodes] + } + if include_source and document.source: + if not document.source.metadata.get("id"): + document.source.metadata["id"] = md5( + document.source.page_content.encode("utf-8") + ).hexdigest() + node_import_query_params["document"] = document.source.__dict__ # Remove backticks from node types for node in document.nodes: node.type = _remove_backticks(node.type) # Import nodes - self.query( - node_import_query, - { - "data": [el.__dict__ for el in document.nodes], - "document": document.source.__dict__, - }, - ) + self.query(node_import_query, node_import_query_params) # Import relationships self.query( rel_import_query, diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 60b79b1..c62a0cc 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -5,6 +5,7 @@ import pytest from neo4j.exceptions import ClientError, ConfigurationError, Neo4jError +from langchain_neo4j.graphs.graph_document import GraphDocument, Node, Relationship from langchain_neo4j.graphs.neo4j_graph import ( LIST_LIMIT, Neo4jGraph, @@ -374,6 +375,32 @@ def test_get_schema(mock_neo4j_driver: MagicMock) -> None: assert graph.get_schema == "test" +def test_add_graph_docs_inc_src_err(mock_neo4j_driver: MagicMock) -> None: + """Tests an error is raised when using add_graph_documents with include_source set + to True and a document is missing a source.""" + graph = Neo4jGraph( + url="bolt://localhost:7687", + username="neo4j", + password="password", + refresh_schema=False, + ) + node_1 = Node(id=1) + node_2 = Node(id=2) + rel = Relationship(source=node_1, target=node_2, type="REL") + + graph_doc = GraphDocument( + nodes=[node_1, node_2], + relationships=[rel], + ) + with pytest.raises(TypeError) as exc_info: + graph.add_graph_documents(graph_documents=[graph_doc], include_source=True) + + assert ( + "include_source is set to True, but at least one document has no `source`." + in str(exc_info.value) + ) + + @pytest.mark.parametrize( "description, schema, is_enhanced, expected_output", [ From 5af91a1b940dc5907d823aec18247035a128d1ff Mon Sep 17 00:00:00 2001 From: Dennis Liu Date: Wed, 8 Jan 2025 22:24:45 +1100 Subject: [PATCH 9/9] Provide optional parameter for embedding dimension in Neo4jVector (#31) * add optional parameter * additional checks in methods * tests * linting --- .../vectorstores/neo4j_vector.py | 44 +++++++++++++++++-- .../unit_tests/vectorstores/test_neo4j.py | 21 +++++++++ 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py b/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py index 9438f1b..fb8f695 100644 --- a/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py +++ b/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py @@ -466,6 +466,8 @@ class Neo4jVector(VectorStore): (default: False). Useful for testing. effective_search_ratio: Controls the candidate pool size by multiplying $k to balance query accuracy and performance. + embedding_dimension: The dimension of the embeddings. If not provided, + will query the embedding model to calculate the dimension. Example: .. code-block:: python @@ -509,6 +511,7 @@ def __init__( relevance_score_fn: Optional[Callable[[float], float]] = None, index_type: IndexType = DEFAULT_INDEX_TYPE, graph: Optional[Neo4jGraph] = None, + embedding_dimension: Optional[int] = None, ) -> None: try: import neo4j @@ -593,8 +596,11 @@ def __init__( self.search_type = search_type self._index_type = index_type - # Calculate embedding dimension - self.embedding_dimension = len(embedding.embed_query("foo")) + if embedding_dimension: + self.embedding_dimension = embedding_dimension + else: + # Calculate embedding dimension + self.embedding_dimension = len(embedding.embed_query("foo")) # Delete existing data if flagged if pre_delete_collection: @@ -1336,6 +1342,7 @@ def from_existing_index( index_name: str, search_type: SearchType = DEFAULT_SEARCH_TYPE, keyword_index_name: Optional[str] = None, + embedding_dimension: Optional[int] = None, **kwargs: Any, ) -> Neo4jVector: """ @@ -1358,10 +1365,24 @@ def from_existing_index( index_name=index_name, keyword_index_name=keyword_index_name, search_type=search_type, + embedding_dimension=embedding_dimension, **kwargs, ) - embedding_dimension, index_type = store.retrieve_existing_index() + if embedding_dimension: + ( + embedding_dimension_from_existing, + index_type, + ) = store.retrieve_existing_index() + if embedding_dimension_from_existing != embedding_dimension: + raise ValueError( + "The provided embedding function and vector index " + "dimensions do not match.\n" + f"Embedding function dimension: {embedding_dimension}\n" + f"Vector index dimension: {embedding_dimension_from_existing}" + ) + else: + embedding_dimension, index_type = store.retrieve_existing_index() # Raise error if relationship index type if index_type == "RELATIONSHIP": @@ -1408,6 +1429,7 @@ def from_existing_relationship_index( embedding: Embeddings, index_name: str, search_type: SearchType = DEFAULT_SEARCH_TYPE, + embedding_dimension: Optional[int] = None, **kwargs: Any, ) -> Neo4jVector: """ @@ -1428,10 +1450,24 @@ def from_existing_relationship_index( store = cls( embedding=embedding, index_name=index_name, + embedding_dimension=embedding_dimension, **kwargs, ) - embedding_dimension, index_type = store.retrieve_existing_index() + if embedding_dimension: + ( + embedding_dimension_from_existing, + index_type, + ) = store.retrieve_existing_index() + if embedding_dimension_from_existing != embedding_dimension: + raise ValueError( + "The provided embedding function and vector index " + "dimensions do not match.\n" + f"Embedding function dimension: {embedding_dimension}\n" + f"Vector index dimension: {embedding_dimension_from_existing}" + ) + else: + embedding_dimension, index_type = store.retrieve_existing_index() if not index_type: raise ValueError( diff --git a/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py b/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py index 201864d..d6dbf32 100644 --- a/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py +++ b/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py @@ -1022,3 +1022,24 @@ def test_select_relevance_score_fn_unsupported_strategy( f"Expected error message to contain '{expected_message}' " f"but got '{str(exc_info.value)}'" ) + + +def test_embedding_dimension_inconsistent_raises_value_error( + neo4j_vector_factory: Any, +) -> None: + mock_embedding = MagicMock() + mock_embedding.embed_query.return_value = [0.1] * 64 + + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(128, "NODE") + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_existing_index", + embedding=mock_embedding, + index_name="test_index", + ) + assert ( + "The provided embedding function and vector index dimensions do not match." + in str(exc_info.value) + )