-
Notifications
You must be signed in to change notification settings - Fork 16.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
langchain[minor]: Add PebbloRetrievalQA chain with Identity & Semanti…
…c Enforcement support (#20641) - **Description:** PebbloRetrievalQA chain introduces identity enforcement using vector-db metadata filtering - **Dependencies:** None - **Issue:** None - **Documentation:** Adding documentation for PebbloRetrievalQA chain in a separate PR(#20746) - **Unit tests:** New unit-tests added --------- Co-authored-by: Eugene Yurtsev <[email protected]>
- Loading branch information
Showing
6 changed files
with
698 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
Chains module for langchain_community | ||
This module contains the community chains. | ||
""" | ||
|
||
import importlib | ||
from typing import TYPE_CHECKING, Any | ||
|
||
if TYPE_CHECKING: | ||
from langchain_community.chains.pebblo_retrieval.base import PebbloRetrievalQA | ||
|
||
__all__ = ["PebbloRetrievalQA"] | ||
|
||
_module_lookup = { | ||
"PebbloRetrievalQA": "langchain_community.chains.pebblo_retrieval.base" | ||
} | ||
|
||
|
||
def __getattr__(name: str) -> Any: | ||
if name in _module_lookup: | ||
module = importlib.import_module(_module_lookup[name]) | ||
return getattr(module, name) | ||
raise AttributeError(f"module {__name__} has no attribute {name}") |
Empty file.
218 changes: 218 additions & 0 deletions
218
libs/community/langchain_community/chains/pebblo_retrieval/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
""" | ||
Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answering | ||
against a vector database. | ||
""" | ||
|
||
import inspect | ||
from typing import Any, Dict, List, Optional | ||
|
||
from langchain.chains.base import Chain | ||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain | ||
from langchain_core.callbacks import ( | ||
AsyncCallbackManagerForChainRun, | ||
CallbackManagerForChainRun, | ||
) | ||
from langchain_core.documents import Document | ||
from langchain_core.language_models import BaseLanguageModel | ||
from langchain_core.pydantic_v1 import Extra, Field, validator | ||
from langchain_core.vectorstores import VectorStoreRetriever | ||
|
||
from langchain_community.chains.pebblo_retrieval.enforcement_filters import ( | ||
SUPPORTED_VECTORSTORES, | ||
set_enforcement_filters, | ||
) | ||
from langchain_community.chains.pebblo_retrieval.models import ( | ||
AuthContext, | ||
SemanticContext, | ||
) | ||
|
||
|
||
class PebbloRetrievalQA(Chain): | ||
""" | ||
Retrieval Chain with Identity & Semantic Enforcement for question-answering | ||
against a vector database. | ||
""" | ||
|
||
combine_documents_chain: BaseCombineDocumentsChain | ||
"""Chain to use to combine the documents.""" | ||
input_key: str = "query" #: :meta private: | ||
output_key: str = "result" #: :meta private: | ||
return_source_documents: bool = False | ||
"""Return the source documents or not.""" | ||
|
||
retriever: VectorStoreRetriever = Field(exclude=True) | ||
"""VectorStore to use for retrieval.""" | ||
auth_context_key: str = "auth_context" #: :meta private: | ||
"""Authentication context for identity enforcement.""" | ||
semantic_context_key: str = "semantic_context" #: :meta private: | ||
"""Semantic context for semantic enforcement.""" | ||
|
||
def _call( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[CallbackManagerForChainRun] = None, | ||
) -> Dict[str, Any]: | ||
"""Run get_relevant_text and llm on input query. | ||
If chain has 'return_source_documents' as 'True', returns | ||
the retrieved documents as well under the key 'source_documents'. | ||
Example: | ||
.. code-block:: python | ||
res = indexqa({'query': 'This is my query'}) | ||
answer, docs = res['result'], res['source_documents'] | ||
""" | ||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | ||
question = inputs[self.input_key] | ||
auth_context = inputs.get(self.auth_context_key) | ||
semantic_context = inputs.get(self.semantic_context_key) | ||
accepts_run_manager = ( | ||
"run_manager" in inspect.signature(self._get_docs).parameters | ||
) | ||
if accepts_run_manager: | ||
docs = self._get_docs( | ||
question, auth_context, semantic_context, run_manager=_run_manager | ||
) | ||
else: | ||
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg] | ||
answer = self.combine_documents_chain.run( | ||
input_documents=docs, question=question, callbacks=_run_manager.get_child() | ||
) | ||
|
||
if self.return_source_documents: | ||
return {self.output_key: answer, "source_documents": docs} | ||
else: | ||
return {self.output_key: answer} | ||
|
||
async def _acall( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | ||
) -> Dict[str, Any]: | ||
"""Run get_relevant_text and llm on input query. | ||
If chain has 'return_source_documents' as 'True', returns | ||
the retrieved documents as well under the key 'source_documents'. | ||
Example: | ||
.. code-block:: python | ||
res = indexqa({'query': 'This is my query'}) | ||
answer, docs = res['result'], res['source_documents'] | ||
""" | ||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() | ||
question = inputs[self.input_key] | ||
auth_context = inputs.get(self.auth_context_key) | ||
semantic_context = inputs.get(self.semantic_context_key) | ||
accepts_run_manager = ( | ||
"run_manager" in inspect.signature(self._aget_docs).parameters | ||
) | ||
if accepts_run_manager: | ||
docs = await self._aget_docs( | ||
question, auth_context, semantic_context, run_manager=_run_manager | ||
) | ||
else: | ||
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg] | ||
answer = await self.combine_documents_chain.arun( | ||
input_documents=docs, question=question, callbacks=_run_manager.get_child() | ||
) | ||
|
||
if self.return_source_documents: | ||
return {self.output_key: answer, "source_documents": docs} | ||
else: | ||
return {self.output_key: answer} | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
arbitrary_types_allowed = True | ||
allow_population_by_field_name = True | ||
|
||
@property | ||
def input_keys(self) -> List[str]: | ||
"""Input keys. | ||
:meta private: | ||
""" | ||
return [self.input_key, self.auth_context_key, self.semantic_context_key] | ||
|
||
@property | ||
def output_keys(self) -> List[str]: | ||
"""Output keys. | ||
:meta private: | ||
""" | ||
_output_keys = [self.output_key] | ||
if self.return_source_documents: | ||
_output_keys += ["source_documents"] | ||
return _output_keys | ||
|
||
@property | ||
def _chain_type(self) -> str: | ||
"""Return the chain type.""" | ||
return "pebblo_retrieval_qa" | ||
|
||
@classmethod | ||
def from_chain_type( | ||
cls, | ||
llm: BaseLanguageModel, | ||
chain_type: str = "stuff", | ||
chain_type_kwargs: Optional[dict] = None, | ||
**kwargs: Any, | ||
) -> "PebbloRetrievalQA": | ||
"""Load chain from chain type.""" | ||
from langchain.chains.question_answering import load_qa_chain | ||
|
||
_chain_type_kwargs = chain_type_kwargs or {} | ||
combine_documents_chain = load_qa_chain( | ||
llm, chain_type=chain_type, **_chain_type_kwargs | ||
) | ||
return cls(combine_documents_chain=combine_documents_chain, **kwargs) | ||
|
||
@validator("retriever", pre=True, always=True) | ||
def validate_vectorstore( | ||
cls, retriever: VectorStoreRetriever | ||
) -> VectorStoreRetriever: | ||
""" | ||
Validate that the vectorstore of the retriever is supported vectorstores. | ||
""" | ||
if not any( | ||
isinstance(retriever.vectorstore, supported_class) | ||
for supported_class in SUPPORTED_VECTORSTORES | ||
): | ||
raise ValueError( | ||
f"Vectorstore must be an instance of one of the supported " | ||
f"vectorstores: {SUPPORTED_VECTORSTORES}. " | ||
f"Got {type(retriever.vectorstore).__name__} instead." | ||
) | ||
return retriever | ||
|
||
def _get_docs( | ||
self, | ||
question: str, | ||
auth_context: Optional[AuthContext], | ||
semantic_context: Optional[SemanticContext], | ||
*, | ||
run_manager: CallbackManagerForChainRun, | ||
) -> List[Document]: | ||
"""Get docs.""" | ||
set_enforcement_filters(self.retriever, auth_context, semantic_context) | ||
return self.retriever.get_relevant_documents( | ||
question, callbacks=run_manager.get_child() | ||
) | ||
|
||
async def _aget_docs( | ||
self, | ||
question: str, | ||
auth_context: Optional[AuthContext], | ||
semantic_context: Optional[SemanticContext], | ||
*, | ||
run_manager: AsyncCallbackManagerForChainRun, | ||
) -> List[Document]: | ||
"""Get docs.""" | ||
set_enforcement_filters(self.retriever, auth_context, semantic_context) | ||
return await self.retriever.aget_relevant_documents( | ||
question, callbacks=run_manager.get_child() | ||
) |
Oops, something went wrong.