diff --git a/src/error.rs b/src/error.rs index 96a471bc..b08ad33e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,6 +3,7 @@ use crate::frame::StreamId; use crate::proto::{self, Initiator}; use bytes::Bytes; +use http::StatusCode; use std::{error, fmt, io}; pub use crate::frame::Reason; @@ -26,7 +27,7 @@ pub struct Error { enum Kind { /// A RST_STREAM frame was received or sent. #[allow(dead_code)] - Reset(StreamId, Reason, Initiator), + Reset(StreamId, Reason, Initiator, Option), /// A GO_AWAY frame was received or sent. GoAway(Bytes, Reason, Initiator), @@ -51,7 +52,7 @@ impl Error { /// action taken by the peer (i.e. a protocol error). pub fn reason(&self) -> Option { match self.kind { - Kind::Reset(_, reason, _) | Kind::GoAway(_, reason, _) | Kind::Reason(reason) => { + Kind::Reset(_, reason, _, _) | Kind::GoAway(_, reason, _) | Kind::Reason(reason) => { Some(reason) } _ => None, @@ -101,7 +102,7 @@ impl Error { pub fn is_remote(&self) -> bool { matches!( self.kind, - Kind::GoAway(_, _, Initiator::Remote) | Kind::Reset(_, _, Initiator::Remote) + Kind::GoAway(_, _, Initiator::Remote) | Kind::Reset(_, _, Initiator::Remote, _) ) } @@ -111,7 +112,7 @@ impl Error { pub fn is_library(&self) -> bool { matches!( self.kind, - Kind::GoAway(_, _, Initiator::Library) | Kind::Reset(_, _, Initiator::Library) + Kind::GoAway(_, _, Initiator::Library) | Kind::Reset(_, _, Initiator::Library, _) ) } } @@ -122,7 +123,9 @@ impl From for Error { Error { kind: match src { - Reset(stream_id, reason, initiator) => Kind::Reset(stream_id, reason, initiator), + Reset(stream_id, reason, initiator, status_code) => { + Kind::Reset(stream_id, reason, initiator, status_code) + } GoAway(debug_data, reason, initiator) => { Kind::GoAway(debug_data, reason, initiator) } @@ -162,13 +165,13 @@ impl From for Error { impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { let debug_data = match self.kind { - Kind::Reset(_, reason, Initiator::User) => { + Kind::Reset(_, reason, Initiator::User, _) => { return write!(fmt, "stream error sent by user: {}", reason) } - Kind::Reset(_, reason, Initiator::Library) => { + Kind::Reset(_, reason, Initiator::Library, _) => { return write!(fmt, "stream error detected: {}", reason) } - Kind::Reset(_, reason, Initiator::Remote) => { + Kind::Reset(_, reason, Initiator::Remote, _) => { return write!(fmt, "stream error received: {}", reason) } Kind::GoAway(ref debug_data, reason, Initiator::User) => { diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 5589fabc..4f32e6fe 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -454,7 +454,7 @@ where // Attempting to read a frame resulted in a stream level error. // This is handled by resetting the frame then trying to read // another frame. - Err(Error::Reset(id, reason, initiator)) => { + Err(Error::Reset(id, reason, initiator, _)) => { debug_assert_eq!(initiator, Initiator::Library); tracing::trace!(?id, ?reason, "stream error"); self.streams.send_reset(id, reason); diff --git a/src/proto/error.rs b/src/proto/error.rs index ad023317..70e3e496 100644 --- a/src/proto/error.rs +++ b/src/proto/error.rs @@ -2,13 +2,14 @@ use crate::codec::SendError; use crate::frame::{Reason, StreamId}; use bytes::Bytes; +use http::StatusCode; use std::fmt; use std::io; /// Either an H2 reason or an I/O error #[derive(Clone, Debug)] pub enum Error { - Reset(StreamId, Reason, Initiator), + Reset(StreamId, Reason, Initiator, Option), GoAway(Bytes, Reason, Initiator), Io(io::ErrorKind, Option), } @@ -23,7 +24,7 @@ pub enum Initiator { impl Error { pub(crate) fn is_local(&self) -> bool { match *self { - Self::Reset(_, _, initiator) | Self::GoAway(_, _, initiator) => initiator.is_local(), + Self::Reset(_, _, initiator, _) | Self::GoAway(_, _, initiator) => initiator.is_local(), Self::Io(..) => true, } } @@ -33,7 +34,15 @@ impl Error { } pub(crate) fn library_reset(stream_id: StreamId, reason: Reason) -> Self { - Self::Reset(stream_id, reason, Initiator::Library) + Self::Reset(stream_id, reason, Initiator::Library, None) + } + + pub(crate) fn library_reset_with_status_code( + stream_id: StreamId, + reason: Reason, + status_code: StatusCode, + ) -> Self { + Self::Reset(stream_id, reason, Initiator::Library, Some(status_code)) } pub(crate) fn library_go_away(reason: Reason) -> Self { @@ -45,7 +54,7 @@ impl Error { } pub(crate) fn remote_reset(stream_id: StreamId, reason: Reason) -> Self { - Self::Reset(stream_id, reason, Initiator::Remote) + Self::Reset(stream_id, reason, Initiator::Remote, None) } pub(crate) fn remote_go_away(debug_data: Bytes, reason: Reason) -> Self { @@ -65,7 +74,7 @@ impl Initiator { impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { match *self { - Self::Reset(_, reason, _) | Self::GoAway(_, reason, _) => reason.fmt(fmt), + Self::Reset(_, reason, _, _) | Self::GoAway(_, reason, _) => reason.fmt(fmt), Self::Io(_, Some(ref inner)) => inner.fmt(fmt), Self::Io(kind, None) => io::Error::from(kind).fmt(fmt), } diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 56092759..a705487e 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -11,7 +11,7 @@ pub use self::error::{Error, Initiator}; pub(crate) use self::peer::{Dyn as DynPeer, Peer}; pub(crate) use self::ping_pong::UserPings; pub(crate) use self::streams::{DynStreams, OpaqueStreamRef, StreamRef, Streams}; -pub(crate) use self::streams::{Open, PollReset, Prioritized}; +pub(crate) use self::streams::{Open, PollReset, Prioritized, SendResetContext}; use crate::codec::Codec; diff --git a/src/proto/streams/mod.rs b/src/proto/streams/mod.rs index c4a83234..d266be33 100644 --- a/src/proto/streams/mod.rs +++ b/src/proto/streams/mod.rs @@ -12,7 +12,7 @@ mod streams; pub(crate) use self::prioritize::Prioritized; pub(crate) use self::recv::Open; -pub(crate) use self::send::PollReset; +pub(crate) use self::send::{PollReset, SendResetContext}; pub(crate) use self::streams::{DynStreams, OpaqueStreamRef, StreamRef, Streams}; use self::buffer::Buffer; diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 2a7abba0..2a529f20 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -47,6 +47,33 @@ pub(crate) enum PollReset { Streaming, } +/// Context for `send_reset`. +pub(crate) struct SendResetContext { + reason: Reason, + initiator: Initiator, + status_code: Option, +} + +impl SendResetContext { + /// Create a new `SendResetContext` with and optional `http::StatusCode`. + pub(crate) fn with_status_code( + reason: Reason, + initiator: Initiator, + status_code: Option, + ) -> Self { + Self { + reason, + initiator, + status_code, + } + } + + /// Create a new `SendResetContext` + pub(crate) fn new(reason: Reason, initiator: Initiator) -> Self { + Self::with_status_code(reason, initiator, None) + } +} + impl Send { /// Create a new `Send` pub fn new(config: &Config) -> Self { @@ -170,8 +197,7 @@ impl Send { /// Send an explicit RST_STREAM frame pub fn send_reset( &mut self, - reason: Reason, - initiator: Initiator, + context: SendResetContext, buffer: &mut Buffer>, stream: &mut store::Ptr, counts: &mut Counts, @@ -182,10 +208,16 @@ impl Send { let is_empty = stream.pending_send.is_empty(); let stream_id = stream.id; + let SendResetContext { + reason, + initiator, + status_code, + } = context; + tracing::trace!( "send_reset(..., reason={:?}, initiator={:?}, stream={:?}, ..., \ is_reset={:?}; is_closed={:?}; pending_send.is_empty={:?}; \ - state={:?} \ + state={:?}; status_code={:?} \ ", reason, initiator, @@ -193,7 +225,8 @@ impl Send { is_reset, is_closed, is_empty, - stream.state + stream.state, + status_code ); if is_reset { @@ -225,6 +258,19 @@ impl Send { // `reclaim_all_capacity`. self.prioritize.clear_queue(buffer, stream); + // For malformed requests, a server may send an HTTP response prior to resetting the stream. + if let Some(status_code) = status_code { + tracing::trace!("send_reset -- sending response with status code: {status_code}"); + let pseudo = frame::Pseudo::response(status_code); + let fields = http::HeaderMap::default(); + let mut frame = frame::Headers::new(stream.id, pseudo, fields); + frame.set_end_stream(); + + tracing::trace!("send_reset -- queueing response; frame={:?}", frame); + self.prioritize + .queue_frame(frame.into(), buffer, stream, task); + } + let frame = frame::Reset::new(stream.id, reason); tracing::trace!("send_reset -- queueing; frame={:?}", frame); @@ -378,8 +424,7 @@ impl Send { tracing::debug!("recv_stream_window_update !!; err={:?}", e); self.send_reset( - Reason::FLOW_CONTROL_ERROR, - Initiator::Library, + SendResetContext::new(Reason::FLOW_CONTROL_ERROR, Initiator::Library), buffer, stream, counts, diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index 5256f09c..f8181c49 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -333,7 +333,9 @@ impl State { /// Set the stream state to reset locally. pub fn set_reset(&mut self, stream_id: StreamId, reason: Reason, initiator: Initiator) { - self.inner = Closed(Cause::Error(Error::Reset(stream_id, reason, initiator))); + self.inner = Closed(Cause::Error(Error::Reset( + stream_id, reason, initiator, None, + ))); } /// Set the stream state to a scheduled reset. @@ -364,7 +366,7 @@ impl State { pub fn is_remote_reset(&self) -> bool { matches!( self.inner, - Closed(Cause::Error(Error::Reset(_, _, Initiator::Remote))) + Closed(Cause::Error(Error::Reset(_, _, Initiator::Remote, _))) ) } @@ -446,7 +448,7 @@ impl State { /// Returns a reason if the stream has been reset. pub(super) fn ensure_reason(&self, mode: PollReset) -> Result, crate::Error> { match self.inner { - Closed(Cause::Error(Error::Reset(_, reason, _))) + Closed(Cause::Error(Error::Reset(_, reason, _, _))) | Closed(Cause::Error(Error::GoAway(_, reason, _))) | Closed(Cause::ScheduledLibraryReset(reason)) => Ok(Some(reason)), Closed(Cause::Error(ref e)) => Err(e.clone().into()), diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 132d91bd..aa97e7c7 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -1547,8 +1547,7 @@ impl Actions { ) { counts.transition(stream, |counts, stream| { self.send.send_reset( - reason, - initiator, + proto::SendResetContext::new(reason, initiator), send_buffer, stream, counts, @@ -1567,15 +1566,20 @@ impl Actions { counts: &mut Counts, res: Result<(), Error>, ) -> Result<(), Error> { - if let Err(Error::Reset(stream_id, reason, initiator)) = res { + if let Err(Error::Reset(stream_id, reason, initiator, status_code)) = res { debug_assert_eq!(stream_id, stream.id); if counts.can_inc_num_local_error_resets() { counts.inc_num_local_error_resets(); // Reset the stream. - self.send - .send_reset(reason, initiator, buffer, stream, counts, &mut self.task); + self.send.send_reset( + proto::SendResetContext::with_status_code(reason, initiator, status_code), + buffer, + stream, + counts, + &mut self.task, + ); Ok(()) } else { tracing::warn!( diff --git a/src/server.rs b/src/server.rs index b00bc086..4956a7e4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -121,7 +121,7 @@ use crate::proto::{self, Config, Error, Prioritized}; use crate::{FlowControl, PingPong, RecvStream, SendStream}; use bytes::{Buf, Bytes}; -use http::{HeaderMap, Method, Request, Response}; +use http::{HeaderMap, Method, Request, Response, StatusCode}; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -1517,7 +1517,7 @@ impl proto::Peer for Peer { macro_rules! malformed { ($($arg:tt)*) => {{ tracing::debug!($($arg)*); - return Err(Error::library_reset(stream_id, Reason::PROTOCOL_ERROR)); + return Err(Error::library_reset_with_status_code(stream_id, Reason::PROTOCOL_ERROR, StatusCode::BAD_REQUEST)); }} } diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index c1af5419..06b6fac9 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -503,6 +503,9 @@ async fn recv_invalid_authority() { let settings = client.assert_server_handshake().await; assert_default_settings!(settings); client.send_frame(bad_headers).await; + client + .recv_frame(frames::headers(1).status(StatusCode::BAD_REQUEST).eos()) + .await; client.recv_frame(frames::reset(1).protocol_error()).await; }; @@ -1338,6 +1341,9 @@ async fn reject_pseudo_protocol_on_non_connect_request() { ))) .await; + client + .recv_frame(frames::headers(1).status(StatusCode::BAD_REQUEST).eos()) + .await; client.recv_frame(frames::reset(1).protocol_error()).await; }; @@ -1378,6 +1384,9 @@ async fn reject_extended_connect_request_without_scheme() { })) .await; + client + .recv_frame(frames::headers(1).status(StatusCode::BAD_REQUEST).eos()) + .await; client.recv_frame(frames::reset(1).protocol_error()).await; }; @@ -1418,6 +1427,9 @@ async fn reject_extended_connect_request_without_path() { })) .await; + client + .recv_frame(frames::headers(1).status(StatusCode::BAD_REQUEST).eos()) + .await; client.recv_frame(frames::reset(1).protocol_error()).await; }; diff --git a/tests/h2-tests/tests/stream_states.rs b/tests/h2-tests/tests/stream_states.rs index 9a377d79..a85f03ec 100644 --- a/tests/h2-tests/tests/stream_states.rs +++ b/tests/h2-tests/tests/stream_states.rs @@ -524,6 +524,9 @@ async fn recv_next_stream_id_updated_by_malformed_headers() { assert_default_settings!(settings); // bad headers -- should error. client.send_frame(bad_headers).await; + client + .recv_frame(frames::headers(1).status(StatusCode::BAD_REQUEST).eos()) + .await; client.recv_frame(frames::reset(1).protocol_error()).await; // this frame is good, but the stream id should already have been incr'd client