diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index bb1915a6f..b49bd87c3 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -8,6 +8,7 @@ from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy +from haystack.utils.filters import convert from psycopg import Error, IntegrityError, connect from psycopg.abc import Query from psycopg.cursor import Cursor @@ -18,6 +19,8 @@ from pgvector.psycopg import register_vector +from .filters import _convert_filters_to_where_clause_and_params + logger = logging.getLogger(__name__) CREATE_TABLE_STATEMENT = """ @@ -158,11 +161,16 @@ def _execute_sql( params = params or () cursor = cursor or self._cursor + sql_query_str = sql_query.as_string(cursor) if not isinstance(sql_query, str) else sql_query + logger.debug("SQL query: %s\nParameters: %s", sql_query_str, params) + try: result = cursor.execute(sql_query, params) except Error as e: self._connection.rollback() - raise DocumentStoreError(error_msg) from e + detailed_error_msg = f"{error_msg}.\nYou can find the SQL query and the parameters in the debug logs." + raise DocumentStoreError(detailed_error_msg) from e + return result def _create_table_if_not_exists(self): @@ -257,15 +265,37 @@ def count_documents(self) -> int: ] return count - def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: # noqa: ARG002 - # TODO: implement filters - sql_get_docs = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the filters provided. + + For a detailed specification of the filters, + refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering) + + :param filters: The filters to apply to the document list. + :return: A list of Documents that match the given filters. + """ + if filters: + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise TypeError(msg) + if "operator" not in filters and "conditions" not in filters: + filters = convert(filters) + + sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + + params = () + if filters: + sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters) + sql_filter += sql_where_clause result = self._execute_sql( - sql_get_docs, error_msg="Could not filter documents from PgvectorDocumentStore", cursor=self._dict_cursor + sql_filter, + params, + error_msg="Could not filter documents from PgvectorDocumentStore.", + cursor=self._dict_cursor, ) - # Fetch all the records records = result.fetchall() docs = self._from_pg_to_haystack_documents(records) return docs @@ -300,6 +330,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D sql_insert += SQL(" RETURNING id") + sql_query_str = sql_insert.as_string(self._cursor) if not isinstance(sql_insert, str) else sql_insert + logger.debug("SQL query: %s\nParameters: %s", sql_query_str, db_documents) + try: self._cursor.executemany(sql_insert, db_documents, returning=True) except IntegrityError as ie: @@ -307,7 +340,11 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D raise DuplicateDocumentError from ie except Error as e: self._connection.rollback() - raise DocumentStoreError from e + error_msg = ( + "Could not write documents to PgvectorDocumentStore. \n" + "You can find the SQL query and the parameters in the debug logs." + ) + raise DocumentStoreError(error_msg) from e # get the number of the inserted documents, inspired by psycopg3 docs # https://www.psycopg.org/psycopg3/docs/api/cursors.html#psycopg.Cursor.executemany diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py new file mode 100644 index 000000000..daa90f502 --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from datetime import datetime +from itertools import chain +from typing import Any, Dict, List + +from haystack.errors import FilterError +from pandas import DataFrame +from psycopg.sql import SQL +from psycopg.types.json import Jsonb + +# we need this mapping to cast meta values to the correct type, +# since they are stored in the JSONB field as strings. +# this dict can be extended if needed +PYTHON_TYPES_TO_PG_TYPES = { + int: "integer", + float: "real", + bool: "boolean", +} + +NO_VALUE = "no_value" + + +def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> tuple[SQL, tuple]: + """ + Convert Haystack filters to a WHERE clause and a tuple of params to query PostgreSQL. + """ + if "field" in filters: + query, values = _parse_comparison_condition(filters) + else: + query, values = _parse_logical_condition(filters) + + where_clause = SQL(" WHERE ") + SQL(query) + params = tuple(value for value in values if value != NO_VALUE) + + return where_clause, params + + +def _parse_logical_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + operator = condition["operator"] + if operator not in ["AND", "OR"]: + msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR'" + raise FilterError(msg) + + # logical conditions can be nested, so we need to parse them recursively + conditions = [] + for c in condition["conditions"]: + if "field" in c: + query, vals = _parse_comparison_condition(c) + else: + query, vals = _parse_logical_condition(c) + conditions.append((query, vals)) + + query_parts, values = [], [] + for c in conditions: + query_parts.append(c[0]) + values.append(c[1]) + if isinstance(values[0], list): + values = list(chain.from_iterable(values)) + + if operator == "AND": + sql_query = f"({' AND '.join(query_parts)})" + elif operator == "OR": + sql_query = f"({' OR '.join(query_parts)})" + else: + msg = f"Unknown logical operator '{operator}'" + raise FilterError(msg) + + return sql_query, values + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]: + field: str = condition["field"] + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + if operator not in COMPARISON_OPERATORS: + msg = f"Unknown comparison operator '{operator}'. Valid operators are: {list(COMPARISON_OPERATORS.keys())}" + raise FilterError(msg) + + value: Any = condition["value"] + if isinstance(value, DataFrame): + # DataFrames are stored as JSONB and we query them as such + value = Jsonb(value.to_json()) + field = f"({field})::jsonb" + + if field.startswith("meta."): + field = _treat_meta_field(field, value) + + field, value = COMPARISON_OPERATORS[operator](field, value) + return field, [value] + + +def _treat_meta_field(field: str, value: Any) -> str: + """ + Internal method that modifies the field str + to make the meta JSONB field queryable. + + Examples: + >>> _treat_meta_field(field="meta.number", value=9) + "(meta->>'number')::integer" + + >>> _treat_meta_field(field="meta.name", value="my_name") + "meta->>'name'" + """ + + # use the ->> operator to access keys in the meta JSONB field + field_name = field.split(".", 1)[-1] + field = f"meta->>'{field_name}'" + + # meta fields are stored as strings in the JSONB field, + # so we need to cast them to the correct type + type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value)) + if isinstance(value, list) and len(value) > 0: + type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value[0])) + + if type_value: + field = f"({field})::{type_value}" + + return field + + +def _equal(field: str, value: Any) -> tuple[str, Any]: + if value is None: + # NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params + return f"{field} IS NULL", NO_VALUE + return f"{field} = %s", value + + +def _not_equal(field: str, value: Any) -> tuple[str, Any]: + # we use IS DISTINCT FROM to correctly handle NULL values + # (not handled by !=) + return f"{field} IS DISTINCT FROM %s", value + + +def _greater_than(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} > %s", value + + +def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} >= %s", value + + +def _less_than(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} < %s", value + + +def _less_than_equal(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} <= %s", value + + +def _not_in(field: str, value: Any) -> tuple[str, List]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" + raise FilterError(msg) + + return f"{field} IS NULL OR {field} != ALL(%s)", [value] + + +def _in(field: str, value: Any) -> tuple[str, List]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" + raise FilterError(msg) + + # see https://www.psycopg.org/psycopg3/docs/basic/adapt.html#lists-adaptation + return f"{field} = ANY(%s)", [value] + + +COMPARISON_OPERATORS = { + "==": _equal, + "!=": _not_equal, + ">": _greater_than, + ">=": _greater_than_equal, + "<": _less_than, + "<=": _less_than_equal, + "in": _in, + "not in": _not_in, +} diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py new file mode 100644 index 000000000..34260f409 --- /dev/null +++ b/integrations/pgvector/tests/conftest.py @@ -0,0 +1,24 @@ +import pytest +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +@pytest.fixture +def document_store(request): + connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" + table_name = f"haystack_{request.node.name}" + embedding_dimension = 768 + vector_function = "cosine_distance" + recreate_table = True + search_strategy = "exact_nearest_neighbor" + + store = PgvectorDocumentStore( + connection_string=connection_string, + table_name=table_name, + embedding_dimension=embedding_dimension, + vector_function=vector_function, + recreate_table=recreate_table, + search_strategy=search_strategy, + ) + yield store + + store.delete_table() diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 9f3521838..e8d9107d7 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -14,27 +14,6 @@ class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): - @pytest.fixture - def document_store(self, request): - connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" - table_name = f"haystack_{request.node.name}" - embedding_dimension = 768 - vector_function = "cosine_distance" - recreate_table = True - search_strategy = "exact_nearest_neighbor" - - store = PgvectorDocumentStore( - connection_string=connection_string, - table_name=table_name, - embedding_dimension=embedding_dimension, - vector_function=vector_function, - recreate_table=recreate_table, - search_strategy=search_strategy, - ) - yield store - - store.delete_table() - def test_write_documents(self, document_store: PgvectorDocumentStore): docs = [Document(id="1")] assert document_store.write_documents(docs) == 1 diff --git a/integrations/pgvector/tests/test_filters.py b/integrations/pgvector/tests/test_filters.py new file mode 100644 index 000000000..8b2dc8ec9 --- /dev/null +++ b/integrations/pgvector/tests/test_filters.py @@ -0,0 +1,179 @@ +from typing import List + +import pytest +from haystack.dataclasses.document import Document +from haystack.testing.document_store import FilterDocumentsTest +from haystack_integrations.document_stores.pgvector.filters import ( + FilterError, + _convert_filters_to_where_clause_and_params, + _parse_comparison_condition, + _parse_logical_condition, + _treat_meta_field, +) +from pandas import DataFrame +from psycopg.sql import SQL +from psycopg.types.json import Jsonb + + +class TestFilters(FilterDocumentsTest): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + This overrides the default assert_documents_are_equal from FilterDocumentsTest. + It is needed because the embeddings are not exactly the same when they are retrieved from Postgres. + """ + + assert len(received) == len(expected) + received.sort(key=lambda x: x.id) + expected.sort(key=lambda x: x.id) + for received_doc, expected_doc in zip(received, expected): + # we first compare the embeddings approximately + if received_doc.embedding is None: + assert expected_doc.embedding is None + else: + assert received_doc.embedding == pytest.approx(expected_doc.embedding) + + received_doc.embedding, expected_doc.embedding = None, None + assert received_doc == expected_doc + + def test_complex_filter(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.page", "operator": "==", "value": "90"}, + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + ], + }, + ], + } + + result = document_store.filter_documents(filters=filters) + + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") + or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") + ], + ) + + @pytest.mark.skip(reason="NOT operator is not supported in PgvectorDocumentStore") + def test_not_operator(self, document_store, filterable_docs): ... + + def test_treat_meta_field(self): + assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'" + assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'" + assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean" + assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean" + + # do not cast the field if its value is not one of the known types, an empty list or None + assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'" + assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'" + assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'" + + def test_comparison_condition_dataframe_jsonb_conversion(self): + dataframe = DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + condition = {"field": "meta.df", "operator": "==", "value": dataframe} + field, values = _parse_comparison_condition(condition) + assert field == "(meta.df)::jsonb = %s" + + # we check each slot of the Jsonb object because it does not implement __eq__ + assert values[0].obj == Jsonb(dataframe.to_json()).obj + assert values[0].dumps == Jsonb(dataframe.to_json()).dumps + + def test_comparison_condition_missing_operator(self): + condition = {"field": "meta.type", "value": "article"} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_comparison_condition_missing_value(self): + condition = {"field": "meta.type", "operator": "=="} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_comparison_condition_unknown_operator(self): + condition = {"field": "meta.type", "operator": "unknown", "value": "article"} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_logical_condition_missing_operator(self): + condition = {"conditions": []} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_missing_conditions(self): + condition = {"operator": "AND"} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_unknown_operator(self): + condition = {"operator": "unknown", "conditions": []} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_nested(self): + condition = { + "operator": "AND", + "conditions": [ + { + "operator": "OR", + "conditions": [ + {"field": "meta.domain", "operator": "!=", "value": "science"}, + {"field": "meta.chapter", "operator": "in", "value": ["intro", "conclusion"]}, + ], + }, + { + "operator": "OR", + "conditions": [ + {"field": "meta.number", "operator": ">=", "value": 90}, + {"field": "meta.author", "operator": "not in", "value": ["John", "Jane"]}, + ], + }, + ], + } + query, values = _parse_logical_condition(condition) + assert query == ( + "((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) " + "AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))" + ) + assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]] + + def test_convert_filters_to_where_clause_and_params(self): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + } + where_clause, params = _convert_filters_to_where_clause_and_params(filters) + assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)") + assert params == (100, "intro") + + def test_convert_filters_to_where_clause_and_params_handle_null(self): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": None}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + } + where_clause, params = _convert_filters_to_where_clause_and_params(filters) + assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)") + assert params == ("intro",)