diff --git a/fixbackend/app.py b/fixbackend/app.py index cc7938e1..0aca2950 100644 --- a/fixbackend/app.py +++ b/fixbackend/app.py @@ -88,6 +88,7 @@ from fixbackend.subscription.billing import BillingService from fixbackend.subscription.router import subscription_router from fixbackend.subscription.subscription_repository import SubscriptionRepository +from fixbackend.workspaces.invitation_repository import InvitationRepositoryImpl from fixbackend.workspaces.repository import WorkspaceRepositoryImpl from fixbackend.workspaces.router import workspaces_router from fixbackend.domain_events.subscriber import DomainEventSubscriber @@ -160,7 +161,20 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]: domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events)) workspace_repo = deps.add( SN.workspace_repo, - WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher), + WorkspaceRepositoryImpl( + session_maker, + graph_db_access, + domain_event_publisher, + RedisPubSubPublisher( + redis=readwrite_redis, + channel="workspaces", + publisher_name="workspace_service", + ), + ), + ) + deps.add( + SN.invitation_repository, + InvitationRepositoryImpl(session_maker, workspace_repo), ) subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker)) deps.add( @@ -245,8 +259,18 @@ async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]: domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events)) workspace_repo = deps.add( SN.workspace_repo, - WorkspaceRepositoryImpl(session_maker, db_access, domain_event_publisher), + WorkspaceRepositoryImpl( + session_maker, + db_access, + domain_event_publisher, + RedisPubSubPublisher( + redis=rw_redis, + channel="workspaces", + publisher_name="workspace_service", + ), + ), ) + cloud_accounts_redis_publisher = RedisPubSubPublisher( redis=rw_redis, channel="cloud_accounts", @@ -306,7 +330,16 @@ async def setup_teardown_billing(_: FastAPI) -> AsyncIterator[None]: metering_repo = deps.add(SN.metering_repo, MeteringRepository(session_maker)) workspace_repo = deps.add( SN.workspace_repo, - WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher), + WorkspaceRepositoryImpl( + session_maker, + graph_db_access, + domain_event_publisher, + RedisPubSubPublisher( + redis=readwrite_redis, + channel="workspaces", + publisher_name="workspace_service", + ), + ), ) subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker)) aws_marketplace = deps.add( diff --git a/fixbackend/auth/user_manager.py b/fixbackend/auth/user_manager.py index 25f96f6e..472af768 100644 --- a/fixbackend/auth/user_manager.py +++ b/fixbackend/auth/user_manager.py @@ -20,13 +20,15 @@ from fastapi_users import BaseUserManager, UUIDIDMixin from fastapi_users.password import PasswordHelperProtocol -from fixbackend.auth.db import UserRepository, UserRepositoryDependency +from fixbackend.auth.user_repository import UserRepository, UserRepositoryDependency from fixbackend.auth.models import User from fixbackend.auth.user_verifier import UserVerifier, UserVerifierDependency from fixbackend.config import Config, ConfigDependency from fixbackend.domain_events.events import UserRegistered from fixbackend.domain_events.publisher import DomainEventPublisher from fixbackend.domain_events.dependencies import DomainEventPublisherDependency +from fixbackend.workspaces.invitation_repository import InvitationRepository, InvitationRepositoryDependency +from fixbackend.workspaces.models import Workspace from fixbackend.workspaces.repository import WorkspaceRepository, WorkspaceRepositoryDependency @@ -39,6 +41,7 @@ def __init__( user_verifier: UserVerifier, workspace_repository: WorkspaceRepository, domain_events_publisher: DomainEventPublisher, + invitation_repository: InvitationRepository, ): super().__init__(user_repository, password_helper) self.user_verifier = user_verifier @@ -46,10 +49,11 @@ def __init__( self.verification_token_secret = config.secret self.workspace_repository = workspace_repository self.domain_events_publisher = domain_events_publisher + self.invitation_repository = invitation_repository async def on_after_register(self, user: User, request: Request | None = None) -> None: if user.is_verified: # oauth2 users are already verified - await self.create_default_workspace(user) + await self.add_to_workspace(user) else: await self.request_verify(user, request) @@ -57,12 +61,28 @@ async def on_after_request_verify(self, user: User, token: str, request: Optiona await self.user_verifier.verify(user, token, request) async def on_after_verify(self, user: User, request: Request | None = None) -> None: - await self.create_default_workspace(user) + await self.add_to_workspace(user) - async def create_default_workspace(self, user: User) -> None: + async def add_to_workspace(self, user: User) -> None: + if ( + pending_invitation := await self.invitation_repository.get_invitation_by_email(user.email) + ) and pending_invitation.accepted_at: + if workspace := await self.workspace_repository.get_workspace(pending_invitation.workspace_id): + await self.workspace_repository.add_to_workspace(workspace.id, user.id) + else: + # wtf? + workspace = await self.create_default_workspace(user) + await self.invitation_repository.delete_invitation(pending_invitation.id) + else: + workspace = await self.create_default_workspace(user) + + await self.domain_events_publisher.publish( + UserRegistered(user_id=user.id, email=user.email, tenant_id=workspace.id) + ) + + async def create_default_workspace(self, user: User) -> Workspace: org_slug = re.sub("[^a-zA-Z0-9-]", "-", user.email) - org = await self.workspace_repository.create_workspace(user.email, org_slug, user) - await self.domain_events_publisher.publish(UserRegistered(user_id=user.id, email=user.email, tenant_id=org.id)) + return await self.workspace_repository.create_workspace(user.email, org_slug, user) async def get_user_manager( @@ -71,8 +91,17 @@ async def get_user_manager( user_verifier: UserVerifierDependency, workspace_repository: WorkspaceRepositoryDependency, domain_event_publisher: DomainEventPublisherDependency, + invitation_repository: InvitationRepositoryDependency, ) -> AsyncIterator[UserManager]: - yield UserManager(config, user_repository, None, user_verifier, workspace_repository, domain_event_publisher) + yield UserManager( + config, + user_repository, + None, + user_verifier, + workspace_repository, + domain_event_publisher, + invitation_repository, + ) UserManagerDependency = Annotated[UserManager, Depends(get_user_manager)] diff --git a/fixbackend/auth/db.py b/fixbackend/auth/user_repository.py similarity index 100% rename from fixbackend/auth/db.py rename to fixbackend/auth/user_repository.py diff --git a/fixbackend/auth/user_verifier.py b/fixbackend/auth/user_verifier.py index 013c9211..418056fc 100644 --- a/fixbackend/auth/user_verifier.py +++ b/fixbackend/auth/user_verifier.py @@ -12,15 +12,13 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import asyncio from abc import ABC, abstractmethod from typing import Annotated, Optional -import boto3 from fastapi import Depends, Request from fixbackend.auth.models import User -from fixbackend.config import Config, ConfigDependency +from fixbackend.notification.service import NotificationService, NotificationServiceDependency class UserVerifier(ABC): @@ -41,59 +39,24 @@ async def verify(self, user: User, token: str, request: Optional[Request]) -> No pass -class ConsoleUserVerifier(UserVerifier): +class UserVerifierImpl(UserVerifier): + def __init__(self, notification_service: NotificationService) -> None: + self.notification_service = notification_service + async def verify(self, user: User, token: str, request: Optional[Request]) -> None: assert request + body_text = self.plaintext_email_content(request, token) - email_body = self.plaintext_email_content(request, token) - - print(email_body) - - -class EMailUserVerifier(UserVerifier): - def __init__(self, config: Config) -> None: - self.client = boto3.client( - "ses", - config.aws_region, - aws_access_key_id=config.aws_access_key_id, - aws_secret_access_key=config.aws_secret_access_key, + await self.notification_service.send_email( + to=user.email, + subject="FIX: verify your e-mail address", + text=body_text, + html=None, ) - async def verify(self, user: User, token: str, request: Optional[Request]) -> None: - destination = user.email - assert request - def send_email(destination: str, token: str) -> None: - body_text = self.plaintext_email_content(request, token) - - self.client.send_email( - Destination={ - "ToAddresses": [ - destination, - ], - }, - Message={ - "Body": { - "Text": { - "Charset": "UTF-8", - "Data": body_text, - }, - }, - "Subject": { - "Charset": "UTF-8", - "Data": "FIX: verify your e-mail address", - }, - }, - Source="noreply@fix.tt", - ) - - await asyncio.to_thread(lambda: send_email(destination, token)) - - -def get_user_verifier(config: ConfigDependency) -> UserVerifier: - if config.aws_access_key_id and config.aws_secret_access_key: - return EMailUserVerifier(config) - return ConsoleUserVerifier() +def get_user_verifier(notification_service: NotificationServiceDependency) -> UserVerifier: + return UserVerifierImpl(notification_service) UserVerifierDependency = Annotated[UserVerifier, Depends(get_user_verifier)] diff --git a/fixbackend/dependencies.py b/fixbackend/dependencies.py index ea3ee303..a6ae7934 100644 --- a/fixbackend/dependencies.py +++ b/fixbackend/dependencies.py @@ -53,6 +53,7 @@ class ServiceNames: billing = "billing" cloud_account_service = "cloud_account_service" domain_event_subscriber = "domain_event_subscriber" + invitation_repository = "invitation_repository" class FixDependencies(Dependencies): diff --git a/fixbackend/domain_events/events.py b/fixbackend/domain_events/events.py index 500b59f2..bde6a236 100644 --- a/fixbackend/domain_events/events.py +++ b/fixbackend/domain_events/events.py @@ -125,3 +125,19 @@ class WorkspaceCreated(Event): kind: ClassVar[str] = "workspace_created" workspace_id: WorkspaceId + + +@frozen +class InvitationAccepted(Event): + kind: ClassVar[str] = "workspace_invitation_accepted" + + workspace_id: WorkspaceId + user_email: str + + +@frozen +class UserJoinedWorkspace(Event): + kind: ClassVar[str] = "user_joined_workspace" + + workspace_id: WorkspaceId + user_id: UserId diff --git a/fixbackend/notification/service.py b/fixbackend/notification/service.py new file mode 100644 index 00000000..b7bad819 --- /dev/null +++ b/fixbackend/notification/service.py @@ -0,0 +1,94 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import Annotated, Optional + +import boto3 +from fastapi import Depends + +from fixbackend.config import Config, ConfigDependency + + +class NotificationService(ABC): + @abstractmethod + async def send_email( + self, + *, + to: str, + subject: str, + text: str, + html: Optional[str], + ) -> None: + """Send an email to the given address.""" + raise NotImplementedError() + + +class ConsoleNotificationService(NotificationService): + async def send_email( + self, + to: str, + subject: str, + text: str, + html: Optional[str], + ) -> None: + print(f"Sending email to {to} with subject {subject}") + print(f"text: {text}") + if html: + print(f"html: {html}") + + +class NotificationServiceImpl(NotificationService): + def __init__(self, config: Config) -> None: + self.ses = boto3.client( + "ses", + config.aws_region, + aws_access_key_id=config.aws_access_key_id, + aws_secret_access_key=config.aws_secret_access_key, + ) + + async def send_email( + self, + *, + to: str, + subject: str, + text: str, + html: Optional[str], + ) -> None: + def send_email() -> None: + body_section = { + "Text": { + "Charset": "UTF-8", + "Data": text, + }, + } + if html: + body_section["Html"] = { + "Charset": "UTF-8", + "Data": html, + } + + self.ses.send_email( + Destination={ + "ToAddresses": [ + to, + ], + }, + Message={ + "Body": body_section, + "Subject": { + "Charset": "UTF-8", + "Data": subject, + }, + }, + Source="noreply@fix.tt", + ) + + await asyncio.to_thread(send_email) + + +def get_notification_service(config: ConfigDependency) -> NotificationService: + if config.aws_access_key_id and config.aws_secret_access_key: + return NotificationServiceImpl(config) + return ConsoleNotificationService() + + +NotificationServiceDependency = Annotated[NotificationService, Depends(get_notification_service)] diff --git a/fixbackend/workspaces/invitation_repository.py b/fixbackend/workspaces/invitation_repository.py new file mode 100644 index 00000000..9035f54b --- /dev/null +++ b/fixbackend/workspaces/invitation_repository.py @@ -0,0 +1,164 @@ +from abc import ABC, abstractmethod +from datetime import timedelta +from typing import Annotated, Callable, Optional, Sequence + +from fastapi import Depends +from fixcloudutils.util import utc +from sqlalchemy import select +from sqlalchemy.orm.exc import StaleDataError + +from fixbackend.auth.user_repository import UserRepository +from fixbackend.dependencies import FixDependency, ServiceNames +from fixbackend.errors import ResourceNotFound +from fixbackend.ids import InvitationId, WorkspaceId +from fixbackend.types import AsyncSessionMaker +from fixbackend.workspaces.models import WorkspaceInvitation, orm +from fixbackend.workspaces.repository import WorkspaceRepository + + +class InvitationRepository(ABC): + @abstractmethod + async def create_invitation(self, workspace_id: WorkspaceId, email: str) -> WorkspaceInvitation: + """Create an invite for a workspace.""" + raise NotImplementedError + + @abstractmethod + async def get_invitation(self, invitation_id: InvitationId) -> Optional[WorkspaceInvitation]: + """Get an invitation by ID.""" + raise NotImplementedError + + @abstractmethod + async def get_invitation_by_email(self, email: str) -> Optional[WorkspaceInvitation]: + """Get an invitation by email.""" + raise NotImplementedError + + @abstractmethod + async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]: + """List all invitations for a workspace.""" + raise NotImplementedError + + @abstractmethod + async def update_invitation( + self, + invitation_id: InvitationId, + update_fn: Callable[[WorkspaceInvitation], WorkspaceInvitation], + ) -> WorkspaceInvitation: + """Update an invitation.""" + raise NotImplementedError + + @abstractmethod + async def delete_invitation(self, invitation_id: InvitationId) -> None: + """Delete an invitation.""" + raise NotImplementedError + + +class InvitationRepositoryImpl(InvitationRepository): + def __init__( + self, + session_maker: AsyncSessionMaker, + workspace_repository: WorkspaceRepository, + ) -> None: + self.session_maker = session_maker + self.workspace_repository = workspace_repository + + async def create_invitation(self, workspace_id: WorkspaceId, email: str) -> WorkspaceInvitation: + async with self.session_maker() as session: + existing_invitation = ( + await session.execute( + select(orm.OrganizationInvite) + .where(orm.OrganizationInvite.organization_id == workspace_id) + .where(orm.OrganizationInvite.user_email == email) + ) + ).scalar_one_or_none() + if existing_invitation: + return existing_invitation.to_model() + + user_repository = UserRepository(session) + + workspace = await self.workspace_repository.get_workspace(workspace_id, session=session) + if workspace is None: + raise ValueError(f"Workspace {workspace_id} does not exist.") + + user = await user_repository.get_by_email(email) + + if user: + if user.id in workspace.all_users(): + raise ValueError(f"User {user.id} is already a member of workspace {workspace_id}") + + invite = orm.OrganizationInvite( + organization_id=workspace_id, + user_email=email, + expires_at=utc() + timedelta(days=7), + ) + session.add(invite) + await session.commit() + await session.refresh(invite) + return invite.to_model() + + async def get_invitation(self, invitation_id: InvitationId) -> Optional[WorkspaceInvitation]: + async with self.session_maker() as session: + statement = select(orm.OrganizationInvite).where(orm.OrganizationInvite.id == invitation_id) + results = await session.execute(statement) + invite = results.unique().scalar_one_or_none() + return invite.to_model() if invite else None + + async def get_invitation_by_email(self, email: str) -> Optional[WorkspaceInvitation]: + async with self.session_maker() as session: + statement = select(orm.OrganizationInvite).where(orm.OrganizationInvite.user_email == email) + results = await session.execute(statement) + invite = results.unique().scalar_one_or_none() + return invite.to_model() if invite else None + + async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]: + async with self.session_maker() as session: + statement = select(orm.OrganizationInvite).where(orm.OrganizationInvite.organization_id == workspace_id) + results = await session.execute(statement) + invites = results.scalars().all() + return [invite.to_model() for invite in invites] + + async def update_invitation( + self, + invitation_id: InvitationId, + update_fn: Callable[[WorkspaceInvitation], WorkspaceInvitation], + ) -> WorkspaceInvitation: + async def do_updade() -> WorkspaceInvitation: + async with self.session_maker() as session: + stored_invite = await session.get(orm.OrganizationInvite, invitation_id) + if stored_invite is None: + raise ResourceNotFound(f"Cloud account {invitation_id} not found") + + invite = update_fn(stored_invite.to_model()) + + if stored_invite.to_model() == invite: + # nothing to update + return invite + + stored_invite.organization_id = invite.workspace_id + stored_invite.user_email = invite.email + stored_invite.expires_at = invite.expires_at + stored_invite.accepted_at = invite.accepted_at + + await session.commit() + await session.refresh(stored_invite) + return stored_invite.to_model() + + while True: + try: + return await do_updade() + except StaleDataError: # in case of concurrent update + pass + + async def delete_invitation(self, invitation_id: InvitationId) -> None: + async with self.session_maker() as session: + invite = await session.get(orm.OrganizationInvite, invitation_id) + if invite is None: + raise ValueError(f"Invitation {invitation_id} does not exist.") + await session.delete(invite) + await session.commit() + + +async def get_invitation_repository(fix: FixDependency) -> InvitationRepository: + return fix.service(ServiceNames.invitation_repository, InvitationRepositoryImpl) + + +InvitationRepositoryDependency = Annotated[InvitationRepository, Depends(get_invitation_repository)] diff --git a/fixbackend/workspaces/invitation_service.py b/fixbackend/workspaces/invitation_service.py new file mode 100644 index 00000000..671214c0 --- /dev/null +++ b/fixbackend/workspaces/invitation_service.py @@ -0,0 +1,167 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +from abc import ABC, abstractmethod +from datetime import timedelta +from logging import getLogger +from typing import Annotated, Dict, Sequence, Tuple +from attrs import evolve +from fastapi import Depends + +import jwt +from fastapi_users.jwt import decode_jwt, generate_jwt + +from fixbackend.auth.models import User +from fixbackend.auth.user_repository import UserRepository, UserRepositoryDependency +from fixbackend.config import Config, ConfigDependency +from fixbackend.domain_events.dependencies import DomainEventPublisherDependency +from fixbackend.domain_events.events import InvitationAccepted +from fixbackend.domain_events.publisher import DomainEventPublisher +from fixbackend.ids import InvitationId, WorkspaceId +from fixbackend.notification.service import NotificationService, NotificationServiceDependency +from fixbackend.workspaces.invitation_repository import InvitationRepository, InvitationRepositoryDependency +from fixbackend.workspaces.models import WorkspaceInvitation +from fixbackend.workspaces.repository import WorkspaceRepository, WorkspaceRepositoryDependency +from fixcloudutils.util import utc + + +log = getLogger(__name__) + +STATE_TOKEN_AUDIENCE = "fix:invitation-state" + + +def generate_state_token(data: Dict[str, str], secret: str) -> str: + data["aud"] = STATE_TOKEN_AUDIENCE + return generate_jwt(data, secret, int(timedelta(days=7).total_seconds())) + + +class InvitationService(ABC): + @abstractmethod + async def invite_user( + self, workspace_id: WorkspaceId, inviter: User, invitee_email: str, accept_invite_base_url: str + ) -> Tuple[WorkspaceInvitation, str]: + """Create an invitation to a workspace.""" + raise NotImplementedError() + + @abstractmethod + async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]: + """List all invitations to a workspace.""" + raise NotImplementedError() + + @abstractmethod + async def accept_invitation(self, token: str) -> WorkspaceInvitation: + """Accept an invitation to a workspace.""" + raise NotImplementedError() + + @abstractmethod + async def revoke_invitation(self, invitation_id: InvitationId) -> None: + """Revoke an invitation to a workspace.""" + raise NotImplementedError() + + +class InvitationServiceImpl(InvitationService): + def __init__( + self, + workspace_repository: WorkspaceRepository, + invitation_repository: InvitationRepository, + notification_service: NotificationService, + user_repository: UserRepository, + domain_events: DomainEventPublisher, + config: Config, + ) -> None: + self.invitation_repository = invitation_repository + self.notification_service = notification_service + self.workspace_repository = workspace_repository + self.user_repository = user_repository + self.domain_events = domain_events + self.config = config + + async def invite_user( + self, workspace_id: WorkspaceId, inviter: User, invitee_email: str, accept_invite_base_url: str + ) -> Tuple[WorkspaceInvitation, str]: + workspace = await self.workspace_repository.get_workspace(workspace_id) + if workspace is None: + raise ValueError(f"Workspace {workspace_id} does not exist.") + + # this is idempotent and will return the existing invitation if it exists + invitation = await self.invitation_repository.create_invitation(workspace_id, invitee_email) + + state_data: Dict[str, str] = { + "invitation_id": str(invitation.id), + } + token = generate_state_token(state_data, secret=self.config.secret) + + subject = f"FIX Cloud {inviter.email} has invited you to FIX workspace" + invite_link = f"{accept_invite_base_url}?token={token}" + text = ( + f"{inviter.email} has invited you to join the workspace {workspace.name}. " + "Please click on the link below to accept the invitation. \n\n" + f"{invite_link}" + ) + await self.notification_service.send_email(to=invitee_email, subject=subject, text=text, html=None) + return invitation, token + + async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]: + return await self.invitation_repository.list_invitations(workspace_id) + + async def accept_invitation(self, token: str) -> WorkspaceInvitation: + try: + decoded_state = decode_jwt(token, self.config.secret, [STATE_TOKEN_AUDIENCE]) + except (jwt.ExpiredSignatureError, jwt.DecodeError) as ex: + log.info(f"accept invitation callback: invalid state token: {token}, {ex}") + raise ValueError("Invalid state token.", ex) + + invitation_id = decoded_state["invitation_id"] + invitation = await self.invitation_repository.get_invitation(invitation_id) + if invitation is None: + raise ValueError(f"Invitation {invitation_id} does not exist.") + + updated = await self.invitation_repository.update_invitation( + invitation_id, lambda invite: evolve(invite, accepted_at=utc()) + ) + + # in case the user already exists, add it to the workspace and delete the invitation + if user := await self.user_repository.get_by_email(invitation.email): + await self.workspace_repository.add_to_workspace(invitation.workspace_id, user.id) + await self.invitation_repository.delete_invitation(invitation_id) + + event = InvitationAccepted(invitation.workspace_id, invitation.email) + await self.domain_events.publish(event) + + return updated + + async def revoke_invitation(self, invitation_id: InvitationId) -> None: + await self.invitation_repository.delete_invitation(invitation_id) + + +def get_invitation_service( + workspace_repository: WorkspaceRepositoryDependency, + invitation_repository: InvitationRepositoryDependency, + notification_service: NotificationServiceDependency, + user_repository: UserRepositoryDependency, + domain_events: DomainEventPublisherDependency, + config: ConfigDependency, +) -> InvitationService: + return InvitationServiceImpl( + workspace_repository=workspace_repository, + invitation_repository=invitation_repository, + notification_service=notification_service, + user_repository=user_repository, + domain_events=domain_events, + config=config, + ) + + +InvitationServiceDependency = Annotated[InvitationService, Depends(get_invitation_service)] diff --git a/fixbackend/workspaces/models/__init__.py b/fixbackend/workspaces/models/__init__.py index 21ed340d..990d383f 100644 --- a/fixbackend/workspaces/models/__init__.py +++ b/fixbackend/workspaces/models/__init__.py @@ -13,12 +13,11 @@ # along with this program. If not, see . from datetime import datetime -from typing import List -from uuid import UUID +from typing import List, Optional from attrs import frozen -from fixbackend.ids import WorkspaceId, UserId, ExternalId +from fixbackend.ids import InvitationId, WorkspaceId, UserId, ExternalId @frozen @@ -35,8 +34,9 @@ def all_users(self) -> List[UserId]: @frozen -class WorkspaceInvite: - id: UUID +class WorkspaceInvitation: + id: InvitationId workspace_id: WorkspaceId - user_id: UserId + email: str expires_at: datetime + accepted_at: Optional[datetime] diff --git a/fixbackend/workspaces/models/orm.py b/fixbackend/workspaces/models/orm.py index 3b73742b..fa639618 100644 --- a/fixbackend/workspaces/models/orm.py +++ b/fixbackend/workspaces/models/orm.py @@ -14,15 +14,15 @@ import uuid from datetime import datetime -from typing import List +from typing import List, Optional from fastapi_users_db_sqlalchemy.generics import GUID -from sqlalchemy import ForeignKey, String, DateTime +from sqlalchemy import ForeignKey, String, DateTime, Integer from sqlalchemy.orm import Mapped, relationship, mapped_column from fixbackend.auth.models import orm from fixbackend.base_model import Base -from fixbackend.ids import WorkspaceId, UserId, ExternalId +from fixbackend.ids import InvitationId, WorkspaceId, UserId, ExternalId from fixbackend.workspaces import models @@ -50,19 +50,22 @@ def to_model(self) -> models.Workspace: class OrganizationInvite(Base): __tablename__ = "organization_invite" - id: Mapped[uuid.UUID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) - organization_id: Mapped[uuid.UUID] = mapped_column(GUID, ForeignKey("organization.id"), nullable=False) - organization: Mapped[Organization] = relationship() - user_id: Mapped[uuid.UUID] = mapped_column(GUID, ForeignKey("user.id"), nullable=False) - user: Mapped[orm.User] = relationship() + id: Mapped[InvitationId] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) + organization_id: Mapped[WorkspaceId] = mapped_column(GUID, ForeignKey("organization.id"), nullable=False) + user_email: Mapped[str] = mapped_column(String(length=320), nullable=False, unique=True) expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + accepted_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + __mapper_args__ = {"version_id_col": version_id} # for optimistic locking - def to_model(self) -> models.WorkspaceInvite: - return models.WorkspaceInvite( + def to_model(self) -> models.WorkspaceInvitation: + return models.WorkspaceInvitation( id=self.id, - workspace_id=WorkspaceId(self.organization_id), - user_id=UserId(self.user_id), + workspace_id=self.organization_id, + email=self.user_email, expires_at=self.expires_at, + accepted_at=self.accepted_at, ) diff --git a/fixbackend/workspaces/repository.py b/fixbackend/workspaces/repository.py index 82ea14ac..946016fe 100644 --- a/fixbackend/workspaces/repository.py +++ b/fixbackend/workspaces/repository.py @@ -14,23 +14,23 @@ import uuid from abc import ABC, abstractmethod -from datetime import datetime, timedelta from typing import Annotated, Optional, Sequence from fastapi import Depends -from sqlalchemy import select +from fixcloudutils.redis.pub_sub import RedisPubSubPublisher +from sqlalchemy import select, or_ from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from fixbackend.auth.models import User -from fixbackend.auth.models import orm as auth_orm from fixbackend.dependencies import FixDependency, ServiceNames from fixbackend.graph_db.service import GraphDatabaseAccessManager from fixbackend.ids import ExternalId, WorkspaceId, UserId from fixbackend.types import AsyncSessionMaker -from fixbackend.workspaces.models import Workspace, WorkspaceInvite, orm +from fixbackend.workspaces.models import Workspace, orm from fixbackend.domain_events.publisher import DomainEventPublisher -from fixbackend.domain_events.events import WorkspaceCreated +from fixbackend.domain_events.events import UserJoinedWorkspace, WorkspaceCreated class WorkspaceRepository(ABC): @@ -40,7 +40,9 @@ async def create_workspace(self, name: str, slug: str, owner: User) -> Workspace raise NotImplementedError @abstractmethod - async def get_workspace(self, workspace_id: WorkspaceId) -> Optional[Workspace]: + async def get_workspace( + self, workspace_id: WorkspaceId, *, session: Optional[AsyncSession] = None + ) -> Optional[Workspace]: """Get a workspace.""" raise NotImplementedError @@ -64,31 +66,6 @@ async def remove_from_workspace(self, workspace_id: WorkspaceId, user_id: UserId """Remove a user from a workspace.""" raise NotImplementedError - @abstractmethod - async def create_invitation(self, workspace_id: WorkspaceId, user_id: UserId) -> WorkspaceInvite: - """Create an invite for a workspace.""" - raise NotImplementedError - - @abstractmethod - async def get_invitation(self, invitation_id: uuid.UUID) -> Optional[WorkspaceInvite]: - """Get an invitation by ID.""" - raise NotImplementedError - - @abstractmethod - async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvite]: - """List all invitations for a workspace.""" - raise NotImplementedError - - @abstractmethod - async def accept_invitation(self, invitation_id: uuid.UUID) -> None: - """Accept an invitation to a workspace.""" - raise NotImplementedError - - @abstractmethod - async def delete_invitation(self, invitation_id: uuid.UUID) -> None: - """Delete an invitation.""" - raise NotImplementedError - class WorkspaceRepositoryImpl(WorkspaceRepository): def __init__( @@ -96,10 +73,12 @@ def __init__( session_maker: AsyncSessionMaker, graph_db_access_manager: GraphDatabaseAccessManager, domain_event_sender: DomainEventPublisher, + pubsub_publisher: RedisPubSubPublisher, ) -> None: self.session_maker = session_maker self.graph_db_access_manager = graph_db_access_manager self.domain_event_sender = domain_event_sender + self.pubsub_publisher = pubsub_publisher async def create_workspace(self, name: str, slug: str, owner: User) -> Workspace: async with self.session_maker() as session: @@ -123,13 +102,21 @@ async def create_workspace(self, name: str, slug: str, owner: User) -> Workspace org = results.unique().scalar_one() return org.to_model() - async def get_workspace(self, workspace_id: WorkspaceId) -> Optional[Workspace]: - async with self.session_maker() as session: + async def get_workspace( + self, workspace_id: WorkspaceId, *, session: Optional[AsyncSession] = None + ) -> Optional[Workspace]: + async def get_ws(session: AsyncSession) -> Optional[Workspace]: statement = select(orm.Organization).where(orm.Organization.id == workspace_id) results = await session.execute(statement) org = results.unique().scalar_one_or_none() return org.to_model() if org else None + if session is not None: + return await get_ws(session) + else: + async with self.session_maker() as session: + return await get_ws(session) + async def update_workspace(self, workspace_id: WorkspaceId, name: str, generate_external_id: bool) -> Workspace: """Update a workspace.""" async with self.session_maker() as session: @@ -148,7 +135,16 @@ async def update_workspace(self, workspace_id: WorkspaceId, name: str, generate_ async def list_workspaces(self, user_id: UserId) -> Sequence[Workspace]: async with self.session_maker() as session: statement = ( - select(orm.Organization).join(orm.OrganizationOwners).where(orm.OrganizationOwners.user_id == user_id) + select(orm.Organization) + .join( + orm.OrganizationOwners, orm.Organization.id == orm.OrganizationOwners.organization_id, isouter=True + ) + .join( + orm.OrganizationMembers, + orm.Organization.id == orm.OrganizationMembers.organization_id, + isouter=True, + ) + .where(or_(orm.OrganizationOwners.user_id == user_id, orm.OrganizationMembers.user_id == user_id)) ) results = await session.execute(statement) orgs = results.unique().scalars().all() @@ -168,6 +164,10 @@ async def add_to_workspace(self, workspace_id: WorkspaceId, user_id: UserId) -> except IntegrityError: raise ValueError("Can't add user to workspace.") + event = UserJoinedWorkspace(workspace_id, user_id) + await self.domain_event_sender.publish(event) + await self.pubsub_publisher.publish(event.kind, event.to_json(), f"tenant-events::{event.workspace_id}") + async def remove_from_workspace(self, workspace_id: WorkspaceId, user_id: UserId) -> None: async with self.session_maker() as session: membership = await session.get(orm.OrganizationMembers, (workspace_id, user_id)) @@ -176,70 +176,6 @@ async def remove_from_workspace(self, workspace_id: WorkspaceId, user_id: UserId await session.delete(membership) await session.commit() - async def create_invitation(self, workspace_id: WorkspaceId, user_id: UserId) -> WorkspaceInvite: - async with self.session_maker() as session: - user = await session.get(auth_orm.User, user_id) - organization = await self.get_workspace(workspace_id) - - if user is None or organization is None: - raise ValueError(f"User {user_id} or organization {workspace_id} does not exist.") - - if user.id in [owner for owner in organization.owners]: - raise ValueError(f"User {user_id} is already an owner of workspace {workspace_id}") - - if user.id in [member for member in organization.members]: - raise ValueError(f"User {user_id} is already a member of workspace {workspace_id}") - - invite = orm.OrganizationInvite( - user_id=user_id, organization_id=workspace_id, expires_at=datetime.utcnow() + timedelta(days=7) - ) - session.add(invite) - await session.commit() - await session.refresh(invite) - return invite.to_model() - - async def get_invitation(self, invitation_id: uuid.UUID) -> Optional[WorkspaceInvite]: - async with self.session_maker() as session: - statement = ( - select(orm.OrganizationInvite) - .where(orm.OrganizationInvite.id == invitation_id) - .options(selectinload(orm.OrganizationInvite.user)) - ) - results = await session.execute(statement) - invite = results.unique().scalar_one_or_none() - return invite.to_model() if invite else None - - async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvite]: - async with self.session_maker() as session: - statement = ( - select(orm.OrganizationInvite) - .where(orm.OrganizationInvite.organization_id == workspace_id) - .options(selectinload(orm.OrganizationInvite.user), selectinload(orm.OrganizationInvite.organization)) - ) - results = await session.execute(statement) - invites = results.scalars().all() - return [invite.to_model() for invite in invites] - - async def accept_invitation(self, invitation_id: uuid.UUID) -> None: - async with self.session_maker() as session: - invite = await session.get(orm.OrganizationInvite, invitation_id) - if invite is None: - raise ValueError(f"Invitation {invitation_id} does not exist.") - if invite.expires_at < datetime.utcnow(): - raise ValueError(f"Invitation {invitation_id} has expired.") - membership = orm.OrganizationMembers(user_id=invite.user_id, organization_id=invite.organization_id) - session.add(membership) - await session.delete(invite) - await session.commit() - - async def delete_invitation(self, invitation_id: uuid.UUID) -> None: - async with self.session_maker() as session: - invite = await session.get(orm.OrganizationInvite, invitation_id) - if invite is None: - raise ValueError(f"Invitation {invitation_id} does not exist.") - await session.delete(invite) - await session.commit() - async def get_workspace_repository(fix: FixDependency) -> WorkspaceRepository: return fix.service(ServiceNames.workspace_repo, WorkspaceRepositoryImpl) diff --git a/fixbackend/workspaces/router.py b/fixbackend/workspaces/router.py index 28a81e35..38a9176b 100644 --- a/fixbackend/workspaces/router.py +++ b/fixbackend/workspaces/router.py @@ -12,32 +12,38 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List -from uuid import UUID +from typing import List, Optional -from fastapi import APIRouter, HTTPException -from pydantic import EmailStr +from fastapi import APIRouter, HTTPException, Request, Response +from fastapi.responses import RedirectResponse from sqlalchemy.exc import IntegrityError from fixbackend.auth.depedencies import AuthenticatedUser -from fixbackend.auth.user_manager import UserManagerDependency +from fixbackend.auth.models import User +from fixbackend.auth.user_repository import UserRepositoryDependency from fixbackend.config import ConfigDependency -from fixbackend.ids import WorkspaceId +from fixbackend.ids import InvitationId, UserId, WorkspaceId +from fixbackend.workspaces.invitation_service import InvitationServiceDependency from fixbackend.workspaces.repository import WorkspaceRepositoryDependency from fixbackend.workspaces.dependencies import UserWorkspaceDependency from fixbackend.workspaces.schemas import ( ExternalIdRead, + UserInvite, WorkspaceCreate, WorkspaceInviteRead, WorkspaceRead, WorkspaceSettingsRead, WorkspaceSettingsUpdate, + WorkspaceUserRead, ) +import asyncio def workspaces_router() -> APIRouter: router = APIRouter() + ACCEPT_INVITE_ROUTE_NAME = "accept_invitation" + @router.get("/") async def list_workspaces( user: AuthenticatedUser, workspace_repository: WorkspaceRepositoryDependency @@ -56,10 +62,10 @@ async def get_workspace( """Get a workspace.""" org = await workspace_repository.get_workspace(workspace_id) if org is None: - raise HTTPException(status_code=404, detail="Organization not found") + raise HTTPException(status_code=404, detail="Workspace not found") if user.id not in org.all_users(): - raise HTTPException(status_code=403, detail="You are not an owner of this organization") + raise HTTPException(status_code=403, detail="You are not a member of this workspace") return WorkspaceRead.from_model(org) @@ -103,71 +109,72 @@ async def create_workspace( @router.get("/{workspace_id}/invites/") async def list_invites( workspace: UserWorkspaceDependency, - workspace_repository: WorkspaceRepositoryDependency, + invitation_service: InvitationServiceDependency, ) -> List[WorkspaceInviteRead]: - invites = await workspace_repository.list_invitations(workspace_id=workspace.id) + invites = await invitation_service.list_invitations(workspace_id=workspace.id) - return [ - WorkspaceInviteRead( - organization_slug=workspace.slug, - user_id=invite.user_id, - expires_at=invite.expires_at, - ) - for invite in invites - ] + return [WorkspaceInviteRead.from_model(invite, workspace) for invite in invites] + + @router.get("/{workspace_id}/users/") + async def list_users( + workspace: UserWorkspaceDependency, + user_repository: UserRepositoryDependency, + ) -> List[WorkspaceUserRead]: + user_ids = workspace.all_users() + users: List[Optional[User]] = await asyncio.gather(*[user_repository.get(user_id) for user_id in user_ids]) + return [WorkspaceUserRead.from_model(user) for user in users if user] @router.post("/{workspace_id}/invites/") async def invite_to_organization( workspace: UserWorkspaceDependency, - user_email: EmailStr, - workspace_repository: WorkspaceRepositoryDependency, - user_manager: UserManagerDependency, + user: AuthenticatedUser, + user_invite: UserInvite, + invitation_service: InvitationServiceDependency, + request: Request, ) -> WorkspaceInviteRead: """Invite a user to the workspace.""" - user = await user_manager.get_by_email(user_email) - if user is None: - raise HTTPException(status_code=404, detail="User not found") - - invite = await workspace_repository.create_invitation(workspace_id=workspace.id, user_id=user.id) + accept_invite_url = str(request.url_for(ACCEPT_INVITE_ROUTE_NAME, workspace_id=workspace.id)) - return WorkspaceInviteRead( - organization_slug=workspace.slug, - user_id=user.id, - expires_at=invite.expires_at, + invite, _ = await invitation_service.invite_user( + workspace_id=workspace.id, + inviter=user, + invitee_email=user_invite.email, + accept_invite_base_url=accept_invite_url, ) + return WorkspaceInviteRead.from_model(invite, workspace) + + @router.delete("/{workspace_id}/users/{user_id}/") + async def remove_user( + workspace: UserWorkspaceDependency, + user_id: UserId, + workspace_repository: WorkspaceRepositoryDependency, + user_repository: UserRepositoryDependency, + ) -> None: + """Delete a user from the workspace.""" + user = await user_repository.get(user_id) + if user is None: + raise HTTPException(status_code=404, detail="User not found") + await workspace_repository.remove_from_workspace(workspace_id=workspace.id, user_id=user.id) + @router.delete("/{workspace_id}/invites/{invite_id}") async def delete_invite( workspace: UserWorkspaceDependency, - invite_id: UUID, - workspace_repository: WorkspaceRepositoryDependency, + invite_id: InvitationId, + invitation_service: InvitationServiceDependency, ) -> None: """Delete invite.""" - await workspace_repository.delete_invitation(invite_id) + await invitation_service.revoke_invitation(invite_id) - @router.get("{workspace_id}/invites/{invite_id}/accept") + @router.get("{workspace_id}/accept_invite", name=ACCEPT_INVITE_ROUTE_NAME) async def accept_invitation( - workspace_id: WorkspaceId, - invite_id: UUID, - user: AuthenticatedUser, - workspace_repository: WorkspaceRepositoryDependency, - ) -> None: + token: str, invitation_service: InvitationServiceDependency, request: Request + ) -> Response: """Accept an invitation to the workspace.""" - org = await workspace_repository.get_workspace(workspace_id) - if org is None: - raise HTTPException(status_code=404, detail="Organization not found") - - invite = await workspace_repository.get_invitation(invite_id) - if invite is None: - raise HTTPException(status_code=404, detail="Invitation not found") - - if user.id != invite.user_id: - raise HTTPException(status_code=403, detail="You can only accept invitations for your own account") - - await workspace_repository.accept_invitation(invite_id) - - return None + invitation = await invitation_service.accept_invitation(token) + url = request.base_url.replace_query_params(message="invitation-accepted", workspace_id=invitation.workspace_id) + return RedirectResponse(url) @router.get("/{workspace_id}/cf_url") async def get_cf_url( diff --git a/fixbackend/workspaces/schemas.py b/fixbackend/workspaces/schemas.py index e6bb8ade..dd9122e3 100644 --- a/fixbackend/workspaces/schemas.py +++ b/fixbackend/workspaces/schemas.py @@ -13,12 +13,13 @@ # along with this program. If not, see . from datetime import datetime -from typing import List +from typing import List, Optional +from fixbackend.auth.models import User from fixbackend.ids import WorkspaceId, UserId, ExternalId -from pydantic import BaseModel, Field +from pydantic import BaseModel, EmailStr, Field -from fixbackend.workspaces.models import Workspace +from fixbackend.workspaces.models import Workspace, WorkspaceInvitation class WorkspaceRead(BaseModel): @@ -104,16 +105,28 @@ class WorkspaceCreate(BaseModel): class WorkspaceInviteRead(BaseModel): - organization_slug: str = Field(description="The slug of the workspace to invite the user to") - user_id: UserId = Field(description="The id of the user to invite") + workspace_id: WorkspaceId = Field(description="The unique identifier of the workspace to invite the user to") + workspace_name: str = Field(description="The name of the workspace to invite the user to") + user_email: str = Field(description="The email of the user to invite") expires_at: datetime = Field(description="The time at which the invitation expires") + accepted_at: Optional[datetime] = Field(description="The time at which the invitation was accepted, if any") + + @staticmethod + def from_model(invite: WorkspaceInvitation, workspace: Workspace) -> "WorkspaceInviteRead": + return WorkspaceInviteRead( + workspace_id=invite.workspace_id, + workspace_name=workspace.name, + user_email=invite.email, + expires_at=invite.expires_at, + accepted_at=invite.accepted_at, + ) model_config = { "json_schema_extra": { "examples": [ { "organization_slug": "my-org", - "user_id": "00000000-0000-0000-0000-000000000000", + "user_email": "foo@bar.com", "expires_at": "2021-01-01T00:00:00Z", } ] @@ -133,3 +146,54 @@ class ExternalIdRead(BaseModel): ] } } + + +class UserInvite(BaseModel): + name: str = Field(description="The name of the user") + email: EmailStr = Field(description="The email of the user") + roles: List[str] = Field(description="The role of the user") + + model_config = { + "json_schema_extra": { + "examples": [ + { + "name": "Foo Bar", + "email": "foo@example.com", + "roles": ["admin"], + } + ] + } + } + + +class WorkspaceUserRead(BaseModel): + id: UserId = Field(description="The user's unique identifier") + sources: List[str] = Field(description="Where the user is found") + name: str = Field(description="The user's name") + email: str = Field(description="The user's email") + roles: List[str] = Field(description="The user's roles") + last_login: Optional[datetime] = Field(description="The user's last login time, if any") + + @staticmethod + def from_model(user: User) -> "WorkspaceUserRead": + return WorkspaceUserRead( + id=user.id, + sources=[], + name=user.email, + email=user.email, + roles=[], + last_login=None, + ) + + model_config = { + "json_schema_extra": { + "examples": [ + { + "sources": ["organization"], + "name": "Foo Bar", + "email": "foo@example.com", + "roles": ["admin"], + } + ] + } + } diff --git a/migrations/versions/2023-12-07T12:14:35Z_update_workspace_invites.py b/migrations/versions/2023-12-07T12:14:35Z_update_workspace_invites.py new file mode 100644 index 00000000..ebd51916 --- /dev/null +++ b/migrations/versions/2023-12-07T12:14:35Z_update_workspace_invites.py @@ -0,0 +1,34 @@ +"""add created_at/updated_at to cloud_account + +Revision ID: 1ccf5fc88e67 +Revises: d294f6e4b5dc +Create Date: 2023-12-07 12:14:35.061355+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "1ccf5fc88e67" +down_revision: Union[str, None] = "d294f6e4b5dc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + # user_email is not nullable, so we need to drop all existing invites + op.execute("TRUNCATE TABLE organization_invite") + + op.add_column("organization_invite", sa.Column("user_email", sa.String(length=320), nullable=False)) + op.add_column("organization_invite", sa.Column("accepted_at", sa.DateTime(timezone=True), nullable=True)) + op.add_column( + "organization_invite", sa.Column("version_id", sa.Integer(), nullable=False, server_default=sa.text("0")) + ) + op.create_unique_constraint(None, "organization_invite", ["user_email"]) + op.drop_constraint("organization_invite_ibfk_2", "organization_invite", type_="foreignkey") + op.drop_column("organization_invite", "user_id") + # ### end Alembic commands ### diff --git a/noxfile.py b/noxfile.py index 274a3f87..ec53a9f7 100644 --- a/noxfile.py +++ b/noxfile.py @@ -36,7 +36,7 @@ def flake8(session: Session) -> None: @session(python=["3.11"]) # type: ignore def test(session: Session) -> None: args = session.posargs or ["--cov"] - session.run_always("poetry", "install", external=True) + session.run_always("poetry", "install", "--quiet", external=True) session.run("pytest", *args) @@ -44,7 +44,7 @@ def test(session: Session) -> None: def mypy(session: Session) -> None: opts = ["--strict"] args = session.posargs or [] + opts + locations - session.run_always("poetry", "install", external=True) + session.run_always("poetry", "install", "--quiet", external=True) session.install("mypy", ".") print(args) session.run("mypy", *args) diff --git a/tests/fixbackend/auth/jwt_strategy_test.py b/tests/fixbackend/auth/jwt_strategy_test.py index 9c567645..b1f769df 100644 --- a/tests/fixbackend/auth/jwt_strategy_test.py +++ b/tests/fixbackend/auth/jwt_strategy_test.py @@ -12,21 +12,21 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from fastapi import Request - import pytest +from cryptography.hazmat.primitives.asymmetric import rsa +from fastapi import Request from sqlalchemy.ext.asyncio import AsyncSession -from fixbackend.auth.db import get_user_repository -from fixbackend.auth.models import User -from fixbackend.workspaces.repository import WorkspaceRepository from fixbackend.auth.auth_backend import FixJWTStrategy -from cryptography.hazmat.primitives.asymmetric import rsa +from fixbackend.auth.models import User from fixbackend.auth.user_manager import UserManager -from fixbackend.config import Config +from fixbackend.auth.user_repository import get_user_repository from fixbackend.auth.user_verifier import UserVerifier -from fixbackend.domain_events.publisher import DomainEventPublisher +from fixbackend.config import Config from fixbackend.domain_events.events import Event +from fixbackend.domain_events.publisher import DomainEventPublisher +from fixbackend.workspaces.invitation_repository import InvitationRepository +from fixbackend.workspaces.repository import WorkspaceRepository @pytest.fixture @@ -54,7 +54,11 @@ async def publish(self, event: Event) -> None: @pytest.mark.asyncio async def test_token_validation( - workspace_repository: WorkspaceRepository, user: User, default_config: Config, session: AsyncSession + workspace_repository: WorkspaceRepository, + user: User, + default_config: Config, + session: AsyncSession, + invitation_repository: InvitationRepository, ) -> None: private_key_1 = rsa.generate_private_key(65537, 2048) private_key_2 = rsa.generate_private_key(65537, 2048) @@ -64,7 +68,13 @@ async def test_token_validation( user_repo = await anext(get_user_repository(session)) user_manager = UserManager( - default_config, user_repo, None, UserVerifierMock(), workspace_repository, DomainEventSenderMock() + default_config, + user_repo, + None, + UserVerifierMock(), + workspace_repository, + DomainEventSenderMock(), + invitation_repository, ) token1 = await strategy1.write_token(user) diff --git a/tests/fixbackend/auth/router_test.py b/tests/fixbackend/auth/router_test.py index 7ecb3fa7..95068e35 100644 --- a/tests/fixbackend/auth/router_test.py +++ b/tests/fixbackend/auth/router_test.py @@ -12,7 +12,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Sequence, Tuple import pytest from fastapi import Request, FastAPI @@ -24,6 +24,9 @@ from fixbackend.domain_events.publisher import DomainEventPublisher from fixbackend.domain_events.dependencies import get_domain_event_publisher from fixbackend.domain_events.events import Event, UserRegistered, WorkspaceCreated +from fixbackend.ids import InvitationId, WorkspaceId +from fixbackend.workspaces.invitation_repository import InvitationRepository, get_invitation_repository +from fixbackend.workspaces.models import WorkspaceInvitation from tests.fixbackend.conftest import InMemoryDomainEventPublisher @@ -44,13 +47,44 @@ async def publish(self, event: Event) -> None: return self.events.append(event) +class InMemoryInvitationRepo(InvitationRepository): + async def get_invitation_by_email(self, email: str) -> Optional[WorkspaceInvitation]: + return None + + async def create_invitation(self, workspace_id: WorkspaceId, email: str) -> WorkspaceInvitation: + """Create an invite for a workspace.""" + raise NotImplementedError + + async def get_invitation(self, invitation_id: InvitationId) -> Optional[WorkspaceInvitation]: + """Get an invitation by ID.""" + raise NotImplementedError + + async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]: + """List all invitations for a workspace.""" + raise NotImplementedError + + async def update_invitation( + self, + invitation_id: InvitationId, + update_fn: Callable[[WorkspaceInvitation], WorkspaceInvitation], + ) -> WorkspaceInvitation: + """Update an invitation.""" + raise NotImplementedError + + async def delete_invitation(self, invitation_id: InvitationId) -> None: + """Delete an invitation.""" + raise NotImplementedError + + @pytest.mark.asyncio async def test_registration_flow( api_client: AsyncClient, fast_api: FastAPI, domain_event_sender: InMemoryDomainEventPublisher ) -> None: verifier = InMemoryVerifier() + invitation_repo = InMemoryInvitationRepo() fast_api.dependency_overrides[get_user_verifier] = lambda: verifier fast_api.dependency_overrides[get_domain_event_publisher] = lambda: domain_event_sender + fast_api.dependency_overrides[get_invitation_repository] = lambda: invitation_repo registration_json = { "email": "user@example.com", diff --git a/tests/fixbackend/cloud_accounts/service_test.py b/tests/fixbackend/cloud_accounts/service_test.py index 4650178d..54b88be6 100644 --- a/tests/fixbackend/cloud_accounts/service_test.py +++ b/tests/fixbackend/cloud_accounts/service_test.py @@ -24,6 +24,7 @@ from fixcloudutils.types import Json from httpx import AsyncClient, Request, Response from redis.asyncio import Redis +from sqlalchemy.ext.asyncio import AsyncSession from fixbackend.cloud_accounts.account_setup import AssumeRoleResult, AssumeRoleResults, AwsAccountSetupHelper from fixbackend.cloud_accounts.models import AwsCloudAccess, CloudAccount, CloudAccountStates @@ -120,12 +121,14 @@ async def list_all_discovered_accounts(self) -> List[CloudAccount]: ) -class OrganizationServiceMock(WorkspaceRepositoryImpl): +class WorkspaceServiceMock(WorkspaceRepositoryImpl): # noinspection PyMissingConstructor def __init__(self) -> None: pass - async def get_workspace(self, workspace_id: WorkspaceId, with_users: bool = False) -> Workspace | None: + async def get_workspace( + self, workspace_id: WorkspaceId, with_users: bool = False, *, session: Optional[AsyncSession] = None + ) -> Workspace | None: if workspace_id != test_workspace_id: return None return organization @@ -194,8 +197,8 @@ def repository() -> CloudAccountRepositoryMock: @pytest.fixture -def organization_repository() -> OrganizationServiceMock: - return OrganizationServiceMock() +def organization_repository() -> WorkspaceServiceMock: + return WorkspaceServiceMock() @pytest.fixture @@ -215,7 +218,7 @@ def account_setup_helper() -> AwsAccountSetupHelperMock: @pytest.fixture def service( - organization_repository: OrganizationServiceMock, + organization_repository: WorkspaceServiceMock, repository: CloudAccountRepositoryMock, pubsub_publisher: RedisPubSubPublisherMock, domain_sender: DomainEventSenderMock, diff --git a/tests/fixbackend/conftest.py b/tests/fixbackend/conftest.py index 461d9519..7d107503 100644 --- a/tests/fixbackend/conftest.py +++ b/tests/fixbackend/conftest.py @@ -20,6 +20,7 @@ from datetime import datetime, timezone from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Sequence, Tuple, Optional from unittest.mock import patch +from attrs import frozen import pytest from alembic.command import upgrade as alembic_upgrade @@ -35,7 +36,7 @@ from sqlalchemy_utils import create_database, database_exists, drop_database from fixbackend.app import fast_api_app -from fixbackend.auth.db import get_user_repository +from fixbackend.auth.user_repository import get_user_repository, UserRepository from fixbackend.auth.models import User from fixbackend.cloud_accounts.repository import CloudAccountRepository, CloudAccountRepositoryImpl from fixbackend.collect.collect_queue import RedisCollectQueue @@ -58,8 +59,10 @@ from fixbackend.subscription.subscription_repository import SubscriptionRepository from fixbackend.types import AsyncSessionMaker from fixbackend.utils import start_of_next_month, uid +from fixbackend.workspaces.invitation_repository import InvitationRepository, InvitationRepositoryImpl from fixbackend.workspaces.models import Workspace from fixbackend.workspaces.repository import WorkspaceRepository, WorkspaceRepositoryImpl +from fixcloudutils.redis.pub_sub import RedisPubSubPublisher DATABASE_URL = "mysql+aiomysql://root@127.0.0.1:3306/fixbackend-testdb" # only used to create/drop the database @@ -211,15 +214,21 @@ def graph_database_access_manager( return GraphDatabaseAccessManager(default_config, async_session_maker) +@pytest.fixture +async def user_repository(session: AsyncSession) -> UserRepository: + repo = await anext(get_user_repository(session)) + return repo + + @pytest.fixture async def user(session: AsyncSession) -> User: - user_db = await anext(get_user_repository(session)) + user_repository = await anext(get_user_repository(session)) user_dict = { "email": "foo@bar.com", "hashed_password": "notreallyhashed", "is_verified": True, } - return await user_db.create(user_dict) + return await user_repository.create(user_dict) @pytest.fixture @@ -428,13 +437,44 @@ async def domain_event_sender() -> InMemoryDomainEventPublisher: return InMemoryDomainEventPublisher() +@frozen +class PubSubMessage: + kind: str + message: Json + channel: Optional[str] + + +class InMemoryRedisPubSubPublisher(RedisPubSubPublisher): + def __init__(self) -> None: + self.events: List[PubSubMessage] = [] + + async def publish(self, kind: str, message: Json, channel: Optional[str] = None) -> None: + self.events.append(PubSubMessage(kind, message, channel)) + + +@pytest.fixture +def pubsub_publisher() -> InMemoryRedisPubSubPublisher: + return InMemoryRedisPubSubPublisher() + + @pytest.fixture async def workspace_repository( async_session_maker: AsyncSessionMaker, graph_database_access_manager: GraphDatabaseAccessManager, domain_event_sender: DomainEventPublisher, + pubsub_publisher: InMemoryRedisPubSubPublisher, ) -> WorkspaceRepository: - return WorkspaceRepositoryImpl(async_session_maker, graph_database_access_manager, domain_event_sender) + return WorkspaceRepositoryImpl( + async_session_maker, graph_database_access_manager, domain_event_sender, pubsub_publisher + ) + + +@pytest.fixture +async def invitation_repository( + async_session_maker: AsyncSessionMaker, + workspace_repository: WorkspaceRepository, +) -> InvitationRepository: + return InvitationRepositoryImpl(async_session_maker, workspace_repository) @pytest.fixture diff --git a/tests/fixbackend/workspaces/invitation_repository_test.py b/tests/fixbackend/workspaces/invitation_repository_test.py new file mode 100644 index 00000000..fda80a03 --- /dev/null +++ b/tests/fixbackend/workspaces/invitation_repository_test.py @@ -0,0 +1,194 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +from attrs import evolve +from fixcloudutils.util import utc +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from fixbackend.auth.user_repository import get_user_repository +from fixbackend.auth.models import User +from fixbackend.workspaces.invitation_repository import InvitationRepository +from fixbackend.workspaces.repository import WorkspaceRepository + + +async def create_user(email: str, session: AsyncSession) -> User: + user_db = await anext(get_user_repository(session)) + user_dict = { + "email": email, + "hashed_password": "notreallyhashed", + "is_verified": True, + } + user = await user_db.create(user_dict) + + return user + + +@pytest.mark.asyncio +async def test_create_invitation( + workspace_repository: WorkspaceRepository, + invitation_repository: InvitationRepository, + session: AsyncSession, +) -> None: + user = await create_user("foo@bar.com", session) + organization = await workspace_repository.create_workspace( + name="Test Organization", slug="test-organization", owner=user + ) + org_id = organization.id + + user2 = await create_user("123foo@bar.com", session) + + invitation = await invitation_repository.create_invitation(workspace_id=org_id, email=user2.email) + assert invitation.workspace_id == org_id + assert invitation.email == user2.email + + # create invitation is idempotent + invitation2 = await invitation_repository.create_invitation(workspace_id=org_id, email=user2.email) + assert invitation2 == invitation + + external_email = "i_do_not_exist@bar.com" + non_user_invitation = await invitation_repository.create_invitation(workspace_id=org_id, email=external_email) + assert non_user_invitation.workspace_id == org_id + assert non_user_invitation.email == external_email + + +@pytest.mark.asyncio +async def test_list_invitations( + workspace_repository: WorkspaceRepository, + invitation_repository: InvitationRepository, + session: AsyncSession, +) -> None: + user = await create_user("foo@bar.com", session) + workspace = await workspace_repository.create_workspace( + name="Test Organization", slug="test-organization", owner=user + ) + + user_db = await anext(get_user_repository(session)) + user_dict = { + "email": "bar@bar.com", + "hashed_password": "notreallyhashed", + "is_verified": True, + } + new_user = await user_db.create(user_dict) + + invitation = await invitation_repository.create_invitation(workspace_id=workspace.id, email=new_user.email) + + # list the invitations + invitations = await invitation_repository.list_invitations(workspace_id=workspace.id) + assert len(invitations) == 1 + assert invitations[0] == invitation + + +@pytest.mark.asyncio +async def test_get_invitation( + workspace_repository: WorkspaceRepository, + invitation_repository: InvitationRepository, + session: AsyncSession, +) -> None: + user = await create_user("foo@bar.com", session) + workspace = await workspace_repository.create_workspace( + name="Test Organization", slug="test-organization", owner=user + ) + user_db = await anext(get_user_repository(session)) + user_dict = { + "email": "bar@bar.com", + "hashed_password": "notreallyhashed", + "is_verified": True, + } + new_user = await user_db.create(user_dict) + + invitation = await invitation_repository.create_invitation(workspace_id=workspace.id, email=new_user.email) + + stored_invitation = await invitation_repository.get_invitation(invitation_id=invitation.id) + assert stored_invitation == invitation + + +@pytest.mark.asyncio +async def test_get_invitation_by_email( + workspace_repository: WorkspaceRepository, + invitation_repository: InvitationRepository, + session: AsyncSession, +) -> None: + user = await create_user("foo@bar.com", session) + workspace = await workspace_repository.create_workspace( + name="Test Organization", slug="test-organization", owner=user + ) + user_db = await anext(get_user_repository(session)) + user_dict = { + "email": "bar@bar.com", + "hashed_password": "notreallyhashed", + "is_verified": True, + } + new_user = await user_db.create(user_dict) + + invitation = await invitation_repository.create_invitation(workspace_id=workspace.id, email=new_user.email) + + stored_invitation = await invitation_repository.get_invitation_by_email(email=new_user.email) + assert stored_invitation == invitation + + +@pytest.mark.asyncio +async def test_update_invitation( + workspace_repository: WorkspaceRepository, + invitation_repository: InvitationRepository, + session: AsyncSession, +) -> None: + user = await create_user("foo@bar.com", session) + workspace = await workspace_repository.create_workspace( + name="Test Organization", slug="test-organization", owner=user + ) + user_db = await anext(get_user_repository(session)) + user_dict = { + "email": "bar@bar.com", + "hashed_password": "notreallyhashed", + "is_verified": True, + } + new_user = await user_db.create(user_dict) + + invitation = await invitation_repository.create_invitation(workspace_id=workspace.id, email=new_user.email) + assert invitation.accepted_at is None + + now = utc() + + updated = await invitation_repository.update_invitation(invitation.id, lambda i: evolve(i, accepted_at=now)) + assert updated.accepted_at is not None + + +@pytest.mark.asyncio +async def test_delete_invitation( + workspace_repository: WorkspaceRepository, invitation_repository: InvitationRepository, session: AsyncSession +) -> None: + user = await create_user("foo@bar.com", session) + organization = await workspace_repository.create_workspace( + name="Test Organization", slug="test-organization", owner=user + ) + org_id = organization.id + + user_db = await anext(get_user_repository(session)) + user_dict = { + "email": "bar@bar.com", + "hashed_password": "notreallyhashed", + "is_verified": True, + } + new_user = await user_db.create(user_dict) + + invitation = await invitation_repository.create_invitation(workspace_id=org_id, email=new_user.email) + + # delete the invitation + await invitation_repository.delete_invitation(invitation_id=invitation.id) + + # the invitation should not exist anymore + invitations = await invitation_repository.list_invitations(workspace_id=org_id) + assert len(invitations) == 0 diff --git a/tests/fixbackend/workspaces/invitation_service_test.py b/tests/fixbackend/workspaces/invitation_service_test.py new file mode 100644 index 00000000..1df26e5c --- /dev/null +++ b/tests/fixbackend/workspaces/invitation_service_test.py @@ -0,0 +1,131 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +from typing import Optional, List +from attrs import frozen +import pytest +from fixbackend.domain_events.events import InvitationAccepted, UserJoinedWorkspace +from fixbackend.workspaces.invitation_service import InvitationService, InvitationServiceImpl + + +from fixbackend.workspaces.repository import WorkspaceRepository +from fixbackend.workspaces.invitation_repository import InvitationRepository +from fixbackend.notification.service import NotificationService +from fixbackend.auth.user_repository import UserRepository +from fixbackend.config import Config +from fixbackend.auth.models import User +from tests.fixbackend.conftest import InMemoryDomainEventPublisher + + +@frozen +class NotificationEmail: + to: str + subject: str + text: str + html: Optional[str] + + +class InMemoryNotificationService(NotificationService): + def __init__(self) -> None: + self.call_args: List[NotificationEmail] = [] + + async def send_email(self, *, to: str, subject: str, text: str, html: str | None) -> None: + self.call_args.append(NotificationEmail(to, subject, text, html)) + + +@pytest.fixture +def notification_service() -> InMemoryNotificationService: + return InMemoryNotificationService() + + +@pytest.fixture +def service( + workspace_repository: WorkspaceRepository, + invitation_repository: InvitationRepository, + notification_service: NotificationService, + user_repository: UserRepository, + domain_event_sender: InMemoryDomainEventPublisher, + default_config: Config, +) -> InvitationService: + return InvitationServiceImpl( + workspace_repository=workspace_repository, + invitation_repository=invitation_repository, + notification_service=notification_service, + user_repository=user_repository, + domain_events=domain_event_sender, + config=default_config, + ) + + +@pytest.mark.asyncio +async def test_invite_accept_user( + service: InvitationService, + workspace_repository: WorkspaceRepository, + invitation_repository: InvitationRepository, + notification_service: InMemoryNotificationService, + user_repository: UserRepository, + domain_event_sender: InMemoryDomainEventPublisher, + user: User, +) -> None: + workspace = await workspace_repository.create_workspace( + name="Test Organization", slug="test-organization", owner=user + ) + + new_user_email = "new@foo.com" + + # invite new user + invite, _ = await service.invite_user(workspace.id, user, new_user_email, "https://example.com") + assert await invitation_repository.list_invitations(workspace.id) == [invite] + + # idempotency + second_invite, _ = await service.invite_user(workspace.id, user, new_user_email, "https://example.com") + assert second_invite == invite + + # list invitations + assert await invitation_repository.list_invitations(workspace.id) == [invite] + + # check email + email = notification_service.call_args[0] + assert email.to == new_user_email + assert email.subject == f"FIX Cloud {user.email} has invited you to FIX workspace" + assert email.text.startswith(f"{user.email} has invited you to join the workspace {workspace.name}.") + assert "https://example.com?token=" in email.text + + # existing user + existing_user = await user_repository.create( + { + "email": "existing@foo.com", + "hashed_password": "notreallyhashed", + "is_verified": True, + } + ) + existing_invite, token = await service.invite_user(workspace.id, user, existing_user.email, "https://example.com") + + # when the existinng user accepts the invite, they should be added to the workspace automatically + # and the invitation should be deleted + await service.accept_invitation(token) + assert list(map(lambda w: w.id, await workspace_repository.list_workspaces(existing_user.id))) == [workspace.id] + assert await service.list_invitations(workspace.id) == [invite] + assert len(domain_event_sender.events) == 3 + assert domain_event_sender.events[1] == UserJoinedWorkspace(workspace.id, existing_user.id) + assert domain_event_sender.events[2] == InvitationAccepted(workspace.id, existing_user.email) + + # invite can be revoked + await service.revoke_invitation(invite.id) + assert await service.list_invitations(workspace.id) == [] + + # invlid token is rejected + with pytest.raises(ValueError): + await service.accept_invitation("invalid token") diff --git a/tests/fixbackend/workspaces/repository_test.py b/tests/fixbackend/workspaces/repository_test.py index 23c7d372..04e5f870 100644 --- a/tests/fixbackend/workspaces/repository_test.py +++ b/tests/fixbackend/workspaces/repository_test.py @@ -17,7 +17,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession -from fixbackend.auth.db import get_user_repository +from fixbackend.auth.user_repository import get_user_repository from fixbackend.auth.models import User from fixbackend.ids import WorkspaceId, UserId from fixbackend.workspaces.repository import WorkspaceRepository @@ -84,7 +84,7 @@ async def test_update_workspace(workspace_repository: WorkspaceRepository, user: @pytest.mark.asyncio -async def test_list_organizations(workspace_repository: WorkspaceRepository, user: User) -> None: +async def test_list_workspaces(workspace_repository: WorkspaceRepository, user: User, session: AsyncSession) -> None: workspace1 = await workspace_repository.create_workspace( name="Test Organization 1", slug="test-organization-1", owner=user ) @@ -93,16 +93,22 @@ async def test_list_organizations(workspace_repository: WorkspaceRepository, use name="Test Organization 2", slug="test-organization-2", owner=user ) + user_db = await anext(get_user_repository(session)) + new_user_dict = {"email": "bar@bar.com", "hashed_password": "notreallyhashed", "is_verified": True} + new_user = await user_db.create(new_user_dict) + member_only_workspace = await workspace_repository.create_workspace( + name="Test Organization 3", slug="test-organization-3", owner=new_user + ) + await workspace_repository.add_to_workspace(workspace_id=member_only_workspace.id, user_id=user.id) + # the user should be the owner of the organization workspaces = await workspace_repository.list_workspaces(user.id) - assert len(workspaces) == 2 - assert set([o.id for o in workspaces]) == {workspace1.id, workspace2.id} + assert len(workspaces) == 3 + assert set([o.id for o in workspaces]) == {workspace1.id, workspace2.id, member_only_workspace.id} @pytest.mark.asyncio -async def test_add_to_organization( - workspace_repository: WorkspaceRepository, session: AsyncSession, user: User -) -> None: +async def test_add_to_workspace(workspace_repository: WorkspaceRepository, session: AsyncSession, user: User) -> None: # add an existing user to the organization organization = await workspace_repository.create_workspace( name="Test Organization", slug="test-organization", owner=user @@ -120,104 +126,11 @@ async def test_add_to_organization( assert len(retrieved_organization.members) == 1 assert retrieved_organization.members[0] == new_user.id + assert retrieved_organization.owners[0] == user.id + # when adding a user which is already a member of the organization, nothing should happen await workspace_repository.add_to_workspace(workspace_id=org_id, user_id=new_user_id) # when adding a non-existing user to the organization, an exception should be raised with pytest.raises(Exception): await workspace_repository.add_to_workspace(workspace_id=org_id, user_id=UserId(uuid.uuid4())) - - -@pytest.mark.asyncio -async def test_create_invitation(workspace_repository: WorkspaceRepository, session: AsyncSession, user: User) -> None: - organization = await workspace_repository.create_workspace( - name="Test Organization", slug="test-organization", owner=user - ) - org_id = organization.id - - user_db = await anext(get_user_repository(session)) - user_dict = { - "email": "123foo@bar.com", - "hashed_password": "notreallyhashed", - "is_verified": True, - } - new_user = await user_db.create(user_dict) - new_user_id = new_user.id - - invitation = await workspace_repository.create_invitation(workspace_id=org_id, user_id=new_user.id) - assert invitation.workspace_id == org_id - assert invitation.user_id == new_user_id - - -@pytest.mark.asyncio -async def test_accept_invitation(workspace_repository: WorkspaceRepository, session: AsyncSession, user: User) -> None: - organization = await workspace_repository.create_workspace( - name="Test Organization", slug="test-organization", owner=user - ) - org_id = organization.id - - user_db = await anext(get_user_repository(session)) - user_dict = { - "email": "123foo@bar.com", - "hashed_password": "notreallyhashed", - "is_verified": True, - } - new_user = await user_db.create(user_dict) - - invitation = await workspace_repository.create_invitation(workspace_id=org_id, user_id=new_user.id) - - # accept the invitation - await workspace_repository.accept_invitation(invitation_id=invitation.id) - - retrieved_organization = await workspace_repository.get_workspace(org_id) - assert retrieved_organization - assert len(retrieved_organization.members) == 1 - assert retrieved_organization.members[0] == new_user.id - - -@pytest.mark.asyncio -async def test_list_invitations(workspace_repository: WorkspaceRepository, session: AsyncSession, user: User) -> None: - organization = await workspace_repository.create_workspace( - name="Test Organization", slug="test-organization", owner=user - ) - org_id = organization.id - - user_db = await anext(get_user_repository(session)) - user_dict = { - "email": "bar@bar.com", - "hashed_password": "notreallyhashed", - "is_verified": True, - } - new_user = await user_db.create(user_dict) - - invitation = await workspace_repository.create_invitation(workspace_id=org_id, user_id=new_user.id) - - # list the invitations - invitations = await workspace_repository.list_invitations(workspace_id=org_id) - assert len(invitations) == 1 - assert invitations[0] == invitation - - -@pytest.mark.asyncio -async def test_delete_invitation(workspace_repository: WorkspaceRepository, session: AsyncSession, user: User) -> None: - organization = await workspace_repository.create_workspace( - name="Test Organization", slug="test-organization", owner=user - ) - org_id = organization.id - - user_db = await anext(get_user_repository(session)) - user_dict = { - "email": "bar@bar.com", - "hashed_password": "notreallyhashed", - "is_verified": True, - } - new_user = await user_db.create(user_dict) - - invitation = await workspace_repository.create_invitation(workspace_id=org_id, user_id=new_user.id) - - # delete the invitation - await workspace_repository.delete_invitation(invitation_id=invitation.id) - - # the invitation should not exist anymore - invitations = await workspace_repository.list_invitations(workspace_id=org_id) - assert len(invitations) == 0 diff --git a/tests/fixbackend/workspaces/router_test.py b/tests/fixbackend/workspaces/router_test.py index b76552d1..52636a57 100644 --- a/tests/fixbackend/workspaces/router_test.py +++ b/tests/fixbackend/workspaces/router_test.py @@ -14,8 +14,7 @@ import uuid -from typing import AsyncIterator, Sequence -from uuid import UUID +from typing import AsyncIterator, Optional, Sequence from attrs import evolve import pytest @@ -59,10 +58,12 @@ class WorkspaceRepositoryMock(WorkspaceRepositoryImpl): def __init__(self) -> None: pass - async def get_workspace(self, workspace_id: UUID) -> Workspace | None: + async def get_workspace( + self, workspace_id: WorkspaceId, *, session: Optional[AsyncSession] = None + ) -> Workspace | None: return workspace - async def list_workspaces(self, owner_id: UUID) -> Sequence[Workspace]: + async def list_workspaces(self, owner_id: UserId) -> Sequence[Workspace]: return [workspace] async def update_workspace(self, workspace_id: WorkspaceId, name: str, generate_external_id: bool) -> Workspace: