-
Notifications
You must be signed in to change notification settings - Fork 16.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
community: fixed bug in GraphVectorStoreRetriever #27846
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
ClassVar, | ||
Optional, | ||
Sequence, | ||
cast, | ||
) | ||
|
||
from langchain_core._api import beta | ||
|
@@ -701,7 +702,7 @@ def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever: | |
docsearch.as_retriever(search_kwargs={'k': 1}) | ||
|
||
""" | ||
return GraphVectorStoreRetriever(vector_store=self, **kwargs) | ||
return GraphVectorStoreRetriever(vectorstore=self, **kwargs) | ||
|
||
|
||
@beta(message="Added in version 0.3.1 of langchain_community. API subject to change.") | ||
|
@@ -837,8 +838,8 @@ class GraphVectorStoreRetriever(VectorStoreRetriever): | |
retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5}) | ||
""" # noqa: E501 | ||
|
||
vector_store: GraphVectorStore | ||
"""GraphVectorStore to use for retrieval.""" | ||
vectorstore: VectorStore | ||
"""VectorStore to use for retrieval.""" | ||
search_type: str = "traversal" | ||
"""Type of search to perform. Defaults to "traversal".""" | ||
allowed_search_types: ClassVar[Collection[str]] = ( | ||
|
@@ -849,14 +850,20 @@ class GraphVectorStoreRetriever(VectorStoreRetriever): | |
"mmr_traversal", | ||
) | ||
|
||
@property | ||
def graph_vectorstore(self) -> GraphVectorStore: | ||
return cast(GraphVectorStore, self.vectorstore) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to add this cast property to get the 3.13 lint to pass. If there is a better way to do this, please let me know. |
||
def _get_relevant_documents( | ||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | ||
) -> list[Document]: | ||
if self.search_type == "traversal": | ||
return list(self.vector_store.traversal_search(query, **self.search_kwargs)) | ||
return list( | ||
self.graph_vectorstore.traversal_search(query, **self.search_kwargs) | ||
) | ||
elif self.search_type == "mmr_traversal": | ||
return list( | ||
self.vector_store.mmr_traversal_search(query, **self.search_kwargs) | ||
self.graph_vectorstore.mmr_traversal_search(query, **self.search_kwargs) | ||
) | ||
else: | ||
return super()._get_relevant_documents(query, run_manager=run_manager) | ||
|
@@ -867,14 +874,14 @@ async def _aget_relevant_documents( | |
if self.search_type == "traversal": | ||
return [ | ||
doc | ||
async for doc in self.vector_store.atraversal_search( | ||
async for doc in self.graph_vectorstore.atraversal_search( | ||
query, **self.search_kwargs | ||
) | ||
] | ||
elif self.search_type == "mmr_traversal": | ||
return [ | ||
doc | ||
async for doc in self.vector_store.ammr_traversal_search( | ||
async for doc in self.graph_vectorstore.ammr_traversal_search( | ||
query, **self.search_kwargs | ||
) | ||
] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. jfyi the community integration tests don't run in CI ever, so might be more effective to add this to the astradb ci instead |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changing this from
vectorstore
tovector_store
created the bug.