Skip to content

Commit

Permalink
Janaka/update/ux reloading web pages (#113)
Browse files Browse the repository at this point in the history
* refactor: auth token based logic

* fix cookie setting

* refactor: remove AUTO_LOGIN as feature type

adding ability to toggle token based auth per org is making it more brittle and risky

* refactor: rename env vars to prefix "DOCQ_"

* refactor: rename session cookie var to be more specific

* tests: fix auth utils tests after refactor

* fix: remove _auto_login_feature_enabled

* chore: tweak debug logs in auth utils
  • Loading branch information
janaka authored Oct 1, 2023
1 parent cc91247 commit 2ec4826
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 237 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ cython_debug/
# Exported from Poetry, used in docker build only
requirements.txt
# Used for file storage in local development only
.persisted/
.persisted*/
# Used for running Streamlit by storing configs and secrets locally
.streamlit/
# Used by GitHub Pages local build before uploading to GitHub
Expand Down
4 changes: 2 additions & 2 deletions misc/docker.env.template
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
STREAMLIT_SERVER_ADDRESS=0.0.0.0
STREAMLIT_SERVER_PORT=8501 #default
DOCQ_DATA=./.persisted/
OPENAI_API_KEY # ideally set value on shell, don't insert a value here becuase it's a secret.
COOKIE_SECRET_KEY=cookie_password
DOCQ_OPENAI_API_KEY # ideally set value on shell, don't insert a value here becuase it's a secret.
DOCQ_COOKIE_HMAC_SECRET_KEY=cookie_password
4 changes: 2 additions & 2 deletions misc/secrets.toml.template
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
DOCQ_DATA = "./.persisted/"
OPENAI_API_KEY = "YOUR-OPENAI-API-KEY"
COOKIE_SECRET_KEY = "cookies_password"
DOCQ_OPENAI_API_KEY = "YOUR-OPENAI-API-KEY"
DOCQ_COOKIE_HMAC_SECRET_KEY = "32_char_secret_used_to_encrypt"
7 changes: 3 additions & 4 deletions source/docq/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

ENV_VAR_DOCQ_DATA = "DOCQ_DATA"
ENV_VAR_DOCQ_DEMO = "DOCQ_DEMO"
ENV_VAR_OPENAI_API_KEY = "OPENAI_API_KEY"
ENV_VAR_COOKIE_SECRET_KEY = "COOKIE_SECRET_KEY"
COOKIE_NAME = "docqai/_docq"
ENV_VAR_OPENAI_API_KEY = "DOCQ_OPENAI_API_KEY"
ENV_VAR_DOCQ_COOKIE_HMAC_SECRET_KEY = "DOCQ_COOKIE_HMAC_SECRET_KEY"
SESSION_COOKIE_NAME = "docqai/_docq"


class SpaceType(Enum):
Expand All @@ -24,7 +24,6 @@ class FeatureType(Enum):
ASK_SHARED = "Ask Shared Documents"
ASK_PUBLIC = "Ask Public Documents"
CHAT_PRIVATE = "General Chat"
AUTO_LOGIN = "Auto Login"


class LogType(Enum):
Expand Down
2 changes: 1 addition & 1 deletion source/docq/manage_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def _get_settings(org_id: int, user_id: int = None) -> dict:


def _update_settings(settings: dict, org_id: int, user_id: int = None) -> bool:
log.debug("Updating settings for user %d", user_id)
with closing(
sqlite3.connect(_get_sqlite_file(user_id), detect_types=sqlite3.PARSE_DECLTYPES)
) as connection, closing(connection.cursor()) as cursor:
user_id = user_id or USER_ID_AS_SYSTEM
log.debug("Updating settings for user %d", user_id)
cursor.executemany(
"INSERT OR REPLACE INTO settings (user_id, org_id, key, val) VALUES (?, ?, ?, ?)",
[(user_id, org_id, key, json.dumps(val)) for key, val in settings.items()],
Expand Down
218 changes: 115 additions & 103 deletions source/docq/support/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,70 @@
import os
from datetime import datetime, timedelta
from secrets import token_hex
from typing import Any, Dict, Optional
from typing import Dict, Optional

from cachetools import TTLCache
from cryptography.fernet import Fernet
from streamlit.components.v1 import html
from streamlit.web.server.websocket_headers import _get_websocket_headers

from ..config import COOKIE_NAME, ENV_VAR_COOKIE_SECRET_KEY, FeatureType
from ..manage_settings import SystemSettingsKey, get_organisation_settings
from ..config import SESSION_COOKIE_NAME, ENV_VAR_DOCQ_COOKIE_HMAC_SECRET_KEY

EXPIRY_HOURS = 4
CACHE_CONFIG = (1024 * 1, 60 * 60 * EXPIRY_HOURS)
AUTH_KEY = Fernet.generate_key()
AUTH_SESSION_SECRET_KEY: str = os.environ.get(ENV_VAR_COOKIE_SECRET_KEY)
AUTH_SESSION_SECRET_KEY: str = os.environ.get(ENV_VAR_DOCQ_COOKIE_HMAC_SECRET_KEY)

# Chase of session data keyed by session id
cached_sessions: TTLCache[str, bytes] = TTLCache(*CACHE_CONFIG)

# Cache of session id's keyed by hmac hash
session_data: TTLCache[str, str] = TTLCache(*CACHE_CONFIG)


# TODO: the code that handles the cookie should move to the web side. session state tracking is in the backend but not a public API as it's just cross cutting.

# Session Cache.
cached_sessions:TTLCache[str, bytes] = TTLCache(*CACHE_CONFIG)
session_data:TTLCache[str, str]= TTLCache(*CACHE_CONFIG)

def init_session_cache() -> None:
"""Initialize session cache."""
if AUTH_SESSION_SECRET_KEY is None:
log.fatal("Failed to initialize session cache: COOKIE_SECRET_KEY not set")
raise ValueError("COOKIE_SECRET_KEY must be set")
log.fatal("Failed to initialize session cache: DOCQ_COOKIE_HMAC_SECRET_KEY not set")
raise ValueError("DOCQ_COOKIE_HMAC_SECRET_KEY must be set")
if len(AUTH_SESSION_SECRET_KEY) < 32:
log.fatal("Failed to initialize session cache: COOKIE_SECRET_KEY must be 32 or more characters")
raise ValueError("COOKIE_SECRET_KEY must be 32 or more characters")
log.fatal("Failed to initialize session cache: DOCQ_COOKIE_HMAC_SECRET_KEY must be 32 or more characters")
raise ValueError("DOCQ_COOKIE_HMAC_SECRET_KEY must be 32 or more characters")


def _set_cookie(cookie: str) -> None:
"""Set client cookie for authentication."""
try:
expiry = datetime.now() + timedelta(hours=EXPIRY_HOURS)
html(f"""
html(
f"""
<script>
const secure = location.protocol === "https:" ? " secure;" : "";
document.cookie = "{COOKIE_NAME}={cookie}; expires={expiry.strftime('%a, %d %b %Y %H:%M:%S GMT')}; path=/; SameSite=Secure;" + secure;
document.cookie = "{SESSION_COOKIE_NAME}={cookie}; expires={expiry.strftime('%a, %d %b %Y %H:%M:%S GMT')}; path=/; SameSite=Secure;" + secure;
</script>
""", width=0, height=0)
""",
width=0,
height=0,
)
except Exception as e:
log.error("Failed to set cookie: %s", e)


def _clear_cookie() -> None:
def _clear_cookie(cookie_name: str) -> None:
"""Clear client cookie."""
html(f"""
html(
f"""
<script>
document.cookie = "{COOKIE_NAME}=; expires=Thu, 01 Jan 1970 00:00:00 GMT; path=/;";
document.cookie = "{cookie_name}=; expires=Thu, 01 Jan 1970 00:00:00 GMT; path=/;";
</script>
""", width=0, height=0)
""",
width=0,
height=0,
)
log.debug("Clear client cookie: %s", cookie_name)


def _get_cookies() -> Optional[Dict[str, str]]:
Expand All @@ -75,148 +89,146 @@ def _get_cookies() -> Optional[Dict[str, str]]:
return None


def _create_hmac( msg: str) -> str:
def _create_hmac(msg: str) -> str:
"""Create a HMAC hash."""
return hmac.new(
AUTH_SESSION_SECRET_KEY.encode(),
msg.encode(),
hashlib.sha256
).hexdigest()
return hmac.new(AUTH_SESSION_SECRET_KEY.encode(), msg.encode(), hashlib.sha256).hexdigest()


def _verify_hmac(msg: str, digest: str) -> bool:
"""Verify credibility of HMAC hash."""
return hmac.compare_digest(
_create_hmac(msg),
digest
)
return hmac.compare_digest(_create_hmac(msg), digest)


def generate_session_id(length: int = 32) -> str:
"""Generate a secure and unique session_id."""
def generate_hmac_session_id(length: int = 32) -> str:
"""Generate a secure (HMAC) and unique session_id then track in session cache."""
id_ = token_hex(length // 2)
hmac_ = _create_hmac(id_)
session_data[hmac_] = id_
log.debug("Generated new hmac session id: %s", hmac_)
return hmac_


def _set_session_id(session_id: str) -> None:
"""Set the session_id in the cookie."""
_set_cookie(session_id)
def _set_cookie_session_id(hmac_session_id: str) -> None:
"""Set the encrypted session_id in the cookie."""
_set_cookie(hmac_session_id)
log.debug("_set_cookie_session_id() - hmac session id: %s", hmac_session_id)


def _get_session_id() -> Optional[str]:
"""Return the session_id from the cookie."""
def _get_cookie_session_id() -> str | None:
"""Return the Docq encrypted HMAC session_id from the cookie."""
try:
hmac_session_id = None
cookies = _get_cookies()
if cookies is not None:
cookie = cookies.get(COOKIE_NAME)
if cookie is None:
return None
if cookie not in cached_sessions:
return None
if not _verify_hmac(session_data[cookie], cookie):
log.warning("Session ID not verified: %s", cookie)
return None
return cookie
hmac_session_id = cookies.get(SESSION_COOKIE_NAME)
return hmac_session_id
except Exception as e:
log.error("Failed to get session id: %s", e)
return None


def _encrypt_auth(*args: tuple, **kwargs: dict) -> bytes:
"""Encrypt the auth data."""
def verify_cookie_hmac_session_id() -> str | None:
"""Verify the encrypted session_id from the cookie.
Return:
str: The hmac_session_id if verified.
None: If not verified.
"""
hmac_session_id = None
hmac_session_id = _get_cookie_session_id()
if hmac_session_id is None:
log.debug("No session id in cookie found")
elif hmac_session_id not in cached_sessions:
log.debug(
"verify_cookie_hmac_session_id(): HMAC Session ID not found in cache. Session expired or was explicitly removed: %s"
)
hmac_session_id = None
elif hmac_session_id not in session_data or not _verify_hmac(session_data[hmac_session_id], hmac_session_id):
log.warning("verify_cookie_hmac_session_id(): HMAC Session ID failed verification: %s")
hmac_session_id = None
return hmac_session_id


def _encrypt(payload: dict) -> bytes:
"""Encrypt some data."""
try:
data = json.dumps([args, kwargs]).encode()
data = json.dumps(payload).encode()
cipher = Fernet(AUTH_KEY)
return cipher.encrypt(data)
except Exception as e:
log.error("Failed to encrypt auth data: %s", e)
return None


def _decrypt_auth(configs: bytes) -> tuple[tuple, dict]:
"""Decrypt the auth data."""
def _decrypt(encrypted_payload: bytes) -> dict:
"""Decrypt some data."""
try:
cipher = Fernet(AUTH_KEY)
data = cipher.decrypt(configs)
data_ = list(json.loads(data))
return tuple(data_[0]), data_[1]
data = cipher.decrypt(encrypted_payload)
result = json.loads(data)
return result
except Exception as e:
log.error("Failed to decrypt auth data: %s", e)
return None


def _update_auth_expiry(session_id: str) -> None:
def _reset_expiry_cache_auth_session(session_id: str) -> None:
"""Update the auth expiry time."""
try:
cached_sessions[session_id] = cached_sessions[session_id]
session_data[session_id] = session_data[session_id]
_set_session_id(session_id)
# _set_cookie_session_id(session_id)
except Exception as e:
log.error("Failed to update auth expiry: %s", e)


def _auto_login_enabled(org_id: int) -> bool:
"""Check if auto login feature is enabled."""
try:
system_settings = get_organisation_settings(org_id=org_id, key=SystemSettingsKey.ENABLED_FEATURES)
if system_settings: # Only enable feature when explicitly enabled (dafault to Disabled)
return FeatureType.AUTO_LOGIN.name in system_settings
return False
except Exception as e:
log.error("Failed to check if auto login is enabled: %s", e)
return False


def cache_session_state_configs(*args: tuple, **kwargs: dict[str, Any]) -> None:
"""Caches the session state configs for auth.
This will cache any arguments and keyword arguments passed to it and can be retrived
by calling the get_auth_configs function:
>>> docq.support.auth_utils.get_auth_configs()
def set_cache_auth_session(val: dict) -> None:
"""Caches the session state configs for auth, persisting across connections.
Args:
args: Arguments to be passed to the auth function.
kwargs: Keyword arguments to be passed to the auth function.
val (dict): The session state for auth.
"""
try:
session_id = _get_session_id()
if not session_id:
session_id = generate_session_id()
_set_session_id(session_id)
cached_sessions[session_id] = _encrypt_auth(*args, **kwargs)
_update_auth_expiry(session_id)
hmac_session_id = _get_cookie_session_id()
if hmac_session_id is None:
hmac_session_id = generate_hmac_session_id()
_set_cookie_session_id(hmac_session_id)
cached_sessions[hmac_session_id] = _encrypt(val)
_reset_expiry_cache_auth_session(hmac_session_id)
except Exception as e:
log.error("Error caching auth session state: %s", e)
log.error("Error caching auth session: %s", e)


def get_auth_configs() -> Optional[tuple[tuple, dict]]:
"""Get cached session state configs for auth."""
def get_cache_auth_session() -> dict | None:
"""Verify the session auth token and get the cached session state for the current session. The current session is identified by a session_id wrapped in a auth token in a browser session cookie."""
try:
session_id = _get_session_id()
if session_id in cached_sessions:
configs = cached_sessions[session_id]
args, kwargs = _decrypt_auth(configs)
selected_org_id = kwargs.get("selected_org_id") or args[1]
if not _auto_login_enabled(selected_org_id):
return None
return _decrypt_auth(configs)
else:
return None
decrypted_auth_session_data = None
hmac_session_id = _get_cookie_session_id()
if hmac_session_id in cached_sessions:
encrypted_auth_session_data = cached_sessions[hmac_session_id]
decrypted_auth_session_data = _decrypt(encrypted_auth_session_data)
return decrypted_auth_session_data
except Exception as e:
log.error("Failed to get auth result: %s", e)
log.error("Failed to get auth session from cache: %s", e)
return None


def session_logout() -> None:
"""Clear all the data used to remember user session."""
def remove_cache_auth_session() -> None:
"""Remove the cached session state for the current session. The current session is identified by a session_id in a particular browsersession cookie."""
try:
hmac_session_id = _get_cookie_session_id()
if hmac_session_id in cached_sessions:
del cached_sessions[hmac_session_id]
if hmac_session_id in session_data:
del session_data[hmac_session_id]
except Exception as e:
log.error("Failed to remove auth session from cache: %s", e)


def reset_cache_and_cookie_auth_session() -> None:
"""Clear all the data used to remember user session (auth session cache and session cookie). This must be called at login and cookie."""
try:
session_id = _get_session_id()
if session_id in cached_sessions:
del cached_sessions[session_id]
if session_id in session_data:
del session_data[session_id]
_clear_cookie()
remove_cache_auth_session()
_clear_cookie(SESSION_COOKIE_NAME)
except Exception as e:
log.error("Failed to logout: %s", e)
log.error("Failed to clear session data caches (hmac, session data, and session cookie ): %s", e)
Loading

0 comments on commit 2ec4826

Please sign in to comment.