Skip to content

Commit

Permalink
Add async support to SelfQueryRetriever (langchain-ai#10175)
Browse files Browse the repository at this point in the history
### Description

SelfQueryRetriever is missing async support, so I am adding it.
I also removed deprecated predict_and_parse method usage here, and added
some tests.

### Issue
N/A

### Tag maintainer
Not yet

### Twitter handle
N/A
  • Loading branch information
asai95 authored Oct 6, 2023
1 parent 35297ca commit fd9da60
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 20 deletions.
4 changes: 3 additions & 1 deletion libs/langchain/langchain/chains/query_constructor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,6 @@ def load_query_constructor_chain(
allowed_operators=allowed_operators,
enable_limit=enable_limit,
)
return LLMChain(llm=llm, prompt=prompt, **kwargs)
return LLMChain(
llm=llm, prompt=prompt, output_parser=prompt.output_parser, **kwargs
)
91 changes: 72 additions & 19 deletions libs/langchain/langchain/retrievers/self_query/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Retriever that generates and executes structured queries over its own data source."""
import logging
from typing import Any, Dict, List, Optional, Tuple, Type, cast

from typing import Any, Dict, List, Optional, Type, cast

from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains import LLMChain
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
Expand Down Expand Up @@ -42,6 +45,8 @@
Weaviate,
)

logger = logging.getLogger(__name__)


def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
"""Get the translator class corresponding to the vector store class."""
Expand Down Expand Up @@ -108,6 +113,49 @@ def validate_translator(cls, values: Dict) -> Dict:
)
return values

def _get_structured_query(
self, inputs: Dict[str, Any], run_manager: CallbackManagerForRetrieverRun
) -> StructuredQuery:
structured_query = cast(
StructuredQuery,
self.llm_chain.predict(callbacks=run_manager.get_child(), **inputs),
)
return structured_query

async def _aget_structured_query(
self, inputs: Dict[str, Any], run_manager: AsyncCallbackManagerForRetrieverRun
) -> StructuredQuery:
structured_query = cast(
StructuredQuery,
await self.llm_chain.apredict(callbacks=run_manager.get_child(), **inputs),
)
return structured_query

def _prepare_query(
self, query: str, structured_query: StructuredQuery
) -> Tuple[str, Dict[str, Any]]:
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query
)
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
if self.use_original_query:
new_query = query
search_kwargs = {**self.search_kwargs, **new_kwargs}
return new_query, search_kwargs

def _get_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
) -> List[Document]:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs

async def _aget_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
) -> List[Document]:
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
return docs

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
Expand All @@ -120,25 +168,30 @@ def _get_relevant_documents(
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = cast(
StructuredQuery,
self.llm_chain.predict_and_parse(
callbacks=run_manager.get_child(), **inputs
),
)
structured_query = self._get_structured_query(inputs, run_manager)
if self.verbose:
print(structured_query)
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query
)
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = self._get_docs_with_query(new_query, search_kwargs)
return docs

if self.use_original_query:
new_query = query
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query.
search_kwargs = {**self.search_kwargs, **new_kwargs}
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = await self._aget_structured_query(inputs, run_manager)
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = await self._aget_docs_with_query(new_query, search_kwargs)
return docs

@classmethod
Expand Down
142 changes: 142 additions & 0 deletions libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Any, Dict, List, Tuple, Union

import pytest

from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers import SelfQueryRetriever
from langchain.schema import Document
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
from tests.unit_tests.llms.fake_llm import FakeLLM


class FakeTranslator(Visitor):
allowed_comparators = (
Comparator.EQ,
Comparator.NE,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
Comparator.CONTAIN,
Comparator.LIKE,
)
allowed_operators = (Operator.AND, Operator.OR, Operator.NOT)

def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"

def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}

def visit_comparison(self, comparison: Comparison) -> Dict:
return {
comparison.attribute: {
self._format_func(comparison.comparator): comparison.value
}
}

def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs


class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
res = self.store.get(query)
if res is None:
return []
return [res]


@pytest.fixture()
def fake_llm() -> FakeLLM:
return FakeLLM(
queries={
"1": """```json
{
"query": "test",
"filter": null
}
```""",
"bar": "baz",
},
sequential_responses=True,
)


@pytest.fixture()
def fake_vectorstore() -> InMemoryVectorstoreWithSearch:
vectorstore = InMemoryVectorstoreWithSearch()
vectorstore.add_documents(
[
Document(
page_content="test",
metadata={
"foo": "bar",
},
),
],
ids=["test"],
)
return vectorstore


@pytest.fixture()
def fake_self_query_retriever(
fake_llm: FakeLLM, fake_vectorstore: InMemoryVectorstoreWithSearch
) -> SelfQueryRetriever:
return SelfQueryRetriever.from_llm(
llm=fake_llm,
vectorstore=fake_vectorstore,
document_contents="test",
metadata_field_info=[
AttributeInfo(
name="foo",
type="string",
description="test",
),
],
structured_query_translator=FakeTranslator(),
)


def test__get_relevant_documents(fake_self_query_retriever: SelfQueryRetriever) -> None:
relevant_documents = fake_self_query_retriever._get_relevant_documents(
"foo",
run_manager=CallbackManagerForRetrieverRun.get_noop_manager(),
)
assert len(relevant_documents) == 1
assert relevant_documents[0].metadata["foo"] == "bar"


@pytest.mark.asyncio
async def test__aget_relevant_documents(
fake_self_query_retriever: SelfQueryRetriever,
) -> None:
relevant_documents = await fake_self_query_retriever._aget_relevant_documents(
"foo",
run_manager=AsyncCallbackManagerForRetrieverRun.get_noop_manager(),
)
assert len(relevant_documents) == 1
assert relevant_documents[0].metadata["foo"] == "bar"

0 comments on commit fd9da60

Please sign in to comment.