Skip to content

Commit

Permalink
feat: add custom_query to OpenSearch retrievers
Browse files Browse the repository at this point in the history
  • Loading branch information
tstadel committed Jun 21, 2024
1 parent 0d53584 commit 2489722
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ def __init__(
This is useful when comparing documents across different indexes. Defaults to False.
:param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents.
This is useful when searching for short text where even one term can make a difference. Defaults to False.
:param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder.
:param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder
**An example custom_query:**
```python
{
"size": 10,
"query": {
"bool": {
"should": [{"multi_match": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,26 @@ def __init__(
:param filters: Filters applied to the retrieved Documents. Defaults to None.
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
:param top_k: Maximum number of Documents to return, defaults to 10
:param custom_query: The query string containing a mandatory `${query_embedding}` and an optional `${filters}` placeholder.
:param custom_query: The query string containing a mandatory `${query_embedding}` and an optional `${filters}`
placeholder
**An example custom_query:**
```python
{
"size": 10,
"query": {
"bool": {
"must": [{"knn": {
"embedding": {
"vector": ${query_embedding}, // mandatory query placeholder
"k": 10000,
}}],
"filter": ${filters} // optional filter placeholder
"must": [
{
"knn": {
"embedding": {
"vector": ${query_embedding}, // mandatory query placeholder
"k": 10000,
}
}
}
],
"filter": ${filters} // optional filter placeholder
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,12 @@ def _bm25_retrieval(
:param top_k: Maximum number of Documents to return, defaults to 10
:param scale_score: If `True` scales the Document`s scores between 0 and 1, defaults to False
:param all_terms_must_match: If `True` all terms in `query` must be present in the Document, defaults to False
:param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder.
:param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder
**An example custom_query:**
```python
{
"size": 10,
"query": {
"bool": {
"should": [{"multi_match": {
Expand All @@ -330,7 +329,7 @@ def _bm25_retrieval(
body: Dict[str, Any] = {"query": {"bool": {"must": {"match_all": {}}}}}
if filters:
body["query"]["bool"]["filter"] = normalize_filters(filters)

if custom_query:
template = Template(custom_query)
# substitute placeholder for query and filters for the custom_query template string
Expand All @@ -343,7 +342,7 @@ def _bm25_retrieval(

else:
operator = "AND" if all_terms_must_match else "OR"
body: Dict[str, Any] = {
body = {
"query": {
"bool": {
"must": [
Expand Down Expand Up @@ -398,21 +397,25 @@ def _embedding_retrieval(
:param filters: Filters applied to the retrieved Documents. Defaults to None.
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
:param top_k: Maximum number of Documents to return, defaults to 10
:param custom_query: The query string containing a mandatory `${query_embedding}` and an optional `${filters}` placeholder.
:param custom_query: The query string containing a mandatory `${query_embedding}` and an optional `${filters}`
placeholder
**An example custom_query:**
```python
{
"size": 10,
"query": {
"bool": {
"must": [{"knn": {
"embedding": {
"vector": ${query_embedding}, // mandatory query placeholder
"k": 10000,
}}],
"filter": ${filters} // optional filter placeholder
"must": [
{
"knn": {
"embedding": {
"vector": ${query_embedding}, // mandatory query placeholder
"k": 10000,
}
}
}
],
"filter": ${filters} // optional filter placeholder
}
}
}
Expand All @@ -437,7 +440,7 @@ def _embedding_retrieval(
body = json.loads(custom_query_json)

else:
body: Dict[str, Any] = {
body = {
"query": {
"bool": {
"must": [
Expand Down
4 changes: 4 additions & 0 deletions integrations/opensearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_run():
top_k=10,
scale_score=False,
all_terms_must_match=False,
custom_query=None,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand All @@ -105,6 +106,7 @@ def test_run_init_params():
scale_score=True,
top_k=11,
fuzziness="1",
custom_query="some custom query",
)
res = retriever.run(query="some query")
mock_store._bm25_retrieval.assert_called_once_with(
Expand All @@ -114,6 +116,7 @@ def test_run_init_params():
top_k=11,
scale_score=True,
all_terms_must_match=True,
custom_query="some custom query",
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand Down Expand Up @@ -146,6 +149,7 @@ def test_run_time_params():
top_k=9,
scale_score=False,
all_terms_must_match=False,
custom_query=None,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand Down
138 changes: 138 additions & 0 deletions integrations/opensearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,103 @@ def test_bm25_retrieval_with_fuzziness(self, document_store: OpenSearchDocumentS
assert "functional" in res[1].content
assert "functional" in res[2].content

def test_bm25_retrieval_with_custom_query(self, document_store: OpenSearchDocumentStore):
document_store.write_documents(
[
Document(
content="Haskell is a functional programming language",
meta={"likes": 100000, "language_type": "functional"},
id="1",
),
Document(
content="Lisp is a functional programming language",
meta={"likes": 10000, "language_type": "functional"},
id="2",
),
Document(
content="Exilir is a functional programming language",
meta={"likes": 1000, "language_type": "functional"},
id="3",
),
Document(
content="F# is a functional programming language",
meta={"likes": 100, "language_type": "functional"},
id="4",
),
Document(
content="C# is a functional programming language",
meta={"likes": 10, "language_type": "functional"},
id="5",
),
Document(
content="C++ is an object oriented programming language",
meta={"likes": 100000, "language_type": "object_oriented"},
id="6",
),
Document(
content="Dart is an object oriented programming language",
meta={"likes": 10000, "language_type": "object_oriented"},
id="7",
),
Document(
content="Go is an object oriented programming language",
meta={"likes": 1000, "language_type": "object_oriented"},
id="8",
),
Document(
content="Python is a object oriented programming language",
meta={"likes": 100, "language_type": "object_oriented"},
id="9",
),
Document(
content="Ruby is a object oriented programming language",
meta={"likes": 10, "language_type": "object_oriented"},
id="10",
),
Document(
content="PHP is a object oriented programming language",
meta={"likes": 1, "language_type": "object_oriented"},
id="11",
),
]
)

custom_query = """
{
"query": {
"function_score": {
"query": {
"bool": {
"must": {
"match": {
"content": $query
}
},
"filter": $filters
}
},
"field_value_factor": {
"field": "likes",
"factor": 0.1,
"modifier": "log1p",
"missing": 0
}
}
}
}
"""

res = document_store._bm25_retrieval(
"functional",
top_k=3,
custom_query=custom_query,
filters={"field": "language_type", "operator": "==", "value": "functional"},
)
assert len(res) == 3
assert "1" == res[0].id
assert "2" == res[1].id
assert "3" == res[2].id

def test_embedding_retrieval(self, document_store_embedding_dim_4: OpenSearchDocumentStore):
docs = [
Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]),
Expand Down Expand Up @@ -343,6 +440,47 @@ def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: Op
)
assert len(results) == 11

def test_embedding_retrieval_with_custom_query(self, document_store_embedding_dim_4: OpenSearchDocumentStore):
docs = [
Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]),
Document(
content="Not very similar document with meta field",
embedding=[0.0, 0.8, 0.3, 0.9],
meta={"meta_field": "custom_value"},
),
]
document_store_embedding_dim_4.write_documents(docs)

custom_query = """
{
"query": {
"bool": {
"must": [
{
"knn": {
"embedding": {
"vector": $query_embedding,
"k": 3
}
}
}
],
"filter": $filters
}
}
}
"""

filters = {"field": "meta_field", "operator": "==", "value": "custom_value"}
# we set top_k=3, to make the test pass as we are not sure whether efficient filtering is supported for nmslib
# TODO: remove top_k=3, when efficient filtering is supported for nmslib
results = document_store_embedding_dim_4._embedding_retrieval(
query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, filters=filters, custom_query=custom_query
)
assert len(results) == 1
assert results[0].content == "Not very similar document with meta field"

def test_embedding_retrieval_query_documents_different_embedding_sizes(
self, document_store_embedding_dim_4: OpenSearchDocumentStore
):
Expand Down
7 changes: 6 additions & 1 deletion integrations/opensearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def test_run():
query_embedding=[0.5, 0.7],
filters={},
top_k=10,
custom_query=None,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand All @@ -106,12 +107,15 @@ def test_run():
def test_run_init_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11)
retriever = OpenSearchEmbeddingRetriever(
document_store=mock_store, filters={"from": "init"}, top_k=11, custom_query="custom_query"
)
res = retriever.run(query_embedding=[0.5, 0.7])
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={"from": "init"},
top_k=11,
custom_query="custom_query",
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand All @@ -128,6 +132,7 @@ def test_run_time_params():
query_embedding=[0.5, 0.7],
filters={"from": "run"},
top_k=9,
custom_query=None,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand Down

0 comments on commit 2489722

Please sign in to comment.