Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor langchain.retrievers.self_query #16115

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/docs/integrations/chat/groq.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
"\n",
"chain = prompt | chat\n",
"chain.invoke({\n",
" \"text\": \"Explain the importance of low latency LLMs.\"\n",
"})"
"chain.invoke({\"text\": \"Explain the importance of low latency LLMs.\"})"
]
},
{
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Logic for converting internal query language to a valid AstraDB query."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from typing import Dict, Tuple, Union

from langchain_core.sql_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)

MULTIPLE_ARITY_COMPARATORS = [Comparator.IN, Comparator.NIN]


class AstraDBTranslator(Visitor):
"""Translate AstraDB internal query language elements to valid filters."""

"""Subset of allowed logical comparators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.IN,
Comparator.NIN,
]

"""Subset of allowed logical operators."""
allowed_operators = [Operator.AND, Operator.OR]

def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
map_dict = {
Operator.AND: "$and",
Operator.OR: "$or",
Comparator.EQ: "$eq",
Comparator.NE: "$ne",
Comparator.GTE: "$gte",
Comparator.LTE: "$lte",
Comparator.LT: "$lt",
Comparator.GT: "$gt",
Comparator.IN: "$in",
Comparator.NIN: "$nin",
}
return map_dict[func]

def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}

def visit_comparison(self, comparison: Comparison) -> Dict:
if comparison.comparator in MULTIPLE_ARITY_COMPARATORS and not isinstance(
comparison.value, list
):
comparison.value = [comparison.value]

comparator = self._format_func(comparison.comparator)
return {comparison.attribute: {comparator: comparison.value}}

def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs
250 changes: 250 additions & 0 deletions libs/community/langchain_community/retrievers/self_query/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
"""Retriever that generates and executes structured queries over its own data source."""
import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union

from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from langchain_core.sql_constructor.base import load_query_constructor_runnable
from langchain_core.sql_constructor.ir import StructuredQuery, Visitor
from langchain_core.sql_constructor.schema import AttributeInfo
from langchain_core.vectorstores import VectorStore

from langchain_community.retrievers.self_query.astradb import AstraDBTranslator
from langchain_community.retrievers.self_query.chroma import ChromaTranslator
from langchain_community.retrievers.self_query.dashvector import DashvectorTranslator
from langchain_community.retrievers.self_query.deeplake import DeepLakeTranslator
from langchain_community.retrievers.self_query.elasticsearch import (
ElasticsearchTranslator,
)
from langchain_community.retrievers.self_query.milvus import MilvusTranslator
from langchain_community.retrievers.self_query.mongodb_atlas import (
MongoDBAtlasTranslator,
)
from langchain_community.retrievers.self_query.myscale import MyScaleTranslator
from langchain_community.retrievers.self_query.opensearch import OpenSearchTranslator
from langchain_community.retrievers.self_query.pgvector import PGVectorTranslator
from langchain_community.retrievers.self_query.pinecone import PineconeTranslator
from langchain_community.retrievers.self_query.qdrant import QdrantTranslator
from langchain_community.retrievers.self_query.redis import RedisTranslator
from langchain_community.retrievers.self_query.supabase import SupabaseVectorTranslator
from langchain_community.retrievers.self_query.timescalevector import (
TimescaleVectorTranslator,
)
from langchain_community.retrievers.self_query.vectara import VectaraTranslator
from langchain_community.retrievers.self_query.weaviate import WeaviateTranslator
from langchain_community.vectorstores import (
AstraDB,
Chroma,
DashVector,
DeepLake,
ElasticsearchStore,
Milvus,
MongoDBAtlasVectorSearch,
MyScale,
OpenSearchVectorSearch,
PGVector,
Pinecone,
Qdrant,
Redis,
SupabaseVectorStore,
TimescaleVector,
Vectara,
Weaviate,
)

logger = logging.getLogger(__name__)


def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
"""Get the translator class corresponding to the vector store class."""
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
AstraDB: AstraDBTranslator,
Chroma: ChromaTranslator,
DashVector: DashvectorTranslator,
DeepLake: DeepLakeTranslator,
ElasticsearchStore: ElasticsearchTranslator,
Milvus: MilvusTranslator,
MongoDBAtlasVectorSearch: MongoDBAtlasTranslator,
MyScale: MyScaleTranslator,
OpenSearchVectorSearch: OpenSearchTranslator,
PGVector: PGVectorTranslator,
Pinecone: PineconeTranslator,
Qdrant: QdrantTranslator,
SupabaseVectorStore: SupabaseVectorTranslator,
TimescaleVector: TimescaleVectorTranslator,
Vectara: VectaraTranslator,
Weaviate: WeaviateTranslator,
}
if isinstance(vectorstore, Qdrant):
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
elif isinstance(vectorstore, MyScale):
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
elif isinstance(vectorstore, Redis):
return RedisTranslator.from_vectorstore(vectorstore)
elif vectorstore.__class__ in BUILTIN_TRANSLATORS:
return BUILTIN_TRANSLATORS[vectorstore.__class__]()
else:
raise ValueError(
f"Self query retriever with Vector Store type {vectorstore.__class__}"
f" not supported."
)


class SelfQueryRetriever(BaseRetriever):
"""Retriever that uses a vector store and an LLM to generate
the vector store queries."""

vectorstore: VectorStore
"""The underlying vector store from which documents will be retrieved."""
query_constructor: Runnable[dict, StructuredQuery] = Field(alias="llm_chain")
"""The query constructor chain for generating the vector store queries.

llm_chain is legacy name kept for backwards compatibility."""
search_type: str = "similarity"
"""The search type to perform on the vector store."""
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass in to the vector store search."""
structured_query_translator: Visitor
"""Translator for turning internal query language into vectorstore search params."""
verbose: bool = False

use_original_query: bool = False
"""Use original query instead of the revised new query from LLM"""

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True
allow_population_by_field_name = True

@root_validator(pre=True)
def validate_translator(cls, values: Dict) -> Dict:
"""Validate translator."""
if "structured_query_translator" not in values:
values["structured_query_translator"] = _get_builtin_translator(
values["vectorstore"]
)
return values

@property
def llm_chain(self) -> Runnable:
"""llm_chain is legacy name kept for backwards compatibility."""
return self.query_constructor

def _prepare_query(
self, query: str, structured_query: StructuredQuery
) -> Tuple[str, Dict[str, Any]]:
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query
)
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
if self.use_original_query:
new_query = query
search_kwargs = {**self.search_kwargs, **new_kwargs}
return new_query, search_kwargs

def _get_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
) -> List[Document]:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs

async def _aget_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
) -> List[Document]:
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
return docs

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query.

Args:
query: string to find relevant documents for

Returns:
List of relevant documents
"""
structured_query = self.query_constructor.invoke(
{"query": query}, config={"callbacks": run_manager.get_child()}
)
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = self._get_docs_with_query(new_query, search_kwargs)
return docs

async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query.

Args:
query: string to find relevant documents for

Returns:
List of relevant documents
"""
structured_query = await self.query_constructor.ainvoke(
{"query": query}, config={"callbacks": run_manager.get_child()}
)
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = await self._aget_docs_with_query(new_query, search_kwargs)
return docs

@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
vectorstore: VectorStore,
document_contents: str,
metadata_field_info: Sequence[Union[AttributeInfo, dict]],
structured_query_translator: Optional[Visitor] = None,
chain_kwargs: Optional[Dict] = None,
enable_limit: bool = False,
use_original_query: bool = False,
**kwargs: Any,
) -> "SelfQueryRetriever":
if structured_query_translator is None:
structured_query_translator = _get_builtin_translator(vectorstore)
chain_kwargs = chain_kwargs or {}

if (
"allowed_comparators" not in chain_kwargs
and structured_query_translator.allowed_comparators is not None
):
chain_kwargs[
"allowed_comparators"
] = structured_query_translator.allowed_comparators
if (
"allowed_operators" not in chain_kwargs
and structured_query_translator.allowed_operators is not None
):
chain_kwargs[
"allowed_operators"
] = structured_query_translator.allowed_operators
query_constructor = load_query_constructor_runnable(
llm,
document_contents,
metadata_field_info,
enable_limit=enable_limit,
**chain_kwargs,
)
return cls(
query_constructor=query_constructor,
vectorstore=vectorstore,
use_original_query=use_original_query,
structured_query_translator=structured_query_translator,
**kwargs,
)
50 changes: 50 additions & 0 deletions libs/community/langchain_community/retrievers/self_query/chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Dict, Tuple, Union

from langchain_core.sql_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)


class ChromaTranslator(Visitor):
"""Translate `Chroma` internal query language elements to valid filters."""

allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
]
"""Subset of allowed logical comparators."""

def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"

def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}

def visit_comparison(self, comparison: Comparison) -> Dict:
return {
comparison.attribute: {
self._format_func(comparison.comparator): comparison.value
}
}

def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs
Loading
Loading