Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain: Add PebbloRetrievalQA chain with Identity & Semantic Enforcement support #20641

Merged
merged 5 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions libs/community/langchain_community/chains/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Chains module for langchain_community

This module contains the community chains.
"""

import importlib
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from langchain_community.chains.pebblo_retrieval.base import PebbloRetrievalQA

__all__ = ["PebbloRetrievalQA"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should use a hard-coded list of string literals otherwise some dev tools won't be able to pick it up


_module_lookup = {
"PebbloRetrievalQA": "langchain_community.chains.pebblo_retrieval.base"
}


def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
Empty file.
218 changes: 218 additions & 0 deletions libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
"""
Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answering
against a vector database.
"""

import inspect
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Extra, Field, validator
from langchain_core.vectorstores import VectorStoreRetriever

from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
SUPPORTED_VECTORSTORES,
set_enforcement_filters,
)
from langchain_community.chains.pebblo_retrieval.models import (
AuthContext,
SemanticContext,
)


class PebbloRetrievalQA(Chain):
"""
Retrieval Chain with Identity & Semantic Enforcement for question-answering
against a vector database.
"""

combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""Return the source documents or not."""

retriever: VectorStoreRetriever = Field(exclude=True)
"""VectorStore to use for retrieval."""
auth_context_key: str = "auth_context" #: :meta private:
"""Authentication context for identity enforcement."""
semantic_context_key: str = "semantic_context" #: :meta private:
"""Semantic context for semantic enforcement."""

def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.

If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.

Example:
.. code-block:: python

res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)

if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}

async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.

If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.

Example:
.. code-block:: python

res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
if accepts_run_manager:
docs = await self._aget_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)

if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True

@property
def input_keys(self) -> List[str]:
"""Input keys.

:meta private:
"""
return [self.input_key, self.auth_context_key, self.semantic_context_key]

@property
def output_keys(self) -> List[str]:
"""Output keys.

:meta private:
"""
_output_keys = [self.output_key]
if self.return_source_documents:
_output_keys += ["source_documents"]
return _output_keys

@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "pebblo_retrieval_qa"

@classmethod
def from_chain_type(
cls,
llm: BaseLanguageModel,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
from langchain.chains.question_answering import load_qa_chain

_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(
llm, chain_type=chain_type, **_chain_type_kwargs
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)

@validator("retriever", pre=True, always=True)
def validate_vectorstore(
cls, retriever: VectorStoreRetriever
) -> VectorStoreRetriever:
"""
Validate that the vectorstore of the retriever is supported vectorstores.
"""
if not any(
isinstance(retriever.vectorstore, supported_class)
for supported_class in SUPPORTED_VECTORSTORES
):
raise ValueError(
f"Vectorstore must be an instance of one of the supported "
f"vectorstores: {SUPPORTED_VECTORSTORES}. "
f"Got {type(retriever.vectorstore).__name__} instead."
)
return retriever

def _get_docs(
self,
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
return self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)

async def _aget_docs(
self,
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
return await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)
Loading
Loading