diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py index b8353a1..a446ab5 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py @@ -8,11 +8,13 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain_community.chains.graph_qa.prompts import ( + CYPHER_FILTER_GENERATION_PROMPT, CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT, ) from langchain_community.graphs.graph_store import GraphStore from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import ( AIMessage, @@ -62,6 +64,23 @@ def extract_cypher(text: str) -> str: return matches[0] if matches else text +def extract_filter(text: str) -> str: + """Extract the filter to be used in the query (if any) from a text. + + Args: + text: Text to extract filter from. + + Returns: + Filter extracted from the text. + """ + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"###(.*?)###" + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else "" + + def construct_schema( structured_schema: Dict[str, Any], include_types: List[str], @@ -88,6 +107,7 @@ def filter_func(x: str) -> bool: for r in structured_schema.get("relationships", []) if all(filter_func(r[t]) for t in ["start", "end", "type"]) ], + "vector_indexes": structured_schema.get("vector_indexes", []), } # Format node properties @@ -112,6 +132,12 @@ def filter_func(x: str) -> bool: for el in filtered_schema["relationships"] ] + formatted_vector_indexes = [ + f"{index['name']}: {index['type']} index on {index['entityType']} " + f"{index['labelsOrTypes']} and PROPERTIES {index['properties']}" + for index in filtered_schema["vector_indexes"] + ] + return "\n".join( [ "Node properties are the following:", @@ -120,6 +146,8 @@ def filter_func(x: str) -> bool: ",".join(formatted_rel_props), "The relationships are the following:", ",".join(formatted_rels), + "The vector indexes are the following:", + ",".join(formatted_vector_indexes), ] ) @@ -170,6 +198,8 @@ class GraphCypherQAChain(Chain): graph_schema: str input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: + embeddings: Optional[Embeddings] = None + """Embedding model to be used for filtering on vector indexes (if any)""" top_k: int = 10 """Number of results to return from the query""" return_intermediate_steps: bool = False @@ -281,7 +311,11 @@ def from_llm( ) if "prompt" not in use_cypher_llm_kwargs: use_cypher_llm_kwargs["prompt"] = ( - cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT + cypher_prompt + if cypher_prompt is not None + else CYPHER_FILTER_GENERATION_PROMPT + if kwargs.get("embeddings") + else CYPHER_GENERATION_PROMPT ) qa_llm = qa_llm or llm @@ -349,10 +383,10 @@ def _call( intermediate_steps: List = [] - generated_cypher = self.cypher_generation_chain.run(args, callbacks=callbacks) + generated_query = self.cypher_generation_chain.run(args, callbacks=callbacks) # Extract Cypher code if it is wrapped in backticks - generated_cypher = extract_cypher(generated_cypher) + generated_cypher = extract_cypher(generated_query) # Correct Cypher query if enabled if self.cypher_query_corrector: @@ -363,12 +397,26 @@ def _call( generated_cypher, color="green", end="\n", verbose=self.verbose ) - intermediate_steps.append({"query": generated_cypher}) + intermediate_steps.append({"query": generated_query}) + + generated_filter = extract_filter(generated_query) # Retrieve and limit the number of results # Generated Cypher be null if query corrector identifies invalid schema if generated_cypher: - context = self.graph.query(generated_cypher)[: self.top_k] + if generated_filter and self.embeddings: + _run_manager.on_text( + "Generated Filter:", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + generated_filter, color="green", end="\n", verbose=self.verbose + ) + search_emb = self.embeddings.embed_query(generated_filter) + context = self.graph.query( + query=generated_cypher, params=dict(search_emb=search_emb) + )[: self.top_k] + else: + context = self.graph.query(generated_cypher)[: self.top_k] else: context = [] diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index 06e644b..f651695 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -13,6 +13,10 @@ # Threshold for returning all available prop values in graph schema DISTINCT_VALUE_LIMIT = 10 +vector_indexes_query = """ +SHOW VECTOR INDEXES YIELD name, type, entityType, labelsOrTypes, properties +""" + node_properties_query = """ CALL apoc.meta.data() YIELD label, other, elementType, type, property @@ -160,6 +164,7 @@ def _get_rel_import_query(baseEntityLabel: bool) -> str: def _format_schema(schema: Dict, is_enhanced: bool) -> str: formatted_node_props = [] formatted_rel_props = [] + formatted_vector_indexes = [] if is_enhanced: # Enhanced formatting for nodes for node_type, properties in schema["node_props"].items(): @@ -274,6 +279,13 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: for el in schema["relationships"] ] + # Format vector indexes + for index in schema["vector_indexes"]: + formatted_vector_indexes.append( + f"{index['name']}: {index['type']} index on {index['entityType']} " + f"{index['labelsOrTypes']} and PROPERTIES {index['properties']}" + ) + return "\n".join( [ "Node properties:", @@ -282,6 +294,8 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: "\n".join(formatted_rel_props), "The relationships:", "\n".join(formatted_rels), + "Vector indexes:", + "\n".join(formatted_vector_indexes), ] ) @@ -489,6 +503,16 @@ def refresh_schema(self) -> None: params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, ) ] + vector_indexes = [ + { + "name": el["name"], + "type": el["type"], + "entityType": el["entityType"], + "labelsOrTypes": el["labelsOrTypes"], + "properties": el["properties"], + } + for el in self.query(vector_indexes_query) + ] # Get constraints & indexes try: @@ -508,6 +532,7 @@ def refresh_schema(self) -> None: "node_props": {el["labels"]: el["properties"] for el in node_properties}, "rel_props": {el["type"]: el["properties"] for el in rel_properties}, "relationships": relationships, + "vector_indexes": vector_indexes, "metadata": {"constraint": constraint, "index": index}, } if self._enhanced_schema: