From e405e81557bd2cb28b3162ce71c0e5c0989b2eb4 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Mon, 16 Dec 2024 16:46:38 +0000 Subject: [PATCH] Added validate_cypher test --- .../langchain_neo4j/chains/graph_qa/cypher.py | 2 +- .../tests/unit_tests/chains/test_graph_qa.py | 25 ++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py index e84a4df..514870f 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py @@ -346,7 +346,7 @@ def from_llm( if validate_cypher: corrector_schema = [ Schema(el["start"], el["type"], el["end"]) - for el in kwargs["graph"].structured_schema.get("relationships") + for el in kwargs["graph"].get_structured_schema.get("relationships", []) ] cypher_query_corrector = CypherQueryCorrector(corrector_schema) diff --git a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py index e458afd..d966366 100644 --- a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py @@ -1,7 +1,7 @@ import pathlib from csv import DictReader from typing import Any, Dict, List -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory @@ -227,6 +227,29 @@ def test_graph_cypher_qa_chain_prompt_selection_7() -> None: ) +def test_validate_cypher() -> None: + with patch( + "langchain_neo4j.chains.graph_qa.cypher.CypherQueryCorrector", + autospec=True, + ) as cypher_query_corrector_mock: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + validate_cypher=True, + allow_dangerous_requests=True, + ) + cypher_query_corrector_mock.assert_called_once_with([]) + + +def test_chain_type() -> None: + chain = GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + assert chain._chain_type == "graph_cypher_chain" + + def test_graph_cypher_qa_chain() -> None: template = """You are a nice chatbot having a conversation with a human.