Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds tests so that GraphCypherQAChain class now has 100% coverage #23

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def from_llm(
if validate_cypher:
corrector_schema = [
Schema(el["start"], el["type"], el["end"])
for el in kwargs["graph"].structured_schema.get("relationships")
for el in kwargs["graph"].get_structured_schema.get("relationships", [])
stellasia marked this conversation as resolved.
Show resolved Hide resolved
]
cypher_query_corrector = CypherQueryCorrector(corrector_schema)

Expand Down
55 changes: 50 additions & 5 deletions libs/neo4j/tests/integration_tests/chains/test_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from unittest.mock import MagicMock
stellasia marked this conversation as resolved.
Show resolved Hide resolved

from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.fake import FakeListLLM

from langchain_neo4j.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph
from tests.unit_tests.llms.fake_llm import FakeLLM
stellasia marked this conversation as resolved.
Show resolved Hide resolved


def test_connect_neo4j() -> None:
Expand Down Expand Up @@ -71,10 +71,13 @@ def test_cypher_generating_run() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = FakeListLLM(responses=[query, "Bruce Willis"])
llm = FakeLLM(
queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True
)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
validate_cypher=True,
allow_dangerous_requests=True,
)
output = chain.run("Who starred in Pulp Fiction?")
Expand Down Expand Up @@ -111,7 +114,7 @@ def test_cypher_top_k() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = FakeListLLM(responses=[query])
llm = FakeLLM(queries={"query": query}, sequential_responses=True)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
Expand Down Expand Up @@ -149,7 +152,9 @@ def test_cypher_intermediate_steps() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = FakeListLLM(responses=[query, "Bruce Willis"])
llm = FakeLLM(
queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True
)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
Expand Down Expand Up @@ -194,7 +199,7 @@ def test_cypher_return_direct() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = FakeListLLM(responses=[query])
llm = FakeLLM(queries={"query": query}, sequential_responses=True)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
Expand All @@ -206,6 +211,46 @@ def test_cypher_return_direct() -> None:
assert output == expected_output


def test_function_response() -> None:
"""Test returning a function response."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j")
password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")

graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
)
# Refresh schema information
graph.refresh_schema()

query = (
"MATCH (a:Actor)-[:ACTED_IN]->(m:Movie) "
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = FakeLLM(
queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True
)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
allow_dangerous_requests=True,
use_function_response=True,
)
output = chain.run("Who starred in Pulp Fiction?")
expected_output = "Bruce Willis"
assert output == expected_output


def test_exclude_types() -> None:
"""Test exclude types from schema."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
Expand Down
Loading
Loading