diff --git a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py index a857fb4..9e62a11 100644 --- a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py @@ -1,9 +1,16 @@ import pathlib from csv import DictReader from typing import Any, Dict, List +from unittest.mock import MagicMock, patch +import pytest from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory -from langchain_core.messages import SystemMessage +from langchain_core.language_models.llms import LLM +from langchain_core.messages import ( + AIMessage, + SystemMessage, + ToolMessage, +) from langchain_core.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, @@ -15,6 +22,7 @@ GraphCypherQAChain, construct_schema, extract_cypher, + get_function_response, ) from langchain_neo4j.chains.graph_qa.cypher_utils import ( CypherQueryCorrector, @@ -141,21 +149,34 @@ def test_graph_cypher_qa_chain_prompt_selection_5() -> None: readonlymemory = ReadOnlySharedMemory(memory=memory) qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[]) cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[]) - try: + with pytest.raises(ValueError) as exc_info: GraphCypherQAChain.from_llm( llm=FakeLLM(), graph=FakeGraphStore(), verbose=True, return_intermediate_steps=False, - qa_prompt=qa_prompt, cypher_prompt=cypher_prompt, cypher_llm_kwargs={"memory": readonlymemory}, + allow_dangerous_requests=True, + ) + assert ( + "Specifying cypher_prompt and cypher_llm_kwargs together is" + " not allowed. Please pass prompt via cypher_llm_kwargs." + ) == str(exc_info.value) + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + return_intermediate_steps=False, + qa_prompt=qa_prompt, qa_llm_kwargs={"memory": readonlymemory}, allow_dangerous_requests=True, ) - assert False - except ValueError: - assert True + assert ( + "Specifying qa_prompt and qa_llm_kwargs together is" + " not allowed. Please pass prompt via qa_llm_kwargs." + ) == str(exc_info.value) def test_graph_cypher_qa_chain_prompt_selection_6() -> None: @@ -182,6 +203,30 @@ def test_graph_cypher_qa_chain_prompt_selection_6() -> None: assert chain.cypher_generation_chain.first == CYPHER_GENERATION_PROMPT +def test_graph_cypher_qa_chain_prompt_selection_7() -> None: + # Pass prompts which do not inherit from BasePromptTemplate + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + cypher_llm_kwargs={"prompt": None}, + allow_dangerous_requests=True, + ) + assert "The cypher_llm_kwargs `prompt` must inherit from BasePromptTemplate" == str( + exc_info.value + ) + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + qa_llm_kwargs={"prompt": None}, + allow_dangerous_requests=True, + ) + assert "The qa_llm_kwargs `prompt` must inherit from BasePromptTemplate" == str( + exc_info.value + ) + + def test_graph_cypher_qa_chain() -> None: template = """You are a nice chatbot having a conversation with a human. @@ -257,7 +302,7 @@ def test_exclude_types() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -268,7 +313,8 @@ def test_exclude_types() -> None: expected_schema = ( "Node properties are the following:\n" "Movie {title: STRING},Actor {name: STRING}\n" - "Relationship properties are the following:\n\n" + "Relationship properties are the following:\n" + "ACTED_IN {role: STRING}\n" "The relationships are the following:\n" "(:Actor)-[:ACTED_IN]->(:Movie)" ) @@ -282,7 +328,7 @@ def test_include_types() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -293,7 +339,8 @@ def test_include_types() -> None: expected_schema = ( "Node properties are the following:\n" "Movie {title: STRING},Actor {name: STRING}\n" - "Relationship properties are the following:\n\n" + "Relationship properties are the following:\n" + "ACTED_IN {role: STRING}\n" "The relationships are the following:\n" "(:Actor)-[:ACTED_IN]->(:Movie)" ) @@ -307,7 +354,7 @@ def test_include_types2() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -331,7 +378,7 @@ def test_include_types3() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -342,13 +389,136 @@ def test_include_types3() -> None: expected_schema = ( "Node properties are the following:\n" "Movie {title: STRING},Actor {name: STRING}\n" - "Relationship properties are the following:\n\n" + "Relationship properties are the following:\n" + "ACTED_IN {role: STRING}\n" "The relationships are the following:\n" "(:Actor)-[:ACTED_IN]->(:Movie)" ) assert output == expected_schema +def test_include_exclude_types_err() -> None: + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + include_types=["Movie", "Actor"], + exclude_types=["Person", "DIRECTED"], + allow_dangerous_requests=True, + ) + assert ( + "Either `exclude_types` or `include_types` can be provided, but not both" + == str(exc_info.value) + ) + + +def test_get_function_response() -> None: + question = "Who directed Dune?" + context = [{"director": "Denis Villeneuve"}] + messages = get_function_response(question, context) + assert len(messages) == 2 + # Validate AIMessage + ai_message = messages[0] + assert isinstance(ai_message, AIMessage) + assert ai_message.content == "" + assert "tool_calls" in ai_message.additional_kwargs + tool_call = ai_message.additional_kwargs["tool_calls"][0] + assert tool_call["function"]["arguments"] == f'{{"question":"{question}"}}' + # Validate ToolMessage + tool_message = messages[1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == str(context) + + +def test_allow_dangerous_requests_err() -> None: + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + ) + assert ( + "In order to use this chain, you must acknowledge that it can make " + "dangerous requests by setting `allow_dangerous_requests` to `True`." + ) in str(exc_info.value) + + +def test_llm_arg_combinations() -> None: + # No llm + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + graph=FakeGraphStore(), allow_dangerous_requests=True + ) + assert "At least one LLM must be provided" == str(exc_info.value) + # llm only + GraphCypherQAChain.from_llm( + llm=FakeLLM(), graph=FakeGraphStore(), allow_dangerous_requests=True + ) + # qa_llm only + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + qa_llm=FakeLLM(), graph=FakeGraphStore(), allow_dangerous_requests=True + ) + assert ( + "If `llm` is not provided, both `qa_llm` and `cypher_llm` must be provided." + == str(exc_info.value) + ) + # cypher_llm only + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + cypher_llm=FakeLLM(), graph=FakeGraphStore(), allow_dangerous_requests=True + ) + assert ( + "If `llm` is not provided, both `qa_llm` and `cypher_llm` must be provided." + == str(exc_info.value) + ) + # llm + qa_llm + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + qa_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + # llm + cypher_llm + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + cypher_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + # qa_llm + cypher_llm + GraphCypherQAChain.from_llm( + qa_llm=FakeLLM(), + cypher_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + # llm + qa_llm + cypher_llm + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + qa_llm=FakeLLM(), + cypher_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + assert ( + "You can specify up to two of 'cypher_llm', 'qa_llm'" + ", and 'llm', but not all three simultaneously." + ) == str(exc_info.value) + + +def test_use_function_response_err() -> None: + llm = MagicMock(spec=LLM) + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=llm, + graph=FakeGraphStore(), + allow_dangerous_requests=True, + use_function_response=True, + ) + assert "Provided LLM does not support native tools/functions" == str(exc_info.value) + + HERE = pathlib.Path(__file__).parent UNIT_TESTS_ROOT = HERE.parent