diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index cd96dce2..98de1bfa 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -229,6 +229,11 @@ impl Recv { return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into()); } + if pseudo.status.is_some() && counts.peer().is_server() { + proto_err!(stream: "cannot use :status header for requests; stream={:?}", stream.id); + return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into()); + } + if !pseudo.is_informational() { let message = counts .peer() @@ -239,27 +244,31 @@ impl Recv { .pending_recv .push_back(&mut self.buffer, Event::Headers(message)); stream.notify_recv(); - } - // Only servers can receive a headers frame that initiates the stream. - // This is verified in `Streams` before calling this function. - if counts.peer().is_server() { - self.pending_accept.push(stream); + // Only servers can receive a headers frame that initiates the stream. + // This is verified in `Streams` before calling this function. + if counts.peer().is_server() { + // Correctness: never push a stream to `pending_accept` without having the + // corresponding headers frame pushed to `stream.pending_recv`. + self.pending_accept.push(stream); + } } Ok(()) } /// Called by the server to get the request - pub fn take_request(&mut self, stream: &mut store::Ptr) -> Result, proto::Error> { + /// + /// # Panics + /// + /// Panics if `stream.pending_recv` has no `Event::Headers` queued. + /// + pub fn take_request(&mut self, stream: &mut store::Ptr) -> Request<()> { use super::peer::PollMessage::*; match stream.pending_recv.pop_front(&mut self.buffer) { - Some(Event::Headers(Server(request))) => Ok(request), - _ => { - proto_err!(stream: "received invalid request; stream={:?}", stream.id); - Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR)) - } + Some(Event::Headers(Server(request))) => request, + _ => unreachable!("server stream queue must start with Headers"), } } diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index d64e0097..dfc5c768 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -1178,7 +1178,7 @@ impl StreamRef { /// # Panics /// /// This function panics if the request isn't present. - pub fn take_request(&self) -> Result, proto::Error> { + pub fn take_request(&self) -> Request<()> { let mut me = self.opaque.inner.lock().unwrap(); let me = &mut *me; diff --git a/src/server.rs b/src/server.rs index 148cad51..f1f4cf47 100644 --- a/src/server.rs +++ b/src/server.rs @@ -425,20 +425,13 @@ where if let Some(inner) = self.connection.next_incoming() { tracing::trace!("received incoming"); - match inner.take_request() { - Ok(req) => { - let (head, _) = req.into_parts(); - let body = RecvStream::new(FlowControl::new(inner.clone_to_opaque())); + let (head, _) = inner.take_request().into_parts(); + let body = RecvStream::new(FlowControl::new(inner.clone_to_opaque())); - let request = Request::from_parts(head, body); - let respond = SendResponse { inner }; + let request = Request::from_parts(head, body); + let respond = SendResponse { inner }; - return Poll::Ready(Some(Ok((request, respond)))); - } - Err(e) => { - return Poll::Ready(Some(Err(e.into()))); - } - } + return Poll::Ready(Some(Ok((request, respond)))); } Poll::Pending diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index 2637011f..0d7bb61c 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -1380,7 +1380,7 @@ async fn reject_non_authority_target_on_connect_request() { } #[tokio::test] -async fn reject_response_headers_in_request() { +async fn reject_informational_status_header_in_request() { h2_support::trace_init!(); let (io, mut client) = mock::new(); @@ -1388,21 +1388,22 @@ async fn reject_response_headers_in_request() { let client = async move { let _ = client.assert_server_handshake().await; - client.send_frame(frames::headers(1).response(128)).await; + let status_code = 128; + assert!(StatusCode::from_u16(status_code) + .unwrap() + .is_informational()); - // TODO: is CANCEL the right error code to expect here? - client.recv_frame(frames::reset(1).cancel()).await; + client + .send_frame(frames::headers(1).response(status_code)) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; }; let srv = async move { let builder = server::Builder::new(); let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); - let res = srv.next().await; - tracing::warn!("{:?}", res); - assert!(res.is_some()); - assert!(res.unwrap().is_err()); - poll_fn(move |cx| srv.poll_closed(cx)) .await .expect("server");