From cf8428cf3efa7d39a333b17d36b10602d8444d13 Mon Sep 17 00:00:00 2001 From: John Faucett Date: Mon, 27 Jan 2025 10:27:18 +0100 Subject: [PATCH 1/9] Add create database command --- advanced_alchemy/cli.py | 23 +++++++ advanced_alchemy/utils/databases.py | 103 ++++++++++++++++++++++++++++ tests/unit/test_cli.py | 2 + 3 files changed, 128 insertions(+) create mode 100644 advanced_alchemy/utils/databases.py diff --git a/advanced_alchemy/cli.py b/advanced_alchemy/cli.py index 01857cef..6f6f6f98 100644 --- a/advanced_alchemy/cli.py +++ b/advanced_alchemy/cli.py @@ -399,4 +399,27 @@ async def _dump_tables() -> None: return run(_dump_tables) + @database_group.command(name="create", help="Create a new database.") + @bind_key_option + @no_prompt_option + def create_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction] + from rich.prompt import Confirm + from advanced_alchemy.utils.databases import create_database as _create_database + + ctx = click.get_current_context() + sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) + + db_name = sqlalchemy_config.get_engine().url.database + + console.rule("[yellow]Starting database creation process[/]", align="left") + input_confirmed = ( + True if no_prompt else Confirm.ask(f"[bold]Are you sure you want to create a new database `{db_name}`?[/]") + ) + if input_confirmed: + _create_database(sqlalchemy_config) + + @database_group.command(name="drop", help="Drop the current database.") + def drop_database() -> None: # pyright: ignore[reportUnusedFunction] + raise NotImplementedError + return database_group diff --git a/advanced_alchemy/utils/databases.py b/advanced_alchemy/utils/databases.py new file mode 100644 index 00000000..b26c2e4e --- /dev/null +++ b/advanced_alchemy/utils/databases.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import asyncio +from copy import copy +from typing import TYPE_CHECKING + +from sqlalchemy import Engine, text +from sqlalchemy.ext.asyncio import AsyncEngine + +from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig +from advanced_alchemy.config.sync import SQLAlchemySyncConfig + +if TYPE_CHECKING: + from sqlalchemy.engine.url import URL + + +def set_engine_database_url(url: URL, database: str | None) -> URL: + if hasattr(url, "_replace"): + new_url = url._replace(database=database) + else: # SQLAlchemy <1.4 + new_url = copy(url) + new_url.database = database # type: ignore # noqa: PGH003 + return new_url + + +def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None: + url = config.get_engine().url + database = url.database + dialect_name = url.get_dialect().name + dialect_driver = url.get_dialect().driver + + if dialect_name == "postgresql": + url = set_engine_database_url(url, database="postgres") + elif dialect_name == "mssql": + url = set_engine_database_url(url, database="master") + elif dialect_name == "cockroachdb": + url = set_engine_database_url(url, database="defaultdb") + elif dialect_name != "sqlite": + url = set_engine_database_url(url, database=None) + + if (dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}) or ( + dialect_name == "postgresql" and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"} + ): + config.engine_config.isolation_level = "AUTOCOMMIT" + + engine = config.create_engine_callable(str(url)) + + if isinstance(engine, Engine): + create_sync_database(engine, database, encoding) + else: + asyncio.run(create_async_database(engine, database, encoding)) + + +def create_sync_database(engine: Engine, database: str | None, encoding: str = "utf8") -> None: + dialect_name = engine.url.get_dialect().name + if dialect_name == "postgresql": + with engine.begin() as conn: + sql = f"CREATE DATABASE {database} ENCODING '{encoding}'" + conn.execute(text(sql)) + + elif dialect_name == "mysql": + with engine.begin() as conn: + sql = f"CREATE DATABASE {database} CHARACTER SET = '{encoding}'" + conn.execute(text(sql)) + + elif dialect_name == "sqlite" and database != ":memory:": + if database: + with engine.begin() as conn: + conn.execute(text("CREATE TABLE DB(id int)")) + conn.execute(text("DROP TABLE DB")) + + else: + with engine.begin() as conn: + sql = f"CREATE DATABASE {database}" + conn.execute(text(sql)) + + engine.dispose() + + +async def create_async_database(engine: AsyncEngine, database: str | None, encoding: str = "utf8") -> None: + dialect_name = engine.url.get_dialect().name + if dialect_name == "postgresql": + async with engine.begin() as conn: + sql = f"CREATE DATABASE {database} ENCODING '{encoding}'" + await conn.execute(text(sql)) + + elif dialect_name == "mysql": + async with engine.begin() as conn: + sql = f"CREATE DATABASE {database} CHARACTER SET = '{encoding}'" + await conn.execute(text(sql)) + + elif dialect_name == "sqlite" and database != ":memory:": + if database: + async with engine.begin() as conn: + await conn.execute(text("CREATE TABLE DB(id int)")) + await conn.execute(text("DROP TABLE DB")) + + else: + async with engine.begin() as conn: + sql = f"CREATE DATABASE {database}" + await conn.execute(text(sql)) + + await engine.dispose() diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index c34ff0de..615e6e84 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -156,3 +156,5 @@ def test_cli_group_creation() -> None: assert "make-migrations" in cli_group.commands assert "drop-all" in cli_group.commands assert "dump-data" in cli_group.commands + assert "create" in cli_group.commands + assert "drop" in cli_group.commands From 61fefcaa17410c319de99f88ff638c9bcdf900e8 Mon Sep 17 00:00:00 2001 From: John Faucett Date: Mon, 27 Jan 2025 11:20:27 +0100 Subject: [PATCH 2/9] Add drop database command --- advanced_alchemy/cli.py | 33 +++++++- advanced_alchemy/utils/databases.py | 113 ++++++++++++++++++++++++++-- 2 files changed, 134 insertions(+), 12 deletions(-) diff --git a/advanced_alchemy/cli.py b/advanced_alchemy/cli.py index 6f6f6f98..9b4f07d5 100644 --- a/advanced_alchemy/cli.py +++ b/advanced_alchemy/cli.py @@ -402,8 +402,17 @@ async def _dump_tables() -> None: @database_group.command(name="create", help="Create a new database.") @bind_key_option @no_prompt_option - def create_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction] + @click.option( + "--encoding", + "encoding", + help="Set the encoding for the created database", + type=str, + required=False, + default="utf8", + ) + def create_database(bind_key: str | None, no_prompt: bool, encoding: str) -> None: # pyright: ignore[reportUnusedFunction] from rich.prompt import Confirm + from advanced_alchemy.utils.databases import create_database as _create_database ctx = click.get_current_context() @@ -416,10 +425,26 @@ def create_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: True if no_prompt else Confirm.ask(f"[bold]Are you sure you want to create a new database `{db_name}`?[/]") ) if input_confirmed: - _create_database(sqlalchemy_config) + _create_database(sqlalchemy_config, encoding) @database_group.command(name="drop", help="Drop the current database.") - def drop_database() -> None: # pyright: ignore[reportUnusedFunction] - raise NotImplementedError + @bind_key_option + @no_prompt_option + def drop_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction] + from rich.prompt import Confirm + + from advanced_alchemy.utils.databases import drop_database as _drop_database + + ctx = click.get_current_context() + sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) + + db_name = sqlalchemy_config.get_engine().url.database + + console.rule("[yellow]Starting database deletion process[/]", align="left") + input_confirmed = ( + True if no_prompt else Confirm.ask(f"[bold]Are you sure you want to drop database `{db_name}`?[/]") + ) + if input_confirmed: + _drop_database(sqlalchemy_config) return database_group diff --git a/advanced_alchemy/utils/databases.py b/advanced_alchemy/utils/databases.py index b26c2e4e..4cea9911 100644 --- a/advanced_alchemy/utils/databases.py +++ b/advanced_alchemy/utils/databases.py @@ -1,18 +1,22 @@ from __future__ import annotations import asyncio +import os from copy import copy +from pathlib import Path from typing import TYPE_CHECKING from sqlalchemy import Engine, text from sqlalchemy.ext.asyncio import AsyncEngine -from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig -from advanced_alchemy.config.sync import SQLAlchemySyncConfig +from advanced_alchemy.utils.dataclass import Empty if TYPE_CHECKING: from sqlalchemy.engine.url import URL + from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig + from advanced_alchemy.config.sync import SQLAlchemySyncConfig + def set_engine_database_url(url: URL, database: str | None) -> URL: if hasattr(url, "_replace"): @@ -44,14 +48,13 @@ def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encodi config.engine_config.isolation_level = "AUTOCOMMIT" engine = config.create_engine_callable(str(url)) - - if isinstance(engine, Engine): - create_sync_database(engine, database, encoding) + if isinstance(engine, AsyncEngine): + asyncio.run(create_database_async(engine, database, encoding)) else: - asyncio.run(create_async_database(engine, database, encoding)) + create_database_sync(engine, database, encoding) -def create_sync_database(engine: Engine, database: str | None, encoding: str = "utf8") -> None: +def create_database_sync(engine: Engine, database: str | None, encoding: str = "utf8") -> None: dialect_name = engine.url.get_dialect().name if dialect_name == "postgresql": with engine.begin() as conn: @@ -77,7 +80,7 @@ def create_sync_database(engine: Engine, database: str | None, encoding: str = " engine.dispose() -async def create_async_database(engine: AsyncEngine, database: str | None, encoding: str = "utf8") -> None: +async def create_database_async(engine: AsyncEngine, database: str | None, encoding: str = "utf8") -> None: dialect_name = engine.url.get_dialect().name if dialect_name == "postgresql": async with engine.begin() as conn: @@ -101,3 +104,97 @@ async def create_async_database(engine: AsyncEngine, database: str | None, encod await conn.execute(text(sql)) await engine.dispose() + + +def drop_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> None: + url = config.get_engine().url + database = url.database + dialect_name = url.get_dialect().name + dialect_driver = url.get_dialect().driver + + if dialect_name == "postgresql": + url = set_engine_database_url(url, database="postgres") + elif dialect_name == "mssql": + url = set_engine_database_url(url, database="master") + elif dialect_name == "cockroachdb": + url = set_engine_database_url(url, database="defaultdb") + elif dialect_name != "sqlite": + url = set_engine_database_url(url, database=None) + + if dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}: + if config.engine_config.connect_args is Empty: + config.engine_config.connect_args = {} + else: + config.engine_config.connect_args["autocommit"] = True # type: ignore # noqa: PGH003 + elif dialect_name == "postgresql" and dialect_driver in { + "asyncpg", + "pg8000", + "psycopg", + "psycopg2", + "psycopg2cffi", + }: + config.engine_config.isolation_level = "AUTOCOMMIT" + + engine = config.create_engine_callable(str(url)) + if isinstance(engine, AsyncEngine): + asyncio.run(drop_database_async(engine, database)) + else: + drop_database_sync(engine, database) + + +async def drop_database_async(engine: AsyncEngine, database: str | None) -> None: + dialect_name = engine.url.get_dialect().name + if dialect_name == "sqlite" and database != ":memory:": + if database: + Path(database).unlink() + elif dialect_name == "postgresql": + async with engine.begin() as conn: + # Disconnect all users from the database we are dropping. + version = conn.dialect.server_version_info + pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid" + sql = f""" + SELECT pg_terminate_backend(pg_stat_activity.{pid_column}) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{database}' + AND {pid_column} <> pg_backend_pid(); + """ # noqa: S608 + await conn.execute(text(sql)) + + # Drop the database. + sql = f"DROP DATABASE {database}" + await conn.execute(text(sql)) + else: + async with engine.begin() as conn: + sql = f"DROP DATABASE {database}" + await conn.execute(text(sql)) + + await engine.dispose() + + +def drop_database_sync(engine: Engine, database: str | None) -> None: + dialect_name = engine.url.get_dialect().name + if dialect_name == "sqlite" and database != ":memory:": + if database: + Path(database).unlink() + elif dialect_name == "postgresql": + with engine.begin() as conn: + # Disconnect all users from the database we are dropping. + version = conn.dialect.server_version_info + pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid" + sql = f""" + SELECT pg_terminate_backend(pg_stat_activity.{pid_column}) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{database}' + AND {pid_column} <> pg_backend_pid(); + """ # noqa: S608 + conn.execute(text(sql)) + + # Drop the database. + sql = f"DROP DATABASE {database}" + conn.execute(text(sql)) + else: + with engine.begin() as conn: + sql = f"DROP DATABASE {database}" + conn.execute(text(sql)) + + engine.dispose() From fbef24eb1cf7096980cd9f64a6486af943cdaf45 Mon Sep 17 00:00:00 2001 From: John Faucett Date: Mon, 27 Jan 2025 11:32:44 +0100 Subject: [PATCH 3/9] A little clean-up --- advanced_alchemy/utils/databases.py | 62 ++++++++++++++--------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/advanced_alchemy/utils/databases.py b/advanced_alchemy/utils/databases.py index 4cea9911..5ed57db2 100644 --- a/advanced_alchemy/utils/databases.py +++ b/advanced_alchemy/utils/databases.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import os from copy import copy from pathlib import Path from typing import TYPE_CHECKING @@ -18,7 +17,7 @@ from advanced_alchemy.config.sync import SQLAlchemySyncConfig -def set_engine_database_url(url: URL, database: str | None) -> URL: +def _set_engine_database_url(url: URL, database: str | None) -> URL: if hasattr(url, "_replace"): new_url = url._replace(database=database) else: # SQLAlchemy <1.4 @@ -27,20 +26,28 @@ def set_engine_database_url(url: URL, database: str | None) -> URL: return new_url -def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None: - url = config.get_engine().url - database = url.database +def get_masterdb_url(url: URL) -> URL: dialect_name = url.get_dialect().name - dialect_driver = url.get_dialect().driver if dialect_name == "postgresql": - url = set_engine_database_url(url, database="postgres") + url = _set_engine_database_url(url, database="postgres") elif dialect_name == "mssql": - url = set_engine_database_url(url, database="master") + url = _set_engine_database_url(url, database="master") elif dialect_name == "cockroachdb": - url = set_engine_database_url(url, database="defaultdb") + url = _set_engine_database_url(url, database="defaultdb") elif dialect_name != "sqlite": - url = set_engine_database_url(url, database=None) + url = _set_engine_database_url(url, database=None) + + return url + + +def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None: + url = config.get_engine().url + database = url.database + dialect_name = url.get_dialect().name + dialect_driver = url.get_dialect().driver + + url = get_masterdb_url(url) if (dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}) or ( dialect_name == "postgresql" and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"} @@ -112,14 +119,7 @@ def drop_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> None: dialect_name = url.get_dialect().name dialect_driver = url.get_dialect().driver - if dialect_name == "postgresql": - url = set_engine_database_url(url, database="postgres") - elif dialect_name == "mssql": - url = set_engine_database_url(url, database="master") - elif dialect_name == "cockroachdb": - url = set_engine_database_url(url, database="defaultdb") - elif dialect_name != "sqlite": - url = set_engine_database_url(url, database=None) + url = get_masterdb_url(url) if dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}: if config.engine_config.connect_args is Empty: @@ -142,6 +142,16 @@ def drop_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> None: drop_database_sync(engine, database) +def _disconnect_users_sql(version: tuple[int, int] | None, database: str | None) -> str: + pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid" + return f""" + SELECT pg_terminate_backend(pg_stat_activity.{pid_column}) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{database}' + AND {pid_column} <> pg_backend_pid(); + """ # noqa: S608 + + async def drop_database_async(engine: AsyncEngine, database: str | None) -> None: dialect_name = engine.url.get_dialect().name if dialect_name == "sqlite" and database != ":memory:": @@ -151,13 +161,7 @@ async def drop_database_async(engine: AsyncEngine, database: str | None) -> None async with engine.begin() as conn: # Disconnect all users from the database we are dropping. version = conn.dialect.server_version_info - pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid" - sql = f""" - SELECT pg_terminate_backend(pg_stat_activity.{pid_column}) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{database}' - AND {pid_column} <> pg_backend_pid(); - """ # noqa: S608 + sql = _disconnect_users_sql(version, database) await conn.execute(text(sql)) # Drop the database. @@ -180,13 +184,7 @@ def drop_database_sync(engine: Engine, database: str | None) -> None: with engine.begin() as conn: # Disconnect all users from the database we are dropping. version = conn.dialect.server_version_info - pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid" - sql = f""" - SELECT pg_terminate_backend(pg_stat_activity.{pid_column}) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{database}' - AND {pid_column} <> pg_backend_pid(); - """ # noqa: S608 + sql = _disconnect_users_sql(version, database) conn.execute(text(sql)) # Drop the database. From b6ab18a7940ecbc297a8e50fd047148360c9b16e Mon Sep 17 00:00:00 2001 From: John Faucett Date: Mon, 27 Jan 2025 11:56:11 +0100 Subject: [PATCH 4/9] Add unit tests --- advanced_alchemy/utils/databases.py | 22 ++++++++++++---------- tests/unit/test_cli.py | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/advanced_alchemy/utils/databases.py b/advanced_alchemy/utils/databases.py index 5ed57db2..7b5eadfe 100644 --- a/advanced_alchemy/utils/databases.py +++ b/advanced_alchemy/utils/databases.py @@ -73,12 +73,13 @@ def create_database_sync(engine: Engine, database: str | None, encoding: str = " sql = f"CREATE DATABASE {database} CHARACTER SET = '{encoding}'" conn.execute(text(sql)) - elif dialect_name == "sqlite" and database != ":memory:": - if database: + elif dialect_name == "sqlite": + if database and database != ":memory:": with engine.begin() as conn: conn.execute(text("CREATE TABLE DB(id int)")) conn.execute(text("DROP TABLE DB")) - + else: + pass else: with engine.begin() as conn: sql = f"CREATE DATABASE {database}" @@ -99,12 +100,13 @@ async def create_database_async(engine: AsyncEngine, database: str | None, encod sql = f"CREATE DATABASE {database} CHARACTER SET = '{encoding}'" await conn.execute(text(sql)) - elif dialect_name == "sqlite" and database != ":memory:": - if database: + elif dialect_name == "sqlite": + if database and database != ":memory:": async with engine.begin() as conn: await conn.execute(text("CREATE TABLE DB(id int)")) await conn.execute(text("DROP TABLE DB")) - + else: + pass else: async with engine.begin() as conn: sql = f"CREATE DATABASE {database}" @@ -154,8 +156,8 @@ def _disconnect_users_sql(version: tuple[int, int] | None, database: str | None) async def drop_database_async(engine: AsyncEngine, database: str | None) -> None: dialect_name = engine.url.get_dialect().name - if dialect_name == "sqlite" and database != ":memory:": - if database: + if dialect_name == "sqlite": + if database and database != ":memory:": Path(database).unlink() elif dialect_name == "postgresql": async with engine.begin() as conn: @@ -177,8 +179,8 @@ async def drop_database_async(engine: AsyncEngine, database: str | None) -> None def drop_database_sync(engine: Engine, database: str | None) -> None: dialect_name = engine.url.get_dialect().name - if dialect_name == "sqlite" and database != ":memory:": - if database: + if dialect_name == "sqlite": + if database and database != ":memory:": Path(database).unlink() elif dialect_name == "postgresql": with engine.begin() as conn: diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 615e6e84..0a1b3bf4 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -145,6 +145,26 @@ def test_dump_data(cli_runner: CliRunner, database_cli: Group, mock_context: Mag assert result.exit_code == 0 +def test_create(cli_runner: CliRunner, database_cli: Group, mock_context: MagicMock, tmp_path: Path) -> None: + """Test create the database.""" + + result = cli_runner.invoke( + database_cli, + ["--config", "tests.unit.fixtures.configs", "create", "--encoding", "latin1", "--no-prompt"], + ) + assert result.exit_code == 0 + + +def test_drop(cli_runner: CliRunner, database_cli: Group, mock_context: MagicMock, tmp_path: Path) -> None: + """Test drop the database.""" + + result = cli_runner.invoke( + database_cli, + ["--config", "tests.unit.fixtures.configs", "drop", "--no-prompt"], + ) + assert result.exit_code == 0 + + def test_cli_group_creation() -> None: """Test that the CLI group is created correctly.""" cli_group = add_migration_commands() From 45d3fcafcb8f52186fae62328361f7b4c11d3d5e Mon Sep 17 00:00:00 2001 From: John Faucett Date: Mon, 27 Jan 2025 22:04:51 +0100 Subject: [PATCH 5/9] small clean-ups --- advanced_alchemy/utils/databases.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/advanced_alchemy/utils/databases.py b/advanced_alchemy/utils/databases.py index 7b5eadfe..a796b74d 100644 --- a/advanced_alchemy/utils/databases.py +++ b/advanced_alchemy/utils/databases.py @@ -109,8 +109,7 @@ async def create_database_async(engine: AsyncEngine, database: str | None, encod pass else: async with engine.begin() as conn: - sql = f"CREATE DATABASE {database}" - await conn.execute(text(sql)) + await conn.execute(text(f"CREATE DATABASE {database}")) await engine.dispose() @@ -165,14 +164,10 @@ async def drop_database_async(engine: AsyncEngine, database: str | None) -> None version = conn.dialect.server_version_info sql = _disconnect_users_sql(version, database) await conn.execute(text(sql)) - - # Drop the database. - sql = f"DROP DATABASE {database}" - await conn.execute(text(sql)) + await conn.execute(text(f"DROP DATABASE {database}")) else: async with engine.begin() as conn: - sql = f"DROP DATABASE {database}" - await conn.execute(text(sql)) + await conn.execute(text(f"DROP DATABASE {database}")) await engine.dispose() @@ -188,13 +183,9 @@ def drop_database_sync(engine: Engine, database: str | None) -> None: version = conn.dialect.server_version_info sql = _disconnect_users_sql(version, database) conn.execute(text(sql)) - - # Drop the database. - sql = f"DROP DATABASE {database}" - conn.execute(text(sql)) + conn.execute(text(f"DROP DATABASE {database}")) else: with engine.begin() as conn: - sql = f"DROP DATABASE {database}" - conn.execute(text(sql)) + conn.execute(text(f"DROP DATABASE {database}")) engine.dispose() From b15840097012605dd24835b28016107a696fa677 Mon Sep 17 00:00:00 2001 From: John Faucett Date: Tue, 28 Jan 2025 18:24:19 +0100 Subject: [PATCH 6/9] add simple sqlite tests --- advanced_alchemy/cli.py | 7 +- advanced_alchemy/utils/databases.py | 12 +-- tests/integration/test_database_commands.py | 106 ++++++++++++++++++++ 3 files changed, 115 insertions(+), 10 deletions(-) create mode 100644 tests/integration/test_database_commands.py diff --git a/advanced_alchemy/cli.py b/advanced_alchemy/cli.py index 9b4f07d5..e578a36d 100644 --- a/advanced_alchemy/cli.py +++ b/advanced_alchemy/cli.py @@ -411,6 +411,7 @@ async def _dump_tables() -> None: default="utf8", ) def create_database(bind_key: str | None, no_prompt: bool, encoding: str) -> None: # pyright: ignore[reportUnusedFunction] + from anyio import run from rich.prompt import Confirm from advanced_alchemy.utils.databases import create_database as _create_database @@ -425,7 +426,11 @@ def create_database(bind_key: str | None, no_prompt: bool, encoding: str) -> Non True if no_prompt else Confirm.ask(f"[bold]Are you sure you want to create a new database `{db_name}`?[/]") ) if input_confirmed: - _create_database(sqlalchemy_config, encoding) + + async def _create_database_wrapper() -> None: + await _create_database(sqlalchemy_config, encoding) + + run(_create_database_wrapper) @database_group.command(name="drop", help="Drop the current database.") @bind_key_option diff --git a/advanced_alchemy/utils/databases.py b/advanced_alchemy/utils/databases.py index a796b74d..ad4679b2 100644 --- a/advanced_alchemy/utils/databases.py +++ b/advanced_alchemy/utils/databases.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -from copy import copy from pathlib import Path from typing import TYPE_CHECKING @@ -18,12 +17,7 @@ def _set_engine_database_url(url: URL, database: str | None) -> URL: - if hasattr(url, "_replace"): - new_url = url._replace(database=database) - else: # SQLAlchemy <1.4 - new_url = copy(url) - new_url.database = database # type: ignore # noqa: PGH003 - return new_url + return url._replace(database=database) def get_masterdb_url(url: URL) -> URL: @@ -41,7 +35,7 @@ def get_masterdb_url(url: URL) -> URL: return url -def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None: +async def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None: url = config.get_engine().url database = url.database dialect_name = url.get_dialect().name @@ -56,7 +50,7 @@ def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encodi engine = config.create_engine_callable(str(url)) if isinstance(engine, AsyncEngine): - asyncio.run(create_database_async(engine, database, encoding)) + await create_database_async(engine, database, encoding) else: create_database_sync(engine, database, encoding) diff --git a/tests/integration/test_database_commands.py b/tests/integration/test_database_commands.py new file mode 100644 index 00000000..ff1e3fd9 --- /dev/null +++ b/tests/integration/test_database_commands.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import AsyncGenerator, Generator, cast + +import pytest +from pytest import CaptureFixture, FixtureRequest +from pytest_lazy_fixtures import lf +from sqlalchemy import Engine, NullPool, create_engine, select, text +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine +from sqlalchemy.orm import sessionmaker + +from advanced_alchemy import base +from advanced_alchemy.extensions.litestar import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig +from advanced_alchemy.utils.databases import create_database + +pytestmark = [ + pytest.mark.integration, +] + + +@pytest.fixture() +def sqlite_engine_cd(tmp_path: Path) -> Generator[Engine, None, None]: + """SQLite engine for end-to-end testing. + + Returns: + Async SQLAlchemy engine instance. + """ + engine = create_engine(f"sqlite:///{tmp_path}/test-cd.db", poolclass=NullPool) + try: + yield engine + finally: + engine.dispose() + + +@pytest.fixture() +async def aiosqlite_engine_cd(tmp_path: Path) -> AsyncGenerator[AsyncEngine, None]: + """SQLite engine for end-to-end testing. + + Returns: + Async SQLAlchemy engine instance. + """ + engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/test-cd-async.db", poolclass=NullPool) + try: + yield engine + finally: + await engine.dispose() + + +@pytest.fixture( + params=[ + pytest.param( + "sqlite_engine_cd", + marks=[ + pytest.mark.sqlite, + pytest.mark.integration, + ], + ) + ] +) +def sync_sqlalchemy_config(request: FixtureRequest) -> Generator[SQLAlchemySyncConfig, None, None]: + engine = cast(Engine, request.getfixturevalue(request.param)) + orm_registry = base.create_registry() + yield SQLAlchemySyncConfig( + engine_instance=engine, + session_maker=sessionmaker(bind=engine, expire_on_commit=False), + metadata=orm_registry.metadata, + ) + + +async def test_create_and_drop_sqlite_sync(sqlite_engine_cd: Engine, tmp_path: Path) -> None: + orm_registry = base.create_registry() + cfg = SQLAlchemySyncConfig( + engine_instance=sqlite_engine_cd, + session_maker=sessionmaker(bind=sqlite_engine_cd, expire_on_commit=False), + metadata=orm_registry.metadata, + ) + file_path = f"{tmp_path}/test-cd.db" + assert not Path(f"{tmp_path}/test-cd.db").exists() + try: + await create_database(cfg) + assert Path(file_path).exists() + finally: + if Path(file_path).exists(): + Path(file_path).unlink() + + +async def test_create_and_drop_sqlite_async(aiosqlite_engine_cd: AsyncEngine, tmp_path: Path) -> None: + orm_registry = base.create_registry() + cfg = SQLAlchemyAsyncConfig( + engine_instance=aiosqlite_engine_cd, + session_maker=async_sessionmaker(bind=aiosqlite_engine_cd, expire_on_commit=False), + metadata=orm_registry.metadata, + ) + file_path = f"{tmp_path}/test-cd-async.db" + assert not Path(file_path).exists() + try: + await create_database(cfg) + assert Path(file_path).exists() + async with cfg.get_session() as sess: + result = await sess.execute(select(text("1"))) + assert result.scalar_one() == 1 + finally: + if Path(file_path).exists(): + Path(file_path).unlink() From b9bb88d9df23fe98d4bae40c0922a3b3a563188e Mon Sep 17 00:00:00 2001 From: John Faucett Date: Wed, 29 Jan 2025 16:14:41 +0100 Subject: [PATCH 7/9] Use adapter pattern to make it more extendable --- advanced_alchemy/utils/databases.py | 218 +++++++------------- tests/integration/test_database_commands.py | 15 +- 2 files changed, 81 insertions(+), 152 deletions(-) diff --git a/advanced_alchemy/utils/databases.py b/advanced_alchemy/utils/databases.py index ad4679b2..99d2e128 100644 --- a/advanced_alchemy/utils/databases.py +++ b/advanced_alchemy/utils/databases.py @@ -1,185 +1,107 @@ from __future__ import annotations -import asyncio from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol from sqlalchemy import Engine, text from sqlalchemy.ext.asyncio import AsyncEngine -from advanced_alchemy.utils.dataclass import Empty +from advanced_alchemy.config.sync import SQLAlchemySyncConfig if TYPE_CHECKING: - from sqlalchemy.engine.url import URL - from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig - from advanced_alchemy.config.sync import SQLAlchemySyncConfig -def _set_engine_database_url(url: URL, database: str | None) -> URL: - return url._replace(database=database) +class Adapter(Protocol): + supported_drivers: list[str] = [] + dialect: str + config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig + encoding: str + _engine: Engine | None + _async_engine: AsyncEngine | None + _database: str | None -def get_masterdb_url(url: URL) -> URL: - dialect_name = url.get_dialect().name + def __init__(self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None: + self.config = config + self.encoding = encoding + self._engine: Engine | None = None + self._async_engine: AsyncEngine | None = None - if dialect_name == "postgresql": - url = _set_engine_database_url(url, database="postgres") - elif dialect_name == "mssql": - url = _set_engine_database_url(url, database="master") - elif dialect_name == "cockroachdb": - url = _set_engine_database_url(url, database="defaultdb") - elif dialect_name != "sqlite": - url = _set_engine_database_url(url, database=None) + if isinstance(config, SQLAlchemySyncConfig): + self._engine = config.get_engine() + self._database = self._engine.url.database + else: + self._async_engine = config.get_engine() + self._database = self._async_engine.url.database - return url + async def create_async(self) -> None: ... + def create(self) -> None: ... + async def drop_async(self) -> None: ... + def drop(self) -> None: ... -async def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None: - url = config.get_engine().url - database = url.database - dialect_name = url.get_dialect().name - dialect_driver = url.get_dialect().driver - url = get_masterdb_url(url) +class SQLiteAdapter(Adapter): + supported_drivers: list[str] = ["pysqlite", "aiosqlite"] + dialect: str = "sqlite" - if (dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}) or ( - dialect_name == "postgresql" and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"} - ): - config.engine_config.isolation_level = "AUTOCOMMIT" + def create(self) -> None: + if self._engine is not None and self._database and self._database != ":memory:": + with self._engine.begin() as conn: + conn.execute(text("CREATE TABLE DB(id int)")) + conn.execute(text("DROP TABLE DB")) - engine = config.create_engine_callable(str(url)) - if isinstance(engine, AsyncEngine): - await create_database_async(engine, database, encoding) - else: - create_database_sync(engine, database, encoding) + async def create_async(self) -> None: + if self._async_engine is not None: + async with self._async_engine.begin() as conn: + await conn.execute(text("CREATE TABLE DB(id int)")) + await conn.execute(text("DROP TABLE DB")) + def drop(self) -> None: + return self._drop() -def create_database_sync(engine: Engine, database: str | None, encoding: str = "utf8") -> None: - dialect_name = engine.url.get_dialect().name - if dialect_name == "postgresql": - with engine.begin() as conn: - sql = f"CREATE DATABASE {database} ENCODING '{encoding}'" - conn.execute(text(sql)) + async def drop_async(self) -> None: + return self._drop() - elif dialect_name == "mysql": - with engine.begin() as conn: - sql = f"CREATE DATABASE {database} CHARACTER SET = '{encoding}'" - conn.execute(text(sql)) + def _drop(self) -> None: + if self._database and self._database != ":memory:": + Path(self._database).unlink() - elif dialect_name == "sqlite": - if database and database != ":memory:": - with engine.begin() as conn: - conn.execute(text("CREATE TABLE DB(id int)")) - conn.execute(text("DROP TABLE DB")) - else: - pass - else: - with engine.begin() as conn: - sql = f"CREATE DATABASE {database}" - conn.execute(text(sql)) - engine.dispose() +ADAPTERS = {"sqlite": SQLiteAdapter} -async def create_database_async(engine: AsyncEngine, database: str | None, encoding: str = "utf8") -> None: - dialect_name = engine.url.get_dialect().name - if dialect_name == "postgresql": - async with engine.begin() as conn: - sql = f"CREATE DATABASE {database} ENCODING '{encoding}'" - await conn.execute(text(sql)) +def get_adapter(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> Adapter: + dialect_name = config.get_engine().url.get_dialect().name + driver = config.get_engine().url.get_dialect().driver - elif dialect_name == "mysql": - async with engine.begin() as conn: - sql = f"CREATE DATABASE {database} CHARACTER SET = '{encoding}'" - await conn.execute(text(sql)) + adapter_class = ADAPTERS.get(dialect_name) - elif dialect_name == "sqlite": - if database and database != ":memory:": - async with engine.begin() as conn: - await conn.execute(text("CREATE TABLE DB(id int)")) - await conn.execute(text("DROP TABLE DB")) - else: - pass - else: - async with engine.begin() as conn: - await conn.execute(text(f"CREATE DATABASE {database}")) - - await engine.dispose() + if not adapter_class: + msg = f"No adapter available for {dialect_name}" + raise ValueError(msg) + if driver not in adapter_class.supported_drivers: + msg = f"{dialect_name} adapter does not support the {driver} driver" + raise ValueError(msg) -def drop_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> None: - url = config.get_engine().url - database = url.database - dialect_name = url.get_dialect().name - dialect_driver = url.get_dialect().driver + return adapter_class(config, encoding=encoding) - url = get_masterdb_url(url) - if dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}: - if config.engine_config.connect_args is Empty: - config.engine_config.connect_args = {} - else: - config.engine_config.connect_args["autocommit"] = True # type: ignore # noqa: PGH003 - elif dialect_name == "postgresql" and dialect_driver in { - "asyncpg", - "pg8000", - "psycopg", - "psycopg2", - "psycopg2cffi", - }: - config.engine_config.isolation_level = "AUTOCOMMIT" - - engine = config.create_engine_callable(str(url)) +async def create_database(config: SQLAlchemySyncConfig | SQLAlchemyAsyncConfig, encoding: str = "utf8") -> None: + adapter = get_adapter(config, encoding) + engine = config.get_engine() if isinstance(engine, AsyncEngine): - asyncio.run(drop_database_async(engine, database)) + await adapter.create_async() else: - drop_database_sync(engine, database) - - -def _disconnect_users_sql(version: tuple[int, int] | None, database: str | None) -> str: - pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid" - return f""" - SELECT pg_terminate_backend(pg_stat_activity.{pid_column}) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{database}' - AND {pid_column} <> pg_backend_pid(); - """ # noqa: S608 - - -async def drop_database_async(engine: AsyncEngine, database: str | None) -> None: - dialect_name = engine.url.get_dialect().name - if dialect_name == "sqlite": - if database and database != ":memory:": - Path(database).unlink() - elif dialect_name == "postgresql": - async with engine.begin() as conn: - # Disconnect all users from the database we are dropping. - version = conn.dialect.server_version_info - sql = _disconnect_users_sql(version, database) - await conn.execute(text(sql)) - await conn.execute(text(f"DROP DATABASE {database}")) - else: - async with engine.begin() as conn: - await conn.execute(text(f"DROP DATABASE {database}")) - - await engine.dispose() - - -def drop_database_sync(engine: Engine, database: str | None) -> None: - dialect_name = engine.url.get_dialect().name - if dialect_name == "sqlite": - if database and database != ":memory:": - Path(database).unlink() - elif dialect_name == "postgresql": - with engine.begin() as conn: - # Disconnect all users from the database we are dropping. - version = conn.dialect.server_version_info - sql = _disconnect_users_sql(version, database) - conn.execute(text(sql)) - conn.execute(text(f"DROP DATABASE {database}")) - else: - with engine.begin() as conn: - conn.execute(text(f"DROP DATABASE {database}")) + adapter.create() + - engine.dispose() +async def drop_database(config: SQLAlchemySyncConfig | SQLAlchemyAsyncConfig) -> None: + adapter = get_adapter(config) + engine = config.get_engine() + if isinstance(engine, AsyncEngine): + await adapter.drop_async() + else: + adapter.drop() diff --git a/tests/integration/test_database_commands.py b/tests/integration/test_database_commands.py index ff1e3fd9..e439bbf2 100644 --- a/tests/integration/test_database_commands.py +++ b/tests/integration/test_database_commands.py @@ -1,19 +1,17 @@ from __future__ import annotations -import asyncio from pathlib import Path from typing import AsyncGenerator, Generator, cast import pytest -from pytest import CaptureFixture, FixtureRequest -from pytest_lazy_fixtures import lf +from pytest import FixtureRequest from sqlalchemy import Engine, NullPool, create_engine, select, text from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.orm import sessionmaker from advanced_alchemy import base from advanced_alchemy.extensions.litestar import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig -from advanced_alchemy.utils.databases import create_database +from advanced_alchemy.utils.databases import create_database, drop_database pytestmark = [ pytest.mark.integration, @@ -81,7 +79,13 @@ async def test_create_and_drop_sqlite_sync(sqlite_engine_cd: Engine, tmp_path: P try: await create_database(cfg) assert Path(file_path).exists() + with cfg.get_session() as sess: + result = sess.execute(select(text("1"))) + assert result.scalar_one() == 1 + await drop_database(cfg) + assert not Path(file_path).exists() finally: + # always clean up if Path(file_path).exists(): Path(file_path).unlink() @@ -101,6 +105,9 @@ async def test_create_and_drop_sqlite_async(aiosqlite_engine_cd: AsyncEngine, tm async with cfg.get_session() as sess: result = await sess.execute(select(text("1"))) assert result.scalar_one() == 1 + await drop_database(cfg) + assert not Path(file_path).exists() finally: + # always clean up if Path(file_path).exists(): Path(file_path).unlink() From a7f84e41e357566e149f074c1ba09bc031f0ffcc Mon Sep 17 00:00:00 2001 From: John Faucett Date: Sat, 1 Feb 2025 10:30:21 +0100 Subject: [PATCH 8/9] Add adapter and test for postgres --- advanced_alchemy/utils/databases.py | 105 ++++++++++++++++---- tests/integration/test_database_commands.py | 88 ++++++++++++---- 2 files changed, 151 insertions(+), 42 deletions(-) diff --git a/advanced_alchemy/utils/databases.py b/advanced_alchemy/utils/databases.py index 99d2e128..8e04fed3 100644 --- a/advanced_alchemy/utils/databases.py +++ b/advanced_alchemy/utils/databases.py @@ -3,8 +3,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Protocol -from sqlalchemy import Engine, text -from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy import URL, Engine, create_engine, text +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from advanced_alchemy.config.sync import SQLAlchemySyncConfig @@ -13,27 +13,26 @@ class Adapter(Protocol): - supported_drivers: list[str] = [] + supported_drivers: set[str] = set() dialect: str config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig - encoding: str + encoding: str | None = None - _engine: Engine | None - _async_engine: AsyncEngine | None - _database: str | None + engine: Engine | None = None + async_engine: AsyncEngine | None = None + original_database_name: str | None = None def __init__(self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None: self.config = config self.encoding = encoding - self._engine: Engine | None = None - self._async_engine: AsyncEngine | None = None if isinstance(config, SQLAlchemySyncConfig): - self._engine = config.get_engine() - self._database = self._engine.url.database + self.setup(config) else: - self._async_engine = config.get_engine() - self._database = self._async_engine.url.database + self.setup_async(config) + + def setup_async(self, config: SQLAlchemyAsyncConfig) -> None: ... + def setup(self, config: SQLAlchemySyncConfig) -> None: ... async def create_async(self) -> None: ... def create(self) -> None: ... @@ -43,18 +42,26 @@ def drop(self) -> None: ... class SQLiteAdapter(Adapter): - supported_drivers: list[str] = ["pysqlite", "aiosqlite"] + supported_drivers: set[str] = {"pysqlite", "aiosqlite"} dialect: str = "sqlite" + def setup(self, config: SQLAlchemySyncConfig) -> None: + self.engine = config.get_engine() + self.original_database_name = self.engine.url.database + + def setup_async(self, config: SQLAlchemyAsyncConfig) -> None: + self.async_engine = config.get_engine() + self.original_database_name = self.async_engine.url.database + def create(self) -> None: - if self._engine is not None and self._database and self._database != ":memory:": - with self._engine.begin() as conn: + if self.engine is not None and self.original_database_name and self.original_database_name != ":memory:": + with self.engine.begin() as conn: conn.execute(text("CREATE TABLE DB(id int)")) conn.execute(text("DROP TABLE DB")) async def create_async(self) -> None: - if self._async_engine is not None: - async with self._async_engine.begin() as conn: + if self.async_engine is not None: + async with self.async_engine.begin() as conn: await conn.execute(text("CREATE TABLE DB(id int)")) await conn.execute(text("DROP TABLE DB")) @@ -65,11 +72,67 @@ async def drop_async(self) -> None: return self._drop() def _drop(self) -> None: - if self._database and self._database != ":memory:": - Path(self._database).unlink() + if self.original_database_name and self.original_database_name != ":memory:": + Path(self.original_database_name).unlink() + + +class PostgresAdapter(Adapter): + supported_drivers: set[str] = {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"} + dialect: str = "postgresql" + + def _set_url(self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> URL: + original_url = self.config.get_engine().url + self.original_database_name = original_url.database + return original_url._replace(database="postgres") + + def setup(self, config: SQLAlchemySyncConfig) -> None: + updated_url = self._set_url(config) + self.engine = create_engine(updated_url, isolation_level="AUTOCOMMIT") + def setup_async(self, config: SQLAlchemyAsyncConfig) -> None: + updated_url = self._set_url(config) + self.async_engine = create_async_engine(updated_url, isolation_level="AUTOCOMMIT") -ADAPTERS = {"sqlite": SQLiteAdapter} + def create(self) -> None: + if self.engine: + with self.engine.begin() as conn: + sql = f"CREATE DATABASE {self.original_database_name} ENCODING '{self.encoding}'" + conn.execute(text(sql)) + + async def create_async(self) -> None: + if self.async_engine: + async with self.async_engine.begin() as conn: + sql = f"CREATE DATABASE {self.original_database_name} ENCODING '{self.encoding}'" + await conn.execute(text(sql)) + + def drop(self) -> None: + if self.engine: + with self.engine.begin() as conn: + # Disconnect all users from the database we are dropping. + version = conn.dialect.server_version_info + sql = self._disconnect_users_sql(version, self.original_database_name) + conn.execute(text(sql)) + conn.execute(text(f"DROP DATABASE {self.original_database_name}")) + + async def drop_async(self) -> None: + if self.async_engine: + async with self.async_engine.begin() as conn: + # Disconnect all users from the database we are dropping. + version = conn.dialect.server_version_info + sql = self._disconnect_users_sql(version, self.original_database_name) + await conn.execute(text(sql)) + await conn.execute(text(f"DROP DATABASE {self.original_database_name}")) + + def _disconnect_users_sql(self, version: tuple[int, int] | None, database: str | None) -> str: + pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid" + return f""" + SELECT pg_terminate_backend(pg_stat_activity.{pid_column}) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{database}' + AND {pid_column} <> pg_backend_pid();""" # noqa: S608 + + +ADAPTERS = {"sqlite": SQLiteAdapter, "postgresql": PostgresAdapter} def get_adapter(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> Adapter: diff --git a/tests/integration/test_database_commands.py b/tests/integration/test_database_commands.py index e439bbf2..43f93b6c 100644 --- a/tests/integration/test_database_commands.py +++ b/tests/integration/test_database_commands.py @@ -1,11 +1,10 @@ from __future__ import annotations from pathlib import Path -from typing import AsyncGenerator, Generator, cast +from typing import AsyncGenerator, Generator import pytest -from pytest import FixtureRequest -from sqlalchemy import Engine, NullPool, create_engine, select, text +from sqlalchemy import URL, Engine, NullPool, create_engine, select, text from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.orm import sessionmaker @@ -46,24 +45,37 @@ async def aiosqlite_engine_cd(tmp_path: Path) -> AsyncGenerator[AsyncEngine, Non await engine.dispose() -@pytest.fixture( - params=[ - pytest.param( - "sqlite_engine_cd", - marks=[ - pytest.mark.sqlite, - pytest.mark.integration, - ], - ) - ] -) -def sync_sqlalchemy_config(request: FixtureRequest) -> Generator[SQLAlchemySyncConfig, None, None]: - engine = cast(Engine, request.getfixturevalue(request.param)) - orm_registry = base.create_registry() - yield SQLAlchemySyncConfig( - engine_instance=engine, - session_maker=sessionmaker(bind=engine, expire_on_commit=False), - metadata=orm_registry.metadata, +@pytest.fixture() +async def asyncpg_engine_cd(docker_ip: str, postgres_service: None) -> AsyncGenerator[AsyncEngine, None]: + """Postgresql instance for end-to-end testing.""" + yield create_async_engine( + URL( + drivername="postgresql+asyncpg", + username="postgres", + password="super-secret", + host=docker_ip, + port=5423, + database="testing_create_delete", + query={}, # type:ignore[arg-type] + ), + poolclass=NullPool, + ) + + +@pytest.fixture() +def psycopg_engine_cd(docker_ip: str, postgres_service: None) -> Generator[Engine, None, None]: + """Postgresql instance for end-to-end testing.""" + yield create_engine( + URL( + drivername="postgresql+psycopg", + username="postgres", + password="super-secret", + host=docker_ip, + port=5423, + database="postgres", + query={}, # type:ignore[arg-type] + ), + poolclass=NullPool, ) @@ -111,3 +123,37 @@ async def test_create_and_drop_sqlite_async(aiosqlite_engine_cd: AsyncEngine, tm # always clean up if Path(file_path).exists(): Path(file_path).unlink() + + +async def test_create_and_drop_postgres_async(asyncpg_engine_cd: AsyncEngine, asyncpg_engine: AsyncEngine) -> None: + orm_registry = base.create_registry() + cfg = SQLAlchemyAsyncConfig( + engine_instance=asyncpg_engine_cd, + session_maker=async_sessionmaker(bind=asyncpg_engine_cd, expire_on_commit=False), + metadata=orm_registry.metadata, + ) + + dbname = asyncpg_engine_cd.url.database + exists_sql = f""" + select exists( + SELECT datname FROM pg_catalog.pg_database WHERE lower(datname) = lower('{dbname}') + ); + """ + + # ensure database does not exist + async with asyncpg_engine.begin() as conn: + result = await conn.execute(text(exists_sql)) + assert not result.scalar_one() + + await create_database(cfg) + async with asyncpg_engine.begin() as conn: + result = await conn.execute(text(exists_sql)) + assert result.scalar_one() + + await drop_database(cfg) + + async with asyncpg_engine.begin() as conn: + result = await conn.execute(text(exists_sql)) + assert not result.scalar_one() + + await asyncpg_engine.dispose() From d43629c546571125628b8c9ac864818b6bf523f4 Mon Sep 17 00:00:00 2001 From: John Faucett Date: Sat, 1 Feb 2025 10:36:41 +0100 Subject: [PATCH 9/9] run the drop database awaitable --- advanced_alchemy/cli.py | 7 ++++++- advanced_alchemy/utils/databases.py | 9 +++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/advanced_alchemy/cli.py b/advanced_alchemy/cli.py index e578a36d..e5ad0047 100644 --- a/advanced_alchemy/cli.py +++ b/advanced_alchemy/cli.py @@ -436,6 +436,7 @@ async def _create_database_wrapper() -> None: @bind_key_option @no_prompt_option def drop_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction] + from anyio import run from rich.prompt import Confirm from advanced_alchemy.utils.databases import drop_database as _drop_database @@ -450,6 +451,10 @@ def drop_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: ig True if no_prompt else Confirm.ask(f"[bold]Are you sure you want to drop database `{db_name}`?[/]") ) if input_confirmed: - _drop_database(sqlalchemy_config) + + async def _drop_database_wrapper() -> None: + await _drop_database(sqlalchemy_config) + + run(_drop_database_wrapper) return database_group diff --git a/advanced_alchemy/utils/databases.py b/advanced_alchemy/utils/databases.py index 8e04fed3..035822df 100644 --- a/advanced_alchemy/utils/databases.py +++ b/advanced_alchemy/utils/databases.py @@ -137,14 +137,13 @@ def _disconnect_users_sql(self, version: tuple[int, int] | None, database: str | def get_adapter(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> Adapter: dialect_name = config.get_engine().url.get_dialect().name - driver = config.get_engine().url.get_dialect().driver - adapter_class = ADAPTERS.get(dialect_name) if not adapter_class: msg = f"No adapter available for {dialect_name}" raise ValueError(msg) + driver = config.get_engine().url.get_dialect().driver if driver not in adapter_class.supported_drivers: msg = f"{dialect_name} adapter does not support the {driver} driver" raise ValueError(msg) @@ -154,8 +153,7 @@ def get_adapter(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: async def create_database(config: SQLAlchemySyncConfig | SQLAlchemyAsyncConfig, encoding: str = "utf8") -> None: adapter = get_adapter(config, encoding) - engine = config.get_engine() - if isinstance(engine, AsyncEngine): + if isinstance(config.get_engine(), AsyncEngine): await adapter.create_async() else: adapter.create() @@ -163,8 +161,7 @@ async def create_database(config: SQLAlchemySyncConfig | SQLAlchemyAsyncConfig, async def drop_database(config: SQLAlchemySyncConfig | SQLAlchemyAsyncConfig) -> None: adapter = get_adapter(config) - engine = config.get_engine() - if isinstance(engine, AsyncEngine): + if isinstance(config.get_engine(), AsyncEngine): await adapter.drop_async() else: adapter.drop()