diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 862f5edd8b..e147d11856 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -102,6 +102,16 @@ jobs: - name: Run doctests run: cargo test --doc + - name: Install valgrind + if: ${{ matrix.os == 'ubuntu-latest' }} + run: sudo apt-get install -y valgrind + shell: bash + + - name: Run memory leaks check + if: ${{ matrix.os == 'ubuntu-latest' }} + run: ci/valgrind-check/run.sh + shell: bash + # NOTE: In GitHub repository settings, the "Require status checks to pass # before merging" branch protection rule ensures that commits are only merged # from branches where specific status checks have passed. These checks are diff --git a/.gitignore b/.gitignore index 695d0464b1..105dae1aa7 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,5 @@ .vscode cargo-timing*.html + +ci/valgrind-check/*.log diff --git a/Cargo.toml b/Cargo.toml index 33dd067d63..da99cb1fdc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ members = [ "commons/zenoh-result", "commons/zenoh-shm", "commons/zenoh-sync", + "commons/zenoh-task", "commons/zenoh-util", "commons/zenoh-runtime", "examples", @@ -51,7 +52,7 @@ members = [ "zenoh-ext", "zenohd", ] -exclude = ["ci/nostd-check"] +exclude = ["ci/nostd-check", "ci/valgrind-check"] [workspace.package] rust-version = "1.66.1" @@ -197,6 +198,7 @@ zenoh-link = { version = "0.11.0-dev", path = "io/zenoh-link" } zenoh-link-commons = { version = "0.11.0-dev", path = "io/zenoh-link-commons" } zenoh = { version = "0.11.0-dev", path = "zenoh", default-features = false } zenoh-runtime = { version = "0.11.0-dev", path = "commons/zenoh-runtime" } +zenoh-task = { version = "0.11.0-dev", path = "commons/zenoh-task" } [profile.dev] debug = true @@ -215,4 +217,4 @@ debug = false # If you want debug symbol in release mode, set the env variab lto = "fat" codegen-units = 1 opt-level = 3 -panic = "abort" +panic = "abort" \ No newline at end of file diff --git a/ci/valgrind-check/Cargo.toml b/ci/valgrind-check/Cargo.toml new file mode 100644 index 0000000000..cf6f6a844b --- /dev/null +++ b/ci/valgrind-check/Cargo.toml @@ -0,0 +1,37 @@ +# +# Copyright (c) 2024 ZettaScale Technology +# +# This program and the accompanying materials are made available under the +# terms of the Eclipse Public License 2.0 which is available at +# http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +# which is available at https://www.apache.org/licenses/LICENSE-2.0. +# +# SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +# +# Contributors: +# ZettaScale Zenoh Team, +# +[package] +name = "valgrind-check" +version = "0.1.0" +repository = "https://github.com/eclipse-zenoh/zenoh" +homepage = "http://zenoh.io" +license = "EPL-2.0 OR Apache-2.0" +edition = "2021" +categories = ["network-programming"] +description = "Internal crate for zenoh." + +[dependencies] +tokio = { version = "1.35.1", features = ["rt-multi-thread", "time", "io-std"] } +env_logger = "0.11.0" +futures = "0.3.25" +zenoh = { path = "../../zenoh/" } +zenoh-runtime = { path = "../../commons/zenoh-runtime/" } + +[[bin]] +name = "pub_sub" +path = "src/pub_sub/bin/z_pub_sub.rs" + +[[bin]] +name = "queryable_get" +path = "src/queryable_get/bin/z_queryable_get.rs" diff --git a/ci/valgrind-check/run.sh b/ci/valgrind-check/run.sh new file mode 100755 index 0000000000..7e2a7dd1a8 --- /dev/null +++ b/ci/valgrind-check/run.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -e +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +function check_leaks { + echo "Checking $1 for memory leaks" + valgrind --leak-check=full --num-callers=50 --log-file="$SCRIPT_DIR/$1_leaks.log" $SCRIPT_DIR/target/debug/$1 + num_leaks=$(grep 'ERROR SUMMARY: [0-9]+' -Eo "$SCRIPT_DIR/$1_leaks.log" | grep '[0-9]+' -Eo) + echo "Detected $num_leaks memory leaks" + if (( num_leaks == 0 )) + then + return 0 + else + cat $SCRIPT_DIR/$1_leaks.log + return -1 + fi +} + +cargo build --manifest-path=$SCRIPT_DIR/Cargo.toml +check_leaks "queryable_get" +check_leaks "pub_sub" \ No newline at end of file diff --git a/ci/valgrind-check/src/pub_sub/bin/z_pub_sub.rs b/ci/valgrind-check/src/pub_sub/bin/z_pub_sub.rs new file mode 100644 index 0000000000..fac3437f39 --- /dev/null +++ b/ci/valgrind-check/src/pub_sub/bin/z_pub_sub.rs @@ -0,0 +1,58 @@ +// +// Copyright (c) 2023 ZettaScale Technology +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// +// Contributors: +// ZettaScale Zenoh Team, +// +use std::time::Duration; +use zenoh::config::Config; +use zenoh::prelude::r#async::*; + +#[tokio::main] +async fn main() { + let _z = zenoh_runtime::ZRuntimePoolGuard; + env_logger::init(); + + let pub_key_expr = KeyExpr::try_from("test/valgrind/data").unwrap(); + let sub_key_expr = KeyExpr::try_from("test/valgrind/**").unwrap(); + + println!("Declaring Publisher on '{pub_key_expr}'..."); + let pub_session = zenoh::open(Config::default()).res().await.unwrap(); + let publisher = pub_session + .declare_publisher(&pub_key_expr) + .res() + .await + .unwrap(); + + println!("Declaring Subscriber on '{sub_key_expr}'..."); + let sub_session = zenoh::open(Config::default()).res().await.unwrap(); + let _subscriber = sub_session + .declare_subscriber(&sub_key_expr) + .callback(|sample| { + println!( + ">> [Subscriber] Received {} ('{}': '{}')", + sample.kind, + sample.key_expr.as_str(), + sample.value + ); + }) + .res() + .await + .unwrap(); + + for idx in 0..5 { + tokio::time::sleep(Duration::from_secs(1)).await; + let buf = format!("[{idx:4}] data"); + println!("Putting Data ('{}': '{}')...", &pub_key_expr, buf); + publisher.put(buf).res().await.unwrap(); + } + + tokio::time::sleep(Duration::from_secs(1)).await; +} diff --git a/ci/valgrind-check/src/queryable_get/bin/z_queryable_get.rs b/ci/valgrind-check/src/queryable_get/bin/z_queryable_get.rs new file mode 100644 index 0000000000..102b6a036c --- /dev/null +++ b/ci/valgrind-check/src/queryable_get/bin/z_queryable_get.rs @@ -0,0 +1,71 @@ +// +// Copyright (c) 2023 ZettaScale Technology +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// +// Contributors: +// ZettaScale Zenoh Team, +// +use std::convert::TryFrom; +use std::time::Duration; +use zenoh::config::Config; +use zenoh::prelude::r#async::*; + +#[tokio::main] +async fn main() { + let _z = zenoh_runtime::ZRuntimePoolGuard; + env_logger::init(); + + let queryable_key_expr = KeyExpr::try_from("test/valgrind/data").unwrap(); + let get_selector = Selector::try_from("test/valgrind/**").unwrap(); + + println!("Declaring Queryable on '{queryable_key_expr}'..."); + let queryable_session = zenoh::open(Config::default()).res().await.unwrap(); + let _queryable = queryable_session + .declare_queryable(&queryable_key_expr.clone()) + .callback(move |query| { + println!(">> Handling query '{}'", query.selector()); + let reply = Ok(Sample::new( + queryable_key_expr.clone(), + query.value().unwrap().clone(), + )); + zenoh_runtime::ZRuntime::Application.block_in_place( + async move { query.reply(reply).res().await.unwrap(); } + ); + }) + .complete(true) + .res() + .await + .unwrap(); + + println!("Declaring Get session for '{get_selector}'..."); + let get_session = zenoh::open(Config::default()).res().await.unwrap(); + + for idx in 0..5 { + tokio::time::sleep(Duration::from_secs(1)).await; + println!("Sending Query '{get_selector}'..."); + let replies = get_session + .get(&get_selector) + .with_value(idx) + .target(QueryTarget::All) + .res() + .await + .unwrap(); + while let Ok(reply) = replies.recv_async().await { + match reply.sample { + Ok(sample) => println!( + ">> Received ('{}': '{}')", + sample.key_expr.as_str(), + sample.value, + ), + Err(err) => println!(">> Received (ERROR: '{}')", String::try_from(&err).unwrap()), + } + } + } + tokio::time::sleep(Duration::from_secs(1)).await; +} diff --git a/commons/zenoh-runtime/Cargo.toml b/commons/zenoh-runtime/Cargo.toml index b7aa15d634..e5bd64b8c5 100644 --- a/commons/zenoh-runtime/Cargo.toml +++ b/commons/zenoh-runtime/Cargo.toml @@ -13,6 +13,7 @@ description = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +futures = { workspace = true } lazy_static = { workspace = true } zenoh-result = { workspace = true, features = ["std"] } zenoh-collections = { workspace = true, features = ["std"] } diff --git a/commons/zenoh-runtime/src/lib.rs b/commons/zenoh-runtime/src/lib.rs index ac040af838..6b62fdf7b7 100644 --- a/commons/zenoh-runtime/src/lib.rs +++ b/commons/zenoh-runtime/src/lib.rs @@ -22,6 +22,7 @@ use std::{ atomic::{AtomicUsize, Ordering}, OnceLock, }, + time::Duration, }; use tokio::runtime::{Handle, Runtime, RuntimeFlavor}; use zenoh_collections::Properties; @@ -147,6 +148,41 @@ impl ZRuntimePool { } } +// If there are any blocking tasks spawned by ZRuntimes, the function will block until they return. +impl Drop for ZRuntimePool { + fn drop(&mut self) { + let handles: Vec<_> = self + .0 + .drain() + .map(|(name, mut rt)| { + std::thread::spawn(move || { + rt.take() + .unwrap_or_else(|| panic!("ZRuntime {name:?} failed to shutdown.")) + .shutdown_timeout(Duration::from_secs(1)) + }) + }) + .collect(); + + for hd in handles { + let _ = hd.join(); + } + } +} + +/// In order to prevent valgrind reporting memory leaks, +/// we use this guard to force drop ZRUNTIME_POOL since Rust does not drop static variables. +#[doc(hidden)] +pub struct ZRuntimePoolGuard; + +impl Drop for ZRuntimePoolGuard { + fn drop(&mut self) { + unsafe { + let ptr = &(*ZRUNTIME_POOL) as *const ZRuntimePool; + std::mem::drop(ptr.read()); + } + } +} + #[derive(Debug, Copy, Clone)] pub struct ZRuntimeConfig { pub application_threads: usize, diff --git a/commons/zenoh-task/Cargo.toml b/commons/zenoh-task/Cargo.toml new file mode 100644 index 0000000000..bf52f13735 --- /dev/null +++ b/commons/zenoh-task/Cargo.toml @@ -0,0 +1,33 @@ +# +# Copyright (c) 2024 ZettaScale Technology +# +# This program and the accompanying materials are made available under the +# terms of the Eclipse Public License 2.0 which is available at +# http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +# which is available at https://www.apache.org/licenses/LICENSE-2.0. +# +# SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +# +# Contributors: +# ZettaScale Zenoh Team, +# +[package] +rust-version = { workspace = true } +name = "zenoh-task" +version = { workspace = true } +repository = { workspace = true } +homepage = { workspace = true } +authors = {workspace = true } +edition = { workspace = true } +license = { workspace = true } +categories = { workspace = true } +description = "Internal crate for zenoh." +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio = { workspace = true, features = ["default", "sync"] } +futures = { workspace = true } +log = { workspace = true } +zenoh-core = { workspace = true } +zenoh-runtime = { workspace = true } +tokio-util = { workspace = true, features = ["rt"] } \ No newline at end of file diff --git a/commons/zenoh-task/src/lib.rs b/commons/zenoh-task/src/lib.rs new file mode 100644 index 0000000000..7b305cee75 --- /dev/null +++ b/commons/zenoh-task/src/lib.rs @@ -0,0 +1,191 @@ +// +// Copyright (c) 2024 ZettaScale Technology +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// +// Contributors: +// ZettaScale Zenoh Team, +// + +//! ⚠️ WARNING ⚠️ +//! +//! This module is intended for Zenoh's internal use. +//! +//! [Click here for Zenoh's documentation](../zenoh/index.html) + +use futures::future::FutureExt; +use std::future::Future; +use std::time::Duration; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; +use tokio_util::task::TaskTracker; +use zenoh_core::{ResolveFuture, SyncResolve}; +use zenoh_runtime::ZRuntime; + +#[derive(Clone)] +pub struct TaskController { + tracker: TaskTracker, + token: CancellationToken, +} + +impl Default for TaskController { + fn default() -> Self { + TaskController { + tracker: TaskTracker::new(), + token: CancellationToken::new(), + } + } +} + +impl TaskController { + /// Spawns a task that can be later terminated by call to [`TaskController::terminate_all()`]. + /// Task output is ignored. + pub fn spawn_abortable(&self, future: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + T: Send + 'static, + { + let token = self.token.child_token(); + let task = async move { + tokio::select! { + _ = token.cancelled() => {}, + _ = future => {} + } + }; + self.tracker.spawn(task) + } + + /// Spawns a task using a specified runtime that can be later terminated by call to [`TaskController::terminate_all()`]. + /// Task output is ignored. + pub fn spawn_abortable_with_rt(&self, rt: ZRuntime, future: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + T: Send + 'static, + { + let token = self.token.child_token(); + let task = async move { + tokio::select! { + _ = token.cancelled() => {}, + _ = future => {} + } + }; + self.tracker.spawn_on(task, &rt) + } + + pub fn get_cancellation_token(&self) -> CancellationToken { + self.token.child_token() + } + + /// Spawns a task that can be cancelled via cancellation of a token obtained by [`TaskController::get_cancellation_token()`], + /// or that can run to completion in finite amount of time. + /// It can be later terminated by call to [`TaskController::terminate_all()`]. + pub fn spawn(&self, future: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + T: Send + 'static, + { + self.tracker.spawn(future.map(|_f| ())) + } + + /// Spawns a task that can be cancelled via cancellation of a token obtained by [`TaskController::get_cancellation_token()`], + /// or that can run to completion in finite amount of time, using a specified runtime. + /// It can be later aborted by call to [`TaskController::terminate_all()`]. + pub fn spawn_with_rt(&self, rt: ZRuntime, future: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + T: Send + 'static, + { + self.tracker.spawn_on(future.map(|_f| ()), &rt) + } + + /// Attempts tp terminate all previously spawned tasks + /// The caller must ensure that all tasks spawned with [`TaskController::spawn()`] + /// or [`TaskController::spawn_with_rt()`] can yield in finite amount of time either because they will run to completion + /// or due to cancellation of token acquired via [`TaskController::get_cancellation_token()`]. + /// Tasks spawned with [`TaskController::spawn_abortable()`] or [`TaskController::spawn_abortable_with_rt()`] will be aborted (i.e. terminated upon next await call). + /// The call blocks until all tasks yield or timeout duration expires. + /// Returns 0 in case of success, number of non terminated tasks otherwise. + pub fn terminate_all(&self, timeout: Duration) -> usize { + ResolveFuture::new(async move { self.terminate_all_async(timeout).await }).res_sync() + } + + /// Async version of [`TaskController::terminate_all()`]. + pub async fn terminate_all_async(&self, timeout: Duration) -> usize { + self.tracker.close(); + self.token.cancel(); + if tokio::time::timeout(timeout, self.tracker.wait()) + .await + .is_err() + { + log::error!("Failed to terminate {} tasks", self.tracker.len()); + return self.tracker.len(); + } + 0 + } +} + +pub struct TerminatableTask { + handle: JoinHandle<()>, + token: CancellationToken, +} + +impl TerminatableTask { + pub fn create_cancellation_token() -> CancellationToken { + CancellationToken::new() + } + + /// Spawns a task that can be later terminated by [`TerminatableTask::terminate()`]. + /// Prior to termination attempt the specified cancellation token will be cancelled. + pub fn spawn(rt: ZRuntime, future: F, token: CancellationToken) -> TerminatableTask + where + F: Future + Send + 'static, + T: Send + 'static, + { + TerminatableTask { + handle: rt.spawn(future.map(|_f| ())), + token, + } + } + + /// Spawns a task that can be later aborted by [`TerminatableTask::terminate()`]. + pub fn spawn_abortable(rt: ZRuntime, future: F) -> TerminatableTask + where + F: Future + Send + 'static, + T: Send + 'static, + { + let token = CancellationToken::new(); + let token2 = token.clone(); + let task = async move { + tokio::select! { + _ = token2.cancelled() => {}, + _ = future => {} + } + }; + + TerminatableTask { + handle: rt.spawn(task), + token, + } + } + + /// Attempts to terminate the task. + /// Returns true if task completed / aborted within timeout duration, false otherwise. + pub fn terminate(self, timeout: Duration) -> bool { + ResolveFuture::new(async move { self.terminate_async(timeout).await }).res_sync() + } + + /// Async version of [`TerminatableTask::terminate()`]. + pub async fn terminate_async(self, timeout: Duration) -> bool { + self.token.cancel(); + if tokio::time::timeout(timeout, self.handle).await.is_err() { + log::error!("Failed to terminate the task"); + return false; + }; + true + } +} diff --git a/io/zenoh-links/zenoh-link-unixpipe/src/unix/unicast.rs b/io/zenoh-links/zenoh-link-unixpipe/src/unix/unicast.rs index 0a0aebe730..eb8ee05d87 100644 --- a/io/zenoh-links/zenoh-link-unixpipe/src/unix/unicast.rs +++ b/io/zenoh-links/zenoh-link-unixpipe/src/unix/unicast.rs @@ -30,8 +30,9 @@ use std::sync::Arc; use tokio::fs::remove_file; use tokio::io::unix::AsyncFd; use tokio::io::Interest; +use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; -use zenoh_core::{zasyncread, zasyncwrite}; +use zenoh_core::{zasyncread, zasyncwrite, ResolveFuture, SyncResolve}; use zenoh_protocol::core::{EndPoint, Locator}; use zenoh_runtime::ZRuntime; @@ -285,6 +286,7 @@ async fn handle_incoming_connections( struct UnicastPipeListener { uplink_locator: Locator, token: CancellationToken, + handle: JoinHandle<()>, } impl UnicastPipeListener { async fn listen(endpoint: EndPoint, manager: Arc) -> ZResult { @@ -300,7 +302,7 @@ impl UnicastPipeListener { // WARN: The spawn_blocking is mandatory verified by the ping/pong test // create listening task - tokio::task::spawn_blocking(move || { + let handle = tokio::task::spawn_blocking(move || { ZRuntime::Acceptor.block_on(async move { loop { tokio::select! { @@ -322,11 +324,13 @@ impl UnicastPipeListener { Ok(Self { uplink_locator: local, token, + handle, }) } fn stop_listening(self) { self.token.cancel(); + let _ = ResolveFuture::new(self.handle).res_sync(); } } diff --git a/io/zenoh-transport/Cargo.toml b/io/zenoh-transport/Cargo.toml index 6f18f7cc5c..5304a9fa17 100644 --- a/io/zenoh-transport/Cargo.toml +++ b/io/zenoh-transport/Cargo.toml @@ -83,6 +83,7 @@ zenoh-shm = { workspace = true, optional = true } zenoh-sync = { workspace = true } zenoh-util = { workspace = true } zenoh-runtime = { workspace = true } +zenoh-task = { workspace = true } diff --git a/io/zenoh-transport/src/manager.rs b/io/zenoh-transport/src/manager.rs index f16a68cfba..f97c126f8b 100644 --- a/io/zenoh-transport/src/manager.rs +++ b/io/zenoh-transport/src/manager.rs @@ -24,7 +24,6 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex as AsyncMutex; -use tokio_util::sync::CancellationToken; use zenoh_config::{Config, LinkRxConf, QueueConf, QueueSizeConf}; use zenoh_crypto::{BlockCipher, PseudoRng}; use zenoh_link::NewLinkChannelSender; @@ -34,6 +33,7 @@ use zenoh_protocol::{ VERSION, }; use zenoh_result::{bail, ZResult}; +use zenoh_task::TaskController; /// # Examples /// ``` @@ -335,7 +335,7 @@ pub struct TransportManager { pub(crate) new_unicast_link_sender: NewLinkChannelSender, #[cfg(feature = "stats")] pub(crate) stats: Arc, - pub(crate) token: CancellationToken, + pub(crate) task_controller: TaskController, } impl TransportManager { @@ -357,32 +357,27 @@ impl TransportManager { new_unicast_link_sender, #[cfg(feature = "stats")] stats: std::sync::Arc::new(crate::stats::TransportStats::default()), - token: CancellationToken::new(), + task_controller: TaskController::default(), }; // @TODO: this should be moved into the unicast module - zenoh_runtime::ZRuntime::Net.spawn({ - let this = this.clone(); - let token = this.token.clone(); - async move { - // while let Ok(link) = new_unicast_link_receiver.recv_async().await { - // this.handle_new_link_unicast(link).await; - // } - loop { - tokio::select! { - res = new_unicast_link_receiver.recv_async() => { - if let Ok(link) = res { - this.handle_new_link_unicast(link).await; - } - } - - _ = token.cancelled() => { - break; + let cancellation_token = this.task_controller.get_cancellation_token(); + this.task_controller + .spawn_with_rt(zenoh_runtime::ZRuntime::Net, { + let this = this.clone(); + async move { + loop { + tokio::select! { + res = new_unicast_link_receiver.recv_async() => { + if let Ok(link) = res { + this.handle_new_link_unicast(link).await; + } + } + _ = cancellation_token.cancelled() => { break; } } } } - } - }); + }); this } @@ -402,10 +397,9 @@ impl TransportManager { pub async fn close(&self) { self.close_unicast().await; - // TODO: Check this - self.token.cancel(); - // WARN: depends on the auto-close of tokio runtime after dropped - // self.tx_executor.runtime.shutdown_background(); + self.task_controller + .terminate_all_async(Duration::from_secs(10)) + .await; } /*************************************/ diff --git a/io/zenoh-transport/src/multicast/transport.rs b/io/zenoh-transport/src/multicast/transport.rs index c647730390..2e7f54098d 100644 --- a/io/zenoh-transport/src/multicast/transport.rs +++ b/io/zenoh-transport/src/multicast/transport.rs @@ -39,6 +39,7 @@ use zenoh_protocol::{ transport::{close, Join}, }; use zenoh_result::{bail, ZResult}; +use zenoh_task::TaskController; // use zenoh_util::{Timed, TimedEvent, TimedHandle, Timer}; /*************************************/ @@ -82,8 +83,8 @@ pub(crate) struct TransportMulticastInner { pub(super) link: Arc>>, // The callback pub(super) callback: Arc>>>, - // token for safe cancellation - token: CancellationToken, + // Task controller for safe task cancellation + task_controller: TaskController, // Transport statistics #[cfg(feature = "stats")] pub(super) stats: Arc, @@ -115,7 +116,7 @@ impl TransportMulticastInner { locator: config.link.link.get_dst().to_owned(), link: Arc::new(RwLock::new(None)), callback: Arc::new(RwLock::new(None)), - token: CancellationToken::new(), + task_controller: TaskController::default(), #[cfg(feature = "stats")] stats, }; @@ -183,8 +184,9 @@ impl TransportMulticastInner { cb.closed(); } - // TODO(yuyuan): use CancellationToken to unify the termination with the above - self.token.cancel(); + self.task_controller + .terminate_all_async(Duration::from_secs(10)) + .await; Ok(()) } @@ -369,7 +371,7 @@ impl TransportMulticastInner { // TODO(yuyuan): refine the clone behaviors let is_active = Arc::new(AtomicBool::new(false)); let c_is_active = is_active.clone(); - let token = self.token.child_token(); + let token = self.task_controller.get_cancellation_token(); let c_token = token.clone(); let c_self = self.clone(); let c_locator = locator.clone(); @@ -389,8 +391,8 @@ impl TransportMulticastInner { let _ = c_self.del_peer(&c_locator, close::reason::EXPIRED); }; - // TODO(yuyuan): Put it into TaskTracker or store as JoinHandle - zenoh_runtime::ZRuntime::Acceptor.spawn(task); + self.task_controller + .spawn_with_rt(zenoh_runtime::ZRuntime::Acceptor, task); // TODO(yuyuan): Integrate the above async task into TransportMulticastPeer // Store the new peer diff --git a/io/zenoh-transport/src/unicast/manager.rs b/io/zenoh-transport/src/unicast/manager.rs index eaf25cd2a3..8a63f4f630 100644 --- a/io/zenoh-transport/src/unicast/manager.rs +++ b/io/zenoh-transport/src/unicast/manager.rs @@ -744,17 +744,18 @@ impl TransportManager { // Spawn a task to accept the link let c_manager = self.clone(); - zenoh_runtime::ZRuntime::Acceptor.spawn(async move { - if let Err(e) = tokio::time::timeout( - c_manager.config.unicast.accept_timeout, - super::establishment::accept::accept_link(link, &c_manager), - ) - .await - { - log::debug!("{}", e); - } - incoming_counter.fetch_sub(1, SeqCst); - }); + self.task_controller + .spawn_with_rt(zenoh_runtime::ZRuntime::Acceptor, async move { + if let Err(e) = tokio::time::timeout( + c_manager.config.unicast.accept_timeout, + super::establishment::accept::accept_link(link, &c_manager), + ) + .await + { + log::debug!("{}", e); + } + incoming_counter.fetch_sub(1, SeqCst); + }); } } diff --git a/zenoh-ext/Cargo.toml b/zenoh-ext/Cargo.toml index 372eaf234a..6a0488cb54 100644 --- a/zenoh-ext/Cargo.toml +++ b/zenoh-ext/Cargo.toml @@ -45,6 +45,7 @@ zenoh-result = { workspace = true } zenoh-sync = { workspace = true } zenoh-util = { workspace = true } zenoh-runtime = { workspace = true } +zenoh-task = { workspace = true } [dev-dependencies] clap = { workspace = true, features = ["derive"] } diff --git a/zenoh-ext/src/publication_cache.rs b/zenoh-ext/src/publication_cache.rs index c8c5679c91..eec398d592 100644 --- a/zenoh-ext/src/publication_cache.rs +++ b/zenoh-ext/src/publication_cache.rs @@ -11,16 +11,17 @@ // Contributors: // ZettaScale Zenoh Team, // -use flume::{bounded, Sender}; use std::collections::{HashMap, VecDeque}; use std::convert::TryInto; use std::future::Ready; +use std::time::Duration; use zenoh::prelude::r#async::*; use zenoh::queryable::{Query, Queryable}; use zenoh::subscriber::FlumeSubscriber; use zenoh::SessionRef; use zenoh_core::{AsyncResolve, Resolvable, SyncResolve}; use zenoh_result::{bail, ZResult}; +use zenoh_task::TerminatableTask; use zenoh_util::core::ResolveFuture; /// The builder of PublicationCache, allowing to configure it. @@ -110,7 +111,7 @@ impl<'a> AsyncResolve for PublicationCacheBuilder<'a, '_, '_> { pub struct PublicationCache<'a> { local_sub: FlumeSubscriber<'a>, _queryable: Queryable<'a, flume::Receiver>, - _stoptx: Sender, + task: TerminatableTask, } impl<'a> PublicationCache<'a> { @@ -166,58 +167,46 @@ impl<'a> PublicationCache<'a> { let history = conf.history; // TODO(yuyuan): use CancellationToken to manage it - let (stoptx, stoprx) = bounded::(1); - zenoh_runtime::ZRuntime::TX.spawn(async move { - let mut cache: HashMap> = - HashMap::with_capacity(resources_limit.unwrap_or(32)); - let limit = resources_limit.unwrap_or(usize::MAX); + let token = TerminatableTask::create_cancellation_token(); + let token2 = token.clone(); + let task = TerminatableTask::spawn( + zenoh_runtime::ZRuntime::TX, + async move { + let mut cache: HashMap> = + HashMap::with_capacity(resources_limit.unwrap_or(32)); + let limit = resources_limit.unwrap_or(usize::MAX); + loop { + tokio::select! { + // on publication received by the local subscriber, store it + sample = sub_recv.recv_async() => { + if let Ok(sample) = sample { + let queryable_key_expr: KeyExpr<'_> = if let Some(prefix) = &queryable_prefix { + prefix.join(&sample.key_expr).unwrap().into() + } else { + sample.key_expr.clone() + }; - loop { - tokio::select! { - // on publication received by the local subscriber, store it - sample = sub_recv.recv_async() => { - if let Ok(sample) = sample { - let queryable_key_expr: KeyExpr<'_> = if let Some(prefix) = &queryable_prefix { - prefix.join(&sample.key_expr).unwrap().into() - } else { - sample.key_expr.clone() - }; - - if let Some(queue) = cache.get_mut(queryable_key_expr.as_keyexpr()) { - if queue.len() >= history { - queue.pop_front(); + if let Some(queue) = cache.get_mut(queryable_key_expr.as_keyexpr()) { + if queue.len() >= history { + queue.pop_front(); + } + queue.push_back(sample); + } else if cache.len() >= limit { + log::error!("PublicationCache on {}: resource_limit exceeded - can't cache publication for a new resource", + pub_key_expr); + } else { + let mut queue: VecDeque = VecDeque::new(); + queue.push_back(sample); + cache.insert(queryable_key_expr.into(), queue); } - queue.push_back(sample); - } else if cache.len() >= limit { - log::error!("PublicationCache on {}: resource_limit exceeded - can't cache publication for a new resource", - pub_key_expr); - } else { - let mut queue: VecDeque = VecDeque::new(); - queue.push_back(sample); - cache.insert(queryable_key_expr.into(), queue); } - } - }, + }, - // on query, reply with cach content - query = quer_recv.recv_async() => { - if let Ok(query) = query { - if !query.selector().key_expr.as_str().contains('*') { - if let Some(queue) = cache.get(query.selector().key_expr.as_keyexpr()) { - for sample in queue { - if let (Ok(Some(time_range)), Some(timestamp)) = (query.selector().time_range(), sample.timestamp) { - if !time_range.contains(timestamp.get_time().to_system_time()){ - continue; - } - } - if let Err(e) = query.reply(Ok(sample.clone())).res_async().await { - log::warn!("Error replying to query: {}", e); - } - } - } - } else { - for (key_expr, queue) in cache.iter() { - if query.selector().key_expr.intersects(unsafe{ keyexpr::from_str_unchecked(key_expr) }) { + // on query, reply with cach content + query = quer_recv.recv_async() => { + if let Ok(query) = query { + if !query.selector().key_expr.as_str().contains('*') { + if let Some(queue) = cache.get(query.selector().key_expr.as_keyexpr()) { for sample in queue { if let (Ok(Some(time_range)), Some(timestamp)) = (query.selector().time_range(), sample.timestamp) { if !time_range.contains(timestamp.get_time().to_system_time()){ @@ -229,21 +218,35 @@ impl<'a> PublicationCache<'a> { } } } + } else { + for (key_expr, queue) in cache.iter() { + if query.selector().key_expr.intersects(unsafe{ keyexpr::from_str_unchecked(key_expr) }) { + for sample in queue { + if let (Ok(Some(time_range)), Some(timestamp)) = (query.selector().time_range(), sample.timestamp) { + if !time_range.contains(timestamp.get_time().to_system_time()){ + continue; + } + } + if let Err(e) = query.reply(Ok(sample.clone())).res_async().await { + log::warn!("Error replying to query: {}", e); + } + } + } + } } } - } - }, - - // When stoptx is dropped, stop the task - _ = stoprx.recv_async() => return + }, + _ = token2.cancelled() => return + } } - } - }); + }, + token, + ); Ok(PublicationCache { local_sub, _queryable: queryable, - _stoptx: stoptx, + task, }) } @@ -254,11 +257,11 @@ impl<'a> PublicationCache<'a> { let PublicationCache { _queryable, local_sub, - _stoptx, + task, } = self; _queryable.undeclare().res_async().await?; local_sub.undeclare().res_async().await?; - drop(_stoptx); + task.terminate(Duration::from_secs(10)); Ok(()) }) } diff --git a/zenoh/Cargo.toml b/zenoh/Cargo.toml index 955e362bc7..144e5dbf72 100644 --- a/zenoh/Cargo.toml +++ b/zenoh/Cargo.toml @@ -106,6 +106,7 @@ zenoh-sync = { workspace = true } zenoh-transport = { workspace = true } zenoh-util = { workspace = true } zenoh-runtime = { workspace = true } +zenoh-task = { workspace = true } [build-dependencies] rustc_version = { workspace = true } diff --git a/zenoh/src/net/primitives/mux.rs b/zenoh/src/net/primitives/mux.rs index 442c040624..5c473e8ad8 100644 --- a/zenoh/src/net/primitives/mux.rs +++ b/zenoh/src/net/primitives/mux.rs @@ -13,7 +13,7 @@ // use super::{EPrimitives, Primitives}; use crate::net::routing::{ - dispatcher::face::Face, + dispatcher::face::{Face, WeakFace}, interceptor::{InterceptorTrait, InterceptorsChain}, RoutingContext, }; @@ -25,7 +25,7 @@ use zenoh_transport::{multicast::TransportMulticast, unicast::TransportUnicast}; pub struct Mux { pub handler: TransportUnicast, - pub(crate) face: OnceLock, + pub(crate) face: OnceLock, pub(crate) interceptor: InterceptorsChain, } @@ -48,14 +48,14 @@ impl Primitives for Mux { }; if self.interceptor.interceptors.is_empty() { let _ = self.handler.schedule(msg); - } else if let Some(face) = self.face.get() { + } else if let Some(face) = self.face.get().and_then(|f| f.upgrade()) { let ctx = RoutingContext::new_out(msg, face.clone()); let prefix = ctx .wire_expr() .and_then(|we| (!we.has_suffix()).then(|| ctx.prefix())) .flatten() .cloned(); - let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(face)); + let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(&face)); if let Some(ctx) = self.interceptor.intercept(ctx, cache) { let _ = self.handler.schedule(ctx.msg); } @@ -72,14 +72,14 @@ impl Primitives for Mux { }; if self.interceptor.interceptors.is_empty() { let _ = self.handler.schedule(msg); - } else if let Some(face) = self.face.get() { + } else if let Some(face) = self.face.get().and_then(|f| f.upgrade()) { let ctx = RoutingContext::new_out(msg, face.clone()); let prefix = ctx .wire_expr() .and_then(|we| (!we.has_suffix()).then(|| ctx.prefix())) .flatten() .cloned(); - let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(face)); + let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(&face)); if let Some(ctx) = self.interceptor.intercept(ctx, cache) { let _ = self.handler.schedule(ctx.msg); } @@ -96,14 +96,14 @@ impl Primitives for Mux { }; if self.interceptor.interceptors.is_empty() { let _ = self.handler.schedule(msg); - } else if let Some(face) = self.face.get() { + } else if let Some(face) = self.face.get().and_then(|f| f.upgrade()) { let ctx = RoutingContext::new_out(msg, face.clone()); let prefix = ctx .wire_expr() .and_then(|we| (!we.has_suffix()).then(|| ctx.prefix())) .flatten() .cloned(); - let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(face)); + let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(&face)); if let Some(ctx) = self.interceptor.intercept(ctx, cache) { let _ = self.handler.schedule(ctx.msg); } @@ -120,14 +120,14 @@ impl Primitives for Mux { }; if self.interceptor.interceptors.is_empty() { let _ = self.handler.schedule(msg); - } else if let Some(face) = self.face.get() { + } else if let Some(face) = self.face.get().and_then(|f| f.upgrade()) { let ctx = RoutingContext::new_out(msg, face.clone()); let prefix = ctx .wire_expr() .and_then(|we| (!we.has_suffix()).then(|| ctx.prefix())) .flatten() .cloned(); - let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(face)); + let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(&face)); if let Some(ctx) = self.interceptor.intercept(ctx, cache) { let _ = self.handler.schedule(ctx.msg); } @@ -144,14 +144,14 @@ impl Primitives for Mux { }; if self.interceptor.interceptors.is_empty() { let _ = self.handler.schedule(msg); - } else if let Some(face) = self.face.get() { + } else if let Some(face) = self.face.get().and_then(|f| f.upgrade()) { let ctx = RoutingContext::new_out(msg, face.clone()); let prefix = ctx .wire_expr() .and_then(|we| (!we.has_suffix()).then(|| ctx.prefix())) .flatten() .cloned(); - let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(face)); + let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(&face)); if let Some(ctx) = self.interceptor.intercept(ctx, cache) { let _ = self.handler.schedule(ctx.msg); } @@ -199,14 +199,14 @@ impl EPrimitives for Mux { }; if self.interceptor.interceptors.is_empty() { let _ = self.handler.schedule(msg); - } else if let Some(face) = self.face.get() { + } else if let Some(face) = self.face.get().and_then(|f| f.upgrade()) { let ctx = RoutingContext::new_out(msg, face.clone()); let prefix = ctx .wire_expr() .and_then(|we| (!we.has_suffix()).then(|| ctx.prefix())) .flatten() .cloned(); - let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(face)); + let cache = prefix.as_ref().and_then(|p| p.get_egress_cache(&face)); if let Some(ctx) = self.interceptor.intercept(ctx, cache) { let _ = self.handler.schedule(ctx.msg); } diff --git a/zenoh/src/net/routing/dispatcher/face.rs b/zenoh/src/net/routing/dispatcher/face.rs index 6ef5c063d0..3d7d850e6b 100644 --- a/zenoh/src/net/routing/dispatcher/face.rs +++ b/zenoh/src/net/routing/dispatcher/face.rs @@ -20,13 +20,14 @@ use crate::KeyExpr; use std::any::Any; use std::collections::HashMap; use std::fmt; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use zenoh_protocol::zenoh::RequestBody; use zenoh_protocol::{ core::{ExprId, WhatAmI, ZenohId}, network::{Mapping, Push, Request, RequestId, Response, ResponseFinal}, }; use zenoh_sync::get_mut_unchecked; +use zenoh_task::TaskController; use zenoh_transport::multicast::TransportMulticast; #[cfg(feature = "stats")] use zenoh_transport::stats::TransportStats; @@ -45,6 +46,7 @@ pub struct FaceState { pub(crate) mcast_group: Option, pub(crate) in_interceptors: Option>, pub(crate) hat: Box, + pub(crate) task_controller: TaskController, } impl FaceState { @@ -73,6 +75,7 @@ impl FaceState { mcast_group, in_interceptors, hat, + task_controller: TaskController::default(), }) } @@ -150,12 +153,36 @@ impl fmt::Display for FaceState { } } +#[derive(Clone)] +pub struct WeakFace { + pub(crate) tables: Weak, + pub(crate) state: Weak, +} + +impl WeakFace { + pub fn upgrade(&self) -> Option { + Some(Face { + tables: self.tables.upgrade()?, + state: self.state.upgrade()?, + }) + } +} + #[derive(Clone)] pub struct Face { pub(crate) tables: Arc, pub(crate) state: Arc, } +impl Face { + pub fn downgrade(&self) -> WeakFace { + WeakFace { + tables: Arc::downgrade(&self.tables), + state: Arc::downgrade(&self.state), + } + } +} + impl Primitives for Face { fn send_declare(&self, msg: zenoh_protocol::network::Declare) { let ctrl_lock = zlock!(self.tables.ctrl_lock); diff --git a/zenoh/src/net/routing/dispatcher/queries.rs b/zenoh/src/net/routing/dispatcher/queries.rs index 3105c6195f..01c681e0d6 100644 --- a/zenoh/src/net/routing/dispatcher/queries.rs +++ b/zenoh/src/net/routing/dispatcher/queries.rs @@ -599,10 +599,16 @@ pub fn route_query( face: Arc::downgrade(outface), qid: *qid, }; - zenoh_runtime::ZRuntime::Net.spawn(async move { - tokio::time::sleep(timeout).await; - cleanup.run().await - }); + let cancellation_token = face.task_controller.get_cancellation_token(); + face.task_controller.spawn_with_rt( + zenoh_runtime::ZRuntime::Net, + async move { + tokio::select! { + _ = tokio::time::sleep(timeout) => { cleanup.run().await } + _ = cancellation_token.cancelled() => {} + } + }, + ); #[cfg(feature = "stats")] if !admin { inc_req_stats!(outface, tx, user, body) @@ -636,10 +642,16 @@ pub fn route_query( face: Arc::downgrade(outface), qid: *qid, }; - zenoh_runtime::ZRuntime::Net.spawn(async move { - tokio::time::sleep(timeout).await; - cleanup.run().await - }); + let cancellation_token = face.task_controller.get_cancellation_token(); + face.task_controller.spawn_with_rt( + zenoh_runtime::ZRuntime::Net, + async move { + tokio::select! { + _ = tokio::time::sleep(timeout) => { cleanup.run().await } + _ = cancellation_token.cancelled() => {} + } + }, + ); #[cfg(feature = "stats")] if !admin { inc_req_stats!(outface, tx, user, body) diff --git a/zenoh/src/net/routing/dispatcher/resource.rs b/zenoh/src/net/routing/dispatcher/resource.rs index 26a498461a..1762ff2cb4 100644 --- a/zenoh/src/net/routing/dispatcher/resource.rs +++ b/zenoh/src/net/routing/dispatcher/resource.rs @@ -293,6 +293,7 @@ impl Resource { let mutres = get_mut_unchecked(&mut resclone); if let Some(ref mut parent) = mutres.parent { if Arc::strong_count(res) <= 3 && res.childs.is_empty() { + // consider only childless resource held by only one external object (+ 1 strong count for resclone, + 1 strong count for res.parent to a total of 3 ) log::debug!("Unregister resource {}", res.expr()); if let Some(context) = mutres.context.as_mut() { for match_ in &mut context.matches { @@ -306,6 +307,7 @@ impl Resource { } } } + mutres.nonwild_prefix.take(); { get_mut_unchecked(parent).childs.remove(&res.suffix); } @@ -314,6 +316,17 @@ impl Resource { } } + pub fn close(self: &mut Arc) { + let r = get_mut_unchecked(self); + for c in r.childs.values_mut() { + Self::close(c); + } + r.parent.take(); + r.childs.clear(); + r.nonwild_prefix.take(); + r.session_ctxs.clear(); + } + #[cfg(test)] pub fn print_tree(from: &Arc) -> String { let mut result = from.expr(); diff --git a/zenoh/src/net/routing/dispatcher/tables.rs b/zenoh/src/net/routing/dispatcher/tables.rs index 73338cc79d..10605b25b1 100644 --- a/zenoh/src/net/routing/dispatcher/tables.rs +++ b/zenoh/src/net/routing/dispatcher/tables.rs @@ -172,6 +172,7 @@ pub fn close_face(tables: &TablesLock, face: &Weak) { match face.upgrade() { Some(mut face) => { log::debug!("Close {}", face); + face.task_controller.terminate_all(Duration::from_secs(10)); finalize_pending_queries(tables, &mut face); zlock!(tables.ctrl_lock).close_face(tables, &mut face); } diff --git a/zenoh/src/net/routing/hat/linkstate_peer/mod.rs b/zenoh/src/net/routing/hat/linkstate_peer/mod.rs index 020d796a1a..35afaf30d7 100644 --- a/zenoh/src/net/routing/hat/linkstate_peer/mod.rs +++ b/zenoh/src/net/routing/hat/linkstate_peer/mod.rs @@ -47,8 +47,8 @@ use std::{ any::Any, collections::{HashMap, HashSet}, sync::Arc, + time::Duration, }; -use tokio::task::JoinHandle; use zenoh_config::{unwrap_or_default, ModeDependent, WhatAmI, WhatAmIMatcher, ZenohId}; use zenoh_protocol::{ common::ZExtBody, @@ -56,6 +56,7 @@ use zenoh_protocol::{ }; use zenoh_result::ZResult; use zenoh_sync::get_mut_unchecked; +use zenoh_task::TerminatableTask; use zenoh_transport::unicast::TransportUnicast; mod network; @@ -112,7 +113,16 @@ struct HatTables { peer_subs: HashSet>, peer_qabls: HashSet>, peers_net: Option, - peers_trees_task: Option>, + peers_trees_task: Option, +} + +impl Drop for HatTables { + fn drop(&mut self) { + if self.peers_trees_task.is_some() { + let task = self.peers_trees_task.take().unwrap(); + task.terminate(Duration::from_secs(10)); + } + } } impl HatTables { @@ -128,24 +138,28 @@ impl HatTables { fn schedule_compute_trees(&mut self, tables_ref: Arc) { log::trace!("Schedule computations"); if self.peers_trees_task.is_none() { - let task = Some(zenoh_runtime::ZRuntime::Net.spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis( - *TREES_COMPUTATION_DELAY_MS, - )) - .await; - let mut tables = zwrite!(tables_ref.tables); - - log::trace!("Compute trees"); - let new_childs = hat_mut!(tables).peers_net.as_mut().unwrap().compute_trees(); - - log::trace!("Compute routes"); - pubsub::pubsub_tree_change(&mut tables, &new_childs); - queries::queries_tree_change(&mut tables, &new_childs); - - log::trace!("Computations completed"); - hat_mut!(tables).peers_trees_task = None; - })); - self.peers_trees_task = task; + let task = TerminatableTask::spawn( + zenoh_runtime::ZRuntime::Net, + async move { + tokio::time::sleep(std::time::Duration::from_millis( + *TREES_COMPUTATION_DELAY_MS, + )) + .await; + let mut tables = zwrite!(tables_ref.tables); + + log::trace!("Compute trees"); + let new_childs = hat_mut!(tables).peers_net.as_mut().unwrap().compute_trees(); + + log::trace!("Compute routes"); + pubsub::pubsub_tree_change(&mut tables, &new_childs); + queries::queries_tree_change(&mut tables, &new_childs); + + log::trace!("Computations completed"); + hat_mut!(tables).peers_trees_task = None; + }, + TerminatableTask::create_cancellation_token(), + ); + self.peers_trees_task = Some(task); } } } diff --git a/zenoh/src/net/routing/hat/linkstate_peer/network.rs b/zenoh/src/net/routing/hat/linkstate_peer/network.rs index 182a721a27..4d3497c861 100644 --- a/zenoh/src/net/routing/hat/linkstate_peer/network.rs +++ b/zenoh/src/net/routing/hat/linkstate_peer/network.rs @@ -15,6 +15,7 @@ use crate::net::codec::Zenoh080Routing; use crate::net::protocol::linkstate::{LinkState, LinkStateList}; use crate::net::routing::dispatcher::tables::NodeId; use crate::net::runtime::Runtime; +use crate::runtime::WeakRuntime; use petgraph::graph::NodeIndex; use petgraph::visit::{VisitMap, Visitable}; use std::convert::TryInto; @@ -115,7 +116,7 @@ pub(super) struct Network { pub(super) trees: Vec, pub(super) distances: Vec, pub(super) graph: petgraph::stable_graph::StableUnGraph, - pub(super) runtime: Runtime, + pub(super) runtime: WeakRuntime, } impl Network { @@ -155,7 +156,7 @@ impl Network { }], distances: vec![0.0], graph, - runtime, + runtime: Runtime::downgrade(&runtime), } } @@ -247,7 +248,7 @@ impl Network { whatami: self.graph[idx].whatami, locators: if details.locators { if idx == self.idx { - Some(self.runtime.get_locators()) + Some(self.runtime.upgrade().unwrap().get_locators()) } else { self.graph[idx].locators.clone() } @@ -336,6 +337,7 @@ impl Network { pub(super) fn link_states(&mut self, link_states: Vec, src: ZenohId) -> Changes { log::trace!("{} Received from {} raw: {:?}", self.name, src, link_states); + let strong_runtime = self.runtime.upgrade().unwrap(); let graph = &self.graph; let links = &mut self.links; @@ -487,13 +489,15 @@ impl Network { if !self.autoconnect.is_empty() { // Connect discovered peers if zenoh_runtime::ZRuntime::Net - .block_in_place(self.runtime.manager().get_transport_unicast(&zid)) + .block_in_place( + strong_runtime.manager().get_transport_unicast(&zid), + ) .is_none() && self.autoconnect.matches(whatami) { if let Some(locators) = locators { - let runtime = self.runtime.clone(); - self.runtime.spawn(async move { + let runtime = strong_runtime.clone(); + strong_runtime.spawn(async move { // random backoff tokio::time::sleep(std::time::Duration::from_millis( rand::random::() % 100, @@ -607,15 +611,15 @@ impl Network { let node = &self.graph[*idx]; if let Some(whatami) = node.whatami { if zenoh_runtime::ZRuntime::Net - .block_in_place(self.runtime.manager().get_transport_unicast(&node.zid)) + .block_in_place(strong_runtime.manager().get_transport_unicast(&node.zid)) .is_none() && self.autoconnect.matches(whatami) { if let Some(locators) = &node.locators { - let runtime = self.runtime.clone(); + let runtime = strong_runtime.clone(); let zid = node.zid; let locators = locators.clone(); - self.runtime.spawn(async move { + strong_runtime.spawn(async move { // random backoff tokio::time::sleep(std::time::Duration::from_millis( rand::random::() % 100, diff --git a/zenoh/src/net/routing/hat/p2p_peer/gossip.rs b/zenoh/src/net/routing/hat/p2p_peer/gossip.rs index 247412bfdf..f651ccdc0d 100644 --- a/zenoh/src/net/routing/hat/p2p_peer/gossip.rs +++ b/zenoh/src/net/routing/hat/p2p_peer/gossip.rs @@ -14,6 +14,7 @@ use crate::net::codec::Zenoh080Routing; use crate::net::protocol::linkstate::{LinkState, LinkStateList}; use crate::net::runtime::Runtime; +use crate::runtime::WeakRuntime; use petgraph::graph::NodeIndex; use std::convert::TryInto; use vec_map::VecMap; @@ -93,7 +94,7 @@ pub(super) struct Network { pub(super) idx: NodeIndex, pub(super) links: VecMap, pub(super) graph: petgraph::stable_graph::StableUnGraph, - pub(super) runtime: Runtime, + pub(super) runtime: WeakRuntime, } impl Network { @@ -124,7 +125,7 @@ impl Network { idx, links: VecMap::new(), graph, - runtime, + runtime: Runtime::downgrade(&runtime), } } @@ -191,7 +192,7 @@ impl Network { whatami: self.graph[idx].whatami, locators: if details.locators { if idx == self.idx { - Some(self.runtime.get_locators()) + Some(self.runtime.upgrade().unwrap().get_locators()) } else { self.graph[idx].locators.clone() } @@ -266,6 +267,7 @@ impl Network { pub(super) fn link_states(&mut self, link_states: Vec, src: ZenohId) { log::trace!("{} Received from {} raw: {:?}", self.name, src, link_states); + let strong_runtime = self.runtime.upgrade().unwrap(); let graph = &self.graph; let links = &mut self.links; @@ -407,13 +409,13 @@ impl Network { if !self.autoconnect.is_empty() { // Connect discovered peers if zenoh_runtime::ZRuntime::Net - .block_in_place(self.runtime.manager().get_transport_unicast(&zid)) + .block_in_place(strong_runtime.manager().get_transport_unicast(&zid)) .is_none() && self.autoconnect.matches(whatami) { if let Some(locators) = locators { - let runtime = self.runtime.clone(); - self.runtime.spawn(async move { + let runtime = strong_runtime.clone(); + strong_runtime.spawn(async move { // random backoff tokio::time::sleep(std::time::Duration::from_millis( rand::random::() % 100, diff --git a/zenoh/src/net/routing/hat/router/mod.rs b/zenoh/src/net/routing/hat/router/mod.rs index 5497afc9b8..030b8da4b4 100644 --- a/zenoh/src/net/routing/hat/router/mod.rs +++ b/zenoh/src/net/routing/hat/router/mod.rs @@ -52,8 +52,8 @@ use std::{ collections::{hash_map::DefaultHasher, HashMap, HashSet}, hash::Hasher, sync::Arc, + time::Duration, }; -use tokio::task::JoinHandle; use zenoh_config::{unwrap_or_default, ModeDependent, WhatAmI, WhatAmIMatcher, ZenohId}; use zenoh_protocol::{ common::ZExtBody, @@ -61,6 +61,7 @@ use zenoh_protocol::{ }; use zenoh_result::ZResult; use zenoh_sync::get_mut_unchecked; +use zenoh_task::TerminatableTask; use zenoh_transport::unicast::TransportUnicast; mod network; @@ -121,11 +122,24 @@ struct HatTables { routers_net: Option, peers_net: Option, shared_nodes: Vec, - routers_trees_task: Option>, - peers_trees_task: Option>, + routers_trees_task: Option, + peers_trees_task: Option, router_peers_failover_brokering: bool, } +impl Drop for HatTables { + fn drop(&mut self) { + if self.peers_trees_task.is_some() { + let task = self.peers_trees_task.take().unwrap(); + task.terminate(Duration::from_secs(10)); + } + if self.routers_trees_task.is_some() { + let task = self.routers_trees_task.take().unwrap(); + task.terminate(Duration::from_secs(10)); + } + } +} + impl HatTables { fn new(router_peers_failover_brokering: bool) -> Self { Self { @@ -243,36 +257,40 @@ impl HatTables { if (net_type == WhatAmI::Router && self.routers_trees_task.is_none()) || (net_type == WhatAmI::Peer && self.peers_trees_task.is_none()) { - let task = Some(zenoh_runtime::ZRuntime::Net.spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis( - *TREES_COMPUTATION_DELAY_MS, - )) - .await; - let mut tables = zwrite!(tables_ref.tables); - - log::trace!("Compute trees"); - let new_childs = match net_type { - WhatAmI::Router => hat_mut!(tables) - .routers_net - .as_mut() - .unwrap() - .compute_trees(), - _ => hat_mut!(tables).peers_net.as_mut().unwrap().compute_trees(), - }; + let task = TerminatableTask::spawn( + zenoh_runtime::ZRuntime::Net, + async move { + tokio::time::sleep(std::time::Duration::from_millis( + *TREES_COMPUTATION_DELAY_MS, + )) + .await; + let mut tables = zwrite!(tables_ref.tables); + + log::trace!("Compute trees"); + let new_childs = match net_type { + WhatAmI::Router => hat_mut!(tables) + .routers_net + .as_mut() + .unwrap() + .compute_trees(), + _ => hat_mut!(tables).peers_net.as_mut().unwrap().compute_trees(), + }; - log::trace!("Compute routes"); - pubsub::pubsub_tree_change(&mut tables, &new_childs, net_type); - queries::queries_tree_change(&mut tables, &new_childs, net_type); + log::trace!("Compute routes"); + pubsub::pubsub_tree_change(&mut tables, &new_childs, net_type); + queries::queries_tree_change(&mut tables, &new_childs, net_type); - log::trace!("Computations completed"); - match net_type { - WhatAmI::Router => hat_mut!(tables).routers_trees_task = None, - _ => hat_mut!(tables).peers_trees_task = None, - }; - })); + log::trace!("Computations completed"); + match net_type { + WhatAmI::Router => hat_mut!(tables).routers_trees_task = None, + _ => hat_mut!(tables).peers_trees_task = None, + }; + }, + TerminatableTask::create_cancellation_token(), + ); match net_type { - WhatAmI::Router => self.routers_trees_task = task, - _ => self.peers_trees_task = task, + WhatAmI::Router => self.routers_trees_task = Some(task), + _ => self.peers_trees_task = Some(task), }; } } diff --git a/zenoh/src/net/routing/router.rs b/zenoh/src/net/routing/router.rs index d67a2baa9d..c80d3bdc09 100644 --- a/zenoh/src/net/routing/router.rs +++ b/zenoh/src/net/routing/router.rs @@ -155,7 +155,7 @@ impl Router { state: newface, }; - let _ = mux.face.set(face.clone()); + let _ = mux.face.set(Face::downgrade(&face)); ctrl_lock.new_transport_unicast_face(&mut tables, &self.tables, &mut face, &transport)?; diff --git a/zenoh/src/net/runtime/mod.rs b/zenoh/src/net/runtime/mod.rs index 282c45f66c..9314186b2e 100644 --- a/zenoh/src/net/runtime/mod.rs +++ b/zenoh/src/net/runtime/mod.rs @@ -29,7 +29,8 @@ pub use adminspace::AdminSpace; use futures::stream::StreamExt; use futures::Future; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, Weak}; +use std::time::Duration; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use uhlc::{HLCBuilder, HLC}; @@ -39,6 +40,7 @@ use zenoh_protocol::core::{Locator, WhatAmI, ZenohId}; use zenoh_protocol::network::NetworkMessage; use zenoh_result::{bail, ZResult}; use zenoh_sync::get_mut_unchecked; +use zenoh_task::TaskController; use zenoh_transport::{ multicast::TransportMulticast, unicast::TransportUnicast, TransportEventHandler, TransportManager, TransportMulticastEventHandler, TransportPeer, TransportPeerEventHandler, @@ -54,7 +56,17 @@ struct RuntimeState { transport_handlers: std::sync::RwLock>>, locators: std::sync::RwLock>, hlc: Option>, - token: CancellationToken, + task_controller: TaskController, +} + +pub struct WeakRuntime { + state: Weak, +} + +impl WeakRuntime { + pub fn upgrade(&self) -> Option { + self.state.upgrade().map(|state| Runtime { state }) + } } #[derive(Clone)] @@ -97,7 +109,7 @@ impl Runtime { let router = Arc::new(Router::new(zid, whatami, hlc.clone(), &config)?); let handler = Arc::new(RuntimeTransportEventHandler { - runtime: std::sync::RwLock::new(None), + runtime: std::sync::RwLock::new(WeakRuntime { state: Weak::new() }), }); let transport_manager = TransportManager::builder() @@ -120,22 +132,33 @@ impl Runtime { transport_handlers: std::sync::RwLock::new(vec![]), locators: std::sync::RwLock::new(vec![]), hlc, - token: CancellationToken::new(), + task_controller: TaskController::default(), }), }; - *handler.runtime.write().unwrap() = Some(runtime.clone()); + *handler.runtime.write().unwrap() = Runtime::downgrade(&runtime); get_mut_unchecked(&mut runtime.state.router.clone()).init_link_state(runtime.clone()); let receiver = config.subscribe(); + let token = runtime.get_cancellation_token(); runtime.spawn({ let runtime2 = runtime.clone(); async move { let mut stream = receiver.into_stream(); - while let Some(event) = stream.next().await { - if &*event == "connect/endpoints" { - if let Err(e) = runtime2.update_peers().await { - log::error!("Error updating peers: {}", e); + loop { + tokio::select! { + res = stream.next() => { + match res { + Some(event) => { + if &*event == "connect/endpoints" { + if let Err(e) = runtime2.update_peers().await { + log::error!("Error updating peers: {}", e); + } + } + }, + None => { break; } + } } + _ = token.cancelled() => { break; } } } } @@ -156,8 +179,24 @@ impl Runtime { pub async fn close(&self) -> ZResult<()> { log::trace!("Runtime::close())"); // TODO: Check this whether is able to terminate all spawned task by Runtime::spawn - self.state.token.cancel(); + self.state + .task_controller + .terminate_all(Duration::from_secs(10)); self.manager().close().await; + // clean up to break cyclic reference of self.state to itself + self.state.transport_handlers.write().unwrap().clear(); + // TODO: the call below is needed to prevent intermittent leak + // due to not freed resource Arc, that apparently happens because + // the task responsible for resource clean up was aborted earlier than expected. + // This should be resolved by identfying correspodning task, and placing + // cancellation token manually inside it. + self.router() + .tables + .tables + .write() + .unwrap() + .root_res + .close(); Ok(()) } @@ -169,18 +208,28 @@ impl Runtime { self.state.locators.read().unwrap().clone() } + /// Spawns a task within runtime. + /// Upon close runtime will block until this task completes pub(crate) fn spawn(&self, future: F) -> JoinHandle<()> where F: Future + Send + 'static, T: Send + 'static, { - let token = self.state.token.clone(); - zenoh_runtime::ZRuntime::Net.spawn(async move { - tokio::select! { - _ = token.cancelled() => {} - _ = future => {} - } - }) + self.state + .task_controller + .spawn_with_rt(zenoh_runtime::ZRuntime::Net, future) + } + + /// Spawns a task within runtime. + /// Upon runtime close the task will be automatically aborted. + pub(crate) fn spawn_abortable(&self, future: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + T: Send + 'static, + { + self.state + .task_controller + .spawn_abortable_with_rt(zenoh_runtime::ZRuntime::Net, future) } pub(crate) fn router(&self) -> Arc { @@ -202,10 +251,20 @@ impl Runtime { pub fn whatami(&self) -> WhatAmI { self.state.whatami } + + pub fn downgrade(this: &Runtime) -> WeakRuntime { + WeakRuntime { + state: Arc::downgrade(&this.state), + } + } + + pub fn get_cancellation_token(&self) -> CancellationToken { + self.state.task_controller.get_cancellation_token() + } } struct RuntimeTransportEventHandler { - runtime: std::sync::RwLock>, + runtime: std::sync::RwLock, } impl TransportEventHandler for RuntimeTransportEventHandler { @@ -214,7 +273,7 @@ impl TransportEventHandler for RuntimeTransportEventHandler { peer: TransportPeer, transport: TransportUnicast, ) -> ZResult> { - match zread!(self.runtime).as_ref() { + match zread!(self.runtime).upgrade().as_ref() { Some(runtime) => { let slave_handlers: Vec> = zread!(runtime.state.transport_handlers) @@ -242,7 +301,7 @@ impl TransportEventHandler for RuntimeTransportEventHandler { &self, transport: TransportMulticast, ) -> ZResult> { - match zread!(self.runtime).as_ref() { + match zread!(self.runtime).upgrade().as_ref() { Some(runtime) => { let slave_handlers: Vec> = zread!(runtime.state.transport_handlers) diff --git a/zenoh/src/net/runtime/orchestrator.rs b/zenoh/src/net/runtime/orchestrator.rs index 3feee6fb1b..3f1026268a 100644 --- a/zenoh/src/net/runtime/orchestrator.rs +++ b/zenoh/src/net/runtime/orchestrator.rs @@ -194,7 +194,7 @@ impl Runtime { let this = self.clone(); match (listen, autoconnect.is_empty()) { (true, false) => { - self.spawn(async move { + self.spawn_abortable(async move { tokio::select! { _ = this.responder(&mcast_socket, &sockets) => {}, _ = this.connect_all(&sockets, autoconnect, &addr) => {}, @@ -202,14 +202,14 @@ impl Runtime { }); } (true, true) => { - self.spawn(async move { + self.spawn_abortable(async move { this.responder(&mcast_socket, &sockets).await; }); } (false, false) => { - self.spawn( - async move { this.connect_all(&sockets, autoconnect, &addr).await }, - ); + self.spawn_abortable(async move { + this.connect_all(&sockets, autoconnect, &addr).await + }); } _ => {} } @@ -658,43 +658,44 @@ impl Runtime { async fn peer_connector_retry(&self, peer: EndPoint) { let retry_config = self.get_connect_retry_config(&peer); let mut period = retry_config.period(); + let cancellation_token = self.get_cancellation_token(); loop { log::trace!("Trying to connect to configured peer {}", peer); let endpoint = peer.clone(); - match tokio::time::timeout( - retry_config.timeout(), - self.manager().open_transport_unicast(endpoint), - ) - .await - { - Ok(Ok(transport)) => { - log::debug!("Successfully connected to configured peer {}", peer); - if let Ok(Some(orch_transport)) = transport.get_callback() { - if let Some(orch_transport) = orch_transport - .as_any() - .downcast_ref::() - { - *zwrite!(orch_transport.endpoint) = Some(peer); + tokio::select! { + res = tokio::time::timeout(retry_config.timeout(), self.manager().open_transport_unicast(endpoint)) => { + match res { + Ok(Ok(transport)) => { + log::debug!("Successfully connected to configured peer {}", peer); + if let Ok(Some(orch_transport)) = transport.get_callback() { + if let Some(orch_transport) = orch_transport + .as_any() + .downcast_ref::() + { + *zwrite!(orch_transport.endpoint) = Some(peer); + } + } + break; + } + Ok(Err(e)) => { + log::debug!( + "Unable to connect to configured peer {}! {}. Retry in {:?}.", + peer, + e, + period.duration() + ); + } + Err(e) => { + log::debug!( + "Unable to connect to configured peer {}! {}. Retry in {:?}.", + peer, + e, + period.duration() + ); } } - break; - } - Ok(Err(e)) => { - log::debug!( - "Unable to connect to configured peer {}! {}. Retry in {:?}.", - peer, - e, - period.duration() - ); - } - Err(e) => { - log::debug!( - "Unable to connect to configured peer {}! {}. Retry in {:?}.", - peer, - e, - period.duration() - ); } + _ = cancellation_token.cancelled() => { break; } } tokio::time::sleep(period.next_duration()).await; } @@ -1018,11 +1019,15 @@ impl Runtime { match session.runtime.whatami() { WhatAmI::Client => { let runtime = session.runtime.clone(); + let cancellation_token = runtime.get_cancellation_token(); session.runtime.spawn(async move { let retry_config = runtime.get_global_connect_retry_config(); let mut period = retry_config.period(); while runtime.start_client().await.is_err() { - tokio::time::sleep(period.next_duration()).await; + tokio::select! { + _ = tokio::time::sleep(period.next_duration()) => {} + _ = cancellation_token.cancelled() => { break; } + } } }); } diff --git a/zenoh/src/net/tests/tables.rs b/zenoh/src/net/tests/tables.rs index e8b6f6ac9f..1b02a5964f 100644 --- a/zenoh/src/net/tests/tables.rs +++ b/zenoh/src/net/tests/tables.rs @@ -166,8 +166,8 @@ fn match_test() { } } -#[test] -fn clean_test() { +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn clean_test() { let config = Config::default(); let router = Router::new( ZenohId::try_from([1]).unwrap(), diff --git a/zenoh/src/scouting.rs b/zenoh/src/scouting.rs index ab5866e388..cd25393754 100644 --- a/zenoh/src/scouting.rs +++ b/zenoh/src/scouting.rs @@ -14,12 +14,13 @@ use crate::handlers::{locked, Callback, DefaultHandler}; use crate::net::runtime::{orchestrator::Loop, Runtime}; -use futures::StreamExt; +use std::time::Duration; use std::{fmt, future::Ready, net::SocketAddr, ops::Deref}; use tokio::net::UdpSocket; use zenoh_core::{AsyncResolve, Resolvable, SyncResolve}; use zenoh_protocol::core::WhatAmIMatcher; use zenoh_result::ZResult; +use zenoh_task::TerminatableTask; /// Constants and helpers for zenoh `whatami` flags. pub use zenoh_protocol::core::WhatAmI; @@ -204,7 +205,7 @@ where /// ``` pub(crate) struct ScoutInner { #[allow(dead_code)] - pub(crate) stop_sender: flume::Sender<()>, + pub(crate) scout_task: Option, } impl ScoutInner { @@ -226,11 +227,19 @@ impl ScoutInner { /// # } /// ``` pub fn stop(self) { - // This drops the inner `stop_sender` and hence stops the scouting receiver std::mem::drop(self); } } +impl Drop for ScoutInner { + fn drop(&mut self) { + if self.scout_task.is_some() { + let task = self.scout_task.take(); + task.unwrap().terminate(Duration::from_secs(10)); + } + } +} + impl fmt::Debug for ScoutInner { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("CallbackScout").finish() @@ -307,7 +316,6 @@ fn scout( zenoh_config::defaults::scouting::multicast::interface, |s| s.as_ref(), ); - let (stop_sender, stop_receiver) = flume::bounded::<()>(1); let ifaces = Runtime::get_interfaces(ifaces); if !ifaces.is_empty() { let sockets: Vec = ifaces @@ -315,25 +323,29 @@ fn scout( .filter_map(|iface| Runtime::bind_ucast_port(iface).ok()) .collect(); if !sockets.is_empty() { - zenoh_runtime::ZRuntime::Net.spawn(async move { - let mut stop_receiver = stop_receiver.stream(); - let scout = Runtime::scout(&sockets, what, &addr, move |hello| { - let callback = callback.clone(); - async move { - callback(hello); - Loop::Continue + let cancellation_token = TerminatableTask::create_cancellation_token(); + let cancellation_token_clone = cancellation_token.clone(); + let task = TerminatableTask::spawn( + zenoh_runtime::ZRuntime::Net, + async move { + let scout = Runtime::scout(&sockets, what, &addr, move |hello| { + let callback = callback.clone(); + async move { + callback(hello); + Loop::Continue + } + }); + tokio::select! { + _ = scout => {}, + _ = cancellation_token_clone.cancelled() => { log::trace!("stop scout({}, {})", what, &config); }, } - }); - let stop = async move { - stop_receiver.next().await; - log::trace!("stop scout({}, {})", what, &config); - }; - tokio::select! { - _ = scout => {}, - _ = stop => {}, - } + }, + cancellation_token.clone(), + ); + return Ok(ScoutInner { + scout_task: Some(task), }); } } - Ok(ScoutInner { stop_sender }) + Ok(ScoutInner { scout_task: None }) } diff --git a/zenoh/src/session.rs b/zenoh/src/session.rs index 7290d0aeac..7cd495d378 100644 --- a/zenoh/src/session.rs +++ b/zenoh/src/session.rs @@ -81,6 +81,7 @@ use zenoh_protocol::{ }, }; use zenoh_result::ZResult; +use zenoh_task::TaskController; use zenoh_util::core::AsyncResolve; zconfigurable! { @@ -392,6 +393,7 @@ pub struct Session { pub(crate) state: Arc>, pub(crate) id: u16, pub(crate) alive: bool, + task_controller: TaskController, } static SESSION_ID_COUNTER: AtomicU16 = AtomicU16::new(0); @@ -412,6 +414,7 @@ impl Session { state: state.clone(), id: SESSION_ID_COUNTER.fetch_add(1, Ordering::SeqCst), alive: true, + task_controller: TaskController::default(), }; runtime.new_handler(Arc::new(admin::Handler::new(session.clone()))); @@ -515,14 +518,18 @@ impl Session { /// session.close().res().await.unwrap(); /// # } /// ``` - pub fn close(self) -> impl Resolve> { + pub fn close(mut self) -> impl Resolve> { ResolveFuture::new(async move { trace!("close()"); + self.task_controller.terminate_all(Duration::from_secs(10)); self.runtime.close().await?; - let primitives = zwrite!(self.state).primitives.as_ref().unwrap().clone(); - primitives.send_close(); - + let mut state = zwrite!(self.state); + state.primitives.as_ref().unwrap().send_close(); + // clean up to break cyclic references from self.state to itself + state.primitives.take(); + state.queryables.clear(); + self.alive = false; Ok(()) }) } @@ -803,6 +810,7 @@ impl Session { state: self.state.clone(), id: self.id, alive: false, + task_controller: self.task_controller.clone(), } } @@ -1499,30 +1507,34 @@ impl Session { if key_expr.intersects(&msub.key_expr) { // Cannot hold session lock when calling tables (matching_status()) // TODO: check which ZRuntime should be used - zenoh_runtime::ZRuntime::RX.spawn({ - let session = self.clone(); - let msub = msub.clone(); - async move { - match msub.current.lock() { - Ok(mut current) => { - if !*current { - if let Ok(status) = - session.matching_status(&msub.key_expr, msub.destination) - { - if status.matching_subscribers() { - *current = true; - let callback = msub.callback.clone(); - (callback)(status) + self.task_controller + .spawn_with_rt(zenoh_runtime::ZRuntime::RX, { + let session = self.clone(); + let msub = msub.clone(); + async move { + match msub.current.lock() { + Ok(mut current) => { + if !*current { + if let Ok(status) = session + .matching_status(&msub.key_expr, msub.destination) + { + if status.matching_subscribers() { + *current = true; + let callback = msub.callback.clone(); + (callback)(status) + } } } } - } - Err(e) => { - log::error!("Error trying to acquire MathginListener lock: {}", e); + Err(e) => { + log::error!( + "Error trying to acquire MathginListener lock: {}", + e + ); + } } } - } - }); + }); } } } @@ -1533,30 +1545,34 @@ impl Session { if key_expr.intersects(&msub.key_expr) { // Cannot hold session lock when calling tables (matching_status()) // TODO: check which ZRuntime should be used - zenoh_runtime::ZRuntime::RX.spawn({ - let session = self.clone(); - let msub = msub.clone(); - async move { - match msub.current.lock() { - Ok(mut current) => { - if *current { - if let Ok(status) = - session.matching_status(&msub.key_expr, msub.destination) - { - if !status.matching_subscribers() { - *current = false; - let callback = msub.callback.clone(); - (callback)(status) + self.task_controller + .spawn_with_rt(zenoh_runtime::ZRuntime::RX, { + let session = self.clone(); + let msub = msub.clone(); + async move { + match msub.current.lock() { + Ok(mut current) => { + if *current { + if let Ok(status) = session + .matching_status(&msub.key_expr, msub.destination) + { + if !status.matching_subscribers() { + *current = false; + let callback = msub.callback.clone(); + (callback)(status) + } } } } - } - Err(e) => { - log::error!("Error trying to acquire MathginListener lock: {}", e); + Err(e) => { + log::error!( + "Error trying to acquire MathginListener lock: {}", + e + ); + } } } - } - }); + }); } } } @@ -1752,27 +1768,33 @@ impl Session { _ => 1, }; - zenoh_runtime::ZRuntime::Net.spawn({ - let state = self.state.clone(); - let zid = self.runtime.zid(); - async move { - tokio::time::sleep(timeout).await; - let mut state = zwrite!(state); - if let Some(query) = state.queries.remove(&qid) { - std::mem::drop(state); - log::debug!("Timeout on query {}! Send error and close.", qid); - if query.reception_mode == ConsolidationMode::Latest { - for (_, reply) in query.replies.unwrap().into_iter() { - (query.callback)(reply); + let token = self.task_controller.get_cancellation_token(); + self.task_controller + .spawn_with_rt(zenoh_runtime::ZRuntime::Net, { + let state = self.state.clone(); + let zid = self.runtime.zid(); + async move { + tokio::select! { + _ = tokio::time::sleep(timeout) => { + let mut state = zwrite!(state); + if let Some(query) = state.queries.remove(&qid) { + std::mem::drop(state); + log::debug!("Timeout on query {}! Send error and close.", qid); + if query.reception_mode == ConsolidationMode::Latest { + for (_, reply) in query.replies.unwrap().into_iter() { + (query.callback)(reply); + } + } + (query.callback)(Reply { + sample: Err("Timeout".into()), + replier_id: zid, + }); + } } + _ = token.cancelled() => {} } - (query.callback)(Reply { - sample: Err("Timeout".into()), - replier_id: zid, - }); } - } - }); + }); let selector = match scope { Some(scope) => Selector {