Skip to content

Commit

Permalink
feat: flatten (de)serialization of custom user claims (#1159)
Browse files Browse the repository at this point in the history
* feat: initial implementation

* test: implement more jwt tests

* test: reduce code duplication with rstest

* test: rename variables

* test: remove unreferenced snapshot

* fix: examples compilation

* test: add missing snapshot

* test: add missing snapshot

* test: fix broken template

* fix: fix starters compilation

* Update examples/demo/src/models/users.rs

Co-authored-by: Jorge Hermo <[email protected]>

* chore: remove todos

* feat: only derive eq and partial eq in test target

* chore: undo changes in starters

* fix: CI compilation

* Added `total_items` to pagination view & response (#1197)

Co-authored-by: Elad Kaplan <[email protected]>

---------

Co-authored-by: Elad Kaplan <[email protected]>
Co-authored-by: Timon Klinkert <[email protected]>
  • Loading branch information
3 people authored Jan 26, 2025
1 parent fcc60c7 commit 7f2e5f8
Show file tree
Hide file tree
Showing 12 changed files with 281 additions and 42 deletions.
4 changes: 2 additions & 2 deletions examples/demo/src/controllers/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async fn register(
let jwt_secret = ctx.config.get_jwt_config()?;

let token = user
.generate_jwt(&jwt_secret.secret, &jwt_secret.expiration)
.generate_jwt(&jwt_secret.secret, jwt_secret.expiration)
.or_else(|_| unauthorized("unauthorized!"))?;
format::json(UserSession::new(&user, &token))
}
Expand Down Expand Up @@ -130,7 +130,7 @@ async fn login(State(ctx): State<AppContext>, Json(params): Json<LoginParams>) -
let jwt_secret = ctx.config.get_jwt_config()?;

let token = user
.generate_jwt(&jwt_secret.secret, &jwt_secret.expiration)
.generate_jwt(&jwt_secret.secret, jwt_secret.expiration)
.or_else(|_| unauthorized("unauthorized!"))?;

format::json(UserSession::new(&user, &token))
Expand Down
12 changes: 5 additions & 7 deletions examples/demo/src/models/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use async_trait::async_trait;
use chrono::offset::Local;
use loco_rs::{auth::jwt, hash, prelude::*};
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Map;
use uuid::Uuid;

pub use super::_entities::users::{self, ActiveModel, Entity, Model};
Expand Down Expand Up @@ -216,12 +216,10 @@ impl super::_entities::users::Model {
/// # Errors
///
/// when could not convert user claims to jwt token
pub fn generate_jwt(&self, secret: &str, expiration: &u64) -> ModelResult<String> {
Ok(jwt::JWT::new(secret).generate_token(
expiration,
self.pid.to_string(),
Some(json!({"Roll": "Administrator"})),
)?)
pub fn generate_jwt(&self, secret: &str, expiration: u64) -> ModelResult<String> {
let mut claims = Map::new();
claims.insert("Role".to_string(), "Administrator".into());
Ok(jwt::JWT::new(secret).generate_token(expiration, self.pid.to_string(), claims)?)
}
}

Expand Down
21 changes: 9 additions & 12 deletions loco-new/base_template/src/controllers/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ async fn register(
/// Verify register user. if the user not verified his email, he can't login to
/// the system.
#[debug_handler]
async fn verify(
State(ctx): State<AppContext>,
Path(token): Path<String>,
) -> Result<Response> {
async fn verify(State(ctx): State<AppContext>, Path(token): Path<String>) -> Result<Response> {
let user = users::Model::find_by_verification_token(&ctx.db, &token).await?;

if user.email_verified_at.is_some() {
Expand Down Expand Up @@ -143,7 +140,7 @@ async fn login(State(ctx): State<AppContext>, Json(params): Json<LoginParams>) -
let jwt_secret = ctx.config.get_jwt_config()?;

let token = user
.generate_jwt(&jwt_secret.secret, &jwt_secret.expiration)
.generate_jwt(&jwt_secret.secret, jwt_secret.expiration)
.or_else(|_| unauthorized("unauthorized!"))?;

format::json(LoginResponse::new(&user, &token))
Expand All @@ -158,14 +155,14 @@ async fn current(auth: auth::JWT, State(ctx): State<AppContext>) -> Result<Respo
/// Magic link authentication provides a secure and passwordless way to log in to the application.
///
/// # Flow
/// 1. **Request a Magic Link**:
/// A registered user sends a POST request to `/magic-link` with their email.
/// If the email exists, a short-lived, one-time-use token is generated and sent to the user's email.
/// 1. **Request a Magic Link**:
/// A registered user sends a POST request to `/magic-link` with their email.
/// If the email exists, a short-lived, one-time-use token is generated and sent to the user's email.
/// For security and to avoid exposing whether an email exists, the response always returns 200, even if the email is invalid.
///
/// 2. **Click the Magic Link**:
/// The user clicks the link (/magic-link/{token}), which validates the token and its expiration.
/// If valid, the server generates a JWT and responds with a [`LoginResponse`].
/// 2. **Click the Magic Link**:
/// The user clicks the link (/magic-link/{token}), which validates the token and its expiration.
/// If valid, the server generates a JWT and responds with a [`LoginResponse`].
/// If invalid or expired, an unauthorized response is returned.
///
/// This flow enhances security by avoiding traditional passwords and providing a seamless login experience.
Expand Down Expand Up @@ -211,7 +208,7 @@ async fn magic_link_verify(
let jwt_secret = ctx.config.get_jwt_config()?;

let token = user
.generate_jwt(&jwt_secret.secret, &jwt_secret.expiration)
.generate_jwt(&jwt_secret.secret, jwt_secret.expiration)
.or_else(|_| unauthorized("unauthorized!"))?;

format::json(LoginResponse::new(&user, &token))
Expand Down
5 changes: 3 additions & 2 deletions loco-new/base_template/src/models/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use async_trait::async_trait;
use chrono::{offset::Local, Duration};
use loco_rs::{auth::jwt, hash, prelude::*};
use serde::{Deserialize, Serialize};
use serde_json::Map;
use uuid::Uuid;

pub use super::_entities::users::{self, ActiveModel, Entity, Model};
Expand Down Expand Up @@ -258,8 +259,8 @@ impl Model {
/// # Errors
///
/// when could not convert user claims to jwt token
pub fn generate_jwt(&self, secret: &str, expiration: &u64) -> ModelResult<String> {
Ok(jwt::JWT::new(secret).generate_token(expiration, self.pid.to_string(), None)?)
pub fn generate_jwt(&self, secret: &str, expiration: u64) -> ModelResult<String> {
Ok(jwt::JWT::new(secret).generate_token(expiration, self.pid.to_string(), Map::new())?)
}
}

Expand Down
111 changes: 97 additions & 14 deletions src/auth/jwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@
//!
//! This module provides functionality for working with JSON Web Tokens (JWTs)
//! and password hashing.
use jsonwebtoken::{
decode, encode, errors::Result as JWTResult, get_current_timestamp, Algorithm, DecodingKey,
EncodingKey, Header, TokenData, Validation,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::{Map, Value};

/// Represents the default JWT algorithm used by the [`JWT`] struct.
const JWT_ALGORITHM: Algorithm = Algorithm::HS512;

/// Represents the claims associated with a user JWT.
#[cfg_attr(test, derive(Eq, PartialEq))]
#[derive(Debug, Serialize, Deserialize)]
pub struct UserClaims {
pub pid: String,
exp: u64,
pub claims: Option<Value>,
#[serde(default, flatten)]
pub claims: Map<String, Value>,
}

/// Represents the JWT configuration and operations.
Expand Down Expand Up @@ -61,17 +62,18 @@ impl JWT {
///
/// # Example
/// ```rust
/// use serde_json::Map;
/// use loco_rs::auth;
///
/// auth::jwt::JWT::new("PqRwLF2rhHe8J22oBeHy").generate_token(&604800, "PID".to_string(), None);
/// auth::jwt::JWT::new("PqRwLF2rhHe8J22oBeHy").generate_token(604800, "PID".to_string(), Map::new());
/// ```
pub fn generate_token(
&self,
expiration: &u64,
expiration: u64,
pid: String,
claims: Option<Value>,
claims: Map<String, Value>,
) -> JWTResult<String> {
let exp = get_current_timestamp().saturating_add(*expiration);
let exp = get_current_timestamp().saturating_add(expiration);

let claims = UserClaims { pid, exp, claims };

Expand Down Expand Up @@ -119,18 +121,27 @@ mod tests {
use super::*;

#[rstest]
#[case("valid token", 60, None)]
#[case("token expired", 1, None)]
#[case("valid token and custom claims", 60, Some(json!({})))]
#[tokio::test]
async fn can_generate_token(
#[case("valid token", 60, json!({}))]
#[case("token expired", 1, json!({}))]
#[case("valid token and custom string claims", 60, json!({ "custom": "claim",}))]
#[case("valid token and custom boolean claims",60, json!({ "custom": true,}))]
#[case("valid token and custom number claims",60, json!({ "custom": 123,}))]
#[case("valid token and custom nested claims",60, json!({ "level1": { "level2": { "level3": "claim" } } }))]
#[case("valid token and custom array claims",60, json!({ "array": [1, 2, 3] }))]
#[case("valid token and custom nested array claims",60, json!({ "level1": { "level2": { "level3": [1, 2, 3] } } }))]
fn can_generate_token(
#[case] test_name: &str,
#[case] expiration: u64,
#[case] claims: Option<Value>,
#[case] json_claims: Value,
) {
let claims = json_claims
.as_object()
.expect("case input claims must be an object")
.clone();
let jwt = JWT::new("PqRwLF2rhHe8J22oBeHy");

let token = jwt
.generate_token(&expiration, "pid".to_string(), claims)
.generate_token(expiration, "pid".to_string(), claims)
.unwrap();

std::thread::sleep(std::time::Duration::from_secs(3));
Expand All @@ -140,4 +151,76 @@ mod tests {
assert_debug_snapshot!(test_name, jwt.validate(&token));
});
}

#[rstest]
#[case::without_custom_claims(json!({}))]
#[case::with_custom_string_claims(json!({ "custom": "claim",}))]
#[case::with_custom_boolean_claims(json!({ "custom": true,}))]
#[case::with_custom_number_claims(json!({ "custom": 123,}))]
#[case::with_custom_nested_claims(json!({ "level1": { "level2": { "level3": "claim" } } }))]
#[case::with_custom_array_claims(json!({ "array": [1, 2, 3] }))]
#[case::with_custom_nested_array_claims(json!({ "level1": { "level2": { "level3": [1, 2, 3] } } }))]
// we use `Value` to reduce code duplicity in the case inputs
fn serialize_user_claims(#[case] json_claims: Value) {
let claims = json_claims
.as_object()
.expect("case input claims must be an object")
.clone();
let input_user_claims = UserClaims {
pid: "pid".to_string(),
exp: 60,
claims: claims.clone(),
};

let mut expected_claim = Map::new();
expected_claim.insert("pid".to_string(), "pid".into());
expected_claim.insert("exp".to_string(), 60.into());
// we add the claims in a flattened way
expected_claim.extend(claims);
let expected_value = Value::from(expected_claim);

// We check between `Value` instead of `String` to avoid key ordering issues when serializing.
// It is because `expected_value` has all the keys in alphabetical order, as the `Value` serialization ensures that.
// But when serializing `input_user_claims`, first the `pid` and `exp` fields are serialized (in that order),
// and then the claims are serialized in alfabetic order. So, the resulting JSON string from the `input_user_claims` serialization
// may have the `pid` and `exp` fields unordered which differs from the `Value` serialization.
assert_eq!(
expected_value,
serde_json::to_value(&input_user_claims).unwrap()
);
}

#[rstest]
#[case::without_custom_claims(json!({}))]
#[case::with_custom_string_claims(json!({ "custom": "claim",}))]
#[case::with_custom_boolean_claims(json!({ "custom": true,}))]
#[case::with_custom_number_claims(json!({ "custom": 123,}))]
#[case::with_custom_nested_claims(json!({ "level1": { "level2": { "level3": "claim" } } }))]
#[case::with_custom_array_claims(json!({ "array": [1, 2, 3] }))]
#[case::with_custom_nested_array_claims(json!({ "level1": { "level2": { "level3": [1, 2, 3] } } }))]
// we use `Value` to reduce code duplicity in the case inputs
fn deserialize_user_claims(#[case] json_claims: Value) {
let claims = json_claims
.as_object()
.expect("case input claims must be an object")
.clone();

let mut input_claims = Map::new();
input_claims.insert("pid".to_string(), "pid".into());
input_claims.insert("exp".to_string(), 60.into());
// we add the claims in a flattened way
input_claims.extend(claims.clone());
let input_json = Value::from(input_claims).to_string();

let expected_user_claims = UserClaims {
pid: "pid".to_string(),
exp: 60,
claims,
};

assert_eq!(
expected_user_claims,
serde_json::from_str(&input_json).unwrap()
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
---
source: src/auth/jwt.rs
expression: jwt.validate(&token)
---
Ok(
TokenData {
header: Header {
typ: Some(
"JWT",
),
alg: HS512,
cty: None,
jku: None,
jwk: None,
kid: None,
x5u: None,
x5c: None,
x5t: None,
x5t_s256: None,
},
claims: UserClaims {
pid: "pid",
exp: EXP,
claims: {
"array": Array [
Number(1),
Number(2),
Number(3),
],
},
},
},
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
source: src/auth/jwt.rs
assertion_line: 133
expression: jwt.validate(&token)
---
Ok(
Expand All @@ -22,9 +21,9 @@ Ok(
claims: UserClaims {
pid: "pid",
exp: EXP,
claims: Some(
Object {},
),
claims: {
"custom": Bool(true),
},
},
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
---
source: src/auth/jwt.rs
expression: jwt.validate(&token)
---
Ok(
TokenData {
header: Header {
typ: Some(
"JWT",
),
alg: HS512,
cty: None,
jku: None,
jwk: None,
kid: None,
x5u: None,
x5c: None,
x5t: None,
x5t_s256: None,
},
claims: UserClaims {
pid: "pid",
exp: EXP,
claims: {
"level1": Object {
"level2": Object {
"level3": Array [
Number(1),
Number(2),
Number(3),
],
},
},
},
},
},
)
Loading

0 comments on commit 7f2e5f8

Please sign in to comment.