From fd9da60aeaed944a03231697115788e5a13c1595 Mon Sep 17 00:00:00 2001 From: Viktor Zhemchuzhnikov Date: Fri, 6 Oct 2023 09:54:21 +0800 Subject: [PATCH] Add async support to SelfQueryRetriever (#10175) ### Description SelfQueryRetriever is missing async support, so I am adding it. I also removed deprecated predict_and_parse method usage here, and added some tests. ### Issue N/A ### Tag maintainer Not yet ### Twitter handle N/A --- .../chains/query_constructor/base.py | 4 +- .../langchain/retrievers/self_query/base.py | 91 ++++++++--- .../retrievers/self_query/test_base.py | 142 ++++++++++++++++++ 3 files changed, 217 insertions(+), 20 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py diff --git a/libs/langchain/langchain/chains/query_constructor/base.py b/libs/langchain/langchain/chains/query_constructor/base.py index 71106753f0ac5..266ea58bd2e11 100644 --- a/libs/langchain/langchain/chains/query_constructor/base.py +++ b/libs/langchain/langchain/chains/query_constructor/base.py @@ -160,4 +160,6 @@ def load_query_constructor_chain( allowed_operators=allowed_operators, enable_limit=enable_limit, ) - return LLMChain(llm=llm, prompt=prompt, **kwargs) + return LLMChain( + llm=llm, prompt=prompt, output_parser=prompt.output_parser, **kwargs + ) diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index 713a461099d3e..0514f6c50e52d 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -1,8 +1,11 @@ """Retriever that generates and executes structured queries over its own data source.""" +import logging +from typing import Any, Dict, List, Optional, Tuple, Type, cast -from typing import Any, Dict, List, Optional, Type, cast - -from langchain.callbacks.manager import CallbackManagerForRetrieverRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) from langchain.chains import LLMChain from langchain.chains.query_constructor.base import load_query_constructor_chain from langchain.chains.query_constructor.ir import StructuredQuery, Visitor @@ -42,6 +45,8 @@ Weaviate, ) +logger = logging.getLogger(__name__) + def _get_builtin_translator(vectorstore: VectorStore) -> Visitor: """Get the translator class corresponding to the vector store class.""" @@ -108,6 +113,49 @@ def validate_translator(cls, values: Dict) -> Dict: ) return values + def _get_structured_query( + self, inputs: Dict[str, Any], run_manager: CallbackManagerForRetrieverRun + ) -> StructuredQuery: + structured_query = cast( + StructuredQuery, + self.llm_chain.predict(callbacks=run_manager.get_child(), **inputs), + ) + return structured_query + + async def _aget_structured_query( + self, inputs: Dict[str, Any], run_manager: AsyncCallbackManagerForRetrieverRun + ) -> StructuredQuery: + structured_query = cast( + StructuredQuery, + await self.llm_chain.apredict(callbacks=run_manager.get_child(), **inputs), + ) + return structured_query + + def _prepare_query( + self, query: str, structured_query: StructuredQuery + ) -> Tuple[str, Dict[str, Any]]: + new_query, new_kwargs = self.structured_query_translator.visit_structured_query( + structured_query + ) + if structured_query.limit is not None: + new_kwargs["k"] = structured_query.limit + if self.use_original_query: + new_query = query + search_kwargs = {**self.search_kwargs, **new_kwargs} + return new_query, search_kwargs + + def _get_docs_with_query( + self, query: str, search_kwargs: Dict[str, Any] + ) -> List[Document]: + docs = self.vectorstore.search(query, self.search_type, **search_kwargs) + return docs + + async def _aget_docs_with_query( + self, query: str, search_kwargs: Dict[str, Any] + ) -> List[Document]: + docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs) + return docs + def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: @@ -120,25 +168,30 @@ def _get_relevant_documents( List of relevant documents """ inputs = self.llm_chain.prep_inputs({"query": query}) - structured_query = cast( - StructuredQuery, - self.llm_chain.predict_and_parse( - callbacks=run_manager.get_child(), **inputs - ), - ) + structured_query = self._get_structured_query(inputs, run_manager) if self.verbose: - print(structured_query) - new_query, new_kwargs = self.structured_query_translator.visit_structured_query( - structured_query - ) - if structured_query.limit is not None: - new_kwargs["k"] = structured_query.limit + logger.info(f"Generated Query: {structured_query}") + new_query, search_kwargs = self._prepare_query(query, structured_query) + docs = self._get_docs_with_query(new_query, search_kwargs) + return docs - if self.use_original_query: - new_query = query + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + """Get documents relevant for a query. - search_kwargs = {**self.search_kwargs, **new_kwargs} - docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs) + Args: + query: string to find relevant documents for + + Returns: + List of relevant documents + """ + inputs = self.llm_chain.prep_inputs({"query": query}) + structured_query = await self._aget_structured_query(inputs, run_manager) + if self.verbose: + logger.info(f"Generated Query: {structured_query}") + new_query, search_kwargs = self._prepare_query(query, structured_query) + docs = await self._aget_docs_with_query(new_query, search_kwargs) return docs @classmethod diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py new file mode 100644 index 0000000000000..cb9909c9015d8 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py @@ -0,0 +1,142 @@ +from typing import Any, Dict, List, Tuple, Union + +import pytest + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) +from langchain.chains.query_constructor.ir import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) +from langchain.chains.query_constructor.schema import AttributeInfo +from langchain.retrievers import SelfQueryRetriever +from langchain.schema import Document +from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore +from tests.unit_tests.llms.fake_llm import FakeLLM + + +class FakeTranslator(Visitor): + allowed_comparators = ( + Comparator.EQ, + Comparator.NE, + Comparator.LT, + Comparator.LTE, + Comparator.GT, + Comparator.GTE, + Comparator.CONTAIN, + Comparator.LIKE, + ) + allowed_operators = (Operator.AND, Operator.OR, Operator.NOT) + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + return f"${func.value}" + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + return {self._format_func(operation.operator): args} + + def visit_comparison(self, comparison: Comparison) -> Dict: + return { + comparison.attribute: { + self._format_func(comparison.comparator): comparison.value + } + } + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs + + +class InMemoryVectorstoreWithSearch(InMemoryVectorStore): + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + res = self.store.get(query) + if res is None: + return [] + return [res] + + +@pytest.fixture() +def fake_llm() -> FakeLLM: + return FakeLLM( + queries={ + "1": """```json +{ + "query": "test", + "filter": null +} +```""", + "bar": "baz", + }, + sequential_responses=True, + ) + + +@pytest.fixture() +def fake_vectorstore() -> InMemoryVectorstoreWithSearch: + vectorstore = InMemoryVectorstoreWithSearch() + vectorstore.add_documents( + [ + Document( + page_content="test", + metadata={ + "foo": "bar", + }, + ), + ], + ids=["test"], + ) + return vectorstore + + +@pytest.fixture() +def fake_self_query_retriever( + fake_llm: FakeLLM, fake_vectorstore: InMemoryVectorstoreWithSearch +) -> SelfQueryRetriever: + return SelfQueryRetriever.from_llm( + llm=fake_llm, + vectorstore=fake_vectorstore, + document_contents="test", + metadata_field_info=[ + AttributeInfo( + name="foo", + type="string", + description="test", + ), + ], + structured_query_translator=FakeTranslator(), + ) + + +def test__get_relevant_documents(fake_self_query_retriever: SelfQueryRetriever) -> None: + relevant_documents = fake_self_query_retriever._get_relevant_documents( + "foo", + run_manager=CallbackManagerForRetrieverRun.get_noop_manager(), + ) + assert len(relevant_documents) == 1 + assert relevant_documents[0].metadata["foo"] == "bar" + + +@pytest.mark.asyncio +async def test__aget_relevant_documents( + fake_self_query_retriever: SelfQueryRetriever, +) -> None: + relevant_documents = await fake_self_query_retriever._aget_relevant_documents( + "foo", + run_manager=AsyncCallbackManagerForRetrieverRun.get_noop_manager(), + ) + assert len(relevant_documents) == 1 + assert relevant_documents[0].metadata["foo"] == "bar"