Skip to content

Commit

Permalink
cover also complex cases
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Jan 23, 2024
1 parent 7306336 commit cc62733
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _execute_sql(
except Error as e:
self._connection.rollback()
sql_query_str = sql_query.as_string(cursor) if not isinstance(sql_query, str) else sql_query
detailed_error_msg = f"{error_msg}.\nSQL query: {sql_query_str}\n Parameters: {params}"
detailed_error_msg = f"{error_msg}.\nSQL query: {sql_query_str} \nParameters: {params}"
raise DocumentStoreError(detailed_error_msg) from e

return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pandas import DataFrame
from psycopg.sql import SQL
from psycopg.types.json import Jsonb
from itertools import chain


def _build_where_clause(filters: Dict[str, Any], cursor) -> str:
Expand Down Expand Up @@ -84,21 +85,36 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
query_parts = []
values = []
for c in conditions:
query_parts.append(SQL(c[0]))
print("c0", c[0])

query_parts.append(c[0])
values.append(c[1])

# values = list(chain.from_iterable(values))

print("query_parts", query_parts)
# if isinstance(query_parts[0], list):
# query_parts = list(chain.from_iterable(query_parts))
# print("chained", query_parts)
sql_query_parts = [SQL(q) if isinstance(q, str) else q for q in query_parts]
if isinstance(values[0], list):
values = list(chain.from_iterable(values))

if operator == "AND":
return SQL(" AND ").join(query_parts), values
sql_query = SQL("(") + SQL(" AND ").join(sql_query_parts)+ SQL(")")

elif operator == "OR":
return SQL(" OR ").join(query_parts), values
sql_query = SQL("(") + SQL(" OR ").join(sql_query_parts)+ SQL(")")

elif operator == "NOT":
query_parts = [SQL("NOT (") + query_part + SQL(")") for query_part in query_parts]
return SQL(" AND ").join(query_parts), values
joined_query_parts = SQL(" AND ").join(sql_query_parts)
sql_query = SQL("NOT (") + joined_query_parts + SQL(")")

else:
msg = f"Unknown logical operator '{operator}'"
raise FilterError(msg)

return sql_query, values


def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
17 changes: 8 additions & 9 deletions integrations/pgvector/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

class TestFilters(FilterDocumentsTest):
def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
print("received", received)
print("expected", expected)
assert len(received) == len(expected)
received.sort(key=lambda x: x.id)
expected.sort(key=lambda x: x.id)
Expand All @@ -29,26 +31,23 @@ def test_complex_filter(self, document_store, filterable_docs):
"operator": "AND",
"conditions": [
{"field": "meta.number", "operator": "==", "value": 100},
{"field": "meta.name", "operator": "==", "value": "name_0"},
{"field": "meta.chapter", "operator": "==", "value": "intro"},
],
},
{
"operator": "AND",
"conditions": [
{"field": "meta.page", "operator": "==", "value": 90},
{"field": "meta.page", "operator": "==", "value": "90"},
{"field": "meta.chapter", "operator": "==", "value": "conclusion"},
],
},
],
}

result = document_store.filter_documents(filters=filters)

self.assert_documents_are_equal(
result,
[
d
for d in filterable_docs
if (d.meta.get("number") == 100 and d.meta.get("name") == "name_0")
or (d.meta.get("page") == 90 and d.meta.get("chapter") == "conclusion")
],
)
[d for d in filterable_docs if
(d.meta.get("number") == 100 and d.meta.get("chapter") == "intro")
or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion")])

0 comments on commit cc62733

Please sign in to comment.