Skip to content

Commit

Permalink
Upgrade Auth & Sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
lucemans committed Jul 30, 2024
1 parent 3c7f3f2 commit f7818bc
Show file tree
Hide file tree
Showing 21 changed files with 425 additions and 337 deletions.
31 changes: 11 additions & 20 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 23 additions & 3 deletions engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,34 @@ chrono = "0.4.38"
color-eyre = "0.6.3"
dotenv = "0.15.0"
dotenvy = "0.15.7"
hex = "0.4.3"
hmac = "0.12.1"
openid = "0.14.0"
poem = "3.0.4"
poem-openapi = { version = "5.0.3", features = ["chrono", "uuid", "email", "email_address", "redoc", "static-files"] }
poem = { version = "3.0.4", git = "https://github.com/poem-web/poem", branch = "master" }
poem-openapi = { version = "5", git = "https://github.com/poem-web/poem", branch = "master", features = [
"chrono",
"uuid",
"sqlx",
"url",
"email",
"email_address",
"redoc",
"static-files",
] }
reqwest = "0.12.5"
serde = "1.0.204"
serde_json = "1.0.120"
serde_with = { version = "3.9.0", features = ["json", "chrono"] }
sqlx = { version = "0.7.4", features = ["runtime-async-std", "tls-rustls", "postgres", "uuid", "chrono", "json", "ipnetwork"] }
sha2 = "0.10.8"
sqlx = { version = "0.7.4", features = [
"runtime-async-std",
"tls-rustls",
"postgres",
"uuid",
"chrono",
"json",
"ipnetwork",
] }
terminal-banner = "0.4.1"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
Expand Down
4 changes: 2 additions & 2 deletions engine/migrations/0002_sessions.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
CREATE TABLE IF NOT EXISTS sessions
(
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
id TEXT PRIMARY KEY NOT NULL,
user_id INT NOT NULL,
user_agent VARCHAR(255) NOT NULL,
user_agent TEXT NOT NULL,
user_ip INET NOT NULL,
last_access TIMESTAMPTZ NOT NULL DEFAULT NOW(),
valid BOOLEAN NOT NULL DEFAULT TRUE
Expand Down
12 changes: 9 additions & 3 deletions engine/src/auth/middleware.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::sync::Arc;

use hmac::{Hmac, Mac};
use poem::{web::Data, Error, FromRequest, Request, RequestBody, Result};
use reqwest::StatusCode;
use uuid::Uuid;
use sha2::Sha256;

use crate::state::AppState;

Expand All @@ -26,12 +27,17 @@ impl<'a> FromRequest<'a> for AuthToken {
.headers()
.get("Authorization")
.and_then(|x| x.to_str().ok())
.and_then(|x| Uuid::parse_str(&x.replace("Bearer ", "")).ok());
.map(|x| x.replace("Bearer ", ""));

match token {
Some(token) => {
// Hash the token
let mut hash = Hmac::<Sha256>::new_from_slice(b"").unwrap();
hash.update(token.as_bytes());
let hash = hex::encode(hash.finalize().into_bytes());

// Check if active session exists with token
let session = SessionState::get_by_id(token, &state.database)
let session = SessionState::try_access(&hash, &state.database)
.await
.unwrap()
.ok_or(Error::from_string(
Expand Down
65 changes: 30 additions & 35 deletions engine/src/auth/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,12 @@ use std::net::IpAddr;
use poem_openapi::Object;
use serde::{Deserialize, Serialize};
use sqlx::types::chrono;
use uuid::Uuid;

use crate::database::Database;

#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct SafeSession {
pub id: String,
pub user_id: i32,
pub user_agent: String,
pub user_ip: IpAddr,
pub last_access: chrono::DateTime<chrono::Utc>,
pub valid: bool,
}

#[derive(sqlx::FromRow, Debug, Clone, Serialize, Deserialize)]
#[derive(sqlx::FromRow, Debug, Clone, Serialize, Deserialize, Object)]
pub struct SessionState {
pub id: Uuid,
pub id: String,
pub user_id: i32,
pub user_agent: String,
pub user_ip: IpAddr,
Expand All @@ -29,14 +18,16 @@ pub struct SessionState {

impl SessionState {
pub async fn new(
session_id: &str,
user_id: i32,
user_agent: &str,
user_ip: &IpAddr,
database: &Database,
) -> Result<Self, sqlx::Error> {
let session = sqlx::query_as::<_, SessionState>(
"INSERT INTO sessions (user_id, user_agent, user_ip) VALUES ($1, $2, $3) RETURNING *",
"INSERT INTO sessions (id, user_id, user_agent, user_ip) VALUES ($1, $2, $3, $4) RETURNING *",
)
.bind(session_id)
.bind(user_id)
.bind(user_agent)
.bind(user_ip)
Expand All @@ -45,7 +36,7 @@ impl SessionState {
Ok(session)
}

pub async fn get_by_id(id: Uuid, database: &Database) -> Result<Option<Self>, sqlx::Error> {
pub async fn get_by_id(id: &str, database: &Database) -> Result<Option<Self>, sqlx::Error> {
let session = sqlx::query_as::<_, SessionState>(
"SELECT * FROM sessions WHERE id = $1 AND valid = TRUE",
)
Expand All @@ -56,8 +47,22 @@ impl SessionState {
Ok(session)
}

pub async fn try_access(id: &str, database: &Database) -> Result<Option<Self>, sqlx::Error> {
let session = sqlx::query_as::<_, SessionState>(
"UPDATE sessions SET last_access = NOW() WHERE id = $1 AND valid = TRUE RETURNING *",
)
.bind(id)
.fetch_optional(&database.pool)
.await?;

Ok(session)
}

/// Get all sessions for a user that are valid
pub async fn get_by_user_id(user_id: i32, database: &Database) -> Result<Vec<Self>, sqlx::Error> {
pub async fn get_by_user_id(
user_id: i32,
database: &Database,
) -> Result<Vec<Self>, sqlx::Error> {
let sessions = sqlx::query_as::<_, SessionState>(
"SELECT * FROM sessions WHERE user_id = $1 AND valid = TRUE",
)
Expand All @@ -69,7 +74,10 @@ impl SessionState {
}

/// Set every session to invalid
pub async fn invalidate_by_user_id(user_id: i32, database: &Database) -> Result<Vec<Self>, sqlx::Error> {
pub async fn invalidate_by_user_id(
user_id: i32,
database: &Database,
) -> Result<Vec<Self>, sqlx::Error> {
let sessions = sqlx::query_as::<_, SessionState>(
"UPDATE sessions SET valid = FALSE WHERE user_id = $1",
)
Expand All @@ -81,7 +89,11 @@ impl SessionState {
}

/// Invalidate all sessions for a user that are older than the given time
pub async fn invalidate_by_user_id_by_time(user_id: i32, database: &Database, invalidate_before: chrono::DateTime<chrono::Utc>) -> Result<Vec<Self>, sqlx::Error> {
pub async fn _invalidate_by_user_id_by_time(
user_id: i32,
database: &Database,
_invalidate_before: chrono::DateTime<chrono::Utc>,
) -> Result<Vec<Self>, sqlx::Error> {
let sessions = sqlx::query_as::<_, SessionState>(
"UPDATE sessions SET valid = FALSE WHERE user_id = $1 AND last_access < $2",
)
Expand All @@ -92,20 +104,3 @@ impl SessionState {
Ok(sessions)
}
}

impl Into<SafeSession> for SessionState {
fn into(self) -> SafeSession {
let id = self.id.to_string();
let id = id[0..6].to_string() + &id[30..];

SafeSession {
// Strip uuid to be abc...xyz
id,
user_id: self.user_id,
user_agent: self.user_agent,
user_ip: self.user_ip,
last_access: self.last_access,
valid: self.valid,
}
}
}
6 changes: 3 additions & 3 deletions engine/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod user_data;
pub mod property;
pub mod product;
pub mod media;
pub mod product;
pub mod property;
pub mod user_data;
31 changes: 28 additions & 3 deletions engine/src/models/user_data.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
use openid::Userinfo;
use poem_openapi::Object;
use serde::{Deserialize, Serialize};
use sqlx::types::Json;
use tracing::info;
use url::Url;

use crate::database::Database;

#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct User {
pub id: i32,
pub oauth_sub: String,
pub name: String,
pub picture: Option<Url>,
}

#[derive(sqlx::FromRow, Debug, Clone, Serialize, Deserialize)]
pub struct UserData {
pub struct UserEntry {
pub id: i32,
pub oauth_sub: String,
pub oauth_data: Json<Userinfo>,
pub nickname: Option<String>,
}

impl UserData {
impl UserEntry {
pub async fn new(oauth_userinfo: &Userinfo, database: &Database) -> Result<Self, sqlx::Error> {
let sub = oauth_userinfo.sub.as_deref().unwrap();

info!("Initializing new User {:?}", oauth_userinfo);
info!("Initializing new User {:?}", oauth_userinfo.sub);

sqlx::query_as::<_, Self>(
"INSERT INTO users (oauth_sub, oauth_data) VALUES ($1, $2) ON CONFLICT (oauth_sub) DO UPDATE SET oauth_data = $2 RETURNING *"
Expand All @@ -36,3 +46,18 @@ impl UserData {
Ok(user)
}
}

impl From<UserEntry> for User {
fn from(user: UserEntry) -> Self {
Self {
id: user.id,
oauth_sub: user.oauth_sub,
name: user
.nickname
.or(user.oauth_data.nickname.clone())
.or(user.oauth_data.name.clone())
.unwrap_or("Unknown".to_string()),
picture: user.oauth_data.picture.clone(),
}
}
}
Loading

0 comments on commit f7818bc

Please sign in to comment.