diff --git a/Cargo.lock b/Cargo.lock index 89bd1b5..091c449 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1432,10 +1432,12 @@ dependencies = [ "console_error_panic_hook", "exitcode", "futures", + "futures-util", "http 1.1.0", "mockito", "oci-spec", "opentelemetry-proto", + "pin-project", "prost", "regex", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index 71db177..b06ae31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ prost = "0.12.6" async-channel = "2.3.1" tower = "0.4.13" async-trait = "0.1.81" +pin-project = "1.1.5" [dependencies.web-sys] version = "0.3.63" @@ -65,6 +66,7 @@ features = [ ] [dev-dependencies] +futures-util = "0.3.30" mockito = "1.4.0" test-case = "3.3.1" tokio = { version = "1.39.1", features = ["macros"]} diff --git a/src/app.rs b/src/app.rs index d403c32..fd66c60 100644 --- a/src/app.rs +++ b/src/app.rs @@ -9,6 +9,7 @@ use axum::{ Router, }; use http::StatusCode; +use time::OffsetDateTime; use url::Url; use crate::{package, pyoci::PyOciError, templates, PyOci}; @@ -50,24 +51,32 @@ pub fn router() -> Router { "/:registry/:namespace/", post(publish_package).layer(DefaultBodyLimit::max(50 * 1024 * 1024)), ) - .layer(axum::middleware::from_fn(trace_middleware)) + .layer(axum::middleware::from_fn(accesslog_middleware)) } /// Log incoming requests -async fn trace_middleware( +async fn accesslog_middleware( method: axum::http::Method, uri: axum::http::Uri, headers: axum::http::HeaderMap, request: axum::extract::Request, next: axum::middleware::Next, ) -> axum::response::Response { + let start = OffsetDateTime::now_utc(); let response = next.run(request).await; let status: u16 = response.status().into(); let user_agent = headers .get("user-agent") .map(|ua| ua.to_str().unwrap_or("")); - tracing::info!(method = %method, status, path = %uri.path(), user_agent, "type" = "request"); + tracing::info!( + elapsed_ms = (OffsetDateTime::now_utc() - start).whole_milliseconds(), + method = method.to_string(), + status, + path = uri.path(), + user_agent, + "type" = "request" + ); response } @@ -90,7 +99,7 @@ async fn list_package( let package: package::Info = path_params.0.try_into()?; - let client = PyOci::new(package.registry.clone(), auth); + let mut client = PyOci::new(package.registry.clone(), auth)?; // Fetch at most 45 packages // https://developers.cloudflare.com/workers/platform/limits/#account-plan-limits let files = client.list_package_files(&package, 45).await?; @@ -101,7 +110,7 @@ async fn list_package( } /// Download package request handler -// #[debug_handler] +#[debug_handler] // Mark the handler as Send when building a wasm target // JsFuture, and most other JS objects are !Send // Because the cloudflare worker runtime is single-threaded, we can safely mark this as Send @@ -117,7 +126,7 @@ async fn download_package( }; let package: package::Info = path_params.0.try_into()?; - let client = PyOci::new(package.registry.clone(), auth); + let mut client = PyOci::new(package.registry.clone(), auth)?; let data = client .download_package_file(&package) .await? @@ -125,7 +134,6 @@ async fn download_package( .await .expect("valid bytes"); - // TODO: With some trickery we could stream the data directly to the response Ok(( [( header::CONTENT_DISPOSITION, @@ -138,7 +146,7 @@ async fn download_package( /// Publish package request handler /// /// ref: https://warehouse.pypa.io/api-reference/legacy.html#upload-api -// #[debug_handler] +#[debug_handler] // Mark the handler as Send when building a wasm target // JsFuture, and most other JS objects are !Send // Because the cloudflare worker runtime is single-threaded, we can safely mark this as Send @@ -156,7 +164,7 @@ async fn publish_package( None => None, }; let package: package::Info = (registry, namespace, None, form_data.filename).try_into()?; - let client = PyOci::new(package.registry.clone(), auth); + let mut client = PyOci::new(package.registry.clone(), auth)?; client .publish_package_file(&package, form_data.content) diff --git a/src/lib.rs b/src/lib.rs index 2da3eae..6c3c980 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,8 @@ mod pyoci; mod templates; // HTTP Transport mod transport; +// Services +mod service; // Re-export the PyOci client pub use pyoci::PyOci; diff --git a/src/otlp.rs b/src/otlp.rs index ff33d1d..8b8d3eb 100644 --- a/src/otlp.rs +++ b/src/otlp.rs @@ -214,7 +214,7 @@ where fn on_close(&self, id: tracing_core::span::Id, ctx: Context<'_, S>) { let span = ctx.span(&id).expect("span not found"); - if !span.parent().is_none() { + if span.parent().is_some() { // This is a sub-span, we'll flush all messages when the root span is closed return; } diff --git a/src/pyoci.rs b/src/pyoci.rs index 0d9e3c3..89554e1 100644 --- a/src/pyoci.rs +++ b/src/pyoci.rs @@ -2,6 +2,7 @@ use anyhow::{bail, Context, Error, Result}; use base16ct::lower::encode_string as hex_encode; use futures::stream::FuturesUnordered; use futures::stream::StreamExt; +use http::HeaderValue; use oci_spec::image::Arch; use oci_spec::image::DescriptorBuilder; use oci_spec::image::ImageIndexBuilder; @@ -155,14 +156,17 @@ pub struct AuthResponse { /// WWW-Authenticate header /// ref: pub struct WwwAuth { - pub realm: String, + pub realm: Url, pub service: String, // scope: String, } impl WwwAuth { /// Parse a WWW-Authenticate header - pub fn parse(value: &str) -> Result { + pub fn parse(header: &HeaderValue) -> Result { + let value = header + .to_str() + .context("Failed to parse WWW-Authenticate header")?; let value = match value.strip_prefix("Bearer ") { None => bail!("Not a Bearer token"), Some(value) => value, @@ -171,15 +175,16 @@ impl WwwAuth { .unwrap() .captures(value) { - Some(value) => value.name("realm").unwrap().as_str().to_string(), - None => bail!("`realm` key missing from WWW-Authenticate header"), + Some(value) => value.name("realm").unwrap().as_str(), + None => bail!("`realm` key missing"), }; + let realm = Url::parse(realm).context("Failed to parse realm URL")?; let service = match Regex::new(r#"service="(?P[^"\s]*)"#) .expect("valid regex") .captures(value) { Some(value) => value.name("service").unwrap().as_str().to_string(), - None => bail!("`service` key missing from WWW-Authenticate header"), + None => bail!("`service` key missing"), }; // let scope = match Regex::new(r#"scope="(?P[^"]*)"#) // .expect("valid regex") @@ -201,7 +206,7 @@ impl WwwAuth { } /// Client to communicate with the OCI v2 registry -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct PyOci { registry: Url, transport: HttpTransport, @@ -209,11 +214,11 @@ pub struct PyOci { impl PyOci { /// Create a new Client - pub fn new(registry: Url, auth: Option) -> Self { - PyOci { + pub fn new(registry: Url, auth: Option) -> Result { + Ok(PyOci { registry, - transport: HttpTransport::new(auth), - } + transport: HttpTransport::new(auth)?, + }) } /// List all files for the given package @@ -221,7 +226,7 @@ impl PyOci { /// Limits the number of files to `n` /// ref: https://github.com/opencontainers/distribution-spec/blob/main/spec.md#listing-tags pub async fn list_package_files( - &self, + &mut self, package: &package::Info, n: usize, ) -> Result> { @@ -239,7 +244,8 @@ impl PyOci { // Even for non-spec registries the last-added seems to be at the end of the list // so this will result in the wanted list of tags in most cases. for tag in tags.iter().rev().take(n) { - futures.push(self.package_info_for_ref(package, &name, tag)); + let pyoci = self.clone(); + futures.push(pyoci.package_info_for_ref(package, &name, tag)); } for result in futures .collect::, Error>>>() @@ -251,7 +257,7 @@ impl PyOci { } async fn package_info_for_ref( - &self, + mut self, package: &package::Info, name: &str, reference: &str, @@ -290,7 +296,10 @@ impl PyOci { Ok(files) } - pub async fn download_package_file(&self, package: &crate::package::Info) -> Result { + pub async fn download_package_file( + &mut self, + package: &crate::package::Info, + ) -> Result { // Pull index let index = match self .pull_manifest(&package.oci_name()?, &package.oci_tag()?) @@ -349,7 +358,7 @@ impl PyOci { } pub async fn publish_package_file( - &self, + &mut self, package: &crate::package::Info, file: Vec, ) -> Result<()> { @@ -430,7 +439,7 @@ impl PyOci { /// /// https://github.com/opencontainers/distribution-spec/blob/main/spec.md#post-then-put async fn push_blob( - &self, + &mut self, // Name of the package, including namespace. e.g. "library/alpine" name: &str, blob: Blob, @@ -442,8 +451,7 @@ impl PyOci { self.transport .head(build_url!(&self, "/v2/{}/blobs/{}", name, digest)), ) - .await - .expect("valid response"); + .await?; match response.status() { StatusCode::OK => { @@ -461,7 +469,7 @@ impl PyOci { .transport .post(url) .header("Content-Type", "application/octet-stream"); - let response = self.transport.send(request).await.expect("valid response"); + let response = self.transport.send(request).await?; let location = match response.status() { StatusCode::CREATED => return Ok(()), StatusCode::ACCEPTED => response @@ -486,7 +494,7 @@ impl PyOci { .header("Content-Type", "application/octet-stream") .header("Content-Length", blob.data.len().to_string()) .body(blob.data); - let response = self.transport.send(request).await.expect("valid response"); + let response = self.transport.send(request).await?; match response.status() { StatusCode::CREATED => {} status => { @@ -509,7 +517,7 @@ impl PyOci { /// /// This returns the raw response so the caller can handle the blob as needed async fn pull_blob( - &self, + &mut self, // Name of the package, including namespace. e.g. "library/alpine" name: String, // Descriptor of the blob to pull @@ -518,7 +526,7 @@ impl PyOci { let digest = descriptor.digest(); let url = build_url!(&self, "/v2/{}/blobs/{}", &name, digest); let request = self.transport.get(url); - let response = self.transport.send(request).await.expect("valid response"); + let response = self.transport.send(request).await?; match response.status() { StatusCode::OK => Ok(response), @@ -527,10 +535,10 @@ impl PyOci { } /// List the available tags for a package - async fn list_tags(&self, name: &str) -> anyhow::Result { + async fn list_tags(&mut self, name: &str) -> anyhow::Result { let url = build_url!(&self, "/v2/{}/tags/list", name); let request = self.transport.get(url); - let response = self.transport.send(request).await.expect("valid response"); + let response = self.transport.send(request).await?; match response.status() { StatusCode::OK => {} status => return Err(PyOciError::from((status, response.text().await?)).into()), @@ -547,7 +555,7 @@ impl PyOci { /// ImageIndex will be pushed with a version tag if version is set /// ImageManifest will always be pushed with a digest reference async fn push_manifest( - &self, + &mut self, name: &str, manifest: Manifest, version: Option<&str>, @@ -573,7 +581,7 @@ impl PyOci { .put(url) .header("Content-Type", content_type) .body(data); - let response = self.transport.send(request).await.expect("valid response"); + let response = self.transport.send(request).await?; match response.status() { StatusCode::CREATED => {} status => return Err(PyOciError::from((status, response.text().await?)).into()), @@ -585,13 +593,13 @@ impl PyOci { /// /// If the manifest does not exist, Ok is returned /// If any other error happens, an Err is returned - async fn pull_manifest(&self, name: &str, reference: &str) -> Result> { + async fn pull_manifest(&mut self, name: &str, reference: &str) -> Result> { let url = build_url!(&self, "/v2/{}/manifests/{}", name, reference); let request = self.transport.get(url).header( "Accept", "application/vnd.oci.image.manifest.v1+json, application/vnd.oci.image.index.v1+json", ); - let response = self.transport.send(request).await.expect("valid response"); + let response = self.transport.send(request).await?; match response.status() { StatusCode::NOT_FOUND => return Ok(None), StatusCode::OK => {} @@ -634,7 +642,7 @@ mod tests { fn test_build_url() -> Result<()> { let client = PyOci { registry: Url::parse("https://example.com").expect("valid url"), - transport: HttpTransport::new(None), + transport: HttpTransport::new(None).unwrap(), }; let url = build_url!(&client, "/foo/{}/", "latest"); assert_eq!(url.as_str(), "https://example.com/foo/latest/"); @@ -645,7 +653,7 @@ mod tests { fn test_build_url_absolute() -> Result<()> { let client = PyOci { registry: Url::parse("https://example.com").expect("valid url"), - transport: HttpTransport::new(None), + transport: HttpTransport::new(None).unwrap(), }; let url = build_url!(&client, "{}/foo?bar=baz&qaz=sha:123", "http://pyoci.nl"); assert_eq!(url.as_str(), "http://pyoci.nl/foo?bar=baz&qaz=sha:123"); @@ -656,7 +664,7 @@ mod tests { fn test_build_url_double_period() { let client = PyOci { registry: Url::parse("https://example.com").expect("valid url"), - transport: HttpTransport::new(None), + transport: HttpTransport::new(None).unwrap(), }; let x = || -> Result { Ok(build_url!(&client, "/foo/{}/", "..")) }(); assert!(x.is_err()); @@ -706,9 +714,9 @@ mod tests { .await, ); - let client = PyOci { + let mut client = PyOci { registry: Url::parse(&url).expect("valid url"), - transport: HttpTransport::new(None), + transport: HttpTransport::new(None).unwrap(), }; let blob = Blob::new("hello".into(), "application/octet-stream"); assert!(client.push_blob("mockserver/foobar", blob).await.is_ok()); @@ -761,9 +769,9 @@ mod tests { .await, ); - let client = PyOci { + let mut client = PyOci { registry: Url::parse(&url).expect("valid url"), - transport: HttpTransport::new(None), + transport: HttpTransport::new(None).unwrap(), }; let blob = Blob::new("hello".into(), "application/octet-stream"); assert!(client.push_blob("mockserver/foobar", blob).await.is_ok()); diff --git a/src/service/auth.rs b/src/service/auth.rs new file mode 100644 index 0000000..0cbff1b --- /dev/null +++ b/src/service/auth.rs @@ -0,0 +1,571 @@ +use anyhow::{anyhow, Context as _, Result}; +use futures::{ready, FutureExt}; +use http::StatusCode; +use pin_project::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, RwLock}; +use std::task::{Context, Poll}; +use tower::{Layer, Service}; +use url::Url; + +use crate::pyoci::{AuthResponse, PyOciError, WwwAuth}; + +/// Authentication layer for the OCI registry +/// This layer will handle [token authentication](https://distribution.github.io/distribution/spec/auth/token/) +/// based on the authentication header of the original request. +#[derive(Debug, Default, Clone)] +pub struct AuthLayer { + // The Basic token to trade for a Bearer token + basic: Option, + // The Bearer token to use for authentication + // Will be set after successful authentication + bearer: Arc>>, +} + +impl AuthLayer { + pub fn new(basic_token: Option) -> Result { + let basic_token = match basic_token { + None => None, + Some(token) => { + let mut token = http::HeaderValue::try_from(token)?; + token.set_sensitive(true); + Some(token) + } + }; + + Ok(Self { + basic: basic_token, + bearer: Arc::new(RwLock::new(None)), + }) + } +} + +impl Layer for AuthLayer { + type Service = AuthService; + + fn layer(&self, service: S) -> Self::Service { + AuthService::new(self.basic.clone(), self.bearer.clone(), service) + } +} + +#[derive(Debug, Clone)] +pub struct AuthService { + basic: Option, + bearer: Arc>>, + service: S, +} + +impl AuthService { + fn new( + basic: Option, + bearer: Arc>>, + service: S, + ) -> Self { + Self { + bearer, + basic, + service, + } + } +} + +impl Service for AuthService +where + S: Service + Clone + Send + 'static, + >::Future: Send, + >::Error: Into, +{ + type Response = S::Response; + type Error = anyhow::Error; + type Future = AuthFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, mut request: reqwest::Request) -> Self::Future { + if let Some(bearer) = self.bearer.read().expect("Failed to get read lock").clone() { + // If we have a bearer token, add it to the request + request + .headers_mut() + .insert(http::header::AUTHORIZATION, bearer); + } + AuthFuture::new( + request.try_clone(), + self.clone(), + self.service.call(request), + ) + } +} + +/// The Future returned by AuthService +/// Implements the actual authentication logic +#[pin_project] +pub struct AuthFuture +where + S: Service, +{ + // Clone of the original request to retry after authentication + request: Option, + // Clone of the original service, used to do the authentication request and retry + // the original request + auth: AuthService, + // State of this Future + #[pin] + state: AuthState, +} + +/// State machine for AuthFuture +#[pin_project(project = AuthStateProj)] +enum AuthState { + // Polling the original request or the retry after authentication + Called { + #[pin] + future: F, + }, + // Polling the authentication request + Authenticating { + #[pin] + future: Pin> + Send>>, + }, +} + +impl AuthFuture +where + S: Service, +{ + fn new(request: Option, inner: AuthService, future: S::Future) -> Self { + Self { + request, + auth: inner, + state: AuthState::Called { future }, + } + } +} + +impl Future for AuthFuture +where + // Service being called that we might need to authenticate for + S: Service + Clone + Send + 'static, + >::Future: Send, + >::Error: Into, +{ + type Output = anyhow::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + match this.state.as_mut().project() { + // Polling original request + AuthStateProj::Called { future } => { + let response = ready!(future.poll(cx)).map_err(Into::into)?; + + if response.status() != StatusCode::UNAUTHORIZED { + return Poll::Ready(Ok(response)); + } + tracing::debug!("Received 401 response, authenticating"); + if this.request.is_none() { + // No clone of the original request, can't retry after authentication + tracing::debug!("No request to retry, skipping authentication"); + return Poll::Ready(Ok(response)); + } + // Take the basic token, we are only expected to trade it once + let Some(basic_token) = this.auth.basic.take() else { + // No basic token to trade for a bearer token + tracing::debug!("No basic token, skipping authentication"); + return Poll::Ready(Ok(response)); + }; + + let www_auth = match response.headers().get("WWW-Authenticate") { + None => { + return Poll::Ready(Err(PyOciError::from(( + StatusCode::BAD_GATEWAY, + "Registry did not provide a WWW-Authenticate header", + )) + .into())); + } + Some(value) => { + match WwwAuth::parse(value) { + Ok(value) => value, + Err(err) => { + return Poll::Ready(Err(PyOciError::from(( + StatusCode::BAD_GATEWAY, + format!("Registry returned invalid WWW-Authenticate header: {err}"), + )) + .into())); + } + } + } + }; + let srv = this.auth.clone(); + this.state.set(AuthState::Authenticating { + // No idea how to type this Future, lets just Pin it + future: authenticate(basic_token, www_auth, srv).boxed(), + }); + } + // Polling authentication request + AuthStateProj::Authenticating { future } => match ready!(future.poll(cx)) { + Ok(bearer_token) => { + // Take the original request, this prevents infinitely retrying if the + // server keeps returning 401 + let mut request = this + .request + .take() + .ok_or_else(|| anyhow!("Tried to retry twice after authentication"))?; + request + .headers_mut() + .insert(http::header::AUTHORIZATION, bearer_token.clone()); + this.auth + .bearer + .write() + .map_err(|_| { + anyhow!("Another thread panicked while writing bearer token") + })? + .replace(bearer_token); + // Retry the original request with the new bearer token + this.state.set(AuthState::Called { + future: this.auth.service.call(request), + }); + } + Err(err) => match err { + // Error during authentication, return the authentication response + AuthError::AuthResponse(auth_response) => { + return Poll::Ready(Ok(auth_response)) + } + // Other error, return it + AuthError::Error(err) => return Poll::Ready(Err(err)), + }, + }, + }; + } + } +} + +enum AuthError { + AuthResponse(reqwest::Response), + Error(anyhow::Error), +} + +impl From for AuthError +where + E: Into, +{ + fn from(err: E) -> Self { + AuthError::Error(err.into()) + } +} + +// Returns the bearer token if successful. +// Returns the upstream response of not. +#[cfg_attr(target_arch = "wasm32", worker::send)] +async fn authenticate( + basic_token: http::HeaderValue, + www_auth: WwwAuth, + mut service: S, +) -> Result +where + S: Service, + >::Future: Send, + >::Error: Into, +{ + let mut auth_url = www_auth.realm; + auth_url + .query_pairs_mut() + .append_pair("grant_type", "password") + .append_pair("service", &www_auth.service); + let mut auth_request = reqwest::Request::new(http::Method::GET, auth_url); + auth_request + .headers_mut() + .append("Authorization", basic_token); + let response = service.call(auth_request).await?; + if response.status() != StatusCode::OK { + return Err(AuthError::AuthResponse(response)); + } + let auth = response.json::().await.map_err(|err| { + PyOciError::from(( + StatusCode::BAD_GATEWAY, + format!("Failed to parse authentication response: {err}"), + )) + })?; + let mut token = http::HeaderValue::try_from(format!("Bearer {}", auth.token)) + .context("Failed to create bearer token header")?; + token.set_sensitive(true); + Ok(token) +} + +/// The high-level tests for this Service are part of `src/transport.rs`. +/// This module tests some of the error cases +#[cfg(test)] +mod test { + use super::*; + use mockito::Server; + use reqwest::{Body, Client}; + use tower::ServiceBuilder; + + // Happy-flow + #[tokio::test] + async fn auth_service() { + let mut server = Server::new_async().await; + let url = server.url(); + let mocks = vec![ + // Response to unauthenticated request + server + .mock("GET", "/foobar") + .with_status(401) + .with_header( + "WWW-Authenticate", + &format!("Bearer realm=\"{url}/token\",service=\"pyoci.fakeservice\""), + ) + .create_async() + .await, + // Token exchange + server + .mock( + "GET", + "/token?grant_type=password&service=pyoci.fakeservice", + ) + .match_header("Authorization", "Basic mybasicauth") + .with_status(200) + .with_body(r#"{"token":"mytoken"}"#) + .create_async() + .await, + // Re-submitted request, with bearer auth + server + .mock("GET", "/foobar") + .match_header("Authorization", "Bearer mytoken") + .with_status(200) + .with_body("Hello, world!") + .create_async() + .await, + ]; + + let mut service = ServiceBuilder::new() + .layer(AuthLayer::new(Some("Basic mybasicauth".into())).unwrap()) + .service(Client::default()); + let request = reqwest::Request::new( + http::Method::GET, + Url::parse(&format!("{url}/foobar")).unwrap(), + ); + + let response = service.call(request).await.unwrap(); + for mock in mocks { + mock.assert_async().await; + } + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.text().await.unwrap(), "Hello, world!"); + } + + // The if the original response it returned if the request can't be cloned. + // Without a clone we can't retry after authentication. + #[tokio::test] + async fn auth_service_missing_clone() { + let mut server = Server::new_async().await; + let url = server.url(); + let mocks = vec![ + // Response to unauthenticated request + server + .mock("GET", "/foobar") + .with_status(401) + .with_header( + "WWW-Authenticate", + &format!("Bearer realm=\"{url}/token\",service=\"pyoci.fakeservice\""), + ) + .create_async() + .await, + ]; + + let mut service = ServiceBuilder::new() + .layer(AuthLayer::new(Some("Basic mybasicauth".into())).unwrap()) + .service(Client::default()); + + // Construct a request that can't be cloned + let mut request = reqwest::Request::new( + http::Method::GET, + Url::parse(&format!("{url}/foobar")).unwrap(), + ); + let chunks: Vec> = vec![Ok("hello"), Ok("world")]; + let stream = futures_util::stream::iter(chunks); + let body = Body::wrap_stream(stream); + *request.body_mut() = Some(body); + + let response = service.call(request).await.unwrap(); + for mock in mocks { + mock.assert_async().await; + } + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } + + // Test if the original response is returned if there is no basic token to exchange. + #[tokio::test] + async fn auth_service_missing_basic_token() { + let mut server = Server::new_async().await; + let url = server.url(); + let mocks = vec![ + // Response to unauthenticated request + server + .mock("GET", "/foobar") + .with_status(401) + .with_header( + "WWW-Authenticate", + &format!("Bearer realm=\"{url}/token\",service=\"pyoci.fakeservice\""), + ) + .create_async() + .await, + ]; + + let mut service = ServiceBuilder::new() + .layer(AuthLayer::new(None).unwrap()) + .service(Client::default()); + + let request = reqwest::Request::new( + http::Method::GET, + Url::parse(&format!("{url}/foobar")).unwrap(), + ); + + let response = service.call(request).await.unwrap(); + for mock in mocks { + mock.assert_async().await; + } + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } + + // Test if BAD_GATEWAY is returned on response of the upsteam server without a + // WWW-Authenticate header. + #[tokio::test] + async fn auth_service_missing_www_auth_header() { + let mut server = Server::new_async().await; + let url = server.url(); + let mocks = vec![ + // invalid response to unauthenticated request + server + .mock("GET", "/foobar") + .with_status(401) + .create_async() + .await, + ]; + + let mut service = ServiceBuilder::new() + .layer(AuthLayer::new(Some("Basic mybasictoken".into())).unwrap()) + .service(Client::default()); + + let request = reqwest::Request::new( + http::Method::GET, + Url::parse(&format!("{url}/foobar")).unwrap(), + ); + + let error = service + .call(request) + .await + .unwrap_err() + .downcast::() + .unwrap(); + for mock in mocks { + mock.assert_async().await; + } + assert_eq!(error.status, StatusCode::BAD_GATEWAY); + assert_eq!( + error.message, + "Registry did not provide a WWW-Authenticate header".to_string() + ); + } + + // Test if BAD_GATEWAY is returned when the server responds with an invalid + // WWW-authenticate header + #[tokio::test] + async fn auth_service_invalid_www_auth_header() { + let mut server = Server::new_async().await; + let url = server.url(); + let mocks = vec![ + // Response to unauthenticated request + server + .mock("GET", "/foobar") + .with_status(401) + .with_header( + "WWW-Authenticate", + &format!("Bearer unknown=\"{url}/token\",service=\"pyoci.fakeservice\""), + ) + .create_async() + .await, + ]; + + let mut service = ServiceBuilder::new() + .layer(AuthLayer::new(Some("Basic mybasictoken".into())).unwrap()) + .service(Client::default()); + + let request = reqwest::Request::new( + http::Method::GET, + Url::parse(&format!("{url}/foobar")).unwrap(), + ); + + let error = service + .call(request) + .await + .unwrap_err() + .downcast::() + .unwrap(); + for mock in mocks { + mock.assert_async().await; + } + assert_eq!(error.status, StatusCode::BAD_GATEWAY); + assert_eq!( + error.message, + "Registry returned invalid WWW-Authenticate header: `realm` key missing".to_string() + ); + } + + // Test if we return BAD_GATEWAY if the server responds with a malformed token response + #[tokio::test] + async fn auth_service_malformed_auth_response() { + let mut server = mockito::Server::new_async().await; + let url = server.url(); + let mocks = vec![ + // Response to unauthenticated request + server + .mock("GET", "/foobar") + .with_status(401) + .with_header( + "WWW-Authenticate", + &format!("Bearer realm=\"{url}/token\",service=\"pyoci.fakeservice\""), + ) + .create_async() + .await, + // Token exchange + server + .mock( + "GET", + "/token?grant_type=password&service=pyoci.fakeservice", + ) + .match_header("Authorization", "Basic mybasictoken") + .with_status(200) + .with_body(r#"{"notatoken":"mytoken"}"#) + .create_async() + .await, + ]; + + let mut service = ServiceBuilder::new() + .layer(AuthLayer::new(Some("Basic mybasictoken".into())).unwrap()) + .service(Client::default()); + + let request = reqwest::Request::new( + http::Method::GET, + Url::parse(&format!("{url}/foobar")).unwrap(), + ); + + let error = service + .call(request) + .await + .unwrap_err() + .downcast::() + .unwrap(); + for mock in mocks { + mock.assert_async().await; + } + assert_eq!(error.status, StatusCode::BAD_GATEWAY); + assert_eq!( + error.message, + "Failed to parse authentication response: error decoding response body".to_string() + ); + } +} diff --git a/src/service/log.rs b/src/service/log.rs new file mode 100644 index 0000000..4c43f99 --- /dev/null +++ b/src/service/log.rs @@ -0,0 +1,97 @@ +use futures::ready; +use pin_project::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use time::OffsetDateTime; +use tower::{Layer, Service}; + +#[derive(Debug, Default, Clone)] +pub struct RequestLogLayer { + request_type: &'static str, +} + +impl RequestLogLayer { + pub fn new(request_type: &'static str) -> Self { + Self { request_type } + } +} + +impl Layer for RequestLogLayer { + type Service = RequestLog; + + fn layer(&self, service: S) -> Self::Service { + RequestLog::new(self.request_type, service) + } +} + +#[derive(Debug, Clone)] +pub struct RequestLog { + request_type: &'static str, + inner: S, +} + +impl RequestLog { + pub fn new(request_type: &'static str, service: S) -> Self { + Self { + request_type, + inner: service, + } + } +} + +impl Service for RequestLog +where + S: Service, +{ + type Response = S::Response; + type Error = S::Error; + type Future = LogFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: reqwest::Request) -> Self::Future { + LogFuture { + method: request.method().to_string(), + url: request.url().to_string(), + inner_fut: self.inner.call(request), + request_type: self.request_type, + start: OffsetDateTime::now_utc(), + } + } +} + +#[pin_project] +pub struct LogFuture { + #[pin] + inner_fut: F, + method: String, + url: String, + request_type: &'static str, + start: OffsetDateTime, +} + +impl Future for LogFuture +where + F: Future>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let result = ready!(this.inner_fut.poll(cx)); + if let Ok(response) = &result { + let status: u16 = response.status().into(); + tracing::info!( + elapsed_ms = (OffsetDateTime::now_utc() - *this.start).whole_milliseconds(), + method = this.method, + status, + url = this.url, + "type" = this.request_type, + ); + } + Poll::Ready(result) + } +} diff --git a/src/service/mod.rs b/src/service/mod.rs new file mode 100644 index 0000000..d18257c --- /dev/null +++ b/src/service/mod.rs @@ -0,0 +1,5 @@ +mod auth; +mod log; + +pub use auth::AuthLayer; +pub use log::RequestLogLayer; diff --git a/src/transport.rs b/src/transport.rs index 4f8c882..6f366a8 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -1,22 +1,49 @@ use anyhow::Result; -use http::StatusCode; -use std::sync::{Arc, Mutex}; -use url::Url; +use std::boxed::Box; +use std::future::poll_fn; +use std::future::Future; +use std::pin::Pin; +use tower::{Service, ServiceBuilder}; -use crate::pyoci::{AuthResponse, WwwAuth}; +use crate::service::AuthLayer; +use crate::service::RequestLogLayer; use crate::USER_AGENT; /// HTTP Transport /// /// This struct is responsible for sending HTTP requests to the upstream OCI registry. -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct HttpTransport { /// HTTP client client: reqwest::Client, - /// Basic auth string, including the "Basic " prefix - basic: Option, - /// Bearer token, including the "Bearer " prefix - bearer: Arc>>, + /// Authentication layer + auth_layer: AuthLayer, +} + +// Wraps the reqwest client so we can implement Service. +// reqwest implements Service normally but not for the WASM target. +// This allows us to use other Service implementations to wrap the reqwest client. +impl Service for HttpTransport { + type Response = reqwest::Response; + type Error = reqwest::Error; + // we need to box the future as we currently can't express the anonymous `impl Future` type + // returned by reqwest::Client::execute + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, request: reqwest::Request) -> Self::Future { + #[cfg(target_arch = "wasm32")] + let fut = Box::pin(worker::send::SendFuture::new(self.client.execute(request))); + #[cfg(not(target_arch = "wasm32"))] + let fut = Box::pin(self.client.execute(request)); + fut + } } impl HttpTransport { @@ -24,13 +51,12 @@ impl HttpTransport { /// /// auth: Basic auth string /// Will be swapped for a Bearer token if needed - pub fn new(auth: Option) -> Self { + pub fn new(auth: Option) -> Result { let client = reqwest::Client::builder().user_agent(USER_AGENT); - Self { - client: client.build().unwrap(), - basic: auth, - bearer: Arc::new(Mutex::new(None)), - } + Ok(Self { + client: client.build()?, + auth_layer: AuthLayer::new(auth)?, + }) } /// Send a request @@ -38,76 +64,17 @@ impl HttpTransport { /// When authentication is required, this method will automatically authenticate /// using the provided Basic auth string and caches the Bearer token for future requests within /// this session. - pub async fn send(&self, request: reqwest::RequestBuilder) -> Result { - let org_request = request.try_clone(); - let bearer_token = { - // Local scope the bearer lock - let token = self.bearer.lock().unwrap(); - token.clone() - }; - let request = match bearer_token { - Some(token) => request.header("Authorization", sens_header(&token)?), - None => request, - }; - let response = self._send(request).await?; - if response.status() != StatusCode::UNAUTHORIZED { - // No authentication needed or some error happened - return Ok(response); - } - let Some(org_request) = org_request else { - return Ok(response); - }; - - // Authenticate - let www_auth: WwwAuth = match response.headers().get("WWW-Authenticate") { - None => return Ok(response), - Some(value) => match WwwAuth::parse(value.to_str()?) { - Ok(value) => value, - Err(_) => return Ok(response), - }, - }; - let Some(basic_token) = &self.basic else { - // No credentials provided - return Ok(response); - }; - - let mut auth_url = Url::parse(&www_auth.realm)?; - auth_url - .query_pairs_mut() - .append_pair("grant_type", "password") - // if client_id is needed, add it here, - // although GitHub does not seem to need a valid client_id - // .append_pair("client_id", username) - .append_pair("service", &www_auth.service); - let auth_request = self.get(auth_url).header("Authorization", basic_token); - let auth_response = self._send(auth_request).await?; - - if auth_response.status() != StatusCode::OK { - // Authentication failed - return Ok(auth_response); - } - - let auth_response: AuthResponse = auth_response.json().await?; - let bearer_token = { - // Local scope the bearer lock and update the token - let mut token = self.bearer.lock().unwrap(); - let new_token = format!("Bearer {}", auth_response.token); - *token = Some(new_token.clone()); - new_token - }; - self._send(org_request.header("Authorization", sens_header(&bearer_token)?)) - .await - } - - /// Send a request - async fn _send(&self, request: reqwest::RequestBuilder) -> Result { + pub async fn send(&mut self, request: reqwest::RequestBuilder) -> Result { let request = request.build()?; tracing::debug!("Request: {:#?}", request); - let method = request.method().as_str().to_string(); - let url = request.url().to_owned().to_string(); - let response = self.client.execute(request).await?; - let status: u16 = response.status().into(); - tracing::info!(method, status, url, "type" = "subrequest"); + + let mut service = ServiceBuilder::new() + .layer(self.auth_layer.clone()) + .layer(RequestLogLayer::new("subrequest")) + .service(self.clone()); + poll_fn(|ctx| service.poll_ready(ctx)).await?; + let response = service.call(request).await?; + tracing::debug!("Response Headers: {:#?}", response.headers()); Ok(response) } @@ -130,16 +97,11 @@ impl HttpTransport { } } -/// Create a new HeaderValue with sensitive data -fn sens_header(value: &str) -> Result { - let mut header = reqwest::header::HeaderValue::from_str(value)?; - header.set_sensitive(true); - Ok(header) -} - #[cfg(test)] mod tests { use super::*; + use http::StatusCode; + use url::Url; /// Test happy-flow, no auth needed #[tokio::test] @@ -154,7 +116,7 @@ mod tests { .await, ]; - let transport = HttpTransport::new(None); + let mut transport = HttpTransport::new(None).unwrap(); let request = transport.get(Url::parse(&format!("{}/foobar", &server.url())).unwrap()); let response = transport.send(request).await.unwrap(); for mock in mocks { @@ -201,8 +163,8 @@ mod tests { .await, ]; - let transport = HttpTransport::new(Some("Basic mybasicauth".to_string())); - let request = transport.get(Url::parse(&format!("{}/foobar", &server.url())).unwrap()); + let mut transport = HttpTransport::new(Some("Basic mybasicauth".to_string())).unwrap(); + let request = transport.get(Url::parse(&format!("{url}/foobar")).unwrap()); let response = transport.send(request).await.unwrap(); for mock in mocks { mock.assert_async().await; @@ -211,6 +173,72 @@ mod tests { assert_eq!(response.text().await.unwrap(), "Hello, world!"); } + /// Test happy-flow, with authentication, multiple requests + /// Subsequent requests should have their bearer token set without authenticating again + #[tokio::test] + async fn http_transport_send_auth_multiple_requests() { + let mut server = mockito::Server::new_async().await; + let url = server.url(); + let mocks = vec![ + // Response to unauthenticated request + server + .mock("GET", "/foobar") + .with_status(401) + .with_header( + "WWW-Authenticate", + &format!("Bearer realm=\"{url}/token\",service=\"pyoci.fakeservice\""), + ) + .create_async() + .await, + // Token exchange + server + .mock( + "GET", + "/token?grant_type=password&service=pyoci.fakeservice", + ) + .match_header("Authorization", "Basic mybasicauth") + .with_status(200) + .with_body(r#"{"token":"mytoken"}"#) + .create_async() + .await, + // Re-submitted request, with bearer auth + server + .mock("GET", "/foobar") + .match_header("Authorization", "Bearer mytoken") + .with_status(200) + .with_body("Hello, world!") + .create_async() + .await, + // Second call to Send, should contain Bearer auth from last request + server + .mock("GET", "/bazqaz") + .match_header("Authorization", "Bearer mytoken") + .with_status(200) + .with_body("Hello, again!") + .create_async() + .await, + ]; + + let mut transport = HttpTransport::new(Some("Basic mybasicauth".to_string())).unwrap(); + // clone the transport to check if they share the bearer token state + let mut transport2 = transport.clone(); + + // First request, initiating authentication + let request = transport.get(Url::parse(&format!("{url}/foobar")).unwrap()); + let response = transport.send(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.text().await.unwrap(), "Hello, world!"); + + // Second request, reusing the previous authentication + let request = transport2.get(Url::parse(&format!("{url}/bazqaz")).unwrap()); + let response = transport2.send(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.text().await.unwrap(), "Hello, again!"); + + for mock in mocks { + mock.assert_async().await; + } + } /// Test missing authentication #[tokio::test] async fn http_transport_send_missing_auth() { @@ -229,8 +257,8 @@ mod tests { .await, ]; - let transport = HttpTransport::new(None); - let request = transport.get(Url::parse(&format!("{}/foobar", &server.url())).unwrap()); + let mut transport = HttpTransport::new(None).unwrap(); + let request = transport.get(Url::parse(&format!("{url}/foobar")).unwrap()); let response = transport.send(request).await.unwrap(); for mock in mocks { mock.assert_async().await; @@ -264,8 +292,8 @@ mod tests { .await, ]; - let transport = HttpTransport::new(Some("Basic mybasicauth".to_string())); - let request = transport.get(Url::parse(&format!("{}/foobar", &server.url())).unwrap()); + let mut transport = HttpTransport::new(Some("Basic mybasicauth".to_string())).unwrap(); + let request = transport.get(Url::parse(&format!("{url}/foobar")).unwrap()); let response = transport.send(request).await.unwrap(); for mock in mocks { mock.assert_async().await; @@ -310,8 +338,8 @@ mod tests { .await, ]; - let transport = HttpTransport::new(Some("Basic mybasicauth".to_string())); - let request = transport.get(Url::parse(&format!("{}/foobar", &server.url())).unwrap()); + let mut transport = HttpTransport::new(Some("Basic mybasicauth".to_string())).unwrap(); + let request = transport.get(Url::parse(&format!("{url}/foobar")).unwrap()); let response = transport.send(request).await.unwrap(); for mock in mocks { mock.assert_async().await;