diff --git a/routers/account.py b/routers/account.py
index 598ada290..c09c83c67 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -8,7 +8,7 @@
from furl import furl
from loguru import logger
from requests.models import HTTPError
-from starlette.responses import Response
+from starlette.exceptions import HTTPException
from bots.models import PublishedRun, PublishedRunVisibility, Workflow
from daras_ai_v2 import icons, paypal
@@ -197,12 +197,7 @@ def invitation_route(
workspace_slug: str | None,
email: str | None,
):
- try:
- invite_id = WorkspaceInvite.api_hashids.decode(invite_id)[0]
- invite = WorkspaceInvite.objects.select_related("workspace").get(id=invite_id)
- except (IndexError, WorkspaceInvite.DoesNotExist):
- return Response(status_code=404)
-
+ invite = load_invite_from_hashid_or_404(invite_id)
invitation_page(current_user=request.user, session=request.session, invite=invite)
description = invite.created_by.full_name()
@@ -223,6 +218,14 @@ def invitation_route(
)
+def load_invite_from_hashid_or_404(invite_id: str) -> WorkspaceInvite:
+ try:
+ invite_id = WorkspaceInvite.api_hashids.decode(invite_id)[0]
+ return WorkspaceInvite.objects.select_related("workspace").get(id=invite_id)
+ except (IndexError, WorkspaceInvite.DoesNotExist):
+ raise HTTPException(status_code=404)
+
+
class TabData(typing.NamedTuple):
title: str
route: typing.Callable
diff --git a/routers/root.py b/routers/root.py
index 8ddd9fca8..06398d3b6 100644
--- a/routers/root.py
+++ b/routers/root.py
@@ -1,12 +1,14 @@
import datetime
import json
import tempfile
+import traceback
import typing
from contextlib import contextmanager
from enum import Enum
from time import time
import gooey_gui as gui
+import sentry_sdk
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import RedirectResponse
@@ -34,6 +36,7 @@
fastapi_request_json,
fastapi_request_form,
get_route_path,
+ resolve_url,
)
from daras_ai_v2.manage_api_keys_widget import manage_api_keys
from daras_ai_v2.meta_content import build_meta_tags, raw_build_meta_tags
@@ -99,6 +102,9 @@ async def favicon():
@app.get("/login/")
def login(request: Request):
+ from routers.account import invitation_route
+ from routers.account import load_invite_from_hashid_or_404
+
if request.user and not request.user.is_anonymous:
return RedirectResponse(
request.query_params.get("next", DEFAULT_LOGIN_REDIRECT)
@@ -106,6 +112,19 @@ def login(request: Request):
context = {
"request": request,
}
+
+ try:
+ if (
+ (next_url := request.query_params.get("next"))
+ and (match := resolve_url(next_url))
+ and match.route.name == invitation_route.__name__
+ and (invite_id := match.matched_params.get("invite_id"))
+ ):
+ context["invite"] = load_invite_from_hashid_or_404(invite_id)
+ except Exception as e:
+ traceback.print_exc()
+ sentry_sdk.capture_exception(e)
+
return templates.TemplateResponse(
"login_options.html",
context=context,
diff --git a/server.py b/server.py
index 50fbdb5dc..cdc84ceda 100644
--- a/server.py
+++ b/server.py
@@ -153,5 +153,8 @@ async def _exc_handler(request: Request, exc: Exception, template_name: str):
from gooey_gui.core.reloader import runserver
runserver(
- "server:app", port=8080, reload=True, reload_excludes=["models.py", "api.py"]
+ "server:app",
+ port=8080,
+ reload=True,
+ reload_excludes=["models.py", "admin.py", "api.py"],
)
diff --git a/templates/login_options.html b/templates/login_options.html
index c9365e079..2ea779519 100644
--- a/templates/login_options.html
+++ b/templates/login_options.html
@@ -10,7 +10,14 @@
{% block content %}
Sign in to Gooey.AI
-
Sign in to access your run history, API key & more. New verified accounts receive {{ settings.VERIFIED_EMAIL_USER_FREE_CREDITS }} credits. 💰
+ t
+
+ {% if invite %}
+ Sign in with {{ invite.email }}
to join {{ invite.workspace.display_name() }} and collaborate with your team.
+ {% else %}
+ Sign in to access your run history, API key & more. New verified accounts receive {{ settings.VERIFIED_EMAIL_USER_FREE_CREDITS }} credits. 💰
+ {% endif %}
+
Loading...
diff --git a/workspaces/models.py b/workspaces/models.py
index b5ffd711e..11375509e 100644
--- a/workspaces/models.py
+++ b/workspaces/models.py
@@ -353,6 +353,26 @@ def get_photo(self) -> str | None:
else:
return self.photo_url or DEFAULT_WORKSPACE_PHOTO_URL
+ def add_domain_members(self):
+ from app_users.models import AppUser
+
+ if not self.domain_name:
+ return
+ current_user = self.get_owners().first()
+ if not current_user:
+ return
+ for user_email in (
+ AppUser.objects.filter(email__iendswith=self.domain_name)
+ .exclude(workspace_memberships__workspace=self)
+ .values_list("email", flat=True)
+ )[:50]:
+ WorkspaceInvite.objects.create_and_send_invite(
+ workspace=self,
+ email=user_email,
+ current_user=current_user,
+ defaults=dict(role=WorkspaceRole.MEMBER),
+ )
+
class WorkspaceMembership(SafeDeleteModel):
workspace = models.ForeignKey(
@@ -583,7 +603,7 @@ def accept(
self,
invitee: AppUser,
*,
- updated_by: AppUser | None,
+ updated_by: AppUser | None = None,
auto_accepted: bool = False,
) -> tuple[WorkspaceMembership, bool]:
"""
diff --git a/workspaces/signals.py b/workspaces/signals.py
index b7188e21e..db63ca9ee 100644
--- a/workspaces/signals.py
+++ b/workspaces/signals.py
@@ -1,10 +1,12 @@
import traceback
+
+import sentry_sdk
from django.core.exceptions import ValidationError
-from django.db.models.signals import post_save
+from django.db import transaction
+from django.db.models.signals import post_save, pre_save
from django.dispatch import receiver
from loguru import logger
from safedelete.signals import post_softdelete
-import sentry_sdk
from app_users.models import AppUser
from .models import Workspace, WorkspaceInvite, WorkspaceMembership, WorkspaceRole
@@ -17,6 +19,7 @@ def add_user_existing_workspace(instance: AppUser, **kwargs):
"""
if not instance.email:
return
+
email_domain = instance.email.split("@")[-1].lower()
for workspace in Workspace.objects.filter(domain_name=email_domain):
try:
@@ -30,6 +33,12 @@ def add_user_existing_workspace(instance: AppUser, **kwargs):
traceback.print_exc()
sentry_sdk.capture_exception(e)
+ for invite in WorkspaceInvite.objects.filter(email__iexact=instance.email):
+ try:
+ invite.accept(instance, auto_accepted=True)
+ except ValidationError:
+ traceback.print_exc()
+
@receiver(post_softdelete, sender=WorkspaceMembership)
def delete_workspace_if_no_members_left(instance: WorkspaceMembership, **kwargs):
@@ -39,3 +48,17 @@ def delete_workspace_if_no_members_left(instance: WorkspaceMembership, **kwargs)
f"Deleting workspace {instance.workspace} because it has no members left"
)
instance.workspace.delete()
+
+
+@receiver(pre_save, sender=Workspace)
+def add_members_on_workspace_domain_change(instance: Workspace, **kwargs):
+ if instance.id:
+ old_workspace_domain = (
+ Workspace.objects.filter(id=instance.id)
+ .values_list("domain_name", flat=True)
+ .first()
+ )
+ else:
+ old_workspace_domain = None
+ if instance.domain_name and instance.domain_name != old_workspace_domain:
+ transaction.on_commit(instance.add_domain_members)
diff --git a/workspaces/views.py b/workspaces/views.py
index 3b88a7a3b..a2e426e6a 100644
--- a/workspaces/views.py
+++ b/workspaces/views.py
@@ -27,6 +27,11 @@ def invitation_page(
current_user: AppUser | None, session: dict, invite: WorkspaceInvite
):
from routers.root import login
+ from routers.account import members_route
+
+ if invite.status == WorkspaceInvite.Status.ACCEPTED:
+ set_current_workspace(session, int(invite.workspace_id))
+ raise gui.RedirectException(get_route_path(members_route))
with (
gui.div(
@@ -313,11 +318,11 @@ def edit_workspace_button_with_dialog(membership: WorkspaceMembership):
return
try:
workspace_copy.full_clean()
+ workspace_copy.save()
except ValidationError as e:
# newlines in markdown
gui.write("\n".join(e.messages), className="text-danger")
else:
- workspace_copy.save()
membership.workspace.refresh_from_db()
ref.set_open(False)
gui.rerun()
@@ -355,7 +360,7 @@ def render_invite_creation_form(workspace: Workspace) -> tuple[str, str]:
"###### Role",
options=WorkspaceRole,
format_func=WorkspaceRole.display_html,
- value=WorkspaceRole.ADMIN.value,
+ value=WorkspaceRole.MEMBER.value,
key="invite-form-role",
)