diff --git a/fixbackend/all_models.py b/fixbackend/all_models.py index 7a190582..2f6bb207 100644 --- a/fixbackend/all_models.py +++ b/fixbackend/all_models.py @@ -22,3 +22,4 @@ from fixbackend.organizations.models.orm import Organization, OrganizationInvite # noqa from fixbackend.graph_db.service import GraphDatabaseAccessEntity # noqa from fixbackend.cloud_accounts.models.orm import CloudAccount # noqa +from fixbackend.dispatcher.next_run_repository import NextRun # noqa diff --git a/fixbackend/app.py b/fixbackend/app.py index f0329710..4a62082a 100644 --- a/fixbackend/app.py +++ b/fixbackend/app.py @@ -24,18 +24,24 @@ from fastapi.exception_handlers import http_exception_handler from fastapi.staticfiles import StaticFiles from prometheus_fastapi_instrumentator import Instrumentator +from redis.asyncio import Redis +from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine from starlette.exceptions import HTTPException from fixbackend import config, dependencies from fixbackend.auth.oauth import github_client, google_client from fixbackend.auth.router import auth_router, users_router +from fixbackend.cloud_accounts.repository import CloudAccountRepositoryImpl from fixbackend.cloud_accounts.router import cloud_accounts_router, cloud_accounts_callback_router from fixbackend.collect.collect_queue import RedisCollectQueue from fixbackend.config import Config from fixbackend.dependencies import FixDependencies from fixbackend.dependencies import ServiceNames as SN +from fixbackend.dispatcher.dispatcher_service import DispatcherService +from fixbackend.dispatcher.next_run_repository import NextRunRepository from fixbackend.events.router import websocket_router +from fixbackend.graph_db.service import GraphDatabaseAccessManager from fixbackend.inventory.inventory_client import InventoryClient from fixbackend.inventory.inventory_service import InventoryService from fixbackend.inventory.router import inventory_router @@ -53,8 +59,14 @@ def fast_api_app(cfg: Config) -> FastAPI: @asynccontextmanager async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]: arq_redis = deps.add(SN.arg_redis, await create_pool(RedisSettings.from_dsn(cfg.redis_queue_url))) - deps.add(SN.async_engine, create_async_engine(cfg.database_url, pool_size=10)) + deps.add(SN.readonly_redis, Redis.from_url(cfg.redis_readonly_url)) + deps.add(SN.readwrite_redis, Redis.from_url(cfg.redis_readwrite_url)) + engine = deps.add(SN.async_engine, create_async_engine(cfg.database_url, pool_size=10)) + session_maker = deps.add(SN.session_maker, async_sessionmaker(engine)) + deps.add(SN.cloud_account_repo, CloudAccountRepositoryImpl(session_maker)) + deps.add(SN.next_run_repo, NextRunRepository(session_maker)) deps.add(SN.collect_queue, RedisCollectQueue(arq_redis)) + deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker)) client = deps.add(SN.inventory_client, InventoryClient(cfg.inventory_url)) deps.add(SN.inventory, InventoryService(client)) if not cfg.static_assets: @@ -67,7 +79,22 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]: @asynccontextmanager async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]: - yield None + arq_redis = deps.add(SN.arg_redis, await create_pool(RedisSettings.from_dsn(cfg.redis_queue_url))) + deps.add(SN.readonly_redis, Redis.from_url(cfg.redis_readonly_url)) + rw_redis = deps.add(SN.readwrite_redis, Redis.from_url(cfg.redis_readwrite_url)) + engine = deps.add(SN.async_engine, create_async_engine(cfg.database_url, pool_size=10)) + session_maker = deps.add(SN.session_maker, async_sessionmaker(engine)) + cloud_accounts = deps.add(SN.cloud_account_repo, CloudAccountRepositoryImpl(session_maker)) + next_run_repo = deps.add(SN.next_run_repo, NextRunRepository(session_maker)) + collect_queue = deps.add(SN.collect_queue, RedisCollectQueue(arq_redis)) + db_access = deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker)) + deps.add(SN.dispatching, DispatcherService(rw_redis, cloud_accounts, next_run_repo, collect_queue, db_access)) + + async with deps: + log.info("Application services started.") + yield None + await arq_redis.close() + log.info("Application services stopped.") app = FastAPI( title="Fix Backend", diff --git a/fixbackend/auth/current_user_dependencies.py b/fixbackend/auth/current_user_dependencies.py index 6ab4067f..8f2fb6a7 100644 --- a/fixbackend/auth/current_user_dependencies.py +++ b/fixbackend/auth/current_user_dependencies.py @@ -29,7 +29,7 @@ from fixbackend.auth.jwt import get_auth_backend from fixbackend.auth.models import User from fixbackend.config import get_config -from fixbackend.graph_db.dependencies import GraphDatabaseAccessManagerDependency +from fixbackend.dependencies import FixDependency from fixbackend.graph_db.models import GraphDatabaseAccess from fixbackend.ids import TenantId from fixbackend.organizations.dependencies import OrganizationServiceDependency @@ -82,10 +82,8 @@ async def get_tenant( TenantDependency = Annotated[TenantId, Depends(get_tenant)] -async def get_current_graph_db( - manager: GraphDatabaseAccessManagerDependency, tenant: TenantDependency -) -> GraphDatabaseAccess: - access = await manager.get_database_access(tenant) +async def get_current_graph_db(fix: FixDependency, tenant: TenantDependency) -> GraphDatabaseAccess: + access = await fix.graph_database_access.get_database_access(tenant) if access is None: raise AttributeError("No database access found for tenant") return access diff --git a/fixbackend/cloud_accounts/repository.py b/fixbackend/cloud_accounts/repository.py index 316fbbed..20675e1f 100644 --- a/fixbackend/cloud_accounts/repository.py +++ b/fixbackend/cloud_accounts/repository.py @@ -17,8 +17,9 @@ from sqlalchemy import select from fixbackend.cloud_accounts.models import orm, CloudAccount, AwsCloudAccess -from fixbackend.db import AsyncSessionMaker, AsyncSessionMakerDependency +from fixbackend.db import AsyncSessionMakerDependency from fixbackend.ids import CloudAccountId, TenantId +from fixbackend.types import AsyncSessionMaker from abc import ABC, abstractmethod @@ -41,10 +42,7 @@ async def delete(self, id: CloudAccountId) -> None: class CloudAccountRepositoryImpl(CloudAccountRepository): - def __init__( - self, - session_maker: AsyncSessionMaker, - ) -> None: + def __init__(self, session_maker: AsyncSessionMaker) -> None: self.session_maker = session_maker async def create(self, cloud_account: CloudAccount) -> CloudAccount: diff --git a/fixbackend/db.py b/fixbackend/db.py index f20b583c..4e358e59 100644 --- a/fixbackend/db.py +++ b/fixbackend/db.py @@ -12,14 +12,13 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import AsyncGenerator, Annotated, Callable +from typing import AsyncGenerator, Annotated from fastapi import Depends from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from fixbackend.dependencies import FixDependency - -AsyncSessionMaker = Callable[[], AsyncSession] +from fixbackend.types import AsyncSessionMaker async def get_async_session_maker(fix: FixDependency) -> AsyncSessionMaker: diff --git a/fixbackend/dependencies.py b/fixbackend/dependencies.py index cdf1c904..521332bd 100644 --- a/fixbackend/dependencies.py +++ b/fixbackend/dependencies.py @@ -16,19 +16,29 @@ from arq import ArqRedis from fastapi.params import Depends from fixcloudutils.service import Dependencies +from redis.asyncio import Redis from sqlalchemy.ext.asyncio import AsyncEngine from fixbackend.collect.collect_queue import RedisCollectQueue +from fixbackend.graph_db.service import GraphDatabaseAccessManager from fixbackend.inventory.inventory_client import InventoryClient from fixbackend.inventory.inventory_service import InventoryService +from fixbackend.types import AsyncSessionMaker class ServiceNames: arg_redis = "arq_redis" + readonly_redis = "readonly_redis" + readwrite_redis = "readwrite_redis" collect_queue = "collect_queue" async_engine = "async_engine" + session_maker = "session_maker" + cloud_account_repo = "cloud_account_repo" + next_run_repo = "next_run_repo" + graph_db_access = "graph_db_access" inventory = "inventory" inventory_client = "inventory_client" + dispatching = "dispatching" class FixDependencies(Dependencies): @@ -44,6 +54,10 @@ def collect_queue(self) -> RedisCollectQueue: def async_engine(self) -> AsyncEngine: return self.service(ServiceNames.async_engine, AsyncEngine) + @property + def session_maker(self) -> AsyncSessionMaker: + return self.service(ServiceNames.async_engine, AsyncSessionMaker) # type: ignore + @property def inventory(self) -> InventoryService: return self.service(ServiceNames.inventory, InventoryService) @@ -52,6 +66,14 @@ def inventory(self) -> InventoryService: def inventory_client(self) -> InventoryClient: return self.service(ServiceNames.inventory, InventoryClient) + @property + def readonly_redis(self) -> Redis: + return self.service(ServiceNames.readonly_redis, Redis) + + @property + def graph_database_access(self) -> GraphDatabaseAccessManager: + return self.service(ServiceNames.graph_db_access, GraphDatabaseAccessManager) + # placeholder for dependencies, will be replaced during the app initialization def fix_dependencies() -> FixDependencies: diff --git a/fixbackend/dispatcher/__init__.py b/fixbackend/dispatcher/__init__.py new file mode 100644 index 00000000..84021a8f --- /dev/null +++ b/fixbackend/dispatcher/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . diff --git a/fixbackend/dispatcher/dispatcher_service.py b/fixbackend/dispatcher/dispatcher_service.py new file mode 100644 index 00000000..fceaea2b --- /dev/null +++ b/fixbackend/dispatcher/dispatcher_service.py @@ -0,0 +1,121 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import logging +from datetime import timedelta, datetime +from typing import Any, Optional + +from fixcloudutils.asyncio.periodic import Periodic +from fixcloudutils.redis.event_stream import RedisStreamListener, Json, MessageContext +from fixcloudutils.service import Service +from fixcloudutils.util import utc +from redis.asyncio import Redis + +from fixbackend.cloud_accounts.models import AwsCloudAccess, CloudAccount +from fixbackend.cloud_accounts.repository import CloudAccountRepository +from fixbackend.collect.collect_queue import CollectQueue, AccountInformation, AwsAccountInformation +from fixbackend.dispatcher.next_run_repository import NextRunRepository +from fixbackend.graph_db.service import GraphDatabaseAccessManager +from fixbackend.ids import CloudAccountId, TenantId + +log = logging.getLogger(__name__) + + +class DispatcherService(Service): + def __init__( + self, + readwrite_redis: Redis, + cloud_account_repo: CloudAccountRepository, + next_run_repo: NextRunRepository, + collect_queue: CollectQueue, + access_manager: GraphDatabaseAccessManager, + ) -> None: + self.cloud_account_repo = cloud_account_repo + self.next_run_repo = next_run_repo + self.collect_queue = collect_queue + self.access_manager = access_manager + self.periodic = Periodic("schedule_next_runs", self.schedule_next_runs, timedelta(minutes=1)) + self.listener = RedisStreamListener( + readwrite_redis, + "fixbackend::cloudaccount", + group="dispatching", + listener="dispatching", + message_processor=self.process_message, + consider_failed_after=timedelta(minutes=5), + batch_size=1, + ) + + async def start(self) -> Any: + await self.listener.start() + await self.periodic.start() + + async def stop(self) -> None: + await self.periodic.stop() + await self.listener.stop() + + async def process_message(self, message: Json, context: MessageContext) -> None: + match context.kind: + case "cloud_account_created": + await self.cloud_account_created(CloudAccountId(message["id"])) + case "cloud_account_deleted": + await self.cloud_account_deleted(CloudAccountId(message["id"])) + case _: + log.error(f"Don't know how to handle messages of kind {context.kind}") + + async def cloud_account_created(self, cid: CloudAccountId) -> None: + if account := await self.cloud_account_repo.get(cid): + await self.trigger_collect(account) + # store an entry in the next_run table + next_run_at = await self.compute_next_run(account.tenant_id) + await self.next_run_repo.create(cid, next_run_at) + else: + log.error("Received a message, that a cloud account is created, but it does not exist in the database") + + async def cloud_account_deleted(self, cid: CloudAccountId) -> None: + # delete the entry from the scheduler table + await self.next_run_repo.delete(cid) + + async def compute_next_run(self, tenant: TenantId) -> datetime: + # compute next run time dependent on the tenant. + result = datetime.now() + timedelta(hours=1) + log.info(f"Next run for tenant: {tenant} is {result}") + return result + + async def trigger_collect(self, account: CloudAccount) -> None: + def account_information() -> Optional[AccountInformation]: + match account.access: + case AwsCloudAccess(account_id=account_id, role_name=role_name, external_id=external_id): + return AwsAccountInformation( + aws_account_id=account_id, + aws_account_name=None, + aws_role_arn=f"arn:aws:iam::{account_id}:role/{role_name}", + external_id=str(external_id), + ) + case _: + log.error(f"Don't know how to handle this cloud access {account.access}. Ignore it.") + return None + + if (ai := account_information()) and (db := await self.access_manager.get_database_access(account.tenant_id)): + await self.collect_queue.enqueue(db, ai) # TODO: create a unique identifier for this run + + async def schedule_next_runs(self) -> None: + now = utc() + async for cid in self.next_run_repo.older_than(now): + if account := await self.cloud_account_repo.get(cid): + await self.trigger_collect(account) + next_run_at = await self.compute_next_run(account.tenant_id) + await self.next_run_repo.update_next_run_at(cid, next_run_at) + else: + log.error("Received a message, that a cloud account is created, but it does not exist in the database") + continue diff --git a/fixbackend/dispatcher/next_run_repository.py b/fixbackend/dispatcher/next_run_repository.py new file mode 100644 index 00000000..51954f3e --- /dev/null +++ b/fixbackend/dispatcher/next_run_repository.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from datetime import datetime +from typing import AsyncIterator + +from fastapi_users_db_sqlalchemy.generics import GUID +from sqlalchemy import DATETIME, select +from sqlalchemy.orm import Mapped, mapped_column + +from fixbackend.base_model import Base +from fixbackend.ids import CloudAccountId +from fixbackend.types import AsyncSessionMaker + + +class NextRun(Base): + __tablename__ = "next_run" + + cloud_account_id: Mapped[CloudAccountId] = mapped_column(GUID, primary_key=True) + at: Mapped[datetime] = mapped_column(DATETIME, nullable=False, index=True) + + +class NextRunRepository: + def __init__(self, session_maker: AsyncSessionMaker) -> None: + self.session_maker = session_maker + + async def create(self, cid: CloudAccountId, next_run: datetime) -> None: + async with self.session_maker() as session: + session.add(NextRun(cloud_account_id=cid, at=next_run)) + await session.commit() + + async def update_next_run_at(self, cid: CloudAccountId, next_run: datetime) -> None: + async with self.session_maker() as session: + if nxt := await session.get(NextRun, cid): + nxt.at = next_run + await session.commit() + + async def delete(self, cid: CloudAccountId) -> None: + async with self.session_maker() as session: + results = await session.execute(select(NextRun).where(NextRun.cloud_account_id == cid)) + if run := results.unique().scalar(): + await session.delete(run) + await session.commit() + + async def older_than(self, at: datetime) -> AsyncIterator[CloudAccountId]: + async with self.session_maker() as session: + async for (entry,) in await session.stream(select(NextRun).where(NextRun.at < at)): + yield entry.cloud_account_id diff --git a/fixbackend/events/websocket_event_handler.py b/fixbackend/events/websocket_event_handler.py index 59bc2e80..1bf935c9 100644 --- a/fixbackend/events/websocket_event_handler.py +++ b/fixbackend/events/websocket_event_handler.py @@ -1,18 +1,12 @@ +from datetime import datetime from typing import Dict, Any, Annotated -from fastapi import WebSocket, Depends -from datetime import datetime -from redis.asyncio import Redis +from fastapi import WebSocket, Depends from fixcloudutils.redis.pub_sub import RedisPubSubListener -from fixbackend.config import ConfigDependency -from fixbackend.ids import TenantId - - -def get_readonly_redis(config: ConfigDependency) -> Redis: - return Redis.from_url(config.redis_readonly_url) # type: ignore - +from redis.asyncio import Redis -ReadonlyRedisDependency = Annotated[Redis, Depends(get_readonly_redis)] +from fixbackend.dependencies import FixDependency +from fixbackend.ids import TenantId class WebsocketEventHandler: @@ -40,8 +34,8 @@ async def ignore_incoming_messages(websocket: WebSocket) -> None: pass -def get_websocket_event_handler(readonly_redis: ReadonlyRedisDependency) -> WebsocketEventHandler: - return WebsocketEventHandler(readonly_redis) +def get_websocket_event_handler(fix: FixDependency) -> WebsocketEventHandler: + return WebsocketEventHandler(fix.readonly_redis) WebsockedtEventHandlerDependency = Annotated[WebsocketEventHandler, Depends(get_websocket_event_handler)] diff --git a/fixbackend/graph_db/service.py b/fixbackend/graph_db/service.py index 6d9e8f5b..87a811cc 100644 --- a/fixbackend/graph_db/service.py +++ b/fixbackend/graph_db/service.py @@ -33,9 +33,9 @@ from fixbackend.base_model import Base from fixbackend.config import Config -from fixbackend.db import AsyncSessionMaker from fixbackend.graph_db.models import GraphDatabaseAccess from fixbackend.ids import TenantId +from fixbackend.types import AsyncSessionMaker log = logging.getLogger(__name__) PasswordLength = 20 diff --git a/fixbackend/organizations/dependencies.py b/fixbackend/organizations/dependencies.py index 06e75538..eda4d399 100644 --- a/fixbackend/organizations/dependencies.py +++ b/fixbackend/organizations/dependencies.py @@ -3,14 +3,12 @@ from fastapi import Depends from fixbackend.db import AsyncSessionDependency -from fixbackend.graph_db.dependencies import GraphDatabaseAccessManagerDependency +from fixbackend.dependencies import FixDependency from fixbackend.organizations.service import OrganizationService -async def get_organization_service( - session: AsyncSessionDependency, graph_db_access_manager: GraphDatabaseAccessManagerDependency -) -> OrganizationService: - return OrganizationService(session, graph_db_access_manager) +async def get_organization_service(session: AsyncSessionDependency, fix: FixDependency) -> OrganizationService: + return OrganizationService(session, fix.graph_database_access) OrganizationServiceDependency = Annotated[OrganizationService, Depends(get_organization_service)] diff --git a/fixbackend/types.py b/fixbackend/types.py new file mode 100644 index 00000000..b533d159 --- /dev/null +++ b/fixbackend/types.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import Callable +from sqlalchemy.ext.asyncio import AsyncSession + +AsyncSessionMaker = Callable[[], AsyncSession] diff --git a/migrations/versions/2023-09-28T13:38:59Z_add_next_run_table.py b/migrations/versions/2023-09-28T13:38:59Z_add_next_run_table.py new file mode 100644 index 00000000..ee2c0159 --- /dev/null +++ b/migrations/versions/2023-09-28T13:38:59Z_add_next_run_table.py @@ -0,0 +1,30 @@ +"""add next_run table + +Revision ID: e3ddf05cd115 +Revises: 9f0f5d8ec3d5 +Create Date: 2023-09-28 13:38:59.635966+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from fastapi_users_db_sqlalchemy import GUID + +# revision identifiers, used by Alembic. +revision: str = "e3ddf05cd115" +down_revision: Union[str, None] = "9f0f5d8ec3d5" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "next_run", + sa.Column("cloud_account_id", GUID(), nullable=False), + sa.Column("at", sa.DATETIME(), nullable=False), + sa.PrimaryKeyConstraint("cloud_account_id"), + ) + op.create_index("idx_at", "next_run", ["at"], unique=False) + # ### end Alembic commands ### diff --git a/tests/fixbackend/auth/router_test.py b/tests/fixbackend/auth/router_test.py index 6f4140cc..27dd77e6 100644 --- a/tests/fixbackend/auth/router_test.py +++ b/tests/fixbackend/auth/router_test.py @@ -12,20 +12,14 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import AsyncIterator, List, Optional, Tuple +from typing import List, Optional, Tuple import pytest -from fastapi import Request +from fastapi import Request, FastAPI from httpx import AsyncClient -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession -from fixbackend.app import fast_api_app from fixbackend.auth.models import User from fixbackend.auth.user_verifier import UserVerifier, get_user_verifier -from fixbackend.config import Config -from fixbackend.config import config as get_config -from fixbackend.db import get_async_session -from fixbackend.dependencies import FixDependencies, fix_dependencies class InMemoryVerifier(UserVerifier): @@ -36,33 +30,18 @@ async def verify(self, user: User, token: str, request: Optional[Request]) -> No return self.verification_requests.append((user, token)) -verifier = InMemoryVerifier() - - -@pytest.fixture -async def client( - db_engine: AsyncEngine, session: AsyncSession, default_config: Config -) -> AsyncIterator[AsyncClient]: # noqa: F811 - app = fast_api_app(default_config) - deps = FixDependencies(async_engine=db_engine) - app.dependency_overrides[get_async_session] = lambda: session - app.dependency_overrides[get_user_verifier] = lambda: verifier - app.dependency_overrides[get_config] = lambda: default_config - app.dependency_overrides[fix_dependencies] = lambda: deps - - async with AsyncClient(app=app, base_url="http://test") as ac: - yield ac - - @pytest.mark.asyncio -async def test_registration_flow(client: AsyncClient) -> None: +async def test_registration_flow(api_client: AsyncClient, fast_api: FastAPI) -> None: + verifier = InMemoryVerifier() + fast_api.dependency_overrides[get_user_verifier] = lambda: verifier + registration_json = { "email": "user@example.com", "password": "changeme", } # register user - response = await client.post("/api/auth/register", json=registration_json) + response = await api_client.post("/api/auth/register", json=registration_json) assert response.status_code == 201 login_json = { @@ -71,7 +50,7 @@ async def test_registration_flow(client: AsyncClient) -> None: } # non_verified can't login - response = await client.post("/api/auth/jwt/login", data=login_json) + response = await api_client.post("/api/auth/jwt/login", data=login_json) assert response.status_code == 400 # verify user @@ -79,7 +58,7 @@ async def test_registration_flow(client: AsyncClient) -> None: verification_json = { "token": token, } - response = await client.post("/api/auth/verify", json=verification_json) + response = await api_client.post("/api/auth/verify", json=verification_json) assert response.status_code == 200 response_json = response.json() assert response_json["email"] == user.email @@ -89,11 +68,11 @@ async def test_registration_flow(client: AsyncClient) -> None: assert response_json["id"] == str(user.id) # verified can login - response = await client.post("/api/auth/jwt/login", data=login_json) + response = await api_client.post("/api/auth/jwt/login", data=login_json) assert response.status_code == 204 auth_cookie = response.cookies.get("fix.auth") assert auth_cookie is not None # organization is created by default - response = await client.get("/api/organizations/", cookies={"fix.auth": auth_cookie}) + response = await api_client.get("/api/organizations/", cookies={"fix.auth": auth_cookie}) assert response.json()[0].get("name") == user.email diff --git a/tests/fixbackend/cloud_accounts/repository_test.py b/tests/fixbackend/cloud_accounts/repository_test.py index 376a269e..633f2ac5 100644 --- a/tests/fixbackend/cloud_accounts/repository_test.py +++ b/tests/fixbackend/cloud_accounts/repository_test.py @@ -18,7 +18,7 @@ from fixbackend.ids import CloudAccountId, ExternalId from fixbackend.cloud_accounts.repository import CloudAccountRepositoryImpl -from fixbackend.db import AsyncSessionMaker +from fixbackend.types import AsyncSessionMaker from fixbackend.cloud_accounts.models import CloudAccount, AwsCloudAccess from fixbackend.organizations.service import OrganizationService from fixbackend.auth.models import User diff --git a/tests/fixbackend/conftest.py b/tests/fixbackend/conftest.py index 4412c80d..157d45ad 100644 --- a/tests/fixbackend/conftest.py +++ b/tests/fixbackend/conftest.py @@ -23,20 +23,28 @@ from alembic.config import Config as AlembicConfig from arq import ArqRedis, create_pool from arq.connections import RedisSettings +from fastapi import FastAPI from fixcloudutils.types import Json from httpx import AsyncClient, MockTransport, Request, Response from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy_utils import create_database, database_exists, drop_database +from fixbackend.app import fast_api_app +from fixbackend.cloud_accounts.repository import CloudAccountRepository, CloudAccountRepositoryImpl from fixbackend.auth.db import get_user_repository from fixbackend.auth.models import User from fixbackend.collect.collect_queue import RedisCollectQueue -from fixbackend.config import Config -from fixbackend.db import AsyncSessionMaker +from fixbackend.config import Config, get_config +from fixbackend.db import get_async_session +from fixbackend.dependencies import FixDependencies, ServiceNames, fix_dependencies +from fixbackend.dispatcher.dispatcher_service import DispatcherService +from fixbackend.dispatcher.next_run_repository import NextRunRepository from fixbackend.graph_db.service import GraphDatabaseAccessManager from fixbackend.inventory.inventory_client import InventoryClient from fixbackend.inventory.inventory_service import InventoryService +from fixbackend.organizations.models import Organization from fixbackend.organizations.service import OrganizationService +from fixbackend.types import AsyncSessionMaker DATABASE_URL = "mysql+aiomysql://root@127.0.0.1:3306/fixbackend-testdb" # only used to create/drop the database @@ -158,9 +166,12 @@ async def user(session: AsyncSession) -> User: "hashed_password": "notreallyhashed", "is_verified": True, } - user = await user_db.create(user_dict) + return await user_db.create(user_dict) - return user + +@pytest.fixture +async def organization(organization_repository: OrganizationService, user: User) -> Organization: + return await organization_repository.create_organization("foo", "foo", user) @pytest.fixture @@ -235,3 +246,57 @@ async def app(request: Request) -> Response: async def inventory_service(inventory_client: InventoryClient) -> AsyncIterator[InventoryService]: async with InventoryService(inventory_client) as service: yield service + + +@pytest.fixture +async def next_run_repository(async_session_maker: AsyncSessionMaker) -> NextRunRepository: + return NextRunRepository(async_session_maker) + + +@pytest.fixture +async def cloud_account_repository(async_session_maker: AsyncSessionMaker) -> CloudAccountRepository: + return CloudAccountRepositoryImpl(async_session_maker) + + +@pytest.fixture +async def organization_repository( + session: AsyncSession, graph_database_access_manager: GraphDatabaseAccessManager +) -> OrganizationService: + return OrganizationService(session, graph_database_access_manager) + + +@pytest.fixture +async def dispatcher( + arq_redis: ArqRedis, + cloud_account_repository: CloudAccountRepository, + next_run_repository: NextRunRepository, + collect_queue: RedisCollectQueue, + graph_database_access_manager: GraphDatabaseAccessManager, +) -> DispatcherService: + return DispatcherService( + arq_redis, cloud_account_repository, next_run_repository, collect_queue, graph_database_access_manager + ) + + +@pytest.fixture +async def fix_deps( + db_engine: AsyncEngine, graph_database_access_manager: GraphDatabaseAccessManager +) -> FixDependencies: + return FixDependencies( + **{ServiceNames.async_engine: db_engine, ServiceNames.graph_db_access: graph_database_access_manager} + ) + + +@pytest.fixture +async def fast_api(fix_deps: FixDependencies, session: AsyncSession, default_config: Config) -> FastAPI: + app = fast_api_app(default_config) + app.dependency_overrides[get_async_session] = lambda: session + app.dependency_overrides[get_config] = lambda: default_config + app.dependency_overrides[fix_dependencies] = lambda: fix_deps + return app + + +@pytest.fixture +async def api_client(fast_api: FastAPI) -> AsyncIterator[AsyncClient]: # noqa: F811 + async with AsyncClient(app=fast_api, base_url="http://test") as ac: + yield ac diff --git a/fixbackend/graph_db/dependencies.py b/tests/fixbackend/dispatcher/__init__.py similarity index 66% rename from fixbackend/graph_db/dependencies.py rename to tests/fixbackend/dispatcher/__init__.py index 6d2694a7..4a5be1e9 100644 --- a/fixbackend/graph_db/dependencies.py +++ b/tests/fixbackend/dispatcher/__init__.py @@ -19,20 +19,11 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - -from typing import Annotated - -from fastapi import Depends - -from fixbackend.config import ConfigDependency -from fixbackend.db import AsyncSessionMakerDependency -from fixbackend.graph_db.service import GraphDatabaseAccessManager - - -def get_graph_database_access_manager( - config: ConfigDependency, session_maker: AsyncSessionMakerDependency -) -> GraphDatabaseAccessManager: - return GraphDatabaseAccessManager(config, session_maker) - - -GraphDatabaseAccessManagerDependency = Annotated[GraphDatabaseAccessManager, Depends(get_graph_database_access_manager)] +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . diff --git a/tests/fixbackend/dispatcher/dispatcher_service_test.py b/tests/fixbackend/dispatcher/dispatcher_service_test.py new file mode 100644 index 00000000..c38fd598 --- /dev/null +++ b/tests/fixbackend/dispatcher/dispatcher_service_test.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import uuid +from datetime import datetime, timedelta + +import pytest +from fixcloudutils.redis.event_stream import MessageContext +from fixcloudutils.util import utc +from redis.asyncio import Redis +from sqlalchemy.ext.asyncio import AsyncSession + +from fixbackend.cloud_accounts.models import CloudAccount, AwsCloudAccess +from fixbackend.cloud_accounts.repository import CloudAccountRepository +from fixbackend.dispatcher.dispatcher_service import DispatcherService +from fixbackend.dispatcher.next_run_repository import NextRunRepository, NextRun +from fixbackend.ids import CloudAccountId +from fixbackend.organizations.models import Organization + + +@pytest.mark.asyncio +async def test_receive_created( + dispatcher: DispatcherService, + session: AsyncSession, + cloud_account_repository: CloudAccountRepository, + organization: Organization, + arq_redis: Redis, +) -> None: + # create a cloud account + cloud_account_id = CloudAccountId(uuid.uuid1()) + await cloud_account_repository.create( + CloudAccount(cloud_account_id, organization.id, AwsCloudAccess("123", organization.external_id, "test")) + ) + # signal to the dispatcher that the cloud account was created + await dispatcher.process_message( + {"id": str(cloud_account_id)}, MessageContext("test", "cloud_account_created", "test", utc(), utc()) + ) + # check that a new entry was created in the next_run table + next_run = await session.get(NextRun, cloud_account_id) + assert next_run is not None + assert next_run.at > datetime.now() # next run is in the future + # check that two new entries are created in the work queue: (e.g.: arq:queue, arq:job:xxx) + assert len(await arq_redis.keys()) == 2 + + +@pytest.mark.asyncio +async def test_receive_deleted( + dispatcher: DispatcherService, session: AsyncSession, next_run_repository: NextRunRepository +) -> None: + # create cloud + cloud_account_id = CloudAccountId(uuid.uuid1()) + # create a next run entry + await next_run_repository.create(cloud_account_id, utc()) + # signal to the dispatcher that the cloud account was created + await dispatcher.process_message( + {"id": str(cloud_account_id)}, MessageContext("test", "cloud_account_deleted", "test", utc(), utc()) + ) + # check that a new entry was created in the next_run table + next_run = await session.get(NextRun, cloud_account_id) + assert next_run is None + + +@pytest.mark.asyncio +async def test_trigger_collect( + dispatcher: DispatcherService, + session: AsyncSession, + cloud_account_repository: CloudAccountRepository, + next_run_repository: NextRunRepository, + organization: Organization, + arq_redis: Redis, +) -> None: + # create a cloud account and next_run entry + cloud_account_id = CloudAccountId(uuid.uuid1()) + account = CloudAccount(cloud_account_id, organization.id, AwsCloudAccess("123", organization.external_id, "test")) + await cloud_account_repository.create(account) + # Create a next run entry scheduled in the past - it should be picked up for collect + await next_run_repository.create(cloud_account_id, utc() - timedelta(hours=1)) + + # schedule runs: make sure a collect is triggered and the next_run is updated + await dispatcher.schedule_next_runs() + next_run = await session.get(NextRun, cloud_account_id) + assert next_run is not None + assert next_run.at > datetime.now() # next run is in the future + # check that two new entries are created in the work queue: (e.g.: arq:queue, arq:job:xxx) + assert len(await arq_redis.keys()) == 2 + + # another run should not change anything + await dispatcher.schedule_next_runs() + again = await session.get(NextRun, cloud_account_id) + assert again is not None + assert again.at == next_run.at + assert len(await arq_redis.keys()) == 2 diff --git a/tests/fixbackend/dispatcher/next_run_repository_test.py b/tests/fixbackend/dispatcher/next_run_repository_test.py new file mode 100644 index 00000000..0feec909 --- /dev/null +++ b/tests/fixbackend/dispatcher/next_run_repository_test.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import uuid +from datetime import timedelta + +import pytest +from fixcloudutils.util import utc + +from fixbackend.dispatcher.next_run_repository import NextRunRepository +from fixbackend.ids import CloudAccountId + + +@pytest.mark.asyncio +async def test_create(next_run_repository: NextRunRepository) -> None: + cid = CloudAccountId(uuid.uuid1()) + now = utc() + now_minus_1 = now - timedelta(minutes=1) + await next_run_repository.create(cid, now_minus_1) + # no entries that are older than 1 hour + entries = [entry async for entry in next_run_repository.older_than(now - timedelta(hours=1))] + assert len(entries) == 0 + # one entry that is older than now + assert [entry async for entry in next_run_repository.older_than(now)] == [cid] + # update the entry to run in 1 hour + await next_run_repository.update_next_run_at(cid, now + timedelta(hours=1)) + assert [entry async for entry in next_run_repository.older_than(now)] == [] + assert [entry async for entry in next_run_repository.older_than(now + timedelta(hours=2))] == [cid] + # delete the entry + await next_run_repository.delete(cid) + assert [entry async for entry in next_run_repository.older_than(now + timedelta(days=365))] == [] diff --git a/tests/fixbackend/graph_db/graph_db_access_test.py b/tests/fixbackend/graph_db/graph_db_access_test.py index 62eeb27d..29e88113 100644 --- a/tests/fixbackend/graph_db/graph_db_access_test.py +++ b/tests/fixbackend/graph_db/graph_db_access_test.py @@ -23,9 +23,9 @@ import pytest -from fixbackend.db import AsyncSessionMaker -from fixbackend.ids import TenantId from fixbackend.graph_db.service import GraphDatabaseAccessManager +from fixbackend.ids import TenantId +from fixbackend.types import AsyncSessionMaker @pytest.mark.asyncio