Skip to content

Commit

Permalink
Avoid proxying data request by forwarding signed S3 URL to client
Browse files Browse the repository at this point in the history
  • Loading branch information
evanjt committed Nov 6, 2024
1 parent d8c6ab1 commit b4efbc9
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 118 deletions.
3 changes: 0 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ pub struct Config {
pub interval_external_services: u64,
pub submission_base_image: String,
pub submission_base_image_tag: String,
pub serializer_secret_key: String,

pub s3_prefix: String, // Prefix within the bucket, ie. labcaller-dev
pub pod_prefix: String, // What is prefixed to the pod name, ie. labcaller-dev}
Expand Down Expand Up @@ -81,8 +80,6 @@ impl Config {
.expect("SUBMISSION_BASE_IMAGE must be set"),
submission_base_image_tag: env::var("SUBMISSION_BASE_IMAGE_TAG")
.expect("SUBMISSION_BASE_IMAGE_TAG must be set"),
serializer_secret_key: env::var("SERIALIZER_SECRET_KEY")
.expect("SERIALIZER_SECRET_KEY must be set"),
db_prefix,
db_url,
s3_prefix,
Expand Down
4 changes: 0 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ async fn main() {
.route("/healthz", get(common::views::healthz))
.route("/api/config", get(common::views::get_ui_config))
.route("/api/status", get(common::views::get_status))
.route(
"/api/submissions/download/:token",
get(submissions::views::download_file),
)
.with_state(db.clone())
.nest(
"/api/submissions",
Expand Down
4 changes: 2 additions & 2 deletions src/submissions/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ impl SubmissionUpdate {
}

#[derive(Debug, Serialize, Deserialize)]
pub struct DownloadToken {
pub token: String,
pub struct DownloadPath {
pub url: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct Claims {
Expand Down
129 changes: 20 additions & 109 deletions src/submissions/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,26 @@ use crate::external::k8s::crd::{
Environment, EnvironmentItems, TrainingWorkload, TrainingWorkloadSpec, ValueField,
};
use anyhow::Result;
use aws_sdk_s3::presigning::PresigningConfig;
use aws_sdk_s3::Client as S3Client;
use axum::body::Body;
use axum::http::HeaderMap;
use axum::{
debug_handler,
extract::{Path, Query, State},
http::{header, StatusCode},
http::StatusCode,
response::IntoResponse,
routing, Json, Router,
};
use axum_keycloak_auth::{
instance::KeycloakAuthInstance, layer::KeycloakAuthLayer, PassthroughMode,
};
use bytes::Bytes;
use chrono::{Duration, Utc};
use futures::StreamExt;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use kube::{api::PostParams, Api};
use rand::Rng;
use sea_orm::{
query::*, ActiveModelTrait, DatabaseConnection, DeleteResult, EntityTrait, IntoActiveModel,
ModelTrait, SqlErr,
};
use std::io::ErrorKind;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use std::time::Duration;
use uuid::Uuid;

pub fn router(
Expand All @@ -50,7 +43,7 @@ pub fn router(
.delete(delete_one)
.post(execute_workflow),
)
.route("/:id/:filename", routing::get(generate_download_token))
.route("/:id/:filename", routing::get(generate_download_url))
.with_state((db, s3))
.layer(
KeycloakAuthLayer::<Role>::builder()
Expand Down Expand Up @@ -371,115 +364,33 @@ pub async fn execute_workflow(
}
}

pub async fn generate_download_token(
pub async fn generate_download_url(
Path((submission_id, filename)): Path<(Uuid, String)>,
) -> Result<Json<super::models::DownloadToken>, (StatusCode, String)> {
let config: crate::config::Config = crate::config::Config::from_env();
let expiration = Utc::now() + Duration::hours(1); // Token expiry in 1 hour
let claims = super::models::Claims {
submission_id,
filename: filename.clone(),
exp: expiration.timestamp() as usize,
};

let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(config.serializer_secret_key.as_ref()),
)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Token creation failed".to_string(),
)
})?;

Ok(Json(super::models::DownloadToken { token }))
}
use tokio_util::io::ReaderStream;

pub async fn download_file(
Path(token): Path<String>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
) -> Result<Json<super::models::DownloadPath>, (StatusCode, String)> {
// Returns a presigned URL from S3. Assumes the client has access to the
// S3 domain (EPFL network in this case).
let config = crate::config::Config::from_env();
let s3 = crate::external::s3::services::get_client(&config).await;

let token_data = decode::<super::models::Claims>(
&token,
&DecodingKey::from_secret(config.serializer_secret_key.as_ref()),
&Validation::default(),
)
.map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string()))?;

let claims = token_data.claims;
if claims.exp < Utc::now().timestamp() as usize {
return Err((StatusCode::UNAUTHORIZED, "Token expired".to_string()));
}

let key = format!(
"{}/outputs/{}/{}",
config.s3_prefix, claims.submission_id, claims.filename
config.s3_prefix, submission_id, filename
);

let object = s3
// Get presigned URL to give to client
let presigned_request = s3
.get_object()
.bucket(&config.s3_bucket)
.key(&key)
.send()
.presigned(
PresigningConfig::builder()
.expires_in(Duration::from_secs(60 * 60)) // One hour
.build()
.expect("Duration is invalid"),
)
.await
.map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {}", e)))?;

// Get the ByteStream from the object
let s3_body = object.body;

// Convert the ByteStream into an AsyncRead
let s3_body_async = s3_body.into_async_read();

// Convert the AsyncRead into a Stream using ReaderStream
let s3_stream = ReaderStream::new(s3_body_async);

// Create a bounded mpsc channel with capacity 10 (adjust as needed)
let (tx, rx) = mpsc::channel::<Result<Bytes, std::io::Error>>(10);

// Spawn a task to read from the s3_stream and send into the channel
tokio::spawn(async move {
tokio::pin!(s3_stream);
while let Some(result) = s3_stream.next().await {
match result {
Ok(bytes) => {
if tx.send(Ok(bytes)).await.is_err() {
// Receiver has dropped
break;
}
}
Err(e) => {
// Map the error into std::io::Error
let io_err = std::io::Error::new(ErrorKind::Other, e);
let _ = tx.send(Err(io_err)).await;
break;
}
}
}
});

// Create a stream from the receiver side of the channel
let stream = ReceiverStream::new(rx);

// Use Body::from_stream to create the response body
let body = Body::from_stream(stream);

// Build the response with headers
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_DISPOSITION,
format!("attachment; filename=\"{}\"", claims.filename)
.parse()
.unwrap(),
);
headers.insert(
header::CONTENT_TYPE,
"application/octet-stream".parse().unwrap(),
);
.unwrap();

Ok((headers, body))
let presigned_url = presigned_request.uri().to_string();
Ok(Json(super::models::DownloadPath { url: presigned_url }))
}

0 comments on commit b4efbc9

Please sign in to comment.