Skip to content

Commit

Permalink
feat: task management (#20)
Browse files Browse the repository at this point in the history
* feat: task management

* fix: get a handle

* cleanup: docs and pub

* doc: expand slightly

* refactor: clean up all tasks before returning

* refactor: use a future instead of the token

* nit: remove newline

* doc: readme update

* fix: remove unnecessary generic

* chore: more docs and bubble up more API to shutdown

* fix: prevent deadlock of task waiting on itself

* chore: bump version
  • Loading branch information
prestwich authored Jan 31, 2025
1 parent 2dfffec commit ba0eb69
Show file tree
Hide file tree
Showing 13 changed files with 698 additions and 103 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "Simple, modern, ergonomic JSON-RPC 2.0 router built with tower an
keywords = ["json-rpc", "jsonrpc", "json"]
categories = ["web-programming::http-server", "web-programming::websocket"]

version = "0.2.0"
version = "0.3.0"
edition = "2021"
rust-version = "1.81"
authors = ["init4", "James Prestwich"]
Expand All @@ -31,7 +31,7 @@ tokio-stream = { version = "0.1.17", optional = true }

# ipc
interprocess = { version = "2.2.2", features = ["async", "tokio"], optional = true }
tokio-util = { version = "0.7.13", optional = true, features = ["io"] }
tokio-util = { version = "0.7.13", optional = true, features = ["io", "rt"] }

# ws
tokio-tungstenite = { version = "0.26.1", features = ["rustls-tls-webpki-roots"], optional = true }
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ ajj aims to provide simple, flexible, and ergonomic routing for JSON-RPC.
- Support for pubsub-style notifications.
- Built-in support for axum, and tower's middleware and service ecosystem.
- Basic built-in pubsub server implementations for WS and IPC.
- Connection-oriented task management automatically cancels tasks on client
disconnect.

## Concepts

Expand Down
48 changes: 45 additions & 3 deletions src/axum.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,50 @@
use crate::types::{InboundData, Response};
use crate::{
types::{InboundData, Response},
HandlerCtx, TaskSet,
};
use axum::{extract::FromRequest, response::IntoResponse};
use bytes::Bytes;
use std::{future::Future, pin::Pin};
use tokio::runtime::Handle;

impl<S> axum::handler::Handler<Bytes, S> for crate::Router<S>
/// A wrapper around an [`Router`] that implements the
/// [`axum::handler::Handler`] trait. This struct is an implementation detail
/// of the [`Router::into_axum`] and [`Router::into_axum_with_handle`] methods.
///
/// [`Router`]: crate::Router
/// [`Router::into_axum`]: crate::Router::into_axum
/// [`Router::into_axum_with_handle`]: crate::Router::into_axum_with_handle
#[derive(Debug, Clone)]
pub(crate) struct IntoAxum<S> {
pub(crate) router: crate::Router<S>,
pub(crate) task_set: TaskSet,
}

impl<S> From<crate::Router<S>> for IntoAxum<S> {
fn from(router: crate::Router<S>) -> Self {
Self {
router,
task_set: Default::default(),
}
}
}

impl<S> IntoAxum<S> {
/// Create a new `IntoAxum` from a router and task set.
pub(crate) fn new(router: crate::Router<S>, handle: Handle) -> Self {
Self {
router,
task_set: handle.into(),
}
}

/// Get a new context, built from the task set.
fn ctx(&self) -> HandlerCtx {
self.task_set.clone().into()
}
}

impl<S> axum::handler::Handler<Bytes, S> for IntoAxum<S>
where
S: Clone + Send + Sync + 'static,
{
Expand All @@ -21,7 +62,8 @@ where
let req = InboundData::try_from(bytes).unwrap_or_default();

if let Some(response) = self
.call_batch_with_state(Default::default(), req, state)
.router
.call_batch_with_state(self.ctx(), req, state)
.await
{
Box::<str>::from(response).into_response()
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ pub(crate) use routes::{BoxedIntoRoute, ErasedIntoRoute, Method, Route};
mod router;
pub use router::Router;

mod tasks;
pub(crate) use tasks::TaskSet;

mod types;
pub use types::{ErrorPayload, ResponsePayload};

Expand Down
5 changes: 4 additions & 1 deletion src/pubsub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ mod ipc;
pub use ipc::ReadJsonStream;

mod shared;
pub use shared::{ConnectionId, ServerShutdown, DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT};
pub use shared::{ConnectionId, DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT};

mod shutdown;
pub use shutdown::ServerShutdown;

mod r#trait;
pub use r#trait::{Connect, In, JsonReqStream, JsonSink, Listener, Out};
Expand Down
104 changes: 49 additions & 55 deletions src/pubsub/shared.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
use core::fmt;

use crate::{
pubsub::{In, JsonSink, Listener, Out},
types::InboundData,
HandlerCtx, TaskSet,
};
use core::fmt;
use serde_json::value::RawValue;
use tokio::{
select,
sync::{mpsc, oneshot, watch},
task::JoinHandle,
};
use tokio::{pin, select, sync::mpsc, task::JoinHandle};
use tokio_stream::StreamExt;
use tokio_util::sync::WaitForCancellationFutureOwned;
use tracing::{debug, debug_span, error, instrument, trace, Instrument};

/// Default notification buffer size per task.
Expand All @@ -19,18 +16,6 @@ pub const DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT: usize = 16;
/// Type alias for identifying connections.
pub type ConnectionId = u64;

/// Holds the shutdown signal for some server.
#[derive(Debug)]
pub struct ServerShutdown {
pub(crate) _shutdown: watch::Sender<()>,
}

impl From<watch::Sender<()>> for ServerShutdown {
fn from(sender: watch::Sender<()>) -> Self {
Self { _shutdown: sender }
}
}

/// The `ListenerTask` listens for new connections, and spawns `RouteTask`s for
/// each.
pub(crate) struct ListenerTask<T: Listener> {
Expand Down Expand Up @@ -67,16 +52,17 @@ where
}

/// Spawn the future produced by [`Self::task_future`].
pub(crate) fn spawn(self) -> JoinHandle<()> {
pub(crate) fn spawn(self) -> JoinHandle<Option<()>> {
let tasks = self.manager.root_tasks.clone();
let future = self.task_future();
tokio::spawn(future)
tasks.spawn_cancellable(future)
}
}

/// The `ConnectionManager` provides connections with IDs, and handles spawning
/// the [`RouteTask`] for each connection.
pub(crate) struct ConnectionManager {
pub(crate) shutdown: watch::Receiver<()>,
pub(crate) root_tasks: TaskSet,

pub(crate) next_id: ConnectionId,

Expand Down Expand Up @@ -107,19 +93,18 @@ impl ConnectionManager {
) -> (RouteTask<T>, WriteTask<T>) {
let (tx, rx) = mpsc::channel(self.notification_buffer_per_task);

let (gone_tx, gone_rx) = oneshot::channel();
let tasks = self.root_tasks.child();

let rt = RouteTask {
router: self.router(),
conn_id,
write_task: tx,
requests,
gone: gone_tx,
tasks: tasks.clone(),
};

let wt = WriteTask {
shutdown: self.shutdown.clone(),
gone: gone_rx,
tasks,
conn_id,
json: rx,
connection,
Expand Down Expand Up @@ -156,8 +141,8 @@ struct RouteTask<T: crate::pubsub::Listener> {
pub(crate) write_task: mpsc::Sender<Box<RawValue>>,
/// Stream of requests.
pub(crate) requests: In<T>,
/// Sender to the [`WriteTask`], to notify it that this task is done.
pub(crate) gone: oneshot::Sender<()>,
/// The task set for this connection
pub(crate) tasks: TaskSet,
}

impl<T: crate::pubsub::Listener> fmt::Debug for RouteTask<T> {
Expand All @@ -179,18 +164,27 @@ where
/// to handle the request, and given a sender to the [`WriteTask`]. This
/// ensures that requests can be handled concurrently.
#[instrument(name = "RouteTask", skip(self), fields(conn_id = self.conn_id))]
pub async fn task_future(self) {
pub async fn task_future(self, cancel: WaitForCancellationFutureOwned) {
let RouteTask {
router,
mut requests,
write_task,
gone,
tasks,
..
} = self;

// The write task is responsible for waiting for its children
let children = tasks.child();

pin!(cancel);

loop {
select! {
biased;
_ = &mut cancel => {
debug!("RouteTask cancelled");
break;
}
_ = write_task.closed() => {
debug!("WriteTask has gone away");
break;
Expand All @@ -208,7 +202,11 @@ where

let span = debug_span!("pubsub request handling", reqs = reqs.len());

let ctx = write_task.clone().into();
let ctx =
HandlerCtx::new(
Some(write_task.clone()),
children.clone(),
);

let fut = router.handle_request_batch(ctx, reqs);
let write_task = write_task.clone();
Expand All @@ -223,7 +221,7 @@ where
};

// Run the future in a new task.
tokio::spawn(
children.spawn_cancellable(
async move {
// Send the response to the write task.
// we don't care if the receiver has gone away,
Expand All @@ -239,27 +237,23 @@ where
}
}
}
// No funny business. Drop the gone signal.
drop(gone);
children.shutdown().await;
}

/// Spawn the future produced by [`Self::task_future`].
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<()> {
let future = self.task_future();
tokio::spawn(future)
let tasks = self.tasks.clone();

let future = move |cancel| self.task_future(cancel);

tasks.spawn_graceful(future)
}
}

/// The Write Task is responsible for writing JSON to the outbound connection.
struct WriteTask<T: Listener> {
/// Shutdown signal.
///
/// Shutdowns bubble back up to [`RouteTask`] when the write task is
/// dropped, via the closed `json` channel.
pub(crate) shutdown: watch::Receiver<()>,

/// Signal that the connection has gone away.
pub(crate) gone: oneshot::Receiver<()>,
/// Task set
pub(crate) tasks: TaskSet,

/// ID of the connection.
pub(crate) conn_id: ConnectionId,
Expand All @@ -281,25 +275,23 @@ impl<T: Listener> WriteTask<T> {
/// channel, and acts on them. It handles JSON messages, and going away
/// instructions. It also listens for the global shutdown signal from the
/// [`ServerShutdown`] struct.
///
/// [`ServerShutdown`]: crate::pubsub::ServerShutdown
#[instrument(skip(self), fields(conn_id = self.conn_id))]
pub(crate) async fn task_future(self) {
let WriteTask {
mut shutdown,
mut gone,
tasks,
mut json,
mut connection,
..
} = self;
shutdown.mark_unchanged();

loop {
select! {
biased;
_ = &mut gone => {
debug!("Connection has gone away");
break;
}
_ = shutdown.changed() => {
debug!("shutdown signal received");

_ = tasks.cancelled() => {
debug!("Shutdown signal received");
break;
}
json = json.recv() => {
Expand All @@ -317,7 +309,9 @@ impl<T: Listener> WriteTask<T> {
}

/// Spawn the future produced by [`Self::task_future`].
pub(crate) fn spawn(self) -> JoinHandle<()> {
tokio::spawn(self.task_future())
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<Option<()>> {
let tasks = self.tasks.clone();
let future = self.task_future();
tasks.spawn_cancellable(future)
}
}
Loading

0 comments on commit ba0eb69

Please sign in to comment.