-
Notifications
You must be signed in to change notification settings - Fork 15.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a ListRerank
document compressor
#13311
Merged
ccurme
merged 26 commits into
langchain-ai:master
from
denver1117:feature/list-rerank-compressor
Jul 18, 2024
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
12a6033
add list reranker
denver1117 3bab92e
add tests
denver1117 963bfb8
add example to documentation
denver1117 12c0d62
fix linting errors in tests
denver1117 2407631
sort imports
denver1117 ba9a055
fix typo in docs
denver1117 8cb08b2
updated docs
denver1117 8ad9bca
cr
hwchase17 ce4332a
fix arg order and callable type
denver1117 e8720be
Merge branch 'feature/list-rerank-compressor' of github.com:denver111…
denver1117 571bfd1
fmt
baskaryan e0afa8a
fmt
baskaryan aa7470f
fmt
baskaryan 5eb4975
poetry
baskaryan 4873379
poetry
baskaryan 6314575
merge
ccurme 52da5c1
move test
ccurme 62b497a
undo change to lock file
ccurme 9abb92b
format
ccurme 1b66387
expand docstring
ccurme 1efa986
typing
ccurme e1f3e61
remove unused typeddict
ccurme d6ff137
format
ccurme ad6b700
export LLMListwiseRerank from module
ccurme 5843510
update docs
ccurme 51d18c8
fix link typo
ccurme File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
99 changes: 99 additions & 0 deletions
99
libs/langchain/langchain/retrievers/document_compressors/list_rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
"""Filter that uses an LLM to rerank documents listwise and select top-k.""" | ||
from typing import Any, Callable, Dict, Optional, Sequence | ||
|
||
from langchain.callbacks.manager import Callbacks | ||
from langchain.chains import LLMChain | ||
from langchain.prompts import PromptTemplate | ||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor | ||
from langchain.schema import BasePromptTemplate, Document | ||
from langchain.schema.language_model import BaseLanguageModel | ||
from langchain.output_parsers import StructuredOutputParser, ResponseSchema | ||
|
||
|
||
def _get_default_chain_prompt() -> PromptTemplate: | ||
prompt_template = """ | ||
{context} | ||
Query = ```{question}``` | ||
Sort the Documents by their relevance to the Query. | ||
|
||
{format_instructions} | ||
Sorted Documents: | ||
""" | ||
response_schemas = [ | ||
ResponseSchema( | ||
name="reranked_documents", | ||
description="""Reranked documents. Format: {"document_id": <int>, "score": <number>}""", | ||
type="array[dict]", | ||
) | ||
] | ||
output_parser = StructuredOutputParser.from_response_schemas(response_schemas) | ||
format_instructions = output_parser.get_format_instructions() | ||
return PromptTemplate( | ||
template=prompt_template, | ||
input_variables=["question", "context"], | ||
output_parser=output_parser, | ||
partial_variables={"format_instructions": format_instructions}, | ||
) | ||
|
||
|
||
def default_get_input(query: str, documents: Sequence[Document]) -> Dict[str, Any]: | ||
"""Return the compression chain input.""" | ||
context = "" | ||
for index, doc in enumerate(documents): | ||
context += f"Document ID: {index} ```{doc.page_content}```\n" | ||
context += f"Documents = [Document ID: 0, ..., Document ID: {len(documents) - 1}]" | ||
return {"question": query, "context": context} | ||
|
||
|
||
class ListRerank(BaseDocumentCompressor): | ||
""" | ||
Document compressor that uses `Zero-Shot Listwise Document Reranking`. | ||
|
||
Source: https://arxiv.org/pdf/2305.02156.pdf | ||
""" | ||
|
||
top_n: int = 3 | ||
"""Number of documents to return.""" | ||
|
||
llm_chain: LLMChain | ||
"""LLM wrapper to use for filtering documents.""" | ||
|
||
get_input: Callable[[str, Document], dict] = default_get_input | ||
"""Callable for constructing the chain input from the query and a sequence of Documents.""" | ||
|
||
def compress_documents( | ||
self, | ||
documents: Sequence[Document], | ||
query: str, | ||
callbacks: Optional[Callbacks] = None, | ||
) -> Sequence[Document]: | ||
"""Filter down documents based on their relevance to the query.""" | ||
_input = self.get_input(query, documents) | ||
results = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks) | ||
final_results = [] | ||
for r in results["reranked_documents"][: self.top_n]: | ||
doc = documents[r["document_id"]] | ||
doc.metadata["relevance_score"] = r["score"] | ||
final_results.append(doc) | ||
return final_results | ||
|
||
@classmethod | ||
def from_llm( | ||
cls, | ||
llm: BaseLanguageModel, | ||
prompt: Optional[BasePromptTemplate] = None, | ||
**kwargs: Any, | ||
) -> "ListRerank": | ||
"""Create a ListRerank document compressor from a language model. | ||
|
||
Args: | ||
llm: The language model to use for filtering. | ||
prompt: The prompt to use for the filter. | ||
**kwargs: Additional arguments to pass to the constructor. | ||
|
||
Returns: | ||
A ListRerank document compressor that uses the given language model. | ||
""" | ||
_prompt = prompt if prompt is not None else _get_default_chain_prompt() | ||
llm_chain = LLMChain(llm=llm, prompt=_prompt) | ||
return cls(llm_chain=llm_chain, **kwargs) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I experimented with a Pydantic parser that defines the full nested structure explicitly and saw notably more output parsing errors. Expressing the
array[dict]
type as an implicit nested type within a singleResponseSchema
type
argument was much more successful.