Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to async-session v4 #50

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Replace RwLock with Mutex
Co-authored-by: Max Countryman <maxc@me.com>
  • Loading branch information
sbihel and maxcountryman committed Sep 11, 2023
commit 079da64729ddaed5d238bcec1abbcfb83a0f6411
68 changes: 18 additions & 50 deletions src/extractors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,46 @@
use std::ops::{Deref, DerefMut};

use axum::{async_trait, extract::FromRequestParts, http::request::Parts, Extension};
use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard};
use tokio::sync::OwnedMutexGuard;

use crate::SessionHandle;

/// An extractor which provides a readable session. Sessions may have many
/// readers.
#[derive(Debug)]
pub struct ReadableSession {
session: OwnedRwLockReadGuard<async_session::Session>,
pub struct Session {
session_guard: OwnedMutexGuard<async_session::Session>,
}

impl Deref for ReadableSession {
type Target = OwnedRwLockReadGuard<async_session::Session>;
impl Deref for Session {
type Target = OwnedMutexGuard<async_session::Session>;

fn deref(&self) -> &Self::Target {
&self.session
&self.session_guard
}
}

#[async_trait]
impl<S> FromRequestParts<S> for ReadableSession
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Extension(session_handle): Extension<SessionHandle> =
Extension::from_request_parts(parts, state)
.await
.expect("Session extension missing. Is the session layer installed?");
let session = session_handle.read_owned().await;

Ok(Self { session })
}
}

/// An extractor which provides a writable session. Sessions may have only one
/// writer.
#[derive(Debug)]
pub struct WritableSession {
session: OwnedRwLockWriteGuard<async_session::Session>,
}

impl Deref for WritableSession {
type Target = OwnedRwLockWriteGuard<async_session::Session>;

fn deref(&self) -> &Self::Target {
&self.session
}
}

impl DerefMut for WritableSession {
impl DerefMut for Session {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.session
&mut self.session_guard
}
}

#[async_trait]
impl<S> FromRequestParts<S> for WritableSession
impl<S> FromRequestParts<S> for Session
where
S: Send + Sync,
S: Send + Sync + Clone,
{
type Rejection = std::convert::Infallible;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Extension(session_handle): Extension<SessionHandle> =
Extension::from_request_parts(parts, state)
.await
.expect("Session extension missing. Is the session layer installed?");
let session = session_handle.write_owned().await;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
use axum::RequestPartsExt;
let Extension(session) = parts
.extract::<Extension<SessionHandle>>()
.await
.expect("Session extension missing. Is the session layer installed?");

Ok(Self { session })
let session_guard = session.lock_owned().await;
Ok(Self { session_guard })
}
}
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
//! ```rust,no_run
//! use async_session_memory_store::MemoryStore;
//! use axum::{routing::get, Router};
//! use axum_sessions::{extractors::WritableSession, PersistencePolicy, SessionLayer};
//! use axum_sessions::{extractors::Session, PersistencePolicy, SessionLayer};
//!
//! #[tokio::main]
//! async fn main() {
//! let store = MemoryStore::new();
//! let secret = b"..."; // MUST be at least 64 bytes!
//! let session_layer = SessionLayer::new(store, secret);
//!
//! async fn handler(mut session: WritableSession) {
//! async fn handler(mut session: Session) {
//! session
//! .insert("foo", 42)
//! .expect("Could not store the answer.");
Expand Down Expand Up @@ -56,7 +56,7 @@
//!
//! async fn handle(request: Request<Body>) -> Result<Response<Body>, Infallible> {
//! let session_handle = request.extensions().get::<SessionHandle>().unwrap();
//! let session = session_handle.read().await;
//! let session = session_handle.lock().await;
//! // Use the session as you'd like.
//!
//! Ok(Response::new(Body::empty()))
Expand Down
21 changes: 10 additions & 11 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use futures::future::BoxFuture;
use hmac::{Hmac, Mac};
use sha2::{digest::generic_array::GenericArray, Sha256};
use tokio::sync::RwLock;
use tokio::sync::Mutex;
use tower::{Layer, Service};

const BASE64_DIGEST_LEN: usize = 44;
Expand All @@ -34,7 +34,7 @@ const BASE64_DIGEST_LEN: usize = 44;
/// than using the handle directly. A notable exception is when using this
/// library as a generic Tower middleware: such use cases will consume the
/// handle directly.
pub type SessionHandle = Arc<RwLock<async_session::Session>>;
pub type SessionHandle = Arc<Mutex<async_session::Session>>;

/// Controls how the session data is persisted and created.
#[derive(Clone)]
Expand Down Expand Up @@ -189,7 +189,7 @@ where
None => None,
};

Arc::new(RwLock::new(
Arc::new(Mutex::new(
session
.and_then(async_session::Session::validate)
.unwrap_or_default(),
Expand Down Expand Up @@ -331,22 +331,21 @@ where
let mut inner = std::mem::replace(&mut self.inner, inner);
Box::pin(async move {
let session_handle = session_layer.load_or_create(cookie_value.clone()).await;
let mut session = session_handle.lock().await;

let mut session = session_handle.write().await;
if let Some(ttl) = session_layer.session_ttl {
(*session).expire_in(ttl);
}
drop(session);

request.extensions_mut().insert(session_handle.clone());

let mut response = inner.call(request).await?;

let session = session_handle.read().await;
let mut session = session_handle.lock().await;
let (session_is_destroyed, session_data_changed) =
(session.is_destroyed(), session.data_changed());
drop(session);

let mut session = session_handle.write().await;
if session_is_destroyed {
if let Err(e) = session_layer.store.destroy_session(&mut session).await {
tracing::error!("Failed to destroy session: {:?}", e);
Expand Down Expand Up @@ -701,7 +700,7 @@ mod tests {
async fn echo_read_session(req: Request<Body>) -> Result<Response<Body>, BoxError> {
{
let session_handle = req.extensions().get::<SessionHandle>().unwrap();
let session = session_handle.write().await;
let session = session_handle.lock().await;
let _ = session.get::<String>("signed_in").unwrap_or_default();
}
Ok(Response::new(req.into_body()))
Expand All @@ -710,7 +709,7 @@ mod tests {
async fn echo_with_session_change(req: Request<Body>) -> Result<Response<Body>, BoxError> {
{
let session_handle = req.extensions().get::<SessionHandle>().unwrap();
let mut session = session_handle.write().await;
let mut session = session_handle.lock().await;
session.insert("signed_in", true).unwrap();
}
Ok(Response::new(req.into_body()))
Expand All @@ -720,7 +719,7 @@ mod tests {
// Destroy the session if we received a session cookie.
if req.headers().get(COOKIE).is_some() {
let session_handle = req.extensions().get::<SessionHandle>().unwrap();
let mut session = session_handle.write().await;
let mut session = session_handle.lock().await;
session.destroy();
}

Expand All @@ -732,7 +731,7 @@ mod tests {

{
let session_handle = req.extensions().get::<SessionHandle>().unwrap();
let mut session = session_handle.write().await;
let mut session = session_handle.lock().await;
counter = session
.get("counter")
.map(|count: i32| count + 1)
Expand Down