diff --git a/chaindexing-tests/Cargo.toml b/chaindexing-tests/Cargo.toml index 722ddce..db25c85 100644 --- a/chaindexing-tests/Cargo.toml +++ b/chaindexing-tests/Cargo.toml @@ -12,5 +12,5 @@ dotenvy = "0.15" diesel = { version = "2", features = ["postgres", "chrono"] } rand = "0.8.5" serde = { version = "1.0", features = ["derive"] } -tokio = { version = "1.29", features = ["full"] } +tokio = { version = "1.37", features = ["full"] } diff --git a/chaindexing/Cargo.toml b/chaindexing/Cargo.toml index f7086b0..b583545 100644 --- a/chaindexing/Cargo.toml +++ b/chaindexing/Cargo.toml @@ -23,7 +23,7 @@ ethers = "2.0" serde = "1" serde_json = "1" tokio-postgres = { version = "0.7", features = ["with-serde_json-1"]} -tokio = "1.29" +tokio = { version = "1.37", features = ["full"] } uuid = { version = "1", features = ["v4", "serde"] } futures-core = { version = "0.3", features = ["alloc"] } futures-util = "0.3" diff --git a/chaindexing/src/handlers.rs b/chaindexing/src/handlers.rs index 18cf34f..bc129d4 100644 --- a/chaindexing/src/handlers.rs +++ b/chaindexing/src/handlers.rs @@ -25,7 +25,7 @@ pub async fn start(config: &Config) let config = config.clone(); node_task - .add_task(tokio::spawn({ + .add_subtask(&tokio::spawn({ let node_task = node_task.clone(); // MultiChainStates are indexed in an order-agnostic fashion, so no need for txn client @@ -42,7 +42,7 @@ pub async fn start(config: &Config) node_task .clone() - .add_task(tokio::spawn(async move { + .add_subtask(&tokio::spawn(async move { let mut interval = interval(Duration::from_millis(config.handler_rate_ms)); diff --git a/chaindexing/src/ingester.rs b/chaindexing/src/ingester.rs index 9c9ab06..b96deae 100644 --- a/chaindexing/src/ingester.rs +++ b/chaindexing/src/ingester.rs @@ -37,7 +37,7 @@ pub async fn start(config: &Config) -> Node let config = config.clone(); node_task - .add_task(tokio::spawn(async move { + .add_subtask(&tokio::spawn(async move { let mut interval = interval(Duration::from_millis(config.ingestion_rate_ms)); let mut last_pruned_at_per_chain_id = HashMap::new(); diff --git a/chaindexing/src/nodes/node_task.rs b/chaindexing/src/nodes/node_task.rs index ff56ce4..33ceef9 100644 --- a/chaindexing/src/nodes/node_task.rs +++ b/chaindexing/src/nodes/node_task.rs @@ -1,9 +1,14 @@ use std::sync::Arc; use tokio::sync::Mutex; -#[derive(Clone)] +#[derive(Clone, PartialEq, Debug)] +struct NodeSubTask(*const tokio::task::JoinHandle<()>); + +unsafe impl Send for NodeSubTask {} + +#[derive(Clone, Debug)] pub struct NodeTask { - tasks: Arc>>>, + subtasks: Arc>>, } impl Default for NodeTask { @@ -15,17 +20,108 @@ impl Default for NodeTask { impl NodeTask { pub fn new() -> Self { NodeTask { - tasks: Arc::new(Mutex::new(Vec::new())), + subtasks: Arc::new(Mutex::new(Vec::new())), } } - pub async fn add_task(&self, task: tokio::task::JoinHandle<()>) { - let mut tasks = self.tasks.lock().await; - tasks.push(task); + pub async fn add_subtask(&self, task: &tokio::task::JoinHandle<()>) { + let mut subtasks = self.subtasks.lock().await; + subtasks.push(NodeSubTask(task)); } pub async fn stop(&self) { - let tasks = self.tasks.lock().await; - for task in tasks.iter() { - task.abort(); + let subtasks = self.subtasks.lock().await; + for subtask in subtasks.iter() { + if let Some(subtask) = unsafe { (*subtask).0.as_ref() } { + subtask.abort(); + } + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[tokio::test] + async fn adds_a_tokio_task() { + let node_task = NodeTask::new(); + let subtask = tokio::spawn(async {}); + + node_task.add_subtask(&subtask).await; + + let subtask = NodeSubTask(&subtask); + + let added_subtasks = node_task.subtasks.lock().await; + + assert_eq!(subtask, *added_subtasks.first().unwrap()); + } + + #[tokio::test] + async fn adds_multiple_flattened_tokio_tasks() { + let node_task = NodeTask::new(); + let subtasks = [ + tokio::spawn(async {}), + tokio::spawn(async {}), + tokio::spawn(async {}), + ]; + + for subtask in subtasks.iter() { + node_task.add_subtask(subtask).await; + } + + let subtasks: Vec<_> = subtasks.iter().map(|t| NodeSubTask(t)).collect(); + + let added_subtasks = node_task.subtasks.lock().await; + + for (index, subtask) in subtasks.iter().enumerate() { + assert_eq!(subtask, added_subtasks.get(index).unwrap()); + } + } + + #[tokio::test] + async fn adds_multiple_nested_subtasks() { + let node_task = NodeTask::new(); + + let subtask = tokio::spawn({ + let node_task = node_task.clone(); + + async move { + let subtask = tokio::spawn({ + let node_task = node_task.clone(); + + async move { + let subtask = tokio::spawn({ + let node_task = node_task.clone(); + + async move { + let subtask = tokio::spawn(async move {}); + node_task.add_subtask(&subtask).await; + assert_is_added(&subtask, &node_task).await; + } + }); + node_task.add_subtask(&subtask).await; + assert_is_added(&subtask, &node_task).await; + } + }); + + node_task.add_subtask(&subtask).await; + + assert_is_added(&subtask, &node_task).await; + } + }); + + node_task.add_subtask(&subtask).await; + assert_is_added(&subtask, &node_task).await; + + async fn assert_is_added(subtask: &tokio::task::JoinHandle<()>, node_task: &NodeTask) { + let subtask = NodeSubTask(subtask); + + let added_subtasks = node_task.subtasks.lock().await; + + assert!(added_subtasks + .iter() + .find(|added_subtask| **added_subtask == subtask) + .is_some()); } } }