Skip to content

Commit

Permalink
Add outputs into submission get one
Browse files Browse the repository at this point in the history
  • Loading branch information
evanjt committed Oct 29, 2024
1 parent ba54498 commit 630eafe
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 13 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ thiserror = "1.0.64"
tokio-util = "0.7.12"
rand = "0.8.5"
schemars = "0.8.21"
aws-smithy-types-convert = { version = "0.60.8", features = ["convert-chrono"] }
1 change: 1 addition & 0 deletions src/external/s3/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod models;
pub mod services;
22 changes: 22 additions & 0 deletions src/external/s3/models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use aws_smithy_types_convert::date_time::DateTimeExt;
use chrono::{DateTime, Utc};
use sea_orm::FromQueryResult;
use serde::Serialize;
use utoipa::ToSchema;

#[derive(ToSchema, Serialize, FromQueryResult, Debug)]
pub struct OutputObject {
key: String,
last_modified: DateTime<Utc>,
size_bytes: i64,
}

impl From<aws_sdk_s3::types::Object> for OutputObject {
fn from(model: aws_sdk_s3::types::Object) -> Self {
Self {
key: model.key.unwrap(),
last_modified: model.last_modified.unwrap().to_chrono_utc().unwrap(),
size_bytes: model.size.unwrap(),
}
}
}
25 changes: 25 additions & 0 deletions src/external/s3/services.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::config::Config;
use anyhow::Result;
use aws_config::BehaviorVersion;
use aws_sdk_s3::config::Credentials;
use aws_sdk_s3::{config::Region, Client as S3Client};
use std::sync::Arc;
use uuid::Uuid;

pub async fn get_client(config: &Config) -> Arc<S3Client> {
let region = Region::new("us-east-1");
Expand All @@ -22,3 +24,26 @@ pub async fn get_client(config: &Config) -> Arc<S3Client> {

Arc::new(S3Client::new(&shared_config))
}

pub async fn get_outputs_from_id(
client: Arc<S3Client>,
id: Uuid,
) -> Result<Vec<super::models::OutputObject>, Box<dyn std::error::Error>> {
let config = crate::config::Config::from_env();
let prefix = format!("{}/outputs/{}/", config.s3_prefix, id);
let mut outputs: Vec<super::models::OutputObject> = vec![];
let list = client
.list_objects()
.bucket(config.s3_bucket)
.prefix(prefix.clone())
.send()
.await?;

if let Some(contents) = list.contents {
for object in contents {
outputs.push(object.into());
}
}

Ok(outputs)
}
6 changes: 5 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ async fn main() {
.with_state(db.clone())
.nest(
"/api/submissions",
submissions::views::router(db.clone(), keycloak_auth_instance.clone()),
submissions::views::router(
db.clone(),
keycloak_auth_instance.clone(),
s3_client.clone(),
),
)
.nest(
"/api/uploads",
Expand Down
10 changes: 10 additions & 0 deletions src/submissions/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ pub struct Model {
pub created_on: NaiveDateTime,
pub last_updated: NaiveDateTime,
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "crate::uploads::db::Entity")]
FileObjectAssociations,
#[sea_orm(has_many = "crate::submissions::run_status::db::Entity")]
RunStatus,
}

impl Related<crate::uploads::db::Entity> for Entity {
Expand All @@ -37,4 +40,11 @@ impl Related<crate::uploads::db::Entity> for Entity {
}
}

// Implement the Related trait for RunStatus to complete the relationship
impl Related<crate::submissions::run_status::db::Entity> for Entity {
fn to() -> RelationDef {
Relation::RunStatus.def()
}
}

impl ActiveModelBehavior for ActiveModel {}
21 changes: 18 additions & 3 deletions src/submissions/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct Submission {
created_on: NaiveDateTime,
last_updated: NaiveDateTime,
pub(super) associations: Option<Vec<crate::uploads::models::UploadRead>>,
outputs: Option<Vec<crate::external::s3::models::OutputObject>>,
}

impl From<super::db::Model> for Submission {
Expand All @@ -28,15 +29,28 @@ impl From<super::db::Model> for Submission {
created_on: model.created_on,
last_updated: model.last_updated,
associations: None,
outputs: None,
}
}
}

impl From<(super::db::Model, Option<Vec<crate::uploads::db::Model>>)> for Submission {
fn from(model_tuple: (super::db::Model, Option<Vec<crate::uploads::db::Model>>)) -> Self {
impl
From<(
super::db::Model,
Option<Vec<crate::uploads::db::Model>>,
Vec<crate::external::s3::models::OutputObject>,
)> for Submission
{
fn from(
model_tuple: (
super::db::Model,
Option<Vec<crate::uploads::db::Model>>,
Vec<crate::external::s3::models::OutputObject>,
),
) -> Self {
let submission = model_tuple.0;
let uploads = model_tuple.1;

let outputs = model_tuple.2;
Self {
id: submission.id,
name: submission.name,
Expand All @@ -52,6 +66,7 @@ impl From<(super::db::Model, Option<Vec<crate::uploads::db::Model>>)> for Submis
.map(|association| association.into())
.collect(),
),
outputs: Some(outputs),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/submissions/run_status/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub enum Relation {
Submissions,
}

// Implement the relationship back to Submission
impl Related<crate::submissions::db::Entity> for Entity {
fn to() -> RelationDef {
Relation::Submissions.def()
Expand Down
30 changes: 21 additions & 9 deletions src/submissions/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::external::k8s::crd::{
Environment, EnvironmentItems, TrainingWorkload, TrainingWorkloadSpec, ValueField,
};
use anyhow::Result;
use aws_sdk_s3::Client as S3Client;
use axum::{
debug_handler,
extract::{Path, Query, State},
Expand All @@ -26,7 +27,11 @@ use sea_orm::{
use std::sync::Arc;
use uuid::Uuid;

pub fn router(db: DatabaseConnection, keycloak_auth_instance: Arc<KeycloakAuthInstance>) -> Router {
pub fn router(
db: DatabaseConnection,
keycloak_auth_instance: Arc<KeycloakAuthInstance>,
s3: Arc<S3Client>,
) -> Router {
Router::new()
.route("/", routing::get(get_all).post(create_one))
.route(
Expand All @@ -36,7 +41,7 @@ pub fn router(db: DatabaseConnection, keycloak_auth_instance: Arc<KeycloakAuthIn
.delete(delete_one)
.post(execute_workflow),
)
.with_state(db)
.with_state((db, s3))
.layer(
KeycloakAuthLayer::<Role>::builder()
.instance(keycloak_auth_instance)
Expand All @@ -57,7 +62,7 @@ const RESOURCE_NAME: &str = "submissions";
)]
pub async fn get_all(
Query(params): Query<FilterOptions>,
State(db): State<DatabaseConnection>,
State((db, _s3)): State<(DatabaseConnection, Arc<S3Client>)>,
) -> impl IntoResponse {
let (offset, limit) = parse_range(params.range.clone());

Expand Down Expand Up @@ -110,7 +115,7 @@ pub async fn get_all(
responses((status = CREATED, body = super::models::Submission))
)]
pub async fn create_one(
State(db): State<DatabaseConnection>,
State((db, _s3)): State<(DatabaseConnection, Arc<S3Client>)>,
Json(payload): Json<super::models::SubmissionCreate>,
) -> Result<(StatusCode, Json<super::models::Submission>), (StatusCode, Json<String>)> {
let new_obj = super::db::Model {
Expand Down Expand Up @@ -158,7 +163,7 @@ pub async fn create_one(
responses((status = OK, body = super::models::Submission))
)]
pub async fn get_one(
State(db): State<DatabaseConnection>,
State((db, s3)): State<(DatabaseConnection, Arc<S3Client>)>,
Path(id): Path<Uuid>,
) -> Result<Json<super::models::Submission>, (StatusCode, Json<String>)> {
let obj = match super::db::Entity::find_by_id(id)
Expand All @@ -169,6 +174,10 @@ pub async fn get_one(
Ok(obj) => obj.unwrap(),
_ => return Err((StatusCode::NOT_FOUND, Json("Not Found".to_string()))),
};
let outputs: Vec<crate::external::s3::models::OutputObject> =
crate::external::s3::services::get_outputs_from_id(s3, obj.id)
.await
.unwrap();

let uploads = match obj.find_related(crate::uploads::db::Entity).all(&db).await {
// Return all or none. If any fail, return an error
Expand All @@ -181,7 +190,7 @@ pub async fn get_one(
}
};

let submission: super::models::Submission = (obj, uploads).into();
let submission: super::models::Submission = (obj, uploads, outputs).into();

Ok(Json(submission))
}
Expand All @@ -192,7 +201,7 @@ pub async fn get_one(
responses((status = OK, body = super::models::Submission))
)]
pub async fn update_one(
State(db): State<DatabaseConnection>,
State((db, _s3)): State<(DatabaseConnection, Arc<S3Client>)>,
Path(id): Path<Uuid>,
Json(payload): Json<super::models::SubmissionUpdate>,
) -> impl IntoResponse {
Expand All @@ -218,7 +227,10 @@ pub async fn update_one(
path = format!("/api/{}/{{id}}", RESOURCE_NAME),
responses((status = NO_CONTENT))
)]
pub async fn delete_one(State(db): State<DatabaseConnection>, Path(id): Path<Uuid>) -> StatusCode {
pub async fn delete_one(
State((db, _s3)): State<(DatabaseConnection, Arc<S3Client>)>,
Path(id): Path<Uuid>,
) -> StatusCode {
let obj = super::db::Entity::find_by_id(id)
.one(&db)
.await
Expand All @@ -236,7 +248,7 @@ pub async fn delete_one(State(db): State<DatabaseConnection>, Path(id): Path<Uui

#[debug_handler]
pub async fn execute_workflow(
State(db): State<DatabaseConnection>,
State((db, _s3)): State<(DatabaseConnection, Arc<S3Client>)>,
Path(id): Path<Uuid>,
) -> StatusCode {
// Generate a unique job name
Expand Down

0 comments on commit 630eafe

Please sign in to comment.