diff --git a/atrium-api/Cargo.toml b/atrium-api/Cargo.toml index 61a8effe..494782d2 100644 --- a/atrium-api/Cargo.toml +++ b/atrium-api/Cargo.toml @@ -27,7 +27,7 @@ agent = ["tokio/sync"] atrium-xrpc-client = "0.2.0" futures = "0.3.28" serde_json = "1.0.107" -tokio = { version = "1.33.0", features = ["test-util", "macros"] } +tokio = { version = "1.33.0", features = ["macros", "rt-multi-thread"] } [package.metadata.docs.rs] all-features = true diff --git a/atrium-api/README.md b/atrium-api/README.md index a3f85ffe..e7dab192 100644 --- a/atrium-api/README.md +++ b/atrium-api/README.md @@ -17,7 +17,7 @@ use atrium_api::client::AtpServiceClient; use atrium_api::com::atproto::server::create_session::Input; use atrium_xrpc_client::reqwest::ReqwestClient; -#[tokio::main(flavor = "current_thread")] +#[tokio::main] async fn main() -> Result<(), Box> { let client = AtpServiceClient::new(ReqwestClient::new("https://bsky.social")); let result = client @@ -34,3 +34,32 @@ async fn main() -> Result<(), Box> { Ok(()) } ``` + +### `AtpAgent` + +While `AtpServiceClient` can be used for simple XRPC calls, it is better to use `AtpAgent`, which has practical features such as session management. + +```rust,no_run +use atrium_api::agent::{store::MemorySessionStore, AtpAgent}; +use atrium_xrpc_client::reqwest::ReqwestClient; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let agent = AtpAgent::new( + ReqwestClient::new("https://bsky.social"), + MemorySessionStore::default(), + ); + agent.login("alice@mail.com", "hunter2").await?; + let result = agent + .api + .app + .bsky + .actor + .get_profile(atrium_api::app::bsky::actor::get_profile::Parameters { + actor: "bsky.app".into(), + }) + .await?; + println!("{:?}", result); + Ok(()) +} +``` diff --git a/atrium-api/src/agent.rs b/atrium-api/src/agent.rs index 3244af30..61182787 100644 --- a/atrium-api/src/agent.rs +++ b/atrium-api/src/agent.rs @@ -1,219 +1,43 @@ -//! An ATP "Agent". -//! Manages session token lifecycles and provides all XRPC methods. +//! Implementation of [`AtpAgent`] and definitions of [`SessionStore`] for it. +mod inner; +pub mod store; + +use self::store::SessionStore; use crate::client::Service; -use async_trait::async_trait; -use atrium_xrpc::error::{Error, XrpcErrorKind}; -use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, XrpcResult}; -use http::{Method, Request, Response}; -use serde::{de::DeserializeOwned, Serialize}; +use atrium_xrpc::error::Error; +use atrium_xrpc::XrpcClient; use std::sync::Arc; -use tokio::sync::{Mutex, Notify, RwLock}; /// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output) pub type Session = crate::com::atproto::server::create_session::Output; -pub struct SessionAuthWrapper -where - T: XrpcClient + Send + Sync, -{ - session: Arc>>, - inner: T, -} - -#[async_trait] -impl HttpClient for SessionAuthWrapper -where - T: XrpcClient + Send + Sync, -{ - async fn send_http( - &self, - req: Request>, - ) -> Result>, Box> { - self.inner.send_http(req).await - } -} - -#[async_trait] -impl XrpcClient for SessionAuthWrapper -where - T: XrpcClient + Send + Sync, -{ - fn base_uri(&self) -> String { - self.inner.base_uri() - } - async fn auth(&self, is_refresh: bool) -> Option { - self.session.read().await.as_ref().map(|session| { - if is_refresh { - session.refresh_jwt.clone() - } else { - session.access_jwt.clone() - } - }) - } -} - -pub struct RefreshWrapper +/// An ATP "Agent". +/// Manages session token lifecycles and provides convenience methods. +pub struct AtpAgent where + S: SessionStore + Send + Sync, T: XrpcClient + Send + Sync, { - session: Arc>>, - inner: T, - is_refreshing: Arc>, - notify: Arc, + store: Arc>, + pub api: Service>, } -impl RefreshWrapper +impl AtpAgent where + S: SessionStore + Send + Sync, T: XrpcClient + Send + Sync, { - // Internal helper to refresh sessions - // - Wraps the actual implementation to ensure only one refresh is attempted at a time. - async fn refresh_session(&self) { - { - let mut is_refreshing = self.is_refreshing.lock().await; - if *is_refreshing { - drop(is_refreshing); - return self.notify.notified().await; - } - *is_refreshing = true; - } - // TODO: Ensure `is_refreshing` is reliably set to false even in the event of unexpected errors within `refresh_session_inner()`. - self.refresh_session_inner().await; - *self.is_refreshing.lock().await = false; - self.notify.notify_waiters(); - } - async fn refresh_session_inner(&self) { - if let Ok(output) = self.call_refresh_session().await { - let mut session = self.session.write().await; - let did_doc = session.as_ref().and_then(|s| s.did_doc.clone()); - let email = session.as_ref().and_then(|s| s.email.clone()); - let email_confirmed = session.as_ref().and_then(|s| s.email_confirmed); - session.replace(Session { - access_jwt: output.access_jwt, - did: output.did, - did_doc, - email, - email_confirmed, - handle: output.handle, - refresh_jwt: output.refresh_jwt, - }); - } else { - self.session.write().await.take(); - } - } - // same as `crate::client::com::atproto::server::Service::refresh_session()` - async fn call_refresh_session( - &self, - ) -> Result< - crate::com::atproto::server::refresh_session::Output, - Error, - > { - let response = self - .inner - .send_xrpc::<(), (), _, _>(&XrpcRequest { - method: Method::POST, - path: "com.atproto.server.refreshSession".into(), - parameters: None, - input: None, - encoding: None, - }) - .await?; - match response { - OutputDataOrBytes::Data(data) => Ok(data), - _ => Err(Error::UnexpectedResponseType), - } - } - fn is_expired(result: &XrpcResult) -> bool - where - O: DeserializeOwned + Send + Sync, - E: DeserializeOwned + Send + Sync, - { - if let Err(Error::XrpcResponse(response)) = &result { - if let Some(XrpcErrorKind::Undefined(body)) = &response.error { - if let Some("ExpiredToken") = &body.error.as_deref() { - return true; - } - } - } - false - } -} - -#[async_trait] -impl HttpClient for RefreshWrapper -where - T: XrpcClient + Send + Sync, -{ - async fn send_http( - &self, - req: Request>, - ) -> Result>, Box> { - self.inner.send_http(req).await - } -} - -#[async_trait] -impl XrpcClient for RefreshWrapper -where - T: XrpcClient + Send + Sync, -{ - fn base_uri(&self) -> String { - self.inner.base_uri() - } - async fn auth(&self, is_refresh: bool) -> Option { - self.inner.auth(is_refresh).await - } - async fn send_xrpc(&self, request: &XrpcRequest) -> XrpcResult - where - P: Serialize + Send + Sync, - I: Serialize + Send + Sync, - O: DeserializeOwned + Send + Sync, - E: DeserializeOwned + Send + Sync, - { - let result = self.inner.send_xrpc(request).await; - // handle session-refreshes as needed - if Self::is_expired(&result) { - self.refresh_session().await; - self.inner.send_xrpc(request).await - } else { - result - } - } -} - -pub struct AtpAgent -where - T: XrpcClient + Send + Sync, -{ - pub api: Service>>, - session: Arc>>, -} - -impl AtpAgent -where - T: XrpcClient + Send + Sync, -{ - pub fn new(xrpc: T) -> Self { - let session = Arc::new(RwLock::new(None)); - let api = Service::new(Arc::new(RefreshWrapper { - session: Arc::clone(&session), - inner: SessionAuthWrapper { - session: Arc::clone(&session), - inner: xrpc, - }, - is_refreshing: Arc::new(Mutex::new(false)), - notify: Arc::new(Notify::new()), - })); - Self { api, session } - } - pub async fn get_session(&self) -> Option { - self.session.read().await.clone() + /// Create a new agent. + pub fn new(xrpc: T, store: S) -> Self { + let store = Arc::new(inner::Store::new(store, xrpc.base_uri())); + let api = Service::new(Arc::new(inner::Client::new(Arc::clone(&store), xrpc))); + Self { store, api } } /// Start a new session with this agent. pub async fn login( &self, - identifier: &str, - password: &str, + identifier: impl AsRef, + password: impl AsRef, ) -> Result> { let result = self .api @@ -221,11 +45,14 @@ where .atproto .server .create_session(crate::com::atproto::server::create_session::Input { - identifier: identifier.into(), - password: password.into(), + identifier: identifier.as_ref().into(), + password: password.as_ref().into(), }) .await?; - self.session.write().await.replace(result.clone()); + self.store.set_session(result.clone()).await; + if let Some(did_doc) = &result.did_doc { + self.store.update_endpoint(did_doc); + } Ok(result) } /// Resume a pre-existing session with this agent. @@ -233,20 +60,25 @@ where &self, session: Session, ) -> Result<(), Error> { - self.session.write().await.replace(session.clone()); + self.store.set_session(session.clone()).await; let result = self.api.com.atproto.server.get_session().await; match result { Ok(output) => { assert_eq!(output.did, session.did); - if let Some(session) = self.session.write().await.as_mut() { + if let Some(mut session) = self.store.get_session().await { + session.did_doc = output.did_doc.clone(); session.email = output.email; session.email_confirmed = output.email_confirmed; session.handle = output.handle; + self.store.set_session(session).await; + } + if let Some(did_doc) = &output.did_doc { + self.store.update_endpoint(did_doc); } Ok(()) } Err(err) => { - self.session.write().await.take(); + self.store.clear_session().await; Err(err) } } @@ -256,8 +88,14 @@ where #[cfg(test)] mod tests { use super::*; + use crate::agent::store::MemorySessionStore; + use crate::did_doc::{DidDocument, Service, VerificationMethod}; + use async_trait::async_trait; + use atrium_xrpc::HttpClient; use futures::future::join_all; + use http::{Request, Response}; use std::collections::HashMap; + use tokio::sync::RwLock; #[derive(Default)] struct DummyResponses { @@ -360,8 +198,8 @@ mod tests { #[tokio::test] async fn test_new() { - let agent = AtpAgent::new(DummyClient::default()); - assert_eq!(agent.get_session().await, None); + let agent = AtpAgent::new(DummyClient::default(), MemorySessionStore::default()); + assert_eq!(agent.store.get_session().await, None); } #[tokio::test] @@ -378,12 +216,12 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client); + let agent = AtpAgent::new(client, MemorySessionStore::default()); agent .login("test", "pass") .await .expect("login should be succeeded"); - assert_eq!(agent.get_session().await, Some(session)); + assert_eq!(agent.store.get_session().await, Some(session)); } // failure with `createSession` error { @@ -393,12 +231,12 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client); + let agent = AtpAgent::new(client, MemorySessionStore::default()); agent .login("test", "bad") .await .expect_err("login should be failed"); - assert_eq!(agent.get_session().await, None); + assert_eq!(agent.store.get_session().await, None); } } @@ -418,8 +256,8 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client); - agent.session.write().await.replace(session); + let agent = AtpAgent::new(client, MemorySessionStore::default()); + agent.store.set_session(session).await; let output = agent .api .com @@ -448,8 +286,8 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client); - agent.session.write().await.replace(session); + let agent = AtpAgent::new(client, MemorySessionStore::default()); + agent.store.set_session(session).await; let output = agent .api .com @@ -460,7 +298,11 @@ mod tests { .expect("get session should be succeeded"); assert_eq!(output.did, "did"); assert_eq!( - agent.get_session().await.map(|session| session.access_jwt), + agent + .store + .get_session() + .await + .map(|session| session.access_jwt), Some("access".into()) ); } @@ -483,8 +325,8 @@ mod tests { ..Default::default() }; let counts = Arc::clone(&client.counts); - let agent = Arc::new(AtpAgent::new(client)); - agent.session.write().await.replace(session); + let agent = Arc::new(AtpAgent::new(client, MemorySessionStore::default())); + agent.store.set_session(session).await; let handles = (0..3).map(|_| { let agent = Arc::clone(&agent); tokio::spawn(async move { agent.api.com.atproto.server.get_session().await }) @@ -499,7 +341,11 @@ mod tests { assert_eq!(output.did, "did"); } assert_eq!( - agent.get_session().await.map(|session| session.access_jwt), + agent + .store + .get_session() + .await + .map(|session| session.access_jwt), Some("access".into()) ); assert_eq!( @@ -529,8 +375,8 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client); - assert_eq!(agent.get_session().await, None); + let agent = AtpAgent::new(client, MemorySessionStore::default()); + assert_eq!(agent.store.get_session().await, None); agent .resume_session(Session { email: Some(String::from("test@example.com")), @@ -538,7 +384,7 @@ mod tests { }) .await .expect("resume_session should be succeeded"); - assert_eq!(agent.get_session().await, Some(session.clone())); + assert_eq!(agent.store.get_session().await, Some(session.clone())); } // failure with `getSession` error { @@ -548,13 +394,13 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client); - assert_eq!(agent.get_session().await, None); + let agent = AtpAgent::new(client, MemorySessionStore::default()); + assert_eq!(agent.store.get_session().await, None); agent .resume_session(session) .await .expect_err("resume_session should be failed"); - assert_eq!(agent.get_session().await, None); + assert_eq!(agent.store.get_session().await, None); } } @@ -574,7 +420,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client); + let agent = AtpAgent::new(client, MemorySessionStore::default()); agent .resume_session(Session { access_jwt: "expired".into(), @@ -582,6 +428,89 @@ mod tests { }) .await .expect("resume_session should be succeeded"); - assert_eq!(agent.get_session().await, Some(session)); + assert_eq!(agent.store.get_session().await, Some(session)); + } + + #[tokio::test] + async fn test_login_with_diddoc() { + let session = session(); + let did_doc = DidDocument { + id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(), + also_known_as: Some(vec!["at://atproto.com".into()]), + verification_method: Some(vec![VerificationMethod { + id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz#atproto".into(), + r#type: "Multikey".into(), + controller: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(), + public_key_multibase: Some( + "zQ3shXjHeiBuRCKmM36cuYnm7YEMzhGnCmCyW92sRJ9pribSF".into(), + ), + }]), + service: Some(vec![Service { + id: "#atproto_pds".into(), + r#type: "AtprotoPersonalDataServer".into(), + service_endpoint: "https://bsky.social".into(), + }]), + }; + // success + { + let client = DummyClient { + responses: DummyResponses { + create_session: Some(crate::com::atproto::server::create_session::Output { + did_doc: Some(did_doc.clone()), + ..session.clone() + }), + ..Default::default() + }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemorySessionStore::default()); + agent + .login("test", "pass") + .await + .expect("login should be succeeded"); + assert_eq!(agent.store.get_endpoint(), "https://bsky.social"); + assert_eq!( + agent.api.com.atproto.server.xrpc.base_uri(), + "https://bsky.social" + ); + } + // invalid services + { + let client = DummyClient { + responses: DummyResponses { + create_session: Some(crate::com::atproto::server::create_session::Output { + did_doc: Some(DidDocument { + service: Some(vec![ + Service { + id: "#pds".into(), // not `#atproto_pds` + r#type: "AtprotoPersonalDataServer".into(), + service_endpoint: "https://bsky.social".into(), + }, + Service { + id: "#atproto_pds".into(), + r#type: "AtprotoPersonalDataServer".into(), + service_endpoint: "htps://bsky.social".into(), // invalid url (not `https`) + }, + ]), + ..did_doc.clone() + }), + ..session.clone() + }), + ..Default::default() + }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemorySessionStore::default()); + agent + .login("test", "pass") + .await + .expect("login should be succeeded"); + // not updated + assert_eq!(agent.store.get_endpoint(), "http://localhost:8080"); + assert_eq!( + agent.api.com.atproto.server.xrpc.base_uri(), + "http://localhost:8080" + ); + } } } diff --git a/atrium-api/src/agent/inner.rs b/atrium-api/src/agent/inner.rs new file mode 100644 index 00000000..f590ec54 --- /dev/null +++ b/atrium-api/src/agent/inner.rs @@ -0,0 +1,252 @@ +use super::{Session, SessionStore}; +use crate::did_doc::DidDocument; +use async_trait::async_trait; +use atrium_xrpc::error::{Error, XrpcErrorKind}; +use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, XrpcResult}; +use http::{Method, Request, Response, Uri}; +use serde::{de::DeserializeOwned, Serialize}; +use std::sync::{Arc, RwLock}; +use tokio::sync::{Mutex, Notify}; + +const REFRESH_SESSION: &str = "com.atproto.server.refreshSession"; + +struct SessionStoreClient { + store: Arc>, + inner: T, +} + +#[async_trait] +impl HttpClient for SessionStoreClient +where + S: Send + Sync, + T: HttpClient + Send + Sync, +{ + async fn send_http( + &self, + request: Request>, + ) -> Result>, Box> { + self.inner.send_http(request).await + } +} + +#[async_trait] +impl XrpcClient for SessionStoreClient +where + S: SessionStore + Send + Sync, + T: XrpcClient + Send + Sync, +{ + fn base_uri(&self) -> String { + self.store.get_endpoint() + } + async fn auth(&self, is_refresh: bool) -> Option { + self.store.get_session().await.map(|session| { + if is_refresh { + session.refresh_jwt.clone() + } else { + session.access_jwt.clone() + } + }) + } +} + +pub struct Client { + store: Arc>, + inner: SessionStoreClient, + is_refreshing: Mutex, + notify: Notify, +} + +impl Client +where + S: SessionStore + Send + Sync, + T: XrpcClient + Send + Sync, +{ + pub(crate) fn new(store: Arc>, xrpc: T) -> Self { + let inner = SessionStoreClient { + store: Arc::clone(&store), + inner: xrpc, + }; + Self { + store, + inner, + is_refreshing: Mutex::new(false), + notify: Notify::new(), + } + } + // Internal helper to refresh sessions + // - Wraps the actual implementation to ensure only one refresh is attempted at a time. + async fn refresh_session(&self) { + { + let mut is_refreshing = self.is_refreshing.lock().await; + if *is_refreshing { + drop(is_refreshing); + return self.notify.notified().await; + } + *is_refreshing = true; + } + // TODO: Ensure `is_refreshing` is reliably set to false even in the event of unexpected errors within `refresh_session_inner()`. + self.refresh_session_inner().await; + *self.is_refreshing.lock().await = false; + self.notify.notify_waiters(); + } + async fn refresh_session_inner(&self) { + if let Ok(output) = self.call_refresh_session().await { + if let Some(mut session) = self.store.get_session().await { + session.access_jwt = output.access_jwt; + session.did = output.did; + session.did_doc = output.did_doc.clone(); + session.handle = output.handle; + session.refresh_jwt = output.refresh_jwt; + self.store.set_session(session).await; + } + if let Some(did_doc) = &output.did_doc { + self.store.update_endpoint(did_doc); + } + } else { + self.store.clear_session().await; + } + } + // same as `crate::client::com::atproto::server::Service::refresh_session()` + async fn call_refresh_session( + &self, + ) -> Result< + crate::com::atproto::server::refresh_session::Output, + Error, + > { + let response = self + .inner + .send_xrpc::<(), (), _, _>(&XrpcRequest { + method: Method::POST, + path: REFRESH_SESSION.into(), + parameters: None, + input: None, + encoding: None, + }) + .await?; + match response { + OutputDataOrBytes::Data(data) => Ok(data), + _ => Err(Error::UnexpectedResponseType), + } + } + fn is_expired(result: &XrpcResult) -> bool + where + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync, + { + if let Err(Error::XrpcResponse(response)) = &result { + if let Some(XrpcErrorKind::Undefined(body)) = &response.error { + if let Some("ExpiredToken") = &body.error.as_deref() { + return true; + } + } + } + false + } +} + +#[async_trait] +impl HttpClient for Client +where + S: Send + Sync, + T: HttpClient + Send + Sync, +{ + async fn send_http( + &self, + request: Request>, + ) -> Result>, Box> { + self.inner.send_http(request).await + } +} + +#[async_trait] +impl XrpcClient for Client +where + S: SessionStore + Send + Sync, + T: XrpcClient + Send + Sync, +{ + fn base_uri(&self) -> String { + self.inner.base_uri() + } + async fn send_xrpc(&self, request: &XrpcRequest) -> XrpcResult + where + P: Serialize + Send + Sync, + I: Serialize + Send + Sync, + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync, + { + let result = self.inner.send_xrpc(request).await; + // handle session-refreshes as needed + if Self::is_expired(&result) { + self.refresh_session().await; + self.inner.send_xrpc(request).await + } else { + result + } + } +} + +pub struct Store { + inner: S, + endpoint: RwLock, +} + +impl Store { + pub fn new(inner: S, initial_endpoint: String) -> Self { + Self { + inner, + endpoint: RwLock::new(initial_endpoint), + } + } + pub fn get_endpoint(&self) -> String { + self.endpoint + .read() + .expect("failed to read endpoint") + .clone() + } + pub fn update_endpoint(&self, did_doc: &DidDocument) { + if let Some(endpoint) = Self::get_pds_endpoint(did_doc) { + *self.endpoint.write().expect("failed to write endpoint") = endpoint; + } + } + fn get_pds_endpoint(did_doc: &DidDocument) -> Option { + Self::get_service_endpoint(did_doc, ("#atproto_pds", "AtprotoPersonalDataServer")) + } + fn get_service_endpoint(did_doc: &DidDocument, (id, r#type): (&str, &str)) -> Option { + let full_id = did_doc.id.clone() + id; + if let Some(services) = &did_doc.service { + let service = services + .iter() + .find(|service| service.id == id || service.id == full_id)?; + if service.r#type == r#type && Self::validate_url(&service.service_endpoint) { + return Some(service.service_endpoint.clone()); + } + } + None + } + fn validate_url(url: &str) -> bool { + if let Ok(uri) = url.parse::() { + if let Some(scheme) = uri.scheme() { + if (scheme == "https" || scheme == "http") && uri.host().is_some() { + return true; + } + } + } + false + } +} + +#[async_trait] +impl SessionStore for Store +where + S: SessionStore + Send + Sync, +{ + async fn get_session(&self) -> Option { + self.inner.get_session().await + } + async fn set_session(&self, session: Session) { + self.inner.set_session(session).await; + } + async fn clear_session(&self) { + self.inner.clear_session().await; + } +} diff --git a/atrium-api/src/agent/store.rs b/atrium-api/src/agent/store.rs new file mode 100644 index 00000000..732b7262 --- /dev/null +++ b/atrium-api/src/agent/store.rs @@ -0,0 +1,15 @@ +mod memory; + +pub use self::memory::MemorySessionStore; +pub(crate) use super::Session; +use async_trait::async_trait; + +#[async_trait] +pub trait SessionStore { + #[must_use] + async fn get_session(&self) -> Option; + #[must_use] + async fn set_session(&self, session: Session); + #[must_use] + async fn clear_session(&self); +} diff --git a/atrium-api/src/agent/store/memory.rs b/atrium-api/src/agent/store/memory.rs new file mode 100644 index 00000000..642d861b --- /dev/null +++ b/atrium-api/src/agent/store/memory.rs @@ -0,0 +1,22 @@ +use super::{Session, SessionStore}; +use async_trait::async_trait; +use std::sync::Arc; +use tokio::sync::RwLock; + +#[derive(Default)] +pub struct MemorySessionStore { + session: Arc>>, +} + +#[async_trait] +impl SessionStore for MemorySessionStore { + async fn get_session(&self) -> Option { + self.session.read().await.clone() + } + async fn set_session(&self, session: Session) { + self.session.write().await.replace(session); + } + async fn clear_session(&self) { + self.session.write().await.take(); + } +}