From 12a60330a30645e85db72aa35665e69f7539296f Mon Sep 17 00:00:00 2001 From: denver1117 Date: Mon, 13 Nov 2023 15:27:08 -0700 Subject: [PATCH 01/23] add list reranker --- .../document_compressors/list_rerank.py | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 libs/langchain/langchain/retrievers/document_compressors/list_rerank.py diff --git a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py new file mode 100644 index 0000000000000..cab84a170943a --- /dev/null +++ b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py @@ -0,0 +1,99 @@ +"""Filter that uses an LLM to rerank documents listwise and select top-k.""" +from typing import Any, Callable, Dict, Optional, Sequence + +from langchain.callbacks.manager import Callbacks +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate +from langchain.retrievers.document_compressors.base import BaseDocumentCompressor +from langchain.schema import BasePromptTemplate, Document +from langchain.schema.language_model import BaseLanguageModel +from langchain.output_parsers import StructuredOutputParser, ResponseSchema + + +def _get_default_chain_prompt() -> PromptTemplate: + prompt_template = """ +{context} +Query = ```{question}``` +Sort the Documents by their relevance to the Query. + +{format_instructions} +Sorted Documents: + """ + response_schemas = [ + ResponseSchema( + name="reranked_documents", + description="""Reranked documents. Format: {"document_id": , "score": }""", + type="array[dict]", + ) + ] + output_parser = StructuredOutputParser.from_response_schemas(response_schemas) + format_instructions = output_parser.get_format_instructions() + return PromptTemplate( + template=prompt_template, + input_variables=["question", "context"], + output_parser=output_parser, + partial_variables={"format_instructions": format_instructions}, + ) + + +def default_get_input(query: str, documents: Sequence[Document]) -> Dict[str, Any]: + """Return the compression chain input.""" + context = "" + for index, doc in enumerate(documents): + context += f"Document ID: {index} ```{doc.page_content}```\n" + context += f"Documents = [Document ID: 0, ..., Document ID: {len(documents) - 1}]" + return {"question": query, "context": context} + + +class ListRerank(BaseDocumentCompressor): + """ + Document compressor that uses `Zero-Shot Listwise Document Reranking`. + + Source: https://arxiv.org/pdf/2305.02156.pdf + """ + + top_n: int = 3 + """Number of documents to return.""" + + llm_chain: LLMChain + """LLM wrapper to use for filtering documents.""" + + get_input: Callable[[str, Document], dict] = default_get_input + """Callable for constructing the chain input from the query and a sequence of Documents.""" + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """Filter down documents based on their relevance to the query.""" + _input = self.get_input(query, documents) + results = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks) + final_results = [] + for r in results["reranked_documents"][: self.top_n]: + doc = documents[r["document_id"]] + doc.metadata["relevance_score"] = r["score"] + final_results.append(doc) + return final_results + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: Optional[BasePromptTemplate] = None, + **kwargs: Any, + ) -> "ListRerank": + """Create a ListRerank document compressor from a language model. + + Args: + llm: The language model to use for filtering. + prompt: The prompt to use for the filter. + **kwargs: Additional arguments to pass to the constructor. + + Returns: + A ListRerank document compressor that uses the given language model. + """ + _prompt = prompt if prompt is not None else _get_default_chain_prompt() + llm_chain = LLMChain(llm=llm, prompt=_prompt) + return cls(llm_chain=llm_chain, **kwargs) From 3bab92e9c248a796462df217d370753798c7ff44 Mon Sep 17 00:00:00 2001 From: denver1117 Date: Fri, 24 Nov 2023 11:50:37 -0500 Subject: [PATCH 02/23] add tests --- .../document_compressors/list_rerank.py | 37 ++++++++++-- .../document_compressors/__init__.py | 0 .../document_compressors/test_list_rerank.py | 57 +++++++++++++++++++ 3 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/retrievers/document_compressors/__init__.py create mode 100644 libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py diff --git a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py index cab84a170943a..66696bfa5931e 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py @@ -1,4 +1,6 @@ """Filter that uses an LLM to rerank documents listwise and select top-k.""" +import logging + from typing import Any, Callable, Dict, Optional, Sequence from langchain.callbacks.manager import Callbacks @@ -9,6 +11,8 @@ from langchain.schema.language_model import BaseLanguageModel from langchain.output_parsers import StructuredOutputParser, ResponseSchema +logger = logging.getLogger(__name__) + def _get_default_chain_prompt() -> PromptTemplate: prompt_template = """ @@ -61,6 +65,9 @@ class ListRerank(BaseDocumentCompressor): get_input: Callable[[str, Document], dict] = default_get_input """Callable for constructing the chain input from the query and a sequence of Documents.""" + fallback: bool = False + """Whether to fallback to the original document scores if the LLM fails.""" + def compress_documents( self, documents: Sequence[Document], @@ -69,11 +76,21 @@ def compress_documents( ) -> Sequence[Document]: """Filter down documents based on their relevance to the query.""" _input = self.get_input(query, documents) - results = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks) + try: + results = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks) + top_documents = results["reranked_documents"][: self.top_n] + except Exception as e: + return self._handle_exception(documents, e) + final_results = [] - for r in results["reranked_documents"][: self.top_n]: - doc = documents[r["document_id"]] - doc.metadata["relevance_score"] = r["score"] + for r in top_documents: + try: + doc = documents[r["document_id"]] + score = float(r["score"]) + except Exception as e: + return self._handle_exception(documents, e) + + doc.metadata["relevance_score"] = score final_results.append(doc) return final_results @@ -97,3 +114,15 @@ def from_llm( _prompt = prompt if prompt is not None else _get_default_chain_prompt() llm_chain = LLMChain(llm=llm, prompt=_prompt) return cls(llm_chain=llm_chain, **kwargs) + + def _handle_exception( + self, documents: Sequence[Document], exception: Exception + ) -> Sequence[Document]: + """Return the top documents by original ranking or raise an exception.""" + if self.fallback: + logger.warning( + "Failed to generate or parse LLM response. Falling back to original scores." + ) + return documents[: self.top_n] + else: + raise exception diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/__init__.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py new file mode 100644 index 0000000000000..a389c85cbed06 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py @@ -0,0 +1,57 @@ +import pytest + +from langchain.retrievers.document_compressors.list_rerank import ListRerank +from langchain.schema import Document +from tests.unit_tests.llms.fake_llm import FakeLLM + +query = "Do you have a pencil?" +top_n = 2 +input_docs = [ + Document(page_content="I have a pen."), + Document(page_content="Do you have a pen?"), + Document(page_content="I have a bag."), +] + + +def test__list_rerank_success() -> None: + llm = FakeLLM( + queries={ + query: '```json {"reranked_documents": [{"document_id": 1, "score": 0.99}, {"document_id": 0, "score": 0.95}, {"document_id": 2, "score": 0.50}]}```' + }, + sequential_responses=True, + ) + + list_rerank = ListRerank.from_llm(llm=llm, top_n=top_n) + output_docs = list_rerank.compress_documents(input_docs, query) + + assert len(output_docs) == top_n + assert output_docs[0].metadata["relevance_score"] == 0.99 + assert output_docs[0].page_content == "Do you have a pen?" + + +def test__list_rerank_error() -> None: + llm = FakeLLM( + queries={ + query: '```json {"reranked_documents": [{"<>": 1, "score": 0.99}, {"document_id": 0, "score": 0.95}, {"document_id": 2, "score": 0.50}]}```' + }, + sequential_responses=True, + ) + + list_rerank = ListRerank.from_llm(llm=llm, top_n=top_n) + + with pytest.raises(KeyError) as excinfo: + output_docs = list_rerank.compress_documents(input_docs, query) + assert "document_id" in str(excinfo.value) + + +def test__list_rerank_fallback() -> None: + llm = FakeLLM( + queries={ + query: '```json {"reranked_documents": [{"<>": 1, "score": 0.99}, {"document_id": 0, "score": 0.95}, {"document_id": 2, "score": 0.50}]}```' + }, + sequential_responses=True, + ) + + list_rerank = ListRerank.from_llm(llm=llm, top_n=top_n, fallback=True) + output_docs = list_rerank.compress_documents(input_docs, query) + assert len(output_docs) == top_n From 963bfb8020afb3be6d5ca2d4d1a16dac84470047 Mon Sep 17 00:00:00 2001 From: denver1117 Date: Wed, 29 Nov 2023 17:48:41 -0700 Subject: [PATCH 03/23] add example to documentation --- .../contextual_compression/index.mdx | 45 +++++++++++++++++++ .../document_compressors/__init__.py | 2 + 2 files changed, 47 insertions(+) diff --git a/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx b/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx index 7d6a623929fb4..09091862c3a09 100644 --- a/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx +++ b/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx @@ -227,6 +227,51 @@ pretty_print_docs(compressed_docs) +### `ListRerank` + +`ListRerank` uses [zero-shot listwise document reranking](https://arxiv.org/pdf/2305.02156.pdf) and functions similarly to `LLMChainFilter` as a robust but more expensive option. + + +```python +from langchain.retrievers.document_compressors import ListRerank + +_filter = LLMChainFilter.from_llm(llm, top_n=1) +compression_retriever = ContextualCompressionRetriever(base_compressor=_filter, base_retriever=retriever) + +compressed_docs = compression_retriever.get_relevant_documents("What did the president say about Ketanji Jackson Brown") +pretty_print_docs(compressed_docs) +``` + + + +``` + Document 1: + + Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. + + Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. + + One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. + + And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. +``` + + + +Similar to `CohereRerank`, `ListRerank` will add a relevant score to the document metadata: + +```python +print(compressed_docs[0].metadata["relevance_score"]) +``` + + + +``` +0.99 +``` + + + # Stringing compressors and document transformers together Using the `DocumentCompressorPipeline` we can also easily combine multiple compressors in sequence. Along with compressors we can add `BaseDocumentTransformer`s to our pipeline, which don't perform any contextual compression but simply perform some transformation on a set of documents. For example `TextSplitter`s can be used as document transformers to split documents into smaller pieces, and the `EmbeddingsRedundantFilter` can be used to filter out redundant documents based on embedding similarity between documents. diff --git a/libs/langchain/langchain/retrievers/document_compressors/__init__.py b/libs/langchain/langchain/retrievers/document_compressors/__init__.py index 410ad540d19e8..0a58aa0e5dbd9 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/__init__.py +++ b/libs/langchain/langchain/retrievers/document_compressors/__init__.py @@ -9,6 +9,7 @@ from langchain.retrievers.document_compressors.embeddings_filter import ( EmbeddingsFilter, ) +from langchain.retrievers.document_compressors.list_rerank import ListRerank __all__ = [ "DocumentCompressorPipeline", @@ -16,4 +17,5 @@ "LLMChainExtractor", "LLMChainFilter", "CohereRerank", + "ListRerank", ] From 12c0d62a5758a40d5a3ac206419e5f368335a71b Mon Sep 17 00:00:00 2001 From: denver1117 Date: Wed, 29 Nov 2023 18:07:28 -0700 Subject: [PATCH 04/23] fix linting errors in tests --- .../document_compressors/list_rerank.py | 10 +++-- .../document_compressors/test_list_rerank.py | 38 +++++++++++++++++-- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py index 66696bfa5931e..2d11b077cedda 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py @@ -23,10 +23,13 @@ def _get_default_chain_prompt() -> PromptTemplate: {format_instructions} Sorted Documents: """ + description = ( + """Reranked documents. Format: {"document_id": , "score": }""" + ) response_schemas = [ ResponseSchema( name="reranked_documents", - description="""Reranked documents. Format: {"document_id": , "score": }""", + description=description, type="array[dict]", ) ] @@ -63,7 +66,7 @@ class ListRerank(BaseDocumentCompressor): """LLM wrapper to use for filtering documents.""" get_input: Callable[[str, Document], dict] = default_get_input - """Callable for constructing the chain input from the query and a sequence of Documents.""" + """Callable for constructing the chain input from the query and Documents.""" fallback: bool = False """Whether to fallback to the original document scores if the LLM fails.""" @@ -121,7 +124,8 @@ def _handle_exception( """Return the top documents by original ranking or raise an exception.""" if self.fallback: logger.warning( - "Failed to generate or parse LLM response. Falling back to original scores." + "Failed to generate or parse LLM response. " + "Falling back to original scores." ) return documents[: self.top_n] else: diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py index a389c85cbed06..8e41d688f0969 100644 --- a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py +++ b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py @@ -16,7 +16,17 @@ def test__list_rerank_success() -> None: llm = FakeLLM( queries={ - query: '```json {"reranked_documents": [{"document_id": 1, "score": 0.99}, {"document_id": 0, "score": 0.95}, {"document_id": 2, "score": 0.50}]}```' + query: """ + ```json + { + "reranked_documents": [ + {"document_id": 1, "score": 0.99}, + {"document_id": 0, "score": 0.95}, + {"document_id": 2, "score": 0.50} + ] + } + ``` + """ }, sequential_responses=True, ) @@ -32,7 +42,17 @@ def test__list_rerank_success() -> None: def test__list_rerank_error() -> None: llm = FakeLLM( queries={ - query: '```json {"reranked_documents": [{"<>": 1, "score": 0.99}, {"document_id": 0, "score": 0.95}, {"document_id": 2, "score": 0.50}]}```' + query: """ + ```json + { + "reranked_documents": [ + {"<>": 1, "score": 0.99}, + {"document_id": 0, "score": 0.95}, + {"document_id": 2, "score": 0.50} + ] + } + ``` + """ }, sequential_responses=True, ) @@ -40,14 +60,24 @@ def test__list_rerank_error() -> None: list_rerank = ListRerank.from_llm(llm=llm, top_n=top_n) with pytest.raises(KeyError) as excinfo: - output_docs = list_rerank.compress_documents(input_docs, query) + list_rerank.compress_documents(input_docs, query) assert "document_id" in str(excinfo.value) def test__list_rerank_fallback() -> None: llm = FakeLLM( queries={ - query: '```json {"reranked_documents": [{"<>": 1, "score": 0.99}, {"document_id": 0, "score": 0.95}, {"document_id": 2, "score": 0.50}]}```' + query: """ + ```json + { + "reranked_documents": [ + {"<>": 1, "score": 0.99}, + {"document_id": 0, "score": 0.95}, + {"document_id": 2, "score": 0.50} + ] + } + ``` + """ }, sequential_responses=True, ) From 240763124c4ce95e1a6189a2fe5196295326b650 Mon Sep 17 00:00:00 2001 From: denver1117 Date: Wed, 29 Nov 2023 18:11:06 -0700 Subject: [PATCH 05/23] sort imports --- .../langchain/retrievers/document_compressors/list_rerank.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py index 2d11b077cedda..9509731208821 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py @@ -1,15 +1,14 @@ """Filter that uses an LLM to rerank documents listwise and select top-k.""" import logging - from typing import Any, Callable, Dict, Optional, Sequence from langchain.callbacks.manager import Callbacks from langchain.chains import LLMChain +from langchain.output_parsers import StructuredOutputParser, ResponseSchema from langchain.prompts import PromptTemplate from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.schema import BasePromptTemplate, Document from langchain.schema.language_model import BaseLanguageModel -from langchain.output_parsers import StructuredOutputParser, ResponseSchema logger = logging.getLogger(__name__) From ba9a055b80bbc3c8a3a2375979e3b248ec1068e3 Mon Sep 17 00:00:00 2001 From: denver1117 Date: Wed, 29 Nov 2023 18:24:56 -0700 Subject: [PATCH 06/23] fix typo in docs --- .../data_connection/retrievers/contextual_compression/index.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx b/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx index 09091862c3a09..3dae09152b4ef 100644 --- a/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx +++ b/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx @@ -235,7 +235,7 @@ pretty_print_docs(compressed_docs) ```python from langchain.retrievers.document_compressors import ListRerank -_filter = LLMChainFilter.from_llm(llm, top_n=1) +_filter = ListRerank.from_llm(llm, top_n=1) compression_retriever = ContextualCompressionRetriever(base_compressor=_filter, base_retriever=retriever) compressed_docs = compression_retriever.get_relevant_documents("What did the president say about Ketanji Jackson Brown") From 8cb08b27d6143d98a0f3e3d2431bd321caa3b1cd Mon Sep 17 00:00:00 2001 From: denver1117 Date: Thu, 30 Nov 2023 08:19:56 -0700 Subject: [PATCH 07/23] updated docs --- .../retrievers/contextual_compression/index.mdx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx b/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx index 3dae09152b4ef..eef93418ab3ad 100644 --- a/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx +++ b/docs/docs/modules/data_connection/retrievers/contextual_compression/index.mdx @@ -229,12 +229,14 @@ pretty_print_docs(compressed_docs) ### `ListRerank` -`ListRerank` uses [zero-shot listwise document reranking](https://arxiv.org/pdf/2305.02156.pdf) and functions similarly to `LLMChainFilter` as a robust but more expensive option. +`ListRerank` uses [zero-shot listwise document reranking](https://arxiv.org/pdf/2305.02156.pdf) and functions similarly to `LLMChainFilter` as a robust but more expensive option. It is recommended to use a more powerful LLM. ```python +from langchain.chat_models import ChatOpenAI from langchain.retrievers.document_compressors import ListRerank +llm = ChatOpenAI(model="gpt-3.5-turbo-1106", temperature=0) _filter = ListRerank.from_llm(llm, top_n=1) compression_retriever = ContextualCompressionRetriever(base_compressor=_filter, base_retriever=retriever) @@ -267,7 +269,7 @@ print(compressed_docs[0].metadata["relevance_score"]) ``` -0.99 +0.9 ``` From 8ad9bcae2eb7bdbb6384c29463a8b0396dc0f9f4 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 4 Dec 2023 19:57:40 -0800 Subject: [PATCH 08/23] cr --- .../langchain/retrievers/document_compressors/list_rerank.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py index 9509731208821..8fd26c18253ce 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py @@ -4,7 +4,7 @@ from langchain.callbacks.manager import Callbacks from langchain.chains import LLMChain -from langchain.output_parsers import StructuredOutputParser, ResponseSchema +from langchain.output_parsers import ResponseSchema, StructuredOutputParser from langchain.prompts import PromptTemplate from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.schema import BasePromptTemplate, Document From ce4332a0694b87d302bc085ed218504e872aa6b0 Mon Sep 17 00:00:00 2001 From: denver1117 Date: Tue, 5 Dec 2023 09:50:00 -0700 Subject: [PATCH 09/23] fix arg order and callable type --- .../retrievers/document_compressors/list_rerank.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py index 9509731208821..3e629def4e451 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/list_rerank.py @@ -58,13 +58,13 @@ class ListRerank(BaseDocumentCompressor): Source: https://arxiv.org/pdf/2305.02156.pdf """ - top_n: int = 3 - """Number of documents to return.""" - llm_chain: LLMChain """LLM wrapper to use for filtering documents.""" - get_input: Callable[[str, Document], dict] = default_get_input + top_n: int = 3 + """Number of documents to return.""" + + get_input: Callable[[str, Sequence[Document]], dict] = default_get_input """Callable for constructing the chain input from the query and Documents.""" fallback: bool = False From e0afa8a3483b5904183b4cb5563cf7d9374830b8 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 1 Apr 2024 16:16:33 -0700 Subject: [PATCH 10/23] fmt --- .../document_compressors/listwise_rerank.py | 16 ++-- .../document_compressors/test_list_rerank.py | 87 ------------------- .../test_listwise_rerank.py | 10 +++ 3 files changed, 21 insertions(+), 92 deletions(-) delete mode 100644 libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py create mode 100644 libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py diff --git a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py index c7e22f7f31962..a5e4482c2d9c3 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class CompressorInput(TypedDict): +class _CompressorInput(TypedDict): documents: Sequence[Document] query: str @@ -25,12 +25,12 @@ class CompressorInput(TypedDict): ) -def _get_prompt_input(input_: CompressorInput) -> Dict[str, Any]: +def _get_prompt_input(input_: dict) -> Dict[str, Any]: """Return the compression chain input.""" documents = input_["documents"] context = "" for index, doc in enumerate(documents): - context += f"Document ID: {index} ```{doc.page_content}```\n" + context += f"Document ID: {index}\n```{doc.page_content}```\n\n" context += f"Documents = [Document ID: 0, ..., Document ID: {len(documents) - 1}]" return {"query": input_["query"], "context": context} @@ -47,12 +47,17 @@ class LLMListwiseRerank(BaseDocumentCompressor): Source: https://arxiv.org/pdf/2305.02156.pdf """ - reranker: Runnable[CompressorInput, List[Document]] - """LLM-based reranker to use for filtering documents.""" + reranker: Runnable[Dict, List[Document]] + """LLM-based reranker to use for filtering documents. Expected to take in a dict + with 'documents: Sequence[Document]' and 'query: str' keys and output a + List[Document].""" top_n: int = 3 """Number of documents to return.""" + class Config: + arbitrary_types_allowed = True + def compress_documents( self, documents: Sequence[Document], @@ -89,6 +94,7 @@ def from_llm( raise ValueError( f"llm of type {type(llm)} does not implement `with_structured_output`." ) + class RankDocuments(BaseModel): """Rank the documents by their relevance to the user question. Rank from most to least relevant.""" diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py deleted file mode 100644 index 708abc971360e..0000000000000 --- a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_list_rerank.py +++ /dev/null @@ -1,87 +0,0 @@ -import pytest - -from langchain.retrievers.document_compressors.listwise_rerank import ListRerank -from langchain.schema import Document -from tests.unit_tests.llms.fake_llm import FakeLLM - -query = "Do you have a pencil?" -top_n = 2 -input_docs = [ - Document(page_content="I have a pen."), - Document(page_content="Do you have a pen?"), - Document(page_content="I have a bag."), -] - - -def test__list_rerank_success() -> None: - llm = FakeLLM( - queries={ - query: """ - ```json - { - "reranked_documents": [ - {"document_id": 1, "score": 0.99}, - {"document_id": 0, "score": 0.95}, - {"document_id": 2, "score": 0.50} - ] - } - ``` - """ - }, - sequential_responses=True, - ) - - list_rerank = ListRerank.from_llm(llm=llm, top_n=top_n) - output_docs = list_rerank.compress_documents(input_docs, query) - - assert len(output_docs) == top_n - assert output_docs[0].metadata["relevance_score"] == 0.99 - assert output_docs[0].page_content == "Do you have a pen?" - - -def test__list_rerank_error() -> None: - llm = FakeLLM( - queries={ - query: """ - ```json - { - "reranked_documents": [ - {"<>": 1, "score": 0.99}, - {"document_id": 0, "score": 0.95}, - {"document_id": 2, "score": 0.50} - ] - } - ``` - """ - }, - sequential_responses=True, - ) - - list_rerank = ListRerank.from_llm(llm=llm, top_n=top_n) - - with pytest.raises(KeyError) as excinfo: - list_rerank.compress_documents(input_docs, query) - assert "document_id" in str(excinfo.value) - - -def test__list_rerank_fallback() -> None: - llm = FakeLLM( - queries={ - query: """ - ```json - { - "reranked_documents": [ - {"<>": 1, "score": 0.99}, - {"document_id": 0, "score": 0.95}, - {"document_id": 2, "score": 0.50} - ] - } - ``` - """ - }, - sequential_responses=True, - ) - - list_rerank = ListRerank.from_llm(llm=llm, top_n=top_n, fallback=True) - output_docs = list_rerank.compress_documents(input_docs, query) - assert len(output_docs) == top_n diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py new file mode 100644 index 0000000000000..d21f0f3658488 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py @@ -0,0 +1,10 @@ +import pytest + +from langchain.retrievers.document_compressors.listwise_rerank import LLMListwiseRerank + + +@pytest.mark.requires("langchain_openai") +def test__list_rerank_init() -> None: + from langchain_openai import ChatOpenAI + + LLMListwiseRerank.from_llm(llm=ChatOpenAI(api_key="foo"), top_n=10) From aa7470f6844e443998a7fcc0e8d19d98c7fbdb14 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 1 Apr 2024 16:16:46 -0700 Subject: [PATCH 11/23] fmt --- .../test_listwise_rerank.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 libs/langchain/tests/integration_tests/retrievers/document_compressors/test_listwise_rerank.py diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_listwise_rerank.py b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_listwise_rerank.py new file mode 100644 index 0000000000000..b7e6496dd8b8f --- /dev/null +++ b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_listwise_rerank.py @@ -0,0 +1,22 @@ +from langchain_core.documents import Document + +from langchain.retrievers.document_compressors.listwise_rerank import LLMListwiseRerank + + +def test_list_rerank() -> None: + from langchain_openai import ChatOpenAI + + documents = [ + Document("Sally is my friend from school"), + Document("Steve is my friend from home"), + Document("I didn't always like yogurt"), + Document("I wonder why it's called football"), + Document("Where's waldo"), + ] + + reranker = LLMListwiseRerank.from_llm( + llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3 + ) + compressed_docs = reranker.compress_documents(documents, "Who is steve") + assert len(compressed_docs) == 3 + assert "Steve" in compressed_docs[0].page_content From 5eb49759b336259eed97bf3a7ffface69acd04c3 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 1 Apr 2024 16:17:31 -0700 Subject: [PATCH 12/23] poetry --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 727fa3b03b0ae..1b9894a8d2f83 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -164,7 +164,6 @@ optional = true # https://python.langchain.com/docs/contributing/code#working-with-optional-dependencies pytest-vcr = "^1.0.2" wrapt = "^1.15.0" -openai = "^1" python-dotenv = "^1.0.0" cassio = "^0.1.0" tiktoken = ">=0.3.2,<0.6.0" @@ -172,6 +171,7 @@ anthropic = "^0.3.11" langchain-core = {path = "../core", develop = true} langchain-community = {path = "../community", develop = true} langchain-text-splitters = {path = "../text-splitters", develop = true} +langchain-openai = {path = "../partners/openai", develop = true} langchainhub = "^0.1.15" [tool.poetry.group.lint] From 48733793ebf383a116bfb18c6dff3686dec970ed Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 1 Apr 2024 16:18:41 -0700 Subject: [PATCH 13/23] poetry --- libs/langchain/poetry.lock | 12 ++++++------ libs/langchain/pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 6fa66329e36b9..08b18d3df7a73 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiodns" @@ -3469,7 +3469,7 @@ files = [ [[package]] name = "langchain-community" -version = "0.0.30" +version = "0.0.31" description = "Community contributed LangChain integrations." optional = false python-versions = ">=3.8.1,<4.0" @@ -3489,7 +3489,7 @@ tenacity = "^8.1.0" [package.extras] cli = ["typer (>=0.9.0,<0.10.0)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "premai (>=0.3.25,<0.4.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "vdms (>=0.0.20,<0.0.21)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "httpx-sse (>=0.4.0,<0.5.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "premai (>=0.3.25,<0.4.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pyjwt (>=2.8.0,<3.0.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "vdms (>=0.0.20,<0.0.21)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] [package.source] type = "directory" @@ -3524,7 +3524,7 @@ url = "../core" name = "langchain-openai" version = "0.0.8" description = "An integration package connecting OpenAI and LangChain" -optional = true +optional = false python-versions = ">=3.8.1,<4.0" files = [ {file = "langchain_openai-0.0.8-py3-none-any.whl", hash = "sha256:4862fc72cecbee0240aaa6df0234d5893dd30cd33ca23ac5cfdd86c11d2c44df"}, @@ -3549,7 +3549,7 @@ develop = true langchain-core = "^0.1.28" [package.extras] -extended-testing = ["lxml (>=4.9.3,<6.0)"] +extended-testing = ["beautifulsoup4 (>=4.12.3,<5.0.0)", "lxml (>=4.9.3,<6.0)"] [package.source] type = "directory" @@ -9411,4 +9411,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "d032ef20444c420c7b53af86704bddffe24705bd7c97644dd2e47c9a922dd154" +content-hash = "7fb2e2f955bb010565c2024cf24372cda3aac56ad7e07e5fa994fea897605801" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 1b9894a8d2f83..78b7debe8e5de 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -171,7 +171,7 @@ anthropic = "^0.3.11" langchain-core = {path = "../core", develop = true} langchain-community = {path = "../community", develop = true} langchain-text-splitters = {path = "../text-splitters", develop = true} -langchain-openai = {path = "../partners/openai", develop = true} +langchain-openai = ">=0.0.2,<0.1" langchainhub = "^0.1.15" [tool.poetry.group.lint] From 52da5c1f53ef087c58074e5a5fb73ab6b330d2a5 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 18 Jul 2024 14:36:57 -0400 Subject: [PATCH 14/23] move test --- .../tests/unit_tests/retrievers/document_compressors/__init__.py | 0 .../retrievers/document_compressors/test_listwise_rerank.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/retrievers/document_compressors/__init__.py rename libs/{community => langchain}/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py (100%) diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/__init__.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/community/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py similarity index 100% rename from libs/community/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py rename to libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py From 62b497a633d9c5110ab99e992b9719214848f0d1 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 18 Jul 2024 14:38:55 -0400 Subject: [PATCH 15/23] undo change to lock file --- libs/langchain/poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index e159a85135d61..34fb07edd6c7b 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1786,7 +1786,7 @@ url = "../core" name = "langchain-openai" version = "0.1.16" description = "An integration package connecting OpenAI and LangChain" -optional = false +optional = true python-versions = ">=3.8.1,<4.0" files = [] develop = true From 9abb92b58349f9521da1cca137bbb5aa64d5c887 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 18 Jul 2024 14:45:01 -0400 Subject: [PATCH 16/23] format --- .../langchain/retrievers/document_compressors/listwise_rerank.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py index a5e4482c2d9c3..bd79a7263c48d 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py @@ -1,4 +1,5 @@ """Filter that uses an LLM to rerank documents listwise and select top-k.""" + import logging from typing import Any, Dict, List, Optional, Sequence, TypedDict From 1b66387af4a9d30d57e854b41d6e7fa94efe440e Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 18 Jul 2024 15:29:02 -0400 Subject: [PATCH 17/23] expand docstring --- .../document_compressors/listwise_rerank.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py index bd79a7263c48d..f1f554d6b5ae1 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py @@ -45,7 +45,36 @@ def _parse_ranking(results: dict) -> List[Document]: class LLMListwiseRerank(BaseDocumentCompressor): """Document compressor that uses `Zero-Shot Listwise Document Reranking`. - Source: https://arxiv.org/pdf/2305.02156.pdf + Adapted from: https://arxiv.org/pdf/2305.02156.pdf + + ``LLMListwiseRerank`` uses a language model to rerank a list of documents based on + their relevance to a query. + + **NOTE**: requires that underlying model implement ``with_structured_output``. + + Example usage: + .. code-block:: python + + from langchain.retrievers.document_compressors.listwise_rerank import ( + LLMListwiseRerank, + ) + from langchain_core.documents import Document + from langchain_openai import ChatOpenAI + + documents = [ + Document("Sally is my friend from school"), + Document("Steve is my friend from home"), + Document("I didn't always like yogurt"), + Document("I wonder why it's called football"), + Document("Where's waldo"), + ] + + reranker = LLMListwiseRerank.from_llm( + llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3 + ) + compressed_docs = reranker.compress_documents(documents, "Who is steve") + assert len(compressed_docs) == 3 + assert "Steve" in compressed_docs[0].page_content """ reranker: Runnable[Dict, List[Document]] From 1efa986e325527f6a85d1d2f0bb5919c4f3b90de Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 18 Jul 2024 15:29:11 -0400 Subject: [PATCH 18/23] typing --- .../retrievers/document_compressors/test_listwise_rerank.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py index d21f0f3658488..d57d9f2cffca8 100644 --- a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py +++ b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_listwise_rerank.py @@ -7,4 +7,7 @@ def test__list_rerank_init() -> None: from langchain_openai import ChatOpenAI - LLMListwiseRerank.from_llm(llm=ChatOpenAI(api_key="foo"), top_n=10) + LLMListwiseRerank.from_llm( + llm=ChatOpenAI(api_key="foo"), # type: ignore[arg-type] + top_n=10, + ) From e1f3e61060887611ef300da0bc135d7af45aa237 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 18 Jul 2024 15:29:47 -0400 Subject: [PATCH 19/23] remove unused typeddict --- .../retrievers/document_compressors/listwise_rerank.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py index f1f554d6b5ae1..b7e301ad6dbd0 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py @@ -1,7 +1,6 @@ """Filter that uses an LLM to rerank documents listwise and select top-k.""" -import logging -from typing import Any, Dict, List, Optional, Sequence, TypedDict +from typing import Any, Dict, List, Optional, Sequence from langchain_core.callbacks import Callbacks from langchain_core.documents import BaseDocumentCompressor, Document @@ -10,13 +9,6 @@ from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough -logger = logging.getLogger(__name__) - - -class _CompressorInput(TypedDict): - documents: Sequence[Document] - query: str - _default_system_tmpl = """{context} From d6ff137ec4136a63a83736c58f1810835a79323b Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 18 Jul 2024 15:33:06 -0400 Subject: [PATCH 20/23] format --- .../langchain/retrievers/document_compressors/listwise_rerank.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py index b7e301ad6dbd0..76fdef05b117b 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py @@ -9,7 +9,6 @@ from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough - _default_system_tmpl = """{context} Sort the Documents by their relevance to the Query.""" From ad6b700ad1185bfac7881ccbf4f72fb801e9dde6 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 18 Jul 2024 15:49:28 -0400 Subject: [PATCH 21/23] export LLMListwiseRerank from module --- .../langchain/retrievers/document_compressors/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libs/langchain/langchain/retrievers/document_compressors/__init__.py b/libs/langchain/langchain/retrievers/document_compressors/__init__.py index 03f66977113d2..de0710bf3b270 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/__init__.py +++ b/libs/langchain/langchain/retrievers/document_compressors/__init__.py @@ -15,6 +15,9 @@ from langchain.retrievers.document_compressors.embeddings_filter import ( EmbeddingsFilter, ) +from langchain.retrievers.document_compressors.listwise_rerank import ( + LLMListwiseRerank, +) _module_lookup = { "FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank", @@ -31,6 +34,7 @@ def __getattr__(name: str) -> Any: __all__ = [ "DocumentCompressorPipeline", "EmbeddingsFilter", + "LLMListwiseRerank", "LLMChainExtractor", "LLMChainFilter", "CohereRerank", From 5843510c292e42b9af35b503ede737f74bf6520b Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 18 Jul 2024 15:50:02 -0400 Subject: [PATCH 22/23] update docs --- docs/docs/how_to/contextual_compression.ipynb | 53 ++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/docs/docs/how_to/contextual_compression.ipynb b/docs/docs/how_to/contextual_compression.ipynb index fc2f04724383a..6400b8b8ab2b6 100644 --- a/docs/docs/how_to/contextual_compression.ipynb +++ b/docs/docs/how_to/contextual_compression.ipynb @@ -220,6 +220,57 @@ "pretty_print_docs(compressed_docs)" ] }, + { + "cell_type": "markdown", + "id": "14002ec8-7ee5-4f91-9315-dd21c3808776", + "metadata": {}, + "source": [ + "### `LLMListwiseRerank`\n", + "\n", + "[LLMListwiseRerank](https://api.python.langchain.com/en/latest/retrievers/langchain.retrievers.document_compressors.chain_filter.LLMListwiseRerank.html) uses [zero-shot listwise document reranking](https://arxiv.org/pdf/2305.02156) and functions similarly to `LLMChainFilter` as a robust but more expensive option. It is recommended to use a more powerful LLM.\n", + "\n", + "Note that `LLMListwiseRerank` requires a model with the [with_structured_output](/docs/integrations/chat/) method implemented." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4ab9ee9f-917e-4d6f-9344-eb7f01533228", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document 1:\n", + "\n", + "Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n", + "\n", + "Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n", + "\n", + "One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n", + "\n", + "And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n" + ] + } + ], + "source": [ + "from langchain.retrievers.document_compressors import LLMListwiseRerank\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n", + "\n", + "_filter = LLMListwiseRerank.from_llm(llm, top_n=1)\n", + "compression_retriever = ContextualCompressionRetriever(\n", + " base_compressor=_filter, base_retriever=retriever\n", + ")\n", + "\n", + "compressed_docs = compression_retriever.invoke(\n", + " \"What did the president say about Ketanji Jackson Brown\"\n", + ")\n", + "pretty_print_docs(compressed_docs)" + ] + }, { "cell_type": "markdown", "id": "7194da42", @@ -295,7 +346,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "617a1756", "metadata": {}, "outputs": [], From 51d18c8b9e87fe286afffeaaeef5af9eaf419235 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 18 Jul 2024 16:23:11 -0400 Subject: [PATCH 23/23] fix link typo --- docs/docs/how_to/contextual_compression.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/how_to/contextual_compression.ipynb b/docs/docs/how_to/contextual_compression.ipynb index 6400b8b8ab2b6..78ead6e7b94e9 100644 --- a/docs/docs/how_to/contextual_compression.ipynb +++ b/docs/docs/how_to/contextual_compression.ipynb @@ -227,7 +227,7 @@ "source": [ "### `LLMListwiseRerank`\n", "\n", - "[LLMListwiseRerank](https://api.python.langchain.com/en/latest/retrievers/langchain.retrievers.document_compressors.chain_filter.LLMListwiseRerank.html) uses [zero-shot listwise document reranking](https://arxiv.org/pdf/2305.02156) and functions similarly to `LLMChainFilter` as a robust but more expensive option. It is recommended to use a more powerful LLM.\n", + "[LLMListwiseRerank](https://api.python.langchain.com/en/latest/retrievers/langchain.retrievers.document_compressors.listwise_rerank.LLMListwiseRerank.html) uses [zero-shot listwise document reranking](https://arxiv.org/pdf/2305.02156) and functions similarly to `LLMChainFilter` as a robust but more expensive option. It is recommended to use a more powerful LLM.\n", "\n", "Note that `LLMListwiseRerank` requires a model with the [with_structured_output](/docs/integrations/chat/) method implemented." ]