diff --git a/CHANGELOG.md b/CHANGELOG.md index 0976c74..c5ee525 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ ## Next +### Changed + +- Made the `source` parameter of `GraphDocument` optional and updated related methods to support this. + +### Fixed + +- Disabled warnings from the Neo4j driver for the Neo4jGraph class. + ## 0.2.0 ### Added diff --git a/libs/neo4j/langchain_neo4j/graphs/graph_document.py b/libs/neo4j/langchain_neo4j/graphs/graph_document.py index ff82ca4..ff32837 100644 --- a/libs/neo4j/langchain_neo4j/graphs/graph_document.py +++ b/libs/neo4j/langchain_neo4j/graphs/graph_document.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Union +from typing import List, Optional, Union from langchain_core.documents import Document from langchain_core.load.serializable import Serializable @@ -43,9 +43,10 @@ class GraphDocument(Serializable): Attributes: nodes (List[Node]): A list of nodes in the graph. relationships (List[Relationship]): A list of relationships in the graph. - source (Document): The document from which the graph information is derived. + source (Optional[Document]): The document from which the graph information is + derived. """ nodes: List[Node] relationships: List[Relationship] - source: Document + source: Optional[Document] = None diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index b1ced4b..21ac2dd 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -616,7 +616,7 @@ def add_graph_documents( - graph_documents (List[GraphDocument]): A list of GraphDocument objects that contain the nodes and relationships to be added to the graph. Each GraphDocument should encapsulate the structure of part of the graph, - including nodes, relationships, and the source document information. + including nodes, relationships, and optionally the source document information. - include_source (bool, optional): If True, stores the source document and links it to nodes in the graph using the MENTIONS relationship. This is useful for tracing back the origin of data. Merges source @@ -650,25 +650,33 @@ def add_graph_documents( ) self.refresh_schema() # Refresh constraint information + # Check each graph_document has a source when include_source is true + if include_source: + for doc in graph_documents: + if doc.source is None: + raise TypeError( + "include_source is set to True, " + "but at least one document has no `source`." + ) + node_import_query = _get_node_import_query(baseEntityLabel, include_source) rel_import_query = _get_rel_import_query(baseEntityLabel) for document in graph_documents: - if not document.source.metadata.get("id"): - document.source.metadata["id"] = md5( - document.source.page_content.encode("utf-8") - ).hexdigest() + node_import_query_params: dict[str, Any] = { + "data": [el.__dict__ for el in document.nodes] + } + if include_source and document.source: + if not document.source.metadata.get("id"): + document.source.metadata["id"] = md5( + document.source.page_content.encode("utf-8") + ).hexdigest() + node_import_query_params["document"] = document.source.__dict__ # Remove backticks from node types for node in document.nodes: node.type = _remove_backticks(node.type) # Import nodes - self.query( - node_import_query, - { - "data": [el.__dict__ for el in document.nodes], - "document": document.source.__dict__, - }, - ) + self.query(node_import_query, node_import_query_params) # Import relationships self.query( rel_import_query, diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 60b79b1..c62a0cc 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -5,6 +5,7 @@ import pytest from neo4j.exceptions import ClientError, ConfigurationError, Neo4jError +from langchain_neo4j.graphs.graph_document import GraphDocument, Node, Relationship from langchain_neo4j.graphs.neo4j_graph import ( LIST_LIMIT, Neo4jGraph, @@ -374,6 +375,32 @@ def test_get_schema(mock_neo4j_driver: MagicMock) -> None: assert graph.get_schema == "test" +def test_add_graph_docs_inc_src_err(mock_neo4j_driver: MagicMock) -> None: + """Tests an error is raised when using add_graph_documents with include_source set + to True and a document is missing a source.""" + graph = Neo4jGraph( + url="bolt://localhost:7687", + username="neo4j", + password="password", + refresh_schema=False, + ) + node_1 = Node(id=1) + node_2 = Node(id=2) + rel = Relationship(source=node_1, target=node_2, type="REL") + + graph_doc = GraphDocument( + nodes=[node_1, node_2], + relationships=[rel], + ) + with pytest.raises(TypeError) as exc_info: + graph.add_graph_documents(graph_documents=[graph_doc], include_source=True) + + assert ( + "include_source is set to True, but at least one document has no `source`." + in str(exc_info.value) + ) + + @pytest.mark.parametrize( "description, schema, is_enhanced, expected_output", [