Skip to content

Commit

Permalink
Functional TrainingWorkload for submitting to RCP - with hardcoded va…
Browse files Browse the repository at this point in the history
…lues
  • Loading branch information
evanjt committed Oct 29, 2024
1 parent 639bf60 commit 95a8a3c
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 6 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,5 @@ secrecy = "0.8.0"
anyhow = "1.0.89"
thiserror = "1.0.64"
tokio-util = "0.7.12"
rand = "0.8.5"
schemars = "0.8.21"
48 changes: 48 additions & 0 deletions src/external/k8s/crd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use kube::CustomResource;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(CustomResource, Debug, Serialize, Deserialize, Clone, JsonSchema)]
#[kube(
group = "run.ai",
version = "v2alpha1",
kind = "TrainingWorkload",
namespaced
)]
pub struct TrainingWorkloadSpec {
pub allow_privilege_escalation: Option<ValueField<bool>>,
pub environment: Environment,
pub gpu: ValueField<String>, // Using ValueField to match `value` structure
pub image: ValueField<String>,
#[serde(rename = "imagePullPolicy")]
pub image_pull_policy: ValueField<String>,
pub name: ValueField<String>,
pub run_as_gid: Option<ValueField<u32>>,
pub run_as_uid: Option<ValueField<u32>>,
pub run_as_user: Option<ValueField<bool>>,
pub service_type: Option<ValueField<String>>,
pub usage: Option<String>,
}

#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct Environment {
pub items: EnvironmentItems,
}

#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct EnvironmentItems {
pub input_object_ids: ValueField<String>,
pub s3_access_key: ValueField<String>,
pub s3_bucket_id: ValueField<String>,
pub s3_prefix: ValueField<String>,
pub s3_secret_key: ValueField<String>,
pub s3_url: ValueField<String>,
pub submission_id: ValueField<String>,
pub base_image: ValueField<String>,
}

#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct ValueField<T> {
pub value: T,
}
1 change: 1 addition & 0 deletions src/external/k8s/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod crd;
pub mod models;
pub mod services;
2 changes: 1 addition & 1 deletion src/external/k8s/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub async fn get_pods() -> Result<Option<Vec<PodName>>> {
Ok(Some(pods))
}

async fn refresh_token_and_get_client() -> Result<Client> {
pub async fn refresh_token_and_get_client() -> Result<Client> {
let app_config = Config::from_env();

// Read and parse the kubeconfig file
Expand Down
4 changes: 2 additions & 2 deletions src/submissions/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use uuid::Uuid;

#[derive(ToSchema, Serialize)]
#[derive(ToSchema, Serialize, Debug)]
pub struct Submission {
id: Uuid,
name: String,
Expand All @@ -14,7 +14,7 @@ pub struct Submission {
comment: Option<String>,
created_on: NaiveDateTime,
last_updated: NaiveDateTime,
associations: Option<Vec<crate::uploads::models::UploadRead>>,
pub(super) associations: Option<Vec<crate::uploads::models::UploadRead>>,
}

impl From<super::db::Model> for Submission {
Expand Down
134 changes: 133 additions & 1 deletion src/submissions/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use crate::common::filter::{apply_filters, parse_range};
use crate::common::models::FilterOptions;
use crate::common::pagination::calculate_content_range;
use crate::common::sort::generic_sort;
// use crate::external::k8s::crd::DictionaryField;
use anyhow::Result;
use axum::debug_handler;
use axum::{
extract::{Path, Query, State},
http::StatusCode,
Expand All @@ -12,10 +15,14 @@ use axum::{
use axum_keycloak_auth::{
instance::KeycloakAuthInstance, layer::KeycloakAuthLayer, PassthroughMode,
};
use kube::api::PostParams;
use kube::Api;
use rand::Rng;
use sea_orm::{
query::*, ActiveModelTrait, DatabaseConnection, DeleteResult, EntityTrait, IntoActiveModel,
ModelTrait, SqlErr,
};
use serde_json::json;
use std::sync::Arc;
use uuid::Uuid;

Expand All @@ -24,7 +31,10 @@ pub fn router(db: DatabaseConnection, keycloak_auth_instance: Arc<KeycloakAuthIn
.route("/", routing::get(get_all).post(create_one))
.route(
"/:id",
routing::get(get_one).put(update_one).delete(delete_one),
routing::get(get_one)
.put(update_one)
.delete(delete_one)
.post(execute_workflow),
)
.with_state(db)
.layer(
Expand Down Expand Up @@ -223,3 +233,125 @@ pub async fn delete_one(State(db): State<DatabaseConnection>, Path(id): Path<Uui

StatusCode::NO_CONTENT
}

// use axum::{extract::{Path, State}, http::StatusCode};
// use rand::Rng;
// use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter};
// use uuid::Uuid;
// use kube::{Api, Client};
// use kube::api::PostParams;
// use crate::config::Config;
use crate::external::k8s::crd::{
Environment,
EnvironmentItems,
// TrainingEnvironment, TrainingEnvironmentItems,
TrainingWorkload,
TrainingWorkloadSpec,
ValueField,
};
// use crate::external::k8s::services::PodName;
use kube::api::ListParams;

#[debug_handler]
pub async fn execute_workflow(
State(db): State<DatabaseConnection>,
Path(id): Path<Uuid>,
) -> StatusCode {
// Generate a unique job name
let random_number: u32 = rand::thread_rng().gen_range(10000..99999);
let job_name = format!("labcaller-{}-{}", id, random_number);

// Fetch submission and related uploads
let obj = match super::db::Entity::find_by_id(id).one(&db).await {
Ok(Some(submission)) => submission,
_ => return StatusCode::NOT_FOUND,
};

let input_object_ids: Vec<Uuid> = obj
.find_related(crate::uploads::db::Entity)
.all(&db)
.await
.unwrap()
.into_iter()
.map(|assoc| assoc.id)
.collect();

// Set up Kubernetes client and configuration
let config = crate::config::Config::from_env();
let client = match crate::external::k8s::services::refresh_token_and_get_client().await {
Ok(client) => client,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR,
};

// Create the `TrainingWorkload` custom resource instance

let training_workload = TrainingWorkload::new(
&job_name,
TrainingWorkloadSpec {
allow_privilege_escalation: Some(ValueField { value: true }),
environment: Environment {
items: EnvironmentItems {
input_object_ids: ValueField {
value: serde_json::to_string(&input_object_ids).unwrap(),
},
s3_access_key: ValueField {
value: config.s3_access_key.to_string(),
},
s3_bucket_id: ValueField {
value: config.s3_bucket.to_string(),
},
s3_prefix: ValueField {
value: config.s3_prefix.to_string(),
},
s3_secret_key: ValueField {
value: config.s3_secret_key.to_string(),
},
s3_url: ValueField {
value: config.s3_url.to_string(),
},
submission_id: ValueField {
value: id.to_string(),
},
base_image: ValueField {
value: "registry.rcp.epfl.ch/rcp-test-ejthomas/dorado:0.2".to_string(),
},
},
},
gpu: ValueField {
value: "1".to_string(),
},
image: ValueField {
value: "registry.rcp.epfl.ch/rcp-test-ejthomas/dorado:0.2".to_string(),
},
image_pull_policy: ValueField {
value: "Always".to_string(),
},
name: ValueField {
value: job_name.clone(),
},
run_as_gid: None,
run_as_uid: None,
run_as_user: None,
// run_as_gid: Some(ValueField { value: 1000 }),
// run_as_uid: Some(ValueField { value: 1000 }),
// run_as_user: Some(ValueField { value: true }),
service_type: None, // or Some(ValueField { value: "service_type_value".to_string() })
usage: Some("Submit".to_string()),
},
);

println!("Submitting TrainingWorkload: {:?}", training_workload);
// Submit the custom resource to Kubernetes
let api: Api<TrainingWorkload> = Api::namespaced(client, &config.kube_namespace);

match api.create(&PostParams::default(), &training_workload).await {
Ok(_) => {
println!("Submitted TrainingWorkload: {}", job_name);
StatusCode::CREATED
}
Err(e) => {
eprintln!("Failed to submit TrainingWorkload: {:?}", e);
StatusCode::INTERNAL_SERVER_ERROR
}
}
}
4 changes: 2 additions & 2 deletions src/uploads/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use serde::Serialize;
use utoipa::ToSchema;
use uuid::Uuid;

#[derive(ToSchema, Serialize, FromQueryResult)]
#[derive(ToSchema, Serialize, FromQueryResult, Debug)]
pub struct UploadRead {
id: Uuid,
pub id: Uuid,
created_on: NaiveDateTime,
filename: String,
size_bytes: i64,
Expand Down

0 comments on commit 95a8a3c

Please sign in to comment.