Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a ListRerank document compressor #13311

Merged
merged 26 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,53 @@ pretty_print_docs(compressed_docs)

</CodeOutputBlock>

### `ListRerank`

`ListRerank` uses [zero-shot listwise document reranking](https://arxiv.org/pdf/2305.02156.pdf) and functions similarly to `LLMChainFilter` as a robust but more expensive option. It is recommended to use a more powerful LLM.


```python
from langchain.chat_models import ChatOpenAI
from langchain.retrievers.document_compressors import ListRerank

llm = ChatOpenAI(model="gpt-3.5-turbo-1106", temperature=0)
_filter = ListRerank.from_llm(llm, top_n=1)
compression_retriever = ContextualCompressionRetriever(base_compressor=_filter, base_retriever=retriever)

compressed_docs = compression_retriever.get_relevant_documents("What did the president say about Ketanji Jackson Brown")
pretty_print_docs(compressed_docs)
```

<CodeOutputBlock lang="python">

```
Document 1:

Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections.

Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service.

One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.

And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.
```

</CodeOutputBlock>

Similar to `CohereRerank`, `ListRerank` will add a relevant score to the document metadata:

```python
print(compressed_docs[0].metadata["relevance_score"])
```

<CodeOutputBlock lang="python">

```
0.9
```

</CodeOutputBlock>

# Stringing compressors and document transformers together
Using the `DocumentCompressorPipeline` we can also easily combine multiple compressors in sequence. Along with compressors we can add `BaseDocumentTransformer`s to our pipeline, which don't perform any contextual compression but simply perform some transformation on a set of documents. For example `TextSplitter`s can be used as document transformers to split documents into smaller pieces, and the `EmbeddingsRedundantFilter` can be used to filter out redundant documents based on embedding similarity between documents.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from langchain.retrievers.document_compressors.embeddings_filter import (
EmbeddingsFilter,
)
from langchain.retrievers.document_compressors.list_rerank import ListRerank

__all__ = [
"DocumentCompressorPipeline",
"EmbeddingsFilter",
"LLMChainExtractor",
"LLMChainFilter",
"CohereRerank",
"ListRerank",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Filter that uses an LLM to rerank documents listwise and select top-k."""
import logging
from typing import Any, Callable, Dict, Optional, Sequence

from langchain.callbacks.manager import Callbacks
from langchain.chains import LLMChain
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain.prompts import PromptTemplate
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain.schema import BasePromptTemplate, Document
from langchain.schema.language_model import BaseLanguageModel

logger = logging.getLogger(__name__)

Check failure on line 13 in libs/langchain/langchain/retrievers/document_compressors/list_rerank.py

View workflow job for this annotation

GitHub Actions / lint / build (3.8)

Ruff (I001)

langchain/retrievers/document_compressors/list_rerank.py:2:1: I001 Import block is un-sorted or un-formatted

Check failure on line 13 in libs/langchain/langchain/retrievers/document_compressors/list_rerank.py

View workflow job for this annotation

GitHub Actions / lint / build (3.11)

Ruff (I001)

langchain/retrievers/document_compressors/list_rerank.py:2:1: I001 Import block is un-sorted or un-formatted


def _get_default_chain_prompt() -> PromptTemplate:
prompt_template = """
{context}
Query = ```{question}```
Sort the Documents by their relevance to the Query.

{format_instructions}
Sorted Documents:
"""
description = (
"""Reranked documents. Format: {"document_id": <int>, "score": <number>}"""
)
response_schemas = [
ResponseSchema(
name="reranked_documents",
description=description,
type="array[dict]",
)
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I experimented with a Pydantic parser that defines the full nested structure explicitly and saw notably more output parsing errors. Expressing the array[dict] type as an implicit nested type within a single ResponseSchema type argument was much more successful.

format_instructions = output_parser.get_format_instructions()
return PromptTemplate(
template=prompt_template,
input_variables=["question", "context"],
output_parser=output_parser,
partial_variables={"format_instructions": format_instructions},
)


def default_get_input(query: str, documents: Sequence[Document]) -> Dict[str, Any]:
"""Return the compression chain input."""
context = ""
for index, doc in enumerate(documents):
context += f"Document ID: {index} ```{doc.page_content}```\n"
context += f"Documents = [Document ID: 0, ..., Document ID: {len(documents) - 1}]"
return {"question": query, "context": context}


class ListRerank(BaseDocumentCompressor):
"""
Document compressor that uses `Zero-Shot Listwise Document Reranking`.

Source: https://arxiv.org/pdf/2305.02156.pdf
"""

top_n: int = 3
"""Number of documents to return."""

llm_chain: LLMChain
"""LLM wrapper to use for filtering documents."""

get_input: Callable[[str, Document], dict] = default_get_input
"""Callable for constructing the chain input from the query and Documents."""

fallback: bool = False
"""Whether to fallback to the original document scores if the LLM fails."""

def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents based on their relevance to the query."""
_input = self.get_input(query, documents)
try:
results = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks)
top_documents = results["reranked_documents"][: self.top_n]
except Exception as e:
return self._handle_exception(documents, e)

final_results = []
for r in top_documents:
try:
doc = documents[r["document_id"]]
score = float(r["score"])
except Exception as e:
return self._handle_exception(documents, e)

doc.metadata["relevance_score"] = score
final_results.append(doc)
return final_results

@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> "ListRerank":
"""Create a ListRerank document compressor from a language model.

Args:
llm: The language model to use for filtering.
prompt: The prompt to use for the filter.
**kwargs: Additional arguments to pass to the constructor.

Returns:
A ListRerank document compressor that uses the given language model.
"""
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
llm_chain = LLMChain(llm=llm, prompt=_prompt)
return cls(llm_chain=llm_chain, **kwargs)

def _handle_exception(
self, documents: Sequence[Document], exception: Exception
) -> Sequence[Document]:
"""Return the top documents by original ranking or raise an exception."""
if self.fallback:
logger.warning(
"Failed to generate or parse LLM response. "
"Falling back to original scores."
)
return documents[: self.top_n]
else:
raise exception
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest

from langchain.retrievers.document_compressors.list_rerank import ListRerank
from langchain.schema import Document
from tests.unit_tests.llms.fake_llm import FakeLLM

query = "Do you have a pencil?"
top_n = 2
input_docs = [
Document(page_content="I have a pen."),
Document(page_content="Do you have a pen?"),
Document(page_content="I have a bag."),
]


def test__list_rerank_success() -> None:
llm = FakeLLM(
queries={
query: """
```json
{
"reranked_documents": [
{"document_id": 1, "score": 0.99},
{"document_id": 0, "score": 0.95},
{"document_id": 2, "score": 0.50}
]
}
```
"""
},
sequential_responses=True,
)

list_rerank = ListRerank.from_llm(llm=llm, top_n=top_n)
output_docs = list_rerank.compress_documents(input_docs, query)

assert len(output_docs) == top_n
assert output_docs[0].metadata["relevance_score"] == 0.99
assert output_docs[0].page_content == "Do you have a pen?"


def test__list_rerank_error() -> None:
llm = FakeLLM(
queries={
query: """
```json
{
"reranked_documents": [
{"<>": 1, "score": 0.99},
{"document_id": 0, "score": 0.95},
{"document_id": 2, "score": 0.50}
]
}
```
"""
},
sequential_responses=True,
)

list_rerank = ListRerank.from_llm(llm=llm, top_n=top_n)

with pytest.raises(KeyError) as excinfo:
list_rerank.compress_documents(input_docs, query)
assert "document_id" in str(excinfo.value)


def test__list_rerank_fallback() -> None:
llm = FakeLLM(
queries={
query: """
```json
{
"reranked_documents": [
{"<>": 1, "score": 0.99},
{"document_id": 0, "score": 0.95},
{"document_id": 2, "score": 0.50}
]
}
```
"""
},
sequential_responses=True,
)

list_rerank = ListRerank.from_llm(llm=llm, top_n=top_n, fallback=True)
output_docs = list_rerank.compress_documents(input_docs, query)
assert len(output_docs) == top_n
Loading