Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: filters in chroma integration #1072

Merged
merged 18 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
from collections import defaultdict
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional

import chromadb
from chromadb.api.types import GetResult, QueryResult, validate_where, validate_where_document
from chromadb.api.types import GetResult, QueryResult
from haystack import default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores.types import DuplicatePolicy

from .errors import ChromaDocumentStoreFilterError
from .filters import _convert_filters
from .utils import get_embedding_function

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -114,83 +113,74 @@ def count_documents(self) -> int:

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Returns the documents that match the filters provided.

Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`,
`"$or"`, `"$not"`), a comparison operator (`"$eq"`, `$ne`, `"$in"`, `$nin`, `"$gt"`, `"$gte"`, `"$lt"`,
`"$lte"`) or a metadata field name.

Logical operator keys take a dictionary of metadata field names and/or logical operators as value. Metadata
field names take a dictionary of comparison operators as value. Comparison operator keys take a single value or
(in case of `"$in"`) a list of values as value. If no logical operator is provided, `"$and"` is used as default
operation. If no comparison operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used
as default operation.

Example:

```python
filters = {
"$and": {
"type": {"$eq": "article"},
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": {"$in": ["economy", "politics"]},
"publisher": {"$eq": "nytimes"}
}
Returns the documents that match the filters provided.

Filters can be provided as a dictionary supporting filtering by ids, metadata, and document content.
Metadata filters should use the `"meta.<metadata_key>"` syntax, while content-based filters
use the `"content"` field directly.
Content filters support the `contains` and `not contains` operators,
while id filters only support the `==` operator.

Due to Chroma's distinction between metadata filters and document filters, filters with `"field": "content"`
(i.e., document content filters) and metadata fields must be supplied separately. For details on chroma filters,
see the [Chroma documentation](https://docs.trychroma.com/guides).

Example:

```python
filter_1 = {
"operator": "AND",
"conditions": [
{"field": "meta.name", "operator": "==", "value": "name_0"},
{"field": "meta.number", "operator": "not in", "value": [2, 9]},
],
}
}
# or simpler using default operators
filters = {
"type": "article",
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": ["economy", "politics"],
"publisher": "nytimes"
filter_2 = {
"operator": "AND",
"conditions": [
{"field": "content", "operator": "contains", "value": "FOO"},
{"field": "content", "operator": "not contains", "value": "BAR"},
],
}
}
```

To use the same logical operator multiple times on the same level, logical operators can take a list of
dictionaries as value.

Example:

```python
filters = {
"$or": [
{
"$and": {
"Type": "News Paper",
"Date": {
"$lt": "2019-01-01"
}
}
},
{
"$and": {
"Type": "Blog Post",
"Date": {
"$gte": "2019-01-01"
}
}
}
]
}
```

:param filters: the filters to apply to the document list.
:returns: a list of Documents that match the given filters.
```

If you need to apply the same logical operator (e.g., "AND", "OR") to multiple conditions at the same level,
you can provide a list of dictionaries as the value for the operator, like in the example below:

```python
filters = {
"operator": "OR",
"conditions": [
{"field": "meta.author", "operator": "==", "value": "author_1"},
{
"operator": "AND",
"conditions": [
{"field": "meta.tag", "operator": "==", "value": "tag_1"},
{"field": "meta.page", "operator": ">", "value": 100},
],
},
{
"operator": "AND",
"conditions": [
{"field": "meta.tag", "operator": "==", "value": "tag_2"},
{"field": "meta.page", "operator": ">", "value": 200},
],
},
],
}
```

:param filters: the filters to apply to the document list.
:returns: a list of Documents that match the given filters.
"""
if filters:
ids, where, where_document = self._normalize_filters(filters)
kwargs: Dict[str, Any] = {"where": where}
chroma_filter = _convert_filters(filters)
kwargs: Dict[str, Any] = {"where": chroma_filter.where}

if ids:
kwargs["ids"] = ids
if where_document:
kwargs["where_document"] = where_document
if chroma_filter.ids:
kwargs["ids"] = chroma_filter.ids
if chroma_filter.where_document:
kwargs["where_document"] = chroma_filter.where_document

result = self._collection.get(**kwargs)
else:
Expand Down Expand Up @@ -285,12 +275,12 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any
include=["embeddings", "documents", "metadatas", "distances"],
)
else:
chroma_filters = self._normalize_filters(filters=filters)
chroma_filters = _convert_filters(filters=filters)
results = self._collection.query(
query_texts=queries,
n_results=top_k,
where=chroma_filters[1],
where_document=chroma_filters[2],
where=chroma_filters.where,
where_document=chroma_filters.where_document,
include=["embeddings", "documents", "metadatas", "distances"],
)

Expand All @@ -316,12 +306,12 @@ def search_embeddings(
include=["embeddings", "documents", "metadatas", "distances"],
)
else:
chroma_filters = self._normalize_filters(filters=filters)
chroma_filters = _convert_filters(filters=filters)
results = self._collection.query(
query_embeddings=query_embeddings,
n_results=top_k,
where=chroma_filters[1],
where_document=chroma_filters[2],
where=chroma_filters.where,
where_document=chroma_filters.where_document,
include=["embeddings", "documents", "metadatas", "distances"],
)

Expand Down Expand Up @@ -355,62 +345,6 @@ def to_dict(self) -> Dict[str, Any]:
**self._embedding_function_params,
)

@staticmethod
def _normalize_filters(filters: Dict[str, Any]) -> Tuple[List[str], Dict[str, Any], Dict[str, Any]]:
"""
Translate Haystack filters to Chroma filters. It returns three dictionaries, to be
passed to `ids`, `where` and `where_document` respectively.
"""
if not isinstance(filters, dict):
msg = "'filters' parameter must be a dictionary"
raise ChromaDocumentStoreFilterError(msg)

ids = []
where = defaultdict(list)
where_document = defaultdict(list)
keys_to_remove = []

for field, value in filters.items():
if field == "content":
# Schedule for removal the original key, we're going to change it
keys_to_remove.append(field)
where_document["$contains"] = value
elif field == "id":
# Schedule for removal the original key, we're going to change it
keys_to_remove.append(field)
ids.append(value)
elif isinstance(value, (list, tuple)):
# Schedule for removal the original key, we're going to change it
keys_to_remove.append(field)

# if the list is empty the filter is invalid, let's just remove it
if len(value) == 0:
continue

# if the list has a single item, just make it a regular key:value filter pair
if len(value) == 1:
where[field] = value[0]
continue

# if the list contains multiple items, we need an $or chain
for v in value:
where["$or"].append({field: v})

for k in keys_to_remove:
del filters[k]

final_where = dict(filters)
final_where.update(dict(where))
try:
if final_where:
validate_where(final_where)
if where_document:
validate_where_document(where_document)
except ValueError as e:
raise ChromaDocumentStoreFilterError(e) from e

return ids, final_where, where_document

@staticmethod
def _get_result_to_documents(result: GetResult) -> List[Document]:
"""
Expand Down
Loading
Loading