Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to async-session v4 #50

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ repository = "https://github.com/maxcountryman/axum-sessions"
documentation = "https://docs.rs/axum-sessions"

[dependencies]
async-session = "3.0.0"
async-session = { git = "https://github.com/http-rs/async-session", rev = "35cb0998f91b81b133c3314414adae4019a62741", default-features = false }
base64 = "0.21.0"
futures = "0.3.21"
hmac = { version = "0.12.1", features = ["std"] }
http-body = "0.4.5"
sha2 = "0.10.6"
tower = "0.4.12"
tracing = "0.1"

Expand All @@ -34,6 +37,7 @@ features = ["sync"]
http = "0.2.8"
hyper = "0.14.19"
serde = "1.0.147"
serde_json = "1.0.93"

[dev-dependencies.rand]
version = "0.8.5"
Expand All @@ -43,3 +47,7 @@ features = ["min_const_gen"]
version = "1.20.1"
default-features = false
features = ["macros", "rt-multi-thread"]

[dev-dependencies.async-session-memory-store]
git = "https://github.com/http-rs/async-session"
rev = "35cb0998f91b81b133c3314414adae4019a62741"
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
//! Using the middleware with axum is straightforward:
//!
//! ```rust,no_run
//! use async_session_memory_store::MemoryStore;
//! use axum::{routing::get, Router};
//! use axum_sessions::{
//! async_session::MemoryStore, extractors::WritableSession, PersistencePolicy, SessionLayer,
//! };
//! use axum_sessions::{extractors::WritableSession, PersistencePolicy, SessionLayer};
//!
//! #[tokio::main]
//! async fn main() {
Expand Down Expand Up @@ -47,8 +46,9 @@
//! ```rust
//! use std::convert::Infallible;
//!
//! use async_session_memory_store::MemoryStore;
//! use axum::http::header::SET_COOKIE;
//! use axum_sessions::{async_session::MemoryStore, SessionHandle, SessionLayer};
//! use axum_sessions::{SessionHandle, SessionLayer};
//! use http::{Request, Response};
//! use hyper::Body;
//! use rand::Rng;
Expand Down
53 changes: 26 additions & 27 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@ use std::{
time::Duration,
};

use async_session::{
base64,
hmac::{Hmac, Mac, NewMac},
sha2::Sha256,
SessionStore,
};
use async_session::SessionStore;
use axum::{
http::{
header::{HeaderValue, COOKIE, SET_COOKIE},
Expand All @@ -21,7 +16,10 @@ use axum::{
response::Response,
};
use axum_extra::extract::cookie::{Cookie, Key, SameSite};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use futures::future::BoxFuture;
use hmac::{Hmac, Mac};
use sha2::{digest::generic_array::GenericArray, Sha256};
use tokio::sync::RwLock;
use tower::{Layer, Service};

Expand Down Expand Up @@ -66,7 +64,10 @@ pub struct SessionLayer<Store> {
key: Key,
}

impl<Store: SessionStore> SessionLayer<Store> {
impl<Store> SessionLayer<Store>
where
Store: SessionStore + Clone + Send + Sync + 'static,
{
/// Creates a layer which will attach a [`SessionHandle`] to requests via an
/// extension. This session is derived from a cryptographically signed
/// cookie. When the client sends a valid, known cookie then the session is
Expand All @@ -86,7 +87,8 @@ impl<Store: SessionStore> SessionLayer<Store> {
/// of your application:
///
/// ```rust
/// # use axum_sessions::{PersistencePolicy, SessionLayer, async_session::MemoryStore, SameSite};
/// # use axum_sessions::{PersistencePolicy, SessionLayer, SameSite};
/// # use async_session_memory_store::MemoryStore;
/// # use std::time::Duration;
/// SessionLayer::new(
/// MemoryStore::new(),
Expand Down Expand Up @@ -183,7 +185,7 @@ impl<Store: SessionStore> SessionLayer<Store> {

async fn load_or_create(&self, cookie_value: Option<String>) -> SessionHandle {
let session = match cookie_value {
Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(),
Some(cookie_value) => self.store.load_session(&cookie_value).await.ok().flatten(),
None => None,
};

Expand Down Expand Up @@ -243,7 +245,7 @@ impl<Store: SessionStore> SessionLayer<Store> {
mac.update(cookie.value().as_bytes());

// Cookie's new value is [MAC | original-value].
let mut new_value = base64::encode(mac.finalize().into_bytes());
let mut new_value = BASE64.encode(mac.finalize().into_bytes());
new_value.push_str(cookie.value());
cookie.set_value(new_value);
}
Expand All @@ -260,18 +262,21 @@ impl<Store: SessionStore> SessionLayer<Store> {

// Split [MAC | original-value] into its two parts.
let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN);
let digest = base64::decode(digest_str).map_err(|_| "bad base64 digest")?;
let digest = BASE64.decode(digest_str).map_err(|_| "bad base64 digest")?;

// Perform the verification.
let mut mac = Hmac::<Sha256>::new_from_slice(self.key.signing()).expect("good key");
mac.update(value.as_bytes());
mac.verify(&digest)
mac.verify(GenericArray::from_slice(&digest))
.map(|_| value.to_string())
.map_err(|_| "value did not verify")
}
}

impl<Inner, Store: SessionStore> Layer<Inner> for SessionLayer<Store> {
impl<Inner, Store> Layer<Inner> for SessionLayer<Store>
where
Store: SessionStore + Clone + Send + Sync + 'static,
{
type Service = Session<Inner, Store>;

fn layer(&self, inner: Inner) -> Self::Service {
Expand All @@ -289,13 +294,13 @@ pub struct Session<Inner, Store: SessionStore> {
layer: SessionLayer<Store>,
}

impl<Inner, ReqBody, ResBody, Store: SessionStore> Service<Request<ReqBody>>
for Session<Inner, Store>
impl<Inner, ReqBody, ResBody, Store> Service<Request<ReqBody>> for Session<Inner, Store>
where
Inner: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
ResBody: Send + 'static,
ReqBody: Send + 'static,
Inner::Future: Send + 'static,
Store: SessionStore + Clone + Send + Sync + 'static,
{
type Response = Inner::Response;
type Error = Inner::Error;
Expand Down Expand Up @@ -341,13 +346,9 @@ where
(session.is_destroyed(), session.data_changed());
drop(session);
sbihel marked this conversation as resolved.
Show resolved Hide resolved

// Pull out the session so we can pass it to the store without `Clone` blowing
// away the `cookie_value`.
let session = RwLock::into_inner(
Arc::try_unwrap(session_handle).expect("Session handle still has owners."),
);
let mut session = session_handle.write().await;
if session_is_destroyed {
if let Err(e) = session_layer.store.destroy_session(session).await {
if let Err(e) = session_layer.store.destroy_session(&mut session).await {
tracing::error!("Failed to destroy session: {:?}", e);
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
}
Expand All @@ -366,7 +367,7 @@ where
// - If we use the `ChangedOnly` policy, only
// `session.data_changed()` should trigger this branch.
} else if session_layer.should_store(&cookie_value, session_data_changed) {
match session_layer.store.store_session(session).await {
match session_layer.store.store_session(&mut session).await {
Ok(Some(cookie_value)) => {
let cookie = session_layer.build_cookie(cookie_value);
response.headers_mut().append(
Expand All @@ -391,21 +392,19 @@ where

#[cfg(test)]
mod tests {
use async_session::{
serde::{Deserialize, Serialize},
serde_json,
};
use async_session_memory_store::MemoryStore;
use axum::http::{Request, Response};
use http::{
header::{COOKIE, SET_COOKIE},
HeaderValue, StatusCode,
};
use hyper::Body;
use rand::Rng;
use serde::{Deserialize, Serialize};
use tower::{BoxError, Service, ServiceBuilder, ServiceExt};

use super::PersistencePolicy;
use crate::{async_session::MemoryStore, SessionHandle, SessionLayer};
use crate::{SessionHandle, SessionLayer};

#[derive(Deserialize, Serialize, PartialEq, Debug)]
struct Counter {
Expand Down