From 2ec4826f575cde8e2de6ab9a0bb43d2e4ce6147a Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sun, 1 Oct 2023 15:37:16 +0100 Subject: [PATCH] Janaka/update/ux reloading web pages (#113) * refactor: auth token based logic * fix cookie setting * refactor: remove AUTO_LOGIN as feature type adding ability to toggle token based auth per org is making it more brittle and risky * refactor: rename env vars to prefix "DOCQ_" * refactor: rename session cookie var to be more specific * tests: fix auth utils tests after refactor * fix: remove _auto_login_feature_enabled * chore: tweak debug logs in auth utils --- .gitignore | 2 +- misc/docker.env.template | 4 +- misc/secrets.toml.template | 4 +- source/docq/config.py | 7 +- source/docq/manage_settings.py | 2 +- source/docq/support/auth_utils.py | 218 ++++++++++++++------------ tests/docq/support/auth_utils_test.py | 120 ++++++-------- web/index.py | 4 +- web/utils/handlers.py | 71 +++++---- web/utils/layout.py | 53 +++++-- web/utils/sessions.py | 21 ++- 11 files changed, 269 insertions(+), 237 deletions(-) diff --git a/.gitignore b/.gitignore index e24be176..fa6f90f1 100644 --- a/.gitignore +++ b/.gitignore @@ -176,7 +176,7 @@ cython_debug/ # Exported from Poetry, used in docker build only requirements.txt # Used for file storage in local development only -.persisted/ +.persisted*/ # Used for running Streamlit by storing configs and secrets locally .streamlit/ # Used by GitHub Pages local build before uploading to GitHub diff --git a/misc/docker.env.template b/misc/docker.env.template index eb739d2b..efddb56b 100644 --- a/misc/docker.env.template +++ b/misc/docker.env.template @@ -1,5 +1,5 @@ STREAMLIT_SERVER_ADDRESS=0.0.0.0 STREAMLIT_SERVER_PORT=8501 #default DOCQ_DATA=./.persisted/ -OPENAI_API_KEY # ideally set value on shell, don't insert a value here becuase it's a secret. -COOKIE_SECRET_KEY=cookie_password \ No newline at end of file +DOCQ_OPENAI_API_KEY # ideally set value on shell, don't insert a value here becuase it's a secret. +DOCQ_COOKIE_HMAC_SECRET_KEY=cookie_password \ No newline at end of file diff --git a/misc/secrets.toml.template b/misc/secrets.toml.template index 5bb3007a..9b6b7136 100644 --- a/misc/secrets.toml.template +++ b/misc/secrets.toml.template @@ -1,3 +1,3 @@ DOCQ_DATA = "./.persisted/" -OPENAI_API_KEY = "YOUR-OPENAI-API-KEY" -COOKIE_SECRET_KEY = "cookies_password" \ No newline at end of file +DOCQ_OPENAI_API_KEY = "YOUR-OPENAI-API-KEY" +DOCQ_COOKIE_HMAC_SECRET_KEY = "32_char_secret_used_to_encrypt" \ No newline at end of file diff --git a/source/docq/config.py b/source/docq/config.py index 05beecfb..1568a48e 100644 --- a/source/docq/config.py +++ b/source/docq/config.py @@ -4,9 +4,9 @@ ENV_VAR_DOCQ_DATA = "DOCQ_DATA" ENV_VAR_DOCQ_DEMO = "DOCQ_DEMO" -ENV_VAR_OPENAI_API_KEY = "OPENAI_API_KEY" -ENV_VAR_COOKIE_SECRET_KEY = "COOKIE_SECRET_KEY" -COOKIE_NAME = "docqai/_docq" +ENV_VAR_OPENAI_API_KEY = "DOCQ_OPENAI_API_KEY" +ENV_VAR_DOCQ_COOKIE_HMAC_SECRET_KEY = "DOCQ_COOKIE_HMAC_SECRET_KEY" +SESSION_COOKIE_NAME = "docqai/_docq" class SpaceType(Enum): @@ -24,7 +24,6 @@ class FeatureType(Enum): ASK_SHARED = "Ask Shared Documents" ASK_PUBLIC = "Ask Public Documents" CHAT_PRIVATE = "General Chat" - AUTO_LOGIN = "Auto Login" class LogType(Enum): diff --git a/source/docq/manage_settings.py b/source/docq/manage_settings.py index 7c365dda..0eaf5ade 100644 --- a/source/docq/manage_settings.py +++ b/source/docq/manage_settings.py @@ -59,11 +59,11 @@ def _get_settings(org_id: int, user_id: int = None) -> dict: def _update_settings(settings: dict, org_id: int, user_id: int = None) -> bool: - log.debug("Updating settings for user %d", user_id) with closing( sqlite3.connect(_get_sqlite_file(user_id), detect_types=sqlite3.PARSE_DECLTYPES) ) as connection, closing(connection.cursor()) as cursor: user_id = user_id or USER_ID_AS_SYSTEM + log.debug("Updating settings for user %d", user_id) cursor.executemany( "INSERT OR REPLACE INTO settings (user_id, org_id, key, val) VALUES (?, ?, ?, ?)", [(user_id, org_id, key, json.dumps(val)) for key, val in settings.items()], diff --git a/source/docq/support/auth_utils.py b/source/docq/support/auth_utils.py index 9f9e359b..90856860 100644 --- a/source/docq/support/auth_utils.py +++ b/source/docq/support/auth_utils.py @@ -6,56 +6,70 @@ import os from datetime import datetime, timedelta from secrets import token_hex -from typing import Any, Dict, Optional +from typing import Dict, Optional from cachetools import TTLCache from cryptography.fernet import Fernet from streamlit.components.v1 import html from streamlit.web.server.websocket_headers import _get_websocket_headers -from ..config import COOKIE_NAME, ENV_VAR_COOKIE_SECRET_KEY, FeatureType -from ..manage_settings import SystemSettingsKey, get_organisation_settings +from ..config import SESSION_COOKIE_NAME, ENV_VAR_DOCQ_COOKIE_HMAC_SECRET_KEY EXPIRY_HOURS = 4 CACHE_CONFIG = (1024 * 1, 60 * 60 * EXPIRY_HOURS) AUTH_KEY = Fernet.generate_key() -AUTH_SESSION_SECRET_KEY: str = os.environ.get(ENV_VAR_COOKIE_SECRET_KEY) +AUTH_SESSION_SECRET_KEY: str = os.environ.get(ENV_VAR_DOCQ_COOKIE_HMAC_SECRET_KEY) + +# Chase of session data keyed by session id +cached_sessions: TTLCache[str, bytes] = TTLCache(*CACHE_CONFIG) + +# Cache of session id's keyed by hmac hash +session_data: TTLCache[str, str] = TTLCache(*CACHE_CONFIG) + + +# TODO: the code that handles the cookie should move to the web side. session state tracking is in the backend but not a public API as it's just cross cutting. -# Session Cache. -cached_sessions:TTLCache[str, bytes] = TTLCache(*CACHE_CONFIG) -session_data:TTLCache[str, str]= TTLCache(*CACHE_CONFIG) def init_session_cache() -> None: """Initialize session cache.""" if AUTH_SESSION_SECRET_KEY is None: - log.fatal("Failed to initialize session cache: COOKIE_SECRET_KEY not set") - raise ValueError("COOKIE_SECRET_KEY must be set") + log.fatal("Failed to initialize session cache: DOCQ_COOKIE_HMAC_SECRET_KEY not set") + raise ValueError("DOCQ_COOKIE_HMAC_SECRET_KEY must be set") if len(AUTH_SESSION_SECRET_KEY) < 32: - log.fatal("Failed to initialize session cache: COOKIE_SECRET_KEY must be 32 or more characters") - raise ValueError("COOKIE_SECRET_KEY must be 32 or more characters") + log.fatal("Failed to initialize session cache: DOCQ_COOKIE_HMAC_SECRET_KEY must be 32 or more characters") + raise ValueError("DOCQ_COOKIE_HMAC_SECRET_KEY must be 32 or more characters") def _set_cookie(cookie: str) -> None: """Set client cookie for authentication.""" try: expiry = datetime.now() + timedelta(hours=EXPIRY_HOURS) - html(f""" + html( + f""" - """, width=0, height=0) + """, + width=0, + height=0, + ) except Exception as e: log.error("Failed to set cookie: %s", e) -def _clear_cookie() -> None: +def _clear_cookie(cookie_name: str) -> None: """Clear client cookie.""" - html(f""" + html( + f""" - """, width=0, height=0) + """, + width=0, + height=0, + ) + log.debug("Clear client cookie: %s", cookie_name) def _get_cookies() -> Optional[Dict[str, str]]: @@ -75,59 +89,70 @@ def _get_cookies() -> Optional[Dict[str, str]]: return None -def _create_hmac( msg: str) -> str: +def _create_hmac(msg: str) -> str: """Create a HMAC hash.""" - return hmac.new( - AUTH_SESSION_SECRET_KEY.encode(), - msg.encode(), - hashlib.sha256 - ).hexdigest() + return hmac.new(AUTH_SESSION_SECRET_KEY.encode(), msg.encode(), hashlib.sha256).hexdigest() def _verify_hmac(msg: str, digest: str) -> bool: """Verify credibility of HMAC hash.""" - return hmac.compare_digest( - _create_hmac(msg), - digest - ) + return hmac.compare_digest(_create_hmac(msg), digest) -def generate_session_id(length: int = 32) -> str: - """Generate a secure and unique session_id.""" +def generate_hmac_session_id(length: int = 32) -> str: + """Generate a secure (HMAC) and unique session_id then track in session cache.""" id_ = token_hex(length // 2) hmac_ = _create_hmac(id_) session_data[hmac_] = id_ + log.debug("Generated new hmac session id: %s", hmac_) return hmac_ -def _set_session_id(session_id: str) -> None: - """Set the session_id in the cookie.""" - _set_cookie(session_id) +def _set_cookie_session_id(hmac_session_id: str) -> None: + """Set the encrypted session_id in the cookie.""" + _set_cookie(hmac_session_id) + log.debug("_set_cookie_session_id() - hmac session id: %s", hmac_session_id) -def _get_session_id() -> Optional[str]: - """Return the session_id from the cookie.""" +def _get_cookie_session_id() -> str | None: + """Return the Docq encrypted HMAC session_id from the cookie.""" try: + hmac_session_id = None cookies = _get_cookies() if cookies is not None: - cookie = cookies.get(COOKIE_NAME) - if cookie is None: - return None - if cookie not in cached_sessions: - return None - if not _verify_hmac(session_data[cookie], cookie): - log.warning("Session ID not verified: %s", cookie) - return None - return cookie + hmac_session_id = cookies.get(SESSION_COOKIE_NAME) + return hmac_session_id except Exception as e: log.error("Failed to get session id: %s", e) return None -def _encrypt_auth(*args: tuple, **kwargs: dict) -> bytes: - """Encrypt the auth data.""" +def verify_cookie_hmac_session_id() -> str | None: + """Verify the encrypted session_id from the cookie. + + Return: + str: The hmac_session_id if verified. + None: If not verified. + """ + hmac_session_id = None + hmac_session_id = _get_cookie_session_id() + if hmac_session_id is None: + log.debug("No session id in cookie found") + elif hmac_session_id not in cached_sessions: + log.debug( + "verify_cookie_hmac_session_id(): HMAC Session ID not found in cache. Session expired or was explicitly removed: %s" + ) + hmac_session_id = None + elif hmac_session_id not in session_data or not _verify_hmac(session_data[hmac_session_id], hmac_session_id): + log.warning("verify_cookie_hmac_session_id(): HMAC Session ID failed verification: %s") + hmac_session_id = None + return hmac_session_id + + +def _encrypt(payload: dict) -> bytes: + """Encrypt some data.""" try: - data = json.dumps([args, kwargs]).encode() + data = json.dumps(payload).encode() cipher = Fernet(AUTH_KEY) return cipher.encrypt(data) except Exception as e: @@ -135,88 +160,75 @@ def _encrypt_auth(*args: tuple, **kwargs: dict) -> bytes: return None -def _decrypt_auth(configs: bytes) -> tuple[tuple, dict]: - """Decrypt the auth data.""" +def _decrypt(encrypted_payload: bytes) -> dict: + """Decrypt some data.""" try: cipher = Fernet(AUTH_KEY) - data = cipher.decrypt(configs) - data_ = list(json.loads(data)) - return tuple(data_[0]), data_[1] + data = cipher.decrypt(encrypted_payload) + result = json.loads(data) + return result except Exception as e: log.error("Failed to decrypt auth data: %s", e) return None -def _update_auth_expiry(session_id: str) -> None: +def _reset_expiry_cache_auth_session(session_id: str) -> None: """Update the auth expiry time.""" try: cached_sessions[session_id] = cached_sessions[session_id] session_data[session_id] = session_data[session_id] - _set_session_id(session_id) + # _set_cookie_session_id(session_id) except Exception as e: log.error("Failed to update auth expiry: %s", e) -def _auto_login_enabled(org_id: int) -> bool: - """Check if auto login feature is enabled.""" - try: - system_settings = get_organisation_settings(org_id=org_id, key=SystemSettingsKey.ENABLED_FEATURES) - if system_settings: # Only enable feature when explicitly enabled (dafault to Disabled) - return FeatureType.AUTO_LOGIN.name in system_settings - return False - except Exception as e: - log.error("Failed to check if auto login is enabled: %s", e) - return False - - -def cache_session_state_configs(*args: tuple, **kwargs: dict[str, Any]) -> None: - """Caches the session state configs for auth. - - This will cache any arguments and keyword arguments passed to it and can be retrived - by calling the get_auth_configs function: - >>> docq.support.auth_utils.get_auth_configs() +def set_cache_auth_session(val: dict) -> None: + """Caches the session state configs for auth, persisting across connections. Args: - args: Arguments to be passed to the auth function. - kwargs: Keyword arguments to be passed to the auth function. + val (dict): The session state for auth. """ try: - session_id = _get_session_id() - if not session_id: - session_id = generate_session_id() - _set_session_id(session_id) - cached_sessions[session_id] = _encrypt_auth(*args, **kwargs) - _update_auth_expiry(session_id) + hmac_session_id = _get_cookie_session_id() + if hmac_session_id is None: + hmac_session_id = generate_hmac_session_id() + _set_cookie_session_id(hmac_session_id) + cached_sessions[hmac_session_id] = _encrypt(val) + _reset_expiry_cache_auth_session(hmac_session_id) except Exception as e: - log.error("Error caching auth session state: %s", e) + log.error("Error caching auth session: %s", e) -def get_auth_configs() -> Optional[tuple[tuple, dict]]: - """Get cached session state configs for auth.""" +def get_cache_auth_session() -> dict | None: + """Verify the session auth token and get the cached session state for the current session. The current session is identified by a session_id wrapped in a auth token in a browser session cookie.""" try: - session_id = _get_session_id() - if session_id in cached_sessions: - configs = cached_sessions[session_id] - args, kwargs = _decrypt_auth(configs) - selected_org_id = kwargs.get("selected_org_id") or args[1] - if not _auto_login_enabled(selected_org_id): - return None - return _decrypt_auth(configs) - else: - return None + decrypted_auth_session_data = None + hmac_session_id = _get_cookie_session_id() + if hmac_session_id in cached_sessions: + encrypted_auth_session_data = cached_sessions[hmac_session_id] + decrypted_auth_session_data = _decrypt(encrypted_auth_session_data) + return decrypted_auth_session_data except Exception as e: - log.error("Failed to get auth result: %s", e) + log.error("Failed to get auth session from cache: %s", e) return None -def session_logout() -> None: - """Clear all the data used to remember user session.""" +def remove_cache_auth_session() -> None: + """Remove the cached session state for the current session. The current session is identified by a session_id in a particular browsersession cookie.""" + try: + hmac_session_id = _get_cookie_session_id() + if hmac_session_id in cached_sessions: + del cached_sessions[hmac_session_id] + if hmac_session_id in session_data: + del session_data[hmac_session_id] + except Exception as e: + log.error("Failed to remove auth session from cache: %s", e) + + +def reset_cache_and_cookie_auth_session() -> None: + """Clear all the data used to remember user session (auth session cache and session cookie). This must be called at login and cookie.""" try: - session_id = _get_session_id() - if session_id in cached_sessions: - del cached_sessions[session_id] - if session_id in session_data: - del session_data[session_id] - _clear_cookie() + remove_cache_auth_session() + _clear_cookie(SESSION_COOKIE_NAME) except Exception as e: - log.error("Failed to logout: %s", e) + log.error("Failed to clear session data caches (hmac, session data, and session cookie ): %s", e) diff --git a/tests/docq/support/auth_utils_test.py b/tests/docq/support/auth_utils_test.py index 6b4cbf47..78b6d91c 100644 --- a/tests/docq/support/auth_utils_test.py +++ b/tests/docq/support/auth_utils_test.py @@ -4,25 +4,24 @@ from typing import Self from unittest.mock import Mock, patch -from docq.config import FeatureType from docq.support import auth_utils from docq.support.auth_utils import ( - _auto_login_enabled, + SESSION_COOKIE_NAME, _clear_cookie, _create_hmac, - _decrypt_auth, - _encrypt_auth, + _decrypt, + _encrypt, + _get_cookie_session_id, _get_cookies, - _get_session_id, _set_cookie, - _set_session_id, + _set_cookie_session_id, _verify_hmac, - cache_session_state_configs, cached_sessions, - generate_session_id, - get_auth_configs, + generate_hmac_session_id, + get_cache_auth_session, + reset_cache_and_cookie_auth_session, session_data, - session_logout, + set_cache_auth_session, ) @@ -33,21 +32,18 @@ def setUp(self: Self) -> None: """Setup module.""" auth_utils.AUTH_SESSION_SECRET_KEY = token_hex(32) - @patch("docq.support.auth_utils.html") def test_set_cookie(self: Self, mock_html: Mock) -> None: """Test set cookie.""" _set_cookie("cookie") mock_html.assert_called_once() - @patch("docq.support.auth_utils.html") def test_clear_cookie(self: Self, mock_html: Mock) -> None: """Test clear cookie.""" - _clear_cookie() + _clear_cookie(SESSION_COOKIE_NAME) mock_html.assert_called_once() - @patch("docq.support.auth_utils._get_websocket_headers") def test_get_cookies(self: Self, mock_headers: Mock) -> None: """Test get cookies.""" @@ -55,14 +51,12 @@ def test_get_cookies(self: Self, mock_headers: Mock) -> None: result = _get_cookies() assert result == {"key": "value"} - def test_create_hmac(self: Self) -> None: """Test create hmac.""" msg = "test" digest = _create_hmac(msg) assert isinstance(digest, str) - def test_verify_hmac(self: Self) -> None: """Test verify hmac.""" msg = "test" @@ -70,85 +64,65 @@ def test_verify_hmac(self: Self) -> None: result = _verify_hmac(msg, digest) assert result - def test_generate_session_id(self: Self) -> None: """Test generate session id.""" - id_ = generate_session_id() + id_ = generate_hmac_session_id() assert isinstance(id_, str) assert len(id_) == 64 - @patch("docq.support.auth_utils._set_cookie") def test_set_session_id(self: Self, mock_set_cookie: Mock) -> None: """Test set session id.""" session_id = "test" - _set_session_id(session_id) + _set_cookie_session_id(session_id) mock_set_cookie.assert_called_once_with(session_id) - @patch("docq.support.auth_utils._get_cookies") - def test_get_session_id(self: Self, mock_get_cookies: Mock) -> None: + def test_get_cookie_session_id(self: Self, mock_get_cookies: Mock) -> None: """Test get session id.""" - session_id = generate_session_id() - cached_sessions[session_id] = _encrypt_auth(("9999", "user", 1)) - mock_get_cookies.return_value = {"docqai/_docq": session_id} - result = _get_session_id() + session_id = generate_hmac_session_id() + cached_sessions[session_id] = _encrypt(("9999", "user", 1)) + mock_get_cookies.return_value = {SESSION_COOKIE_NAME: session_id} + result = _get_cookie_session_id() assert result == session_id - def test_encrypt_decrypt_auth(self: Self) -> None: """Test encrypt decrypt auth.""" - auth, kwargs = ("9999", "user", 1), {} - encrypted_auth = _encrypt_auth(*auth, **kwargs) - decrypted_auth = _decrypt_auth(encrypted_auth) - assert (auth, kwargs) == decrypted_auth + payload = {"org_id": "9999", "username": "user name", "user_id": 1} + encrypted_auth = _encrypt(payload) + decrypted_auth = _decrypt(encrypted_auth) + assert payload == decrypted_auth - - @patch("docq.support.auth_utils._get_session_id") - @patch("docq.support.auth_utils._auto_login_enabled") - def test_cache_auth( - self: Self, - mock_auto_login_enabled: Mock, - mock_get_session_id: Mock - ) -> None: + @patch("docq.support.auth_utils._get_cookie_session_id") + def test_cache_auth(self: Self, mock_get_cookie_session_id: Mock) -> None: """Test cache auth.""" - args, kwargs = ("9999", "user", 1), {} - session_id = generate_session_id() - mock_get_session_id.return_value = session_id - mock_auto_login_enabled.return_value = True - cache_session_state_configs(*args, **kwargs) + payload = {"org_id": "9999", "username": "user name", "user_id": 1} + session_id = generate_hmac_session_id() + mock_get_cookie_session_id.return_value = session_id + set_cache_auth_session(payload) assert session_id in cached_sessions - - @patch("docq.support.auth_utils._auto_login_enabled") - @patch("docq.support.auth_utils._get_session_id") - def test_auth_result(self: Self, mock_get_session_id: Mock, mock_auto_login_enabled: Mock) -> None: + @patch("docq.support.auth_utils._get_cookie_session_id") + def test_auth_result( + self: Self, + mock_get_cookie_session_id: Mock, + ) -> None: """Test auth result.""" - args, kwargs = (("9999", "user", 1), {}) - session_id = generate_session_id() - mock_get_session_id.return_value = session_id - mock_auto_login_enabled.return_value = True - cache_session_state_configs(*args, **kwargs) - result = get_auth_configs() - assert result == (args, kwargs), "Auth result should be same as input" - - - @patch("docq.support.auth_utils._get_session_id") - def test_session_logout(self: Self, mock_get_session_id: Mock) -> None: + payload = {"org_id": "9999", "username": "user name", "user_id": 1} + session_id = generate_hmac_session_id() + mock_get_cookie_session_id.return_value = session_id + # mock_auto_login_enabled.return_value = True + set_cache_auth_session(payload) + result = get_cache_auth_session() + assert result == {"org_id": "9999", "username": "user name", "user_id": 1} + + @patch("docq.support.auth_utils._get_cookie_session_id") + def test_session_logout(self: Self, mock_get_cookie_session_id: Mock) -> None: """Test session logout.""" - session_id = generate_session_id() - cached_sessions[session_id] = _encrypt_auth(("9999", "user", 1)) + session_id = generate_hmac_session_id() + cached_sessions[session_id] = _encrypt(("9999", "user", 1)) session_data[session_id] = session_id - mock_get_session_id.return_value = session_id - session_logout() - assert session_id not in cached_sessions , "Cached session should be deleted on logout" + mock_get_cookie_session_id.return_value = session_id + reset_cache_and_cookie_auth_session() + assert session_id not in cached_sessions, "Cached session should be deleted on logout" assert session_id not in session_data, "Session data should be deleted on logout" - - - @patch("docq.support.auth_utils.get_organisation_settings") - def test_auto_login_enabled(self: Self, mock_get_system_settings: Mock) -> None: - """Test auto login enabled.""" - mock_get_system_settings.return_value = [FeatureType.AUTO_LOGIN.name] - result = _auto_login_enabled(9999) - assert mock_get_system_settings.call_count == 1 - assert result, "Auto login should be enabled" diff --git a/web/index.py b/web/index.py index 794c0e35..710fd1a1 100644 --- a/web/index.py +++ b/web/index.py @@ -2,9 +2,9 @@ import streamlit as st from st_pages import Page, Section, add_page_title, show_pages -from utils.layout import load_setup_ui, org_selection_ui, production_layout, public_access +from utils.layout import init_with_pretty_error_ui, org_selection_ui, production_layout, public_access -load_setup_ui() +init_with_pretty_error_ui() production_layout() diff --git a/web/utils/handlers.py b/web/utils/handlers.py index a88c6310..4e6d8229 100644 --- a/web/utils/handlers.py +++ b/web/utils/handlers.py @@ -24,7 +24,7 @@ from docq.access_control.main import SpaceAccessor, SpaceAccessType from docq.data_source.list import SpaceDataSources from docq.domain import DocumentListItem, SpaceKey -from docq.support.auth_utils import cache_session_state_configs, get_auth_configs, session_logout +from docq.support.auth_utils import get_cache_auth_session, reset_cache_and_cookie_auth_session, set_cache_auth_session from .constants import ( MAX_NUMBER_OF_PERSONAL_DOCS, @@ -41,6 +41,7 @@ get_chat_session, get_public_space_group_id, get_selected_org_id, + get_settings_session, get_username, reset_session_state, set_auth_session, @@ -60,7 +61,8 @@ def _set_session_state_configs( super_admin: bool = False, selected_org_admin: bool = False, space_group_id: Optional[int] = None, - public_session_id: Optional[str] = None ) -> None: + public_session_id: Optional[str] = None, +) -> None: """Set the session state for the configs. Args: @@ -86,17 +88,18 @@ def _set_session_state_configs( SessionKeyNameForAuth.PUBLIC_SESSION_ID.name: public_session_id, SessionKeyNameForAuth.PUBLIC_SPACE_GROUP_ID.name: space_group_id, SessionKeyNameForAuth.ANONYMOUS.name: anonymous, - } + }, + True, ) else: - cache_session_state_configs( - user_id=user_id, - selected_org_id=selected_org_id, - name=name, - username=username, - super_admin=super_admin, - selected_org_admin=selected_org_admin, - ) + # cache_session_state_configs( + # user_id=user_id, + # selected_org_id=selected_org_id, + # name=name, + # username=username, + # super_admin=super_admin, + # selected_org_admin=selected_org_admin, + # ) set_auth_session( { SessionKeyNameForAuth.ID.name: user_id, @@ -106,7 +109,8 @@ def _set_session_state_configs( SessionKeyNameForAuth.SELECTED_ORG_ID.name: selected_org_id, SessionKeyNameForAuth.SELECTED_ORG_ADMIN.name: selected_org_admin, SessionKeyNameForAuth.ANONYMOUS.name: anonymous, - } + }, + True, ) set_settings_session( { @@ -121,15 +125,17 @@ def _set_session_state_configs( def handle_login(username: str, password: str) -> bool: """Handle login.""" reset_session_state() + reset_cache_and_cookie_auth_session() result = manage_users.authenticate(username, password) - current_user_id = result[0] - member_orgs = manage_organisations.list_organisations( - user_id=current_user_id - ) # we can't use handle_list_orgs() here - default_org_id = member_orgs[0][0] - selected_org_admin = current_user_id in [x[0] for x in member_orgs[0][2]] - log.info("Login result: %s", result) + if result: + current_user_id = result[0] + member_orgs = manage_organisations.list_organisations( + user_id=current_user_id + ) # we can't use handle_list_orgs() here + default_org_id = member_orgs[0][0] + selected_org_admin = current_user_id in [x[0] for x in member_orgs[0][2]] + log.info("Login result: %s", result) _set_session_state_configs( user_id=current_user_id, selected_org_id=default_org_id, @@ -145,18 +151,11 @@ def handle_login(username: str, password: str) -> bool: return False -def handle_set_cached_session_configs() -> None: - """Set cached auth configs.""" - auth_configs = get_auth_configs() - if auth_configs and len(auth_configs) == 2: - _args, _kwargs = auth_configs - _set_session_state_configs(*_args, **_kwargs) - - def handle_logout() -> None: """Handle logout.""" reset_session_state() - session_logout() + reset_cache_and_cookie_auth_session() + log.info("Logout") def handle_create_user() -> int: @@ -354,10 +353,7 @@ def _get_chat_spaces(feature: domain.FeatureKey) -> tuple[Optional[SpaceKey], Li if feature.type_ == config.FeatureType.ASK_PUBLIC: personal_space = None - shared_spaces = [ - domain.SpaceKey(config.SpaceType.SHARED, s_[0], select_org_id) - for s_ in list_public_spaces() - ] + shared_spaces = [domain.SpaceKey(config.SpaceType.SHARED, s_[0], select_org_id) for s_ in list_public_spaces()] return personal_space, shared_spaces shared_spaces = None @@ -533,6 +529,7 @@ def get_enabled_features() -> list[domain.FeatureKey]: def handle_update_system_settings() -> None: current_org_id = get_selected_org_id() + manage_settings.update_organisation_settings( { config.SystemSettingsKey.ENABLED_FEATURES.name: [ @@ -541,6 +538,14 @@ def handle_update_system_settings() -> None: }, org_id=current_org_id, ) + set_settings_session( + { + config.SystemSettingsKey.ENABLED_FEATURES.name: [ + f.name for f in st.session_state[f"system_settings_{config.SystemSettingsKey.ENABLED_FEATURES.name}"] + ], + }, + SessionKeyNameForSettings.SYSTEM, + ) def get_max_number_of_documents(type_: config.SpaceType): @@ -636,7 +641,7 @@ def handle_public_session() -> None: space_group_id=space_group_id, public_session_id=session_id, ) - else: # if no query params are provided, set space_group_id and public_session_id to -1 to disable ASK_PUBLIC feature + else: # if no query params are provided, set space_group_id and public_session_id to -1 to disable ASK_PUBLIC feature _set_session_state_configs( user_id=None, selected_org_id=None, diff --git a/web/utils/layout.py b/web/utils/layout.py index 0e8c3dfb..a55a182c 100644 --- a/web/utils/layout.py +++ b/web/utils/layout.py @@ -8,7 +8,11 @@ from docq.access_control.main import SpaceAccessType from docq.config import FeatureType, LogType, SpaceType, SystemSettingsKey from docq.domain import DocumentListItem, FeatureKey, SpaceKey -from docq.manage_users import list_users_by_org +from docq.support.auth_utils import ( + get_cache_auth_session, + reset_cache_and_cookie_auth_session, + verify_cookie_hmac_session_id, +) from st_pages import hide_pages from streamlit.components.v1 import html from streamlit.delta_generator import DeltaGenerator @@ -16,6 +20,7 @@ from .constants import ALLOWED_DOC_EXTS, SessionKeyNameForAuth, SessionKeyNameForChat from .formatters import format_archived, format_datetime, format_filesize, format_timestamp from .handlers import ( + _set_session_state_configs, get_enabled_features, get_max_number_of_documents, get_shared_space, @@ -44,7 +49,6 @@ handle_org_selection_change, handle_public_session, handle_reindex_space, - handle_set_cached_session_configs, handle_update_org, handle_update_space_details, handle_update_space_group, @@ -70,7 +74,8 @@ get_public_space_group_id, get_selected_org_id, is_current_user_super_admin, - set_selected_org_id, + reset_session_state, + session_state_exists, ) _chat_ui_script = """ @@ -200,6 +205,7 @@ def __no_admin_menu() -> None: ] ) + def __embed_page_config() -> None: st.markdown( """ @@ -239,7 +245,7 @@ def __login_form() -> None: if handle_login(username, password): st.experimental_rerun() else: - st.error("Invalid username or password.") + st.error("The Username and Password you entered doesn't match what we have.") st.stop() else: st.stop() @@ -268,10 +274,33 @@ def public_access() -> None: def auth_required(show_login_form: bool = True, requiring_admin: bool = False, show_logout_button: bool = True) -> bool: """Decide layout based on current user's access.""" - handle_set_cached_session_configs() - auth = get_auth_session() + log.debug("auth_required() called") + auth = None __always_hidden_pages() + + session_state_existed = session_state_exists() + log.debug("auth_required(): session_state_existed: %s", session_state_existed) + if session_state_existed: + auth = get_auth_session() + elif verify_cookie_hmac_session_id() is not None: + # there's a valid auth session token. Let's get session state from cache. + auth = get_cache_auth_session() + log.debug("auth_required(): Got auth session state from cache: %s", auth) + if auth: + log.debug("auth_required(): Valid auth session found: %s", auth) + if not session_state_existed: + # the user probably refreshed the page resetting Streamlit session state because it's bound to a browser session connection. + _set_session_state_configs( + user_id=auth[SessionKeyNameForAuth.ID.name], + selected_org_id=auth[SessionKeyNameForAuth.SELECTED_ORG_ID.name], + name=auth[SessionKeyNameForAuth.NAME.name], + username=auth[SessionKeyNameForAuth.USERNAME.name], + anonymous=False, + super_admin=auth[SessionKeyNameForAuth.SUPER_ADMIN.name], + selected_org_admin=auth[SessionKeyNameForAuth.SELECTED_ORG_ADMIN.name], + ) + if show_logout_button: __logout_button() @@ -283,6 +312,9 @@ def auth_required(show_login_form: bool = True, requiring_admin: bool = False, s return True else: + log.debug("auth_required(): No valid auth session found. User needs to re-authenticate.") + reset_session_state() + reset_cache_and_cookie_auth_session() if show_login_form: __login_form() return False @@ -314,7 +346,7 @@ def public_space_enabled(feature: FeatureKey) -> None: feature_is_ready, spaces = (space_group_id != -1 or session_id != -1), None if feature_is_ready: spaces = list_public_spaces() - if not feature_is_ready or not spaces: # Stop the app if there are no public spaces. + if not feature_is_ready or not spaces: # Stop the app if there are no public spaces. st.error("This feature is not ready.") st.info("Please contact your administrator to configure this feature.") st.stop() @@ -881,12 +913,11 @@ def org_selection_ui() -> None: handle_org_selection_change(selected[0]) - -def load_setup_ui() -> None: +def init_with_pretty_error_ui() -> None: """UI to run setup and prevent showing errors to the user.""" try: setup.init() except Exception as e: - st.error("Docq encountered an error while initializing please refer to logs for more details.") - log.exception("Error while setting up the app: %s", e) + st.error("Something went wrong starting Docq.") + log.fatal("Error: setup.init() failed with %s", e) st.stop() diff --git a/web/utils/sessions.py b/web/utils/sessions.py index bca975d7..aee05b7d 100644 --- a/web/utils/sessions.py +++ b/web/utils/sessions.py @@ -1,9 +1,11 @@ """Session utilities.""" +import logging from typing import Any import streamlit as st from docq import config, manage_users +from docq.support.auth_utils import set_cache_auth_session from .constants import ( SESSION_KEY_NAME_DOCQ, @@ -25,9 +27,15 @@ def _init_session_state() -> None: st.session_state[SESSION_KEY_NAME_DOCQ][SessionKeySubName.CHAT.name][n.name] = {} +def session_state_exists() -> bool: + """Check if any session state exists.""" + return SESSION_KEY_NAME_DOCQ in st.session_state + + def reset_session_state() -> None: """Reset the session state. This must be called for user login and logout.""" st.session_state[SESSION_KEY_NAME_DOCQ] = {} + logging.debug("called reset_session_state()") def _get_session_value(name: SessionKeySubName, key_: str = None, subkey_: str = None) -> Any | None: @@ -72,6 +80,14 @@ def set_chat_session(val: Any | None, type_: config.FeatureType = None, key_: Se ) +def set_auth_session(val: dict = None, cache: bool = False) -> None: + """Set the auth session value.""" + _set_session_value(val, SessionKeySubName.AUTH) + if cache: + # this persists the auth session across browser session in Streamlit i.e. when the user hits refresh. + set_cache_auth_session(val) + + def get_auth_session() -> dict: """Get the auth session value.""" return _get_session_value(SessionKeySubName.AUTH) @@ -96,11 +112,6 @@ def set_if_current_user_is_selected_org_admin(selected_org_id: int) -> None: _set_session_value(is_org_admin, SessionKeySubName.AUTH, SessionKeyNameForAuth.SELECTED_ORG_ADMIN.name) -def set_auth_session(val: dict = None) -> None: - """Set the auth session value.""" - _set_session_value(val, SessionKeySubName.AUTH) - - def get_authenticated_user_id() -> int | None: """Get the authenticated user id.""" return _get_session_value(SessionKeySubName.AUTH, SessionKeyNameForAuth.ID.name)