From 81a0b2a4db6863033f02c6b420bebdd6b7925274 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Fri, 1 Nov 2024 16:26:05 -0500 Subject: [PATCH] fixed bug in GraphVectorStoreRetriever --- .../graph_vectorstores/base.py | 21 ++++++++++++------ .../graph_vectorstores/test_cassandra.py | 22 +++++++++++++++++++ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/libs/community/langchain_community/graph_vectorstores/base.py b/libs/community/langchain_community/graph_vectorstores/base.py index 0a320d98f9eca..3c0d81b915f48 100644 --- a/libs/community/langchain_community/graph_vectorstores/base.py +++ b/libs/community/langchain_community/graph_vectorstores/base.py @@ -8,6 +8,7 @@ ClassVar, Optional, Sequence, + cast, ) from langchain_core._api import beta @@ -701,7 +702,7 @@ def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever: docsearch.as_retriever(search_kwargs={'k': 1}) """ - return GraphVectorStoreRetriever(vector_store=self, **kwargs) + return GraphVectorStoreRetriever(vectorstore=self, **kwargs) @beta(message="Added in version 0.3.1 of langchain_community. API subject to change.") @@ -837,8 +838,8 @@ class GraphVectorStoreRetriever(VectorStoreRetriever): retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5}) """ # noqa: E501 - vector_store: GraphVectorStore - """GraphVectorStore to use for retrieval.""" + vectorstore: VectorStore + """VectorStore to use for retrieval.""" search_type: str = "traversal" """Type of search to perform. Defaults to "traversal".""" allowed_search_types: ClassVar[Collection[str]] = ( @@ -849,14 +850,20 @@ class GraphVectorStoreRetriever(VectorStoreRetriever): "mmr_traversal", ) + @property + def graph_vectorstore(self) -> GraphVectorStore: + return cast(GraphVectorStore, self.vectorstore) + def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> list[Document]: if self.search_type == "traversal": - return list(self.vector_store.traversal_search(query, **self.search_kwargs)) + return list( + self.graph_vectorstore.traversal_search(query, **self.search_kwargs) + ) elif self.search_type == "mmr_traversal": return list( - self.vector_store.mmr_traversal_search(query, **self.search_kwargs) + self.graph_vectorstore.mmr_traversal_search(query, **self.search_kwargs) ) else: return super()._get_relevant_documents(query, run_manager=run_manager) @@ -867,14 +874,14 @@ async def _aget_relevant_documents( if self.search_type == "traversal": return [ doc - async for doc in self.vector_store.atraversal_search( + async for doc in self.graph_vectorstore.atraversal_search( query, **self.search_kwargs ) ] elif self.search_type == "mmr_traversal": return [ doc - async for doc in self.vector_store.ammr_traversal_search( + async for doc in self.graph_vectorstore.ammr_traversal_search( query, **self.search_kwargs ) ] diff --git a/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py b/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py index d55f5469e546a..337497a528446 100644 --- a/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py +++ b/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py @@ -440,6 +440,17 @@ def test_gvs_traversal_search_sync( ts_labels = {doc.metadata["label"] for doc in ts_response} assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} + # verify the same works as a retriever + retriever = g_store.as_retriever( + search_type="traversal", search_kwargs={"k": 2, "depth": 2} + ) + + ts_labels = { + doc.metadata["label"] + for doc in retriever.get_relevant_documents(query="[2, 10]") + } + assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} + async def test_gvs_traversal_search_async( self, populated_graph_vector_store_d2: CassandraGraphVectorStore, @@ -453,6 +464,17 @@ async def test_gvs_traversal_search_async( # so ordering is not deterministic: assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} + # verify the same works as a retriever + retriever = g_store.as_retriever( + search_type="traversal", search_kwargs={"k": 2, "depth": 2} + ) + + ts_labels = { + doc.metadata["label"] + for doc in await retriever.aget_relevant_documents(query="[2, 10]") + } + assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} + def test_gvs_mmr_traversal_search_sync( self, populated_graph_vector_store_d2: CassandraGraphVectorStore,