diff --git a/fixbackend/workspaces/invitation_service.py b/fixbackend/workspaces/invitation_service.py index a9b3b20c..67b7ac0e 100644 --- a/fixbackend/workspaces/invitation_service.py +++ b/fixbackend/workspaces/invitation_service.py @@ -58,7 +58,7 @@ async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[Workspac raise NotImplementedError() @abstractmethod - async def accept_invitation(self, token: str) -> None: + async def accept_invitation(self, token: str) -> WorkspaceInvitation: """Accept an invitation to a workspace.""" raise NotImplementedError() @@ -111,7 +111,7 @@ async def invite_user( async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvitation]: return await self.invitation_repository.list_invitations(workspace_id) - async def accept_invitation(self, token: str) -> None: + async def accept_invitation(self, token: str) -> WorkspaceInvitation: try: decoded_state = decode_jwt(token, self.config.secret, [STATE_TOKEN_AUDIENCE]) except (jwt.ExpiredSignatureError, jwt.DecodeError) as ex: @@ -123,7 +123,7 @@ async def accept_invitation(self, token: str) -> None: if invitation is None: raise ValueError(f"Invitation {invitation_id} does not exist.") - await self.invitation_repository.update_invitation( + updated = await self.invitation_repository.update_invitation( invitation_id, lambda invite: evolve(invite, accepted_at=utc()) ) @@ -132,6 +132,8 @@ async def accept_invitation(self, token: str) -> None: await self.workspace_repository.add_to_workspace(invitation.workspace_id, user.id) await self.invitation_repository.delete_invitation(invitation_id) + return updated + async def revoke_invitation(self, invitation_id: InvitationId) -> None: await self.invitation_repository.delete_invitation(invitation_id) diff --git a/fixbackend/workspaces/router.py b/fixbackend/workspaces/router.py index 755a3581..469c81e5 100644 --- a/fixbackend/workspaces/router.py +++ b/fixbackend/workspaces/router.py @@ -14,8 +14,8 @@ from typing import List -from fastapi import APIRouter, HTTPException, Request -from pydantic import EmailStr +from fastapi import APIRouter, HTTPException, Request, Response +from fastapi.responses import RedirectResponse from sqlalchemy.exc import IntegrityError from fixbackend.auth.depedencies import AuthenticatedUser @@ -26,6 +26,7 @@ from fixbackend.workspaces.dependencies import UserWorkspaceDependency from fixbackend.workspaces.schemas import ( ExternalIdRead, + InviteEmail, WorkspaceCreate, WorkspaceInviteRead, WorkspaceRead, @@ -108,20 +109,13 @@ async def list_invites( ) -> List[WorkspaceInviteRead]: invites = await invitation_service.list_invitations(workspace_id=workspace.id) - return [ - WorkspaceInviteRead( - organization_slug=workspace.slug, - user_email=invite.email, - expires_at=invite.expires_at, - ) - for invite in invites - ] + return [WorkspaceInviteRead.from_model(invite, workspace) for invite in invites] @router.post("/{workspace_id}/invites/") async def invite_to_organization( workspace: UserWorkspaceDependency, user: AuthenticatedUser, - user_email: EmailStr, + email: InviteEmail, invitation_service: InvitationServiceDependency, request: Request, ) -> WorkspaceInviteRead: @@ -130,14 +124,10 @@ async def invite_to_organization( accept_invite_url = str(request.url_for(ACCEPT_INVITE_ROUTE_NAME, workspace_id=workspace.id)) invite, _ = await invitation_service.invite_user( - workspace_id=workspace.id, inviter=user, invitee_email=user_email, accept_invite_base_url=accept_invite_url + workspace_id=workspace.id, inviter=user, invitee_email=email.email, accept_invite_base_url=accept_invite_url ) - return WorkspaceInviteRead( - organization_slug=workspace.slug, - user_email=invite.email, - expires_at=invite.expires_at, - ) + return WorkspaceInviteRead.from_model(invite, workspace) @router.delete("/{workspace_id}/invites/{invite_id}") async def delete_invite( @@ -150,11 +140,12 @@ async def delete_invite( @router.get("{workspace_id}/accept_invite", name=ACCEPT_INVITE_ROUTE_NAME) async def accept_invitation( - token: str, - invitation_service: InvitationServiceDependency, - ) -> None: + token: str, invitation_service: InvitationServiceDependency, request: Request + ) -> Response: """Accept an invitation to the workspace.""" - await invitation_service.accept_invitation(token) + invitation = await invitation_service.accept_invitation(token) + url = request.base_url.replace_query_params(message="invitation-accepted", workspace_id=invitation.workspace_id) + return RedirectResponse(url) @router.get("/{workspace_id}/cf_url") async def get_cf_url( diff --git a/fixbackend/workspaces/schemas.py b/fixbackend/workspaces/schemas.py index 14271f38..7b426e45 100644 --- a/fixbackend/workspaces/schemas.py +++ b/fixbackend/workspaces/schemas.py @@ -13,12 +13,12 @@ # along with this program. If not, see . from datetime import datetime -from typing import List +from typing import List, Optional from fixbackend.ids import WorkspaceId, UserId, ExternalId -from pydantic import BaseModel, Field +from pydantic import BaseModel, EmailStr, Field -from fixbackend.workspaces.models import Workspace +from fixbackend.workspaces.models import Workspace, WorkspaceInvitation class WorkspaceRead(BaseModel): @@ -104,9 +104,21 @@ class WorkspaceCreate(BaseModel): class WorkspaceInviteRead(BaseModel): - organization_slug: str = Field(description="The slug of the workspace to invite the user to") + workspace_id: WorkspaceId = Field(description="The unique identifier of the workspace to invite the user to") + workspace_name: str = Field(description="The name of the workspace to invite the user to") user_email: str = Field(description="The email of the user to invite") expires_at: datetime = Field(description="The time at which the invitation expires") + accepted_at: Optional[datetime] = Field(description="The time at which the invitation was accepted, if any") + + @staticmethod + def from_model(invite: WorkspaceInvitation, workspace: Workspace) -> "WorkspaceInviteRead": + return WorkspaceInviteRead( + workspace_id=invite.workspace_id, + workspace_name=workspace.name, + user_email=invite.email, + expires_at=invite.expires_at, + accepted_at=invite.accepted_at, + ) model_config = { "json_schema_extra": { @@ -133,3 +145,7 @@ class ExternalIdRead(BaseModel): ] } } + + +class InviteEmail(BaseModel): + email: EmailStr = Field(description="The email of the user to invite")