From aa79ba51ea557b1f6028ddb3f9859045497d40f3 Mon Sep 17 00:00:00 2001 From: Allex Veldman Date: Fri, 26 Jul 2024 13:53:39 +0200 Subject: [PATCH 1/3] feat: Use Axum for routing and handling requests This moves large parts of the request handling from cloudflare specific implementations to Axum, allowing reuse in non-cloudflare worker server setups in the future. --- Cargo.lock | 139 ++++++++++- Cargo.toml | 7 +- src/app.rs | 590 +++++++++++++++++++++++++++++++++++++++++++++++ src/cf.rs | 167 +++----------- src/lib.rs | 3 + src/otlp.rs | 57 ++--- src/package.rs | 69 +++--- src/pyoci.rs | 38 ++- src/transport.rs | 2 +- 9 files changed, 869 insertions(+), 203 deletions(-) create mode 100644 src/app.rs diff --git a/Cargo.lock b/Cargo.lock index a59014d..89a957b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -90,9 +90,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.80" +version = "0.1.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", @@ -111,6 +111,67 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "axum" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +dependencies = [ + "async-trait", + "axum-core", + "axum-macros", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "itoa", + "matchit", + "memchr", + "mime", + "multer", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper 1.0.1", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 0.1.2", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00c055ee2d014ae5981ce1016374e8213682aa14d9bf40e48ab48b5f3ef20eaa" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "backtrace" version = "0.3.71" @@ -601,6 +662,18 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "http" version = "0.2.12" @@ -885,13 +958,31 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.11" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" dependencies = [ + "hermit-abi", "libc", "wasi", - "windows-sys 0.48.0", + "windows-sys 0.52.0", +] + +[[package]] +name = "multer" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http 1.1.0", + "httparse", + "memchr", + "mime", + "spin", + "version_check", ] [[package]] @@ -1197,12 +1288,15 @@ dependencies = [ "anyhow", "askama", "async-channel", + "async-trait", + "axum", "base16ct", "base64 0.22.1", "bytes", "console_error_panic_hook", "exitcode", "futures", + "http 1.1.0", "oci-spec", "opentelemetry-proto", "prost", @@ -1213,6 +1307,8 @@ dependencies = [ "sha2", "test-case", "time", + "tokio", + "tower", "tracing", "tracing-core", "tracing-subscriber", @@ -1308,7 +1404,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 1.0.1", "system-configuration", "tokio", "tokio-native-tls", @@ -1396,6 +1492,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + [[package]] name = "ryu" version = "1.0.17" @@ -1584,6 +1686,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "sync_wrapper" version = "1.0.1" @@ -1735,9 +1843,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.37.0" +version = "1.39.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "d040ac2b29ab03b09d4129c2f5bbd012a3ac2f79d38ff506a4bf8dd34b0eac8a" dependencies = [ "backtrace", "bytes", @@ -1745,7 +1853,19 @@ dependencies = [ "mio", "pin-project-lite", "socket2", - "windows-sys 0.48.0", + "tokio-macros", + "windows-sys 0.52.0", +] + +[[package]] +name = "tokio-macros" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", ] [[package]] @@ -2296,6 +2416,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d58b567e8b518469c73c8d7e4798d5abcd7db76a1b33121fffd36ac6fa62da05" dependencies = [ "async-trait", + "axum", "bytes", "chrono", "futures-channel", diff --git a/Cargo.toml b/Cargo.toml index 8896995..eec31f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,12 +35,14 @@ sha2 = "0.10.8" base16ct = { version = "0.2.0", features = ["alloc"] } urlencoding = "2.1.3" anyhow = "1.0.86" +http = "1.1.0" +axum = { version = "0.7.5", default-features = false, features = ["multipart","macros"] } # wasm dependencies time = { version = "0.3.36", features = ["wasm-bindgen"] } console_error_panic_hook = "0.1.7" wasm-bindgen = "0.2.92" -worker = {version = "0.3.0"} +worker = { version = "0.3.0", features = ["http", "axum"] } tracing-web = "0.1.3" futures = "0.3.30" @@ -49,6 +51,8 @@ opentelemetry-proto = { version ="0.6.0", features = ["gen-tonic-messages", "log tracing-core = "0.1.32" prost = "0.12.6" async-channel = "2.3.1" +tower = "0.4.13" +async-trait = "0.1.81" [dependencies.web-sys] version = "0.3.63" @@ -62,3 +66,4 @@ features = [ [dev-dependencies] test-case = "3.3.1" +tokio = { version = "1.39.1", features = ["macros"]} diff --git a/src/app.rs b/src/app.rs new file mode 100644 index 0000000..b53be53 --- /dev/null +++ b/src/app.rs @@ -0,0 +1,590 @@ +use askama::Template; +use async_trait::async_trait; +use axum::{ + debug_handler, + extract::{FromRequestParts, Multipart, Path}, + http::{header, request::Parts, HeaderMap}, + response::{Html, IntoResponse}, + routing::{get, post}, + Router, +}; +use http::StatusCode; +use url::Url; + +use crate::{package, pyoci::PyOciError, templates, PyOci}; + +#[derive(Debug)] +// Custom error type to translate between anyhow/axum +struct AppError(anyhow::Error); + +// Tell axum how to convert `AppError` into a response. +impl IntoResponse for AppError { + fn into_response(self) -> axum::response::Response { + match self.0.downcast_ref::() { + Some(err) => (err.status, err.message.clone()).into_response(), + None => (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", self.0)).into_response(), + } + } +} + +// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into +// `Result<_, AppError>`. That way you don't need to do that manually. +impl From for AppError +where + E: Into, +{ + fn from(err: E) -> Self { + Self(err.into()) + } +} +/// Request Router +pub fn router() -> Router { + // TODO: Validate HOST header against a list of allowed hosts + Router::new() + .route("/:registry/:namespace/:package/", get(list_package)) + .route( + "/:registry/:namespace/:package/:filename", + get(download_package), + ) + .route("/:registry/:namespace/", post(publish_package)) + .layer(axum::middleware::from_fn(trace_middleware)) +} + +/// Log incoming requests +async fn trace_middleware( + method: axum::http::Method, + uri: axum::http::Uri, + request: axum::extract::Request, + next: axum::middleware::Next, +) -> axum::response::Response { + let response = next.run(request).await; + + let status: u16 = response.status().into(); + tracing::info!(method = %method, status, path = %uri.path(), "type" = "request"); + response +} + +/// List package request handler +#[debug_handler] +// Mark the handler as Send when building a wasm target +// JsFuture, and most other JS objects are !Send +// Because the cloudflare worker runtime is single-threaded, we can safely mark this as Send +// https://docs.rs/worker/latest/worker/index.html#send-helpers +#[cfg_attr(target_arch = "wasm32", worker::send)] +async fn list_package( + headers: HeaderMap, + Host(host): Host, + path_params: Path<(String, String, String)>, +) -> Result, AppError> { + let auth = match headers.get("Authorization") { + Some(auth) => Some(auth.to_str()?.to_owned()), + None => None, + }; + + let package: package::Info = path_params.0.try_into()?; + + let client = PyOci::new(package.registry.clone(), auth); + // Fetch at most 45 packages + // https://developers.cloudflare.com/workers/platform/limits/#account-plan-limits + let files = client.list_package_files(&package, 45).await?; + + // TODO: swap to application/vnd.pypi.simple.v1+json + let template = templates::ListPackageTemplate { host, files }; + Ok(Html(template.render().expect("valid template"))) +} + +/// Download package request handler +// #[debug_handler] +// Mark the handler as Send when building a wasm target +// JsFuture, and most other JS objects are !Send +// Because the cloudflare worker runtime is single-threaded, we can safely mark this as Send +// https://docs.rs/worker/latest/worker/index.html#send-helpers +#[cfg_attr(target_arch = "wasm32", worker::send)] +async fn download_package( + path_params: Path<(String, String, Option, String)>, + headers: HeaderMap, +) -> Result { + let auth = match headers.get("Authorization") { + Some(auth) => Some(auth.to_str()?.to_owned()), + None => None, + }; + let package: package::Info = path_params.0.try_into()?; + + let client = PyOci::new(package.registry.clone(), auth); + let data = client + .download_package_file(&package) + .await? + .bytes() + .await + .expect("valid bytes"); + + // TODO: With some trickery we could stream the data directly to the response + Ok(( + [( + header::CONTENT_DISPOSITION, + format!("attachment; filename=\"{}\"", package.filename()), + )], + data, + )) +} + +/// Publish package request handler +/// +/// ref: https://warehouse.pypa.io/api-reference/legacy.html#upload-api +// #[debug_handler] +// Mark the handler as Send when building a wasm target +// JsFuture, and most other JS objects are !Send +// Because the cloudflare worker runtime is single-threaded, we can safely mark this as Send +// https://docs.rs/worker/latest/worker/index.html#send-helpers +#[cfg_attr(target_arch = "wasm32", worker::send)] +async fn publish_package( + Path((registry, namespace)): Path<(String, String)>, + headers: HeaderMap, + multipart: Multipart, +) -> Result { + let form_data = UploadForm::from_multipart(multipart).await?; + + let auth = match headers.get("Authorization") { + Some(auth) => Some(auth.to_str()?.to_owned()), + None => None, + }; + let package: package::Info = (registry, namespace, None, form_data.filename).try_into()?; + let client = PyOci::new(package.registry.clone(), auth); + + client + .publish_package_file(&package, form_data.content) + .await?; + Ok("Published".into()) +} + +/// Form data for the upload API +/// +/// ref: https://warehouse.pypa.io/api-reference/legacy.html#upload-api +#[derive(Debug)] +struct UploadForm { + filename: String, + content: Vec, +} + +impl UploadForm { + /// Convert a Multipart into an UploadForm + /// + /// Returns MultiPartError if the form can't be parsed + async fn from_multipart(mut multipart: Multipart) -> anyhow::Result { + let mut action = None; + let mut protocol_version = None; + let mut content = None; + let mut filename = None; + while let Some(field) = multipart.next_field().await? { + match field.name() { + Some(":action") => action = Some(field.text().await?), + Some("protocol_version") => protocol_version = Some(field.text().await?), + Some("content") => { + filename = field.file_name().map(|s| s.to_string()); + content = Some(field.bytes().await?) + } + _ => (), + } + } + + match action { + Some(action) if action == "file_upload" => (), + None => { + return Err(PyOciError::from(( + StatusCode::BAD_REQUEST, + "Missing ':action' form-field", + )) + .into()) + } + _ => { + return Err(PyOciError::from(( + StatusCode::BAD_REQUEST, + "Invalid ':action' form-field", + )) + .into()) + } + }; + + match protocol_version { + Some(protocol_version) if protocol_version == "1" => (), + None => { + return Err(PyOciError::from(( + StatusCode::BAD_REQUEST, + "Missing 'protocol_version' form-field", + )) + .into()) + } + _ => { + return Err(PyOciError::from(( + StatusCode::BAD_REQUEST, + "Invalid 'protocol_version' form-field", + )) + .into()) + } + }; + + let content = match content { + None => { + return Err(PyOciError::from(( + StatusCode::BAD_REQUEST, + "Missing 'content' form-field", + )) + .into()) + } + Some(content) if content.is_empty() => { + return Err( + PyOciError::from((StatusCode::BAD_REQUEST, "No 'content' provided")).into(), + ) + } + Some(content) => content, + }; + + let filename = match filename { + Some(filename) if filename.is_empty() => { + return Err( + PyOciError::from((StatusCode::BAD_REQUEST, "No 'filename' provided")).into(), + ) + } + Some(filename) => filename, + None => { + return Err(PyOciError::from(( + StatusCode::BAD_REQUEST, + "'content' form-field is missing a 'filename'", + )) + .into()) + } + }; + + Ok(Self { + filename, + content: content.into(), + }) + } +} + +/// Extract the host from the request as a URL +/// includes the scheme and port, the path, query, username and password are removed if present +struct Host(Url); + +#[async_trait] +impl FromRequestParts for Host +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let uri = &parts.uri; + let mut url = + Url::parse(&uri.to_string()).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid URL"))?; + + url.set_path(""); + url.set_query(None); + url.set_username("") + .map_err(|_| (StatusCode::BAD_REQUEST, "Failed to clear URL username"))?; + url.set_password(None) + .map_err(|_| (StatusCode::BAD_REQUEST, "Failed to clear URL password"))?; + Ok(Host(url)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use axum::{body::to_bytes, http::Request}; + use tower::ServiceExt; + + #[tokio::test] + async fn publish_package_missing_action() { + let router = router(); + + let form = "--foobar\r\n\ + Content-Disposition: form-data; name=\"submit-name\"\r\n\ + \r\n\ + Larry\r\n\ + --foobar--\r\n"; + let req = Request::builder() + .method("POST") + .uri("/pypi/pytest/") + .header("Content-Type", "multipart/form-data; boundary=foobar") + .body(form.to_string()) + .unwrap(); + let response = router.oneshot(req).await.unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = String::from_utf8( + to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() + .into(), + ) + .unwrap(); + assert_eq!(&body, "Missing ':action' form-field"); + } + + #[tokio::test] + async fn publish_package_invalid_action() { + let router = router(); + + let form = "--foobar\r\n\ + Content-Disposition: form-data; name=\":action\"\r\n\ + \r\n\ + not-file_download\r\n\ + --foobar--\r\n"; + let req = Request::builder() + .method("POST") + .uri("/pypi/pytest/") + .header("Content-Type", "multipart/form-data; boundary=foobar") + .body(form.to_string()) + .unwrap(); + let response = router.oneshot(req).await.unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = String::from_utf8( + to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() + .into(), + ) + .unwrap(); + assert_eq!(&body, "Invalid ':action' form-field"); + } + + #[tokio::test] + async fn publish_package_missing_protocol_version() { + let router = router(); + + let form = "--foobar\r\n\ + Content-Disposition: form-data; name=\":action\"\r\n\ + \r\n\ + file_upload\r\n\ + --foobar--\r\n"; + let req = Request::builder() + .method("POST") + .uri("/pypi/pytest/") + .header("Content-Type", "multipart/form-data; boundary=foobar") + .body(form.to_string()) + .unwrap(); + let response = router.oneshot(req).await.unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = String::from_utf8( + to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() + .into(), + ) + .unwrap(); + assert_eq!(&body, "Missing 'protocol_version' form-field"); + } + + #[tokio::test] + async fn publish_package_invalid_protocol_version() { + let router = router(); + + let form = "--foobar\r\n\ + Content-Disposition: form-data; name=\":action\"\r\n\ + \r\n\ + file_upload\r\n\ + --foobar\r\n\ + Content-Disposition: form-data; name=\"protocol_version\"\r\n\ + \r\n\ + 2\r\n\ + --foobar--\r\n"; + let req = Request::builder() + .method("POST") + .uri("/pypi/pytest/") + .header("Content-Type", "multipart/form-data; boundary=foobar") + .body(form.to_string()) + .unwrap(); + let response = router.oneshot(req).await.unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = String::from_utf8( + to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() + .into(), + ) + .unwrap(); + assert_eq!(&body, "Invalid 'protocol_version' form-field"); + } + + #[tokio::test] + async fn publish_package_missing_content() { + let router = router(); + + let form = "--foobar\r\n\ + Content-Disposition: form-data; name=\":action\"\r\n\ + \r\n\ + file_upload\r\n\ + --foobar\r\n\ + Content-Disposition: form-data; name=\"protocol_version\"\r\n\ + \r\n\ + 1\r\n\ + --foobar--\r\n"; + let req = Request::builder() + .method("POST") + .uri("/pypi/pytest/") + .header("Content-Type", "multipart/form-data; boundary=foobar") + .body(form.to_string()) + .unwrap(); + let response = router.oneshot(req).await.unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = String::from_utf8( + to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() + .into(), + ) + .unwrap(); + assert_eq!(&body, "Missing 'content' form-field"); + } + + #[tokio::test] + async fn publish_package_empty_content() { + let router = router(); + + let form = "--foobar\r\n\ + Content-Disposition: form-data; name=\":action\"\r\n\ + \r\n\ + file_upload\r\n\ + --foobar\r\n\ + Content-Disposition: form-data; name=\"protocol_version\"\r\n\ + \r\n\ + 1\r\n\ + --foobar\r\n\ + Content-Disposition: form-data; name=\"content\"\r\n\ + \r\n\ + \r\n\ + --foobar--\r\n"; + let req = Request::builder() + .method("POST") + .uri("/pypi/pytest/") + .header("Content-Type", "multipart/form-data; boundary=foobar") + .body(form.to_string()) + .unwrap(); + let response = router.oneshot(req).await.unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = String::from_utf8( + to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() + .into(), + ) + .unwrap(); + assert_eq!(&body, "No 'content' provided"); + } + + #[tokio::test] + async fn publish_package_content_missing_filename() { + let router = router(); + + let form = "--foobar\r\n\ + Content-Disposition: form-data; name=\":action\"\r\n\ + \r\n\ + file_upload\r\n\ + --foobar\r\n\ + Content-Disposition: form-data; name=\"protocol_version\"\r\n\ + \r\n\ + 1\r\n\ + --foobar\r\n\ + Content-Disposition: form-data; name=\"content\"\r\n\ + \r\n\ + someawesomepackagedata\r\n\ + --foobar--\r\n"; + let req = Request::builder() + .method("POST") + .uri("/pypi/pytest/") + .header("Content-Type", "multipart/form-data; boundary=foobar") + .body(form.to_string()) + .unwrap(); + let response = router.oneshot(req).await.unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = String::from_utf8( + to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() + .into(), + ) + .unwrap(); + assert_eq!(&body, "'content' form-field is missing a 'filename'"); + } + + #[tokio::test] + async fn publish_package_content_filename_empty() { + let router = router(); + + let form = "--foobar\r\n\ + Content-Disposition: form-data; name=\":action\"\r\n\ + \r\n\ + file_upload\r\n\ + --foobar\r\n\ + Content-Disposition: form-data; name=\"protocol_version\"\r\n\ + \r\n\ + 1\r\n\ + --foobar\r\n\ + Content-Disposition: form-data; name=\"content\"; filename=\"\"\r\n\ + \r\n\ + someawesomepackagedata\r\n\ + --foobar--\r\n"; + let req = Request::builder() + .method("POST") + .uri("/pypi/pytest/") + .header("Content-Type", "multipart/form-data; boundary=foobar") + .body(form.to_string()) + .unwrap(); + let response = router.oneshot(req).await.unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = String::from_utf8( + to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() + .into(), + ) + .unwrap(); + assert_eq!(&body, "No 'filename' provided"); + } + + #[tokio::test] + async fn publish_package_url_encoded_registry() { + let router = router(); + + let form = "--foobar\r\n\ + Content-Disposition: form-data; name=\":action\"\r\n\ + \r\n\ + file_upload\r\n\ + --foobar\r\n\ + Content-Disposition: form-data; name=\"protocol_version\"\r\n\ + \r\n\ + 1\r\n\ + --foobar\r\n\ + Content-Disposition: form-data; name=\"content\"; filename=\"foobar-1.0.0.tar.gz\"\r\n\ + \r\n\ + someawesomepackagedata\r\n\ + --foobar--\r\n"; + let req = Request::builder() + .method("POST") + .uri("/pyoci.allexveldman.nl/pyoci/") + .header("Content-Type", "multipart/form-data; boundary=foobar") + .body(form.to_string()) + .unwrap(); + let response = router.oneshot(req).await.unwrap(); + + let status = response.status(); + let body = String::from_utf8( + to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() + .into(), + ) + .unwrap(); + assert_eq!(&body, "Published"); + assert_eq!(status, StatusCode::CREATED); + } +} diff --git a/src/cf.rs b/src/cf.rs index aaebc22..6402991 100644 --- a/src/cf.rs +++ b/src/cf.rs @@ -1,36 +1,33 @@ -use anyhow::{bail, Result}; -use askama::Template; +use http::{Request, Response}; use opentelemetry_proto::tonic::logs::v1::LogRecord; -use std::{str::FromStr, sync::OnceLock}; +use std::collections::HashMap; +use std::sync::OnceLock; +use tower::Service; +use tracing::{info_span, Instrument}; use tracing_subscriber::fmt::time::UtcTime; use tracing_subscriber::prelude::*; use tracing_subscriber::EnvFilter; use tracing_web::MakeWebConsoleWriter; -use worker::{ - console_log, event, Context, Env, FormEntry, Request, Response, ResponseBuilder, RouteContext, - Router, -}; - -use crate::{package, pyoci::OciError, templates, PyOci}; +use worker::{console_log, event, Body, Cf, Context, Env}; /// Wrap an async route handler into a closure that can be used in the router. /// /// Allows request handlers to return Result instead of worker::Result -macro_rules! wrap { - ($e:expr) => { - |req: Request, ctx: RouteContext<()>| async { wrap($e(req, ctx).await) } - }; -} - -fn wrap(res: Result) -> worker::Result { - match res { - Ok(response) => Ok(response), - Err(e) => match e.downcast_ref::() { - Some(err) => Response::error(err.to_string(), err.status().into()), - None => Response::error(e.to_string(), 400), - }, - } -} +// macro_rules! wrap { +// ($e:expr) => { +// |req: Request, ctx: RouteContext<()>| async { wrap($e(req, ctx).await) } +// }; +// } +// +// fn wrap(res: Result) -> worker::Result { +// match res { +// Ok(response) => Ok(response), +// Err(e) => match e.downcast_ref::() { +// Some(err) => Response::error(err.to_string(), err.status().into()), +// None => Response::error(e.to_string(), 400), +// }, +// } +// } /// Called once when the worker is started #[event(start)] @@ -74,9 +71,13 @@ fn init(env: &Env) -> &'static Option>> { /// Entrypoint for the fetch event #[event(fetch, respond_with_errors)] -async fn fetch(req: Request, env: Env, ctx: Context) -> worker::Result { +async fn fetch( + req: Request, + env: Env, + _ctx: Context, +) -> worker::Result> { let receiver = init(&env); - let cf = req.cf().expect("valid cf").clone(); + let cf = req.extensions().get::().unwrap().to_owned(); let otlp_endpoint = match env.secret("OTLP_ENDPOINT") { Ok(endpoint) => endpoint.to_string(), Err(_) => "".to_string(), @@ -86,111 +87,17 @@ async fn fetch(req: Request, env: Env, ctx: Context) -> worker::Result Err(_) => "".to_string(), }; - let result = _fetch(req, env, ctx).await; + let span = info_span!("fetch", path = %req.uri().path(), method = %req.method()); + let result = crate::app::router().call(req).instrument(span).await; if let Some(receiver) = receiver { - crate::otlp::flush(receiver, otlp_endpoint, otlp_auth, &cf).await; - } - result -} + let attributes = HashMap::from([ + ("service.name".to_string(), Some("pyoci".to_string())), + ("cloud.region".to_string(), cf.region()), + ("cloud.availability_zone".to_string(), Some(cf.colo())), + ]); -#[tracing::instrument( - name="fetch", - skip(req, env, _ctx), - fields(path = %req.path(), method = %req.method())) -] -async fn _fetch(req: Request, env: Env, _ctx: Context) -> worker::Result { - let method = req.method().to_string(); - let path = req.path(); - - let response = router().run(req, env).await; - - let status = match &response { - Ok(response) => response.status_code(), - Err(_) => 400, - }; - - tracing::info!(method, status, path, "type" = "request"); - response -} - -/// Request Router -fn router<'a>() -> Router<'a, ()> { - Router::new() - .get_async("/:registry/:namespace/:package/", wrap!(list_package)) - .get_async( - "/:registry/:namespace/:package/:filename", - wrap!(download_package), - ) - .post_async("/:registry/:namespace/", wrap!(publish_package)) -} - -/// List package request handler -async fn list_package(req: Request, _ctx: RouteContext<()>) -> Result { - let auth = req.headers().get("Authorization").expect("valid header"); - let package = package::Info::from_str(&req.path())?; - let client = PyOci::new(package.registry.clone(), auth); - // Fetch at most 45 packages - // https://developers.cloudflare.com/workers/platform/limits/#account-plan-limits - let files = client.list_package_files(&package, 45).await?; - let mut host = req.url().expect("valid url"); - host.set_path(""); - // TODO: swap to application/vnd.pypi.simple.v1+json - let template = templates::ListPackageTemplate { host, files }; - Ok( - Response::from_html(template.render().expect("valid template")) - .expect("valid html response"), - ) -} - -/// Download package request handler -async fn download_package(req: Request, _ctx: RouteContext<()>) -> Result { - let auth = req.headers().get("Authorization").expect("valid header"); - let package = package::Info::from_str(&req.path())?; - let client = PyOci::new(package.registry.clone(), auth); - let data = client - .download_package_file(&package) - .await? - .bytes() - .await - .expect("valid bytes"); - - // TODO: With some trickery we could stream the data directly to the response - let response = ResponseBuilder::new() - .with_header( - "Content-Disposition", - &format!("attachment; filename=\"{}\"", package.filename()), - ) - .expect("valid header") - .from_bytes(data.into()) - .expect("valid response"); - Ok(response) -} - -/// Publish package request handler -/// -/// ref: https://warehouse.pypa.io/api-reference/legacy.html#upload-api -async fn publish_package(mut req: Request, ctx: RouteContext<()>) -> Result { - let (Some(registry), Some(namespace)) = (ctx.param("registry"), ctx.param("namespace")) else { - bail!("Missing registry or namespace"); - }; - let Ok(form_data) = req.form_data().await else { - bail!("Invalid form data"); - }; - let Some(content) = form_data.get("content") else { - bail!("Missing file"); - }; - let FormEntry::File(file) = content else { - bail!("Expected file"); - }; - let auth = req.headers().get("Authorization").expect("valid header"); - let package = package::Info::new(registry, namespace, &file.name())?; - let client = PyOci::new(package.registry.clone(), auth); - - // FormEntry::File does not provide a streaming interface - // so we must read the entire file into memory - let data = file.bytes().await.expect("valid bytes"); - - client.publish_package_file(&package, data).await?; - Ok(Response::ok("Published").unwrap()) + crate::otlp::flush(receiver, otlp_endpoint, otlp_auth, &attributes).await; + } + Ok(result?) } diff --git a/src/lib.rs b/src/lib.rs index e5a9b6a..2da3eae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,10 @@ #![warn(unused_extern_crates)] +// Webserver request handlers +mod app; // Request handlers for the cloudflare worker mod cf; +// OTLP handlers mod otlp; // Helper for parsing and managing Python/OCI packages mod package; diff --git a/src/otlp.rs b/src/otlp.rs index 2a19b2b..ff33d1d 100644 --- a/src/otlp.rs +++ b/src/otlp.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::fmt::{self, Write}; use std::sync::RwLock; @@ -7,7 +8,7 @@ use time::OffsetDateTime; use tracing::Subscriber; use tracing_core::Event; use tracing_subscriber::{layer::Context, registry::LookupSpan, Layer}; -use worker::{console_log, Cf}; +use worker::console_log; use tracing::field::{Field, Visit}; @@ -21,38 +22,31 @@ use crate::USER_AGENT; /// Convert a batch of log records into a ExportLogsServiceRequest /// -fn build_logs_export_body(logs: Vec, cf: &Cf) -> ExportLogsServiceRequest { +fn build_logs_export_body( + logs: Vec, + attributes: &HashMap>, +) -> ExportLogsServiceRequest { let scope_logs = ScopeLogs { scope: None, log_records: logs, schema_url: "".to_string(), }; - let region = cf.region().map(|region| AnyValue { - value: Some(any_value::Value::StringValue(region)), - }); - let zone = Some(AnyValue { - value: Some(any_value::Value::StringValue(cf.colo())), - }); - + let mut attrs = vec![]; + for (key, value) in attributes { + let Some(value) = value else { + continue; + }; + attrs.push(KeyValue { + key: key.into(), + value: Some(AnyValue { + value: Some(any_value::Value::StringValue(value.into())), + }), + }); + } let resource_logs = ResourceLogs { resource: Some(Resource { - attributes: vec![ - KeyValue { - key: "service.name".to_string(), - value: Some(AnyValue { - value: Some(any_value::Value::StringValue("pyoci".to_string())), - }), - }, - KeyValue { - key: "cloud.region".to_string(), - value: region, - }, - KeyValue { - key: "cloud.availability_zone".to_string(), - value: zone, - }, - ], + attributes: attrs, dropped_attributes_count: 0, }), scope_logs: vec![scope_logs], @@ -89,7 +83,7 @@ pub async fn flush( receiver: &async_channel::Receiver>, otlp_endpoint: String, otlp_auth: String, - cf: &Cf, + attributes: &HashMap>, ) { console_log!("Flushing logs to OTLP"); let client = reqwest::Client::builder() @@ -107,7 +101,7 @@ pub async fn flush( break; } }; - let body = build_logs_export_body(log_records, cf).encode_to_vec(); + let body = build_logs_export_body(log_records, attributes).encode_to_vec(); let mut url = url::Url::parse(&otlp_endpoint).unwrap(); url.path_segments_mut().unwrap().extend(&["v1", "logs"]); // send to OTLP Collector @@ -141,6 +135,8 @@ pub async fn flush( // Private methods impl OtlpLogLayer { + /// Push all recorded log messages onto the channel + /// This is called at the end of every request fn flush(&self) -> Result<()> { let records: Vec = self.records.write().unwrap().drain(..).collect(); console_log!("Sending {} log records to OTLP", records.len()); @@ -216,7 +212,12 @@ where self.records.write().unwrap().push(log_record); } - fn on_close(&self, _id: tracing_core::span::Id, _ctx: Context<'_, S>) { + fn on_close(&self, id: tracing_core::span::Id, ctx: Context<'_, S>) { + let span = ctx.span(&id).expect("span not found"); + if !span.parent().is_none() { + // This is a sub-span, we'll flush all messages when the root span is closed + return; + } if let Err(err) = self.flush() { console_log!("Failed to flush log records: {:?}", err); } diff --git a/src/package.rs b/src/package.rs index 1d22469..2835e48 100644 --- a/src/package.rs +++ b/src/package.rs @@ -114,8 +114,8 @@ impl FromStr for File { type Err = Error; /// Parse a filename into the package name, version and architecture - fn from_str(value: &str) -> Result { + // TODO: No need to identify wheel vs sdist, only extract name and version if value.is_empty() { bail!("empty string"); }; @@ -205,32 +205,52 @@ impl FromStr for Info { let parts: Vec<&str> = value.split('/').collect(); match parts[..] { [registry, namespace, distribution] => { - let file = File { - name: distribution.replace('-', "_"), - ..File::default() - }; - Ok(Info { - registry: registry_url(registry)?, - namespace: namespace.to_string(), - file, - }) + Ok((registry.into(), namespace.into(), distribution.into()).try_into()?) } [registry, namespace, distribution, filename] => { - let file = File::from_str(filename)?; - if distribution != file.name { - bail!("Filename does not match distribution name"); - }; - Ok(Info { - registry: registry_url(registry)?, - namespace: namespace.to_string(), - file, - }) + Ok((registry.into(), namespace.into(), Some(distribution.into()), filename.into()).try_into()?) } _ => bail!("Expected '//' or '///', got '{}'", value), } } } +impl TryFrom<(String, String, String)> for Info { + type Error = Error; + + fn try_from((registry, namespace, distribution): (String, String, String)) -> Result { + Ok(Info { + registry: registry_url(®istry)?, + namespace, + file: File { + name: distribution.replace('-', "_"), + ..File::default() + }, + }) + } +} + +impl TryFrom<(String, String, Option, String)> for Info { + type Error = Error; + + fn try_from( + (registry, namespace, distribution, filename): (String, String, Option, String), + ) -> Result { + let file = File::from_str(&filename)?; + if let Some(distribution) = distribution { + if distribution != file.name { + bail!("Filename does not match distribution name"); + } + } + + Ok(Info { + registry: registry_url(®istry)?, + namespace, + file, + }) + } +} + /// Parse the registry URL /// /// If no scheme is provided, it will default to `https://` @@ -249,18 +269,7 @@ fn registry_url(registry: &str) -> Result { } impl Info { - pub fn new(registry: &str, namespace: &str, filename: &str) -> Result { - let info = Info { - registry: registry_url(registry)?, - namespace: namespace.to_string(), - file: File::from_str(filename)?, - }; - Ok(info) - } - /// Replace the version of the package for an OCI tag - /// - /// /// as a tag MUST be at most 128 characters in length and MUST match the following regular expression: /// [a-zA-Z0-9_][a-zA-Z0-9._-]{0,127} diff --git a/src/pyoci.rs b/src/pyoci.rs index 1b89459..8fde760 100644 --- a/src/pyoci.rs +++ b/src/pyoci.rs @@ -116,6 +116,34 @@ fn digest(data: &[u8]) -> String { format!("sha256:{}", hex_encode(&sha)) } +#[derive(Debug)] +pub struct PyOciError { + pub status: StatusCode, + pub message: String, +} +impl std::error::Error for PyOciError {} + +impl std::fmt::Display for PyOciError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}: {}", self.status, self.message) + } +} + +impl From<(StatusCode, &str)> for PyOciError { + fn from((status, message): (StatusCode, &str)) -> Self { + PyOciError { + status, + message: message.to_string(), + } + } +} + +impl From<(StatusCode, String)> for PyOciError { + fn from((status, message): (StatusCode, String)) -> Self { + PyOciError { status, message } + } +} + /// Returned when a request has been authorized but the user has insufficient permissions #[derive(Debug)] pub enum OciError { @@ -206,6 +234,7 @@ impl WwwAuth { } /// Client to communicate with the OCI v2 registry +#[derive(Debug)] pub struct PyOci { registry: Url, transport: HttpTransport, @@ -232,7 +261,6 @@ impl PyOci { let name = package.oci_name()?; let result = self.list_tags(&name).await?; tracing::debug!("{:?}", result); - let tags = result.tags(); let mut files: Vec = Vec::new(); let futures = FuturesUnordered::new(); @@ -347,7 +375,7 @@ impl PyOci { }; // pull blob in first layer of manifest let [blob_descriptor] = &manifest.layers()[..] else { - bail!("Manifest should define exactly one layer"); + bail!("Image Manifest defines unexpected number of layers, was this package published by pyoci?"); }; self.pull_blob(package.oci_name()?, blob_descriptor.to_owned()) .await @@ -528,12 +556,14 @@ impl PyOci { }; Ok(response) } - async fn list_tags(&self, name: &str) -> Result { + + /// List the available tags for a package + async fn list_tags(&self, name: &str) -> anyhow::Result { let url = build_url!(&self, "/v2/{}/tags/list", name); let request = self.transport.get(url); let response = self.transport.send(request).await.expect("valid response"); if !response.status().is_success() { - bail!(response.json::().await?) + return Err(PyOciError::from((StatusCode::NOT_FOUND, response.text().await?)).into()); }; let tags = response .json::() diff --git a/src/transport.rs b/src/transport.rs index 7d5e0c4..585c76c 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -7,7 +7,7 @@ use crate::USER_AGENT; /// HTTP Transport /// /// This struct is responsible for sending HTTP requests to the upstream OCI registry. -#[derive(Default)] +#[derive(Debug, Default)] pub struct HttpTransport { /// HTTP client client: reqwest::Client, From eb82ba51b0f0f57d45885446cb208ea29b82ffb0 Mon Sep 17 00:00:00 2001 From: Allex Veldman Date: Thu, 1 Aug 2024 10:29:22 +0200 Subject: [PATCH 2/3] chore(test): Extend testing OCI registry responses are now mocked using `mockito`, allowing high-level test cases and registry integration tests. This commit only includes a happy-flow of the publish_package method and some smaller internal tests around issues I found during development. --- Cargo.lock | 222 +++++++++++++++++++++++++++++++++++++++++++++-- Cargo.toml | 1 + src/app.rs | 93 +++++++++++++++++++- src/pyoci.rs | 176 ++++++++++++++++++++++++++++++++----- src/transport.rs | 1 + 5 files changed, 461 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 89a957b..86a8342 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,6 +76,16 @@ dependencies = [ "nom", ] +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-channel" version = "2.3.1" @@ -241,6 +251,12 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.6.0" @@ -270,6 +286,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "colored" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" +dependencies = [ + "lazy_static", + "windows-sys 0.48.0", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -637,6 +663,25 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "h2" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "h2" version = "0.4.5" @@ -736,6 +781,12 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "humansize" version = "2.1.3" @@ -745,6 +796,29 @@ dependencies = [ "libm", ] +[[package]] +name = "hyper" +version = "0.14.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper" version = "1.3.1" @@ -754,7 +828,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2", + "h2 0.4.5", "http 1.1.0", "http-body 1.0.0", "httparse", @@ -773,7 +847,7 @@ checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", - "hyper", + "hyper 1.3.1", "hyper-util", "rustls", "rustls-pki-types", @@ -790,7 +864,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper", + "hyper 1.3.1", "hyper-util", "native-tls", "tokio", @@ -809,7 +883,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper", + "hyper 1.3.1", "pin-project-lite", "socket2", "tokio", @@ -898,6 +972,16 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.21" @@ -968,6 +1052,25 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mockito" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2f6e023aa5bdf392aa06c78e4a4e6d498baab5138d0c993503350ebbc37bf1e" +dependencies = [ + "assert-json-diff", + "colored", + "futures-core", + "hyper 0.14.30", + "log", + "rand", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "multer" version = "3.1.0" @@ -1175,6 +1278,29 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.5", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -1225,6 +1351,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "ppv-lite86" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2288c0e17cc8d342c712bb43a257a80ebffce59cdb33d5000d8348f3ec02528b" +dependencies = [ + "zerocopy", + "zerocopy-derive", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -1297,6 +1433,7 @@ dependencies = [ "exitcode", "futures", "http 1.1.0", + "mockito", "oci-spec", "opentelemetry-proto", "prost", @@ -1329,6 +1466,45 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "redox_syscall" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +dependencies = [ + "bitflags 2.5.0", +] + [[package]] name = "regex" version = "1.10.4" @@ -1384,11 +1560,11 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", + "h2 0.4.5", "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper", + "hyper 1.3.1", "hyper-rustls", "hyper-tls", "hyper-util", @@ -1513,6 +1689,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "security-framework" version = "2.11.0" @@ -1621,6 +1803,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "similar" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" + [[package]] name = "slab" version = "0.4.9" @@ -1851,6 +2039,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "socket2", "tokio-macros", @@ -2484,6 +2673,27 @@ dependencies = [ "web-sys", ] +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index eec31f9..09825dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,5 +65,6 @@ features = [ ] [dev-dependencies] +mockito = "1.4.0" test-case = "3.3.1" tokio = { version = "1.39.1", features = ["macros"]} diff --git a/src/app.rs b/src/app.rs index b53be53..1e6cfaf 100644 --- a/src/app.rs +++ b/src/app.rs @@ -553,6 +553,91 @@ mod tests { #[tokio::test] async fn publish_package_url_encoded_registry() { + let mut server = mockito::Server::new_async().await; + let url = server.url(); + let encoded_url = urlencoding::encode(&url).into_owned(); + + let mut mocks = vec![]; + // Mock the server, in order of expected requests + // IndexManifest does not yet exist + mocks.push( + server + .mock("GET", "/v2/mockserver/foobar/manifests/1.0.0") + .with_status(404) + .create_async() + .await, + ); + // HEAD request to check if blob exists for: + // - layer + // - config + mocks.push( + server + .mock( + "HEAD", + mockito::Matcher::Regex(r"/v2/mockserver/foobar/blobs/.+".to_string()), + ) + .expect(2) + .with_status(404) + .create_async() + .await, + ); + // POST request with blob for layer + mocks.push( + server + .mock("POST", "/v2/mockserver/foobar/blobs/uploads/") + .with_status(202) // ACCEPTED + .with_header( + "Location", + &format!("{url}/v2/mockserver/foobar/blobs/uploads/1?_state=uploading"), + ) + .create_async() + .await, + ); + mocks.push( + server + .mock("PUT", "/v2/mockserver/foobar/blobs/uploads/1?_state=uploading&digest=sha256%3Ab7513fb69106a855b69153582dec476677b3c79f4a13cfee6fb7a356cfa754c0") + .with_status(201) // CREATED + .create_async() + .await, + ); + // POST request with blob for config + mocks.push( + server + .mock("POST", "/v2/mockserver/foobar/blobs/uploads/") + .with_status(202) // ACCEPTED + .with_header( + "Location", + &format!("{url}/v2/mockserver/foobar/blobs/uploads/2?_state=uploading"), + ) + .create_async() + .await, + ); + mocks.push( + server + .mock("PUT", "/v2/mockserver/foobar/blobs/uploads/2?_state=uploading&digest=sha256%3A44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a") + .with_status(201) // CREATED + .create_async() + .await, + ); + // PUT request to create Manifest + mocks.push( + server + .mock("PUT", "/v2/mockserver/foobar/manifests/sha256:7ffd96d9eab411893eeacfa906e30956290a07b0141d7c1dd54c9fd5c7c48cf5") + .match_header("Content-Type", "application/vnd.oci.image.manifest.v1+json") + .with_status(201) // CREATED + .create_async() + .await, + ); + // PUT request to create Index + mocks.push( + server + .mock("PUT", "/v2/mockserver/foobar/manifests/1.0.0") + .match_header("Content-Type", "application/vnd.oci.image.index.v1+json") + .with_status(201) // CREATED + .create_async() + .await, + ); + let router = router(); let form = "--foobar\r\n\ @@ -570,7 +655,7 @@ mod tests { --foobar--\r\n"; let req = Request::builder() .method("POST") - .uri("/pyoci.allexveldman.nl/pyoci/") + .uri(format!("/{encoded_url}/mockserver/")) .header("Content-Type", "multipart/form-data; boundary=foobar") .body(form.to_string()) .unwrap(); @@ -584,7 +669,11 @@ mod tests { .into(), ) .unwrap(); + + for mock in mocks { + mock.assert_async().await; + } assert_eq!(&body, "Published"); - assert_eq!(status, StatusCode::CREATED); + assert_eq!(status, StatusCode::OK); } } diff --git a/src/pyoci.rs b/src/pyoci.rs index 8fde760..e331963 100644 --- a/src/pyoci.rs +++ b/src/pyoci.rs @@ -27,6 +27,9 @@ use crate::ARTIFACT_TYPE; /// Build an URL from a format string while sanitizing the parameters /// +/// Note that if the resulting path is an absolute URL, the registry URL is ignored. +/// For more info, see [`Url::join`] +/// /// Returns Err when a parameter fails sanitization macro_rules! build_url { ($pyoci:expr, $uri:literal, $($param:expr),+) => {{ @@ -35,8 +38,8 @@ macro_rules! build_url { $(sanitize($param)?,)* ); let mut new_url = $pyoci.registry.clone(); - new_url.set_path(&uri); - new_url + new_url.set_path(""); + new_url.join(&uri)? }} } @@ -478,9 +481,15 @@ impl PyOci { .await .expect("valid response"); - if response.status() == StatusCode::OK { - tracing::info!("Blob already exists: {name}:{digest}"); - return Ok(()); + match response.status() { + StatusCode::OK => { + tracing::info!("Blob already exists: {name}:{digest}"); + return Ok(()); + } + StatusCode::NOT_FOUND => {} + status => { + return Err(PyOciError::from((status, response.text().await?)).into()); + } } let url = build_url!(&self, "/v2/{}/blobs/uploads/", name); @@ -494,24 +503,17 @@ impl PyOci { StatusCode::ACCEPTED => response .headers() .get("Location") - .expect("a Location header") + .context("Registry response did not contain a Location header")? .to_str() - .expect("valid Location header value"), + .context("Failed to parse Location header as ASCII")?, status => { - bail!(response - .json::() - .await - .with_context(|| format!( - "Failed to upload blob, registry responded with '{}'", - status - ))?) + return Err(PyOciError::from((status, response.text().await?)).into()); } }; - let mut url: Url = if location.starts_with('/') { - build_url!(&self, "{}", location) - } else { - location.parse().expect("valid url") - }; + let mut url: Url = build_url!(&self, "{}", location); + // `append_pair` percent-encodes the values as application/x-www-form-urlencoded. + // ghcr.io seems to be fine with a percent-encoded digest but this could be an issue with + // other registries. url.query_pairs_mut().append_pair("digest", digest); let request = self @@ -521,8 +523,11 @@ impl PyOci { .header("Content-Length", blob.data.len().to_string()) .body(blob.data); let response = self.transport.send(request).await.expect("valid response"); - if response.status() != StatusCode::CREATED { - bail!(response.json::().await?) + match response.status() { + StatusCode::CREATED => {} + status => { + return Err(PyOciError::from((status, response.text().await?)).into()); + } } tracing::debug!( "Blob-location: {}", @@ -584,7 +589,7 @@ impl PyOci { ) -> Result<()> { let (url, data, content_type) = match manifest { Manifest::Index(index) => { - let version = version.context("`version` required for pushing and ImageIndex")?; + let version = version.context("`version` required for pushing an ImageIndex")?; let url = build_url!(&self, "v2/{}/manifests/{}", name, version); let data = index.to_string().expect("valid json"); (url, data, "application/vnd.oci.image.index.v1+json") @@ -604,8 +609,9 @@ impl PyOci { .header("Content-Type", content_type) .body(data); let response = self.transport.send(request).await.expect("valid response"); - if !response.status().is_success() { - bail!(response.json::().await?) + match response.status() { + StatusCode::CREATED => {} + status => return Err(PyOciError::from((status, response.text().await?)).into()), }; Ok(()) } @@ -674,6 +680,17 @@ mod tests { Ok(()) } + #[test] + fn test_build_url_absolute() -> Result<()> { + let client = PyOci { + registry: Url::parse("https://example.com").expect("valid url"), + transport: HttpTransport::new(None), + }; + let url = build_url!(&client, "{}/foo?bar=baz&qaz=sha:123", "http://pyoci.nl"); + assert_eq!(url.as_str(), "http://pyoci.nl/foo?bar=baz&qaz=sha:123"); + Ok(()) + } + #[test] fn test_build_url_double_period() { let client = PyOci { @@ -683,4 +700,115 @@ mod tests { let x = || -> Result { Ok(build_url!(&client, "/foo/{}/", "..")) }(); assert!(x.is_err()); } + + /// Test if a relative Location header is properly handled + #[tokio::test] + async fn test_push_blob_location_relative() { + let mut server = mockito::Server::new_async().await; + let url = server.url(); + + let mut mocks = vec![]; + // Mock the server, in order of expected requests + + // HEAD request to check if blob exists + mocks.push( + server + .mock( + "HEAD", + "/v2/mockserver/foobar/blobs/sha256:2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", + ) + .with_status(404) + .create_async() + .await, + ); + // POST request initiating blob upload + mocks.push( + server + .mock("POST", "/v2/mockserver/foobar/blobs/uploads/") + .with_status(202) // ACCEPTED + .with_header( + "Location", + "/v2/mockserver/foobar/blobs/uploads/1?_state=uploading", + ) + .create_async() + .await, + ); + // PUT request to upload blob + mocks.push( + server + .mock( + "PUT", + "/v2/mockserver/foobar/blobs/uploads/1?_state=uploading&digest=sha256%3A2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", + ) + .with_status(201) // CREATED + .create_async() + .await, + ); + + let client = PyOci { + registry: Url::parse(&url).expect("valid url"), + transport: HttpTransport::new(None), + }; + let blob = Blob::new("hello".into(), "application/octet-stream"); + assert!(client.push_blob("mockserver/foobar", blob).await.is_ok()); + + for mock in mocks { + mock.assert_async().await; + } + } + /// Test if an absolute Location header is properly handled + #[tokio::test] + async fn test_push_blob_location_absolute() { + let mut server = mockito::Server::new_async().await; + let url = server.url(); + + let mut mocks = vec![]; + // Mock the server, in order of expected requests + + // HEAD request to check if blob exists + mocks.push( + server + .mock( + "HEAD", + "/v2/mockserver/foobar/blobs/sha256:2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", + ) + .with_status(404) + .create_async() + .await, + ); + // POST request initiating blob upload + mocks.push( + server + .mock("POST", "/v2/mockserver/foobar/blobs/uploads/") + .with_status(202) // ACCEPTED + .with_header( + "Location", + &format!("{url}/v2/mockserver/foobar/blobs/uploads/1?_state=uploading"), + ) + .create_async() + .await, + ); + // PUT request to upload blob + mocks.push( + server + .mock( + "PUT", + "/v2/mockserver/foobar/blobs/uploads/1?_state=uploading&digest=sha256%3A2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", + ) + .with_status(201) // CREATED + .create_async() + .await, + ); + + let client = PyOci { + registry: Url::parse(&url).expect("valid url"), + transport: HttpTransport::new(None), + }; + let blob = Blob::new("hello".into(), "application/octet-stream"); + assert!(client.push_blob("mockserver/foobar", blob).await.is_ok()); + + for mock in mocks { + mock.assert_async().await; + } + } } diff --git a/src/transport.rs b/src/transport.rs index 585c76c..952f425 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -52,6 +52,7 @@ impl HttpTransport { }; let response = self._send(request).await.expect("valid response"); if response.status() != 401 { + // TODO: support returning 403 when the authentication is not sufficient return Ok(response); } let Some(org_request) = org_request else { From b3e23a4991a9140fb61baa1af9457373ec6ee16d Mon Sep 17 00:00:00 2001 From: Allex Veldman Date: Thu, 1 Aug 2024 15:28:23 +0200 Subject: [PATCH 3/3] chore: Remove the last use of OciError Error handling is now more generically handled through PyOciError. --- src/app.rs | 37 +++++++++++-------------------------- src/pyoci.rs | 50 +++++--------------------------------------------- 2 files changed, 16 insertions(+), 71 deletions(-) diff --git a/src/app.rs b/src/app.rs index 1e6cfaf..df5d659 100644 --- a/src/app.rs +++ b/src/app.rs @@ -557,20 +557,17 @@ mod tests { let url = server.url(); let encoded_url = urlencoding::encode(&url).into_owned(); - let mut mocks = vec![]; - // Mock the server, in order of expected requests - // IndexManifest does not yet exist - mocks.push( + let mocks = vec![ + // Mock the server, in order of expected requests + // IndexManifest does not yet exist server .mock("GET", "/v2/mockserver/foobar/manifests/1.0.0") .with_status(404) .create_async() .await, - ); - // HEAD request to check if blob exists for: - // - layer - // - config - mocks.push( + // HEAD request to check if blob exists for: + // - layer + // - config server .mock( "HEAD", @@ -580,9 +577,7 @@ mod tests { .with_status(404) .create_async() .await, - ); - // POST request with blob for layer - mocks.push( + // POST request with blob for layer server .mock("POST", "/v2/mockserver/foobar/blobs/uploads/") .with_status(202) // ACCEPTED @@ -592,16 +587,12 @@ mod tests { ) .create_async() .await, - ); - mocks.push( server .mock("PUT", "/v2/mockserver/foobar/blobs/uploads/1?_state=uploading&digest=sha256%3Ab7513fb69106a855b69153582dec476677b3c79f4a13cfee6fb7a356cfa754c0") .with_status(201) // CREATED .create_async() .await, - ); - // POST request with blob for config - mocks.push( + // POST request with blob for config server .mock("POST", "/v2/mockserver/foobar/blobs/uploads/") .with_status(202) // ACCEPTED @@ -611,32 +602,26 @@ mod tests { ) .create_async() .await, - ); - mocks.push( server .mock("PUT", "/v2/mockserver/foobar/blobs/uploads/2?_state=uploading&digest=sha256%3A44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a") .with_status(201) // CREATED .create_async() .await, - ); - // PUT request to create Manifest - mocks.push( + // PUT request to create Manifest server .mock("PUT", "/v2/mockserver/foobar/manifests/sha256:7ffd96d9eab411893eeacfa906e30956290a07b0141d7c1dd54c9fd5c7c48cf5") .match_header("Content-Type", "application/vnd.oci.image.manifest.v1+json") .with_status(201) // CREATED .create_async() .await, - ); - // PUT request to create Index - mocks.push( + // PUT request to create Index server .mock("PUT", "/v2/mockserver/foobar/manifests/1.0.0") .match_header("Content-Type", "application/vnd.oci.image.index.v1+json") .with_status(201) // CREATED .create_async() .await, - ); + ]; let router = router(); diff --git a/src/pyoci.rs b/src/pyoci.rs index e331963..613a886 100644 --- a/src/pyoci.rs +++ b/src/pyoci.rs @@ -147,42 +147,6 @@ impl From<(StatusCode, String)> for PyOciError { } } -/// Returned when a request has been authorized but the user has insufficient permissions -#[derive(Debug)] -pub enum OciError { - /// The user has insufficient permissions - Forbidden, - /// The user could not authorize - Unauthorized, -} - -impl OciError { - pub fn status(&self) -> StatusCode { - match self { - OciError::Forbidden => StatusCode::FORBIDDEN, - OciError::Unauthorized => StatusCode::UNAUTHORIZED, - } - } - - fn from_status(status: StatusCode) -> Option { - match status { - StatusCode::FORBIDDEN => Some(OciError::Forbidden), - StatusCode::UNAUTHORIZED => Some(OciError::Unauthorized), - _ => None, - } - } -} - -impl std::fmt::Display for OciError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - OciError::Forbidden => write!(f, "Forbidden"), - OciError::Unauthorized => write!(f, "Unauthorized"), - } - } -} -impl std::error::Error for OciError {} - #[derive(Deserialize)] pub struct AuthResponse { pub token: String, @@ -627,16 +591,12 @@ impl PyOci { "application/vnd.oci.image.manifest.v1+json, application/vnd.oci.image.index.v1+json", ); let response = self.transport.send(request).await.expect("valid response"); - let status = response.status(); - if status == StatusCode::NOT_FOUND { - return Ok(None); - }; - if let Some(err) = OciError::from_status(status) { - return Err(err.into()); - }; - if !status.is_success() { - bail!(response.json::().await?) + match response.status() { + StatusCode::NOT_FOUND => return Ok(None), + StatusCode::OK => {} + status => return Err(PyOciError::from((status, response.text().await?)).into()), }; + match response.headers().get("Content-Type") { Some(value) if value == "application/vnd.oci.image.index.v1+json" => { Ok(Some(Manifest::Index(Box::new(