From 1c6410e3f2f1c8758285df17f74e826c66050a42 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Mon, 20 Nov 2023 17:26:03 +0100 Subject: [PATCH] Elasticsearch - refactor `_search_documents` (#57) * set scale_score default to False * unrelated: replace text w content * first implementation * test * fix some tests * make tests more robust; skip unsupported ones * rm unsupported test * ignore import-not-found * first chunk addressing PR feedback * improve tests * use _search_documents also in bm25 retrieval * improve logic and tests * fix format * better format * Update document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * remove wrong increment * move ruff ignore error --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --- .../elasticsearch_haystack/document_store.py | 21 ++++++--- .../tests/test_document_store.py | 44 +++++++++++++++++++ 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py index 083918d71..4d1903e9f 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py @@ -106,6 +106,10 @@ def _search_documents(self, **kwargs) -> List[Document]: Calls the Elasticsearch client's search method and handles pagination. """ + top_k = kwargs.get("size") + if top_k is None and "knn" in kwargs and "k" in kwargs["knn"]: + top_k = kwargs["knn"]["k"] + documents: List[Document] = [] from_ = 0 # Handle pagination @@ -115,8 +119,12 @@ def _search_documents(self, **kwargs) -> List[Document]: from_=from_, **kwargs, ) + documents.extend(self._deserialize_document(hit) for hit in res["hits"]["hits"]) from_ = len(documents) + + if top_k is not None and from_ >= top_k: + break if from_ >= res["hits"]["total"]["value"]: break return documents @@ -326,14 +334,13 @@ def _bm25_retrieval( if filters: body["query"]["bool"]["filter"] = _normalize_filters(filters) - res = self._client.search(index=self._index, **body) + documents = self._search_documents(**body) - docs = [] - for hit in res["hits"]["hits"]: - if scale_score: - hit["_score"] = float(1 / (1 + np.exp(-np.asarray(hit["_score"] / BM25_SCALING_FACTOR)))) - docs.append(self._deserialize_document(hit)) - return docs + if scale_score: + for doc in documents: + doc.score = float(1 / (1 + np.exp(-np.asarray(doc.score / BM25_SCALING_FACTOR)))) + + return documents def _embedding_retrieval( self, diff --git a/document_stores/elasticsearch/tests/test_document_store.py b/document_stores/elasticsearch/tests/test_document_store.py index 11443546c..1e7b3f115 100644 --- a/document_stores/elasticsearch/tests/test_document_store.py +++ b/document_stores/elasticsearch/tests/test_document_store.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: 2023-present Silvano Cerza # # SPDX-License-Identifier: Apache-2.0 + +import random from typing import List from unittest.mock import patch @@ -92,6 +94,34 @@ def test_bm25_retrieval(self, docstore: ElasticsearchDocumentStore): assert "functional" in res[1].content assert "functional" in res[2].content + def test_bm25_retrieval_pagination(self, docstore: ElasticsearchDocumentStore): + """ + Test that handling of pagination works as expected, when the matching documents are > 10. + """ + docstore.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + Document(content="Java is an object oriented programming language"), + Document(content="Javascript is a programming language"), + Document(content="Typescript is a programming language"), + Document(content="C is a programming language"), + ] + ) + + res = docstore._bm25_retrieval("programming", top_k=11) + assert len(res) == 11 + assert all("programming" in doc.content for doc in res) + def test_bm25_retrieval_with_fuzziness(self, docstore: ElasticsearchDocumentStore): docstore.write_documents( [ @@ -282,6 +312,20 @@ def test_embedding_retrieval_w_filters(self, docstore: ElasticsearchDocumentStor assert len(results) == 1 assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_pagination(self, docstore: ElasticsearchDocumentStore): + """ + Test that handling of pagination works as expected, when the matching documents are > 10. + """ + + docs = [ + Document(content=f"Document {i}", embedding=[random.random() for _ in range(4)]) # noqa: S311 + for i in range(20) + ] + + docstore.write_documents(docs) + results = docstore._embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=11, filters={}) + assert len(results) == 11 + def test_embedding_retrieval_query_documents_different_embedding_sizes(self, docstore: ElasticsearchDocumentStore): """ Test that the retrieval fails if the query embedding and the documents have different embedding sizes.