Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SSE less dependent on tokio #3154

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions axum/src/response/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use http::{header, HeaderValue, StatusCode};

mod redirect;

#[cfg(feature = "tokio")]
pub mod sse;

#[doc(no_inline)]
Expand All @@ -27,7 +26,6 @@ pub use axum_core::response::{
pub use self::redirect::Redirect;

#[doc(inline)]
#[cfg(feature = "tokio")]
pub use sse::Sse;

/// An HTML response.
Expand Down
196 changes: 124 additions & 72 deletions axum/src/response/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,18 @@ use futures_util::stream::{Stream, TryStream};
use http_body::Frame;
use pin_project_lite::pin_project;
use std::{
fmt,
future::Future,
fmt, mem,
pin::Pin,
task::{ready, Context, Poll},
time::Duration,
};
use sync_wrapper::SyncWrapper;
use tokio::time::Sleep;

/// An SSE response
#[derive(Clone)]
#[must_use]
pub struct Sse<S> {
stream: S,
keep_alive: Option<KeepAlive>,
}

impl<S> Sse<S> {
Expand All @@ -65,26 +62,22 @@ impl<S> Sse<S> {
S: TryStream<Ok = Event> + Send + 'static,
S::Error: Into<BoxError>,
{
Sse {
stream,
keep_alive: None,
}
Sse { stream }
}

/// Configure the interval between keep-alive messages.
///
/// Defaults to no keep-alive messages.
pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
self.keep_alive = Some(keep_alive);
self
#[cfg(feature = "tokio")]
pub fn keep_alive(self, keep_alive: KeepAlive) -> Sse<KeepAliveStream<S>> {
Sse {
stream: KeepAliveStream::new(keep_alive, self.stream),
}
}
}

impl<S> fmt::Debug for Sse<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sse")
.field("stream", &format_args!("{}", std::any::type_name::<S>()))
.field("keep_alive", &self.keep_alive)
.finish()
}
}
Expand All @@ -102,7 +95,6 @@ where
],
Body::new(SseBody {
event_stream: SyncWrapper::new(self.stream),
keep_alive: self.keep_alive.map(KeepAliveStream::new),
}),
)
.into_response()
Expand All @@ -113,8 +105,6 @@ pin_project! {
struct SseBody<S> {
#[pin]
event_stream: SyncWrapper<S>,
#[pin]
keep_alive: Option<KeepAliveStream>,
}
}

Expand All @@ -131,35 +121,54 @@ where
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.project();

match this.event_stream.get_pin_mut().poll_next(cx) {
Poll::Pending => {
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
} else {
Poll::Pending
}
}
Poll::Ready(Some(Ok(event))) => {
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive.reset();
match ready!(this.event_stream.get_pin_mut().poll_next(cx)) {
Some(Ok(event)) => Poll::Ready(Some(Ok(Frame::data(event.finalize())))),
Some(Err(error)) => Poll::Ready(Some(Err(error))),
None => Poll::Ready(None),
}
}
}

#[derive(Debug, Clone)]
enum Buffer {
Active(BytesMut),
Finalized(Bytes),
}

impl Buffer {
fn as_mut(&mut self) -> &mut BytesMut {
match self {
Buffer::Active(bytes_mut) => bytes_mut,
Buffer::Finalized(bytes) => {
*self = Buffer::Active(BytesMut::from(mem::take(bytes)));
match self {
Buffer::Active(bytes_mut) => bytes_mut,
Buffer::Finalized(_) => unreachable!(),
}
Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
}
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
Poll::Ready(None) => Poll::Ready(None),
}
}
}

/// Server-sent event
#[derive(Debug, Default, Clone)]
#[derive(Debug, Clone)]
#[must_use]
pub struct Event {
buffer: BytesMut,
buffer: Buffer,
flags: EventFlags,
}

impl Event {
/// Default keep-alive event
pub const DEFAULT_KEEP_ALIVE: Self = Self::finalized(Bytes::from_static(b":\n\n"));

const fn finalized(bytes: Bytes) -> Self {
Self {
buffer: Buffer::Finalized(bytes),
flags: EventFlags::from_bits(0),
}
}

/// Set the event's data data field(s) (`data: <content>`)
///
/// Newlines in `data` will automatically be broken across `data: ` fields.
Expand All @@ -179,7 +188,7 @@ impl Event {
T: AsRef<str>,
{
if self.flags.contains(EventFlags::HAS_DATA) {
panic!("Called `EventBuilder::data` multiple times");
panic!("Called `Event::data` multiple times");
}

for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
Expand Down Expand Up @@ -222,13 +231,14 @@ impl Event {
}
}
if self.flags.contains(EventFlags::HAS_DATA) {
panic!("Called `EventBuilder::json_data` multiple times");
panic!("Called `Event::json_data` multiple times");
}

self.buffer.extend_from_slice(b"data: ");
serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data)
let buffer = self.buffer.as_mut();
buffer.extend_from_slice(b"data: ");
serde_json::to_writer(IgnoreNewLines(buffer.writer()), &data)
.map_err(axum_core::Error::new)?;
self.buffer.put_u8(b'\n');
buffer.put_u8(b'\n');

self.flags.insert(EventFlags::HAS_DATA);

Expand Down Expand Up @@ -272,7 +282,7 @@ impl Event {
T: AsRef<str>,
{
if self.flags.contains(EventFlags::HAS_EVENT) {
panic!("Called `EventBuilder::event` multiple times");
panic!("Called `Event::event` multiple times");
}
self.flags.insert(EventFlags::HAS_EVENT);

Expand All @@ -292,33 +302,32 @@ impl Event {
/// Panics if this function has already been called on this event.
pub fn retry(mut self, duration: Duration) -> Event {
if self.flags.contains(EventFlags::HAS_RETRY) {
panic!("Called `EventBuilder::retry` multiple times");
panic!("Called `Event::retry` multiple times");
}
self.flags.insert(EventFlags::HAS_RETRY);

self.buffer.extend_from_slice(b"retry:");
let buffer = self.buffer.as_mut();
buffer.extend_from_slice(b"retry:");

let secs = duration.as_secs();
let millis = duration.subsec_millis();

if secs > 0 {
// format seconds
self.buffer
.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
buffer.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());

// pad milliseconds
if millis < 10 {
self.buffer.extend_from_slice(b"00");
buffer.extend_from_slice(b"00");
} else if millis < 100 {
self.buffer.extend_from_slice(b"0");
buffer.extend_from_slice(b"0");
}
}

// format milliseconds
self.buffer
.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
buffer.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());

self.buffer.put_u8(b'\n');
buffer.put_u8(b'\n');

self
}
Expand All @@ -340,7 +349,7 @@ impl Event {
T: AsRef<str>,
{
if self.flags.contains(EventFlags::HAS_ID) {
panic!("Called `EventBuilder::id` multiple times");
panic!("Called `Event::id` multiple times");
}
self.flags.insert(EventFlags::HAS_ID);

Expand All @@ -362,20 +371,36 @@ impl Event {
None,
"SSE field value cannot contain newlines or carriage returns",
);
self.buffer.extend_from_slice(name.as_bytes());
self.buffer.put_u8(b':');
self.buffer.put_u8(b' ');
self.buffer.extend_from_slice(value);
self.buffer.put_u8(b'\n');

let buffer = self.buffer.as_mut();
buffer.extend_from_slice(name.as_bytes());
buffer.put_u8(b':');
buffer.put_u8(b' ');
buffer.extend_from_slice(value);
buffer.put_u8(b'\n');
}

fn finalize(mut self) -> Bytes {
self.buffer.put_u8(b'\n');
self.buffer.freeze()
fn finalize(self) -> Bytes {
match self.buffer {
Buffer::Finalized(bytes) => bytes,
Buffer::Active(mut bytes_mut) => {
bytes_mut.put_u8(b'\n');
bytes_mut.freeze()
}
}
}
}

#[derive(Default, Debug, Copy, Clone, PartialEq)]
impl Default for Event {
fn default() -> Self {
Self {
buffer: Buffer::Active(BytesMut::new()),
flags: EventFlags::from_bits(0),
}
}
}

#[derive(Debug, Copy, Clone, PartialEq)]
struct EventFlags(u8);

impl EventFlags {
Expand Down Expand Up @@ -406,15 +431,15 @@ impl EventFlags {
#[derive(Debug, Clone)]
#[must_use]
pub struct KeepAlive {
event: Bytes,
event: Event,
max_interval: Duration,
}

impl KeepAlive {
/// Create a new `KeepAlive`.
pub fn new() -> Self {
Self {
event: Bytes::from_static(b":\n\n"),
event: Event::DEFAULT_KEEP_ALIVE,
max_interval: Duration::from_secs(15),
}
}
Expand Down Expand Up @@ -451,7 +476,7 @@ impl KeepAlive {
/// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
/// comments.
pub fn event(mut self, event: Event) -> Self {
self.event = event.finalize();
self.event = Event::finalized(event.finalize());
self
}
}
Expand All @@ -462,19 +487,25 @@ impl Default for KeepAlive {
}
}

#[cfg(feature = "tokio")]
pin_project! {
/// A wrapper around a stream that produces keep-alive events
#[derive(Debug)]
struct KeepAliveStream {
keep_alive: KeepAlive,
pub struct KeepAliveStream<S> {
#[pin]
alive_timer: Sleep,
alive_timer: tokio::time::Sleep,
#[pin]
inner: S,
keep_alive: KeepAlive,
}
}

impl KeepAliveStream {
fn new(keep_alive: KeepAlive) -> Self {
#[cfg(feature = "tokio")]
impl<S> KeepAliveStream<S> {
fn new(keep_alive: KeepAlive, inner: S) -> Self {
Self {
alive_timer: tokio::time::sleep(keep_alive.max_interval),
inner,
keep_alive,
}
}
Expand All @@ -484,17 +515,38 @@ impl KeepAliveStream {
this.alive_timer
.reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
}
}

fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
let this = self.as_mut().project();
#[cfg(feature = "tokio")]
impl<S, E> Stream for KeepAliveStream<S>
where
S: Stream<Item = Result<Event, E>>,
{
type Item = Result<Event, E>;

ready!(this.alive_timer.poll(cx));
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use std::future::Future;

let event = this.keep_alive.event.clone();
let mut this = self.as_mut().project();

self.reset();
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(event))) => {
self.reset();

Poll::Ready(event)
Poll::Ready(Some(Ok(event)))
}
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => {
ready!(this.alive_timer.poll(cx));

let event = this.keep_alive.event.clone();

self.reset();

Poll::Ready(Some(Ok(event)))
}
}
}
}

Expand Down