Skip to content

Commit

Permalink
Merge branch 'feature/agent-rework' into oauth-session
Browse files Browse the repository at this point in the history
  • Loading branch information
avdb13 committed Nov 24, 2024
1 parent fc11bb8 commit 82a9398
Show file tree
Hide file tree
Showing 23 changed files with 362 additions and 344 deletions.
40 changes: 20 additions & 20 deletions atrium-api/src/agent/atp_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
did_doc::DidDocument,
types::{string::Did, TryFromUnknown},
};
use atrium_common::store::MapStore;
use atrium_common::store::Store;
use atrium_xrpc::{Error, XrpcClient};
use std::{ops::Deref, sync::Arc};

Expand All @@ -16,7 +16,7 @@ pub type AtpSession = crate::com::atproto::server::create_session::Output;

pub struct CredentialSession<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S: Store<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
Expand All @@ -27,7 +27,7 @@ where

impl<S, T> CredentialSession<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S: Store<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
Expand Down Expand Up @@ -152,7 +152,7 @@ where
/// Manages session token lifecycles and provides convenience methods.
pub struct AtpAgent<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S: Store<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
Expand All @@ -161,7 +161,7 @@ where

impl<S, T> AtpAgent<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S: Store<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
Expand All @@ -173,7 +173,7 @@ where

impl<S, T> Deref for AtpAgent<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S: Store<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
Expand All @@ -191,7 +191,7 @@ mod tests {
use crate::com::atproto::server::create_session::OutputData;
use crate::did_doc::{DidDocument, Service, VerificationMethod};
use crate::types::TryIntoUnknown;
use atrium_common::store::memory::MemoryMapStore;
use atrium_common::store::memory::MemoryStore;
use atrium_xrpc::HttpClient;
use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
use std::collections::HashMap;
Expand Down Expand Up @@ -319,7 +319,7 @@ mod tests {
#[tokio::test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_new() {
let agent = AtpAgent::new(MockClient::default(), MemoryMapStore::default());
let agent = AtpAgent::new(MockClient::default(), MemoryStore::default());
assert_eq!(agent.get_session().await, None);
}

Expand All @@ -338,7 +338,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());
agent.login("test", "pass").await.expect("login should be succeeded");
assert_eq!(agent.get_session().await, Some(session_data.into()));
}
Expand All @@ -348,7 +348,7 @@ mod tests {
responses: MockResponses { ..Default::default() },
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());
agent.login("test", "bad").await.expect_err("login should be failed");
assert_eq!(agent.get_session().await, None);
}
Expand All @@ -374,7 +374,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());
agent
.store
.set((), session_data.clone().into())
Expand Down Expand Up @@ -412,7 +412,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());
agent
.store
.set((), session_data.clone().into())
Expand Down Expand Up @@ -460,7 +460,7 @@ mod tests {
..Default::default()
};
let counts = Arc::clone(&client.counts);
let agent = Arc::new(AtpAgent::new(client, MemoryMapStore::default()));
let agent = Arc::new(AtpAgent::new(client, MemoryStore::default()));
agent
.store
.set((), session_data.clone().into())
Expand Down Expand Up @@ -519,7 +519,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());
assert_eq!(agent.get_session().await, None);
agent
.resume_session(
Expand All @@ -539,7 +539,7 @@ mod tests {
responses: MockResponses { ..Default::default() },
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());
assert_eq!(agent.get_session().await, None);
agent
.resume_session(session_data.clone().into())
Expand Down Expand Up @@ -569,7 +569,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());
agent
.resume_session(
OutputData { access_jwt: "expired".into(), ..session_data.clone() }.into(),
Expand Down Expand Up @@ -618,7 +618,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());
agent.login("test", "pass").await.expect("login should be succeeded");
assert_eq!(agent.get_endpoint().await, "https://bsky.social");
assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "https://bsky.social");
Expand Down Expand Up @@ -653,7 +653,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());
agent.login("test", "pass").await.expect("login should be succeeded");
// not updated
assert_eq!(agent.get_endpoint().await, "http://localhost:8080");
Expand All @@ -666,7 +666,7 @@ mod tests {
async fn test_configure_labelers_header() {
let client = MockClient::default();
let headers = Arc::clone(&client.headers);
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());

agent
.api
Expand Down Expand Up @@ -729,7 +729,7 @@ mod tests {
async fn test_configure_proxy_header() {
let client = MockClient::default();
let headers = Arc::clone(&client.headers);
let agent = AtpAgent::new(client, MemoryMapStore::default());
let agent = AtpAgent::new(client, MemoryStore::default());

agent
.api
Expand Down
21 changes: 11 additions & 10 deletions atrium-api/src/agent/atp_agent/inner.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::did_doc::DidDocument;
use crate::types::string::Did;
use crate::types::TryFromUnknown;
use atrium_common::store::MapStore;
use atrium_common::store::Store as StoreTrait;
use atrium_xrpc::error::{Error, Result, XrpcErrorKind};
use atrium_xrpc::types::AuthorizationToken;
use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
use http::{Method, Request, Response};
use serde::{de::DeserializeOwned, Serialize};
use std::hash::Hash;
use std::{
fmt::Debug,
sync::{Arc, RwLock},
Expand Down Expand Up @@ -71,14 +72,14 @@ where

impl<S, T> XrpcClient for WrapperClient<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S: StoreTrait<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
fn base_uri(&self) -> String {
self.store.get_endpoint()
}
async fn authorization_token(&self, is_refresh: bool) -> Option<String> {
self.store.get_session().await.map(|session| {
async fn authorization_token(&self, is_refresh: bool) -> Option<AuthorizationToken> {
self.store.get(&()).await.transpose().and_then(core::result::Result::ok).map(|session| {
AuthorizationToken::Bearer(if is_refresh {
session.data.refresh_jwt
} else {
Expand All @@ -103,7 +104,7 @@ pub struct Client<S, T> {

impl<S, T> Client<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S: StoreTrait<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
pub fn new(store: Arc<Store<S>>, xrpc: T) -> Self {
Expand Down Expand Up @@ -218,7 +219,7 @@ where

impl<S, T> Clone for Client<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S: StoreTrait<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
fn clone(&self) -> Self {
Expand Down Expand Up @@ -247,7 +248,7 @@ where

impl<S, T> XrpcClient for Client<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S: StoreTrait<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
Expand Down Expand Up @@ -294,11 +295,11 @@ impl<S> Store<S> {
}
}

impl<S, K, V> MapStore<K, V> for Store<S>
impl<S, K, V> StoreTrait<K, V> for Store<S>
where
K: Eq + Hash + Send + Sync,
V: Clone + Send + Sync,
S: MapStore<K, V> + Send + Sync,
V: Clone + Send,
S: StoreTrait<K, V> + Sync,
{
type Error = S::Error;

Expand Down
3 changes: 0 additions & 3 deletions atrium-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
pub mod resolver;
pub mod store;
pub mod types;

pub mod resolver;
pub mod store;
2 changes: 1 addition & 1 deletion atrium-common/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::future::Future;
use std::hash::Hash;

#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait MapStore<K, V>
pub trait Store<K, V>
where
K: Eq + Hash,
V: Clone,
Expand Down
20 changes: 10 additions & 10 deletions atrium-common/src/store/memory.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,46 @@
use super::MapStore;
use super::Store;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::Mutex;

#[derive(Error, Debug)]
#[error("memory store error")]
pub struct Error;

// TODO: LRU cache?
#[derive(Clone)]
pub struct MemoryMapStore<K, V> {
pub struct MemoryStore<K, V> {
store: Arc<Mutex<HashMap<K, V>>>,
}

impl<K, V> Default for MemoryMapStore<K, V> {
impl<K, V> Default for MemoryStore<K, V> {
fn default() -> Self {
Self { store: Arc::new(Mutex::new(HashMap::new())) }
}
}

impl<K, V> MapStore<K, V> for MemoryMapStore<K, V>
impl<K, V> Store<K, V> for MemoryStore<K, V>
where
K: Debug + Eq + Hash + Send + Sync + 'static,
V: Debug + Clone + Send + Sync + 'static,
{
type Error = Error;

async fn get(&self, key: &K) -> Result<Option<V>, Self::Error> {
Ok(self.store.lock().unwrap().get(key).cloned())
Ok(self.store.lock().await.get(key).cloned())
}
async fn set(&self, key: K, value: V) -> Result<(), Self::Error> {
self.store.lock().unwrap().insert(key, value);
self.store.lock().await.insert(key, value);
Ok(())
}
async fn del(&self, key: &K) -> Result<(), Self::Error> {
self.store.lock().unwrap().remove(key);
self.store.lock().await.remove(key);
Ok(())
}
async fn clear(&self) -> Result<(), Self::Error> {
self.store.lock().unwrap().clear();
self.store.lock().await.clear();
Ok(())
}
}
16 changes: 0 additions & 16 deletions atrium-oauth/identity/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use atrium_api::types::string::Did;
use atrium_common::resolver;
use atrium_xrpc::http::uri::InvalidUri;
use atrium_xrpc::http::StatusCode;
use thiserror::Error;
Expand Down Expand Up @@ -36,19 +35,4 @@ pub enum Error {
Uri(#[from] InvalidUri),
}

impl From<resolver::Error> for Error {
fn from(error: resolver::Error) -> Self {
match error {
resolver::Error::DnsResolver(error) => Error::DnsResolver(error),
resolver::Error::Http(error) => Error::Http(error),
resolver::Error::HttpClient(error) => Error::HttpClient(error),
resolver::Error::HttpStatus(error) => Error::HttpStatus(error),
resolver::Error::SerdeJson(error) => Error::SerdeJson(error),
resolver::Error::SerdeHtmlForm(error) => Error::SerdeHtmlForm(error),
resolver::Error::Uri(error) => Error::Uri(error),
resolver::Error::NotFound => Error::NotFound,
}
}
}

pub type Result<T> = core::result::Result<T, Error>;
Loading

0 comments on commit 82a9398

Please sign in to comment.