Skip to content

Commit

Permalink
feat: Pgvector - recreate the connection if it is no longer valid (#1202
Browse files Browse the repository at this point in the history
)

* try refreshing connection

* small improvements

* rename method
  • Loading branch information
anakin87 authored Nov 21, 2024
1 parent 472ada8 commit b42ec5c
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,29 +156,41 @@ def __init__(
self._connection = None
self._cursor = None
self._dict_cursor = None
self._table_initialized = False

@property
def cursor(self):
if self._cursor is None:
if self._cursor is None or not self._connection_is_valid(self._connection):
self._create_connection()

return self._cursor

@property
def dict_cursor(self):
if self._dict_cursor is None:
if self._dict_cursor is None or not self._connection_is_valid(self._connection):
self._create_connection()

return self._dict_cursor

@property
def connection(self):
if self._connection is None:
if self._connection is None or not self._connection_is_valid(self._connection):
self._create_connection()

return self._connection

def _create_connection(self):
"""
Internal method to create a connection to the PostgreSQL database.
"""

# close the connection if it already exists
if self._connection:
try:
self._connection.close()
except Error as e:
logger.debug("Failed to close connection: %s", str(e))

conn_str = self.connection_string.resolve_value() or ""
connection = connect(conn_str)
connection.autocommit = True
Expand All @@ -189,16 +201,40 @@ def _create_connection(self):
self._cursor = self._connection.cursor()
self._dict_cursor = self._connection.cursor(row_factory=dict_row)

# Init schema
if not self._table_initialized:
self._initialize_table()

return self._connection

def _initialize_table(self):
"""
Internal method to initialize the table.
"""
if self.recreate_table:
self.delete_table()

self._create_table_if_not_exists()
self._create_keyword_index_if_not_exists()

if self.search_strategy == "hnsw":
self._handle_hnsw()

return self._connection
self._table_initialized = True

@staticmethod
def _connection_is_valid(connection):
"""
Internal method to check if the connection is still valid.
"""

# implementation inspired to psycopg pool
# https://github.com/psycopg/psycopg/blob/d38cf7798b0c602ff43dac9f20bbab96237a9c38/psycopg_pool/psycopg_pool/pool.py#L528

try:
connection.execute("")
except Error:
return False
return True

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down
19 changes: 19 additions & 0 deletions integrations/pgvector/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ def test_write_dataframe(self, document_store: PgvectorDocumentStore):
retrieved_docs = document_store.filter_documents()
assert retrieved_docs == docs

def test_connection_check_and_recreation(self, document_store: PgvectorDocumentStore):
original_connection = document_store.connection

with patch.object(PgvectorDocumentStore, "_connection_is_valid", return_value=False):
new_connection = document_store.connection

# verify that a new connection is created
assert new_connection is not original_connection
assert document_store._connection == new_connection
assert original_connection.closed

assert document_store._cursor is not None
assert document_store._dict_cursor is not None

# test with new connection
with patch.object(PgvectorDocumentStore, "_connection_is_valid", return_value=True):
same_connection = document_store.connection
assert same_connection is document_store._connection


@pytest.mark.usefixtures("patches_for_unit_tests")
def test_init(monkeypatch):
Expand Down

0 comments on commit b42ec5c

Please sign in to comment.