-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support vector index filtering #2
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,11 +8,13 @@ | |
from langchain.chains.base import Chain | ||
from langchain.chains.llm import LLMChain | ||
from langchain_community.chains.graph_qa.prompts import ( | ||
CYPHER_FILTER_GENERATION_PROMPT, | ||
CYPHER_GENERATION_PROMPT, | ||
CYPHER_QA_PROMPT, | ||
) | ||
from langchain_community.graphs.graph_store import GraphStore | ||
from langchain_core.callbacks import CallbackManagerForChainRun | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.language_models import BaseLanguageModel | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
|
@@ -62,6 +64,23 @@ def extract_cypher(text: str) -> str: | |
return matches[0] if matches else text | ||
|
||
|
||
def extract_filter(text: str) -> str: | ||
"""Extract the filter to be used in the query (if any) from a text. | ||
|
||
Args: | ||
text: Text to extract filter from. | ||
|
||
Returns: | ||
Filter extracted from the text. | ||
""" | ||
# The pattern to find Cypher code enclosed in triple backticks | ||
pattern = r"###(.*?)###" | ||
Comment on lines
+76
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does the text get enclosed in Also, I think the comment can be corrected here. |
||
# Find all matches in the input text | ||
matches = re.findall(pattern, text, re.DOTALL) | ||
|
||
return matches[0] if matches else "" | ||
|
||
|
||
def construct_schema( | ||
structured_schema: Dict[str, Any], | ||
include_types: List[str], | ||
|
@@ -88,6 +107,7 @@ def filter_func(x: str) -> bool: | |
for r in structured_schema.get("relationships", []) | ||
if all(filter_func(r[t]) for t in ["start", "end", "type"]) | ||
], | ||
"vector_indexes": structured_schema.get("vector_indexes", []), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you also add some test to ensure this field is included in libs/neo4j/tests/unit_tests/chains/test_graph_qa.py ? |
||
} | ||
|
||
# Format node properties | ||
|
@@ -112,6 +132,12 @@ def filter_func(x: str) -> bool: | |
for el in filtered_schema["relationships"] | ||
] | ||
|
||
formatted_vector_indexes = [ | ||
f"{index['name']}: {index['type']} index on {index['entityType']} " | ||
f"{index['labelsOrTypes']} and PROPERTIES {index['properties']}" | ||
for index in filtered_schema["vector_indexes"] | ||
] | ||
|
||
return "\n".join( | ||
[ | ||
"Node properties are the following:", | ||
|
@@ -120,6 +146,8 @@ def filter_func(x: str) -> bool: | |
",".join(formatted_rel_props), | ||
"The relationships are the following:", | ||
",".join(formatted_rels), | ||
"The vector indexes are the following:", | ||
",".join(formatted_vector_indexes), | ||
Comment on lines
+149
to
+150
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, please write some tests to test this under libs/neo4j/tests/unit_tests/chains/test_graph_qa.py |
||
] | ||
) | ||
|
||
|
@@ -170,6 +198,8 @@ class GraphCypherQAChain(Chain): | |
graph_schema: str | ||
input_key: str = "query" #: :meta private: | ||
output_key: str = "result" #: :meta private: | ||
embeddings: Optional[Embeddings] = None | ||
"""Embedding model to be used for filtering on vector indexes (if any)""" | ||
top_k: int = 10 | ||
"""Number of results to return from the query""" | ||
return_intermediate_steps: bool = False | ||
|
@@ -281,7 +311,11 @@ def from_llm( | |
) | ||
if "prompt" not in use_cypher_llm_kwargs: | ||
use_cypher_llm_kwargs["prompt"] = ( | ||
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT | ||
cypher_prompt | ||
if cypher_prompt is not None | ||
else CYPHER_FILTER_GENERATION_PROMPT | ||
if kwargs.get("embeddings") | ||
else CYPHER_GENERATION_PROMPT | ||
) | ||
|
||
qa_llm = qa_llm or llm | ||
|
@@ -349,10 +383,10 @@ def _call( | |
|
||
intermediate_steps: List = [] | ||
|
||
generated_cypher = self.cypher_generation_chain.run(args, callbacks=callbacks) | ||
generated_query = self.cypher_generation_chain.run(args, callbacks=callbacks) | ||
|
||
# Extract Cypher code if it is wrapped in backticks | ||
generated_cypher = extract_cypher(generated_cypher) | ||
generated_cypher = extract_cypher(generated_query) | ||
|
||
# Correct Cypher query if enabled | ||
if self.cypher_query_corrector: | ||
|
@@ -363,12 +397,26 @@ def _call( | |
generated_cypher, color="green", end="\n", verbose=self.verbose | ||
) | ||
|
||
intermediate_steps.append({"query": generated_cypher}) | ||
intermediate_steps.append({"query": generated_query}) | ||
|
||
generated_filter = extract_filter(generated_query) | ||
|
||
# Retrieve and limit the number of results | ||
# Generated Cypher be null if query corrector identifies invalid schema | ||
if generated_cypher: | ||
context = self.graph.query(generated_cypher)[: self.top_k] | ||
if generated_filter and self.embeddings: | ||
_run_manager.on_text( | ||
"Generated Filter:", end="\n", verbose=self.verbose | ||
) | ||
_run_manager.on_text( | ||
generated_filter, color="green", end="\n", verbose=self.verbose | ||
) | ||
search_emb = self.embeddings.embed_query(generated_filter) | ||
context = self.graph.query( | ||
query=generated_cypher, params=dict(search_emb=search_emb) | ||
)[: self.top_k] | ||
else: | ||
context = self.graph.query(generated_cypher)[: self.top_k] | ||
else: | ||
context = [] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a unit test for this under libs/neo4j/tests/unit_tests/chains/test_graph_qa.py?