From 5e916f8282b29c57cd0f2b1b4edfbd4996550dfe Mon Sep 17 00:00:00 2001 From: Katy Chen Date: Tue, 17 Dec 2024 22:57:56 -0500 Subject: [PATCH] fix: Cypher query extraction for node names with spaces - Updated `extract_cypher` to wrap node label with spaces in quotes. - Added tests for node names with one and multiple spaces. - Renamed test cases to improve clarity. - Checked lint and format. --- .../langchain_neo4j/chains/graph_qa/cypher.py | 14 ++++++++- .../tests/unit_tests/chains/test_graph_qa.py | 31 +++++++++++++++++-- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py index e84a4df..8343485 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py @@ -57,8 +57,20 @@ def extract_cypher(text: str) -> str: # Find all matches in the input text matches = re.findall(pattern, text, re.DOTALL) + if matches: + cypher_query = matches[0] + else: + return text - return matches[0] if matches else text + # Remove backticks + cypher_query = cypher_query.replace("`", "") + + # Quote node labels if they contain spaces + cypher_query = re.sub( + r":\s*(\s*)([a-zA-Z0-9_]+(?:\s+[a-zA-Z0-9_]+)+)(\s*)", r":'\2'", cypher_query + ) + + return cypher_query def construct_schema( diff --git a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py index a857fb4..ab760d4 100644 --- a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py @@ -236,18 +236,43 @@ def test_graph_cypher_qa_chain() -> None: assert True -def test_no_backticks() -> None: +def test_extract_cypher_on_no_backticks() -> None: """Test if there are no backticks, so the original text should be returned.""" query = "MATCH (n) RETURN n" output = extract_cypher(query) assert output == query -def test_backticks() -> None: +def test_extract_cypher_on_backticks() -> None: """Test if there are backticks. Query from within backticks should be returned.""" query = "You can use the following query: ```MATCH (n) RETURN n```" + expected_output = "MATCH (n) RETURN n" output = extract_cypher(query) - assert output == "MATCH (n) RETURN n" + assert output == expected_output + + +def test_extract_cypher_on_label_with_spaces() -> None: + """Test if node labels with spaces are quoted.""" + query = "```MATCH (n:Label With Space) RETURN n```" + expected_output = "MATCH (n:'Label With Space') RETURN n" + output = extract_cypher(query) + assert output == expected_output + + +def test_extract_cypher_on_label_with_multi_spaces() -> None: + """Test if node labels with multiple spaces are quoted.""" + query = "```MATCH (n:Label With Space) RETURN n```" + expected_output = "MATCH (n:'Label With Space') RETURN n" + output = extract_cypher(query) + assert output == expected_output + + +def test_extract_cypher_on_label_without_spaces() -> None: + """Test if node labels without spaces are not quoted.""" + query = "```MATCH (n:LabelWithoutSpace) RETURN n```" + expected_output = "MATCH (n:LabelWithoutSpace) RETURN n" + output = extract_cypher(query) + assert output == expected_output def test_exclude_types() -> None: