Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement filtering for WeaviateDocumentStore #278

Merged
merged 13 commits into from
Feb 14, 2024
Merged
65 changes: 20 additions & 45 deletions integrations/weaviate/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ readme = "README.md"
requires-python = ">=3.8"
license = "Apache-2.0"
keywords = []
authors = [
{ name = "deepset GmbH", email = "[email protected]" },
]
authors = [{ name = "deepset GmbH", email = "[email protected]" }]
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
Expand All @@ -28,6 +26,7 @@ dependencies = [
"haystack-ai",
"weaviate-client==3.*",
"haystack-pydoc-tools",
"python-dateutil",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add this to have proper parsing of ISO 8601 dates as datetime.fromisoformat() supports it properly only from 3.11.

]

[project.urls]
Expand All @@ -47,51 +46,25 @@ root = "../.."
git_describe_command = 'git describe --tags --match="integrations/weaviate-v[0-9]*"'

[tool.hatch.envs.default]
dependencies = [
"coverage[toml]>=6.5",
"pytest",
"ipython",
]
dependencies = ["coverage[toml]>=6.5", "pytest", "ipython"]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
test-cov = "coverage run -m pytest {args:tests}"
cov-report = [
"- coverage combine",
"coverage report",
]
cov = [
"test-cov",
"cov-report",
]
docs = [
"pydoc-markdown pydoc/config.yml"
]
cov-report = ["- coverage combine", "coverage report"]
cov = ["test-cov", "cov-report"]
docs = ["pydoc-markdown pydoc/config.yml"]

[[tool.hatch.envs.all.matrix]]
python = ["3.8", "3.9", "3.10", "3.11", "3.12"]

[tool.hatch.envs.lint]
detached = true
dependencies = [
"black>=23.1.0",
"mypy>=1.0.0",
"ruff>=0.0.243",
]
dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
style = [
"ruff {args:.}",
"black --check --diff {args:.}",
]
fmt = [
"black {args:.}",
"ruff --fix {args:.}",
"style",
]
all = [
"style",
"typing",
]
style = ["ruff {args:.}", "black --check --diff {args:.}"]
fmt = ["black {args:.}", "ruff --fix {args:.}", "style"]
all = ["style", "typing"]

[tool.black]
target-version = ["py37"]
Expand Down Expand Up @@ -134,9 +107,15 @@ ignore = [
# Allow boolean positional values in function calls, like `dict.get(... True)`
"FBT003",
# Ignore checks for possible passwords
"S105", "S106", "S107",
"S105",
"S106",
"S107",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
"C901",
"PLR0911",
"PLR0912",
"PLR0913",
"PLR0915",
]
unfixable = [
# Don't touch unused imports
Expand Down Expand Up @@ -164,11 +143,7 @@ weaviate_haystack = ["src/haystack_integrations", "*/weaviate-haystack/src"]
tests = ["tests", "*/weaviate-haystack/tests"]

[tool.coverage.report]
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]

[[tool.mypy.overrides]]
module = [
Expand All @@ -177,6 +152,6 @@ module = [
"pytest.*",
"weaviate.*",
"numpy",
"grpc"
"grpc",
]
ignore_missing_imports = true
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
from typing import Any, Dict

from dateutil import parser
from haystack.errors import FilterError
from pandas import DataFrame


def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert filters from Haystack format to Weaviate format.
"""
if not isinstance(filters, dict):
msg = "Filters must be a dictionary"
raise FilterError(msg)

if "field" in filters:
return {"operator": "And", "operands": [_parse_comparison_condition(filters)]}
return _parse_logical_condition(filters)


OPERATOR_INVERSE = {
"==": "!=",
"!=": "==",
">": "<=",
">=": "<",
"<": ">=",
"<=": ">",
"in": "not in",
"not in": "in",
"AND": "OR",
"OR": "AND",
"NOT": "AND",
}


def _invert_condition(filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Invert condition recursively.
Weaviate doesn't support NOT filters so we need to invert them ourselves.
"""
inverted_condition = filters.copy()
if "operator" not in filters:
# Let's not handle this stuff in here, we'll fail later on anyway.
return inverted_condition
inverted_condition["operator"] = OPERATOR_INVERSE[filters["operator"]]
if "conditions" in filters:
inverted_condition["conditions"] = []
for condition in filters["conditions"]:
inverted_condition["conditions"].append(_invert_condition(condition))

return inverted_condition


def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "conditions" not in condition:
msg = f"'conditions' key missing in {condition}"
raise FilterError(msg)

operator = condition["operator"]
if operator in ["AND", "OR"]:
operands = []
for c in condition["conditions"]:
if "field" not in c:
operands.append(_parse_logical_condition(c))
else:
operands.append(_parse_comparison_condition(c))
return {"operator": operator.lower().capitalize(), "operands": operands}
elif operator == "NOT":
inverted_conditions = _invert_condition(condition)
return _parse_logical_condition(inverted_conditions)
else:
msg = f"Unknown logical operator '{operator}'"
raise FilterError(msg)


def _infer_value_type(value: Any) -> str:
if value is None:
return "valueNull"

if isinstance(value, bool):
return "valueBoolean"
if isinstance(value, int):
return "valueInt"
if isinstance(value, float):
return "valueNumber"

if isinstance(value, str):
try:
parser.isoparse(value)
return "valueDate"
except ValueError:
return "valueText"

msg = f"Unknown value type {type(value)}"
raise FilterError(msg)


def _handle_date(value: Any) -> str:
if isinstance(value, str):
try:
return parser.isoparse(value).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
except ValueError:
pass
return value


def _equal(field: str, value: Any) -> Dict[str, Any]:
if value is None:
return {"path": field, "operator": "IsNull", "valueBoolean": True}
return {"path": field, "operator": "Equal", _infer_value_type(value): _handle_date(value)}


def _not_equal(field: str, value: Any) -> Dict[str, Any]:
if value is None:
return {"path": field, "operator": "IsNull", "valueBoolean": False}
return {
"operator": "Or",
"operands": [
{"path": field, "operator": "NotEqual", _infer_value_type(value): _handle_date(value)},
{"path": field, "operator": "IsNull", "valueBoolean": True},
],
}


def _greater_than(field: str, value: Any) -> Dict[str, Any]:
if value is None:
# When the value is None and '>' is used we create a filter that would return a Document
# if it has a field set and not set at the same time.
# This will cause the filter to match no Document.
# This way we keep the behavior consistent with other Document Stores.
return _match_no_document(field)
if isinstance(value, str):
try:
parser.isoparse(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc
if type(value) in [list, DataFrame]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return {"path": field, "operator": "GreaterThan", _infer_value_type(value): _handle_date(value)}


def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]:
if value is None:
# When the value is None and '>=' is used we create a filter that would return a Document
# if it has a field set and not set at the same time.
# This will cause the filter to match no Document.
# This way we keep the behavior consistent with other Document Stores.
return _match_no_document(field)
if isinstance(value, str):
try:
parser.isoparse(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc
if type(value) in [list, DataFrame]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return {"path": field, "operator": "GreaterThanEqual", _infer_value_type(value): _handle_date(value)}


def _less_than(field: str, value: Any) -> Dict[str, Any]:
if value is None:
# When the value is None and '<' is used we create a filter that would return a Document
# if it has a field set and not set at the same time.
# This will cause the filter to match no Document.
# This way we keep the behavior consistent with other Document Stores.
return _match_no_document(field)
if isinstance(value, str):
try:
parser.isoparse(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc
if type(value) in [list, DataFrame]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return {"path": field, "operator": "LessThan", _infer_value_type(value): _handle_date(value)}


def _less_than_equal(field: str, value: Any) -> Dict[str, Any]:
if value is None:
# When the value is None and '<=' is used we create a filter that would return a Document
# if it has a field set and not set at the same time.
# This will cause the filter to match no Document.
# This way we keep the behavior consistent with other Document Stores.
return _match_no_document(field)
if isinstance(value, str):
try:
parser.isoparse(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc
if type(value) in [list, DataFrame]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return {"path": field, "operator": "LessThanEqual", _infer_value_type(value): _handle_date(value)}


def _in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators"
raise FilterError(msg)

return {"operator": "And", "operands": [_equal(field, v) for v in value]}


def _not_in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators"
raise FilterError(msg)
return {"operator": "And", "operands": [_not_equal(field, v) for v in value]}


COMPARISON_OPERATORS = {
"==": _equal,
"!=": _not_equal,
">": _greater_than,
">=": _greater_than_equal,
"<": _less_than,
"<=": _less_than_equal,
"in": _in,
"not in": _not_in,
}


def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
field: str = condition["field"]

if field.startswith("meta."):
# Documents are flattened otherwise we wouldn't be able to properly query them.
# We're forced to flatten because Weaviate doesn't support querying of nested properties
# as of now. If we don't flatten the documents we can't filter them.
# As time of writing this they have it in their backlog, see:
# https://github.com/weaviate/weaviate/issues/3694
field = field.replace("meta.", "")

if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "value" not in condition:
msg = f"'value' key missing in {condition}"
raise FilterError(msg)
operator: str = condition["operator"]
value: Any = condition["value"]
if isinstance(value, DataFrame):
value = value.to_json()

return COMPARISON_OPERATORS[operator](field, value)


def _match_no_document(field: str) -> Dict[str, Any]:
"""
Returns a filters that will match no Document, this is used to keep the behavior consistent
between different Document Stores.
"""
return {
"operator": "And",
"operands": [
{"path": field, "operator": "IsNull", "valueBoolean": False},
{"path": field, "operator": "IsNull", "valueBoolean": True},
],
}
Loading