Skip to content

Commit

Permalink
Adds tests so that GraphCypherQAChain class now has 100% coverage (#23)
Browse files Browse the repository at this point in the history
* Added tests to improve GraphCypherQAChain test coverage

* Fixed linting issue

* Added validate_cypher test

* Added test_function_response

* 100% coverage for GraphCypherQAChain

* Refactoring
  • Loading branch information
alexthomas93 authored Dec 18, 2024
1 parent cfa583a commit a53e8ee
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 20 deletions.
2 changes: 1 addition & 1 deletion libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def from_llm(
if validate_cypher:
corrector_schema = [
Schema(el["start"], el["type"], el["end"])
for el in kwargs["graph"].structured_schema.get("relationships")
for el in kwargs["graph"].get_structured_schema.get("relationships", [])
]
cypher_query_corrector = CypherQueryCorrector(corrector_schema)

Expand Down
55 changes: 50 additions & 5 deletions libs/neo4j/tests/integration_tests/chains/test_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from unittest.mock import MagicMock

from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.fake import FakeListLLM

from langchain_neo4j.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph
from tests.llms.fake_llm import FakeLLM


def test_connect_neo4j() -> None:
Expand Down Expand Up @@ -71,10 +71,13 @@ def test_cypher_generating_run() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = FakeListLLM(responses=[query, "Bruce Willis"])
llm = FakeLLM(
queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True
)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
validate_cypher=True,
allow_dangerous_requests=True,
)
output = chain.run("Who starred in Pulp Fiction?")
Expand Down Expand Up @@ -111,7 +114,7 @@ def test_cypher_top_k() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = FakeListLLM(responses=[query])
llm = FakeLLM(queries={"query": query}, sequential_responses=True)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
Expand Down Expand Up @@ -149,7 +152,9 @@ def test_cypher_intermediate_steps() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = FakeListLLM(responses=[query, "Bruce Willis"])
llm = FakeLLM(
queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True
)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
Expand Down Expand Up @@ -194,7 +199,7 @@ def test_cypher_return_direct() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = FakeListLLM(responses=[query])
llm = FakeLLM(queries={"query": query}, sequential_responses=True)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
Expand All @@ -206,6 +211,46 @@ def test_cypher_return_direct() -> None:
assert output == expected_output


def test_function_response() -> None:
"""Test returning a function response."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j")
password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")

graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
)
# Refresh schema information
graph.refresh_schema()

query = (
"MATCH (a:Actor)-[:ACTED_IN]->(m:Movie) "
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = FakeLLM(
queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True
)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
allow_dangerous_requests=True,
use_function_response=True,
)
output = chain.run("Who starred in Pulp Fiction?")
expected_output = "Bruce Willis"
assert output == expected_output


def test_exclude_types() -> None:
"""Test exclude types from schema."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
Expand Down
Empty file.
File renamed without changes.
Loading

0 comments on commit a53e8ee

Please sign in to comment.