Skip to content

Commit

Permalink
IAM Auth for RDS (#3479)
Browse files Browse the repository at this point in the history
* k

* functional iam auth

* k

* k

* improve typing

* add deployment options

* cleanup

* quick clean up

* minor cleanup

* additional clarity for db session operations

* nit

* k

* k

* update configs

* docker compose spacing
  • Loading branch information
pablonyx authored Dec 17, 2024
1 parent 2859869 commit 8db6d49
Show file tree
Hide file tree
Showing 15 changed files with 281 additions and 138 deletions.
96 changes: 58 additions & 38 deletions backend/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,49 @@
from typing import Any, Literal
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
from sqlalchemy.engine.base import Connection
from typing import Literal
import os
import ssl
import asyncio
from logging.config import fileConfig
import logging
from logging.config import fileConfig

from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem

from shared_configs.configs import MULTI_TENANT
from onyx.db.engine import build_connection_string
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine import get_all_tenant_ids
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA

# Alembic Config object
config = context.config

# Interpret the config file for Python logging.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
fileConfig(config.config_file_name)

# Add your model's MetaData object here for 'autogenerate' support
target_metadata = [Base.metadata, ResultModelBase.metadata]

EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}

# Set up logging
logger = logging.getLogger(__name__)

ssl_context: ssl.SSLContext | None = None
if USE_IAM_AUTH:
if not os.path.exists(SSL_CERT_FILE):
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)


def include_object(
object: SchemaItem,
Expand All @@ -49,20 +59,12 @@ def include_object(
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
"""
Determines whether a database object should be included in migrations.
Excludes specified tables from migrations.
"""
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True


def get_schema_options() -> tuple[str, bool, bool]:
"""
Parses command-line options passed via '-x' in Alembic commands.
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
"""
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
Expand Down Expand Up @@ -90,16 +92,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
"""
Executes migrations in the specified schema.
"""
logger.info(f"About to migrate schema: {schema_name}")

if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))

# Set search_path to the target schema
connection.execute(text(f'SET search_path TO "{schema_name}"'))

context.configure(
Expand All @@ -117,22 +115,42 @@ def do_run_migrations(
context.run_migrations()


def provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
if USE_IAM_AUTH:
# Database connection settings
region = AWS_REGION
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER

# Get IAM authentication token
token = get_iam_auth_token(host, port, user, region)

# For Alembic / SQLAlchemy in this context, set SSL and password
cparams["password"] = token
cparams["ssl"] = ssl_context


async def run_async_migrations() -> None:
"""
Determines whether to run migrations for a single schema or all schemas,
and executes migrations accordingly.
"""
schema_name, create_schema, upgrade_all_tenants = get_schema_options()

engine = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
)

if USE_IAM_AUTH:

@event.listens_for(engine.sync_engine, "do_connect")
def event_provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)

if upgrade_all_tenants:
# Run migrations for all tenant schemas sequentially
tenant_schemas = get_all_tenant_ids()

for schema in tenant_schemas:
try:
logger.info(f"Migrating schema: {schema}")
Expand Down Expand Up @@ -162,15 +180,20 @@ async def run_async_migrations() -> None:


def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode.
"""
schema_name, _, upgrade_all_tenants = get_schema_options()
url = build_connection_string()

if upgrade_all_tenants:
# Run offline migrations for all tenant schemas
engine = create_async_engine(url)

if USE_IAM_AUTH:

@event.listens_for(engine.sync_engine, "do_connect")
def event_provide_iam_token_for_alembic_offline(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)

tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()

Expand Down Expand Up @@ -207,9 +230,6 @@ def run_migrations_offline() -> None:


def run_migrations_online() -> None:
"""
Runs migrations in 'online' mode using an asynchronous engine.
"""
asyncio.run(run_async_migrations())


Expand Down
4 changes: 4 additions & 0 deletions backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"

POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
Expand Down Expand Up @@ -174,6 +175,9 @@
except ValueError:
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT

USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"


REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"

SSL_CERT_FILE = "bundle.pem"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
Expand Down
Loading

0 comments on commit 8db6d49

Please sign in to comment.