Skip to content

Commit

Permalink
feat: efficient knn filtering support for OpenSearch (#1134)
Browse files Browse the repository at this point in the history
* feat: efficient filtering support for OpenSearch

* add hint about supported knn engines to docstring

* Apply suggestions from code review
  • Loading branch information
tstadel authored Oct 29, 2024
1 parent 6cc39e6 commit 701a790
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
custom_query: Optional[Dict[str, Any]] = None,
raise_on_failure: bool = True,
efficient_filtering: bool = False,
):
"""
Create the OpenSearchEmbeddingRetriever component.
Expand Down Expand Up @@ -85,6 +86,8 @@ def __init__(
:param raise_on_failure:
If `True`, raises an exception if the API call fails.
If `False`, logs a warning and returns an empty list.
:param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search.
This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib".
:raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore.
"""
Expand All @@ -100,6 +103,7 @@ def __init__(
)
self._custom_query = custom_query
self._raise_on_failure = raise_on_failure
self._efficient_filtering = efficient_filtering

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -116,6 +120,7 @@ def to_dict(self) -> Dict[str, Any]:
filter_policy=self._filter_policy.value,
custom_query=self._custom_query,
raise_on_failure=self._raise_on_failure,
efficient_filtering=self._efficient_filtering,
)

@classmethod
Expand Down Expand Up @@ -146,6 +151,7 @@ def run(
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
custom_query: Optional[Dict[str, Any]] = None,
efficient_filtering: Optional[bool] = None,
):
"""
Retrieve documents using a vector similarity metric.
Expand Down Expand Up @@ -196,6 +202,9 @@ def run(
)
```
:param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search.
This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib".
:returns:
Dictionary with key "documents" containing the retrieved Documents.
- documents: List of Document similar to `query_embedding`.
Expand All @@ -208,6 +217,8 @@ def run(
top_k = self._top_k
if custom_query is None:
custom_query = self._custom_query
if efficient_filtering is None:
efficient_filtering = self._efficient_filtering

docs: List[Document] = []

Expand All @@ -217,6 +228,7 @@ def run(
filters=filters,
top_k=top_k,
custom_query=custom_query,
efficient_filtering=efficient_filtering,
)
except Exception as e:
if self._raise_on_failure:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def _embedding_retrieval(
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
custom_query: Optional[Dict[str, Any]] = None,
efficient_filtering: bool = False,
) -> List[Document]:
"""
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
Expand Down Expand Up @@ -474,6 +475,8 @@ def _embedding_retrieval(
}
```
:param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search.
This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib".
:raises ValueError: If `query_embedding` is an empty list
:returns: List of Document that are most similar to `query_embedding`
"""
Expand Down Expand Up @@ -509,7 +512,10 @@ def _embedding_retrieval(
}

if filters:
body["query"]["bool"]["filter"] = normalize_filters(filters)
if efficient_filtering:
body["query"]["bool"]["must"][0]["knn"]["embedding"]["filter"] = normalize_filters(filters)
else:
body["query"]["bool"]["filter"] = normalize_filters(filters)

body["size"] = top_k

Expand Down
44 changes: 44 additions & 0 deletions integrations/opensearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,27 @@ def document_store_embedding_dim_4(self, request):
yield store
store.client.indices.delete(index=index, params={"ignore": [400, 404]})

@pytest.fixture
def document_store_embedding_dim_4_faiss(self, request):
"""
This is the most basic requirement for the child class: provide
an instance of this document store so the base class can use it.
"""
hosts = ["https://localhost:9200"]
# Use a different index for each test so we can run them in parallel
index = f"{request.node.name}"

store = OpenSearchDocumentStore(
hosts=hosts,
index=index,
http_auth=("admin", "admin"),
verify_certs=False,
embedding_dim=4,
method={"space_type": "innerproduct", "engine": "faiss", "name": "hnsw"},
)
yield store
store.client.indices.delete(index=index, params={"ignore": [400, 404]})

def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
"""
The OpenSearchDocumentStore.filter_documents() method returns a Documents with their score set.
Expand Down Expand Up @@ -690,6 +711,29 @@ def test_embedding_retrieval_with_filters(self, document_store_embedding_dim_4:
assert len(results) == 1
assert results[0].content == "Not very similar document with meta field"

def test_embedding_retrieval_with_filters_efficient_filtering(
self, document_store_embedding_dim_4_faiss: OpenSearchDocumentStore
):
docs = [
Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]),
Document(
content="Not very similar document with meta field",
embedding=[0.0, 0.8, 0.3, 0.9],
meta={"meta_field": "custom_value"},
),
]
document_store_embedding_dim_4_faiss.write_documents(docs)

filters = {"field": "meta_field", "operator": "==", "value": "custom_value"}
results = document_store_embedding_dim_4_faiss._embedding_retrieval(
query_embedding=[0.1, 0.1, 0.1, 0.1],
filters=filters,
efficient_filtering=True,
)
assert len(results) == 1
assert results[0].content == "Not very similar document with meta field"

def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: OpenSearchDocumentStore):
"""
Test that handling of pagination works as expected, when the matching documents are > 10.
Expand Down
15 changes: 13 additions & 2 deletions integrations/opensearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_init_default():
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._filter_policy == FilterPolicy.REPLACE
assert retriever._efficient_filtering is False

retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace")
assert retriever._filter_policy == FilterPolicy.REPLACE
Expand Down Expand Up @@ -82,6 +83,7 @@ def test_to_dict(_mock_opensearch_client):
"filter_policy": "replace",
"custom_query": {"some": "custom query"},
"raise_on_failure": True,
"efficient_filtering": False,
},
}

Expand All @@ -101,6 +103,7 @@ def test_from_dict(_mock_opensearch_client):
"filter_policy": "replace",
"custom_query": {"some": "custom query"},
"raise_on_failure": False,
"efficient_filtering": True,
},
}
retriever = OpenSearchEmbeddingRetriever.from_dict(data)
Expand All @@ -110,6 +113,7 @@ def test_from_dict(_mock_opensearch_client):
assert retriever._custom_query == {"some": "custom query"}
assert retriever._raise_on_failure is False
assert retriever._filter_policy == FilterPolicy.REPLACE
assert retriever._efficient_filtering is True

# For backwards compatibility with older versions of the retriever without a filter policy
data = {
Expand Down Expand Up @@ -139,6 +143,7 @@ def test_run():
filters={},
top_k=10,
custom_query=None,
efficient_filtering=False,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand All @@ -150,14 +155,19 @@ def test_run_init_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = OpenSearchEmbeddingRetriever(
document_store=mock_store, filters={"from": "init"}, top_k=11, custom_query="custom_query"
document_store=mock_store,
filters={"from": "init"},
top_k=11,
custom_query="custom_query",
efficient_filtering=True,
)
res = retriever.run(query_embedding=[0.5, 0.7])
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={"from": "init"},
top_k=11,
custom_query="custom_query",
efficient_filtering=True,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand All @@ -169,12 +179,13 @@ def test_run_time_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11)
res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9)
res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, efficient_filtering=True)
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={"from": "run"},
top_k=9,
custom_query=None,
efficient_filtering=True,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand Down

0 comments on commit 701a790

Please sign in to comment.