Skip to content

Commit

Permalink
Merge pull request #33 from stakpak/feat/add-sync
Browse files Browse the repository at this point in the history
Feat: Add Sync
  • Loading branch information
kajogo777 authored Feb 13, 2025
2 parents 4caac82 + 3db1440 commit 1ddbc19
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 41 deletions.
3 changes: 3 additions & 0 deletions src/client/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ impl FlowRef {
pub struct AgentSession {
pub id: Uuid,
pub agent_id: AgentID,
pub flow_ref: Option<FlowRef>,
pub visibility: AgentSessionVisibility,
pub checkpoints: Vec<AgentCheckpointListItem>,
pub created_at: DateTime<Utc>,
Expand Down Expand Up @@ -385,6 +386,7 @@ pub struct AgentCheckpointListItem {
pub struct AgentSessionListItem {
pub id: Uuid,
pub agent_id: AgentID,
pub flow_ref: Option<FlowRef>,
pub visibility: AgentSessionVisibility,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
Expand All @@ -395,6 +397,7 @@ impl From<AgentSession> for AgentSessionListItem {
Self {
id: item.id,
agent_id: item.agent_id,
flow_ref: item.flow_ref,
visibility: item.visibility,
created_at: item.created_at,
updated_at: item.updated_at,
Expand Down
51 changes: 51 additions & 0 deletions src/commands/agent/get_or_create_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use uuid::Uuid;

use crate::client::{
models::{
AgentCheckpointListItem, AgentID, AgentInput, AgentSessionListItem, AgentSessionVisibility,
},
Client,
};

pub async fn get_or_create_session(
client: &Client,
agent_id: AgentID,
checkpoint_id: Option<String>,
input: Option<AgentInput>,
) -> Result<(AgentID, AgentSessionListItem, AgentCheckpointListItem), String> {
match checkpoint_id {
Some(checkpoint_id) => {
let checkpoint_uuid = Uuid::parse_str(&checkpoint_id).map_err(|_| {
format!(
"Invalid checkpoint ID '{}' - must be a valid UUID",
checkpoint_id
)
})?;

let output = client.get_agent_checkpoint(checkpoint_uuid).await?;

Ok((
output.output.get_agent_id(),
output.session,
output.checkpoint,
))
}
None => {
let session = client
.create_agent_session(
agent_id.clone(),
AgentSessionVisibility::Private,
input.clone(),
)
.await?;

let checkpoint = session
.checkpoints
.first()
.ok_or("No checkpoint found in new session")?
.clone();

Ok((agent_id, session.into(), checkpoint))
}
}
}
21 changes: 20 additions & 1 deletion src/commands/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@ use crate::{
mod get_next_input;
pub use get_next_input::*;

mod get_or_create_session;
pub use get_or_create_session::*;

mod run_actions;
pub use run_actions::*;

mod run_agent;
pub use run_agent::*;

use super::flow;

#[derive(Subcommand)]
pub enum AgentCommands {
/// List agent sessions
Expand Down Expand Up @@ -128,11 +133,25 @@ impl AgentCommands {

input.set_user_prompt(user_prompt);

let (agent_id, session, checkpoint) =
get_or_create_session(&client, agent_id, checkpoint_id, Some(input.clone()))
.await?;

if let Some(flow_ref) = &session.flow_ref {
let config_clone = config.clone();
let client_clone = Client::new(&config_clone).map_err(|e| e.to_string())?;
let flow_ref = flow_ref.clone();
tokio::spawn(async move {
flow::sync(&config_clone, &client_clone, &flow_ref, None).await
});
}

run_agent(
&config,
&client,
agent_id,
checkpoint_id,
session,
checkpoint,
Some(input),
short_circuit_actions,
interactive,
Expand Down
45 changes: 7 additions & 38 deletions src/commands/agent/run_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use uuid::Uuid;

use crate::{
client::{
models::{AgentID, AgentInput, AgentSessionVisibility, AgentStatus, RunAgentInput},
models::{
AgentCheckpointListItem, AgentID, AgentInput, AgentSessionListItem, AgentStatus,
RunAgentInput,
},
Client,
},
commands::agent::get_next_input,
Expand All @@ -12,51 +15,17 @@ use crate::{

use super::{get_next_input_interactive, AgentOutputListener};

#[allow(clippy::too_many_arguments)]
pub async fn run_agent(
config: &AppConfig,
client: &Client,
agent_id: AgentID,
checkpoint_id: Option<String>,
session: AgentSessionListItem,
checkpoint: AgentCheckpointListItem,
input: Option<AgentInput>,
short_circuit_actions: bool,
interactive: bool,
) -> Result<Uuid, String> {
let (agent_id, session, checkpoint) = match checkpoint_id {
Some(checkpoint_id) => {
let checkpoint_uuid = Uuid::parse_str(&checkpoint_id).map_err(|_| {
format!(
"Invalid checkpoint ID '{}' - must be a valid UUID",
checkpoint_id
)
})?;

let output = client.get_agent_checkpoint(checkpoint_uuid).await?;

(
output.output.get_agent_id(),
output.session,
output.checkpoint,
)
}
None => {
let session = client
.create_agent_session(
agent_id.clone(),
AgentSessionVisibility::Private,
input.clone(),
)
.await?;

let checkpoint = session
.checkpoints
.first()
.ok_or("No checkpoint found in new session")?
.clone();

(agent_id, session.into(), checkpoint)
}
};

let print = setup_output_handler(config, session.id.to_string()).await?;

let mut input = RunAgentInput {
Expand Down
24 changes: 22 additions & 2 deletions src/commands/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use agent::{run_agent, AgentCommands};
use agent::{get_or_create_session, run_agent, AgentCommands};
use clap::Subcommand;
use flow::{clone, get_flow_ref, push, sync};
use termimad::MadSkin;
Expand Down Expand Up @@ -360,6 +360,21 @@ impl Commands {
if path_map.is_empty() {
return Err("No configurations found to apply".into());
}

let config_clone = config.clone();
let client_clone = Client::new(&config_clone).map_err(|e| e.to_string())?;
let flow_ref_clone = flow_ref.clone();
let dir_clone = dir.clone();
tokio::spawn(async move {
flow::sync(
&config_clone,
&client_clone,
&flow_ref_clone,
dir_clone.as_deref(),
)
.await
});

let agent_id = AgentID::KevinV1;

let agent_input = match provisioner {
Expand All @@ -377,11 +392,16 @@ impl Commands {
}
}?;

let (agent_id, session, checkpoint) =
get_or_create_session(&client, agent_id, None, Some(agent_input.clone()))
.await?;

let checkpoint_id = run_agent(
&config,
&client,
agent_id,
None,
session,
checkpoint,
Some(agent_input),
true,
true,
Expand Down

0 comments on commit 1ddbc19

Please sign in to comment.