Skip to content

Commit

Permalink
Initialize postgres_cluster fixture lazily
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhollas committed May 31, 2024
1 parent dd31591 commit 5f3a21a
Showing 1 changed file with 37 additions and 26 deletions.
63 changes: 37 additions & 26 deletions src/aiida/tools/pytest_fixtures/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,49 @@

import pathlib
import typing as t
from uuid import uuid4

import pytest
from pgtest.pgtest import PGTest

if t.TYPE_CHECKING:
from pgtest.pgtest import PGTest


@pytest.fixture(scope='session')
def postgres_cluster():
"""Create a temporary and isolated PostgreSQL cluster using ``pgtest`` and cleanup after the yield.
class PostgresCluster:
def __init__(self):
# We initialize the cluster lazily
self.cluster = None

:param database_name: Name of the database.
:param database_username: Username to use for authentication.
:param database_password: Password to use for authentication.
:returns: Dictionary with parameters to connect to the PostgreSQL cluster.
"""
from uuid import uuid4
def _create(self):
try:
self.cluster = PGTest()
except OSError as e:
raise RuntimeError('Could not initialize PostgreSQL cluster') from e

from pgtest.pgtest import PGTest
def _close(self):
if self.cluster is not None:
self.cluster.close()

def create_database(
database_name: str | None = None, database_username: str | None = None, database_password: str | None = None
self,
database_name: str | None = None,
database_username: str | None = None,
database_password: str | None = None,
) -> dict[str, str]:
from aiida.manage.external.postgres import Postgres

if self.cluster is None:
self._create()

postgres_config = {
'database_engine': 'postgresql_psycopg2',
'database_name': database_name or str(uuid4()),
'database_username': database_username or 'guest',
'database_password': database_password or 'guest',
}

postgres = Postgres(interactive=False, quiet=True, dbinfo=cluster.dsn) # type: ignore[union-attr]
postgres = Postgres(interactive=False, quiet=True, dbinfo=self.cluster.dsn) # type: ignore[union-attr]
if not postgres.dbuser_exists(postgres_config['database_username']):
postgres.create_dbuser(
postgres_config['database_username'], postgres_config['database_password'], 'CREATEDB'
Expand All @@ -48,16 +58,21 @@ def create_database(

return postgres_config

cluster = None
try:
cluster = PGTest()
cluster.create_database = create_database
yield cluster
except OSError:
yield None
finally:
if cluster is not None:
cluster.close()

# TODO: Update docstring accordingly
@pytest.fixture(scope='session')
def postgres_cluster():
"""Create a temporary and isolated PostgreSQL cluster using ``pgtest`` and cleanup after the yield.
:param database_name: Name of the database.
:param database_username: Username to use for authentication.
:param database_password: Password to use for authentication.
:returns: Dictionary with parameters to connect to the PostgreSQL cluster.
"""

cluster = PostgresCluster()
yield cluster
cluster._close()


@pytest.fixture(scope='session')
Expand All @@ -78,10 +93,6 @@ def config_psql_dos(
def factory(
database_name: str | None = None, database_username: str | None = None, database_password: str | None = None
) -> dict[str, t.Any]:
if not postgres_cluster:
msg = 'Could not initialize PostgreSQL cluster'
raise RuntimeError(msg)

storage_config: dict[str, t.Any] = postgres_cluster.create_database(
database_name=database_name,
database_username=database_username,
Expand Down

0 comments on commit 5f3a21a

Please sign in to comment.