Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a ListRerank document compressor #13311

Merged
merged 26 commits into from
Jul 18, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Filter that uses an LLM to rerank documents listwise and select top-k."""
from typing import Any, Callable, Dict, Optional, Sequence

from langchain.callbacks.manager import Callbacks
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain.schema import BasePromptTemplate, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.output_parsers import StructuredOutputParser, ResponseSchema


def _get_default_chain_prompt() -> PromptTemplate:
prompt_template = """
{context}
Query = ```{question}```
Sort the Documents by their relevance to the Query.

{format_instructions}
Sorted Documents:
"""
response_schemas = [
ResponseSchema(
name="reranked_documents",
description="""Reranked documents. Format: {"document_id": <int>, "score": <number>}""",
type="array[dict]",
)
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I experimented with a Pydantic parser that defines the full nested structure explicitly and saw notably more output parsing errors. Expressing the array[dict] type as an implicit nested type within a single ResponseSchema type argument was much more successful.

format_instructions = output_parser.get_format_instructions()
return PromptTemplate(
template=prompt_template,
input_variables=["question", "context"],
output_parser=output_parser,
partial_variables={"format_instructions": format_instructions},
)


def default_get_input(query: str, documents: Sequence[Document]) -> Dict[str, Any]:
"""Return the compression chain input."""
context = ""
for index, doc in enumerate(documents):
context += f"Document ID: {index} ```{doc.page_content}```\n"
context += f"Documents = [Document ID: 0, ..., Document ID: {len(documents) - 1}]"
return {"question": query, "context": context}


class ListRerank(BaseDocumentCompressor):
"""
Document compressor that uses `Zero-Shot Listwise Document Reranking`.

Source: https://arxiv.org/pdf/2305.02156.pdf
"""

top_n: int = 3
"""Number of documents to return."""

llm_chain: LLMChain
"""LLM wrapper to use for filtering documents."""

get_input: Callable[[str, Document], dict] = default_get_input
"""Callable for constructing the chain input from the query and a sequence of Documents."""

def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents based on their relevance to the query."""
_input = self.get_input(query, documents)
results = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks)
final_results = []
for r in results["reranked_documents"][: self.top_n]:
doc = documents[r["document_id"]]
doc.metadata["relevance_score"] = r["score"]
final_results.append(doc)
return final_results

@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> "ListRerank":
"""Create a ListRerank document compressor from a language model.

Args:
llm: The language model to use for filtering.
prompt: The prompt to use for the filter.
**kwargs: Additional arguments to pass to the constructor.

Returns:
A ListRerank document compressor that uses the given language model.
"""
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
llm_chain = LLMChain(llm=llm, prompt=_prompt)
return cls(llm_chain=llm_chain, **kwargs)