Skip to content

Commit

Permalink
Use a tokio::task::LocalSet.
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubadamw committed Nov 3, 2024
1 parent f66b357 commit f7f4fd9
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 42 deletions.
97 changes: 56 additions & 41 deletions src/worker/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ struct ActionInput<T> {
#[allow(clippy::too_many_arguments)]
async fn handle_start_step_run(
action_function_task_join_set: &mut tokio::task::JoinSet<crate::InternalResult<()>>,
local_set: &tokio::task::LocalSet,
abort_handles: &mut HashMap<String, tokio::task::AbortHandle>,
dispatcher: &mut DispatcherClient<
tonic::service::interceptor::InterceptedService<
Expand Down Expand Up @@ -103,43 +104,56 @@ async fn handle_start_step_run(

let worker_id = worker_id.to_string();
let step_run_id = action.step_run_id.clone();
let abort_handle = action_function_task_join_set.spawn_local(async move {
let context = Context::new(
input.parents,
workflow_run_id,
workflow_step_run_id,
workflow_service_client,
data,
);
let action_event = match action_callable(context, input.input).catch_unwind().await {
Ok(Ok(output_value)) => step_action_event(
&worker_id,
&action,
StepActionEventType::StepEventTypeCompleted,
serde_json::to_string(&output_value).expect("must succeed"),
),
Ok(Err(error)) => step_action_event(
&worker_id,
&action,
StepActionEventType::StepEventTypeFailed,
error.to_string(),
),
Err(_) => step_action_event(
&worker_id,
&action,
StepActionEventType::StepEventTypeFailed,
"action panicked".to_owned(),
),
};

dispatcher
.send_step_action_event(action_event)
.await
.map_err(crate::InternalError::CouldNotSendStepStatus)?
.into_inner();

Ok(())
});
let abort_handle = action_function_task_join_set.spawn_local_on(
async move {
let context = Context::new(
input.parents,
workflow_run_id,
workflow_step_run_id,
workflow_service_client,
data,
);
let action_event = match action_callable(context, input.input).catch_unwind().await {
Ok(Ok(output_value)) => step_action_event(
&worker_id,
&action,
StepActionEventType::StepEventTypeCompleted,
serde_json::to_string(&output_value).expect("must succeed"),
),
Ok(Err(error)) => step_action_event(
&worker_id,
&action,
StepActionEventType::StepEventTypeFailed,
error.to_string(),
),
Err(error) => {
let message = error
.downcast_ref::<&str>()
.map(|value| value.to_string())
.or_else(|| error.downcast_ref::<String>().map(|value| value.clone()));
step_action_event(
&worker_id,
&action,
StepActionEventType::StepEventTypeFailed,
message.unwrap_or_else(|| {
String::from(
"task panicked with a payload that was not a `&str` nor a `String`",
)
}),
)
}
};

dispatcher
.send_step_action_event(action_event)
.await
.map_err(crate::InternalError::CouldNotSendStepStatus)?
.into_inner();

Ok(())
},
local_set,
);
abort_handles.insert(step_run_id, abort_handle);

Ok(())
Expand Down Expand Up @@ -182,14 +196,15 @@ pub(crate) async fn run(
listener_v2_timeout: Option<u64>,
mut interrupt_receiver: tokio::sync::mpsc::Receiver<()>,
data: Arc<DataMap>,
) -> crate::InternalResult<()> {
) -> crate::InternalResult<tokio::task::LocalSet> {
use futures_util::StreamExt;

let mut retries: usize = 0;
let mut listen_strategy = ListenStrategy::V2;

let connection_attempt = tokio::time::Instant::now();

let local_set = tokio::task::LocalSet::new();
let mut abort_handles = HashMap::new();

'main_loop: loop {
Expand Down Expand Up @@ -252,7 +267,7 @@ pub(crate) async fn run(
let action = match result {
Err(status) => match status.code() {
tonic::Code::Cancelled => {
return Ok(());
return Ok(local_set);
}
tonic::Code::DeadlineExceeded => {
continue 'main_loop;
Expand All @@ -279,7 +294,7 @@ pub(crate) async fn run(

match action_type {
ActionType::StartStepRun => {
handle_start_step_run(action_function_task_join_set, &mut abort_handles, &mut dispatcher, workflow_service_client.clone(), namespace, worker_id, &workflows, action, data.clone()).await?;
handle_start_step_run(action_function_task_join_set, &local_set, &mut abort_handles, &mut dispatcher, workflow_service_client.clone(), namespace, worker_id, &workflows, action, data.clone()).await?;
}
ActionType::CancelStepRun => {
handle_cancel_step_run(&mut abort_handles, action).await?;
Expand All @@ -298,5 +313,5 @@ pub(crate) async fn run(
}
}

Ok(())
Ok(local_set)
}
3 changes: 2 additions & 1 deletion src/worker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,13 @@ impl Worker<'_> {
let mut action_function_task_join_set =
tokio::task::JoinSet::<crate::InternalResult<()>>::new();

futures_util::try_join! {
let (_, local_set) = futures_util::try_join! {
heartbeat::run(dispatcher.clone(), &worker_id, heartbeat_interrupt_receiver),
listener::run(&mut action_function_task_join_set, dispatcher, workflow_service_client, namespace, &worker_id, workflows, *listener_v2_timeout, listening_interrupt_receiver, data)
}?;

action_function_task_join_set.shutdown().await;
local_set.await;

Ok(())
}
Expand Down

0 comments on commit f7f4fd9

Please sign in to comment.