Skip to content

Commit

Permalink
langchain[minor]: Add PebbloRetrievalQA chain with Identity & Semanti…
Browse files Browse the repository at this point in the history
…c Enforcement support (#20641)

- **Description:** PebbloRetrievalQA chain introduces identity
enforcement using vector-db metadata filtering
- **Dependencies:** None
- **Issue:** None
- **Documentation:** Adding documentation for PebbloRetrievalQA chain in
a separate PR(#20746)
- **Unit tests:** New unit-tests added

---------

Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
Raj725 and eyurtsev authored May 15, 2024
1 parent f2f970f commit 54e0032
Show file tree
Hide file tree
Showing 6 changed files with 698 additions and 0 deletions.
24 changes: 24 additions & 0 deletions libs/community/langchain_community/chains/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Chains module for langchain_community
This module contains the community chains.
"""

import importlib
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from langchain_community.chains.pebblo_retrieval.base import PebbloRetrievalQA

__all__ = ["PebbloRetrievalQA"]

_module_lookup = {
"PebbloRetrievalQA": "langchain_community.chains.pebblo_retrieval.base"
}


def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
Empty file.
218 changes: 218 additions & 0 deletions libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
"""
Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answering
against a vector database.
"""

import inspect
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Extra, Field, validator
from langchain_core.vectorstores import VectorStoreRetriever

from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
SUPPORTED_VECTORSTORES,
set_enforcement_filters,
)
from langchain_community.chains.pebblo_retrieval.models import (
AuthContext,
SemanticContext,
)


class PebbloRetrievalQA(Chain):
"""
Retrieval Chain with Identity & Semantic Enforcement for question-answering
against a vector database.
"""

combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""Return the source documents or not."""

retriever: VectorStoreRetriever = Field(exclude=True)
"""VectorStore to use for retrieval."""
auth_context_key: str = "auth_context" #: :meta private:
"""Authentication context for identity enforcement."""
semantic_context_key: str = "semantic_context" #: :meta private:
"""Semantic context for semantic enforcement."""

def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)

if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}

async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
if accepts_run_manager:
docs = await self._aget_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)

if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True

@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return [self.input_key, self.auth_context_key, self.semantic_context_key]

@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
_output_keys = [self.output_key]
if self.return_source_documents:
_output_keys += ["source_documents"]
return _output_keys

@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "pebblo_retrieval_qa"

@classmethod
def from_chain_type(
cls,
llm: BaseLanguageModel,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
from langchain.chains.question_answering import load_qa_chain

_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(
llm, chain_type=chain_type, **_chain_type_kwargs
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)

@validator("retriever", pre=True, always=True)
def validate_vectorstore(
cls, retriever: VectorStoreRetriever
) -> VectorStoreRetriever:
"""
Validate that the vectorstore of the retriever is supported vectorstores.
"""
if not any(
isinstance(retriever.vectorstore, supported_class)
for supported_class in SUPPORTED_VECTORSTORES
):
raise ValueError(
f"Vectorstore must be an instance of one of the supported "
f"vectorstores: {SUPPORTED_VECTORSTORES}. "
f"Got {type(retriever.vectorstore).__name__} instead."
)
return retriever

def _get_docs(
self,
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
return self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)

async def _aget_docs(
self,
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
return await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)
Loading

0 comments on commit 54e0032

Please sign in to comment.