Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: filters in chroma integration #1072

Merged
merged 18 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
from collections import defaultdict
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Union

import chromadb
from chromadb.api.types import GetResult, QueryResult, validate_where, validate_where_document
from chromadb.api.types import GetResult, QueryResult
from haystack import default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores.types import DuplicatePolicy

from .errors import ChromaDocumentStoreFilterError
from .filters import _convert_filters
from .utils import get_embedding_function

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,79 +111,73 @@ def count_documents(self) -> int:
"""
return self._collection.count()

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
def filter_documents(self, filters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None) -> List[Document]:
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the documents that match the filters provided.

Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`,
`"$or"`, `"$not"`), a comparison operator (`"$eq"`, `$ne`, `"$in"`, `$nin`, `"$gt"`, `"$gte"`, `"$lt"`,
`"$lte"`) or a metadata field name.

Logical operator keys take a dictionary of metadata field names and/or logical operators as value. Metadata
field names take a dictionary of comparison operators as value. Comparison operator keys take a single value or
(in case of `"$in"`) a list of values as value. If no logical operator is provided, `"$and"` is used as default
operation. If no comparison operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used
as default operation.

Example:

```python
filters = {
"$and": {
"type": {"$eq": "article"},
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": {"$in": ["economy", "politics"]},
"publisher": {"$eq": "nytimes"}
}
}
}
# or simpler using default operators
filters = {
"type": "article",
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": ["economy", "politics"],
"publisher": "nytimes"
}
}
```

To use the same logical operator multiple times on the same level, logical operators can take a list of
dictionaries as value.

Example:

```python
filters = {
"$or": [
{
"$and": {
"Type": "News Paper",
"Date": {
"$lt": "2019-01-01"
}
}
},
{
"$and": {
"Type": "Blog Post",
"Date": {
"$gte": "2019-01-01"
}
}
}
]
}
```

:param filters: the filters to apply to the document list.
:returns: a list of Documents that match the given filters.
Returns the documents that match the filters provided.

Filters can be provided as a dictionary or a list of dictionaries, supporting filtering by
ids, metadata, and document content.
Metadata filters should use fields like `"meta.name"`, while content-based filters
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
use the `"content"` field directly.
Content filters support the `contains` and `not contains` operators,
while id filters only support the `==` operator.

Due to Chroma's distinction between metadata filters and document filters, filters with `"field": "content"`
(i.e., document content filters) and metadata fields must be supplied separately. For details on chroma filters,
see the [Chroma documentation](https://docs.trychroma.com/guides).

Example:

```python
filters = [
{
"operator": "AND",
"conditions": [
{"field": "meta.name", "operator": "==", "value": "name_0"},
{"field": "meta.number", "operator": "not in", "value": [2, 9]},
],
},
{
"operator": "AND",
"conditions": [
{"field": "content", "operator": "contains", "value": "FOO"},
{"field": "content", "operator": "not contains", "value": "BAR"},
],
},
]
```

If you need to apply the same logical operator (e.g., "AND", "OR") to multiple conditions at the same level,
you can provide a list of dictionaries as the value for the operator, like in the example below:

```python
filters = {
"operator": "OR",
"conditions": [
{"field": "meta.author", "operator": "==", "value": "author_1"},
{
"operator": "AND",
"conditions": [
{"field": "meta.tag", "operator": "==", "value": "tag_1"},
{"field": "meta.page", "operator": ">", "value": 100},
],
},
{
"operator": "AND",
"conditions": [
{"field": "meta.tag", "operator": "==", "value": "tag_2"},
{"field": "meta.page", "operator": ">", "value": 200},
],
},
],
}
```

:param filters: the filters to apply to the document list.
:returns: a list of Documents that match the given filters.
"""
if filters:
ids, where, where_document = self._normalize_filters(filters)
ids, where, where_document = _convert_filters(filters)
kwargs: Dict[str, Any] = {"where": where}

if ids:
Expand Down Expand Up @@ -285,7 +278,7 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any
include=["embeddings", "documents", "metadatas", "distances"],
)
else:
chroma_filters = self._normalize_filters(filters=filters)
chroma_filters = _convert_filters(filters=filters)
results = self._collection.query(
query_texts=queries,
n_results=top_k,
Expand Down Expand Up @@ -316,7 +309,7 @@ def search_embeddings(
include=["embeddings", "documents", "metadatas", "distances"],
)
else:
chroma_filters = self._normalize_filters(filters=filters)
chroma_filters = _convert_filters(filters=filters)
results = self._collection.query(
query_embeddings=query_embeddings,
n_results=top_k,
Expand Down Expand Up @@ -355,62 +348,6 @@ def to_dict(self) -> Dict[str, Any]:
**self._embedding_function_params,
)

@staticmethod
def _normalize_filters(filters: Dict[str, Any]) -> Tuple[List[str], Dict[str, Any], Dict[str, Any]]:
"""
Translate Haystack filters to Chroma filters. It returns three dictionaries, to be
passed to `ids`, `where` and `where_document` respectively.
"""
if not isinstance(filters, dict):
msg = "'filters' parameter must be a dictionary"
raise ChromaDocumentStoreFilterError(msg)

ids = []
where = defaultdict(list)
where_document = defaultdict(list)
keys_to_remove = []

for field, value in filters.items():
if field == "content":
# Schedule for removal the original key, we're going to change it
keys_to_remove.append(field)
where_document["$contains"] = value
elif field == "id":
# Schedule for removal the original key, we're going to change it
keys_to_remove.append(field)
ids.append(value)
elif isinstance(value, (list, tuple)):
# Schedule for removal the original key, we're going to change it
keys_to_remove.append(field)

# if the list is empty the filter is invalid, let's just remove it
if len(value) == 0:
continue

# if the list has a single item, just make it a regular key:value filter pair
if len(value) == 1:
where[field] = value[0]
continue

# if the list contains multiple items, we need an $or chain
for v in value:
where["$or"].append({field: v})

for k in keys_to_remove:
del filters[k]

final_where = dict(filters)
final_where.update(dict(where))
try:
if final_where:
validate_where(final_where)
if where_document:
validate_where_document(where_document)
except ValueError as e:
raise ChromaDocumentStoreFilterError(e) from e

return ids, final_where, where_document

@staticmethod
def _get_result_to_documents(result: GetResult) -> List[Document]:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union

from chromadb.api.types import validate_where, validate_where_document

from .errors import ChromaDocumentStoreFilterError

OPERATORS = {
"==": "$eq",
"!=": "$ne",
">": "$gt",
">=": "$gte",
"<": "$lt",
"<=": "$lte",
"in": "$in",
"not in": "$nin",
"AND": "$and",
"OR": "$or",
"contains": "$contains",
"not contains": "$not_contains",
}


def _convert_filters(
filters: Union[Dict[str, Any], List[Dict[str, Any]]]
) -> Tuple[List[str], Dict[str, Any], Dict[str, Any]]:
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
"""
Converts Haystack filters into a format compatible with Chroma, separating them into ids, metadata filters,
and content filters to be passed to chroma as ids, where, and where_document clauses respectively.

"""

ids = []
where: Dict[str, Any] = defaultdict(list)
where_document: Dict[str, Any] = defaultdict(list)

if isinstance(filters, dict): # if filters is a dict, convert it to a list
filters = [filters]

for clause in filters:
normalized_clause = _normalize_filters(clause)
for field, value in normalized_clause.items():
if value is None:
continue
where_document.update(create_where_document_filter(field, value))
# if where_document is not empty, current clause is a content filter and we can skip rest of the loop
if where_document:
continue
# if field is "id", it'll be passed to Chroma's ids filter
elif field == "id":
if not value["$eq"]:
msg = f"id filter only supports '==' operator, got {value}"
raise ChromaDocumentStoreFilterError(msg)
ids.append(value["$eq"])
else:
where[field] = value

try:
if where_document:
test_clause = "document content filter"
validate_where_document(where_document)
elif where:
test_clause = "metadata filter"
validate_where(where)
except ValueError as e:
msg = f"Invalid '{test_clause}' : {e}"
raise ChromaDocumentStoreFilterError(msg) from e
shadeMe marked this conversation as resolved.
Show resolved Hide resolved

return ids, where, where_document


def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
"""
Converts Haystack filters to Chroma compatible filters.
"""
normalized_filters = {}

if "field" in filters:
normalized_filters.update(_parse_comparison_condition(filters))
else:
normalized_filters.update(_parse_logical_condition(filters))

return normalized_filters


def create_where_document_filter(field, value) -> Dict[str, Any]:
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
"""
Method to convert Haystack filters with the "content" field to Chroma-compatible document filters

"""
where_document: Dict[str, Any] = defaultdict(list)
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
document_filters = []
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved

if value is None:
return where_document
if field == "content":
return value
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
if field in ["$and", "$or"] and value[0].get("content"):
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
# Use list comprehension to populate the field without modifying the original structure
document_filters = [
create_where_document_filter(k, v) for v in value if isinstance(v, dict) for k, v in v.items()
]
if document_filters:
where_document[field] = document_filters
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
return where_document


def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise ChromaDocumentStoreFilterError(msg)
if "conditions" not in condition:
msg = f"'conditions' key missing in {condition}"
raise ChromaDocumentStoreFilterError(msg)

operator = condition["operator"]
conditions = [_normalize_filters(c) for c in condition["conditions"]]

if operator not in OPERATORS:
msg = f"Unknown operator {operator}"
raise ChromaDocumentStoreFilterError(msg)
return {OPERATORS[operator]: conditions}


def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "field" not in condition:
msg = f"'field' key missing in {condition}"
raise ChromaDocumentStoreFilterError(msg)
field: str = ""
# remove the "meta." prefix from the field name
if condition["field"].startswith("meta."):
field = condition["field"].split(".")[-1]
else:
field = condition["field"]

if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise ChromaDocumentStoreFilterError(msg)
if "value" not in condition:
msg = f"'value' key missing in {condition}"
raise ChromaDocumentStoreFilterError(msg)
operator: str = condition["operator"]
value: Any = condition["value"]

return {field: {OPERATORS[operator]: value}}
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading