diff --git a/fasthtml/oauth.py b/fasthtml/oauth.py
index 8b16f1ae..c5f59998 100644
--- a/fasthtml/oauth.py
+++ b/fasthtml/oauth.py
@@ -138,8 +138,8 @@ def url_match(url, patterns=http_patterns):
# %% ../nbs/api/08_oauth.ipynb
class OAuth:
- def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):
- if not skip: skip = [redir_path,login_path]
+ def __init__(self, app, cli, skip=None, redir_path='/redirect', error_path='/error', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):
+ if not skip: skip = [redir_path,error_path,login_path]
store_attr()
def before(req, session):
auth = req.scope['auth'] = session.get('auth')
@@ -150,8 +150,8 @@ def before(req, session):
app.before.append(Beforeware(before, skip=skip))
@app.get(redir_path)
- def redirect(code:str, req, session, state:str=None):
- if not code: return "No code provided!"
+ def redirect(req, session, code:str=None, error:str=None, state:str=None):
+ if not code: session['oauth_error']=error; return RedirectResponse(self.error_path, status_code=303)
scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'
base_url = f"{scheme}://{req.url.netloc}"
info = AttrDictDefault(cli.retr_info(code, base_url+redir_path))
diff --git a/nbs/api/08_oauth.ipynb b/nbs/api/08_oauth.ipynb
index dbe96210..797ec534 100644
--- a/nbs/api/08_oauth.ipynb
+++ b/nbs/api/08_oauth.ipynb
@@ -417,8 +417,8 @@
"source": [
"#| export\n",
"class OAuth:\n",
- " def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):\n",
- " if not skip: skip = [redir_path,login_path]\n",
+ " def __init__(self, app, cli, skip=None, redir_path='/redirect', error_path='/error', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):\n",
+ " if not skip: skip = [redir_path,error_path,login_path]\n",
" store_attr()\n",
" def before(req, session):\n",
" auth = req.scope['auth'] = session.get('auth')\n",
@@ -429,8 +429,8 @@
" app.before.append(Beforeware(before, skip=skip))\n",
"\n",
" @app.get(redir_path)\n",
- " def redirect(code:str, req, session, state:str=None):\n",
- " if not code: return \"No code provided!\"\n",
+ " def redirect(req, session, code:str=None, error:str=None, state:str=None):\n",
+ " if not code: session['oauth_error']=error; return RedirectResponse(self.error_path, status_code=303)\n",
" scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'\n",
" base_url = f\"{scheme}://{req.url.netloc}\"\n",
" info = AttrDictDefault(cli.retr_info(code, base_url+redir_path))\n",