From 687127f362d50459e7ef8c228872de00428a0cf2 Mon Sep 17 00:00:00 2001 From: Jakub Wieczorek Date: Sun, 3 Nov 2024 21:42:10 +0100 Subject: [PATCH] Use `tokio_util::task::LocalPoolHandle`. --- Cargo.lock | 352 ++++++++++++++++++++++++++++++++++++ Cargo.toml | 3 + examples/non_send_future.rs | 39 ++++ src/step_function.rs | 11 +- src/worker/listener.rs | 118 ++++++------ src/worker/mod.rs | 10 +- src/workflow.rs | 4 +- 7 files changed, 462 insertions(+), 75 deletions(-) create mode 100644 examples/non_send_future.rs diff --git a/Cargo.lock b/Cargo.lock index 2278a50..f4570f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -322,6 +322,15 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "envy" version = "0.4.2" @@ -365,6 +374,30 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "futures" version = "0.3.31" @@ -510,6 +543,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.0" @@ -528,8 +567,10 @@ dependencies = [ "futures-util", "http", "jsonwebtoken", + "num_cpus", "prost", "prost-types", + "reqwest", "rstest", "secrecy", "serde", @@ -537,6 +578,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", + "tokio-util", "tonic", "tonic-build", "tracing", @@ -623,6 +665,23 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +dependencies = [ + "futures-util", + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "hyper-timeout" version = "0.5.1" @@ -636,6 +695,22 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.9" @@ -661,6 +736,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -681,6 +766,12 @@ dependencies = [ "hashbrown 0.15.0", ] +[[package]] +name = "ipnet" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" + [[package]] name = "itertools" version = "0.13.0" @@ -808,6 +899,23 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nix" version = "0.29.0" @@ -864,6 +972,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.36.5" @@ -879,12 +997,50 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "openssl" +version = "0.10.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-sys" +version = "0.9.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "overload" version = "0.1.1" @@ -972,6 +1128,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + [[package]] name = "powerfmt" version = "0.2.0" @@ -1166,6 +1328,49 @@ version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" +[[package]] +name = "reqwest" +version = "0.12.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" +dependencies = [ + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-tls", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls-pemfile", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "system-configuration", + "tokio", + "tokio-native-tls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "windows-registry", +] + [[package]] name = "ring" version = "0.17.8" @@ -1391,6 +1596,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1492,6 +1709,30 @@ name = "sync_wrapper" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +dependencies = [ + "futures-core", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] [[package]] name = "tempfile" @@ -1567,6 +1808,21 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.40.0" @@ -1595,6 +1851,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.0" @@ -1626,6 +1892,8 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", ] @@ -1807,18 +2075,44 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "unicode-bidi" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" + [[package]] name = "unicode-ident" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] + [[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "url" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "ustr" version = "1.1.0" @@ -1838,6 +2132,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -1885,6 +2185,18 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.95" @@ -1914,6 +2226,16 @@ version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" +[[package]] +name = "web-sys" +version = "0.3.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1936,6 +2258,36 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result", + "windows-strings", + "windows-targets", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index d1c733e..445c08e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ envy = "0.4" futures-util = "0.3" http = "1" jsonwebtoken = "9" +num_cpus = "1" prost = "0.13" prost-types = "0.13" secrecy = { version = "0.10", features = ["serde"] } @@ -18,6 +19,7 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" thiserror = "1" tokio = { version = "1", features = ["fs", "macros", "rt-multi-thread", "sync"] } +tokio-util = { version = "0.7", default-features = false, features = ["rt"] } tonic = { version = "0.12", features = ["tls", "tls-native-roots"] } tracing = "0.1" ustr = { version = "1", features = ["serde"] } @@ -28,5 +30,6 @@ tempfile = "3" [dev-dependencies] dotenv = "0.15" +reqwest = "0.12" rstest = "0.23" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/non_send_future.rs b/examples/non_send_future.rs new file mode 100644 index 0000000..07d0fd9 --- /dev/null +++ b/examples/non_send_future.rs @@ -0,0 +1,39 @@ +use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder}; + +#[derive(serde::Serialize, serde::Deserialize)] +struct HelloOutput { + text: String, +} + +async fn execute_hello(_context: Context, _: serde_json::Value) -> anyhow::Result { + let text = reqwest::get("https://example.org").await?.text().await?; + Ok(HelloOutput { text }) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + dotenv::dotenv().ok(); + tracing_subscriber::fmt() + .with_target(false) + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("hatchet_sdk=debug".parse()?), + ) + .init(); + + let client = Client::new()?; + let mut worker = client.worker("example_spawn_workflow").build(); + worker.register_workflow( + WorkflowBuilder::default() + .name("hello-panic") + .step( + StepBuilder::default() + .name("hello") + .function(&execute_hello) + .build()?, + ) + .build()?, + ); + worker.start().await?; + Ok(()) +} diff --git a/src/step_function.rs b/src/step_function.rs index 4fd2827..c5fe77c 100644 --- a/src/step_function.rs +++ b/src/step_function.rs @@ -105,8 +105,9 @@ impl Context { } pub(crate) type StepFunction = dyn Fn( - Context, - serde_json::Value, -) -> std::panic::AssertUnwindSafe< - futures_util::future::LocalBoxFuture<'static, anyhow::Result>, ->; + Context, + serde_json::Value, + ) -> std::panic::AssertUnwindSafe< + futures_util::future::LocalBoxFuture<'static, anyhow::Result>, + > + Send + + Sync; diff --git a/src/worker/listener.rs b/src/worker/listener.rs index bd780b1..a55aa6c 100644 --- a/src/worker/listener.rs +++ b/src/worker/listener.rs @@ -50,9 +50,8 @@ struct ActionInput { #[allow(clippy::too_many_arguments)] async fn handle_start_step_run( - action_function_task_join_set: &mut tokio::task::JoinSet>, - local_set: &tokio::task::LocalSet, - abort_handles: &mut HashMap, + local_pool_handle: &tokio_util::task::LocalPoolHandle, + join_handles: &mut HashMap>>, dispatcher: &mut DispatcherClient< tonic::service::interceptor::InterceptedService< tonic::transport::Channel, @@ -104,67 +103,65 @@ async fn handle_start_step_run( let worker_id = worker_id.to_string(); let step_run_id = action.step_run_id.clone(); - let abort_handle = action_function_task_join_set.spawn_local_on( - async move { - let context = Context::new( - input.parents, - workflow_run_id, - workflow_step_run_id, - workflow_service_client, - data, - ); - let action_event = match action_callable(context, input.input).catch_unwind().await { - Ok(Ok(output_value)) => step_action_event( - &worker_id, - &action, - StepActionEventType::StepEventTypeCompleted, - serde_json::to_string(&output_value).expect("must succeed"), - ), - Ok(Err(error)) => step_action_event( + let join_handle = local_pool_handle.spawn_pinned(|| async move { + let context = Context::new( + input.parents, + workflow_run_id, + workflow_step_run_id, + workflow_service_client, + data, + ); + let action_event = match action_callable(context, input.input).catch_unwind().await { + Ok(Ok(output_value)) => step_action_event( + &worker_id, + &action, + StepActionEventType::StepEventTypeCompleted, + serde_json::to_string(&output_value).expect("must succeed"), + ), + Ok(Err(error)) => step_action_event( + &worker_id, + &action, + StepActionEventType::StepEventTypeFailed, + error.to_string(), + ), + Err(error) => { + let message = error + .downcast_ref::<&str>() + .map(|value| value.to_string()) + .or_else(|| error.downcast_ref::().map(|value| value.clone())); + step_action_event( &worker_id, &action, StepActionEventType::StepEventTypeFailed, - error.to_string(), - ), - Err(error) => { - let message = error - .downcast_ref::<&str>() - .map(|value| value.to_string()) - .or_else(|| error.downcast_ref::().map(|value| value.clone())); - step_action_event( - &worker_id, - &action, - StepActionEventType::StepEventTypeFailed, - message.unwrap_or_else(|| { - String::from( - "task panicked with a payload that was not a `&str` nor a `String`", - ) - }), - ) - } - }; + message.unwrap_or_else(|| { + String::from( + "task panicked with a payload that was not a `&str` nor a `String`", + ) + }), + ) + } + }; + + dispatcher + .send_step_action_event(action_event) + .await + .map_err(crate::InternalError::CouldNotSendStepStatus)? + .into_inner(); - dispatcher - .send_step_action_event(action_event) - .await - .map_err(crate::InternalError::CouldNotSendStepStatus)? - .into_inner(); + Ok(()) + }); - Ok(()) - }, - local_set, - ); - abort_handles.insert(step_run_id, abort_handle); + join_handles.insert(step_run_id, join_handle); Ok(()) } async fn handle_cancel_step_run( - abort_handles: &mut HashMap, + join_handles: &mut HashMap>>, action: AssignedAction, ) -> crate::InternalResult<()> { - if let Some(abort_handle) = abort_handles.remove(&action.step_run_id) { - abort_handle.abort(); + if let Some(join_handle) = join_handles.remove(&action.step_run_id) { + join_handle.abort(); } else { warn!( "Could not find the abort handle for the workflow run ID: {}", @@ -177,7 +174,7 @@ async fn handle_cancel_step_run( #[allow(clippy::too_many_arguments)] pub(crate) async fn run( - action_function_task_join_set: &mut tokio::task::JoinSet>, + local_pool_handle: &tokio_util::task::LocalPoolHandle, mut dispatcher: DispatcherClient< tonic::service::interceptor::InterceptedService< tonic::transport::Channel, @@ -196,16 +193,15 @@ pub(crate) async fn run( listener_v2_timeout: Option, mut interrupt_receiver: tokio::sync::mpsc::Receiver<()>, data: Arc, -) -> crate::InternalResult { +) -> crate::InternalResult<()> { use futures_util::StreamExt; let mut retries: usize = 0; let mut listen_strategy = ListenStrategy::V2; let connection_attempt = tokio::time::Instant::now(); - - let local_set = tokio::task::LocalSet::new(); - let mut abort_handles = HashMap::new(); + let mut join_handles: HashMap>> = + Default::default(); 'main_loop: loop { info!("Listening…"); @@ -267,7 +263,7 @@ pub(crate) async fn run( let action = match result { Err(status) => match status.code() { tonic::Code::Cancelled => { - return Ok(local_set); + return Ok(()); } tonic::Code::DeadlineExceeded => { continue 'main_loop; @@ -294,10 +290,10 @@ pub(crate) async fn run( match action_type { ActionType::StartStepRun => { - handle_start_step_run(action_function_task_join_set, &local_set, &mut abort_handles, &mut dispatcher, workflow_service_client.clone(), namespace, worker_id, &workflows, action, data.clone()).await?; + handle_start_step_run(local_pool_handle, &mut join_handles, &mut dispatcher, workflow_service_client.clone(), namespace, worker_id, &workflows, action, data.clone()).await?; } ActionType::CancelStepRun => { - handle_cancel_step_run(&mut abort_handles, action).await?; + handle_cancel_step_run(&mut join_handles, action).await?; } ActionType::StartGetGroupKey => { todo!() @@ -313,5 +309,5 @@ pub(crate) async fn run( } } - Ok(local_set) + Ok(()) } diff --git a/src/worker/mod.rs b/src/worker/mod.rs index cf092f8..d39f10d 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -323,17 +323,13 @@ impl Worker<'_> { .map_err(crate::InternalError::CouldNotRegisterWorker)? .into_inner(); - let mut action_function_task_join_set = - tokio::task::JoinSet::>::new(); + let local_pool_handle = tokio_util::task::LocalPoolHandle::new(num_cpus::get()); - let (_, local_set) = futures_util::try_join! { + futures_util::try_join! { heartbeat::run(dispatcher.clone(), &worker_id, heartbeat_interrupt_receiver), - listener::run(&mut action_function_task_join_set, dispatcher, workflow_service_client, namespace, &worker_id, workflows, *listener_v2_timeout, listening_interrupt_receiver, data) + listener::run(&local_pool_handle, dispatcher, workflow_service_client, namespace, &worker_id, workflows, *listener_v2_timeout, listening_interrupt_receiver, data) }?; - action_function_task_join_set.shutdown().await; - local_set.await; - Ok(()) } } diff --git a/src/workflow.rs b/src/workflow.rs index 56f3863..87f18a9 100644 --- a/src/workflow.rs +++ b/src/workflow.rs @@ -41,7 +41,7 @@ where I: serde::de::DeserializeOwned, O: serde::Serialize, Fut: std::future::Future> + 'static, - F: Fn(Context, I) -> Fut, + F: Fn(Context, I) -> Fut + Send + Sync, { fn to_step_function(self) -> Arc { use futures_util::FutureExt; @@ -60,7 +60,7 @@ where I: serde::de::DeserializeOwned, O: serde::ser::Serialize, Fut: std::future::Future> + 'static, - F: Fn(I) -> Fut, + F: Fn(I) -> Fut + Send + Sync, { fn to_step_function(self) -> Arc { use futures_util::FutureExt;