Skip to content

Commit

Permalink
Pgvector - filters (#257)
Browse files Browse the repository at this point in the history
* very first draft

* setup integration folder and workflow

* update readme

* making progress!

* mypy overrides

* making progress on index

* drop sqlalchemy in favor of psycopggit add tests/test_document_store.py !

* good improvements!

* docstrings

* improve definition

* small improvements

* more test cases

* standardize

* start working on filters

* inner_product

* explicit create statement

* address feedback

* tests separation

* filters - draft

* change embedding_similarity_function to vector_function

* explicit insert and update statements

* remove useless condition

* unit tests for conversion functions

* tests change

* simplify!

* progress!

* better error messages and more

* cover also complex cases

* fmt

* make things work again

* progress on simplification

* further simplification

* filters simplification

* fmt

* rm print

* uncomment line

* fix name

* mv check filters is a dict in filter_documents

* f-strings

* NO_VALUE constant

* handle nested logical conditions in _parse_logical_condition

* add examples to _treat_meta_field

* fix fmt

* ellipsis fmt

* more tests for unhappy paths

* more tests for internal methods

* black

* log debug query and params
  • Loading branch information
anakin87 authored Jan 31, 2024
1 parent dabf071 commit ae80056
Show file tree
Hide file tree
Showing 5 changed files with 489 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from haystack.dataclasses.document import ByteStream, Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils.filters import convert
from psycopg import Error, IntegrityError, connect
from psycopg.abc import Query
from psycopg.cursor import Cursor
Expand All @@ -18,6 +19,8 @@

from pgvector.psycopg import register_vector

from .filters import _convert_filters_to_where_clause_and_params

logger = logging.getLogger(__name__)

CREATE_TABLE_STATEMENT = """
Expand Down Expand Up @@ -158,11 +161,16 @@ def _execute_sql(
params = params or ()
cursor = cursor or self._cursor

sql_query_str = sql_query.as_string(cursor) if not isinstance(sql_query, str) else sql_query
logger.debug("SQL query: %s\nParameters: %s", sql_query_str, params)

try:
result = cursor.execute(sql_query, params)
except Error as e:
self._connection.rollback()
raise DocumentStoreError(error_msg) from e
detailed_error_msg = f"{error_msg}.\nYou can find the SQL query and the parameters in the debug logs."
raise DocumentStoreError(detailed_error_msg) from e

return result

def _create_table_if_not_exists(self):
Expand Down Expand Up @@ -257,15 +265,37 @@ def count_documents(self) -> int:
]
return count

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: # noqa: ARG002
# TODO: implement filters
sql_get_docs = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name))
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Returns the documents that match the filters provided.
For a detailed specification of the filters,
refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering)
:param filters: The filters to apply to the document list.
:return: A list of Documents that match the given filters.
"""
if filters:
if not isinstance(filters, dict):
msg = "Filters must be a dictionary"
raise TypeError(msg)
if "operator" not in filters and "conditions" not in filters:
filters = convert(filters)

sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name))

params = ()
if filters:
sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters)
sql_filter += sql_where_clause

result = self._execute_sql(
sql_get_docs, error_msg="Could not filter documents from PgvectorDocumentStore", cursor=self._dict_cursor
sql_filter,
params,
error_msg="Could not filter documents from PgvectorDocumentStore.",
cursor=self._dict_cursor,
)

# Fetch all the records
records = result.fetchall()
docs = self._from_pg_to_haystack_documents(records)
return docs
Expand Down Expand Up @@ -300,14 +330,21 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D

sql_insert += SQL(" RETURNING id")

sql_query_str = sql_insert.as_string(self._cursor) if not isinstance(sql_insert, str) else sql_insert
logger.debug("SQL query: %s\nParameters: %s", sql_query_str, db_documents)

try:
self._cursor.executemany(sql_insert, db_documents, returning=True)
except IntegrityError as ie:
self._connection.rollback()
raise DuplicateDocumentError from ie
except Error as e:
self._connection.rollback()
raise DocumentStoreError from e
error_msg = (
"Could not write documents to PgvectorDocumentStore. \n"
"You can find the SQL query and the parameters in the debug logs."
)
raise DocumentStoreError(error_msg) from e

# get the number of the inserted documents, inspired by psycopg3 docs
# https://www.psycopg.org/psycopg3/docs/api/cursors.html#psycopg.Cursor.executemany
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from itertools import chain
from typing import Any, Dict, List

from haystack.errors import FilterError
from pandas import DataFrame
from psycopg.sql import SQL
from psycopg.types.json import Jsonb

# we need this mapping to cast meta values to the correct type,
# since they are stored in the JSONB field as strings.
# this dict can be extended if needed
PYTHON_TYPES_TO_PG_TYPES = {
int: "integer",
float: "real",
bool: "boolean",
}

NO_VALUE = "no_value"


def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> tuple[SQL, tuple]:
"""
Convert Haystack filters to a WHERE clause and a tuple of params to query PostgreSQL.
"""
if "field" in filters:
query, values = _parse_comparison_condition(filters)
else:
query, values = _parse_logical_condition(filters)

where_clause = SQL(" WHERE ") + SQL(query)
params = tuple(value for value in values if value != NO_VALUE)

return where_clause, params


def _parse_logical_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "conditions" not in condition:
msg = f"'conditions' key missing in {condition}"
raise FilterError(msg)

operator = condition["operator"]
if operator not in ["AND", "OR"]:
msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR'"
raise FilterError(msg)

# logical conditions can be nested, so we need to parse them recursively
conditions = []
for c in condition["conditions"]:
if "field" in c:
query, vals = _parse_comparison_condition(c)
else:
query, vals = _parse_logical_condition(c)
conditions.append((query, vals))

query_parts, values = [], []
for c in conditions:
query_parts.append(c[0])
values.append(c[1])
if isinstance(values[0], list):
values = list(chain.from_iterable(values))

if operator == "AND":
sql_query = f"({' AND '.join(query_parts)})"
elif operator == "OR":
sql_query = f"({' OR '.join(query_parts)})"
else:
msg = f"Unknown logical operator '{operator}'"
raise FilterError(msg)

return sql_query, values


def _parse_comparison_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]:
field: str = condition["field"]
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "value" not in condition:
msg = f"'value' key missing in {condition}"
raise FilterError(msg)
operator: str = condition["operator"]
if operator not in COMPARISON_OPERATORS:
msg = f"Unknown comparison operator '{operator}'. Valid operators are: {list(COMPARISON_OPERATORS.keys())}"
raise FilterError(msg)

value: Any = condition["value"]
if isinstance(value, DataFrame):
# DataFrames are stored as JSONB and we query them as such
value = Jsonb(value.to_json())
field = f"({field})::jsonb"

if field.startswith("meta."):
field = _treat_meta_field(field, value)

field, value = COMPARISON_OPERATORS[operator](field, value)
return field, [value]


def _treat_meta_field(field: str, value: Any) -> str:
"""
Internal method that modifies the field str
to make the meta JSONB field queryable.
Examples:
>>> _treat_meta_field(field="meta.number", value=9)
"(meta->>'number')::integer"
>>> _treat_meta_field(field="meta.name", value="my_name")
"meta->>'name'"
"""

# use the ->> operator to access keys in the meta JSONB field
field_name = field.split(".", 1)[-1]
field = f"meta->>'{field_name}'"

# meta fields are stored as strings in the JSONB field,
# so we need to cast them to the correct type
type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value))
if isinstance(value, list) and len(value) > 0:
type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value[0]))

if type_value:
field = f"({field})::{type_value}"

return field


def _equal(field: str, value: Any) -> tuple[str, Any]:
if value is None:
# NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params
return f"{field} IS NULL", NO_VALUE
return f"{field} = %s", value


def _not_equal(field: str, value: Any) -> tuple[str, Any]:
# we use IS DISTINCT FROM to correctly handle NULL values
# (not handled by !=)
return f"{field} IS DISTINCT FROM %s", value


def _greater_than(field: str, value: Any) -> tuple[str, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)

return f"{field} > %s", value


def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)

return f"{field} >= %s", value


def _less_than(field: str, value: Any) -> tuple[str, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)

return f"{field} < %s", value


def _less_than_equal(field: str, value: Any) -> tuple[str, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)

return f"{field} <= %s", value


def _not_in(field: str, value: Any) -> tuple[str, List]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
raise FilterError(msg)

return f"{field} IS NULL OR {field} != ALL(%s)", [value]


def _in(field: str, value: Any) -> tuple[str, List]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
raise FilterError(msg)

# see https://www.psycopg.org/psycopg3/docs/basic/adapt.html#lists-adaptation
return f"{field} = ANY(%s)", [value]


COMPARISON_OPERATORS = {
"==": _equal,
"!=": _not_equal,
">": _greater_than,
">=": _greater_than_equal,
"<": _less_than,
"<=": _less_than_equal,
"in": _in,
"not in": _not_in,
}
24 changes: 24 additions & 0 deletions integrations/pgvector/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore


@pytest.fixture
def document_store(request):
connection_string = "postgresql://postgres:postgres@localhost:5432/postgres"
table_name = f"haystack_{request.node.name}"
embedding_dimension = 768
vector_function = "cosine_distance"
recreate_table = True
search_strategy = "exact_nearest_neighbor"

store = PgvectorDocumentStore(
connection_string=connection_string,
table_name=table_name,
embedding_dimension=embedding_dimension,
vector_function=vector_function,
recreate_table=recreate_table,
search_strategy=search_strategy,
)
yield store

store.delete_table()
21 changes: 0 additions & 21 deletions integrations/pgvector/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,6 @@


class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest):
@pytest.fixture
def document_store(self, request):
connection_string = "postgresql://postgres:postgres@localhost:5432/postgres"
table_name = f"haystack_{request.node.name}"
embedding_dimension = 768
vector_function = "cosine_distance"
recreate_table = True
search_strategy = "exact_nearest_neighbor"

store = PgvectorDocumentStore(
connection_string=connection_string,
table_name=table_name,
embedding_dimension=embedding_dimension,
vector_function=vector_function,
recreate_table=recreate_table,
search_strategy=search_strategy,
)
yield store

store.delete_table()

def test_write_documents(self, document_store: PgvectorDocumentStore):
docs = [Document(id="1")]
assert document_store.write_documents(docs) == 1
Expand Down
Loading

0 comments on commit ae80056

Please sign in to comment.