diff --git a/fasthtml/oauth.py b/fasthtml/oauth.py index 8b16f1ae..c5f59998 100644 --- a/fasthtml/oauth.py +++ b/fasthtml/oauth.py @@ -138,8 +138,8 @@ def url_match(url, patterns=http_patterns): # %% ../nbs/api/08_oauth.ipynb class OAuth: - def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns): - if not skip: skip = [redir_path,login_path] + def __init__(self, app, cli, skip=None, redir_path='/redirect', error_path='/error', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns): + if not skip: skip = [redir_path,error_path,login_path] store_attr() def before(req, session): auth = req.scope['auth'] = session.get('auth') @@ -150,8 +150,8 @@ def before(req, session): app.before.append(Beforeware(before, skip=skip)) @app.get(redir_path) - def redirect(code:str, req, session, state:str=None): - if not code: return "No code provided!" + def redirect(req, session, code:str=None, error:str=None, state:str=None): + if not code: session['oauth_error']=error; return RedirectResponse(self.error_path, status_code=303) scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https' base_url = f"{scheme}://{req.url.netloc}" info = AttrDictDefault(cli.retr_info(code, base_url+redir_path)) diff --git a/nbs/api/08_oauth.ipynb b/nbs/api/08_oauth.ipynb index dbe96210..797ec534 100644 --- a/nbs/api/08_oauth.ipynb +++ b/nbs/api/08_oauth.ipynb @@ -417,8 +417,8 @@ "source": [ "#| export\n", "class OAuth:\n", - " def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):\n", - " if not skip: skip = [redir_path,login_path]\n", + " def __init__(self, app, cli, skip=None, redir_path='/redirect', error_path='/error', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):\n", + " if not skip: skip = [redir_path,error_path,login_path]\n", " store_attr()\n", " def before(req, session):\n", " auth = req.scope['auth'] = session.get('auth')\n", @@ -429,8 +429,8 @@ " app.before.append(Beforeware(before, skip=skip))\n", "\n", " @app.get(redir_path)\n", - " def redirect(code:str, req, session, state:str=None):\n", - " if not code: return \"No code provided!\"\n", + " def redirect(req, session, code:str=None, error:str=None, state:str=None):\n", + " if not code: session['oauth_error']=error; return RedirectResponse(self.error_path, status_code=303)\n", " scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'\n", " base_url = f\"{scheme}://{req.url.netloc}\"\n", " info = AttrDictDefault(cli.retr_info(code, base_url+redir_path))\n",