Skip to content

Commit

Permalink
fix: session caching bug
Browse files Browse the repository at this point in the history
refactor: cache var names
  • Loading branch information
janaka committed Oct 1, 2023
1 parent 2ec4826 commit f7675b9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 33 deletions.
62 changes: 37 additions & 25 deletions source/docq/support/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@
from streamlit.components.v1 import html
from streamlit.web.server.websocket_headers import _get_websocket_headers

from ..config import SESSION_COOKIE_NAME, ENV_VAR_DOCQ_COOKIE_HMAC_SECRET_KEY
from ..config import ENV_VAR_DOCQ_COOKIE_HMAC_SECRET_KEY, SESSION_COOKIE_NAME

EXPIRY_HOURS = 4
CACHE_CONFIG = (1024 * 1, 60 * 60 * EXPIRY_HOURS)
TTL = 60 * 60 * EXPIRY_HOURS
CACHE_CONFIG = (1024 * 1, TTL)
AUTH_KEY = Fernet.generate_key()
AUTH_SESSION_SECRET_KEY: str = os.environ.get(ENV_VAR_DOCQ_COOKIE_HMAC_SECRET_KEY)

# Chase of session data keyed by session id
cached_sessions: TTLCache[str, bytes] = TTLCache(*CACHE_CONFIG)
# Cache of session data keyed by hmac hash (hmac of session id)
cached_session_data: TTLCache[str, bytes] = TTLCache(*CACHE_CONFIG)

# Cache of session id's keyed by hmac hash
session_data: TTLCache[str, str] = TTLCache(*CACHE_CONFIG)
# Cache of session id's keyed by hmac hash (hmac of session id)
cached_session_ids: TTLCache[str, str] = TTLCache(*CACHE_CONFIG)


# TODO: the code that handles the cookie should move to the web side. session state tracking is in the backend but not a public API as it's just cross cutting.
Expand Down Expand Up @@ -103,7 +104,7 @@ def generate_hmac_session_id(length: int = 32) -> str:
"""Generate a secure (HMAC) and unique session_id then track in session cache."""
id_ = token_hex(length // 2)
hmac_ = _create_hmac(id_)
session_data[hmac_] = id_
cached_session_ids[hmac_] = id_
log.debug("Generated new hmac session id: %s", hmac_)
return hmac_

Expand Down Expand Up @@ -136,15 +137,18 @@ def verify_cookie_hmac_session_id() -> str | None:
"""
hmac_session_id = None
hmac_session_id = _get_cookie_session_id()

if hmac_session_id is None:
log.debug("No session id in cookie found")
elif hmac_session_id not in cached_sessions:
log.debug(
"verify_cookie_hmac_session_id(): HMAC Session ID not found in cache. Session expired or was explicitly removed: %s"
log.debug("verify_cookie_hmac_session_id(): No session id (auth token) cookie found.")
elif hmac_session_id not in cached_session_ids:
log.warning(
"verify_cookie_hmac_session_id(): item with key=hmac_session_id `cached_session_ids`. The auth session either expired or explicitly removed."
)
log.debug("cached session ids : %s", cached_session_ids.keys())
log.debug("cached session data: %s", cached_session_data.keys())
hmac_session_id = None
elif hmac_session_id not in session_data or not _verify_hmac(session_data[hmac_session_id], hmac_session_id):
log.warning("verify_cookie_hmac_session_id(): HMAC Session ID failed verification: %s")
elif not _verify_hmac(cached_session_ids[hmac_session_id], hmac_session_id):
log.warning("verify_cookie_hmac_session_id(): HMAC Session ID failed verification.")
hmac_session_id = None
return hmac_session_id

Expand Down Expand Up @@ -172,11 +176,11 @@ def _decrypt(encrypted_payload: bytes) -> dict:
return None


def _reset_expiry_cache_auth_session(session_id: str) -> None:
def _reset_expiry_cache_auth_session(hmac_session_id: str) -> None:
"""Update the auth expiry time."""
try:
cached_sessions[session_id] = cached_sessions[session_id]
session_data[session_id] = session_data[session_id]
cached_session_data[hmac_session_id] = cached_session_data[hmac_session_id]
cached_session_ids[hmac_session_id] = cached_session_ids[hmac_session_id]
# _set_cookie_session_id(session_id)
except Exception as e:
log.error("Failed to update auth expiry: %s", e)
Expand All @@ -190,10 +194,16 @@ def set_cache_auth_session(val: dict) -> None:
"""
try:
hmac_session_id = _get_cookie_session_id()
if hmac_session_id is None:
log.debug("set_cache_auth_session() - hmac session id: %s", hmac_session_id)

if hmac_session_id is None or hmac_session_id not in cached_session_ids:
log.debug(
"set_cache_auth_session() - Valid session id (auth token) not found. session_data: %s",
cached_session_ids.keys(),
)
hmac_session_id = generate_hmac_session_id()
_set_cookie_session_id(hmac_session_id)
cached_sessions[hmac_session_id] = _encrypt(val)
_set_cookie_session_id(hmac_session_id)
cached_session_data[hmac_session_id] = _encrypt(val)
_reset_expiry_cache_auth_session(hmac_session_id)
except Exception as e:
log.error("Error caching auth session: %s", e)
Expand All @@ -204,8 +214,8 @@ def get_cache_auth_session() -> dict | None:
try:
decrypted_auth_session_data = None
hmac_session_id = _get_cookie_session_id()
if hmac_session_id in cached_sessions:
encrypted_auth_session_data = cached_sessions[hmac_session_id]
if hmac_session_id in cached_session_data:
encrypted_auth_session_data = cached_session_data[hmac_session_id]
decrypted_auth_session_data = _decrypt(encrypted_auth_session_data)
return decrypted_auth_session_data
except Exception as e:
Expand All @@ -217,10 +227,12 @@ def remove_cache_auth_session() -> None:
"""Remove the cached session state for the current session. The current session is identified by a session_id in a particular browsersession cookie."""
try:
hmac_session_id = _get_cookie_session_id()
if hmac_session_id in cached_sessions:
del cached_sessions[hmac_session_id]
if hmac_session_id in session_data:
del session_data[hmac_session_id]
if hmac_session_id in cached_session_data:
del cached_session_data[hmac_session_id]
log.debug("Removed from cached_session: %s", hmac_session_id)
if hmac_session_id in cached_session_ids:
del cached_session_ids[hmac_session_id]
log.debug("Removed from session_data: %s", hmac_session_id)
except Exception as e:
log.error("Failed to remove auth session from cache: %s", e)

Expand Down
16 changes: 8 additions & 8 deletions tests/docq/support/auth_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
_set_cookie,
_set_cookie_session_id,
_verify_hmac,
cached_sessions,
cached_session_data,
generate_hmac_session_id,
get_cache_auth_session,
reset_cache_and_cookie_auth_session,
session_data,
cached_session_ids,
set_cache_auth_session,
)

Expand Down Expand Up @@ -81,7 +81,7 @@ def test_set_session_id(self: Self, mock_set_cookie: Mock) -> None:
def test_get_cookie_session_id(self: Self, mock_get_cookies: Mock) -> None:
"""Test get session id."""
session_id = generate_hmac_session_id()
cached_sessions[session_id] = _encrypt(("9999", "user", 1))
cached_session_data[session_id] = _encrypt(("9999", "user", 1))
mock_get_cookies.return_value = {SESSION_COOKIE_NAME: session_id}
result = _get_cookie_session_id()
assert result == session_id
Expand All @@ -100,7 +100,7 @@ def test_cache_auth(self: Self, mock_get_cookie_session_id: Mock) -> None:
session_id = generate_hmac_session_id()
mock_get_cookie_session_id.return_value = session_id
set_cache_auth_session(payload)
assert session_id in cached_sessions
assert session_id in cached_session_data

@patch("docq.support.auth_utils._get_cookie_session_id")
def test_auth_result(
Expand All @@ -120,9 +120,9 @@ def test_auth_result(
def test_session_logout(self: Self, mock_get_cookie_session_id: Mock) -> None:
"""Test session logout."""
session_id = generate_hmac_session_id()
cached_sessions[session_id] = _encrypt(("9999", "user", 1))
session_data[session_id] = session_id
cached_session_data[session_id] = _encrypt(("9999", "user", 1))
cached_session_ids[session_id] = session_id
mock_get_cookie_session_id.return_value = session_id
reset_cache_and_cookie_auth_session()
assert session_id not in cached_sessions, "Cached session should be deleted on logout"
assert session_id not in session_data, "Session data should be deleted on logout"
assert session_id not in cached_session_data, "Cached session should be deleted on logout"
assert session_id not in cached_session_ids, "Session data should be deleted on logout"

0 comments on commit f7675b9

Please sign in to comment.