Skip to content

Commit

Permalink
Added validate_cypher test
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Dec 16, 2024
1 parent 0e1686a commit e405e81
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
2 changes: 1 addition & 1 deletion libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def from_llm(
if validate_cypher:
corrector_schema = [
Schema(el["start"], el["type"], el["end"])
for el in kwargs["graph"].structured_schema.get("relationships")
for el in kwargs["graph"].get_structured_schema.get("relationships", [])
]
cypher_query_corrector = CypherQueryCorrector(corrector_schema)

Expand Down
25 changes: 24 additions & 1 deletion libs/neo4j/tests/unit_tests/chains/test_graph_qa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib
from csv import DictReader
from typing import Any, Dict, List
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
Expand Down Expand Up @@ -227,6 +227,29 @@ def test_graph_cypher_qa_chain_prompt_selection_7() -> None:
)


def test_validate_cypher() -> None:
with patch(
"langchain_neo4j.chains.graph_qa.cypher.CypherQueryCorrector",
autospec=True,
) as cypher_query_corrector_mock:
GraphCypherQAChain.from_llm(
llm=FakeLLM(),
graph=FakeGraphStore(),
validate_cypher=True,
allow_dangerous_requests=True,
)
cypher_query_corrector_mock.assert_called_once_with([])


def test_chain_type() -> None:
chain = GraphCypherQAChain.from_llm(
llm=FakeLLM(),
graph=FakeGraphStore(),
allow_dangerous_requests=True,
)
assert chain._chain_type == "graph_cypher_chain"


def test_graph_cypher_qa_chain() -> None:
template = """You are a nice chatbot having a conversation with a human.
Expand Down

0 comments on commit e405e81

Please sign in to comment.