Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement keyword retrieval for pgvector integration #644

Merged
merged 27 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions integrations/pgvector/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ exclude_lines = [
"if TYPE_CHECKING:",
]

[tool.pytest.ini_options]
markers = [
"integration: integration tests"
]


[[tool.mypy.overrides]]
module = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0
from .embedding_retriever import PgvectorEmbeddingRetriever
from .keyword_retriever import PgvectorKeywordRetriever

__all__ = ["PgvectorEmbeddingRetriever"]
__all__ = ["PgvectorEmbeddingRetriever", "PgvectorKeywordRetriever"]
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None,
):
"""
:param document_store: An instance of `PgvectorDocumentStore}.
:param document_store: An instance of `PgvectorDocumentStore`.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
:param vector_function: The similarity function to use when searching for similar embeddings.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore


@component
class PgvectorKeywordRetriever:
"""
Retrieve documents from the `PgvectorDocumentStore`, based on keywords.

To rank the documents, the `ts_rank_cd` function of PostgreSQL is used.
It considers how often the query terms appear in the document, how close together the terms are in the document,
and how important is the part of the document where they occur.
For more details, see
[Postgres documentation](https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-RANKING).

Usage example:
```python
from haystack.document_stores import DuplicatePolicy
from haystack import Document

from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore
from haystack_integrations.components.retrievers.pgvector import PgvectorKeywordRetriever

# Set an environment variable `PG_CONN_STR` with the connection string to your PostgreSQL database.
# e.g., "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"

document_store = PgvectorDocumentStore(language="english", recreate_table=True)

documents = [Document(content="There are over 7,000 languages spoken around the world today."),
Document(content="Elephants have been observed to behave in a way that indicates..."),
Document(content="In certain places, you can witness the phenomenon of bioluminescent waves.")]

document_store.write_documents(documents_with_embeddings.get("documents"), policy=DuplicatePolicy.OVERWRITE)

retriever = PgvectorKeywordRetriever(document_store=document_store)

result = retriever.run(query="languages")

assert res['retriever']['documents'][0].content == "There are over 7,000 languages spoken around the world today."
"""

def __init__(
self,
*,
document_store: PgvectorDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
):
"""
:param document_store: An instance of `PgvectorDocumentStore`.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.

:raises ValueError: If `document_store` is not an instance of `PgvectorDocumentStore`.
"""
if not isinstance(document_store, PgvectorDocumentStore):
msg = "document_store must be an instance of PgvectorDocumentStore"
raise ValueError(msg)

self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self.filters,
top_k=self.top_k,
document_store=self.document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PgvectorKeywordRetriever":
"""
Deserializes the component from a dictionary.

:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
doc_store_params = data["init_parameters"]["document_store"]
data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
):
"""
Retrieve documents from the `PgvectorDocumentStore`, based on keywords.

:param query: String to search in `Document`s' content.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.

:returns: A dictionary with the following keys:
- `documents`: List of `Document`s that match the query.
"""
filters = filters or self.filters
top_k = top_k or self.top_k

docs = self.document_store._keyword_retrieval(
query=query,
filters=filters,
top_k=top_k,
)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
meta = EXCLUDED.meta
"""

KEYWORD_QUERY = """
SELECT {table_name}.*, ts_rank_cd(to_tsvector({language}, content), query) AS score
FROM {table_name}, plainto_tsquery({language}, %s) query
WHERE to_tsvector({language}, content) @@ query
"""

VALID_VECTOR_FUNCTIONS = ["cosine_similarity", "inner_product", "l2_distance"]

VECTOR_FUNCTION_TO_POSTGRESQL_OPS = {
Expand All @@ -65,6 +71,8 @@

HNSW_INDEX_NAME = "haystack_hnsw_index"

KEYWORD_INDEX_NAME = "haystack_keyword_index"


class PgvectorDocumentStore:
"""
Expand All @@ -76,6 +84,7 @@ def __init__(
*,
connection_string: Secret = Secret.from_env_var("PG_CONN_STR"),
table_name: str = "haystack_documents",
language: str = "english",
embedding_dimension: int = 768,
vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] = "cosine_similarity",
recreate_table: bool = False,
Expand All @@ -92,6 +101,10 @@ def __init__(
:param connection_string: The connection string to use to connect to the PostgreSQL database, defined as an
environment variable, e.g.: `PG_CONN_STR="postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"`
:param table_name: The name of the table to use to store Haystack documents.
:param language: The language to be used to parse query and document content in keyword retrieval.
To see the list of available languages, you can run the following SQL query in your PostgreSQL database:
`SELECT cfgname FROM pg_ts_config;`.
More information can be found in this [StackOverflow answer](https://stackoverflow.com/a/39752553).
:param embedding_dimension: The dimension of the embedding.
:param vector_function: The similarity function to use when searching for similar embeddings.
`"cosine_similarity"` and `"inner_product"` are similarity functions and
Expand All @@ -116,7 +129,7 @@ def __init__(
[pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw)
:param hnsw_ef_search: The `ef_search` parameter to use at query time. Only used if search_strategy is set to
`"hnsw"`. You can find more information about this parameter in the
[pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw)
[pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw).
"""

self.connection_string = connection_string
Expand All @@ -131,6 +144,7 @@ def __init__(
self.hnsw_recreate_index_if_exists = hnsw_recreate_index_if_exists
self.hnsw_index_creation_kwargs = hnsw_index_creation_kwargs or {}
self.hnsw_ef_search = hnsw_ef_search
self.language = language

connection = connect(self.connection_string.resolve_value())
connection.autocommit = True
Expand All @@ -146,6 +160,7 @@ def __init__(
if recreate_table:
self.delete_table()
self._create_table_if_not_exists()
self._create_keyword_index_if_not_exists()

if search_strategy == "hnsw":
self._handle_hnsw()
Expand All @@ -168,6 +183,7 @@ def to_dict(self) -> Dict[str, Any]:
hnsw_recreate_index_if_exists=self.hnsw_recreate_index_if_exists,
hnsw_index_creation_kwargs=self.hnsw_index_creation_kwargs,
hnsw_ef_search=self.hnsw_ef_search,
language=self.language,
)

@classmethod
Expand Down Expand Up @@ -231,6 +247,29 @@ def delete_table(self):

self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore")

def _create_keyword_index_if_not_exists(self):
"""
Internal method to create the keyword index if not exists.
"""
index_exists = bool(
self._execute_sql(
"SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s",
(self.table_name, KEYWORD_INDEX_NAME),
"Could not check if keyword index exists",
).fetchone()
)

sql_create_index = SQL(
"CREATE INDEX {index_name} ON {table_name} USING GIN (to_tsvector({language}, content))"
).format(
index_name=Identifier(KEYWORD_INDEX_NAME),
table_name=Identifier(self.table_name),
language=SQLLiteral(self.language),
)

if not index_exists:
self._execute_sql(sql_create_index, error_msg="Could not create keyword index on table")

def _handle_hnsw(self):
"""
Internal method to handle the HNSW index creation.
Expand Down Expand Up @@ -475,6 +514,54 @@ def delete_documents(self, document_ids: List[str]) -> None:

self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore")

def _keyword_retrieval(
self,
query: str,
*,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
) -> List[Document]:
"""
Retrieves documents that are most similar to the query using a full-text search.

This method is not meant to be part of the public interface of
`PgvectorDocumentStore` and it should not be called directly.
`PgvectorKeywordRetriever` uses this method directly and is the public interface for it.

:returns: List of Documents that are most similar to `query`
"""
if not query:
msg = "query must be a non-empty string"
raise ValueError(msg)

sql_select = SQL(KEYWORD_QUERY).format(
table_name=Identifier(self.table_name),
language=SQLLiteral(self.language),
query=SQLLiteral(query),
)

where_params = ()
sql_where_clause = SQL("")
if filters:
sql_where_clause, where_params = _convert_filters_to_where_clause_and_params(
filters=filters, operator="AND"
)

sql_sort = SQL(" ORDER BY score DESC LIMIT {top_k}").format(top_k=SQLLiteral(top_k))

sql_query = sql_select + sql_where_clause + sql_sort

result = self._execute_sql(
sql_query,
(query, *where_params),
error_msg="Could not retrieve documents from PgvectorDocumentStore.",
cursor=self._dict_cursor,
)

records = result.fetchall()
docs = self._from_pg_to_haystack_documents(records)
return docs

def _embedding_retrieval(
self,
query_embedding: List[float],
Expand All @@ -489,6 +576,7 @@ def _embedding_retrieval(
This method is not meant to be part of the public interface of
`PgvectorDocumentStore` and it should not be called directly.
`PgvectorEmbeddingRetriever` uses this method directly and is the public interface for it.

:returns: List of Documents that are most similar to `query_embedding`
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from itertools import chain
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Literal, Tuple

from haystack.errors import FilterError
from pandas import DataFrame
Expand All @@ -22,7 +22,9 @@
NO_VALUE = "no_value"


def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> Tuple[SQL, Tuple]:
def _convert_filters_to_where_clause_and_params(
filters: Dict[str, Any], operator: Literal["WHERE", "AND"] = "WHERE"
) -> Tuple[SQL, Tuple]:
"""
Convert Haystack filters to a WHERE clause and a tuple of params to query PostgreSQL.
"""
Expand All @@ -31,7 +33,7 @@ def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> Tupl
else:
query, values = _parse_logical_condition(filters)

where_clause = SQL(" WHERE ") + SQL(query)
where_clause = SQL(f" {operator} ") + SQL(query)
params = tuple(value for value in values if value != NO_VALUE)

return where_clause, params
Expand Down
4 changes: 3 additions & 1 deletion integrations/pgvector/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ def patches_for_unit_tests():
) as mock_delete, patch(
"haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._create_table_if_not_exists"
) as mock_create, patch(
"haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._create_keyword_index_if_not_exists"
) as mock_create_kw_index, patch(
"haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._handle_hnsw"
) as mock_hnsw:

yield mock_connect, mock_register, mock_delete, mock_create, mock_hnsw
yield mock_connect, mock_register, mock_delete, mock_create, mock_create_kw_index, mock_hnsw


@pytest.fixture
Expand Down
1 change: 1 addition & 0 deletions integrations/pgvector/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_to_dict(monkeypatch):
"recreate_table": True,
"search_strategy": "hnsw",
"hnsw_recreate_index_if_exists": True,
"language": "english",
"hnsw_index_creation_kwargs": {"m": 32, "ef_construction": 128},
"hnsw_ef_search": 50,
},
Expand Down
Loading