diff --git a/aana/alembic/versions/5ad873484aa3_init.py b/aana/alembic/versions/5ad873484aa3_init.py index 89b77d13..0a200515 100644 --- a/aana/alembic/versions/5ad873484aa3_init.py +++ b/aana/alembic/versions/5ad873484aa3_init.py @@ -10,6 +10,7 @@ from sqlalchemy.schema import CreateSequence, Sequence from aana.storage.types import JSON +from aana.storage.utcnow import utcnow # revision identifiers, used by Alembic. revision: str = "5ad873484aa3" @@ -57,14 +58,14 @@ def upgrade() -> None: sa.Column( "created_at", sa.DateTime(timezone=True), - server_default=sa.text("(CURRENT_TIMESTAMP)"), + server_default=utcnow(), nullable=False, comment="Timestamp when row is inserted", ), sa.Column( "updated_at", sa.DateTime(timezone=True), - server_default=sa.text("(CURRENT_TIMESTAMP)"), + server_default=utcnow(), nullable=False, comment="Timestamp when row is updated", ), @@ -84,14 +85,14 @@ def upgrade() -> None: sa.Column( "created_at", sa.DateTime(timezone=True), - server_default=sa.text("(CURRENT_TIMESTAMP)"), + server_default=utcnow(), nullable=False, comment="Timestamp when row is inserted", ), sa.Column( "updated_at", sa.DateTime(timezone=True), - server_default=sa.text("(CURRENT_TIMESTAMP)"), + server_default=utcnow(), nullable=False, comment="Timestamp when row is updated", ), @@ -135,14 +136,14 @@ def upgrade() -> None: sa.Column( "assigned_at", sa.DateTime(timezone=True), - server_default=sa.text("(CURRENT_TIMESTAMP)"), + server_default=utcnow(), nullable=True, comment="Timestamp when the task was assigned", ), sa.Column( "completed_at", sa.DateTime(timezone=True), - server_default=sa.text("(CURRENT_TIMESTAMP)"), + server_default=utcnow(), nullable=True, comment="Timestamp when the task was completed", ), @@ -161,14 +162,14 @@ def upgrade() -> None: sa.Column( "created_at", sa.DateTime(timezone=True), - server_default=sa.text("(CURRENT_TIMESTAMP)"), + server_default=utcnow(), nullable=False, comment="Timestamp when row is inserted", ), sa.Column( "updated_at", sa.DateTime(timezone=True), - server_default=sa.text("(CURRENT_TIMESTAMP)"), + server_default=utcnow(), nullable=False, comment="Timestamp when row is updated", ), @@ -216,14 +217,14 @@ def upgrade() -> None: sa.Column( "created_at", sa.DateTime(timezone=True), - server_default=sa.text("(CURRENT_TIMESTAMP)"), + server_default=utcnow(), nullable=False, comment="Timestamp when row is inserted", ), sa.Column( "updated_at", sa.DateTime(timezone=True), - server_default=sa.text("(CURRENT_TIMESTAMP)"), + server_default=utcnow(), nullable=False, comment="Timestamp when row is updated", ), diff --git a/aana/storage/models/base.py b/aana/storage/models/base.py index 616066b1..390dd484 100644 --- a/aana/storage/models/base.py +++ b/aana/storage/models/base.py @@ -1,7 +1,7 @@ import datetime from typing import Annotated, Any, TypeVar -from sqlalchemy import DateTime, MetaData, String, func +from sqlalchemy import DateTime, MetaData, String from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -11,6 +11,7 @@ ) from aana.core.models.media import MediaId +from aana.storage.utcnow import utcnow timestamp = Annotated[ datetime.datetime, @@ -79,11 +80,11 @@ class TimeStampEntity: """Mixin for database entities that will have create/update timestamps.""" created_at: Mapped[timestamp] = mapped_column( - server_default=func.now(), + server_default=utcnow(), comment="Timestamp when row is inserted", ) updated_at: Mapped[timestamp] = mapped_column( - onupdate=func.now(), - server_default=func.now(), + onupdate=utcnow(), + server_default=utcnow(), comment="Timestamp when row is updated", ) diff --git a/aana/storage/utcnow.py b/aana/storage/utcnow.py new file mode 100644 index 00000000..6e87ec50 --- /dev/null +++ b/aana/storage/utcnow.py @@ -0,0 +1,35 @@ +from sqlalchemy import DateTime +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.functions import FunctionElement + + +class utcnow(FunctionElement): + """UTCNOW() expression for multiple dialects.""" + + inherit_cache = True + type = DateTime() + + +@compiles(utcnow) +def default_sql_utcnow(element, compiler, **kw): + """Assume, by default, time zones work correctly. + + Note: + This is a valid assumption for PostgreSQL and Oracle. + """ + return "CURRENT_TIMESTAMP" + + +@compiles(utcnow, "sqlite") +def sqlite_sql_utcnow(element, compiler, **kw): + """SQLite DATETIME('NOW') returns a correct `datetime.datetime` but does not add milliseconds to it. + + Directly call STRFTIME with the final %f modifier in order to get those. + """ + return r"(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))" + + +@compiles(utcnow, "snowflake") +def snowflake_sql_utcnow(element, compiler, **kw): + """In Snowflake, SYSDATE() returns the current timestamp for the system in the UTC time zone.""" + return "SYSDATE()" diff --git a/aana/tests/db/test_utcnow.py b/aana/tests/db/test_utcnow.py new file mode 100644 index 00000000..ee1153d0 --- /dev/null +++ b/aana/tests/db/test_utcnow.py @@ -0,0 +1,13 @@ +# ruff: noqa: S101 +from datetime import datetime, timedelta, timezone + +from aana.storage.utcnow import utcnow + + +def test_utcnow(db_session): + """Tests the utcnow() function.""" + current_time_utc = datetime.now(tz=timezone.utc) + result = db_session.execute(utcnow()).scalar() + result = result.replace(tzinfo=timezone.utc) # Make result offset-aware + assert isinstance(result, datetime) + assert current_time_utc - result < timedelta(seconds=1)