Skip to content

Commit

Permalink
A little clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
DataDaoDe committed Jan 27, 2025
1 parent 61fefca commit fbef24e
Showing 1 changed file with 30 additions and 32 deletions.
62 changes: 30 additions & 32 deletions advanced_alchemy/utils/databases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import os
from copy import copy
from pathlib import Path
from typing import TYPE_CHECKING
Expand All @@ -18,7 +17,7 @@
from advanced_alchemy.config.sync import SQLAlchemySyncConfig


def set_engine_database_url(url: URL, database: str | None) -> URL:
def _set_engine_database_url(url: URL, database: str | None) -> URL:
if hasattr(url, "_replace"):
new_url = url._replace(database=database)
else: # SQLAlchemy <1.4
Expand All @@ -27,20 +26,28 @@ def set_engine_database_url(url: URL, database: str | None) -> URL:
return new_url


def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None:
url = config.get_engine().url
database = url.database
def get_masterdb_url(url: URL) -> URL:
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver

if dialect_name == "postgresql":
url = set_engine_database_url(url, database="postgres")
url = _set_engine_database_url(url, database="postgres")
elif dialect_name == "mssql":
url = set_engine_database_url(url, database="master")
url = _set_engine_database_url(url, database="master")
elif dialect_name == "cockroachdb":
url = set_engine_database_url(url, database="defaultdb")
url = _set_engine_database_url(url, database="defaultdb")
elif dialect_name != "sqlite":
url = set_engine_database_url(url, database=None)
url = _set_engine_database_url(url, database=None)

return url


def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None:
url = config.get_engine().url
database = url.database
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver

url = get_masterdb_url(url)

if (dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}) or (
dialect_name == "postgresql" and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
Expand Down Expand Up @@ -112,14 +119,7 @@ def drop_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> None:
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver

if dialect_name == "postgresql":
url = set_engine_database_url(url, database="postgres")
elif dialect_name == "mssql":
url = set_engine_database_url(url, database="master")
elif dialect_name == "cockroachdb":
url = set_engine_database_url(url, database="defaultdb")
elif dialect_name != "sqlite":
url = set_engine_database_url(url, database=None)
url = get_masterdb_url(url)

if dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}:
if config.engine_config.connect_args is Empty:
Expand All @@ -142,6 +142,16 @@ def drop_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> None:
drop_database_sync(engine, database)


def _disconnect_users_sql(version: tuple[int, int] | None, database: str | None) -> str:
pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid"
return f"""
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();
""" # noqa: S608


async def drop_database_async(engine: AsyncEngine, database: str | None) -> None:
dialect_name = engine.url.get_dialect().name
if dialect_name == "sqlite" and database != ":memory:":
Expand All @@ -151,13 +161,7 @@ async def drop_database_async(engine: AsyncEngine, database: str | None) -> None
async with engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid"
sql = f"""
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();
""" # noqa: S608
sql = _disconnect_users_sql(version, database)
await conn.execute(text(sql))

# Drop the database.
Expand All @@ -180,13 +184,7 @@ def drop_database_sync(engine: Engine, database: str | None) -> None:
with engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid"
sql = f"""
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();
""" # noqa: S608
sql = _disconnect_users_sql(version, database)
conn.execute(text(sql))

# Drop the database.
Expand Down

0 comments on commit fbef24e

Please sign in to comment.