From fb68f331b29be6cd7d4a334ca428dcbe42ecacc0 Mon Sep 17 00:00:00 2001 From: Rajdeep Sengupta Date: Sun, 5 May 2024 14:00:20 +0530 Subject: [PATCH] Added renew token feature --- src/core/session.rs | 200 ++++++++++++++++++++++++------- src/errors.rs | 9 +- src/handlers/session_handler.rs | 8 +- src/utils/refresh_token_utils.rs | 50 ++++++++ src/utils/session_utils.rs | 28 +++-- 5 files changed, 241 insertions(+), 54 deletions(-) create mode 100644 src/utils/refresh_token_utils.rs diff --git a/src/core/session.rs b/src/core/session.rs index fc40e21..cc2075f 100644 --- a/src/core/session.rs +++ b/src/core/session.rs @@ -1,5 +1,11 @@ use crate::{ - errors::{Error, Result}, models::session_model::SessionResponse, traits::{decryption::Decrypt, encryption::Encrypt}, utils::{encryption_utils::Encryption, session_utils::{IDToken, RefreshToken}} + errors::{Error, Result}, + models::session_model::SessionResponse, + traits::{decryption::Decrypt, encryption::Encrypt}, + utils::{ + encryption_utils::Encryption, + session_utils::{IDToken, RefreshToken}, + }, }; use bson::{doc, DateTime}; use futures::StreamExt; @@ -58,29 +64,34 @@ impl Session { } } - pub async fn verify( - mongo_client: &Client, - id_token: &str, - ) -> Result { + pub async fn verify(mongo_client: &Client, id_token: &str) -> Result<(IDToken, bool)> { let token_data = match IDToken::verify(&id_token) { - Ok(token_data) => { + Ok(token_verify_result) => { + // check if the session is expired using the boolean + if !token_verify_result.1 { + return Ok(token_verify_result); + } let db = mongo_client.database("test"); let collection_session: Collection = db.collection("sessions"); - let dek_data = match Dek::get(mongo_client, &token_data.uid).await { + let dek_data = match Dek::get(mongo_client, &token_verify_result.0.uid).await { Ok(dek) => dek, Err(e) => return Err(e), }; - let encrypted_id = Encryption::encrypt_data(&token_data.uid, &dek_data.dek); + let encrypted_id = + Encryption::encrypt_data(&token_verify_result.0.uid, &dek_data.dek); let encrypted_id_token = Encryption::encrypt_data(&id_token, &dek_data.dek); - + let session = match collection_session - .count_documents(doc! { - "uid": encrypted_id, - "id_token": encrypted_id_token, - "is_revoked": false, - }, None) + .count_documents( + doc! { + "uid": encrypted_id, + "id_token": encrypted_id_token, + "is_revoked": false, + }, + None, + ) .await { Ok(count) => { @@ -91,7 +102,7 @@ impl Session { message: "Invalid token".to_string(), }) } - }, + } Err(e) => Err(Error::ServerError { message: e.to_string(), }), @@ -101,15 +112,126 @@ impl Session { message: "Invalid token".to_string(), }); } else { - Ok(token_data) + Ok(token_verify_result) } - }, + } Err(e) => return Err(e), }; token_data } - pub async fn get_all_from_uid(mongo_client: &Client, uid: &str) -> Result> { + pub async fn refresh_session( + mongo_client: &Client, + id_token: &str, + refresh_token: &str, + ) -> Result<(String, String)> { + // verify refresh token + match RefreshToken::verify(&refresh_token) { + Ok(_) => {} + Err(e) => return Err(e), + } + match Self::verify(&mongo_client, &id_token).await { + Ok(token_verify_result) => { + if !token_verify_result.1 { + let db = mongo_client.database("test"); + let collection_session: Collection = db.collection("sessions"); + + let dek_data = match Dek::get(mongo_client, &token_verify_result.0.uid).await { + Ok(dek) => dek, + Err(e) => return Err(e), + }; + + let encrypted_id = Encryption::encrypt_data(&token_verify_result.0.uid, &dek_data.dek); + let encrypted_id_token = Encryption::encrypt_data(&id_token, &dek_data.dek); + let encrypted_refresh_token = + Encryption::encrypt_data(&refresh_token, &dek_data.dek); + + match collection_session + .count_documents( + doc! { + "uid": &encrypted_id, + "id_token": &encrypted_id_token, + "refresh_token": &encrypted_refresh_token, + "is_revoked": false, + }, + None, + ) + .await + { + Ok(count) => { + if count == 1 { + // generate a new id token and refresh token + let user = match User::get_from_uid(&mongo_client, &token_verify_result.0.uid).await { + Ok(user) => user, + Err(e) => return Err(e), + }; + let new_id_token = match IDToken::new(&user).sign() { + Ok(token) => token, + Err(_) => "".to_string(), + }; + + let new_refresh_token = match RefreshToken::new(&token_verify_result.0.uid).sign() { + Ok(token) => token, + Err(_) => "".to_string(), + }; + + // encrypt the new tokens + let new_id_token_encrypted = Encryption::encrypt_data(&new_id_token, &dek_data.dek); + let new_refresh_token_encrypted = Encryption::encrypt_data(&new_refresh_token, &dek_data.dek); + + match collection_session + .update_one( + doc! { + "uid": encrypted_id, + "id_token": encrypted_id_token, + "refresh_token": encrypted_refresh_token, + "is_revoked": false, + }, + doc! { + "$set": { + "id_token": new_id_token_encrypted, + "refresh_token": new_refresh_token_encrypted, + "updated_at": DateTime::now(), + } + }, + None, + ) + .await + { + Ok(_) => return Ok((new_id_token, new_refresh_token)), + Err(e) => return Err(Error::ServerError { + message: e.to_string(), + }), + }; + } else { + match Self::revoke_all(&mongo_client, &token_verify_result.0.uid).await { + Ok(_) => return Err(Error::SessionExpired { + message: "Invalid token".to_string(), + }), + Err(e) => return Err(e), + } + } + } + Err(e) => return Err(Error::ServerError { + message: e.to_string(), + }), + }; + } else { + return Err(Error::ActiveSessionExists { + message: "Active Session already exists".to_string(), + }); + } + } + Err(e) => { + return Err(e); + } + }; + } + + pub async fn get_all_from_uid( + mongo_client: &Client, + uid: &str, + ) -> Result> { let db = mongo_client.database("test"); let collection_session: Collection = db.collection("sessions"); @@ -139,21 +261,23 @@ impl Session { match IDToken::verify(&decrypted_session.id_token) { Ok(token) => { println!("{:?}", token); - sessions_res.push( - SessionResponse { - uid: decrypted_session.uid, - email: decrypted_session.email, - user_agent: decrypted_session.user_agent, - is_revoked: decrypted_session.is_revoked, - created_at: decrypted_session.created_at, - updated_at: decrypted_session.updated_at, - } - ); + sessions_res.push(SessionResponse { + uid: decrypted_session.uid, + email: decrypted_session.email, + user_agent: decrypted_session.user_agent, + is_revoked: decrypted_session.is_revoked, + created_at: decrypted_session.created_at, + updated_at: decrypted_session.updated_at, + }); } Err(_) => continue, } } - Err(e) => return Err(Error::ServerError { message: e.to_string() }), + Err(e) => { + return Err(Error::ServerError { + message: e.to_string(), + }) + } } } Ok(sessions_res) @@ -164,11 +288,7 @@ impl Session { let collection_session: Collection = db.collection("sessions"); match collection_session - .update_many( - doc! {"uid": uid}, - doc! {"$set": {"is_revoked": true}}, - None, - ) + .update_many(doc! {"uid": uid}, doc! {"$set": {"is_revoked": true}}, None) .await { Ok(_) => Ok(()), @@ -178,11 +298,7 @@ impl Session { } } - pub async fn revoke( - id_token: &str, - refresh_token: &str, - mongo_client: &Client, - ) -> Result<()> { + pub async fn revoke(id_token: &str, refresh_token: &str, mongo_client: &Client) -> Result<()> { let db = mongo_client.database("test"); let collection_session: Collection = db.collection("sessions"); @@ -219,16 +335,12 @@ impl Session { } } - pub async fn delete_all(mongo_client: &Client, uid: &str) -> Result<()> { let db = mongo_client.database("test"); let collection_session: Collection = db.collection("sessions"); match collection_session - .delete_many( - doc! {"uid": uid}, - None, - ) + .delete_many(doc! {"uid": uid}, None) .await { Ok(_) => Ok(()), diff --git a/src/errors.rs b/src/errors.rs index 631b3a8..18ad2e3 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -29,6 +29,7 @@ pub enum Error { SignatureVerificationError { message: String }, ExpiredSignature { message: String }, SessionExpired { message: String }, + ActiveSessionExists { message: String }, // -- Encryption Errors @@ -127,6 +128,11 @@ impl Error { ClientError::SERVICE_ERROR, ), + Self::ActiveSessionExists { message: _ } => ( + StatusCode::CONFLICT, + ClientError::ACTIVE_SESSION_EXISTS, + ), + _ => ( StatusCode::INTERNAL_SERVER_ERROR, ClientError::SERVICE_ERROR, @@ -147,7 +153,8 @@ pub enum ClientError { INVALID_TOKEN, SIGNATURE_VERIFICATION_ERROR, EXPIRED_SIGNATURE, - SESSION_EXPIRED + SESSION_EXPIRED, + ACTIVE_SESSION_EXISTS, } // region: --- Error Boilerplate diff --git a/src/handlers/session_handler.rs b/src/handlers/session_handler.rs index 1742e9d..37a4655 100644 --- a/src/handlers/session_handler.rs +++ b/src/handlers/session_handler.rs @@ -16,7 +16,13 @@ pub async fn verify_session( // verify the token match Session::verify(&state.mongo_client, &payload.token).await { Ok(data) => { - return Ok(Json(data)); + return { + if data.1 { + Ok(Json(data.0)) + } else { + Err(Error::SessionExpired { message: "Session expired".to_string() }) + } + } } Err(e) => return Err(e), diff --git a/src/utils/refresh_token_utils.rs b/src/utils/refresh_token_utils.rs new file mode 100644 index 0000000..0d56096 --- /dev/null +++ b/src/utils/refresh_token_utils.rs @@ -0,0 +1,50 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct RefreshToken { + uid: String, + iss: String, + iat: usize, + exp: usize, + scope: String, + data: Option>, +} + +impl RefreshToken { + pub fn new(uid: &str, scope: &str) -> Self { + let server_url = + std::env::var("SERVER_URL").unwrap_or_else(|_| "http://localhost:8080".to_string()); + Self { + uid: uid.to_string(), + iss: server_url, + iat: chrono::Utc::now().timestamp() as usize, + exp: chrono::Utc::now().timestamp() as usize + (3600 * 24 * 30), // 30 days + scope: scope.to_string(), + data: None, + } + } + + pub fn sign(&self) -> String { + let private_key = load_private_key().unwrap(); + let rsa = Rsa::private_key_from_pem(&private_key).unwrap(); + let private_key = PKey::from_rsa(rsa).unwrap(); + let mut claims = ClaimsSet:: { + registered: RegisteredClaims { + issuer: Some(From::from(self.iss.clone())), + subject: Some(From::from(self.uid.clone())), + issued_at: Some(From::from(self.iat)), + expiration: Some(From::from(self.exp)), + not_before: None, + ..Default::default() + }, + private: self.clone(), + }; + let header = Header { + algorithm: Algorithm::RS256, + ..Default::default() + }; + encode(&header, &claims, &private_key).unwrap() + } +} \ No newline at end of file diff --git a/src/utils/session_utils.rs b/src/utils/session_utils.rs index ce83c47..e96a020 100644 --- a/src/utils/session_utils.rs +++ b/src/utils/session_utils.rs @@ -96,7 +96,7 @@ impl IDToken { }; } - pub fn verify(token: &str) -> Result { + pub fn verify(token: &str) -> Result<(Self, bool), Error> { let public_key = load_public_key()?; let validation = Validation::new(jwt::Algorithm::RS256); // Try to create a DecodingKey from the public key @@ -113,17 +113,28 @@ impl IDToken { match jwt::decode::(&token, &decoding_key, &validation) { Ok(val) => { let token_data = val.claims; - Ok(token_data) + Ok((token_data, true)) } - Err(e) => match e { + Err(e) => match e.kind() { // check if ExpiredSignature - _ if e.to_string().contains("ExpiredSignature") => { - return Err(Error::ExpiredSignature { - message: "Expired signature".to_string(), - }) + jwt::errors::ErrorKind::ExpiredSignature => { + // get token claims even if it is expired to check the data by decoding it with exp flag set to false + let mut validation = Validation::new(jwt::Algorithm::RS256); + validation.validate_exp = false; + match jwt::decode::(&token, &decoding_key, &validation) { + Ok(val) => { + let token_data = val.claims; + Ok((token_data, false)) + } + Err(_) => { + return Err(Error::ServerError { + message: "Error decoding token".to_string(), + }) + } + } } // check if InvalidSignature - _ if e.to_string().contains("InvalidSignature") => { + jwt::errors::ErrorKind::InvalidSignature => { return Err(Error::SignatureVerificationError { message: "Invalid signature".to_string(), }) @@ -136,6 +147,7 @@ impl IDToken { }, } } + } // RefreshToken struct