Skip to content

Commit

Permalink
Replace Rocket with Axum. (#58)
Browse files Browse the repository at this point in the history
Axum gives us more control over the flow of requests and allows us to
use all the tower ecosystem, which has a lot of nice middlewares.

The first immediate benefit we get is better logging using the tracing
ecosystem and support for a request ID. Clients can send a
`x-request-id` header, which will be included in all log messages
printed by the request. If no ID is provided one is generated randomly
by the server. The ID, generated or provided by the client, is also
included in the response.

Most of our business code and endpoints are untouched, since Rocket and
Axum use fairly similar paradigms anyway. Most of the changes are in
auxilary code, such as testing, metrics and response formatting.
  • Loading branch information
plietar authored Feb 22, 2024
1 parent 914336f commit 2e832a4
Show file tree
Hide file tree
Showing 13 changed files with 1,300 additions and 1,052 deletions.
828 changes: 306 additions & 522 deletions Cargo.lock

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ rust-version = "1.70"
crate-type = ["rlib", "cdylib"]

[dependencies]
rocket = { version = "0.5.0", features = ["json"] }
regex = "1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
Expand All @@ -26,10 +25,16 @@ clap = { version = "4.4.8", features = ["derive"] }
anyhow = "1.0.75"
thiserror = "1.0.50"
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"], optional = true }
rocket_prometheus = "0.10.0"
prometheus = "0.13.3"
log = "0.4.20"
tokio = "1.35.1"
tokio = { version = "1.35.1", features = ["fs", "rt-multi-thread"] }
axum = "0.7.4"
tracing-subscriber = "0.3.18"
tracing = "0.1.40"
tower-http = { version = "0.5.1", features = ["trace", "catch-panic", "request-id", "util"] }
tokio-util = { version = "0.7.10", features = ["io"] }
futures = "0.3.30"
tower = "0.4.13"
mime = "0.3.17"

[dev-dependencies]
assert_cmd = "2.0.6"
Expand All @@ -40,6 +45,7 @@ tempdir = "0.3.7"
tar = "0.4.38"
chrono = "0.4.33"
rand = "0.8.5"
tracing-capture = "0.1.0"

[features]
python = [ "dep:pyo3" ]
1 change: 0 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ RUN apt-get -yq update && \
apt-get -yqq install openssh-client git

COPY --from=builder /usr/local/cargo/bin/* /usr/local/bin/
COPY --from=builder /usr/src/outpack_server/Rocket.toml .
COPY start-with-wait /usr/local/bin
EXPOSE 8000
ENTRYPOINT ["start-with-wait"]
4 changes: 0 additions & 4 deletions Rocket.toml

This file was deleted.

225 changes: 128 additions & 97 deletions src/api.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
use anyhow::{bail, Context};
use rocket::fs::TempFile;
use rocket::serde::json::{Error, Json};
use rocket::serde::{Deserialize, Serialize};
use rocket::State;
use rocket::{catch, catchers, routes, Build, Request, Rocket};
use axum::extract::rejection::JsonRejection;
use axum::extract::{self, Query, State};
use axum::response::IntoResponse;
use axum::response::Response;
use axum::{Json, Router};
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::io::ErrorKind;
use std::path::{Path, PathBuf};
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::trace::TraceLayer;

use crate::config;
use crate::hash;
use crate::location;
use crate::metadata;
use crate::metrics;
use crate::responses;
use crate::store;
use rocket_prometheus::PrometheusMetrics;

use crate::outpack_file::OutpackFile;
use responses::{FailResponse, OutpackError, OutpackSuccess};
use crate::responses::{OutpackError, OutpackSuccess};
use crate::upload::{Upload, UploadLayer};

type OutpackResult<T> = Result<OutpackSuccess<T>, OutpackError>;

Expand All @@ -29,142 +33,136 @@ pub struct ApiRoot {
pub schema_version: String,
}

#[catch(500)]
fn internal_error(_req: &Request) -> Json<FailResponse> {
Json(FailResponse::from(OutpackError {
fn internal_error(_err: Box<dyn Any + Send + 'static>) -> Response {
OutpackError {
error: String::from("UNKNOWN_ERROR"),
detail: String::from("Something went wrong"),
kind: Some(ErrorKind::Other),
}))
}
.into_response()
}

#[catch(404)]
fn not_found(_req: &Request) -> Json<FailResponse> {
Json(FailResponse::from(OutpackError {
async fn not_found() -> OutpackError {
OutpackError {
error: String::from("NOT_FOUND"),
detail: String::from("This route does not exist"),
kind: Some(ErrorKind::NotFound),
}))
}

#[catch(400)]
fn bad_request(_req: &Request) -> Json<FailResponse> {
Json(FailResponse::from(OutpackError {
error: String::from("BAD_REQUEST"),
detail: String::from(
"The request could not be understood by the server due to malformed syntax",
),
kind: Some(ErrorKind::InvalidInput),
}))
}
}

#[rocket::get("/")]
fn index() -> OutpackResult<ApiRoot> {
Ok(ApiRoot {
async fn index() -> OutpackResult<ApiRoot> {
Ok(OutpackSuccess::from(ApiRoot {
schema_version: String::from("0.1.1"),
}
.into())
}))
}

#[rocket::get("/metadata/list")]
fn list_location_metadata(root: &State<PathBuf>) -> OutpackResult<Vec<location::LocationEntry>> {
location::read_locations(root)
async fn list_location_metadata(
root: State<PathBuf>,
) -> OutpackResult<Vec<location::LocationEntry>> {
location::read_locations(&root)
.map_err(OutpackError::from)
.map(OutpackSuccess::from)
}

#[rocket::get("/packit/metadata?<known_since>")]
fn get_metadata(
root: &State<PathBuf>,
#[derive(Deserialize)]
struct KnownSince {
known_since: Option<f64>,
}
async fn get_metadata_since(
root: State<PathBuf>,
query: Query<KnownSince>,
) -> OutpackResult<Vec<metadata::PackitPacket>> {
metadata::get_packit_metadata_from_date(root, known_since)
metadata::get_packit_metadata_from_date(&root, query.known_since)
.map_err(OutpackError::from)
.map(OutpackSuccess::from)
}

#[rocket::get("/metadata/<id>/json")]
fn get_metadata_by_id(root: &State<PathBuf>, id: &str) -> OutpackResult<serde_json::Value> {
metadata::get_metadata_by_id(root, id)
async fn get_metadata_by_id(
root: State<PathBuf>,
id: extract::Path<String>,
) -> OutpackResult<serde_json::Value> {
metadata::get_metadata_by_id(&root, &id)
.map_err(OutpackError::from)
.map(OutpackSuccess::from)
}

#[rocket::get("/metadata/<id>/text")]
fn get_metadata_raw(root: &State<PathBuf>, id: &str) -> Result<String, OutpackError> {
metadata::get_metadata_text(root, id).map_err(OutpackError::from)
async fn get_metadata_raw(
root: State<PathBuf>,
id: extract::Path<String>,
) -> Result<String, OutpackError> {
metadata::get_metadata_text(&root, &id).map_err(OutpackError::from)
}

#[rocket::get("/file/<hash>")]
async fn get_file(root: &State<PathBuf>, hash: &str) -> Result<OutpackFile, OutpackError> {
let path = store::file_path(root, hash);
async fn get_file(
root: State<PathBuf>,
hash: extract::Path<String>,
) -> Result<OutpackFile, OutpackError> {
let path = store::file_path(&root, &hash);
OutpackFile::open(hash.to_owned(), path?)
.await
.map_err(OutpackError::from)
}

#[rocket::get("/checksum?<alg>")]
async fn get_checksum(root: &State<PathBuf>, alg: Option<String>) -> OutpackResult<String> {
metadata::get_ids_digest(root, alg)
#[derive(Deserialize)]
struct Algorithm {
alg: Option<String>,
}

async fn get_checksum(root: State<PathBuf>, query: Query<Algorithm>) -> OutpackResult<String> {
metadata::get_ids_digest(&root, query.0.alg)
.map_err(OutpackError::from)
.map(OutpackSuccess::from)
}

#[rocket::post("/packets/missing", format = "json", data = "<ids>")]
async fn get_missing_packets(
root: &State<PathBuf>,
ids: Result<Json<Ids>, Error<'_>>,
root: State<PathBuf>,
ids: Result<Json<Ids>, JsonRejection>,
) -> OutpackResult<Vec<String>> {
let ids = ids?;
metadata::get_missing_ids(root, &ids.ids, ids.unpacked)
metadata::get_missing_ids(&root, &ids.ids, ids.unpacked)
.map_err(OutpackError::from)
.map(OutpackSuccess::from)
}

#[rocket::post("/files/missing", format = "json", data = "<hashes>")]
async fn get_missing_files(
root: &State<PathBuf>,
hashes: Result<Json<Hashes>, Error<'_>>,
root: State<PathBuf>,
hashes: Result<Json<Hashes>, JsonRejection>,
) -> OutpackResult<Vec<String>> {
let hashes = hashes?;
store::get_missing_files(root, &hashes.hashes)
store::get_missing_files(&root, &hashes.hashes)
.map_err(OutpackError::from)
.map(OutpackSuccess::from)
}

#[rocket::post("/file/<hash>", format = "binary", data = "<file>")]
async fn add_file(
root: &State<PathBuf>,
hash: &str,
file: TempFile<'_>,
root: State<PathBuf>,
hash: extract::Path<String>,
file: Upload,
) -> Result<OutpackSuccess<()>, OutpackError> {
store::put_file(root, file, hash)
store::put_file(&root, file, &hash)
.await
.map_err(OutpackError::from)
.map(OutpackSuccess::from)
}

#[rocket::post("/packet/<hash>", format = "plain", data = "<packet>")]
async fn add_packet(
root: &State<PathBuf>,
hash: &str,
packet: &str,
root: State<PathBuf>,
hash: extract::Path<String>,
packet: String,
) -> Result<OutpackSuccess<()>, OutpackError> {
let hash = hash.parse::<hash::Hash>().map_err(OutpackError::from)?;
metadata::add_packet(root, packet, &hash)
metadata::add_packet(&root, &packet, &hash)
.map_err(OutpackError::from)
.map(OutpackSuccess::from)
}

#[derive(Serialize, Deserialize)]
#[serde(crate = "rocket::serde")]
struct Ids {
ids: Vec<String>,
unpacked: bool,
}

#[derive(Serialize, Deserialize)]
#[serde(crate = "rocket::serde")]
struct Hashes {
hashes: Vec<String>,
}
Expand Down Expand Up @@ -205,35 +203,68 @@ pub fn preflight(root: &Path) -> anyhow::Result<()> {
Ok(())
}

fn api_build(root: &Path) -> Rocket<Build> {
let prometheus = PrometheusMetrics::new();
metrics::register(prometheus.registry(), root).expect("metrics registered");
rocket::build()
.manage(root.to_owned())
.attach(prometheus.clone())
.mount("/metrics", prometheus)
.register("/", catchers![internal_error, not_found, bad_request])
.mount(
"/",
routes![
index,
list_location_metadata,
get_metadata,
get_metadata_by_id,
get_metadata_raw,
get_file,
get_checksum,
get_missing_packets,
get_missing_files,
add_file,
add_packet
],
)
fn make_request_span(request: &axum::extract::Request) -> tracing::span::Span {
let request_id = String::from_utf8_lossy(request.headers()["x-request-id"].as_bytes());
tracing::span!(
tracing::Level::DEBUG,
"request",
method = tracing::field::display(request.method()),
uri = tracing::field::display(request.uri()),
version = tracing::field::debug(request.version()),
request_id = tracing::field::display(request_id)
)
}

pub fn api(root: &Path) -> anyhow::Result<Rocket<Build>> {
pub fn api(root: &Path) -> anyhow::Result<Router> {
use axum::routing::{get, post};

let registry = prometheus::Registry::new();

metrics::RepositoryMetrics::register(&registry, root).expect("repository metrics registered");
let http_metrics = metrics::HttpMetrics::register(&registry).expect("http metrics registered");

preflight(root)?;
Ok(api_build(root))

let routes = Router::new()
.route("/", get(index))
.route("/metadata/list", get(list_location_metadata))
.route("/metadata/:id/json", get(get_metadata_by_id))
.route("/metadata/:id/text", get(get_metadata_raw))
.route("/checksum", get(get_checksum))
.route("/packets/missing", post(get_missing_packets))
.route("/files/missing", post(get_missing_files))
.route("/packit/metadata", get(get_metadata_since))
.route("/file/:hash", get(get_file).post(add_file))
.route("/packet/:hash", post(add_packet))
.route("/metrics", get(|| async move { metrics::render(registry) }))
.fallback(not_found)
.with_state(root.to_owned());

Ok(routes
.layer(UploadLayer::new(root.join(".outpack").join("files")))
.layer(TraceLayer::new_for_http().make_span_with(make_request_span))
.layer(PropagateRequestIdLayer::x_request_id())
.layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
.layer(CatchPanicLayer::custom(internal_error))
.layer(http_metrics.layer()))
}

pub fn serve(root: &Path) -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::TRACE)
.init();

let app = api(root)?;

tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async {
let listener = tokio::net::TcpListener::bind("0.0.0.0:8000").await?;
tracing::info!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await?;
Ok(())
})
}

#[cfg(test)]
Expand Down
3 changes: 1 addition & 2 deletions src/bin/outpack/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ fn main() -> anyhow::Result<()> {
}

Command::StartServer { root } => {
let server = outpack::api::api(&root)?;
rocket::execute(server.launch())?;
outpack::api::serve(&root)?;
}
}
Ok(())
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ mod metrics;
mod outpack_file;
mod responses;
mod store;
mod upload;
mod utils;
Loading

0 comments on commit 2e832a4

Please sign in to comment.