From b42ec5c4c38a2fbcb6528c9631b4a633e92dab04 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 21 Nov 2024 10:06:22 +0100 Subject: [PATCH] feat: Pgvector - recreate the connection if it is no longer valid (#1202) * try refreshing connection * small improvements * rename method --- .../pgvector/document_store.py | 46 +++++++++++++++++-- .../pgvector/tests/test_document_store.py | 19 ++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 8e9c0f2fc..6682c2fee 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -156,29 +156,41 @@ def __init__( self._connection = None self._cursor = None self._dict_cursor = None + self._table_initialized = False @property def cursor(self): - if self._cursor is None: + if self._cursor is None or not self._connection_is_valid(self._connection): self._create_connection() return self._cursor @property def dict_cursor(self): - if self._dict_cursor is None: + if self._dict_cursor is None or not self._connection_is_valid(self._connection): self._create_connection() return self._dict_cursor @property def connection(self): - if self._connection is None: + if self._connection is None or not self._connection_is_valid(self._connection): self._create_connection() return self._connection def _create_connection(self): + """ + Internal method to create a connection to the PostgreSQL database. + """ + + # close the connection if it already exists + if self._connection: + try: + self._connection.close() + except Error as e: + logger.debug("Failed to close connection: %s", str(e)) + conn_str = self.connection_string.resolve_value() or "" connection = connect(conn_str) connection.autocommit = True @@ -189,16 +201,40 @@ def _create_connection(self): self._cursor = self._connection.cursor() self._dict_cursor = self._connection.cursor(row_factory=dict_row) - # Init schema + if not self._table_initialized: + self._initialize_table() + + return self._connection + + def _initialize_table(self): + """ + Internal method to initialize the table. + """ if self.recreate_table: self.delete_table() + self._create_table_if_not_exists() self._create_keyword_index_if_not_exists() if self.search_strategy == "hnsw": self._handle_hnsw() - return self._connection + self._table_initialized = True + + @staticmethod + def _connection_is_valid(connection): + """ + Internal method to check if the connection is still valid. + """ + + # implementation inspired to psycopg pool + # https://github.com/psycopg/psycopg/blob/d38cf7798b0c602ff43dac9f20bbab96237a9c38/psycopg_pool/psycopg_pool/pool.py#L528 + + try: + connection.execute("") + except Error: + return False + return True def to_dict(self) -> Dict[str, Any]: """ diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 4af4fc8de..c6f160f91 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -41,6 +41,25 @@ def test_write_dataframe(self, document_store: PgvectorDocumentStore): retrieved_docs = document_store.filter_documents() assert retrieved_docs == docs + def test_connection_check_and_recreation(self, document_store: PgvectorDocumentStore): + original_connection = document_store.connection + + with patch.object(PgvectorDocumentStore, "_connection_is_valid", return_value=False): + new_connection = document_store.connection + + # verify that a new connection is created + assert new_connection is not original_connection + assert document_store._connection == new_connection + assert original_connection.closed + + assert document_store._cursor is not None + assert document_store._dict_cursor is not None + + # test with new connection + with patch.object(PgvectorDocumentStore, "_connection_is_valid", return_value=True): + same_connection = document_store.connection + assert same_connection is document_store._connection + @pytest.mark.usefixtures("patches_for_unit_tests") def test_init(monkeypatch):