diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index b0e6335e..d5483ebb 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -42,5 +42,5 @@ jobs: uses: ikalnytskyi/action-setup-postgres@v6 - name: Test with pytest env: - USE_DEPLOYMENT_CACHE: True + SNOWFLAKE_TEST_PARAMETERS: ${{ secrets.SNOWFLAKE_TEST_PARAMETERS }} run: poetry run pytest -vv diff --git a/.github/workflows/run_tests_with_gpu.yml b/.github/workflows/run_tests_with_gpu.yml index fe75ea7a..b18bf8cb 100644 --- a/.github/workflows/run_tests_with_gpu.yml +++ b/.github/workflows/run_tests_with_gpu.yml @@ -34,7 +34,7 @@ jobs: - name: Run tests with GPU run: | - HF_TOKEN=${{ secrets.HF_TOKEN }} dstack apply -f tests.dstack.yml --force -y + SNOWFLAKE_TEST_PARAMETERS=${{ secrets.SNOWFLAKE_TEST_PARAMETERS }} HF_TOKEN=${{ secrets.HF_TOKEN }} dstack apply -f tests.dstack.yml --force -y - name: Extract pytest logs if: ${{ always() }} diff --git a/aana/alembic/env.py b/aana/alembic/env.py index 7f5e658f..80772a49 100644 --- a/aana/alembic/env.py +++ b/aana/alembic/env.py @@ -9,6 +9,8 @@ class SnowflakeImpl(DefaultImpl): + """Custom implementation for Snowflake.""" + __dialect__ = "snowflake" diff --git a/aana/alembic/versions/5ad873484aa3_init.py b/aana/alembic/versions/5ad873484aa3_init.py index c3a6033d..c62abda4 100644 --- a/aana/alembic/versions/5ad873484aa3_init.py +++ b/aana/alembic/versions/5ad873484aa3_init.py @@ -23,12 +23,15 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### # id_seq = Sequence('my_sequence_name', start=1, increment=1) - op.execute(CreateSequence(Sequence("caption_id_seq"))) - op.execute(CreateSequence(Sequence("transcript_id_seq"))) + if op.get_context().dialect.name != "sqlite": + op.execute(CreateSequence(Sequence("caption_id_seq"))) + op.execute(CreateSequence(Sequence("transcript_id_seq"))) op.create_table( "caption", - sa.Column("id", sa.Integer(), Sequence("caption_id_seq"), nullable=False), + sa.Column("id", sa.Integer(), Sequence("caption_id_seq"), nullable=False) + if op.get_context().dialect.name == "snowflake" + else sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column( "model", sa.String(), @@ -96,7 +99,12 @@ def upgrade() -> None: ) op.create_table( "tasks", - sa.Column("id", sa.String(), nullable=False, comment="Task ID"), + sa.Column( + "id", + sa.String() if op.get_context().dialect.name == "snowflake" else sa.UUID(), + nullable=False, + comment="Task ID", + ), sa.Column( "endpoint", sa.String(), @@ -165,11 +173,13 @@ def upgrade() -> None: comment="Timestamp when row is updated", ), sa.PrimaryKeyConstraint("id", name=op.f("pk_tasks")), - prefixes=["HYBRID"], + prefixes=["HYBRID"] if op.get_context().dialect.name == "snowflake" else [], ) op.create_table( "transcript", - sa.Column("id", sa.Integer(), Sequence("transcript_id_seq"), nullable=False), + sa.Column("id", sa.Integer(), Sequence("transcript_id_seq"), nullable=False) + if op.get_context().dialect.name == "snowflake" + else sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column( "model", sa.String(), diff --git a/aana/storage/custom_types.py b/aana/storage/custom_types.py index 11afcdbd..70015813 100644 --- a/aana/storage/custom_types.py +++ b/aana/storage/custom_types.py @@ -26,35 +26,15 @@ def process(value): return process -JSON = VARIANT - -# class JSON(TypeDecorator): -# """Custom JSON type that supports Snowflake-specific and standard dialects.""" - -# impl = SqlAlchemyJSON # Default to standard SQLAlchemy JSON - -# def load_dialect_impl(self, dialect): -# """Load dialect-specific implementation.""" -# if dialect.name == "snowflake": -# return SnowflakeVariantType() -# return self.impl - -# def bind_expression(self, bindvalue): -# """Handle binding expressions dynamically.""" -# if hasattr( -# bindvalue.type, "bind_expression" -# ): # Check if impl has bind_expression -# return bindvalue.type.bind_expression(bindvalue) -# return bindvalue # Default binding behavior - -# def process_result_value(self, value, dialect): -# """Process the result based on dialect.""" -# if dialect.name == "snowflake": -# if value is None: -# return None -# try: -# return orjson.loads(value) -# except (ValueError, TypeError): -# return value # Return raw value if not valid JSON -# # For other dialects, call the default implementation -# return self.impl.process_result_value(value, dialect) +class JSON(TypeDecorator): + """Custom JSON type that supports Snowflake-specific and standard dialects.""" + + impl = SqlAlchemyJSON # Default to standard SQLAlchemy JSON + # impl = VARIANT # Default to Snowflake VARIANT + + def load_dialect_impl(self, dialect): + """Load dialect-specific implementation.""" + if dialect.name == "snowflake": + return VARIANT() + else: + return SqlAlchemyJSON() diff --git a/aana/storage/models/caption.py b/aana/storage/models/caption.py index 63b6fa84..f7213796 100644 --- a/aana/storage/models/caption.py +++ b/aana/storage/models/caption.py @@ -11,7 +11,7 @@ from aana.core.models.captions import Caption -caption_id_seq = Sequence("caption_id_seq", start=1, increment=1) +caption_id_seq = Sequence("caption_id_seq", start=1, increment=1, optional=True) class CaptionEntity(BaseEntity, TimeStampEntity): @@ -28,7 +28,9 @@ class CaptionEntity(BaseEntity, TimeStampEntity): __tablename__ = "caption" - id: Mapped[int] = mapped_column(caption_id_seq, primary_key=True) + id: Mapped[int] = mapped_column( + caption_id_seq, autoincrement=True, primary_key=True + ) model: Mapped[str] = mapped_column( nullable=False, comment="Name of model used to generate the caption" ) diff --git a/aana/storage/models/task.py b/aana/storage/models/task.py index 18204051..1ad6db12 100644 --- a/aana/storage/models/task.py +++ b/aana/storage/models/task.py @@ -1,17 +1,11 @@ -import re import uuid from enum import Enum from sqlalchemy import ( UUID, PickleType, - event, - insert, - select, ) -from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy.sql import Insert from aana.storage.custom_types import JSON from aana.storage.models.base import BaseEntity, TimeStampEntity, timestamp @@ -32,7 +26,6 @@ class TaskEntity(BaseEntity, TimeStampEntity): """Table for task items.""" __tablename__ = "tasks" - __table_args__ = {"prefixes": ["HYBRID"]} id: Mapped[uuid.UUID] = mapped_column( UUID, primary_key=True, default=uuid.uuid4, comment="Task ID" diff --git a/aana/storage/models/transcript.py b/aana/storage/models/transcript.py index 625edb7d..33d55695 100644 --- a/aana/storage/models/transcript.py +++ b/aana/storage/models/transcript.py @@ -15,7 +15,7 @@ AsrTranscriptionInfo, ) -transcript_id_seq = Sequence("transcript_id_seq", start=1, increment=1) +transcript_id_seq = Sequence("transcript_id_seq", start=1, increment=1, optional=True) class TranscriptEntity(BaseEntity, TimeStampEntity): @@ -33,7 +33,9 @@ class TranscriptEntity(BaseEntity, TimeStampEntity): __tablename__ = "transcript" - id: Mapped[int] = mapped_column(transcript_id_seq, primary_key=True) + id: Mapped[int] = mapped_column( + transcript_id_seq, autoincrement=True, primary_key=True + ) model: Mapped[str] = mapped_column( nullable=False, comment="Name of model used to generate transcript" ) diff --git a/aana/storage/op.py b/aana/storage/op.py index eecff7e7..47271d3f 100644 --- a/aana/storage/op.py +++ b/aana/storage/op.py @@ -1,4 +1,3 @@ -import json import typing from enum import Enum from pathlib import Path @@ -19,7 +18,6 @@ import re -from snowflake.sqlalchemy.custom_types import OBJECT, VARIANT from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import Insert diff --git a/aana/tests/conftest.py b/aana/tests/conftest.py index f6692bf7..2243af1b 100644 --- a/aana/tests/conftest.py +++ b/aana/tests/conftest.py @@ -1,5 +1,6 @@ # This file is used to define fixtures that are used in the integration tests. # ruff: noqa: S101 +import json import os import tempfile from pathlib import Path @@ -10,7 +11,7 @@ from pytest_postgresql import factories from sqlalchemy.orm import Session -from aana.configs.db import DbSettings, PostgreSQLConfig, SQLiteConfig +from aana.configs.db import DbSettings, PostgreSQLConfig, SnowflakeConfig, SQLiteConfig from aana.configs.settings import settings as aana_settings from aana.sdk import AanaSDK from aana.storage.op import DbType, run_alembic_migrations @@ -211,7 +212,36 @@ def postgres_db_session(postgresql): yield session -@pytest.fixture(params=["sqlite_db_session", "postgres_db_session"]) +@pytest.fixture(scope="function") +def snowflake_db_session(): + """Creates a new snowflake database and session for each test.""" + SNOWFLAKE_TEST_PARAMETERS = os.environ.get("SNOWFLAKE_TEST_PARAMETERS") + if not SNOWFLAKE_TEST_PARAMETERS: + pytest.skip("Snowflake test parameters not found") + + SNOWFLAKE_TEST_PARAMETERS = json.loads(SNOWFLAKE_TEST_PARAMETERS) + + aana_settings.db_config.datastore_type = DbType.SNOWFLAKE + aana_settings.db_config.datastore_config = SnowflakeConfig( + **SNOWFLAKE_TEST_PARAMETERS + ) + os.environ["DB_CONFIG"] = jsonify(aana_settings.db_config) + + # Reset the engine + aana_settings.db_config._engine = None + + # Run migrations to set up the schema + run_alembic_migrations(aana_settings) + + # Create a new session + engine = aana_settings.db_config.get_engine() + with Session(engine) as session: + yield session + + +@pytest.fixture( + params=["sqlite_db_session", "postgres_db_session", "snowflake_db_session"] +) def db_session(request): """Iterate over different database type for db tests.""" return request.getfixturevalue(request.param) diff --git a/tests.dstack.yml b/tests.dstack.yml index 8e43d8ee..303aba6d 100644 --- a/tests.dstack.yml +++ b/tests.dstack.yml @@ -8,6 +8,7 @@ image: nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04 env: - HF_TOKEN + - SNOWFLAKE_TEST_PARAMETERS commands: - apt-get update