forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Description: Add a BM25 Retriever that do not need Elastic search - Dependencies: rank_bm25(if it is not installed it will be install by using pip, just like TFIDFRetriever do) - Tag maintainer: @rlancemartin, @eyurtsev - Twitter handle: DayuanJian21687 --------- Co-authored-by: Bagatur <[email protected]>
- Loading branch information
1 parent
fa0a9e5
commit ee40d37
Showing
6 changed files
with
323 additions
and
2 deletions.
There are no files selected for viewing
175 changes: 175 additions & 0 deletions
175
docs/extras/modules/data_connection/retrievers/integrations/bm25.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ab66dd43", | ||
"metadata": {}, | ||
"source": [ | ||
"# BM25\n", | ||
"\n", | ||
"[BM25](https://en.wikipedia.org/wiki/Okapi_BM25) also known as the Okapi BM25, is a ranking function used in information retrieval systems to estimate the relevance of documents to a given search query.\n", | ||
"\n", | ||
"This notebook goes over how to use a retriever that under the hood uses BM25 using [`rank_bm25`](https://github.com/dorianbrown/rank_bm25) package.\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a801b57c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# !pip install rank_bm25" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "393ac030", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/workspaces/langchain/.venv/lib/python3.10/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.10) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", | ||
" warnings.warn(\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from langchain.retrievers import BM25Retriever" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "aaf80e7f", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create New Retriever with Texts" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "98b1c017", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"retriever = BM25Retriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c016b266", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create a New Retriever with Documents\n", | ||
"\n", | ||
"You can now create a new retriever with the documents you created." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "53af4f00", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.schema import Document\n", | ||
"\n", | ||
"retriever = BM25Retriever.from_documents(\n", | ||
" [\n", | ||
" Document(page_content=\"foo\"),\n", | ||
" Document(page_content=\"bar\"),\n", | ||
" Document(page_content=\"world\"),\n", | ||
" Document(page_content=\"hello\"),\n", | ||
" Document(page_content=\"foo bar\"),\n", | ||
" ]\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "08437fa2", | ||
"metadata": {}, | ||
"source": [ | ||
"## Use Retriever\n", | ||
"\n", | ||
"We can now use the retriever!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "c0455218", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"result = retriever.get_relevant_documents(\"foo\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "7dfa5c29", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[Document(page_content='foo', metadata={}),\n", | ||
" Document(page_content='foo bar', metadata={}),\n", | ||
" Document(page_content='hello', metadata={}),\n", | ||
" Document(page_content='world', metadata={})]" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"result" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "997aaa8d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" | ||
BM25 Retriever without elastic search | ||
""" | ||
|
||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, Callable, Dict, Iterable, List, Optional | ||
|
||
from langchain.callbacks.manager import ( | ||
AsyncCallbackManagerForRetrieverRun, | ||
CallbackManagerForRetrieverRun, | ||
) | ||
from langchain.schema import BaseRetriever, Document | ||
|
||
|
||
def default_preprocessing_func(text: str) -> List[str]: | ||
return text.split() | ||
|
||
|
||
class BM25Retriever(BaseRetriever): | ||
vectorizer: Any | ||
docs: List[Document] | ||
k: int = 4 | ||
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
arbitrary_types_allowed = True | ||
|
||
@classmethod | ||
def from_texts( | ||
cls, | ||
texts: Iterable[str], | ||
metadatas: Optional[Iterable[dict]] = None, | ||
bm25_params: Optional[Dict[str, Any]] = None, | ||
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, | ||
**kwargs: Any, | ||
) -> BM25Retriever: | ||
try: | ||
from rank_bm25 import BM25Okapi | ||
except ImportError: | ||
raise ImportError( | ||
"Could not import rank_bm25, please install with `pip install " | ||
"rank_bm25`." | ||
) | ||
|
||
texts_processed = [preprocess_func(t) for t in texts] | ||
bm25_params = bm25_params or {} | ||
vectorizer = BM25Okapi(texts_processed, **bm25_params) | ||
metadatas = metadatas or ({} for _ in texts) | ||
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] | ||
return cls( | ||
vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs | ||
) | ||
|
||
@classmethod | ||
def from_documents( | ||
cls, | ||
documents: Iterable[Document], | ||
*, | ||
bm25_params: Optional[Dict[str, Any]] = None, | ||
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, | ||
**kwargs: Any, | ||
) -> BM25Retriever: | ||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) | ||
return cls.from_texts( | ||
texts=texts, | ||
bm25_params=bm25_params, | ||
metadatas=metadatas, | ||
preprocess_func=preprocess_func, | ||
**kwargs, | ||
) | ||
|
||
def _get_relevant_documents( | ||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | ||
) -> List[Document]: | ||
processed_query = self.preprocess_func(query) | ||
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k) | ||
return return_docs | ||
|
||
async def _aget_relevant_documents( | ||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun | ||
) -> List[Document]: | ||
raise NotImplementedError |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import pytest | ||
|
||
from langchain.retrievers.bm25 import BM25Retriever | ||
from langchain.schema import Document | ||
|
||
|
||
@pytest.mark.requires("rank_bm25") | ||
def test_from_texts() -> None: | ||
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] | ||
bm25_retriever = BM25Retriever.from_texts(texts=input_texts) | ||
assert len(bm25_retriever.docs) == 3 | ||
assert bm25_retriever.vectorizer.doc_len == [4, 5, 4] | ||
|
||
|
||
@pytest.mark.requires("rank_bm25") | ||
def test_from_texts_with_bm25_params() -> None: | ||
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] | ||
bm25_retriever = BM25Retriever.from_texts( | ||
texts=input_texts, bm25_params={"epsilon": 10} | ||
) | ||
# should count only multiple words (have, pan) | ||
assert bm25_retriever.vectorizer.epsilon == 10 | ||
|
||
|
||
@pytest.mark.requires("rank_bm25") | ||
def test_from_documents() -> None: | ||
input_docs = [ | ||
Document(page_content="I have a pen."), | ||
Document(page_content="Do you have a pen?"), | ||
Document(page_content="I have a bag."), | ||
] | ||
bm25_retriever = BM25Retriever.from_documents(documents=input_docs) | ||
assert len(bm25_retriever.docs) == 3 | ||
assert bm25_retriever.vectorizer.doc_len == [4, 5, 4] |