diff --git a/src/worker/listener.rs b/src/worker/listener.rs index 59f2385..bd780b1 100644 --- a/src/worker/listener.rs +++ b/src/worker/listener.rs @@ -51,6 +51,7 @@ struct ActionInput { #[allow(clippy::too_many_arguments)] async fn handle_start_step_run( action_function_task_join_set: &mut tokio::task::JoinSet>, + local_set: &tokio::task::LocalSet, abort_handles: &mut HashMap, dispatcher: &mut DispatcherClient< tonic::service::interceptor::InterceptedService< @@ -103,43 +104,56 @@ async fn handle_start_step_run( let worker_id = worker_id.to_string(); let step_run_id = action.step_run_id.clone(); - let abort_handle = action_function_task_join_set.spawn_local(async move { - let context = Context::new( - input.parents, - workflow_run_id, - workflow_step_run_id, - workflow_service_client, - data, - ); - let action_event = match action_callable(context, input.input).catch_unwind().await { - Ok(Ok(output_value)) => step_action_event( - &worker_id, - &action, - StepActionEventType::StepEventTypeCompleted, - serde_json::to_string(&output_value).expect("must succeed"), - ), - Ok(Err(error)) => step_action_event( - &worker_id, - &action, - StepActionEventType::StepEventTypeFailed, - error.to_string(), - ), - Err(_) => step_action_event( - &worker_id, - &action, - StepActionEventType::StepEventTypeFailed, - "action panicked".to_owned(), - ), - }; - - dispatcher - .send_step_action_event(action_event) - .await - .map_err(crate::InternalError::CouldNotSendStepStatus)? - .into_inner(); - - Ok(()) - }); + let abort_handle = action_function_task_join_set.spawn_local_on( + async move { + let context = Context::new( + input.parents, + workflow_run_id, + workflow_step_run_id, + workflow_service_client, + data, + ); + let action_event = match action_callable(context, input.input).catch_unwind().await { + Ok(Ok(output_value)) => step_action_event( + &worker_id, + &action, + StepActionEventType::StepEventTypeCompleted, + serde_json::to_string(&output_value).expect("must succeed"), + ), + Ok(Err(error)) => step_action_event( + &worker_id, + &action, + StepActionEventType::StepEventTypeFailed, + error.to_string(), + ), + Err(error) => { + let message = error + .downcast_ref::<&str>() + .map(|value| value.to_string()) + .or_else(|| error.downcast_ref::().map(|value| value.clone())); + step_action_event( + &worker_id, + &action, + StepActionEventType::StepEventTypeFailed, + message.unwrap_or_else(|| { + String::from( + "task panicked with a payload that was not a `&str` nor a `String`", + ) + }), + ) + } + }; + + dispatcher + .send_step_action_event(action_event) + .await + .map_err(crate::InternalError::CouldNotSendStepStatus)? + .into_inner(); + + Ok(()) + }, + local_set, + ); abort_handles.insert(step_run_id, abort_handle); Ok(()) @@ -182,7 +196,7 @@ pub(crate) async fn run( listener_v2_timeout: Option, mut interrupt_receiver: tokio::sync::mpsc::Receiver<()>, data: Arc, -) -> crate::InternalResult<()> { +) -> crate::InternalResult { use futures_util::StreamExt; let mut retries: usize = 0; @@ -190,6 +204,7 @@ pub(crate) async fn run( let connection_attempt = tokio::time::Instant::now(); + let local_set = tokio::task::LocalSet::new(); let mut abort_handles = HashMap::new(); 'main_loop: loop { @@ -252,7 +267,7 @@ pub(crate) async fn run( let action = match result { Err(status) => match status.code() { tonic::Code::Cancelled => { - return Ok(()); + return Ok(local_set); } tonic::Code::DeadlineExceeded => { continue 'main_loop; @@ -279,7 +294,7 @@ pub(crate) async fn run( match action_type { ActionType::StartStepRun => { - handle_start_step_run(action_function_task_join_set, &mut abort_handles, &mut dispatcher, workflow_service_client.clone(), namespace, worker_id, &workflows, action, data.clone()).await?; + handle_start_step_run(action_function_task_join_set, &local_set, &mut abort_handles, &mut dispatcher, workflow_service_client.clone(), namespace, worker_id, &workflows, action, data.clone()).await?; } ActionType::CancelStepRun => { handle_cancel_step_run(&mut abort_handles, action).await?; @@ -298,5 +313,5 @@ pub(crate) async fn run( } } - Ok(()) + Ok(local_set) } diff --git a/src/worker/mod.rs b/src/worker/mod.rs index 17a3bcb..cf092f8 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -326,12 +326,13 @@ impl Worker<'_> { let mut action_function_task_join_set = tokio::task::JoinSet::>::new(); - futures_util::try_join! { + let (_, local_set) = futures_util::try_join! { heartbeat::run(dispatcher.clone(), &worker_id, heartbeat_interrupt_receiver), listener::run(&mut action_function_task_join_set, dispatcher, workflow_service_client, namespace, &worker_id, workflows, *listener_v2_timeout, listening_interrupt_receiver, data) }?; action_function_task_join_set.shutdown().await; + local_set.await; Ok(()) }