Skip to content

Commit

Permalink
feat: allow pluggable tower layers in connector service stack
Browse files Browse the repository at this point in the history
  • Loading branch information
JESS IZEN committed Dec 12, 2024
1 parent d36c0f5 commit f04f579
Show file tree
Hide file tree
Showing 10 changed files with 1,332 additions and 221 deletions.
13 changes: 11 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = ["Sean McArthur <[email protected]>"]
readme = "README.md"
license = "MIT OR Apache-2.0"
edition = "2021"
rust-version = "1.63.0"
rust-version = "1.64.0"
autotests = true

[package.metadata.docs.rs]
Expand Down Expand Up @@ -105,6 +105,7 @@ url = "2.4"
bytes = "1.0"
serde = "1.0"
serde_urlencoded = "0.7.1"
tower = { version = "0.5.2", default-features = false, features = ["timeout", "util"] }
tower-service = "0.3"
futures-core = { version = "0.3.28", default-features = false }
futures-util = { version = "0.3.28", default-features = false }
Expand Down Expand Up @@ -169,7 +170,6 @@ quinn = { version = "0.11.1", default-features = false, features = ["rustls", "r
slab = { version = "0.4.9", optional = true } # just to get minimal versions working with quinn
futures-channel = { version = "0.3", optional = true }


[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
env_logger = "0.10"
hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] }
Expand Down Expand Up @@ -222,6 +222,11 @@ features = [
wasm-bindgen = { version = "0.2.89", features = ["serde-serialize"] }
wasm-bindgen-test = "0.3"

[dev-dependencies]
tower = { version = "0.5.2", default-features = false, features = ["limit"] }
num_cpus = "1.0"
libc = "0"

[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(reqwest_unstable)'] }

Expand Down Expand Up @@ -253,6 +258,10 @@ path = "examples/form.rs"
name = "simple"
path = "examples/simple.rs"

[[example]]
name = "connect_via_lower_priority_tokio_runtime"
path = "examples/connect_via_lower_priority_tokio_runtime.rs"

[[test]]
name = "blocking"
path = "tests/blocking.rs"
Expand Down
272 changes: 272 additions & 0 deletions examples/connect_via_lower_priority_tokio_runtime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
#![deny(warnings)]
// This example demonstrates how to delegate the connect calls, which contain TLS handshakes,
// to a secondary tokio runtime of lower OS thread priority using a custom tower layer.
// This helps to ensure that long-running futures during handshake crypto operations don't block other I/O futures.
//
// This does introduce overhead of additional threads, channels, extra vtables, etc,
// so it is best suited to services with large numbers of incoming connections or that
// are otherwise very sensitive to any blocking futures. Or, you might want fewer threads
// and/or to use the current_thread runtime.
//
// This is using the `tokio` runtime and certain other dependencies:
//
// `tokio = { version = "1", features = ["full"] }`
// `num_cpus = "1.0"`
// `libc = "0"`
// `pin-project-lite = "0.2"`
// `tower = { version = "0.5", default-features = false}`

#[cfg(not(target_arch = "wasm32"))]
#[tokio::main]
async fn main() -> Result<(), reqwest::Error> {
background_threadpool::init_background_runtime();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;

let client = reqwest::Client::builder()
.connector_layer(background_threadpool::BackgroundProcessorLayer::<http::Uri>::new())
.build()
.expect("should be able to build reqwest client");

let url = if let Some(url) = std::env::args().nth(1) {
url
} else {
println!("No CLI URL provided, using default.");
"https://hyper.rs".into()
};

eprintln!("Fetching {url:?}...");

let res = client.get(url).send().await?;

eprintln!("Response: {:?} {}", res.version(), res.status());
eprintln!("Headers: {:#?}\n", res.headers());

let body = res.text().await?;

println!("{body}");

Ok(())
}

// separating out for convenience to avoid a million #[cfg(not(target_arch = "wasm32"))]
#[cfg(not(target_arch = "wasm32"))]
mod background_threadpool {
use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::OnceLock,
task::{Context, Poll},
};

use futures_util::TryFutureExt;
use pin_project_lite::pin_project;
use tokio::{runtime::Handle, select, sync::mpsc::error::TrySendError};
use tower::{BoxError, Layer, Service};

static CPU_HEAVY_THREAD_POOL: OnceLock<
tokio::sync::mpsc::Sender<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
> = OnceLock::new();

pub(crate) fn init_background_runtime() {
std::thread::Builder::new()
.name("cpu-heavy-background-threadpool".to_string())
.spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.thread_name("cpu-heavy-background-pool-thread")
.worker_threads(num_cpus::get() as usize)
// ref: https://github.com/tokio-rs/tokio/issues/4941
// consider uncommenting if seeing heavy task contention
// .disable_lifo_slot()
.on_thread_start(move || unsafe {
// Reduce thread pool thread niceness, so they are lower priority
// than the foreground executor and don't interfere with I/O tasks
#[cfg(target_os = "linux")]
{
*libc::__errno_location() = 0;
if libc::nice(10) == -1 && *libc::__errno_location() != 0 {
let error = std::io::Error::last_os_error();
log::error!("failed to set threadpool niceness: {}", error);
}
}
})
.enable_all()
.build()
.unwrap_or_else(|e| panic!("cpu heavy runtime failed_to_initialize: {}", e));
rt.block_on(async {
log::debug!("starting background cpu-heavy work");
process_cpu_work().await;
});
})
.unwrap_or_else(|e| panic!("cpu heavy thread failed_to_initialize: {}", e));
}

#[cfg(not(target_arch = "wasm32"))]
async fn process_cpu_work() {
// we only use this channel for routing work, it should move pretty quick, it can be small
let (tx, mut rx) = tokio::sync::mpsc::channel(10);
// share the handle to the background channel globally
CPU_HEAVY_THREAD_POOL.set(tx).unwrap();

while let Some(work) = rx.recv().await {
tokio::task::spawn(work);
}
}

// retrieve the sender to the background channel, and send the future over to it for execution
fn send_to_background_runtime(future: impl Future<Output = ()> + Send + 'static) {
let tx = CPU_HEAVY_THREAD_POOL.get().expect(
"start up the secondary tokio runtime before sending to `CPU_HEAVY_THREAD_POOL`",
);

match tx.try_send(Box::pin(future)) {
Ok(_) => (),
Err(TrySendError::Closed(_)) => {
panic!("background cpu heavy runtime channel is closed")
}
Err(TrySendError::Full(msg)) => {
log::warn!(
"background cpu heavy runtime channel is full, task spawning loop delayed"
);
let tx = tx.clone();
Handle::current().spawn(async move {
tx.send(msg)
.await
.expect("background cpu heavy runtime channel is closed")
});
}
}
}

// This tower layer injects futures with a oneshot channel, and then sends them to the background runtime for processing.
// We don't use the Buffer service because that is intended to process sequentially on a single task, whereas we want to
// spawn a new task per call.
pub struct BackgroundProcessorLayer<Request> {
_p: PhantomData<fn(Request)>,
}
impl<Request> BackgroundProcessorLayer<Request> {
pub fn new() -> Self {
Self { _p: PhantomData }
}
}
impl<S, Request> Layer<S> for BackgroundProcessorLayer<Request> {
type Service = BackgroundProcessor<S>;
fn layer(&self, service: S) -> Self::Service {
BackgroundProcessor::new(service)
}
}

impl<Request> std::fmt::Debug for BackgroundProcessorLayer<Request> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("BackgroundProcessorLayer").finish()
}
}

impl<Request> Clone for BackgroundProcessorLayer<Request> {
fn clone(&self) -> Self {
Self { _p: PhantomData }
}
}

impl<Request> Copy for BackgroundProcessorLayer<Request> {}

// This tower service injects futures with a oneshot channel, and then sends them to the background runtime for processing.
#[derive(Debug, Clone)]
pub struct BackgroundProcessor<S> {
inner: S,
}

impl<S> BackgroundProcessor<S> {
pub fn new(inner: S) -> Self {
BackgroundProcessor { inner }
}
}

impl<S, Request> Service<Request> for BackgroundProcessor<S>
where
S: Service<Request>,
S::Response: Send + 'static,
S::Error: Into<BoxError> + Send,
S::Future: Send + 'static,
{
type Response = S::Response;

type Error = BoxError;

type Future = BackgroundResponseFuture<S::Response>;

fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)),
}
}

fn call(&mut self, req: Request) -> Self::Future {
let response = self.inner.call(req);

// wrap our inner service's future with a future that writes to this oneshot channel
let (mut tx, rx) = tokio::sync::oneshot::channel();
let future = async move {
select!(
_ = tx.closed() => {
// receiver already dropped, don't need to do anything
}
result = response.map_err(|err| Into::<BoxError>::into(err)) => {
// if this fails, the receiver already dropped, so we don't need to do anything
let _ = tx.send(result);
}
)
};
// send the wrapped future to the background
send_to_background_runtime(future);

BackgroundResponseFuture::new(rx)
}
}

// `BackgroundProcessor` response future
pin_project! {
#[derive(Debug)]
pub struct BackgroundResponseFuture<S> {
#[pin]
rx: tokio::sync::oneshot::Receiver<Result<S, BoxError>>,
}
}

impl<S> BackgroundResponseFuture<S> {
pub(crate) fn new(rx: tokio::sync::oneshot::Receiver<Result<S, BoxError>>) -> Self {
BackgroundResponseFuture { rx }
}
}

impl<S> Future for BackgroundResponseFuture<S>
where
S: Send + 'static,
{
type Output = Result<S, BoxError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();

// now poll on the receiver end of the oneshot to get the result
match this.rx.poll(cx) {
Poll::Ready(v) => match v {
Ok(v) => Poll::Ready(v.map_err(Into::into)),
Err(err) => Poll::Ready(Err(Box::new(err) as BoxError)),
},
Poll::Pending => Poll::Pending,
}
}
}
}

// The [cfg(not(target_arch = "wasm32"))] above prevent building the tokio::main function
// for wasm32 target, because tokio isn't compatible with wasm32.
// If you aren't building for wasm32, you don't need that line.
// The two lines below avoid the "'main' function not found" error when building for wasm32 target.
#[cfg(any(target_arch = "wasm32"))]
fn main() {}
Loading

0 comments on commit f04f579

Please sign in to comment.