Skip to content

Commit

Permalink
feat: add JWT manager and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
narenaryan committed Feb 3, 2025
1 parent ebaf032 commit 0a20738
Show file tree
Hide file tree
Showing 3 changed files with 386 additions and 0 deletions.
Empty file.
222 changes: 222 additions & 0 deletions src/whispr/server/auth/jwt_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import os
import json
import logging
from datetime import datetime, timedelta, timezone
from uuid import uuid4

import jwt

from whispr.server.config import (
TOKEN_EXPIRATION_SECONDS,
TOKEN_ALGORITHM,
TOKEN_DIR,
)
import whispr.server.utils.file as file_utils

# Setup logging
logger = logging.getLogger(__name__)

# Define file paths for token persistence and revocation list
TOKEN_STORAGE_FILE = os.path.join(TOKEN_DIR, "tokens.json")
REVOCATION_FILE = os.path.join(TOKEN_DIR, "revoked_tokens.json")

# In-memory cache for revoked token IDs (jti)
_revoked_tokens = set()


def _load_revoked_tokens():
"""Load the revoked tokens list from persistent storage."""
global _revoked_tokens
if os.path.exists(REVOCATION_FILE):
try:
data = file_utils.read_file(REVOCATION_FILE)
# Expecting a JSON list of revoked jti strings.
_revoked_tokens = set(json.loads(data))
logger.info("Loaded revoked tokens from %s", REVOCATION_FILE)
except Exception as e:
logger.error("Failed to load revoked tokens: %s", e)
_revoked_tokens = set()
else:
_revoked_tokens = set()


def _save_revoked_tokens():
"""Persist the revoked tokens list to storage."""
try:
# Convert the set to a list before saving.
revoked_t_bytes = json.dumps(list(_revoked_tokens)).encode("utf-8")
file_utils.write_file_atomic(REVOCATION_FILE, revoked_t_bytes)
logger.info("Saved revoked tokens to %s", REVOCATION_FILE)
except Exception as e:
logger.error("Failed to save revoked tokens: %s", e)


# Initialize revoked tokens cache on module load.
_load_revoked_tokens()


def generate_token(payload: dict, secret: str) -> str:
"""
Generate a new JWT token with the provided payload and secret.
:param payload: JWT Payload
:param secret: Secret to encode JWT
Returns:
The encoded JWT token as a string.
"""
# Use a timezone-aware UTC datetime
now = datetime.now(timezone.utc)
expiration_time = now + timedelta(seconds=TOKEN_EXPIRATION_SECONDS)
payload["exp"] = int(
expiration_time.timestamp()
) # Ensure it's an integer timestamp

# Optionally add a token identifier for revocation tracking if not already present.
if "jti" not in payload:
payload["jti"] = str(uuid4())

try:
token = jwt.encode(payload, secret, algorithm=TOKEN_ALGORITHM)
logger.info("Generated token for jti: %s", payload["jti"])
return token
except Exception as e:
logger.error("Error generating token: %s", e)
raise


def renew_token(old_token: str, secret: str) -> str:
"""
Renew an existing token if it is still valid.
Validates the provided token and, if unexpired, creates a new token with an extended expiration.
Raises:
Exception: if the token is invalid or expired.
Returns:
A new JWT token string.
"""
# Validate old token first. If invalid, validate_token will raise an exception.
payload = validate_token(old_token, secret)

# Remove the old expiration; you might also choose to update other session data if needed.
payload.pop("exp", None)
# Optionally, generate a new token id for the renewed token.
payload["jti"] = str(uuid4())

new_token = generate_token(payload, secret)
logger.info("Renewed token; new jti: %s", payload["jti"])
return new_token


def validate_token(token: str, secret: str) -> dict:
"""
Validate the provided JWT token.
Decodes the token, verifies its signature, expiration, and checks whether it has been revoked.
Raises:
jwt.ExpiredSignatureError: if the token has expired.
jwt.InvalidTokenError: if token validation fails.
Exception: if the token has been revoked.
Returns:
The decoded token payload as a dictionary.
"""
try:
# Decode and verify the token. This raises exceptions on errors.
payload = jwt.decode(token, secret, algorithms=[TOKEN_ALGORITHM])
except jwt.ExpiredSignatureError as e:
logger.error("Token expired: %s", e)
raise
except jwt.InvalidTokenError as e:
logger.error("Invalid token: %s", e)
raise

# Check if the token is revoked based on its unique identifier (jti).
token_jti = payload.get("jti")
if token_jti and is_token_revoked(token_jti):
msg = "Token has been revoked"
logger.error(msg)
raise Exception(msg)

logger.info("Validated token with jti: %s", token_jti)
return payload


def revoke_token(token: str) -> None:
"""
Revoke a token via administrative action.
Validates the token to extract its unique identifier (jti) and adds it to the revocation list.
Persists the updated revocation state.
Raises:
Exception: if token is invalid or if persistence fails.
"""
try:
# Decode without verifying expiration (if needed) to get the jti.
# Here we use the secret if available, assuming the token structure contains a valid jti.
# Alternatively, you can decode without verifying signature (jwt.decode(..., options={"verify_signature": False}))
payload = jwt.decode(token, options={"verify_signature": False})
token_jti = payload.get("jti")
if not token_jti:
raise Exception("Token does not contain a jti claim and cannot be revoked.")
except Exception as e:
logger.error("Failed to decode token for revocation: %s", e)
raise

_revoked_tokens.add(token_jti)
_save_revoked_tokens()
logger.info("Token revoked, jti: %s", token_jti)


def is_token_revoked(jti: str) -> bool:
"""
Check if a token (by its unique identifier) has been revoked.
Returns:
True if the token has been revoked, False otherwise.
"""
return jti in _revoked_tokens


def persist_token(token: str) -> None:
"""
Persist token data to the token storage directory.
Writes the token string (and optionally its expiration) to a file under TOKEN_STORAGE_FILE.
Raises:
Exception: if the persistence fails.
"""
# Load existing tokens (if any)
tokens = []
if os.path.exists(TOKEN_STORAGE_FILE):
try:
data = file_utils.read_file(TOKEN_STORAGE_FILE)
tokens = json.loads(data)
except Exception as e:
logger.warning("Could not load existing tokens, starting fresh: %s", e)

tokens.append(token)
try:
t_bytes = json.dumps(tokens).encode("utf-8")
file_utils.write_file_atomic(TOKEN_STORAGE_FILE, t_bytes)
logger.info("Persisted token to %s", TOKEN_STORAGE_FILE)
except Exception as e:
logger.error("Failed to persist token: %s", e)
raise


def load_persisted_tokens() -> list:
"""
Load persisted tokens from the token storage directory.
Returns:
A list of token strings read from the storage file.
"""
if os.path.exists(TOKEN_STORAGE_FILE):
try:
data = file_utils.read_file(TOKEN_STORAGE_FILE)
tokens = json.loads(data)
logger.info("Loaded %d persisted tokens", len(tokens))
return tokens
except Exception as e:
logger.error("Failed to load persisted tokens: %s", e)
return []
return []
164 changes: 164 additions & 0 deletions tests/test_server_jwt_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Tests for auth/jwt_manager module"""

import json
import unittest
from datetime import timedelta, timezone
import datetime as dt
from freezegun import freeze_time

import jwt
from jwt import ExpiredSignatureError, InvalidTokenError
from unittest.mock import patch

# Import the module to test
from whispr.server.auth import jwt_manager
from whispr.server.config import TOKEN_EXPIRATION_SECONDS, TOKEN_ALGORITHM

# A sample secret for testing purposes.
TEST_SECRET = "test_secret_key"

class JWTManagerTestCase(unittest.TestCase):
"""Unit tests for the JWT manager functions."""

def setUp(self):
"""Reset the revoked tokens set before each test."""
jwt_manager._revoked_tokens.clear()

@freeze_time("2024-01-01 00:00:00+00:00")
def test_generate_token_includes_exp_and_jti(self):
"""Test that generate_token adds expiration and jti to the payload with UTC-aware datetimes."""
from datetime import datetime

# datetime.now() will now return the frozen time.
fixed_now = datetime.now(timezone.utc)
payload = {"user_id": 123}
token = jwt_manager.generate_token(payload.copy(), TEST_SECRET)
decoded = jwt.decode(token, TEST_SECRET, algorithms=[TOKEN_ALGORITHM])
self.assertIn("exp", decoded)
self.assertIn("jti", decoded)

# Verify jti is a valid UUID4.
from uuid import UUID
try:
UUID(decoded["jti"], version=4)
except ValueError:
self.fail("jti is not a valid UUID4 string")

# Calculate expected expiration timestamp.
expected_exp = fixed_now + timedelta(seconds=TOKEN_EXPIRATION_SECONDS)
self.assertAlmostEqual(
decoded["exp"], int(expected_exp.timestamp()), delta=5,
msg=f"Expected {int(expected_exp.timestamp())} but got {decoded['exp']}"
)

@freeze_time("2024-01-01 00:00:00+00:00")
def test_renew_token_valid(self):
"""Test that renew_token creates a new token with extended expiration and new jti using UTC-aware time."""
from datetime import datetime

# Freeze time at the initial moment.
fixed_now = datetime.now(timezone.utc)
payload = {"user_id": 456}
original_token = jwt_manager.generate_token(payload.copy(), TEST_SECRET)

# Advance time by 1 second for renewal.
new_time = fixed_now + timedelta(seconds=1)
with freeze_time(new_time):
new_token = jwt_manager.renew_token(original_token, TEST_SECRET)
decoded_new = jwt.decode(new_token, TEST_SECRET, algorithms=[TOKEN_ALGORITHM])

decoded_original = jwt.decode(original_token, TEST_SECRET, algorithms=[TOKEN_ALGORITHM])
self.assertNotEqual(decoded_original.get("jti"), decoded_new.get("jti"))

expected_new_exp = new_time + timedelta(seconds=TOKEN_EXPIRATION_SECONDS)
self.assertAlmostEqual(
decoded_new["exp"], int(expected_new_exp.timestamp()), delta=5,
msg=f"Expected {int(expected_new_exp.timestamp())} but got {decoded_new['exp']}"
)

def test_validate_token_valid(self):
"""Test that a valid token is decoded correctly."""
payload = {"user_id": 789}
token = jwt_manager.generate_token(payload.copy(), TEST_SECRET)
decoded = jwt_manager.validate_token(token, TEST_SECRET)
self.assertEqual(decoded.get("user_id"), payload["user_id"])
self.assertIn("jti", decoded)

def test_validate_token_expired(self):
"""Test that validate_token raises ExpiredSignatureError for expired tokens."""
payload = {"user_id": 101}
# Manually set expiration to past time.
past = dt.datetime.utcnow() - timedelta(seconds=10)
payload["exp"] = past
payload["jti"] = "dummy-jti"
token = jwt.encode(payload, TEST_SECRET, algorithm=TOKEN_ALGORITHM)
with self.assertRaises(ExpiredSignatureError):
jwt_manager.validate_token(token, TEST_SECRET)

def test_validate_token_invalid_signature(self):
"""Test that validate_token raises InvalidTokenError for tampered tokens."""
payload = {"user_id": 202}
token = jwt_manager.generate_token(payload.copy(), TEST_SECRET)

# Split the token into header, payload, and signature.
parts = token.split('.')
self.assertEqual(len(parts), 3, "Token must have three parts separated by dots")

# Tamper with the signature part explicitly.
sig = parts[2]
# Change the last character in the signature part.
tampered_sig = sig[:-1] + ('a' if sig[-1] != 'a' else 'b')
tampered_token = f"{parts[0]}.{parts[1]}.{tampered_sig}"

with self.assertRaises(InvalidTokenError):
jwt_manager.validate_token(tampered_token, TEST_SECRET)

@patch("whispr.server.auth.jwt_manager.is_token_revoked", return_value=True)
def test_validate_token_revoked(self, mock_revoked):
"""Test that validate_token raises an exception when token is revoked."""
payload = {"user_id": 303}
token = jwt_manager.generate_token(payload.copy(), TEST_SECRET)
with self.assertRaises(Exception) as context:
jwt_manager.validate_token(token, TEST_SECRET)
self.assertIn("revoked", str(context.exception).lower())
mock_revoked.assert_called()

@patch("whispr.server.auth.jwt_manager.file_utils.write_file_atomic")
def test_revoke_token(self, mock_write_atomic):
"""Test that revoke_token adds token jti to the revocation list and calls atomic file write."""
payload = {"user_id": 404}
token = jwt_manager.generate_token(payload.copy(), TEST_SECRET)
decoded = jwt.decode(token, TEST_SECRET, algorithms=[TOKEN_ALGORITHM])
jti = decoded.get("jti")
self.assertNotIn(jti, jwt_manager._revoked_tokens)
jwt_manager.revoke_token(token)
self.assertIn(jti, jwt_manager._revoked_tokens)
# Ensure atomic write was called for persistence.
self.assertTrue(mock_write_atomic.called)

@patch("whispr.server.auth.jwt_manager.file_utils.write_file_atomic")
@patch("whispr.server.auth.jwt_manager.file_utils.read_file")
def test_persist_and_load_tokens(self, mock_read_file, mock_write_atomic):
"""Test persist_token and load_persisted_tokens functions using file_utils mocks."""
# Simulate no tokens file existing initially.
with patch("os.path.exists", return_value=False):
# Persist a token.
test_token = "sample.token.value"
jwt_manager.persist_token(test_token)
self.assertTrue(mock_write_atomic.called)

# Now simulate loading tokens from a file.
tokens_list = ["token1", "token2"]
mock_read_file.return_value = json.dumps(tokens_list)
with patch("os.path.exists", return_value=True):
loaded_tokens = jwt_manager.load_persisted_tokens()
self.assertEqual(loaded_tokens, tokens_list)

@patch("whispr.server.auth.jwt_manager.file_utils.write_file_atomic")
def test_persist_token_file_error(self, mock_write_atomic):
"""Test persist_token raises an error when file_utils.write_file_atomic fails."""
mock_write_atomic.side_effect = Exception("Write error")
with patch("os.path.exists", return_value=False):
with self.assertRaises(Exception) as context:
jwt_manager.persist_token("token_error")
self.assertIn("Write error", str(context.exception))

0 comments on commit 0a20738

Please sign in to comment.