Skip to content

Commit

Permalink
Add adapter and test for postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
DataDaoDe committed Feb 1, 2025
1 parent 4649343 commit a7f84e4
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 42 deletions.
105 changes: 84 additions & 21 deletions advanced_alchemy/utils/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Protocol

from sqlalchemy import Engine, text
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy import URL, Engine, create_engine, text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from advanced_alchemy.config.sync import SQLAlchemySyncConfig

Expand All @@ -13,27 +13,26 @@


class Adapter(Protocol):
supported_drivers: list[str] = []
supported_drivers: set[str] = set()
dialect: str
config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig
encoding: str
encoding: str | None = None

_engine: Engine | None
_async_engine: AsyncEngine | None
_database: str | None
engine: Engine | None = None
async_engine: AsyncEngine | None = None
original_database_name: str | None = None

def __init__(self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None:
self.config = config
self.encoding = encoding
self._engine: Engine | None = None
self._async_engine: AsyncEngine | None = None

if isinstance(config, SQLAlchemySyncConfig):
self._engine = config.get_engine()
self._database = self._engine.url.database
self.setup(config)
else:
self._async_engine = config.get_engine()
self._database = self._async_engine.url.database
self.setup_async(config)

def setup_async(self, config: SQLAlchemyAsyncConfig) -> None: ...
def setup(self, config: SQLAlchemySyncConfig) -> None: ...

async def create_async(self) -> None: ...
def create(self) -> None: ...
Expand All @@ -43,18 +42,26 @@ def drop(self) -> None: ...


class SQLiteAdapter(Adapter):
supported_drivers: list[str] = ["pysqlite", "aiosqlite"]
supported_drivers: set[str] = {"pysqlite", "aiosqlite"}
dialect: str = "sqlite"

def setup(self, config: SQLAlchemySyncConfig) -> None:
self.engine = config.get_engine()
self.original_database_name = self.engine.url.database

def setup_async(self, config: SQLAlchemyAsyncConfig) -> None:
self.async_engine = config.get_engine()
self.original_database_name = self.async_engine.url.database

def create(self) -> None:
if self._engine is not None and self._database and self._database != ":memory:":
with self._engine.begin() as conn:
if self.engine is not None and self.original_database_name and self.original_database_name != ":memory:":
with self.engine.begin() as conn:
conn.execute(text("CREATE TABLE DB(id int)"))
conn.execute(text("DROP TABLE DB"))

async def create_async(self) -> None:
if self._async_engine is not None:
async with self._async_engine.begin() as conn:
if self.async_engine is not None:
async with self.async_engine.begin() as conn:
await conn.execute(text("CREATE TABLE DB(id int)"))
await conn.execute(text("DROP TABLE DB"))

Expand All @@ -65,11 +72,67 @@ async def drop_async(self) -> None:
return self._drop()

def _drop(self) -> None:
if self._database and self._database != ":memory:":
Path(self._database).unlink()
if self.original_database_name and self.original_database_name != ":memory:":
Path(self.original_database_name).unlink()


class PostgresAdapter(Adapter):
supported_drivers: set[str] = {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
dialect: str = "postgresql"

def _set_url(self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> URL:
original_url = self.config.get_engine().url
self.original_database_name = original_url.database
return original_url._replace(database="postgres")

def setup(self, config: SQLAlchemySyncConfig) -> None:
updated_url = self._set_url(config)
self.engine = create_engine(updated_url, isolation_level="AUTOCOMMIT")

def setup_async(self, config: SQLAlchemyAsyncConfig) -> None:
updated_url = self._set_url(config)
self.async_engine = create_async_engine(updated_url, isolation_level="AUTOCOMMIT")

ADAPTERS = {"sqlite": SQLiteAdapter}
def create(self) -> None:
if self.engine:
with self.engine.begin() as conn:
sql = f"CREATE DATABASE {self.original_database_name} ENCODING '{self.encoding}'"
conn.execute(text(sql))

async def create_async(self) -> None:
if self.async_engine:
async with self.async_engine.begin() as conn:
sql = f"CREATE DATABASE {self.original_database_name} ENCODING '{self.encoding}'"
await conn.execute(text(sql))

def drop(self) -> None:
if self.engine:
with self.engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
sql = self._disconnect_users_sql(version, self.original_database_name)
conn.execute(text(sql))
conn.execute(text(f"DROP DATABASE {self.original_database_name}"))

async def drop_async(self) -> None:
if self.async_engine:
async with self.async_engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
sql = self._disconnect_users_sql(version, self.original_database_name)
await conn.execute(text(sql))
await conn.execute(text(f"DROP DATABASE {self.original_database_name}"))

def _disconnect_users_sql(self, version: tuple[int, int] | None, database: str | None) -> str:
pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid"
return f"""
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();""" # noqa: S608


ADAPTERS = {"sqlite": SQLiteAdapter, "postgresql": PostgresAdapter}


def get_adapter(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> Adapter:
Expand Down
88 changes: 67 additions & 21 deletions tests/integration/test_database_commands.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

from pathlib import Path
from typing import AsyncGenerator, Generator, cast
from typing import AsyncGenerator, Generator

import pytest
from pytest import FixtureRequest
from sqlalchemy import Engine, NullPool, create_engine, select, text
from sqlalchemy import URL, Engine, NullPool, create_engine, select, text
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker

Expand Down Expand Up @@ -46,24 +45,37 @@ async def aiosqlite_engine_cd(tmp_path: Path) -> AsyncGenerator[AsyncEngine, Non
await engine.dispose()


@pytest.fixture(
params=[
pytest.param(
"sqlite_engine_cd",
marks=[
pytest.mark.sqlite,
pytest.mark.integration,
],
)
]
)
def sync_sqlalchemy_config(request: FixtureRequest) -> Generator[SQLAlchemySyncConfig, None, None]:
engine = cast(Engine, request.getfixturevalue(request.param))
orm_registry = base.create_registry()
yield SQLAlchemySyncConfig(
engine_instance=engine,
session_maker=sessionmaker(bind=engine, expire_on_commit=False),
metadata=orm_registry.metadata,
@pytest.fixture()
async def asyncpg_engine_cd(docker_ip: str, postgres_service: None) -> AsyncGenerator[AsyncEngine, None]:
"""Postgresql instance for end-to-end testing."""
yield create_async_engine(
URL(
drivername="postgresql+asyncpg",
username="postgres",
password="super-secret",
host=docker_ip,
port=5423,
database="testing_create_delete",
query={}, # type:ignore[arg-type]
),
poolclass=NullPool,
)


@pytest.fixture()
def psycopg_engine_cd(docker_ip: str, postgres_service: None) -> Generator[Engine, None, None]:
"""Postgresql instance for end-to-end testing."""
yield create_engine(
URL(
drivername="postgresql+psycopg",
username="postgres",
password="super-secret",
host=docker_ip,
port=5423,
database="postgres",
query={}, # type:ignore[arg-type]
),
poolclass=NullPool,
)


Expand Down Expand Up @@ -111,3 +123,37 @@ async def test_create_and_drop_sqlite_async(aiosqlite_engine_cd: AsyncEngine, tm
# always clean up
if Path(file_path).exists():
Path(file_path).unlink()


async def test_create_and_drop_postgres_async(asyncpg_engine_cd: AsyncEngine, asyncpg_engine: AsyncEngine) -> None:
orm_registry = base.create_registry()
cfg = SQLAlchemyAsyncConfig(
engine_instance=asyncpg_engine_cd,
session_maker=async_sessionmaker(bind=asyncpg_engine_cd, expire_on_commit=False),
metadata=orm_registry.metadata,
)

dbname = asyncpg_engine_cd.url.database
exists_sql = f"""
select exists(
SELECT datname FROM pg_catalog.pg_database WHERE lower(datname) = lower('{dbname}')
);
"""

# ensure database does not exist
async with asyncpg_engine.begin() as conn:
result = await conn.execute(text(exists_sql))
assert not result.scalar_one()

await create_database(cfg)
async with asyncpg_engine.begin() as conn:
result = await conn.execute(text(exists_sql))
assert result.scalar_one()

await drop_database(cfg)

async with asyncpg_engine.begin() as conn:
result = await conn.execute(text(exists_sql))
assert not result.scalar_one()

await asyncpg_engine.dispose()

0 comments on commit a7f84e4

Please sign in to comment.