Skip to content

Commit

Permalink
refine auto scaler set/get logic
Browse files Browse the repository at this point in the history
  • Loading branch information
smtmfft committed Aug 13, 2024
1 parent 7d497d6 commit 88c9e29
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 31 deletions.
65 changes: 35 additions & 30 deletions provers/risc0/driver/src/bonsai/auto_scaling.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use anyhow::Result;
use anyhow::{Error, Result};
use lazy_static::lazy_static;
use reqwest::{header::HeaderMap, header::HeaderValue, header::CONTENT_TYPE, Client};
use serde::Deserialize;
use std::env;
use tracing::{debug, error};
use tracing::{debug, error as trace_err};

#[derive(Debug, Deserialize, Default)]
struct ScalerResponse {
Expand All @@ -22,38 +22,34 @@ impl BonsaiAutoScaler {
Self { url, api_key }
}

async fn get_bonsai_gpu_num(&self) -> u32 {
async fn get_bonsai_gpu_num(&self) -> Result<ScalerResponse> {
// Create a new client
let client = Client::new();
let url = self.url.clone() + "/workers";

// Create custom headers
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert("x-api-key", HeaderValue::from_str(&self.api_key).unwrap());

println!("Requesting scaler status from: {}", url);
debug!("Requesting scaler status from: {}", self.url);
// Make the POST request
let response = client.get(url).headers(headers).send().await.unwrap();
let response = client.get(self.url.clone()).headers(headers).send().await?;

// Check if the request was successful
if response.status().is_success() {
// Parse the JSON response
let data: ScalerResponse = response.json().await.unwrap_or_default();

// Use the parsed data
println!("Scaler status: {:?}", data);
data.current
debug!("Scaler status: {:?}", data);
Ok(data)
} else {
error!("Request failed with status: {}", response.status());
0
trace_err!("Request failed with status: {}", response.status());
Err(Error::msg("Failed to get bonsai gpu num".to_string()))
}
}

async fn set_bonsai_gpu_num(&self, gpu_num: u32) -> Result<()> {
// Create a new client
let client = Client::new();
let url = self.url.clone() + "/workers";

// Create custom headers
let mut headers = HeaderMap::new();
Expand All @@ -62,7 +58,7 @@ impl BonsaiAutoScaler {

// Make the POST request
let response = client
.post(url)
.post(self.url.clone())
.headers(headers)
.body(gpu_num.to_string())
.send()
Expand All @@ -77,7 +73,7 @@ impl BonsaiAutoScaler {
debug!("Scaler status: {:?}", data);
assert_eq!(data.desired, gpu_num);
} else {
error!("Request failed with status: {}", response.status());
trace_err!("Request failed with status: {}", response.status());
}

Ok(())
Expand All @@ -91,32 +87,37 @@ lazy_static! {
env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set");
}

const MAX_BONSAI_GPU_NUM: u32 = 15;

pub(crate) async fn maxpower_bonsai() -> Result<()> {
let auto_scaler = BonsaiAutoScaler::new(BONSAI_API_URL.to_string(), BONSAI_API_KEY.to_string());

let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await;
if current_gpu_num == 15 {
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?;
// either already maxed out or pending to be maxed out
if current_gpu_num.current == MAX_BONSAI_GPU_NUM
|| (current_gpu_num.current + current_gpu_num.pending == MAX_BONSAI_GPU_NUM)
{
Ok(())
} else {
auto_scaler.set_bonsai_gpu_num(15).await?;
// wait 3 minute for the bonsai to heat up
auto_scaler.set_bonsai_gpu_num(MAX_BONSAI_GPU_NUM).await?;
// wait for the bonsai to heat up
tokio::time::sleep(tokio::time::Duration::from_secs(180)).await;
assert!(auto_scaler.get_bonsai_gpu_num().await == 15);
let scaler_status = auto_scaler.get_bonsai_gpu_num().await?;
assert!(scaler_status.current == MAX_BONSAI_GPU_NUM);
Ok(())
}
}

pub(crate) async fn shutdown_bonsai() -> Result<()> {
let auto_scaler = BonsaiAutoScaler::new(BONSAI_API_URL.to_string(), BONSAI_API_KEY.to_string());
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await;
if current_gpu_num == 15 {
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?;
if current_gpu_num.current == 0 {
Ok(())
} else {
auto_scaler.set_bonsai_gpu_num(0).await?;

// wait 1 minute for the bonsai to cool down
// wait few minute for the bonsai to cool down
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
assert!(auto_scaler.get_bonsai_gpu_num().await == 0);
let scaler_status = auto_scaler.get_bonsai_gpu_num().await?;
assert!(scaler_status.current == 0);
Ok(())
}
}
Expand All @@ -132,8 +133,12 @@ mod test {
let bonsai_url = env::var("BONSAI_API_URL").expect("BONSAI_API_URL must be set");
let bonsai_key = env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set");
let auto_scaler = BonsaiAutoScaler::new(bonsai_url, bonsai_key);
let gpu_num = auto_scaler.get_bonsai_gpu_num().await;
assert!(gpu_num >= 0 && gpu_num <= 15);
let scalar_status = auto_scaler.get_bonsai_gpu_num().await.unwrap();
assert!(scalar_status.current <= MAX_BONSAI_GPU_NUM);
assert_eq!(
scalar_status.desired,
scalar_status.current + scalar_status.pending
);
}

#[ignore]
Expand All @@ -149,7 +154,7 @@ mod test {
.expect("Failed to set bonsai gpu num");
// wait few minute for the bonsai to heat up
tokio::time::sleep(tokio::time::Duration::from_secs(200)).await;
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await;
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current;
assert_eq!(current_gpu_num, 7);

auto_scaler
Expand All @@ -158,7 +163,7 @@ mod test {
.expect("Failed to set bonsai gpu num");
// wait few minute for the bonsai to cool down
tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await;
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current;
assert_eq!(current_gpu_num, 0);
}
}
2 changes: 1 addition & 1 deletion provers/risc0/driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use risc0_zkvm::{serde::to_vec, sha::Digest};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use std::fmt::Debug;
use tracing::{debug, error, info as traicing_info};
use tracing::{debug, info as traicing_info};

use crate::{
bonsai::auto_scaling::{maxpower_bonsai, shutdown_bonsai},
Expand Down

0 comments on commit 88c9e29

Please sign in to comment.