From 1f9f9d8b5f06faddf13b798d9e1dbe30171dfa5d Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 12 Nov 2024 05:42:17 +0000 Subject: [PATCH] extend request types --- atrium-common/src/store.rs | 4 +- atrium-oauth/oauth-client/src/lib.rs | 2 +- atrium-oauth/oauth-client/src/oauth_client.rs | 2 +- .../oauth-client/src/oauth_session.rs | 15 +---- atrium-oauth/oauth-client/src/server_agent.rs | 57 +++++++++++++++---- atrium-oauth/oauth-client/src/store.rs | 2 +- atrium-oauth/oauth-client/src/types.rs | 8 ++- .../oauth-client/src/types/request.rs | 22 +++++-- 8 files changed, 73 insertions(+), 39 deletions(-) diff --git a/atrium-common/src/store.rs b/atrium-common/src/store.rs index b7a3177d..6b4b3be7 100644 --- a/atrium-common/src/store.rs +++ b/atrium-common/src/store.rs @@ -36,10 +36,10 @@ where T: MapStore<(), S> + Send + Sync, { async fn get_session(&self) -> Option { - self.get(&Default::default()).await.expect("Infallible") + self.get(&()).await.expect("Infallible") } async fn set_session(&self, session: S) { - self.set(Default::default(), session).await.expect("Infallible") + self.set((), session).await.expect("Infallible") } async fn clear_session(&self) { self.clear().await.expect("Infallible") diff --git a/atrium-oauth/oauth-client/src/lib.rs b/atrium-oauth/oauth-client/src/lib.rs index 07de79d0..1271ff57 100644 --- a/atrium-oauth/oauth-client/src/lib.rs +++ b/atrium-oauth/oauth-client/src/lib.rs @@ -5,12 +5,12 @@ mod http_client; mod jose; mod keyset; mod oauth_client; +mod oauth_session; mod resolver; mod server_agent; pub mod store; mod types; mod utils; -mod oauth_session; pub use atproto::{ AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, GrantType, Scope, diff --git a/atrium-oauth/oauth-client/src/oauth_client.rs b/atrium-oauth/oauth-client/src/oauth_client.rs index 8a8ea45e..e6a62b1a 100644 --- a/atrium-oauth/oauth-client/src/oauth_client.rs +++ b/atrium-oauth/oauth-client/src/oauth_client.rs @@ -9,7 +9,7 @@ use crate::store::state::{InternalStateData, StateStore}; use crate::types::{ AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions, CallbackParams, OAuthAuthorizationServerMetadata, OAuthClientMetadata, - OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters, TokenSet, + OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters, TryIntoOAuthClientMetadata, }; use crate::utils::{compare_algos, generate_key, generate_nonce, get_random_values}; diff --git a/atrium-oauth/oauth-client/src/oauth_session.rs b/atrium-oauth/oauth-client/src/oauth_session.rs index a89bfaa8..84a43075 100644 --- a/atrium-oauth/oauth-client/src/oauth_session.rs +++ b/atrium-oauth/oauth-client/src/oauth_session.rs @@ -18,16 +18,15 @@ use crate::{server_agent::OAuthServerAgent, store::session::Session}; #[derive(Clone, Debug, Error)] pub enum Error {} +#[allow(dead_code)] pub struct OAuthSession where S: SessionStore, - // N: DpopStore + Send + Sync, T: XrpcClient + Send + Sync + 'static, D: DidResolver + Send + Sync, H: HandleResolver + Send + Sync, { did: Did, - // dpop: DpopClient, server: Arc>, store: Arc>, } @@ -35,30 +34,22 @@ where impl OAuthSession where S: SessionStore, - // N: DpopStore + Send + Sync, T: XrpcClient + Send + Sync, D: DidResolver + Send + Sync, H: HandleResolver + Send + Sync, { pub fn new( did: Did, - // dpop: DpopClient, server: Arc>, store: Arc>, ) -> Self { - Self { - did, - // dpop, - server, - store, - } + Self { did, server, store } } } impl HttpClient for OAuthSession where S: SessionStore + Send + Sync, - // N: DpopStore + Send + Sync, T: XrpcClient + Send + Sync, D: DidResolver + Send + Sync, H: HandleResolver + Send + Sync, @@ -74,7 +65,6 @@ where impl XrpcClient for OAuthSession where S: SessionStore + Send + Sync, - // N: DpopStore + Send + Sync, T: XrpcClient + Send + Sync, D: DidResolver + Send + Sync, H: HandleResolver + Send + Sync, @@ -99,7 +89,6 @@ where impl SessionManager for OAuthSession where S: SessionStore + Send + Sync, - // N: DpopStore + Send + Sync, T: XrpcClient + Send + Sync, D: DidResolver + Send + Sync, H: HandleResolver + Send + Sync, diff --git a/atrium-oauth/oauth-client/src/server_agent.rs b/atrium-oauth/oauth-client/src/server_agent.rs index 2a05beff..b949c471 100644 --- a/atrium-oauth/oauth-client/src/server_agent.rs +++ b/atrium-oauth/oauth-client/src/server_agent.rs @@ -4,8 +4,9 @@ use crate::jose::jwt::{RegisteredClaims, RegisteredClaimsAud}; use crate::keyset::Keyset; use crate::resolver::OAuthResolver; use crate::types::{ - OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthTokenResponse, - PushedAuthorizationRequestParameters, TokenGrantType, TokenRequestParameters, TokenSet, + AuthorizationCodeParameters, OAuthAuthorizationServerMetadata, OAuthClientMetadata, + OAuthTokenResponse, PushedAuthorizationRequestParameters, RefreshTokenParameters, + RevocationRequestParameters, TokenRequestParameters, TokenSet, }; use crate::utils::{compare_algos, generate_nonce}; use atrium_api::types::string::Datetime; @@ -56,7 +57,7 @@ pub type Result = core::result::Result; #[allow(dead_code)] pub enum OAuthRequest { Token(TokenRequestParameters), - Revocation, + Revocation(RevocationRequestParameters), Introspection, PushedAuthorizationRequest(PushedAuthorizationRequestParameters), } @@ -65,14 +66,14 @@ impl OAuthRequest { fn name(&self) -> String { String::from(match self { Self::Token(_) => "token", - Self::Revocation => "revocation", + Self::Revocation(_) => "revocation", Self::Introspection => "introspection", Self::PushedAuthorizationRequest(_) => "pushed_authorization_request", }) } fn expected_status(&self) -> StatusCode { match self { - Self::Token(_) => StatusCode::OK, + Self::Token(_) | Self::Revocation(_) => StatusCode::OK, Self::PushedAuthorizationRequest(_) => StatusCode::CREATED, _ => unimplemented!(), } @@ -162,12 +163,44 @@ where } pub async fn exchange_code(&self, code: &str, verifier: &str) -> Result { self.verify_token_response( - self.request(OAuthRequest::Token(TokenRequestParameters { - grant_type: TokenGrantType::AuthorizationCode, - code: code.into(), - redirect_uri: self.client_metadata.redirect_uris[0].clone(), // ? - code_verifier: verifier.into(), - })) + self.request(OAuthRequest::Token(TokenRequestParameters::AuthorizationCode( + AuthorizationCodeParameters { + code: code.into(), + redirect_uri: self.client_metadata.redirect_uris[0].clone(), // ? + code_verifier: verifier.into(), + }, + ))) + .await?, + ) + .await + } + pub async fn revoke_session(&self, token: &str) -> Result<()> { + self.request(OAuthRequest::Revocation(RevocationRequestParameters { token: token.into() })) + .await + } + pub async fn refresh_session(&self, token_set: TokenSet) -> Result { + let TokenSet { sub, scope, refresh_token, access_token, token_type, expires_at, .. } = + token_set; + let expires_in = expires_at.map(|expires_at| { + expires_at.as_ref().signed_duration_since(Datetime::now().as_ref()).num_seconds() + }); + let token_response = OAuthTokenResponse { + access_token, + token_type, + expires_in, + refresh_token, + scope, + sub: Some(sub), + }; + let TokenSet { scope, refresh_token: Some(refresh_token), .. } = + self.verify_token_response(token_response).await? + else { + todo!(); + }; + self.verify_token_response( + self.request(OAuthRequest::Token(TokenRequestParameters::RefreshToken( + RefreshTokenParameters { refresh_token, scope }, + ))) .await?, ) .await @@ -267,7 +300,7 @@ where fn endpoint(&self, request: &OAuthRequest) -> Option<&String> { match request { OAuthRequest::Token(_) => Some(&self.server_metadata.token_endpoint), - OAuthRequest::Revocation => self.server_metadata.revocation_endpoint.as_ref(), + OAuthRequest::Revocation(_) => self.server_metadata.revocation_endpoint.as_ref(), OAuthRequest::Introspection => self.server_metadata.introspection_endpoint.as_ref(), OAuthRequest::PushedAuthorizationRequest(_) => { self.server_metadata.pushed_authorization_request_endpoint.as_ref() diff --git a/atrium-oauth/oauth-client/src/store.rs b/atrium-oauth/oauth-client/src/store.rs index bb7b109c..f7247255 100644 --- a/atrium-oauth/oauth-client/src/store.rs +++ b/atrium-oauth/oauth-client/src/store.rs @@ -1,2 +1,2 @@ -pub mod state; pub mod session; +pub mod state; diff --git a/atrium-oauth/oauth-client/src/types.rs b/atrium-oauth/oauth-client/src/types.rs index 7d09739f..7db747c4 100644 --- a/atrium-oauth/oauth-client/src/types.rs +++ b/atrium-oauth/oauth-client/src/types.rs @@ -7,12 +7,14 @@ mod token; pub use client_metadata::{OAuthClientMetadata, TryIntoOAuthClientMetadata}; pub use metadata::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata}; pub use request::{ - AuthorizationCodeChallengeMethod, AuthorizationResponseType, - PushedAuthorizationRequestParameters, TokenGrantType, TokenRequestParameters, + AuthorizationCodeChallengeMethod, AuthorizationCodeParameters, AuthorizationResponseType, + PushedAuthorizationRequestParameters, RefreshTokenParameters, RevocationRequestParameters, + TokenRequestParameters, }; pub use response::{OAuthPusehedAuthorizationRequestResponse, OAuthTokenResponse}; use serde::Deserialize; -pub use token::{TokenSet, TokenInfo}; +#[allow(unused_imports)] +pub use token::{TokenInfo, TokenSet}; #[derive(Debug, Deserialize)] pub enum AuthorizeOptionPrompt { diff --git a/atrium-oauth/oauth-client/src/types/request.rs b/atrium-oauth/oauth-client/src/types/request.rs index a5b71474..cb110b51 100644 --- a/atrium-oauth/oauth-client/src/types/request.rs +++ b/atrium-oauth/oauth-client/src/types/request.rs @@ -45,18 +45,28 @@ pub struct PushedAuthorizationRequestParameters { pub prompt: Option, } +// https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 #[derive(Serialize)] -#[serde(rename_all = "snake_case")] -pub enum TokenGrantType { - AuthorizationCode, +#[serde(tag = "grant_type", rename_all = "snake_case")] +pub enum TokenRequestParameters { + AuthorizationCode(AuthorizationCodeParameters), + RefreshToken(RefreshTokenParameters), } #[derive(Serialize)] -pub struct TokenRequestParameters { - // https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 - pub grant_type: TokenGrantType, +pub struct AuthorizationCodeParameters { pub code: String, pub redirect_uri: String, // https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 pub code_verifier: String, } + +#[derive(Serialize)] +pub struct RefreshTokenParameters { + pub refresh_token: String, + pub scope: Option, +} +#[derive(Serialize)] +pub struct RevocationRequestParameters { + pub token: String, +}