Skip to content

Commit

Permalink
More snowflake related fixes. Added tests for snowflake DB.
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Dec 10, 2024
1 parent 2b7d70d commit f4fceee
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ jobs:
uses: ikalnytskyi/action-setup-postgres@v6
- name: Test with pytest
env:
USE_DEPLOYMENT_CACHE: True
SNOWFLAKE_TEST_PARAMETERS: ${{ secrets.SNOWFLAKE_TEST_PARAMETERS }}
run: poetry run pytest -vv
2 changes: 1 addition & 1 deletion .github/workflows/run_tests_with_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
- name: Run tests with GPU
run: |
HF_TOKEN=${{ secrets.HF_TOKEN }} dstack apply -f tests.dstack.yml --force -y
SNOWFLAKE_TEST_PARAMETERS=${{ secrets.SNOWFLAKE_TEST_PARAMETERS }} HF_TOKEN=${{ secrets.HF_TOKEN }} dstack apply -f tests.dstack.yml --force -y
- name: Extract pytest logs
if: ${{ always() }}
Expand Down
2 changes: 2 additions & 0 deletions aana/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class SnowflakeImpl(DefaultImpl):
"""Custom implementation for Snowflake."""

__dialect__ = "snowflake"


Expand Down
22 changes: 16 additions & 6 deletions aana/alembic/versions/5ad873484aa3_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# id_seq = Sequence('my_sequence_name', start=1, increment=1)

op.execute(CreateSequence(Sequence("caption_id_seq")))
op.execute(CreateSequence(Sequence("transcript_id_seq")))
if op.get_context().dialect.name != "sqlite":
op.execute(CreateSequence(Sequence("caption_id_seq")))
op.execute(CreateSequence(Sequence("transcript_id_seq")))

op.create_table(
"caption",
sa.Column("id", sa.Integer(), Sequence("caption_id_seq"), nullable=False),
sa.Column("id", sa.Integer(), Sequence("caption_id_seq"), nullable=False)
if op.get_context().dialect.name == "snowflake"
else sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"model",
sa.String(),
Expand Down Expand Up @@ -96,7 +99,12 @@ def upgrade() -> None:
)
op.create_table(
"tasks",
sa.Column("id", sa.String(), nullable=False, comment="Task ID"),
sa.Column(
"id",
sa.String() if op.get_context().dialect.name == "snowflake" else sa.UUID(),
nullable=False,
comment="Task ID",
),
sa.Column(
"endpoint",
sa.String(),
Expand Down Expand Up @@ -165,11 +173,13 @@ def upgrade() -> None:
comment="Timestamp when row is updated",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_tasks")),
prefixes=["HYBRID"],
prefixes=["HYBRID"] if op.get_context().dialect.name == "snowflake" else [],
)
op.create_table(
"transcript",
sa.Column("id", sa.Integer(), Sequence("transcript_id_seq"), nullable=False),
sa.Column("id", sa.Integer(), Sequence("transcript_id_seq"), nullable=False)
if op.get_context().dialect.name == "snowflake"
else sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"model",
sa.String(),
Expand Down
44 changes: 12 additions & 32 deletions aana/storage/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,15 @@ def process(value):
return process


JSON = VARIANT

# class JSON(TypeDecorator):
# """Custom JSON type that supports Snowflake-specific and standard dialects."""

# impl = SqlAlchemyJSON # Default to standard SQLAlchemy JSON

# def load_dialect_impl(self, dialect):
# """Load dialect-specific implementation."""
# if dialect.name == "snowflake":
# return SnowflakeVariantType()
# return self.impl

# def bind_expression(self, bindvalue):
# """Handle binding expressions dynamically."""
# if hasattr(
# bindvalue.type, "bind_expression"
# ): # Check if impl has bind_expression
# return bindvalue.type.bind_expression(bindvalue)
# return bindvalue # Default binding behavior

# def process_result_value(self, value, dialect):
# """Process the result based on dialect."""
# if dialect.name == "snowflake":
# if value is None:
# return None
# try:
# return orjson.loads(value)
# except (ValueError, TypeError):
# return value # Return raw value if not valid JSON
# # For other dialects, call the default implementation
# return self.impl.process_result_value(value, dialect)
class JSON(TypeDecorator):
"""Custom JSON type that supports Snowflake-specific and standard dialects."""

impl = SqlAlchemyJSON # Default to standard SQLAlchemy JSON
# impl = VARIANT # Default to Snowflake VARIANT

def load_dialect_impl(self, dialect):
"""Load dialect-specific implementation."""
if dialect.name == "snowflake":
return VARIANT()
else:
return SqlAlchemyJSON()
6 changes: 4 additions & 2 deletions aana/storage/models/caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aana.core.models.captions import Caption


caption_id_seq = Sequence("caption_id_seq", start=1, increment=1)
caption_id_seq = Sequence("caption_id_seq", start=1, increment=1, optional=True)


class CaptionEntity(BaseEntity, TimeStampEntity):
Expand All @@ -28,7 +28,9 @@ class CaptionEntity(BaseEntity, TimeStampEntity):

__tablename__ = "caption"

id: Mapped[int] = mapped_column(caption_id_seq, primary_key=True)
id: Mapped[int] = mapped_column(
caption_id_seq, autoincrement=True, primary_key=True
)
model: Mapped[str] = mapped_column(
nullable=False, comment="Name of model used to generate the caption"
)
Expand Down
7 changes: 0 additions & 7 deletions aana/storage/models/task.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import re
import uuid
from enum import Enum

from sqlalchemy import (
UUID,
PickleType,
event,
insert,
select,
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.sql import Insert

from aana.storage.custom_types import JSON
from aana.storage.models.base import BaseEntity, TimeStampEntity, timestamp
Expand All @@ -32,7 +26,6 @@ class TaskEntity(BaseEntity, TimeStampEntity):
"""Table for task items."""

__tablename__ = "tasks"
__table_args__ = {"prefixes": ["HYBRID"]}

id: Mapped[uuid.UUID] = mapped_column(
UUID, primary_key=True, default=uuid.uuid4, comment="Task ID"
Expand Down
6 changes: 4 additions & 2 deletions aana/storage/models/transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
AsrTranscriptionInfo,
)

transcript_id_seq = Sequence("transcript_id_seq", start=1, increment=1)
transcript_id_seq = Sequence("transcript_id_seq", start=1, increment=1, optional=True)


class TranscriptEntity(BaseEntity, TimeStampEntity):
Expand All @@ -33,7 +33,9 @@ class TranscriptEntity(BaseEntity, TimeStampEntity):

__tablename__ = "transcript"

id: Mapped[int] = mapped_column(transcript_id_seq, primary_key=True)
id: Mapped[int] = mapped_column(
transcript_id_seq, autoincrement=True, primary_key=True
)
model: Mapped[str] = mapped_column(
nullable=False, comment="Name of model used to generate transcript"
)
Expand Down
2 changes: 0 additions & 2 deletions aana/storage/op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import typing
from enum import Enum
from pathlib import Path
Expand All @@ -19,7 +18,6 @@

import re

from snowflake.sqlalchemy.custom_types import OBJECT, VARIANT
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import Insert

Expand Down
34 changes: 32 additions & 2 deletions aana/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# This file is used to define fixtures that are used in the integration tests.
# ruff: noqa: S101
import json
import os
import tempfile
from pathlib import Path
Expand All @@ -10,7 +11,7 @@
from pytest_postgresql import factories
from sqlalchemy.orm import Session

from aana.configs.db import DbSettings, PostgreSQLConfig, SQLiteConfig
from aana.configs.db import DbSettings, PostgreSQLConfig, SnowflakeConfig, SQLiteConfig
from aana.configs.settings import settings as aana_settings
from aana.sdk import AanaSDK
from aana.storage.op import DbType, run_alembic_migrations
Expand Down Expand Up @@ -211,7 +212,36 @@ def postgres_db_session(postgresql):
yield session


@pytest.fixture(params=["sqlite_db_session", "postgres_db_session"])
@pytest.fixture(scope="function")
def snowflake_db_session():
"""Creates a new snowflake database and session for each test."""
SNOWFLAKE_TEST_PARAMETERS = os.environ.get("SNOWFLAKE_TEST_PARAMETERS")
if not SNOWFLAKE_TEST_PARAMETERS:
pytest.skip("Snowflake test parameters not found")

SNOWFLAKE_TEST_PARAMETERS = json.loads(SNOWFLAKE_TEST_PARAMETERS)

aana_settings.db_config.datastore_type = DbType.SNOWFLAKE
aana_settings.db_config.datastore_config = SnowflakeConfig(
**SNOWFLAKE_TEST_PARAMETERS
)
os.environ["DB_CONFIG"] = jsonify(aana_settings.db_config)

# Reset the engine
aana_settings.db_config._engine = None

# Run migrations to set up the schema
run_alembic_migrations(aana_settings)

# Create a new session
engine = aana_settings.db_config.get_engine()
with Session(engine) as session:
yield session


@pytest.fixture(
params=["sqlite_db_session", "postgres_db_session", "snowflake_db_session"]
)
def db_session(request):
"""Iterate over different database type for db tests."""
return request.getfixturevalue(request.param)
1 change: 1 addition & 0 deletions tests.dstack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ image: nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04

env:
- HF_TOKEN
- SNOWFLAKE_TEST_PARAMETERS

commands:
- apt-get update
Expand Down

0 comments on commit f4fceee

Please sign in to comment.