Skip to content

Commit

Permalink
fix: Cypher query extraction for node names with spaces (#24)
Browse files Browse the repository at this point in the history
- Updated `extract_cypher` to wrap node label with spaces in quotes.
- Added tests for node names with one and multiple spaces.
- Renamed test cases to improve clarity.
- Checked lint and format.

Co-authored-by: Alex Thomas <[email protected]>
  • Loading branch information
chkaty and alexthomas93 authored Jan 8, 2025
1 parent 5af91a1 commit ceaf64a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
14 changes: 13 additions & 1 deletion libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,20 @@ def extract_cypher(text: str) -> str:

# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
if matches:
cypher_query = matches[0]
else:
return text

return matches[0] if matches else text
# Remove backticks
cypher_query = cypher_query.replace("`", "")

# Quote node labels if they contain spaces
cypher_query = re.sub(
r":\s*(\s*)([a-zA-Z0-9_]+(?:\s+[a-zA-Z0-9_]+)+)(\s*)", r":'\2'", cypher_query
)

return cypher_query


def construct_schema(
Expand Down
31 changes: 28 additions & 3 deletions libs/neo4j/tests/unit_tests/chains/test_graph_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,18 +317,43 @@ def test_cypher_generation_failure() -> None:
assert response == []


def test_no_backticks() -> None:
def test_extract_cypher_on_no_backticks() -> None:
"""Test if there are no backticks, so the original text should be returned."""
query = "MATCH (n) RETURN n"
output = extract_cypher(query)
assert output == query


def test_backticks() -> None:
def test_extract_cypher_on_backticks() -> None:
"""Test if there are backticks. Query from within backticks should be returned."""
query = "You can use the following query: ```MATCH (n) RETURN n```"
expected_output = "MATCH (n) RETURN n"
output = extract_cypher(query)
assert output == "MATCH (n) RETURN n"
assert output == expected_output


def test_extract_cypher_on_label_with_spaces() -> None:
"""Test if node labels with spaces are quoted."""
query = "```MATCH (n:Label With Space) RETURN n```"
expected_output = "MATCH (n:'Label With Space') RETURN n"
output = extract_cypher(query)
assert output == expected_output


def test_extract_cypher_on_label_with_multi_spaces() -> None:
"""Test if node labels with multiple spaces are quoted."""
query = "```MATCH (n:Label With Space) RETURN n```"
expected_output = "MATCH (n:'Label With Space') RETURN n"
output = extract_cypher(query)
assert output == expected_output


def test_extract_cypher_on_label_without_spaces() -> None:
"""Test if node labels without spaces are not quoted."""
query = "```MATCH (n:LabelWithoutSpace) RETURN n```"
expected_output = "MATCH (n:LabelWithoutSpace) RETURN n"
output = extract_cypher(query)
assert output == expected_output


def test_exclude_types() -> None:
Expand Down

0 comments on commit ceaf64a

Please sign in to comment.