diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py index e84a4df..8343485 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py @@ -57,8 +57,20 @@ def extract_cypher(text: str) -> str: # Find all matches in the input text matches = re.findall(pattern, text, re.DOTALL) + if matches: + cypher_query = matches[0] + else: + return text - return matches[0] if matches else text + # Remove backticks + cypher_query = cypher_query.replace("`", "") + + # Quote node labels if they contain spaces + cypher_query = re.sub( + r":\s*(\s*)([a-zA-Z0-9_]+(?:\s+[a-zA-Z0-9_]+)+)(\s*)", r":'\2'", cypher_query + ) + + return cypher_query def construct_schema( diff --git a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py index a857fb4..ab760d4 100644 --- a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py @@ -236,18 +236,43 @@ def test_graph_cypher_qa_chain() -> None: assert True -def test_no_backticks() -> None: +def test_extract_cypher_on_no_backticks() -> None: """Test if there are no backticks, so the original text should be returned.""" query = "MATCH (n) RETURN n" output = extract_cypher(query) assert output == query -def test_backticks() -> None: +def test_extract_cypher_on_backticks() -> None: """Test if there are backticks. Query from within backticks should be returned.""" query = "You can use the following query: ```MATCH (n) RETURN n```" + expected_output = "MATCH (n) RETURN n" output = extract_cypher(query) - assert output == "MATCH (n) RETURN n" + assert output == expected_output + + +def test_extract_cypher_on_label_with_spaces() -> None: + """Test if node labels with spaces are quoted.""" + query = "```MATCH (n:Label With Space) RETURN n```" + expected_output = "MATCH (n:'Label With Space') RETURN n" + output = extract_cypher(query) + assert output == expected_output + + +def test_extract_cypher_on_label_with_multi_spaces() -> None: + """Test if node labels with multiple spaces are quoted.""" + query = "```MATCH (n:Label With Space) RETURN n```" + expected_output = "MATCH (n:'Label With Space') RETURN n" + output = extract_cypher(query) + assert output == expected_output + + +def test_extract_cypher_on_label_without_spaces() -> None: + """Test if node labels without spaces are not quoted.""" + query = "```MATCH (n:LabelWithoutSpace) RETURN n```" + expected_output = "MATCH (n:LabelWithoutSpace) RETURN n" + output = extract_cypher(query) + assert output == expected_output def test_exclude_types() -> None: