Skip to content

Commit

Permalink
Add support for Snowflake token authentication and environment variab…
Browse files Browse the repository at this point in the history
…le mapping
  • Loading branch information
Aleksandr Movchan committed Dec 11, 2024
1 parent f0ca3f5 commit ed776d3
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
28 changes: 27 additions & 1 deletion aana/configs/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from os import PathLike

from pydantic import model_validator
from pydantic_settings import BaseSettings
from sqlalchemy.engine import Engine
from typing_extensions import TypedDict
Expand Down Expand Up @@ -35,12 +37,14 @@ class PostgreSQLConfig(TypedDict):
database: str


class SnowflakeConfig(TypedDict):
class SnowflakeConfig(TypedDict, total=False):
"""Config values for Snowflake.
Attributes:
account (str): The account name.
user (str): The user to connect to the Snowflake server.
host (str): The host of the Snowflake server.
token (str): The token to connect to the Snowflake server.
password (str): The password to connect to the Snowflake server.
database (str): The database name.
schema (str): The schema name.
Expand All @@ -50,6 +54,8 @@ class SnowflakeConfig(TypedDict):

account: str
user: str
host: str
token: str
password: str
database: str
schema: str
Expand Down Expand Up @@ -100,3 +106,23 @@ def __setstate__(self, state):
# We don't need to do anything special here, since the engine will be recreated
# if needed.
self.__dict__.update(state)

@model_validator(mode="after")
def update_from_alias_env_vars(self):
"""Update the database configuration from alias environment variables."""
if self.datastore_type == DbType.SNOWFLAKE:
mapping = {
"SNOWFLAKE_ACCOUNT": "account",
"SNOWFLAKE_DATABASE": "database",
"SNOWFLAKE_HOST": "host",
"SNOWFLAKE_SCHEMA": "schema",
"SNOWFLAKE_USER": "user",
"SNOWFLAKE_PASSWORD": "password",
"SNOWFLAKE_WAREHOUSE": "warehouse",
"SNOWFLAKE_ROLE": "role",
"SNOWFLAKE_TOKEN": "token",
}
for env_var, key in mapping.items():
if not self.datastore_config.get(key) and os.environ.get(env_var):
self.datastore_config[key] = os.environ[env_var]
return self
13 changes: 12 additions & 1 deletion aana/storage/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def create_sqlite_engine(db_config: "DbSettings"):
)


def create_snowflake_engine(db_config: "DbSettings"):
def create_snowflake_engine(db_config: "DbSettings"): # noqa: C901
"""Create a Snowflake SQLAlchemy engine based on the provided configuration.
Args:
Expand All @@ -94,6 +94,17 @@ def create_snowflake_engine(db_config: "DbSettings"):
sqlalchemy.engine.Engine: SQLAlchemy engine instance.
"""
datastore_config = db_config.datastore_config

# If token is not provided, check if token file exists
SNOWFLAKE_TOKEN_PATH = Path("/snowflake/session/token")
if SNOWFLAKE_TOKEN_PATH.exists() and "token" not in datastore_config:
token = SNOWFLAKE_TOKEN_PATH.read_text()
datastore_config["token"] = token

# Set authenticator to oauth if token is provided
if "token" in datastore_config:
datastore_config["authenticator"] = "oauth"

connection_string = SNOWFLAKE_URL(**datastore_config)
engine = create_engine(
connection_string,
Expand Down
26 changes: 25 additions & 1 deletion aana/tests/db/datastore/test_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# ruff: noqa: S101
import json
import os

import pytest

from aana.configs.db import DbSettings, PostgreSQLConfig, SQLiteConfig
from aana.configs.db import DbSettings, PostgreSQLConfig, SnowflakeConfig, SQLiteConfig


@pytest.fixture
Expand All @@ -28,6 +31,20 @@ def sqlite_settings():
)


@pytest.fixture
def snowflake_settings():
"""Fixture for working Snowflake settings."""
SNOWFLAKE_TEST_PARAMETERS = os.environ.get("SNOWFLAKE_TEST_PARAMETERS")
if not SNOWFLAKE_TEST_PARAMETERS:
pytest.skip("Snowflake test parameters not found")
SNOWFLAKE_TEST_PARAMETERS = json.loads(SNOWFLAKE_TEST_PARAMETERS)

return DbSettings(
datastore_type="snowflake",
datastore_config=SnowflakeConfig(**SNOWFLAKE_TEST_PARAMETERS),
)


def test_get_engine_idempotent(pg_settings, sqlite_settings):
"""Tests that get_engine returns the same engine on subsequent calls."""
for db_settings in (pg_settings, sqlite_settings):
Expand All @@ -52,6 +69,13 @@ def test_sqlite_datastore_config(sqlite_settings):
assert str(engine.url) == f"sqlite:///{sqlite_settings.datastore_config['path']}"


def test_snowflake_datastore_config(snowflake_settings):
"""Tests datastore config for Snowflake."""
engine = snowflake_settings.get_engine()

assert engine.name == "snowflake"


def test_nonexistent_datastore_config():
"""Tests that datastore config errors on unsupported DB types."""
db_settings = DbSettings(
Expand Down

0 comments on commit ed776d3

Please sign in to comment.