diff --git a/routers/account.py b/routers/account.py index 598ada290..c09c83c67 100644 --- a/routers/account.py +++ b/routers/account.py @@ -8,7 +8,7 @@ from furl import furl from loguru import logger from requests.models import HTTPError -from starlette.responses import Response +from starlette.exceptions import HTTPException from bots.models import PublishedRun, PublishedRunVisibility, Workflow from daras_ai_v2 import icons, paypal @@ -197,12 +197,7 @@ def invitation_route( workspace_slug: str | None, email: str | None, ): - try: - invite_id = WorkspaceInvite.api_hashids.decode(invite_id)[0] - invite = WorkspaceInvite.objects.select_related("workspace").get(id=invite_id) - except (IndexError, WorkspaceInvite.DoesNotExist): - return Response(status_code=404) - + invite = load_invite_from_hashid_or_404(invite_id) invitation_page(current_user=request.user, session=request.session, invite=invite) description = invite.created_by.full_name() @@ -223,6 +218,14 @@ def invitation_route( ) +def load_invite_from_hashid_or_404(invite_id: str) -> WorkspaceInvite: + try: + invite_id = WorkspaceInvite.api_hashids.decode(invite_id)[0] + return WorkspaceInvite.objects.select_related("workspace").get(id=invite_id) + except (IndexError, WorkspaceInvite.DoesNotExist): + raise HTTPException(status_code=404) + + class TabData(typing.NamedTuple): title: str route: typing.Callable diff --git a/routers/root.py b/routers/root.py index 8ddd9fca8..06398d3b6 100644 --- a/routers/root.py +++ b/routers/root.py @@ -1,12 +1,14 @@ import datetime import json import tempfile +import traceback import typing from contextlib import contextmanager from enum import Enum from time import time import gooey_gui as gui +import sentry_sdk from fastapi import Depends from fastapi import HTTPException from fastapi.responses import RedirectResponse @@ -34,6 +36,7 @@ fastapi_request_json, fastapi_request_form, get_route_path, + resolve_url, ) from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_content import build_meta_tags, raw_build_meta_tags @@ -99,6 +102,9 @@ async def favicon(): @app.get("/login/") def login(request: Request): + from routers.account import invitation_route + from routers.account import load_invite_from_hashid_or_404 + if request.user and not request.user.is_anonymous: return RedirectResponse( request.query_params.get("next", DEFAULT_LOGIN_REDIRECT) @@ -106,6 +112,19 @@ def login(request: Request): context = { "request": request, } + + try: + if ( + (next_url := request.query_params.get("next")) + and (match := resolve_url(next_url)) + and match.route.name == invitation_route.__name__ + and (invite_id := match.matched_params.get("invite_id")) + ): + context["invite"] = load_invite_from_hashid_or_404(invite_id) + except Exception as e: + traceback.print_exc() + sentry_sdk.capture_exception(e) + return templates.TemplateResponse( "login_options.html", context=context, diff --git a/server.py b/server.py index 50fbdb5dc..cdc84ceda 100644 --- a/server.py +++ b/server.py @@ -153,5 +153,8 @@ async def _exc_handler(request: Request, exc: Exception, template_name: str): from gooey_gui.core.reloader import runserver runserver( - "server:app", port=8080, reload=True, reload_excludes=["models.py", "api.py"] + "server:app", + port=8080, + reload=True, + reload_excludes=["models.py", "admin.py", "api.py"], ) diff --git a/templates/login_options.html b/templates/login_options.html index c9365e079..2ea779519 100644 --- a/templates/login_options.html +++ b/templates/login_options.html @@ -10,7 +10,14 @@ {% block content %}

Sign in to Gooey.AI

-

Sign in to access your run history, API key & more. New verified accounts receive {{ settings.VERIFIED_EMAIL_USER_FREE_CREDITS }} credits. 💰

+ t
+

+ {% if invite %} + Sign in with {{ invite.email }} to join {{ invite.workspace.display_name() }} and collaborate with your team. + {% else %} + Sign in to access your run history, API key & more. New verified accounts receive {{ settings.VERIFIED_EMAIL_USER_FREE_CREDITS }} credits. 💰 + {% endif %} +

Loading...

diff --git a/workspaces/models.py b/workspaces/models.py index b5ffd711e..11375509e 100644 --- a/workspaces/models.py +++ b/workspaces/models.py @@ -353,6 +353,26 @@ def get_photo(self) -> str | None: else: return self.photo_url or DEFAULT_WORKSPACE_PHOTO_URL + def add_domain_members(self): + from app_users.models import AppUser + + if not self.domain_name: + return + current_user = self.get_owners().first() + if not current_user: + return + for user_email in ( + AppUser.objects.filter(email__iendswith=self.domain_name) + .exclude(workspace_memberships__workspace=self) + .values_list("email", flat=True) + )[:50]: + WorkspaceInvite.objects.create_and_send_invite( + workspace=self, + email=user_email, + current_user=current_user, + defaults=dict(role=WorkspaceRole.MEMBER), + ) + class WorkspaceMembership(SafeDeleteModel): workspace = models.ForeignKey( @@ -583,7 +603,7 @@ def accept( self, invitee: AppUser, *, - updated_by: AppUser | None, + updated_by: AppUser | None = None, auto_accepted: bool = False, ) -> tuple[WorkspaceMembership, bool]: """ diff --git a/workspaces/signals.py b/workspaces/signals.py index b7188e21e..db63ca9ee 100644 --- a/workspaces/signals.py +++ b/workspaces/signals.py @@ -1,10 +1,12 @@ import traceback + +import sentry_sdk from django.core.exceptions import ValidationError -from django.db.models.signals import post_save +from django.db import transaction +from django.db.models.signals import post_save, pre_save from django.dispatch import receiver from loguru import logger from safedelete.signals import post_softdelete -import sentry_sdk from app_users.models import AppUser from .models import Workspace, WorkspaceInvite, WorkspaceMembership, WorkspaceRole @@ -17,6 +19,7 @@ def add_user_existing_workspace(instance: AppUser, **kwargs): """ if not instance.email: return + email_domain = instance.email.split("@")[-1].lower() for workspace in Workspace.objects.filter(domain_name=email_domain): try: @@ -30,6 +33,12 @@ def add_user_existing_workspace(instance: AppUser, **kwargs): traceback.print_exc() sentry_sdk.capture_exception(e) + for invite in WorkspaceInvite.objects.filter(email__iexact=instance.email): + try: + invite.accept(instance, auto_accepted=True) + except ValidationError: + traceback.print_exc() + @receiver(post_softdelete, sender=WorkspaceMembership) def delete_workspace_if_no_members_left(instance: WorkspaceMembership, **kwargs): @@ -39,3 +48,17 @@ def delete_workspace_if_no_members_left(instance: WorkspaceMembership, **kwargs) f"Deleting workspace {instance.workspace} because it has no members left" ) instance.workspace.delete() + + +@receiver(pre_save, sender=Workspace) +def add_members_on_workspace_domain_change(instance: Workspace, **kwargs): + if instance.id: + old_workspace_domain = ( + Workspace.objects.filter(id=instance.id) + .values_list("domain_name", flat=True) + .first() + ) + else: + old_workspace_domain = None + if instance.domain_name and instance.domain_name != old_workspace_domain: + transaction.on_commit(instance.add_domain_members) diff --git a/workspaces/views.py b/workspaces/views.py index 3b88a7a3b..a2e426e6a 100644 --- a/workspaces/views.py +++ b/workspaces/views.py @@ -27,6 +27,11 @@ def invitation_page( current_user: AppUser | None, session: dict, invite: WorkspaceInvite ): from routers.root import login + from routers.account import members_route + + if invite.status == WorkspaceInvite.Status.ACCEPTED: + set_current_workspace(session, int(invite.workspace_id)) + raise gui.RedirectException(get_route_path(members_route)) with ( gui.div( @@ -313,11 +318,11 @@ def edit_workspace_button_with_dialog(membership: WorkspaceMembership): return try: workspace_copy.full_clean() + workspace_copy.save() except ValidationError as e: # newlines in markdown gui.write("\n".join(e.messages), className="text-danger") else: - workspace_copy.save() membership.workspace.refresh_from_db() ref.set_open(False) gui.rerun() @@ -355,7 +360,7 @@ def render_invite_creation_form(workspace: Workspace) -> tuple[str, str]: "###### Role", options=WorkspaceRole, format_func=WorkspaceRole.display_html, - value=WorkspaceRole.ADMIN.value, + value=WorkspaceRole.MEMBER.value, key="invite-form-role", )