Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vector index filtering #2

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_community.chains.graph_qa.prompts import (
CYPHER_FILTER_GENERATION_PROMPT,
CYPHER_GENERATION_PROMPT,
CYPHER_QA_PROMPT,
)
from langchain_community.graphs.graph_store import GraphStore
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import (
AIMessage,
Expand Down Expand Up @@ -62,6 +64,23 @@ def extract_cypher(text: str) -> str:
return matches[0] if matches else text


def extract_filter(text: str) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a unit test for this under libs/neo4j/tests/unit_tests/chains/test_graph_qa.py?

"""Extract the filter to be used in the query (if any) from a text.

Args:
text: Text to extract filter from.

Returns:
Filter extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"###(.*?)###"
Comment on lines +76 to +77
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the text get enclosed in ###?

Also, I think the comment can be corrected here.

# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)

return matches[0] if matches else ""


def construct_schema(
structured_schema: Dict[str, Any],
include_types: List[str],
Expand All @@ -88,6 +107,7 @@ def filter_func(x: str) -> bool:
for r in structured_schema.get("relationships", [])
if all(filter_func(r[t]) for t in ["start", "end", "type"])
],
"vector_indexes": structured_schema.get("vector_indexes", []),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add some test to ensure this field is included in libs/neo4j/tests/unit_tests/chains/test_graph_qa.py ?

}

# Format node properties
Expand All @@ -112,6 +132,12 @@ def filter_func(x: str) -> bool:
for el in filtered_schema["relationships"]
]

formatted_vector_indexes = [
f"{index['name']}: {index['type']} index on {index['entityType']} "
f"{index['labelsOrTypes']} and PROPERTIES {index['properties']}"
for index in filtered_schema["vector_indexes"]
]

return "\n".join(
[
"Node properties are the following:",
Expand All @@ -120,6 +146,8 @@ def filter_func(x: str) -> bool:
",".join(formatted_rel_props),
"The relationships are the following:",
",".join(formatted_rels),
"The vector indexes are the following:",
",".join(formatted_vector_indexes),
Comment on lines +149 to +150
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, please write some tests to test this under libs/neo4j/tests/unit_tests/chains/test_graph_qa.py

]
)

Expand Down Expand Up @@ -170,6 +198,8 @@ class GraphCypherQAChain(Chain):
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
embeddings: Optional[Embeddings] = None
"""Embedding model to be used for filtering on vector indexes (if any)"""
top_k: int = 10
"""Number of results to return from the query"""
return_intermediate_steps: bool = False
Expand Down Expand Up @@ -281,7 +311,11 @@ def from_llm(
)
if "prompt" not in use_cypher_llm_kwargs:
use_cypher_llm_kwargs["prompt"] = (
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
cypher_prompt
if cypher_prompt is not None
else CYPHER_FILTER_GENERATION_PROMPT
if kwargs.get("embeddings")
else CYPHER_GENERATION_PROMPT
)

qa_llm = qa_llm or llm
Expand Down Expand Up @@ -349,10 +383,10 @@ def _call(

intermediate_steps: List = []

generated_cypher = self.cypher_generation_chain.run(args, callbacks=callbacks)
generated_query = self.cypher_generation_chain.run(args, callbacks=callbacks)

# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
generated_cypher = extract_cypher(generated_query)

# Correct Cypher query if enabled
if self.cypher_query_corrector:
Expand All @@ -363,12 +397,26 @@ def _call(
generated_cypher, color="green", end="\n", verbose=self.verbose
)

intermediate_steps.append({"query": generated_cypher})
intermediate_steps.append({"query": generated_query})

generated_filter = extract_filter(generated_query)

# Retrieve and limit the number of results
# Generated Cypher be null if query corrector identifies invalid schema
if generated_cypher:
context = self.graph.query(generated_cypher)[: self.top_k]
if generated_filter and self.embeddings:
_run_manager.on_text(
"Generated Filter:", end="\n", verbose=self.verbose
)
_run_manager.on_text(
generated_filter, color="green", end="\n", verbose=self.verbose
)
search_emb = self.embeddings.embed_query(generated_filter)
context = self.graph.query(
query=generated_cypher, params=dict(search_emb=search_emb)
)[: self.top_k]
else:
context = self.graph.query(generated_cypher)[: self.top_k]
else:
context = []

Expand Down
25 changes: 25 additions & 0 deletions libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# Threshold for returning all available prop values in graph schema
DISTINCT_VALUE_LIMIT = 10

vector_indexes_query = """
SHOW VECTOR INDEXES YIELD name, type, entityType, labelsOrTypes, properties
"""

node_properties_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
Expand Down Expand Up @@ -160,6 +164,7 @@ def _get_rel_import_query(baseEntityLabel: bool) -> str:
def _format_schema(schema: Dict, is_enhanced: bool) -> str:
formatted_node_props = []
formatted_rel_props = []
formatted_vector_indexes = []
if is_enhanced:
# Enhanced formatting for nodes
for node_type, properties in schema["node_props"].items():
Expand Down Expand Up @@ -274,6 +279,13 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str:
for el in schema["relationships"]
]

# Format vector indexes
for index in schema["vector_indexes"]:
formatted_vector_indexes.append(
f"{index['name']}: {index['type']} index on {index['entityType']} "
f"{index['labelsOrTypes']} and PROPERTIES {index['properties']}"
)

return "\n".join(
[
"Node properties:",
Expand All @@ -282,6 +294,8 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str:
"\n".join(formatted_rel_props),
"The relationships:",
"\n".join(formatted_rels),
"Vector indexes:",
"\n".join(formatted_vector_indexes),
]
)

Expand Down Expand Up @@ -489,6 +503,16 @@ def refresh_schema(self) -> None:
params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
)
]
vector_indexes = [
{
"name": el["name"],
"type": el["type"],
"entityType": el["entityType"],
"labelsOrTypes": el["labelsOrTypes"],
"properties": el["properties"],
}
for el in self.query(vector_indexes_query)
]

# Get constraints & indexes
try:
Expand All @@ -508,6 +532,7 @@ def refresh_schema(self) -> None:
"node_props": {el["labels"]: el["properties"] for el in node_properties},
"rel_props": {el["type"]: el["properties"] for el in rel_properties},
"relationships": relationships,
"vector_indexes": vector_indexes,
"metadata": {"constraint": constraint, "index": index},
}
if self._enhanced_schema:
Expand Down