forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a
ListRerank
document compressor (langchain-ai#13311)
- **Description:** This PR adds a new document compressor called `ListRerank`. It's derived from `BaseDocumentCompressor`. It's a near exact implementation of introduced by this paper: [Zero-Shot Listwise Document Reranking with a Large Language Model](https://arxiv.org/pdf/2305.02156.pdf) which it finds to outperform pointwise reranking, which is somewhat implemented in LangChain as [LLMChainFilter](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py). - **Issue:** None - **Dependencies:** None - **Tag maintainer:** @hwchase17 @izzymsft - **Twitter handle:** @HarrisEMitchell Notes: 1. I didn't add anything to `docs`. I wasn't exactly sure which patterns to follow as [cohere reranker is under Retrievers](https://python.langchain.com/docs/integrations/retrievers/cohere-reranker) with other external document retrieval integrations, but other contextual compression is [here](https://python.langchain.com/docs/modules/data_connection/retrievers/contextual_compression/). Happy to contribute to either with some direction. 2. I followed syntax, docstrings, implementation patterns, etc. as well as I could looking at nearby modules. One thing I didn't do was put the default prompt in a separate `.py` file like [Chain Filter](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_filter_prompt.py) and [Chain Extract](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_extract_prompt.py). Happy to follow that pattern if it would be preferred. --------- Co-authored-by: Harrison Chase <[email protected]> Co-authored-by: Bagatur <[email protected]> Co-authored-by: Chester Curme <[email protected]>
- Loading branch information
1 parent
6355e9f
commit 9d34f32
Showing
6 changed files
with
228 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
137 changes: 137 additions & 0 deletions
137
libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
"""Filter that uses an LLM to rerank documents listwise and select top-k.""" | ||
|
||
from typing import Any, Dict, List, Optional, Sequence | ||
|
||
from langchain_core.callbacks import Callbacks | ||
from langchain_core.documents import BaseDocumentCompressor, Document | ||
from langchain_core.language_models import BaseLanguageModel | ||
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate | ||
from langchain_core.pydantic_v1 import BaseModel, Field | ||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough | ||
|
||
_default_system_tmpl = """{context} | ||
Sort the Documents by their relevance to the Query.""" | ||
_DEFAULT_PROMPT = ChatPromptTemplate.from_messages( | ||
[("system", _default_system_tmpl), ("human", "{query}")], | ||
) | ||
|
||
|
||
def _get_prompt_input(input_: dict) -> Dict[str, Any]: | ||
"""Return the compression chain input.""" | ||
documents = input_["documents"] | ||
context = "" | ||
for index, doc in enumerate(documents): | ||
context += f"Document ID: {index}\n```{doc.page_content}```\n\n" | ||
context += f"Documents = [Document ID: 0, ..., Document ID: {len(documents) - 1}]" | ||
return {"query": input_["query"], "context": context} | ||
|
||
|
||
def _parse_ranking(results: dict) -> List[Document]: | ||
ranking = results["ranking"] | ||
docs = results["documents"] | ||
return [docs[i] for i in ranking.ranked_document_ids] | ||
|
||
|
||
class LLMListwiseRerank(BaseDocumentCompressor): | ||
"""Document compressor that uses `Zero-Shot Listwise Document Reranking`. | ||
Adapted from: https://arxiv.org/pdf/2305.02156.pdf | ||
``LLMListwiseRerank`` uses a language model to rerank a list of documents based on | ||
their relevance to a query. | ||
**NOTE**: requires that underlying model implement ``with_structured_output``. | ||
Example usage: | ||
.. code-block:: python | ||
from langchain.retrievers.document_compressors.listwise_rerank import ( | ||
LLMListwiseRerank, | ||
) | ||
from langchain_core.documents import Document | ||
from langchain_openai import ChatOpenAI | ||
documents = [ | ||
Document("Sally is my friend from school"), | ||
Document("Steve is my friend from home"), | ||
Document("I didn't always like yogurt"), | ||
Document("I wonder why it's called football"), | ||
Document("Where's waldo"), | ||
] | ||
reranker = LLMListwiseRerank.from_llm( | ||
llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3 | ||
) | ||
compressed_docs = reranker.compress_documents(documents, "Who is steve") | ||
assert len(compressed_docs) == 3 | ||
assert "Steve" in compressed_docs[0].page_content | ||
""" | ||
|
||
reranker: Runnable[Dict, List[Document]] | ||
"""LLM-based reranker to use for filtering documents. Expected to take in a dict | ||
with 'documents: Sequence[Document]' and 'query: str' keys and output a | ||
List[Document].""" | ||
|
||
top_n: int = 3 | ||
"""Number of documents to return.""" | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
def compress_documents( | ||
self, | ||
documents: Sequence[Document], | ||
query: str, | ||
callbacks: Optional[Callbacks] = None, | ||
) -> Sequence[Document]: | ||
"""Filter down documents based on their relevance to the query.""" | ||
results = self.reranker.invoke( | ||
{"documents": documents, "query": query}, config={"callbacks": callbacks} | ||
) | ||
return results[: self.top_n] | ||
|
||
@classmethod | ||
def from_llm( | ||
cls, | ||
llm: BaseLanguageModel, | ||
*, | ||
prompt: Optional[BasePromptTemplate] = None, | ||
**kwargs: Any, | ||
) -> "LLMListwiseRerank": | ||
"""Create a LLMListwiseRerank document compressor from a language model. | ||
Args: | ||
llm: The language model to use for filtering. **Must implement | ||
BaseLanguageModel.with_structured_output().** | ||
prompt: The prompt to use for the filter. | ||
**kwargs: Additional arguments to pass to the constructor. | ||
Returns: | ||
A LLMListwiseRerank document compressor that uses the given language model. | ||
""" | ||
|
||
if llm.with_structured_output == BaseLanguageModel.with_structured_output: | ||
raise ValueError( | ||
f"llm of type {type(llm)} does not implement `with_structured_output`." | ||
) | ||
|
||
class RankDocuments(BaseModel): | ||
"""Rank the documents by their relevance to the user question. | ||
Rank from most to least relevant.""" | ||
|
||
ranked_document_ids: List[int] = Field( | ||
..., | ||
description=( | ||
"The integer IDs of the documents, sorted from most to least " | ||
"relevant to the user question." | ||
), | ||
) | ||
|
||
_prompt = prompt if prompt is not None else _DEFAULT_PROMPT | ||
reranker = RunnablePassthrough.assign( | ||
ranking=RunnableLambda(_get_prompt_input) | ||
| _prompt | ||
| llm.with_structured_output(RankDocuments) | ||
) | RunnableLambda(_parse_ranking) | ||
return cls(reranker=reranker, **kwargs) |
22 changes: 22 additions & 0 deletions
22
...langchain/tests/integration_tests/retrievers/document_compressors/test_listwise_rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from langchain_core.documents import Document | ||
|
||
from langchain.retrievers.document_compressors.listwise_rerank import LLMListwiseRerank | ||
|
||
|
||
def test_list_rerank() -> None: | ||
from langchain_openai import ChatOpenAI | ||
|
||
documents = [ | ||
Document("Sally is my friend from school"), | ||
Document("Steve is my friend from home"), | ||
Document("I didn't always like yogurt"), | ||
Document("I wonder why it's called football"), | ||
Document("Where's waldo"), | ||
] | ||
|
||
reranker = LLMListwiseRerank.from_llm( | ||
llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3 | ||
) | ||
compressed_docs = reranker.compress_documents(documents, "Who is steve") | ||
assert len(compressed_docs) == 3 | ||
assert "Steve" in compressed_docs[0].page_content |
Empty file.
13 changes: 13 additions & 0 deletions
13
libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import pytest | ||
|
||
from langchain.retrievers.document_compressors.listwise_rerank import LLMListwiseRerank | ||
|
||
|
||
@pytest.mark.requires("langchain_openai") | ||
def test__list_rerank_init() -> None: | ||
from langchain_openai import ChatOpenAI | ||
|
||
LLMListwiseRerank.from_llm( | ||
llm=ChatOpenAI(api_key="foo"), # type: ignore[arg-type] | ||
top_n=10, | ||
) |