From aa0a73d9d79cf6b495c5f7d948872b62e6a52ab6 Mon Sep 17 00:00:00 2001 From: nanoqsh Date: Mon, 6 Jan 2025 09:16:18 +0500 Subject: [PATCH 1/3] Update `Sse::keep_alive` --- axum/src/response/mod.rs | 2 - axum/src/response/sse.rs | 189 +++++++++++++++++++++++++-------------- 2 files changed, 121 insertions(+), 70 deletions(-) diff --git a/axum/src/response/mod.rs b/axum/src/response/mod.rs index dd616dff57..70be745200 100644 --- a/axum/src/response/mod.rs +++ b/axum/src/response/mod.rs @@ -4,7 +4,6 @@ use http::{header, HeaderValue, StatusCode}; mod redirect; -#[cfg(feature = "tokio")] pub mod sse; #[doc(no_inline)] @@ -27,7 +26,6 @@ pub use axum_core::response::{ pub use self::redirect::Redirect; #[doc(inline)] -#[cfg(feature = "tokio")] pub use sse::Sse; /// An HTML response. diff --git a/axum/src/response/sse.rs b/axum/src/response/sse.rs index 54ec2b46a1..881357ec86 100644 --- a/axum/src/response/sse.rs +++ b/axum/src/response/sse.rs @@ -40,6 +40,7 @@ use pin_project_lite::pin_project; use std::{ fmt, future::Future, + mem, pin::Pin, task::{ready, Context, Poll}, time::Duration, @@ -52,7 +53,6 @@ use tokio::time::Sleep; #[must_use] pub struct Sse { stream: S, - keep_alive: Option, } impl Sse { @@ -65,18 +65,15 @@ impl Sse { S: TryStream + Send + 'static, S::Error: Into, { - Sse { - stream, - keep_alive: None, - } + Sse { stream } } /// Configure the interval between keep-alive messages. - /// - /// Defaults to no keep-alive messages. - pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self { - self.keep_alive = Some(keep_alive); - self + #[cfg(feature = "tokio")] + pub fn keep_alive(self, keep_alive: KeepAlive) -> Sse> { + Sse { + stream: KeepAliveStream::new(keep_alive, self.stream), + } } } @@ -84,7 +81,6 @@ impl fmt::Debug for Sse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Sse") .field("stream", &format_args!("{}", std::any::type_name::())) - .field("keep_alive", &self.keep_alive) .finish() } } @@ -102,7 +98,6 @@ where ], Body::new(SseBody { event_stream: SyncWrapper::new(self.stream), - keep_alive: self.keep_alive.map(KeepAliveStream::new), }), ) .into_response() @@ -113,8 +108,6 @@ pin_project! { struct SseBody { #[pin] event_stream: SyncWrapper, - #[pin] - keep_alive: Option, } } @@ -131,35 +124,54 @@ where ) -> Poll, Self::Error>>> { let this = self.project(); - match this.event_stream.get_pin_mut().poll_next(cx) { - Poll::Pending => { - if let Some(keep_alive) = this.keep_alive.as_pin_mut() { - keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e)))) - } else { - Poll::Pending - } - } - Poll::Ready(Some(Ok(event))) => { - if let Some(keep_alive) = this.keep_alive.as_pin_mut() { - keep_alive.reset(); + match ready!(this.event_stream.get_pin_mut().poll_next(cx)) { + Some(Ok(event)) => Poll::Ready(Some(Ok(Frame::data(event.finalize())))), + Some(Err(error)) => Poll::Ready(Some(Err(error))), + None => Poll::Ready(None), + } + } +} + +#[derive(Debug, Clone)] +enum Buffer { + Active(BytesMut), + Finalized(Bytes), +} + +impl Buffer { + fn as_mut(&mut self) -> &mut BytesMut { + match self { + Buffer::Active(bytes_mut) => bytes_mut, + Buffer::Finalized(bytes) => { + *self = Buffer::Active(BytesMut::from(mem::take(bytes))); + match self { + Buffer::Active(bytes_mut) => bytes_mut, + Buffer::Finalized(_) => unreachable!(), } - Poll::Ready(Some(Ok(Frame::data(event.finalize())))) } - Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))), - Poll::Ready(None) => Poll::Ready(None), } } } /// Server-sent event -#[derive(Debug, Default, Clone)] +#[derive(Debug, Clone)] #[must_use] pub struct Event { - buffer: BytesMut, + buffer: Buffer, flags: EventFlags, } impl Event { + /// Default keep-alive event + pub const DEFAULT_KEEP_ALIVE: Self = Self::finalized(Bytes::from_static(b":\n\n")); + + const fn finalized(bytes: Bytes) -> Self { + Self { + buffer: Buffer::Finalized(bytes), + flags: EventFlags::from_bits(0), + } + } + /// Set the event's data data field(s) (`data: `) /// /// Newlines in `data` will automatically be broken across `data: ` fields. @@ -179,7 +191,7 @@ impl Event { T: AsRef, { if self.flags.contains(EventFlags::HAS_DATA) { - panic!("Called `EventBuilder::data` multiple times"); + panic!("Called `Event::data` multiple times"); } for line in memchr_split(b'\n', data.as_ref().as_bytes()) { @@ -222,13 +234,14 @@ impl Event { } } if self.flags.contains(EventFlags::HAS_DATA) { - panic!("Called `EventBuilder::json_data` multiple times"); + panic!("Called `Event::json_data` multiple times"); } - self.buffer.extend_from_slice(b"data: "); - serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data) + let buffer = self.buffer.as_mut(); + buffer.extend_from_slice(b"data: "); + serde_json::to_writer(IgnoreNewLines(buffer.writer()), &data) .map_err(axum_core::Error::new)?; - self.buffer.put_u8(b'\n'); + buffer.put_u8(b'\n'); self.flags.insert(EventFlags::HAS_DATA); @@ -272,7 +285,7 @@ impl Event { T: AsRef, { if self.flags.contains(EventFlags::HAS_EVENT) { - panic!("Called `EventBuilder::event` multiple times"); + panic!("Called `Event::event` multiple times"); } self.flags.insert(EventFlags::HAS_EVENT); @@ -292,33 +305,32 @@ impl Event { /// Panics if this function has already been called on this event. pub fn retry(mut self, duration: Duration) -> Event { if self.flags.contains(EventFlags::HAS_RETRY) { - panic!("Called `EventBuilder::retry` multiple times"); + panic!("Called `Event::retry` multiple times"); } self.flags.insert(EventFlags::HAS_RETRY); - self.buffer.extend_from_slice(b"retry:"); + let buffer = self.buffer.as_mut(); + buffer.extend_from_slice(b"retry:"); let secs = duration.as_secs(); let millis = duration.subsec_millis(); if secs > 0 { // format seconds - self.buffer - .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes()); + buffer.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes()); // pad milliseconds if millis < 10 { - self.buffer.extend_from_slice(b"00"); + buffer.extend_from_slice(b"00"); } else if millis < 100 { - self.buffer.extend_from_slice(b"0"); + buffer.extend_from_slice(b"0"); } } // format milliseconds - self.buffer - .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes()); + buffer.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes()); - self.buffer.put_u8(b'\n'); + buffer.put_u8(b'\n'); self } @@ -340,7 +352,7 @@ impl Event { T: AsRef, { if self.flags.contains(EventFlags::HAS_ID) { - panic!("Called `EventBuilder::id` multiple times"); + panic!("Called `Event::id` multiple times"); } self.flags.insert(EventFlags::HAS_ID); @@ -362,20 +374,36 @@ impl Event { None, "SSE field value cannot contain newlines or carriage returns", ); - self.buffer.extend_from_slice(name.as_bytes()); - self.buffer.put_u8(b':'); - self.buffer.put_u8(b' '); - self.buffer.extend_from_slice(value); - self.buffer.put_u8(b'\n'); + + let buffer = self.buffer.as_mut(); + buffer.extend_from_slice(name.as_bytes()); + buffer.put_u8(b':'); + buffer.put_u8(b' '); + buffer.extend_from_slice(value); + buffer.put_u8(b'\n'); } - fn finalize(mut self) -> Bytes { - self.buffer.put_u8(b'\n'); - self.buffer.freeze() + fn finalize(self) -> Bytes { + match self.buffer { + Buffer::Finalized(bytes) => bytes, + Buffer::Active(mut bytes_mut) => { + bytes_mut.put_u8(b'\n'); + bytes_mut.freeze() + } + } } } -#[derive(Default, Debug, Copy, Clone, PartialEq)] +impl Default for Event { + fn default() -> Self { + Self { + buffer: Buffer::Active(BytesMut::new()), + flags: EventFlags::from_bits(0), + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq)] struct EventFlags(u8); impl EventFlags { @@ -406,7 +434,7 @@ impl EventFlags { #[derive(Debug, Clone)] #[must_use] pub struct KeepAlive { - event: Bytes, + event: Event, max_interval: Duration, } @@ -414,7 +442,7 @@ impl KeepAlive { /// Create a new `KeepAlive`. pub fn new() -> Self { Self { - event: Bytes::from_static(b":\n\n"), + event: Event::DEFAULT_KEEP_ALIVE, max_interval: Duration::from_secs(15), } } @@ -451,7 +479,7 @@ impl KeepAlive { /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE /// comments. pub fn event(mut self, event: Event) -> Self { - self.event = event.finalize(); + self.event = Event::finalized(event.finalize()); self } } @@ -462,19 +490,25 @@ impl Default for KeepAlive { } } +#[cfg(feature = "tokio")] pin_project! { + /// A wrapper around a stream that produces keep-alive events #[derive(Debug)] - struct KeepAliveStream { - keep_alive: KeepAlive, + pub struct KeepAliveStream { #[pin] alive_timer: Sleep, + #[pin] + inner: S, + keep_alive: KeepAlive, } } -impl KeepAliveStream { - fn new(keep_alive: KeepAlive) -> Self { +#[cfg(feature = "tokio")] +impl KeepAliveStream { + fn new(keep_alive: KeepAlive, inner: S) -> Self { Self { alive_timer: tokio::time::sleep(keep_alive.max_interval), + inner, keep_alive, } } @@ -484,17 +518,36 @@ impl KeepAliveStream { this.alive_timer .reset(tokio::time::Instant::now() + this.keep_alive.max_interval); } +} - fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); +#[cfg(feature = "tokio")] +impl Stream for KeepAliveStream +where + S: Stream>, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.as_mut().project(); + + match this.inner.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(event))) => { + self.reset(); - ready!(this.alive_timer.poll(cx)); + Poll::Ready(Some(Ok(event))) + } + Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => { + ready!(this.alive_timer.poll(cx)); - let event = this.keep_alive.event.clone(); + let event = this.keep_alive.event.clone(); - self.reset(); + self.reset(); - Poll::Ready(event) + Poll::Ready(Some(Ok(event))) + } + } } } From 3a8d4620d80cda04b9da81fe721ad38d9c1d27f8 Mon Sep 17 00:00:00 2001 From: nanoqsh Date: Tue, 7 Jan 2025 09:40:50 +0500 Subject: [PATCH 2/3] Specify full path for `Sleep` --- axum/src/response/sse.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/axum/src/response/sse.rs b/axum/src/response/sse.rs index 881357ec86..5db5fd5c28 100644 --- a/axum/src/response/sse.rs +++ b/axum/src/response/sse.rs @@ -46,7 +46,6 @@ use std::{ time::Duration, }; use sync_wrapper::SyncWrapper; -use tokio::time::Sleep; /// An SSE response #[derive(Clone)] @@ -496,7 +495,7 @@ pin_project! { #[derive(Debug)] pub struct KeepAliveStream { #[pin] - alive_timer: Sleep, + alive_timer: tokio::time::Sleep, #[pin] inner: S, keep_alive: KeepAlive, From 69f4e81c55ba7b6d4c637ba58f40af27705f1a9a Mon Sep 17 00:00:00 2001 From: nanoqsh Date: Tue, 7 Jan 2025 09:44:31 +0500 Subject: [PATCH 3/3] Remove global `Future` import --- axum/src/response/sse.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/axum/src/response/sse.rs b/axum/src/response/sse.rs index 5db5fd5c28..a429c1bf4d 100644 --- a/axum/src/response/sse.rs +++ b/axum/src/response/sse.rs @@ -38,9 +38,7 @@ use futures_util::stream::{Stream, TryStream}; use http_body::Frame; use pin_project_lite::pin_project; use std::{ - fmt, - future::Future, - mem, + fmt, mem, pin::Pin, task::{ready, Context, Poll}, time::Duration, @@ -527,6 +525,8 @@ where type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use std::future::Future; + let mut this = self.as_mut().project(); match this.inner.as_mut().poll_next(cx) {