Skip to content

Commit

Permalink
Add Snowflake DB support
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Dec 9, 2024
1 parent d9bc745 commit 53f85b2
Show file tree
Hide file tree
Showing 8 changed files with 374 additions and 17 deletions.
6 changes: 6 additions & 0 deletions aana/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from logging.config import fileConfig

from alembic import context
from alembic.ddl.impl import DefaultImpl
from sqlalchemy import engine_from_config, pool

from aana.configs.settings import settings
from aana.storage.models.base import BaseEntity


class SnowflakeImpl(DefaultImpl):
__dialect__ = "snowflake"


# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
Expand Down
8 changes: 5 additions & 3 deletions aana/alembic/versions/5ad873484aa3_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import sqlalchemy as sa
from alembic import op

from aana.storage.custom_types import JSON, VARIANT

# revision identifiers, used by Alembic.
revision: str = "5ad873484aa3"
down_revision: str | None = None
Expand Down Expand Up @@ -90,7 +92,7 @@ def upgrade() -> None:
)
op.create_table(
"tasks",
sa.Column("id", sa.UUID(), nullable=False, comment="Task ID"),
sa.Column("id", sa.String(), nullable=False, comment="Task ID"),
sa.Column(
"endpoint",
sa.String(),
Expand Down Expand Up @@ -140,7 +142,7 @@ def upgrade() -> None:
),
sa.Column(
"result",
sa.JSON(),
JSON(),
nullable=True,
comment="Result of the task in JSON format",
),
Expand Down Expand Up @@ -176,7 +178,7 @@ def upgrade() -> None:
comment="Full text transcript of media",
),
sa.Column(
"segments", sa.JSON(), nullable=False, comment="Segments of the transcript"
"segments", JSON(), nullable=False, comment="Segments of the transcript"
),
sa.Column(
"language",
Expand Down
26 changes: 24 additions & 2 deletions aana/configs/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,34 @@ class PostgreSQLConfig(TypedDict):
database: str


class SnowflakeConfig(TypedDict):
"""Config values for Snowflake.
Attributes:
account (str): The account name.
user (str): The user to connect to the Snowflake server.
password (str): The password to connect to the Snowflake server.
database (str): The database name.
schema (str): The schema name.
warehouse (str): The warehouse name.
role (str): The role name.
"""

account: str
user: str
password: str
database: str
schema: str
warehouse: str
role: str


class DbSettings(BaseSettings):
"""Database configuration.
Attributes:
datastore_type (DbType | str): The type of the datastore. Default is DbType.SQLITE.
datastore_config (SQLiteConfig | PostgreSQLConfig): The configuration for the datastore.
datastore_config (SQLiteConfig | PostgreSQLConfig | SnowflakeConfig): The configuration for the datastore.
Default is SQLiteConfig(path="/var/lib/aana_data").
pool_size (int): The number of connections to keep in the pool. Default is 5.
max_overflow (int): The number of connections that can be created when the pool is exhausted.
Expand All @@ -50,7 +72,7 @@ class DbSettings(BaseSettings):
"""

datastore_type: DbType | str = DbType.SQLITE
datastore_config: SQLiteConfig | PostgreSQLConfig = SQLiteConfig(
datastore_config: SQLiteConfig | PostgreSQLConfig | SnowflakeConfig = SQLiteConfig(
path="/var/lib/aana_data"
)
pool_size: int = 5
Expand Down
60 changes: 60 additions & 0 deletions aana/storage/custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import orjson
from snowflake.sqlalchemy.custom_types import VARIANT as SnowflakeVariantType
from sqlalchemy import func
from sqlalchemy.types import JSON as SqlAlchemyJSON
from sqlalchemy.types import TypeDecorator


class VARIANT(SnowflakeVariantType):
"""Extends VARIANT type for better SqlAlchemy support."""

def bind_expression(self, bindvalue):
"""Wraps value with PARSE_JSON for Snowflake."""
return func.PARSE_JSON(bindvalue)

def result_processor(self, dialect, coltype):
"""Convert JSON string to Python dictionary when retrieving."""

def process(value):
if value is None:
return None
try:
return orjson.loads(value)
except (ValueError, TypeError):
return value # Return raw value if not valid JSON

return process


JSON = VARIANT

# class JSON(TypeDecorator):
# """Custom JSON type that supports Snowflake-specific and standard dialects."""

# impl = SqlAlchemyJSON # Default to standard SQLAlchemy JSON

# def load_dialect_impl(self, dialect):
# """Load dialect-specific implementation."""
# if dialect.name == "snowflake":
# return SnowflakeVariantType()
# return self.impl

# def bind_expression(self, bindvalue):
# """Handle binding expressions dynamically."""
# if hasattr(
# bindvalue.type, "bind_expression"
# ): # Check if impl has bind_expression
# return bindvalue.type.bind_expression(bindvalue)
# return bindvalue # Default binding behavior

# def process_result_value(self, value, dialect):
# """Process the result based on dialect."""
# if dialect.name == "snowflake":
# if value is None:
# return None
# try:
# return orjson.loads(value)
# except (ValueError, TypeError):
# return value # Return raw value if not valid JSON
# # For other dialects, call the default implementation
# return self.impl.process_result_value(value, dialect)
8 changes: 7 additions & 1 deletion aana/storage/models/task.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import re
import uuid
from enum import Enum

from sqlalchemy import (
JSON,
UUID,
PickleType,
event,
insert,
select,
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.sql import Insert

from aana.storage.custom_types import JSON
from aana.storage.models.base import BaseEntity, TimeStampEntity, timestamp


Expand Down
76 changes: 75 additions & 1 deletion aana/storage/op.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import typing
from enum import Enum
from pathlib import Path

import orjson
from alembic import command
from alembic.config import Config
from sqlalchemy import create_engine
from snowflake.sqlalchemy import URL as SNOWFLAKE_URL
from sqlalchemy import create_engine, event

from aana.exceptions.runtime import EmptyMigrationsException
from aana.utils.core import get_module_dir
Expand All @@ -14,6 +16,12 @@
if typing.TYPE_CHECKING:
from aana.configs.db import DbSettings

import re

from snowflake.sqlalchemy.custom_types import OBJECT, VARIANT
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import Insert


class DbType(str, Enum):
"""Engine types for relational database.
Expand All @@ -25,6 +33,7 @@ class DbType(str, Enum):

POSTGRESQL = "postgresql"
SQLITE = "sqlite"
SNOWFLAKE = "snowflake"


def create_postgresql_engine(db_config: "DbSettings"):
Expand Down Expand Up @@ -76,6 +85,69 @@ def create_sqlite_engine(db_config: "DbSettings"):
)


def create_snowflake_engine(db_config: "DbSettings"):
"""Create a Snowflake SQLAlchemy engine based on the provided configuration.
Args:
db_config (DbSettings): Database configuration.
Returns:
sqlalchemy.engine.Engine: SQLAlchemy engine instance.
"""
datastore_config = db_config.datastore_config
connection_string = SNOWFLAKE_URL(**datastore_config)
engine = create_engine(
connection_string,
pool_size=db_config.pool_size,
max_overflow=db_config.max_overflow,
pool_recycle=db_config.pool_recycle,
)

@event.listens_for(engine, "before_cursor_execute")
def preprocess_parameters(
conn, cursor, statement, parameters, context, executemany
):
"""Preprocess parameters before executing a query."""
if isinstance(parameters, dict): # Handle dict-style parameters
for key, value in parameters.items():
# Convert VARIANT type to JSON string
if (
isinstance(value, dict)
and context.compiled
and key in context.compiled.binds
and isinstance(context.compiled.binds["result"].type, VARIANT)
):
parameters[key] = jsonify(value)

@compiles(Insert, "default")
def compile_insert(insert_stmt, compiler, **kwargs):
"""Compile INSERT statements to use SELECT instead of VALUES for Snowflake PARSE_JSON."""
sql = compiler.visit_insert(insert_stmt, **kwargs)

# Only transform if PARSE_JSON is present in the SQL
if "PARSE_JSON" not in sql:
return sql

# Locate the VALUES clause and replace it
def replace_values_with_select(sql):
# Regex to find `VALUES (...)` ensuring balanced parentheses
pattern = r"VALUES\s*(\((?:[^)(]+|\((?:[^)(]+|\([^)(]*\))*\))*\))"
match = re.search(pattern, sql)
if match:
values_clause = match.group(1) # Captures the `(...)` after VALUES
# Replace VALUES (...) with SELECT ...
return sql.replace(
f"VALUES {values_clause}", f"SELECT {values_clause[1:-1]}"
)
return sql

# Replace the VALUES clause with SELECT
sql = replace_values_with_select(sql)
return sql

return engine


def create_database_engine(db_config: "DbSettings"):
"""Create SQLAlchemy engine based on the provided configuration.
Expand All @@ -91,6 +163,8 @@ def create_database_engine(db_config: "DbSettings"):
return create_postgresql_engine(db_config)
elif db_type == DbType.SQLITE:
return create_sqlite_engine(db_config)
elif db_type == DbType.SNOWFLAKE:
return create_snowflake_engine(db_config)
else:
raise ValueError(f"Unsupported database type: {db_type}") # noqa: TRY003

Expand Down
Loading

0 comments on commit 53f85b2

Please sign in to comment.