-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ebaf032
commit 0a20738
Showing
3 changed files
with
386 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import os | ||
import json | ||
import logging | ||
from datetime import datetime, timedelta, timezone | ||
from uuid import uuid4 | ||
|
||
import jwt | ||
|
||
from whispr.server.config import ( | ||
TOKEN_EXPIRATION_SECONDS, | ||
TOKEN_ALGORITHM, | ||
TOKEN_DIR, | ||
) | ||
import whispr.server.utils.file as file_utils | ||
|
||
# Setup logging | ||
logger = logging.getLogger(__name__) | ||
|
||
# Define file paths for token persistence and revocation list | ||
TOKEN_STORAGE_FILE = os.path.join(TOKEN_DIR, "tokens.json") | ||
REVOCATION_FILE = os.path.join(TOKEN_DIR, "revoked_tokens.json") | ||
|
||
# In-memory cache for revoked token IDs (jti) | ||
_revoked_tokens = set() | ||
|
||
|
||
def _load_revoked_tokens(): | ||
"""Load the revoked tokens list from persistent storage.""" | ||
global _revoked_tokens | ||
if os.path.exists(REVOCATION_FILE): | ||
try: | ||
data = file_utils.read_file(REVOCATION_FILE) | ||
# Expecting a JSON list of revoked jti strings. | ||
_revoked_tokens = set(json.loads(data)) | ||
logger.info("Loaded revoked tokens from %s", REVOCATION_FILE) | ||
except Exception as e: | ||
logger.error("Failed to load revoked tokens: %s", e) | ||
_revoked_tokens = set() | ||
else: | ||
_revoked_tokens = set() | ||
|
||
|
||
def _save_revoked_tokens(): | ||
"""Persist the revoked tokens list to storage.""" | ||
try: | ||
# Convert the set to a list before saving. | ||
revoked_t_bytes = json.dumps(list(_revoked_tokens)).encode("utf-8") | ||
file_utils.write_file_atomic(REVOCATION_FILE, revoked_t_bytes) | ||
logger.info("Saved revoked tokens to %s", REVOCATION_FILE) | ||
except Exception as e: | ||
logger.error("Failed to save revoked tokens: %s", e) | ||
|
||
|
||
# Initialize revoked tokens cache on module load. | ||
_load_revoked_tokens() | ||
|
||
|
||
def generate_token(payload: dict, secret: str) -> str: | ||
""" | ||
Generate a new JWT token with the provided payload and secret. | ||
:param payload: JWT Payload | ||
:param secret: Secret to encode JWT | ||
Returns: | ||
The encoded JWT token as a string. | ||
""" | ||
# Use a timezone-aware UTC datetime | ||
now = datetime.now(timezone.utc) | ||
expiration_time = now + timedelta(seconds=TOKEN_EXPIRATION_SECONDS) | ||
payload["exp"] = int( | ||
expiration_time.timestamp() | ||
) # Ensure it's an integer timestamp | ||
|
||
# Optionally add a token identifier for revocation tracking if not already present. | ||
if "jti" not in payload: | ||
payload["jti"] = str(uuid4()) | ||
|
||
try: | ||
token = jwt.encode(payload, secret, algorithm=TOKEN_ALGORITHM) | ||
logger.info("Generated token for jti: %s", payload["jti"]) | ||
return token | ||
except Exception as e: | ||
logger.error("Error generating token: %s", e) | ||
raise | ||
|
||
|
||
def renew_token(old_token: str, secret: str) -> str: | ||
""" | ||
Renew an existing token if it is still valid. | ||
Validates the provided token and, if unexpired, creates a new token with an extended expiration. | ||
Raises: | ||
Exception: if the token is invalid or expired. | ||
Returns: | ||
A new JWT token string. | ||
""" | ||
# Validate old token first. If invalid, validate_token will raise an exception. | ||
payload = validate_token(old_token, secret) | ||
|
||
# Remove the old expiration; you might also choose to update other session data if needed. | ||
payload.pop("exp", None) | ||
# Optionally, generate a new token id for the renewed token. | ||
payload["jti"] = str(uuid4()) | ||
|
||
new_token = generate_token(payload, secret) | ||
logger.info("Renewed token; new jti: %s", payload["jti"]) | ||
return new_token | ||
|
||
|
||
def validate_token(token: str, secret: str) -> dict: | ||
""" | ||
Validate the provided JWT token. | ||
Decodes the token, verifies its signature, expiration, and checks whether it has been revoked. | ||
Raises: | ||
jwt.ExpiredSignatureError: if the token has expired. | ||
jwt.InvalidTokenError: if token validation fails. | ||
Exception: if the token has been revoked. | ||
Returns: | ||
The decoded token payload as a dictionary. | ||
""" | ||
try: | ||
# Decode and verify the token. This raises exceptions on errors. | ||
payload = jwt.decode(token, secret, algorithms=[TOKEN_ALGORITHM]) | ||
except jwt.ExpiredSignatureError as e: | ||
logger.error("Token expired: %s", e) | ||
raise | ||
except jwt.InvalidTokenError as e: | ||
logger.error("Invalid token: %s", e) | ||
raise | ||
|
||
# Check if the token is revoked based on its unique identifier (jti). | ||
token_jti = payload.get("jti") | ||
if token_jti and is_token_revoked(token_jti): | ||
msg = "Token has been revoked" | ||
logger.error(msg) | ||
raise Exception(msg) | ||
|
||
logger.info("Validated token with jti: %s", token_jti) | ||
return payload | ||
|
||
|
||
def revoke_token(token: str) -> None: | ||
""" | ||
Revoke a token via administrative action. | ||
Validates the token to extract its unique identifier (jti) and adds it to the revocation list. | ||
Persists the updated revocation state. | ||
Raises: | ||
Exception: if token is invalid or if persistence fails. | ||
""" | ||
try: | ||
# Decode without verifying expiration (if needed) to get the jti. | ||
# Here we use the secret if available, assuming the token structure contains a valid jti. | ||
# Alternatively, you can decode without verifying signature (jwt.decode(..., options={"verify_signature": False})) | ||
payload = jwt.decode(token, options={"verify_signature": False}) | ||
token_jti = payload.get("jti") | ||
if not token_jti: | ||
raise Exception("Token does not contain a jti claim and cannot be revoked.") | ||
except Exception as e: | ||
logger.error("Failed to decode token for revocation: %s", e) | ||
raise | ||
|
||
_revoked_tokens.add(token_jti) | ||
_save_revoked_tokens() | ||
logger.info("Token revoked, jti: %s", token_jti) | ||
|
||
|
||
def is_token_revoked(jti: str) -> bool: | ||
""" | ||
Check if a token (by its unique identifier) has been revoked. | ||
Returns: | ||
True if the token has been revoked, False otherwise. | ||
""" | ||
return jti in _revoked_tokens | ||
|
||
|
||
def persist_token(token: str) -> None: | ||
""" | ||
Persist token data to the token storage directory. | ||
Writes the token string (and optionally its expiration) to a file under TOKEN_STORAGE_FILE. | ||
Raises: | ||
Exception: if the persistence fails. | ||
""" | ||
# Load existing tokens (if any) | ||
tokens = [] | ||
if os.path.exists(TOKEN_STORAGE_FILE): | ||
try: | ||
data = file_utils.read_file(TOKEN_STORAGE_FILE) | ||
tokens = json.loads(data) | ||
except Exception as e: | ||
logger.warning("Could not load existing tokens, starting fresh: %s", e) | ||
|
||
tokens.append(token) | ||
try: | ||
t_bytes = json.dumps(tokens).encode("utf-8") | ||
file_utils.write_file_atomic(TOKEN_STORAGE_FILE, t_bytes) | ||
logger.info("Persisted token to %s", TOKEN_STORAGE_FILE) | ||
except Exception as e: | ||
logger.error("Failed to persist token: %s", e) | ||
raise | ||
|
||
|
||
def load_persisted_tokens() -> list: | ||
""" | ||
Load persisted tokens from the token storage directory. | ||
Returns: | ||
A list of token strings read from the storage file. | ||
""" | ||
if os.path.exists(TOKEN_STORAGE_FILE): | ||
try: | ||
data = file_utils.read_file(TOKEN_STORAGE_FILE) | ||
tokens = json.loads(data) | ||
logger.info("Loaded %d persisted tokens", len(tokens)) | ||
return tokens | ||
except Exception as e: | ||
logger.error("Failed to load persisted tokens: %s", e) | ||
return [] | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
"""Tests for auth/jwt_manager module""" | ||
|
||
import json | ||
import unittest | ||
from datetime import timedelta, timezone | ||
import datetime as dt | ||
from freezegun import freeze_time | ||
|
||
import jwt | ||
from jwt import ExpiredSignatureError, InvalidTokenError | ||
from unittest.mock import patch | ||
|
||
# Import the module to test | ||
from whispr.server.auth import jwt_manager | ||
from whispr.server.config import TOKEN_EXPIRATION_SECONDS, TOKEN_ALGORITHM | ||
|
||
# A sample secret for testing purposes. | ||
TEST_SECRET = "test_secret_key" | ||
|
||
class JWTManagerTestCase(unittest.TestCase): | ||
"""Unit tests for the JWT manager functions.""" | ||
|
||
def setUp(self): | ||
"""Reset the revoked tokens set before each test.""" | ||
jwt_manager._revoked_tokens.clear() | ||
|
||
@freeze_time("2024-01-01 00:00:00+00:00") | ||
def test_generate_token_includes_exp_and_jti(self): | ||
"""Test that generate_token adds expiration and jti to the payload with UTC-aware datetimes.""" | ||
from datetime import datetime | ||
|
||
# datetime.now() will now return the frozen time. | ||
fixed_now = datetime.now(timezone.utc) | ||
payload = {"user_id": 123} | ||
token = jwt_manager.generate_token(payload.copy(), TEST_SECRET) | ||
decoded = jwt.decode(token, TEST_SECRET, algorithms=[TOKEN_ALGORITHM]) | ||
self.assertIn("exp", decoded) | ||
self.assertIn("jti", decoded) | ||
|
||
# Verify jti is a valid UUID4. | ||
from uuid import UUID | ||
try: | ||
UUID(decoded["jti"], version=4) | ||
except ValueError: | ||
self.fail("jti is not a valid UUID4 string") | ||
|
||
# Calculate expected expiration timestamp. | ||
expected_exp = fixed_now + timedelta(seconds=TOKEN_EXPIRATION_SECONDS) | ||
self.assertAlmostEqual( | ||
decoded["exp"], int(expected_exp.timestamp()), delta=5, | ||
msg=f"Expected {int(expected_exp.timestamp())} but got {decoded['exp']}" | ||
) | ||
|
||
@freeze_time("2024-01-01 00:00:00+00:00") | ||
def test_renew_token_valid(self): | ||
"""Test that renew_token creates a new token with extended expiration and new jti using UTC-aware time.""" | ||
from datetime import datetime | ||
|
||
# Freeze time at the initial moment. | ||
fixed_now = datetime.now(timezone.utc) | ||
payload = {"user_id": 456} | ||
original_token = jwt_manager.generate_token(payload.copy(), TEST_SECRET) | ||
|
||
# Advance time by 1 second for renewal. | ||
new_time = fixed_now + timedelta(seconds=1) | ||
with freeze_time(new_time): | ||
new_token = jwt_manager.renew_token(original_token, TEST_SECRET) | ||
decoded_new = jwt.decode(new_token, TEST_SECRET, algorithms=[TOKEN_ALGORITHM]) | ||
|
||
decoded_original = jwt.decode(original_token, TEST_SECRET, algorithms=[TOKEN_ALGORITHM]) | ||
self.assertNotEqual(decoded_original.get("jti"), decoded_new.get("jti")) | ||
|
||
expected_new_exp = new_time + timedelta(seconds=TOKEN_EXPIRATION_SECONDS) | ||
self.assertAlmostEqual( | ||
decoded_new["exp"], int(expected_new_exp.timestamp()), delta=5, | ||
msg=f"Expected {int(expected_new_exp.timestamp())} but got {decoded_new['exp']}" | ||
) | ||
|
||
def test_validate_token_valid(self): | ||
"""Test that a valid token is decoded correctly.""" | ||
payload = {"user_id": 789} | ||
token = jwt_manager.generate_token(payload.copy(), TEST_SECRET) | ||
decoded = jwt_manager.validate_token(token, TEST_SECRET) | ||
self.assertEqual(decoded.get("user_id"), payload["user_id"]) | ||
self.assertIn("jti", decoded) | ||
|
||
def test_validate_token_expired(self): | ||
"""Test that validate_token raises ExpiredSignatureError for expired tokens.""" | ||
payload = {"user_id": 101} | ||
# Manually set expiration to past time. | ||
past = dt.datetime.utcnow() - timedelta(seconds=10) | ||
payload["exp"] = past | ||
payload["jti"] = "dummy-jti" | ||
token = jwt.encode(payload, TEST_SECRET, algorithm=TOKEN_ALGORITHM) | ||
with self.assertRaises(ExpiredSignatureError): | ||
jwt_manager.validate_token(token, TEST_SECRET) | ||
|
||
def test_validate_token_invalid_signature(self): | ||
"""Test that validate_token raises InvalidTokenError for tampered tokens.""" | ||
payload = {"user_id": 202} | ||
token = jwt_manager.generate_token(payload.copy(), TEST_SECRET) | ||
|
||
# Split the token into header, payload, and signature. | ||
parts = token.split('.') | ||
self.assertEqual(len(parts), 3, "Token must have three parts separated by dots") | ||
|
||
# Tamper with the signature part explicitly. | ||
sig = parts[2] | ||
# Change the last character in the signature part. | ||
tampered_sig = sig[:-1] + ('a' if sig[-1] != 'a' else 'b') | ||
tampered_token = f"{parts[0]}.{parts[1]}.{tampered_sig}" | ||
|
||
with self.assertRaises(InvalidTokenError): | ||
jwt_manager.validate_token(tampered_token, TEST_SECRET) | ||
|
||
@patch("whispr.server.auth.jwt_manager.is_token_revoked", return_value=True) | ||
def test_validate_token_revoked(self, mock_revoked): | ||
"""Test that validate_token raises an exception when token is revoked.""" | ||
payload = {"user_id": 303} | ||
token = jwt_manager.generate_token(payload.copy(), TEST_SECRET) | ||
with self.assertRaises(Exception) as context: | ||
jwt_manager.validate_token(token, TEST_SECRET) | ||
self.assertIn("revoked", str(context.exception).lower()) | ||
mock_revoked.assert_called() | ||
|
||
@patch("whispr.server.auth.jwt_manager.file_utils.write_file_atomic") | ||
def test_revoke_token(self, mock_write_atomic): | ||
"""Test that revoke_token adds token jti to the revocation list and calls atomic file write.""" | ||
payload = {"user_id": 404} | ||
token = jwt_manager.generate_token(payload.copy(), TEST_SECRET) | ||
decoded = jwt.decode(token, TEST_SECRET, algorithms=[TOKEN_ALGORITHM]) | ||
jti = decoded.get("jti") | ||
self.assertNotIn(jti, jwt_manager._revoked_tokens) | ||
jwt_manager.revoke_token(token) | ||
self.assertIn(jti, jwt_manager._revoked_tokens) | ||
# Ensure atomic write was called for persistence. | ||
self.assertTrue(mock_write_atomic.called) | ||
|
||
@patch("whispr.server.auth.jwt_manager.file_utils.write_file_atomic") | ||
@patch("whispr.server.auth.jwt_manager.file_utils.read_file") | ||
def test_persist_and_load_tokens(self, mock_read_file, mock_write_atomic): | ||
"""Test persist_token and load_persisted_tokens functions using file_utils mocks.""" | ||
# Simulate no tokens file existing initially. | ||
with patch("os.path.exists", return_value=False): | ||
# Persist a token. | ||
test_token = "sample.token.value" | ||
jwt_manager.persist_token(test_token) | ||
self.assertTrue(mock_write_atomic.called) | ||
|
||
# Now simulate loading tokens from a file. | ||
tokens_list = ["token1", "token2"] | ||
mock_read_file.return_value = json.dumps(tokens_list) | ||
with patch("os.path.exists", return_value=True): | ||
loaded_tokens = jwt_manager.load_persisted_tokens() | ||
self.assertEqual(loaded_tokens, tokens_list) | ||
|
||
@patch("whispr.server.auth.jwt_manager.file_utils.write_file_atomic") | ||
def test_persist_token_file_error(self, mock_write_atomic): | ||
"""Test persist_token raises an error when file_utils.write_file_atomic fails.""" | ||
mock_write_atomic.side_effect = Exception("Write error") | ||
with patch("os.path.exists", return_value=False): | ||
with self.assertRaises(Exception) as context: | ||
jwt_manager.persist_token("token_error") | ||
self.assertIn("Write error", str(context.exception)) |