Skip to content

Commit

Permalink
feat: add custom_query param to OpenSearch retrievers (#841)
Browse files Browse the repository at this point in the history
* feat: add custom_query param to OpenSearch retrievers

* feat: add custom_query to OpenSearch retrievers

* add as run param

* fix lint

* switch to jinja2 templates

* Revert "switch to jinja2 templates"

This reverts commit f36ed13.

* support custom_query as dict

* remove unneccessary comments

* remove str

* fix lint
  • Loading branch information
tstadel authored Jun 25, 2024
1 parent be09adf commit 69c29a9
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
top_k: int = 10,
scale_score: bool = False,
all_terms_must_match: bool = False,
custom_query: Optional[Dict[str, Any]] = None,
):
"""
Create the OpenSearchBM25Retriever component.
Expand All @@ -31,6 +32,31 @@ def __init__(
This is useful when comparing documents across different indexes. Defaults to False.
:param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents.
This is useful when searching for short text where even one term can make a difference. Defaults to False.
:param custom_query: The query containing a mandatory `$query` and an optional `$filters` placeholder
**An example custom_query:**
```python
{
"query": {
"bool": {
"should": [{"multi_match": {
"query": "$query", // mandatory query placeholder
"type": "most_fields",
"fields": ["content", "title"]}}],
"filter": "$filters" // optional filter placeholder
}
}
}
```
**For this custom_query, a sample `run()` could be:**
```python
retriever.run(query="Why did the revenue increase?",
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
```
:raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore.
"""
Expand All @@ -44,6 +70,7 @@ def __init__(
self._top_k = top_k
self._scale_score = scale_score
self._all_terms_must_match = all_terms_must_match
self._custom_query = custom_query

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -86,6 +113,7 @@ def run(
top_k: Optional[int] = None,
fuzziness: Optional[str] = None,
scale_score: Optional[bool] = None,
custom_query: Optional[Dict[str, Any]] = None,
):
"""
Retrieve documents using BM25 retrieval.
Expand All @@ -97,6 +125,31 @@ def run(
:param fuzziness: Fuzziness parameter for full-text queries.
:param scale_score: Whether to scale the score of retrieved documents between 0 and 1.
This is useful when comparing documents across different indexes.
:param custom_query: The query containing a mandatory `$query` and an optional `$filters` placeholder
**An example custom_query:**
```python
{
"query": {
"bool": {
"should": [{"multi_match": {
"query": "$query", // mandatory query placeholder
"type": "most_fields",
"fields": ["content", "title"]}}],
"filter": "$filters" // optional filter placeholder
}
}
}
```
**For this custom_query, a sample `run()` could be:**
```python
retriever.run(query="Why did the revenue increase?",
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
```
:returns:
A dictionary containing the retrieved documents with the following structure:
Expand All @@ -113,6 +166,8 @@ def run(
fuzziness = self._fuzziness
if scale_score is None:
scale_score = self._scale_score
if custom_query is None:
custom_query = self._custom_query

docs = self._document_store._bm25_retrieval(
query=query,
Expand All @@ -121,5 +176,6 @@ def run(
top_k=top_k,
scale_score=scale_score,
all_terms_must_match=all_terms_must_match,
custom_query=custom_query,
)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
document_store: OpenSearchDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
custom_query: Optional[Dict[str, Any]] = None,
):
"""
Create the OpenSearchEmbeddingRetriever component.
Expand All @@ -30,6 +31,37 @@ def __init__(
:param filters: Filters applied to the retrieved Documents. Defaults to None.
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
:param top_k: Maximum number of Documents to return, defaults to 10
:param custom_query: The query containing a mandatory `$query_embedding` and an optional `$filters` placeholder
**An example custom_query:**
```python
{
"query": {
"bool": {
"must": [
{
"knn": {
"embedding": {
"vector": "$query_embedding", // mandatory query placeholder
"k": 10000,
}
}
}
],
"filter": "$filters" // optional filter placeholder
}
}
}
```
**For this custom_query, a sample `run()` could be:**
```python
retriever.run(query_embedding=embedding,
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
```
:raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore.
"""
if not isinstance(document_store, OpenSearchDocumentStore):
Expand All @@ -39,6 +71,7 @@ def __init__(
self._document_store = document_store
self._filters = filters or {}
self._top_k = top_k
self._custom_query = custom_query

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -71,13 +104,50 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
def run(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
custom_query: Optional[Dict[str, Any]] = None,
):
"""
Retrieve documents using a vector similarity metric.
:param query_embedding: Embedding of the query.
:param filters: Optional filters to narrow down the search space.
:param top_k: Maximum number of Documents to return.
:param custom_query: The query containing a mandatory `$query_embedding` and an optional `$filters` placeholder
**An example custom_query:**
```python
{
"query": {
"bool": {
"must": [
{
"knn": {
"embedding": {
"vector": "$query_embedding", // mandatory query placeholder
"k": 10000,
}
}
}
],
"filter": "$filters" // optional filter placeholder
}
}
}
```
**For this custom_query, a sample `run()` could be:**
```python
retriever.run(query_embedding=embedding,
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
```
:returns:
Dictionary with key "documents" containing the retrieved Documents.
- documents: List of Document similar to `query_embedding`.
Expand All @@ -86,10 +156,13 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] =
filters = self._filters
if top_k is None:
top_k = self._top_k
if custom_query is None:
custom_query = self._custom_query

docs = self._document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
custom_query=custom_query,
)
return {"documents": docs}
Loading

0 comments on commit 69c29a9

Please sign in to comment.