From 2489722cedafecede1d4916980fa24c114664ecc Mon Sep 17 00:00:00 2001 From: tstadel Date: Fri, 21 Jun 2024 21:00:18 +0200 Subject: [PATCH] feat: add custom_query to OpenSearch retrievers --- .../retrievers/opensearch/bm25_retriever.py | 3 +- .../opensearch/embedding_retriever.py | 21 ++- .../opensearch/document_store.py | 31 ++-- .../opensearch/tests/test_bm25_retriever.py | 4 + .../opensearch/tests/test_document_store.py | 138 ++++++++++++++++++ .../tests/test_embedding_retriever.py | 7 +- 6 files changed, 179 insertions(+), 25 deletions(-) diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py index c700339bc..37f8774bf 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py @@ -32,13 +32,12 @@ def __init__( This is useful when comparing documents across different indexes. Defaults to False. :param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents. This is useful when searching for short text where even one term can make a difference. Defaults to False. - :param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder. + :param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder **An example custom_query:** ```python { - "size": 10, "query": { "bool": { "should": [{"multi_match": { diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index 0f7012a26..6ea0e2c3c 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -31,21 +31,26 @@ def __init__( :param filters: Filters applied to the retrieved Documents. Defaults to None. Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. :param top_k: Maximum number of Documents to return, defaults to 10 - :param custom_query: The query string containing a mandatory `${query_embedding}` and an optional `${filters}` placeholder. + :param custom_query: The query string containing a mandatory `${query_embedding}` and an optional `${filters}` + placeholder **An example custom_query:** ```python { - "size": 10, "query": { "bool": { - "must": [{"knn": { - "embedding": { - "vector": ${query_embedding}, // mandatory query placeholder - "k": 10000, - }}], - "filter": ${filters} // optional filter placeholder + "must": [ + { + "knn": { + "embedding": { + "vector": ${query_embedding}, // mandatory query placeholder + "k": 10000, + } + } + } + ], + "filter": ${filters} // optional filter placeholder } } } diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 69edb7143..dbcf4594e 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -303,13 +303,12 @@ def _bm25_retrieval( :param top_k: Maximum number of Documents to return, defaults to 10 :param scale_score: If `True` scales the Document`s scores between 0 and 1, defaults to False :param all_terms_must_match: If `True` all terms in `query` must be present in the Document, defaults to False - :param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder. + :param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder **An example custom_query:** ```python { - "size": 10, "query": { "bool": { "should": [{"multi_match": { @@ -330,7 +329,7 @@ def _bm25_retrieval( body: Dict[str, Any] = {"query": {"bool": {"must": {"match_all": {}}}}} if filters: body["query"]["bool"]["filter"] = normalize_filters(filters) - + if custom_query: template = Template(custom_query) # substitute placeholder for query and filters for the custom_query template string @@ -343,7 +342,7 @@ def _bm25_retrieval( else: operator = "AND" if all_terms_must_match else "OR" - body: Dict[str, Any] = { + body = { "query": { "bool": { "must": [ @@ -398,21 +397,25 @@ def _embedding_retrieval( :param filters: Filters applied to the retrieved Documents. Defaults to None. Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. :param top_k: Maximum number of Documents to return, defaults to 10 - :param custom_query: The query string containing a mandatory `${query_embedding}` and an optional `${filters}` placeholder. + :param custom_query: The query string containing a mandatory `${query_embedding}` and an optional `${filters}` + placeholder **An example custom_query:** - ```python { - "size": 10, "query": { "bool": { - "must": [{"knn": { - "embedding": { - "vector": ${query_embedding}, // mandatory query placeholder - "k": 10000, - }}], - "filter": ${filters} // optional filter placeholder + "must": [ + { + "knn": { + "embedding": { + "vector": ${query_embedding}, // mandatory query placeholder + "k": 10000, + } + } + } + ], + "filter": ${filters} // optional filter placeholder } } } @@ -437,7 +440,7 @@ def _embedding_retrieval( body = json.loads(custom_query_json) else: - body: Dict[str, Any] = { + body = { "query": { "bool": { "must": [ diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index 71fc19c6a..e682a4614 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -89,6 +89,7 @@ def test_run(): top_k=10, scale_score=False, all_terms_must_match=False, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -105,6 +106,7 @@ def test_run_init_params(): scale_score=True, top_k=11, fuzziness="1", + custom_query="some custom query", ) res = retriever.run(query="some query") mock_store._bm25_retrieval.assert_called_once_with( @@ -114,6 +116,7 @@ def test_run_init_params(): top_k=11, scale_score=True, all_terms_must_match=True, + custom_query="some custom query", ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -146,6 +149,7 @@ def test_run_time_params(): top_k=9, scale_score=False, all_terms_must_match=False, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1 diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index af8fd8e25..fec12df0d 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -292,6 +292,103 @@ def test_bm25_retrieval_with_fuzziness(self, document_store: OpenSearchDocumentS assert "functional" in res[1].content assert "functional" in res[2].content + def test_bm25_retrieval_with_custom_query(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document( + content="Haskell is a functional programming language", + meta={"likes": 100000, "language_type": "functional"}, + id="1", + ), + Document( + content="Lisp is a functional programming language", + meta={"likes": 10000, "language_type": "functional"}, + id="2", + ), + Document( + content="Exilir is a functional programming language", + meta={"likes": 1000, "language_type": "functional"}, + id="3", + ), + Document( + content="F# is a functional programming language", + meta={"likes": 100, "language_type": "functional"}, + id="4", + ), + Document( + content="C# is a functional programming language", + meta={"likes": 10, "language_type": "functional"}, + id="5", + ), + Document( + content="C++ is an object oriented programming language", + meta={"likes": 100000, "language_type": "object_oriented"}, + id="6", + ), + Document( + content="Dart is an object oriented programming language", + meta={"likes": 10000, "language_type": "object_oriented"}, + id="7", + ), + Document( + content="Go is an object oriented programming language", + meta={"likes": 1000, "language_type": "object_oriented"}, + id="8", + ), + Document( + content="Python is a object oriented programming language", + meta={"likes": 100, "language_type": "object_oriented"}, + id="9", + ), + Document( + content="Ruby is a object oriented programming language", + meta={"likes": 10, "language_type": "object_oriented"}, + id="10", + ), + Document( + content="PHP is a object oriented programming language", + meta={"likes": 1, "language_type": "object_oriented"}, + id="11", + ), + ] + ) + + custom_query = """ + { + "query": { + "function_score": { + "query": { + "bool": { + "must": { + "match": { + "content": $query + } + }, + "filter": $filters + } + }, + "field_value_factor": { + "field": "likes", + "factor": 0.1, + "modifier": "log1p", + "missing": 0 + } + } + } + } + """ + + res = document_store._bm25_retrieval( + "functional", + top_k=3, + custom_query=custom_query, + filters={"field": "language_type", "operator": "==", "value": "functional"}, + ) + assert len(res) == 3 + assert "1" == res[0].id + assert "2" == res[1].id + assert "3" == res[2].id + def test_embedding_retrieval(self, document_store_embedding_dim_4: OpenSearchDocumentStore): docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), @@ -343,6 +440,47 @@ def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: Op ) assert len(results) == 11 + def test_embedding_retrieval_with_custom_query(self, document_store_embedding_dim_4: OpenSearchDocumentStore): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4.write_documents(docs) + + custom_query = """ + { + "query": { + "bool": { + "must": [ + { + "knn": { + "embedding": { + "vector": $query_embedding, + "k": 3 + } + } + } + ], + "filter": $filters + } + } + } + """ + + filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} + # we set top_k=3, to make the test pass as we are not sure whether efficient filtering is supported for nmslib + # TODO: remove top_k=3, when efficient filtering is supported for nmslib + results = document_store_embedding_dim_4._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, filters=filters, custom_query=custom_query + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_query_documents_different_embedding_sizes( self, document_store_embedding_dim_4: OpenSearchDocumentStore ): diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index c1015ca33..5be38699c 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -96,6 +96,7 @@ def test_run(): query_embedding=[0.5, 0.7], filters={}, top_k=10, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -106,12 +107,15 @@ def test_run(): def test_run_init_params(): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] - retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11) + retriever = OpenSearchEmbeddingRetriever( + document_store=mock_store, filters={"from": "init"}, top_k=11, custom_query="custom_query" + ) res = retriever.run(query_embedding=[0.5, 0.7]) mock_store._embedding_retrieval.assert_called_once_with( query_embedding=[0.5, 0.7], filters={"from": "init"}, top_k=11, + custom_query="custom_query", ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -128,6 +132,7 @@ def test_run_time_params(): query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1