diff --git a/integrations/pgvector/examples/example.py b/integrations/pgvector/examples/embedding_retrieval.py similarity index 100% rename from integrations/pgvector/examples/example.py rename to integrations/pgvector/examples/embedding_retrieval.py diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml index 39e2183cb..b440cf28e 100644 --- a/integrations/pgvector/pyproject.toml +++ b/integrations/pgvector/pyproject.toml @@ -174,6 +174,11 @@ exclude_lines = [ "if TYPE_CHECKING:", ] +[tool.pytest.ini_options] +markers = [ + "integration: integration tests" +] + [[tool.mypy.overrides]] module = [ diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py index ec0cf0dc4..ea9fa8fe7 100644 --- a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py @@ -2,5 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 from .embedding_retriever import PgvectorEmbeddingRetriever +from .keyword_retriever import PgvectorKeywordRetriever -__all__ = ["PgvectorEmbeddingRetriever"] +__all__ = ["PgvectorEmbeddingRetriever", "PgvectorKeywordRetriever"] diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py index 6085545cb..be894dcf7 100644 --- a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py @@ -64,7 +64,7 @@ def __init__( vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, ): """ - :param document_store: An instance of `PgvectorDocumentStore}. + :param document_store: An instance of `PgvectorDocumentStore`. :param filters: Filters applied to the retrieved Documents. :param top_k: Maximum number of Documents to return. :param vector_function: The similarity function to use when searching for similar embeddings. diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py new file mode 100644 index 000000000..c09ac9bb5 --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +@component +class PgvectorKeywordRetriever: + """ + Retrieve documents from the `PgvectorDocumentStore`, based on keywords. + + To rank the documents, the `ts_rank_cd` function of PostgreSQL is used. + It considers how often the query terms appear in the document, how close together the terms are in the document, + and how important is the part of the document where they occur. + For more details, see + [Postgres documentation](https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-RANKING). + + Usage example: + ```python + from haystack.document_stores import DuplicatePolicy + from haystack import Document + + from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + from haystack_integrations.components.retrievers.pgvector import PgvectorKeywordRetriever + + # Set an environment variable `PG_CONN_STR` with the connection string to your PostgreSQL database. + # e.g., "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME" + + document_store = PgvectorDocumentStore(language="english", recreate_table=True) + + documents = [Document(content="There are over 7,000 languages spoken around the world today."), + Document(content="Elephants have been observed to behave in a way that indicates..."), + Document(content="In certain places, you can witness the phenomenon of bioluminescent waves.")] + + document_store.write_documents(documents_with_embeddings.get("documents"), policy=DuplicatePolicy.OVERWRITE) + + retriever = PgvectorKeywordRetriever(document_store=document_store) + + result = retriever.run(query="languages") + + assert res['retriever']['documents'][0].content == "There are over 7,000 languages spoken around the world today." + """ + + def __init__( + self, + *, + document_store: PgvectorDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ): + """ + :param document_store: An instance of `PgvectorDocumentStore`. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + + :raises ValueError: If `document_store` is not an instance of `PgvectorDocumentStore`. + """ + if not isinstance(document_store, PgvectorDocumentStore): + msg = "document_store must be an instance of PgvectorDocumentStore" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PgvectorKeywordRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + doc_store_params = data["init_parameters"]["document_store"] + data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + ): + """ + Retrieve documents from the `PgvectorDocumentStore`, based on keywords. + + :param query: String to search in `Document`s' content. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + + :returns: A dictionary with the following keys: + - `documents`: List of `Document`s that match the query. + """ + filters = filters or self.filters + top_k = top_k or self.top_k + + docs = self.document_store._keyword_retrieval( + query=query, + filters=filters, + top_k=top_k, + ) + return {"documents": docs} diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index da08a5f19..bb663f936 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -53,6 +53,12 @@ meta = EXCLUDED.meta """ +KEYWORD_QUERY = """ +SELECT {table_name}.*, ts_rank_cd(to_tsvector({language}, content), query) AS score +FROM {table_name}, plainto_tsquery({language}, %s) query +WHERE to_tsvector({language}, content) @@ query +""" + VALID_VECTOR_FUNCTIONS = ["cosine_similarity", "inner_product", "l2_distance"] VECTOR_FUNCTION_TO_POSTGRESQL_OPS = { @@ -65,6 +71,8 @@ HNSW_INDEX_NAME = "haystack_hnsw_index" +KEYWORD_INDEX_NAME = "haystack_keyword_index" + class PgvectorDocumentStore: """ @@ -76,6 +84,7 @@ def __init__( *, connection_string: Secret = Secret.from_env_var("PG_CONN_STR"), table_name: str = "haystack_documents", + language: str = "english", embedding_dimension: int = 768, vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] = "cosine_similarity", recreate_table: bool = False, @@ -92,6 +101,10 @@ def __init__( :param connection_string: The connection string to use to connect to the PostgreSQL database, defined as an environment variable, e.g.: `PG_CONN_STR="postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"` :param table_name: The name of the table to use to store Haystack documents. + :param language: The language to be used to parse query and document content in keyword retrieval. + To see the list of available languages, you can run the following SQL query in your PostgreSQL database: + `SELECT cfgname FROM pg_ts_config;`. + More information can be found in this [StackOverflow answer](https://stackoverflow.com/a/39752553). :param embedding_dimension: The dimension of the embedding. :param vector_function: The similarity function to use when searching for similar embeddings. `"cosine_similarity"` and `"inner_product"` are similarity functions and @@ -116,7 +129,7 @@ def __init__( [pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw) :param hnsw_ef_search: The `ef_search` parameter to use at query time. Only used if search_strategy is set to `"hnsw"`. You can find more information about this parameter in the - [pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw) + [pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw). """ self.connection_string = connection_string @@ -131,6 +144,7 @@ def __init__( self.hnsw_recreate_index_if_exists = hnsw_recreate_index_if_exists self.hnsw_index_creation_kwargs = hnsw_index_creation_kwargs or {} self.hnsw_ef_search = hnsw_ef_search + self.language = language connection = connect(self.connection_string.resolve_value()) connection.autocommit = True @@ -146,6 +160,7 @@ def __init__( if recreate_table: self.delete_table() self._create_table_if_not_exists() + self._create_keyword_index_if_not_exists() if search_strategy == "hnsw": self._handle_hnsw() @@ -168,6 +183,7 @@ def to_dict(self) -> Dict[str, Any]: hnsw_recreate_index_if_exists=self.hnsw_recreate_index_if_exists, hnsw_index_creation_kwargs=self.hnsw_index_creation_kwargs, hnsw_ef_search=self.hnsw_ef_search, + language=self.language, ) @classmethod @@ -231,6 +247,29 @@ def delete_table(self): self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore") + def _create_keyword_index_if_not_exists(self): + """ + Internal method to create the keyword index if not exists. + """ + index_exists = bool( + self._execute_sql( + "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", + (self.table_name, KEYWORD_INDEX_NAME), + "Could not check if keyword index exists", + ).fetchone() + ) + + sql_create_index = SQL( + "CREATE INDEX {index_name} ON {table_name} USING GIN (to_tsvector({language}, content))" + ).format( + index_name=Identifier(KEYWORD_INDEX_NAME), + table_name=Identifier(self.table_name), + language=SQLLiteral(self.language), + ) + + if not index_exists: + self._execute_sql(sql_create_index, error_msg="Could not create keyword index on table") + def _handle_hnsw(self): """ Internal method to handle the HNSW index creation. @@ -475,6 +514,54 @@ def delete_documents(self, document_ids: List[str]) -> None: self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") + def _keyword_retrieval( + self, + query: str, + *, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query using a full-text search. + + This method is not meant to be part of the public interface of + `PgvectorDocumentStore` and it should not be called directly. + `PgvectorKeywordRetriever` uses this method directly and is the public interface for it. + + :returns: List of Documents that are most similar to `query` + """ + if not query: + msg = "query must be a non-empty string" + raise ValueError(msg) + + sql_select = SQL(KEYWORD_QUERY).format( + table_name=Identifier(self.table_name), + language=SQLLiteral(self.language), + query=SQLLiteral(query), + ) + + where_params = () + sql_where_clause = SQL("") + if filters: + sql_where_clause, where_params = _convert_filters_to_where_clause_and_params( + filters=filters, operator="AND" + ) + + sql_sort = SQL(" ORDER BY score DESC LIMIT {top_k}").format(top_k=SQLLiteral(top_k)) + + sql_query = sql_select + sql_where_clause + sql_sort + + result = self._execute_sql( + sql_query, + (query, *where_params), + error_msg="Could not retrieve documents from PgvectorDocumentStore.", + cursor=self._dict_cursor, + ) + + records = result.fetchall() + docs = self._from_pg_to_haystack_documents(records) + return docs + def _embedding_retrieval( self, query_embedding: List[float], @@ -489,6 +576,7 @@ def _embedding_retrieval( This method is not meant to be part of the public interface of `PgvectorDocumentStore` and it should not be called directly. `PgvectorEmbeddingRetriever` uses this method directly and is the public interface for it. + :returns: List of Documents that are most similar to `query_embedding` """ diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py index 6199d93ce..d3604cfb3 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from datetime import datetime from itertools import chain -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Literal, Tuple from haystack.errors import FilterError from pandas import DataFrame @@ -22,7 +22,9 @@ NO_VALUE = "no_value" -def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> Tuple[SQL, Tuple]: +def _convert_filters_to_where_clause_and_params( + filters: Dict[str, Any], operator: Literal["WHERE", "AND"] = "WHERE" +) -> Tuple[SQL, Tuple]: """ Convert Haystack filters to a WHERE clause and a tuple of params to query PostgreSQL. """ @@ -31,7 +33,7 @@ def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> Tupl else: query, values = _parse_logical_condition(filters) - where_clause = SQL(" WHERE ") + SQL(query) + where_clause = SQL(f" {operator} ") + SQL(query) params = tuple(value for value in values if value != NO_VALUE) return where_clause, params diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py index 94b35a04d..6547db9eb 100644 --- a/integrations/pgvector/tests/conftest.py +++ b/integrations/pgvector/tests/conftest.py @@ -36,10 +36,12 @@ def patches_for_unit_tests(): ) as mock_delete, patch( "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._create_table_if_not_exists" ) as mock_create, patch( + "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._create_keyword_index_if_not_exists" + ) as mock_create_kw_index, patch( "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._handle_hnsw" ) as mock_hnsw: - yield mock_connect, mock_register, mock_delete, mock_create, mock_hnsw + yield mock_connect, mock_register, mock_delete, mock_create, mock_create_kw_index, mock_hnsw @pytest.fixture diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index bf5ccd5d4..6fd7e0dc0 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -89,6 +89,7 @@ def test_to_dict(monkeypatch): "recreate_table": True, "search_strategy": "hnsw", "hnsw_recreate_index_if_exists": True, + "language": "english", "hnsw_index_creation_kwargs": {"m": 32, "ef_construction": 128}, "hnsw_ef_search": 50, }, diff --git a/integrations/pgvector/tests/test_keyword_retrieval.py b/integrations/pgvector/tests/test_keyword_retrieval.py new file mode 100644 index 000000000..4a5614165 --- /dev/null +++ b/integrations/pgvector/tests/test_keyword_retrieval.py @@ -0,0 +1,50 @@ +import pytest +from haystack.dataclasses.document import Document +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +@pytest.mark.integration +class TestKeywordRetrieval: + def test_keyword_retrieval(self, document_store: PgvectorDocumentStore): + docs = [ + Document(content="The quick brown fox chased the dog", embedding=[0.1] * 768), + Document(content="The fox was brown", embedding=[0.1] * 768), + Document(content="The lazy dog", embedding=[0.1] * 768), + Document(content="fox fox fox", embedding=[0.1] * 768), + ] + + document_store.write_documents(docs) + + results = document_store._keyword_retrieval(query="fox", top_k=2) + + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert results[0].id == docs[-1].id + assert results[0].score > results[1].score + + def test_keyword_retrieval_with_filters(self, document_store: PgvectorDocumentStore): + docs = [ + Document( + content="The quick brown fox chased the dog", + embedding=([0.1] * 768), + meta={"meta_field": "right_value"}, + ), + Document(content="The fox was brown", embedding=([0.1] * 768), meta={"meta_field": "right_value"}), + Document(content="The lazy dog", embedding=([0.1] * 768), meta={"meta_field": "right_value"}), + Document(content="fox fox fox", embedding=([0.1] * 768), meta={"meta_field": "wrong_value"}), + ] + + document_store.write_documents(docs) + + filters = {"field": "meta.meta_field", "operator": "==", "value": "right_value"} + + results = document_store._keyword_retrieval(query="fox", top_k=3, filters=filters) + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert doc.meta["meta_field"] == "right_value" + + def test_empty_query(self, document_store: PgvectorDocumentStore): + with pytest.raises(ValueError): + document_store._keyword_retrieval(query="") diff --git a/integrations/pgvector/tests/test_retriever.py b/integrations/pgvector/tests/test_retrievers.py similarity index 53% rename from integrations/pgvector/tests/test_retriever.py rename to integrations/pgvector/tests/test_retrievers.py index 61381c24e..ef6f918ed 100644 --- a/integrations/pgvector/tests/test_retriever.py +++ b/integrations/pgvector/tests/test_retrievers.py @@ -6,11 +6,11 @@ import pytest from haystack.dataclasses import Document from haystack.utils.auth import EnvVarSecret -from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever, PgvectorKeywordRetriever from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore -class TestRetriever: +class TestEmbeddingRetriever: def test_init_default(self, mock_store): retriever = PgvectorEmbeddingRetriever(document_store=mock_store) assert retriever.document_store == mock_store @@ -46,6 +46,7 @@ def test_to_dict(self, mock_store): "recreate_table": True, "search_strategy": "exact_nearest_neighbor", "hnsw_recreate_index_if_exists": False, + "language": "english", "hnsw_index_creation_kwargs": {}, "hnsw_ef_search": None, }, @@ -114,3 +115,99 @@ def test_run(self): ) assert res == {"documents": [doc]} + + +class TestKeywordRetriever: + def test_init_default(self, mock_store): + retriever = PgvectorKeywordRetriever(document_store=mock_store) + assert retriever.document_store == mock_store + assert retriever.filters == {} + assert retriever.top_k == 10 + + def test_init(self, mock_store): + retriever = PgvectorKeywordRetriever(document_store=mock_store, filters={"field": "value"}, top_k=5) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + + def test_to_dict(self, mock_store): + retriever = PgvectorKeywordRetriever(document_store=mock_store, filters={"field": "value"}, top_k=5) + res = retriever.to_dict() + t = "haystack_integrations.components.retrievers.pgvector.keyword_retriever.PgvectorKeywordRetriever" + assert res == { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "table_name": "haystack", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "language": "english", + "hnsw_index_creation_kwargs": {}, + "hnsw_ef_search": None, + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + }, + } + + @pytest.mark.usefixtures("patches_for_unit_tests") + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("PG_CONN_STR", "some-connection-string") + t = "haystack_integrations.components.retrievers.pgvector.keyword_retriever.PgvectorKeywordRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_ef_search": None, + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + }, + } + + retriever = PgvectorKeywordRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, PgvectorDocumentStore) + assert isinstance(document_store.connection_string, EnvVarSecret) + assert document_store.table_name == "haystack_test_to_dict" + assert document_store.embedding_dimension == 768 + assert document_store.vector_function == "cosine_similarity" + assert document_store.recreate_table + assert document_store.search_strategy == "exact_nearest_neighbor" + assert not document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {} + assert document_store.hnsw_ef_search is None + + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + + def test_run(self): + mock_store = Mock(spec=PgvectorDocumentStore) + doc = Document(content="Test doc", embedding=[0.1, 0.2]) + mock_store._keyword_retrieval.return_value = [doc] + + retriever = PgvectorKeywordRetriever(document_store=mock_store) + res = retriever.run(query="test query") + + mock_store._keyword_retrieval.assert_called_once_with(query="test query", filters={}, top_k=10) + + assert res == {"documents": [doc]}