diff --git a/fixbackend/app.py b/fixbackend/app.py
index cc7938e1..0aca2950 100644
--- a/fixbackend/app.py
+++ b/fixbackend/app.py
@@ -88,6 +88,7 @@
from fixbackend.subscription.billing import BillingService
from fixbackend.subscription.router import subscription_router
from fixbackend.subscription.subscription_repository import SubscriptionRepository
+from fixbackend.workspaces.invitation_repository import InvitationRepositoryImpl
from fixbackend.workspaces.repository import WorkspaceRepositoryImpl
from fixbackend.workspaces.router import workspaces_router
from fixbackend.domain_events.subscriber import DomainEventSubscriber
@@ -160,7 +161,20 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events))
workspace_repo = deps.add(
SN.workspace_repo,
- WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher),
+ WorkspaceRepositoryImpl(
+ session_maker,
+ graph_db_access,
+ domain_event_publisher,
+ RedisPubSubPublisher(
+ redis=readwrite_redis,
+ channel="workspaces",
+ publisher_name="workspace_service",
+ ),
+ ),
+ )
+ deps.add(
+ SN.invitation_repository,
+ InvitationRepositoryImpl(session_maker, workspace_repo),
)
subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker))
deps.add(
@@ -245,8 +259,18 @@ async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]:
domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events))
workspace_repo = deps.add(
SN.workspace_repo,
- WorkspaceRepositoryImpl(session_maker, db_access, domain_event_publisher),
+ WorkspaceRepositoryImpl(
+ session_maker,
+ db_access,
+ domain_event_publisher,
+ RedisPubSubPublisher(
+ redis=rw_redis,
+ channel="workspaces",
+ publisher_name="workspace_service",
+ ),
+ ),
)
+
cloud_accounts_redis_publisher = RedisPubSubPublisher(
redis=rw_redis,
channel="cloud_accounts",
@@ -306,7 +330,16 @@ async def setup_teardown_billing(_: FastAPI) -> AsyncIterator[None]:
metering_repo = deps.add(SN.metering_repo, MeteringRepository(session_maker))
workspace_repo = deps.add(
SN.workspace_repo,
- WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher),
+ WorkspaceRepositoryImpl(
+ session_maker,
+ graph_db_access,
+ domain_event_publisher,
+ RedisPubSubPublisher(
+ redis=readwrite_redis,
+ channel="workspaces",
+ publisher_name="workspace_service",
+ ),
+ ),
)
subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker))
aws_marketplace = deps.add(
diff --git a/fixbackend/auth/user_manager.py b/fixbackend/auth/user_manager.py
index 25f96f6e..472af768 100644
--- a/fixbackend/auth/user_manager.py
+++ b/fixbackend/auth/user_manager.py
@@ -20,13 +20,15 @@
from fastapi_users import BaseUserManager, UUIDIDMixin
from fastapi_users.password import PasswordHelperProtocol
-from fixbackend.auth.db import UserRepository, UserRepositoryDependency
+from fixbackend.auth.user_repository import UserRepository, UserRepositoryDependency
from fixbackend.auth.models import User
from fixbackend.auth.user_verifier import UserVerifier, UserVerifierDependency
from fixbackend.config import Config, ConfigDependency
from fixbackend.domain_events.events import UserRegistered
from fixbackend.domain_events.publisher import DomainEventPublisher
from fixbackend.domain_events.dependencies import DomainEventPublisherDependency
+from fixbackend.workspaces.invitation_repository import InvitationRepository, InvitationRepositoryDependency
+from fixbackend.workspaces.models import Workspace
from fixbackend.workspaces.repository import WorkspaceRepository, WorkspaceRepositoryDependency
@@ -39,6 +41,7 @@ def __init__(
user_verifier: UserVerifier,
workspace_repository: WorkspaceRepository,
domain_events_publisher: DomainEventPublisher,
+ invitation_repository: InvitationRepository,
):
super().__init__(user_repository, password_helper)
self.user_verifier = user_verifier
@@ -46,10 +49,11 @@ def __init__(
self.verification_token_secret = config.secret
self.workspace_repository = workspace_repository
self.domain_events_publisher = domain_events_publisher
+ self.invitation_repository = invitation_repository
async def on_after_register(self, user: User, request: Request | None = None) -> None:
if user.is_verified: # oauth2 users are already verified
- await self.create_default_workspace(user)
+ await self.add_to_workspace(user)
else:
await self.request_verify(user, request)
@@ -57,12 +61,28 @@ async def on_after_request_verify(self, user: User, token: str, request: Optiona
await self.user_verifier.verify(user, token, request)
async def on_after_verify(self, user: User, request: Request | None = None) -> None:
- await self.create_default_workspace(user)
+ await self.add_to_workspace(user)
- async def create_default_workspace(self, user: User) -> None:
+ async def add_to_workspace(self, user: User) -> None:
+ if (
+ pending_invitation := await self.invitation_repository.get_invitation_by_email(user.email)
+ ) and pending_invitation.accepted_at:
+ if workspace := await self.workspace_repository.get_workspace(pending_invitation.workspace_id):
+ await self.workspace_repository.add_to_workspace(workspace.id, user.id)
+ else:
+ # wtf?
+ workspace = await self.create_default_workspace(user)
+ await self.invitation_repository.delete_invitation(pending_invitation.id)
+ else:
+ workspace = await self.create_default_workspace(user)
+
+ await self.domain_events_publisher.publish(
+ UserRegistered(user_id=user.id, email=user.email, tenant_id=workspace.id)
+ )
+
+ async def create_default_workspace(self, user: User) -> Workspace:
org_slug = re.sub("[^a-zA-Z0-9-]", "-", user.email)
- org = await self.workspace_repository.create_workspace(user.email, org_slug, user)
- await self.domain_events_publisher.publish(UserRegistered(user_id=user.id, email=user.email, tenant_id=org.id))
+ return await self.workspace_repository.create_workspace(user.email, org_slug, user)
async def get_user_manager(
@@ -71,8 +91,17 @@ async def get_user_manager(
user_verifier: UserVerifierDependency,
workspace_repository: WorkspaceRepositoryDependency,
domain_event_publisher: DomainEventPublisherDependency,
+ invitation_repository: InvitationRepositoryDependency,
) -> AsyncIterator[UserManager]:
- yield UserManager(config, user_repository, None, user_verifier, workspace_repository, domain_event_publisher)
+ yield UserManager(
+ config,
+ user_repository,
+ None,
+ user_verifier,
+ workspace_repository,
+ domain_event_publisher,
+ invitation_repository,
+ )
UserManagerDependency = Annotated[UserManager, Depends(get_user_manager)]
diff --git a/fixbackend/auth/db.py b/fixbackend/auth/user_repository.py
similarity index 100%
rename from fixbackend/auth/db.py
rename to fixbackend/auth/user_repository.py
diff --git a/fixbackend/auth/user_verifier.py b/fixbackend/auth/user_verifier.py
index 013c9211..418056fc 100644
--- a/fixbackend/auth/user_verifier.py
+++ b/fixbackend/auth/user_verifier.py
@@ -12,15 +12,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import asyncio
from abc import ABC, abstractmethod
from typing import Annotated, Optional
-import boto3
from fastapi import Depends, Request
from fixbackend.auth.models import User
-from fixbackend.config import Config, ConfigDependency
+from fixbackend.notification.service import NotificationService, NotificationServiceDependency
class UserVerifier(ABC):
@@ -41,59 +39,24 @@ async def verify(self, user: User, token: str, request: Optional[Request]) -> No
pass
-class ConsoleUserVerifier(UserVerifier):
+class UserVerifierImpl(UserVerifier):
+ def __init__(self, notification_service: NotificationService) -> None:
+ self.notification_service = notification_service
+
async def verify(self, user: User, token: str, request: Optional[Request]) -> None:
assert request
+ body_text = self.plaintext_email_content(request, token)
- email_body = self.plaintext_email_content(request, token)
-
- print(email_body)
-
-
-class EMailUserVerifier(UserVerifier):
- def __init__(self, config: Config) -> None:
- self.client = boto3.client(
- "ses",
- config.aws_region,
- aws_access_key_id=config.aws_access_key_id,
- aws_secret_access_key=config.aws_secret_access_key,
+ await self.notification_service.send_email(
+ to=user.email,
+ subject="FIX: verify your e-mail address",
+ text=body_text,
+ html=None,
)
- async def verify(self, user: User, token: str, request: Optional[Request]) -> None:
- destination = user.email
- assert request
- def send_email(destination: str, token: str) -> None:
- body_text = self.plaintext_email_content(request, token)
-
- self.client.send_email(
- Destination={
- "ToAddresses": [
- destination,
- ],
- },
- Message={
- "Body": {
- "Text": {
- "Charset": "UTF-8",
- "Data": body_text,
- },
- },
- "Subject": {
- "Charset": "UTF-8",
- "Data": "FIX: verify your e-mail address",
- },
- },
- Source="noreply@fix.tt",
- )
-
- await asyncio.to_thread(lambda: send_email(destination, token))
-
-
-def get_user_verifier(config: ConfigDependency) -> UserVerifier:
- if config.aws_access_key_id and config.aws_secret_access_key:
- return EMailUserVerifier(config)
- return ConsoleUserVerifier()
+def get_user_verifier(notification_service: NotificationServiceDependency) -> UserVerifier:
+ return UserVerifierImpl(notification_service)
UserVerifierDependency = Annotated[UserVerifier, Depends(get_user_verifier)]
diff --git a/fixbackend/dependencies.py b/fixbackend/dependencies.py
index ea3ee303..a6ae7934 100644
--- a/fixbackend/dependencies.py
+++ b/fixbackend/dependencies.py
@@ -53,6 +53,7 @@ class ServiceNames:
billing = "billing"
cloud_account_service = "cloud_account_service"
domain_event_subscriber = "domain_event_subscriber"
+ invitation_repository = "invitation_repository"
class FixDependencies(Dependencies):
diff --git a/fixbackend/domain_events/events.py b/fixbackend/domain_events/events.py
index 500b59f2..bde6a236 100644
--- a/fixbackend/domain_events/events.py
+++ b/fixbackend/domain_events/events.py
@@ -125,3 +125,19 @@ class WorkspaceCreated(Event):
kind: ClassVar[str] = "workspace_created"
workspace_id: WorkspaceId
+
+
+@frozen
+class InvitationAccepted(Event):
+ kind: ClassVar[str] = "workspace_invitation_accepted"
+
+ workspace_id: WorkspaceId
+ user_email: str
+
+
+@frozen
+class UserJoinedWorkspace(Event):
+ kind: ClassVar[str] = "user_joined_workspace"
+
+ workspace_id: WorkspaceId
+ user_id: UserId
diff --git a/fixbackend/notification/service.py b/fixbackend/notification/service.py
new file mode 100644
index 00000000..b7bad819
--- /dev/null
+++ b/fixbackend/notification/service.py
@@ -0,0 +1,94 @@
+import asyncio
+from abc import ABC, abstractmethod
+from typing import Annotated, Optional
+
+import boto3
+from fastapi import Depends
+
+from fixbackend.config import Config, ConfigDependency
+
+
+class NotificationService(ABC):
+ @abstractmethod
+ async def send_email(
+ self,
+ *,
+ to: str,
+ subject: str,
+ text: str,
+ html: Optional[str],
+ ) -> None:
+ """Send an email to the given address."""
+ raise NotImplementedError()
+
+
+class ConsoleNotificationService(NotificationService):
+ async def send_email(
+ self,
+ to: str,
+ subject: str,
+ text: str,
+ html: Optional[str],
+ ) -> None:
+ print(f"Sending email to {to} with subject {subject}")
+ print(f"text: {text}")
+ if html:
+ print(f"html: {html}")
+
+
+class NotificationServiceImpl(NotificationService):
+ def __init__(self, config: Config) -> None:
+ self.ses = boto3.client(
+ "ses",
+ config.aws_region,
+ aws_access_key_id=config.aws_access_key_id,
+ aws_secret_access_key=config.aws_secret_access_key,
+ )
+
+ async def send_email(
+ self,
+ *,
+ to: str,
+ subject: str,
+ text: str,
+ html: Optional[str],
+ ) -> None:
+ def send_email() -> None:
+ body_section = {
+ "Text": {
+ "Charset": "UTF-8",
+ "Data": text,
+ },
+ }
+ if html:
+ body_section["Html"] = {
+ "Charset": "UTF-8",
+ "Data": html,
+ }
+
+ self.ses.send_email(
+ Destination={
+ "ToAddresses": [
+ to,
+ ],
+ },
+ Message={
+ "Body": body_section,
+ "Subject": {
+ "Charset": "UTF-8",
+ "Data": subject,
+ },
+ },
+ Source="noreply@fix.tt",
+ )
+
+ await asyncio.to_thread(send_email)
+
+
+def get_notification_service(config: ConfigDependency) -> NotificationService:
+ if config.aws_access_key_id and config.aws_secret_access_key:
+ return NotificationServiceImpl(config)
+ return ConsoleNotificationService()
+
+
+NotificationServiceDependency = Annotated[NotificationService, Depends(get_notification_service)]
diff --git a/fixbackend/workspaces/invitation_repository.py b/fixbackend/workspaces/invitation_repository.py
new file mode 100644
index 00000000..9035f54b
--- /dev/null
+++ b/fixbackend/workspaces/invitation_repository.py
@@ -0,0 +1,164 @@
+from abc import ABC, abstractmethod
+from datetime import timedelta
+from typing import Annotated, Callable, Optional, Sequence
+
+from fastapi import Depends
+from fixcloudutils.util import utc
+from sqlalchemy import select
+from sqlalchemy.orm.exc import StaleDataError
+
+from fixbackend.auth.user_repository import UserRepository
+from fixbackend.dependencies import FixDependency, ServiceNames
+from fixbackend.errors import ResourceNotFound
+from fixbackend.ids import InvitationId, WorkspaceId
+from fixbackend.types import AsyncSessionMaker
+from fixbackend.workspaces.models import WorkspaceInvitation, orm
+from fixbackend.workspaces.repository import WorkspaceRepository
+
+
+class InvitationRepository(ABC):
+ @abstractmethod
+ async def create_invitation(self, workspace_id: WorkspaceId, email: str) -> WorkspaceInvitation:
+ """Create an invite for a workspace."""
+ raise NotImplementedError
+
+ @abstractmethod
+ async def get_invitation(self, invitation_id: InvitationId) -> Optional[WorkspaceInvitation]:
+ """Get an invitation by ID."""
+ raise NotImplementedError
+
+ @abstractmethod
+ async def get_invitation_by_email(self, email: str) -> Optional[WorkspaceInvitation]:
+ """Get an invitation by email."""
+ raise NotImplementedError
+
+ @abstractmethod
+ async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]:
+ """List all invitations for a workspace."""
+ raise NotImplementedError
+
+ @abstractmethod
+ async def update_invitation(
+ self,
+ invitation_id: InvitationId,
+ update_fn: Callable[[WorkspaceInvitation], WorkspaceInvitation],
+ ) -> WorkspaceInvitation:
+ """Update an invitation."""
+ raise NotImplementedError
+
+ @abstractmethod
+ async def delete_invitation(self, invitation_id: InvitationId) -> None:
+ """Delete an invitation."""
+ raise NotImplementedError
+
+
+class InvitationRepositoryImpl(InvitationRepository):
+ def __init__(
+ self,
+ session_maker: AsyncSessionMaker,
+ workspace_repository: WorkspaceRepository,
+ ) -> None:
+ self.session_maker = session_maker
+ self.workspace_repository = workspace_repository
+
+ async def create_invitation(self, workspace_id: WorkspaceId, email: str) -> WorkspaceInvitation:
+ async with self.session_maker() as session:
+ existing_invitation = (
+ await session.execute(
+ select(orm.OrganizationInvite)
+ .where(orm.OrganizationInvite.organization_id == workspace_id)
+ .where(orm.OrganizationInvite.user_email == email)
+ )
+ ).scalar_one_or_none()
+ if existing_invitation:
+ return existing_invitation.to_model()
+
+ user_repository = UserRepository(session)
+
+ workspace = await self.workspace_repository.get_workspace(workspace_id, session=session)
+ if workspace is None:
+ raise ValueError(f"Workspace {workspace_id} does not exist.")
+
+ user = await user_repository.get_by_email(email)
+
+ if user:
+ if user.id in workspace.all_users():
+ raise ValueError(f"User {user.id} is already a member of workspace {workspace_id}")
+
+ invite = orm.OrganizationInvite(
+ organization_id=workspace_id,
+ user_email=email,
+ expires_at=utc() + timedelta(days=7),
+ )
+ session.add(invite)
+ await session.commit()
+ await session.refresh(invite)
+ return invite.to_model()
+
+ async def get_invitation(self, invitation_id: InvitationId) -> Optional[WorkspaceInvitation]:
+ async with self.session_maker() as session:
+ statement = select(orm.OrganizationInvite).where(orm.OrganizationInvite.id == invitation_id)
+ results = await session.execute(statement)
+ invite = results.unique().scalar_one_or_none()
+ return invite.to_model() if invite else None
+
+ async def get_invitation_by_email(self, email: str) -> Optional[WorkspaceInvitation]:
+ async with self.session_maker() as session:
+ statement = select(orm.OrganizationInvite).where(orm.OrganizationInvite.user_email == email)
+ results = await session.execute(statement)
+ invite = results.unique().scalar_one_or_none()
+ return invite.to_model() if invite else None
+
+ async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]:
+ async with self.session_maker() as session:
+ statement = select(orm.OrganizationInvite).where(orm.OrganizationInvite.organization_id == workspace_id)
+ results = await session.execute(statement)
+ invites = results.scalars().all()
+ return [invite.to_model() for invite in invites]
+
+ async def update_invitation(
+ self,
+ invitation_id: InvitationId,
+ update_fn: Callable[[WorkspaceInvitation], WorkspaceInvitation],
+ ) -> WorkspaceInvitation:
+ async def do_updade() -> WorkspaceInvitation:
+ async with self.session_maker() as session:
+ stored_invite = await session.get(orm.OrganizationInvite, invitation_id)
+ if stored_invite is None:
+ raise ResourceNotFound(f"Cloud account {invitation_id} not found")
+
+ invite = update_fn(stored_invite.to_model())
+
+ if stored_invite.to_model() == invite:
+ # nothing to update
+ return invite
+
+ stored_invite.organization_id = invite.workspace_id
+ stored_invite.user_email = invite.email
+ stored_invite.expires_at = invite.expires_at
+ stored_invite.accepted_at = invite.accepted_at
+
+ await session.commit()
+ await session.refresh(stored_invite)
+ return stored_invite.to_model()
+
+ while True:
+ try:
+ return await do_updade()
+ except StaleDataError: # in case of concurrent update
+ pass
+
+ async def delete_invitation(self, invitation_id: InvitationId) -> None:
+ async with self.session_maker() as session:
+ invite = await session.get(orm.OrganizationInvite, invitation_id)
+ if invite is None:
+ raise ValueError(f"Invitation {invitation_id} does not exist.")
+ await session.delete(invite)
+ await session.commit()
+
+
+async def get_invitation_repository(fix: FixDependency) -> InvitationRepository:
+ return fix.service(ServiceNames.invitation_repository, InvitationRepositoryImpl)
+
+
+InvitationRepositoryDependency = Annotated[InvitationRepository, Depends(get_invitation_repository)]
diff --git a/fixbackend/workspaces/invitation_service.py b/fixbackend/workspaces/invitation_service.py
new file mode 100644
index 00000000..671214c0
--- /dev/null
+++ b/fixbackend/workspaces/invitation_service.py
@@ -0,0 +1,167 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+
+from abc import ABC, abstractmethod
+from datetime import timedelta
+from logging import getLogger
+from typing import Annotated, Dict, Sequence, Tuple
+from attrs import evolve
+from fastapi import Depends
+
+import jwt
+from fastapi_users.jwt import decode_jwt, generate_jwt
+
+from fixbackend.auth.models import User
+from fixbackend.auth.user_repository import UserRepository, UserRepositoryDependency
+from fixbackend.config import Config, ConfigDependency
+from fixbackend.domain_events.dependencies import DomainEventPublisherDependency
+from fixbackend.domain_events.events import InvitationAccepted
+from fixbackend.domain_events.publisher import DomainEventPublisher
+from fixbackend.ids import InvitationId, WorkspaceId
+from fixbackend.notification.service import NotificationService, NotificationServiceDependency
+from fixbackend.workspaces.invitation_repository import InvitationRepository, InvitationRepositoryDependency
+from fixbackend.workspaces.models import WorkspaceInvitation
+from fixbackend.workspaces.repository import WorkspaceRepository, WorkspaceRepositoryDependency
+from fixcloudutils.util import utc
+
+
+log = getLogger(__name__)
+
+STATE_TOKEN_AUDIENCE = "fix:invitation-state"
+
+
+def generate_state_token(data: Dict[str, str], secret: str) -> str:
+ data["aud"] = STATE_TOKEN_AUDIENCE
+ return generate_jwt(data, secret, int(timedelta(days=7).total_seconds()))
+
+
+class InvitationService(ABC):
+ @abstractmethod
+ async def invite_user(
+ self, workspace_id: WorkspaceId, inviter: User, invitee_email: str, accept_invite_base_url: str
+ ) -> Tuple[WorkspaceInvitation, str]:
+ """Create an invitation to a workspace."""
+ raise NotImplementedError()
+
+ @abstractmethod
+ async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]:
+ """List all invitations to a workspace."""
+ raise NotImplementedError()
+
+ @abstractmethod
+ async def accept_invitation(self, token: str) -> WorkspaceInvitation:
+ """Accept an invitation to a workspace."""
+ raise NotImplementedError()
+
+ @abstractmethod
+ async def revoke_invitation(self, invitation_id: InvitationId) -> None:
+ """Revoke an invitation to a workspace."""
+ raise NotImplementedError()
+
+
+class InvitationServiceImpl(InvitationService):
+ def __init__(
+ self,
+ workspace_repository: WorkspaceRepository,
+ invitation_repository: InvitationRepository,
+ notification_service: NotificationService,
+ user_repository: UserRepository,
+ domain_events: DomainEventPublisher,
+ config: Config,
+ ) -> None:
+ self.invitation_repository = invitation_repository
+ self.notification_service = notification_service
+ self.workspace_repository = workspace_repository
+ self.user_repository = user_repository
+ self.domain_events = domain_events
+ self.config = config
+
+ async def invite_user(
+ self, workspace_id: WorkspaceId, inviter: User, invitee_email: str, accept_invite_base_url: str
+ ) -> Tuple[WorkspaceInvitation, str]:
+ workspace = await self.workspace_repository.get_workspace(workspace_id)
+ if workspace is None:
+ raise ValueError(f"Workspace {workspace_id} does not exist.")
+
+ # this is idempotent and will return the existing invitation if it exists
+ invitation = await self.invitation_repository.create_invitation(workspace_id, invitee_email)
+
+ state_data: Dict[str, str] = {
+ "invitation_id": str(invitation.id),
+ }
+ token = generate_state_token(state_data, secret=self.config.secret)
+
+ subject = f"FIX Cloud {inviter.email} has invited you to FIX workspace"
+ invite_link = f"{accept_invite_base_url}?token={token}"
+ text = (
+ f"{inviter.email} has invited you to join the workspace {workspace.name}. "
+ "Please click on the link below to accept the invitation. \n\n"
+ f"{invite_link}"
+ )
+ await self.notification_service.send_email(to=invitee_email, subject=subject, text=text, html=None)
+ return invitation, token
+
+ async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]:
+ return await self.invitation_repository.list_invitations(workspace_id)
+
+ async def accept_invitation(self, token: str) -> WorkspaceInvitation:
+ try:
+ decoded_state = decode_jwt(token, self.config.secret, [STATE_TOKEN_AUDIENCE])
+ except (jwt.ExpiredSignatureError, jwt.DecodeError) as ex:
+ log.info(f"accept invitation callback: invalid state token: {token}, {ex}")
+ raise ValueError("Invalid state token.", ex)
+
+ invitation_id = decoded_state["invitation_id"]
+ invitation = await self.invitation_repository.get_invitation(invitation_id)
+ if invitation is None:
+ raise ValueError(f"Invitation {invitation_id} does not exist.")
+
+ updated = await self.invitation_repository.update_invitation(
+ invitation_id, lambda invite: evolve(invite, accepted_at=utc())
+ )
+
+ # in case the user already exists, add it to the workspace and delete the invitation
+ if user := await self.user_repository.get_by_email(invitation.email):
+ await self.workspace_repository.add_to_workspace(invitation.workspace_id, user.id)
+ await self.invitation_repository.delete_invitation(invitation_id)
+
+ event = InvitationAccepted(invitation.workspace_id, invitation.email)
+ await self.domain_events.publish(event)
+
+ return updated
+
+ async def revoke_invitation(self, invitation_id: InvitationId) -> None:
+ await self.invitation_repository.delete_invitation(invitation_id)
+
+
+def get_invitation_service(
+ workspace_repository: WorkspaceRepositoryDependency,
+ invitation_repository: InvitationRepositoryDependency,
+ notification_service: NotificationServiceDependency,
+ user_repository: UserRepositoryDependency,
+ domain_events: DomainEventPublisherDependency,
+ config: ConfigDependency,
+) -> InvitationService:
+ return InvitationServiceImpl(
+ workspace_repository=workspace_repository,
+ invitation_repository=invitation_repository,
+ notification_service=notification_service,
+ user_repository=user_repository,
+ domain_events=domain_events,
+ config=config,
+ )
+
+
+InvitationServiceDependency = Annotated[InvitationService, Depends(get_invitation_service)]
diff --git a/fixbackend/workspaces/models/__init__.py b/fixbackend/workspaces/models/__init__.py
index 21ed340d..990d383f 100644
--- a/fixbackend/workspaces/models/__init__.py
+++ b/fixbackend/workspaces/models/__init__.py
@@ -13,12 +13,11 @@
# along with this program. If not, see .
from datetime import datetime
-from typing import List
-from uuid import UUID
+from typing import List, Optional
from attrs import frozen
-from fixbackend.ids import WorkspaceId, UserId, ExternalId
+from fixbackend.ids import InvitationId, WorkspaceId, UserId, ExternalId
@frozen
@@ -35,8 +34,9 @@ def all_users(self) -> List[UserId]:
@frozen
-class WorkspaceInvite:
- id: UUID
+class WorkspaceInvitation:
+ id: InvitationId
workspace_id: WorkspaceId
- user_id: UserId
+ email: str
expires_at: datetime
+ accepted_at: Optional[datetime]
diff --git a/fixbackend/workspaces/models/orm.py b/fixbackend/workspaces/models/orm.py
index 3b73742b..fa639618 100644
--- a/fixbackend/workspaces/models/orm.py
+++ b/fixbackend/workspaces/models/orm.py
@@ -14,15 +14,15 @@
import uuid
from datetime import datetime
-from typing import List
+from typing import List, Optional
from fastapi_users_db_sqlalchemy.generics import GUID
-from sqlalchemy import ForeignKey, String, DateTime
+from sqlalchemy import ForeignKey, String, DateTime, Integer
from sqlalchemy.orm import Mapped, relationship, mapped_column
from fixbackend.auth.models import orm
from fixbackend.base_model import Base
-from fixbackend.ids import WorkspaceId, UserId, ExternalId
+from fixbackend.ids import InvitationId, WorkspaceId, UserId, ExternalId
from fixbackend.workspaces import models
@@ -50,19 +50,22 @@ def to_model(self) -> models.Workspace:
class OrganizationInvite(Base):
__tablename__ = "organization_invite"
- id: Mapped[uuid.UUID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4)
- organization_id: Mapped[uuid.UUID] = mapped_column(GUID, ForeignKey("organization.id"), nullable=False)
- organization: Mapped[Organization] = relationship()
- user_id: Mapped[uuid.UUID] = mapped_column(GUID, ForeignKey("user.id"), nullable=False)
- user: Mapped[orm.User] = relationship()
+ id: Mapped[InvitationId] = mapped_column(GUID, primary_key=True, default=uuid.uuid4)
+ organization_id: Mapped[WorkspaceId] = mapped_column(GUID, ForeignKey("organization.id"), nullable=False)
+ user_email: Mapped[str] = mapped_column(String(length=320), nullable=False, unique=True)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
+ accepted_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
+ version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
+
+ __mapper_args__ = {"version_id_col": version_id} # for optimistic locking
- def to_model(self) -> models.WorkspaceInvite:
- return models.WorkspaceInvite(
+ def to_model(self) -> models.WorkspaceInvitation:
+ return models.WorkspaceInvitation(
id=self.id,
- workspace_id=WorkspaceId(self.organization_id),
- user_id=UserId(self.user_id),
+ workspace_id=self.organization_id,
+ email=self.user_email,
expires_at=self.expires_at,
+ accepted_at=self.accepted_at,
)
diff --git a/fixbackend/workspaces/repository.py b/fixbackend/workspaces/repository.py
index 82ea14ac..946016fe 100644
--- a/fixbackend/workspaces/repository.py
+++ b/fixbackend/workspaces/repository.py
@@ -14,23 +14,23 @@
import uuid
from abc import ABC, abstractmethod
-from datetime import datetime, timedelta
from typing import Annotated, Optional, Sequence
from fastapi import Depends
-from sqlalchemy import select
+from fixcloudutils.redis.pub_sub import RedisPubSubPublisher
+from sqlalchemy import select, or_
from sqlalchemy.exc import IntegrityError
+from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from fixbackend.auth.models import User
-from fixbackend.auth.models import orm as auth_orm
from fixbackend.dependencies import FixDependency, ServiceNames
from fixbackend.graph_db.service import GraphDatabaseAccessManager
from fixbackend.ids import ExternalId, WorkspaceId, UserId
from fixbackend.types import AsyncSessionMaker
-from fixbackend.workspaces.models import Workspace, WorkspaceInvite, orm
+from fixbackend.workspaces.models import Workspace, orm
from fixbackend.domain_events.publisher import DomainEventPublisher
-from fixbackend.domain_events.events import WorkspaceCreated
+from fixbackend.domain_events.events import UserJoinedWorkspace, WorkspaceCreated
class WorkspaceRepository(ABC):
@@ -40,7 +40,9 @@ async def create_workspace(self, name: str, slug: str, owner: User) -> Workspace
raise NotImplementedError
@abstractmethod
- async def get_workspace(self, workspace_id: WorkspaceId) -> Optional[Workspace]:
+ async def get_workspace(
+ self, workspace_id: WorkspaceId, *, session: Optional[AsyncSession] = None
+ ) -> Optional[Workspace]:
"""Get a workspace."""
raise NotImplementedError
@@ -64,31 +66,6 @@ async def remove_from_workspace(self, workspace_id: WorkspaceId, user_id: UserId
"""Remove a user from a workspace."""
raise NotImplementedError
- @abstractmethod
- async def create_invitation(self, workspace_id: WorkspaceId, user_id: UserId) -> WorkspaceInvite:
- """Create an invite for a workspace."""
- raise NotImplementedError
-
- @abstractmethod
- async def get_invitation(self, invitation_id: uuid.UUID) -> Optional[WorkspaceInvite]:
- """Get an invitation by ID."""
- raise NotImplementedError
-
- @abstractmethod
- async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvite]:
- """List all invitations for a workspace."""
- raise NotImplementedError
-
- @abstractmethod
- async def accept_invitation(self, invitation_id: uuid.UUID) -> None:
- """Accept an invitation to a workspace."""
- raise NotImplementedError
-
- @abstractmethod
- async def delete_invitation(self, invitation_id: uuid.UUID) -> None:
- """Delete an invitation."""
- raise NotImplementedError
-
class WorkspaceRepositoryImpl(WorkspaceRepository):
def __init__(
@@ -96,10 +73,12 @@ def __init__(
session_maker: AsyncSessionMaker,
graph_db_access_manager: GraphDatabaseAccessManager,
domain_event_sender: DomainEventPublisher,
+ pubsub_publisher: RedisPubSubPublisher,
) -> None:
self.session_maker = session_maker
self.graph_db_access_manager = graph_db_access_manager
self.domain_event_sender = domain_event_sender
+ self.pubsub_publisher = pubsub_publisher
async def create_workspace(self, name: str, slug: str, owner: User) -> Workspace:
async with self.session_maker() as session:
@@ -123,13 +102,21 @@ async def create_workspace(self, name: str, slug: str, owner: User) -> Workspace
org = results.unique().scalar_one()
return org.to_model()
- async def get_workspace(self, workspace_id: WorkspaceId) -> Optional[Workspace]:
- async with self.session_maker() as session:
+ async def get_workspace(
+ self, workspace_id: WorkspaceId, *, session: Optional[AsyncSession] = None
+ ) -> Optional[Workspace]:
+ async def get_ws(session: AsyncSession) -> Optional[Workspace]:
statement = select(orm.Organization).where(orm.Organization.id == workspace_id)
results = await session.execute(statement)
org = results.unique().scalar_one_or_none()
return org.to_model() if org else None
+ if session is not None:
+ return await get_ws(session)
+ else:
+ async with self.session_maker() as session:
+ return await get_ws(session)
+
async def update_workspace(self, workspace_id: WorkspaceId, name: str, generate_external_id: bool) -> Workspace:
"""Update a workspace."""
async with self.session_maker() as session:
@@ -148,7 +135,16 @@ async def update_workspace(self, workspace_id: WorkspaceId, name: str, generate_
async def list_workspaces(self, user_id: UserId) -> Sequence[Workspace]:
async with self.session_maker() as session:
statement = (
- select(orm.Organization).join(orm.OrganizationOwners).where(orm.OrganizationOwners.user_id == user_id)
+ select(orm.Organization)
+ .join(
+ orm.OrganizationOwners, orm.Organization.id == orm.OrganizationOwners.organization_id, isouter=True
+ )
+ .join(
+ orm.OrganizationMembers,
+ orm.Organization.id == orm.OrganizationMembers.organization_id,
+ isouter=True,
+ )
+ .where(or_(orm.OrganizationOwners.user_id == user_id, orm.OrganizationMembers.user_id == user_id))
)
results = await session.execute(statement)
orgs = results.unique().scalars().all()
@@ -168,6 +164,10 @@ async def add_to_workspace(self, workspace_id: WorkspaceId, user_id: UserId) ->
except IntegrityError:
raise ValueError("Can't add user to workspace.")
+ event = UserJoinedWorkspace(workspace_id, user_id)
+ await self.domain_event_sender.publish(event)
+ await self.pubsub_publisher.publish(event.kind, event.to_json(), f"tenant-events::{event.workspace_id}")
+
async def remove_from_workspace(self, workspace_id: WorkspaceId, user_id: UserId) -> None:
async with self.session_maker() as session:
membership = await session.get(orm.OrganizationMembers, (workspace_id, user_id))
@@ -176,70 +176,6 @@ async def remove_from_workspace(self, workspace_id: WorkspaceId, user_id: UserId
await session.delete(membership)
await session.commit()
- async def create_invitation(self, workspace_id: WorkspaceId, user_id: UserId) -> WorkspaceInvite:
- async with self.session_maker() as session:
- user = await session.get(auth_orm.User, user_id)
- organization = await self.get_workspace(workspace_id)
-
- if user is None or organization is None:
- raise ValueError(f"User {user_id} or organization {workspace_id} does not exist.")
-
- if user.id in [owner for owner in organization.owners]:
- raise ValueError(f"User {user_id} is already an owner of workspace {workspace_id}")
-
- if user.id in [member for member in organization.members]:
- raise ValueError(f"User {user_id} is already a member of workspace {workspace_id}")
-
- invite = orm.OrganizationInvite(
- user_id=user_id, organization_id=workspace_id, expires_at=datetime.utcnow() + timedelta(days=7)
- )
- session.add(invite)
- await session.commit()
- await session.refresh(invite)
- return invite.to_model()
-
- async def get_invitation(self, invitation_id: uuid.UUID) -> Optional[WorkspaceInvite]:
- async with self.session_maker() as session:
- statement = (
- select(orm.OrganizationInvite)
- .where(orm.OrganizationInvite.id == invitation_id)
- .options(selectinload(orm.OrganizationInvite.user))
- )
- results = await session.execute(statement)
- invite = results.unique().scalar_one_or_none()
- return invite.to_model() if invite else None
-
- async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvite]:
- async with self.session_maker() as session:
- statement = (
- select(orm.OrganizationInvite)
- .where(orm.OrganizationInvite.organization_id == workspace_id)
- .options(selectinload(orm.OrganizationInvite.user), selectinload(orm.OrganizationInvite.organization))
- )
- results = await session.execute(statement)
- invites = results.scalars().all()
- return [invite.to_model() for invite in invites]
-
- async def accept_invitation(self, invitation_id: uuid.UUID) -> None:
- async with self.session_maker() as session:
- invite = await session.get(orm.OrganizationInvite, invitation_id)
- if invite is None:
- raise ValueError(f"Invitation {invitation_id} does not exist.")
- if invite.expires_at < datetime.utcnow():
- raise ValueError(f"Invitation {invitation_id} has expired.")
- membership = orm.OrganizationMembers(user_id=invite.user_id, organization_id=invite.organization_id)
- session.add(membership)
- await session.delete(invite)
- await session.commit()
-
- async def delete_invitation(self, invitation_id: uuid.UUID) -> None:
- async with self.session_maker() as session:
- invite = await session.get(orm.OrganizationInvite, invitation_id)
- if invite is None:
- raise ValueError(f"Invitation {invitation_id} does not exist.")
- await session.delete(invite)
- await session.commit()
-
async def get_workspace_repository(fix: FixDependency) -> WorkspaceRepository:
return fix.service(ServiceNames.workspace_repo, WorkspaceRepositoryImpl)
diff --git a/fixbackend/workspaces/router.py b/fixbackend/workspaces/router.py
index 28a81e35..38a9176b 100644
--- a/fixbackend/workspaces/router.py
+++ b/fixbackend/workspaces/router.py
@@ -12,32 +12,38 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import List
-from uuid import UUID
+from typing import List, Optional
-from fastapi import APIRouter, HTTPException
-from pydantic import EmailStr
+from fastapi import APIRouter, HTTPException, Request, Response
+from fastapi.responses import RedirectResponse
from sqlalchemy.exc import IntegrityError
from fixbackend.auth.depedencies import AuthenticatedUser
-from fixbackend.auth.user_manager import UserManagerDependency
+from fixbackend.auth.models import User
+from fixbackend.auth.user_repository import UserRepositoryDependency
from fixbackend.config import ConfigDependency
-from fixbackend.ids import WorkspaceId
+from fixbackend.ids import InvitationId, UserId, WorkspaceId
+from fixbackend.workspaces.invitation_service import InvitationServiceDependency
from fixbackend.workspaces.repository import WorkspaceRepositoryDependency
from fixbackend.workspaces.dependencies import UserWorkspaceDependency
from fixbackend.workspaces.schemas import (
ExternalIdRead,
+ UserInvite,
WorkspaceCreate,
WorkspaceInviteRead,
WorkspaceRead,
WorkspaceSettingsRead,
WorkspaceSettingsUpdate,
+ WorkspaceUserRead,
)
+import asyncio
def workspaces_router() -> APIRouter:
router = APIRouter()
+ ACCEPT_INVITE_ROUTE_NAME = "accept_invitation"
+
@router.get("/")
async def list_workspaces(
user: AuthenticatedUser, workspace_repository: WorkspaceRepositoryDependency
@@ -56,10 +62,10 @@ async def get_workspace(
"""Get a workspace."""
org = await workspace_repository.get_workspace(workspace_id)
if org is None:
- raise HTTPException(status_code=404, detail="Organization not found")
+ raise HTTPException(status_code=404, detail="Workspace not found")
if user.id not in org.all_users():
- raise HTTPException(status_code=403, detail="You are not an owner of this organization")
+ raise HTTPException(status_code=403, detail="You are not a member of this workspace")
return WorkspaceRead.from_model(org)
@@ -103,71 +109,72 @@ async def create_workspace(
@router.get("/{workspace_id}/invites/")
async def list_invites(
workspace: UserWorkspaceDependency,
- workspace_repository: WorkspaceRepositoryDependency,
+ invitation_service: InvitationServiceDependency,
) -> List[WorkspaceInviteRead]:
- invites = await workspace_repository.list_invitations(workspace_id=workspace.id)
+ invites = await invitation_service.list_invitations(workspace_id=workspace.id)
- return [
- WorkspaceInviteRead(
- organization_slug=workspace.slug,
- user_id=invite.user_id,
- expires_at=invite.expires_at,
- )
- for invite in invites
- ]
+ return [WorkspaceInviteRead.from_model(invite, workspace) for invite in invites]
+
+ @router.get("/{workspace_id}/users/")
+ async def list_users(
+ workspace: UserWorkspaceDependency,
+ user_repository: UserRepositoryDependency,
+ ) -> List[WorkspaceUserRead]:
+ user_ids = workspace.all_users()
+ users: List[Optional[User]] = await asyncio.gather(*[user_repository.get(user_id) for user_id in user_ids])
+ return [WorkspaceUserRead.from_model(user) for user in users if user]
@router.post("/{workspace_id}/invites/")
async def invite_to_organization(
workspace: UserWorkspaceDependency,
- user_email: EmailStr,
- workspace_repository: WorkspaceRepositoryDependency,
- user_manager: UserManagerDependency,
+ user: AuthenticatedUser,
+ user_invite: UserInvite,
+ invitation_service: InvitationServiceDependency,
+ request: Request,
) -> WorkspaceInviteRead:
"""Invite a user to the workspace."""
- user = await user_manager.get_by_email(user_email)
- if user is None:
- raise HTTPException(status_code=404, detail="User not found")
-
- invite = await workspace_repository.create_invitation(workspace_id=workspace.id, user_id=user.id)
+ accept_invite_url = str(request.url_for(ACCEPT_INVITE_ROUTE_NAME, workspace_id=workspace.id))
- return WorkspaceInviteRead(
- organization_slug=workspace.slug,
- user_id=user.id,
- expires_at=invite.expires_at,
+ invite, _ = await invitation_service.invite_user(
+ workspace_id=workspace.id,
+ inviter=user,
+ invitee_email=user_invite.email,
+ accept_invite_base_url=accept_invite_url,
)
+ return WorkspaceInviteRead.from_model(invite, workspace)
+
+ @router.delete("/{workspace_id}/users/{user_id}/")
+ async def remove_user(
+ workspace: UserWorkspaceDependency,
+ user_id: UserId,
+ workspace_repository: WorkspaceRepositoryDependency,
+ user_repository: UserRepositoryDependency,
+ ) -> None:
+ """Delete a user from the workspace."""
+ user = await user_repository.get(user_id)
+ if user is None:
+ raise HTTPException(status_code=404, detail="User not found")
+ await workspace_repository.remove_from_workspace(workspace_id=workspace.id, user_id=user.id)
+
@router.delete("/{workspace_id}/invites/{invite_id}")
async def delete_invite(
workspace: UserWorkspaceDependency,
- invite_id: UUID,
- workspace_repository: WorkspaceRepositoryDependency,
+ invite_id: InvitationId,
+ invitation_service: InvitationServiceDependency,
) -> None:
"""Delete invite."""
- await workspace_repository.delete_invitation(invite_id)
+ await invitation_service.revoke_invitation(invite_id)
- @router.get("{workspace_id}/invites/{invite_id}/accept")
+ @router.get("{workspace_id}/accept_invite", name=ACCEPT_INVITE_ROUTE_NAME)
async def accept_invitation(
- workspace_id: WorkspaceId,
- invite_id: UUID,
- user: AuthenticatedUser,
- workspace_repository: WorkspaceRepositoryDependency,
- ) -> None:
+ token: str, invitation_service: InvitationServiceDependency, request: Request
+ ) -> Response:
"""Accept an invitation to the workspace."""
- org = await workspace_repository.get_workspace(workspace_id)
- if org is None:
- raise HTTPException(status_code=404, detail="Organization not found")
-
- invite = await workspace_repository.get_invitation(invite_id)
- if invite is None:
- raise HTTPException(status_code=404, detail="Invitation not found")
-
- if user.id != invite.user_id:
- raise HTTPException(status_code=403, detail="You can only accept invitations for your own account")
-
- await workspace_repository.accept_invitation(invite_id)
-
- return None
+ invitation = await invitation_service.accept_invitation(token)
+ url = request.base_url.replace_query_params(message="invitation-accepted", workspace_id=invitation.workspace_id)
+ return RedirectResponse(url)
@router.get("/{workspace_id}/cf_url")
async def get_cf_url(
diff --git a/fixbackend/workspaces/schemas.py b/fixbackend/workspaces/schemas.py
index e6bb8ade..dd9122e3 100644
--- a/fixbackend/workspaces/schemas.py
+++ b/fixbackend/workspaces/schemas.py
@@ -13,12 +13,13 @@
# along with this program. If not, see .
from datetime import datetime
-from typing import List
+from typing import List, Optional
+from fixbackend.auth.models import User
from fixbackend.ids import WorkspaceId, UserId, ExternalId
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, EmailStr, Field
-from fixbackend.workspaces.models import Workspace
+from fixbackend.workspaces.models import Workspace, WorkspaceInvitation
class WorkspaceRead(BaseModel):
@@ -104,16 +105,28 @@ class WorkspaceCreate(BaseModel):
class WorkspaceInviteRead(BaseModel):
- organization_slug: str = Field(description="The slug of the workspace to invite the user to")
- user_id: UserId = Field(description="The id of the user to invite")
+ workspace_id: WorkspaceId = Field(description="The unique identifier of the workspace to invite the user to")
+ workspace_name: str = Field(description="The name of the workspace to invite the user to")
+ user_email: str = Field(description="The email of the user to invite")
expires_at: datetime = Field(description="The time at which the invitation expires")
+ accepted_at: Optional[datetime] = Field(description="The time at which the invitation was accepted, if any")
+
+ @staticmethod
+ def from_model(invite: WorkspaceInvitation, workspace: Workspace) -> "WorkspaceInviteRead":
+ return WorkspaceInviteRead(
+ workspace_id=invite.workspace_id,
+ workspace_name=workspace.name,
+ user_email=invite.email,
+ expires_at=invite.expires_at,
+ accepted_at=invite.accepted_at,
+ )
model_config = {
"json_schema_extra": {
"examples": [
{
"organization_slug": "my-org",
- "user_id": "00000000-0000-0000-0000-000000000000",
+ "user_email": "foo@bar.com",
"expires_at": "2021-01-01T00:00:00Z",
}
]
@@ -133,3 +146,54 @@ class ExternalIdRead(BaseModel):
]
}
}
+
+
+class UserInvite(BaseModel):
+ name: str = Field(description="The name of the user")
+ email: EmailStr = Field(description="The email of the user")
+ roles: List[str] = Field(description="The role of the user")
+
+ model_config = {
+ "json_schema_extra": {
+ "examples": [
+ {
+ "name": "Foo Bar",
+ "email": "foo@example.com",
+ "roles": ["admin"],
+ }
+ ]
+ }
+ }
+
+
+class WorkspaceUserRead(BaseModel):
+ id: UserId = Field(description="The user's unique identifier")
+ sources: List[str] = Field(description="Where the user is found")
+ name: str = Field(description="The user's name")
+ email: str = Field(description="The user's email")
+ roles: List[str] = Field(description="The user's roles")
+ last_login: Optional[datetime] = Field(description="The user's last login time, if any")
+
+ @staticmethod
+ def from_model(user: User) -> "WorkspaceUserRead":
+ return WorkspaceUserRead(
+ id=user.id,
+ sources=[],
+ name=user.email,
+ email=user.email,
+ roles=[],
+ last_login=None,
+ )
+
+ model_config = {
+ "json_schema_extra": {
+ "examples": [
+ {
+ "sources": ["organization"],
+ "name": "Foo Bar",
+ "email": "foo@example.com",
+ "roles": ["admin"],
+ }
+ ]
+ }
+ }
diff --git a/migrations/versions/2023-12-07T12:14:35Z_update_workspace_invites.py b/migrations/versions/2023-12-07T12:14:35Z_update_workspace_invites.py
new file mode 100644
index 00000000..ebd51916
--- /dev/null
+++ b/migrations/versions/2023-12-07T12:14:35Z_update_workspace_invites.py
@@ -0,0 +1,34 @@
+"""add created_at/updated_at to cloud_account
+
+Revision ID: 1ccf5fc88e67
+Revises: d294f6e4b5dc
+Create Date: 2023-12-07 12:14:35.061355+00:00
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision: str = "1ccf5fc88e67"
+down_revision: Union[str, None] = "d294f6e4b5dc"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+
+ # user_email is not nullable, so we need to drop all existing invites
+ op.execute("TRUNCATE TABLE organization_invite")
+
+ op.add_column("organization_invite", sa.Column("user_email", sa.String(length=320), nullable=False))
+ op.add_column("organization_invite", sa.Column("accepted_at", sa.DateTime(timezone=True), nullable=True))
+ op.add_column(
+ "organization_invite", sa.Column("version_id", sa.Integer(), nullable=False, server_default=sa.text("0"))
+ )
+ op.create_unique_constraint(None, "organization_invite", ["user_email"])
+ op.drop_constraint("organization_invite_ibfk_2", "organization_invite", type_="foreignkey")
+ op.drop_column("organization_invite", "user_id")
+ # ### end Alembic commands ###
diff --git a/noxfile.py b/noxfile.py
index 274a3f87..ec53a9f7 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -36,7 +36,7 @@ def flake8(session: Session) -> None:
@session(python=["3.11"]) # type: ignore
def test(session: Session) -> None:
args = session.posargs or ["--cov"]
- session.run_always("poetry", "install", external=True)
+ session.run_always("poetry", "install", "--quiet", external=True)
session.run("pytest", *args)
@@ -44,7 +44,7 @@ def test(session: Session) -> None:
def mypy(session: Session) -> None:
opts = ["--strict"]
args = session.posargs or [] + opts + locations
- session.run_always("poetry", "install", external=True)
+ session.run_always("poetry", "install", "--quiet", external=True)
session.install("mypy", ".")
print(args)
session.run("mypy", *args)
diff --git a/tests/fixbackend/auth/jwt_strategy_test.py b/tests/fixbackend/auth/jwt_strategy_test.py
index 9c567645..b1f769df 100644
--- a/tests/fixbackend/auth/jwt_strategy_test.py
+++ b/tests/fixbackend/auth/jwt_strategy_test.py
@@ -12,21 +12,21 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from fastapi import Request
-
import pytest
+from cryptography.hazmat.primitives.asymmetric import rsa
+from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession
-from fixbackend.auth.db import get_user_repository
-from fixbackend.auth.models import User
-from fixbackend.workspaces.repository import WorkspaceRepository
from fixbackend.auth.auth_backend import FixJWTStrategy
-from cryptography.hazmat.primitives.asymmetric import rsa
+from fixbackend.auth.models import User
from fixbackend.auth.user_manager import UserManager
-from fixbackend.config import Config
+from fixbackend.auth.user_repository import get_user_repository
from fixbackend.auth.user_verifier import UserVerifier
-from fixbackend.domain_events.publisher import DomainEventPublisher
+from fixbackend.config import Config
from fixbackend.domain_events.events import Event
+from fixbackend.domain_events.publisher import DomainEventPublisher
+from fixbackend.workspaces.invitation_repository import InvitationRepository
+from fixbackend.workspaces.repository import WorkspaceRepository
@pytest.fixture
@@ -54,7 +54,11 @@ async def publish(self, event: Event) -> None:
@pytest.mark.asyncio
async def test_token_validation(
- workspace_repository: WorkspaceRepository, user: User, default_config: Config, session: AsyncSession
+ workspace_repository: WorkspaceRepository,
+ user: User,
+ default_config: Config,
+ session: AsyncSession,
+ invitation_repository: InvitationRepository,
) -> None:
private_key_1 = rsa.generate_private_key(65537, 2048)
private_key_2 = rsa.generate_private_key(65537, 2048)
@@ -64,7 +68,13 @@ async def test_token_validation(
user_repo = await anext(get_user_repository(session))
user_manager = UserManager(
- default_config, user_repo, None, UserVerifierMock(), workspace_repository, DomainEventSenderMock()
+ default_config,
+ user_repo,
+ None,
+ UserVerifierMock(),
+ workspace_repository,
+ DomainEventSenderMock(),
+ invitation_repository,
)
token1 = await strategy1.write_token(user)
diff --git a/tests/fixbackend/auth/router_test.py b/tests/fixbackend/auth/router_test.py
index 7ecb3fa7..95068e35 100644
--- a/tests/fixbackend/auth/router_test.py
+++ b/tests/fixbackend/auth/router_test.py
@@ -12,7 +12,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import List, Optional, Tuple
+from typing import Callable, List, Optional, Sequence, Tuple
import pytest
from fastapi import Request, FastAPI
@@ -24,6 +24,9 @@
from fixbackend.domain_events.publisher import DomainEventPublisher
from fixbackend.domain_events.dependencies import get_domain_event_publisher
from fixbackend.domain_events.events import Event, UserRegistered, WorkspaceCreated
+from fixbackend.ids import InvitationId, WorkspaceId
+from fixbackend.workspaces.invitation_repository import InvitationRepository, get_invitation_repository
+from fixbackend.workspaces.models import WorkspaceInvitation
from tests.fixbackend.conftest import InMemoryDomainEventPublisher
@@ -44,13 +47,44 @@ async def publish(self, event: Event) -> None:
return self.events.append(event)
+class InMemoryInvitationRepo(InvitationRepository):
+ async def get_invitation_by_email(self, email: str) -> Optional[WorkspaceInvitation]:
+ return None
+
+ async def create_invitation(self, workspace_id: WorkspaceId, email: str) -> WorkspaceInvitation:
+ """Create an invite for a workspace."""
+ raise NotImplementedError
+
+ async def get_invitation(self, invitation_id: InvitationId) -> Optional[WorkspaceInvitation]:
+ """Get an invitation by ID."""
+ raise NotImplementedError
+
+ async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]:
+ """List all invitations for a workspace."""
+ raise NotImplementedError
+
+ async def update_invitation(
+ self,
+ invitation_id: InvitationId,
+ update_fn: Callable[[WorkspaceInvitation], WorkspaceInvitation],
+ ) -> WorkspaceInvitation:
+ """Update an invitation."""
+ raise NotImplementedError
+
+ async def delete_invitation(self, invitation_id: InvitationId) -> None:
+ """Delete an invitation."""
+ raise NotImplementedError
+
+
@pytest.mark.asyncio
async def test_registration_flow(
api_client: AsyncClient, fast_api: FastAPI, domain_event_sender: InMemoryDomainEventPublisher
) -> None:
verifier = InMemoryVerifier()
+ invitation_repo = InMemoryInvitationRepo()
fast_api.dependency_overrides[get_user_verifier] = lambda: verifier
fast_api.dependency_overrides[get_domain_event_publisher] = lambda: domain_event_sender
+ fast_api.dependency_overrides[get_invitation_repository] = lambda: invitation_repo
registration_json = {
"email": "user@example.com",
diff --git a/tests/fixbackend/cloud_accounts/service_test.py b/tests/fixbackend/cloud_accounts/service_test.py
index 4650178d..54b88be6 100644
--- a/tests/fixbackend/cloud_accounts/service_test.py
+++ b/tests/fixbackend/cloud_accounts/service_test.py
@@ -24,6 +24,7 @@
from fixcloudutils.types import Json
from httpx import AsyncClient, Request, Response
from redis.asyncio import Redis
+from sqlalchemy.ext.asyncio import AsyncSession
from fixbackend.cloud_accounts.account_setup import AssumeRoleResult, AssumeRoleResults, AwsAccountSetupHelper
from fixbackend.cloud_accounts.models import AwsCloudAccess, CloudAccount, CloudAccountStates
@@ -120,12 +121,14 @@ async def list_all_discovered_accounts(self) -> List[CloudAccount]:
)
-class OrganizationServiceMock(WorkspaceRepositoryImpl):
+class WorkspaceServiceMock(WorkspaceRepositoryImpl):
# noinspection PyMissingConstructor
def __init__(self) -> None:
pass
- async def get_workspace(self, workspace_id: WorkspaceId, with_users: bool = False) -> Workspace | None:
+ async def get_workspace(
+ self, workspace_id: WorkspaceId, with_users: bool = False, *, session: Optional[AsyncSession] = None
+ ) -> Workspace | None:
if workspace_id != test_workspace_id:
return None
return organization
@@ -194,8 +197,8 @@ def repository() -> CloudAccountRepositoryMock:
@pytest.fixture
-def organization_repository() -> OrganizationServiceMock:
- return OrganizationServiceMock()
+def organization_repository() -> WorkspaceServiceMock:
+ return WorkspaceServiceMock()
@pytest.fixture
@@ -215,7 +218,7 @@ def account_setup_helper() -> AwsAccountSetupHelperMock:
@pytest.fixture
def service(
- organization_repository: OrganizationServiceMock,
+ organization_repository: WorkspaceServiceMock,
repository: CloudAccountRepositoryMock,
pubsub_publisher: RedisPubSubPublisherMock,
domain_sender: DomainEventSenderMock,
diff --git a/tests/fixbackend/conftest.py b/tests/fixbackend/conftest.py
index 461d9519..7d107503 100644
--- a/tests/fixbackend/conftest.py
+++ b/tests/fixbackend/conftest.py
@@ -20,6 +20,7 @@
from datetime import datetime, timezone
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Sequence, Tuple, Optional
from unittest.mock import patch
+from attrs import frozen
import pytest
from alembic.command import upgrade as alembic_upgrade
@@ -35,7 +36,7 @@
from sqlalchemy_utils import create_database, database_exists, drop_database
from fixbackend.app import fast_api_app
-from fixbackend.auth.db import get_user_repository
+from fixbackend.auth.user_repository import get_user_repository, UserRepository
from fixbackend.auth.models import User
from fixbackend.cloud_accounts.repository import CloudAccountRepository, CloudAccountRepositoryImpl
from fixbackend.collect.collect_queue import RedisCollectQueue
@@ -58,8 +59,10 @@
from fixbackend.subscription.subscription_repository import SubscriptionRepository
from fixbackend.types import AsyncSessionMaker
from fixbackend.utils import start_of_next_month, uid
+from fixbackend.workspaces.invitation_repository import InvitationRepository, InvitationRepositoryImpl
from fixbackend.workspaces.models import Workspace
from fixbackend.workspaces.repository import WorkspaceRepository, WorkspaceRepositoryImpl
+from fixcloudutils.redis.pub_sub import RedisPubSubPublisher
DATABASE_URL = "mysql+aiomysql://root@127.0.0.1:3306/fixbackend-testdb"
# only used to create/drop the database
@@ -211,15 +214,21 @@ def graph_database_access_manager(
return GraphDatabaseAccessManager(default_config, async_session_maker)
+@pytest.fixture
+async def user_repository(session: AsyncSession) -> UserRepository:
+ repo = await anext(get_user_repository(session))
+ return repo
+
+
@pytest.fixture
async def user(session: AsyncSession) -> User:
- user_db = await anext(get_user_repository(session))
+ user_repository = await anext(get_user_repository(session))
user_dict = {
"email": "foo@bar.com",
"hashed_password": "notreallyhashed",
"is_verified": True,
}
- return await user_db.create(user_dict)
+ return await user_repository.create(user_dict)
@pytest.fixture
@@ -428,13 +437,44 @@ async def domain_event_sender() -> InMemoryDomainEventPublisher:
return InMemoryDomainEventPublisher()
+@frozen
+class PubSubMessage:
+ kind: str
+ message: Json
+ channel: Optional[str]
+
+
+class InMemoryRedisPubSubPublisher(RedisPubSubPublisher):
+ def __init__(self) -> None:
+ self.events: List[PubSubMessage] = []
+
+ async def publish(self, kind: str, message: Json, channel: Optional[str] = None) -> None:
+ self.events.append(PubSubMessage(kind, message, channel))
+
+
+@pytest.fixture
+def pubsub_publisher() -> InMemoryRedisPubSubPublisher:
+ return InMemoryRedisPubSubPublisher()
+
+
@pytest.fixture
async def workspace_repository(
async_session_maker: AsyncSessionMaker,
graph_database_access_manager: GraphDatabaseAccessManager,
domain_event_sender: DomainEventPublisher,
+ pubsub_publisher: InMemoryRedisPubSubPublisher,
) -> WorkspaceRepository:
- return WorkspaceRepositoryImpl(async_session_maker, graph_database_access_manager, domain_event_sender)
+ return WorkspaceRepositoryImpl(
+ async_session_maker, graph_database_access_manager, domain_event_sender, pubsub_publisher
+ )
+
+
+@pytest.fixture
+async def invitation_repository(
+ async_session_maker: AsyncSessionMaker,
+ workspace_repository: WorkspaceRepository,
+) -> InvitationRepository:
+ return InvitationRepositoryImpl(async_session_maker, workspace_repository)
@pytest.fixture
diff --git a/tests/fixbackend/workspaces/invitation_repository_test.py b/tests/fixbackend/workspaces/invitation_repository_test.py
new file mode 100644
index 00000000..fda80a03
--- /dev/null
+++ b/tests/fixbackend/workspaces/invitation_repository_test.py
@@ -0,0 +1,194 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+
+from attrs import evolve
+from fixcloudutils.util import utc
+import pytest
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from fixbackend.auth.user_repository import get_user_repository
+from fixbackend.auth.models import User
+from fixbackend.workspaces.invitation_repository import InvitationRepository
+from fixbackend.workspaces.repository import WorkspaceRepository
+
+
+async def create_user(email: str, session: AsyncSession) -> User:
+ user_db = await anext(get_user_repository(session))
+ user_dict = {
+ "email": email,
+ "hashed_password": "notreallyhashed",
+ "is_verified": True,
+ }
+ user = await user_db.create(user_dict)
+
+ return user
+
+
+@pytest.mark.asyncio
+async def test_create_invitation(
+ workspace_repository: WorkspaceRepository,
+ invitation_repository: InvitationRepository,
+ session: AsyncSession,
+) -> None:
+ user = await create_user("foo@bar.com", session)
+ organization = await workspace_repository.create_workspace(
+ name="Test Organization", slug="test-organization", owner=user
+ )
+ org_id = organization.id
+
+ user2 = await create_user("123foo@bar.com", session)
+
+ invitation = await invitation_repository.create_invitation(workspace_id=org_id, email=user2.email)
+ assert invitation.workspace_id == org_id
+ assert invitation.email == user2.email
+
+ # create invitation is idempotent
+ invitation2 = await invitation_repository.create_invitation(workspace_id=org_id, email=user2.email)
+ assert invitation2 == invitation
+
+ external_email = "i_do_not_exist@bar.com"
+ non_user_invitation = await invitation_repository.create_invitation(workspace_id=org_id, email=external_email)
+ assert non_user_invitation.workspace_id == org_id
+ assert non_user_invitation.email == external_email
+
+
+@pytest.mark.asyncio
+async def test_list_invitations(
+ workspace_repository: WorkspaceRepository,
+ invitation_repository: InvitationRepository,
+ session: AsyncSession,
+) -> None:
+ user = await create_user("foo@bar.com", session)
+ workspace = await workspace_repository.create_workspace(
+ name="Test Organization", slug="test-organization", owner=user
+ )
+
+ user_db = await anext(get_user_repository(session))
+ user_dict = {
+ "email": "bar@bar.com",
+ "hashed_password": "notreallyhashed",
+ "is_verified": True,
+ }
+ new_user = await user_db.create(user_dict)
+
+ invitation = await invitation_repository.create_invitation(workspace_id=workspace.id, email=new_user.email)
+
+ # list the invitations
+ invitations = await invitation_repository.list_invitations(workspace_id=workspace.id)
+ assert len(invitations) == 1
+ assert invitations[0] == invitation
+
+
+@pytest.mark.asyncio
+async def test_get_invitation(
+ workspace_repository: WorkspaceRepository,
+ invitation_repository: InvitationRepository,
+ session: AsyncSession,
+) -> None:
+ user = await create_user("foo@bar.com", session)
+ workspace = await workspace_repository.create_workspace(
+ name="Test Organization", slug="test-organization", owner=user
+ )
+ user_db = await anext(get_user_repository(session))
+ user_dict = {
+ "email": "bar@bar.com",
+ "hashed_password": "notreallyhashed",
+ "is_verified": True,
+ }
+ new_user = await user_db.create(user_dict)
+
+ invitation = await invitation_repository.create_invitation(workspace_id=workspace.id, email=new_user.email)
+
+ stored_invitation = await invitation_repository.get_invitation(invitation_id=invitation.id)
+ assert stored_invitation == invitation
+
+
+@pytest.mark.asyncio
+async def test_get_invitation_by_email(
+ workspace_repository: WorkspaceRepository,
+ invitation_repository: InvitationRepository,
+ session: AsyncSession,
+) -> None:
+ user = await create_user("foo@bar.com", session)
+ workspace = await workspace_repository.create_workspace(
+ name="Test Organization", slug="test-organization", owner=user
+ )
+ user_db = await anext(get_user_repository(session))
+ user_dict = {
+ "email": "bar@bar.com",
+ "hashed_password": "notreallyhashed",
+ "is_verified": True,
+ }
+ new_user = await user_db.create(user_dict)
+
+ invitation = await invitation_repository.create_invitation(workspace_id=workspace.id, email=new_user.email)
+
+ stored_invitation = await invitation_repository.get_invitation_by_email(email=new_user.email)
+ assert stored_invitation == invitation
+
+
+@pytest.mark.asyncio
+async def test_update_invitation(
+ workspace_repository: WorkspaceRepository,
+ invitation_repository: InvitationRepository,
+ session: AsyncSession,
+) -> None:
+ user = await create_user("foo@bar.com", session)
+ workspace = await workspace_repository.create_workspace(
+ name="Test Organization", slug="test-organization", owner=user
+ )
+ user_db = await anext(get_user_repository(session))
+ user_dict = {
+ "email": "bar@bar.com",
+ "hashed_password": "notreallyhashed",
+ "is_verified": True,
+ }
+ new_user = await user_db.create(user_dict)
+
+ invitation = await invitation_repository.create_invitation(workspace_id=workspace.id, email=new_user.email)
+ assert invitation.accepted_at is None
+
+ now = utc()
+
+ updated = await invitation_repository.update_invitation(invitation.id, lambda i: evolve(i, accepted_at=now))
+ assert updated.accepted_at is not None
+
+
+@pytest.mark.asyncio
+async def test_delete_invitation(
+ workspace_repository: WorkspaceRepository, invitation_repository: InvitationRepository, session: AsyncSession
+) -> None:
+ user = await create_user("foo@bar.com", session)
+ organization = await workspace_repository.create_workspace(
+ name="Test Organization", slug="test-organization", owner=user
+ )
+ org_id = organization.id
+
+ user_db = await anext(get_user_repository(session))
+ user_dict = {
+ "email": "bar@bar.com",
+ "hashed_password": "notreallyhashed",
+ "is_verified": True,
+ }
+ new_user = await user_db.create(user_dict)
+
+ invitation = await invitation_repository.create_invitation(workspace_id=org_id, email=new_user.email)
+
+ # delete the invitation
+ await invitation_repository.delete_invitation(invitation_id=invitation.id)
+
+ # the invitation should not exist anymore
+ invitations = await invitation_repository.list_invitations(workspace_id=org_id)
+ assert len(invitations) == 0
diff --git a/tests/fixbackend/workspaces/invitation_service_test.py b/tests/fixbackend/workspaces/invitation_service_test.py
new file mode 100644
index 00000000..1df26e5c
--- /dev/null
+++ b/tests/fixbackend/workspaces/invitation_service_test.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+
+from typing import Optional, List
+from attrs import frozen
+import pytest
+from fixbackend.domain_events.events import InvitationAccepted, UserJoinedWorkspace
+from fixbackend.workspaces.invitation_service import InvitationService, InvitationServiceImpl
+
+
+from fixbackend.workspaces.repository import WorkspaceRepository
+from fixbackend.workspaces.invitation_repository import InvitationRepository
+from fixbackend.notification.service import NotificationService
+from fixbackend.auth.user_repository import UserRepository
+from fixbackend.config import Config
+from fixbackend.auth.models import User
+from tests.fixbackend.conftest import InMemoryDomainEventPublisher
+
+
+@frozen
+class NotificationEmail:
+ to: str
+ subject: str
+ text: str
+ html: Optional[str]
+
+
+class InMemoryNotificationService(NotificationService):
+ def __init__(self) -> None:
+ self.call_args: List[NotificationEmail] = []
+
+ async def send_email(self, *, to: str, subject: str, text: str, html: str | None) -> None:
+ self.call_args.append(NotificationEmail(to, subject, text, html))
+
+
+@pytest.fixture
+def notification_service() -> InMemoryNotificationService:
+ return InMemoryNotificationService()
+
+
+@pytest.fixture
+def service(
+ workspace_repository: WorkspaceRepository,
+ invitation_repository: InvitationRepository,
+ notification_service: NotificationService,
+ user_repository: UserRepository,
+ domain_event_sender: InMemoryDomainEventPublisher,
+ default_config: Config,
+) -> InvitationService:
+ return InvitationServiceImpl(
+ workspace_repository=workspace_repository,
+ invitation_repository=invitation_repository,
+ notification_service=notification_service,
+ user_repository=user_repository,
+ domain_events=domain_event_sender,
+ config=default_config,
+ )
+
+
+@pytest.mark.asyncio
+async def test_invite_accept_user(
+ service: InvitationService,
+ workspace_repository: WorkspaceRepository,
+ invitation_repository: InvitationRepository,
+ notification_service: InMemoryNotificationService,
+ user_repository: UserRepository,
+ domain_event_sender: InMemoryDomainEventPublisher,
+ user: User,
+) -> None:
+ workspace = await workspace_repository.create_workspace(
+ name="Test Organization", slug="test-organization", owner=user
+ )
+
+ new_user_email = "new@foo.com"
+
+ # invite new user
+ invite, _ = await service.invite_user(workspace.id, user, new_user_email, "https://example.com")
+ assert await invitation_repository.list_invitations(workspace.id) == [invite]
+
+ # idempotency
+ second_invite, _ = await service.invite_user(workspace.id, user, new_user_email, "https://example.com")
+ assert second_invite == invite
+
+ # list invitations
+ assert await invitation_repository.list_invitations(workspace.id) == [invite]
+
+ # check email
+ email = notification_service.call_args[0]
+ assert email.to == new_user_email
+ assert email.subject == f"FIX Cloud {user.email} has invited you to FIX workspace"
+ assert email.text.startswith(f"{user.email} has invited you to join the workspace {workspace.name}.")
+ assert "https://example.com?token=" in email.text
+
+ # existing user
+ existing_user = await user_repository.create(
+ {
+ "email": "existing@foo.com",
+ "hashed_password": "notreallyhashed",
+ "is_verified": True,
+ }
+ )
+ existing_invite, token = await service.invite_user(workspace.id, user, existing_user.email, "https://example.com")
+
+ # when the existinng user accepts the invite, they should be added to the workspace automatically
+ # and the invitation should be deleted
+ await service.accept_invitation(token)
+ assert list(map(lambda w: w.id, await workspace_repository.list_workspaces(existing_user.id))) == [workspace.id]
+ assert await service.list_invitations(workspace.id) == [invite]
+ assert len(domain_event_sender.events) == 3
+ assert domain_event_sender.events[1] == UserJoinedWorkspace(workspace.id, existing_user.id)
+ assert domain_event_sender.events[2] == InvitationAccepted(workspace.id, existing_user.email)
+
+ # invite can be revoked
+ await service.revoke_invitation(invite.id)
+ assert await service.list_invitations(workspace.id) == []
+
+ # invlid token is rejected
+ with pytest.raises(ValueError):
+ await service.accept_invitation("invalid token")
diff --git a/tests/fixbackend/workspaces/repository_test.py b/tests/fixbackend/workspaces/repository_test.py
index 23c7d372..04e5f870 100644
--- a/tests/fixbackend/workspaces/repository_test.py
+++ b/tests/fixbackend/workspaces/repository_test.py
@@ -17,7 +17,7 @@
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
-from fixbackend.auth.db import get_user_repository
+from fixbackend.auth.user_repository import get_user_repository
from fixbackend.auth.models import User
from fixbackend.ids import WorkspaceId, UserId
from fixbackend.workspaces.repository import WorkspaceRepository
@@ -84,7 +84,7 @@ async def test_update_workspace(workspace_repository: WorkspaceRepository, user:
@pytest.mark.asyncio
-async def test_list_organizations(workspace_repository: WorkspaceRepository, user: User) -> None:
+async def test_list_workspaces(workspace_repository: WorkspaceRepository, user: User, session: AsyncSession) -> None:
workspace1 = await workspace_repository.create_workspace(
name="Test Organization 1", slug="test-organization-1", owner=user
)
@@ -93,16 +93,22 @@ async def test_list_organizations(workspace_repository: WorkspaceRepository, use
name="Test Organization 2", slug="test-organization-2", owner=user
)
+ user_db = await anext(get_user_repository(session))
+ new_user_dict = {"email": "bar@bar.com", "hashed_password": "notreallyhashed", "is_verified": True}
+ new_user = await user_db.create(new_user_dict)
+ member_only_workspace = await workspace_repository.create_workspace(
+ name="Test Organization 3", slug="test-organization-3", owner=new_user
+ )
+ await workspace_repository.add_to_workspace(workspace_id=member_only_workspace.id, user_id=user.id)
+
# the user should be the owner of the organization
workspaces = await workspace_repository.list_workspaces(user.id)
- assert len(workspaces) == 2
- assert set([o.id for o in workspaces]) == {workspace1.id, workspace2.id}
+ assert len(workspaces) == 3
+ assert set([o.id for o in workspaces]) == {workspace1.id, workspace2.id, member_only_workspace.id}
@pytest.mark.asyncio
-async def test_add_to_organization(
- workspace_repository: WorkspaceRepository, session: AsyncSession, user: User
-) -> None:
+async def test_add_to_workspace(workspace_repository: WorkspaceRepository, session: AsyncSession, user: User) -> None:
# add an existing user to the organization
organization = await workspace_repository.create_workspace(
name="Test Organization", slug="test-organization", owner=user
@@ -120,104 +126,11 @@ async def test_add_to_organization(
assert len(retrieved_organization.members) == 1
assert retrieved_organization.members[0] == new_user.id
+ assert retrieved_organization.owners[0] == user.id
+
# when adding a user which is already a member of the organization, nothing should happen
await workspace_repository.add_to_workspace(workspace_id=org_id, user_id=new_user_id)
# when adding a non-existing user to the organization, an exception should be raised
with pytest.raises(Exception):
await workspace_repository.add_to_workspace(workspace_id=org_id, user_id=UserId(uuid.uuid4()))
-
-
-@pytest.mark.asyncio
-async def test_create_invitation(workspace_repository: WorkspaceRepository, session: AsyncSession, user: User) -> None:
- organization = await workspace_repository.create_workspace(
- name="Test Organization", slug="test-organization", owner=user
- )
- org_id = organization.id
-
- user_db = await anext(get_user_repository(session))
- user_dict = {
- "email": "123foo@bar.com",
- "hashed_password": "notreallyhashed",
- "is_verified": True,
- }
- new_user = await user_db.create(user_dict)
- new_user_id = new_user.id
-
- invitation = await workspace_repository.create_invitation(workspace_id=org_id, user_id=new_user.id)
- assert invitation.workspace_id == org_id
- assert invitation.user_id == new_user_id
-
-
-@pytest.mark.asyncio
-async def test_accept_invitation(workspace_repository: WorkspaceRepository, session: AsyncSession, user: User) -> None:
- organization = await workspace_repository.create_workspace(
- name="Test Organization", slug="test-organization", owner=user
- )
- org_id = organization.id
-
- user_db = await anext(get_user_repository(session))
- user_dict = {
- "email": "123foo@bar.com",
- "hashed_password": "notreallyhashed",
- "is_verified": True,
- }
- new_user = await user_db.create(user_dict)
-
- invitation = await workspace_repository.create_invitation(workspace_id=org_id, user_id=new_user.id)
-
- # accept the invitation
- await workspace_repository.accept_invitation(invitation_id=invitation.id)
-
- retrieved_organization = await workspace_repository.get_workspace(org_id)
- assert retrieved_organization
- assert len(retrieved_organization.members) == 1
- assert retrieved_organization.members[0] == new_user.id
-
-
-@pytest.mark.asyncio
-async def test_list_invitations(workspace_repository: WorkspaceRepository, session: AsyncSession, user: User) -> None:
- organization = await workspace_repository.create_workspace(
- name="Test Organization", slug="test-organization", owner=user
- )
- org_id = organization.id
-
- user_db = await anext(get_user_repository(session))
- user_dict = {
- "email": "bar@bar.com",
- "hashed_password": "notreallyhashed",
- "is_verified": True,
- }
- new_user = await user_db.create(user_dict)
-
- invitation = await workspace_repository.create_invitation(workspace_id=org_id, user_id=new_user.id)
-
- # list the invitations
- invitations = await workspace_repository.list_invitations(workspace_id=org_id)
- assert len(invitations) == 1
- assert invitations[0] == invitation
-
-
-@pytest.mark.asyncio
-async def test_delete_invitation(workspace_repository: WorkspaceRepository, session: AsyncSession, user: User) -> None:
- organization = await workspace_repository.create_workspace(
- name="Test Organization", slug="test-organization", owner=user
- )
- org_id = organization.id
-
- user_db = await anext(get_user_repository(session))
- user_dict = {
- "email": "bar@bar.com",
- "hashed_password": "notreallyhashed",
- "is_verified": True,
- }
- new_user = await user_db.create(user_dict)
-
- invitation = await workspace_repository.create_invitation(workspace_id=org_id, user_id=new_user.id)
-
- # delete the invitation
- await workspace_repository.delete_invitation(invitation_id=invitation.id)
-
- # the invitation should not exist anymore
- invitations = await workspace_repository.list_invitations(workspace_id=org_id)
- assert len(invitations) == 0
diff --git a/tests/fixbackend/workspaces/router_test.py b/tests/fixbackend/workspaces/router_test.py
index b76552d1..52636a57 100644
--- a/tests/fixbackend/workspaces/router_test.py
+++ b/tests/fixbackend/workspaces/router_test.py
@@ -14,8 +14,7 @@
import uuid
-from typing import AsyncIterator, Sequence
-from uuid import UUID
+from typing import AsyncIterator, Optional, Sequence
from attrs import evolve
import pytest
@@ -59,10 +58,12 @@ class WorkspaceRepositoryMock(WorkspaceRepositoryImpl):
def __init__(self) -> None:
pass
- async def get_workspace(self, workspace_id: UUID) -> Workspace | None:
+ async def get_workspace(
+ self, workspace_id: WorkspaceId, *, session: Optional[AsyncSession] = None
+ ) -> Workspace | None:
return workspace
- async def list_workspaces(self, owner_id: UUID) -> Sequence[Workspace]:
+ async def list_workspaces(self, owner_id: UserId) -> Sequence[Workspace]:
return [workspace]
async def update_workspace(self, workspace_id: WorkspaceId, name: str, generate_external_id: bool) -> Workspace: