From ed776d30e3e279ea777fe18ac6084b75dcbda67f Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Wed, 11 Dec 2024 14:56:32 +0000 Subject: [PATCH] Add support for Snowflake token authentication and environment variable mapping --- aana/configs/db.py | 28 +++++++++++++++++++++++++- aana/storage/op.py | 13 +++++++++++- aana/tests/db/datastore/test_config.py | 26 +++++++++++++++++++++++- 3 files changed, 64 insertions(+), 3 deletions(-) diff --git a/aana/configs/db.py b/aana/configs/db.py index bfb01b55..4f2df17d 100644 --- a/aana/configs/db.py +++ b/aana/configs/db.py @@ -1,5 +1,7 @@ +import os from os import PathLike +from pydantic import model_validator from pydantic_settings import BaseSettings from sqlalchemy.engine import Engine from typing_extensions import TypedDict @@ -35,12 +37,14 @@ class PostgreSQLConfig(TypedDict): database: str -class SnowflakeConfig(TypedDict): +class SnowflakeConfig(TypedDict, total=False): """Config values for Snowflake. Attributes: account (str): The account name. user (str): The user to connect to the Snowflake server. + host (str): The host of the Snowflake server. + token (str): The token to connect to the Snowflake server. password (str): The password to connect to the Snowflake server. database (str): The database name. schema (str): The schema name. @@ -50,6 +54,8 @@ class SnowflakeConfig(TypedDict): account: str user: str + host: str + token: str password: str database: str schema: str @@ -100,3 +106,23 @@ def __setstate__(self, state): # We don't need to do anything special here, since the engine will be recreated # if needed. self.__dict__.update(state) + + @model_validator(mode="after") + def update_from_alias_env_vars(self): + """Update the database configuration from alias environment variables.""" + if self.datastore_type == DbType.SNOWFLAKE: + mapping = { + "SNOWFLAKE_ACCOUNT": "account", + "SNOWFLAKE_DATABASE": "database", + "SNOWFLAKE_HOST": "host", + "SNOWFLAKE_SCHEMA": "schema", + "SNOWFLAKE_USER": "user", + "SNOWFLAKE_PASSWORD": "password", + "SNOWFLAKE_WAREHOUSE": "warehouse", + "SNOWFLAKE_ROLE": "role", + "SNOWFLAKE_TOKEN": "token", + } + for env_var, key in mapping.items(): + if not self.datastore_config.get(key) and os.environ.get(env_var): + self.datastore_config[key] = os.environ[env_var] + return self diff --git a/aana/storage/op.py b/aana/storage/op.py index 47271d3f..b7d4f227 100644 --- a/aana/storage/op.py +++ b/aana/storage/op.py @@ -84,7 +84,7 @@ def create_sqlite_engine(db_config: "DbSettings"): ) -def create_snowflake_engine(db_config: "DbSettings"): +def create_snowflake_engine(db_config: "DbSettings"): # noqa: C901 """Create a Snowflake SQLAlchemy engine based on the provided configuration. Args: @@ -94,6 +94,17 @@ def create_snowflake_engine(db_config: "DbSettings"): sqlalchemy.engine.Engine: SQLAlchemy engine instance. """ datastore_config = db_config.datastore_config + + # If token is not provided, check if token file exists + SNOWFLAKE_TOKEN_PATH = Path("/snowflake/session/token") + if SNOWFLAKE_TOKEN_PATH.exists() and "token" not in datastore_config: + token = SNOWFLAKE_TOKEN_PATH.read_text() + datastore_config["token"] = token + + # Set authenticator to oauth if token is provided + if "token" in datastore_config: + datastore_config["authenticator"] = "oauth" + connection_string = SNOWFLAKE_URL(**datastore_config) engine = create_engine( connection_string, diff --git a/aana/tests/db/datastore/test_config.py b/aana/tests/db/datastore/test_config.py index 4f3a4464..7ddd22a8 100644 --- a/aana/tests/db/datastore/test_config.py +++ b/aana/tests/db/datastore/test_config.py @@ -1,7 +1,10 @@ # ruff: noqa: S101 +import json +import os + import pytest -from aana.configs.db import DbSettings, PostgreSQLConfig, SQLiteConfig +from aana.configs.db import DbSettings, PostgreSQLConfig, SnowflakeConfig, SQLiteConfig @pytest.fixture @@ -28,6 +31,20 @@ def sqlite_settings(): ) +@pytest.fixture +def snowflake_settings(): + """Fixture for working Snowflake settings.""" + SNOWFLAKE_TEST_PARAMETERS = os.environ.get("SNOWFLAKE_TEST_PARAMETERS") + if not SNOWFLAKE_TEST_PARAMETERS: + pytest.skip("Snowflake test parameters not found") + SNOWFLAKE_TEST_PARAMETERS = json.loads(SNOWFLAKE_TEST_PARAMETERS) + + return DbSettings( + datastore_type="snowflake", + datastore_config=SnowflakeConfig(**SNOWFLAKE_TEST_PARAMETERS), + ) + + def test_get_engine_idempotent(pg_settings, sqlite_settings): """Tests that get_engine returns the same engine on subsequent calls.""" for db_settings in (pg_settings, sqlite_settings): @@ -52,6 +69,13 @@ def test_sqlite_datastore_config(sqlite_settings): assert str(engine.url) == f"sqlite:///{sqlite_settings.datastore_config['path']}" +def test_snowflake_datastore_config(snowflake_settings): + """Tests datastore config for Snowflake.""" + engine = snowflake_settings.get_engine() + + assert engine.name == "snowflake" + + def test_nonexistent_datastore_config(): """Tests that datastore config errors on unsupported DB types.""" db_settings = DbSettings(