Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the invitation system #201

Merged
merged 20 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions fixbackend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from fixbackend.subscription.billing import BillingService
from fixbackend.subscription.router import subscription_router
from fixbackend.subscription.subscription_repository import SubscriptionRepository
from fixbackend.workspaces.invitation_repository import InvitationRepositoryImpl
from fixbackend.workspaces.repository import WorkspaceRepositoryImpl
from fixbackend.workspaces.router import workspaces_router
from fixbackend.domain_events.subscriber import DomainEventSubscriber
Expand Down Expand Up @@ -160,7 +161,20 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events))
workspace_repo = deps.add(
SN.workspace_repo,
WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher),
WorkspaceRepositoryImpl(
session_maker,
graph_db_access,
domain_event_publisher,
RedisPubSubPublisher(
redis=readwrite_redis,
channel="workspaces",
publisher_name="workspace_service",
),
),
)
deps.add(
SN.invitation_repository,
InvitationRepositoryImpl(session_maker, workspace_repo),
)
subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker))
deps.add(
Expand Down Expand Up @@ -245,8 +259,18 @@ async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]:
domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events))
workspace_repo = deps.add(
SN.workspace_repo,
WorkspaceRepositoryImpl(session_maker, db_access, domain_event_publisher),
WorkspaceRepositoryImpl(
session_maker,
db_access,
domain_event_publisher,
RedisPubSubPublisher(
redis=rw_redis,
channel="workspaces",
publisher_name="workspace_service",
),
),
)

cloud_accounts_redis_publisher = RedisPubSubPublisher(
redis=rw_redis,
channel="cloud_accounts",
Expand Down Expand Up @@ -306,7 +330,16 @@ async def setup_teardown_billing(_: FastAPI) -> AsyncIterator[None]:
metering_repo = deps.add(SN.metering_repo, MeteringRepository(session_maker))
workspace_repo = deps.add(
SN.workspace_repo,
WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher),
WorkspaceRepositoryImpl(
session_maker,
graph_db_access,
domain_event_publisher,
RedisPubSubPublisher(
redis=readwrite_redis,
channel="workspaces",
publisher_name="workspace_service",
),
),
)
subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker))
aws_marketplace = deps.add(
Expand Down
43 changes: 36 additions & 7 deletions fixbackend/auth/user_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from fastapi_users import BaseUserManager, UUIDIDMixin
from fastapi_users.password import PasswordHelperProtocol

from fixbackend.auth.db import UserRepository, UserRepositoryDependency
from fixbackend.auth.user_repository import UserRepository, UserRepositoryDependency
from fixbackend.auth.models import User
from fixbackend.auth.user_verifier import UserVerifier, UserVerifierDependency
from fixbackend.config import Config, ConfigDependency
from fixbackend.domain_events.events import UserRegistered
from fixbackend.domain_events.publisher import DomainEventPublisher
from fixbackend.domain_events.dependencies import DomainEventPublisherDependency
from fixbackend.workspaces.invitation_repository import InvitationRepository, InvitationRepositoryDependency
from fixbackend.workspaces.models import Workspace
from fixbackend.workspaces.repository import WorkspaceRepository, WorkspaceRepositoryDependency


Expand All @@ -39,30 +41,48 @@ def __init__(
user_verifier: UserVerifier,
workspace_repository: WorkspaceRepository,
domain_events_publisher: DomainEventPublisher,
invitation_repository: InvitationRepository,
):
super().__init__(user_repository, password_helper)
self.user_verifier = user_verifier
self.reset_password_token_secret = config.secret
self.verification_token_secret = config.secret
self.workspace_repository = workspace_repository
self.domain_events_publisher = domain_events_publisher
self.invitation_repository = invitation_repository

async def on_after_register(self, user: User, request: Request | None = None) -> None:
if user.is_verified: # oauth2 users are already verified
await self.create_default_workspace(user)
await self.add_to_workspace(user)
else:
await self.request_verify(user, request)

async def on_after_request_verify(self, user: User, token: str, request: Optional[Request] = None) -> None:
await self.user_verifier.verify(user, token, request)

async def on_after_verify(self, user: User, request: Request | None = None) -> None:
await self.create_default_workspace(user)
await self.add_to_workspace(user)

async def create_default_workspace(self, user: User) -> None:
async def add_to_workspace(self, user: User) -> None:
if (
pending_invitation := await self.invitation_repository.get_invitation_by_email(user.email)
) and pending_invitation.accepted_at:
if workspace := await self.workspace_repository.get_workspace(pending_invitation.workspace_id):
await self.workspace_repository.add_to_workspace(workspace.id, user.id)
else:
# wtf?
workspace = await self.create_default_workspace(user)
await self.invitation_repository.delete_invitation(pending_invitation.id)
else:
workspace = await self.create_default_workspace(user)

await self.domain_events_publisher.publish(
UserRegistered(user_id=user.id, email=user.email, tenant_id=workspace.id)
)

async def create_default_workspace(self, user: User) -> Workspace:
org_slug = re.sub("[^a-zA-Z0-9-]", "-", user.email)
org = await self.workspace_repository.create_workspace(user.email, org_slug, user)
await self.domain_events_publisher.publish(UserRegistered(user_id=user.id, email=user.email, tenant_id=org.id))
return await self.workspace_repository.create_workspace(user.email, org_slug, user)


async def get_user_manager(
Expand All @@ -71,8 +91,17 @@ async def get_user_manager(
user_verifier: UserVerifierDependency,
workspace_repository: WorkspaceRepositoryDependency,
domain_event_publisher: DomainEventPublisherDependency,
invitation_repository: InvitationRepositoryDependency,
) -> AsyncIterator[UserManager]:
yield UserManager(config, user_repository, None, user_verifier, workspace_repository, domain_event_publisher)
yield UserManager(
config,
user_repository,
None,
user_verifier,
workspace_repository,
domain_event_publisher,
invitation_repository,
)


UserManagerDependency = Annotated[UserManager, Depends(get_user_manager)]
File renamed without changes.
63 changes: 13 additions & 50 deletions fixbackend/auth/user_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import asyncio
from abc import ABC, abstractmethod
from typing import Annotated, Optional

import boto3
from fastapi import Depends, Request

from fixbackend.auth.models import User
from fixbackend.config import Config, ConfigDependency
from fixbackend.notification.service import NotificationService, NotificationServiceDependency


class UserVerifier(ABC):
Expand All @@ -41,59 +39,24 @@ async def verify(self, user: User, token: str, request: Optional[Request]) -> No
pass


class ConsoleUserVerifier(UserVerifier):
class UserVerifierImpl(UserVerifier):
def __init__(self, notification_service: NotificationService) -> None:
self.notification_service = notification_service

async def verify(self, user: User, token: str, request: Optional[Request]) -> None:
assert request
body_text = self.plaintext_email_content(request, token)

email_body = self.plaintext_email_content(request, token)

print(email_body)


class EMailUserVerifier(UserVerifier):
def __init__(self, config: Config) -> None:
self.client = boto3.client(
"ses",
config.aws_region,
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
await self.notification_service.send_email(
to=user.email,
subject="FIX: verify your e-mail address",
text=body_text,
html=None,
)

async def verify(self, user: User, token: str, request: Optional[Request]) -> None:
destination = user.email
assert request

def send_email(destination: str, token: str) -> None:
body_text = self.plaintext_email_content(request, token)

self.client.send_email(
Destination={
"ToAddresses": [
destination,
],
},
Message={
"Body": {
"Text": {
"Charset": "UTF-8",
"Data": body_text,
},
},
"Subject": {
"Charset": "UTF-8",
"Data": "FIX: verify your e-mail address",
},
},
Source="[email protected]",
)

await asyncio.to_thread(lambda: send_email(destination, token))


def get_user_verifier(config: ConfigDependency) -> UserVerifier:
if config.aws_access_key_id and config.aws_secret_access_key:
return EMailUserVerifier(config)
return ConsoleUserVerifier()
def get_user_verifier(notification_service: NotificationServiceDependency) -> UserVerifier:
return UserVerifierImpl(notification_service)


UserVerifierDependency = Annotated[UserVerifier, Depends(get_user_verifier)]
1 change: 1 addition & 0 deletions fixbackend/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ServiceNames:
billing = "billing"
cloud_account_service = "cloud_account_service"
domain_event_subscriber = "domain_event_subscriber"
invitation_repository = "invitation_repository"


class FixDependencies(Dependencies):
Expand Down
16 changes: 16 additions & 0 deletions fixbackend/domain_events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,19 @@ class WorkspaceCreated(Event):
kind: ClassVar[str] = "workspace_created"

workspace_id: WorkspaceId


@frozen
class InvitationAccepted(Event):
kind: ClassVar[str] = "workspace_invitation_accepted"

workspace_id: WorkspaceId
user_email: str


@frozen
class UserJoinedWorkspace(Event):
kind: ClassVar[str] = "user_joined_workspace"

workspace_id: WorkspaceId
user_id: UserId
94 changes: 94 additions & 0 deletions fixbackend/notification/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Annotated, Optional

import boto3
from fastapi import Depends

from fixbackend.config import Config, ConfigDependency


class NotificationService(ABC):
@abstractmethod
async def send_email(
self,
*,
to: str,
subject: str,
text: str,
html: Optional[str],
) -> None:
"""Send an email to the given address."""
raise NotImplementedError()


class ConsoleNotificationService(NotificationService):
async def send_email(
self,
to: str,
subject: str,
text: str,
html: Optional[str],
) -> None:
print(f"Sending email to {to} with subject {subject}")
print(f"text: {text}")
if html:
print(f"html: {html}")


class NotificationServiceImpl(NotificationService):
def __init__(self, config: Config) -> None:
self.ses = boto3.client(
"ses",
config.aws_region,
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
)

async def send_email(
self,
*,
to: str,
subject: str,
text: str,
html: Optional[str],
) -> None:
def send_email() -> None:
body_section = {
"Text": {
"Charset": "UTF-8",
"Data": text,
},
}
if html:
body_section["Html"] = {
"Charset": "UTF-8",
"Data": html,
}

self.ses.send_email(
Destination={
"ToAddresses": [
to,
],
},
Message={
"Body": body_section,
"Subject": {
"Charset": "UTF-8",
"Data": subject,
},
},
Source="[email protected]",
)

await asyncio.to_thread(send_email)


def get_notification_service(config: ConfigDependency) -> NotificationService:
if config.aws_access_key_id and config.aws_secret_access_key:
return NotificationServiceImpl(config)
return ConsoleNotificationService()


NotificationServiceDependency = Annotated[NotificationService, Depends(get_notification_service)]
Loading