Skip to content

Commit

Permalink
Makes the source parameter of GraphDocument optional (#32)
Browse files Browse the repository at this point in the history
* Makes the source parameter of GraphDocument optional

* Updated CHANGELOG
  • Loading branch information
alexthomas93 authored Jan 8, 2025
1 parent 6c9ac24 commit daeb84a
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 15 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

## Next

### Changed

- Made the `source` parameter of `GraphDocument` optional and updated related methods to support this.

### Fixed

- Disabled warnings from the Neo4j driver for the Neo4jGraph class.

## 0.2.0

### Added
Expand Down
7 changes: 4 additions & 3 deletions libs/neo4j/langchain_neo4j/graphs/graph_document.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import List, Union
from typing import List, Optional, Union

from langchain_core.documents import Document
from langchain_core.load.serializable import Serializable
Expand Down Expand Up @@ -43,9 +43,10 @@ class GraphDocument(Serializable):
Attributes:
nodes (List[Node]): A list of nodes in the graph.
relationships (List[Relationship]): A list of relationships in the graph.
source (Document): The document from which the graph information is derived.
source (Optional[Document]): The document from which the graph information is
derived.
"""

nodes: List[Node]
relationships: List[Relationship]
source: Document
source: Optional[Document] = None
32 changes: 20 additions & 12 deletions libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def add_graph_documents(
- graph_documents (List[GraphDocument]): A list of GraphDocument objects
that contain the nodes and relationships to be added to the graph. Each
GraphDocument should encapsulate the structure of part of the graph,
including nodes, relationships, and the source document information.
including nodes, relationships, and optionally the source document information.
- include_source (bool, optional): If True, stores the source document
and links it to nodes in the graph using the MENTIONS relationship.
This is useful for tracing back the origin of data. Merges source
Expand Down Expand Up @@ -650,25 +650,33 @@ def add_graph_documents(
)
self.refresh_schema() # Refresh constraint information

# Check each graph_document has a source when include_source is true
if include_source:
for doc in graph_documents:
if doc.source is None:
raise TypeError(
"include_source is set to True, "
"but at least one document has no `source`."
)

node_import_query = _get_node_import_query(baseEntityLabel, include_source)
rel_import_query = _get_rel_import_query(baseEntityLabel)
for document in graph_documents:
if not document.source.metadata.get("id"):
document.source.metadata["id"] = md5(
document.source.page_content.encode("utf-8")
).hexdigest()
node_import_query_params: dict[str, Any] = {
"data": [el.__dict__ for el in document.nodes]
}
if include_source and document.source:
if not document.source.metadata.get("id"):
document.source.metadata["id"] = md5(
document.source.page_content.encode("utf-8")
).hexdigest()
node_import_query_params["document"] = document.source.__dict__

# Remove backticks from node types
for node in document.nodes:
node.type = _remove_backticks(node.type)
# Import nodes
self.query(
node_import_query,
{
"data": [el.__dict__ for el in document.nodes],
"document": document.source.__dict__,
},
)
self.query(node_import_query, node_import_query_params)
# Import relationships
self.query(
rel_import_query,
Expand Down
27 changes: 27 additions & 0 deletions libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from neo4j.exceptions import ClientError, ConfigurationError, Neo4jError

from langchain_neo4j.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_neo4j.graphs.neo4j_graph import (
LIST_LIMIT,
Neo4jGraph,
Expand Down Expand Up @@ -374,6 +375,32 @@ def test_get_schema(mock_neo4j_driver: MagicMock) -> None:
assert graph.get_schema == "test"


def test_add_graph_docs_inc_src_err(mock_neo4j_driver: MagicMock) -> None:
"""Tests an error is raised when using add_graph_documents with include_source set
to True and a document is missing a source."""
graph = Neo4jGraph(
url="bolt://localhost:7687",
username="neo4j",
password="password",
refresh_schema=False,
)
node_1 = Node(id=1)
node_2 = Node(id=2)
rel = Relationship(source=node_1, target=node_2, type="REL")

graph_doc = GraphDocument(
nodes=[node_1, node_2],
relationships=[rel],
)
with pytest.raises(TypeError) as exc_info:
graph.add_graph_documents(graph_documents=[graph_doc], include_source=True)

assert (
"include_source is set to True, but at least one document has no `source`."
in str(exc_info.value)
)


@pytest.mark.parametrize(
"description, schema, is_enhanced, expected_output",
[
Expand Down

0 comments on commit daeb84a

Please sign in to comment.