From 82a9398f2d96391c4b75a32d52a7d60db0e79e6d Mon Sep 17 00:00:00 2001 From: avdb13 Date: Sun, 24 Nov 2024 04:47:13 +0000 Subject: [PATCH] Merge branch 'feature/agent-rework' into oauth-session --- atrium-api/src/agent/atp_agent.rs | 40 ++--- atrium-api/src/agent/atp_agent/inner.rs | 21 +-- atrium-common/src/lib.rs | 3 - atrium-common/src/store.rs | 2 +- atrium-common/src/store/memory.rs | 20 +-- atrium-oauth/identity/src/error.rs | 16 -- atrium-oauth/oauth-client/examples/main.rs | 31 ++-- atrium-oauth/oauth-client/src/error.rs | 2 +- .../oauth-client/src/http_client/dpop.rs | 39 ++--- atrium-oauth/oauth-client/src/oauth_client.rs | 91 ++++++----- .../oauth-client/src/oauth_session.rs | 145 +++++++++-------- atrium-oauth/oauth-client/src/server_agent.rs | 150 ++++++++---------- atrium-oauth/oauth-client/src/store.rs | 1 + .../oauth-client/src/store/session.rs | 11 +- .../oauth-client/src/store/session_getter.rs | 49 ++++++ atrium-oauth/oauth-client/src/store/state.rs | 6 +- atrium-oauth/oauth-client/src/types.rs | 6 +- .../oauth-client/src/types/request.rs | 5 +- atrium-oauth/oauth-client/src/types/token.rs | 14 +- bsky-sdk/src/agent.rs | 20 +-- bsky-sdk/src/agent/builder.rs | 20 +-- bsky-sdk/src/record.rs | 10 +- bsky-sdk/src/record/agent.rs | 4 +- 23 files changed, 362 insertions(+), 344 deletions(-) create mode 100644 atrium-oauth/oauth-client/src/store/session_getter.rs diff --git a/atrium-api/src/agent/atp_agent.rs b/atrium-api/src/agent/atp_agent.rs index f68e61c0..092f92a6 100644 --- a/atrium-api/src/agent/atp_agent.rs +++ b/atrium-api/src/agent/atp_agent.rs @@ -7,7 +7,7 @@ use crate::{ did_doc::DidDocument, types::{string::Did, TryFromUnknown}, }; -use atrium_common::store::MapStore; +use atrium_common::store::Store; use atrium_xrpc::{Error, XrpcClient}; use std::{ops::Deref, sync::Arc}; @@ -16,7 +16,7 @@ pub type AtpSession = crate::com::atproto::server::create_session::Output; pub struct CredentialSession where - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, T: XrpcClient + Send + Sync, { @@ -27,7 +27,7 @@ where impl CredentialSession where - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, T: XrpcClient + Send + Sync, { @@ -152,7 +152,7 @@ where /// Manages session token lifecycles and provides convenience methods. pub struct AtpAgent where - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, T: XrpcClient + Send + Sync, { @@ -161,7 +161,7 @@ where impl AtpAgent where - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, T: XrpcClient + Send + Sync, { @@ -173,7 +173,7 @@ where impl Deref for AtpAgent where - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, T: XrpcClient + Send + Sync, { @@ -191,7 +191,7 @@ mod tests { use crate::com::atproto::server::create_session::OutputData; use crate::did_doc::{DidDocument, Service, VerificationMethod}; use crate::types::TryIntoUnknown; - use atrium_common::store::memory::MemoryMapStore; + use atrium_common::store::memory::MemoryStore; use atrium_xrpc::HttpClient; use http::{HeaderMap, HeaderName, HeaderValue, Request, Response}; use std::collections::HashMap; @@ -319,7 +319,7 @@ mod tests { #[tokio::test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] async fn test_new() { - let agent = AtpAgent::new(MockClient::default(), MemoryMapStore::default()); + let agent = AtpAgent::new(MockClient::default(), MemoryStore::default()); assert_eq!(agent.get_session().await, None); } @@ -338,7 +338,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); agent.login("test", "pass").await.expect("login should be succeeded"); assert_eq!(agent.get_session().await, Some(session_data.into())); } @@ -348,7 +348,7 @@ mod tests { responses: MockResponses { ..Default::default() }, ..Default::default() }; - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); agent.login("test", "bad").await.expect_err("login should be failed"); assert_eq!(agent.get_session().await, None); } @@ -374,7 +374,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); agent .store .set((), session_data.clone().into()) @@ -412,7 +412,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); agent .store .set((), session_data.clone().into()) @@ -460,7 +460,7 @@ mod tests { ..Default::default() }; let counts = Arc::clone(&client.counts); - let agent = Arc::new(AtpAgent::new(client, MemoryMapStore::default())); + let agent = Arc::new(AtpAgent::new(client, MemoryStore::default())); agent .store .set((), session_data.clone().into()) @@ -519,7 +519,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); assert_eq!(agent.get_session().await, None); agent .resume_session( @@ -539,7 +539,7 @@ mod tests { responses: MockResponses { ..Default::default() }, ..Default::default() }; - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); assert_eq!(agent.get_session().await, None); agent .resume_session(session_data.clone().into()) @@ -569,7 +569,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); agent .resume_session( OutputData { access_jwt: "expired".into(), ..session_data.clone() }.into(), @@ -618,7 +618,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); agent.login("test", "pass").await.expect("login should be succeeded"); assert_eq!(agent.get_endpoint().await, "https://bsky.social"); assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "https://bsky.social"); @@ -653,7 +653,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); agent.login("test", "pass").await.expect("login should be succeeded"); // not updated assert_eq!(agent.get_endpoint().await, "http://localhost:8080"); @@ -666,7 +666,7 @@ mod tests { async fn test_configure_labelers_header() { let client = MockClient::default(); let headers = Arc::clone(&client.headers); - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); agent .api @@ -729,7 +729,7 @@ mod tests { async fn test_configure_proxy_header() { let client = MockClient::default(); let headers = Arc::clone(&client.headers); - let agent = AtpAgent::new(client, MemoryMapStore::default()); + let agent = AtpAgent::new(client, MemoryStore::default()); agent .api diff --git a/atrium-api/src/agent/atp_agent/inner.rs b/atrium-api/src/agent/atp_agent/inner.rs index 962badbd..ba801f77 100644 --- a/atrium-api/src/agent/atp_agent/inner.rs +++ b/atrium-api/src/agent/atp_agent/inner.rs @@ -1,12 +1,13 @@ use crate::did_doc::DidDocument; use crate::types::string::Did; use crate::types::TryFromUnknown; -use atrium_common::store::MapStore; +use atrium_common::store::Store as StoreTrait; use atrium_xrpc::error::{Error, Result, XrpcErrorKind}; use atrium_xrpc::types::AuthorizationToken; use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest}; use http::{Method, Request, Response}; use serde::{de::DeserializeOwned, Serialize}; +use std::hash::Hash; use std::{ fmt::Debug, sync::{Arc, RwLock}, @@ -71,14 +72,14 @@ where impl XrpcClient for WrapperClient where - S: MapStore<(), AtpSession> + Send + Sync, + S: StoreTrait<(), AtpSession> + Send + Sync, T: XrpcClient + Send + Sync, { fn base_uri(&self) -> String { self.store.get_endpoint() } - async fn authorization_token(&self, is_refresh: bool) -> Option { - self.store.get_session().await.map(|session| { + async fn authorization_token(&self, is_refresh: bool) -> Option { + self.store.get(&()).await.transpose().and_then(core::result::Result::ok).map(|session| { AuthorizationToken::Bearer(if is_refresh { session.data.refresh_jwt } else { @@ -103,7 +104,7 @@ pub struct Client { impl Client where - S: MapStore<(), AtpSession> + Send + Sync, + S: StoreTrait<(), AtpSession> + Send + Sync, T: XrpcClient + Send + Sync, { pub fn new(store: Arc>, xrpc: T) -> Self { @@ -218,7 +219,7 @@ where impl Clone for Client where - S: MapStore<(), AtpSession> + Send + Sync, + S: StoreTrait<(), AtpSession> + Send + Sync, T: XrpcClient + Send + Sync, { fn clone(&self) -> Self { @@ -247,7 +248,7 @@ where impl XrpcClient for Client where - S: MapStore<(), AtpSession> + Send + Sync, + S: StoreTrait<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, T: XrpcClient + Send + Sync, { @@ -294,11 +295,11 @@ impl Store { } } -impl MapStore for Store +impl StoreTrait for Store where K: Eq + Hash + Send + Sync, - V: Clone + Send + Sync, - S: MapStore + Send + Sync, + V: Clone + Send, + S: StoreTrait + Sync, { type Error = S::Error; diff --git a/atrium-common/src/lib.rs b/atrium-common/src/lib.rs index 97195bdf..8a69602e 100644 --- a/atrium-common/src/lib.rs +++ b/atrium-common/src/lib.rs @@ -1,6 +1,3 @@ pub mod resolver; pub mod store; pub mod types; - -pub mod resolver; -pub mod store; diff --git a/atrium-common/src/store.rs b/atrium-common/src/store.rs index 97f7a3e4..d2d8a30a 100644 --- a/atrium-common/src/store.rs +++ b/atrium-common/src/store.rs @@ -5,7 +5,7 @@ use std::future::Future; use std::hash::Hash; #[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] -pub trait MapStore +pub trait Store where K: Eq + Hash, V: Clone, diff --git a/atrium-common/src/store/memory.rs b/atrium-common/src/store/memory.rs index ed6d9971..b792bf4d 100644 --- a/atrium-common/src/store/memory.rs +++ b/atrium-common/src/store/memory.rs @@ -1,27 +1,27 @@ -use super::MapStore; +use super::Store; use std::collections::HashMap; use std::fmt::Debug; use std::hash::Hash; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use thiserror::Error; +use tokio::sync::Mutex; #[derive(Error, Debug)] #[error("memory store error")] pub struct Error; -// TODO: LRU cache? #[derive(Clone)] -pub struct MemoryMapStore { +pub struct MemoryStore { store: Arc>>, } -impl Default for MemoryMapStore { +impl Default for MemoryStore { fn default() -> Self { Self { store: Arc::new(Mutex::new(HashMap::new())) } } } -impl MapStore for MemoryMapStore +impl Store for MemoryStore where K: Debug + Eq + Hash + Send + Sync + 'static, V: Debug + Clone + Send + Sync + 'static, @@ -29,18 +29,18 @@ where type Error = Error; async fn get(&self, key: &K) -> Result, Self::Error> { - Ok(self.store.lock().unwrap().get(key).cloned()) + Ok(self.store.lock().await.get(key).cloned()) } async fn set(&self, key: K, value: V) -> Result<(), Self::Error> { - self.store.lock().unwrap().insert(key, value); + self.store.lock().await.insert(key, value); Ok(()) } async fn del(&self, key: &K) -> Result<(), Self::Error> { - self.store.lock().unwrap().remove(key); + self.store.lock().await.remove(key); Ok(()) } async fn clear(&self) -> Result<(), Self::Error> { - self.store.lock().unwrap().clear(); + self.store.lock().await.clear(); Ok(()) } } diff --git a/atrium-oauth/identity/src/error.rs b/atrium-oauth/identity/src/error.rs index cdb6769b..8dc0dc6f 100644 --- a/atrium-oauth/identity/src/error.rs +++ b/atrium-oauth/identity/src/error.rs @@ -1,5 +1,4 @@ use atrium_api::types::string::Did; -use atrium_common::resolver; use atrium_xrpc::http::uri::InvalidUri; use atrium_xrpc::http::StatusCode; use thiserror::Error; @@ -36,19 +35,4 @@ pub enum Error { Uri(#[from] InvalidUri), } -impl From for Error { - fn from(error: resolver::Error) -> Self { - match error { - resolver::Error::DnsResolver(error) => Error::DnsResolver(error), - resolver::Error::Http(error) => Error::Http(error), - resolver::Error::HttpClient(error) => Error::HttpClient(error), - resolver::Error::HttpStatus(error) => Error::HttpStatus(error), - resolver::Error::SerdeJson(error) => Error::SerdeJson(error), - resolver::Error::SerdeHtmlForm(error) => Error::SerdeHtmlForm(error), - resolver::Error::Uri(error) => Error::Uri(error), - resolver::Error::NotFound => Error::NotFound, - } - } -} - pub type Result = core::result::Result; diff --git a/atrium-oauth/oauth-client/examples/main.rs b/atrium-oauth/oauth-client/examples/main.rs index 655db944..af0f18e7 100644 --- a/atrium-oauth/oauth-client/examples/main.rs +++ b/atrium-oauth/oauth-client/examples/main.rs @@ -1,8 +1,7 @@ use atrium_api::agent::Agent; -use atrium_common::store::memory::MemoryMapStore; use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}; use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver}; -use atrium_oauth_client::store::session::{MemorySessionStore, Session}; +use atrium_oauth_client::store::session::MemorySessionStore; use atrium_oauth_client::store::state::MemoryStateStore; use atrium_oauth_client::{ AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient, @@ -80,7 +79,7 @@ async fn main() -> Result<(), Box> { ); // Click the URL and sign in, - // then copy and paste the URL like “http://127.0.0.1/?iss=...&code=...” after it is redirected. + // then copy and paste the URL like “http://127.0.0.1/callback?iss=...&code=...” after it is redirected. print!("Redirected url: "); stdout().lock().flush()?; @@ -90,13 +89,25 @@ async fn main() -> Result<(), Box> { let uri = url.trim().parse::()?; let params = serde_html_form::from_str(uri.query().unwrap())?; - let session_manager = client.callback::>(params).await?; - let session = session_manager.get_session(false).await?; - println!("{}", serde_json::to_string_pretty(&session)?); - - let agent = Agent::new(session_manager); - let session = agent.api.com.atproto.server.get_session().await?; - println!("{:?}", &session.data); + let (session, _) = client.callback(params).await?; + let agent = Agent::new(session); + let output = agent + .api + .app + .bsky + .feed + .get_timeline( + atrium_api::app::bsky::feed::get_timeline::ParametersData { + algorithm: None, + cursor: None, + limit: 3.try_into().ok(), + } + .into(), + ) + .await?; + for feed in &output.feed { + println!("{feed:?}"); + } Ok(()) } diff --git a/atrium-oauth/oauth-client/src/error.rs b/atrium-oauth/oauth-client/src/error.rs index ca2301b5..ba1bd5ce 100644 --- a/atrium-oauth/oauth-client/src/error.rs +++ b/atrium-oauth/oauth-client/src/error.rs @@ -16,7 +16,7 @@ pub enum Error { Authorize(String), #[error("callback error: {0}")] Callback(String), - #[error("state store error: {0:?}")] + #[error("state store error: {0}")] StateStore(Box), #[error("session store error: {0}")] SessionStore(Box), diff --git a/atrium-oauth/oauth-client/src/http_client/dpop.rs b/atrium-oauth/oauth-client/src/http_client/dpop.rs index be242e5a..2ba8f287 100644 --- a/atrium-oauth/oauth-client/src/http_client/dpop.rs +++ b/atrium-oauth/oauth-client/src/http_client/dpop.rs @@ -1,8 +1,8 @@ use crate::jose::create_signed_jwt; use crate::jose::jws::RegisteredHeader; use crate::jose::jwt::{Claims, PublicClaims, RegisteredClaims}; -use atrium_common::store::memory::MemoryMapStore; -use atrium_common::store::MapStore; +use atrium_common::store::memory::MemoryStore; +use atrium_common::store::Store; use atrium_xrpc::http::{Request, Response}; use atrium_xrpc::HttpClient; use base64::engine::general_purpose::URL_SAFE_NO_PAD; @@ -38,19 +38,21 @@ pub enum Error { type Result = core::result::Result; -pub struct DpopClient> +pub struct DpopClient> where - S: MapStore, + S: Store, { inner: Arc, pub(crate) key: Key, nonces: S, + is_auth_server: bool, } impl DpopClient { pub fn new( key: Key, http_client: Arc, + is_auth_server: bool, supported_algs: &Option>, ) -> Result { if let Some(algs) = supported_algs { @@ -65,14 +67,14 @@ impl DpopClient { return Err(Error::UnsupportedKey); } } - let nonces = MemoryMapStore::::default(); - Ok(Self { inner: http_client, key, iss, nonces }) + let nonces = MemoryStore::::default(); + Ok(Self { inner: http_client, key, nonces, is_auth_server }) } } impl DpopClient where - S: MapStore, + S: Store, { fn build_proof( &self, @@ -102,16 +104,18 @@ where _ => unimplemented!(), } } - fn is_use_dpop_nonce_error(&self, response: &Response>, is_auth_server: bool) -> bool { + fn is_use_dpop_nonce_error(&self, response: &Response>) -> bool { // https://datatracker.ietf.org/doc/html/rfc9449#name-authorization-server-provid - if is_auth_server && response.status() == 400 { - if let Ok(res) = serde_json::from_slice::(response.body()) { - return res.error == "use_dpop_nonce"; - }; + if self.is_auth_server { + if response.status() == 400 { + if let Ok(res) = serde_json::from_slice::(response.body()) { + return res.error == "use_dpop_nonce"; + }; + } } + // https://datatracker.ietf.org/doc/html/rfc6750#section-3 // https://datatracker.ietf.org/doc/html/rfc9449#name-resource-server-provided-no - if !is_auth_server && response.status() == 401 { - // https://datatracker.ietf.org/doc/html/rfc6750#section-3 + else if response.status() == 401 { if let Some(www_auth) = response.headers().get("WWW-Authenticate").and_then(|v| v.to_str().ok()) { @@ -133,8 +137,8 @@ where impl HttpClient for DpopClient where T: HttpClient + Send + Sync + 'static, - S: MapStore + Send + Sync + 'static, - S::Error: Send + Sync + 'static, + S: Store + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, { async fn send_http( &self, @@ -146,7 +150,6 @@ where let htm = request.method().to_string(); let htu = uri.to_string(); - let is_auth_server = uri.path().starts_with("/oauth"); let ath = match request.headers().get("Authorization").and_then(|v| v.to_str().ok()) { Some(s) if s.starts_with("DPoP ") => { Some(URL_SAFE_NO_PAD.encode(Sha256::digest(s.strip_prefix("DPoP ").unwrap()))) @@ -178,7 +181,7 @@ where } } - if !self.is_use_dpop_nonce_error(&response, is_auth_server) { + if !self.is_use_dpop_nonce_error(&response) { return Ok(response); } let next_proof = self.build_proof(htm, htu, ath, next_nonce)?; diff --git a/atrium-oauth/oauth-client/src/oauth_client.rs b/atrium-oauth/oauth-client/src/oauth_client.rs index 198e2958..ba2a5de1 100644 --- a/atrium-oauth/oauth-client/src/oauth_client.rs +++ b/atrium-oauth/oauth-client/src/oauth_client.rs @@ -5,6 +5,7 @@ use crate::oauth_session::OAuthSession; use crate::resolver::{OAuthResolver, OAuthResolverConfig}; use crate::server_agent::{OAuthRequest, OAuthServerAgent}; use crate::store::session::{Session, SessionStore}; +use crate::store::session_getter::SessionGetter; use crate::store::state::{InternalStateData, StateStore}; use crate::types::{ AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions, CallbackParams, @@ -13,8 +14,9 @@ use crate::types::{ TryIntoOAuthClientMetadata, }; use crate::utils::{compare_algos, generate_key, generate_nonce, get_random_values}; +use atrium_api::types::string::Did; use atrium_common::resolver::Resolver; -use atrium_common::store::MapStore; +use atrium_common::store::Store; use atrium_identity::{did::DidResolver, handle::HandleResolver}; use atrium_xrpc::HttpClient; use base64::engine::general_purpose::URL_SAFE_NO_PAD; @@ -60,23 +62,19 @@ where #[cfg(feature = "default-client")] pub struct OAuthClient where - S0: StateStore, - S1: SessionStore, T: HttpClient + Send + Sync + 'static, { pub client_metadata: OAuthClientMetadata, keyset: Option, resolver: Arc>, state_store: S0, - session_store: S1, + session_getter: SessionGetter, http_client: Arc, } #[cfg(not(feature = "default-client"))] pub struct OAuthClient where - S0: StateStore, - S1: SessionStore, T: HttpClient + Send + Sync + 'static, { pub client_metadata: OAuthClientMetadata, @@ -88,11 +86,7 @@ where } #[cfg(feature = "default-client")] -impl OAuthClient -where - S0: StateStore, - S1: SessionStore, -{ +impl OAuthClient { pub fn new(config: OAuthClientConfig) -> Result where M: TryIntoOAuthClientMetadata, @@ -105,7 +99,7 @@ where keyset, resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())), state_store: config.state_store, - session_store: config.session_store, + session_getter: SessionGetter::new(config.session_store), http_client, }) } @@ -138,11 +132,13 @@ where impl OAuthClient where - S0: StateStore, - S1: SessionStore, + S0: StateStore + Send + Sync + 'static, + S1: SessionStore + Send + Sync + 'static, D: DidResolver + Send + Sync + 'static, H: HandleResolver + Send + Sync + 'static, T: HttpClient + Send + Sync + 'static, + S0::Error: std::error::Error + Send + Sync + 'static, + S1::Error: std::error::Error + Send + Sync + 'static, { pub fn jwks(&self) -> JwkSet { self.keyset.as_ref().map(|keyset| keyset.public_jwks()).unwrap_or_default() @@ -186,14 +182,7 @@ where prompt: options.prompt.map(String::from), }; if metadata.pushed_authorization_request_endpoint.is_some() { - let server = OAuthServerAgent::new( - dpop_key, - metadata.clone(), - self.client_metadata.clone(), - self.resolver.clone(), - self.http_client.clone(), - self.keyset.clone(), - )?; + let server = self.create_server_agent(dpop_key, metadata.clone())?; let par_response = server .request::( OAuthRequest::PushedAuthorizationRequest(parameters), @@ -220,11 +209,7 @@ where todo!() } } - pub async fn callback(&self, params: CallbackParams) -> Result> - where - S: MapStore<(), Session> + Default + Send + Sync + 'static, - S::Error: Send + Sync + 'static, - { + pub async fn callback(&self, params: CallbackParams) -> Result<(OAuthSession, Option)> { let Some(state_key) = params.state else { return Err(Error::Callback("missing `state` parameter".into())); }; @@ -247,29 +232,43 @@ where } else if metadata.authorization_response_iss_parameter_supported == Some(true) { return Err(Error::Callback("missing `iss` parameter".into())); } - let server = OAuthServerAgent::new( - state.dpop_key.clone(), - metadata.clone(), + let server = self.create_server_agent(state.dpop_key.clone(), metadata.clone())?; + match server.exchange_code(¶ms.code, &state.verifier).await { + Ok(token_set) => { + let sub = token_set.sub.clone(); + self.session_getter + .set(sub.clone(), Session { dpop_key: state.dpop_key.clone(), token_set }) + .await + .map_err(|e| Error::SessionStore(Box::new(e)))?; + Ok((self.create_session(server, sub).await?, state.app_state)) + } + Err(_) => { + todo!() + } + } + } + async fn create_session( + &self, + server: OAuthServerAgent, + sub: Did, + ) -> Result> { + Ok(server + .create_session(sub, self.http_client.clone(), self.session_getter.clone()) + .await?) + } + fn create_server_agent( + &self, + dpop_key: Key, + server_metadata: OAuthAuthorizationServerMetadata, + ) -> Result> { + Ok(OAuthServerAgent::new( + dpop_key, + server_metadata, self.client_metadata.clone(), self.resolver.clone(), self.http_client.clone(), self.keyset.clone(), - )?; - let token_set = server.exchange_code(¶ms.code, &state.verifier).await?; - - let session = Session { dpop_key: state.dpop_key.clone(), token_set: token_set.clone() }; - self.session_store.set(token_set.sub.clone(), session.clone()).await.unwrap(); - - let session_store = S::default(); - session_store - .set((), session.clone()) - .await - .map_err(|e| crate::Error::SessionStore(Box::new(e)))?; - - Ok(OAuthSession::new( - session_store, - server.from_metadata(metadata.clone(), state.dpop_key.clone())?, - )) + )?) } fn generate_dpop_key(metadata: &OAuthAuthorizationServerMetadata) -> Option { let mut algs = diff --git a/atrium-oauth/oauth-client/src/oauth_session.rs b/atrium-oauth/oauth-client/src/oauth_session.rs index f9829188..8db608e1 100644 --- a/atrium-oauth/oauth-client/src/oauth_session.rs +++ b/atrium-oauth/oauth-client/src/oauth_session.rs @@ -1,125 +1,122 @@ -use std::fmt::Debug; +use std::sync::Arc; use atrium_api::{agent::SessionManager, types::string::Did}; -use atrium_common::store::MapStore; -use atrium_identity::{did::DidResolver, handle::HandleResolver}; +use atrium_common::store::{memory::MemoryStore, Store}; use atrium_xrpc::{ http::{Request, Response}, types::AuthorizationToken, HttpClient, XrpcClient, }; -use chrono::TimeDelta; -use thiserror::Error; +use jose_jwk::Key; -use crate::{server_agent::OAuthServerAgent, store::session::Session}; +use crate::{http_client::dpop::Error, server_agent::OAuthServerAgent, DpopClient, TokenSet}; -#[derive(Clone, Debug, Error)] -pub enum Error {} - -pub struct OAuthSession +pub struct OAuthSession> where - S: MapStore<(), Session> + Default, T: HttpClient + Send + Sync + 'static, - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, + S: Store, { - session_store: S, - server: OAuthServerAgent, + #[allow(dead_code)] + server_agent: OAuthServerAgent, + dpop_client: DpopClient, + token_set: TokenSet, } -impl OAuthSession +impl OAuthSession where - S: MapStore<(), Session> + Default, - S::Error: Send + Sync + 'static, T: HttpClient + Send + Sync + 'static, - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, { - pub fn new(session_store: S, server: OAuthServerAgent) -> Self { - Self { session_store, server } + pub(crate) fn new( + server_agent: OAuthServerAgent, + dpop_key: Key, + http_client: Arc, + token_set: TokenSet, + ) -> Result { + let dpop_client = DpopClient::new( + dpop_key, + http_client.clone(), + false, + &server_agent.server_metadata.token_endpoint_auth_signing_alg_values_supported, + )?; + Ok(Self { server_agent, dpop_client, token_set }) } - pub async fn get_session(&self, refresh: bool) -> crate::Result { - let Some(session) = self - .session_store - .get(&()) - .await - .map_err(|e| crate::Error::SessionStore(Box::new(e)))? - else { - panic!("a session should always exist"); - }; - if session.expires_in().expect("no expires_at") == TimeDelta::zero() && refresh { - let token_set = self.server.refresh(session.token_set.clone()).await?; - Ok(Session { dpop_key: session.dpop_key.clone(), token_set }) - } else { - Ok(session) - } + pub fn dpop_key(&self) -> Key { + self.dpop_client.key.clone() } - pub async fn logout(&self) -> crate::Result<()> { - let session = self.get_session(false).await?; + pub fn token_set(&self) -> TokenSet { + self.token_set.clone() + } + // pub async fn get_session(&self, refresh: bool) -> crate::Result { + // let Some(session) = self + // .session_store + // .get(&()) + // .await + // .map_err(|e| crate::Error::SessionStore(Box::new(e)))? + // else { + // panic!("a session should always exist"); + // }; + // if session.expires_in().expect("no expires_at") == TimeDelta::zero() && refresh { + // let token_set = self.server.refresh(session.token_set.clone()).await?; + // Ok(Session { dpop_key: session.dpop_key.clone(), token_set }) + // } else { + // Ok(session) + // } + // } + // pub async fn logout(&self) -> crate::Result<()> { + // let session = self.get_session(false).await?; - self.server.revoke(&session.token_set.access_token).await; - self.session_store.clear().await.map_err(|e| crate::Error::SessionStore(Box::new(e)))?; + // self.server.revoke(&session.token_set.access_token).await; + // self.session_store.clear().await.map_err(|e| crate::Error::SessionStore(Box::new(e)))?; - Ok(()) - } + // Ok(()) + // } } -impl HttpClient for OAuthSession +impl HttpClient for OAuthSession where - S: MapStore<(), Session> + Default + Sync, T: HttpClient + Send + Sync + 'static, - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, + D: Send + Sync + 'static, + H: Send + Sync + 'static, + S: Store + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, { async fn send_http( &self, request: Request>, ) -> Result>, Box> { - self.server.send_http(request).await + self.dpop_client.send_http(request).await } } -impl XrpcClient for OAuthSession +impl XrpcClient for OAuthSession where - S: MapStore<(), Session> + Default + Sync, - S::Error: Send + Sync + 'static, T: HttpClient + Send + Sync + 'static, - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, + D: Send + Sync + 'static, + H: Send + Sync + 'static, + S: Store + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, { fn base_uri(&self) -> String { - // let Ok(Some(Session { dpop_key: _, token_set })) = - // futures::FutureExt::now_or_never(self.get_session(false)).transpose() - // else { - // panic!("session, now or never"); - // }; - - todo!() + self.token_set.aud.clone() } async fn authorization_token(&self, is_refresh: bool) -> Option { - let Ok(Session { dpop_key: _, token_set }) = self.get_session(false).await else { - return None; - }; if is_refresh { - token_set.refresh_token.as_ref().cloned().map(AuthorizationToken::Dpop) + self.token_set.refresh_token.as_ref().cloned().map(AuthorizationToken::Dpop) } else { - Some(AuthorizationToken::Bearer(token_set.access_token.clone())) + Some(AuthorizationToken::Dpop(self.token_set.access_token.clone())) } } } -impl SessionManager for OAuthSession +impl SessionManager for OAuthSession where - S: MapStore<(), Session> + Default + Sync, - S::Error: Send + Sync + 'static, T: HttpClient + Send + Sync + 'static, - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, + D: Send + Sync + 'static, + H: Send + Sync + 'static, + S: Store + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, { async fn did(&self) -> Option { - let Ok(Some(session)) = self.session_store.get(&()).await else { - return None; - }; - Some(session.token_set.sub.parse().expect("TokenSet contains valid sub")) + Some(self.token_set.sub.clone()) } } diff --git a/atrium-oauth/oauth-client/src/server_agent.rs b/atrium-oauth/oauth-client/src/server_agent.rs index 229aff0c..4c5e15f7 100644 --- a/atrium-oauth/oauth-client/src/server_agent.rs +++ b/atrium-oauth/oauth-client/src/server_agent.rs @@ -3,15 +3,19 @@ use crate::http_client::dpop::DpopClient; use crate::jose::jwt::{RegisteredClaims, RegisteredClaimsAud}; use crate::keyset::Keyset; use crate::resolver::OAuthResolver; +use crate::store::session::SessionStore; +use crate::store::session_getter::SessionGetter; use crate::types::{ OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthTokenResponse, - PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType, - TokenRequestParameters, TokenSet, + PushedAuthorizationRequestParameters, RefreshRequestParameters, RevocationRequestParameters, + TokenGrantType, TokenRequestParameters, TokenSet, }; use crate::utils::{compare_algos, generate_nonce}; -use atrium_api::types::string::Datetime; +use crate::OAuthSession; +use atrium_api::types::string::{Datetime, Did}; +use atrium_common::store::Store; use atrium_identity::{did::DidResolver, handle::HandleResolver}; -use atrium_xrpc::http::{Method, Request, Response, StatusCode}; +use atrium_xrpc::http::{Method, Request, StatusCode}; use atrium_xrpc::HttpClient; use chrono::{TimeDelta, Utc}; use jose_jwk::Key; @@ -32,6 +36,8 @@ pub enum Error { Token(String), #[error("unsupported authentication method")] UnsupportedAuthMethod, + #[error("failed to parse DID: {0}")] + InvalidDid(&'static str), #[error("no refresh token available for {0}")] NoRefreshToken(String), #[error(transparent)] @@ -102,10 +108,9 @@ pub struct OAuthServerAgent where T: HttpClient + Send + Sync + 'static, { - server_metadata: OAuthAuthorizationServerMetadata, - client_metadata: OAuthClientMetadata, + pub(crate) server_metadata: OAuthAuthorizationServerMetadata, + pub(crate) client_metadata: OAuthClientMetadata, dpop_client: DpopClient, - http_client: Arc, resolver: Arc>, keyset: Option, } @@ -126,11 +131,11 @@ where ) -> Result { let dpop_client = DpopClient::new( dpop_key, - client_metadata.client_id.clone(), http_client.clone(), + true, &server_metadata.token_endpoint_auth_signing_alg_values_supported, )?; - Ok(Self { server_metadata, client_metadata, dpop_client, http_client, resolver, keyset }) + Ok(Self { server_metadata, client_metadata, dpop_client, resolver, keyset }) } /** * VERY IMPORTANT ! Always call this to process token responses. @@ -158,7 +163,7 @@ where .map(Datetime::new) }); Ok(TokenSet { - sub: sub.clone(), + sub: sub.parse().map_err(Error::InvalidDid)?, aud: identity.pds, iss: metadata.issuer, scope: token_response.scope, @@ -170,13 +175,12 @@ where } pub async fn exchange_code(&self, code: &str, verifier: &str) -> Result { self.verify_token_response( - self.request(OAuthRequest::Token(TokenRequestParameters::AuthorizationCode( - AuthorizationCodeParameters { - code: code.into(), - redirect_uri: self.client_metadata.redirect_uris[0].clone(), // ? - code_verifier: verifier.into(), - }, - ))) + self.request(OAuthRequest::Token(TokenRequestParameters { + grant_type: TokenGrantType::AuthorizationCode, + code: code.into(), + redirect_uri: self.client_metadata.redirect_uris[0].clone(), // ? + code_verifier: verifier.into(), + })) .await?, ) .await @@ -188,35 +192,40 @@ where })) .await; } - /** - * /!\ IMPORTANT /!\ - * - * The "sub" MUST be a DID, whose issuer authority is indeed the server we - * are trying to obtain credentials from. Note that we are doing this - * *before* we actually try to refresh the token: - * 1) To avoid unnecessary refresh - * 2) So that the refresh is the last async operation, ensuring as few - * async operations happen before the result gets a chance to be stored. - */ - pub async fn refresh(&self, token_set: TokenSet) -> Result { - let Some(refresh_token) = token_set.refresh_token else { - return Err(Error::NoRefreshToken(token_set.sub.clone())); + #[allow(dead_code)] + pub async fn refresh(&self, token_set: &TokenSet) { + let Some(refresh_token) = token_set.refresh_token.as_ref() else { + // TODO + return; }; - let (metadata, atrium_identity::identity_resolver::ResolvedIdentity { pds: aud, .. }) = - self.resolver.resolve_from_identity(&token_set.sub).await?; - if metadata.issuer != self.server_metadata.issuer { - let _ = self.revoke(&token_set.access_token).await; - return Err(Error::Token("issuer mismatch".into())); - } - let token_set = self - .verify_token_response( - self.request(OAuthRequest::Token(TokenRequestParameters::RefreshToken( - RefreshTokenParameters { refresh_token, scope: token_set.scope.clone() }, - ))) - .await?, - ) - .await?; - Ok(TokenSet { aud, ..token_set }) + // TODO + let result = self + .request::(OAuthRequest::Refresh(RefreshRequestParameters { + grant_type: TokenGrantType::RefreshToken, + refresh_token: refresh_token.clone(), + scope: None, + })) + .await; + println!("{result:?}"); + + // let Some(refresh_token) = token_set.refresh_token else { + // return Err(Error::NoRefreshToken(token_set.sub.clone())); + // }; + // let (metadata, atrium_identity::identity_resolver::ResolvedIdentity { pds: aud, .. }) = + // self.resolver.resolve_from_identity(&token_set.sub).await?; + // if metadata.issuer != self.server_metadata.issuer { + // let _ = self.revoke(&token_set.access_token).await; + // return Err(Error::Token("issuer mismatch".into())); + // } + // let token_set = self + // .verify_token_response( + // self.request(OAuthRequest::Token(TokenRequestParameters::RefreshToken( + // RefreshTokenParameters { refresh_token, scope: token_set.scope.clone() }, + // ))) + // .await?, + // ) + // .await?; + // Ok(TokenSet { aud, ..token_set }) } pub async fn request(&self, request: OAuthRequest) -> Result where @@ -323,44 +332,19 @@ where } } } - #[allow(clippy::wrong_self_convention)] - pub async fn from_issuer( - &self, - issuer: &str, - dpop_key: Key, - ) -> Result> { - let server_metadata = self.resolver.get_authorization_server_metadata(issuer).await?; - self.from_metadata(server_metadata, dpop_key) - } - #[allow(clippy::wrong_self_convention)] - pub fn from_metadata( - &self, - server_metadata: OAuthAuthorizationServerMetadata, - dpop_key: Key, - ) -> Result> { - let server = OAuthServerAgent::new( - dpop_key, - server_metadata, - self.client_metadata.clone(), - self.resolver.clone(), - self.http_client.clone(), - self.keyset.clone(), - )?; - Ok(server) - } -} - -impl HttpClient for OAuthServerAgent -where - T: HttpClient + Send + Sync + 'static, - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, -{ - async fn send_http( - &self, - request: Request>, - ) -> core::result::Result>, Box> + pub(crate) async fn create_session( + self, + sub: Did, + http_client: Arc, + session_getter: SessionGetter, + ) -> Result> + where + S: SessionStore + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, { - self.dpop_client.send_http(request).await + let dpop_key = self.dpop_client.key.clone(); + // TODO + let session = session_getter.get(&sub).await.expect("").unwrap(); + Ok(OAuthSession::new(self, dpop_key, http_client, session.token_set)?) } } diff --git a/atrium-oauth/oauth-client/src/store.rs b/atrium-oauth/oauth-client/src/store.rs index f7247255..a06b3710 100644 --- a/atrium-oauth/oauth-client/src/store.rs +++ b/atrium-oauth/oauth-client/src/store.rs @@ -1,2 +1,3 @@ pub mod session; +pub mod session_getter; pub mod state; diff --git a/atrium-oauth/oauth-client/src/store/session.rs b/atrium-oauth/oauth-client/src/store/session.rs index a15d7f8d..0dd73f92 100644 --- a/atrium-oauth/oauth-client/src/store/session.rs +++ b/atrium-oauth/oauth-client/src/store/session.rs @@ -1,11 +1,10 @@ -use atrium_api::types::string::Datetime; -use atrium_common::store::{memory::MemoryMapStore, MapStore}; +use crate::types::TokenSet; +use atrium_api::types::string::{Datetime, Did}; +use atrium_common::store::{memory::MemoryStore, Store}; use chrono::TimeDelta; use jose_jwk::Key; use serde::{Deserialize, Serialize}; -use crate::TokenSet; - #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Session { pub dpop_key: Key, @@ -20,8 +19,8 @@ impl Session { } } -pub trait SessionStore: MapStore {} +pub trait SessionStore: Store {} -pub type MemorySessionStore = MemoryMapStore; +pub type MemorySessionStore = MemoryStore; impl SessionStore for MemorySessionStore {} diff --git a/atrium-oauth/oauth-client/src/store/session_getter.rs b/atrium-oauth/oauth-client/src/store/session_getter.rs new file mode 100644 index 00000000..183ab913 --- /dev/null +++ b/atrium-oauth/oauth-client/src/store/session_getter.rs @@ -0,0 +1,49 @@ +use crate::store::session::{Session, SessionStore}; +use atrium_api::types::string::Did; +use atrium_common::store::Store; +use std::sync::Arc; + +#[derive(Debug)] +pub struct SessionGetter { + store: Arc, +} + +impl SessionGetter { + pub fn new(store: S) -> Self { + Self { store: Arc::new(store) } + } + // TODO: extended store methods? +} + +impl Clone for SessionGetter { + fn clone(&self) -> Self { + Self { store: self.store.clone() } + } +} + +impl Store for SessionGetter +where + S: SessionStore + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, +{ + type Error = S::Error; + async fn get(&self, key: &Did) -> Result, Self::Error> { + self.store.get(key).await + } + async fn set(&self, key: Did, value: Session) -> Result<(), Self::Error> { + self.store.set(key, value).await + } + async fn del(&self, key: &Did) -> Result<(), Self::Error> { + self.store.del(key).await + } + async fn clear(&self) -> Result<(), Self::Error> { + self.store.clear().await + } +} + +impl SessionStore for SessionGetter +where + S: SessionStore + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, +{ +} diff --git a/atrium-oauth/oauth-client/src/store/state.rs b/atrium-oauth/oauth-client/src/store/state.rs index 3adeefee..a39a2cb4 100644 --- a/atrium-oauth/oauth-client/src/store/state.rs +++ b/atrium-oauth/oauth-client/src/store/state.rs @@ -1,4 +1,4 @@ -use atrium_common::store::{memory::MemoryMapStore, MapStore}; +use atrium_common::store::{memory::MemoryStore, Store}; use jose_jwk::Key; use serde::{Deserialize, Serialize}; @@ -10,8 +10,8 @@ pub struct InternalStateData { pub app_state: Option, } -pub trait StateStore: MapStore {} +pub trait StateStore: Store {} -pub type MemoryStateStore = MemoryMapStore; +pub type MemoryStateStore = MemoryStore; impl StateStore for MemoryStateStore {} diff --git a/atrium-oauth/oauth-client/src/types.rs b/atrium-oauth/oauth-client/src/types.rs index 24693a62..4d84a806 100644 --- a/atrium-oauth/oauth-client/src/types.rs +++ b/atrium-oauth/oauth-client/src/types.rs @@ -9,13 +9,13 @@ pub use client_metadata::{OAuthClientMetadata, TryIntoOAuthClientMetadata}; pub use metadata::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata}; pub use request::{ AuthorizationCodeChallengeMethod, AuthorizationResponseType, - PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType, - TokenRequestParameters, + PushedAuthorizationRequestParameters, RefreshRequestParameters, RevocationRequestParameters, + TokenGrantType, TokenRequestParameters, }; pub use response::{OAuthPusehedAuthorizationRequestResponse, OAuthTokenResponse}; use serde::Deserialize; #[allow(unused_imports)] -pub use token::{TokenInfo, TokenSet}; +pub use token::TokenSet; #[derive(Debug, Deserialize)] pub enum AuthorizeOptionPrompt { diff --git a/atrium-oauth/oauth-client/src/types/request.rs b/atrium-oauth/oauth-client/src/types/request.rs index d361c5f7..80d44a55 100644 --- a/atrium-oauth/oauth-client/src/types/request.rs +++ b/atrium-oauth/oauth-client/src/types/request.rs @@ -55,7 +55,9 @@ pub enum TokenGrantType { } #[derive(Serialize)] -pub struct AuthorizationCodeParameters { +pub struct TokenRequestParameters { + // https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 + pub grant_type: TokenGrantType, pub code: String, pub redirect_uri: String, // https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 @@ -70,6 +72,7 @@ pub struct RefreshRequestParameters { pub scope: Option, } +#[allow(dead_code)] #[derive(Serialize)] pub struct RevocationRequestParameters { pub token: String, diff --git a/atrium-oauth/oauth-client/src/types/token.rs b/atrium-oauth/oauth-client/src/types/token.rs index 9504015c..d09736e0 100644 --- a/atrium-oauth/oauth-client/src/types/token.rs +++ b/atrium-oauth/oauth-client/src/types/token.rs @@ -1,11 +1,11 @@ use super::response::OAuthTokenType; -use atrium_api::types::string::Datetime; +use atrium_api::types::string::{Datetime, Did}; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct TokenSet { pub iss: String, - pub sub: String, + pub sub: Did, pub aud: String, pub scope: Option, @@ -15,13 +15,3 @@ pub struct TokenSet { pub expires_at: Option, } - -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct TokenInfo { - pub iss: String, - pub sub: String, - pub aud: String, - pub scope: Option, - - pub expires_at: Option, -} diff --git a/bsky-sdk/src/agent.rs b/bsky-sdk/src/agent.rs index 3c0ddda4..a5385955 100644 --- a/bsky-sdk/src/agent.rs +++ b/bsky-sdk/src/agent.rs @@ -12,8 +12,8 @@ use atrium_api::agent::atp_agent::{AtpAgent, AtpSession}; use atrium_api::app::bsky::actor::defs::PreferencesItem; use atrium_api::types::{Object, Union}; use atrium_api::xrpc::XrpcClient; -use atrium_common::store::memory::MemoryMapStore; -use atrium_common::store::MapStore; +use atrium_common::store::memory::MemoryStore; +use atrium_common::store::Store; #[cfg(feature = "default-client")] use atrium_xrpc_client::reqwest::ReqwestClient; use std::collections::HashMap; @@ -38,20 +38,20 @@ use std::sync::Arc; #[cfg(feature = "default-client")] #[derive(Clone)] -pub struct BskyAgent> +pub struct BskyAgent> where T: XrpcClient + Send + Sync, - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, { inner: Arc>, } #[cfg(not(feature = "default-client"))] -pub struct BskyAgent +pub struct BskyAgent where T: XrpcClient + Send + Sync, - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, { inner: Arc>, } @@ -60,7 +60,7 @@ where #[cfg(feature = "default-client")] impl BskyAgent { /// Create a new [`BskyAtpAgentBuilder`] with the default client and session store. - pub fn builder() -> BskyAtpAgentBuilder> { + pub fn builder() -> BskyAtpAgentBuilder> { BskyAtpAgentBuilder::default() } } @@ -68,7 +68,7 @@ impl BskyAgent { impl BskyAgent where T: XrpcClient + Send + Sync, - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, { /// Get the agent's current state as a [`Config`]. @@ -251,7 +251,7 @@ where impl Deref for BskyAgent where T: XrpcClient + Send + Sync, - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, { type Target = AtpAgent; @@ -269,7 +269,7 @@ mod tests { #[derive(Clone)] struct NoopStore; - impl MapStore<(), AtpSession> for NoopStore { + impl Store<(), AtpSession> for NoopStore { type Error = std::convert::Infallible; async fn get(&self, _key: &()) -> core::result::Result, Self::Error> { diff --git a/bsky-sdk/src/agent/builder.rs b/bsky-sdk/src/agent/builder.rs index 9a2324ba..7d3c4485 100644 --- a/bsky-sdk/src/agent/builder.rs +++ b/bsky-sdk/src/agent/builder.rs @@ -3,17 +3,17 @@ use super::BskyAgent; use crate::error::Result; use atrium_api::agent::atp_agent::{AtpAgent, AtpSession}; use atrium_api::xrpc::XrpcClient; -use atrium_common::store::memory::MemoryMapStore; -use atrium_common::store::MapStore; +use atrium_common::store::memory::MemoryStore; +use atrium_common::store::Store; #[cfg(feature = "default-client")] use atrium_xrpc_client::reqwest::ReqwestClient; use std::sync::Arc; /// A builder for creating a [`BskyAtpAgent`]. -pub struct BskyAtpAgentBuilder> +pub struct BskyAtpAgentBuilder> where T: XrpcClient + Send + Sync, - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, { config: Config, store: S, @@ -26,14 +26,14 @@ where { /// Create a new builder with the given XRPC client. pub fn new(client: T) -> Self { - Self { config: Config::default(), store: MemoryMapStore::default(), client } + Self { config: Config::default(), store: MemoryStore::default(), client } } } impl BskyAtpAgentBuilder where T: XrpcClient + Send + Sync, - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, { /// Set the configuration for the agent. @@ -46,7 +46,7 @@ where /// Returns a new builder with the session store set. pub fn store(self, store: S0) -> BskyAtpAgentBuilder where - S0: MapStore<(), AtpSession> + Send + Sync, + S0: Store<(), AtpSession> + Send + Sync, { BskyAtpAgentBuilder { config: self.config, store, client: self.client } } @@ -93,10 +93,10 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "default-client")))] #[cfg(feature = "default-client")] -impl Default for BskyAtpAgentBuilder> { +impl Default for BskyAtpAgentBuilder> { /// Create a new builder with the default client and session store. /// - /// Default client is [`ReqwestClient`] and default session store is [`MemoryMapStore`]. + /// Default client is [`ReqwestClient`] and default session store is [`MemoryStore`]. fn default() -> Self { Self::new(ReqwestClient::new(Config::default().endpoint)) } @@ -126,7 +126,7 @@ mod tests { struct MockSessionStore; - impl MapStore<(), AtpSession> for MockSessionStore { + impl Store<(), AtpSession> for MockSessionStore { type Error = std::convert::Infallible; async fn get(&self, _key: &()) -> core::result::Result, Self::Error> { diff --git a/bsky-sdk/src/record.rs b/bsky-sdk/src/record.rs index 4d590ee4..7a7bba1d 100644 --- a/bsky-sdk/src/record.rs +++ b/bsky-sdk/src/record.rs @@ -11,13 +11,13 @@ use atrium_api::com::atproto::repo::{ }; use atrium_api::types::{Collection, LimitedNonZeroU8, TryIntoUnknown}; use atrium_api::xrpc::XrpcClient; -use atrium_common::store::MapStore; +use atrium_common::store::Store; #[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] pub trait Record where T: XrpcClient + Send + Sync, - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, { fn list( @@ -47,7 +47,7 @@ macro_rules! record_impl { impl Record for $record where T: XrpcClient + Send + Sync, - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, { async fn list( @@ -165,7 +165,7 @@ macro_rules! record_impl { impl Record for $record_data where T: XrpcClient + Send + Sync, - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, { async fn list( @@ -325,7 +325,7 @@ mod tests { struct MockSessionStore; - impl MapStore<(), AtpSession> for MockSessionStore { + impl Store<(), AtpSession> for MockSessionStore { type Error = std::convert::Infallible; async fn get(&self, _key: &()) -> core::result::Result, Self::Error> { diff --git a/bsky-sdk/src/record/agent.rs b/bsky-sdk/src/record/agent.rs index 7237f76e..00fac1ae 100644 --- a/bsky-sdk/src/record/agent.rs +++ b/bsky-sdk/src/record/agent.rs @@ -6,12 +6,12 @@ use atrium_api::com::atproto::repo::{create_record, delete_record}; use atrium_api::record::KnownRecord; use atrium_api::types::string::RecordKey; use atrium_api::xrpc::XrpcClient; -use atrium_common::store::MapStore; +use atrium_common::store::Store; impl BskyAgent where T: XrpcClient + Send + Sync, - S: MapStore<(), AtpSession> + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, S::Error: Send + Sync + 'static, { /// Create a record with various types of data.