diff --git a/fixbackend/all_models.py b/fixbackend/all_models.py
index 7a190582..2f6bb207 100644
--- a/fixbackend/all_models.py
+++ b/fixbackend/all_models.py
@@ -22,3 +22,4 @@
from fixbackend.organizations.models.orm import Organization, OrganizationInvite # noqa
from fixbackend.graph_db.service import GraphDatabaseAccessEntity # noqa
from fixbackend.cloud_accounts.models.orm import CloudAccount # noqa
+from fixbackend.dispatcher.next_run_repository import NextRun # noqa
diff --git a/fixbackend/app.py b/fixbackend/app.py
index f0329710..4a62082a 100644
--- a/fixbackend/app.py
+++ b/fixbackend/app.py
@@ -24,18 +24,24 @@
from fastapi.exception_handlers import http_exception_handler
from fastapi.staticfiles import StaticFiles
from prometheus_fastapi_instrumentator import Instrumentator
+from redis.asyncio import Redis
+from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine
from starlette.exceptions import HTTPException
from fixbackend import config, dependencies
from fixbackend.auth.oauth import github_client, google_client
from fixbackend.auth.router import auth_router, users_router
+from fixbackend.cloud_accounts.repository import CloudAccountRepositoryImpl
from fixbackend.cloud_accounts.router import cloud_accounts_router, cloud_accounts_callback_router
from fixbackend.collect.collect_queue import RedisCollectQueue
from fixbackend.config import Config
from fixbackend.dependencies import FixDependencies
from fixbackend.dependencies import ServiceNames as SN
+from fixbackend.dispatcher.dispatcher_service import DispatcherService
+from fixbackend.dispatcher.next_run_repository import NextRunRepository
from fixbackend.events.router import websocket_router
+from fixbackend.graph_db.service import GraphDatabaseAccessManager
from fixbackend.inventory.inventory_client import InventoryClient
from fixbackend.inventory.inventory_service import InventoryService
from fixbackend.inventory.router import inventory_router
@@ -53,8 +59,14 @@ def fast_api_app(cfg: Config) -> FastAPI:
@asynccontextmanager
async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
arq_redis = deps.add(SN.arg_redis, await create_pool(RedisSettings.from_dsn(cfg.redis_queue_url)))
- deps.add(SN.async_engine, create_async_engine(cfg.database_url, pool_size=10))
+ deps.add(SN.readonly_redis, Redis.from_url(cfg.redis_readonly_url))
+ deps.add(SN.readwrite_redis, Redis.from_url(cfg.redis_readwrite_url))
+ engine = deps.add(SN.async_engine, create_async_engine(cfg.database_url, pool_size=10))
+ session_maker = deps.add(SN.session_maker, async_sessionmaker(engine))
+ deps.add(SN.cloud_account_repo, CloudAccountRepositoryImpl(session_maker))
+ deps.add(SN.next_run_repo, NextRunRepository(session_maker))
deps.add(SN.collect_queue, RedisCollectQueue(arq_redis))
+ deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker))
client = deps.add(SN.inventory_client, InventoryClient(cfg.inventory_url))
deps.add(SN.inventory, InventoryService(client))
if not cfg.static_assets:
@@ -67,7 +79,22 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
@asynccontextmanager
async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]:
- yield None
+ arq_redis = deps.add(SN.arg_redis, await create_pool(RedisSettings.from_dsn(cfg.redis_queue_url)))
+ deps.add(SN.readonly_redis, Redis.from_url(cfg.redis_readonly_url))
+ rw_redis = deps.add(SN.readwrite_redis, Redis.from_url(cfg.redis_readwrite_url))
+ engine = deps.add(SN.async_engine, create_async_engine(cfg.database_url, pool_size=10))
+ session_maker = deps.add(SN.session_maker, async_sessionmaker(engine))
+ cloud_accounts = deps.add(SN.cloud_account_repo, CloudAccountRepositoryImpl(session_maker))
+ next_run_repo = deps.add(SN.next_run_repo, NextRunRepository(session_maker))
+ collect_queue = deps.add(SN.collect_queue, RedisCollectQueue(arq_redis))
+ db_access = deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker))
+ deps.add(SN.dispatching, DispatcherService(rw_redis, cloud_accounts, next_run_repo, collect_queue, db_access))
+
+ async with deps:
+ log.info("Application services started.")
+ yield None
+ await arq_redis.close()
+ log.info("Application services stopped.")
app = FastAPI(
title="Fix Backend",
diff --git a/fixbackend/auth/current_user_dependencies.py b/fixbackend/auth/current_user_dependencies.py
index 6ab4067f..8f2fb6a7 100644
--- a/fixbackend/auth/current_user_dependencies.py
+++ b/fixbackend/auth/current_user_dependencies.py
@@ -29,7 +29,7 @@
from fixbackend.auth.jwt import get_auth_backend
from fixbackend.auth.models import User
from fixbackend.config import get_config
-from fixbackend.graph_db.dependencies import GraphDatabaseAccessManagerDependency
+from fixbackend.dependencies import FixDependency
from fixbackend.graph_db.models import GraphDatabaseAccess
from fixbackend.ids import TenantId
from fixbackend.organizations.dependencies import OrganizationServiceDependency
@@ -82,10 +82,8 @@ async def get_tenant(
TenantDependency = Annotated[TenantId, Depends(get_tenant)]
-async def get_current_graph_db(
- manager: GraphDatabaseAccessManagerDependency, tenant: TenantDependency
-) -> GraphDatabaseAccess:
- access = await manager.get_database_access(tenant)
+async def get_current_graph_db(fix: FixDependency, tenant: TenantDependency) -> GraphDatabaseAccess:
+ access = await fix.graph_database_access.get_database_access(tenant)
if access is None:
raise AttributeError("No database access found for tenant")
return access
diff --git a/fixbackend/cloud_accounts/repository.py b/fixbackend/cloud_accounts/repository.py
index 316fbbed..20675e1f 100644
--- a/fixbackend/cloud_accounts/repository.py
+++ b/fixbackend/cloud_accounts/repository.py
@@ -17,8 +17,9 @@
from sqlalchemy import select
from fixbackend.cloud_accounts.models import orm, CloudAccount, AwsCloudAccess
-from fixbackend.db import AsyncSessionMaker, AsyncSessionMakerDependency
+from fixbackend.db import AsyncSessionMakerDependency
from fixbackend.ids import CloudAccountId, TenantId
+from fixbackend.types import AsyncSessionMaker
from abc import ABC, abstractmethod
@@ -41,10 +42,7 @@ async def delete(self, id: CloudAccountId) -> None:
class CloudAccountRepositoryImpl(CloudAccountRepository):
- def __init__(
- self,
- session_maker: AsyncSessionMaker,
- ) -> None:
+ def __init__(self, session_maker: AsyncSessionMaker) -> None:
self.session_maker = session_maker
async def create(self, cloud_account: CloudAccount) -> CloudAccount:
diff --git a/fixbackend/db.py b/fixbackend/db.py
index f20b583c..4e358e59 100644
--- a/fixbackend/db.py
+++ b/fixbackend/db.py
@@ -12,14 +12,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import AsyncGenerator, Annotated, Callable
+from typing import AsyncGenerator, Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from fixbackend.dependencies import FixDependency
-
-AsyncSessionMaker = Callable[[], AsyncSession]
+from fixbackend.types import AsyncSessionMaker
async def get_async_session_maker(fix: FixDependency) -> AsyncSessionMaker:
diff --git a/fixbackend/dependencies.py b/fixbackend/dependencies.py
index cdf1c904..521332bd 100644
--- a/fixbackend/dependencies.py
+++ b/fixbackend/dependencies.py
@@ -16,19 +16,29 @@
from arq import ArqRedis
from fastapi.params import Depends
from fixcloudutils.service import Dependencies
+from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncEngine
from fixbackend.collect.collect_queue import RedisCollectQueue
+from fixbackend.graph_db.service import GraphDatabaseAccessManager
from fixbackend.inventory.inventory_client import InventoryClient
from fixbackend.inventory.inventory_service import InventoryService
+from fixbackend.types import AsyncSessionMaker
class ServiceNames:
arg_redis = "arq_redis"
+ readonly_redis = "readonly_redis"
+ readwrite_redis = "readwrite_redis"
collect_queue = "collect_queue"
async_engine = "async_engine"
+ session_maker = "session_maker"
+ cloud_account_repo = "cloud_account_repo"
+ next_run_repo = "next_run_repo"
+ graph_db_access = "graph_db_access"
inventory = "inventory"
inventory_client = "inventory_client"
+ dispatching = "dispatching"
class FixDependencies(Dependencies):
@@ -44,6 +54,10 @@ def collect_queue(self) -> RedisCollectQueue:
def async_engine(self) -> AsyncEngine:
return self.service(ServiceNames.async_engine, AsyncEngine)
+ @property
+ def session_maker(self) -> AsyncSessionMaker:
+ return self.service(ServiceNames.async_engine, AsyncSessionMaker) # type: ignore
+
@property
def inventory(self) -> InventoryService:
return self.service(ServiceNames.inventory, InventoryService)
@@ -52,6 +66,14 @@ def inventory(self) -> InventoryService:
def inventory_client(self) -> InventoryClient:
return self.service(ServiceNames.inventory, InventoryClient)
+ @property
+ def readonly_redis(self) -> Redis:
+ return self.service(ServiceNames.readonly_redis, Redis)
+
+ @property
+ def graph_database_access(self) -> GraphDatabaseAccessManager:
+ return self.service(ServiceNames.graph_db_access, GraphDatabaseAccessManager)
+
# placeholder for dependencies, will be replaced during the app initialization
def fix_dependencies() -> FixDependencies:
diff --git a/fixbackend/dispatcher/__init__.py b/fixbackend/dispatcher/__init__.py
new file mode 100644
index 00000000..84021a8f
--- /dev/null
+++ b/fixbackend/dispatcher/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
diff --git a/fixbackend/dispatcher/dispatcher_service.py b/fixbackend/dispatcher/dispatcher_service.py
new file mode 100644
index 00000000..fceaea2b
--- /dev/null
+++ b/fixbackend/dispatcher/dispatcher_service.py
@@ -0,0 +1,121 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import logging
+from datetime import timedelta, datetime
+from typing import Any, Optional
+
+from fixcloudutils.asyncio.periodic import Periodic
+from fixcloudutils.redis.event_stream import RedisStreamListener, Json, MessageContext
+from fixcloudutils.service import Service
+from fixcloudutils.util import utc
+from redis.asyncio import Redis
+
+from fixbackend.cloud_accounts.models import AwsCloudAccess, CloudAccount
+from fixbackend.cloud_accounts.repository import CloudAccountRepository
+from fixbackend.collect.collect_queue import CollectQueue, AccountInformation, AwsAccountInformation
+from fixbackend.dispatcher.next_run_repository import NextRunRepository
+from fixbackend.graph_db.service import GraphDatabaseAccessManager
+from fixbackend.ids import CloudAccountId, TenantId
+
+log = logging.getLogger(__name__)
+
+
+class DispatcherService(Service):
+ def __init__(
+ self,
+ readwrite_redis: Redis,
+ cloud_account_repo: CloudAccountRepository,
+ next_run_repo: NextRunRepository,
+ collect_queue: CollectQueue,
+ access_manager: GraphDatabaseAccessManager,
+ ) -> None:
+ self.cloud_account_repo = cloud_account_repo
+ self.next_run_repo = next_run_repo
+ self.collect_queue = collect_queue
+ self.access_manager = access_manager
+ self.periodic = Periodic("schedule_next_runs", self.schedule_next_runs, timedelta(minutes=1))
+ self.listener = RedisStreamListener(
+ readwrite_redis,
+ "fixbackend::cloudaccount",
+ group="dispatching",
+ listener="dispatching",
+ message_processor=self.process_message,
+ consider_failed_after=timedelta(minutes=5),
+ batch_size=1,
+ )
+
+ async def start(self) -> Any:
+ await self.listener.start()
+ await self.periodic.start()
+
+ async def stop(self) -> None:
+ await self.periodic.stop()
+ await self.listener.stop()
+
+ async def process_message(self, message: Json, context: MessageContext) -> None:
+ match context.kind:
+ case "cloud_account_created":
+ await self.cloud_account_created(CloudAccountId(message["id"]))
+ case "cloud_account_deleted":
+ await self.cloud_account_deleted(CloudAccountId(message["id"]))
+ case _:
+ log.error(f"Don't know how to handle messages of kind {context.kind}")
+
+ async def cloud_account_created(self, cid: CloudAccountId) -> None:
+ if account := await self.cloud_account_repo.get(cid):
+ await self.trigger_collect(account)
+ # store an entry in the next_run table
+ next_run_at = await self.compute_next_run(account.tenant_id)
+ await self.next_run_repo.create(cid, next_run_at)
+ else:
+ log.error("Received a message, that a cloud account is created, but it does not exist in the database")
+
+ async def cloud_account_deleted(self, cid: CloudAccountId) -> None:
+ # delete the entry from the scheduler table
+ await self.next_run_repo.delete(cid)
+
+ async def compute_next_run(self, tenant: TenantId) -> datetime:
+ # compute next run time dependent on the tenant.
+ result = datetime.now() + timedelta(hours=1)
+ log.info(f"Next run for tenant: {tenant} is {result}")
+ return result
+
+ async def trigger_collect(self, account: CloudAccount) -> None:
+ def account_information() -> Optional[AccountInformation]:
+ match account.access:
+ case AwsCloudAccess(account_id=account_id, role_name=role_name, external_id=external_id):
+ return AwsAccountInformation(
+ aws_account_id=account_id,
+ aws_account_name=None,
+ aws_role_arn=f"arn:aws:iam::{account_id}:role/{role_name}",
+ external_id=str(external_id),
+ )
+ case _:
+ log.error(f"Don't know how to handle this cloud access {account.access}. Ignore it.")
+ return None
+
+ if (ai := account_information()) and (db := await self.access_manager.get_database_access(account.tenant_id)):
+ await self.collect_queue.enqueue(db, ai) # TODO: create a unique identifier for this run
+
+ async def schedule_next_runs(self) -> None:
+ now = utc()
+ async for cid in self.next_run_repo.older_than(now):
+ if account := await self.cloud_account_repo.get(cid):
+ await self.trigger_collect(account)
+ next_run_at = await self.compute_next_run(account.tenant_id)
+ await self.next_run_repo.update_next_run_at(cid, next_run_at)
+ else:
+ log.error("Received a message, that a cloud account is created, but it does not exist in the database")
+ continue
diff --git a/fixbackend/dispatcher/next_run_repository.py b/fixbackend/dispatcher/next_run_repository.py
new file mode 100644
index 00000000..51954f3e
--- /dev/null
+++ b/fixbackend/dispatcher/next_run_repository.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from datetime import datetime
+from typing import AsyncIterator
+
+from fastapi_users_db_sqlalchemy.generics import GUID
+from sqlalchemy import DATETIME, select
+from sqlalchemy.orm import Mapped, mapped_column
+
+from fixbackend.base_model import Base
+from fixbackend.ids import CloudAccountId
+from fixbackend.types import AsyncSessionMaker
+
+
+class NextRun(Base):
+ __tablename__ = "next_run"
+
+ cloud_account_id: Mapped[CloudAccountId] = mapped_column(GUID, primary_key=True)
+ at: Mapped[datetime] = mapped_column(DATETIME, nullable=False, index=True)
+
+
+class NextRunRepository:
+ def __init__(self, session_maker: AsyncSessionMaker) -> None:
+ self.session_maker = session_maker
+
+ async def create(self, cid: CloudAccountId, next_run: datetime) -> None:
+ async with self.session_maker() as session:
+ session.add(NextRun(cloud_account_id=cid, at=next_run))
+ await session.commit()
+
+ async def update_next_run_at(self, cid: CloudAccountId, next_run: datetime) -> None:
+ async with self.session_maker() as session:
+ if nxt := await session.get(NextRun, cid):
+ nxt.at = next_run
+ await session.commit()
+
+ async def delete(self, cid: CloudAccountId) -> None:
+ async with self.session_maker() as session:
+ results = await session.execute(select(NextRun).where(NextRun.cloud_account_id == cid))
+ if run := results.unique().scalar():
+ await session.delete(run)
+ await session.commit()
+
+ async def older_than(self, at: datetime) -> AsyncIterator[CloudAccountId]:
+ async with self.session_maker() as session:
+ async for (entry,) in await session.stream(select(NextRun).where(NextRun.at < at)):
+ yield entry.cloud_account_id
diff --git a/fixbackend/events/websocket_event_handler.py b/fixbackend/events/websocket_event_handler.py
index 59bc2e80..1bf935c9 100644
--- a/fixbackend/events/websocket_event_handler.py
+++ b/fixbackend/events/websocket_event_handler.py
@@ -1,18 +1,12 @@
+from datetime import datetime
from typing import Dict, Any, Annotated
-from fastapi import WebSocket, Depends
-from datetime import datetime
-from redis.asyncio import Redis
+from fastapi import WebSocket, Depends
from fixcloudutils.redis.pub_sub import RedisPubSubListener
-from fixbackend.config import ConfigDependency
-from fixbackend.ids import TenantId
-
-
-def get_readonly_redis(config: ConfigDependency) -> Redis:
- return Redis.from_url(config.redis_readonly_url) # type: ignore
-
+from redis.asyncio import Redis
-ReadonlyRedisDependency = Annotated[Redis, Depends(get_readonly_redis)]
+from fixbackend.dependencies import FixDependency
+from fixbackend.ids import TenantId
class WebsocketEventHandler:
@@ -40,8 +34,8 @@ async def ignore_incoming_messages(websocket: WebSocket) -> None:
pass
-def get_websocket_event_handler(readonly_redis: ReadonlyRedisDependency) -> WebsocketEventHandler:
- return WebsocketEventHandler(readonly_redis)
+def get_websocket_event_handler(fix: FixDependency) -> WebsocketEventHandler:
+ return WebsocketEventHandler(fix.readonly_redis)
WebsockedtEventHandlerDependency = Annotated[WebsocketEventHandler, Depends(get_websocket_event_handler)]
diff --git a/fixbackend/graph_db/service.py b/fixbackend/graph_db/service.py
index 6d9e8f5b..87a811cc 100644
--- a/fixbackend/graph_db/service.py
+++ b/fixbackend/graph_db/service.py
@@ -33,9 +33,9 @@
from fixbackend.base_model import Base
from fixbackend.config import Config
-from fixbackend.db import AsyncSessionMaker
from fixbackend.graph_db.models import GraphDatabaseAccess
from fixbackend.ids import TenantId
+from fixbackend.types import AsyncSessionMaker
log = logging.getLogger(__name__)
PasswordLength = 20
diff --git a/fixbackend/organizations/dependencies.py b/fixbackend/organizations/dependencies.py
index 06e75538..eda4d399 100644
--- a/fixbackend/organizations/dependencies.py
+++ b/fixbackend/organizations/dependencies.py
@@ -3,14 +3,12 @@
from fastapi import Depends
from fixbackend.db import AsyncSessionDependency
-from fixbackend.graph_db.dependencies import GraphDatabaseAccessManagerDependency
+from fixbackend.dependencies import FixDependency
from fixbackend.organizations.service import OrganizationService
-async def get_organization_service(
- session: AsyncSessionDependency, graph_db_access_manager: GraphDatabaseAccessManagerDependency
-) -> OrganizationService:
- return OrganizationService(session, graph_db_access_manager)
+async def get_organization_service(session: AsyncSessionDependency, fix: FixDependency) -> OrganizationService:
+ return OrganizationService(session, fix.graph_database_access)
OrganizationServiceDependency = Annotated[OrganizationService, Depends(get_organization_service)]
diff --git a/fixbackend/types.py b/fixbackend/types.py
new file mode 100644
index 00000000..b533d159
--- /dev/null
+++ b/fixbackend/types.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from typing import Callable
+from sqlalchemy.ext.asyncio import AsyncSession
+
+AsyncSessionMaker = Callable[[], AsyncSession]
diff --git a/migrations/versions/2023-09-28T13:38:59Z_add_next_run_table.py b/migrations/versions/2023-09-28T13:38:59Z_add_next_run_table.py
new file mode 100644
index 00000000..ee2c0159
--- /dev/null
+++ b/migrations/versions/2023-09-28T13:38:59Z_add_next_run_table.py
@@ -0,0 +1,30 @@
+"""add next_run table
+
+Revision ID: e3ddf05cd115
+Revises: 9f0f5d8ec3d5
+Create Date: 2023-09-28 13:38:59.635966+00:00
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+from fastapi_users_db_sqlalchemy import GUID
+
+# revision identifiers, used by Alembic.
+revision: str = "e3ddf05cd115"
+down_revision: Union[str, None] = "9f0f5d8ec3d5"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table(
+ "next_run",
+ sa.Column("cloud_account_id", GUID(), nullable=False),
+ sa.Column("at", sa.DATETIME(), nullable=False),
+ sa.PrimaryKeyConstraint("cloud_account_id"),
+ )
+ op.create_index("idx_at", "next_run", ["at"], unique=False)
+ # ### end Alembic commands ###
diff --git a/tests/fixbackend/auth/router_test.py b/tests/fixbackend/auth/router_test.py
index 6f4140cc..27dd77e6 100644
--- a/tests/fixbackend/auth/router_test.py
+++ b/tests/fixbackend/auth/router_test.py
@@ -12,20 +12,14 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import AsyncIterator, List, Optional, Tuple
+from typing import List, Optional, Tuple
import pytest
-from fastapi import Request
+from fastapi import Request, FastAPI
from httpx import AsyncClient
-from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
-from fixbackend.app import fast_api_app
from fixbackend.auth.models import User
from fixbackend.auth.user_verifier import UserVerifier, get_user_verifier
-from fixbackend.config import Config
-from fixbackend.config import config as get_config
-from fixbackend.db import get_async_session
-from fixbackend.dependencies import FixDependencies, fix_dependencies
class InMemoryVerifier(UserVerifier):
@@ -36,33 +30,18 @@ async def verify(self, user: User, token: str, request: Optional[Request]) -> No
return self.verification_requests.append((user, token))
-verifier = InMemoryVerifier()
-
-
-@pytest.fixture
-async def client(
- db_engine: AsyncEngine, session: AsyncSession, default_config: Config
-) -> AsyncIterator[AsyncClient]: # noqa: F811
- app = fast_api_app(default_config)
- deps = FixDependencies(async_engine=db_engine)
- app.dependency_overrides[get_async_session] = lambda: session
- app.dependency_overrides[get_user_verifier] = lambda: verifier
- app.dependency_overrides[get_config] = lambda: default_config
- app.dependency_overrides[fix_dependencies] = lambda: deps
-
- async with AsyncClient(app=app, base_url="http://test") as ac:
- yield ac
-
-
@pytest.mark.asyncio
-async def test_registration_flow(client: AsyncClient) -> None:
+async def test_registration_flow(api_client: AsyncClient, fast_api: FastAPI) -> None:
+ verifier = InMemoryVerifier()
+ fast_api.dependency_overrides[get_user_verifier] = lambda: verifier
+
registration_json = {
"email": "user@example.com",
"password": "changeme",
}
# register user
- response = await client.post("/api/auth/register", json=registration_json)
+ response = await api_client.post("/api/auth/register", json=registration_json)
assert response.status_code == 201
login_json = {
@@ -71,7 +50,7 @@ async def test_registration_flow(client: AsyncClient) -> None:
}
# non_verified can't login
- response = await client.post("/api/auth/jwt/login", data=login_json)
+ response = await api_client.post("/api/auth/jwt/login", data=login_json)
assert response.status_code == 400
# verify user
@@ -79,7 +58,7 @@ async def test_registration_flow(client: AsyncClient) -> None:
verification_json = {
"token": token,
}
- response = await client.post("/api/auth/verify", json=verification_json)
+ response = await api_client.post("/api/auth/verify", json=verification_json)
assert response.status_code == 200
response_json = response.json()
assert response_json["email"] == user.email
@@ -89,11 +68,11 @@ async def test_registration_flow(client: AsyncClient) -> None:
assert response_json["id"] == str(user.id)
# verified can login
- response = await client.post("/api/auth/jwt/login", data=login_json)
+ response = await api_client.post("/api/auth/jwt/login", data=login_json)
assert response.status_code == 204
auth_cookie = response.cookies.get("fix.auth")
assert auth_cookie is not None
# organization is created by default
- response = await client.get("/api/organizations/", cookies={"fix.auth": auth_cookie})
+ response = await api_client.get("/api/organizations/", cookies={"fix.auth": auth_cookie})
assert response.json()[0].get("name") == user.email
diff --git a/tests/fixbackend/cloud_accounts/repository_test.py b/tests/fixbackend/cloud_accounts/repository_test.py
index 376a269e..633f2ac5 100644
--- a/tests/fixbackend/cloud_accounts/repository_test.py
+++ b/tests/fixbackend/cloud_accounts/repository_test.py
@@ -18,7 +18,7 @@
from fixbackend.ids import CloudAccountId, ExternalId
from fixbackend.cloud_accounts.repository import CloudAccountRepositoryImpl
-from fixbackend.db import AsyncSessionMaker
+from fixbackend.types import AsyncSessionMaker
from fixbackend.cloud_accounts.models import CloudAccount, AwsCloudAccess
from fixbackend.organizations.service import OrganizationService
from fixbackend.auth.models import User
diff --git a/tests/fixbackend/conftest.py b/tests/fixbackend/conftest.py
index 4412c80d..157d45ad 100644
--- a/tests/fixbackend/conftest.py
+++ b/tests/fixbackend/conftest.py
@@ -23,20 +23,28 @@
from alembic.config import Config as AlembicConfig
from arq import ArqRedis, create_pool
from arq.connections import RedisSettings
+from fastapi import FastAPI
from fixcloudutils.types import Json
from httpx import AsyncClient, MockTransport, Request, Response
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy_utils import create_database, database_exists, drop_database
+from fixbackend.app import fast_api_app
+from fixbackend.cloud_accounts.repository import CloudAccountRepository, CloudAccountRepositoryImpl
from fixbackend.auth.db import get_user_repository
from fixbackend.auth.models import User
from fixbackend.collect.collect_queue import RedisCollectQueue
-from fixbackend.config import Config
-from fixbackend.db import AsyncSessionMaker
+from fixbackend.config import Config, get_config
+from fixbackend.db import get_async_session
+from fixbackend.dependencies import FixDependencies, ServiceNames, fix_dependencies
+from fixbackend.dispatcher.dispatcher_service import DispatcherService
+from fixbackend.dispatcher.next_run_repository import NextRunRepository
from fixbackend.graph_db.service import GraphDatabaseAccessManager
from fixbackend.inventory.inventory_client import InventoryClient
from fixbackend.inventory.inventory_service import InventoryService
+from fixbackend.organizations.models import Organization
from fixbackend.organizations.service import OrganizationService
+from fixbackend.types import AsyncSessionMaker
DATABASE_URL = "mysql+aiomysql://root@127.0.0.1:3306/fixbackend-testdb"
# only used to create/drop the database
@@ -158,9 +166,12 @@ async def user(session: AsyncSession) -> User:
"hashed_password": "notreallyhashed",
"is_verified": True,
}
- user = await user_db.create(user_dict)
+ return await user_db.create(user_dict)
- return user
+
+@pytest.fixture
+async def organization(organization_repository: OrganizationService, user: User) -> Organization:
+ return await organization_repository.create_organization("foo", "foo", user)
@pytest.fixture
@@ -235,3 +246,57 @@ async def app(request: Request) -> Response:
async def inventory_service(inventory_client: InventoryClient) -> AsyncIterator[InventoryService]:
async with InventoryService(inventory_client) as service:
yield service
+
+
+@pytest.fixture
+async def next_run_repository(async_session_maker: AsyncSessionMaker) -> NextRunRepository:
+ return NextRunRepository(async_session_maker)
+
+
+@pytest.fixture
+async def cloud_account_repository(async_session_maker: AsyncSessionMaker) -> CloudAccountRepository:
+ return CloudAccountRepositoryImpl(async_session_maker)
+
+
+@pytest.fixture
+async def organization_repository(
+ session: AsyncSession, graph_database_access_manager: GraphDatabaseAccessManager
+) -> OrganizationService:
+ return OrganizationService(session, graph_database_access_manager)
+
+
+@pytest.fixture
+async def dispatcher(
+ arq_redis: ArqRedis,
+ cloud_account_repository: CloudAccountRepository,
+ next_run_repository: NextRunRepository,
+ collect_queue: RedisCollectQueue,
+ graph_database_access_manager: GraphDatabaseAccessManager,
+) -> DispatcherService:
+ return DispatcherService(
+ arq_redis, cloud_account_repository, next_run_repository, collect_queue, graph_database_access_manager
+ )
+
+
+@pytest.fixture
+async def fix_deps(
+ db_engine: AsyncEngine, graph_database_access_manager: GraphDatabaseAccessManager
+) -> FixDependencies:
+ return FixDependencies(
+ **{ServiceNames.async_engine: db_engine, ServiceNames.graph_db_access: graph_database_access_manager}
+ )
+
+
+@pytest.fixture
+async def fast_api(fix_deps: FixDependencies, session: AsyncSession, default_config: Config) -> FastAPI:
+ app = fast_api_app(default_config)
+ app.dependency_overrides[get_async_session] = lambda: session
+ app.dependency_overrides[get_config] = lambda: default_config
+ app.dependency_overrides[fix_dependencies] = lambda: fix_deps
+ return app
+
+
+@pytest.fixture
+async def api_client(fast_api: FastAPI) -> AsyncIterator[AsyncClient]: # noqa: F811
+ async with AsyncClient(app=fast_api, base_url="http://test") as ac:
+ yield ac
diff --git a/fixbackend/graph_db/dependencies.py b/tests/fixbackend/dispatcher/__init__.py
similarity index 66%
rename from fixbackend/graph_db/dependencies.py
rename to tests/fixbackend/dispatcher/__init__.py
index 6d2694a7..4a5be1e9 100644
--- a/fixbackend/graph_db/dependencies.py
+++ b/tests/fixbackend/dispatcher/__init__.py
@@ -19,20 +19,11 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-
-from typing import Annotated
-
-from fastapi import Depends
-
-from fixbackend.config import ConfigDependency
-from fixbackend.db import AsyncSessionMakerDependency
-from fixbackend.graph_db.service import GraphDatabaseAccessManager
-
-
-def get_graph_database_access_manager(
- config: ConfigDependency, session_maker: AsyncSessionMakerDependency
-) -> GraphDatabaseAccessManager:
- return GraphDatabaseAccessManager(config, session_maker)
-
-
-GraphDatabaseAccessManagerDependency = Annotated[GraphDatabaseAccessManager, Depends(get_graph_database_access_manager)]
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
diff --git a/tests/fixbackend/dispatcher/dispatcher_service_test.py b/tests/fixbackend/dispatcher/dispatcher_service_test.py
new file mode 100644
index 00000000..c38fd598
--- /dev/null
+++ b/tests/fixbackend/dispatcher/dispatcher_service_test.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import uuid
+from datetime import datetime, timedelta
+
+import pytest
+from fixcloudutils.redis.event_stream import MessageContext
+from fixcloudutils.util import utc
+from redis.asyncio import Redis
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from fixbackend.cloud_accounts.models import CloudAccount, AwsCloudAccess
+from fixbackend.cloud_accounts.repository import CloudAccountRepository
+from fixbackend.dispatcher.dispatcher_service import DispatcherService
+from fixbackend.dispatcher.next_run_repository import NextRunRepository, NextRun
+from fixbackend.ids import CloudAccountId
+from fixbackend.organizations.models import Organization
+
+
+@pytest.mark.asyncio
+async def test_receive_created(
+ dispatcher: DispatcherService,
+ session: AsyncSession,
+ cloud_account_repository: CloudAccountRepository,
+ organization: Organization,
+ arq_redis: Redis,
+) -> None:
+ # create a cloud account
+ cloud_account_id = CloudAccountId(uuid.uuid1())
+ await cloud_account_repository.create(
+ CloudAccount(cloud_account_id, organization.id, AwsCloudAccess("123", organization.external_id, "test"))
+ )
+ # signal to the dispatcher that the cloud account was created
+ await dispatcher.process_message(
+ {"id": str(cloud_account_id)}, MessageContext("test", "cloud_account_created", "test", utc(), utc())
+ )
+ # check that a new entry was created in the next_run table
+ next_run = await session.get(NextRun, cloud_account_id)
+ assert next_run is not None
+ assert next_run.at > datetime.now() # next run is in the future
+ # check that two new entries are created in the work queue: (e.g.: arq:queue, arq:job:xxx)
+ assert len(await arq_redis.keys()) == 2
+
+
+@pytest.mark.asyncio
+async def test_receive_deleted(
+ dispatcher: DispatcherService, session: AsyncSession, next_run_repository: NextRunRepository
+) -> None:
+ # create cloud
+ cloud_account_id = CloudAccountId(uuid.uuid1())
+ # create a next run entry
+ await next_run_repository.create(cloud_account_id, utc())
+ # signal to the dispatcher that the cloud account was created
+ await dispatcher.process_message(
+ {"id": str(cloud_account_id)}, MessageContext("test", "cloud_account_deleted", "test", utc(), utc())
+ )
+ # check that a new entry was created in the next_run table
+ next_run = await session.get(NextRun, cloud_account_id)
+ assert next_run is None
+
+
+@pytest.mark.asyncio
+async def test_trigger_collect(
+ dispatcher: DispatcherService,
+ session: AsyncSession,
+ cloud_account_repository: CloudAccountRepository,
+ next_run_repository: NextRunRepository,
+ organization: Organization,
+ arq_redis: Redis,
+) -> None:
+ # create a cloud account and next_run entry
+ cloud_account_id = CloudAccountId(uuid.uuid1())
+ account = CloudAccount(cloud_account_id, organization.id, AwsCloudAccess("123", organization.external_id, "test"))
+ await cloud_account_repository.create(account)
+ # Create a next run entry scheduled in the past - it should be picked up for collect
+ await next_run_repository.create(cloud_account_id, utc() - timedelta(hours=1))
+
+ # schedule runs: make sure a collect is triggered and the next_run is updated
+ await dispatcher.schedule_next_runs()
+ next_run = await session.get(NextRun, cloud_account_id)
+ assert next_run is not None
+ assert next_run.at > datetime.now() # next run is in the future
+ # check that two new entries are created in the work queue: (e.g.: arq:queue, arq:job:xxx)
+ assert len(await arq_redis.keys()) == 2
+
+ # another run should not change anything
+ await dispatcher.schedule_next_runs()
+ again = await session.get(NextRun, cloud_account_id)
+ assert again is not None
+ assert again.at == next_run.at
+ assert len(await arq_redis.keys()) == 2
diff --git a/tests/fixbackend/dispatcher/next_run_repository_test.py b/tests/fixbackend/dispatcher/next_run_repository_test.py
new file mode 100644
index 00000000..0feec909
--- /dev/null
+++ b/tests/fixbackend/dispatcher/next_run_repository_test.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import uuid
+from datetime import timedelta
+
+import pytest
+from fixcloudutils.util import utc
+
+from fixbackend.dispatcher.next_run_repository import NextRunRepository
+from fixbackend.ids import CloudAccountId
+
+
+@pytest.mark.asyncio
+async def test_create(next_run_repository: NextRunRepository) -> None:
+ cid = CloudAccountId(uuid.uuid1())
+ now = utc()
+ now_minus_1 = now - timedelta(minutes=1)
+ await next_run_repository.create(cid, now_minus_1)
+ # no entries that are older than 1 hour
+ entries = [entry async for entry in next_run_repository.older_than(now - timedelta(hours=1))]
+ assert len(entries) == 0
+ # one entry that is older than now
+ assert [entry async for entry in next_run_repository.older_than(now)] == [cid]
+ # update the entry to run in 1 hour
+ await next_run_repository.update_next_run_at(cid, now + timedelta(hours=1))
+ assert [entry async for entry in next_run_repository.older_than(now)] == []
+ assert [entry async for entry in next_run_repository.older_than(now + timedelta(hours=2))] == [cid]
+ # delete the entry
+ await next_run_repository.delete(cid)
+ assert [entry async for entry in next_run_repository.older_than(now + timedelta(days=365))] == []
diff --git a/tests/fixbackend/graph_db/graph_db_access_test.py b/tests/fixbackend/graph_db/graph_db_access_test.py
index 62eeb27d..29e88113 100644
--- a/tests/fixbackend/graph_db/graph_db_access_test.py
+++ b/tests/fixbackend/graph_db/graph_db_access_test.py
@@ -23,9 +23,9 @@
import pytest
-from fixbackend.db import AsyncSessionMaker
-from fixbackend.ids import TenantId
from fixbackend.graph_db.service import GraphDatabaseAccessManager
+from fixbackend.ids import TenantId
+from fixbackend.types import AsyncSessionMaker
@pytest.mark.asyncio