Skip to content

Commit

Permalink
feat(s2n-quic-dc): only poll accepted streams that are ready
Browse files Browse the repository at this point in the history
  • Loading branch information
camshaft committed Dec 7, 2024
1 parent 645c36a commit 2eeb553
Show file tree
Hide file tree
Showing 4 changed files with 364 additions and 38 deletions.
142 changes: 104 additions & 38 deletions dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
server,
socket::Socket,
},
task::waker,
};
use core::{
future::poll_fn,
Expand All @@ -31,7 +32,11 @@ use s2n_quic_core::{
recovery::RttEstimator,
time::{Clock, Timestamp},
};
use std::{collections::VecDeque, io};
use std::{
collections::{BTreeSet, VecDeque},
io,
task::Waker,
};
use tokio::{
io::AsyncWrite as _,
net::{TcpListener, TcpStream},
Expand All @@ -49,6 +54,7 @@ where
backlog: usize,
accept_flavor: accept::Flavor,
subscriber: Sub,
waker_set: waker::Set,
}

impl<Sub> Acceptor<Sub>
Expand All @@ -74,6 +80,7 @@ where
backlog,
accept_flavor,
subscriber,
waker_set: Default::default(),
};

if let Ok(addr) = acceptor.socket.local_addr() {
Expand All @@ -90,23 +97,43 @@ where
acceptor
}

pub async fn run(self) {
pub async fn run(mut self) {
let drop_guard = DropLog;
let mut fresh = FreshQueue::new(&self);
let mut workers = WorkerSet::new(&self);
let mut workers = WorkerSet::new(&mut self);
let mut context = WorkerContext::new(&self);

let mut prev_waker: Option<Waker> = None;

poll_fn(move |cx| {
let waker_needs_update = if let Some(prev_waker) = prev_waker.as_mut() {
!prev_waker.will_wake(cx.waker())
} else {
true
};

if waker_needs_update {
prev_waker = Some(cx.waker().clone());
self.waker_set.update_root(cx.waker());
}

let now = self.env.clock().get_time();
let publisher = publisher(&self.subscriber, &now);

fresh.fill(cx, &self.socket, &publisher);

for (socket, remote_address) in fresh.drain() {
workers.push(socket, remote_address, now, &self.subscriber, &publisher);
workers.push(
socket,
remote_address,
&mut context,
now,
&self.subscriber,
&publisher,
);
}

let res = workers.poll(cx, &mut context, now, &publisher);
let res = workers.poll(&mut self.waker_set, &mut context, now, &publisher);

publisher.on_acceptor_tcp_loop_iteration_completed(
event::builder::AcceptorTcpLoopIterationCompleted {
Expand Down Expand Up @@ -261,6 +288,8 @@ where
/// This list is ordered by sojourn time, where the front of the list is the oldest. The front
/// will be the first to be reclaimed in the case of overload.
working: VecDeque<usize>,
/// A list of [`Worker`] entries that have transitioned from working to free
pending_removal: BTreeSet<usize>,
/// Tracks the [sojourn time](https://en.wikipedia.org/wiki/Mean_sojourn_time) of processing
/// streams in worker entries.
sojourn_time: RttEstimator,
Expand All @@ -271,19 +300,20 @@ where
Sub: event::Subscriber + Clone,
{
#[inline]
pub fn new(acceptor: &Acceptor<Sub>) -> Self {
pub fn new(acceptor: &mut Acceptor<Sub>) -> Self {
let backlog = acceptor.backlog;
let mut workers = Vec::with_capacity(backlog);
let mut free = VecDeque::with_capacity(backlog);
let now = acceptor.env.clock().get_time();
for idx in 0..backlog {
workers.push(Worker::new(now));
workers.push(Worker::new(now, acceptor.waker_set.waker(idx)));
free.push_back(idx);
}
Self {
workers: workers.into(),
free,
working: VecDeque::with_capacity(backlog),
pending_removal: BTreeSet::new(),
// set the initial estimate high to avoid backlog churn before we get stable samples
sojourn_time: RttEstimator::new(Duration::from_secs(30)),
}
Expand All @@ -294,6 +324,7 @@ where
&mut self,
stream: TcpStream,
remote_address: SocketAddress,
worker_cx: &mut WorkerContext<Sub>,
now: Timestamp,
subscriber: &Sub,
publisher: &Pub,
Expand All @@ -315,51 +346,85 @@ where
};
self.workers[idx].push(stream, remote_address, now, subscriber, publisher);
self.working.push_back(idx);

// kick off the initial poll to register wakers with the socket
self.poll_worker(worker_cx, now, publisher, idx);
}

#[inline]
pub fn poll<Pub>(
&mut self,
cx: &mut Context,
waker_set: &mut waker::Set,
worker_cx: &mut WorkerContext<Sub>,
now: Timestamp,
publisher: &Pub,
) -> ControlFlow<()>
where
Pub: EndpointPublisher,
{
// poll any workers that are in our set
for idx in waker_set.drain() {
if self.poll_worker(worker_cx, now, publisher, idx).is_break() {
return ControlFlow::Break(());
}
}

// check if there are any workers that are pending removal
if self.pending_removal.is_empty() {
return ControlFlow::Continue(());
}

// remove any workers that were freed up
self.working.retain(|idx| {
if self.pending_removal.remove(idx) {
self.free.push_back(*idx);
false
} else {
true
}
});

ControlFlow::Continue(())
}

#[inline]
fn poll_worker<Pub>(
&mut self,
worker_cx: &mut WorkerContext<Sub>,
now: Timestamp,
publisher: &Pub,
idx: usize,
) -> ControlFlow<()>
where
Pub: EndpointPublisher,
{
let mut cf = ControlFlow::Continue(());

self.working.retain(|&idx| {
let worker = &mut self.workers[idx];
let Poll::Ready(res) = worker.poll(cx, worker_cx, now, publisher) else {
// keep processing it
return true;
};
let worker = &mut self.workers[idx];
let Poll::Ready(res) = worker.poll(worker_cx, now, publisher) else {
return cf;
};

match res {
Ok(ControlFlow::Continue(())) => {
// update the accept_time estimate
self.sojourn_time.update_rtt(
Duration::ZERO,
worker.sojourn(now),
now,
true,
PacketNumberSpace::ApplicationData,
);
}
Ok(ControlFlow::Break(())) => {
cf = ControlFlow::Break(());
}
Err(Some(err)) => publisher
.on_acceptor_tcp_io_error(event::builder::AcceptorTcpIoError { error: &err }),
Err(None) => {}
match res {
Ok(ControlFlow::Continue(())) => {
// update the accept_time estimate
self.sojourn_time.update_rtt(
Duration::ZERO,
worker.sojourn(now),
now,
true,
PacketNumberSpace::ApplicationData,
);
}
Ok(ControlFlow::Break(())) => {
cf = ControlFlow::Break(());
}
Err(Some(err)) => publisher
.on_acceptor_tcp_io_error(event::builder::AcceptorTcpIoError { error: &err }),
Err(None) => {}
}

// the worker is done so remove it from the working queue
self.free.push_back(idx);
false
});
self.pending_removal.insert(idx);

cf
}
Expand Down Expand Up @@ -474,18 +539,20 @@ where
stream: Option<(TcpStream, SocketAddress)>,
subscriber_ctx: Option<Sub::ConnectionContext>,
state: WorkerState,
waker: Waker,
}

impl<Sub> Worker<Sub>
where
Sub: event::Subscriber + Clone,
{
pub fn new(now: Timestamp) -> Self {
pub fn new(now: Timestamp, waker: Waker) -> Self {
Self {
queue_time: now,
stream: None,
subscriber_ctx: None,
state: WorkerState::Init,
waker,
}
}

Expand Down Expand Up @@ -539,7 +606,6 @@ where
#[inline]
pub fn poll<Pub>(
&mut self,
cx: &mut Context,
context: &mut WorkerContext<Sub>,
now: Timestamp,
publisher: &Pub,
Expand All @@ -561,7 +627,7 @@ where
context.recv_buffer.clear();

let res = ready!(self.state.poll(
cx,
&mut core::task::Context::from_waker(&self.waker),
context,
&mut self.stream,
&mut self.subscriber_ctx,
Expand Down
3 changes: 3 additions & 0 deletions dc/s2n-quic-dc/src/task/waker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

pub mod set;
pub mod worker;

pub use set::Set;
89 changes: 89 additions & 0 deletions dc/s2n-quic-dc/src/task/waker/set.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use super::worker;
use std::{
sync::{Arc, Mutex},
task::{Wake, Waker},
};

mod bitset;
use bitset::BitSet;

#[derive(Default)]
pub struct Set {
state: Arc<State>,
ready: BitSet,
}

impl Set {
/// Updates the root waker
pub fn update_root(&self, waker: &Waker) {
self.state.root.update(waker);
}

/// Registers a waker with the given ID
pub fn waker(&mut self, id: usize) -> Waker {
// reserve space in the locally ready set
self.ready.resize_for_id(id);
let state = self.state.clone();
state.ready.lock().unwrap().resize_for_id(id);
Waker::from(Arc::new(Slot { id, state }))
}

/// Returns all of the IDs that are woken
pub fn drain(&mut self) -> impl Iterator<Item = usize> + '_ {
core::mem::swap(&mut self.ready, &mut self.state.ready.lock().unwrap());
self.ready.drain()
}
}

#[derive(Default)]
struct State {
root: worker::Waker,
ready: Mutex<BitSet>,
}

struct Slot {
id: usize,
state: Arc<State>,
}

impl Wake for Slot {
#[inline]
fn wake(self: Arc<Self>) {
let mut ready = self.state.ready.lock().unwrap();
ready.insert(self.id);
drop(ready);
self.state.root.wake();
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::collections::BTreeSet;

#[test]
fn waker_set_test() {
bolero::check!().with_type::<Vec<u8>>().for_each(|ops| {
let mut root = Set::default();
let mut wakers = vec![];

if let Some(max) = ops.iter().cloned().max() {
let len = max as usize + 1;
for i in 0..len {
wakers.push(root.waker(i));
}
}

for idx in ops {
wakers[*idx as usize].wake_by_ref();
}

let actual = root.drain().collect::<BTreeSet<_>>();
let expected = ops.iter().map(|v| *v as usize).collect::<BTreeSet<_>>();
assert_eq!(actual, expected);
})
}
}
Loading

0 comments on commit 2eeb553

Please sign in to comment.