Skip to content

Commit

Permalink
Add a ListRerank document compressor (langchain-ai#13311)
Browse files Browse the repository at this point in the history
- **Description:** This PR adds a new document compressor called
`ListRerank`. It's derived from `BaseDocumentCompressor`. It's a near
exact implementation of introduced by this paper: [Zero-Shot Listwise
Document Reranking with a Large Language
Model](https://arxiv.org/pdf/2305.02156.pdf) which it finds to
outperform pointwise reranking, which is somewhat implemented in
LangChain as
[LLMChainFilter](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py).
- **Issue:** None
- **Dependencies:** None
- **Tag maintainer:** @hwchase17 @izzymsft
- **Twitter handle:** @HarrisEMitchell

Notes:
1. I didn't add anything to `docs`. I wasn't exactly sure which patterns
to follow as [cohere reranker is under
Retrievers](https://python.langchain.com/docs/integrations/retrievers/cohere-reranker)
with other external document retrieval integrations, but other
contextual compression is
[here](https://python.langchain.com/docs/modules/data_connection/retrievers/contextual_compression/).
Happy to contribute to either with some direction.
2. I followed syntax, docstrings, implementation patterns, etc. as well
as I could looking at nearby modules. One thing I didn't do was put the
default prompt in a separate `.py` file like [Chain
Filter](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_filter_prompt.py)
and [Chain
Extract](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_extract_prompt.py).
Happy to follow that pattern if it would be preferred.

---------

Co-authored-by: Harrison Chase <[email protected]>
Co-authored-by: Bagatur <[email protected]>
Co-authored-by: Chester Curme <[email protected]>
  • Loading branch information
4 people authored and olgamurraft committed Aug 16, 2024
1 parent 6355e9f commit 9d34f32
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 1 deletion.
53 changes: 52 additions & 1 deletion docs/docs/how_to/contextual_compression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,57 @@
"pretty_print_docs(compressed_docs)"
]
},
{
"cell_type": "markdown",
"id": "14002ec8-7ee5-4f91-9315-dd21c3808776",
"metadata": {},
"source": [
"### `LLMListwiseRerank`\n",
"\n",
"[LLMListwiseRerank](https://api.python.langchain.com/en/latest/retrievers/langchain.retrievers.document_compressors.listwise_rerank.LLMListwiseRerank.html) uses [zero-shot listwise document reranking](https://arxiv.org/pdf/2305.02156) and functions similarly to `LLMChainFilter` as a robust but more expensive option. It is recommended to use a more powerful LLM.\n",
"\n",
"Note that `LLMListwiseRerank` requires a model with the [with_structured_output](/docs/integrations/chat/) method implemented."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4ab9ee9f-917e-4d6f-9344-eb7f01533228",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Document 1:\n",
"\n",
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
"\n",
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
"\n",
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
"\n",
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n"
]
}
],
"source": [
"from langchain.retrievers.document_compressors import LLMListwiseRerank\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n",
"\n",
"_filter = LLMListwiseRerank.from_llm(llm, top_n=1)\n",
"compression_retriever = ContextualCompressionRetriever(\n",
" base_compressor=_filter, base_retriever=retriever\n",
")\n",
"\n",
"compressed_docs = compression_retriever.invoke(\n",
" \"What did the president say about Ketanji Jackson Brown\"\n",
")\n",
"pretty_print_docs(compressed_docs)"
]
},
{
"cell_type": "markdown",
"id": "7194da42",
Expand Down Expand Up @@ -295,7 +346,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "617a1756",
"metadata": {},
"outputs": [],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from langchain.retrievers.document_compressors.embeddings_filter import (
EmbeddingsFilter,
)
from langchain.retrievers.document_compressors.listwise_rerank import (
LLMListwiseRerank,
)

_module_lookup = {
"FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank",
Expand All @@ -31,6 +34,7 @@ def __getattr__(name: str) -> Any:
__all__ = [
"DocumentCompressorPipeline",
"EmbeddingsFilter",
"LLMListwiseRerank",
"LLMChainExtractor",
"LLMChainFilter",
"CohereRerank",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Filter that uses an LLM to rerank documents listwise and select top-k."""

from typing import Any, Dict, List, Optional, Sequence

from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough

_default_system_tmpl = """{context}
Sort the Documents by their relevance to the Query."""
_DEFAULT_PROMPT = ChatPromptTemplate.from_messages(
[("system", _default_system_tmpl), ("human", "{query}")],
)


def _get_prompt_input(input_: dict) -> Dict[str, Any]:
"""Return the compression chain input."""
documents = input_["documents"]
context = ""
for index, doc in enumerate(documents):
context += f"Document ID: {index}\n```{doc.page_content}```\n\n"
context += f"Documents = [Document ID: 0, ..., Document ID: {len(documents) - 1}]"
return {"query": input_["query"], "context": context}


def _parse_ranking(results: dict) -> List[Document]:
ranking = results["ranking"]
docs = results["documents"]
return [docs[i] for i in ranking.ranked_document_ids]


class LLMListwiseRerank(BaseDocumentCompressor):
"""Document compressor that uses `Zero-Shot Listwise Document Reranking`.
Adapted from: https://arxiv.org/pdf/2305.02156.pdf
``LLMListwiseRerank`` uses a language model to rerank a list of documents based on
their relevance to a query.
**NOTE**: requires that underlying model implement ``with_structured_output``.
Example usage:
.. code-block:: python
from langchain.retrievers.document_compressors.listwise_rerank import (
LLMListwiseRerank,
)
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
documents = [
Document("Sally is my friend from school"),
Document("Steve is my friend from home"),
Document("I didn't always like yogurt"),
Document("I wonder why it's called football"),
Document("Where's waldo"),
]
reranker = LLMListwiseRerank.from_llm(
llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3
)
compressed_docs = reranker.compress_documents(documents, "Who is steve")
assert len(compressed_docs) == 3
assert "Steve" in compressed_docs[0].page_content
"""

reranker: Runnable[Dict, List[Document]]
"""LLM-based reranker to use for filtering documents. Expected to take in a dict
with 'documents: Sequence[Document]' and 'query: str' keys and output a
List[Document]."""

top_n: int = 3
"""Number of documents to return."""

class Config:
arbitrary_types_allowed = True

def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents based on their relevance to the query."""
results = self.reranker.invoke(
{"documents": documents, "query": query}, config={"callbacks": callbacks}
)
return results[: self.top_n]

@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> "LLMListwiseRerank":
"""Create a LLMListwiseRerank document compressor from a language model.
Args:
llm: The language model to use for filtering. **Must implement
BaseLanguageModel.with_structured_output().**
prompt: The prompt to use for the filter.
**kwargs: Additional arguments to pass to the constructor.
Returns:
A LLMListwiseRerank document compressor that uses the given language model.
"""

if llm.with_structured_output == BaseLanguageModel.with_structured_output:
raise ValueError(
f"llm of type {type(llm)} does not implement `with_structured_output`."
)

class RankDocuments(BaseModel):
"""Rank the documents by their relevance to the user question.
Rank from most to least relevant."""

ranked_document_ids: List[int] = Field(
...,
description=(
"The integer IDs of the documents, sorted from most to least "
"relevant to the user question."
),
)

_prompt = prompt if prompt is not None else _DEFAULT_PROMPT
reranker = RunnablePassthrough.assign(
ranking=RunnableLambda(_get_prompt_input)
| _prompt
| llm.with_structured_output(RankDocuments)
) | RunnableLambda(_parse_ranking)
return cls(reranker=reranker, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from langchain_core.documents import Document

from langchain.retrievers.document_compressors.listwise_rerank import LLMListwiseRerank


def test_list_rerank() -> None:
from langchain_openai import ChatOpenAI

documents = [
Document("Sally is my friend from school"),
Document("Steve is my friend from home"),
Document("I didn't always like yogurt"),
Document("I wonder why it's called football"),
Document("Where's waldo"),
]

reranker = LLMListwiseRerank.from_llm(
llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3
)
compressed_docs = reranker.compress_documents(documents, "Who is steve")
assert len(compressed_docs) == 3
assert "Steve" in compressed_docs[0].page_content
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from langchain.retrievers.document_compressors.listwise_rerank import LLMListwiseRerank


@pytest.mark.requires("langchain_openai")
def test__list_rerank_init() -> None:
from langchain_openai import ChatOpenAI

LLMListwiseRerank.from_llm(
llm=ChatOpenAI(api_key="foo"), # type: ignore[arg-type]
top_n=10,
)

0 comments on commit 9d34f32

Please sign in to comment.