diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..00642f837 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: seanmonstar diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0dab7b3dd..a23753531 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -13,21 +13,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v4 - name: Install Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - profile: minimal - toolchain: stable - override: true components: rustfmt - - name: cargo fmt --check - uses: actions-rs/cargo@v1 - with: - command: fmt - args: --all -- --check + - run: cargo fmt --all --check test: name: Test @@ -43,43 +36,67 @@ jobs: - stable steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v4 - name: Install Rust (${{ matrix.rust }}) - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: - profile: minimal toolchain: ${{ matrix.rust }} - override: true - name: Install libssl-dev run: sudo apt-get update && sudo apt-get install libssl-dev - name: Build without unstable flag - uses: actions-rs/cargo@v1 - with: - command: build + run: cargo build - name: Check with unstable flag - uses: actions-rs/cargo@v1 - with: - command: check - args: --features unstable + run: cargo check --features unstable - name: Run lib tests and doc tests - uses: actions-rs/cargo@v1 - with: - command: test + run: cargo test - name: Run integration tests - uses: actions-rs/cargo@v1 - with: - command: test - args: -p h2-tests + run: cargo test -p h2-tests - name: Run h2spec run: ./ci/h2spec.sh if: matrix.rust == 'stable' - - name: Check minimal versions - run: cargo clean; cargo update -Zminimal-versions; cargo check - if: matrix.rust == 'nightly' + #clippy_check: + # runs-on: ubuntu-latest + # steps: + # - uses: actions/checkout@v4 + # - name: Run Clippy + # run: cargo clippy --all-targets --all-features + + msrv: + name: Check MSRV + needs: [style] + + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Get MSRV from package metadata + id: msrv + run: grep rust-version Cargo.toml | cut -d '"' -f2 | sed 's/^/version=/' >> $GITHUB_OUTPUT + + - name: Install Rust (${{ steps.metadata.outputs.msrv }}) + id: msrv-toolchain + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ steps.msrv.outputs.version }} + + - run: cargo check -p h2 + + minimal-versions: + runs-on: ubuntu-latest + needs: [style] + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@nightly + - uses: dtolnay/rust-toolchain@stable + - uses: taiki-e/install-action@cargo-hack + - uses: taiki-e/install-action@cargo-minimal-versions + - run: cargo minimal-versions --ignore-private check diff --git a/CHANGELOG.md b/CHANGELOG.md index 66a88460e..3b9663dbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,173 @@ +# 0.4.4 (April 3, 2024) + +* Limit number of CONTINUATION frames for misbehaving connections. + +# 0.4.3 (March 15, 2024) + +* Fix flow control limits to not apply until receiving SETTINGS ack. +* Fix not returning an error if IO ended without `close_notify`. +* Improve performance of decoding many headers. + +# 0.4.2 (January 17th, 2024) + +* Limit error resets for misbehaving connections. +* Fix selecting MAX_CONCURRENT_STREAMS value if no value is advertised initially. + +# 0.4.1 (January 8, 2024) + +* Fix assigning connection capacity which could starve streams in some instances. + +# 0.4.0 (November 15, 2023) + +* Update to `http` 1.0. +* Remove deprecated `Server::poll_close()`. + +# 0.3.22 (November 15, 2023) + +* Add `header_table_size(usize)` option to client and server builders. +* Improve throughput when vectored IO is not available. +* Update indexmap to 2. + +# 0.3.21 (August 21, 2023) + +* Fix opening of new streams over peer's max concurrent limit. +* Fix `RecvStream` to return data even if it has received a `CANCEL` stream error. +* Update MSRV to 1.63. + +# 0.3.20 (June 26, 2023) + +* Fix panic if a server received a request with a `:status` pseudo header in the 1xx range. +* Fix panic if a reset stream had pending push promises that were more than allowed. +* Fix potential flow control overflow by subtraction, instead returning a connection error. + +# 0.3.19 (May 12, 2023) + +* Fix counting reset streams when triggered by a GOAWAY. +* Send `too_many_resets` in opaque debug data of GOAWAY when too many resets received. + +# 0.3.18 (April 17, 2023) + +* Fix panic because of opposite check in `is_remote_local()`. + +# 0.3.17 (April 13, 2023) + +* Add `Error::is_library()` method to check if the originated inside `h2`. +* Add `max_pending_accept_reset_streams(usize)` option to client and server + builders. +* Fix theoretical memory growth when receiving too many HEADERS and then + RST_STREAM frames faster than an application can accept them off the queue. + (CVE-2023-26964) + +# 0.3.16 (February 27, 2023) + +* Set `Protocol` extension on requests when received Extended CONNECT requests. +* Remove `B: Unpin + 'static` bound requiremented of bufs +* Fix releasing of frames when stream is finished, reducing memory usage. +* Fix panic when trying to send data and connection window is available, but stream window is not. +* Fix spurious wakeups when stream capacity is not available. + +# 0.3.15 (October 21, 2022) + +* Remove `B: Buf` bound on `SendStream`'s parameter +* add accessor for `StreamId` u32 + +# 0.3.14 (August 16, 2022) + +* Add `Error::is_reset` function. +* Bump MSRV to Rust 1.56. +* Return `RST_STREAM(NO_ERROR)` when the server early responds. + +# 0.3.13 (March 31, 2022) + +* Update private internal `tokio-util` dependency. + +# 0.3.12 (March 9, 2022) + +* Avoid time operations that can panic (#599) +* Bump MSRV to Rust 1.49 (#606) +* Fix header decoding error when a header name is contained at a continuation + header boundary (#589) +* Remove I/O type names from handshake `tracing` spans (#608) + +# 0.3.11 (January 26, 2022) + +* Make `SendStream::poll_capacity` never return `Ok(Some(0))` (#596) +* Fix panic when receiving already reset push promise (#597) + +# 0.3.10 (January 6, 2022) + +* Add `Error::is_go_away()` and `Error::is_remote()` methods. +* Fix panic if receiving malformed PUSH_PROMISE with stream ID of 0. + +# 0.3.9 (December 9, 2021) + +* Fix hang related to new `max_send_buffer_size`. + +# 0.3.8 (December 8, 2021) + +* Add "extended CONNECT support". Adds `h2::ext::Protocol`, which is used for request and response extensions to connect new protocols over an HTTP/2 stream. +* Add `max_send_buffer_size` options to client and server builders, and a default of ~400MB. This acts like a high-water mark for the `poll_capacity()` method. +* Fix panic if receiving malformed HEADERS with stream ID of 0. + +# 0.3.7 (October 22, 2021) + +* Fix panic if server sends a malformed frame on a stream client was about to open. +* Fix server to treat `:status` in a request as a stream error instead of connection error. + +# 0.3.6 (September 30, 2021) + +* Fix regression of `h2::Error` that were created via `From` not returning their reason code in `Error::reason()`. + +# 0.3.5 (September 29, 2021) + +* Fix sending of very large headers. Previously when a single header was too big to fit in a single `HEADERS` frame, an error was returned. Now it is broken up and sent correctly. +* Fix buffered data field to be a bigger integer size. +* Refactor error format to include what initiated the error (remote, local, or user), if it was a stream or connection-level error, and any received debug data. + +# 0.3.4 (August 20, 2021) + +* Fix panic when encoding header size update over a certain size. +* Fix `SendRequest` to wake up connection when dropped. +* Fix potential hang if `RecvStream` is placed in the request or response `extensions`. +* Stop calling `Instant::now` if zero reset streams are configured. + +# 0.3.3 (April 29, 2021) + +* Fix client being able to make `CONNECT` requests without a `:path`. +* Expose `RecvStream::poll_data`. +* Fix some docs. + +# 0.3.2 (March 24, 2021) + +* Fix incorrect handling of received 1xx responses on the client when the request body is still streaming. + +# 0.3.1 (February 26, 2021) + +* Add `Connection::max_concurrent_recv_streams()` getter. +* Add `Connection::max_concurrent_send_streams()` getter. +* Fix client to ignore receipt of 1xx headers frames. +* Fix incorrect calculation of pseudo header lengths when determining if a received header is too big. +* Reduce monomorphized code size of internal code. + +# 0.3.0 (December 23, 2020) + +* Update to Tokio v1 and Bytes v1. +* Disable `tracing`'s `log` feature. (It can still be enabled by a user in their own `Cargo.toml`.) + +# 0.2.7 (October 22, 2020) + +* Fix stream ref count when sending a push promise +* Fix receiving empty DATA frames in response to a HEAD request +* Fix handling of client disabling SERVER_PUSH + +# 0.2.6 (July 13, 2020) + +* Integrate `tracing` directly where `log` was used. (For 0.2.x, `log`s are still emitted by default.) + +# 0.2.5 (May 6, 2020) + +* Fix rare debug assert failure in store shutdown. + # 0.2.4 (March 30, 2020) * Fix when receiving `SETTINGS_HEADER_TABLE_SIZE` setting. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 10e74bf29..4b69dc699 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,7 +5,7 @@ ## Getting Help ## If you have a question about the h2 library or have encountered problems using it, you may -[file an issue][issue] or ask ask a question on the [Tokio Gitter][gitter]. +[file an issue][issue] or ask a question on the [Tokio Discord][discord]. ## Submitting a Pull Request ## @@ -15,7 +15,7 @@ Do you have an improvement? 2. We will try to respond to your issue promptly. 3. Fork this repo, develop and test your code changes. See the project's [README](README.md) for further information about working in this repository. 4. Submit a pull request against this repo's `master` branch. -6. Your branch may be merged once all configured checks pass, including: +5. Your branch may be merged once all configured checks pass, including: - Code review has been completed. - The branch has passed tests in CI. @@ -81,4 +81,4 @@ Describe the testing you've done to validate your change. Performance-related changes should include before- and after- benchmark results. [issue]: https://github.com/hyperium/h2/issues/new -[gitter]: https://gitter.im/tokio-rs/tokio +[discord]: https://discord.gg/tokio diff --git a/Cargo.toml b/Cargo.toml index ede6b9143..c76b9ecf9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,24 +1,23 @@ [package] name = "h2" # When releasing to crates.io: -# - Update doc URL. -# - html_root_url. # - Update CHANGELOG.md. # - Create git tag -version = "0.2.4" +version = "0.4.4" license = "MIT" authors = [ "Carl Lerche ", "Sean McArthur ", ] -description = "An HTTP/2.0 client and server" -documentation = "https://docs.rs/h2/0.2.4/h2/" +description = "An HTTP/2 client and server" +documentation = "https://docs.rs/h2" repository = "https://github.com/hyperium/h2" readme = "README.md" keywords = ["http", "async", "non-blocking"] categories = ["asynchronous", "web-programming", "network-programming"] exclude = ["fixtures/**", "ci/**"] -edition = "2018" +edition = "2021" +rust-version = "1.63" [features] # Enables `futures::Stream` implementations for various types. @@ -43,31 +42,36 @@ members = [ futures-core = { version = "0.3", default-features = false } futures-sink = { version = "0.3", default-features = false } futures-util = { version = "0.3", default-features = false } -tokio-util = { version = "0.3.1", features = ["codec"] } -tokio = { version = "0.2", features = ["io-util"] } -bytes = "0.5.2" -http = "0.2" -log = "0.4.1" +tokio-util = { version = "0.7.1", features = ["codec", "io"] } +tokio = { version = "1", features = ["io-util"] } +bytes = "1" +http = "1" +tracing = { version = "0.1.35", default-features = false, features = ["std"] } fnv = "1.0.5" -slab = "0.4.0" -indexmap = "1.0" +slab = "0.4.2" +indexmap = { version = "2", features = ["std"] } [dev-dependencies] # Fuzzing -quickcheck = { version = "0.4.1", default-features = false } -rand = "0.3.15" +quickcheck = { version = "1.0.3", default-features = false } +rand = "0.8.4" # HPACK fixtures -hex = "0.2.0" -walkdir = "1.0.0" +hex = "0.4.3" +walkdir = "2.3.2" serde = "1.0.0" serde_json = "1.0.0" # Examples -tokio = { version = "0.2", features = ["dns", "macros", "rt-core", "sync", "tcp"] } -env_logger = { version = "0.5.3", default-features = false } -rustls = "0.16" -tokio-rustls = "0.12.0" -webpki = "0.21" -webpki-roots = "0.17" +tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "net"] } +env_logger = { version = "0.10", default-features = false } +tokio-rustls = "0.24" +webpki-roots = "0.25" + +[package.metadata.docs.rs] +features = ["stream"] + +[[bench]] +name = "main" +harness = false diff --git a/README.md b/README.md index 21f00a500..f83357d5d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # H2 -A Tokio aware, HTTP/2.0 client & server implementation for Rust. +A Tokio aware, HTTP/2 client & server implementation for Rust. [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![Crates.io](https://img.shields.io/crates/v/h2.svg)](https://crates.io/crates/h2) @@ -12,24 +12,23 @@ More information about this crate can be found in the [crate documentation][dox] ## Features -* Client and server HTTP/2.0 implementation. -* Implements the full HTTP/2.0 specification. +* Client and server HTTP/2 implementation. +* Implements the full HTTP/2 specification. * Passes [h2spec](https://github.com/summerwind/h2spec). * Focus on performance and correctness. * Built on [Tokio](https://tokio.rs). ## Non goals -This crate is intended to only be an implementation of the HTTP/2.0 +This crate is intended to only be an implementation of the HTTP/2 specification. It does not handle: * Managing TCP connections * HTTP 1.0 upgrade * TLS -* Any feature not described by the HTTP/2.0 specification. +* Any feature not described by the HTTP/2 specification. -The intent is that this crate will eventually be used by -[hyper](https://github.com/hyperium/hyper), which will provide all of these features. +This crate is now used by [hyper](https://github.com/hyperium/hyper), which will provide all of these features. ## Usage @@ -37,7 +36,7 @@ To use `h2`, first add this to your `Cargo.toml`: ```toml [dependencies] -h2 = "0.2" +h2 = "0.4" ``` Next, add this to your crate: @@ -56,7 +55,7 @@ fn main() { **How does h2 compare to [solicit] or [rust-http2]?** -The h2 library has implemented more of the details of the HTTP/2.0 specification +The h2 library has implemented more of the details of the HTTP/2 specification than any other Rust library. It also passes the [h2spec] set of tests. The h2 library is rapidly approaching "production ready" quality. diff --git a/benches/main.rs b/benches/main.rs new file mode 100644 index 000000000..b1e64edf4 --- /dev/null +++ b/benches/main.rs @@ -0,0 +1,148 @@ +use bytes::Bytes; +use h2::{ + client, + server::{self, SendResponse}, + RecvStream, +}; +use http::Request; + +use std::{ + error::Error, + time::{Duration, Instant}, +}; + +use tokio::net::{TcpListener, TcpStream}; + +const NUM_REQUESTS_TO_SEND: usize = 100_000; + +// The actual server. +async fn server(addr: &str) -> Result<(), Box> { + let listener = TcpListener::bind(addr).await?; + + loop { + if let Ok((socket, _peer_addr)) = listener.accept().await { + tokio::spawn(async move { + if let Err(e) = serve(socket).await { + println!(" -> err={:?}", e); + } + }); + } + } +} + +async fn serve(socket: TcpStream) -> Result<(), Box> { + let mut connection = server::handshake(socket).await?; + while let Some(result) = connection.accept().await { + let (request, respond) = result?; + tokio::spawn(async move { + if let Err(e) = handle_request(request, respond).await { + println!("error while handling request: {}", e); + } + }); + } + Ok(()) +} + +async fn handle_request( + mut request: Request, + mut respond: SendResponse, +) -> Result<(), Box> { + let body = request.body_mut(); + while let Some(data) = body.data().await { + let data = data?; + let _ = body.flow_control().release_capacity(data.len()); + } + let response = http::Response::new(()); + let mut send = respond.send_response(response, false)?; + send.send_data(Bytes::from_static(b"pong"), true)?; + + Ok(()) +} + +// The benchmark +async fn send_requests(addr: &str) -> Result<(), Box> { + let tcp = loop { + let Ok(tcp) = TcpStream::connect(addr).await else { + continue; + }; + break tcp; + }; + let (client, h2) = client::handshake(tcp).await?; + // Spawn a task to run the conn... + tokio::spawn(async move { + if let Err(e) = h2.await { + println!("GOT ERR={:?}", e); + } + }); + + let mut handles = Vec::with_capacity(NUM_REQUESTS_TO_SEND); + for _i in 0..NUM_REQUESTS_TO_SEND { + let mut client = client.clone(); + let task = tokio::spawn(async move { + let request = Request::builder().body(()).unwrap(); + + let instant = Instant::now(); + let (response, _) = client.send_request(request, true).unwrap(); + let response = response.await.unwrap(); + let mut body = response.into_body(); + while let Some(_chunk) = body.data().await {} + instant.elapsed() + }); + handles.push(task); + } + + let instant = Instant::now(); + let mut result = Vec::with_capacity(NUM_REQUESTS_TO_SEND); + for handle in handles { + result.push(handle.await.unwrap()); + } + let mut sum = Duration::new(0, 0); + for r in result.iter() { + sum = sum.checked_add(*r).unwrap(); + } + + println!("Overall: {}ms.", instant.elapsed().as_millis()); + println!("Fastest: {}ms", result.iter().min().unwrap().as_millis()); + println!("Slowest: {}ms", result.iter().max().unwrap().as_millis()); + println!( + "Avg : {}ms", + sum.div_f64(NUM_REQUESTS_TO_SEND as f64).as_millis() + ); + Ok(()) +} + +fn main() { + let _ = env_logger::try_init(); + let addr = "127.0.0.1:5928"; + println!("H2 running in current-thread runtime at {addr}:"); + std::thread::spawn(|| { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + rt.block_on(server(addr)).unwrap(); + }); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + rt.block_on(send_requests(addr)).unwrap(); + + let addr = "127.0.0.1:5929"; + println!("H2 running in multi-thread runtime at {addr}:"); + std::thread::spawn(|| { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + rt.block_on(server(addr)).unwrap(); + }); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + rt.block_on(send_requests(addr)).unwrap(); +} diff --git a/ci/h2spec.sh b/ci/h2spec.sh index c55af5513..ff7295515 100755 --- a/ci/h2spec.sh +++ b/ci/h2spec.sh @@ -13,7 +13,7 @@ SERVER_PID=$! # wait 'til the server is listening before running h2spec, and pipe server's # stdout to a log file. -sed '/listening on Ok(V4(127.0.0.1:5928))/q' <&3 ; cat <&3 > "${LOGFILE}" & +sed '/listening on Ok(127.0.0.1:5928)/q' <&3 ; cat <&3 > "${LOGFILE}" & # run h2spec against the server, printing the server log if h2spec failed ./h2spec -p 5928 diff --git a/examples/akamai.rs b/examples/akamai.rs index 29d8a9347..8d87b778e 100644 --- a/examples/akamai.rs +++ b/examples/akamai.rs @@ -3,8 +3,7 @@ use http::{Method, Request}; use tokio::net::TcpStream; use tokio_rustls::TlsConnector; -use rustls::Session; -use webpki::DNSNameRef; +use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore, ServerName}; use std::error::Error; use std::net::ToSocketAddrs; @@ -16,9 +15,19 @@ pub async fn main() -> Result<(), Box> { let _ = env_logger::try_init(); let tls_client_config = std::sync::Arc::new({ - let mut c = rustls::ClientConfig::new(); - c.root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + let mut root_store = RootCertStore::empty(); + root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + + let mut c = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); c.alpn_protocols.push(ALPN_H2.as_bytes().to_owned()); c }); @@ -33,17 +42,14 @@ pub async fn main() -> Result<(), Box> { println!("ADDR: {:?}", addr); let tcp = TcpStream::connect(&addr).await?; - let dns_name = DNSNameRef::try_from_ascii_str("http2.akamai.com").unwrap(); + let dns_name = ServerName::try_from("http2.akamai.com").unwrap(); let connector = TlsConnector::from(tls_client_config); let res = connector.connect(dns_name, tcp).await; let tls = res.unwrap(); { let (_, session) = tls.get_ref(); - let negotiated_protocol = session.get_alpn_protocol(); - assert_eq!( - Some(ALPN_H2.as_bytes()), - negotiated_protocol.as_ref().map(|x| &**x) - ); + let negotiated_protocol = session.alpn_protocol(); + assert_eq!(Some(ALPN_H2.as_bytes()), negotiated_protocol); } println!("Starting client handshake"); diff --git a/examples/server.rs b/examples/server.rs index 1753b7a2e..6d6490db0 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,21 +1,23 @@ use std::error::Error; use bytes::Bytes; -use h2::server; +use h2::server::{self, SendResponse}; +use h2::RecvStream; +use http::Request; use tokio::net::{TcpListener, TcpStream}; #[tokio::main] async fn main() -> Result<(), Box> { let _ = env_logger::try_init(); - let mut listener = TcpListener::bind("127.0.0.1:5928").await?; + let listener = TcpListener::bind("127.0.0.1:5928").await?; println!("listening on {:?}", listener.local_addr()); loop { if let Ok((socket, _peer_addr)) = listener.accept().await { tokio::spawn(async move { - if let Err(e) = handle(socket).await { + if let Err(e) = serve(socket).await { println!(" -> err={:?}", e); } }); @@ -23,22 +25,41 @@ async fn main() -> Result<(), Box> { } } -async fn handle(socket: TcpStream) -> Result<(), Box> { +async fn serve(socket: TcpStream) -> Result<(), Box> { let mut connection = server::handshake(socket).await?; println!("H2 connection bound"); while let Some(result) = connection.accept().await { - let (request, mut respond) = result?; - println!("GOT request: {:?}", request); - let response = http::Response::new(()); + let (request, respond) = result?; + tokio::spawn(async move { + if let Err(e) = handle_request(request, respond).await { + println!("error while handling request: {}", e); + } + }); + } + + println!("~~~~~~~~~~~ H2 connection CLOSE !!!!!! ~~~~~~~~~~~"); + Ok(()) +} - let mut send = respond.send_response(response, false)?; +async fn handle_request( + mut request: Request, + mut respond: SendResponse, +) -> Result<(), Box> { + println!("GOT request: {:?}", request); - println!(">>>> sending data"); - send.send_data(Bytes::from_static(b"hello world"), true)?; + let body = request.body_mut(); + while let Some(data) = body.data().await { + let data = data?; + println!("<<<< recv {:?}", data); + let _ = body.flow_control().release_capacity(data.len()); } - println!("~~~~~~~~~~~~~~~~~~~~~~~~~~~ H2 connection CLOSE !!!!!! ~~~~~~~~~~~"); + let response = http::Response::new(()); + let mut send = respond.send_response(response, false)?; + println!(">>>> send"); + send.send_data(Bytes::from_static(b"hello "), false)?; + send.send_data(Bytes::from_static(b"world\n"), true)?; Ok(()) } diff --git a/fuzz/.gitignore b/fuzz/.gitignore new file mode 100644 index 000000000..572e03bdf --- /dev/null +++ b/fuzz/.gitignore @@ -0,0 +1,4 @@ + +target +corpus +artifacts diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml new file mode 100644 index 000000000..922eca238 --- /dev/null +++ b/fuzz/Cargo.toml @@ -0,0 +1,41 @@ + +[package] +name = "h2-oss-fuzz" +version = "0.0.0" +authors = [ "David Korczynski " ] +publish = false +edition = "2018" + +[package.metadata] +cargo-fuzz = true + +[dependencies] +arbitrary = { version = "1", features = ["derive"] } +libfuzzer-sys = { version = "0.4.0", features = ["arbitrary-derive"] } +tokio = { version = "1", features = [ "full" ] } +h2 = { path = "../", features = [ "unstable" ] } +h2-support = { path = "../tests/h2-support" } +futures = { version = "0.3", default-features = false, features = ["std"] } +http = "1" + +# Prevent this from interfering with workspaces +[workspace] +members = ["."] + +[[bin]] +name = "fuzz_client" +path = "fuzz_targets/fuzz_client.rs" +test = false +doc = false + +[[bin]] +name = "fuzz_hpack" +path = "fuzz_targets/fuzz_hpack.rs" +test = false +doc = false + +[[bin]] +name = "fuzz_e2e" +path = "fuzz_targets/fuzz_e2e.rs" +test = false +doc = false diff --git a/fuzz/fuzz_targets/fuzz_client.rs b/fuzz/fuzz_targets/fuzz_client.rs new file mode 100644 index 000000000..0b4672653 --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_client.rs @@ -0,0 +1,34 @@ +#![no_main] +use h2_support::prelude::*; +use libfuzzer_sys::{arbitrary::Arbitrary, fuzz_target}; + +#[derive(Debug, Arbitrary)] +struct HttpSpec { + uri: Vec, + header_name: Vec, + header_value: Vec, +} + +async fn fuzz_entry(inp: HttpSpec) { + if let Ok(req) = Request::builder() + .uri(&inp.uri[..]) + .header(&inp.header_name[..], &inp.header_value[..]) + .body(()) + { + let (io, mut _srv) = mock::new(); + let (mut client, _h2) = client::Builder::new() + .handshake::<_, Bytes>(io) + .await + .unwrap(); + + // this could still trigger a user error: + // - if the uri isn't absolute + // - if the header name isn't allowed in http2 (like connection) + let _ = client.send_request(req, true); + } +} + +fuzz_target!(|inp: HttpSpec| { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(fuzz_entry(inp)); +}); diff --git a/fuzz/fuzz_targets/fuzz_e2e.rs b/fuzz/fuzz_targets/fuzz_e2e.rs new file mode 100644 index 000000000..02792c134 --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_e2e.rs @@ -0,0 +1,129 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; + +use futures::future; +use futures::stream::FuturesUnordered; +use futures::Stream; +use http::{Method, Request}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +struct MockIo<'a> { + input: &'a [u8], +} + +impl<'a> MockIo<'a> { + fn next_byte(&mut self) -> Option { + if let Some(&c) = self.input.first() { + self.input = &self.input[1..]; + Some(c) + } else { + None + } + } + + fn next_u32(&mut self) -> u32 { + (self.next_byte().unwrap_or(0) as u32) << 8 | self.next_byte().unwrap_or(0) as u32 + } +} + +impl<'a> AsyncRead for MockIo<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll> { + + + let mut len = self.next_u32() as usize; + if self.input.is_empty() { + Poll::Ready(Ok(())) + } else if len == 0 { + cx.waker().clone().wake(); + Poll::Pending + } else { + if len > self.input.len() { + len = self.input.len(); + } + + if len > buf.remaining() { + len = buf.remaining(); + } + buf.put_slice(&self.input[..len]); + self.input = &self.input[len..]; + Poll::Ready(Ok(())) + } + } +} + +impl<'a> AsyncWrite for MockIo<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let len = std::cmp::min(self.next_u32() as usize, buf.len()); + if len == 0 { + if self.input.is_empty() { + Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } else { + cx.waker().clone().wake(); + Poll::Pending + } + } else { + Poll::Ready(Ok(len)) + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +async fn run(script: &[u8]) -> Result<(), h2::Error> { + let io = MockIo { input: script }; + let (mut h2, mut connection) = h2::client::handshake(io).await?; + let mut futs = FuturesUnordered::new(); + let future = future::poll_fn(|cx| { + if let Poll::Ready(()) = Pin::new(&mut connection).poll(cx)? { + return Poll::Ready(Ok::<_, h2::Error>(())); + } + while futs.len() < 128 { + if !h2.poll_ready(cx)?.is_ready() { + break; + } + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + let (resp, mut send) = h2.send_request(request, false)?; + send.send_data(vec![0u8; 32769].into(), true).unwrap(); + drop(send); + futs.push(resp); + } + loop { + match Pin::new(&mut futs).poll_next(cx) { + Poll::Pending | Poll::Ready(None) => break, + r @ Poll::Ready(Some(Ok(_))) | r @ Poll::Ready(Some(Err(_))) => { + eprintln!("{:?}", r); + } + } + } + Poll::Pending + }); + future.await?; + Ok(()) +} + +fuzz_target!(|data: &[u8]| { + let rt = tokio::runtime::Runtime::new().unwrap(); + let _res = rt.block_on(run(data)); +}); + diff --git a/fuzz/fuzz_targets/fuzz_hpack.rs b/fuzz/fuzz_targets/fuzz_hpack.rs new file mode 100644 index 000000000..c2597bb94 --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_hpack.rs @@ -0,0 +1,6 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data_: &[u8]| { + let _decoder_ = h2::fuzz_bridge::fuzz_logic::fuzz_hpack(data_); +}); diff --git a/src/client.rs b/src/client.rs index 63514e322..25b151f53 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,18 +1,18 @@ -//! Client implementation of the HTTP/2.0 protocol. +//! Client implementation of the HTTP/2 protocol. //! //! # Getting started //! -//! Running an HTTP/2.0 client requires the caller to establish the underlying +//! Running an HTTP/2 client requires the caller to establish the underlying //! connection as well as get the connection to a state that is ready to begin -//! the HTTP/2.0 handshake. See [here](../index.html#handshake) for more +//! the HTTP/2 handshake. See [here](../index.html#handshake) for more //! details. //! //! This could be as basic as using Tokio's [`TcpStream`] to connect to a remote //! host, but usually it means using either ALPN or HTTP/1.1 protocol upgrades. //! //! Once a connection is obtained, it is passed to [`handshake`], which will -//! begin the [HTTP/2.0 handshake]. This returns a future that completes once -//! the handshake process is performed and HTTP/2.0 streams may be initialized. +//! begin the [HTTP/2 handshake]. This returns a future that completes once +//! the handshake process is performed and HTTP/2 streams may be initialized. //! //! [`handshake`] uses default configuration values. There are a number of //! settings that can be changed by using [`Builder`] instead. @@ -26,16 +26,16 @@ //! # Making requests //! //! Requests are made using the [`SendRequest`] handle provided by the handshake -//! future. Once a request is submitted, an HTTP/2.0 stream is initialized and +//! future. Once a request is submitted, an HTTP/2 stream is initialized and //! the request is sent to the server. //! //! A request body and request trailers are sent using [`SendRequest`] and the //! server's response is returned once the [`ResponseFuture`] future completes. //! Both the [`SendStream`] and [`ResponseFuture`] instances are returned by -//! [`SendRequest::send_request`] and are tied to the HTTP/2.0 stream +//! [`SendRequest::send_request`] and are tied to the HTTP/2 stream //! initialized by the sent request. //! -//! The [`SendRequest::poll_ready`] function returns `Ready` when a new HTTP/2.0 +//! The [`SendRequest::poll_ready`] function returns `Ready` when a new HTTP/2 //! stream can be created, i.e. as long as the current number of active streams //! is below [`MAX_CONCURRENT_STREAMS`]. If a new stream cannot be created, the //! caller will be notified once an existing stream closes, freeing capacity for @@ -131,13 +131,14 @@ //! [`SendRequest`]: struct.SendRequest.html //! [`ResponseFuture`]: struct.ResponseFuture.html //! [`SendRequest::poll_ready`]: struct.SendRequest.html#method.poll_ready -//! [HTTP/2.0 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader +//! [HTTP/2 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader //! [`Builder`]: struct.Builder.html //! [`Error`]: ../struct.Error.html -use crate::codec::{Codec, RecvError, SendError, UserError}; +use crate::codec::{Codec, SendError, UserError}; +use crate::ext::Protocol; use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId}; -use crate::proto; +use crate::proto::{self, Error}; use crate::{FlowControl, PingPong, RecvStream, SendStream}; use bytes::{Buf, Bytes}; @@ -149,8 +150,9 @@ use std::task::{Context, Poll}; use std::time::Duration; use std::usize; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tracing::Instrument; -/// Initializes new HTTP/2.0 streams on a connection by sending a request. +/// Initializes new HTTP/2 streams on a connection by sending a request. /// /// This type does no work itself. Instead, it is a handle to the inner /// connection state held by [`Connection`]. If the associated connection @@ -160,7 +162,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; /// / threads than their associated [`Connection`] instance. Internally, there /// is a buffer used to stage requests before they get written to the /// connection. There is no guarantee that requests get written to the -/// connection in FIFO order as HTTP/2.0 prioritization logic can play a role. +/// connection in FIFO order as HTTP/2 prioritization logic can play a role. /// /// [`SendRequest`] implements [`Clone`], enabling the creation of many /// instances that are backed by a single connection. @@ -183,10 +185,10 @@ pub struct ReadySendRequest { inner: Option>, } -/// Manages all state associated with an HTTP/2.0 client connection. +/// Manages all state associated with an HTTP/2 client connection. /// /// A `Connection` is backed by an I/O resource (usually a TCP socket) and -/// implements the HTTP/2.0 client logic for that connection. It is responsible +/// implements the HTTP/2 client logic for that connection. It is responsible /// for driving the internal state forward, performing the work requested of the /// associated handles ([`SendRequest`], [`ResponseFuture`], [`SendStream`], /// [`RecvStream`]). @@ -219,7 +221,7 @@ pub struct ReadySendRequest { /// // Submit the connection handle to an executor. /// tokio::spawn(async { connection.await.expect("connection failed"); }); /// -/// // Now, use `send_request` to initialize HTTP/2.0 streams. +/// // Now, use `send_request` to initialize HTTP/2 streams. /// // ... /// # Ok(()) /// # } @@ -273,7 +275,7 @@ pub struct PushPromises { /// Methods can be chained in order to set the configuration values. /// /// The client is constructed by calling [`handshake`] and passing the I/O -/// handle that will back the HTTP/2.0 server. +/// handle that will back the HTTP/2 server. /// /// New instances of `Builder` are obtained via [`Builder::new`]. /// @@ -293,7 +295,7 @@ pub struct PushPromises { /// # async fn doc(my_io: T) /// -> Result<((SendRequest, Connection)), h2::Error> /// # { -/// // `client_fut` is a future representing the completion of the HTTP/2.0 +/// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .initial_window_size(1_000_000) @@ -310,23 +312,38 @@ pub struct Builder { reset_stream_duration: Duration, /// Initial maximum number of locally initiated (send) streams. - /// After receiving a Settings frame from the remote peer, + /// After receiving a SETTINGS frame from the remote peer, /// the connection will overwrite this value with the /// MAX_CONCURRENT_STREAMS specified in the frame. + /// If no value is advertised by the remote peer in the initial SETTINGS + /// frame, it will be set to usize::MAX. initial_max_send_streams: usize, /// Initial target window size for new connections. initial_target_connection_window_size: Option, + /// Maximum amount of bytes to "buffer" for writing per stream. + max_send_buffer_size: usize, + /// Maximum number of locally reset streams to keep at a time. reset_stream_max: usize, + /// Maximum number of remotely reset streams to allow in the pending + /// accept queue. + pending_accept_reset_stream_max: usize, + /// Initial `Settings` frame to send as part of the handshake. settings: Settings, /// The stream ID of the first (lowest) stream. Subsequent streams will use /// monotonically increasing stream IDs. stream_id: StreamId, + + /// Maximum number of locally reset streams due to protocol error across + /// the lifetime of the connection. + /// + /// When this gets exceeded, we issue GOAWAYs. + local_max_error_reset_streams: Option, } #[derive(Debug)] @@ -336,9 +353,9 @@ pub(crate) struct Peer; impl SendRequest where - B: Buf + 'static, + B: Buf, { - /// Returns `Ready` when the connection can initialize a new HTTP/2.0 + /// Returns `Ready` when the connection can initialize a new HTTP/2 /// stream. /// /// This function must return `Ready` before `send_request` is called. When @@ -386,16 +403,16 @@ where ReadySendRequest { inner: Some(self) } } - /// Sends a HTTP/2.0 request to the server. + /// Sends a HTTP/2 request to the server. /// - /// `send_request` initializes a new HTTP/2.0 stream on the associated + /// `send_request` initializes a new HTTP/2 stream on the associated /// connection, then sends the given request using this new stream. Only the /// request head is sent. /// /// On success, a [`ResponseFuture`] instance and [`SendStream`] instance /// are returned. The [`ResponseFuture`] instance is used to get the /// server's response and the [`SendStream`] instance is used to send a - /// request body or trailers to the server over the same HTTP/2.0 stream. + /// request body or trailers to the server over the same HTTP/2 stream. /// /// To send a request body or trailers, set `end_of_stream` to `false`. /// Then, use the returned [`SendStream`] instance to stream request body @@ -501,8 +518,10 @@ where self.inner .send_request(request, end_of_stream, self.pending.as_ref()) .map_err(Into::into) - .map(|stream| { - if stream.is_pending_open() { + .map(|(stream, is_full)| { + if stream.is_pending_open() && is_full { + // Only prevent sending another request when the request queue + // is not full. self.pending = Some(stream.clone_to_opaque()); } @@ -516,6 +535,19 @@ where (response, stream) }) } + + /// Returns whether the [extended CONNECT protocol][1] is enabled or not. + /// + /// This setting is configured by the server peer by sending the + /// [`SETTINGS_ENABLE_CONNECT_PROTOCOL` parameter][2] in a `SETTINGS` frame. + /// This method returns the currently acknowledged value received from the + /// remote. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + /// [2]: https://datatracker.ietf.org/doc/html/rfc8441#section-3 + pub fn is_extended_connect_protocol_enabled(&self) -> bool { + self.inner.is_extended_connect_protocol_enabled() + } } impl fmt::Debug for SendRequest @@ -566,7 +598,7 @@ where impl Future for ReadySendRequest where - B: Buf + 'static, + B: Buf, { type Output = Result, crate::Error>; @@ -600,7 +632,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .initial_window_size(1_000_000) @@ -613,12 +645,15 @@ impl Builder { /// ``` pub fn new() -> Builder { Builder { + max_send_buffer_size: proto::DEFAULT_MAX_SEND_BUFFER_SIZE, reset_stream_duration: Duration::from_secs(proto::DEFAULT_RESET_STREAM_SECS), reset_stream_max: proto::DEFAULT_RESET_STREAM_MAX, + pending_accept_reset_stream_max: proto::DEFAULT_REMOTE_RESET_STREAM_MAX, initial_target_connection_window_size: None, initial_max_send_streams: usize::MAX, settings: Default::default(), stream_id: 1.into(), + local_max_error_reset_streams: Some(proto::DEFAULT_LOCAL_RESET_COUNT_MAX), } } @@ -642,7 +677,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .initial_window_size(1_000_000) @@ -677,7 +712,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .initial_connection_window_size(1_000_000) @@ -692,7 +727,7 @@ impl Builder { self } - /// Indicates the size (in octets) of the largest HTTP/2.0 frame payload that the + /// Indicates the size (in octets) of the largest HTTP/2 frame payload that the /// configured client is able to accept. /// /// The sender may send data frames that are **smaller** than this value, @@ -711,7 +746,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .max_frame_size(1_000_000) @@ -751,7 +786,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .max_header_list_size(16 * 1024) @@ -786,7 +821,7 @@ impl Builder { /// a protocol level error. Instead, the `h2` library will immediately reset /// the stream. /// - /// See [Section 5.1.2] in the HTTP/2.0 spec for more details. + /// See [Section 5.1.2] in the HTTP/2 spec for more details. /// /// [Section 5.1.2]: https://http2.github.io/http2-spec/#rfc.section.5.1.2 /// @@ -800,7 +835,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .max_concurrent_streams(1000) @@ -818,8 +853,10 @@ impl Builder { /// Sets the initial maximum of locally initiated (send) streams. /// /// The initial settings will be overwritten by the remote peer when - /// the Settings frame is received. The new value will be set to the - /// `max_concurrent_streams()` from the frame. + /// the SETTINGS frame is received. The new value will be set to the + /// `max_concurrent_streams()` from the frame. If no value is advertised in + /// the initial SETTINGS frame from the remote peer as part of + /// [HTTP/2 Connection Preface], `usize::MAX` will be set. /// /// This setting prevents the caller from exceeding this number of /// streams that are counted towards the concurrency limit. @@ -827,9 +864,12 @@ impl Builder { /// Sending streams past the limit returned by the peer will be treated /// as a stream error of type PROTOCOL_ERROR or REFUSED_STREAM. /// - /// See [Section 5.1.2] in the HTTP/2.0 spec for more details. + /// See [Section 5.1.2] in the HTTP/2 spec for more details. /// - /// [Section 5.1.2]: https://http2.github.io/http2-spec/#rfc.section.5.1.2 + /// The default value is `usize::MAX`. + /// + /// [HTTP/2 Connection Preface]: https://httpwg.org/specs/rfc9113.html#preface + /// [Section 5.1.2]: https://httpwg.org/specs/rfc9113.html#rfc.section.5.1.2 /// /// # Examples /// @@ -841,7 +881,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .initial_max_send_streams(1000) @@ -858,7 +898,7 @@ impl Builder { /// Sets the maximum number of concurrent locally reset streams. /// - /// When a stream is explicitly reset, the HTTP/2.0 specification requires + /// When a stream is explicitly reset, the HTTP/2 specification requires /// that any further frames received for that stream must be ignored for /// "some time". /// @@ -886,7 +926,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .max_concurrent_reset_streams(1000) @@ -903,7 +943,7 @@ impl Builder { /// Sets the duration to remember locally reset streams. /// - /// When a stream is explicitly reset, the HTTP/2.0 specification requires + /// When a stream is explicitly reset, the HTTP/2 specification requires /// that any further frames received for that stream must be ignored for /// "some time". /// @@ -932,7 +972,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .reset_stream_duration(Duration::from_secs(10)) @@ -947,14 +987,92 @@ impl Builder { self } + /// Sets the maximum number of local resets due to protocol errors made by the remote end. + /// + /// Invalid frames and many other protocol errors will lead to resets being generated for those streams. + /// Too many of these often indicate a malicious client, and there are attacks which can abuse this to DOS servers. + /// This limit protects against these DOS attacks by limiting the amount of resets we can be forced to generate. + /// + /// When the number of local resets exceeds this threshold, the client will close the connection. + /// + /// If you really want to disable this, supply [`Option::None`] here. + /// Disabling this is not recommended and may expose you to DOS attacks. + /// + /// The default value is currently 1024, but could change. + pub fn max_local_error_reset_streams(&mut self, max: Option) -> &mut Self { + self.local_max_error_reset_streams = max; + self + } + + /// Sets the maximum number of pending-accept remotely-reset streams. + /// + /// Streams that have been received by the peer, but not accepted by the + /// user, can also receive a RST_STREAM. This is a legitimate pattern: one + /// could send a request and then shortly after, realize it is not needed, + /// sending a CANCEL. + /// + /// However, since those streams are now "closed", they don't count towards + /// the max concurrent streams. So, they will sit in the accept queue, + /// using memory. + /// + /// When the number of remotely-reset streams sitting in the pending-accept + /// queue reaches this maximum value, a connection error with the code of + /// `ENHANCE_YOUR_CALM` will be sent to the peer, and returned by the + /// `Future`. + /// + /// The default value is currently 20, but could change. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use h2::client::*; + /// # use bytes::Bytes; + /// # + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> + /// # { + /// // `client_fut` is a future representing the completion of the HTTP/2 + /// // handshake. + /// let client_fut = Builder::new() + /// .max_pending_accept_reset_streams(100) + /// .handshake(my_io); + /// # client_fut.await + /// # } + /// # + /// # pub fn main() {} + /// ``` + pub fn max_pending_accept_reset_streams(&mut self, max: usize) -> &mut Self { + self.pending_accept_reset_stream_max = max; + self + } + + /// Sets the maximum send buffer size per stream. + /// + /// Once a stream has buffered up to (or over) the maximum, the stream's + /// flow control will not "poll" additional capacity. Once bytes for the + /// stream have been written to the connection, the send buffer capacity + /// will be freed up again. + /// + /// The default is currently ~400KB, but may change. + /// + /// # Panics + /// + /// This function panics if `max` is larger than `u32::MAX`. + pub fn max_send_buffer_size(&mut self, max: usize) -> &mut Self { + assert!(max <= std::u32::MAX as usize); + self.max_send_buffer_size = max; + self + } + /// Enables or disables server push promises. /// - /// This value is included in the initial SETTINGS handshake. When set, the - /// server MUST NOT send a push promise. Setting this value to value to + /// This value is included in the initial SETTINGS handshake. + /// Setting this value to value to /// false in the initial SETTINGS handshake guarantees that the remote server /// will never send a push promise. /// - /// This setting can be changed during the life of a single HTTP/2.0 + /// This setting can be changed during the life of a single HTTP/2 /// connection by sending another settings frame updating the value. /// /// Default value: `true`. @@ -970,7 +1088,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .enable_push(false) @@ -985,6 +1103,39 @@ impl Builder { self } + /// Sets the header table size. + /// + /// This setting informs the peer of the maximum size of the header compression + /// table used to encode header blocks, in octets. The encoder may select any value + /// equal to or less than the header table size specified by the sender. + /// + /// The default value is 4,096. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use h2::client::*; + /// # use bytes::Bytes; + /// # + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> + /// # { + /// // `client_fut` is a future representing the completion of the HTTP/2 + /// // handshake. + /// let client_fut = Builder::new() + /// .header_table_size(1_000_000) + /// .handshake(my_io); + /// # client_fut.await + /// # } + /// # + /// # pub fn main() {} + /// ``` + pub fn header_table_size(&mut self, size: u32) -> &mut Self { + self.settings.set_header_table_size(Some(size)); + self + } + /// Sets the first stream ID to something other than 1. #[cfg(feature = "unstable")] pub fn initial_stream_id(&mut self, stream_id: u32) -> &mut Self { @@ -996,22 +1147,22 @@ impl Builder { self } - /// Creates a new configured HTTP/2.0 client backed by `io`. + /// Creates a new configured HTTP/2 client backed by `io`. /// /// It is expected that `io` already be in an appropriate state to commence - /// the [HTTP/2.0 handshake]. The handshake is completed once both the connection + /// the [HTTP/2 handshake]. The handshake is completed once both the connection /// preface and the initial settings frame is sent by the client. /// /// The handshake future does not wait for the initial settings frame from the /// server. /// /// Returns a future which resolves to the [`Connection`] / [`SendRequest`] - /// tuple once the HTTP/2.0 handshake has been completed. + /// tuple once the HTTP/2 handshake has been completed. /// /// This function also allows the caller to configure the send payload data /// type. See [Outbound data type] for more details. /// - /// [HTTP/2.0 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader + /// [HTTP/2 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader /// [`Connection`]: struct.Connection.html /// [`SendRequest`]: struct.SendRequest.html /// [Outbound data type]: ../index.html#outbound-data-type. @@ -1028,7 +1179,7 @@ impl Builder { /// # async fn doc(my_io: T) /// -> Result<((SendRequest, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .handshake(my_io); @@ -1048,7 +1199,7 @@ impl Builder { /// # async fn doc(my_io: T) /// # -> Result<((SendRequest<&'static [u8]>, Connection)), h2::Error> /// # { - /// // `client_fut` is a future representing the completion of the HTTP/2.0 + /// // `client_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let client_fut = Builder::new() /// .handshake::<_, &'static [u8]>(my_io); @@ -1063,7 +1214,7 @@ impl Builder { ) -> impl Future, Connection), crate::Error>> where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { Connection::handshake2(io, self.clone()) } @@ -1075,19 +1226,19 @@ impl Default for Builder { } } -/// Creates a new configured HTTP/2.0 client with default configuration +/// Creates a new configured HTTP/2 client with default configuration /// values backed by `io`. /// /// It is expected that `io` already be in an appropriate state to commence -/// the [HTTP/2.0 handshake]. See [Handshake] for more details. +/// the [HTTP/2 handshake]. See [Handshake] for more details. /// /// Returns a future which resolves to the [`Connection`] / [`SendRequest`] -/// tuple once the HTTP/2.0 handshake has been completed. The returned +/// tuple once the HTTP/2 handshake has been completed. The returned /// [`Connection`] instance will be using default configuration values. Use /// [`Builder`] to customize the configuration values used by a [`Connection`] /// instance. /// -/// [HTTP/2.0 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader +/// [HTTP/2 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader /// [Handshake]: ../index.html#handshake /// [`Connection`]: struct.Connection.html /// [`SendRequest`]: struct.SendRequest.html @@ -1102,7 +1253,7 @@ impl Default for Builder { /// # async fn doc(my_io: T) -> Result<(), h2::Error> /// # { /// let (send_request, connection) = client::handshake(my_io).await?; -/// // The HTTP/2.0 handshake has completed, now start polling +/// // The HTTP/2 handshake has completed, now start polling /// // `connection` and use `send_request` to send requests to the /// // server. /// # Ok(()) @@ -1115,26 +1266,38 @@ where T: AsyncRead + AsyncWrite + Unpin, { let builder = Builder::new(); - builder.handshake(io).await + builder + .handshake(io) + .instrument(tracing::trace_span!("client_handshake")) + .await } // ===== impl Connection ===== +async fn bind_connection(io: &mut T) -> Result<(), crate::Error> +where + T: AsyncRead + AsyncWrite + Unpin, +{ + tracing::debug!("binding client connection"); + + let msg: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + io.write_all(msg).await.map_err(crate::Error::from_io)?; + + tracing::debug!("client connection bound"); + + Ok(()) +} + impl Connection where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { async fn handshake2( mut io: T, builder: Builder, ) -> Result<(SendRequest, Connection), crate::Error> { - log::debug!("binding client connection"); - - let msg: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; - io.write_all(msg).await.map_err(crate::Error::from_io)?; - - log::debug!("client connection bound"); + bind_connection(&mut io).await?; // Create the codec let mut codec = Codec::new(io); @@ -1157,8 +1320,11 @@ where proto::Config { next_stream_id: builder.stream_id, initial_max_send_streams: builder.initial_max_send_streams, + max_send_buffer_size: builder.max_send_buffer_size, reset_stream_duration: builder.reset_stream_duration, reset_stream_max: builder.reset_stream_max, + remote_reset_stream_max: builder.pending_accept_reset_stream_max, + local_error_reset_streams_max: builder.local_max_error_reset_streams, settings: builder.settings.clone(), }, ); @@ -1224,12 +1390,39 @@ where pub fn ping_pong(&mut self) -> Option { self.inner.take_user_pings().map(PingPong::new) } + + /// Returns the maximum number of concurrent streams that may be initiated + /// by this client. + /// + /// This limit is configured by the server peer by sending the + /// [`SETTINGS_MAX_CONCURRENT_STREAMS` parameter][1] in a `SETTINGS` frame. + /// This method returns the currently acknowledged value received from the + /// remote. + /// + /// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2 + pub fn max_concurrent_send_streams(&self) -> usize { + self.inner.max_send_streams() + } + /// Returns the maximum number of concurrent streams that may be initiated + /// by the server on this connection. + /// + /// This returns the value of the [`SETTINGS_MAX_CONCURRENT_STREAMS` + /// parameter][1] sent in a `SETTINGS` frame that has been + /// acknowledged by the remote peer. The value to be sent is configured by + /// the [`Builder::max_concurrent_streams`][2] method before handshaking + /// with the remote peer. + /// + /// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2 + /// [2]: ../struct.Builder.html#method.max_concurrent_streams + pub fn max_concurrent_recv_streams(&self) -> usize { + self.inner.max_recv_streams() + } } impl Future for Connection where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { type Output = Result<(), crate::Error>; @@ -1375,6 +1568,7 @@ impl Peer { pub fn convert_send_message( id: StreamId, request: Request<()>, + protocol: Option, end_of_stream: bool, ) -> Result { use http::request::Parts; @@ -1394,7 +1588,7 @@ impl Peer { // Build the set pseudo header set. All requests will include `method` // and `path`. - let mut pseudo = Pseudo::request(method, uri); + let mut pseudo = Pseudo::request(method, uri, protocol); if pseudo.scheme.is_none() { // If the scheme is not set, then there are a two options. @@ -1414,7 +1608,7 @@ impl Peer { return Err(UserError::MissingUriSchemeAndAuthority.into()); } else { // This is acceptable as per the above comment. However, - // HTTP/2.0 requires that a scheme is set. Since we are + // HTTP/2 requires that a scheme is set. Since we are // forwarding an HTTP 1.1 request, the scheme is set to // "http". pseudo.set_scheme(uri::Scheme::HTTP); @@ -1438,19 +1632,23 @@ impl Peer { impl proto::Peer for Peer { type Poll = Response<()>; + const NAME: &'static str = "Client"; + fn r#dyn() -> proto::DynPeer { proto::DynPeer::Client } + /* fn is_server() -> bool { false } + */ fn convert_poll_message( pseudo: Pseudo, fields: HeaderMap, stream_id: StreamId, - ) -> Result { + ) -> Result { let mut b = Response::builder(); b = b.version(Version::HTTP_2); @@ -1464,10 +1662,7 @@ impl proto::Peer for Peer { Err(_) => { // TODO: Should there be more specialized handling for different // kinds of errors - return Err(RecvError::Stream { - id: stream_id, - reason: Reason::PROTOCOL_ERROR, - }); + return Err(Error::library_reset(stream_id, Reason::PROTOCOL_ERROR)); } }; diff --git a/src/codec/error.rs b/src/codec/error.rs index 87ab82bf9..c3dd9a772 100644 --- a/src/codec/error.rs +++ b/src/codec/error.rs @@ -1,26 +1,12 @@ -use crate::frame::{Reason, StreamId}; +use crate::proto::Error; use std::{fmt, io}; -/// Errors that are received -#[derive(Debug)] -pub enum RecvError { - Connection(Reason), - Stream { id: StreamId, reason: Reason }, - Io(io::Error), -} - /// Errors caused by sending a message #[derive(Debug)] pub enum SendError { - /// User error + Connection(Error), User(UserError), - - /// Connection error prevents sending. - Connection(Reason), - - /// I/O error - Io(io::Error), } /// Errors caused by users of the library @@ -35,9 +21,6 @@ pub enum UserError { /// The payload size is too big PayloadTooBig, - /// A header size is too big - HeaderTooBig, - /// The application attempted to initiate too many streams to remote. Rejected, @@ -63,45 +46,25 @@ pub enum UserError { /// Tries to update local SETTINGS while ACK has not been received. SendSettingsWhilePending, -} - -// ===== impl RecvError ===== - -impl From for RecvError { - fn from(src: io::Error) -> Self { - RecvError::Io(src) - } -} - -impl fmt::Display for RecvError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - use self::RecvError::*; - match *self { - Connection(ref reason) => reason.fmt(fmt), - Stream { ref reason, .. } => reason.fmt(fmt), - Io(ref e) => e.fmt(fmt), - } - } + /// Tries to send push promise to peer who has disabled server push + PeerDisabledServerPush, } // ===== impl SendError ===== impl fmt::Display for SendError { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - use self::SendError::*; - match *self { - User(ref e) => e.fmt(fmt), - Connection(ref reason) => reason.fmt(fmt), - Io(ref e) => e.fmt(fmt), + Self::Connection(ref e) => e.fmt(fmt), + Self::User(ref e) => e.fmt(fmt), } } } impl From for SendError { fn from(src: io::Error) -> Self { - SendError::Io(src) + Self::Connection(src.into()) } } @@ -121,7 +84,6 @@ impl fmt::Display for UserError { InactiveStreamId => "inactive stream", UnexpectedFrameType => "unexpected frame type", PayloadTooBig => "payload too big", - HeaderTooBig => "header too big", Rejected => "rejected", ReleaseCapacityTooBig => "release capacity too big", OverflowedStreamId => "stream ID overflowed", @@ -130,6 +92,7 @@ impl fmt::Display for UserError { PollResetAfterSendResponse => "poll_reset after send_response is illegal", SendPingWhilePending => "send_ping before received previous pong", SendSettingsWhilePending => "sending SETTINGS before received previous ACK", + PeerDisabledServerPush => "sending PUSH_PROMISE to peer who disabled server push", }) } } diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index 76a236ed2..9270a8635 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -1,8 +1,8 @@ -use crate::codec::RecvError; use crate::frame::{self, Frame, Kind, Reason}; use crate::frame::{ DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE, }; +use crate::proto::Error; use crate::hpack; @@ -30,6 +30,8 @@ pub struct FramedRead { max_header_list_size: usize, + max_continuation_frames: usize, + partial: Option, } @@ -41,6 +43,8 @@ struct Partial { /// Partial header payload buf: BytesMut, + + continuation_frames_count: usize, } #[derive(Debug)] @@ -51,255 +55,18 @@ enum Continuable { impl FramedRead { pub fn new(inner: InnerFramedRead) -> FramedRead { + let max_header_list_size = DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE; + let max_continuation_frames = + calc_max_continuation_frames(max_header_list_size, inner.decoder().max_frame_length()); FramedRead { inner, hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE), - max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE, + max_header_list_size, + max_continuation_frames, partial: None, } } - fn decode_frame(&mut self, mut bytes: BytesMut) -> Result, RecvError> { - use self::RecvError::*; - - log::trace!("decoding frame from {}B", bytes.len()); - - // Parse the head - let head = frame::Head::parse(&bytes); - - if self.partial.is_some() && head.kind() != Kind::Continuation { - proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind()); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - - let kind = head.kind(); - - log::trace!(" -> kind={:?}", kind); - - macro_rules! header_block { - ($frame:ident, $head:ident, $bytes:ident) => ({ - // Drop the frame header - // TODO: Change to drain: carllerche/bytes#130 - let _ = $bytes.split_to(frame::HEADER_LEN); - - // Parse the header frame w/o parsing the payload - let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) { - Ok(res) => res, - Err(frame::Error::InvalidDependencyId) => { - proto_err!(stream: "invalid HEADERS dependency ID"); - // A stream cannot depend on itself. An endpoint MUST - // treat this as a stream error (Section 5.4.2) of type - // `PROTOCOL_ERROR`. - return Err(Stream { - id: $head.stream_id(), - reason: Reason::PROTOCOL_ERROR, - }); - }, - Err(e) => { - proto_err!(conn: "failed to load frame; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - }; - - let is_end_headers = frame.is_end_headers(); - - // Load the HPACK encoded headers - match frame.load_hpack(&mut payload, self.max_header_list_size, &mut self.hpack) { - Ok(_) => {}, - Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, - Err(frame::Error::MalformedMessage) => { - let id = $head.stream_id(); - proto_err!(stream: "malformed header block; stream={:?}", id); - return Err(Stream { - id, - reason: Reason::PROTOCOL_ERROR, - }); - }, - Err(e) => { - proto_err!(conn: "failed HPACK decoding; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - } - - if is_end_headers { - frame.into() - } else { - log::trace!("loaded partial header block"); - // Defer returning the frame - self.partial = Some(Partial { - frame: Continuable::$frame(frame), - buf: payload, - }); - - return Ok(None); - } - }); - } - - let frame = match kind { - Kind::Settings => { - let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); - - res.map_err(|e| { - proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::Ping => { - let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); - - res.map_err(|e| { - proto_err!(conn: "failed to load PING frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::WindowUpdate => { - let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); - - res.map_err(|e| { - proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::Data => { - let _ = bytes.split_to(frame::HEADER_LEN); - let res = frame::Data::load(head, bytes.freeze()); - - // TODO: Should this always be connection level? Probably not... - res.map_err(|e| { - proto_err!(conn: "failed to load DATA frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::Headers => header_block!(Headers, head, bytes), - Kind::Reset => { - let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); - res.map_err(|e| { - proto_err!(conn: "failed to load RESET frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::GoAway => { - let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]); - res.map_err(|e| { - proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::PushPromise => header_block!(PushPromise, head, bytes), - Kind::Priority => { - if head.stream_id() == 0 { - // Invalid stream identifier - proto_err!(conn: "invalid stream ID 0"); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - - match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) { - Ok(frame) => frame.into(), - Err(frame::Error::InvalidDependencyId) => { - // A stream cannot depend on itself. An endpoint MUST - // treat this as a stream error (Section 5.4.2) of type - // `PROTOCOL_ERROR`. - let id = head.stream_id(); - proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id); - return Err(Stream { - id, - reason: Reason::PROTOCOL_ERROR, - }); - } - Err(e) => { - proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - } - } - Kind::Continuation => { - let is_end_headers = (head.flag() & 0x4) == 0x4; - - let mut partial = match self.partial.take() { - Some(partial) => partial, - None => { - proto_err!(conn: "received unexpected CONTINUATION frame"); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - }; - - // The stream identifiers must match - if partial.frame.stream_id() != head.stream_id() { - proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID"); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - - // Extend the buf - if partial.buf.is_empty() { - partial.buf = bytes.split_off(frame::HEADER_LEN); - } else { - if partial.frame.is_over_size() { - // If there was left over bytes previously, they may be - // needed to continue decoding, even though we will - // be ignoring this frame. This is done to keep the HPACK - // decoder state up-to-date. - // - // Still, we need to be careful, because if a malicious - // attacker were to try to send a gigantic string, such - // that it fits over multiple header blocks, we could - // grow memory uncontrollably again, and that'd be a shame. - // - // Instead, we use a simple heuristic to determine if - // we should continue to ignore decoding, or to tell - // the attacker to go away. - if partial.buf.len() + bytes.len() > self.max_header_list_size { - proto_err!(conn: "CONTINUATION frame header block size over ignorable limit"); - return Err(Connection(Reason::COMPRESSION_ERROR)); - } - } - partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); - } - - match partial.frame.load_hpack( - &mut partial.buf, - self.max_header_list_size, - &mut self.hpack, - ) { - Ok(_) => {} - Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) - if !is_end_headers => {} - Err(frame::Error::MalformedMessage) => { - let id = head.stream_id(); - proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id); - return Err(Stream { - id, - reason: Reason::PROTOCOL_ERROR, - }); - } - Err(e) => { - proto_err!(conn: "failed HPACK decoding; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - } - - if is_end_headers { - partial.frame.into() - } else { - self.partial = Some(partial); - return Ok(None); - } - } - Kind::Unknown => { - // Unknown frames are ignored - return Ok(None); - } - }; - - Ok(Some(frame)) - } - pub fn get_ref(&self) -> &T { self.inner.get_ref() } @@ -309,7 +76,6 @@ impl FramedRead { } /// Returns the current max frame size setting - #[cfg(feature = "unstable")] #[inline] pub fn max_frame_size(&self) -> usize { self.inner.decoder().max_frame_length() @@ -321,45 +87,333 @@ impl FramedRead { #[inline] pub fn set_max_frame_size(&mut self, val: usize) { assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize); - self.inner.decoder_mut().set_max_frame_length(val) + self.inner.decoder_mut().set_max_frame_length(val); + // Update max CONTINUATION frames too, since its based on this + self.max_continuation_frames = calc_max_continuation_frames(self.max_header_list_size, val); } /// Update the max header list size setting. #[inline] pub fn set_max_header_list_size(&mut self, val: usize) { self.max_header_list_size = val; + // Update max CONTINUATION frames too, since its based on this + self.max_continuation_frames = calc_max_continuation_frames(val, self.max_frame_size()); } + + /// Update the header table size setting. + #[inline] + pub fn set_header_table_size(&mut self, val: usize) { + self.hpack.queue_size_update(val); + } +} + +fn calc_max_continuation_frames(header_max: usize, frame_max: usize) -> usize { + // At least this many frames needed to use max header list size + let min_frames_for_list = (header_max / frame_max).max(1); + // Some padding for imperfectly packed frames + // 25% without floats + let padding = min_frames_for_list >> 2; + min_frames_for_list.saturating_add(padding).max(5) +} + +/// Decodes a frame. +/// +/// This method is intentionally de-generified and outlined because it is very large. +fn decode_frame( + hpack: &mut hpack::Decoder, + max_header_list_size: usize, + max_continuation_frames: usize, + partial_inout: &mut Option, + mut bytes: BytesMut, +) -> Result, Error> { + let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len()); + let _e = span.enter(); + + tracing::trace!("decoding frame from {}B", bytes.len()); + + // Parse the head + let head = frame::Head::parse(&bytes); + + if partial_inout.is_some() && head.kind() != Kind::Continuation { + proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind()); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + + let kind = head.kind(); + + tracing::trace!(frame.kind = ?kind); + + macro_rules! header_block { + ($frame:ident, $head:ident, $bytes:ident) => ({ + // Drop the frame header + // TODO: Change to drain: carllerche/bytes#130 + let _ = $bytes.split_to(frame::HEADER_LEN); + + // Parse the header frame w/o parsing the payload + let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) { + Ok(res) => res, + Err(frame::Error::InvalidDependencyId) => { + proto_err!(stream: "invalid HEADERS dependency ID"); + // A stream cannot depend on itself. An endpoint MUST + // treat this as a stream error (Section 5.4.2) of type + // `PROTOCOL_ERROR`. + return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR)); + }, + Err(e) => { + proto_err!(conn: "failed to load frame; err={:?}", e); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + }; + + let is_end_headers = frame.is_end_headers(); + + // Load the HPACK encoded headers + match frame.load_hpack(&mut payload, max_header_list_size, hpack) { + Ok(_) => {}, + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, + Err(frame::Error::MalformedMessage) => { + let id = $head.stream_id(); + proto_err!(stream: "malformed header block; stream={:?}", id); + return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); + }, + Err(e) => { + proto_err!(conn: "failed HPACK decoding; err={:?}", e); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + } + + if is_end_headers { + frame.into() + } else { + tracing::trace!("loaded partial header block"); + // Defer returning the frame + *partial_inout = Some(Partial { + frame: Continuable::$frame(frame), + buf: payload, + continuation_frames_count: 0, + }); + + return Ok(None); + } + }); + } + + let frame = match kind { + Kind::Settings => { + let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|e| { + proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::Ping => { + let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|e| { + proto_err!(conn: "failed to load PING frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::WindowUpdate => { + let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|e| { + proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::Data => { + let _ = bytes.split_to(frame::HEADER_LEN); + let res = frame::Data::load(head, bytes.freeze()); + + // TODO: Should this always be connection level? Probably not... + res.map_err(|e| { + proto_err!(conn: "failed to load DATA frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::Headers => header_block!(Headers, head, bytes), + Kind::Reset => { + let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); + res.map_err(|e| { + proto_err!(conn: "failed to load RESET frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::GoAway => { + let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]); + res.map_err(|e| { + proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::PushPromise => header_block!(PushPromise, head, bytes), + Kind::Priority => { + if head.stream_id() == 0 { + // Invalid stream identifier + proto_err!(conn: "invalid stream ID 0"); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + + match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) { + Ok(frame) => frame.into(), + Err(frame::Error::InvalidDependencyId) => { + // A stream cannot depend on itself. An endpoint MUST + // treat this as a stream error (Section 5.4.2) of type + // `PROTOCOL_ERROR`. + let id = head.stream_id(); + proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id); + return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); + } + Err(e) => { + proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + } + } + Kind::Continuation => { + let is_end_headers = (head.flag() & 0x4) == 0x4; + + let mut partial = match partial_inout.take() { + Some(partial) => partial, + None => { + proto_err!(conn: "received unexpected CONTINUATION frame"); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + }; + + // The stream identifiers must match + if partial.frame.stream_id() != head.stream_id() { + proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID"); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + + // Check for CONTINUATION flood + if is_end_headers { + partial.continuation_frames_count = 0; + } else { + let cnt = partial.continuation_frames_count + 1; + if cnt > max_continuation_frames { + tracing::debug!("too_many_continuations, max = {}", max_continuation_frames); + return Err(Error::library_go_away_data( + Reason::ENHANCE_YOUR_CALM, + "too_many_continuations", + )); + } else { + partial.continuation_frames_count = cnt; + } + } + + // Extend the buf + if partial.buf.is_empty() { + partial.buf = bytes.split_off(frame::HEADER_LEN); + } else { + if partial.frame.is_over_size() { + // If there was left over bytes previously, they may be + // needed to continue decoding, even though we will + // be ignoring this frame. This is done to keep the HPACK + // decoder state up-to-date. + // + // Still, we need to be careful, because if a malicious + // attacker were to try to send a gigantic string, such + // that it fits over multiple header blocks, we could + // grow memory uncontrollably again, and that'd be a shame. + // + // Instead, we use a simple heuristic to determine if + // we should continue to ignore decoding, or to tell + // the attacker to go away. + if partial.buf.len() + bytes.len() > max_header_list_size { + proto_err!(conn: "CONTINUATION frame header block size over ignorable limit"); + return Err(Error::library_go_away(Reason::COMPRESSION_ERROR)); + } + } + partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); + } + + match partial + .frame + .load_hpack(&mut partial.buf, max_header_list_size, hpack) + { + Ok(_) => {} + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {} + Err(frame::Error::MalformedMessage) => { + let id = head.stream_id(); + proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id); + return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); + } + Err(e) => { + proto_err!(conn: "failed HPACK decoding; err={:?}", e); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + } + + if is_end_headers { + partial.frame.into() + } else { + *partial_inout = Some(partial); + return Ok(None); + } + } + Kind::Unknown => { + // Unknown frames are ignored + return Ok(None); + } + }; + + Ok(Some(frame)) } impl Stream for FramedRead where T: AsyncRead + Unpin, { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let span = tracing::trace_span!("FramedRead::poll_next"); + let _e = span.enter(); loop { - log::trace!("poll"); + tracing::trace!("poll"); let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) { Some(Ok(bytes)) => bytes, Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))), None => return Poll::Ready(None), }; - log::trace!("poll; bytes={}B", bytes.len()); - if let Some(frame) = self.decode_frame(bytes)? { - log::debug!("received; frame={:?}", frame); + tracing::trace!(read.bytes = bytes.len()); + let Self { + ref mut hpack, + max_header_list_size, + ref mut partial, + max_continuation_frames, + .. + } = *self; + if let Some(frame) = decode_frame( + hpack, + max_header_list_size, + max_continuation_frames, + partial, + bytes, + )? { + tracing::debug!(?frame, "received"); return Poll::Ready(Some(Ok(frame))); } } } } -fn map_err(err: io::Error) -> RecvError { +fn map_err(err: io::Error) -> Error { if let io::ErrorKind::InvalidData = err.kind() { if let Some(custom) = err.get_ref() { if custom.is::() { - return RecvError::Connection(Reason::FRAME_SIZE_ERROR); + return Error::library_go_away(Reason::FRAME_SIZE_ERROR); } } } diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs index c63f12228..c88af02da 100644 --- a/src/codec/framed_write.rs +++ b/src/codec/framed_write.rs @@ -3,13 +3,11 @@ use crate::codec::UserError::*; use crate::frame::{self, Frame, FrameSize}; use crate::hpack; -use bytes::{ - buf::{BufExt, BufMutExt}, - Buf, BufMut, BytesMut, -}; +use bytes::{Buf, BufMut, BytesMut}; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_util::io::poll_write_buf; use std::io::{self, Cursor}; @@ -26,6 +24,11 @@ pub struct FramedWrite { /// Upstream `AsyncWrite` inner: T, + encoder: Encoder, +} + +#[derive(Debug)] +struct Encoder { /// HPACK encoder hpack: hpack::Encoder, @@ -42,6 +45,12 @@ pub struct FramedWrite { /// Max frame size, this is specified by the peer max_frame_size: FrameSize, + + /// Chain payloads bigger than this. + chain_threshold: usize, + + /// Min buffer required to attempt to write a frame + min_buffer_capacity: usize, } #[derive(Debug)] @@ -50,20 +59,22 @@ enum Next { Continuation(frame::Continuation), } -/// Initialze the connection with this amount of write buffer. +/// Initialize the connection with this amount of write buffer. /// /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS /// frame that big. const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024; -/// Min buffer required to attempt to write a frame -const MIN_BUFFER_CAPACITY: usize = frame::HEADER_LEN + CHAIN_THRESHOLD; - -/// Chain payloads bigger than this. The remote will never advertise a max frame -/// size less than this (well, the spec says the max frame size can't be less -/// than 16kb, so not even close). +/// Chain payloads bigger than this when vectored I/O is enabled. The remote +/// will never advertise a max frame size less than this (well, the spec says +/// the max frame size can't be less than 16kb, so not even close). const CHAIN_THRESHOLD: usize = 256; +/// Chain payloads bigger than this when vectored I/O is **not** enabled. +/// A larger value in this scenario will reduce the number of small and +/// fragmented data being sent, and hereby improve the throughput. +const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024; + // TODO: Make generic impl FramedWrite where @@ -71,13 +82,22 @@ where B: Buf, { pub fn new(inner: T) -> FramedWrite { + let chain_threshold = if inner.is_write_vectored() { + CHAIN_THRESHOLD + } else { + CHAIN_THRESHOLD_WITHOUT_VECTORED_IO + }; FramedWrite { inner, - hpack: hpack::Encoder::default(), - buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), - next: None, - last_data_frame: None, - max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, + encoder: Encoder { + hpack: hpack::Encoder::default(), + buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), + next: None, + last_data_frame: None, + max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, + chain_threshold, + min_buffer_capacity: chain_threshold + frame::HEADER_LEN, + }, } } @@ -86,11 +106,11 @@ where /// Calling this function may result in the current contents of the buffer /// to be flushed to `T`. pub fn poll_ready(&mut self, cx: &mut Context) -> Poll> { - if !self.has_capacity() { + if !self.encoder.has_capacity() { // Try flushing ready!(self.flush(cx))?; - if !self.has_capacity() { + if !self.encoder.has_capacity() { return Poll::Pending; } } @@ -103,10 +123,94 @@ where /// `poll_ready` must be called first to ensure that a frame may be /// accepted. pub fn buffer(&mut self, item: Frame) -> Result<(), UserError> { + self.encoder.buffer(item) + } + + /// Flush buffered data to the wire + pub fn flush(&mut self, cx: &mut Context) -> Poll> { + let span = tracing::trace_span!("FramedWrite::flush"); + let _e = span.enter(); + + loop { + while !self.encoder.is_empty() { + match self.encoder.next { + Some(Next::Data(ref mut frame)) => { + tracing::trace!(queued_data_frame = true); + let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); + ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))? + } + _ => { + tracing::trace!(queued_data_frame = false); + ready!(poll_write_buf( + Pin::new(&mut self.inner), + cx, + &mut self.encoder.buf + ))? + } + }; + } + + match self.encoder.unset_frame() { + ControlFlow::Continue => (), + ControlFlow::Break => break, + } + } + + tracing::trace!("flushing buffer"); + // Flush the upstream + ready!(Pin::new(&mut self.inner).poll_flush(cx))?; + + Poll::Ready(Ok(())) + } + + /// Close the codec + pub fn shutdown(&mut self, cx: &mut Context) -> Poll> { + ready!(self.flush(cx))?; + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +#[must_use] +enum ControlFlow { + Continue, + Break, +} + +impl Encoder +where + B: Buf, +{ + fn unset_frame(&mut self) -> ControlFlow { + // Clear internal buffer + self.buf.set_position(0); + self.buf.get_mut().clear(); + + // The data frame has been written, so unset it + match self.next.take() { + Some(Next::Data(frame)) => { + self.last_data_frame = Some(frame); + debug_assert!(self.is_empty()); + ControlFlow::Break + } + Some(Next::Continuation(frame)) => { + // Buffer the continuation frame, then try to write again + let mut buf = limited_write_buf!(self); + if let Some(continuation) = frame.encode(&mut buf) { + self.next = Some(Next::Continuation(continuation)); + } + ControlFlow::Continue + } + None => ControlFlow::Break, + } + } + + fn buffer(&mut self, item: Frame) -> Result<(), UserError> { // Ensure that we have enough capacity to accept the write. assert!(self.has_capacity()); + let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item); + let _e = span.enter(); - log::debug!("send; frame={:?}", item); + tracing::debug!(frame = ?item, "send"); match item { Frame::Data(mut v) => { @@ -117,12 +221,17 @@ where return Err(PayloadTooBig); } - if len >= CHAIN_THRESHOLD { + if len >= self.chain_threshold { let head = v.head(); // Encode the frame head to the buffer head.encode(len, self.buf.get_mut()); + if self.buf.get_ref().remaining() < self.chain_threshold { + let extra_bytes = self.chain_threshold - self.buf.remaining(); + self.buf.get_mut().put(v.payload_mut().take(extra_bytes)); + } + // Save the data frame self.next = Some(Next::Data(v)); } else { @@ -150,105 +259,41 @@ where } Frame::Settings(v) => { v.encode(self.buf.get_mut()); - log::trace!("encoded settings; rem={:?}", self.buf.remaining()); + tracing::trace!(rem = self.buf.remaining(), "encoded settings"); } Frame::GoAway(v) => { v.encode(self.buf.get_mut()); - log::trace!("encoded go_away; rem={:?}", self.buf.remaining()); + tracing::trace!(rem = self.buf.remaining(), "encoded go_away"); } Frame::Ping(v) => { v.encode(self.buf.get_mut()); - log::trace!("encoded ping; rem={:?}", self.buf.remaining()); + tracing::trace!(rem = self.buf.remaining(), "encoded ping"); } Frame::WindowUpdate(v) => { v.encode(self.buf.get_mut()); - log::trace!("encoded window_update; rem={:?}", self.buf.remaining()); + tracing::trace!(rem = self.buf.remaining(), "encoded window_update"); } Frame::Priority(_) => { /* v.encode(self.buf.get_mut()); - log::trace!("encoded priority; rem={:?}", self.buf.remaining()); + tracing::trace!("encoded priority; rem={:?}", self.buf.remaining()); */ unimplemented!(); } Frame::Reset(v) => { v.encode(self.buf.get_mut()); - log::trace!("encoded reset; rem={:?}", self.buf.remaining()); + tracing::trace!(rem = self.buf.remaining(), "encoded reset"); } } Ok(()) } - /// Flush buffered data to the wire - pub fn flush(&mut self, cx: &mut Context) -> Poll> { - log::trace!("flush"); - - loop { - while !self.is_empty() { - match self.next { - Some(Next::Data(ref mut frame)) => { - log::trace!(" -> queued data frame"); - let mut buf = (&mut self.buf).chain(frame.payload_mut()); - ready!(Pin::new(&mut self.inner).poll_write_buf(cx, &mut buf))?; - } - _ => { - log::trace!(" -> not a queued data frame"); - ready!(Pin::new(&mut self.inner).poll_write_buf(cx, &mut self.buf))?; - } - } - } - - // Clear internal buffer - self.buf.set_position(0); - self.buf.get_mut().clear(); - - // The data frame has been written, so unset it - match self.next.take() { - Some(Next::Data(frame)) => { - self.last_data_frame = Some(frame); - debug_assert!(self.is_empty()); - break; - } - Some(Next::Continuation(frame)) => { - // Buffer the continuation frame, then try to write again - let mut buf = limited_write_buf!(self); - if let Some(continuation) = frame.encode(&mut self.hpack, &mut buf) { - // We previously had a CONTINUATION, and after encoding - // it, we got *another* one? Let's just double check - // that at least some progress is being made... - if self.buf.get_ref().len() == frame::HEADER_LEN { - // If *only* the CONTINUATION frame header was - // written, and *no* header fields, we're stuck - // in a loop... - panic!("CONTINUATION frame write loop; header value too big to encode"); - } - - self.next = Some(Next::Continuation(continuation)); - } - } - None => { - break; - } - } - } - - log::trace!("flushing buffer"); - // Flush the upstream - ready!(Pin::new(&mut self.inner).poll_flush(cx))?; - - Poll::Ready(Ok(())) - } - - /// Close the codec - pub fn shutdown(&mut self, cx: &mut Context) -> Poll> { - ready!(self.flush(cx))?; - Pin::new(&mut self.inner).poll_shutdown(cx) - } - fn has_capacity(&self) -> bool { - self.next.is_none() && self.buf.get_ref().remaining_mut() >= MIN_BUFFER_CAPACITY + self.next.is_none() + && (self.buf.get_ref().capacity() - self.buf.get_ref().len() + >= self.min_buffer_capacity) } fn is_empty(&self) -> bool { @@ -259,26 +304,32 @@ where } } +impl Encoder { + fn max_frame_size(&self) -> usize { + self.max_frame_size as usize + } +} + impl FramedWrite { /// Returns the max frame size that can be sent pub fn max_frame_size(&self) -> usize { - self.max_frame_size as usize + self.encoder.max_frame_size() } /// Set the peer's max frame size. pub fn set_max_frame_size(&mut self, val: usize) { assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); - self.max_frame_size = val as FrameSize; + self.encoder.max_frame_size = val as FrameSize; } /// Set the peer's header table size. pub fn set_header_table_size(&mut self, val: usize) { - self.hpack.update_max_size(val); + self.encoder.hpack.update_max_size(val); } /// Retrieve the last data frame that has been sent pub fn take_last_data_frame(&mut self) -> Option> { - self.last_data_frame.take() + self.encoder.last_data_frame.take() } pub fn get_mut(&mut self) -> &mut T { @@ -287,25 +338,13 @@ impl FramedWrite { } impl AsyncRead for FramedWrite { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { Pin::new(&mut self.inner).poll_read(cx, buf) } - - fn poll_read_buf( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut Buf, - ) -> Poll> { - Pin::new(&mut self.inner).poll_read_buf(cx, buf) - } } // We never project the Pin to `B`. diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 7d0ab73d8..6cbdc1e18 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -2,12 +2,13 @@ mod error; mod framed_read; mod framed_write; -pub use self::error::{RecvError, SendError, UserError}; +pub use self::error::{SendError, UserError}; use self::framed_read::FramedRead; use self::framed_write::FramedWrite; use crate::frame::{self, Data, Frame}; +use crate::proto::Error; use bytes::Buf; use futures_core::Stream; @@ -94,6 +95,11 @@ impl Codec { self.framed_write().set_header_table_size(val) } + /// Set the decoder header table size size. + pub fn set_recv_header_table_size(&mut self, val: usize) { + self.inner.set_header_table_size(val) + } + /// Set the max header list size that can be received. pub fn set_max_recv_header_list_size(&mut self, val: usize) { self.inner.set_max_header_list_size(val); @@ -155,7 +161,7 @@ impl Stream for Codec where T: AsyncRead + Unpin, { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_next(cx) diff --git a/src/error.rs b/src/error.rs index 71f119057..02d91cbec 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,11 +1,13 @@ use crate::codec::{SendError, UserError}; -use crate::proto; +use crate::frame::StreamId; +use crate::proto::{self, Initiator}; +use bytes::Bytes; use std::{error, fmt, io}; pub use crate::frame::Reason; -/// Represents HTTP/2.0 operation errors. +/// Represents HTTP/2 operation errors. /// /// `Error` covers error cases raised by protocol errors caused by the /// peer, I/O (transport) errors, and errors caused by the user of the library. @@ -22,11 +24,15 @@ pub struct Error { #[derive(Debug)] enum Kind { - /// An error caused by an action taken by the remote peer. - /// - /// This is either an error received by the peer or caused by an invalid - /// action taken by the peer (i.e. a protocol error). - Proto(Reason), + /// A RST_STREAM frame was received or sent. + #[allow(dead_code)] + Reset(StreamId, Reason, Initiator), + + /// A GO_AWAY frame was received or sent. + GoAway(Bytes, Reason, Initiator), + + /// The user created an error from a bare Reason. + Reason(Reason), /// An error resulting from an invalid action taken by the user of this /// library. @@ -45,17 +51,16 @@ impl Error { /// action taken by the peer (i.e. a protocol error). pub fn reason(&self) -> Option { match self.kind { - Kind::Proto(reason) => Some(reason), + Kind::Reset(_, reason, _) | Kind::GoAway(_, reason, _) | Kind::Reason(reason) => { + Some(reason) + } _ => None, } } - /// Returns the true if the error is an io::Error + /// Returns true if the error is an io::Error pub fn is_io(&self) -> bool { - match self.kind { - Kind::Io(_) => true, - _ => false, - } + matches!(self.kind, Kind::Io(..)) } /// Returns the error if the error is an io::Error @@ -79,6 +84,36 @@ impl Error { kind: Kind::Io(err), } } + + /// Returns true if the error is from a `GOAWAY`. + pub fn is_go_away(&self) -> bool { + matches!(self.kind, Kind::GoAway(..)) + } + + /// Returns true if the error is from a `RST_STREAM`. + pub fn is_reset(&self) -> bool { + matches!(self.kind, Kind::Reset(..)) + } + + /// Returns true if the error was received in a frame from the remote. + /// + /// Such as from a received `RST_STREAM` or `GOAWAY` frame. + pub fn is_remote(&self) -> bool { + matches!( + self.kind, + Kind::GoAway(_, _, Initiator::Remote) | Kind::Reset(_, _, Initiator::Remote) + ) + } + + /// Returns true if the error was created by `h2`. + /// + /// Such as noticing some protocol error and sending a GOAWAY or RST_STREAM. + pub fn is_library(&self) -> bool { + matches!( + self.kind, + Kind::GoAway(_, _, Initiator::Library) | Kind::Reset(_, _, Initiator::Library) + ) + } } impl From for Error { @@ -87,8 +122,13 @@ impl From for Error { Error { kind: match src { - Proto(reason) => Kind::Proto(reason), - Io(e) => Kind::Io(e), + Reset(stream_id, reason, initiator) => Kind::Reset(stream_id, reason, initiator), + GoAway(debug_data, reason, initiator) => { + Kind::GoAway(debug_data, reason, initiator) + } + Io(kind, inner) => { + Kind::Io(inner.map_or_else(|| kind.into(), |inner| io::Error::new(kind, inner))) + } }, } } @@ -97,7 +137,7 @@ impl From for Error { impl From for Error { fn from(src: Reason) -> Error { Error { - kind: Kind::Proto(src), + kind: Kind::Reason(src), } } } @@ -106,8 +146,7 @@ impl From for Error { fn from(src: SendError) -> Error { match src { SendError::User(e) => e.into(), - SendError::Connection(reason) => reason.into(), - SendError::Io(e) => Error::from_io(e), + SendError::Connection(e) => e.into(), } } } @@ -122,13 +161,38 @@ impl From for Error { impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - use self::Kind::*; - - match self.kind { - Proto(ref reason) => write!(fmt, "protocol error: {}", reason), - User(ref e) => write!(fmt, "user error: {}", e), - Io(ref e) => fmt::Display::fmt(e, fmt), + let debug_data = match self.kind { + Kind::Reset(_, reason, Initiator::User) => { + return write!(fmt, "stream error sent by user: {}", reason) + } + Kind::Reset(_, reason, Initiator::Library) => { + return write!(fmt, "stream error detected: {}", reason) + } + Kind::Reset(_, reason, Initiator::Remote) => { + return write!(fmt, "stream error received: {}", reason) + } + Kind::GoAway(ref debug_data, reason, Initiator::User) => { + write!(fmt, "connection error sent by user: {}", reason)?; + debug_data + } + Kind::GoAway(ref debug_data, reason, Initiator::Library) => { + write!(fmt, "connection error detected: {}", reason)?; + debug_data + } + Kind::GoAway(ref debug_data, reason, Initiator::Remote) => { + write!(fmt, "connection error received: {}", reason)?; + debug_data + } + Kind::Reason(reason) => return write!(fmt, "protocol error: {}", reason), + Kind::User(ref e) => return write!(fmt, "user error: {}", e), + Kind::Io(ref e) => return e.fmt(fmt), + }; + + if !debug_data.is_empty() { + write!(fmt, " ({:?})", debug_data)?; } + + Ok(()) } } @@ -140,3 +204,26 @@ impl error::Error for Error { } } } + +#[cfg(test)] +mod tests { + use std::error::Error as _; + use std::io; + + use super::Error; + use crate::Reason; + + #[test] + fn error_from_reason() { + let err = Error::from(Reason::HTTP_1_1_REQUIRED); + assert_eq!(err.reason(), Some(Reason::HTTP_1_1_REQUIRED)); + } + + #[test] + fn io_error_source() { + let err = Error::from_io(io::Error::new(io::ErrorKind::BrokenPipe, "hi")); + let source = err.source().expect("io error should have source"); + let io_err = source.downcast_ref::().expect("should be io error"); + assert_eq!(io_err.kind(), io::ErrorKind::BrokenPipe); + } +} diff --git a/src/ext.rs b/src/ext.rs new file mode 100644 index 000000000..cf383a495 --- /dev/null +++ b/src/ext.rs @@ -0,0 +1,55 @@ +//! Extensions specific to the HTTP/2 protocol. + +use crate::hpack::BytesStr; + +use bytes::Bytes; +use std::fmt; + +/// Represents the `:protocol` pseudo-header used by +/// the [Extended CONNECT Protocol]. +/// +/// [Extended CONNECT Protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 +#[derive(Clone, Eq, PartialEq)] +pub struct Protocol { + value: BytesStr, +} + +impl Protocol { + /// Converts a static string to a protocol name. + pub const fn from_static(value: &'static str) -> Self { + Self { + value: BytesStr::from_static(value), + } + } + + /// Returns a str representation of the header. + pub fn as_str(&self) -> &str { + self.value.as_str() + } + + pub(crate) fn try_from(bytes: Bytes) -> Result { + Ok(Self { + value: BytesStr::try_from(bytes)?, + }) + } +} + +impl<'a> From<&'a str> for Protocol { + fn from(value: &'a str) -> Self { + Self { + value: BytesStr::from(value), + } + } +} + +impl AsRef<[u8]> for Protocol { + fn as_ref(&self) -> &[u8] { + self.value.as_ref() + } +} + +impl fmt::Debug for Protocol { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.value.fmt(f) + } +} diff --git a/src/frame/data.rs b/src/frame/data.rs index 91de52df9..5ed3c31b5 100644 --- a/src/frame/data.rs +++ b/src/frame/data.rs @@ -16,7 +16,7 @@ pub struct Data { pad_len: Option, } -#[derive(Copy, Clone, Eq, PartialEq)] +#[derive(Copy, Clone, Default, Eq, PartialEq)] struct DataFlags(u8); const END_STREAM: u8 = 0x1; @@ -36,7 +36,7 @@ impl Data { } } - /// Returns the stream identifer that this frame is associated with. + /// Returns the stream identifier that this frame is associated with. /// /// This cannot be a zero stream identifier. pub fn stream_id(&self) -> StreamId { @@ -63,7 +63,7 @@ impl Data { } } - /// Returns whther the `PADDED` flag is set on this frame. + /// Returns whether the `PADDED` flag is set on this frame. #[cfg(feature = "unstable")] pub fn is_padded(&self) -> bool { self.flags.is_padded() @@ -148,7 +148,7 @@ impl Data { /// /// Panics if `dst` cannot contain the data frame. pub(crate) fn encode_chunk(&mut self, dst: &mut U) { - let len = self.data.remaining() as usize; + let len = self.data.remaining(); assert!(dst.remaining_mut() >= len); @@ -211,12 +211,6 @@ impl DataFlags { } } -impl Default for DataFlags { - fn default() -> Self { - DataFlags(0) - } -} - impl From for u8 { fn from(src: DataFlags) -> u8 { src.0 diff --git a/src/frame/go_away.rs b/src/frame/go_away.rs index a46ba7a37..99330e981 100644 --- a/src/frame/go_away.rs +++ b/src/frame/go_away.rs @@ -8,7 +8,6 @@ use crate::frame::{self, Error, Head, Kind, Reason, StreamId}; pub struct GoAway { last_stream_id: StreamId, error_code: Reason, - #[allow(unused)] debug_data: Bytes, } @@ -21,6 +20,14 @@ impl GoAway { } } + pub fn with_debug_data(last_stream_id: StreamId, reason: Reason, debug_data: Bytes) -> Self { + Self { + last_stream_id, + error_code: reason, + debug_data, + } + } + pub fn last_stream_id(&self) -> StreamId { self.last_stream_id } @@ -29,8 +36,7 @@ impl GoAway { self.error_code } - #[cfg(feature = "unstable")] - pub fn debug_data(&self) -> &[u8] { + pub fn debug_data(&self) -> &Bytes { &self.debug_data } @@ -51,11 +57,12 @@ impl GoAway { } pub fn encode(&self, dst: &mut B) { - log::trace!("encoding GO_AWAY; code={:?}", self.error_code); + tracing::trace!("encoding GO_AWAY; code={:?}", self.error_code); let head = Head::new(Kind::GoAway, 0, StreamId::zero()); - head.encode(8, dst); + head.encode(8 + self.debug_data.len(), dst); dst.put_u32(self.last_stream_id.into()); dst.put_u32(self.error_code.into()); + dst.put(self.debug_data.slice(..)); } } diff --git a/src/frame/head.rs b/src/frame/head.rs index 2abc08e1d..38be2f697 100644 --- a/src/frame/head.rs +++ b/src/frame/head.rs @@ -36,7 +36,7 @@ impl Head { } } - /// Parse an HTTP/2.0 frame header + /// Parse an HTTP/2 frame header pub fn parse(header: &[u8]) -> Head { let (stream_id, _) = StreamId::parse(&header[5..]); diff --git a/src/frame/headers.rs b/src/frame/headers.rs index 2491d8da0..e9b163e56 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -1,20 +1,17 @@ use super::{util, StreamDependency, StreamId}; +use crate::ext::Protocol; use crate::frame::{Error, Frame, Head, Kind}; use crate::hpack::{self, BytesStr}; use http::header::{self, HeaderName, HeaderValue}; use http::{uri, HeaderMap, Method, Request, StatusCode, Uri}; -use bytes::{Bytes, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use std::fmt; use std::io::Cursor; -type EncodeBuf<'a> = bytes::buf::ext::Limit<&'a mut BytesMut>; - -// Minimum MAX_FRAME_SIZE is 16kb, so save some arbitrary space for frame -// head and other header bits. -const MAX_HEADER_LENGTH: usize = 1024 * 16 - 100; +type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>; /// Header frame /// @@ -71,6 +68,7 @@ pub struct Pseudo { pub scheme: Option, pub authority: Option, pub path: Option, + pub protocol: Option, // Response pub status: Option, @@ -90,6 +88,9 @@ struct HeaderBlock { /// The decoded header fields fields: HeaderMap, + /// Precomputed size of all of our header fields, for perf reasons + field_size: usize, + /// Set to true if decoding went over the max header list size. is_over_size: bool, @@ -100,11 +101,7 @@ struct HeaderBlock { #[derive(Debug)] struct EncodingHeaderBlock { - /// Argument to pass to the HPACK encoder to resume encoding - hpack: Option, - - /// remaining headers to encode - headers: Iter, + hpack: Bytes, } const END_STREAM: u8 = 0x1; @@ -122,6 +119,7 @@ impl Headers { stream_id, stream_dep: None, header_block: HeaderBlock { + field_size: calculate_headermap_size(&fields), fields, is_over_size: false, pseudo, @@ -138,6 +136,7 @@ impl Headers { stream_id, stream_dep: None, header_block: HeaderBlock { + field_size: calculate_headermap_size(&fields), fields, is_over_size: false, pseudo: Pseudo::default(), @@ -153,7 +152,11 @@ impl Headers { let flags = HeadersFlag(head.flag()); let mut pad = 0; - log::trace!("loading headers; flags={:?}", flags); + tracing::trace!("loading headers; flags={:?}", flags); + + if head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } // Read the padding length if flags.is_padded() { @@ -199,6 +202,7 @@ impl Headers { stream_dep, header_block: HeaderBlock { fields: HeaderMap::new(), + field_size: 0, is_over_size: false, pseudo: Pseudo::default(), }, @@ -241,10 +245,6 @@ impl Headers { self.header_block.is_over_size } - pub(crate) fn has_too_big_field(&self) -> bool { - self.header_block.has_too_big_field() - } - pub fn into_parts(self) -> (Pseudo, HeaderMap) { (self.header_block.pseudo, self.header_block.fields) } @@ -254,6 +254,11 @@ impl Headers { &mut self.header_block.pseudo } + /// Whether it has status 1xx + pub(crate) fn is_informational(&self) -> bool { + self.header_block.pseudo.is_informational() + } + pub fn fields(&self) -> &HeaderMap { &self.header_block.fields } @@ -274,8 +279,8 @@ impl Headers { let head = self.head(); self.header_block - .into_encoding() - .encode(&head, encoder, dst, |_| {}) + .into_encoding(encoder) + .encode(&head, dst, |_| {}) } fn head(&self) -> Head { @@ -296,6 +301,10 @@ impl fmt::Debug for Headers { .field("stream_id", &self.stream_id) .field("flags", &self.flags); + if let Some(ref protocol) = self.header_block.pseudo.protocol { + builder.field("protocol", protocol); + } + if let Some(ref dep) = self.stream_dep { builder.field("stream_dep", dep); } @@ -307,17 +316,20 @@ impl fmt::Debug for Headers { // ===== util ===== -pub fn parse_u64(src: &[u8]) -> Result { +#[derive(Debug, PartialEq, Eq)] +pub struct ParseU64Error; + +pub fn parse_u64(src: &[u8]) -> Result { if src.len() > 19 { // At danger for overflow... - return Err(()); + return Err(ParseU64Error); } let mut ret = 0; for &d in src { if d < b'0' || d > b'9' { - return Err(()); + return Err(ParseU64Error); } ret *= 10; @@ -331,7 +343,7 @@ pub fn parse_u64(src: &[u8]) -> Result { #[derive(Debug)] pub enum PushPromiseHeaderError { - InvalidContentLength(Result), + InvalidContentLength(Result), NotSafeAndCacheable, } @@ -345,6 +357,7 @@ impl PushPromise { PushPromise { flags: PushPromiseFlag::default(), header_block: HeaderBlock { + field_size: calculate_headermap_size(&fields), fields, is_over_size: false, pseudo, @@ -379,7 +392,7 @@ impl PushPromise { fn safe_and_cacheable(method: &Method) -> bool { // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods - return method == Method::GET || method == Method::HEAD; + method == Method::GET || method == Method::HEAD } pub fn fields(&self) -> &HeaderMap { @@ -398,6 +411,10 @@ impl PushPromise { let flags = PushPromiseFlag(head.flag()); let mut pad = 0; + if head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } + // Read the padding length if flags.is_padded() { if src.is_empty() { @@ -432,6 +449,7 @@ impl PushPromise { flags, header_block: HeaderBlock { fields: HeaderMap::new(), + field_size: 0, is_over_size: false, pseudo: Pseudo::default(), }, @@ -475,8 +493,6 @@ impl PushPromise { encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option { - use bytes::BufMut; - // At this point, the `is_end_headers` flag should always be set debug_assert!(self.flags.is_end_headers()); @@ -484,8 +500,8 @@ impl PushPromise { let promised_id = self.promised_id; self.header_block - .into_encoding() - .encode(&head, encoder, dst, |dst| { + .into_encoding(encoder) + .encode(&head, dst, |dst| { dst.put_u32(promised_id.into()); }) } @@ -524,38 +540,39 @@ impl Continuation { Head::new(Kind::Continuation, END_HEADERS, self.stream_id) } - pub fn encode( - self, - encoder: &mut hpack::Encoder, - dst: &mut EncodeBuf<'_>, - ) -> Option { + pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option { // Get the CONTINUATION frame head let head = self.head(); - self.header_block.encode(&head, encoder, dst, |_| {}) + self.header_block.encode(&head, dst, |_| {}) } } // ===== impl Pseudo ===== impl Pseudo { - pub fn request(method: Method, uri: Uri) -> Self { + pub fn request(method: Method, uri: Uri, protocol: Option) -> Self { let parts = uri::Parts::from(uri); let mut path = parts .path_and_query - .map(|v| Bytes::copy_from_slice(v.as_str().as_bytes())) - .unwrap_or_else(Bytes::new); + .map(|v| BytesStr::from(v.as_str())) + .unwrap_or(BytesStr::from_static("")); - if path.is_empty() && method != Method::OPTIONS { - path = Bytes::from_static(b"/"); + match method { + Method::OPTIONS | Method::CONNECT => {} + _ if path.is_empty() => { + path = BytesStr::from_static("/"); + } + _ => {} } let mut pseudo = Pseudo { method: Some(method), scheme: None, authority: None, - path: Some(unsafe { BytesStr::from_utf8_unchecked(path) }), + path: Some(path).filter(|p| !p.is_empty()), + protocol, status: None, }; @@ -569,9 +586,7 @@ impl Pseudo { // If the URI includes an authority component, add it to the pseudo // headers if let Some(authority) = parts.authority { - pseudo.set_authority(unsafe { - BytesStr::from_utf8_unchecked(Bytes::copy_from_slice(authority.as_str().as_bytes())) - }); + pseudo.set_authority(BytesStr::from(authority.as_str())); } pseudo @@ -583,34 +598,45 @@ impl Pseudo { scheme: None, authority: None, path: None, + protocol: None, status: Some(status), } } + #[cfg(feature = "unstable")] + pub fn set_status(&mut self, value: StatusCode) { + self.status = Some(value); + } + pub fn set_scheme(&mut self, scheme: uri::Scheme) { - let bytes = match scheme.as_str() { - "http" => Bytes::from_static(b"http"), - "https" => Bytes::from_static(b"https"), - s => Bytes::copy_from_slice(s.as_bytes()), + let bytes_str = match scheme.as_str() { + "http" => BytesStr::from_static("http"), + "https" => BytesStr::from_static("https"), + s => BytesStr::from(s), }; - self.scheme = Some(unsafe { BytesStr::from_utf8_unchecked(bytes) }); + self.scheme = Some(bytes_str); + } + + #[cfg(feature = "unstable")] + pub fn set_protocol(&mut self, protocol: Protocol) { + self.protocol = Some(protocol); } pub fn set_authority(&mut self, authority: BytesStr) { self.authority = Some(authority); } + + /// Whether it has status 1xx + pub(crate) fn is_informational(&self) -> bool { + self.status + .map_or(false, |status| status.is_informational()) + } } // ===== impl EncodingHeaderBlock ===== impl EncodingHeaderBlock { - fn encode( - mut self, - head: &Head, - encoder: &mut hpack::Encoder, - dst: &mut EncodeBuf<'_>, - f: F, - ) -> Option + fn encode(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option where F: FnOnce(&mut EncodeBuf<'_>), { @@ -626,15 +652,17 @@ impl EncodingHeaderBlock { f(dst); // Now, encode the header payload - let continuation = match encoder.encode(self.hpack, &mut self.headers, dst) { - hpack::Encode::Full => None, - hpack::Encode::Partial(state) => Some(Continuation { + let continuation = if self.hpack.len() > dst.remaining_mut() { + dst.put_slice(&self.hpack.split_to(dst.remaining_mut())); + + Some(Continuation { stream_id: head.stream_id(), - header_block: EncodingHeaderBlock { - hpack: Some(state), - headers: self.headers, - }, - }), + header_block: self, + }) + } else { + dst.put_slice(&self.hpack); + + None }; // Compute the header block length @@ -682,6 +710,10 @@ impl Iterator for Iter { return Some(Path(path)); } + if let Some(protocol) = pseudo.protocol.take() { + return Some(Protocol(protocol)); + } + if let Some(status) = pseudo.status.take() { return Some(Status(status)); } @@ -817,19 +849,19 @@ impl HeaderBlock { macro_rules! set_pseudo { ($field:ident, $val:expr) => {{ if reg { - log::trace!("load_hpack; header malformed -- pseudo not at head of block"); + tracing::trace!("load_hpack; header malformed -- pseudo not at head of block"); malformed = true; } else if self.pseudo.$field.is_some() { - log::trace!("load_hpack; header malformed -- repeated pseudo"); + tracing::trace!("load_hpack; header malformed -- repeated pseudo"); malformed = true; } else { let __val = $val; headers_size += - decoded_header_size(stringify!($ident).len() + 1, __val.as_str().len()); + decoded_header_size(stringify!($field).len() + 1, __val.as_str().len()); if headers_size < max_header_list_size { self.pseudo.$field = Some(__val); } else if !self.is_over_size { - log::trace!("load_hpack; header list size over max"); + tracing::trace!("load_hpack; header list size over max"); self.is_over_size = true; } } @@ -856,19 +888,24 @@ impl HeaderBlock { || name == "keep-alive" || name == "proxy-connection" { - log::trace!("load_hpack; connection level header"); + tracing::trace!("load_hpack; connection level header"); malformed = true; } else if name == header::TE && value != "trailers" { - log::trace!("load_hpack; TE header not set to trailers; val={:?}", value); + tracing::trace!( + "load_hpack; TE header not set to trailers; val={:?}", + value + ); malformed = true; } else { reg = true; headers_size += decoded_header_size(name.as_str().len(), value.len()); if headers_size < max_header_list_size { + self.field_size += + decoded_header_size(name.as_str().len(), value.len()); self.fields.append(name, value); } else if !self.is_over_size { - log::trace!("load_hpack; header list size over max"); + tracing::trace!("load_hpack; header list size over max"); self.is_over_size = true; } } @@ -877,30 +914,35 @@ impl HeaderBlock { Method(v) => set_pseudo!(method, v), Scheme(v) => set_pseudo!(scheme, v), Path(v) => set_pseudo!(path, v), + Protocol(v) => set_pseudo!(protocol, v), Status(v) => set_pseudo!(status, v), } }); if let Err(e) = res { - log::trace!("hpack decoding error; err={:?}", e); + tracing::trace!("hpack decoding error; err={:?}", e); return Err(e.into()); } if malformed { - log::trace!("malformed message"); + tracing::trace!("malformed message"); return Err(Error::MalformedMessage); } Ok(()) } - fn into_encoding(self) -> EncodingHeaderBlock { + fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock { + let mut hpack = BytesMut::new(); + let headers = Iter { + pseudo: Some(self.pseudo), + fields: self.fields.into_iter(), + }; + + encoder.encode(headers, &mut hpack); + EncodingHeaderBlock { - hpack: None, - headers: Iter { - pseudo: Some(self.pseudo), - fields: self.fields.into_iter(), - }, + hpack: hpack.freeze(), } } @@ -927,54 +969,83 @@ impl HeaderBlock { + pseudo_size!(status) + pseudo_size!(authority) + pseudo_size!(path) - + self - .fields - .iter() - .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len())) - .sum::() + + self.field_size } +} - /// Iterate over all pseudos and headers to see if any individual pair - /// would be too large to encode. - pub(crate) fn has_too_big_field(&self) -> bool { - macro_rules! pseudo_size { - ($name:ident) => {{ - self.pseudo - .$name - .as_ref() - .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len())) - .unwrap_or(0) - }}; - } - - if pseudo_size!(method) > MAX_HEADER_LENGTH { - return true; - } - - if pseudo_size!(scheme) > MAX_HEADER_LENGTH { - return true; - } - - if pseudo_size!(authority) > MAX_HEADER_LENGTH { - return true; - } - - if pseudo_size!(path) > MAX_HEADER_LENGTH { - return true; - } - - // skip :status, its never going to be too big - - for (name, value) in &self.fields { - if decoded_header_size(name.as_str().len(), value.len()) > MAX_HEADER_LENGTH { - return true; - } - } - - false - } +fn calculate_headermap_size(map: &HeaderMap) -> usize { + map.iter() + .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len())) + .sum::() } fn decoded_header_size(name: usize, value: usize) -> usize { name + value + 32 } + +#[cfg(test)] +mod test { + use super::*; + use crate::frame; + use crate::hpack::{huffman, Encoder}; + + #[test] + fn test_nameless_header_at_resume() { + let mut encoder = Encoder::default(); + let mut dst = BytesMut::new(); + + let headers = Headers::new( + StreamId::ZERO, + Default::default(), + HeaderMap::from_iter(vec![ + ( + HeaderName::from_static("hello"), + HeaderValue::from_static("world"), + ), + ( + HeaderName::from_static("hello"), + HeaderValue::from_static("zomg"), + ), + ( + HeaderName::from_static("hello"), + HeaderValue::from_static("sup"), + ), + ]), + ); + + let continuation = headers + .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8)) + .unwrap(); + + assert_eq!(17, dst.len()); + assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]); + assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]); + assert_eq!("hello", huff_decode(&dst[11..15])); + assert_eq!(0x80 | 4, dst[15]); + + let mut world = dst[16..17].to_owned(); + + dst.clear(); + + assert!(continuation + .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16)) + .is_none()); + + world.extend_from_slice(&dst[9..12]); + assert_eq!("world", huff_decode(&world)); + + assert_eq!(24, dst.len()); + assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]); + + // // Next is not indexed + assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]); + assert_eq!("zomg", huff_decode(&dst[15..18])); + assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]); + assert_eq!("sup", huff_decode(&dst[21..])); + } + + fn huff_decode(src: &[u8]) -> BytesMut { + let mut buf = BytesMut::new(); + huffman::decode(src, &mut buf).unwrap() + } +} diff --git a/src/frame/mod.rs b/src/frame/mod.rs index 4c49d6bb1..0e8e7035c 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -11,11 +11,11 @@ use std::fmt; /// /// # Examples /// -/// ```rust +/// ```ignore +/// # // We ignore this doctest because the macro is not exported. /// let buf: [u8; 4] = [0, 0, 0, 1]; /// assert_eq!(1u32, unpack_octets_4!(buf, 0, u32)); /// ``` -#[macro_escape] macro_rules! unpack_octets_4 { // TODO: Get rid of this macro ($buf:expr, $offset:expr, $tip:ty) => { @@ -26,6 +26,15 @@ macro_rules! unpack_octets_4 { }; } +#[cfg(test)] +mod tests { + #[test] + fn test_unpack_octets_4() { + let buf: [u8; 4] = [0, 0, 0, 1]; + assert_eq!(1u32, unpack_octets_4!(buf, 0, u32)); + } +} + mod data; mod go_away; mod head; @@ -60,7 +69,7 @@ pub use crate::hpack::BytesStr; pub use self::settings::{ DEFAULT_INITIAL_WINDOW_SIZE, DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, - MAX_INITIAL_WINDOW_SIZE, MAX_MAX_FRAME_SIZE, + MAX_MAX_FRAME_SIZE, }; pub type FrameSize = u32; diff --git a/src/frame/ping.rs b/src/frame/ping.rs index 1802ec185..241d06ea1 100644 --- a/src/frame/ping.rs +++ b/src/frame/ping.rs @@ -85,7 +85,7 @@ impl Ping { pub fn encode(&self, dst: &mut B) { let sz = self.payload.len(); - log::trace!("encoding PING; ack={} len={}", self.ack, sz); + tracing::trace!("encoding PING; ack={} len={}", self.ack, sz); let flags = if self.ack { ACK_FLAG } else { 0 }; let head = Head::new(Kind::Ping, flags, StreamId::zero()); diff --git a/src/frame/reason.rs b/src/frame/reason.rs index 031b6cd92..ff5e2012f 100644 --- a/src/frame/reason.rs +++ b/src/frame/reason.rs @@ -1,6 +1,6 @@ use std::fmt; -/// HTTP/2.0 error codes. +/// HTTP/2 error codes. /// /// Error codes are used in `RST_STREAM` and `GOAWAY` frames to convey the /// reasons for the stream or connection error. For example, diff --git a/src/frame/reset.rs b/src/frame/reset.rs index 6edecf1a3..39f6ac202 100644 --- a/src/frame/reset.rs +++ b/src/frame/reset.rs @@ -2,7 +2,7 @@ use crate::frame::{self, Error, Head, Kind, Reason, StreamId}; use bytes::BufMut; -#[derive(Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub struct Reset { stream_id: StreamId, error_code: Reason, @@ -38,7 +38,7 @@ impl Reset { } pub fn encode(&self, dst: &mut B) { - log::trace!( + tracing::trace!( "encoding RESET; id={:?} code={:?}", self.stream_id, self.error_code diff --git a/src/frame/settings.rs b/src/frame/settings.rs index c70938144..484498a9d 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -13,6 +13,7 @@ pub struct Settings { initial_window_size: Option, max_frame_size: Option, max_header_list_size: Option, + enable_connect_protocol: Option, } /// An enum that lists all valid settings that can be sent in a SETTINGS @@ -27,6 +28,7 @@ pub enum Setting { InitialWindowSize(u32), MaxFrameSize(u32), MaxHeaderListSize(u32), + EnableConnectProtocol(u32), } #[derive(Copy, Clone, Eq, PartialEq, Default)] @@ -99,23 +101,29 @@ impl Settings { self.max_header_list_size = size; } - pub fn is_push_enabled(&self) -> bool { - self.enable_push.unwrap_or(1) != 0 + pub fn is_push_enabled(&self) -> Option { + self.enable_push.map(|val| val != 0) } pub fn set_enable_push(&mut self, enable: bool) { self.enable_push = Some(enable as u32); } + pub fn is_extended_connect_protocol_enabled(&self) -> Option { + self.enable_connect_protocol.map(|val| val != 0) + } + + pub fn set_enable_connect_protocol(&mut self, val: Option) { + self.enable_connect_protocol = val; + } + pub fn header_table_size(&self) -> Option { self.header_table_size } - /* pub fn set_header_table_size(&mut self, size: Option) { self.header_table_size = size; } - */ pub fn load(head: Head, payload: &[u8]) -> Result { use self::Setting::*; @@ -141,7 +149,7 @@ impl Settings { // Ensure the payload length is correct, each setting is 6 bytes long. if payload.len() % 6 != 0 { - log::debug!("invalid settings payload length; len={:?}", payload.len()); + tracing::debug!("invalid settings payload length; len={:?}", payload.len()); return Err(Error::InvalidPayloadAckSettings); } @@ -172,15 +180,23 @@ impl Settings { } } Some(MaxFrameSize(val)) => { - if val < DEFAULT_MAX_FRAME_SIZE || val > MAX_MAX_FRAME_SIZE { - return Err(Error::InvalidSettingValue); - } else { + if DEFAULT_MAX_FRAME_SIZE <= val && val <= MAX_MAX_FRAME_SIZE { settings.max_frame_size = Some(val); + } else { + return Err(Error::InvalidSettingValue); } } Some(MaxHeaderListSize(val)) => { settings.max_header_list_size = Some(val); } + Some(EnableConnectProtocol(val)) => match val { + 0 | 1 => { + settings.enable_connect_protocol = Some(val); + } + _ => { + return Err(Error::InvalidSettingValue); + } + }, None => {} } } @@ -199,13 +215,13 @@ impl Settings { let head = Head::new(Kind::Settings, self.flags.into(), StreamId::zero()); let payload_len = self.payload_len(); - log::trace!("encoding SETTINGS; len={}", payload_len); + tracing::trace!("encoding SETTINGS; len={}", payload_len); head.encode(payload_len, dst); // Encode the settings self.for_each(|setting| { - log::trace!("encoding setting; val={:?}", setting); + tracing::trace!("encoding setting; val={:?}", setting); setting.encode(dst) }); } @@ -236,6 +252,10 @@ impl Settings { if let Some(v) = self.max_header_list_size { f(MaxHeaderListSize(v)); } + + if let Some(v) = self.enable_connect_protocol { + f(EnableConnectProtocol(v)); + } } } @@ -269,6 +289,9 @@ impl fmt::Debug for Settings { Setting::MaxHeaderListSize(v) => { builder.field("max_header_list_size", &v); } + Setting::EnableConnectProtocol(v) => { + builder.field("enable_connect_protocol", &v); + } }); builder.finish() @@ -291,6 +314,7 @@ impl Setting { 4 => Some(InitialWindowSize(val)), 5 => Some(MaxFrameSize(val)), 6 => Some(MaxHeaderListSize(val)), + 8 => Some(EnableConnectProtocol(val)), _ => None, } } @@ -322,6 +346,7 @@ impl Setting { InitialWindowSize(v) => (4, v), MaxFrameSize(v) => (5, v), MaxHeaderListSize(v) => (6, v), + EnableConnectProtocol(v) => (8, v), }; dst.put_u16(kind); diff --git a/src/frame/window_update.rs b/src/frame/window_update.rs index 72c1c2581..eed2ce17e 100644 --- a/src/frame/window_update.rs +++ b/src/frame/window_update.rs @@ -48,7 +48,7 @@ impl WindowUpdate { } pub fn encode(&self, dst: &mut B) { - log::trace!("encoding WINDOW_UPDATE; id={:?}", self.stream_id); + tracing::trace!("encoding WINDOW_UPDATE; id={:?}", self.stream_id); let head = Head::new(Kind::WindowUpdate, 0, self.stream_id); head.encode(4, dst); dst.put_u32(self.size_increment); diff --git a/src/fuzz_bridge.rs b/src/fuzz_bridge.rs new file mode 100644 index 000000000..3ea8b591c --- /dev/null +++ b/src/fuzz_bridge.rs @@ -0,0 +1,28 @@ +#[cfg(fuzzing)] +pub mod fuzz_logic { + use crate::hpack; + use bytes::BytesMut; + use http::header::HeaderName; + use std::io::Cursor; + + pub fn fuzz_hpack(data_: &[u8]) { + let mut decoder_ = hpack::Decoder::new(0); + let mut buf = BytesMut::new(); + buf.extend(data_); + let _dec_res = decoder_.decode(&mut Cursor::new(&mut buf), |_h| {}); + + if let Ok(s) = std::str::from_utf8(data_) { + if let Ok(h) = http::Method::from_bytes(s.as_bytes()) { + let m_ = hpack::Header::Method(h); + let mut encoder = hpack::Encoder::new(0, 0); + let _res = encode(&mut encoder, vec![m_]); + } + } + } + + fn encode(e: &mut hpack::Encoder, hdrs: Vec>>) -> BytesMut { + let mut dst = BytesMut::with_capacity(1024); + e.encode(&mut hdrs.into_iter(), &mut dst); + dst + } +} diff --git a/src/hpack/decoder.rs b/src/hpack/decoder.rs index 4befa8702..e48976c36 100644 --- a/src/hpack/decoder.rs +++ b/src/hpack/decoder.rs @@ -142,6 +142,12 @@ struct Table { max_size: usize, } +struct StringMarker { + offset: usize, + len: usize, + string: Option, +} + // ===== impl Decoder ===== impl Decoder { @@ -183,7 +189,10 @@ impl Decoder { self.last_max_update = size; } - log::trace!("decode"); + let span = tracing::trace_span!("hpack::decode"); + let _e = span.enter(); + + tracing::trace!("decode"); while let Some(ty) = peek_u8(src) { // At this point we are always at the beginning of the next block @@ -191,14 +200,14 @@ impl Decoder { // determined from the first byte. match Representation::load(ty)? { Indexed => { - log::trace!(" Indexed; rem={:?}", src.remaining()); + tracing::trace!(rem = src.remaining(), kind = %"Indexed"); can_resize = false; let entry = self.decode_indexed(src)?; consume(src); f(entry); } LiteralWithIndexing => { - log::trace!(" LiteralWithIndexing; rem={:?}", src.remaining()); + tracing::trace!(rem = src.remaining(), kind = %"LiteralWithIndexing"); can_resize = false; let entry = self.decode_literal(src, true)?; @@ -209,14 +218,14 @@ impl Decoder { f(entry); } LiteralWithoutIndexing => { - log::trace!(" LiteralWithoutIndexing; rem={:?}", src.remaining()); + tracing::trace!(rem = src.remaining(), kind = %"LiteralWithoutIndexing"); can_resize = false; let entry = self.decode_literal(src, false)?; consume(src); f(entry); } LiteralNeverIndexed => { - log::trace!(" LiteralNeverIndexed; rem={:?}", src.remaining()); + tracing::trace!(rem = src.remaining(), kind = %"LiteralNeverIndexed"); can_resize = false; let entry = self.decode_literal(src, false)?; consume(src); @@ -226,7 +235,7 @@ impl Decoder { f(entry); } SizeUpdate => { - log::trace!(" SizeUpdate; rem={:?}", src.remaining()); + tracing::trace!(rem = src.remaining(), kind = %"SizeUpdate"); if !can_resize { return Err(DecoderError::InvalidMaxDynamicSize); } @@ -248,10 +257,10 @@ impl Decoder { return Err(DecoderError::InvalidMaxDynamicSize); } - log::debug!( - "Decoder changed max table size from {} to {}", - self.table.size(), - new_size + tracing::debug!( + from = self.table.size(), + to = new_size, + "Decoder changed max table size" ); self.table.set_max_size(new_size); @@ -276,10 +285,13 @@ impl Decoder { // First, read the header name if table_idx == 0 { + let old_pos = buf.position(); + let name_marker = self.try_decode_string(buf)?; + let value_marker = self.try_decode_string(buf)?; + buf.set_position(old_pos); // Read the name as a literal - let name = self.decode_string(buf)?; - let value = self.decode_string(buf)?; - + let name = name_marker.consume(buf); + let value = value_marker.consume(buf); Header::new(name, value) } else { let e = self.table.get(table_idx)?; @@ -289,7 +301,11 @@ impl Decoder { } } - fn decode_string(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result { + fn try_decode_string( + &mut self, + buf: &mut Cursor<&mut BytesMut>, + ) -> Result { + let old_pos = buf.position(); const HUFF_FLAG: u8 = 0b1000_0000; // The first bit in the first byte contains the huffman encoded flag. @@ -302,25 +318,38 @@ impl Decoder { let len = decode_int(buf, 7)?; if len > buf.remaining() { - log::trace!( - "decode_string underflow; len={}; remaining={}", - len, - buf.remaining() - ); + tracing::trace!(len, remaining = buf.remaining(), "decode_string underflow",); return Err(DecoderError::NeedMore(NeedMore::StringUnderflow)); } + let offset = (buf.position() - old_pos) as usize; if huff { let ret = { - let raw = &buf.bytes()[..len]; - huffman::decode(raw, &mut self.buffer).map(BytesMut::freeze) + let raw = &buf.chunk()[..len]; + huffman::decode(raw, &mut self.buffer).map(|buf| StringMarker { + offset, + len, + string: Some(BytesMut::freeze(buf)), + }) }; buf.advance(len); - return ret; + ret + } else { + buf.advance(len); + Ok(StringMarker { + offset, + len, + string: None, + }) } + } - Ok(take(buf, len)) + fn decode_string(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result { + let old_pos = buf.position(); + let marker = self.try_decode_string(buf)?; + buf.set_position(old_pos); + Ok(marker.consume(buf)) } } @@ -418,9 +447,9 @@ fn decode_int(buf: &mut B, prefix_size: u8) -> Result(buf: &mut B) -> Option { +fn peek_u8(buf: &B) -> Option { if buf.has_remaining() { - Some(buf.bytes()[0]) + Some(buf.chunk()[0]) } else { None } @@ -434,6 +463,19 @@ fn take(buf: &mut Cursor<&mut BytesMut>, n: usize) -> Bytes { head.freeze() } +impl StringMarker { + fn consume(self, buf: &mut Cursor<&mut BytesMut>) -> Bytes { + buf.advance(self.offset); + match self.string { + Some(string) => { + buf.advance(self.len); + string + } + None => take(buf, self.len), + } + } +} + fn consume(buf: &mut Cursor<&mut BytesMut>) { // remove bytes from the internal BytesMut when they have been successfully // decoded. This is a more permanent cursor position, which will be @@ -578,13 +620,13 @@ pub fn get_static(idx: usize) -> Header { use http::header::HeaderValue; match idx { - 1 => Header::Authority(from_static("")), + 1 => Header::Authority(BytesStr::from_static("")), 2 => Header::Method(Method::GET), 3 => Header::Method(Method::POST), - 4 => Header::Path(from_static("/")), - 5 => Header::Path(from_static("/index.html")), - 6 => Header::Scheme(from_static("http")), - 7 => Header::Scheme(from_static("https")), + 4 => Header::Path(BytesStr::from_static("/")), + 5 => Header::Path(BytesStr::from_static("/index.html")), + 6 => Header::Scheme(BytesStr::from_static("http")), + 7 => Header::Scheme(BytesStr::from_static("https")), 8 => Header::Status(StatusCode::OK), 9 => Header::Status(StatusCode::NO_CONTENT), 10 => Header::Status(StatusCode::PARTIAL_CONTENT), @@ -784,22 +826,17 @@ pub fn get_static(idx: usize) -> Header { } } -fn from_static(s: &'static str) -> BytesStr { - unsafe { BytesStr::from_utf8_unchecked(Bytes::from_static(s.as_bytes())) } -} - #[cfg(test)] mod test { use super::*; - use crate::hpack::Header; #[test] fn test_peek_u8() { let b = 0xff; let mut buf = Cursor::new(vec![b]); - assert_eq!(peek_u8(&mut buf), Some(b)); + assert_eq!(peek_u8(&buf), Some(b)); assert_eq!(buf.get_u8(), b); - assert_eq!(peek_u8(&mut buf), None); + assert_eq!(peek_u8(&buf), None); } #[test] @@ -814,8 +851,7 @@ mod test { fn test_decode_empty() { let mut de = Decoder::new(0); let mut buf = BytesMut::new(); - let empty = de.decode(&mut Cursor::new(&mut buf), |_| {}).unwrap(); - assert_eq!(empty, ()); + let _: () = de.decode(&mut Cursor::new(&mut buf), |_| {}).unwrap(); } #[test] @@ -823,17 +859,16 @@ mod test { let mut de = Decoder::new(0); let mut buf = BytesMut::new(); - buf.extend(&[0b01000000, 0x80 | 2]); + buf.extend([0b01000000, 0x80 | 2]); buf.extend(huff_encode(b"foo")); - buf.extend(&[0x80 | 3]); + buf.extend([0x80 | 3]); buf.extend(huff_encode(b"bar")); let mut res = vec![]; - let _ = de - .decode(&mut Cursor::new(&mut buf), |h| { - res.push(h); - }) - .unwrap(); + de.decode(&mut Cursor::new(&mut buf), |h| { + res.push(h); + }) + .unwrap(); assert_eq!(res.len(), 1); assert_eq!(de.table.size(), 0); @@ -852,7 +887,50 @@ mod test { fn huff_encode(src: &[u8]) -> BytesMut { let mut buf = BytesMut::new(); - huffman::encode(src, &mut buf).unwrap(); + huffman::encode(src, &mut buf); buf } + + #[test] + fn test_decode_continuation_header_with_non_huff_encoded_name() { + let mut de = Decoder::new(0); + let value = huff_encode(b"bar"); + let mut buf = BytesMut::new(); + // header name is non_huff encoded + buf.extend([0b01000000, 3]); + buf.extend(b"foo"); + // header value is partial + buf.extend([0x80 | 3]); + buf.extend(&value[0..1]); + + let mut res = vec![]; + let e = de + .decode(&mut Cursor::new(&mut buf), |h| { + res.push(h); + }) + .unwrap_err(); + // decode error because the header value is partial + assert_eq!(e, DecoderError::NeedMore(NeedMore::StringUnderflow)); + + // extend buf with the remaining header value + buf.extend(&value[1..]); + de.decode(&mut Cursor::new(&mut buf), |h| { + res.push(h); + }) + .unwrap(); + + assert_eq!(res.len(), 1); + assert_eq!(de.table.size(), 0); + + match res[0] { + Header::Field { + ref name, + ref value, + } => { + assert_eq!(name, "foo"); + assert_eq!(value, "bar"); + } + _ => panic!(), + } + } } diff --git a/src/hpack/encoder.rs b/src/hpack/encoder.rs index ef177485f..bd49056f6 100644 --- a/src/hpack/encoder.rs +++ b/src/hpack/encoder.rs @@ -1,34 +1,15 @@ use super::table::{Index, Table}; use super::{huffman, Header}; -use bytes::{buf::ext::Limit, BufMut, BytesMut}; +use bytes::{BufMut, BytesMut}; use http::header::{HeaderName, HeaderValue}; -type DstBuf<'a> = Limit<&'a mut BytesMut>; - #[derive(Debug)] pub struct Encoder { table: Table, size_update: Option, } -#[derive(Debug)] -pub enum Encode { - Full, - Partial(EncodeState), -} - -#[derive(Debug)] -pub struct EncodeState { - index: Index, - value: Option, -} - -#[derive(Debug, PartialEq, Eq)] -pub enum EncoderError { - BufferOverflow, -} - #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum SizeUpdate { One(usize), @@ -77,56 +58,24 @@ impl Encoder { } /// Encode a set of headers into the provide buffer - pub fn encode( - &mut self, - resume: Option, - headers: &mut I, - dst: &mut DstBuf<'_>, - ) -> Encode + pub fn encode(&mut self, headers: I, dst: &mut BytesMut) where - I: Iterator>>, + I: IntoIterator>>, { - let pos = position(dst); + let span = tracing::trace_span!("hpack::encode"); + let _e = span.enter(); - if let Err(e) = self.encode_size_updates(dst) { - if e == EncoderError::BufferOverflow { - rewind(dst, pos); - } - - unreachable!("encode_size_updates errored"); - } + self.encode_size_updates(dst); let mut last_index = None; - if let Some(resume) = resume { - let pos = position(dst); - - let res = match resume.value { - Some(ref value) => self.encode_header_without_name(&resume.index, value, dst), - None => self.encode_header(&resume.index, dst), - }; - - if res.is_err() { - rewind(dst, pos); - return Encode::Partial(resume); - } - last_index = Some(resume.index); - } - for header in headers { - let pos = position(dst); - match header.reify() { // The header has an associated name. In which case, try to // index it in the table. Ok(header) => { let index = self.table.index(header); - let res = self.encode_header(&index, dst); - - if res.is_err() { - rewind(dst, pos); - return Encode::Partial(EncodeState { index, value: None }); - } + self.encode_header(&index, dst); last_index = Some(index); } @@ -135,99 +84,81 @@ impl Encoder { // which case, we skip table lookup and just use the same index // as the previous entry. Err(value) => { - let res = self.encode_header_without_name( + self.encode_header_without_name( last_index.as_ref().unwrap_or_else(|| { panic!("encoding header without name, but no previous index to use for name"); }), &value, dst, ); - - if res.is_err() { - rewind(dst, pos); - return Encode::Partial(EncodeState { - index: last_index.unwrap(), // checked just above - value: Some(value), - }); - } } - }; + } } - - Encode::Full } - fn encode_size_updates(&mut self, dst: &mut DstBuf<'_>) -> Result<(), EncoderError> { + fn encode_size_updates(&mut self, dst: &mut BytesMut) { match self.size_update.take() { Some(SizeUpdate::One(val)) => { self.table.resize(val); - encode_size_update(val, dst)?; + encode_size_update(val, dst); } Some(SizeUpdate::Two(min, max)) => { self.table.resize(min); self.table.resize(max); - encode_size_update(min, dst)?; - encode_size_update(max, dst)?; + encode_size_update(min, dst); + encode_size_update(max, dst); } None => {} } - - Ok(()) } - fn encode_header(&mut self, index: &Index, dst: &mut DstBuf<'_>) -> Result<(), EncoderError> { + fn encode_header(&mut self, index: &Index, dst: &mut BytesMut) { match *index { Index::Indexed(idx, _) => { - encode_int(idx, 7, 0x80, dst)?; + encode_int(idx, 7, 0x80, dst); } Index::Name(idx, _) => { - let header = self.table.resolve(&index); + let header = self.table.resolve(index); - encode_not_indexed(idx, header.value_slice(), header.is_sensitive(), dst)?; + encode_not_indexed(idx, header.value_slice(), header.is_sensitive(), dst); } Index::Inserted(_) => { - let header = self.table.resolve(&index); + let header = self.table.resolve(index); assert!(!header.is_sensitive()); - if !dst.has_remaining_mut() { - return Err(EncoderError::BufferOverflow); - } - dst.put_u8(0b0100_0000); - encode_str(header.name().as_slice(), dst)?; - encode_str(header.value_slice(), dst)?; + encode_str(header.name().as_slice(), dst); + encode_str(header.value_slice(), dst); } Index::InsertedValue(idx, _) => { - let header = self.table.resolve(&index); + let header = self.table.resolve(index); assert!(!header.is_sensitive()); - encode_int(idx, 6, 0b0100_0000, dst)?; - encode_str(header.value_slice(), dst)?; + encode_int(idx, 6, 0b0100_0000, dst); + encode_str(header.value_slice(), dst); } Index::NotIndexed(_) => { - let header = self.table.resolve(&index); + let header = self.table.resolve(index); encode_not_indexed2( header.name().as_slice(), header.value_slice(), header.is_sensitive(), dst, - )?; + ); } } - - Ok(()) } fn encode_header_without_name( &mut self, last: &Index, value: &HeaderValue, - dst: &mut DstBuf<'_>, - ) -> Result<(), EncoderError> { + dst: &mut BytesMut, + ) { match *last { Index::Indexed(..) | Index::Name(..) @@ -235,7 +166,7 @@ impl Encoder { | Index::InsertedValue(..) => { let idx = self.table.resolve_idx(last); - encode_not_indexed(idx, value.as_ref(), value.is_sensitive(), dst)?; + encode_not_indexed(idx, value.as_ref(), value.is_sensitive(), dst); } Index::NotIndexed(_) => { let last = self.table.resolve(last); @@ -245,11 +176,9 @@ impl Encoder { value.as_ref(), value.is_sensitive(), dst, - )?; + ); } } - - Ok(()) } } @@ -259,52 +188,32 @@ impl Default for Encoder { } } -fn encode_size_update(val: usize, dst: &mut B) -> Result<(), EncoderError> { +fn encode_size_update(val: usize, dst: &mut BytesMut) { encode_int(val, 5, 0b0010_0000, dst) } -fn encode_not_indexed( - name: usize, - value: &[u8], - sensitive: bool, - dst: &mut DstBuf<'_>, -) -> Result<(), EncoderError> { +fn encode_not_indexed(name: usize, value: &[u8], sensitive: bool, dst: &mut BytesMut) { if sensitive { - encode_int(name, 4, 0b10000, dst)?; + encode_int(name, 4, 0b10000, dst); } else { - encode_int(name, 4, 0, dst)?; + encode_int(name, 4, 0, dst); } - encode_str(value, dst)?; - Ok(()) + encode_str(value, dst); } -fn encode_not_indexed2( - name: &[u8], - value: &[u8], - sensitive: bool, - dst: &mut DstBuf<'_>, -) -> Result<(), EncoderError> { - if !dst.has_remaining_mut() { - return Err(EncoderError::BufferOverflow); - } - +fn encode_not_indexed2(name: &[u8], value: &[u8], sensitive: bool, dst: &mut BytesMut) { if sensitive { dst.put_u8(0b10000); } else { dst.put_u8(0); } - encode_str(name, dst)?; - encode_str(value, dst)?; - Ok(()) + encode_str(name, dst); + encode_str(value, dst); } -fn encode_str(val: &[u8], dst: &mut DstBuf<'_>) -> Result<(), EncoderError> { - if !dst.has_remaining_mut() { - return Err(EncoderError::BufferOverflow); - } - +fn encode_str(val: &[u8], dst: &mut BytesMut) { if !val.is_empty() { let idx = position(dst); @@ -312,50 +221,43 @@ fn encode_str(val: &[u8], dst: &mut DstBuf<'_>) -> Result<(), EncoderError> { dst.put_u8(0); // Encode with huffman - huffman::encode(val, dst)?; + huffman::encode(val, dst); let huff_len = position(dst) - (idx + 1); if encode_int_one_byte(huff_len, 7) { // Write the string head - dst.get_mut()[idx] = 0x80 | huff_len as u8; + dst[idx] = 0x80 | huff_len as u8; } else { - // Write the head to a placeholer + // Write the head to a placeholder const PLACEHOLDER_LEN: usize = 8; let mut buf = [0u8; PLACEHOLDER_LEN]; let head_len = { let mut head_dst = &mut buf[..]; - encode_int(huff_len, 7, 0x80, &mut head_dst)?; + encode_int(huff_len, 7, 0x80, &mut head_dst); PLACEHOLDER_LEN - head_dst.remaining_mut() }; - if dst.remaining_mut() < head_len { - return Err(EncoderError::BufferOverflow); - } - // This is just done to reserve space in the destination dst.put_slice(&buf[1..head_len]); - let written = dst.get_mut(); // Shift the header forward for i in 0..huff_len { let src_i = idx + 1 + (huff_len - (i + 1)); let dst_i = idx + head_len + (huff_len - (i + 1)); - written[dst_i] = written[src_i]; + dst[dst_i] = dst[src_i]; } // Copy in the head for i in 0..head_len { - written[idx + i] = buf[i]; + dst[idx + i] = buf[i]; } } } else { // Write an empty string dst.put_u8(0); } - - Ok(()) } /// Encode an integer into the given destination buffer @@ -364,47 +266,25 @@ fn encode_int( prefix_bits: usize, // The number of bits in the prefix first_byte: u8, // The base upon which to start encoding the int dst: &mut B, -) -> Result<(), EncoderError> { - let mut rem = dst.remaining_mut(); - - if rem == 0 { - return Err(EncoderError::BufferOverflow); - } - +) { if encode_int_one_byte(value, prefix_bits) { dst.put_u8(first_byte | value as u8); - return Ok(()); + return; } let low = (1 << prefix_bits) - 1; value -= low; - if value > 0x0fff_ffff { - panic!("value out of range"); - } - dst.put_u8(first_byte | low as u8); - rem -= 1; while value >= 128 { - if rem == 0 { - return Err(EncoderError::BufferOverflow); - } - dst.put_u8(0b1000_0000 | value as u8); - rem -= 1; value >>= 7; } - if rem == 0 { - return Err(EncoderError::BufferOverflow); - } - dst.put_u8(value as u8); - - Ok(()) } /// Returns true if the in the int can be fully encoded in the first byte. @@ -412,19 +292,13 @@ fn encode_int_one_byte(value: usize, prefix_bits: usize) -> bool { value < (1 << prefix_bits) - 1 } -fn position(buf: &DstBuf<'_>) -> usize { - buf.get_ref().len() -} - -fn rewind(buf: &mut DstBuf<'_>, pos: usize) { - buf.get_mut().truncate(pos); +fn position(buf: &BytesMut) -> usize { + buf.len() } #[cfg(test)] mod test { use super::*; - use crate::hpack::Header; - use bytes::buf::BufMutExt; use http::*; #[test] @@ -802,49 +676,15 @@ mod test { } #[test] - fn test_nameless_header_at_resume() { + fn test_large_size_update() { let mut encoder = Encoder::default(); - let max_len = 15; - let mut dst = BytesMut::with_capacity(64); - - let mut input = vec![ - Header::Field { - name: Some("hello".parse().unwrap()), - value: HeaderValue::from_bytes(b"world").unwrap(), - }, - Header::Field { - name: None, - value: HeaderValue::from_bytes(b"zomg").unwrap(), - }, - Header::Field { - name: None, - value: HeaderValue::from_bytes(b"sup").unwrap(), - }, - ] - .into_iter(); - - let resume = match encoder.encode(None, &mut input, &mut (&mut dst).limit(max_len)) { - Encode::Partial(r) => r, - _ => panic!("encode should be partial"), - }; - - assert_eq!(&[0x40, 0x80 | 4], &dst[0..2]); - assert_eq!("hello", huff_decode(&dst[2..6])); - assert_eq!(0x80 | 4, dst[6]); - assert_eq!("world", huff_decode(&dst[7..11])); - dst.clear(); + encoder.update_max_size(1912930560); + assert_eq!(Some(SizeUpdate::One(1912930560)), encoder.size_update); - match encoder.encode(Some(resume), &mut input, &mut (&mut dst).limit(max_len)) { - Encode::Full => {} - unexpected => panic!("resume returned unexpected: {:?}", unexpected), - } - - // Next is not indexed - assert_eq!(&[15, 47, 0x80 | 3], &dst[0..3]); - assert_eq!("zomg", huff_decode(&dst[3..6])); - assert_eq!(&[15, 47, 0x80 | 3], &dst[6..9]); - assert_eq!("sup", huff_decode(&dst[9..])); + let mut dst = BytesMut::with_capacity(6); + encoder.encode_size_updates(&mut dst); + assert_eq!([63, 225, 129, 148, 144, 7], &dst[..]); } #[test] @@ -855,7 +695,7 @@ mod test { fn encode(e: &mut Encoder, hdrs: Vec>>) -> BytesMut { let mut dst = BytesMut::with_capacity(1024); - e.encode(None, &mut hdrs.into_iter(), &mut (&mut dst).limit(1024)); + e.encode(&mut hdrs.into_iter(), &mut dst); dst } diff --git a/src/hpack/header.rs b/src/hpack/header.rs index 74369506c..0b5d1fded 100644 --- a/src/hpack/header.rs +++ b/src/hpack/header.rs @@ -1,11 +1,12 @@ use super::{DecoderError, NeedMore}; +use crate::ext::Protocol; use bytes::Bytes; use http::header::{HeaderName, HeaderValue}; use http::{Method, StatusCode}; use std::fmt; -/// HTTP/2.0 Header +/// HTTP/2 Header #[derive(Debug, Clone, Eq, PartialEq)] pub enum Header { Field { name: T, value: HeaderValue }, @@ -14,6 +15,7 @@ pub enum Header { Method(Method), Scheme(BytesStr), Path(BytesStr), + Protocol(Protocol), Status(StatusCode), } @@ -25,6 +27,7 @@ pub enum Name<'a> { Method, Scheme, Path, + Protocol, Status, } @@ -51,6 +54,7 @@ impl Header> { Method(v) => Method(v), Scheme(v) => Scheme(v), Path(v) => Path(v), + Protocol(v) => Protocol(v), Status(v) => Status(v), }) } @@ -79,6 +83,10 @@ impl Header { let value = BytesStr::try_from(value)?; Ok(Header::Path(value)) } + b"protocol" => { + let value = Protocol::try_from(value)?; + Ok(Header::Protocol(value)) + } b"status" => { let status = StatusCode::from_bytes(&value)?; Ok(Header::Status(status)) @@ -104,6 +112,7 @@ impl Header { Header::Method(ref v) => 32 + 7 + v.as_ref().len(), Header::Scheme(ref v) => 32 + 7 + v.len(), Header::Path(ref v) => 32 + 5 + v.len(), + Header::Protocol(ref v) => 32 + 9 + v.as_str().len(), Header::Status(_) => 32 + 7 + 3, } } @@ -116,6 +125,7 @@ impl Header { Header::Method(..) => Name::Method, Header::Scheme(..) => Name::Scheme, Header::Path(..) => Name::Path, + Header::Protocol(..) => Name::Protocol, Header::Status(..) => Name::Status, } } @@ -127,6 +137,7 @@ impl Header { Header::Method(ref v) => v.as_ref().as_ref(), Header::Scheme(ref v) => v.as_ref(), Header::Path(ref v) => v.as_ref(), + Header::Protocol(ref v) => v.as_ref(), Header::Status(ref v) => v.as_str().as_ref(), } } @@ -156,6 +167,10 @@ impl Header { Header::Path(ref b) => a == b, _ => false, }, + Header::Protocol(ref a) => match *other { + Header::Protocol(ref b) => a == b, + _ => false, + }, Header::Status(ref a) => match *other { Header::Status(ref b) => a == b, _ => false, @@ -175,18 +190,18 @@ impl Header { use http::header; match *self { - Header::Field { ref name, .. } => match *name { + Header::Field { ref name, .. } => matches!( + *name, header::AGE - | header::AUTHORIZATION - | header::CONTENT_LENGTH - | header::ETAG - | header::IF_MODIFIED_SINCE - | header::IF_NONE_MATCH - | header::LOCATION - | header::COOKIE - | header::SET_COOKIE => true, - _ => false, - }, + | header::AUTHORIZATION + | header::CONTENT_LENGTH + | header::ETAG + | header::IF_MODIFIED_SINCE + | header::IF_NONE_MATCH + | header::LOCATION + | header::COOKIE + | header::SET_COOKIE + ), Header::Path(..) => true, _ => false, } @@ -205,6 +220,7 @@ impl From
for Header> { Header::Method(v) => Header::Method(v), Header::Scheme(v) => Header::Scheme(v), Header::Path(v) => Header::Path(v), + Header::Protocol(v) => Header::Protocol(v), Header::Status(v) => Header::Status(v), } } @@ -215,12 +231,13 @@ impl<'a> Name<'a> { match self { Name::Field(name) => Ok(Header::Field { name: name.clone(), - value: HeaderValue::from_bytes(&*value)?, + value: HeaderValue::from_bytes(&value)?, }), Name::Authority => Ok(Header::Authority(BytesStr::try_from(value)?)), - Name::Method => Ok(Header::Method(Method::from_bytes(&*value)?)), + Name::Method => Ok(Header::Method(Method::from_bytes(&value)?)), Name::Scheme => Ok(Header::Scheme(BytesStr::try_from(value)?)), Name::Path => Ok(Header::Path(BytesStr::try_from(value)?)), + Name::Protocol => Ok(Header::Protocol(Protocol::try_from(value)?)), Name::Status => { match StatusCode::from_bytes(&value) { Ok(status) => Ok(Header::Status(status)), @@ -238,6 +255,7 @@ impl<'a> Name<'a> { Name::Method => b":method", Name::Scheme => b":scheme", Name::Path => b":path", + Name::Protocol => b":protocol", Name::Status => b":status", } } @@ -246,8 +264,12 @@ impl<'a> Name<'a> { // ===== impl BytesStr ===== impl BytesStr { - pub(crate) unsafe fn from_utf8_unchecked(bytes: Bytes) -> Self { - BytesStr(bytes) + pub(crate) const fn from_static(value: &'static str) -> Self { + BytesStr(Bytes::from_static(value.as_bytes())) + } + + pub(crate) fn from(value: &str) -> Self { + BytesStr(Bytes::copy_from_slice(value.as_bytes())) } #[doc(hidden)] diff --git a/src/hpack/huffman/mod.rs b/src/hpack/huffman/mod.rs index b8db8b4d3..86c97eb58 100644 --- a/src/hpack/huffman/mod.rs +++ b/src/hpack/huffman/mod.rs @@ -1,7 +1,7 @@ mod table; use self::table::{DECODE_TABLE, ENCODE_TABLE}; -use crate::hpack::{DecoderError, EncoderError}; +use crate::hpack::DecoderError; use bytes::{BufMut, BytesMut}; @@ -40,11 +40,9 @@ pub fn decode(src: &[u8], buf: &mut BytesMut) -> Result Ok(buf.split()) } -// TODO: return error when there is not enough room to encode the value -pub fn encode(src: &[u8], dst: &mut B) -> Result<(), EncoderError> { +pub fn encode(src: &[u8], dst: &mut BytesMut) { let mut bits: u64 = 0; let mut bits_left = 40; - let mut rem = dst.remaining_mut(); for &b in src { let (nbits, code) = ENCODE_TABLE[b as usize]; @@ -53,29 +51,18 @@ pub fn encode(src: &[u8], dst: &mut B) -> Result<(), EncoderError> { bits_left -= nbits; while bits_left <= 32 { - if rem == 0 { - return Err(EncoderError::BufferOverflow); - } - dst.put_u8((bits >> 32) as u8); bits <<= 8; bits_left += 8; - rem -= 1; } } if bits_left != 40 { - if rem == 0 { - return Err(EncoderError::BufferOverflow); - } - // This writes the EOS token bits |= (1 << bits_left) - 1; dst.put_u8((bits >> 32) as u8); } - - Ok(()) } impl Decoder { @@ -125,7 +112,7 @@ mod test { #[test] fn decode_single_byte() { assert_eq!("o", decode(&[0b00111111]).unwrap()); - assert_eq!("0", decode(&[0x0 + 7]).unwrap()); + assert_eq!("0", decode(&[7]).unwrap()); assert_eq!("A", decode(&[(0x21 << 2) + 3]).unwrap()); } @@ -144,23 +131,23 @@ mod test { #[test] fn encode_single_byte() { - let mut dst = Vec::with_capacity(1); + let mut dst = BytesMut::with_capacity(1); - encode(b"o", &mut dst).unwrap(); + encode(b"o", &mut dst); assert_eq!(&dst[..], &[0b00111111]); dst.clear(); - encode(b"0", &mut dst).unwrap(); - assert_eq!(&dst[..], &[0x0 + 7]); + encode(b"0", &mut dst); + assert_eq!(&dst[..], &[7]); dst.clear(); - encode(b"A", &mut dst).unwrap(); + encode(b"A", &mut dst); assert_eq!(&dst[..], &[(0x21 << 2) + 3]); } #[test] fn encode_decode_str() { - const DATA: &'static [&'static str] = &[ + const DATA: &[&str] = &[ "hello world", ":method", ":scheme", @@ -185,9 +172,9 @@ mod test { ]; for s in DATA { - let mut dst = Vec::with_capacity(s.len()); + let mut dst = BytesMut::with_capacity(s.len()); - encode(s.as_bytes(), &mut dst).unwrap(); + encode(s.as_bytes(), &mut dst); let decoded = decode(&dst).unwrap(); @@ -197,13 +184,12 @@ mod test { #[test] fn encode_decode_u8() { - const DATA: &'static [&'static [u8]] = - &[b"\0", b"\0\0\0", b"\0\x01\x02\x03\x04\x05", b"\xFF\xF8"]; + const DATA: &[&[u8]] = &[b"\0", b"\0\0\0", b"\0\x01\x02\x03\x04\x05", b"\xFF\xF8"]; for s in DATA { - let mut dst = Vec::with_capacity(s.len()); + let mut dst = BytesMut::with_capacity(s.len()); - encode(s, &mut dst).unwrap(); + encode(s, &mut dst); let decoded = decode(&dst).unwrap(); diff --git a/src/hpack/mod.rs b/src/hpack/mod.rs index 365b0057f..12c75d553 100644 --- a/src/hpack/mod.rs +++ b/src/hpack/mod.rs @@ -1,12 +1,12 @@ mod decoder; mod encoder; pub(crate) mod header; -mod huffman; +pub(crate) mod huffman; mod table; #[cfg(test)] mod test; pub use self::decoder::{Decoder, DecoderError, NeedMore}; -pub use self::encoder::{Encode, EncodeState, Encoder, EncoderError}; +pub use self::encoder::Encoder; pub use self::header::{BytesStr, Header}; diff --git a/src/hpack/table.rs b/src/hpack/table.rs index e7c8ce760..3e45f413b 100644 --- a/src/hpack/table.rs +++ b/src/hpack/table.rs @@ -319,7 +319,7 @@ impl Table { let mut probe = probe + 1; probe_loop!(probe < self.indices.len(), { - let pos = &mut self.indices[probe as usize]; + let pos = &mut self.indices[probe]; prev = match mem::replace(pos, Some(prev)) { Some(p) => p, @@ -404,7 +404,7 @@ impl Table { // Find the associated position probe_loop!(probe < self.indices.len(), { - debug_assert!(!self.indices[probe].is_none()); + debug_assert!(self.indices[probe].is_some()); let mut pos = self.indices[probe].unwrap(); @@ -597,7 +597,7 @@ impl Table { } assert!(dist <= their_dist, - "could not find entry; actual={}; desired={};" + + "could not find entry; actual={}; desired={}" + "probe={}, dist={}; their_dist={}; index={}; msg={}", actual, desired, probe, dist, their_dist, index.wrapping_sub(self.inserted), msg); @@ -656,12 +656,12 @@ fn to_raw_capacity(n: usize) -> usize { #[inline] fn desired_pos(mask: usize, hash: HashValue) -> usize { - (hash.0 & mask) as usize + hash.0 & mask } #[inline] fn probe_distance(mask: usize, hash: HashValue, current: usize) -> usize { - current.wrapping_sub(desired_pos(mask, hash)) & mask as usize + current.wrapping_sub(desired_pos(mask, hash)) & mask } fn hash_header(header: &Header) -> HashValue { @@ -751,6 +751,7 @@ fn index_static(header: &Header) -> Option<(usize, bool)> { "/index.html" => Some((5, true)), _ => Some((4, false)), }, + Header::Protocol(..) => None, Header::Status(ref v) => match u16::from(*v) { 200 => Some((8, true)), 204 => Some((9, true)), diff --git a/src/hpack/test/fixture.rs b/src/hpack/test/fixture.rs index 20ee1275b..d3f76e3bf 100644 --- a/src/hpack/test/fixture.rs +++ b/src/hpack/test/fixture.rs @@ -1,6 +1,6 @@ use crate::hpack::{Decoder, Encoder, Header}; -use bytes::{buf::BufMutExt, BytesMut}; +use bytes::BytesMut; use hex::FromHex; use serde_json::Value; @@ -52,8 +52,8 @@ fn test_story(story: Value) { Case { seqno: case.get("seqno").unwrap().as_u64().unwrap(), - wire: wire, - expect: expect, + wire, + expect, header_table_size: size, } }) @@ -100,18 +100,14 @@ fn test_story(story: Value) { let mut input: Vec<_> = case .expect .iter() - .map(|&(ref name, ref value)| { + .map(|(name, value)| { Header::new(name.clone().into(), value.clone().into()) .unwrap() .into() }) .collect(); - encoder.encode( - None, - &mut input.clone().into_iter(), - &mut (&mut buf).limit(limit), - ); + encoder.encode(&mut input.clone().into_iter(), &mut buf); decoder .decode(&mut Cursor::new(&mut buf), |e| { @@ -138,6 +134,7 @@ fn key_str(e: &Header) -> &str { Header::Method(..) => ":method", Header::Scheme(..) => ":scheme", Header::Path(..) => ":path", + Header::Protocol(..) => ":protocol", Header::Status(..) => ":status", } } @@ -145,10 +142,11 @@ fn key_str(e: &Header) -> &str { fn value_str(e: &Header) -> &str { match *e { Header::Field { ref value, .. } => value.to_str().unwrap(), - Header::Authority(ref v) => &**v, + Header::Authority(ref v) => v, Header::Method(ref m) => m.as_str(), - Header::Scheme(ref v) => &**v, - Header::Path(ref v) => &**v, + Header::Scheme(ref v) => v, + Header::Path(ref v) => v, + Header::Protocol(ref v) => v.as_str(), Header::Status(ref v) => v.as_str(), } } diff --git a/src/hpack/test/fuzz.rs b/src/hpack/test/fuzz.rs index dbf9b3c8f..af9e8ea23 100644 --- a/src/hpack/test/fuzz.rs +++ b/src/hpack/test/fuzz.rs @@ -1,14 +1,15 @@ -use crate::hpack::{Decoder, Encode, Encoder, Header}; +use crate::hpack::{Decoder, Encoder, Header}; use http::header::{HeaderName, HeaderValue}; -use bytes::{buf::BufMutExt, Bytes, BytesMut}; +use bytes::BytesMut; use quickcheck::{Arbitrary, Gen, QuickCheck, TestResult}; -use rand::{Rng, SeedableRng, StdRng}; +use rand::distributions::Slice; +use rand::rngs::StdRng; +use rand::{thread_rng, Rng, SeedableRng}; use std::io::Cursor; -const MIN_CHUNK: usize = 16; const MAX_CHUNK: usize = 2 * 1024; #[test] @@ -36,17 +37,8 @@ fn hpack_fuzz_seeded() { #[derive(Debug, Clone)] struct FuzzHpack { - // The magic seed that makes the test case reproducible - seed: [usize; 4], - // The set of headers to encode / decode frames: Vec, - - // The list of chunk sizes to do it in - chunks: Vec, - - // Number of times reduced - reduced: usize, } #[derive(Debug, Clone)] @@ -56,9 +48,9 @@ struct HeaderFrame { } impl FuzzHpack { - fn new(seed: [usize; 4]) -> FuzzHpack { + fn new(seed: [u8; 32]) -> FuzzHpack { // Seed the RNG - let mut rng = StdRng::from_seed(&seed); + let mut rng = StdRng::from_seed(seed); // Generates a bunch of source headers let mut source: Vec>> = vec![]; @@ -68,12 +60,12 @@ impl FuzzHpack { } // Actual test run headers - let num: usize = rng.gen_range(40, 500); + let num: usize = rng.gen_range(40..500); let mut frames: Vec = vec![]; let mut added = 0; - let skew: i32 = rng.gen_range(1, 5); + let skew: i32 = rng.gen_range(1..5); // Rough number of headers to add while added < num { @@ -82,24 +74,24 @@ impl FuzzHpack { headers: vec![], }; - match rng.gen_range(0, 20) { + match rng.gen_range(0..20) { 0 => { // Two resizes - let high = rng.gen_range(128, MAX_CHUNK * 2); - let low = rng.gen_range(0, high); + let high = rng.gen_range(128..MAX_CHUNK * 2); + let low = rng.gen_range(0..high); - frame.resizes.extend(&[low, high]); + frame.resizes.extend([low, high]); } 1..=3 => { - frame.resizes.push(rng.gen_range(128, MAX_CHUNK * 2)); + frame.resizes.push(rng.gen_range(128..MAX_CHUNK * 2)); } _ => {} } let mut is_name_required = true; - for _ in 0..rng.gen_range(1, (num - added) + 1) { - let x: f64 = rng.gen_range(0.0, 1.0); + for _ in 0..rng.gen_range(1..(num - added) + 1) { + let x: f64 = rng.gen_range(0.0..1.0); let x = x.powi(skew); let i = (x * source.len() as f64) as usize; @@ -128,23 +120,10 @@ impl FuzzHpack { frames.push(frame); } - // Now, generate the buffer sizes used to encode - let mut chunks = vec![]; - - for _ in 0..rng.gen_range(0, 100) { - chunks.push(rng.gen_range(MIN_CHUNK, MAX_CHUNK)); - } - - FuzzHpack { - seed: seed, - frames: frames, - chunks: chunks, - reduced: 0, - } + FuzzHpack { frames } } fn run(self) { - let mut chunks = self.chunks; let frames = self.frames; let mut expect = vec![]; @@ -173,11 +152,7 @@ impl FuzzHpack { } } - let mut input = frame.headers.into_iter(); - let mut index = None; - - let mut max_chunk = chunks.pop().unwrap_or(MAX_CHUNK); - let mut buf = BytesMut::with_capacity(max_chunk); + let mut buf = BytesMut::new(); if let Some(max) = frame.resizes.iter().max() { decoder.queue_size_update(*max); @@ -188,25 +163,7 @@ impl FuzzHpack { encoder.update_max_size(*resize); } - loop { - match encoder.encode(index.take(), &mut input, &mut (&mut buf).limit(max_chunk)) { - Encode::Full => break, - Encode::Partial(i) => { - index = Some(i); - - // Decode the chunk! - decoder - .decode(&mut Cursor::new(&mut buf), |h| { - let e = expect.remove(0); - assert_eq!(h, e); - }) - .expect("partial decode"); - - max_chunk = chunks.pop().unwrap_or(MAX_CHUNK); - buf = BytesMut::with_capacity(max_chunk); - } - } - } + encoder.encode(frame.headers, &mut buf); // Decode the chunk! decoder @@ -222,31 +179,31 @@ impl FuzzHpack { } impl Arbitrary for FuzzHpack { - fn arbitrary(g: &mut G) -> Self { - FuzzHpack::new(quickcheck::Rng::gen(g)) + fn arbitrary(_: &mut Gen) -> Self { + FuzzHpack::new(thread_rng().gen()) } } fn gen_header(g: &mut StdRng) -> Header> { use http::{Method, StatusCode}; - if g.gen_weighted_bool(10) { - match g.next_u32() % 5 { + if g.gen_ratio(1, 10) { + match g.gen_range(0u32..5) { 0 => { let value = gen_string(g, 4, 20); Header::Authority(to_shared(value)) } 1 => { - let method = match g.next_u32() % 6 { + let method = match g.gen_range(0u32..6) { 0 => Method::GET, 1 => Method::POST, 2 => Method::PUT, 3 => Method::PATCH, 4 => Method::DELETE, 5 => { - let n: usize = g.gen_range(3, 7); + let n: usize = g.gen_range(3..7); let bytes: Vec = (0..n) - .map(|_| g.choose(b"ABCDEFGHIJKLMNOPQRSTUVWXYZ").unwrap().clone()) + .map(|_| *g.sample(Slice::new(b"ABCDEFGHIJKLMNOPQRSTUVWXYZ").unwrap())) .collect(); Method::from_bytes(&bytes).unwrap() @@ -257,7 +214,7 @@ fn gen_header(g: &mut StdRng) -> Header> { Header::Method(method) } 2 => { - let value = match g.next_u32() % 2 { + let value = match g.gen_range(0u32..2) { 0 => "http", 1 => "https", _ => unreachable!(), @@ -266,7 +223,7 @@ fn gen_header(g: &mut StdRng) -> Header> { Header::Scheme(to_shared(value.to_string())) } 3 => { - let value = match g.next_u32() % 100 { + let value = match g.gen_range(0u32..100) { 0 => "/".to_string(), 1 => "/index.html".to_string(), _ => gen_string(g, 2, 20), @@ -282,14 +239,14 @@ fn gen_header(g: &mut StdRng) -> Header> { _ => unreachable!(), } } else { - let name = if g.gen_weighted_bool(10) { + let name = if g.gen_ratio(1, 10) { None } else { Some(gen_header_name(g)) }; let mut value = gen_header_value(g); - if g.gen_weighted_bool(30) { + if g.gen_ratio(1, 30) { value.set_sensitive(true); } @@ -300,84 +257,86 @@ fn gen_header(g: &mut StdRng) -> Header> { fn gen_header_name(g: &mut StdRng) -> HeaderName { use http::header; - if g.gen_weighted_bool(2) { - g.choose(&[ - header::ACCEPT, - header::ACCEPT_CHARSET, - header::ACCEPT_ENCODING, - header::ACCEPT_LANGUAGE, - header::ACCEPT_RANGES, - header::ACCESS_CONTROL_ALLOW_CREDENTIALS, - header::ACCESS_CONTROL_ALLOW_HEADERS, - header::ACCESS_CONTROL_ALLOW_METHODS, - header::ACCESS_CONTROL_ALLOW_ORIGIN, - header::ACCESS_CONTROL_EXPOSE_HEADERS, - header::ACCESS_CONTROL_MAX_AGE, - header::ACCESS_CONTROL_REQUEST_HEADERS, - header::ACCESS_CONTROL_REQUEST_METHOD, - header::AGE, - header::ALLOW, - header::ALT_SVC, - header::AUTHORIZATION, - header::CACHE_CONTROL, - header::CONNECTION, - header::CONTENT_DISPOSITION, - header::CONTENT_ENCODING, - header::CONTENT_LANGUAGE, - header::CONTENT_LENGTH, - header::CONTENT_LOCATION, - header::CONTENT_RANGE, - header::CONTENT_SECURITY_POLICY, - header::CONTENT_SECURITY_POLICY_REPORT_ONLY, - header::CONTENT_TYPE, - header::COOKIE, - header::DNT, - header::DATE, - header::ETAG, - header::EXPECT, - header::EXPIRES, - header::FORWARDED, - header::FROM, - header::HOST, - header::IF_MATCH, - header::IF_MODIFIED_SINCE, - header::IF_NONE_MATCH, - header::IF_RANGE, - header::IF_UNMODIFIED_SINCE, - header::LAST_MODIFIED, - header::LINK, - header::LOCATION, - header::MAX_FORWARDS, - header::ORIGIN, - header::PRAGMA, - header::PROXY_AUTHENTICATE, - header::PROXY_AUTHORIZATION, - header::PUBLIC_KEY_PINS, - header::PUBLIC_KEY_PINS_REPORT_ONLY, - header::RANGE, - header::REFERER, - header::REFERRER_POLICY, - header::REFRESH, - header::RETRY_AFTER, - header::SERVER, - header::SET_COOKIE, - header::STRICT_TRANSPORT_SECURITY, - header::TE, - header::TRAILER, - header::TRANSFER_ENCODING, - header::USER_AGENT, - header::UPGRADE, - header::UPGRADE_INSECURE_REQUESTS, - header::VARY, - header::VIA, - header::WARNING, - header::WWW_AUTHENTICATE, - header::X_CONTENT_TYPE_OPTIONS, - header::X_DNS_PREFETCH_CONTROL, - header::X_FRAME_OPTIONS, - header::X_XSS_PROTECTION, - ]) - .unwrap() + if g.gen_ratio(1, 2) { + g.sample( + Slice::new(&[ + header::ACCEPT, + header::ACCEPT_CHARSET, + header::ACCEPT_ENCODING, + header::ACCEPT_LANGUAGE, + header::ACCEPT_RANGES, + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + header::ACCESS_CONTROL_ALLOW_HEADERS, + header::ACCESS_CONTROL_ALLOW_METHODS, + header::ACCESS_CONTROL_ALLOW_ORIGIN, + header::ACCESS_CONTROL_EXPOSE_HEADERS, + header::ACCESS_CONTROL_MAX_AGE, + header::ACCESS_CONTROL_REQUEST_HEADERS, + header::ACCESS_CONTROL_REQUEST_METHOD, + header::AGE, + header::ALLOW, + header::ALT_SVC, + header::AUTHORIZATION, + header::CACHE_CONTROL, + header::CONNECTION, + header::CONTENT_DISPOSITION, + header::CONTENT_ENCODING, + header::CONTENT_LANGUAGE, + header::CONTENT_LENGTH, + header::CONTENT_LOCATION, + header::CONTENT_RANGE, + header::CONTENT_SECURITY_POLICY, + header::CONTENT_SECURITY_POLICY_REPORT_ONLY, + header::CONTENT_TYPE, + header::COOKIE, + header::DNT, + header::DATE, + header::ETAG, + header::EXPECT, + header::EXPIRES, + header::FORWARDED, + header::FROM, + header::HOST, + header::IF_MATCH, + header::IF_MODIFIED_SINCE, + header::IF_NONE_MATCH, + header::IF_RANGE, + header::IF_UNMODIFIED_SINCE, + header::LAST_MODIFIED, + header::LINK, + header::LOCATION, + header::MAX_FORWARDS, + header::ORIGIN, + header::PRAGMA, + header::PROXY_AUTHENTICATE, + header::PROXY_AUTHORIZATION, + header::PUBLIC_KEY_PINS, + header::PUBLIC_KEY_PINS_REPORT_ONLY, + header::RANGE, + header::REFERER, + header::REFERRER_POLICY, + header::REFRESH, + header::RETRY_AFTER, + header::SERVER, + header::SET_COOKIE, + header::STRICT_TRANSPORT_SECURITY, + header::TE, + header::TRAILER, + header::TRANSFER_ENCODING, + header::USER_AGENT, + header::UPGRADE, + header::UPGRADE_INSECURE_REQUESTS, + header::VARY, + header::VIA, + header::WARNING, + header::WWW_AUTHENTICATE, + header::X_CONTENT_TYPE_OPTIONS, + header::X_DNS_PREFETCH_CONTROL, + header::X_FRAME_OPTIONS, + header::X_XSS_PROTECTION, + ]) + .unwrap(), + ) .clone() } else { let value = gen_string(g, 1, 25); @@ -394,9 +353,7 @@ fn gen_string(g: &mut StdRng, min: usize, max: usize) -> String { let bytes: Vec<_> = (min..max) .map(|_| { // Chars to pick from - g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----") - .unwrap() - .clone() + *g.sample(Slice::new(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----").unwrap()) }) .collect(); @@ -404,6 +361,5 @@ fn gen_string(g: &mut StdRng, min: usize, max: usize) -> String { } fn to_shared(src: String) -> crate::hpack::BytesStr { - let b: Bytes = src.into(); - unsafe { crate::hpack::BytesStr::from_utf8_unchecked(b) } + crate::hpack::BytesStr::from(src.as_str()) } diff --git a/src/lib.rs b/src/lib.rs index f0bd67d63..fd7782f8e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ -//! An asynchronous, HTTP/2.0 server and client implementation. +//! An asynchronous, HTTP/2 server and client implementation. //! -//! This library implements the [HTTP/2.0] specification. The implementation is +//! This library implements the [HTTP/2] specification. The implementation is //! asynchronous, using [futures] as the basis for the API. The implementation //! is also decoupled from TCP or TLS details. The user must handle ALPN and //! HTTP/1.1 upgrades themselves. @@ -11,7 +11,7 @@ //! //! ```toml //! [dependencies] -//! h2 = "0.2" +//! h2 = "0.4" //! ``` //! //! # Layout @@ -24,19 +24,19 @@ //! # Handshake //! //! Both the client and the server require a connection to already be in a state -//! ready to start the HTTP/2.0 handshake. This library does not provide +//! ready to start the HTTP/2 handshake. This library does not provide //! facilities to do this. //! -//! There are three ways to reach an appropriate state to start the HTTP/2.0 +//! There are three ways to reach an appropriate state to start the HTTP/2 //! handshake. //! //! * Opening an HTTP/1.1 connection and performing an [upgrade]. //! * Opening a connection with TLS and use ALPN to negotiate the protocol. //! * Open a connection with prior knowledge, i.e. both the client and the //! server assume that the connection is immediately ready to start the -//! HTTP/2.0 handshake once opened. +//! HTTP/2 handshake once opened. //! -//! Once the connection is ready to start the HTTP/2.0 handshake, it can be +//! Once the connection is ready to start the HTTP/2 handshake, it can be //! passed to [`server::handshake`] or [`client::handshake`]. At this point, the //! library will start the handshake process, which consists of: //! @@ -48,10 +48,10 @@ //! //! # Flow control //! -//! [Flow control] is a fundamental feature of HTTP/2.0. The `h2` library +//! [Flow control] is a fundamental feature of HTTP/2. The `h2` library //! exposes flow control to the user. //! -//! An HTTP/2.0 client or server may not send unlimited data to the peer. When a +//! An HTTP/2 client or server may not send unlimited data to the peer. When a //! stream is initiated, both the client and the server are provided with an //! initial window size for that stream. A window size is the number of bytes //! the endpoint can send to the peer. At any point in time, the peer may @@ -66,7 +66,7 @@ //! Managing flow control for outbound data is done through [`SendStream`]. See //! the struct level documentation for those two types for more details. //! -//! [HTTP/2.0]: https://http2.github.io/ +//! [HTTP/2]: https://http2.github.io/ //! [futures]: https://docs.rs/futures/ //! [`client`]: client/index.html //! [`server`]: server/index.html @@ -78,16 +78,21 @@ //! [`server::handshake`]: server/fn.handshake.html //! [`client::handshake`]: client/fn.handshake.html -#![doc(html_root_url = "https://docs.rs/h2/0.2.4")] -#![deny(missing_debug_implementations, missing_docs)] +#![deny( + missing_debug_implementations, + missing_docs, + clippy::missing_safety_doc, + clippy::undocumented_unsafe_blocks +)] +#![allow(clippy::type_complexity, clippy::manual_range_contains)] #![cfg_attr(test, deny(warnings))] macro_rules! proto_err { (conn: $($msg:tt)+) => { - log::debug!("connection error PROTOCOL_ERROR -- {};", format_args!($($msg)+)) + tracing::debug!("connection error PROTOCOL_ERROR -- {};", format_args!($($msg)+)) }; (stream: $($msg:tt)+) => { - log::debug!("stream error PROTOCOL_ERROR -- {};", format_args!($($msg)+)) + tracing::debug!("stream error PROTOCOL_ERROR -- {};", format_args!($($msg)+)) }; } @@ -104,8 +109,14 @@ macro_rules! ready { mod codec; mod error; mod hpack; + +#[cfg(not(feature = "unstable"))] mod proto; +#[cfg(feature = "unstable")] +#[allow(missing_docs)] +pub mod proto; + #[cfg(not(feature = "unstable"))] mod frame; @@ -114,52 +125,16 @@ mod frame; pub mod frame; pub mod client; +pub mod ext; pub mod server; mod share; +#[cfg(fuzzing)] +#[cfg_attr(feature = "unstable", allow(missing_docs))] +pub mod fuzz_bridge; + pub use crate::error::{Error, Reason}; pub use crate::share::{FlowControl, Ping, PingPong, Pong, RecvStream, SendStream, StreamId}; #[cfg(feature = "unstable")] -pub use codec::{Codec, RecvError, SendError, UserError}; - -use std::task::Poll; - -// TODO: Get rid of this trait once https://github.com/rust-lang/rust/pull/63512 -// is stablized. -trait PollExt { - /// Changes the success value of this `Poll` with the closure provided. - fn map_ok_(self, f: F) -> Poll>> - where - F: FnOnce(T) -> U; - /// Changes the error value of this `Poll` with the closure provided. - fn map_err_(self, f: F) -> Poll>> - where - F: FnOnce(E) -> U; -} - -impl PollExt for Poll>> { - fn map_ok_(self, f: F) -> Poll>> - where - F: FnOnce(T) -> U, - { - match self { - Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(f(t)))), - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } - - fn map_err_(self, f: F) -> Poll>> - where - F: FnOnce(E) -> U, - { - match self { - Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(t))), - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(f(e)))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} +pub use codec::{Codec, SendError, UserError}; diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 49c123efa..5969bb841 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -1,22 +1,35 @@ -use crate::codec::{RecvError, UserError}; +use crate::codec::UserError; use crate::frame::{Reason, StreamId}; -use crate::{client, frame, proto, server}; +use crate::{client, server}; use crate::frame::DEFAULT_INITIAL_WINDOW_SIZE; use crate::proto::*; -use bytes::{Buf, Bytes}; +use bytes::Bytes; use futures_core::Stream; use std::io; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::AsyncRead; /// An H2 connection #[derive(Debug)] pub(crate) struct Connection +where + P: Peer, +{ + /// Read / write frame values + codec: Codec>, + + inner: ConnectionInner, +} + +// Extracted part of `Connection` which does not depend on `T`. Reduces the amount of duplicated +// method instantiations. +#[derive(Debug)] +struct ConnectionInner where P: Peer, { @@ -27,10 +40,7 @@ where /// /// This exists separately from State in order to support /// graceful shutdown. - error: Option, - - /// Read / write frame values - codec: Codec>, + error: Option, /// Pending GOAWAY frames to write. go_away: GoAway, @@ -44,16 +54,34 @@ where /// Stream state handler streams: Streams, + /// A `tracing` span tracking the lifetime of the connection. + span: tracing::Span, + /// Client or server _phantom: PhantomData

, } +struct DynConnection<'a, B: Buf = Bytes> { + state: &'a mut State, + + go_away: &'a mut GoAway, + + streams: DynStreams<'a, B>, + + error: &'a mut Option, + + ping_pong: &'a mut PingPong, +} + #[derive(Debug, Clone)] pub(crate) struct Config { pub next_stream_id: StreamId, pub initial_max_send_streams: usize, + pub max_send_buffer_size: usize, pub reset_stream_duration: Duration, pub reset_stream_max: usize, + pub remote_reset_stream_max: usize, + pub local_error_reset_streams_max: Option, pub settings: frame::Settings, } @@ -63,10 +91,10 @@ enum State { Open, /// The codec must be flushed - Closing(Reason), + Closing(Reason, Initiator), /// In a closed state - Closed(Reason), + Closed(Reason, Initiator), } impl Connection @@ -76,58 +104,97 @@ where B: Buf, { pub fn new(codec: Codec>, config: Config) -> Connection { - let streams = Streams::new(streams::Config { - local_init_window_sz: config - .settings - .initial_window_size() - .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), - initial_max_send_streams: config.initial_max_send_streams, - local_next_stream_id: config.next_stream_id, - local_push_enabled: config.settings.is_push_enabled(), - local_reset_duration: config.reset_stream_duration, - local_reset_max: config.reset_stream_max, - remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, - remote_max_initiated: config - .settings - .max_concurrent_streams() - .map(|max| max as usize), - }); + fn streams_config(config: &Config) -> streams::Config { + streams::Config { + initial_max_send_streams: config.initial_max_send_streams, + local_max_buffer_size: config.max_send_buffer_size, + local_next_stream_id: config.next_stream_id, + local_push_enabled: config.settings.is_push_enabled().unwrap_or(true), + extended_connect_protocol_enabled: config + .settings + .is_extended_connect_protocol_enabled() + .unwrap_or(false), + local_reset_duration: config.reset_stream_duration, + local_reset_max: config.reset_stream_max, + remote_reset_max: config.remote_reset_stream_max, + remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, + remote_max_initiated: config + .settings + .max_concurrent_streams() + .map(|max| max as usize), + local_max_error_reset_streams: config.local_error_reset_streams_max, + } + } + let streams = Streams::new(streams_config(&config)); Connection { - state: State::Open, - error: None, codec, - go_away: GoAway::new(), - ping_pong: PingPong::new(), - settings: Settings::new(config.settings), - streams, - _phantom: PhantomData, + inner: ConnectionInner { + state: State::Open, + error: None, + go_away: GoAway::new(), + ping_pong: PingPong::new(), + settings: Settings::new(config.settings), + streams, + span: tracing::debug_span!("Connection", peer = %P::NAME), + _phantom: PhantomData, + }, } } /// connection flow control pub(crate) fn set_target_window_size(&mut self, size: WindowSize) { - self.streams.set_target_connection_window_size(size); + let _res = self.inner.streams.set_target_connection_window_size(size); + // TODO: proper error handling + debug_assert!(_res.is_ok()); } /// Send a new SETTINGS frame with an updated initial window size. pub(crate) fn set_initial_window_size(&mut self, size: WindowSize) -> Result<(), UserError> { let mut settings = frame::Settings::default(); settings.set_initial_window_size(Some(size)); - self.settings.send_settings(settings) + self.inner.settings.send_settings(settings) + } + + /// Send a new SETTINGS frame with extended CONNECT protocol enabled. + pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> { + let mut settings = frame::Settings::default(); + settings.set_enable_connect_protocol(Some(1)); + self.inner.settings.send_settings(settings) + } + + /// Returns the maximum number of concurrent streams that may be initiated + /// by this peer. + pub(crate) fn max_send_streams(&self) -> usize { + self.inner.streams.max_send_streams() + } + + /// Returns the maximum number of concurrent streams that may be initiated + /// by the remote peer. + pub(crate) fn max_recv_streams(&self) -> usize { + self.inner.streams.max_recv_streams() + } + + #[cfg(feature = "unstable")] + pub fn num_wired_streams(&self) -> usize { + self.inner.streams.num_wired_streams() } /// Returns `Ready` when the connection is ready to receive a frame. /// - /// Returns `RecvError` as this may raise errors that are caused by delayed + /// Returns `Error` as this may raise errors that are caused by delayed /// processing of received frames. - fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + let _e = self.inner.span.enter(); + let span = tracing::trace_span!("poll_ready"); + let _e = span.enter(); // The order of these calls don't really matter too much - ready!(self.ping_pong.send_pending_pong(cx, &mut self.codec))?; - ready!(self.ping_pong.send_pending_ping(cx, &mut self.codec))?; + ready!(self.inner.ping_pong.send_pending_pong(cx, &mut self.codec))?; + ready!(self.inner.ping_pong.send_pending_ping(cx, &mut self.codec))?; ready!(self + .inner .settings - .poll_send(cx, &mut self.codec, &mut self.streams))?; - ready!(self.streams.send_pending_refusal(cx, &mut self.codec))?; + .poll_send(cx, &mut self.codec, &mut self.inner.streams))?; + ready!(self.inner.streams.send_pending_refusal(cx, &mut self.codec))?; Poll::Ready(Ok(())) } @@ -137,50 +204,31 @@ where /// This will return `Some(reason)` if the connection should be closed /// afterwards. If this is a graceful shutdown, this returns `None`. fn poll_go_away(&mut self, cx: &mut Context) -> Poll>> { - self.go_away.send_pending_go_away(cx, &mut self.codec) - } - - fn go_away(&mut self, id: StreamId, e: Reason) { - let frame = frame::GoAway::new(id, e); - self.streams.send_go_away(id); - self.go_away.go_away(frame); - } - - fn go_away_now(&mut self, e: Reason) { - let last_processed_id = self.streams.last_processed_id(); - let frame = frame::GoAway::new(last_processed_id, e); - self.go_away.go_away_now(frame); + self.inner.go_away.send_pending_go_away(cx, &mut self.codec) } pub fn go_away_from_user(&mut self, e: Reason) { - let last_processed_id = self.streams.last_processed_id(); - let frame = frame::GoAway::new(last_processed_id, e); - self.go_away.go_away_from_user(frame); - - // Notify all streams of reason we're abruptly closing. - self.streams.recv_err(&proto::Error::Proto(e)); + self.inner.as_dyn().go_away_from_user(e) } - fn take_error(&mut self, ours: Reason) -> Poll> { - let reason = if let Some(theirs) = self.error.take() { - match (ours, theirs) { - // If either side reported an error, return that - // to the user. - (Reason::NO_ERROR, err) | (err, Reason::NO_ERROR) => err, - // If both sides reported an error, give their - // error back to th user. We assume our error - // was a consequence of their error, and less - // important. - (_, theirs) => theirs, - } - } else { - ours - }; - - if reason == Reason::NO_ERROR { - Poll::Ready(Ok(())) - } else { - Poll::Ready(Err(proto::Error::Proto(reason))) + fn take_error(&mut self, ours: Reason, initiator: Initiator) -> Result<(), Error> { + let (debug_data, theirs) = self + .inner + .error + .take() + .as_ref() + .map_or((Bytes::new(), Reason::NO_ERROR), |frame| { + (frame.debug_data().clone(), frame.reason()) + }); + + match (ours, theirs) { + (Reason::NO_ERROR, Reason::NO_ERROR) => Ok(()), + (ours, Reason::NO_ERROR) => Err(Error::GoAway(Bytes::new(), ours, initiator)), + // If both sides reported an error, give their + // error back to th user. We assume our error + // was a consequence of their error, and less + // important. + (_, theirs) => Err(Error::remote_go_away(debug_data, theirs)), } } @@ -189,102 +237,71 @@ where pub fn maybe_close_connection_if_no_streams(&mut self) { // If we poll() and realize that there are no streams or references // then we can close the connection by transitioning to GOAWAY - if !self.streams.has_streams_or_other_references() { - self.go_away_now(Reason::NO_ERROR); + if !self.inner.streams.has_streams_or_other_references() { + self.inner.as_dyn().go_away_now(Reason::NO_ERROR); } } pub(crate) fn take_user_pings(&mut self) -> Option { - self.ping_pong.take_user_pings() + self.inner.ping_pong.take_user_pings() } /// Advances the internal state of the connection. - pub fn poll(&mut self, cx: &mut Context) -> Poll> { - use crate::codec::RecvError::*; + pub fn poll(&mut self, cx: &mut Context) -> Poll> { + // XXX(eliza): cloning the span is unfortunately necessary here in + // order to placate the borrow checker — `self` is mutably borrowed by + // `poll2`, which means that we can't borrow `self.span` to enter it. + // The clone is just an atomic ref bump. + let span = self.inner.span.clone(); + let _e = span.enter(); + let span = tracing::trace_span!("poll"); + let _e = span.enter(); loop { + tracing::trace!(connection.state = ?self.inner.state); // TODO: probably clean up this glob of code - match self.state { + match self.inner.state { // When open, continue to poll a frame State::Open => { - match self.poll2(cx) { - // The connection has shutdown normally - Poll::Ready(Ok(())) => self.state = State::Closing(Reason::NO_ERROR), + let result = match self.poll2(cx) { + Poll::Ready(result) => result, // The connection is not ready to make progress Poll::Pending => { // Ensure all window updates have been sent. // // This will also handle flushing `self.codec` - ready!(self.streams.poll_complete(cx, &mut self.codec))?; + ready!(self.inner.streams.poll_complete(cx, &mut self.codec))?; - if (self.error.is_some() || self.go_away.should_close_on_idle()) - && !self.streams.has_streams() + if (self.inner.error.is_some() + || self.inner.go_away.should_close_on_idle()) + && !self.inner.streams.has_streams() { - self.go_away_now(Reason::NO_ERROR); + self.inner.as_dyn().go_away_now(Reason::NO_ERROR); continue; } return Poll::Pending; } - // Attempting to read a frame resulted in a connection level - // error. This is handled by setting a GOAWAY frame followed by - // terminating the connection. - Poll::Ready(Err(Connection(e))) => { - log::debug!("Connection::poll; connection error={:?}", e); - - // We may have already sent a GOAWAY for this error, - // if so, don't send another, just flush and close up. - if let Some(reason) = self.go_away.going_away_reason() { - if reason == e { - log::trace!(" -> already going away"); - self.state = State::Closing(e); - continue; - } - } + }; - // Reset all active streams - self.streams.recv_err(&e.into()); - self.go_away_now(e); - } - // Attempting to read a frame resulted in a stream level error. - // This is handled by resetting the frame then trying to read - // another frame. - Poll::Ready(Err(Stream { id, reason })) => { - log::trace!("stream error; id={:?}; reason={:?}", id, reason); - self.streams.send_reset(id, reason); - } - // Attempting to read a frame resulted in an I/O error. All - // active streams must be reset. - // - // TODO: Are I/O errors recoverable? - Poll::Ready(Err(Io(e))) => { - log::debug!("Connection::poll; IO error={:?}", e); - let e = e.into(); - - // Reset all active streams - self.streams.recv_err(&e); - - // Return the error - return Poll::Ready(Err(e)); - } - } + self.inner.as_dyn().handle_poll2_result(result)? } - State::Closing(reason) => { - log::trace!("connection closing after flush"); + State::Closing(reason, initiator) => { + tracing::trace!("connection closing after flush"); // Flush/shutdown the codec ready!(self.codec.shutdown(cx))?; // Transition the state to error - self.state = State::Closed(reason); + self.inner.state = State::Closed(reason, initiator); + } + State::Closed(reason, initiator) => { + return Poll::Ready(self.take_error(reason, initiator)); } - State::Closed(reason) => return self.take_error(reason), } } } - fn poll2(&mut self, cx: &mut Context) -> Poll> { - use crate::frame::Frame::*; - + fn poll2(&mut self, cx: &mut Context) -> Poll> { // This happens outside of the loop to prevent needing to do a clock // check and then comparison of the queue possibly multiple times a // second (and thus, the clock wouldn't have changed enough to matter). @@ -297,13 +314,13 @@ where // - poll_go_away may buffer a graceful shutdown GOAWAY frame // - If it has, we've also added a PING to be sent in poll_ready if let Some(reason) = ready!(self.poll_go_away(cx)?) { - if self.go_away.should_close_now() { - if self.go_away.is_user_initiated() { + if self.inner.go_away.should_close_now() { + if self.inner.go_away.is_user_initiated() { // A user initiated abrupt shutdown shouldn't return // the same error back to the user. return Poll::Ready(Ok(())); } else { - return Poll::Ready(Err(RecvError::Connection(reason))); + return Poll::Ready(Err(Error::library_go_away(reason))); } } // Only NO_ERROR should be waiting for idle @@ -315,61 +332,20 @@ where } ready!(self.poll_ready(cx))?; - match ready!(Pin::new(&mut self.codec).poll_next(cx)?) { - Some(Headers(frame)) => { - log::trace!("recv HEADERS; frame={:?}", frame); - self.streams.recv_headers(frame)?; - } - Some(Data(frame)) => { - log::trace!("recv DATA; frame={:?}", frame); - self.streams.recv_data(frame)?; - } - Some(Reset(frame)) => { - log::trace!("recv RST_STREAM; frame={:?}", frame); - self.streams.recv_reset(frame)?; - } - Some(PushPromise(frame)) => { - log::trace!("recv PUSH_PROMISE; frame={:?}", frame); - self.streams.recv_push_promise(frame)?; - } - Some(Settings(frame)) => { - log::trace!("recv SETTINGS; frame={:?}", frame); - self.settings - .recv_settings(frame, &mut self.codec, &mut self.streams)?; - } - Some(GoAway(frame)) => { - log::trace!("recv GOAWAY; frame={:?}", frame); - // This should prevent starting new streams, - // but should allow continuing to process current streams - // until they are all EOS. Once they are, State should - // transition to GoAway. - self.streams.recv_go_away(&frame)?; - self.error = Some(frame.reason()); - } - Some(Ping(frame)) => { - log::trace!("recv PING; frame={:?}", frame); - let status = self.ping_pong.recv_ping(frame); - if status.is_shutdown() { - assert!( - self.go_away.is_going_away(), - "received unexpected shutdown ping" - ); - - let last_processed_id = self.streams.last_processed_id(); - self.go_away(last_processed_id, Reason::NO_ERROR); - } - } - Some(WindowUpdate(frame)) => { - log::trace!("recv WINDOW_UPDATE; frame={:?}", frame); - self.streams.recv_window_update(frame)?; - } - Some(Priority(frame)) => { - log::trace!("recv PRIORITY; frame={:?}", frame); - // TODO: handle + match self + .inner + .as_dyn() + .recv_frame(ready!(Pin::new(&mut self.codec).poll_next(cx)?))? + { + ReceivedFrame::Settings(frame) => { + self.inner.settings.recv_settings( + frame, + &mut self.codec, + &mut self.inner.streams, + )?; } - None => { - log::trace!("codec closed"); - self.streams.recv_eof(false).expect("mutex poisoned"); + ReceivedFrame::Continue => (), + ReceivedFrame::Done => { return Poll::Ready(Ok(())); } } @@ -377,17 +353,213 @@ where } fn clear_expired_reset_streams(&mut self) { - self.streams.clear_expired_reset_streams(); + self.inner.streams.clear_expired_reset_streams(); } } +impl ConnectionInner +where + P: Peer, + B: Buf, +{ + fn as_dyn(&mut self) -> DynConnection<'_, B> { + let ConnectionInner { + state, + go_away, + streams, + error, + ping_pong, + .. + } = self; + let streams = streams.as_dyn(); + DynConnection { + state, + go_away, + streams, + error, + ping_pong, + } + } +} + +impl DynConnection<'_, B> +where + B: Buf, +{ + fn go_away(&mut self, id: StreamId, e: Reason) { + let frame = frame::GoAway::new(id, e); + self.streams.send_go_away(id); + self.go_away.go_away(frame); + } + + fn go_away_now(&mut self, e: Reason) { + let last_processed_id = self.streams.last_processed_id(); + let frame = frame::GoAway::new(last_processed_id, e); + self.go_away.go_away_now(frame); + } + + fn go_away_now_data(&mut self, e: Reason, data: Bytes) { + let last_processed_id = self.streams.last_processed_id(); + let frame = frame::GoAway::with_debug_data(last_processed_id, e, data); + self.go_away.go_away_now(frame); + } + + fn go_away_from_user(&mut self, e: Reason) { + let last_processed_id = self.streams.last_processed_id(); + let frame = frame::GoAway::new(last_processed_id, e); + self.go_away.go_away_from_user(frame); + + // Notify all streams of reason we're abruptly closing. + self.streams.handle_error(Error::user_go_away(e)); + } + + fn handle_poll2_result(&mut self, result: Result<(), Error>) -> Result<(), Error> { + match result { + // The connection has shutdown normally + Ok(()) => { + *self.state = State::Closing(Reason::NO_ERROR, Initiator::Library); + Ok(()) + } + // Attempting to read a frame resulted in a connection level + // error. This is handled by setting a GOAWAY frame followed by + // terminating the connection. + Err(Error::GoAway(debug_data, reason, initiator)) => { + let e = Error::GoAway(debug_data.clone(), reason, initiator); + tracing::debug!(error = ?e, "Connection::poll; connection error"); + + // We may have already sent a GOAWAY for this error, + // if so, don't send another, just flush and close up. + if self + .go_away + .going_away() + .map_or(false, |frame| frame.reason() == reason) + { + tracing::trace!(" -> already going away"); + *self.state = State::Closing(reason, initiator); + return Ok(()); + } + + // Reset all active streams + self.streams.handle_error(e); + self.go_away_now_data(reason, debug_data); + Ok(()) + } + // Attempting to read a frame resulted in a stream level error. + // This is handled by resetting the frame then trying to read + // another frame. + Err(Error::Reset(id, reason, initiator)) => { + debug_assert_eq!(initiator, Initiator::Library); + tracing::trace!(?id, ?reason, "stream error"); + self.streams.send_reset(id, reason); + Ok(()) + } + // Attempting to read a frame resulted in an I/O error. All + // active streams must be reset. + // + // TODO: Are I/O errors recoverable? + Err(Error::Io(kind, inner)) => { + tracing::debug!(error = ?kind, "Connection::poll; IO error"); + let e = Error::Io(kind, inner); + + // Reset all active streams + self.streams.handle_error(e.clone()); + + // Some client implementations drop the connections without notifying its peer + // Attempting to read after the client dropped the connection results in UnexpectedEof + // If as a server, we don't have anything more to send, just close the connection + // without error + // + // See https://github.com/hyperium/hyper/issues/3427 + if self.streams.is_server() + && self.streams.is_buffer_empty() + && matches!(kind, io::ErrorKind::UnexpectedEof) + { + *self.state = State::Closed(Reason::NO_ERROR, Initiator::Library); + return Ok(()); + } + + // Return the error + Err(e) + } + } + } + + fn recv_frame(&mut self, frame: Option) -> Result { + use crate::frame::Frame::*; + match frame { + Some(Headers(frame)) => { + tracing::trace!(?frame, "recv HEADERS"); + self.streams.recv_headers(frame)?; + } + Some(Data(frame)) => { + tracing::trace!(?frame, "recv DATA"); + self.streams.recv_data(frame)?; + } + Some(Reset(frame)) => { + tracing::trace!(?frame, "recv RST_STREAM"); + self.streams.recv_reset(frame)?; + } + Some(PushPromise(frame)) => { + tracing::trace!(?frame, "recv PUSH_PROMISE"); + self.streams.recv_push_promise(frame)?; + } + Some(Settings(frame)) => { + tracing::trace!(?frame, "recv SETTINGS"); + return Ok(ReceivedFrame::Settings(frame)); + } + Some(GoAway(frame)) => { + tracing::trace!(?frame, "recv GOAWAY"); + // This should prevent starting new streams, + // but should allow continuing to process current streams + // until they are all EOS. Once they are, State should + // transition to GoAway. + self.streams.recv_go_away(&frame)?; + *self.error = Some(frame); + } + Some(Ping(frame)) => { + tracing::trace!(?frame, "recv PING"); + let status = self.ping_pong.recv_ping(frame); + if status.is_shutdown() { + assert!( + self.go_away.is_going_away(), + "received unexpected shutdown ping" + ); + + let last_processed_id = self.streams.last_processed_id(); + self.go_away(last_processed_id, Reason::NO_ERROR); + } + } + Some(WindowUpdate(frame)) => { + tracing::trace!(?frame, "recv WINDOW_UPDATE"); + self.streams.recv_window_update(frame)?; + } + Some(Priority(frame)) => { + tracing::trace!(?frame, "recv PRIORITY"); + // TODO: handle + } + None => { + tracing::trace!("codec closed"); + self.streams.recv_eof(false).expect("mutex poisoned"); + return Ok(ReceivedFrame::Done); + } + } + Ok(ReceivedFrame::Continue) + } +} + +enum ReceivedFrame { + Settings(frame::Settings), + Continue, + Done, +} + impl Connection where T: AsyncRead + AsyncWrite, B: Buf, { pub(crate) fn streams(&self) -> &Streams { - &self.streams + &self.inner.streams } } @@ -397,12 +569,12 @@ where B: Buf, { pub fn next_incoming(&mut self) -> Option> { - self.streams.next_incoming() + self.inner.streams.next_incoming() } // Graceful shutdown only makes sense for server peers. pub fn go_away_gracefully(&mut self) { - if self.go_away.is_going_away() { + if self.inner.go_away.is_going_away() { // No reason to start a new one. return; } @@ -418,11 +590,11 @@ where // > send another GOAWAY frame with an updated last stream identifier. // > This ensures that a connection can be cleanly shut down without // > losing requests. - self.go_away(StreamId::MAX, Reason::NO_ERROR); + self.inner.as_dyn().go_away(StreamId::MAX, Reason::NO_ERROR); // We take the advice of waiting 1 RTT literally, and wait // for a pong before proceeding. - self.ping_pong.ping_shutdown(); + self.inner.ping_pong.ping_shutdown(); } } @@ -433,6 +605,6 @@ where { fn drop(&mut self) { // Ignore errors as this indicates that the mutex is poisoned. - let _ = self.streams.recv_eof(true); + let _ = self.inner.streams.recv_eof(true); } } diff --git a/src/proto/error.rs b/src/proto/error.rs index c3ee20d03..ad023317e 100644 --- a/src/proto/error.rs +++ b/src/proto/error.rs @@ -1,53 +1,91 @@ -use crate::codec::{RecvError, SendError}; -use crate::frame::Reason; +use crate::codec::SendError; +use crate::frame::{Reason, StreamId}; +use bytes::Bytes; +use std::fmt; use std::io; /// Either an H2 reason or an I/O error -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum Error { - Proto(Reason), - Io(io::Error), + Reset(StreamId, Reason, Initiator), + GoAway(Bytes, Reason, Initiator), + Io(io::ErrorKind, Option), +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Initiator { + User, + Library, + Remote, } impl Error { - /// Clone the error for internal purposes. - /// - /// `io::Error` is not `Clone`, so we only copy the `ErrorKind`. - pub(super) fn shallow_clone(&self) -> Error { + pub(crate) fn is_local(&self) -> bool { match *self { - Error::Proto(reason) => Error::Proto(reason), - Error::Io(ref io) => Error::Io(io::Error::from(io.kind())), + Self::Reset(_, _, initiator) | Self::GoAway(_, _, initiator) => initiator.is_local(), + Self::Io(..) => true, } } -} -impl From for Error { - fn from(src: Reason) -> Self { - Error::Proto(src) + pub(crate) fn user_go_away(reason: Reason) -> Self { + Self::GoAway(Bytes::new(), reason, Initiator::User) + } + + pub(crate) fn library_reset(stream_id: StreamId, reason: Reason) -> Self { + Self::Reset(stream_id, reason, Initiator::Library) + } + + pub(crate) fn library_go_away(reason: Reason) -> Self { + Self::GoAway(Bytes::new(), reason, Initiator::Library) + } + + pub(crate) fn library_go_away_data(reason: Reason, debug_data: impl Into) -> Self { + Self::GoAway(debug_data.into(), reason, Initiator::Library) + } + + pub(crate) fn remote_reset(stream_id: StreamId, reason: Reason) -> Self { + Self::Reset(stream_id, reason, Initiator::Remote) + } + + pub(crate) fn remote_go_away(debug_data: Bytes, reason: Reason) -> Self { + Self::GoAway(debug_data, reason, Initiator::Remote) } } -impl From for Error { - fn from(src: io::Error) -> Self { - Error::Io(src) +impl Initiator { + fn is_local(&self) -> bool { + match *self { + Self::User | Self::Library => true, + Self::Remote => false, + } } } -impl From for RecvError { - fn from(src: Error) -> RecvError { - match src { - Error::Proto(reason) => RecvError::Connection(reason), - Error::Io(e) => RecvError::Io(e), +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + Self::Reset(_, reason, _) | Self::GoAway(_, reason, _) => reason.fmt(fmt), + Self::Io(_, Some(ref inner)) => inner.fmt(fmt), + Self::Io(kind, None) => io::Error::from(kind).fmt(fmt), } } } +impl From for Error { + fn from(src: io::ErrorKind) -> Self { + Error::Io(src, None) + } +} + +impl From for Error { + fn from(src: io::Error) -> Self { + Error::Io(src.kind(), src.get_ref().map(|inner| inner.to_string())) + } +} + impl From for SendError { - fn from(src: Error) -> SendError { - match src { - Error::Proto(reason) => SendError::Connection(reason), - Error::Io(e) => SendError::Io(e), - } + fn from(src: Error) -> Self { + Self::Connection(src) } } diff --git a/src/proto/go_away.rs b/src/proto/go_away.rs index 91d37b642..d52252cd7 100644 --- a/src/proto/go_away.rs +++ b/src/proto/go_away.rs @@ -26,12 +26,8 @@ pub(super) struct GoAway { /// were a `frame::GoAway`, it might appear like we eventually wanted to /// serialize it. We **only** want to be able to look up these fields at a /// later time. -/// -/// (Technically, `frame::GoAway` should gain an opaque_debug_data field as -/// well, and we wouldn't want to save that here to accidentally dump in logs, -/// or waste struct space.) #[derive(Debug)] -struct GoingAway { +pub(crate) struct GoingAway { /// Stores the highest stream ID of a GOAWAY that has been sent. /// /// It's illegal to send a subsequent GOAWAY with a higher ID. @@ -98,9 +94,9 @@ impl GoAway { self.is_user_initiated } - /// Return the last Reason we've sent. - pub fn going_away_reason(&self) -> Option { - self.going_away.as_ref().map(|g| g.reason) + /// Returns the going away info, if any. + pub fn going_away(&self) -> Option<&GoingAway> { + self.going_away.as_ref() } /// Returns if the connection should close now, or wait until idle. @@ -141,7 +137,7 @@ impl GoAway { return Poll::Ready(Some(Ok(reason))); } else if self.should_close_now() { - return match self.going_away_reason() { + return match self.going_away().map(|going_away| going_away.reason) { Some(reason) => Poll::Ready(Some(Ok(reason))), None => Poll::Ready(None), }; @@ -150,3 +146,9 @@ impl GoAway { Poll::Ready(None) } } + +impl GoingAway { + pub(crate) fn reason(&self) -> Reason { + self.reason + } +} diff --git a/src/proto/mod.rs b/src/proto/mod.rs index f9e068b58..560927598 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -7,10 +7,10 @@ mod settings; mod streams; pub(crate) use self::connection::{Config, Connection}; -pub(crate) use self::error::Error; +pub use self::error::{Error, Initiator}; pub(crate) use self::peer::{Dyn as DynPeer, Peer}; pub(crate) use self::ping_pong::UserPings; -pub(crate) use self::streams::{OpaqueStreamRef, StreamRef, Streams}; +pub(crate) use self::streams::{DynStreams, OpaqueStreamRef, StreamRef, Streams}; pub(crate) use self::streams::{Open, PollReset, Prioritized}; use crate::codec::Codec; @@ -30,6 +30,9 @@ pub type PingPayload = [u8; 8]; pub type WindowSize = u32; // Constants -pub const MAX_WINDOW_SIZE: WindowSize = (1 << 31) - 1; +pub const MAX_WINDOW_SIZE: WindowSize = (1 << 31) - 1; // i32::MAX as u32 +pub const DEFAULT_REMOTE_RESET_STREAM_MAX: usize = 20; +pub const DEFAULT_LOCAL_RESET_COUNT_MAX: usize = 1024; pub const DEFAULT_RESET_STREAM_MAX: usize = 10; pub const DEFAULT_RESET_STREAM_SECS: u64 = 30; +pub const DEFAULT_MAX_SEND_BUFFER_SIZE: usize = 1024 * 400; diff --git a/src/proto/peer.rs b/src/proto/peer.rs index 8d327fbfc..cbe7fb289 100644 --- a/src/proto/peer.rs +++ b/src/proto/peer.rs @@ -1,7 +1,6 @@ -use crate::codec::RecvError; use crate::error::Reason; use crate::frame::{Pseudo, StreamId}; -use crate::proto::Open; +use crate::proto::{Error, Open}; use http::{HeaderMap, Request, Response}; @@ -11,21 +10,24 @@ use std::fmt; pub(crate) trait Peer { /// Message type polled from the transport type Poll: fmt::Debug; + const NAME: &'static str; fn r#dyn() -> Dyn; - fn is_server() -> bool; + //fn is_server() -> bool; fn convert_poll_message( pseudo: Pseudo, fields: HeaderMap, stream_id: StreamId, - ) -> Result; + ) -> Result; + /* fn is_local_init(id: StreamId) -> bool { assert!(!id.is_zero()); Self::is_server() == id.is_server_initiated() } + */ } /// A dynamic representation of `Peer`. @@ -60,7 +62,7 @@ impl Dyn { pseudo: Pseudo, fields: HeaderMap, stream_id: StreamId, - ) -> Result { + ) -> Result { if self.is_server() { crate::server::Peer::convert_poll_message(pseudo, fields, stream_id) .map(PollMessage::Server) @@ -71,12 +73,12 @@ impl Dyn { } /// Returns true if the remote peer can initiate a stream with the given ID. - pub fn ensure_can_open(&self, id: StreamId, mode: Open) -> Result<(), RecvError> { + pub fn ensure_can_open(&self, id: StreamId, mode: Open) -> Result<(), Error> { if self.is_server() { // Ensure that the ID is a valid client initiated ID if mode.is_push_promise() || !id.is_client_initiated() { proto_err!(conn: "cannot open stream {:?} - not client initiated", id); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } Ok(()) @@ -84,7 +86,7 @@ impl Dyn { // Ensure that the ID is a valid server initiated ID if !mode.is_push_promise() || !id.is_server_initiated() { proto_err!(conn: "cannot open stream {:?} - not server initiated", id); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } Ok(()) diff --git a/src/proto/ping_pong.rs b/src/proto/ping_pong.rs index 0022d4a5b..59023e26a 100644 --- a/src/proto/ping_pong.rs +++ b/src/proto/ping_pong.rs @@ -107,7 +107,7 @@ impl PingPong { &Ping::SHUTDOWN, "pending_ping should be for shutdown", ); - log::trace!("recv PING SHUTDOWN ack"); + tracing::trace!("recv PING SHUTDOWN ack"); return ReceivedPing::Shutdown; } @@ -117,7 +117,7 @@ impl PingPong { if let Some(ref users) = self.user_pings { if ping.payload() == &Ping::USER && users.receive_pong() { - log::trace!("recv PING USER ack"); + tracing::trace!("recv PING USER ack"); return ReceivedPing::Unknown; } } @@ -125,7 +125,7 @@ impl PingPong { // else we were acked a ping we didn't send? // The spec doesn't require us to do anything about this, // so for resiliency, just ignore it for now. - log::warn!("recv PING ack that we never sent: {:?}", ping); + tracing::warn!("recv PING ack that we never sent: {:?}", ping); ReceivedPing::Unknown } else { // Save the ping's payload to be sent as an acknowledgement. @@ -200,10 +200,7 @@ impl PingPong { impl ReceivedPing { pub(crate) fn is_shutdown(&self) -> bool { - match *self { - ReceivedPing::Shutdown => true, - _ => false, - } + matches!(*self, Self::Shutdown) } } @@ -211,11 +208,16 @@ impl ReceivedPing { impl UserPings { pub(crate) fn send_ping(&self) -> Result<(), Option> { - let prev = self.0.state.compare_and_swap( - USER_STATE_EMPTY, // current - USER_STATE_PENDING_PING, // new - Ordering::AcqRel, - ); + let prev = self + .0 + .state + .compare_exchange( + USER_STATE_EMPTY, // current + USER_STATE_PENDING_PING, // new + Ordering::AcqRel, + Ordering::Acquire, + ) + .unwrap_or_else(|v| v); match prev { USER_STATE_EMPTY => { @@ -234,11 +236,16 @@ impl UserPings { // Must register before checking state, in case state were to change // before we could register, and then the ping would just be lost. self.0.pong_task.register(cx.waker()); - let prev = self.0.state.compare_and_swap( - USER_STATE_RECEIVED_PONG, // current - USER_STATE_EMPTY, // new - Ordering::AcqRel, - ); + let prev = self + .0 + .state + .compare_exchange( + USER_STATE_RECEIVED_PONG, // current + USER_STATE_EMPTY, // new + Ordering::AcqRel, + Ordering::Acquire, + ) + .unwrap_or_else(|v| v); match prev { USER_STATE_RECEIVED_PONG => Poll::Ready(Ok(())), @@ -252,11 +259,16 @@ impl UserPings { impl UserPingsRx { fn receive_pong(&self) -> bool { - let prev = self.0.state.compare_and_swap( - USER_STATE_PENDING_PONG, // current - USER_STATE_RECEIVED_PONG, // new - Ordering::AcqRel, - ); + let prev = self + .0 + .state + .compare_exchange( + USER_STATE_PENDING_PONG, // current + USER_STATE_RECEIVED_PONG, // new + Ordering::AcqRel, + Ordering::Acquire, + ) + .unwrap_or_else(|v| v); if prev == USER_STATE_PENDING_PONG { self.0.pong_task.wake(); diff --git a/src/proto/settings.rs b/src/proto/settings.rs index b1d91e652..d6155fc3d 100644 --- a/src/proto/settings.rs +++ b/src/proto/settings.rs @@ -1,6 +1,5 @@ -use crate::codec::{RecvError, UserError}; +use crate::codec::UserError; use crate::error::Reason; -use crate::frame; use crate::proto::*; use std::task::{Context, Poll}; @@ -12,6 +11,9 @@ pub(crate) struct Settings { /// the socket first then the settings applied **before** receiving any /// further frames. remote: Option, + /// Whether the connection has received the initial SETTINGS frame from the + /// remote peer. + has_received_remote_initial_settings: bool, } #[derive(Debug)] @@ -32,6 +34,7 @@ impl Settings { // the handshake process. local: Local::WaitingAck(local), remote: None, + has_received_remote_initial_settings: false, } } @@ -40,7 +43,7 @@ impl Settings { frame: frame::Settings, codec: &mut Codec, streams: &mut Streams, - ) -> Result<(), RecvError> + ) -> Result<(), Error> where T: AsyncWrite + Unpin, B: Buf, @@ -50,7 +53,7 @@ impl Settings { if frame.is_ack() { match &self.local { Local::WaitingAck(local) => { - log::debug!("received settings ACK; applying {:?}", local); + tracing::debug!("received settings ACK; applying {:?}", local); if let Some(max) = local.max_frame_size() { codec.set_max_recv_frame_size(max as usize); @@ -60,6 +63,10 @@ impl Settings { codec.set_max_recv_header_list_size(max as usize); } + if let Some(val) = local.header_table_size() { + codec.set_recv_header_table_size(val as usize); + } + streams.apply_local_settings(local)?; self.local = Local::Synced; Ok(()) @@ -68,7 +75,7 @@ impl Settings { // We haven't sent any SETTINGS frames to be ACKed, so // this is very bizarre! Remote is either buggy or malicious. proto_err!(conn: "received unexpected settings ack"); - Err(RecvError::Connection(Reason::PROTOCOL_ERROR)) + Err(Error::library_go_away(Reason::PROTOCOL_ERROR)) } } } else { @@ -85,26 +92,35 @@ impl Settings { match &self.local { Local::ToSend(..) | Local::WaitingAck(..) => Err(UserError::SendSettingsWhilePending), Local::Synced => { - log::trace!("queue to send local settings: {:?}", frame); + tracing::trace!("queue to send local settings: {:?}", frame); self.local = Local::ToSend(frame); Ok(()) } } } + /// Sets `true` to `self.has_received_remote_initial_settings`. + /// Returns `true` if this method is called for the first time. + /// (i.e. it is the initial SETTINGS frame from the remote peer) + fn mark_remote_initial_settings_as_received(&mut self) -> bool { + let has_received = self.has_received_remote_initial_settings; + self.has_received_remote_initial_settings = true; + !has_received + } + pub(crate) fn poll_send( &mut self, cx: &mut Context, dst: &mut Codec, streams: &mut Streams, - ) -> Poll> + ) -> Poll> where T: AsyncWrite + Unpin, B: Buf, C: Buf, P: Peer, { - if let Some(settings) = &self.remote { + if let Some(settings) = self.remote.clone() { if !dst.poll_ready(cx)?.is_ready() { return Poll::Pending; } @@ -115,7 +131,10 @@ impl Settings { // Buffer the settings frame dst.buffer(frame.into()).expect("invalid settings frame"); - log::trace!("ACK sent; applying settings"); + tracing::trace!("ACK sent; applying settings"); + + let is_initial = self.mark_remote_initial_settings_as_received(); + streams.apply_remote_settings(&settings, is_initial)?; if let Some(val) = settings.header_table_size() { dst.set_send_header_table_size(val as usize); @@ -124,8 +143,6 @@ impl Settings { if let Some(val) = settings.max_frame_size() { dst.set_max_send_frame_size(val as usize); } - - streams.apply_remote_settings(settings)?; } self.remote = None; @@ -139,7 +156,7 @@ impl Settings { // Buffer the settings frame dst.buffer(settings.clone().into()) .expect("invalid settings frame"); - log::trace!("local settings sent; waiting for ack: {:?}", settings); + tracing::trace!("local settings sent; waiting for ack: {:?}", settings); self.local = Local::WaitingAck(settings.clone()); } diff --git a/src/proto/streams/buffer.rs b/src/proto/streams/buffer.rs index 652f2eda1..02d265061 100644 --- a/src/proto/streams/buffer.rs +++ b/src/proto/streams/buffer.rs @@ -29,6 +29,10 @@ impl Buffer { pub fn new() -> Self { Buffer { slab: Slab::new() } } + + pub fn is_empty(&self) -> bool { + self.slab.is_empty() + } } impl Deque { @@ -92,13 +96,4 @@ impl Deque { None => None, } } - - /* - pub fn peek_front<'a, T>(&self, buf: &'a Buffer) -> Option<&'a T> { - match self.indices { - Some(idxs) => Some(&buf.slab[idxs.head].value), - None => None, - } - } - */ } diff --git a/src/proto/streams/counts.rs b/src/proto/streams/counts.rs index bcd07e814..fdb07f1cd 100644 --- a/src/proto/streams/counts.rs +++ b/src/proto/streams/counts.rs @@ -21,10 +21,26 @@ pub(super) struct Counts { num_recv_streams: usize, /// Maximum number of pending locally reset streams - max_reset_streams: usize, + max_local_reset_streams: usize, /// Current number of pending locally reset streams - num_reset_streams: usize, + num_local_reset_streams: usize, + + /// Max number of "pending accept" streams that were remotely reset + max_remote_reset_streams: usize, + + /// Current number of "pending accept" streams that were remotely reset + num_remote_reset_streams: usize, + + /// Maximum number of locally reset streams due to protocol error across + /// the lifetime of the connection. + /// + /// When this gets exceeded, we issue GOAWAYs. + max_local_error_reset_streams: Option, + + /// Total number of locally reset streams due to protocol error across the + /// lifetime of the connection. + num_local_error_reset_streams: usize, } impl Counts { @@ -36,11 +52,23 @@ impl Counts { num_send_streams: 0, max_recv_streams: config.remote_max_initiated.unwrap_or(usize::MAX), num_recv_streams: 0, - max_reset_streams: config.local_reset_max, - num_reset_streams: 0, + max_local_reset_streams: config.local_reset_max, + num_local_reset_streams: 0, + max_remote_reset_streams: config.remote_reset_max, + num_remote_reset_streams: 0, + max_local_error_reset_streams: config.local_max_error_reset_streams, + num_local_error_reset_streams: 0, } } + /// Returns true when the next opened stream will reach capacity of outbound streams + /// + /// The number of client send streams is incremented in prioritize; send_request has to guess if + /// it should wait before allowing another request to be sent. + pub fn next_send_stream_will_reach_capacity(&self) -> bool { + self.max_send_streams <= (self.num_send_streams + 1) + } + /// Returns the current peer pub fn peer(&self) -> peer::Dyn { self.peer @@ -50,6 +78,26 @@ impl Counts { self.num_send_streams != 0 || self.num_recv_streams != 0 } + /// Returns true if we can issue another local reset due to protocol error. + pub fn can_inc_num_local_error_resets(&self) -> bool { + if let Some(max) = self.max_local_error_reset_streams { + max > self.num_local_error_reset_streams + } else { + true + } + } + + pub fn inc_num_local_error_resets(&mut self) { + assert!(self.can_inc_num_local_error_resets()); + + // Increment the number of remote initiated streams + self.num_local_error_reset_streams += 1; + } + + pub(crate) fn max_local_error_resets(&self) -> Option { + self.max_local_error_reset_streams + } + /// Returns true if the receive stream concurrency can be incremented pub fn can_inc_num_recv_streams(&self) -> bool { self.max_recv_streams > self.num_recv_streams @@ -90,7 +138,7 @@ impl Counts { /// Returns true if the number of pending reset streams can be incremented. pub fn can_inc_num_reset_streams(&self) -> bool { - self.max_reset_streams > self.num_reset_streams + self.max_local_reset_streams > self.num_local_reset_streams } /// Increments the number of pending reset streams. @@ -101,12 +149,41 @@ impl Counts { pub fn inc_num_reset_streams(&mut self) { assert!(self.can_inc_num_reset_streams()); - self.num_reset_streams += 1; + self.num_local_reset_streams += 1; } - pub fn apply_remote_settings(&mut self, settings: &frame::Settings) { - if let Some(val) = settings.max_concurrent_streams() { - self.max_send_streams = val as usize; + pub(crate) fn max_remote_reset_streams(&self) -> usize { + self.max_remote_reset_streams + } + + /// Returns true if the number of pending REMOTE reset streams can be + /// incremented. + pub(crate) fn can_inc_num_remote_reset_streams(&self) -> bool { + self.max_remote_reset_streams > self.num_remote_reset_streams + } + + /// Increments the number of pending REMOTE reset streams. + /// + /// # Panics + /// + /// Panics on failure as this should have been validated before hand. + pub(crate) fn inc_num_remote_reset_streams(&mut self) { + assert!(self.can_inc_num_remote_reset_streams()); + + self.num_remote_reset_streams += 1; + } + + pub(crate) fn dec_num_remote_reset_streams(&mut self) { + assert!(self.num_remote_reset_streams > 0); + + self.num_remote_reset_streams -= 1; + } + + pub fn apply_remote_settings(&mut self, settings: &frame::Settings, is_initial: bool) { + match settings.max_concurrent_streams() { + Some(val) => self.max_send_streams = val as usize, + None if is_initial => self.max_send_streams = usize::MAX, + None => {} } } @@ -133,7 +210,7 @@ impl Counts { // TODO: move this to macro? pub fn transition_after(&mut self, mut stream: store::Ptr, is_reset_counted: bool) { - log::trace!( + tracing::trace!( "transition_after; stream={:?}; state={:?}; is_closed={:?}; \ pending_send_empty={:?}; buffered_send_data={}; \ num_recv={}; num_send={}", @@ -155,7 +232,7 @@ impl Counts { } if stream.is_counted { - log::trace!("dec_num_streams; stream={:?}", stream.id); + tracing::trace!("dec_num_streams; stream={:?}", stream.id); // Decrement the number of active streams. self.dec_num_streams(&mut stream); } @@ -167,6 +244,18 @@ impl Counts { } } + /// Returns the maximum number of streams that can be initiated by this + /// peer. + pub(crate) fn max_send_streams(&self) -> usize { + self.max_send_streams + } + + /// Returns the maximum number of streams that can be initiated by the + /// remote peer. + pub(crate) fn max_recv_streams(&self) -> usize { + self.max_recv_streams + } + fn dec_num_streams(&mut self, stream: &mut store::Ptr) { assert!(stream.is_counted); @@ -182,8 +271,8 @@ impl Counts { } fn dec_num_reset_streams(&mut self) { - assert!(self.num_reset_streams > 0); - self.num_reset_streams -= 1; + assert!(self.num_local_reset_streams > 0); + self.num_local_reset_streams -= 1; } } diff --git a/src/proto/streams/flow_control.rs b/src/proto/streams/flow_control.rs index f3cea1699..57a935825 100644 --- a/src/proto/streams/flow_control.rs +++ b/src/proto/streams/flow_control.rs @@ -19,6 +19,7 @@ const UNCLAIMED_NUMERATOR: i32 = 1; const UNCLAIMED_DENOMINATOR: i32 = 2; #[test] +#[allow(clippy::assertions_on_constants)] fn sanity_unclaimed_ratio() { assert!(UNCLAIMED_NUMERATOR < UNCLAIMED_DENOMINATOR); assert!(UNCLAIMED_NUMERATOR >= 0); @@ -74,12 +75,12 @@ impl FlowControl { self.window_size > self.available } - pub fn claim_capacity(&mut self, capacity: WindowSize) { - self.available -= capacity; + pub fn claim_capacity(&mut self, capacity: WindowSize) -> Result<(), Reason> { + self.available.decrease_by(capacity) } - pub fn assign_capacity(&mut self, capacity: WindowSize) { - self.available += capacity; + pub fn assign_capacity(&mut self, capacity: WindowSize) -> Result<(), Reason> { + self.available.increase_by(capacity) } /// If a WINDOW_UPDATE frame should be sent, returns a positive number @@ -120,7 +121,7 @@ impl FlowControl { return Err(Reason::FLOW_CONTROL_ERROR); } - log::trace!( + tracing::trace!( "inc_window; sz={}; old={}; new={}", sz, self.window_size, @@ -135,49 +136,55 @@ impl FlowControl { /// /// This is called after receiving a SETTINGS frame with a lower /// INITIAL_WINDOW_SIZE value. - pub fn dec_send_window(&mut self, sz: WindowSize) { - log::trace!( + pub fn dec_send_window(&mut self, sz: WindowSize) -> Result<(), Reason> { + tracing::trace!( "dec_window; sz={}; window={}, available={}", sz, self.window_size, self.available ); - // This should not be able to overflow `window_size` from the bottom. - self.window_size -= sz; + // ~~This should not be able to overflow `window_size` from the bottom.~~ wrong. it can. + self.window_size.decrease_by(sz)?; + Ok(()) } /// Decrement the recv-side window size. /// /// This is called after receiving a SETTINGS ACK frame with a lower /// INITIAL_WINDOW_SIZE value. - pub fn dec_recv_window(&mut self, sz: WindowSize) { - log::trace!( + pub fn dec_recv_window(&mut self, sz: WindowSize) -> Result<(), Reason> { + tracing::trace!( "dec_recv_window; sz={}; window={}, available={}", sz, self.window_size, self.available ); // This should not be able to overflow `window_size` from the bottom. - self.window_size -= sz; - self.available -= sz; + self.window_size.decrease_by(sz)?; + self.available.decrease_by(sz)?; + Ok(()) } /// Decrements the window reflecting data has actually been sent. The caller /// must ensure that the window has capacity. - pub fn send_data(&mut self, sz: WindowSize) { - log::trace!( + pub fn send_data(&mut self, sz: WindowSize) -> Result<(), Reason> { + tracing::trace!( "send_data; sz={}; window={}; available={}", sz, self.window_size, self.available ); - // Ensure that the argument is correct - assert!(sz <= self.window_size); + // If send size is zero it's meaningless to update flow control window + if sz > 0 { + // Ensure that the argument is correct + assert!(self.window_size.0 >= sz as i32); - // Update values - self.window_size -= sz; - self.available -= sz; + // Update values + self.window_size.decrease_by(sz)?; + self.available.decrease_by(sz)?; + } + Ok(()) } } @@ -188,7 +195,7 @@ impl FlowControl { /// /// This type tries to centralize the knowledge of addition and subtraction /// to this capacity, instead of having integer casts throughout the source. -#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd)] pub struct Window(i32); impl Window { @@ -204,60 +211,48 @@ impl Window { assert!(self.0 >= 0, "negative Window"); self.0 as WindowSize } -} -impl PartialEq for Window { - fn eq(&self, other: &WindowSize) -> bool { - if self.0 < 0 { - false + pub fn decrease_by(&mut self, other: WindowSize) -> Result<(), Reason> { + if let Some(v) = self.0.checked_sub(other as i32) { + self.0 = v; + Ok(()) } else { - (self.0 as WindowSize).eq(other) + Err(Reason::FLOW_CONTROL_ERROR) } } -} -impl PartialEq for WindowSize { - fn eq(&self, other: &Window) -> bool { - other.eq(self) + pub fn increase_by(&mut self, other: WindowSize) -> Result<(), Reason> { + let other = self.add(other)?; + self.0 = other.0; + Ok(()) } -} -impl PartialOrd for Window { - fn partial_cmp(&self, other: &WindowSize) -> Option<::std::cmp::Ordering> { - if self.0 < 0 { - Some(::std::cmp::Ordering::Less) + pub fn add(&self, other: WindowSize) -> Result { + if let Some(v) = self.0.checked_add(other as i32) { + Ok(Self(v)) } else { - (self.0 as WindowSize).partial_cmp(other) + Err(Reason::FLOW_CONTROL_ERROR) } } } -impl PartialOrd for WindowSize { - fn partial_cmp(&self, other: &Window) -> Option<::std::cmp::Ordering> { - if other.0 < 0 { - Some(::std::cmp::Ordering::Greater) +impl PartialEq for Window { + fn eq(&self, other: &usize) -> bool { + if self.0 < 0 { + false } else { - self.partial_cmp(&(other.0 as WindowSize)) + (self.0 as usize).eq(other) } } } -impl ::std::ops::SubAssign for Window { - fn sub_assign(&mut self, other: WindowSize) { - self.0 -= other as i32; - } -} - -impl ::std::ops::Add for Window { - type Output = Self; - fn add(self, other: WindowSize) -> Self::Output { - Window(self.0 + other as i32) - } -} - -impl ::std::ops::AddAssign for Window { - fn add_assign(&mut self, other: WindowSize) { - self.0 += other as i32; +impl PartialOrd for Window { + fn partial_cmp(&self, other: &usize) -> Option<::std::cmp::Ordering> { + if self.0 < 0 { + Some(::std::cmp::Ordering::Less) + } else { + (self.0 as usize).partial_cmp(other) + } } } diff --git a/src/proto/streams/mod.rs b/src/proto/streams/mod.rs index 508d9a1e3..c4a832342 100644 --- a/src/proto/streams/mod.rs +++ b/src/proto/streams/mod.rs @@ -7,12 +7,13 @@ mod send; mod state; mod store; mod stream; +#[allow(clippy::module_inception)] mod streams; pub(crate) use self::prioritize::Prioritized; pub(crate) use self::recv::Open; pub(crate) use self::send::PollReset; -pub(crate) use self::streams::{OpaqueStreamRef, StreamRef, Streams}; +pub(crate) use self::streams::{DynStreams, OpaqueStreamRef, StreamRef, Streams}; use self::buffer::Buffer; use self::counts::Counts; @@ -32,30 +33,43 @@ use std::time::Duration; #[derive(Debug)] pub struct Config { - /// Initial window size of locally initiated streams - pub local_init_window_sz: WindowSize, - /// Initial maximum number of locally initiated streams. /// After receiving a Settings frame from the remote peer, /// the connection will overwrite this value with the /// MAX_CONCURRENT_STREAMS specified in the frame. pub initial_max_send_streams: usize, + /// Max amount of DATA bytes to buffer per stream. + pub local_max_buffer_size: usize, + /// The stream ID to start the next local stream with pub local_next_stream_id: StreamId, /// If the local peer is willing to receive push promises pub local_push_enabled: bool, + /// If extended connect protocol is enabled. + pub extended_connect_protocol_enabled: bool, + /// How long a locally reset stream should ignore frames pub local_reset_duration: Duration, /// Maximum number of locally reset streams to keep at a time pub local_reset_max: usize, + /// Maximum number of remotely reset "pending accept" streams to keep at a + /// time. Going over this number results in a connection error. + pub remote_reset_max: usize, + /// Initial window size of remote initiated streams pub remote_init_window_sz: WindowSize, /// Maximum number of remote initiated streams pub remote_max_initiated: Option, + + /// Maximum number of locally reset streams due to protocol error across + /// the lifetime of the connection. + /// + /// When this gets exceeded, we issue GOAWAYs. + pub local_max_error_reset_streams: Option, } diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index a13393282..14b37e223 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -1,15 +1,17 @@ use super::store::Resolve; use super::*; -use crate::frame::{Reason, StreamId}; +use crate::frame::Reason; use crate::codec::UserError; use crate::codec::UserError::*; -use bytes::buf::ext::{BufExt, Take}; -use std::io; -use std::task::{Context, Poll, Waker}; -use std::{cmp, fmt, mem}; +use bytes::buf::Take; +use std::{ + cmp::{self, Ordering}, + fmt, io, mem, + task::{Context, Poll, Waker}, +}; /// # Warning /// @@ -18,7 +20,7 @@ use std::{cmp, fmt, mem}; /// This is because "idle" stream IDs – those which have been initiated but /// have yet to receive frames – will be implicitly closed on receipt of a /// frame on a higher stream ID. If these queues was not ordered by stream -/// IDs, some mechanism would be necessary to ensure that the lowest-numberedh] +/// IDs, some mechanism would be necessary to ensure that the lowest-numbered] /// idle stream is opened first. #[derive(Debug)] pub(super) struct Prioritize { @@ -51,6 +53,9 @@ pub(super) struct Prioritize { /// What `DATA` frame is currently being sent in the codec. in_flight_data_frame: InFlightData, + + /// The maximum amount of bytes a stream should buffer. + max_buffer_size: usize, } #[derive(Debug, Eq, PartialEq)] @@ -82,9 +87,11 @@ impl Prioritize { flow.inc_window(config.remote_init_window_sz) .expect("invalid initial window size"); - flow.assign_capacity(config.remote_init_window_sz); + // TODO: proper error handling + let _res = flow.assign_capacity(config.remote_init_window_sz); + debug_assert!(_res.is_ok()); - log::trace!("Prioritize::new; flow={:?}", flow); + tracing::trace!("Prioritize::new; flow={:?}", flow); Prioritize { pending_send: store::Queue::new(), @@ -93,9 +100,14 @@ impl Prioritize { flow, last_opened_id: StreamId::ZERO, in_flight_data_frame: InFlightData::Nothing, + max_buffer_size: config.local_max_buffer_size, } } + pub(crate) fn max_buffer_size(&self) -> usize { + self.max_buffer_size + } + /// Queue a frame to be sent to the remote pub fn queue_frame( &mut self, @@ -104,6 +116,8 @@ impl Prioritize { stream: &mut store::Ptr, task: &mut Option, ) { + let span = tracing::trace_span!("Prioritize::queue_frame", ?stream.id); + let _e = span.enter(); // Queue the frame in the buffer stream.pending_send.push_back(buffer, frame); self.schedule_send(stream, task); @@ -112,7 +126,7 @@ impl Prioritize { pub fn schedule_send(&mut self, stream: &mut store::Ptr, task: &mut Option) { // If the stream is waiting to be opened, nothing more to do. if stream.is_send_ready() { - log::trace!("schedule_send; {:?}", stream.id); + tracing::trace!(?stream.id, "schedule_send"); // Queue the stream self.pending_send.push(stream); @@ -156,22 +170,29 @@ impl Prioritize { } // Update the buffered data counter - stream.buffered_send_data += sz; + stream.buffered_send_data += sz as usize; - log::trace!( - "send_data; sz={}; buffered={}; requested={}", - sz, - stream.buffered_send_data, - stream.requested_send_capacity - ); + let span = + tracing::trace_span!("send_data", sz, requested = stream.requested_send_capacity); + let _e = span.enter(); + tracing::trace!(buffered = stream.buffered_send_data); // Implicitly request more send capacity if not enough has been // requested yet. - if stream.requested_send_capacity < stream.buffered_send_data { + if (stream.requested_send_capacity as usize) < stream.buffered_send_data { // Update the target requested capacity - stream.requested_send_capacity = stream.buffered_send_data; + stream.requested_send_capacity = + cmp::min(stream.buffered_send_data, WindowSize::MAX as usize) as WindowSize; - self.try_assign_capacity(stream); + // `try_assign_capacity` will queue the stream to `pending_capacity` if the capcaity + // cannot be assigned at the time it is called. + // + // Streams over the max concurrent count will still call `send_data` so we should be + // careful not to put it into `pending_capacity` as it will starve the connection + // capacity for other streams + if !stream.is_pending_open { + self.try_assign_capacity(stream); + } } if frame.is_end_stream() { @@ -179,10 +200,9 @@ impl Prioritize { self.reserve_capacity(0, stream, counts); } - log::trace!( - "send_data (2); available={}; buffered={}", - stream.send_flow.available(), - stream.buffered_send_data + tracing::trace!( + available = %stream.send_flow.available(), + buffered = stream.buffered_send_data, ); // The `stream.buffered_send_data == 0` check is here so that, if a zero @@ -214,50 +234,58 @@ impl Prioritize { stream: &mut store::Ptr, counts: &mut Counts, ) { - log::trace!( - "reserve_capacity; stream={:?}; requested={:?}; effective={:?}; curr={:?}", - stream.id, - capacity, - capacity + stream.buffered_send_data, - stream.requested_send_capacity + let span = tracing::trace_span!( + "reserve_capacity", + ?stream.id, + requested = capacity, + effective = (capacity as usize) + stream.buffered_send_data, + curr = stream.requested_send_capacity ); + let _e = span.enter(); // Actual capacity is `capacity` + the current amount of buffered data. // If it were less, then we could never send out the buffered data. - let capacity = capacity + stream.buffered_send_data; + let capacity = (capacity as usize) + stream.buffered_send_data; - if capacity == stream.requested_send_capacity { - // Nothing to do - } else if capacity < stream.requested_send_capacity { - // Update the target requested capacity - stream.requested_send_capacity = capacity; + match capacity.cmp(&(stream.requested_send_capacity as usize)) { + Ordering::Equal => { + // Nothing to do + } + Ordering::Less => { + // Update the target requested capacity + stream.requested_send_capacity = capacity as WindowSize; - // Currently available capacity assigned to the stream - let available = stream.send_flow.available().as_size(); + // Currently available capacity assigned to the stream + let available = stream.send_flow.available().as_size(); - // If the stream has more assigned capacity than requested, reclaim - // some for the connection - if available > capacity { - let diff = available - capacity; + // If the stream has more assigned capacity than requested, reclaim + // some for the connection + if available as usize > capacity { + let diff = available - capacity as WindowSize; - stream.send_flow.claim_capacity(diff); + // TODO: proper error handling + let _res = stream.send_flow.claim_capacity(diff); + debug_assert!(_res.is_ok()); - self.assign_connection_capacity(diff, stream, counts); - } - } else { - // If trying to *add* capacity, but the stream send side is closed, - // there's nothing to be done. - if stream.state.is_send_closed() { - return; + self.assign_connection_capacity(diff, stream, counts); + } } + Ordering::Greater => { + // If trying to *add* capacity, but the stream send side is closed, + // there's nothing to be done. + if stream.state.is_send_closed() { + return; + } - // Update the target requested capacity - stream.requested_send_capacity = capacity; + // Update the target requested capacity + stream.requested_send_capacity = + cmp::min(capacity, WindowSize::MAX as usize) as WindowSize; - // Try to assign additional capacity to the stream. If none is - // currently available, the stream will be queued to receive some - // when more becomes available. - self.try_assign_capacity(stream); + // Try to assign additional capacity to the stream. If none is + // currently available, the stream will be queued to receive some + // when more becomes available. + self.try_assign_capacity(stream); + } } } @@ -266,13 +294,14 @@ impl Prioritize { inc: WindowSize, stream: &mut store::Ptr, ) -> Result<(), Reason> { - log::trace!( - "recv_stream_window_update; stream={:?}; state={:?}; inc={}; flow={:?}", - stream.id, - stream.state, + let span = tracing::trace_span!( + "recv_stream_window_update", + ?stream.id, + ?stream.state, inc, - stream.send_flow + flow = ?stream.send_flow ); + let _e = span.enter(); if stream.state.is_send_closed() && stream.buffered_send_data == 0 { // We can't send any data, so don't bother doing anything else. @@ -306,27 +335,35 @@ impl Prioritize { /// connection pub fn reclaim_all_capacity(&mut self, stream: &mut store::Ptr, counts: &mut Counts) { let available = stream.send_flow.available().as_size(); - stream.send_flow.claim_capacity(available); - // Re-assign all capacity to the connection - self.assign_connection_capacity(available, stream, counts); + if available > 0 { + // TODO: proper error handling + let _res = stream.send_flow.claim_capacity(available); + debug_assert!(_res.is_ok()); + // Re-assign all capacity to the connection + self.assign_connection_capacity(available, stream, counts); + } } /// Reclaim just reserved capacity, not buffered capacity, and re-assign /// it to the connection pub fn reclaim_reserved_capacity(&mut self, stream: &mut store::Ptr, counts: &mut Counts) { // only reclaim requested capacity that isn't already buffered - if stream.requested_send_capacity > stream.buffered_send_data { - let reserved = stream.requested_send_capacity - stream.buffered_send_data; + if stream.requested_send_capacity as usize > stream.buffered_send_data { + let reserved = stream.requested_send_capacity - stream.buffered_send_data as WindowSize; - stream.send_flow.claim_capacity(reserved); + // TODO: proper error handling + let _res = stream.send_flow.claim_capacity(reserved); + debug_assert!(_res.is_ok()); self.assign_connection_capacity(reserved, stream, counts); } } pub fn clear_pending_capacity(&mut self, store: &mut Store, counts: &mut Counts) { + let span = tracing::trace_span!("clear_pending_capacity"); + let _e = span.enter(); while let Some(stream) = self.pending_capacity.pop(store) { counts.transition(stream, |_, stream| { - log::trace!("clear_pending_capacity; stream={:?}", stream.id); + tracing::trace!(?stream.id, "clear_pending_capacity"); }) } } @@ -339,9 +376,12 @@ impl Prioritize { ) where R: Resolve, { - log::trace!("assign_connection_capacity; inc={}", inc); + let span = tracing::trace_span!("assign_connection_capacity", inc); + let _e = span.enter(); - self.flow.assign_capacity(inc); + // TODO: proper error handling + let _res = self.flow.assign_capacity(inc); + debug_assert!(_res.is_ok()); // Assign newly acquired capacity to streams pending capacity. while self.flow.available() > 0 { @@ -358,11 +398,11 @@ impl Prioritize { continue; } - counts.transition(stream, |_, mut stream| { + counts.transition(stream, |_, stream| { // Try to assign capacity to the stream. This will also re-queue the // stream if there isn't enough connection level capacity to fulfill // the capacity request. - self.try_assign_capacity(&mut stream); + self.try_assign_capacity(stream); }) } } @@ -373,7 +413,7 @@ impl Prioritize { // Total requested should never go below actual assigned // (Note: the window size can go lower than assigned) - debug_assert!(total_requested >= stream.send_flow.available()); + debug_assert!(stream.send_flow.available() <= total_requested as usize); // The amount of additional capacity that the stream requests. // Don't assign more than the window has available! @@ -382,15 +422,14 @@ impl Prioritize { // Can't assign more than what is available stream.send_flow.window_size() - stream.send_flow.available().as_size(), ); - - log::trace!( - "try_assign_capacity; stream={:?}, requested={}; additional={}; buffered={}; window={}; conn={}", - stream.id, - total_requested, + let span = tracing::trace_span!("try_assign_capacity", ?stream.id); + let _e = span.enter(); + tracing::trace!( + requested = total_requested, additional, - stream.buffered_send_data, - stream.send_flow.window_size(), - self.flow.available() + buffered = stream.buffered_send_data, + window = stream.send_flow.window_size(), + conn = %self.flow.available() ); if additional == 0 { @@ -416,24 +455,25 @@ impl Prioritize { // TODO: Should prioritization factor into this? let assign = cmp::min(conn_available, additional); - log::trace!(" assigning; stream={:?}, capacity={}", stream.id, assign,); + tracing::trace!(capacity = assign, "assigning"); // Assign the capacity to the stream - stream.assign_capacity(assign); + stream.assign_capacity(assign, self.max_buffer_size); // Claim the capacity from the connection - self.flow.claim_capacity(assign); + // TODO: proper error handling + let _res = self.flow.claim_capacity(assign); + debug_assert!(_res.is_ok()); } - log::trace!( - "try_assign_capacity(2); available={}; requested={}; buffered={}; has_unavailable={:?}", - stream.send_flow.available(), - stream.requested_send_capacity, - stream.buffered_send_data, - stream.send_flow.has_unavailable() + tracing::trace!( + available = %stream.send_flow.available(), + requested = stream.requested_send_capacity, + buffered = stream.buffered_send_data, + has_unavailable = %stream.send_flow.has_unavailable() ); - if stream.send_flow.available() < stream.requested_send_capacity + if stream.send_flow.available() < stream.requested_send_capacity as usize && stream.send_flow.has_unavailable() { // The stream requires additional capacity and the stream's @@ -485,14 +525,17 @@ impl Prioritize { // The max frame length let max_frame_len = dst.max_send_frame_size(); - log::trace!("poll_complete"); + tracing::trace!("poll_complete"); loop { - self.schedule_pending_open(store, counts); + if let Some(mut stream) = self.pop_pending_open(store, counts) { + self.pending_send.push_front(&mut stream); + self.try_assign_capacity(&mut stream); + } match self.pop_frame(buffer, store, max_frame_len, counts) { Some(frame) => { - log::trace!("writing frame={:?}", frame); + tracing::trace!(?frame, "writing"); debug_assert_eq!(self.in_flight_data_frame, InFlightData::Nothing); if let Frame::Data(ref frame) = frame { @@ -538,47 +581,62 @@ impl Prioritize { where B: Buf, { - log::trace!("try reclaim frame"); + let span = tracing::trace_span!("try_reclaim_frame"); + let _e = span.enter(); // First check if there are any data chunks to take back if let Some(frame) = dst.take_last_data_frame() { - log::trace!( - " -> reclaimed; frame={:?}; sz={}", - frame, - frame.payload().inner.get_ref().remaining() - ); - - let mut eos = false; - let key = frame.payload().stream; - - match mem::replace(&mut self.in_flight_data_frame, InFlightData::Nothing) { - InFlightData::Nothing => panic!("wasn't expecting a frame to reclaim"), - InFlightData::Drop => { - log::trace!("not reclaiming frame for cancelled stream"); - return false; - } - InFlightData::DataFrame(k) => { - debug_assert_eq!(k, key); - } - } + self.reclaim_frame_inner(buffer, store, frame) + } else { + false + } + } + + fn reclaim_frame_inner( + &mut self, + buffer: &mut Buffer>, + store: &mut Store, + frame: frame::Data>, + ) -> bool + where + B: Buf, + { + tracing::trace!( + ?frame, + sz = frame.payload().inner.get_ref().remaining(), + "reclaimed" + ); - let mut frame = frame.map(|prioritized| { - // TODO: Ensure fully written - eos = prioritized.end_of_stream; - prioritized.inner.into_inner() - }); + let mut eos = false; + let key = frame.payload().stream; - if frame.payload().has_remaining() { - let mut stream = store.resolve(key); + match mem::replace(&mut self.in_flight_data_frame, InFlightData::Nothing) { + InFlightData::Nothing => panic!("wasn't expecting a frame to reclaim"), + InFlightData::Drop => { + tracing::trace!("not reclaiming frame for cancelled stream"); + return false; + } + InFlightData::DataFrame(k) => { + debug_assert_eq!(k, key); + } + } - if eos { - frame.set_end_stream(true); - } + let mut frame = frame.map(|prioritized| { + // TODO: Ensure fully written + eos = prioritized.end_of_stream; + prioritized.inner.into_inner() + }); - self.push_back_frame(frame.into(), buffer, &mut stream); + if frame.payload().has_remaining() { + let mut stream = store.resolve(key); - return true; + if eos { + frame.set_end_stream(true); } + + self.push_back_frame(frame.into(), buffer, &mut stream); + + return true; } false @@ -603,11 +661,12 @@ impl Prioritize { } pub fn clear_queue(&mut self, buffer: &mut Buffer>, stream: &mut store::Ptr) { - log::trace!("clear_queue; stream={:?}", stream.id); + let span = tracing::trace_span!("clear_queue", ?stream.id); + let _e = span.enter(); // TODO: make this more efficient? while let Some(frame) = stream.pending_send.pop_front(buffer) { - log::trace!("dropping; frame={:?}", frame); + tracing::trace!(?frame, "dropping"); } stream.buffered_send_data = 0; @@ -644,16 +703,14 @@ impl Prioritize { where B: Buf, { - log::trace!("pop_frame"); + let span = tracing::trace_span!("pop_frame"); + let _e = span.enter(); loop { match self.pending_send.pop(store) { Some(mut stream) => { - log::trace!( - "pop_frame; stream={:?}; stream.state={:?}", - stream.id, - stream.state - ); + let span = tracing::trace_span!("popped", ?stream.id, ?stream.state); + let _e = span.enter(); // It's possible that this stream, besides having data to send, // is also queued to send a reset, and thus is already in the queue @@ -662,11 +719,7 @@ impl Prioritize { // To be safe, we just always ask the stream. let is_pending_reset = stream.is_pending_reset_expiration(); - log::trace!( - " --> stream={:?}; is_pending_reset={:?};", - stream.id, - is_pending_reset - ); + tracing::trace!(is_pending_reset); let frame = match stream.pending_send.pop_front(buffer) { Some(Frame::Data(mut frame)) => { @@ -675,25 +728,20 @@ impl Prioritize { let stream_capacity = stream.send_flow.available(); let sz = frame.payload().remaining(); - log::trace!( - " --> data frame; stream={:?}; sz={}; eos={:?}; window={}; \ - available={}; requested={}; buffered={};", - frame.stream_id(), + tracing::trace!( sz, - frame.is_end_stream(), - stream_capacity, - stream.send_flow.available(), - stream.requested_send_capacity, - stream.buffered_send_data, + eos = frame.is_end_stream(), + window = %stream_capacity, + available = %stream.send_flow.available(), + requested = stream.requested_send_capacity, + buffered = stream.buffered_send_data, + "data frame" ); // Zero length data frames always have capacity to // be sent. if sz > 0 && stream_capacity == 0 { - log::trace!( - " --> stream capacity is 0; requested={}", - stream.requested_send_capacity - ); + tracing::trace!("stream capacity is 0"); // Ensure that the stream is waiting for // connection level capacity @@ -721,34 +769,45 @@ impl Prioritize { // capacity at this point. debug_assert!(len <= self.flow.window_size()); - log::trace!(" --> sending data frame; len={}", len); - - // Update the flow control - log::trace!(" -- updating stream flow --"); - stream.send_flow.send_data(len); - - // Decrement the stream's buffered data counter - debug_assert!(stream.buffered_send_data >= len); - stream.buffered_send_data -= len; - stream.requested_send_capacity -= len; - - // Assign the capacity back to the connection that - // was just consumed from the stream in the previous - // line. - self.flow.assign_capacity(len); - - log::trace!(" -- updating connection flow --"); - self.flow.send_data(len); - - // Wrap the frame's data payload to ensure that the - // correct amount of data gets written. + // Check if the stream level window the peer knows is available. In some + // scenarios, maybe the window we know is available but the window which + // peer knows is not. + if len > 0 && len > stream.send_flow.window_size() { + stream.pending_send.push_front(buffer, frame.into()); + continue; + } - let eos = frame.is_end_stream(); - let len = len as usize; + tracing::trace!(len, "sending data frame"); - if frame.payload().remaining() > len { - frame.set_end_stream(false); - } + // Update the flow control + tracing::trace_span!("updating stream flow").in_scope(|| { + stream.send_data(len, self.max_buffer_size); + + // Assign the capacity back to the connection that + // was just consumed from the stream in the previous + // line. + // TODO: proper error handling + let _res = self.flow.assign_capacity(len); + debug_assert!(_res.is_ok()); + }); + + let (eos, len) = tracing::trace_span!("updating connection flow") + .in_scope(|| { + // TODO: proper error handling + let _res = self.flow.send_data(len); + debug_assert!(_res.is_ok()); + + // Wrap the frame's data payload to ensure that the + // correct amount of data gets written. + + let eos = frame.is_end_stream(); + let len = len as usize; + + if frame.payload().remaining() > len { + frame.set_end_stream(false); + } + (eos, len) + }); Frame::Data(frame.map(|buf| Prioritized { inner: buf.take(len), @@ -780,7 +839,10 @@ impl Prioritize { }), None => { if let Some(reason) = stream.state.get_scheduled_reset() { - stream.state.set_reset(reason); + let stream_id = stream.id; + stream + .state + .set_reset(stream_id, reason, Initiator::Library); let frame = frame::Reset::new(stream.id, reason); Frame::Reset(frame) @@ -789,7 +851,7 @@ impl Prioritize { // had data buffered to be sent, but all the frames are cleared // in clear_queue(). Instead of doing O(N) traversal through queue // to remove, lets just ignore the stream here. - log::trace!("removing dangling stream from pending_send"); + tracing::trace!("removing dangling stream from pending_send"); // Since this should only happen as a consequence of `clear_queue`, // we must be in a closed state of some kind. debug_assert!(stream.state.is_closed()); @@ -799,7 +861,7 @@ impl Prioritize { } }; - log::trace!("pop_frame; frame={:?}", frame); + tracing::trace!("pop_frame; frame={:?}", frame); if cfg!(debug_assertions) && stream.state.is_idle() { debug_assert!(stream.id > self.last_opened_id); @@ -823,20 +885,24 @@ impl Prioritize { } } - fn schedule_pending_open(&mut self, store: &mut Store, counts: &mut Counts) { - log::trace!("schedule_pending_open"); + fn pop_pending_open<'s>( + &mut self, + store: &'s mut Store, + counts: &mut Counts, + ) -> Option> { + tracing::trace!("schedule_pending_open"); // check for any pending open streams - while counts.can_inc_num_send_streams() { + if counts.can_inc_num_send_streams() { if let Some(mut stream) = self.pending_open.pop(store) { - log::trace!("schedule_pending_open; stream={:?}", stream.id); + tracing::trace!("schedule_pending_open; stream={:?}", stream.id); counts.inc_num_send_streams(&mut stream); - self.pending_send.push(&mut stream); stream.notify_send(); - } else { - return; + return Some(stream); } } + + None } } @@ -850,8 +916,12 @@ where self.inner.remaining() } - fn bytes(&self) -> &[u8] { - self.inner.bytes() + fn chunk(&self) -> &[u8] { + self.inner.chunk() + } + + fn chunks_vectored<'a>(&'a self, dst: &mut [std::io::IoSlice<'a>]) -> usize { + self.inner.chunks_vectored(dst) } fn advance(&mut self, cnt: usize) { diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index f0e23a4ad..46cb87cd0 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -1,14 +1,14 @@ use super::*; -use crate::codec::{RecvError, UserError}; +use crate::codec::UserError; use crate::frame::{PushPromiseHeaderError, Reason, DEFAULT_INITIAL_WINDOW_SIZE}; -use crate::{frame, proto}; -use std::task::Context; +use crate::proto; use http::{HeaderMap, Request, Response}; +use std::cmp::Ordering; use std::io; -use std::task::{Poll, Waker}; -use std::time::{Duration, Instant}; +use std::task::{Context, Poll, Waker}; +use std::time::Instant; #[derive(Debug)] pub(super) struct Recv { @@ -54,8 +54,11 @@ pub(super) struct Recv { /// Refused StreamId, this represents a frame that must be sent out. refused: Option, - /// If push promises are allowed to be recevied. + /// If push promises are allowed to be received. is_push_enabled: bool, + + /// If extended connect protocol is enabled. + is_extended_connect_protocol_enabled: bool, } #[derive(Debug)] @@ -68,7 +71,7 @@ pub(super) enum Event { #[derive(Debug)] pub(super) enum RecvHeaderBlockError { Oversize(T), - State(RecvError), + State(Error), } #[derive(Debug)] @@ -77,12 +80,6 @@ pub(crate) enum Open { Headers, } -#[derive(Debug, Clone, Copy)] -struct Indices { - head: store::Key, - tail: store::Key, -} - impl Recv { pub fn new(peer: peer::Dyn, config: &Config) -> Self { let next_stream_id = if peer.is_server() { 1 } else { 2 }; @@ -93,10 +90,10 @@ impl Recv { // settings flow.inc_window(DEFAULT_INITIAL_WINDOW_SIZE) .expect("invalid initial remote window size"); - flow.assign_capacity(DEFAULT_INITIAL_WINDOW_SIZE); + flow.assign_capacity(DEFAULT_INITIAL_WINDOW_SIZE).unwrap(); Recv { - init_window_sz: config.local_init_window_sz, + init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, flow, in_flight_data: 0 as WindowSize, next_stream_id: Ok(next_stream_id.into()), @@ -109,6 +106,7 @@ impl Recv { buffer: Buffer::new(), refused: None, is_push_enabled: config.local_push_enabled, + is_extended_connect_protocol_enabled: config.extended_connect_protocol_enabled, } } @@ -130,7 +128,7 @@ impl Recv { id: StreamId, mode: Open, counts: &mut Counts, - ) -> Result, RecvError> { + ) -> Result, Error> { assert!(self.refused.is_none()); counts.peer().ensure_can_open(id, mode)?; @@ -138,7 +136,7 @@ impl Recv { let next_id = self.next_stream_id()?; if id < next_id { proto_err!(conn: "id ({:?}) < next_id ({:?})", id, next_id); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } self.next_stream_id = id.next_id(); @@ -160,8 +158,8 @@ impl Recv { stream: &mut store::Ptr, counts: &mut Counts, ) -> Result<(), RecvHeaderBlockError>> { - log::trace!("opening stream; init_window={}", self.init_window_sz); - let is_initial = stream.state.recv_open(frame.is_end_stream())?; + tracing::trace!("opening stream; init_window={}", self.init_window_sz); + let is_initial = stream.state.recv_open(&frame)?; if is_initial { // TODO: be smarter about this logic @@ -180,13 +178,9 @@ impl Recv { if let Some(content_length) = frame.fields().get(header::CONTENT_LENGTH) { let content_length = match frame::parse_u64(content_length.as_bytes()) { Ok(v) => v, - Err(()) => { + Err(_) => { proto_err!(stream: "could not parse content-length; stream={:?}", stream.id); - return Err(RecvError::Stream { - id: stream.id, - reason: Reason::PROTOCOL_ERROR, - } - .into()); + return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into()); } }; @@ -206,7 +200,7 @@ impl Recv { // So, if peer is a server, we'll send a 431. In either case, // an error is recorded, which will send a REFUSED_STREAM, // since we don't want any of the data frames either. - log::debug!( + tracing::debug!( "stream error REQUEST_HEADER_FIELDS_TOO_LARGE -- \ recv_headers: frame is over size; stream={:?}", stream.id @@ -226,20 +220,38 @@ impl Recv { let stream_id = frame.stream_id(); let (pseudo, fields) = frame.into_parts(); - let message = counts - .peer() - .convert_poll_message(pseudo, fields, stream_id)?; - // Push the frame onto the stream's recv buffer - stream - .pending_recv - .push_back(&mut self.buffer, Event::Headers(message)); - stream.notify_recv(); + if pseudo.protocol.is_some() + && counts.peer().is_server() + && !self.is_extended_connect_protocol_enabled + { + proto_err!(stream: "cannot use :protocol if extended connect protocol is disabled; stream={:?}", stream.id); + return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into()); + } - // Only servers can receive a headers frame that initiates the stream. - // This is verified in `Streams` before calling this function. - if counts.peer().is_server() { - self.pending_accept.push(stream); + if pseudo.status.is_some() && counts.peer().is_server() { + proto_err!(stream: "cannot use :status header for requests; stream={:?}", stream.id); + return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into()); + } + + if !pseudo.is_informational() { + let message = counts + .peer() + .convert_poll_message(pseudo, fields, stream_id)?; + + // Push the frame onto the stream's recv buffer + stream + .pending_recv + .push_back(&mut self.buffer, Event::Headers(message)); + stream.notify_recv(); + + // Only servers can receive a headers frame that initiates the stream. + // This is verified in `Streams` before calling this function. + if counts.peer().is_server() { + // Correctness: never push a stream to `pending_accept` without having the + // corresponding headers frame pushed to `stream.pending_recv`. + self.pending_accept.push(stream); + } } Ok(()) @@ -247,13 +259,16 @@ impl Recv { /// Called by the server to get the request /// - /// TODO: Should this fn return `Result`? + /// # Panics + /// + /// Panics if `stream.pending_recv` has no `Event::Headers` queued. + /// pub fn take_request(&mut self, stream: &mut store::Ptr) -> Request<()> { use super::peer::PollMessage::*; match stream.pending_recv.pop_front(&mut self.buffer) { Some(Event::Headers(Server(request))) => request, - _ => panic!(), + _ => unreachable!("server stream queue must start with Headers"), } } @@ -303,7 +318,13 @@ impl Recv { Some(Event::Headers(Client(response))) => Poll::Ready(Ok(response)), Some(_) => panic!("poll_response called after response returned"), None => { - stream.state.ensure_recv_open()?; + if !stream.state.ensure_recv_open()? { + proto_err!(stream: "poll_response: stream={:?} is not opened;", stream.id); + return Poll::Ready(Err(Error::library_reset( + stream.id, + Reason::PROTOCOL_ERROR, + ))); + } stream.recv_task = Some(cx.waker().clone()); Poll::Pending @@ -316,16 +337,13 @@ impl Recv { &mut self, frame: frame::Headers, stream: &mut store::Ptr, - ) -> Result<(), RecvError> { + ) -> Result<(), Error> { // Transition the state stream.state.recv_close()?; if stream.ensure_content_length_zero().is_err() { proto_err!(stream: "recv_trailers: content-length is not zero; stream={:?};", stream.id); - return Err(RecvError::Stream { - id: stream.id, - reason: Reason::PROTOCOL_ERROR, - }); + return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR)); } let trailers = frame.into_fields(); @@ -341,7 +359,7 @@ impl Recv { /// Releases capacity of the connection pub fn release_connection_capacity(&mut self, capacity: WindowSize, task: &mut Option) { - log::trace!( + tracing::trace!( "release_connection_capacity; size={}, connection in_flight_data={}", capacity, self.in_flight_data, @@ -351,7 +369,9 @@ impl Recv { self.in_flight_data -= capacity; // Assign capacity to connection - self.flow.assign_capacity(capacity); + // TODO: proper error handling + let _res = self.flow.assign_capacity(capacity); + debug_assert!(_res.is_ok()); if self.flow.unclaimed_capacity().is_some() { if let Some(task) = task.take() { @@ -367,7 +387,7 @@ impl Recv { stream: &mut store::Ptr, task: &mut Option, ) -> Result<(), UserError> { - log::trace!("release_capacity; size={}", capacity); + tracing::trace!("release_capacity; size={}", capacity); if capacity > stream.in_flight_recv_data { return Err(UserError::ReleaseCapacityTooBig); @@ -379,7 +399,9 @@ impl Recv { stream.in_flight_recv_data -= capacity; // Assign capacity to stream - stream.recv_flow.assign_capacity(capacity); + // TODO: proper error handling + let _res = stream.recv_flow.assign_capacity(capacity); + debug_assert!(_res.is_ok()); if stream.recv_flow.unclaimed_capacity().is_some() { // Queue the stream for sending the WINDOW_UPDATE frame. @@ -401,7 +423,7 @@ impl Recv { return; } - log::trace!( + tracing::trace!( "auto-release closed stream ({:?}) capacity: {:?}", stream.id, stream.in_flight_recv_data, @@ -425,8 +447,12 @@ impl Recv { /// /// The `task` is an optional parked task for the `Connection` that might /// be blocked on needing more window capacity. - pub fn set_target_connection_window(&mut self, target: WindowSize, task: &mut Option) { - log::trace!( + pub fn set_target_connection_window( + &mut self, + target: WindowSize, + task: &mut Option, + ) -> Result<(), Reason> { + tracing::trace!( "set_target_connection_window; target={}; available={}, reserved={}", target, self.flow.available(), @@ -438,11 +464,15 @@ impl Recv { // // Update the flow controller with the difference between the new // target and the current target. - let current = (self.flow.available() + self.in_flight_data).checked_size(); + let current = self + .flow + .available() + .add(self.in_flight_data)? + .checked_size(); if target > current { - self.flow.assign_capacity(target - current); + self.flow.assign_capacity(target - current)?; } else { - self.flow.claim_capacity(current - target); + self.flow.claim_capacity(current - target)?; } // If changing the target capacity means we gained a bunch of capacity, @@ -453,67 +483,77 @@ impl Recv { task.wake(); } } + Ok(()) } pub(crate) fn apply_local_settings( &mut self, settings: &frame::Settings, store: &mut Store, - ) -> Result<(), RecvError> { - let target = if let Some(val) = settings.initial_window_size() { - val - } else { - return Ok(()); - }; + ) -> Result<(), proto::Error> { + if let Some(val) = settings.is_extended_connect_protocol_enabled() { + self.is_extended_connect_protocol_enabled = val; + } - let old_sz = self.init_window_sz; - self.init_window_sz = target; + if let Some(target) = settings.initial_window_size() { + let old_sz = self.init_window_sz; + self.init_window_sz = target; - log::trace!("update_initial_window_size; new={}; old={}", target, old_sz,); + tracing::trace!("update_initial_window_size; new={}; old={}", target, old_sz,); - // Per RFC 7540 §6.9.2: - // - // In addition to changing the flow-control window for streams that are - // not yet active, a SETTINGS frame can alter the initial flow-control - // window size for streams with active flow-control windows (that is, - // streams in the "open" or "half-closed (remote)" state). When the - // value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust - // the size of all stream flow-control windows that it maintains by the - // difference between the new value and the old value. - // - // A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available - // space in a flow-control window to become negative. A sender MUST - // track the negative flow-control window and MUST NOT send new - // flow-controlled frames until it receives WINDOW_UPDATE frames that - // cause the flow-control window to become positive. - - if target < old_sz { - // We must decrease the (local) window on every open stream. - let dec = old_sz - target; - log::trace!("decrementing all windows; dec={}", dec); - - store.for_each(|mut stream| { - stream.recv_flow.dec_recv_window(dec); - Ok(()) - }) - } else if target > old_sz { - // We must increase the (local) window on every open stream. - let inc = target - old_sz; - log::trace!("incrementing all windows; inc={}", inc); - store.for_each(|mut stream| { - // XXX: Shouldn't the peer have already noticed our - // overflow and sent us a GOAWAY? - stream - .recv_flow - .inc_window(inc) - .map_err(RecvError::Connection)?; - stream.recv_flow.assign_capacity(inc); - Ok(()) - }) - } else { - // size is the same... so do nothing - Ok(()) + // Per RFC 7540 §6.9.2: + // + // In addition to changing the flow-control window for streams that are + // not yet active, a SETTINGS frame can alter the initial flow-control + // window size for streams with active flow-control windows (that is, + // streams in the "open" or "half-closed (remote)" state). When the + // value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust + // the size of all stream flow-control windows that it maintains by the + // difference between the new value and the old value. + // + // A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available + // space in a flow-control window to become negative. A sender MUST + // track the negative flow-control window and MUST NOT send new + // flow-controlled frames until it receives WINDOW_UPDATE frames that + // cause the flow-control window to become positive. + + match target.cmp(&old_sz) { + Ordering::Less => { + // We must decrease the (local) window on every open stream. + let dec = old_sz - target; + tracing::trace!("decrementing all windows; dec={}", dec); + + store.try_for_each(|mut stream| { + stream + .recv_flow + .dec_recv_window(dec) + .map_err(proto::Error::library_go_away)?; + Ok::<_, proto::Error>(()) + })?; + } + Ordering::Greater => { + // We must increase the (local) window on every open stream. + let inc = target - old_sz; + tracing::trace!("incrementing all windows; inc={}", inc); + store.try_for_each(|mut stream| { + // XXX: Shouldn't the peer have already noticed our + // overflow and sent us a GOAWAY? + stream + .recv_flow + .inc_window(inc) + .map_err(proto::Error::library_go_away)?; + stream + .recv_flow + .assign_capacity(inc) + .map_err(proto::Error::library_go_away)?; + Ok::<_, proto::Error>(()) + })?; + } + Ordering::Equal => (), + } } + + Ok(()) } pub fn is_end_stream(&self, stream: &store::Ptr) -> bool { @@ -524,11 +564,7 @@ impl Recv { stream.pending_recv.is_empty() } - pub fn recv_data( - &mut self, - frame: frame::Data, - stream: &mut store::Ptr, - ) -> Result<(), RecvError> { + pub fn recv_data(&mut self, frame: frame::Data, stream: &mut store::Ptr) -> Result<(), Error> { let sz = frame.payload().len(); // This should have been enforced at the codec::FramedRead layer, so @@ -537,7 +573,7 @@ impl Recv { let sz = sz as WindowSize; - let is_ignoring_frame = stream.state.is_local_reset(); + let is_ignoring_frame = stream.state.is_local_error(); if !is_ignoring_frame && !stream.state.is_recv_streaming() { // TODO: There are cases where this can be a stream error of @@ -546,10 +582,10 @@ impl Recv { // Receiving a DATA frame when not expecting one is a protocol // error. proto_err!(conn: "unexpected DATA frame; stream={:?}", stream.id); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } - log::trace!( + tracing::trace!( "recv_data; size={}; connection={}; stream={}", sz, self.flow.window_size(), @@ -557,7 +593,7 @@ impl Recv { ); if is_ignoring_frame { - log::trace!( + tracing::trace!( "recv_data; frame ignored on locally reset {:?} for some time", stream.id, ); @@ -577,10 +613,7 @@ impl Recv { // So, for violating the **stream** window, we can send either a // stream or connection error. We've opted to send a stream // error. - return Err(RecvError::Stream { - id: stream.id, - reason: Reason::FLOW_CONTROL_ERROR, - }); + return Err(Error::library_reset(stream.id, Reason::FLOW_CONTROL_ERROR)); } if stream.dec_content_length(frame.payload().len()).is_err() { @@ -589,10 +622,7 @@ impl Recv { stream.id, frame.payload().len(), ); - return Err(RecvError::Stream { - id: stream.id, - reason: Reason::PROTOCOL_ERROR, - }); + return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR)); } if frame.is_end_stream() { @@ -602,20 +632,30 @@ impl Recv { stream.id, frame.payload().len(), ); - return Err(RecvError::Stream { - id: stream.id, - reason: Reason::PROTOCOL_ERROR, - }); + return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR)); } if stream.state.recv_close().is_err() { proto_err!(conn: "recv_data: failed to transition to closed state; stream={:?}", stream.id); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } } + // Received a frame, but no one cared about it. fix issue#648 + if !stream.is_recv { + tracing::trace!( + "recv_data; frame ignored on stream release {:?} for some time", + stream.id, + ); + self.release_connection_capacity(sz, &mut None); + return Ok(()); + } + // Update stream level flow control - stream.recv_flow.send_data(sz); + stream + .recv_flow + .send_data(sz) + .map_err(proto::Error::library_go_away)?; // Track the data as in-flight stream.in_flight_recv_data += sz; @@ -629,7 +669,7 @@ impl Recv { Ok(()) } - pub fn ignore_data(&mut self, sz: WindowSize) -> Result<(), RecvError> { + pub fn ignore_data(&mut self, sz: WindowSize) -> Result<(), Error> { // Ensure that there is enough capacity on the connection... self.consume_connection_window(sz)?; @@ -645,18 +685,18 @@ impl Recv { Ok(()) } - pub fn consume_connection_window(&mut self, sz: WindowSize) -> Result<(), RecvError> { + pub fn consume_connection_window(&mut self, sz: WindowSize) -> Result<(), Error> { if self.flow.window_size() < sz { - log::debug!( + tracing::debug!( "connection error FLOW_CONTROL_ERROR -- window_size ({:?}) < sz ({:?});", self.flow.window_size(), sz, ); - return Err(RecvError::Connection(Reason::FLOW_CONTROL_ERROR)); + return Err(Error::library_go_away(Reason::FLOW_CONTROL_ERROR)); } // Update connection level flow control - self.flow.send_data(sz); + self.flow.send_data(sz).map_err(Error::library_go_away)?; // Track the data as in-flight self.in_flight_data += sz; @@ -667,7 +707,7 @@ impl Recv { &mut self, frame: frame::PushPromise, stream: &mut store::Ptr, - ) -> Result<(), RecvError> { + ) -> Result<(), Error> { stream.state.reserve_remote()?; if frame.is_over_size() { // A frame is over size if the decoded header block was bigger than @@ -681,15 +721,15 @@ impl Recv { // So, if peer is a server, we'll send a 431. In either case, // an error is recorded, which will send a REFUSED_STREAM, // since we don't want any of the data frames either. - log::debug!( + tracing::debug!( "stream error REFUSED_STREAM -- recv_push_promise: \ headers frame is over size; promised_id={:?};", frame.promised_id(), ); - return Err(RecvError::Stream { - id: frame.promised_id(), - reason: Reason::REFUSED_STREAM, - }); + return Err(Error::library_reset( + frame.promised_id(), + Reason::REFUSED_STREAM, + )); } let promised_id = frame.promised_id(); @@ -712,10 +752,7 @@ impl Recv { promised_id, ), } - return Err(RecvError::Stream { - id: promised_id, - reason: Reason::PROTOCOL_ERROR, - }); + return Err(Error::library_reset(promised_id, Reason::PROTOCOL_ERROR)); } use super::peer::PollMessage::*; @@ -730,7 +767,7 @@ impl Recv { pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> { if let Ok(next) = self.next_stream_id { if id >= next { - log::debug!( + tracing::debug!( "stream ID implicitly closed, PROTOCOL_ERROR; stream={:?}", id ); @@ -743,20 +780,48 @@ impl Recv { } /// Handle remote sending an explicit RST_STREAM. - pub fn recv_reset(&mut self, frame: frame::Reset, stream: &mut Stream) { + pub fn recv_reset( + &mut self, + frame: frame::Reset, + stream: &mut Stream, + counts: &mut Counts, + ) -> Result<(), Error> { + // Reseting a stream that the user hasn't accepted is possible, + // but should be done with care. These streams will continue + // to take up memory in the accept queue, but will no longer be + // counted as "concurrent" streams. + // + // So, we have a separate limit for these. + // + // See https://github.com/hyperium/hyper/issues/2877 + if stream.is_pending_accept { + if counts.can_inc_num_remote_reset_streams() { + counts.inc_num_remote_reset_streams(); + } else { + tracing::warn!( + "recv_reset; remotely-reset pending-accept streams reached limit ({:?})", + counts.max_remote_reset_streams(), + ); + return Err(Error::library_go_away_data( + Reason::ENHANCE_YOUR_CALM, + "too_many_resets", + )); + } + } + // Notify the stream - stream - .state - .recv_reset(frame.reason(), stream.is_pending_send); + stream.state.recv_reset(frame, stream.is_pending_send); stream.notify_send(); stream.notify_recv(); + + Ok(()) } - /// Handle a received error - pub fn recv_err(&mut self, err: &proto::Error, stream: &mut Stream) { + /// Handle a connection-level error + pub fn handle_error(&mut self, err: &proto::Error, stream: &mut Stream) { // Receive an error - stream.state.recv_err(err); + stream.state.handle_error(err); // If a receiver is waiting, notify it stream.notify_send(); @@ -775,7 +840,7 @@ impl Recv { } pub(super) fn clear_recv_buffer(&mut self, stream: &mut Stream) { - while let Some(_) = stream.pending_recv.pop_front(&mut self.buffer) { + while stream.pending_recv.pop_front(&mut self.buffer).is_some() { // drop it } } @@ -787,11 +852,11 @@ impl Recv { self.max_stream_id } - pub fn next_stream_id(&self) -> Result { + pub fn next_stream_id(&self) -> Result { if let Ok(id) = self.next_stream_id { Ok(id) } else { - Err(RecvError::Connection(Reason::PROTOCOL_ERROR)) + Err(Error::library_go_away(Reason::PROTOCOL_ERROR)) } } @@ -805,11 +870,21 @@ impl Recv { } } + pub(super) fn maybe_reset_next_stream_id(&mut self, id: StreamId) { + if let Ok(next_id) = self.next_stream_id { + // !Peer::is_local_init should have been called beforehand + debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated()); + if id >= next_id { + self.next_stream_id = id.next_id(); + } + } + } + /// Returns true if the remote peer can reserve a stream with the given ID. - pub fn ensure_can_reserve(&self) -> Result<(), RecvError> { + pub fn ensure_can_reserve(&self) -> Result<(), Error> { if !self.is_push_enabled { proto_err!(conn: "recv_push_promise: push is disabled"); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } Ok(()) @@ -817,20 +892,11 @@ impl Recv { /// Add a locally reset stream to queue to be eventually reaped. pub fn enqueue_reset_expiration(&mut self, stream: &mut store::Ptr, counts: &mut Counts) { - if !stream.state.is_local_reset() || stream.is_pending_reset_expiration() { + if !stream.state.is_local_error() || stream.is_pending_reset_expiration() { return; } - log::trace!("enqueue_reset_expiration; {:?}", stream.id); - - if !counts.can_inc_num_reset_streams() { - // try to evict 1 stream if possible - // if max allow is 0, this won't be able to evict, - // and then we'll just bail after - if let Some(evicted) = self.pending_reset_expired.pop(stream.store_mut()) { - counts.transition_after(evicted, true); - } - } + tracing::trace!("enqueue_reset_expiration; {:?}", stream.id); if counts.can_inc_num_reset_streams() { counts.inc_num_reset_streams(); @@ -864,13 +930,18 @@ impl Recv { } pub fn clear_expired_reset_streams(&mut self, store: &mut Store, counts: &mut Counts) { - let now = Instant::now(); - let reset_duration = self.reset_duration; - while let Some(stream) = self.pending_reset_expired.pop_if(store, |stream| { - let reset_at = stream.reset_at.expect("reset_at must be set if in queue"); - now - reset_at > reset_duration - }) { - counts.transition_after(stream, true); + if !self.pending_reset_expired.is_empty() { + let now = Instant::now(); + let reset_duration = self.reset_duration; + while let Some(stream) = self.pending_reset_expired.pop_if(store, |stream| { + let reset_at = stream.reset_at.expect("reset_at must be set if in queue"); + // rust-lang/rust#86470 tracks a bug in the standard library where `Instant` + // subtraction can panic (because, on some platforms, `Instant` isn't actually + // monotonic). We use a saturating operation to avoid this panic here. + now.saturating_duration_since(reset_at) > reset_duration + }) { + counts.transition_after(stream, true); + } } } @@ -891,7 +962,7 @@ impl Recv { fn clear_stream_window_update_queue(&mut self, store: &mut Store, counts: &mut Counts) { while let Some(stream) = self.pending_window_updates.pop(store) { counts.transition(stream, |_, stream| { - log::trace!("clear_stream_window_update_queue; stream={:?}", stream.id); + tracing::trace!("clear_stream_window_update_queue; stream={:?}", stream.id); }) } } @@ -981,7 +1052,7 @@ impl Recv { }; counts.transition(stream, |_, stream| { - log::trace!("pending_window_updates -- pop; stream={:?}", stream.id); + tracing::trace!("pending_window_updates -- pop; stream={:?}", stream.id); debug_assert!(!stream.is_pending_window_update); if !stream.state.is_recv_streaming() { @@ -1022,7 +1093,6 @@ impl Recv { cx: &Context, stream: &mut Stream, ) -> Poll>> { - // TODO: Return error when the stream is reset match stream.pending_recv.pop_front(&mut self.buffer) { Some(Event::Data(payload)) => Poll::Ready(Some(Ok(payload))), Some(event) => { @@ -1083,19 +1153,14 @@ impl Recv { impl Open { pub fn is_push_promise(&self) -> bool { - use self::Open::*; - - match *self { - PushPromise => true, - _ => false, - } + matches!(*self, Self::PushPromise) } } // ===== impl RecvHeaderBlockError ===== -impl From for RecvHeaderBlockError { - fn from(err: RecvError) -> Self { +impl From for RecvHeaderBlockError { + fn from(err: Error) -> Self { RecvHeaderBlockError::State(err) } } diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 4d38593ec..626e61a33 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -2,15 +2,16 @@ use super::{ store, Buffer, Codec, Config, Counts, Frame, Prioritize, Prioritized, Store, Stream, StreamId, StreamIdOverflow, WindowSize, }; -use crate::codec::{RecvError, UserError}; +use crate::codec::UserError; use crate::frame::{self, Reason}; +use crate::proto::{self, Error, Initiator}; use bytes::Buf; -use http; -use std::task::{Context, Poll, Waker}; use tokio::io::AsyncWrite; +use std::cmp::Ordering; use std::io; +use std::task::{Context, Poll, Waker}; /// Manages state transitions related to outbound frames. #[derive(Debug)] @@ -32,6 +33,11 @@ pub(super) struct Send { /// Prioritization layer prioritize: Prioritize, + + is_push_enabled: bool, + + /// If extended connect protocol is enabled. + is_extended_connect_protocol_enabled: bool, } /// A value to detect which public API has called `poll_reset`. @@ -49,6 +55,8 @@ impl Send { max_stream_id: StreamId::MAX, next_stream_id: Ok(config.local_next_stream_id), prioritize: Prioritize::new(config), + is_push_enabled: true, + is_extended_connect_protocol_enabled: false, } } @@ -77,11 +85,11 @@ impl Send { || fields.contains_key("keep-alive") || fields.contains_key("proxy-connection") { - log::debug!("illegal connection-specific headers found"); + tracing::debug!("illegal connection-specific headers found"); return Err(UserError::MalformedHeaders); } else if let Some(te) = fields.get(http::header::TE) { if te != "trailers" { - log::debug!("illegal connection-specific headers found"); + tracing::debug!("illegal connection-specific headers found"); return Err(UserError::MalformedHeaders); } } @@ -95,7 +103,11 @@ impl Send { stream: &mut store::Ptr, task: &mut Option, ) -> Result<(), UserError> { - log::trace!( + if !self.is_push_enabled { + return Err(UserError::PeerDisabledServerPush); + } + + tracing::trace!( "send_push_promise; frame={:?}; init_window={:?}", frame, self.init_window_sz @@ -118,7 +130,7 @@ impl Send { counts: &mut Counts, task: &mut Option, ) -> Result<(), UserError> { - log::trace!( + tracing::trace!( "send_headers; frame={:?}; init_window={:?}", frame, self.init_window_sz @@ -126,31 +138,32 @@ impl Send { Self::check_headers(frame.fields())?; - if frame.has_too_big_field() { - return Err(UserError::HeaderTooBig); - } - let end_stream = frame.is_end_stream(); // Update the state stream.state.send_open(end_stream)?; - if counts.peer().is_local_init(frame.stream_id()) { - // If we're waiting on a PushPromise anyway - // handle potentially queueing the stream at that point - if !stream.is_pending_push { - if counts.can_inc_num_send_streams() { - counts.inc_num_send_streams(stream); - } else { - self.prioritize.queue_open(stream); - } - } + let mut pending_open = false; + if counts.peer().is_local_init(frame.stream_id()) && !stream.is_pending_push { + self.prioritize.queue_open(stream); + pending_open = true; } // Queue the frame for sending + // + // This call expects that, since new streams are in the open queue, new + // streams won't be pushed on pending_send. self.prioritize .queue_frame(frame.into(), buffer, stream, task); + // Need to notify the connection when pushing onto pending_open since + // queue_frame only notifies for pending_send. + if pending_open { + if let Some(task) = task.take() { + task.wake(); + } + } + Ok(()) } @@ -158,6 +171,7 @@ impl Send { pub fn send_reset( &mut self, reason: Reason, + initiator: Initiator, buffer: &mut Buffer>, stream: &mut store::Ptr, counts: &mut Counts, @@ -166,14 +180,16 @@ impl Send { let is_reset = stream.state.is_reset(); let is_closed = stream.state.is_closed(); let is_empty = stream.pending_send.is_empty(); + let stream_id = stream.id; - log::trace!( - "send_reset(..., reason={:?}, stream={:?}, ..., \ + tracing::trace!( + "send_reset(..., reason={:?}, initiator={:?}, stream={:?}, ..., \ is_reset={:?}; is_closed={:?}; pending_send.is_empty={:?}; \ state={:?} \ ", reason, - stream.id, + initiator, + stream_id, is_reset, is_closed, is_empty, @@ -182,23 +198,23 @@ impl Send { if is_reset { // Don't double reset - log::trace!( + tracing::trace!( " -> not sending RST_STREAM ({:?} is already reset)", - stream.id + stream_id ); return; } // Transition the state to reset no matter what. - stream.state.set_reset(reason); + stream.state.set_reset(stream_id, reason, initiator); // If closed AND the send queue is flushed, then the stream cannot be // reset explicitly, either. Implicit resets can still be queued. if is_closed && is_empty { - log::trace!( + tracing::trace!( " -> not sending explicit RST_STREAM ({:?} was closed \ and send queue was flushed)", - stream.id + stream_id ); return; } @@ -211,7 +227,7 @@ impl Send { let frame = frame::Reset::new(stream.id, reason); - log::trace!("send_reset -- queueing; frame={:?}", frame); + tracing::trace!("send_reset -- queueing; frame={:?}", frame); self.prioritize .queue_frame(frame.into(), buffer, stream, task); self.prioritize.reclaim_all_capacity(stream, counts); @@ -263,13 +279,9 @@ impl Send { return Err(UserError::UnexpectedFrameType); } - if frame.has_too_big_field() { - return Err(UserError::HeaderTooBig); - } - stream.state.send_close(); - log::trace!("send_trailers -- queuing; frame={:?}", frame); + tracing::trace!("send_trailers -- queuing; frame={:?}", frame); self.prioritize .queue_frame(frame.into(), buffer, stream, task); @@ -326,14 +338,7 @@ impl Send { /// Current available stream send capacity pub fn capacity(&self, stream: &mut store::Ptr) -> WindowSize { - let available = stream.send_flow.available().as_size(); - let buffered = stream.buffered_send_data; - - if available <= buffered { - 0 - } else { - available - buffered - } + stream.capacity(self.prioritize.max_buffer_size()) } pub fn poll_reset( @@ -370,9 +375,16 @@ impl Send { task: &mut Option, ) -> Result<(), Reason> { if let Err(e) = self.prioritize.recv_stream_window_update(sz, stream) { - log::debug!("recv_stream_window_update !!; err={:?}", e); - - self.send_reset(Reason::FLOW_CONTROL_ERROR, buffer, stream, counts, task); + tracing::debug!("recv_stream_window_update !!; err={:?}", e); + + self.send_reset( + Reason::FLOW_CONTROL_ERROR, + Initiator::Library, + buffer, + stream, + counts, + task, + ); return Err(e); } @@ -380,7 +392,7 @@ impl Send { Ok(()) } - pub(super) fn recv_go_away(&mut self, last_stream_id: StreamId) -> Result<(), RecvError> { + pub(super) fn recv_go_away(&mut self, last_stream_id: StreamId) -> Result<(), Error> { if last_stream_id > self.max_stream_id { // The remote endpoint sent a `GOAWAY` frame indicating a stream // that we never sent, or that we have already terminated on account @@ -393,14 +405,14 @@ impl Send { "recv_go_away: last_stream_id ({:?}) > max_stream_id ({:?})", last_stream_id, self.max_stream_id, ); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } self.max_stream_id = last_stream_id; Ok(()) } - pub fn recv_err( + pub fn handle_error( &mut self, buffer: &mut Buffer>, stream: &mut store::Ptr, @@ -418,7 +430,11 @@ impl Send { store: &mut Store, counts: &mut Counts, task: &mut Option, - ) -> Result<(), RecvError> { + ) -> Result<(), Error> { + if let Some(val) = settings.is_extended_connect_protocol_enabled() { + self.is_extended_connect_protocol_enabled = val; + } + // Applies an update to the remote endpoint's initial window size. // // Per RFC 7540 §6.9.2: @@ -440,62 +456,84 @@ impl Send { let old_val = self.init_window_sz; self.init_window_sz = val; - if val < old_val { - // We must decrease the (remote) window on every open stream. - let dec = old_val - val; - log::trace!("decrementing all windows; dec={}", dec); - - let mut total_reclaimed = 0; - store.for_each(|mut stream| { - let stream = &mut *stream; - - stream.send_flow.dec_send_window(dec); - - // It's possible that decreasing the window causes - // `window_size` (the stream-specific window) to fall below - // `available` (the portion of the connection-level window - // that we have allocated to the stream). - // In this case, we should take that excess allocation away - // and reassign it to other streams. - let window_size = stream.send_flow.window_size(); - let available = stream.send_flow.available().as_size(); - let reclaimed = if available > window_size { - // Drop down to `window_size`. - let reclaim = available - window_size; - stream.send_flow.claim_capacity(reclaim); - total_reclaimed += reclaim; - reclaim - } else { - 0 - }; - - log::trace!( - "decremented stream window; id={:?}; decr={}; reclaimed={}; flow={:?}", - stream.id, - dec, - reclaimed, - stream.send_flow - ); - - // TODO: Should this notify the producer when the capacity - // of a stream is reduced? Maybe it should if the capacity - // is reduced to zero, allowing the producer to stop work. - - Ok::<_, RecvError>(()) - })?; - - self.prioritize - .assign_connection_capacity(total_reclaimed, store, counts); - } else if val > old_val { - let inc = val - old_val; - - store.for_each(|mut stream| { - self.recv_stream_window_update(inc, buffer, &mut stream, counts, task) - .map_err(RecvError::Connection) - })?; + match val.cmp(&old_val) { + Ordering::Less => { + // We must decrease the (remote) window on every open stream. + let dec = old_val - val; + tracing::trace!("decrementing all windows; dec={}", dec); + + let mut total_reclaimed = 0; + store.try_for_each(|mut stream| { + let stream = &mut *stream; + + tracing::trace!( + "decrementing stream window; id={:?}; decr={}; flow={:?}", + stream.id, + dec, + stream.send_flow + ); + + // TODO: this decrement can underflow based on received frames! + stream + .send_flow + .dec_send_window(dec) + .map_err(proto::Error::library_go_away)?; + + // It's possible that decreasing the window causes + // `window_size` (the stream-specific window) to fall below + // `available` (the portion of the connection-level window + // that we have allocated to the stream). + // In this case, we should take that excess allocation away + // and reassign it to other streams. + let window_size = stream.send_flow.window_size(); + let available = stream.send_flow.available().as_size(); + let reclaimed = if available > window_size { + // Drop down to `window_size`. + let reclaim = available - window_size; + stream + .send_flow + .claim_capacity(reclaim) + .map_err(proto::Error::library_go_away)?; + total_reclaimed += reclaim; + reclaim + } else { + 0 + }; + + tracing::trace!( + "decremented stream window; id={:?}; decr={}; reclaimed={}; flow={:?}", + stream.id, + dec, + reclaimed, + stream.send_flow + ); + + // TODO: Should this notify the producer when the capacity + // of a stream is reduced? Maybe it should if the capacity + // is reduced to zero, allowing the producer to stop work. + + Ok::<_, proto::Error>(()) + })?; + + self.prioritize + .assign_connection_capacity(total_reclaimed, store, counts); + } + Ordering::Greater => { + let inc = val - old_val; + + store.try_for_each(|mut stream| { + self.recv_stream_window_update(inc, buffer, &mut stream, counts, task) + .map_err(Error::library_go_away) + })?; + } + Ordering::Equal => (), } } + if let Some(val) = settings.is_push_enabled() { + self.is_push_enabled = val + } + Ok(()) } @@ -530,4 +568,18 @@ impl Send { true } } + + pub(super) fn maybe_reset_next_stream_id(&mut self, id: StreamId) { + if let Ok(next_id) = self.next_stream_id { + // Peer::is_local_init should have been called beforehand + debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated()); + if id >= next_id { + self.next_stream_id = id.next_id(); + } + } + } + + pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { + self.is_extended_connect_protocol_enabled + } } diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index 26323124d..5256f09cf 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -1,9 +1,8 @@ use std::io; -use crate::codec::UserError::*; -use crate::codec::{RecvError, UserError}; -use crate::frame::Reason; -use crate::proto::{self, PollReset}; +use crate::codec::UserError; +use crate::frame::{self, Reason, StreamId}; +use crate::proto::{self, Error, Initiator, PollReset}; use self::Inner::*; use self::Peer::*; @@ -53,7 +52,7 @@ pub struct State { inner: Inner, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] enum Inner { Idle, // TODO: these states shouldn't count against concurrency limits: @@ -65,18 +64,17 @@ enum Inner { Closed(Cause), } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Default)] enum Peer { + #[default] AwaitingHeaders, Streaming, } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] enum Cause { EndStream, - Proto(Reason), - LocallyReset(Reason), - Io, + Error(Error), /// This indicates to the connection that a reset frame must be sent out /// once the send queue has been flushed. @@ -85,7 +83,7 @@ enum Cause { /// - User drops all references to a stream, so we want to CANCEL the it. /// - Header block size was too large, so we want to REFUSE, possibly /// after sending a 431 response frame. - Scheduled(Reason), + ScheduledLibraryReset(Reason), } impl State { @@ -123,7 +121,7 @@ impl State { } _ => { // All other transitions result in a protocol error - return Err(UnexpectedFrameType); + return Err(UserError::UnexpectedFrameType); } }; @@ -133,9 +131,9 @@ impl State { /// Opens the receive-half of the stream when a HEADERS frame is received. /// /// Returns true if this transitions the state to Open. - pub fn recv_open(&mut self, eos: bool) -> Result { - let remote = Streaming; + pub fn recv_open(&mut self, frame: &frame::Headers) -> Result { let mut initial = false; + let eos = frame.is_end_stream(); self.inner = match self.inner { Idle => { @@ -146,7 +144,12 @@ impl State { } else { Open { local: AwaitingHeaders, - remote, + remote: if frame.is_informational() { + tracing::trace!("skipping 1xx response headers"); + AwaitingHeaders + } else { + Streaming + }, } } } @@ -155,6 +158,9 @@ impl State { if eos { Closed(Cause::EndStream) + } else if frame.is_informational() { + tracing::trace!("skipping 1xx response headers"); + ReservedRemote } else { HalfClosedLocal(Streaming) } @@ -166,20 +172,31 @@ impl State { if eos { HalfClosedRemote(local) } else { - Open { local, remote } + Open { + local, + remote: if frame.is_informational() { + tracing::trace!("skipping 1xx response headers"); + AwaitingHeaders + } else { + Streaming + }, + } } } HalfClosedLocal(AwaitingHeaders) => { if eos { Closed(Cause::EndStream) + } else if frame.is_informational() { + tracing::trace!("skipping 1xx response headers"); + HalfClosedLocal(AwaitingHeaders) } else { - HalfClosedLocal(remote) + HalfClosedLocal(Streaming) } } - state => { + ref state => { // All other transitions result in a protocol error proto_err!(conn: "recv_open: in unexpected state {:?}", state); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } }; @@ -187,15 +204,15 @@ impl State { } /// Transition from Idle -> ReservedRemote - pub fn reserve_remote(&mut self) -> Result<(), RecvError> { + pub fn reserve_remote(&mut self) -> Result<(), Error> { match self.inner { Idle => { self.inner = ReservedRemote; Ok(()) } - state => { + ref state => { proto_err!(conn: "reserve_remote: in unexpected state {:?}", state); - Err(RecvError::Connection(Reason::PROTOCOL_ERROR)) + Err(Error::library_go_away(Reason::PROTOCOL_ERROR)) } } } @@ -212,22 +229,22 @@ impl State { } /// Indicates that the remote side will not send more data to the local. - pub fn recv_close(&mut self) -> Result<(), RecvError> { + pub fn recv_close(&mut self) -> Result<(), Error> { match self.inner { Open { local, .. } => { // The remote side will continue to receive data. - log::trace!("recv_close: Open => HalfClosedRemote({:?})", local); + tracing::trace!("recv_close: Open => HalfClosedRemote({:?})", local); self.inner = HalfClosedRemote(local); Ok(()) } HalfClosedLocal(..) => { - log::trace!("recv_close: HalfClosedLocal => Closed"); + tracing::trace!("recv_close: HalfClosedLocal => Closed"); self.inner = Closed(Cause::EndStream); Ok(()) } - state => { + ref state => { proto_err!(conn: "recv_close: in unexpected state {:?}", state); - Err(RecvError::Connection(Reason::PROTOCOL_ERROR)) + Err(Error::library_go_away(Reason::PROTOCOL_ERROR)) } } } @@ -235,9 +252,9 @@ impl State { /// The remote explicitly sent a RST_STREAM. /// /// # Arguments - /// - `reason`: the reason field of the received RST_STREAM frame. + /// - `frame`: the received RST_STREAM frame. /// - `queued`: true if this stream has frames in the pending send queue. - pub fn recv_reset(&mut self, reason: Reason, queued: bool) { + pub fn recv_reset(&mut self, frame: frame::Reset, queued: bool) { match self.inner { // If the stream is already in a `Closed` state, do nothing, // provided that there are no frames still in the send queue. @@ -256,30 +273,28 @@ impl State { // In either of these cases, we want to overwrite the stream's // previous state with the received RST_STREAM, so that the queue // will be cleared by `Prioritize::pop_frame`. - state => { - log::trace!( - "recv_reset; reason={:?}; state={:?}; queued={:?}", - reason, + ref state => { + tracing::trace!( + "recv_reset; frame={:?}; state={:?}; queued={:?}", + frame, state, queued ); - self.inner = Closed(Cause::Proto(reason)); + self.inner = Closed(Cause::Error(Error::remote_reset( + frame.stream_id(), + frame.reason(), + ))); } } } - /// We noticed a protocol error. - pub fn recv_err(&mut self, err: &proto::Error) { - use crate::proto::Error::*; - + /// Handle a connection-level error. + pub fn handle_error(&mut self, err: &proto::Error) { match self.inner { Closed(..) => {} _ => { - log::trace!("recv_err; err={:?}", err); - self.inner = Closed(match *err { - Proto(reason) => Cause::LocallyReset(reason), - Io(..) => Cause::Io, - }); + tracing::trace!("handle_error; err={:?}", err); + self.inner = Closed(Cause::Error(err.clone())); } } } @@ -287,9 +302,15 @@ impl State { pub fn recv_eof(&mut self) { match self.inner { Closed(..) => {} - s => { - log::trace!("recv_eof; state={:?}", s); - self.inner = Closed(Cause::Io); + ref state => { + tracing::trace!("recv_eof; state={:?}", state); + self.inner = Closed(Cause::Error( + io::Error::new( + io::ErrorKind::BrokenPipe, + "stream closed because of a broken pipe", + ) + .into(), + )); } } } @@ -299,50 +320,54 @@ impl State { match self.inner { Open { remote, .. } => { // The remote side will continue to receive data. - log::trace!("send_close: Open => HalfClosedLocal({:?})", remote); + tracing::trace!("send_close: Open => HalfClosedLocal({:?})", remote); self.inner = HalfClosedLocal(remote); } HalfClosedRemote(..) => { - log::trace!("send_close: HalfClosedRemote => Closed"); + tracing::trace!("send_close: HalfClosedRemote => Closed"); self.inner = Closed(Cause::EndStream); } - state => panic!("send_close: unexpected state {:?}", state), + ref state => panic!("send_close: unexpected state {:?}", state), } } /// Set the stream state to reset locally. - pub fn set_reset(&mut self, reason: Reason) { - self.inner = Closed(Cause::LocallyReset(reason)); + pub fn set_reset(&mut self, stream_id: StreamId, reason: Reason, initiator: Initiator) { + self.inner = Closed(Cause::Error(Error::Reset(stream_id, reason, initiator))); } /// Set the stream state to a scheduled reset. pub fn set_scheduled_reset(&mut self, reason: Reason) { debug_assert!(!self.is_closed()); - self.inner = Closed(Cause::Scheduled(reason)); + self.inner = Closed(Cause::ScheduledLibraryReset(reason)); } pub fn get_scheduled_reset(&self) -> Option { match self.inner { - Closed(Cause::Scheduled(reason)) => Some(reason), + Closed(Cause::ScheduledLibraryReset(reason)) => Some(reason), _ => None, } } pub fn is_scheduled_reset(&self) -> bool { - match self.inner { - Closed(Cause::Scheduled(..)) => true, - _ => false, - } + matches!(self.inner, Closed(Cause::ScheduledLibraryReset(..))) } - pub fn is_local_reset(&self) -> bool { + pub fn is_local_error(&self) -> bool { match self.inner { - Closed(Cause::LocallyReset(_)) => true, - Closed(Cause::Scheduled(..)) => true, + Closed(Cause::Error(ref e)) => e.is_local(), + Closed(Cause::ScheduledLibraryReset(..)) => true, _ => false, } } + pub fn is_remote_reset(&self) -> bool { + matches!( + self.inner, + Closed(Cause::Error(Error::Reset(_, _, Initiator::Remote))) + ) + } + /// Returns true if the stream is already reset. pub fn is_reset(&self) -> bool { match self.inner { @@ -353,74 +378,66 @@ impl State { } pub fn is_send_streaming(&self) -> bool { - match self.inner { + matches!( + self.inner, Open { - local: Streaming, .. - } => true, - HalfClosedRemote(Streaming) => true, - _ => false, - } + local: Streaming, + .. + } | HalfClosedRemote(Streaming) + ) } /// Returns true when the stream is in a state to receive headers pub fn is_recv_headers(&self) -> bool { - match self.inner { - Idle => true, - Open { + matches!( + self.inner, + Idle | Open { remote: AwaitingHeaders, .. - } => true, - HalfClosedLocal(AwaitingHeaders) => true, - ReservedRemote => true, - _ => false, - } + } | HalfClosedLocal(AwaitingHeaders) + | ReservedRemote + ) } pub fn is_recv_streaming(&self) -> bool { - match self.inner { + matches!( + self.inner, Open { - remote: Streaming, .. - } => true, - HalfClosedLocal(Streaming) => true, - _ => false, - } + remote: Streaming, + .. + } | HalfClosedLocal(Streaming) + ) } pub fn is_closed(&self) -> bool { - match self.inner { - Closed(_) => true, - _ => false, - } + matches!(self.inner, Closed(_)) } pub fn is_recv_closed(&self) -> bool { - match self.inner { - Closed(..) | HalfClosedRemote(..) | ReservedLocal => true, - _ => false, - } + matches!( + self.inner, + Closed(..) | HalfClosedRemote(..) | ReservedLocal + ) } pub fn is_send_closed(&self) -> bool { - match self.inner { - Closed(..) | HalfClosedLocal(..) | ReservedRemote => true, - _ => false, - } + matches!( + self.inner, + Closed(..) | HalfClosedLocal(..) | ReservedRemote + ) } pub fn is_idle(&self) -> bool { - match self.inner { - Idle => true, - _ => false, - } + matches!(self.inner, Idle) } pub fn ensure_recv_open(&self) -> Result { // TODO: Is this correct? match self.inner { - Closed(Cause::Proto(reason)) - | Closed(Cause::LocallyReset(reason)) - | Closed(Cause::Scheduled(reason)) => Err(proto::Error::Proto(reason)), - Closed(Cause::Io) => Err(proto::Error::Io(io::ErrorKind::BrokenPipe.into())), + Closed(Cause::Error(ref e)) => Err(e.clone()), + Closed(Cause::ScheduledLibraryReset(reason)) => { + Err(proto::Error::library_go_away(reason)) + } Closed(Cause::EndStream) | HalfClosedRemote(..) | ReservedLocal => Ok(false), _ => Ok(true), } @@ -429,10 +446,10 @@ impl State { /// Returns a reason if the stream has been reset. pub(super) fn ensure_reason(&self, mode: PollReset) -> Result, crate::Error> { match self.inner { - Closed(Cause::Proto(reason)) - | Closed(Cause::LocallyReset(reason)) - | Closed(Cause::Scheduled(reason)) => Ok(Some(reason)), - Closed(Cause::Io) => Err(proto::Error::Io(io::ErrorKind::BrokenPipe.into()).into()), + Closed(Cause::Error(Error::Reset(_, reason, _))) + | Closed(Cause::Error(Error::GoAway(_, reason, _))) + | Closed(Cause::ScheduledLibraryReset(reason)) => Ok(Some(reason)), + Closed(Cause::Error(ref e)) => Err(e.clone().into()), Open { local: Streaming, .. } @@ -450,9 +467,3 @@ impl Default for State { State { inner: Inner::Idle } } } - -impl Default for Peer { - fn default() -> Self { - AwaitingHeaders - } -} diff --git a/src/proto/streams/store.rs b/src/proto/streams/store.rs index ebb1cd712..35fd6f25e 100644 --- a/src/proto/streams/store.rs +++ b/src/proto/streams/store.rs @@ -1,9 +1,8 @@ use super::*; -use slab; - use indexmap::{self, IndexMap}; +use std::convert::Infallible; use std::fmt; use std::marker::PhantomData; use std::ops; @@ -128,7 +127,21 @@ impl Store { } } - pub fn for_each(&mut self, mut f: F) -> Result<(), E> + #[allow(clippy::blocks_in_conditions)] + pub(crate) fn for_each(&mut self, mut f: F) + where + F: FnMut(Ptr), + { + match self.try_for_each(|ptr| { + f(ptr); + Ok::<_, Infallible>(()) + }) { + Ok(()) => (), + Err(infallible) => match infallible {}, + } + } + + pub fn try_for_each(&mut self, mut f: F) -> Result<(), E> where F: FnMut(Ptr) -> Result<(), E>, { @@ -204,6 +217,12 @@ impl Store { } } +// While running h2 unit/integration tests, enable this debug assertion. +// +// In practice, we don't need to ensure this. But the integration tests +// help to make sure we've cleaned up in cases where we could (like, the +// runtime isn't suddenly dropping the task for unknown reasons). +#[cfg(feature = "unstable")] impl Drop for Store { fn drop(&mut self) { use std::thread; @@ -238,10 +257,10 @@ where /// /// If the stream is already contained by the list, return `false`. pub fn push(&mut self, stream: &mut store::Ptr) -> bool { - log::trace!("Queue::push"); + tracing::trace!("Queue::push_back"); if N::is_queued(stream) { - log::trace!(" -> already queued"); + tracing::trace!(" -> already queued"); return false; } @@ -253,7 +272,7 @@ where // Queue the stream match self.indices { Some(ref mut idxs) => { - log::trace!(" -> existing entries"); + tracing::trace!(" -> existing entries"); // Update the current tail node to point to `stream` let key = stream.key(); @@ -263,7 +282,47 @@ where idxs.tail = stream.key(); } None => { - log::trace!(" -> first entry"); + tracing::trace!(" -> first entry"); + self.indices = Some(store::Indices { + head: stream.key(), + tail: stream.key(), + }); + } + } + + true + } + + /// Queue the stream + /// + /// If the stream is already contained by the list, return `false`. + pub fn push_front(&mut self, stream: &mut store::Ptr) -> bool { + tracing::trace!("Queue::push_front"); + + if N::is_queued(stream) { + tracing::trace!(" -> already queued"); + return false; + } + + N::set_queued(stream, true); + + // The next pointer shouldn't be set + debug_assert!(N::next(stream).is_none()); + + // Queue the stream + match self.indices { + Some(ref mut idxs) => { + tracing::trace!(" -> existing entries"); + + // Update the provided stream to point to the head node + let head_key = stream.resolve(idxs.head).key(); + N::set_next(stream, Some(head_key)); + + // Update the head pointer + idxs.head = stream.key(); + } + None => { + tracing::trace!(" -> first entry"); self.indices = Some(store::Indices { head: stream.key(), tail: stream.key(), @@ -282,15 +341,15 @@ where let mut stream = store.resolve(idxs.head); if idxs.head == idxs.tail { - assert!(N::next(&*stream).is_none()); + assert!(N::next(&stream).is_none()); self.indices = None; } else { - idxs.head = N::take_next(&mut *stream).unwrap(); + idxs.head = N::take_next(&mut stream).unwrap(); self.indices = Some(idxs); } - debug_assert!(N::is_queued(&*stream)); - N::set_queued(&mut *stream, false); + debug_assert!(N::is_queued(&stream)); + N::set_queued(&mut stream, false); return Some(stream); } @@ -298,6 +357,10 @@ where None } + pub fn is_empty(&self) -> bool { + self.indices.is_none() + } + pub fn pop_if<'a, R, F>(&mut self, store: &'a mut R, f: F) -> Option> where R: Resolve, @@ -323,7 +386,7 @@ impl<'a> Ptr<'a> { } pub fn store_mut(&mut self) -> &mut Store { - &mut self.store + self.store } /// Remove the stream from the store diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index 398672049..43e313647 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -45,7 +45,7 @@ pub(super) struct Stream { /// Amount of data buffered at the prioritization layer. /// TODO: Technically this could be greater than the window size... - pub buffered_send_data: WindowSize, + pub buffered_send_data: usize, /// Task tracking additional send capacity (i.e. window updates). send_task: Option, @@ -99,6 +99,9 @@ pub(super) struct Stream { /// Frames pending for this stream to read pub pending_recv: buffer::Deque, + /// When the RecvStream drop occurs, no data should be received. + pub is_recv: bool, + /// Task tracking receiving frames pub recv_task: Option, @@ -143,7 +146,9 @@ impl Stream { recv_flow .inc_window(init_recv_window) .expect("invalid initial receive window"); - recv_flow.assign_capacity(init_recv_window); + // TODO: proper error handling? + let _res = recv_flow.assign_capacity(init_recv_window); + debug_assert!(_res.is_ok()); send_flow .inc_window(init_send_window) @@ -180,6 +185,7 @@ impl Stream { reset_at: None, next_reset_expire: None, pending_recv: buffer::Deque::new(), + is_recv: true, recv_task: None, pending_push_promises: store::Queue::new(), content_length: ContentLength::Omitted, @@ -248,7 +254,7 @@ impl Stream { // The stream is not in any queue !self.is_pending_send && !self.is_pending_send_capacity && !self.is_pending_accept && !self.is_pending_window_update && - !self.is_pending_open && !self.reset_at.is_some() + !self.is_pending_open && self.reset_at.is_none() } /// Returns true when the consumer of the stream has dropped all handles @@ -260,25 +266,69 @@ impl Stream { self.ref_count == 0 && !self.state.is_closed() } - pub fn assign_capacity(&mut self, capacity: WindowSize) { + /// Current available stream send capacity + pub fn capacity(&self, max_buffer_size: usize) -> WindowSize { + let available = self.send_flow.available().as_size() as usize; + let buffered = self.buffered_send_data; + + available.min(max_buffer_size).saturating_sub(buffered) as WindowSize + } + + pub fn assign_capacity(&mut self, capacity: WindowSize, max_buffer_size: usize) { + let prev_capacity = self.capacity(max_buffer_size); debug_assert!(capacity > 0); - self.send_capacity_inc = true; - self.send_flow.assign_capacity(capacity); + // TODO: proper error handling + let _res = self.send_flow.assign_capacity(capacity); + debug_assert!(_res.is_ok()); - log::trace!( - " assigned capacity to stream; available={}; buffered={}; id={:?}", + tracing::trace!( + " assigned capacity to stream; available={}; buffered={}; id={:?}; max_buffer_size={} prev={}", self.send_flow.available(), self.buffered_send_data, - self.id + self.id, + max_buffer_size, + prev_capacity, ); - // Only notify if the capacity exceeds the amount of buffered data - if self.send_flow.available() > self.buffered_send_data { - log::trace!(" notifying task"); - self.notify_send(); + if prev_capacity < self.capacity(max_buffer_size) { + self.notify_capacity(); } } + pub fn send_data(&mut self, len: WindowSize, max_buffer_size: usize) { + let prev_capacity = self.capacity(max_buffer_size); + + // TODO: proper error handling + let _res = self.send_flow.send_data(len); + debug_assert!(_res.is_ok()); + + // Decrement the stream's buffered data counter + debug_assert!(self.buffered_send_data >= len as usize); + self.buffered_send_data -= len as usize; + self.requested_send_capacity -= len; + + tracing::trace!( + " sent stream data; available={}; buffered={}; id={:?}; max_buffer_size={} prev={}", + self.send_flow.available(), + self.buffered_send_data, + self.id, + max_buffer_size, + prev_capacity, + ); + + if prev_capacity < self.capacity(max_buffer_size) { + self.notify_capacity(); + } + } + + /// If the capacity was limited because of the max_send_buffer_size, + /// then consider waking the send task again... + pub fn notify_capacity(&mut self) { + self.send_capacity_inc = true; + tracing::trace!(" notifying task"); + self.notify_send(); + } + /// Returns `Err` when the decrement cannot be completed due to overflow. pub fn dec_content_length(&mut self, len: usize) -> Result<(), ()> { match self.content_length { @@ -286,7 +336,11 @@ impl Stream { Some(val) => *rem = val, None => return Err(()), }, - ContentLength::Head => return Err(()), + ContentLength::Head => { + if len != 0 { + return Err(()); + } + } _ => {} } @@ -361,7 +415,7 @@ impl store::Next for NextSend { if val { // ensure that stream is not queued for being opened // if it's being put into queue for sending data - debug_assert_eq!(stream.is_pending_open, false); + debug_assert!(!stream.is_pending_open); } stream.is_pending_send = val; } @@ -432,7 +486,7 @@ impl store::Next for NextOpen { if val { // ensure that stream is not queued for being sent // if it's being put into queue for opening the stream - debug_assert_eq!(stream.is_pending_send, false); + debug_assert!(!stream.is_pending_send); } stream.is_pending_open = val; } @@ -468,9 +522,6 @@ impl store::Next for NextResetExpire { impl ContentLength { pub fn is_head(&self) -> bool { - match *self { - ContentLength::Head => true, - _ => false, - } + matches!(*self, Self::Head) } } diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 8f6186194..fa8e6843b 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -1,9 +1,10 @@ use super::recv::RecvHeaderBlockError; use super::store::{self, Entry, Resolve, Store}; use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; -use crate::codec::{Codec, RecvError, SendError, UserError}; +use crate::codec::{Codec, SendError, UserError}; +use crate::ext::Protocol; use crate::frame::{self, Frame, Reason}; -use crate::proto::{peer, Open, Peer, WindowSize}; +use crate::proto::{peer, Error, Initiator, Open, Peer, WindowSize}; use crate::{client, proto, server}; use bytes::{Buf, Bytes}; @@ -11,7 +12,6 @@ use http::{HeaderMap, Request, Response}; use std::task::{Context, Poll, Waker}; use tokio::io::AsyncWrite; -use crate::PollExt; use std::sync::{Arc, Mutex}; use std::{fmt, io}; @@ -21,7 +21,7 @@ where P: Peer, { /// Holds most of the connection and stream related state for processing - /// HTTP/2.0 frames associated with streams. + /// HTTP/2 frames associated with streams. inner: Arc>, /// This is the queue of frames to be written to the wire. This is split out @@ -37,6 +37,17 @@ where _p: ::std::marker::PhantomData

, } +// Like `Streams` but with a `peer::Dyn` field instead of a static `P: Peer` type parameter. +// Ensures that the methods only get one instantiation, instead of two (client and server) +#[derive(Debug)] +pub(crate) struct DynStreams<'a, B> { + inner: &'a Mutex, + + send_buffer: &'a SendBuffer, + + peer: peer::Dyn, +} + /// Reference to the stream state #[derive(Debug)] pub(crate) struct StreamRef { @@ -101,23 +112,13 @@ where let peer = P::r#dyn(); Streams { - inner: Arc::new(Mutex::new(Inner { - counts: Counts::new(peer, &config), - actions: Actions { - recv: Recv::new(peer, &config), - send: Send::new(&config), - task: None, - conn_error: None, - }, - store: Store::new(), - refs: 1, - })), + inner: Inner::new(peer, config), send_buffer: Arc::new(SendBuffer::new()), _p: ::std::marker::PhantomData, } } - pub fn set_target_connection_window_size(&mut self, size: WindowSize) { + pub fn set_target_connection_window_size(&mut self, size: WindowSize) -> Result<(), Reason> { let mut me = self.inner.lock().unwrap(); let me = &mut *me; @@ -126,24 +127,303 @@ where .set_target_connection_window(size, &mut me.actions.task) } - /// Process inbound headers - pub fn recv_headers(&mut self, frame: frame::Headers) -> Result<(), RecvError> { - let id = frame.stream_id(); + pub fn next_incoming(&mut self) -> Option> { + let mut me = self.inner.lock().unwrap(); + let me = &mut *me; + me.actions.recv.next_incoming(&mut me.store).map(|key| { + let stream = &mut me.store.resolve(key); + tracing::trace!( + "next_incoming; id={:?}, state={:?}", + stream.id, + stream.state + ); + // TODO: ideally, OpaqueStreamRefs::new would do this, but we're holding + // the lock, so it can't. + me.refs += 1; + + // Pending-accepted remotely-reset streams are counted. + if stream.state.is_remote_reset() { + me.counts.dec_num_remote_reset_streams(); + } + + StreamRef { + opaque: OpaqueStreamRef::new(self.inner.clone(), stream), + send_buffer: self.send_buffer.clone(), + } + }) + } + + pub fn send_pending_refusal( + &mut self, + cx: &mut Context, + dst: &mut Codec>, + ) -> Poll> + where + T: AsyncWrite + Unpin, + { + let mut me = self.inner.lock().unwrap(); + let me = &mut *me; + me.actions.recv.send_pending_refusal(cx, dst) + } + + pub fn clear_expired_reset_streams(&mut self) { + let mut me = self.inner.lock().unwrap(); + let me = &mut *me; + me.actions + .recv + .clear_expired_reset_streams(&mut me.store, &mut me.counts); + } + + pub fn poll_complete( + &mut self, + cx: &mut Context, + dst: &mut Codec>, + ) -> Poll> + where + T: AsyncWrite + Unpin, + { + let mut me = self.inner.lock().unwrap(); + me.poll_complete(&self.send_buffer, cx, dst) + } + + pub fn apply_remote_settings( + &mut self, + frame: &frame::Settings, + is_initial: bool, + ) -> Result<(), Error> { let mut me = self.inner.lock().unwrap(); let me = &mut *me; + let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + me.counts.apply_remote_settings(frame, is_initial); + + me.actions.send.apply_remote_settings( + frame, + send_buffer, + &mut me.store, + &mut me.counts, + &mut me.actions.task, + ) + } + + pub fn apply_local_settings(&mut self, frame: &frame::Settings) -> Result<(), Error> { + let mut me = self.inner.lock().unwrap(); + let me = &mut *me; + + me.actions.recv.apply_local_settings(frame, &mut me.store) + } + + pub fn send_request( + &mut self, + mut request: Request<()>, + end_of_stream: bool, + pending: Option<&OpaqueStreamRef>, + ) -> Result<(StreamRef, bool), SendError> { + use super::stream::ContentLength; + use http::Method; + + let protocol = request.extensions_mut().remove::(); + + // Clear before taking lock, incase extensions contain a StreamRef. + request.extensions_mut().clear(); + + // TODO: There is a hazard with assigning a stream ID before the + // prioritize layer. If prioritization reorders new streams, this + // implicitly closes the earlier stream IDs. + // + // See: hyperium/h2#11 + let mut me = self.inner.lock().unwrap(); + let me = &mut *me; + + let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + me.actions.ensure_no_conn_error()?; + me.actions.send.ensure_next_stream_id()?; + + // The `pending` argument is provided by the `Client`, and holds + // a store `Key` of a `Stream` that may have been not been opened + // yet. + // + // If that stream is still pending, the Client isn't allowed to + // queue up another pending stream. They should use `poll_ready`. + if let Some(stream) = pending { + if me.store.resolve(stream.key).is_pending_open { + return Err(UserError::Rejected.into()); + } + } + + if me.counts.peer().is_server() { + // Servers cannot open streams. PushPromise must first be reserved. + return Err(UserError::UnexpectedFrameType.into()); + } + + let stream_id = me.actions.send.open()?; + + let mut stream = Stream::new( + stream_id, + me.actions.send.init_window_sz(), + me.actions.recv.init_window_sz(), + ); + + if *request.method() == Method::HEAD { + stream.content_length = ContentLength::Head; + } + + // Convert the message + let headers = + client::Peer::convert_send_message(stream_id, request, protocol, end_of_stream)?; + + let mut stream = me.store.insert(stream.id, stream); + + let sent = me.actions.send.send_headers( + headers, + send_buffer, + &mut stream, + &mut me.counts, + &mut me.actions.task, + ); + + // send_headers can return a UserError, if it does, + // we should forget about this stream. + if let Err(err) = sent { + stream.unlink(); + stream.remove(); + return Err(err.into()); + } + + // Given that the stream has been initialized, it should not be in the + // closed state. + debug_assert!(!stream.state.is_closed()); + + // TODO: ideally, OpaqueStreamRefs::new would do this, but we're holding + // the lock, so it can't. + me.refs += 1; + + let is_full = me.counts.next_send_stream_will_reach_capacity(); + Ok(( + StreamRef { + opaque: OpaqueStreamRef::new(self.inner.clone(), &mut stream), + send_buffer: self.send_buffer.clone(), + }, + is_full, + )) + } + + pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { + self.inner + .lock() + .unwrap() + .actions + .send + .is_extended_connect_protocol_enabled() + } +} + +impl DynStreams<'_, B> { + pub fn is_buffer_empty(&self) -> bool { + self.send_buffer.is_empty() + } + + pub fn is_server(&self) -> bool { + self.peer.is_server() + } + + pub fn recv_headers(&mut self, frame: frame::Headers) -> Result<(), Error> { + let mut me = self.inner.lock().unwrap(); + + me.recv_headers(self.peer, self.send_buffer, frame) + } + + pub fn recv_data(&mut self, frame: frame::Data) -> Result<(), Error> { + let mut me = self.inner.lock().unwrap(); + me.recv_data(self.peer, self.send_buffer, frame) + } + + pub fn recv_reset(&mut self, frame: frame::Reset) -> Result<(), Error> { + let mut me = self.inner.lock().unwrap(); + + me.recv_reset(self.send_buffer, frame) + } + + /// Notify all streams that a connection-level error happened. + pub fn handle_error(&mut self, err: proto::Error) -> StreamId { + let mut me = self.inner.lock().unwrap(); + me.handle_error(self.send_buffer, err) + } + + pub fn recv_go_away(&mut self, frame: &frame::GoAway) -> Result<(), Error> { + let mut me = self.inner.lock().unwrap(); + me.recv_go_away(self.send_buffer, frame) + } + + pub fn last_processed_id(&self) -> StreamId { + self.inner.lock().unwrap().actions.recv.last_processed_id() + } + + pub fn recv_window_update(&mut self, frame: frame::WindowUpdate) -> Result<(), Error> { + let mut me = self.inner.lock().unwrap(); + me.recv_window_update(self.send_buffer, frame) + } + + pub fn recv_push_promise(&mut self, frame: frame::PushPromise) -> Result<(), Error> { + let mut me = self.inner.lock().unwrap(); + me.recv_push_promise(self.send_buffer, frame) + } + + pub fn recv_eof(&mut self, clear_pending_accept: bool) -> Result<(), ()> { + let mut me = self.inner.lock().map_err(|_| ())?; + me.recv_eof(self.send_buffer, clear_pending_accept) + } + + pub fn send_reset(&mut self, id: StreamId, reason: Reason) { + let mut me = self.inner.lock().unwrap(); + me.send_reset(self.send_buffer, id, reason) + } + + pub fn send_go_away(&mut self, last_processed_id: StreamId) { + let mut me = self.inner.lock().unwrap(); + me.actions.recv.go_away(last_processed_id); + } +} + +impl Inner { + fn new(peer: peer::Dyn, config: Config) -> Arc> { + Arc::new(Mutex::new(Inner { + counts: Counts::new(peer, &config), + actions: Actions { + recv: Recv::new(peer, &config), + send: Send::new(&config), + task: None, + conn_error: None, + }, + store: Store::new(), + refs: 1, + })) + } + + fn recv_headers( + &mut self, + peer: peer::Dyn, + send_buffer: &SendBuffer, + frame: frame::Headers, + ) -> Result<(), Error> { + let id = frame.stream_id(); + // The GOAWAY process has begun. All streams with a greater ID than // specified as part of GOAWAY should be ignored. - if id > me.actions.recv.max_stream_id() { - log::trace!( + if id > self.actions.recv.max_stream_id() { + tracing::trace!( "id ({:?}) > max_stream_id ({:?}), ignoring HEADERS", id, - me.actions.recv.max_stream_id() + self.actions.recv.max_stream_id() ); return Ok(()); } - let key = match me.store.find_entry(id) { + let key = match self.store.find_entry(id) { Entry::Occupied(e) => e.key(), Entry::Vacant(e) => { // Client: it's possible to send a request, and then send @@ -151,27 +431,28 @@ where // // Server: we can't reset a stream before having received // the request headers, so don't allow. - if !P::is_server() { + if !peer.is_server() { // This may be response headers for a stream we've already // forgotten about... - if me.actions.may_have_forgotten_stream::

(id) { - log::debug!( + if self.actions.may_have_forgotten_stream(peer, id) { + tracing::debug!( "recv_headers for old stream={:?}, sending STREAM_CLOSED", id, ); - return Err(RecvError::Stream { - id, - reason: Reason::STREAM_CLOSED, - }); + return Err(Error::library_reset(id, Reason::STREAM_CLOSED)); } } - match me.actions.recv.open(id, Open::Headers, &mut me.counts)? { + match self + .actions + .recv + .open(id, Open::Headers, &mut self.counts)? + { Some(stream_id) => { let stream = Stream::new( stream_id, - me.actions.send.init_window_sz(), - me.actions.recv.init_window_sz(), + self.actions.send.init_window_sz(), + self.actions.recv.init_window_sz(), ); e.insert(stream) @@ -181,22 +462,22 @@ where } }; - let stream = me.store.resolve(key); + let stream = self.store.resolve(key); - if stream.state.is_local_reset() { + if stream.state.is_local_error() { // Locally reset streams must ignore frames "for some time". // This is because the remote may have sent trailers before // receiving the RST_STREAM frame. - log::trace!("recv_headers; ignoring trailers on {:?}", stream.id); + tracing::trace!("recv_headers; ignoring trailers on {:?}", stream.id); return Ok(()); } - let actions = &mut me.actions; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + let actions = &mut self.actions; + let mut send_buffer = send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - me.counts.transition(stream, |counts, stream| { - log::trace!( + self.counts.transition(stream, |counts, stream| { + tracing::trace!( "recv_headers; stream={:?}; state={:?}", stream.id, stream.state @@ -221,10 +502,7 @@ where Ok(()) } else { - Err(RecvError::Stream { - id: stream.id, - reason: Reason::REFUSED_STREAM, - }) + Err(Error::library_reset(stream.id, Reason::REFUSED_STREAM)) } }, Err(RecvHeaderBlockError::State(err)) => Err(err), @@ -234,10 +512,7 @@ where // Receiving trailers that don't set EOS is a "malformed" // message. Malformed messages are a stream error. proto_err!(stream: "recv_headers: trailers frame was not EOS; stream={:?}", stream.id); - return Err(RecvError::Stream { - id: stream.id, - reason: Reason::PROTOCOL_ERROR, - }); + return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR)); } actions.recv.recv_trailers(frame, stream) @@ -247,28 +522,30 @@ where }) } - pub fn recv_data(&mut self, frame: frame::Data) -> Result<(), RecvError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - + fn recv_data( + &mut self, + peer: peer::Dyn, + send_buffer: &SendBuffer, + frame: frame::Data, + ) -> Result<(), Error> { let id = frame.stream_id(); - let stream = match me.store.find_mut(&id) { + let stream = match self.store.find_mut(&id) { Some(stream) => stream, None => { // The GOAWAY process has begun. All streams with a greater ID // than specified as part of GOAWAY should be ignored. - if id > me.actions.recv.max_stream_id() { - log::trace!( + if id > self.actions.recv.max_stream_id() { + tracing::trace!( "id ({:?}) > max_stream_id ({:?}), ignoring DATA", id, - me.actions.recv.max_stream_id() + self.actions.recv.max_stream_id() ); return Ok(()); } - if me.actions.may_have_forgotten_stream::

(id) { - log::debug!("recv_data for old stream={:?}, sending STREAM_CLOSED", id,); + if self.actions.may_have_forgotten_stream(peer, id) { + tracing::debug!("recv_data for old stream={:?}, sending STREAM_CLOSED", id,); let sz = frame.payload().len(); // This should have been enforced at the codec::FramedRead layer, so @@ -276,30 +553,27 @@ where assert!(sz <= super::MAX_WINDOW_SIZE as usize); let sz = sz as WindowSize; - me.actions.recv.ignore_data(sz)?; - return Err(RecvError::Stream { - id, - reason: Reason::STREAM_CLOSED, - }); + self.actions.recv.ignore_data(sz)?; + return Err(Error::library_reset(id, Reason::STREAM_CLOSED)); } proto_err!(conn: "recv_data: stream not found; id={:?}", id); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } }; - let actions = &mut me.actions; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + let actions = &mut self.actions; + let mut send_buffer = send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - me.counts.transition(stream, |counts, stream| { + self.counts.transition(stream, |counts, stream| { let sz = frame.payload().len(); let res = actions.recv.recv_data(frame, stream); // Any stream error after receiving a DATA frame means // we won't give the data to the user, and so they can't // release the capacity. We do it automatically. - if let Err(RecvError::Stream { .. }) = res { + if let Err(Error::Reset(..)) = res { actions .recv .release_connection_capacity(sz as WindowSize, &mut None); @@ -308,183 +582,176 @@ where }) } - pub fn recv_reset(&mut self, frame: frame::Reset) -> Result<(), RecvError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - + fn recv_reset( + &mut self, + send_buffer: &SendBuffer, + frame: frame::Reset, + ) -> Result<(), Error> { let id = frame.stream_id(); if id.is_zero() { proto_err!(conn: "recv_reset: invalid stream ID 0"); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } // The GOAWAY process has begun. All streams with a greater ID than // specified as part of GOAWAY should be ignored. - if id > me.actions.recv.max_stream_id() { - log::trace!( + if id > self.actions.recv.max_stream_id() { + tracing::trace!( "id ({:?}) > max_stream_id ({:?}), ignoring RST_STREAM", id, - me.actions.recv.max_stream_id() + self.actions.recv.max_stream_id() ); return Ok(()); } - let stream = match me.store.find_mut(&id) { + let stream = match self.store.find_mut(&id) { Some(stream) => stream, None => { // TODO: Are there other error cases? - me.actions - .ensure_not_idle(me.counts.peer(), id) - .map_err(RecvError::Connection)?; + self.actions + .ensure_not_idle(self.counts.peer(), id) + .map_err(Error::library_go_away)?; return Ok(()); } }; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + let mut send_buffer = send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - let actions = &mut me.actions; + let actions = &mut self.actions; - me.counts.transition(stream, |counts, stream| { - actions.recv.recv_reset(frame, stream); - actions.send.recv_err(send_buffer, stream, counts); + self.counts.transition(stream, |counts, stream| { + actions.recv.recv_reset(frame, stream, counts)?; + actions.send.handle_error(send_buffer, stream, counts); assert!(stream.state.is_closed()); Ok(()) }) } - /// Handle a received error and return the ID of the last processed stream. - pub fn recv_err(&mut self, err: &proto::Error) -> StreamId { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; + fn recv_window_update( + &mut self, + send_buffer: &SendBuffer, + frame: frame::WindowUpdate, + ) -> Result<(), Error> { + let id = frame.stream_id(); - let actions = &mut me.actions; - let counts = &mut me.counts; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + let mut send_buffer = send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - let last_processed_id = actions.recv.last_processed_id(); - - me.store - .for_each(|stream| { - counts.transition(stream, |counts, stream| { - actions.recv.recv_err(err, &mut *stream); - actions.send.recv_err(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) - }) - .unwrap(); - - actions.conn_error = Some(err.shallow_clone()); + if id.is_zero() { + self.actions + .send + .recv_connection_window_update(frame, &mut self.store, &mut self.counts) + .map_err(Error::library_go_away)?; + } else { + // The remote may send window updates for streams that the local now + // considers closed. It's ok... + if let Some(mut stream) = self.store.find_mut(&id) { + // This result is ignored as there is nothing to do when there + // is an error. The stream is reset by the function on error and + // the error is informational. + let _ = self.actions.send.recv_stream_window_update( + frame.size_increment(), + send_buffer, + &mut stream, + &mut self.counts, + &mut self.actions.task, + ); + } else { + self.actions + .ensure_not_idle(self.counts.peer(), id) + .map_err(Error::library_go_away)?; + } + } - last_processed_id + Ok(()) } - pub fn recv_go_away(&mut self, frame: &frame::GoAway) -> Result<(), RecvError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let actions = &mut me.actions; - let counts = &mut me.counts; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + fn handle_error(&mut self, send_buffer: &SendBuffer, err: proto::Error) -> StreamId { + let actions = &mut self.actions; + let counts = &mut self.counts; + let mut send_buffer = send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - let last_stream_id = frame.last_stream_id(); - - actions.send.recv_go_away(last_stream_id)?; - - let err = frame.reason().into(); + let last_processed_id = actions.recv.last_processed_id(); - me.store - .for_each(|stream| { - if stream.id > last_stream_id { - counts.transition(stream, |counts, stream| { - actions.recv.recv_err(&err, &mut *stream); - actions.send.recv_err(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) - } else { - Ok::<_, ()>(()) - } + self.store.for_each(|stream| { + counts.transition(stream, |counts, stream| { + actions.recv.handle_error(&err, &mut *stream); + actions.send.handle_error(send_buffer, stream, counts); }) - .unwrap(); + }); actions.conn_error = Some(err); - Ok(()) - } - - pub fn last_processed_id(&self) -> StreamId { - self.inner.lock().unwrap().actions.recv.last_processed_id() - } - - pub fn recv_window_update(&mut self, frame: frame::WindowUpdate) -> Result<(), RecvError> { - let id = frame.stream_id(); - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; + last_processed_id + } - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + fn recv_go_away( + &mut self, + send_buffer: &SendBuffer, + frame: &frame::GoAway, + ) -> Result<(), Error> { + let actions = &mut self.actions; + let counts = &mut self.counts; + let mut send_buffer = send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - if id.is_zero() { - me.actions - .send - .recv_connection_window_update(frame, &mut me.store, &mut me.counts) - .map_err(RecvError::Connection)?; - } else { - // The remote may send window updates for streams that the local now - // considers closed. It's ok... - if let Some(mut stream) = me.store.find_mut(&id) { - // This result is ignored as there is nothing to do when there - // is an error. The stream is reset by the function on error and - // the error is informational. - let _ = me.actions.send.recv_stream_window_update( - frame.size_increment(), - send_buffer, - &mut stream, - &mut me.counts, - &mut me.actions.task, - ); - } else { - me.actions - .ensure_not_idle(me.counts.peer(), id) - .map_err(RecvError::Connection)?; + let last_stream_id = frame.last_stream_id(); + + actions.send.recv_go_away(last_stream_id)?; + + let err = Error::remote_go_away(frame.debug_data().clone(), frame.reason()); + + self.store.for_each(|stream| { + if stream.id > last_stream_id { + counts.transition(stream, |counts, stream| { + actions.recv.handle_error(&err, &mut *stream); + actions.send.handle_error(send_buffer, stream, counts); + }) } - } + }); + + actions.conn_error = Some(err); Ok(()) } - pub fn recv_push_promise(&mut self, frame: frame::PushPromise) -> Result<(), RecvError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - + fn recv_push_promise( + &mut self, + send_buffer: &SendBuffer, + frame: frame::PushPromise, + ) -> Result<(), Error> { let id = frame.stream_id(); let promised_id = frame.promised_id(); // First, ensure that the initiating stream is still in a valid state. - let parent_key = match me.store.find_mut(&id) { + let parent_key = match self.store.find_mut(&id) { Some(stream) => { // The GOAWAY process has begun. All streams with a greater ID // than specified as part of GOAWAY should be ignored. - if id > me.actions.recv.max_stream_id() { - log::trace!( + if id > self.actions.recv.max_stream_id() { + tracing::trace!( "id ({:?}) > max_stream_id ({:?}), ignoring PUSH_PROMISE", id, - me.actions.recv.max_stream_id() + self.actions.recv.max_stream_id() ); return Ok(()); } // The stream must be receive open - stream.state.ensure_recv_open()?; + if !stream.state.ensure_recv_open()? { + proto_err!(conn: "recv_push_promise: initiating stream is not opened"); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + stream.key() } None => { proto_err!(conn: "recv_push_promise: initiating stream is in an invalid state"); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } }; @@ -493,16 +760,16 @@ where // could grow in memory indefinitely. // Ensure that we can reserve streams - me.actions.recv.ensure_can_reserve()?; + self.actions.recv.ensure_can_reserve()?; // Next, open the stream. // // If `None` is returned, then the stream is being refused. There is no // further work to be done. - if me + if self .actions .recv - .open(promised_id, Open::PushPromise, &mut me.counts)? + .open(promised_id, Open::PushPromise, &mut self.counts)? .is_none() { return Ok(()); @@ -512,23 +779,23 @@ where // this requires a bit of indirection to make the borrow checker happy. let child_key: Option = { // Create state for the stream - let stream = me.store.insert(promised_id, { + let stream = self.store.insert(promised_id, { Stream::new( promised_id, - me.actions.send.init_window_sz(), - me.actions.recv.init_window_sz(), + self.actions.send.init_window_sz(), + self.actions.recv.init_window_sz(), ) }); - let actions = &mut me.actions; + let actions = &mut self.actions; - me.counts.transition(stream, |counts, stream| { + self.counts.transition(stream, |counts, stream| { let stream_valid = actions.recv.recv_push_promise(frame, stream); match stream_valid { Ok(()) => Ok(Some(stream.key())), _ => { - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + let mut send_buffer = send_buffer.inner.lock().unwrap(); actions .reset_on_recv_stream_err( &mut *send_buffer, @@ -543,10 +810,10 @@ where }; // If we're successful, push the headers and stream... if let Some(child) = child_key { - let mut ppp = me.store[parent_key].pending_push_promises.take(); - ppp.push(&mut me.store.resolve(child)); + let mut ppp = self.store[parent_key].pending_push_promises.take(); + ppp.push(&mut self.store.resolve(child)); - let parent = &mut me.store.resolve(parent_key); + let parent = &mut self.store.resolve(parent_key); parent.pending_push_promises = ppp; parent.notify_recv(); }; @@ -554,220 +821,121 @@ where Ok(()) } - pub fn next_incoming(&mut self) -> Option> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - let key = me.actions.recv.next_incoming(&mut me.store); - // TODO: ideally, OpaqueStreamRefs::new would do this, but we're holding - // the lock, so it can't. - me.refs += 1; - key.map(|key| { - let stream = &mut me.store.resolve(key); - log::trace!( - "next_incoming; id={:?}, state={:?}", - stream.id, - stream.state + fn recv_eof( + &mut self, + send_buffer: &SendBuffer, + clear_pending_accept: bool, + ) -> Result<(), ()> { + let actions = &mut self.actions; + let counts = &mut self.counts; + let mut send_buffer = send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + if actions.conn_error.is_none() { + actions.conn_error = Some( + io::Error::new( + io::ErrorKind::BrokenPipe, + "connection closed because of a broken pipe", + ) + .into(), ); - StreamRef { - opaque: OpaqueStreamRef::new(self.inner.clone(), stream), - send_buffer: self.send_buffer.clone(), - } - }) - } + } - pub fn send_pending_refusal( - &mut self, - cx: &mut Context, - dst: &mut Codec>, - ) -> Poll> - where - T: AsyncWrite + Unpin, - { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - me.actions.recv.send_pending_refusal(cx, dst) - } + tracing::trace!("Streams::recv_eof"); - pub fn clear_expired_reset_streams(&mut self) { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - me.actions - .recv - .clear_expired_reset_streams(&mut me.store, &mut me.counts); + self.store.for_each(|stream| { + counts.transition(stream, |counts, stream| { + actions.recv.recv_eof(stream); + + // This handles resetting send state associated with the + // stream + actions.send.handle_error(send_buffer, stream, counts); + }) + }); + + actions.clear_queues(clear_pending_accept, &mut self.store, counts); + Ok(()) } - pub fn poll_complete( + fn poll_complete( &mut self, + send_buffer: &SendBuffer, cx: &mut Context, dst: &mut Codec>, ) -> Poll> where T: AsyncWrite + Unpin, + B: Buf, { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + let mut send_buffer = send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; // Send WINDOW_UPDATE frames first // // TODO: It would probably be better to interleave updates w/ data // frames. - ready!(me + ready!(self .actions .recv - .poll_complete(cx, &mut me.store, &mut me.counts, dst))?; + .poll_complete(cx, &mut self.store, &mut self.counts, dst))?; // Send any other pending frames - ready!(me - .actions - .send - .poll_complete(cx, send_buffer, &mut me.store, &mut me.counts, dst))?; + ready!(self.actions.send.poll_complete( + cx, + send_buffer, + &mut self.store, + &mut self.counts, + dst + ))?; // Nothing else to do, track the task - me.actions.task = Some(cx.waker().clone()); + self.actions.task = Some(cx.waker().clone()); Poll::Ready(Ok(())) } - pub fn apply_remote_settings(&mut self, frame: &frame::Settings) -> Result<(), RecvError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - me.counts.apply_remote_settings(frame); - - me.actions.send.apply_remote_settings( - frame, - send_buffer, - &mut me.store, - &mut me.counts, - &mut me.actions.task, - ) - } - - pub fn apply_local_settings(&mut self, frame: &frame::Settings) -> Result<(), RecvError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - me.actions.recv.apply_local_settings(frame, &mut me.store) - } - - pub fn send_request( - &mut self, - request: Request<()>, - end_of_stream: bool, - pending: Option<&OpaqueStreamRef>, - ) -> Result, SendError> { - use super::stream::ContentLength; - use http::Method; - - // TODO: There is a hazard with assigning a stream ID before the - // prioritize layer. If prioritization reorders new streams, this - // implicitly closes the earlier stream IDs. - // - // See: hyperium/h2#11 - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - me.actions.ensure_no_conn_error()?; - me.actions.send.ensure_next_stream_id()?; - - // The `pending` argument is provided by the `Client`, and holds - // a store `Key` of a `Stream` that may have been not been opened - // yet. - // - // If that stream is still pending, the Client isn't allowed to - // queue up another pending stream. They should use `poll_ready`. - if let Some(stream) = pending { - if me.store.resolve(stream.key).is_pending_open { - return Err(UserError::Rejected.into()); - } - } - - if me.counts.peer().is_server() { - // Servers cannot open streams. PushPromise must first be reserved. - return Err(UserError::UnexpectedFrameType.into()); - } - - let stream_id = me.actions.send.open()?; - - let mut stream = Stream::new( - stream_id, - me.actions.send.init_window_sz(), - me.actions.recv.init_window_sz(), - ); - - if *request.method() == Method::HEAD { - stream.content_length = ContentLength::Head; - } - - // Convert the message - let headers = client::Peer::convert_send_message(stream_id, request, end_of_stream)?; - - let mut stream = me.store.insert(stream.id, stream); - - let sent = me.actions.send.send_headers( - headers, - send_buffer, - &mut stream, - &mut me.counts, - &mut me.actions.task, - ); - - // send_headers can return a UserError, if it does, - // we should forget about this stream. - if let Err(err) = sent { - stream.unlink(); - stream.remove(); - return Err(err.into()); - } - - // Given that the stream has been initialized, it should not be in the - // closed state. - debug_assert!(!stream.state.is_closed()); - - // TODO: ideally, OpaqueStreamRefs::new would do this, but we're holding - // the lock, so it can't. - me.refs += 1; - - Ok(StreamRef { - opaque: OpaqueStreamRef::new(self.inner.clone(), &mut stream), - send_buffer: self.send_buffer.clone(), - }) - } - - pub fn send_reset(&mut self, id: StreamId, reason: Reason) { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let key = match me.store.find_entry(id) { + fn send_reset(&mut self, send_buffer: &SendBuffer, id: StreamId, reason: Reason) { + let key = match self.store.find_entry(id) { Entry::Occupied(e) => e.key(), Entry::Vacant(e) => { + // Resetting a stream we don't know about? That could be OK... + // + // 1. As a server, we just received a request, but that request + // was bad, so we're resetting before even accepting it. + // This is totally fine. + // + // 2. The remote may have sent us a frame on new stream that + // it's *not* supposed to have done, and thus, we don't know + // the stream. In that case, sending a reset will "open" the + // stream in our store. Maybe that should be a connection + // error instead? At least for now, we need to update what + // our vision of the next stream is. + if self.counts.peer().is_local_init(id) { + // We normally would open this stream, so update our + // next-send-id record. + self.actions.send.maybe_reset_next_stream_id(id); + } else { + // We normally would recv this stream, so update our + // next-recv-id record. + self.actions.recv.maybe_reset_next_stream_id(id); + } + let stream = Stream::new(id, 0, 0); e.insert(stream) } }; - let stream = me.store.resolve(key); - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + let stream = self.store.resolve(key); + let mut send_buffer = send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - me.actions - .send_reset(stream, reason, &mut me.counts, send_buffer); - } - - pub fn send_go_away(&mut self, last_processed_id: StreamId) { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - let actions = &mut me.actions; - actions.recv.go_away(last_processed_id); + self.actions.send_reset( + stream, + reason, + Initiator::Library, + &mut self.counts, + send_buffer, + ); } } @@ -788,7 +956,7 @@ where if let Some(pending) = pending { let mut stream = me.store.resolve(pending.key); - log::trace!("poll_pending_open; stream = {:?}", stream.is_pending_open); + tracing::trace!("poll_pending_open; stream = {:?}", stream.is_pending_open); if stream.is_pending_open { stream.wait_send(cx); return Poll::Pending; @@ -802,39 +970,32 @@ impl Streams where P: Peer, { + pub fn as_dyn(&self) -> DynStreams { + let Self { + inner, + send_buffer, + _p, + } = self; + DynStreams { + inner, + send_buffer, + peer: P::r#dyn(), + } + } + /// This function is safe to call multiple times. /// /// A `Result` is returned to avoid panicking if the mutex is poisoned. pub fn recv_eof(&mut self, clear_pending_accept: bool) -> Result<(), ()> { - let mut me = self.inner.lock().map_err(|_| ())?; - let me = &mut *me; - - let actions = &mut me.actions; - let counts = &mut me.counts; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - if actions.conn_error.is_none() { - actions.conn_error = Some(io::Error::from(io::ErrorKind::BrokenPipe).into()); - } - - log::trace!("Streams::recv_eof"); - - me.store - .for_each(|stream| { - counts.transition(stream, |counts, stream| { - actions.recv.recv_eof(stream); + self.as_dyn().recv_eof(clear_pending_accept) + } - // This handles resetting send state associated with the - // stream - actions.send.recv_err(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) - }) - .expect("recv_eof"); + pub(crate) fn max_send_streams(&self) -> usize { + self.inner.lock().unwrap().counts.max_send_streams() + } - actions.clear_queues(clear_pending_accept, &mut me.store, counts); - Ok(()) + pub(crate) fn max_recv_streams(&self) -> usize { + self.inner.lock().unwrap().counts.max_recv_streams() } #[cfg(feature = "unstable")] @@ -880,7 +1041,14 @@ where P: Peer, { fn drop(&mut self) { - let _ = self.inner.lock().map(|mut inner| inner.refs -= 1); + if let Ok(mut inner) = self.inner.lock() { + inner.refs -= 1; + if inner.refs == 1 { + if let Some(task) = inner.actions.task.take() { + task.wake(); + } + } + } } } @@ -940,14 +1108,16 @@ impl StreamRef { let send_buffer = &mut *send_buffer; me.actions - .send_reset(stream, reason, &mut me.counts, send_buffer); + .send_reset(stream, reason, Initiator::User, &mut me.counts, send_buffer); } pub fn send_response( &mut self, - response: Response<()>, + mut response: Response<()>, end_of_stream: bool, ) -> Result<(), UserError> { + // Clear before taking lock, incase extensions contain a StreamRef. + response.extensions_mut().clear(); let mut me = self.opaque.inner.lock().unwrap(); let me = &mut *me; @@ -965,7 +1135,12 @@ impl StreamRef { }) } - pub fn send_push_promise(&mut self, request: Request<()>) -> Result, UserError> { + pub fn send_push_promise( + &mut self, + mut request: Request<()>, + ) -> Result, UserError> { + // Clear before taking lock, incase extensions contain a StreamRef. + request.extensions_mut().clear(); let mut me = self.opaque.inner.lock().unwrap(); let me = &mut *me; @@ -1003,9 +1178,10 @@ impl StreamRef { let mut child_stream = me.store.resolve(child_key); child_stream.unlink(); child_stream.remove(); - return Err(err.into()); + return Err(err); } + me.refs += 1; let opaque = OpaqueStreamRef::new(self.opaque.inner.clone(), &mut me.store.resolve(child_key)); @@ -1085,10 +1261,7 @@ impl StreamRef { .map_err(From::from) } - pub fn clone_to_opaque(&self) -> OpaqueStreamRef - where - B: 'static, - { + pub fn clone_to_opaque(&self) -> OpaqueStreamRef { self.opaque.clone() } @@ -1137,7 +1310,7 @@ impl OpaqueStreamRef { me.actions .recv .poll_pushed(cx, &mut stream) - .map_ok_(|(h, key)| { + .map_ok(|(h, key)| { me.refs += 1; let opaque_ref = OpaqueStreamRef::new(self.inner.clone(), &mut me.store.resolve(key)); @@ -1201,12 +1374,13 @@ impl OpaqueStreamRef { .release_capacity(capacity, &mut stream, &mut me.actions.task) } + /// Clear the receive queue and set the status to no longer receive data frames. pub(crate) fn clear_recv_buffer(&mut self) { let mut me = self.inner.lock().unwrap(); let me = &mut *me; let mut stream = me.store.resolve(self.key); - + stream.is_recv = false; me.actions.recv.clear_recv_buffer(&mut stream); } @@ -1248,7 +1422,7 @@ impl Clone for OpaqueStreamRef { OpaqueStreamRef { inner: self.inner.clone(), - key: self.key.clone(), + key: self.key, } } } @@ -1265,7 +1439,7 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { Ok(inner) => inner, Err(_) => { if ::std::thread::panicking() { - log::trace!("StreamRef::drop; mutex poisoned"); + tracing::trace!("StreamRef::drop; mutex poisoned"); return; } else { panic!("StreamRef::drop; mutex poisoned"); @@ -1277,7 +1451,7 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { me.refs -= 1; let mut stream = me.store.resolve(key); - log::trace!("drop_stream_ref; stream={:?}", stream); + tracing::trace!("drop_stream_ref; stream={:?}", stream); // decrement the stream's ref count by 1. stream.ref_dec(); @@ -1317,9 +1491,21 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { fn maybe_cancel(stream: &mut store::Ptr, actions: &mut Actions, counts: &mut Counts) { if stream.is_canceled_interest() { + // Server is allowed to early respond without fully consuming the client input stream + // But per the RFC, must send a RST_STREAM(NO_ERROR) in such cases. https://www.rfc-editor.org/rfc/rfc7540#section-8.1 + // Some other http2 implementation may interpret other error code as fatal if not respected (i.e: nginx https://trac.nginx.org/nginx/ticket/2376) + let reason = if counts.peer().is_server() + && stream.state.is_send_closed() + && stream.state.is_recv_streaming() + { + Reason::NO_ERROR + } else { + Reason::CANCEL + }; + actions .send - .schedule_implicit_reset(stream, Reason::CANCEL, counts, &mut actions.task); + .schedule_implicit_reset(stream, reason, counts, &mut actions.task); actions.recv.enqueue_reset_expiration(stream, counts); } } @@ -1331,6 +1517,11 @@ impl SendBuffer { let inner = Mutex::new(Buffer::new()); SendBuffer { inner } } + + pub fn is_empty(&self) -> bool { + let buf = self.inner.lock().unwrap(); + buf.is_empty() + } } // ===== impl Actions ===== @@ -1340,12 +1531,19 @@ impl Actions { &mut self, stream: store::Ptr, reason: Reason, + initiator: Initiator, counts: &mut Counts, send_buffer: &mut Buffer>, ) { counts.transition(stream, |counts, stream| { - self.send - .send_reset(reason, send_buffer, stream, counts, &mut self.task); + self.send.send_reset( + reason, + initiator, + send_buffer, + stream, + counts, + &mut self.task, + ); self.recv.enqueue_reset_expiration(stream, counts); // if a RecvStream is parked, ensure it's notified stream.notify_recv(); @@ -1357,13 +1555,28 @@ impl Actions { buffer: &mut Buffer>, stream: &mut store::Ptr, counts: &mut Counts, - res: Result<(), RecvError>, - ) -> Result<(), RecvError> { - if let Err(RecvError::Stream { reason, .. }) = res { - // Reset the stream. - self.send - .send_reset(reason, buffer, stream, counts, &mut self.task); - Ok(()) + res: Result<(), Error>, + ) -> Result<(), Error> { + if let Err(Error::Reset(stream_id, reason, initiator)) = res { + debug_assert_eq!(stream_id, stream.id); + + if counts.can_inc_num_local_error_resets() { + counts.inc_num_local_error_resets(); + + // Reset the stream. + self.send + .send_reset(reason, initiator, buffer, stream, counts, &mut self.task); + Ok(()) + } else { + tracing::warn!( + "reset_on_recv_stream_err; locally-reset streams reached limit ({:?})", + counts.max_local_error_resets().unwrap(), + ); + Err(Error::library_go_away_data( + Reason::ENHANCE_YOUR_CALM, + "too_many_internal_resets", + )) + } } else { res } @@ -1379,7 +1592,7 @@ impl Actions { fn ensure_no_conn_error(&self) -> Result<(), proto::Error> { if let Some(ref err) = self.conn_error { - Err(err.shallow_clone()) + Err(err.clone()) } else { Ok(()) } @@ -1394,11 +1607,11 @@ impl Actions { /// is more likely to be latency/memory constraints that caused this, /// and not a bad actor. So be less catastrophic, the spec allows /// us to send another RST_STREAM of STREAM_CLOSED. - fn may_have_forgotten_stream(&self, id: StreamId) -> bool { + fn may_have_forgotten_stream(&self, peer: peer::Dyn, id: StreamId) -> bool { if id.is_zero() { return false; } - if P::is_local_init(id) { + if peer.is_local_init(id) { self.send.may_have_created_stream(id) } else { self.recv.may_have_created_stream(id) diff --git a/src/server.rs b/src/server.rs index 59247b596..4f8722269 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,10 +1,10 @@ -//! Server implementation of the HTTP/2.0 protocol. +//! Server implementation of the HTTP/2 protocol. //! //! # Getting started //! -//! Running an HTTP/2.0 server requires the caller to manage accepting the +//! Running an HTTP/2 server requires the caller to manage accepting the //! connections as well as getting the connections to a state that is ready to -//! begin the HTTP/2.0 handshake. See [here](../index.html#handshake) for more +//! begin the HTTP/2 handshake. See [here](../index.html#handshake) for more //! details. //! //! This could be as basic as using Tokio's [`TcpListener`] to accept @@ -12,8 +12,8 @@ //! upgrades. //! //! Once a connection is obtained, it is passed to [`handshake`], -//! which will begin the [HTTP/2.0 handshake]. This returns a future that -//! completes once the handshake process is performed and HTTP/2.0 streams may +//! which will begin the [HTTP/2 handshake]. This returns a future that +//! completes once the handshake process is performed and HTTP/2 streams may //! be received. //! //! [`handshake`] uses default configuration values. There are a number of @@ -21,12 +21,12 @@ //! //! # Inbound streams //! -//! The [`Connection`] instance is used to accept inbound HTTP/2.0 streams. It +//! The [`Connection`] instance is used to accept inbound HTTP/2 streams. It //! does this by implementing [`futures::Stream`]. When a new stream is -//! received, a call to [`Connection::poll`] will return `(request, response)`. +//! received, a call to [`Connection::accept`] will return `(request, response)`. //! The `request` handle (of type [`http::Request`]) contains the //! HTTP request head as well as provides a way to receive the inbound data -//! stream and the trailers. The `response` handle (of type [`SendStream`]) +//! stream and the trailers. The `response` handle (of type [`SendResponse`]) //! allows responding to the request, stream the response payload, send //! trailers, and send push promises. //! @@ -36,19 +36,19 @@ //! # Managing the connection //! //! The [`Connection`] instance is used to manage connection state. The caller -//! is required to call either [`Connection::poll`] or +//! is required to call either [`Connection::accept`] or //! [`Connection::poll_close`] in order to advance the connection state. Simply //! operating on [`SendStream`] or [`RecvStream`] will have no effect unless the //! connection state is advanced. //! -//! It is not required to call **both** [`Connection::poll`] and +//! It is not required to call **both** [`Connection::accept`] and //! [`Connection::poll_close`]. If the caller is ready to accept a new stream, -//! then only [`Connection::poll`] should be called. When the caller **does +//! then only [`Connection::accept`] should be called. When the caller **does //! not** want to accept a new stream, [`Connection::poll_close`] should be //! called. //! //! The [`Connection`] instance should only be dropped once -//! [`Connection::poll_close`] returns `Ready`. Once [`Connection::poll`] +//! [`Connection::poll_close`] returns `Ready`. Once [`Connection::accept`] //! returns `Ready(None)`, there will no longer be any more inbound streams. At //! this point, only [`Connection::poll_close`] should be called. //! @@ -59,9 +59,9 @@ //! //! # Example //! -//! A basic HTTP/2.0 server example that runs over TCP and assumes [prior +//! A basic HTTP/2 server example that runs over TCP and assumes [prior //! knowledge], i.e. both the client and the server assume that the TCP socket -//! will use the HTTP/2.0 protocol without prior negotiation. +//! will use the HTTP/2 protocol without prior negotiation. //! //! ```no_run //! use h2::server; @@ -77,9 +77,9 @@ //! if let Ok((socket, _peer_addr)) = listener.accept().await { //! // Spawn a new task to process each connection. //! tokio::spawn(async { -//! // Start the HTTP/2.0 connection handshake +//! // Start the HTTP/2 connection handshake //! let mut h2 = server::handshake(socket).await.unwrap(); -//! // Accept all inbound HTTP/2.0 streams sent over the +//! // Accept all inbound HTTP/2 streams sent over the //! // connection. //! while let Some(request) = h2.accept().await { //! let (request, mut respond) = request.unwrap(); @@ -104,7 +104,7 @@ //! //! [prior knowledge]: http://httpwg.org/specs/rfc7540.html#known-http //! [`handshake`]: fn.handshake.html -//! [HTTP/2.0 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader +//! [HTTP/2 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader //! [`Builder`]: struct.Builder.html //! [`Connection`]: struct.Connection.html //! [`Connection::poll`]: struct.Connection.html#method.poll @@ -115,9 +115,9 @@ //! [`SendStream`]: ../struct.SendStream.html //! [`TcpListener`]: https://docs.rs/tokio-core/0.1/tokio_core/net/struct.TcpListener.html -use crate::codec::{Codec, RecvError, UserError}; +use crate::codec::{Codec, UserError}; use crate::frame::{self, Pseudo, PushPromiseHeaderError, Reason, Settings, StreamId}; -use crate::proto::{self, Config, Prioritized}; +use crate::proto::{self, Config, Error, Prioritized}; use crate::{FlowControl, PingPong, RecvStream, SendStream}; use bytes::{Buf, Bytes}; @@ -126,10 +126,11 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use std::{convert, fmt, io, mem}; -use tokio::io::{AsyncRead, AsyncWrite}; +use std::{fmt, io}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tracing::instrument::{Instrument, Instrumented}; -/// In progress HTTP/2.0 connection handshake future. +/// In progress HTTP/2 connection handshake future. /// /// This type implements `Future`, yielding a `Connection` instance once the /// handshake has completed. @@ -149,12 +150,14 @@ pub struct Handshake { builder: Builder, /// The current state of the handshake. state: Handshaking, + /// Span tracking the handshake + span: tracing::Span, } -/// Accepts inbound HTTP/2.0 streams on a connection. +/// Accepts inbound HTTP/2 streams on a connection. /// /// A `Connection` is backed by an I/O resource (usually a TCP socket) and -/// implements the HTTP/2.0 server logic for that connection. It is responsible +/// implements the HTTP/2 server logic for that connection. It is responsible /// for receiving inbound streams initiated by the client as well as driving the /// internal state forward. /// @@ -179,9 +182,11 @@ pub struct Handshake { /// # async fn doc(my_io: T) { /// let mut server = server::handshake(my_io).await.unwrap(); /// while let Some(request) = server.accept().await { -/// let (request, respond) = request.unwrap(); -/// // Process the request and send the response back to the client -/// // using `respond`. +/// tokio::spawn(async move { +/// let (request, respond) = request.unwrap(); +/// // Process the request and send the response back to the client +/// // using `respond`. +/// }); /// } /// # } /// # @@ -197,7 +202,7 @@ pub struct Connection { /// Methods can be chained in order to set the configuration values. /// /// The server is constructed by calling [`handshake`] and passing the I/O -/// handle that will back the HTTP/2.0 server. +/// handle that will back the HTTP/2 server. /// /// New instances of `Builder` are obtained via [`Builder::new`]. /// @@ -216,7 +221,7 @@ pub struct Connection { /// # fn doc(my_io: T) /// # -> Handshake /// # { -/// // `server_fut` is a future representing the completion of the HTTP/2.0 +/// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut = Builder::new() /// .initial_window_size(1_000_000) @@ -235,11 +240,24 @@ pub struct Builder { /// Maximum number of locally reset streams to keep at a time. reset_stream_max: usize, + /// Maximum number of remotely reset streams to allow in the pending + /// accept queue. + pending_accept_reset_stream_max: usize, + /// Initial `Settings` frame to send as part of the handshake. settings: Settings, /// Initial target window size for new connections. initial_target_connection_window_size: Option, + + /// Maximum amount of bytes to "buffer" for writing per stream. + max_send_buffer_size: usize, + + /// Maximum number of locally reset streams due to protocol error across + /// the lifetime of the connection. + /// + /// When this gets exceeded, we issue GOAWAYs. + local_max_error_reset_streams: Option, } /// Send a response back to the client @@ -252,7 +270,7 @@ pub struct Builder { /// stream. /// /// If the `SendResponse` instance is dropped without sending a response, then -/// the HTTP/2.0 stream will be reset. +/// the HTTP/2 stream will be reset. /// /// See [module] level docs for more details. /// @@ -271,7 +289,7 @@ pub struct SendResponse { /// It can not be used to initiate push promises. /// /// If the `SendPushedResponse` instance is dropped without sending a response, then -/// the HTTP/2.0 stream will be reset. +/// the HTTP/2 stream will be reset. /// /// See [module] level docs for more details. /// @@ -290,11 +308,11 @@ impl fmt::Debug for SendPushedResponse { /// Stages of an in-progress handshake. enum Handshaking { /// State 1. Connection is flushing pending SETTINGS frame. - Flushing(Flush>), + Flushing(Instrumented>>), /// State 2. Connection is waiting for the client preface. - ReadingPreface(ReadPreface>), - /// Dummy state for `mem::replace`. - Empty, + ReadingPreface(Instrumented>>), + /// State 3. Handshake is done, polling again would panic. + Done, } /// Flush a Sink @@ -313,18 +331,18 @@ pub(crate) struct Peer; const PREFACE: [u8; 24] = *b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; -/// Creates a new configured HTTP/2.0 server with default configuration +/// Creates a new configured HTTP/2 server with default configuration /// values backed by `io`. /// /// It is expected that `io` already be in an appropriate state to commence -/// the [HTTP/2.0 handshake]. See [Handshake] for more details. +/// the [HTTP/2 handshake]. See [Handshake] for more details. /// /// Returns a future which resolves to the [`Connection`] instance once the -/// HTTP/2.0 handshake has been completed. The returned [`Connection`] +/// HTTP/2 handshake has been completed. The returned [`Connection`] /// instance will be using default configuration values. Use [`Builder`] to /// customize the configuration values used by a [`Connection`] instance. /// -/// [HTTP/2.0 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader +/// [HTTP/2 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader /// [Handshake]: ../index.html#handshake /// [`Connection`]: struct.Connection.html /// @@ -338,8 +356,8 @@ const PREFACE: [u8; 24] = *b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; /// # async fn doc(my_io: T) /// # { /// let connection = server::handshake(my_io).await.unwrap(); -/// // The HTTP/2.0 handshake has completed, now use `connection` to -/// // accept inbound HTTP/2.0 streams. +/// // The HTTP/2 handshake has completed, now use `connection` to +/// // accept inbound HTTP/2 streams. /// # } /// # /// # pub fn main() {} @@ -356,9 +374,12 @@ where impl Connection where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { fn handshake2(io: T, builder: Builder) -> Handshake { + let span = tracing::trace_span!("server_handshake"); + let entered = span.enter(); + // Create the codec. let mut codec = Codec::new(io); @@ -376,9 +397,16 @@ where .expect("invalid SETTINGS frame"); // Create the handshake future. - let state = Handshaking::from(codec); + let state = + Handshaking::Flushing(Flush::new(codec).instrument(tracing::trace_span!("flush"))); - Handshake { builder, state } + drop(entered); + + Handshake { + builder, + state, + span, + } } /// Accept the next incoming request on this connection. @@ -395,14 +423,14 @@ where ) -> Poll, SendResponse), crate::Error>>> { // Always try to advance the internal state. Getting Pending also is // needed to allow this function to return Pending. - if let Poll::Ready(_) = self.poll_closed(cx)? { + if self.poll_closed(cx)?.is_ready() { // If the socket is closed, don't return anything // TODO: drop any pending streams return Poll::Ready(None); } if let Some(inner) = self.connection.next_incoming() { - log::trace!("received incoming"); + tracing::trace!("received incoming"); let (head, _) = inner.take_request().into_parts(); let body = RecvStream::new(FlowControl::new(inner.clone_to_opaque())); @@ -456,6 +484,19 @@ where Ok(()) } + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + /// + /// # Errors + /// + /// Returns an error if a previous call is still pending acknowledgement + /// from the remote endpoint. + pub fn enable_connect_protocol(&mut self) -> Result<(), crate::Error> { + self.connection.set_enable_connect_protocol()?; + Ok(()) + } + /// Returns `Ready` when the underlying connection has closed. /// /// If any new inbound streams are received during a call to `poll_closed`, @@ -473,12 +514,6 @@ where self.connection.poll(cx).map_err(Into::into) } - #[doc(hidden)] - #[deprecated(note = "renamed to poll_closed")] - pub fn poll_close(&mut self, cx: &mut Context) -> Poll> { - self.poll_closed(cx) - } - /// Sets the connection to a GOAWAY state. /// /// Does not terminate the connection. Must continue being polled to close @@ -517,13 +552,48 @@ where pub fn ping_pong(&mut self) -> Option { self.connection.take_user_pings().map(PingPong::new) } + + /// Returns the maximum number of concurrent streams that may be initiated + /// by the server on this connection. + /// + /// This limit is configured by the client peer by sending the + /// [`SETTINGS_MAX_CONCURRENT_STREAMS` parameter][1] in a `SETTINGS` frame. + /// This method returns the currently acknowledged value received from the + /// remote. + /// + /// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2 + pub fn max_concurrent_send_streams(&self) -> usize { + self.connection.max_send_streams() + } + + /// Returns the maximum number of concurrent streams that may be initiated + /// by the client on this connection. + /// + /// This returns the value of the [`SETTINGS_MAX_CONCURRENT_STREAMS` + /// parameter][1] sent in a `SETTINGS` frame that has been + /// acknowledged by the remote peer. The value to be sent is configured by + /// the [`Builder::max_concurrent_streams`][2] method before handshaking + /// with the remote peer. + /// + /// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2 + /// [2]: ../struct.Builder.html#method.max_concurrent_streams + pub fn max_concurrent_recv_streams(&self) -> usize { + self.connection.max_recv_streams() + } + + // Could disappear at anytime. + #[doc(hidden)] + #[cfg(feature = "unstable")] + pub fn num_wired_streams(&self) -> usize { + self.connection.num_wired_streams() + } } #[cfg(feature = "stream")] impl futures_core::Stream for Connection where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { type Item = Result<(Request, SendResponse), crate::Error>; @@ -561,7 +631,7 @@ impl Builder { /// # fn doc(my_io: T) /// # -> Handshake /// # { - /// // `server_fut` is a future representing the completion of the HTTP/2.0 + /// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut = Builder::new() /// .initial_window_size(1_000_000) @@ -576,8 +646,12 @@ impl Builder { Builder { reset_stream_duration: Duration::from_secs(proto::DEFAULT_RESET_STREAM_SECS), reset_stream_max: proto::DEFAULT_RESET_STREAM_MAX, + pending_accept_reset_stream_max: proto::DEFAULT_REMOTE_RESET_STREAM_MAX, settings: Settings::default(), initial_target_connection_window_size: None, + max_send_buffer_size: proto::DEFAULT_MAX_SEND_BUFFER_SIZE, + + local_max_error_reset_streams: Some(proto::DEFAULT_LOCAL_RESET_COUNT_MAX), } } @@ -600,7 +674,7 @@ impl Builder { /// # fn doc(my_io: T) /// # -> Handshake /// # { - /// // `server_fut` is a future representing the completion of the HTTP/2.0 + /// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut = Builder::new() /// .initial_window_size(1_000_000) @@ -634,7 +708,7 @@ impl Builder { /// # fn doc(my_io: T) /// # -> Handshake /// # { - /// // `server_fut` is a future representing the completion of the HTTP/2.0 + /// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut = Builder::new() /// .initial_connection_window_size(1_000_000) @@ -649,7 +723,7 @@ impl Builder { self } - /// Indicates the size (in octets) of the largest HTTP/2.0 frame payload that the + /// Indicates the size (in octets) of the largest HTTP/2 frame payload that the /// configured server is able to accept. /// /// The sender may send data frames that are **smaller** than this value, @@ -667,7 +741,7 @@ impl Builder { /// # fn doc(my_io: T) /// # -> Handshake /// # { - /// // `server_fut` is a future representing the completion of the HTTP/2.0 + /// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut = Builder::new() /// .max_frame_size(1_000_000) @@ -706,7 +780,7 @@ impl Builder { /// # fn doc(my_io: T) /// # -> Handshake /// # { - /// // `server_fut` is a future representing the completion of the HTTP/2.0 + /// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut = Builder::new() /// .max_header_list_size(16 * 1024) @@ -741,7 +815,7 @@ impl Builder { /// a protocol level error. Instead, the `h2` library will immediately reset /// the stream. /// - /// See [Section 5.1.2] in the HTTP/2.0 spec for more details. + /// See [Section 5.1.2] in the HTTP/2 spec for more details. /// /// [Section 5.1.2]: https://http2.github.io/http2-spec/#rfc.section.5.1.2 /// @@ -754,7 +828,7 @@ impl Builder { /// # fn doc(my_io: T) /// # -> Handshake /// # { - /// // `server_fut` is a future representing the completion of the HTTP/2.0 + /// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut = Builder::new() /// .max_concurrent_streams(1000) @@ -773,7 +847,7 @@ impl Builder { /// /// When a stream is explicitly reset by either calling /// [`SendResponse::send_reset`] or by dropping a [`SendResponse`] instance - /// before completing the stream, the HTTP/2.0 specification requires that + /// before completing the stream, the HTTP/2 specification requires that /// any further frames received for that stream must be ignored for "some /// time". /// @@ -800,7 +874,7 @@ impl Builder { /// # fn doc(my_io: T) /// # -> Handshake /// # { - /// // `server_fut` is a future representing the completion of the HTTP/2.0 + /// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut = Builder::new() /// .max_concurrent_reset_streams(1000) @@ -815,11 +889,90 @@ impl Builder { self } + /// Sets the maximum number of local resets due to protocol errors made by the remote end. + /// + /// Invalid frames and many other protocol errors will lead to resets being generated for those streams. + /// Too many of these often indicate a malicious client, and there are attacks which can abuse this to DOS servers. + /// This limit protects against these DOS attacks by limiting the amount of resets we can be forced to generate. + /// + /// When the number of local resets exceeds this threshold, the server will issue GOAWAYs with an error code of + /// `ENHANCE_YOUR_CALM` to the client. + /// + /// If you really want to disable this, supply [`Option::None`] here. + /// Disabling this is not recommended and may expose you to DOS attacks. + /// + /// The default value is currently 1024, but could change. + pub fn max_local_error_reset_streams(&mut self, max: Option) -> &mut Self { + self.local_max_error_reset_streams = max; + self + } + + /// Sets the maximum number of pending-accept remotely-reset streams. + /// + /// Streams that have been received by the peer, but not accepted by the + /// user, can also receive a RST_STREAM. This is a legitimate pattern: one + /// could send a request and then shortly after, realize it is not needed, + /// sending a CANCEL. + /// + /// However, since those streams are now "closed", they don't count towards + /// the max concurrent streams. So, they will sit in the accept queue, + /// using memory. + /// + /// When the number of remotely-reset streams sitting in the pending-accept + /// queue reaches this maximum value, a connection error with the code of + /// `ENHANCE_YOUR_CALM` will be sent to the peer, and returned by the + /// `Future`. + /// + /// The default value is currently 20, but could change. + /// + /// # Examples + /// + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use h2::server::*; + /// # + /// # fn doc(my_io: T) + /// # -> Handshake + /// # { + /// // `server_fut` is a future representing the completion of the HTTP/2 + /// // handshake. + /// let server_fut = Builder::new() + /// .max_pending_accept_reset_streams(100) + /// .handshake(my_io); + /// # server_fut + /// # } + /// # + /// # pub fn main() {} + /// ``` + pub fn max_pending_accept_reset_streams(&mut self, max: usize) -> &mut Self { + self.pending_accept_reset_stream_max = max; + self + } + + /// Sets the maximum send buffer size per stream. + /// + /// Once a stream has buffered up to (or over) the maximum, the stream's + /// flow control will not "poll" additional capacity. Once bytes for the + /// stream have been written to the connection, the send buffer capacity + /// will be freed up again. + /// + /// The default is currently ~400KB, but may change. + /// + /// # Panics + /// + /// This function panics if `max` is larger than `u32::MAX`. + pub fn max_send_buffer_size(&mut self, max: usize) -> &mut Self { + assert!(max <= std::u32::MAX as usize); + self.max_send_buffer_size = max; + self + } + /// Sets the maximum number of concurrent locally reset streams. /// /// When a stream is explicitly reset by either calling /// [`SendResponse::send_reset`] or by dropping a [`SendResponse`] instance - /// before completing the stream, the HTTP/2.0 specification requires that + /// before completing the stream, the HTTP/2 specification requires that /// any further frames received for that stream must be ignored for "some /// time". /// @@ -847,7 +1000,7 @@ impl Builder { /// # fn doc(my_io: T) /// # -> Handshake /// # { - /// // `server_fut` is a future representing the completion of the HTTP/2.0 + /// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut = Builder::new() /// .reset_stream_duration(Duration::from_secs(10)) @@ -862,18 +1015,26 @@ impl Builder { self } - /// Creates a new configured HTTP/2.0 server backed by `io`. + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + pub fn enable_connect_protocol(&mut self) -> &mut Self { + self.settings.set_enable_connect_protocol(Some(1)); + self + } + + /// Creates a new configured HTTP/2 server backed by `io`. /// /// It is expected that `io` already be in an appropriate state to commence - /// the [HTTP/2.0 handshake]. See [Handshake] for more details. + /// the [HTTP/2 handshake]. See [Handshake] for more details. /// /// Returns a future which resolves to the [`Connection`] instance once the - /// HTTP/2.0 handshake has been completed. + /// HTTP/2 handshake has been completed. /// /// This function also allows the caller to configure the send payload data /// type. See [Outbound data type] for more details. /// - /// [HTTP/2.0 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader + /// [HTTP/2 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader /// [Handshake]: ../index.html#handshake /// [`Connection`]: struct.Connection.html /// [Outbound data type]: ../index.html#outbound-data-type. @@ -889,7 +1050,7 @@ impl Builder { /// # fn doc(my_io: T) /// # -> Handshake /// # { - /// // `server_fut` is a future representing the completion of the HTTP/2.0 + /// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut = Builder::new() /// .handshake(my_io); @@ -909,7 +1070,7 @@ impl Builder { /// # fn doc(my_io: T) /// # -> Handshake /// # { - /// // `server_fut` is a future representing the completion of the HTTP/2.0 + /// // `server_fut` is a future representing the completion of the HTTP/2 /// // handshake. /// let server_fut: Handshake<_, &'static [u8]> = Builder::new() /// .handshake(my_io); @@ -921,7 +1082,7 @@ impl Builder { pub fn handshake(&self, io: T) -> Handshake where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { Connection::handshake2(io, self.clone()) } @@ -1019,7 +1180,7 @@ impl SendResponse { /// /// # Panics /// - /// If the lock on the strean store has been poisoned. + /// If the lock on the stream store has been poisoned. pub fn stream_id(&self) -> crate::StreamId { crate::StreamId::from_internal(self.inner.stream_id()) } @@ -1091,7 +1252,7 @@ impl SendPushedResponse { /// /// # Panics /// - /// If the lock on the strean store has been poisoned. + /// If the lock on the stream store has been poisoned. pub fn stream_id(&self) -> crate::StreamId { self.inner.stream_id() } @@ -1146,8 +1307,10 @@ where let mut rem = PREFACE.len() - self.pos; while rem > 0 { - let n = ready!(Pin::new(self.inner_mut()).poll_read(cx, &mut buf[..rem])) + let mut buf = ReadBuf::new(&mut buf[..rem]); + ready!(Pin::new(self.inner_mut()).poll_read(cx, &mut buf)) .map_err(crate::Error::from_io)?; + let n = buf.filled().len(); if n == 0 { return Poll::Ready(Err(crate::Error::from_io(io::Error::new( io::ErrorKind::UnexpectedEof, @@ -1155,10 +1318,10 @@ where )))); } - if PREFACE[self.pos..self.pos + n] != buf[..n] { + if &PREFACE[self.pos..self.pos + n] != buf.filled() { proto_err!(conn: "read_preface: invalid preface"); // TODO: Should this just write the GO_AWAY frame directly? - return Poll::Ready(Err(Reason::PROTOCOL_ERROR.into())); + return Poll::Ready(Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into())); } self.pos += n; @@ -1174,68 +1337,70 @@ where impl Future for Handshake where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { type Output = Result, crate::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - log::trace!("Handshake::poll(); state={:?};", self.state); - use crate::server::Handshaking::*; - - self.state = if let Flushing(ref mut flush) = self.state { - // We're currently flushing a pending SETTINGS frame. Poll the - // flush future, and, if it's completed, advance our state to wait - // for the client preface. - let codec = match Pin::new(flush).poll(cx)? { - Poll::Pending => { - log::trace!("Handshake::poll(); flush.poll()=Pending"); - return Poll::Pending; + let span = self.span.clone(); // XXX(eliza): T_T + let _e = span.enter(); + tracing::trace!(state = ?self.state); + + loop { + match &mut self.state { + Handshaking::Flushing(flush) => { + // We're currently flushing a pending SETTINGS frame. Poll the + // flush future, and, if it's completed, advance our state to wait + // for the client preface. + let codec = match Pin::new(flush).poll(cx)? { + Poll::Pending => { + tracing::trace!(flush.poll = %"Pending"); + return Poll::Pending; + } + Poll::Ready(flushed) => { + tracing::trace!(flush.poll = %"Ready"); + flushed + } + }; + self.state = Handshaking::ReadingPreface( + ReadPreface::new(codec).instrument(tracing::trace_span!("read_preface")), + ); } - Poll::Ready(flushed) => { - log::trace!("Handshake::poll(); flush.poll()=Ready"); - flushed + Handshaking::ReadingPreface(read) => { + let codec = ready!(Pin::new(read).poll(cx)?); + + self.state = Handshaking::Done; + + let connection = proto::Connection::new( + codec, + Config { + next_stream_id: 2.into(), + // Server does not need to locally initiate any streams + initial_max_send_streams: 0, + max_send_buffer_size: self.builder.max_send_buffer_size, + reset_stream_duration: self.builder.reset_stream_duration, + reset_stream_max: self.builder.reset_stream_max, + remote_reset_stream_max: self.builder.pending_accept_reset_stream_max, + local_error_reset_streams_max: self + .builder + .local_max_error_reset_streams, + settings: self.builder.settings.clone(), + }, + ); + + tracing::trace!("connection established!"); + let mut c = Connection { connection }; + if let Some(sz) = self.builder.initial_target_connection_window_size { + c.set_target_window_size(sz); + } + + return Poll::Ready(Ok(c)); + } + Handshaking::Done => { + panic!("Handshaking::poll() called again after handshaking was complete") } - }; - Handshaking::from(ReadPreface::new(codec)) - } else { - // Otherwise, we haven't actually advanced the state, but we have - // to replace it with itself, because we have to return a value. - // (note that the assignment to `self.state` has to be outside of - // the `if let` block above in order to placate the borrow checker). - mem::replace(&mut self.state, Handshaking::Empty) - }; - let poll = if let ReadingPreface(ref mut read) = self.state { - // We're now waiting for the client preface. Poll the `ReadPreface` - // future. If it has completed, we will create a `Connection` handle - // for the connection. - Pin::new(read).poll(cx) - // Actually creating the `Connection` has to occur outside of this - // `if let` block, because we've borrowed `self` mutably in order - // to poll the state and won't be able to borrow the SETTINGS frame - // as well until we release the borrow for `poll()`. - } else { - unreachable!("Handshake::poll() state was not advanced completely!") - }; - poll?.map(|codec| { - let connection = proto::Connection::new( - codec, - Config { - next_stream_id: 2.into(), - // Server does not need to locally initiate any streams - initial_max_send_streams: 0, - reset_stream_duration: self.builder.reset_stream_duration, - reset_stream_max: self.builder.reset_stream_max, - settings: self.builder.settings.clone(), - }, - ); - - log::trace!("Handshake::poll(); connection established!"); - let mut c = Connection { connection }; - if let Some(sz) = self.builder.initial_target_connection_window_size { - c.set_target_window_size(sz); } - Ok(c) - }) + } } } @@ -1289,15 +1454,15 @@ impl Peer { if let Err(e) = frame::PushPromise::validate_request(&request) { use PushPromiseHeaderError::*; match e { - NotSafeAndCacheable => log::debug!( - "convert_push_message: method {} is not safe and cacheable; promised_id={:?}", + NotSafeAndCacheable => tracing::debug!( + ?promised_id, + "convert_push_message: method {} is not safe and cacheable", request.method(), - promised_id, ), - InvalidContentLength(e) => log::debug!( - "convert_push_message; promised request has invalid content-length {:?}; promised_id={:?}", + InvalidContentLength(e) => tracing::debug!( + ?promised_id, + "convert_push_message; promised request has invalid content-length {:?}", e, - promised_id, ), } return Err(UserError::MalformedHeaders); @@ -1314,7 +1479,7 @@ impl Peer { _, ) = request.into_parts(); - let pseudo = Pseudo::request(method, uri); + let pseudo = Pseudo::request(method, uri, None); Ok(frame::PushPromise::new( stream_id, @@ -1328,9 +1493,13 @@ impl Peer { impl proto::Peer for Peer { type Poll = Request<()>; + const NAME: &'static str = "Server"; + + /* fn is_server() -> bool { true } + */ fn r#dyn() -> proto::DynPeer { proto::DynPeer::Server @@ -1340,20 +1509,17 @@ impl proto::Peer for Peer { pseudo: Pseudo, fields: HeaderMap, stream_id: StreamId, - ) -> Result { + ) -> Result { use http::{uri, Version}; let mut b = Request::builder(); macro_rules! malformed { ($($arg:tt)*) => {{ - log::debug!($($arg)*); - return Err(RecvError::Stream { - id: stream_id, - reason: Reason::PROTOCOL_ERROR, - }); + tracing::debug!($($arg)*); + return Err(Error::library_reset(stream_id, Reason::PROTOCOL_ERROR)); }} - }; + } b = b.version(Version::HTTP_2); @@ -1365,10 +1531,18 @@ impl proto::Peer for Peer { malformed!("malformed headers: missing method"); } - // Specifying :status for a request is a protocol error + let has_protocol = pseudo.protocol.is_some(); + if has_protocol { + if is_connect { + // Assert that we have the right type. + b = b.extension::(pseudo.protocol.unwrap()); + } else { + malformed!("malformed headers: :protocol on non-CONNECT request"); + } + } + if pseudo.status.is_some() { - log::trace!("malformed headers: :status field on request; PROTOCOL_ERROR"); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + malformed!("malformed headers: :status field on request"); } // Convert the URI @@ -1389,8 +1563,8 @@ impl proto::Peer for Peer { // A :scheme is required, except CONNECT. if let Some(scheme) = pseudo.scheme { - if is_connect { - malformed!(":scheme in CONNECT"); + if is_connect && !has_protocol { + malformed!("malformed headers: :scheme in CONNECT"); } let maybe_scheme = scheme.parse(); let scheme = maybe_scheme.or_else(|why| { @@ -1407,13 +1581,13 @@ impl proto::Peer for Peer { if parts.authority.is_some() { parts.scheme = Some(scheme); } - } else if !is_connect { + } else if !is_connect || has_protocol { malformed!("malformed headers: missing scheme"); } if let Some(path) = pseudo.path { - if is_connect { - malformed!(":path in CONNECT"); + if is_connect && !has_protocol { + malformed!("malformed headers: :path in CONNECT"); } // This cannot be empty @@ -1425,6 +1599,8 @@ impl proto::Peer for Peer { parts.path_and_query = Some(maybe_path.or_else(|why| { malformed!("malformed headers: malformed path ({:?}): {}", path, why,) })?); + } else if is_connect && has_protocol { + malformed!("malformed headers: missing path in extended CONNECT"); } b = b.uri(parts); @@ -1435,10 +1611,7 @@ impl proto::Peer for Peer { // TODO: Should there be more specialized handling for different // kinds of errors proto_err!(stream: "error building request: {}; stream={:?}", e, stream_id); - return Err(RecvError::Stream { - id: stream_id, - reason: Reason::PROTOCOL_ERROR, - }); + return Err(Error::library_reset(stream_id, Reason::PROTOCOL_ERROR)); } }; @@ -1457,42 +1630,9 @@ where #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { - Handshaking::Flushing(_) => write!(f, "Handshaking::Flushing(_)"), - Handshaking::ReadingPreface(_) => write!(f, "Handshaking::ReadingPreface(_)"), - Handshaking::Empty => write!(f, "Handshaking::Empty"), + Handshaking::Flushing(_) => f.write_str("Flushing(_)"), + Handshaking::ReadingPreface(_) => f.write_str("ReadingPreface(_)"), + Handshaking::Done => f.write_str("Done"), } } } - -impl convert::From>> for Handshaking -where - T: AsyncRead + AsyncWrite, - B: Buf, -{ - #[inline] - fn from(flush: Flush>) -> Self { - Handshaking::Flushing(flush) - } -} - -impl convert::From>> for Handshaking -where - T: AsyncRead + AsyncWrite, - B: Buf, -{ - #[inline] - fn from(read: ReadPreface>) -> Self { - Handshaking::ReadingPreface(read) - } -} - -impl convert::From>> for Handshaking -where - T: AsyncRead + AsyncWrite, - B: Buf, -{ - #[inline] - fn from(codec: Codec>) -> Self { - Handshaking::from(Flush::new(codec)) - } -} diff --git a/src/share.rs b/src/share.rs index 06291068d..26b428797 100644 --- a/src/share.rs +++ b/src/share.rs @@ -5,7 +5,6 @@ use crate::proto::{self, WindowSize}; use bytes::{Buf, Bytes}; use http::HeaderMap; -use crate::PollExt; use std::fmt; #[cfg(feature = "stream")] use std::pin::Pin; @@ -16,7 +15,7 @@ use std::task::{Context, Poll}; /// # Overview /// /// A `SendStream` is provided by [`SendRequest`] and [`SendResponse`] once the -/// HTTP/2.0 message header has been sent sent. It is used to stream the message +/// HTTP/2 message header has been sent sent. It is used to stream the message /// body and send the message trailers. See method level documentation for more /// details. /// @@ -35,7 +34,7 @@ use std::task::{Context, Poll}; /// /// # Flow control /// -/// In HTTP/2.0, data cannot be sent to the remote peer unless there is +/// In HTTP/2, data cannot be sent to the remote peer unless there is /// available window capacity on both the stream and the connection. When a data /// frame is sent, both the stream window and the connection window are /// decremented. When the stream level window reaches zero, no further data can @@ -44,7 +43,7 @@ use std::task::{Context, Poll}; /// /// When the remote peer is ready to receive more data, it sends `WINDOW_UPDATE` /// frames. These frames increment the windows. See the [specification] for more -/// details on the principles of HTTP/2.0 flow control. +/// details on the principles of HTTP/2 flow control. /// /// The implications for sending data are that the caller **should** ensure that /// both the stream and the connection has available window capacity before @@ -95,7 +94,7 @@ use std::task::{Context, Poll}; /// [`send_trailers`]: #method.send_trailers /// [`send_reset`]: #method.send_reset #[derive(Debug)] -pub struct SendStream { +pub struct SendStream { inner: proto::StreamRef, } @@ -109,13 +108,19 @@ pub struct SendStream { /// new stream. /// /// [Section 5.1.1]: https://tools.ietf.org/html/rfc7540#section-5.1.1 -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub struct StreamId(u32); +impl From for u32 { + fn from(src: StreamId) -> Self { + src.0 + } +} + /// Receives the body stream and trailers from the remote peer. /// /// A `RecvStream` is provided by [`client::ResponseFuture`] and -/// [`server::Connection`] with the received HTTP/2.0 message head (the response +/// [`server::Connection`] with the received HTTP/2 message head (the response /// and request head respectively). /// /// A `RecvStream` instance is used to receive the streaming message body and @@ -125,11 +130,6 @@ pub struct StreamId(u32); /// See method level documentation for more details on receiving data. See /// [`FlowControl`] for more details on inbound flow control. /// -/// Note that this type implements [`Stream`], yielding the received data frames. -/// When this implementation is used, the capacity is immediately released when -/// the data is yielded. It is recommended to only use this API when the data -/// will not be retained in memory for extended periods of time. -/// /// [`client::ResponseFuture`]: client/struct.ResponseFuture.html /// [`server::Connection`]: server/struct.Connection.html /// [`FlowControl`]: struct.FlowControl.html @@ -173,12 +173,12 @@ pub struct RecvStream { /// /// # Scenarios /// -/// Following is a basic scenario with an HTTP/2.0 connection containing a +/// Following is a basic scenario with an HTTP/2 connection containing a /// single active stream. /// /// * A new stream is activated. The receive window is initialized to 1024 (the /// value of the initial window size for this connection). -/// * A `DATA` frame is received containing a payload of 400 bytes. +/// * A `DATA` frame is received containing a payload of 600 bytes. /// * The receive window size is reduced to 424 bytes. /// * [`release_capacity`] is called with 200. /// * The receive window size is now 624 bytes. The peer may send no more than @@ -312,8 +312,8 @@ impl SendStream { pub fn poll_capacity(&mut self, cx: &mut Context) -> Poll>> { self.inner .poll_capacity(cx) - .map_ok_(|w| w as usize) - .map_err_(Into::into) + .map_ok(|w| w as usize) + .map_err(Into::into) } /// Sends a single data frame to the remote peer. @@ -388,6 +388,18 @@ impl StreamId { pub(crate) fn from_internal(id: crate::frame::StreamId) -> Self { StreamId(id.into()) } + + /// Returns the `u32` corresponding to this `StreamId` + /// + /// # Note + /// + /// This is the same as the `From` implementation, but + /// included as an inherent method because that implementation doesn't + /// appear in rustdocs, as well as a way to force the type instead of + /// relying on inference. + pub fn as_u32(&self) -> u32 { + (*self).into() + } } // ===== impl RecvStream ===== @@ -406,9 +418,9 @@ impl RecvStream { futures_util::future::poll_fn(move |cx| self.poll_trailers(cx)).await } - #[doc(hidden)] + /// Poll for the next data frame. pub fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll>> { - self.inner.inner.poll_data(cx).map_err_(Into::into) + self.inner.inner.poll_data(cx).map_err(Into::into) } #[doc(hidden)] @@ -544,8 +556,8 @@ impl PingPong { pub fn send_ping(&mut self, ping: Ping) -> Result<(), crate::Error> { // Passing a `Ping` here is just to be forwards-compatible with // eventually allowing choosing a ping payload. For now, we can - // just drop it. - drop(ping); + // just ignore it. + let _ = ping; self.inner.send_ping().map_err(|err| match err { Some(err) => err.into(), diff --git a/tests/h2-fuzz/Cargo.toml b/tests/h2-fuzz/Cargo.toml index d119aedf4..b0f9599e9 100644 --- a/tests/h2-fuzz/Cargo.toml +++ b/tests/h2-fuzz/Cargo.toml @@ -8,8 +8,8 @@ edition = "2018" [dependencies] h2 = { path = "../.." } -env_logger = { version = "0.5.3", default-features = false } -futures = { version = "0.3", default-features = false } +env_logger = { version = "0.9", default-features = false } +futures = { version = "0.3", default-features = false, features = ["std"] } honggfuzz = "0.5" -http = "0.2" -tokio = { version = "0.2", features = [] } +http = "1" +tokio = { version = "1", features = [ "full" ] } diff --git a/tests/h2-fuzz/src/main.rs b/tests/h2-fuzz/src/main.rs index a57fb76a5..28905524b 100644 --- a/tests/h2-fuzz/src/main.rs +++ b/tests/h2-fuzz/src/main.rs @@ -1,132 +1,128 @@ -use futures::future; -use futures::stream::FuturesUnordered; -use futures::Stream; -use http::{Method, Request}; -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; - -struct MockIo<'a> { - input: &'a [u8], -} - -impl<'a> MockIo<'a> { - fn next_byte(&mut self) -> Option { - if let Some(&c) = self.input.first() { - self.input = &self.input[1..]; - Some(c) - } else { - None - } - } - - fn next_u32(&mut self) -> u32 { - (self.next_byte().unwrap_or(0) as u32) << 8 | self.next_byte().unwrap_or(0) as u32 - } -} - -impl<'a> AsyncRead for MockIo<'a> { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { - false - } - - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - let mut len = self.next_u32() as usize; - if self.input.is_empty() { - Poll::Ready(Ok(0)) - } else if len == 0 { - cx.waker().clone().wake(); - Poll::Pending - } else { - if len > self.input.len() { - len = self.input.len(); - } - - if len > buf.len() { - len = buf.len(); - } - buf[0..len].copy_from_slice(&self.input[0..len]); - self.input = &self.input[len..]; - Poll::Ready(Ok(len)) - } - } -} - -impl<'a> AsyncWrite for MockIo<'a> { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let len = std::cmp::min(self.next_u32() as usize, buf.len()); - if len == 0 { - if self.input.is_empty() { - Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) - } else { - cx.waker().clone().wake(); - Poll::Pending - } - } else { - Poll::Ready(Ok(len)) - } - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -async fn run(script: &[u8]) -> Result<(), h2::Error> { - let io = MockIo { input: script }; - let (mut h2, mut connection) = h2::client::handshake(io).await?; - let mut futs = FuturesUnordered::new(); - let future = future::poll_fn(|cx| { - if let Poll::Ready(()) = Pin::new(&mut connection).poll(cx)? { - return Poll::Ready(Ok::<_, h2::Error>(())); - } - while futs.len() < 128 { - if !h2.poll_ready(cx)?.is_ready() { - break; - } - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); - let (resp, mut send) = h2.send_request(request, false)?; - send.send_data(vec![0u8; 32769].into(), true).unwrap(); - drop(send); - futs.push(resp); - } - loop { - match Pin::new(&mut futs).poll_next(cx) { - Poll::Pending | Poll::Ready(None) => break, - r @ Poll::Ready(Some(Ok(_))) | r @ Poll::Ready(Some(Err(_))) => { - eprintln!("{:?}", r); - } - } - } - Poll::Pending - }); - future.await?; - Ok(()) -} - -fn main() { - env_logger::init(); - let mut rt = tokio::runtime::Runtime::new().unwrap(); - loop { - honggfuzz::fuzz!(|data: &[u8]| { - eprintln!("{:?}", rt.block_on(run(data))); - }); - } -} +use futures::future; +use futures::stream::FuturesUnordered; +use futures::Stream; +use http::{Method, Request}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +struct MockIo<'a> { + input: &'a [u8], +} + +impl<'a> MockIo<'a> { + fn next_byte(&mut self) -> Option { + if let Some(&c) = self.input.first() { + self.input = &self.input[1..]; + Some(c) + } else { + None + } + } + + fn next_u32(&mut self) -> u32 { + (self.next_byte().unwrap_or(0) as u32) << 8 | self.next_byte().unwrap_or(0) as u32 + } +} + +impl<'a> AsyncRead for MockIo<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll> { + let mut len = self.next_u32() as usize; + if self.input.is_empty() { + Poll::Ready(Ok(())) + } else if len == 0 { + cx.waker().clone().wake(); + Poll::Pending + } else { + if len > self.input.len() { + len = self.input.len(); + } + + if len > buf.remaining() { + len = buf.remaining(); + } + buf.put_slice(&self.input[len..]); + self.input = &self.input[len..]; + Poll::Ready(Ok(())) + } + } +} + +impl<'a> AsyncWrite for MockIo<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let len = std::cmp::min(self.next_u32() as usize, buf.len()); + if len == 0 { + if self.input.is_empty() { + Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } else { + cx.waker().clone().wake(); + Poll::Pending + } + } else { + Poll::Ready(Ok(len)) + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +async fn run(script: &[u8]) -> Result<(), h2::Error> { + let io = MockIo { input: script }; + let (mut h2, mut connection) = h2::client::handshake(io).await?; + let mut futs = FuturesUnordered::new(); + let future = future::poll_fn(|cx| { + if let Poll::Ready(()) = Pin::new(&mut connection).poll(cx)? { + return Poll::Ready(Ok::<_, h2::Error>(())); + } + while futs.len() < 128 { + if !h2.poll_ready(cx)?.is_ready() { + break; + } + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + let (resp, mut send) = h2.send_request(request, false)?; + send.send_data(vec![0u8; 32769].into(), true).unwrap(); + drop(send); + futs.push(resp); + } + loop { + match Pin::new(&mut futs).poll_next(cx) { + Poll::Pending | Poll::Ready(None) => break, + r @ Poll::Ready(Some(Ok(_))) | r @ Poll::Ready(Some(Err(_))) => { + eprintln!("{:?}", r); + } + } + } + Poll::Pending + }); + future.await?; + Ok(()) +} + +fn main() { + env_logger::init(); + let rt = tokio::runtime::Runtime::new().unwrap(); + loop { + honggfuzz::fuzz!(|data: &[u8]| { + eprintln!("{:?}", rt.block_on(run(data))); + }); + } +} diff --git a/tests/h2-support/Cargo.toml b/tests/h2-support/Cargo.toml index b48dc36a6..970648d5a 100644 --- a/tests/h2-support/Cargo.toml +++ b/tests/h2-support/Cargo.toml @@ -2,14 +2,18 @@ name = "h2-support" version = "0.1.0" authors = ["Carl Lerche "] +publish = false edition = "2018" [dependencies] h2 = { path = "../..", features = ["stream", "unstable"] } -bytes = "0.5" -env_logger = "0.5.9" +atty = "0.2" +bytes = "1" +tracing = "0.1" +tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt"] } +tracing-tree = "0.2" futures = { version = "0.3", default-features = false } -http = "0.2" -tokio = { version = "0.2", features = ["time"] } -tokio-test = "0.2" +http = "1" +tokio = { version = "1", features = ["time"] } +tokio-test = "0.4" diff --git a/tests/h2-support/src/assert.rs b/tests/h2-support/src/assert.rs index 8bc6d25c7..88e3d4f7c 100644 --- a/tests/h2-support/src/assert.rs +++ b/tests/h2-support/src/assert.rs @@ -47,6 +47,16 @@ macro_rules! assert_settings { }}; } +#[macro_export] +macro_rules! assert_go_away { + ($frame:expr) => {{ + match $frame { + h2::frame::Frame::GoAway(v) => v, + f => panic!("expected GO_AWAY; actual={:?}", f), + } + }}; +} + #[macro_export] macro_rules! poll_err { ($transport:expr) => {{ @@ -80,6 +90,7 @@ macro_rules! assert_default_settings { use h2::frame::Frame; +#[track_caller] pub fn assert_frame_eq, U: Into>(t: T, u: U) { let actual: Frame = t.into(); let expected: Frame = u.into(); diff --git a/tests/h2-support/src/client_ext.rs b/tests/h2-support/src/client_ext.rs index a9ab71d99..eebbae98b 100644 --- a/tests/h2-support/src/client_ext.rs +++ b/tests/h2-support/src/client_ext.rs @@ -11,7 +11,7 @@ pub trait SendRequestExt { impl SendRequestExt for SendRequest where - B: Buf + Unpin + 'static, + B: Buf, { fn get(&mut self, uri: &str) -> ResponseFuture { let req = Request::builder() diff --git a/tests/h2-support/src/frames.rs b/tests/h2-support/src/frames.rs index b9393b2b5..858bf770b 100644 --- a/tests/h2-support/src/frames.rs +++ b/tests/h2-support/src/frames.rs @@ -2,12 +2,15 @@ use std::convert::TryInto; use std::fmt; use bytes::Bytes; -use http::{self, HeaderMap}; +use http::{HeaderMap, StatusCode}; -use h2::frame::{self, Frame, StreamId}; +use h2::{ + ext::Protocol, + frame::{self, Frame, StreamId}, +}; -pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; -pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; +pub const SETTINGS: &[u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; +pub const SETTINGS_ACK: &[u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; // ==== helper functions to easily construct h2 Frames ==== @@ -109,7 +112,9 @@ impl Mock { let method = method.try_into().unwrap(); let uri = uri.try_into().unwrap(); let (id, _, fields) = self.into_parts(); - let frame = frame::Headers::new(id, frame::Pseudo::request(method, uri), fields); + let extensions = Default::default(); + let pseudo = frame::Pseudo::request(method, uri, extensions); + let frame = frame::Headers::new(id, pseudo, fields); Mock(frame) } @@ -162,6 +167,14 @@ impl Mock { Mock(frame) } + pub fn status(self, value: StatusCode) -> Self { + let (id, mut pseudo, fields) = self.into_parts(); + + pseudo.set_status(value); + + Mock(frame::Headers::new(id, pseudo, fields)) + } + pub fn scheme(self, value: &str) -> Self { let (id, mut pseudo, fields) = self.into_parts(); let value = value.parse().unwrap(); @@ -171,6 +184,15 @@ impl Mock { Mock(frame::Headers::new(id, pseudo, fields)) } + pub fn protocol(self, value: &str) -> Self { + let (id, mut pseudo, fields) = self.into_parts(); + let value = Protocol::from(value); + + pseudo.set_protocol(value); + + Mock(frame::Headers::new(id, pseudo, fields)) + } + pub fn eos(mut self) -> Self { self.0.set_end_stream(); self @@ -222,8 +244,9 @@ impl Mock { let method = method.try_into().unwrap(); let uri = uri.try_into().unwrap(); let (id, promised, _, fields) = self.into_parts(); - let frame = - frame::PushPromise::new(id, promised, frame::Pseudo::request(method, uri), fields); + let extensions = Default::default(); + let pseudo = frame::Pseudo::request(method, uri, extensions); + let frame = frame::PushPromise::new(id, promised, pseudo, fields); Mock(frame) } @@ -274,12 +297,31 @@ impl Mock { self.reason(frame::Reason::FRAME_SIZE_ERROR) } + pub fn calm(self) -> Self { + self.reason(frame::Reason::ENHANCE_YOUR_CALM) + } + pub fn no_error(self) -> Self { self.reason(frame::Reason::NO_ERROR) } + pub fn data(self, debug_data: I) -> Self + where + I: Into, + { + Mock(frame::GoAway::with_debug_data( + self.0.last_stream_id(), + self.0.reason(), + debug_data.into(), + )) + } + pub fn reason(self, reason: frame::Reason) -> Self { - Mock(frame::GoAway::new(self.0.last_stream_id(), reason)) + Mock(frame::GoAway::with_debug_data( + self.0.last_stream_id(), + reason, + self.0.debug_data().clone(), + )) } } @@ -339,6 +381,21 @@ impl Mock { self.0.set_max_header_list_size(Some(val)); self } + + pub fn disable_push(mut self) -> Self { + self.0.set_enable_push(false); + self + } + + pub fn enable_connect_protocol(mut self, val: u32) -> Self { + self.0.set_enable_connect_protocol(Some(val)); + self + } + + pub fn header_table_size(mut self, val: u32) -> Self { + self.0.set_header_table_size(Some(val)); + self + } } impl From> for frame::Settings { diff --git a/tests/h2-support/src/lib.rs b/tests/h2-support/src/lib.rs index d88f6cabf..3c13c0afe 100644 --- a/tests/h2-support/src/lib.rs +++ b/tests/h2-support/src/lib.rs @@ -8,6 +8,7 @@ pub mod raw; pub mod frames; pub mod mock; pub mod prelude; +pub mod trace; pub mod util; mod client_ext; @@ -24,3 +25,19 @@ pub type Codec = h2::Codec; // This is the frame type that is sent pub type SendFrame = h2::frame::Frame; + +#[macro_export] +macro_rules! trace_init { + () => { + let _guard = $crate::trace::init(); + let span = $crate::prelude::tracing::info_span!( + "test", + "{}", + // get the name of the test thread to generate a unique span for the test + std::thread::current() + .name() + .expect("test threads must be named") + ); + let _e = span.enter(); + }; +} diff --git a/tests/h2-support/src/mock.rs b/tests/h2-support/src/mock.rs index 08837fa56..9ec5ba379 100644 --- a/tests/h2-support/src/mock.rs +++ b/tests/h2-support/src/mock.rs @@ -1,12 +1,13 @@ use crate::SendFrame; use h2::frame::{self, Frame}; -use h2::{self, RecvError, SendError}; +use h2::proto::Error; +use h2::SendError; use futures::future::poll_fn; use futures::{ready, Stream, StreamExt}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use super::assert::assert_frame_eq; use std::pin::Pin; @@ -53,9 +54,12 @@ struct Inner { /// True when the pipe is closed. closed: bool, + + /// Trigger an `UnexpectedEof` error on read + unexpected_eof: bool, } -const PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; +const PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; /// Create a new mock and handle pub fn new() -> (Mock, Handle) { @@ -72,6 +76,7 @@ pub fn new_with_write_capacity(cap: usize) -> (Mock, Handle) { tx_rem: cap, tx_rem_task: None, closed: false, + unexpected_eof: false, })); let mock = Mock { @@ -95,6 +100,11 @@ impl Handle { &mut self.codec } + pub fn close_without_notify(&mut self) { + let mut me = self.codec.get_mut().inner.lock().unwrap(); + me.unexpected_eof = true; + } + /// Send a frame pub async fn send(&mut self, item: SendFrame) -> Result<(), SendError> { // Queue the frame @@ -147,10 +157,11 @@ impl Handle { poll_fn(move |cx| { while buf.has_remaining() { let res = Pin::new(self.codec.get_mut()) - .poll_write_buf(cx, &mut buf) + .poll_write(cx, buf.chunk()) .map_err(|e| panic!("write err={:?}", e)); - ready!(res).unwrap(); + let n = ready!(res).unwrap(); + buf.advance(n); } Poll::Ready(()) @@ -219,22 +230,15 @@ impl Handle { let settings = settings.into(); self.send(settings.into()).await.unwrap(); - let frame = self.next().await; - let settings = match frame { - Some(frame) => match frame.unwrap() { - Frame::Settings(settings) => { - // Send the ACK - let ack = frame::Settings::ack(); + let frame = self.next().await.expect("unexpected EOF").unwrap(); + let settings = assert_settings!(frame); - // TODO: Don't unwrap? - self.send(ack.into()).await.unwrap(); + // Send the ACK + let ack = frame::Settings::ack(); + + // TODO: Don't unwrap? + self.send(ack.into()).await.unwrap(); - settings - } - frame => panic!("unexpected frame; frame={:?}", frame), - }, - None => panic!("unexpected EOF"), - }; let frame = self.next().await; let f = assert_settings!(frame.unwrap().unwrap()); @@ -283,7 +287,7 @@ impl Handle { } impl Stream for Handle { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.codec).poll_next(cx) @@ -294,8 +298,8 @@ impl AsyncRead for Handle { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { Pin::new(self.codec.get_mut()).poll_read(cx, buf) } } @@ -344,29 +348,36 @@ impl AsyncRead for Mock { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { assert!( - buf.len() > 0, + buf.remaining() > 0, "attempted read with zero length buffer... wut?" ); let mut me = self.pipe.inner.lock().unwrap(); + if me.unexpected_eof { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "Simulate an unexpected eof error", + ))); + } + if me.rx.is_empty() { if me.closed { - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } me.rx_task = Some(cx.waker().clone()); return Poll::Pending; } - let n = cmp::min(buf.len(), me.rx.len()); - buf[..n].copy_from_slice(&me.rx[..n]); + let n = cmp::min(buf.remaining(), me.rx.len()); + buf.put_slice(&me.rx[..n]); me.rx.drain(..n); - Poll::Ready(Ok(n)) + Poll::Ready(Ok(())) } } @@ -427,10 +438,10 @@ impl AsyncRead for Pipe { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { assert!( - buf.len() > 0, + buf.remaining() > 0, "attempted read with zero length buffer... wut?" ); @@ -438,18 +449,18 @@ impl AsyncRead for Pipe { if me.tx.is_empty() { if me.closed { - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } me.tx_task = Some(cx.waker().clone()); return Poll::Pending; } - let n = cmp::min(buf.len(), me.tx.len()); - buf[..n].copy_from_slice(&me.tx[..n]); + let n = cmp::min(buf.remaining(), me.tx.len()); + buf.put_slice(&me.tx[..n]); me.tx.drain(..n); - Poll::Ready(Ok(n)) + Poll::Ready(Ok(())) } } @@ -479,5 +490,5 @@ impl AsyncWrite for Pipe { } pub async fn idle_ms(ms: u64) { - tokio::time::delay_for(Duration::from_millis(ms)).await + tokio::time::sleep(Duration::from_millis(ms)).await } diff --git a/tests/h2-support/src/prelude.rs b/tests/h2-support/src/prelude.rs index 2e95b68b0..c40a518da 100644 --- a/tests/h2-support/src/prelude.rs +++ b/tests/h2-support/src/prelude.rs @@ -2,6 +2,7 @@ pub use h2; pub use h2::client; +pub use h2::ext::Protocol; pub use h2::frame::StreamId; pub use h2::server; pub use h2::*; @@ -20,15 +21,15 @@ pub use super::{Codec, SendFrame}; // Re-export macros pub use super::{ - assert_closed, assert_data, assert_default_settings, assert_headers, assert_ping, poll_err, - poll_frame, raw_codec, + assert_closed, assert_data, assert_default_settings, assert_go_away, assert_headers, + assert_ping, assert_settings, poll_err, poll_frame, raw_codec, }; pub use super::assert::assert_frame_eq; // Re-export useful crates pub use tokio_test::io as mock_io; -pub use {bytes, env_logger, futures, http, tokio::io as tokio_io}; +pub use {bytes, futures, http, tokio::io as tokio_io, tracing, tracing_subscriber}; // Re-export primary future types pub use futures::{Future, Sink, Stream}; @@ -42,10 +43,7 @@ pub use super::client_ext::SendRequestExt; // Re-export HTTP types pub use http::{uri, HeaderMap, Method, Request, Response, StatusCode, Version}; -pub use bytes::{ - buf::{BufExt, BufMutExt}, - Buf, BufMut, Bytes, BytesMut, -}; +pub use bytes::{Buf, BufMut, Bytes, BytesMut}; pub use tokio::io::{AsyncRead, AsyncWrite}; @@ -92,7 +90,7 @@ pub trait ClientExt { impl ClientExt for client::Connection where T: AsyncRead + AsyncWrite + Unpin + 'static, - B: Buf + Unpin + 'static, + B: Buf, { fn run<'a, F: Future + Unpin + 'a>( &'a mut self, @@ -105,7 +103,7 @@ where // Connection is done... b.await } - Right((v, _)) => return v, + Right((v, _)) => v, Left((Err(e), _)) => panic!("err: {:?}", e), } }) @@ -124,6 +122,7 @@ pub fn build_large_headers() -> Vec<(&'static str, String)> { ("eight", build_large_string('8', 4 * 1024)), ("nine", "nine".to_string()), ("ten", build_large_string('0', 4 * 1024)), + ("eleven", build_large_string('1', 32 * 1024)), ] } diff --git a/tests/h2-support/src/trace.rs b/tests/h2-support/src/trace.rs new file mode 100644 index 000000000..87038c350 --- /dev/null +++ b/tests/h2-support/src/trace.rs @@ -0,0 +1,17 @@ +pub use tracing; +pub use tracing_subscriber; + +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +pub fn init() -> tracing::dispatcher::DefaultGuard { + let use_colors = atty::is(atty::Stream::Stdout); + let layer = tracing_tree::HierarchicalLayer::default() + .with_writer(tracing_subscriber::fmt::writer::TestWriter::default()) + .with_indent_lines(true) + .with_ansi(use_colors) + .with_targets(true) + .with_indent_amount(2); + + tracing_subscriber::registry().with(layer).set_default() +} diff --git a/tests/h2-support/src/util.rs b/tests/h2-support/src/util.rs index ec768badc..02b6450d0 100644 --- a/tests/h2-support/src/util.rs +++ b/tests/h2-support/src/util.rs @@ -1,5 +1,3 @@ -use h2; - use bytes::{BufMut, Bytes}; use futures::ready; use std::future::Future; @@ -32,10 +30,11 @@ pub async fn yield_once() { .await; } +/// Should only be called after a non-0 capacity was requested for the stream. pub fn wait_for_capacity(stream: h2::SendStream, target: usize) -> WaitForCapacity { WaitForCapacity { stream: Some(stream), - target: target, + target, } } @@ -54,14 +53,19 @@ impl Future for WaitForCapacity { type Output = h2::SendStream; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let _ = ready!(self.stream().poll_capacity(cx)).unwrap(); + loop { + let _ = ready!(self.stream().poll_capacity(cx)).unwrap(); - let act = self.stream().capacity(); + let act = self.stream().capacity(); - if act >= self.target { - return Poll::Ready(self.stream.take().unwrap().into()); - } + // If a non-0 capacity was requested for the stream before calling + // wait_for_capacity, then poll_capacity should return Pending + // until there is a non-0 capacity. + assert_ne!(act, 0); - Poll::Pending + if act >= self.target { + return Poll::Ready(self.stream.take().unwrap()); + } + } } } diff --git a/tests/h2-tests/Cargo.toml b/tests/h2-tests/Cargo.toml index 3e9d130f3..6afdf9053 100644 --- a/tests/h2-tests/Cargo.toml +++ b/tests/h2-tests/Cargo.toml @@ -9,6 +9,6 @@ edition = "2018" [dev-dependencies] h2-support = { path = "../h2-support" } -log = "0.4.1" +tracing = "0.1.13" futures = { version = "0.3", default-features = false, features = ["alloc"] } -tokio = { version = "0.2", features = ["macros", "tcp"] } +tokio = { version = "1", features = ["macros", "net", "rt", "io-util", "rt-multi-thread"] } diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index b156d97c4..7bd223e3c 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -2,12 +2,13 @@ use futures::future::{join, ready, select, Either}; use futures::stream::FuturesUnordered; use futures::StreamExt; use h2_support::prelude::*; +use std::io; use std::pin::Pin; use std::task::Context; #[tokio::test] async fn handshake() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let mock = mock_io::Builder::new() .handshake() @@ -16,7 +17,7 @@ async fn handshake() { let (_client, h2) = client::handshake(mock).await.unwrap(); - log::trace!("hands have been shook"); + tracing::trace!("hands have been shook"); // At this point, the connection should be closed h2.await.unwrap(); @@ -24,7 +25,7 @@ async fn handshake() { #[tokio::test] async fn client_other_thread() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -60,7 +61,7 @@ async fn client_other_thread() { #[tokio::test] async fn recv_invalid_server_stream_id() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let mock = mock_io::Builder::new() .handshake() @@ -84,7 +85,7 @@ async fn recv_invalid_server_stream_id() { .body(()) .unwrap(); - log::info!("sending request"); + tracing::info!("sending request"); let (response, _) = client.send_request(request, true).unwrap(); // The connection errors @@ -96,7 +97,7 @@ async fn recv_invalid_server_stream_id() { #[tokio::test] async fn request_stream_id_overflows() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let h2 = async move { @@ -149,7 +150,7 @@ async fn request_stream_id_overflows() { #[tokio::test] async fn client_builder_max_concurrent_streams() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let mut settings = frame::Settings::default(); @@ -187,7 +188,7 @@ async fn client_builder_max_concurrent_streams() { #[tokio::test] async fn request_over_max_concurrent_streams_errors() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -239,6 +240,8 @@ async fn request_over_max_concurrent_streams_errors() { // first request is allowed let (resp1, mut stream1) = client.send_request(request, false).unwrap(); + // as long as we let the connection internals tick + client = h2.drive(client.ready()).await.unwrap(); let request = Request::builder() .method(Method::POST) @@ -284,9 +287,93 @@ async fn request_over_max_concurrent_streams_errors() { join(srv, h2).await; } +#[tokio::test] +async fn recv_decrement_max_concurrent_streams_when_requests_queued() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( + frames::headers(1) + .request("POST", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + + srv.ping_pong([0; 8]).await; + + // limit this server later in life + srv.send_frame(frames::settings().max_concurrent_streams(1)) + .await; + srv.recv_frame(frames::settings_ack()).await; + srv.recv_frame( + frames::headers(3) + .request("POST", "https://example.com/") + .eos(), + ) + .await; + srv.ping_pong([1; 8]).await; + srv.send_frame(frames::headers(3).response(200).eos()).await; + + srv.recv_frame( + frames::headers(5) + .request("POST", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(5).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.expect("handshake"); + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + // first request is allowed + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // first request is allowed + let (resp1, _) = client.send_request(request, true).unwrap(); + + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // second request is put into pending_open + let (resp2, _) = client.send_request(request, true).unwrap(); + + h2.drive(async move { + resp1.await.expect("req"); + }) + .await; + join(async move { h2.await.unwrap() }, async move { + resp2.await.unwrap() + }) + .await; + }; + + join(srv, h2).await; +} + #[tokio::test] async fn send_request_poll_ready_when_connection_error() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -336,6 +423,8 @@ async fn send_request_poll_ready_when_connection_error() { // first request is allowed let (resp1, _) = client.send_request(request, true).unwrap(); + // as long as we let the connection internals tick + client = h2.drive(client.ready()).await.unwrap(); let request = Request::builder() .method(Method::POST) @@ -371,7 +460,7 @@ async fn send_request_poll_ready_when_connection_error() { resp2.await.expect_err("req2"); })); - while let Some(_) = unordered.next().await {} + while unordered.next().await.is_some() {} }; join(srv, h2).await; @@ -379,7 +468,7 @@ async fn send_request_poll_ready_when_connection_error() { #[tokio::test] async fn send_reset_notifies_recv_stream() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -410,7 +499,11 @@ async fn send_reset_notifies_recv_stream() { }; let rx = async { let mut body = res.into_body(); - body.next().await.unwrap().expect_err("RecvBody"); + let err = body.next().await.unwrap().expect_err("RecvBody"); + assert_eq!( + err.to_string(), + "stream error sent by user: refused stream before processing any application logic" + ); }; // a FuturesUnordered is used on purpose! @@ -432,7 +525,7 @@ async fn send_reset_notifies_recv_stream() { #[tokio::test] async fn http_11_request_without_scheme_or_authority() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -462,7 +555,7 @@ async fn http_11_request_without_scheme_or_authority() { #[tokio::test] async fn http_2_request_without_scheme_or_authority() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -485,9 +578,8 @@ async fn http_2_request_without_scheme_or_authority() { client .send_request(request, true) .expect_err("should be UserError"); - let ret = h2.await.expect("h2"); + let _: () = h2.await.expect("h2"); drop(client); - ret }; join(srv, h2).await; @@ -499,7 +591,7 @@ fn request_with_h1_version() {} #[tokio::test] async fn request_with_connection_headers() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); // can't assert full handshake, since client never sends a request, and @@ -517,7 +609,7 @@ async fn request_with_connection_headers() { ("keep-alive", "5"), ("proxy-connection", "bar"), ("transfer-encoding", "chunked"), - ("upgrade", "HTTP/2.0"), + ("upgrade", "HTTP/2"), ("te", "boom"), ]; @@ -542,7 +634,7 @@ async fn request_with_connection_headers() { #[tokio::test] async fn connection_close_notifies_response_future() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { let settings = srv.assert_client_handshake().await; @@ -571,7 +663,7 @@ async fn connection_close_notifies_response_future() { .0 .await; let err = res.expect_err("response"); - assert_eq!(err.to_string(), "broken pipe"); + assert_eq!(err.to_string(), "stream closed because of a broken pipe"); }; join(async move { conn.await.expect("conn") }, req).await; }; @@ -581,7 +673,7 @@ async fn connection_close_notifies_response_future() { #[tokio::test] async fn connection_close_notifies_client_poll_ready() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -610,7 +702,7 @@ async fn connection_close_notifies_client_poll_ready() { .0 .await; let err = res.expect_err("response"); - assert_eq!(err.to_string(), "broken pipe"); + assert_eq!(err.to_string(), "stream closed because of a broken pipe"); }; conn.drive(req).await; @@ -618,7 +710,10 @@ async fn connection_close_notifies_client_poll_ready() { let err = poll_fn(move |cx| client.poll_ready(cx)) .await .expect_err("poll_ready"); - assert_eq!(err.to_string(), "broken pipe"); + assert_eq!( + err.to_string(), + "connection closed because of a broken pipe" + ); }; join(srv, client).await; @@ -626,7 +721,7 @@ async fn connection_close_notifies_client_poll_ready() { #[tokio::test] async fn sending_request_on_closed_connection() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -672,7 +767,7 @@ async fn sending_request_on_closed_connection() { }; let poll_err = poll_fn(|cx| client.poll_ready(cx)).await.unwrap_err(); - let msg = "protocol error: unspecific protocol error detected"; + let msg = "connection error detected: unspecific protocol error detected"; assert_eq!(poll_err.to_string(), msg); let request = Request::builder() @@ -688,7 +783,7 @@ async fn sending_request_on_closed_connection() { #[tokio::test] async fn recv_too_big_headers() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -708,7 +803,7 @@ async fn recv_too_big_headers() { .await; srv.send_frame(frames::headers(1).response(200).eos()).await; srv.send_frame(frames::headers(3).response(200)).await; - // no reset for 1, since it's closed anyways + // no reset for 1, since it's closed anyway // but reset for 3, since server hasn't closed stream srv.recv_frame(frames::reset(3).refused()).await; idle_ms(10).await; @@ -751,7 +846,7 @@ async fn recv_too_big_headers() { #[tokio::test] async fn pending_send_request_gets_reset_by_peer_properly() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let payload = Bytes::from(vec![0; (frame::DEFAULT_INITIAL_WINDOW_SIZE * 2) as usize]); @@ -823,7 +918,7 @@ async fn pending_send_request_gets_reset_by_peer_properly() { #[tokio::test] async fn request_without_path() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -854,7 +949,7 @@ async fn request_without_path() { #[tokio::test] async fn request_options_with_star() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); // Note the lack of trailing slash. @@ -899,7 +994,7 @@ async fn notify_on_send_capacity() { // stream, the client is notified. use tokio::sync::oneshot; - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let (done_tx, done_rx) = oneshot::channel(); @@ -979,7 +1074,7 @@ async fn notify_on_send_capacity() { #[tokio::test] async fn send_stream_poll_reset() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -1017,7 +1112,7 @@ async fn drop_pending_open() { // This test checks that a stream queued for pending open behaves correctly when its // client drops. use tokio::sync::oneshot; - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let (init_tx, init_rx) = oneshot::channel(); @@ -1105,7 +1200,7 @@ async fn malformed_response_headers_dont_unlink_stream() { // no remaining references correctly resets the stream, without prematurely // unlinking it. use tokio::sync::oneshot; - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let (drop_tx, drop_rx) = oneshot::channel(); @@ -1171,8 +1266,605 @@ async fn malformed_response_headers_dont_unlink_stream() { join(srv, client).await; } -const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; -const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; +#[tokio::test] +async fn allow_empty_data_for_head() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( + frames::headers(1) + .request("HEAD", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame( + frames::headers(1) + .response(200) + .field("content-length", 100), + ) + .await; + srv.send_frame(frames::data(1, "").eos()).await; + }; + + let h2 = async move { + let (mut client, h2) = client::Builder::new() + .handshake::<_, Bytes>(io) + .await + .unwrap(); + tokio::spawn(async { + h2.await.expect("connection failed"); + }); + let request = Request::builder() + .method(Method::HEAD) + .uri("https://example.com/") + .body(()) + .unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + let (_, mut body) = response.await.unwrap().into_parts(); + assert_eq!(body.data().await.unwrap().unwrap(), ""); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn early_hints() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(103)).await; + srv.send_frame(frames::headers(1).response(200).field("content-length", 2)) + .await; + srv.send_frame(frames::data(1, "ok").eos()).await; + }; + + let h2 = async move { + let (mut client, h2) = client::Builder::new() + .handshake::<_, Bytes>(io) + .await + .unwrap(); + tokio::spawn(async { + h2.await.expect("connection failed"); + }); + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + let (ha, mut body) = response.await.unwrap().into_parts(); + eprintln!("{:?}", ha); + assert_eq!(body.data().await.unwrap().unwrap(), "ok"); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn informational_while_local_streaming() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + srv.send_frame(frames::headers(1).response(103)).await; + srv.send_frame(frames::headers(1).response(200).field("content-length", 2)) + .await; + srv.recv_frame(frames::data(1, "hello").eos()).await; + srv.send_frame(frames::data(1, "ok").eos()).await; + }; + + let h2 = async move { + let (mut client, h2) = client::Builder::new() + .handshake::<_, Bytes>(io) + .await + .unwrap(); + tokio::spawn(async { + h2.await.expect("connection failed"); + }); + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + // don't EOS stream yet.. + let (response, mut body_tx) = client.send_request(request, false).unwrap(); + // eventual response is 200, not 103 + let resp = response.await.expect("response"); + // assert_eq!(resp.status(), 200); + // now we can end the stream + body_tx.send_data("hello".into(), true).expect("send_data"); + let mut body = resp.into_body(); + assert_eq!(body.data().await.unwrap().unwrap(), "ok"); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn extended_connect_protocol_disabled_by_default() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::get("https://example.com/").body(()).unwrap(); + // first request is allowed + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + + assert!(!client.is_extended_connect_protocol_enabled()); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn extended_connect_protocol_enabled_during_handshake() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1)) + .await; + assert_default_settings!(settings); + + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::get("https://example.com/").body(()).unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + + assert!(client.is_extended_connect_protocol_enabled()); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn invalid_connect_protocol_enabled_setting() { + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let srv = async move { + // Send a settings frame + srv.send(frames::settings().enable_connect_protocol(2).into()) + .await + .unwrap(); + srv.read_preface().await.unwrap(); + + let settings = assert_settings!(srv.next().await.expect("unexpected EOF").unwrap()); + assert_default_settings!(settings); + + // Send the ACK + let ack = frame::Settings::ack(); + + // TODO: Don't unwrap? + srv.send(ack.into()).await.unwrap(); + + let frame = srv.next().await.unwrap().unwrap(); + let go_away = assert_go_away!(frame); + assert_eq!(go_away.reason(), Reason::PROTOCOL_ERROR); + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::get("https://example.com/").body(()).unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + + let error = h2.drive(response).await.unwrap_err(); + assert_eq!(error.reason(), Some(Reason::PROTOCOL_ERROR)); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn extended_connect_request() { + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1)) + .await; + assert_default_settings!(settings); + + srv.recv_frame( + frames::headers(1) + .request("CONNECT", "http://bread/baguette") + .protocol("the-bread-protocol") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + let request = Request::connect("http://bread/baguette") + .extension(Protocol::from("the-bread-protocol")) + .body(()) + .unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn rogue_server_odd_headers() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.send_frame(frames::headers(1)).await; + srv.recv_frame(frames::go_away(0).protocol_error()).await; + }; + + let h2 = async move { + let (_client, h2) = client::handshake(io).await.unwrap(); + + let err = h2.await.unwrap_err(); + assert!(err.is_go_away()); + assert_eq!(err.reason(), Some(Reason::PROTOCOL_ERROR)); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn rogue_server_even_headers() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.send_frame(frames::headers(2)).await; + srv.recv_frame(frames::go_away(0).protocol_error()).await; + }; + + let h2 = async move { + let (_client, h2) = client::handshake(io).await.unwrap(); + + let err = h2.await.unwrap_err(); + assert!(err.is_go_away()); + assert_eq!(err.reason(), Some(Reason::PROTOCOL_ERROR)); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn rogue_server_reused_headers() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + + srv.recv_frame( + frames::headers(1) + .request("GET", "https://camembert.fromage") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.send_frame(frames::headers(1)).await; + srv.recv_frame(frames::reset(1).stream_closed()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + h2.drive(async { + let request = Request::builder() + .method(Method::GET) + .uri("https://camembert.fromage") + .body(()) + .unwrap(); + let _res = client.send_request(request, true).unwrap().0.await.unwrap(); + }) + .await; + + h2.await.unwrap(); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn client_builder_header_table_size() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + let mut settings = frame::Settings::default(); + + settings.set_header_table_size(Some(10000)); + + let srv = async move { + let recv_settings = srv.assert_client_handshake().await; + assert_frame_eq(recv_settings, settings); + + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let mut builder = client::Builder::new(); + builder.header_table_size(10000); + + let h2 = async move { + let (mut client, mut h2) = builder.handshake::<_, Bytes>(io).await.unwrap(); + let request = Request::get("https://example.com/").body(()).unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn configured_max_concurrent_send_streams_and_update_it_based_on_empty_settings_frame() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + // Send empty SETTINGS frame (no MAX_CONCURRENT_STREAMS is provided) + srv.send_frame(frames::settings()).await; + }; + + let h2 = async move { + let (_client, h2) = client::Builder::new() + // Configure the initial value to 2024 + .initial_max_send_streams(2024) + .handshake::<_, bytes::Bytes>(io) + .await + .unwrap(); + let mut h2 = std::pin::pin!(h2); + // It should be pre-configured value before it receives the initial + // SETTINGS frame from the server + assert_eq!(h2.max_concurrent_send_streams(), 2024); + h2.as_mut().await.unwrap(); + // If the server's initial SETTINGS frame does not include + // MAX_CONCURRENT_STREAMS, this should be updated to usize::MAX. + assert_eq!(h2.max_concurrent_send_streams(), usize::MAX); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn configured_max_concurrent_send_streams_and_update_it_based_on_non_empty_settings_frame() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + // Send SETTINGS frame with MAX_CONCURRENT_STREAMS set to 42 + srv.send_frame(frames::settings().max_concurrent_streams(42)) + .await; + }; + + let h2 = async move { + let (_client, h2) = client::Builder::new() + // Configure the initial value to 2024 + .initial_max_send_streams(2024) + .handshake::<_, bytes::Bytes>(io) + .await + .unwrap(); + let mut h2 = std::pin::pin!(h2); + // It should be pre-configured value before it receives the initial + // SETTINGS frame from the server + assert_eq!(h2.max_concurrent_send_streams(), 2024); + h2.as_mut().await.unwrap(); + // Now the client has received the initial SETTINGS frame from the + // server, which should update the value accordingly + assert_eq!(h2.max_concurrent_send_streams(), 42); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn receive_settings_frame_twice_with_second_one_empty() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + // Send the initial SETTINGS frame with MAX_CONCURRENT_STREAMS set to 42 + srv.send_frame(frames::settings().max_concurrent_streams(42)) + .await; + + // Handle the client's connection preface + srv.read_preface().await.unwrap(); + match srv.next().await { + Some(frame) => match frame.unwrap() { + h2::frame::Frame::Settings(_) => { + let ack = frame::Settings::ack(); + srv.send(ack.into()).await.unwrap(); + } + frame => { + panic!("unexpected frame: {:?}", frame); + } + }, + None => { + panic!("unexpected EOF"); + } + } + + // Should receive the ack for the server's initial SETTINGS frame + let frame = assert_settings!(srv.next().await.unwrap().unwrap()); + assert!(frame.is_ack()); + + // Send another SETTINGS frame with no MAX_CONCURRENT_STREAMS + // This should not update the max_concurrent_send_streams value that + // the client manages. + srv.send_frame(frames::settings()).await; + }; + + let h2 = async move { + let (_client, h2) = client::handshake(io).await.unwrap(); + let mut h2 = std::pin::pin!(h2); + assert_eq!(h2.max_concurrent_send_streams(), usize::MAX); + h2.as_mut().await.unwrap(); + // Even though the second SETTINGS frame contained no value for + // MAX_CONCURRENT_STREAMS, update to usize::MAX should not happen + assert_eq!(h2.max_concurrent_send_streams(), 42); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn receive_settings_frame_twice_with_second_one_non_empty() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + // Send the initial SETTINGS frame with MAX_CONCURRENT_STREAMS set to 42 + srv.send_frame(frames::settings().max_concurrent_streams(42)) + .await; + + // Handle the client's connection preface + srv.read_preface().await.unwrap(); + match srv.next().await { + Some(frame) => match frame.unwrap() { + h2::frame::Frame::Settings(_) => { + let ack = frame::Settings::ack(); + srv.send(ack.into()).await.unwrap(); + } + frame => { + panic!("unexpected frame: {:?}", frame); + } + }, + None => { + panic!("unexpected EOF"); + } + } + + // Should receive the ack for the server's initial SETTINGS frame + let frame = assert_settings!(srv.next().await.unwrap().unwrap()); + assert!(frame.is_ack()); + + // Send another SETTINGS frame with no MAX_CONCURRENT_STREAMS + // This should not update the max_concurrent_send_streams value that + // the client manages. + srv.send_frame(frames::settings().max_concurrent_streams(2024)) + .await; + }; + + let h2 = async move { + let (_client, h2) = client::handshake(io).await.unwrap(); + let mut h2 = std::pin::pin!(h2); + assert_eq!(h2.max_concurrent_send_streams(), usize::MAX); + h2.as_mut().await.unwrap(); + // The most-recently advertised value should be used + assert_eq!(h2.max_concurrent_send_streams(), 2024); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn server_drop_connection_unexpectedly_return_unexpected_eof_err() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( + frames::headers(1) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .await; + srv.close_without_notify(); + }; + + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.unwrap(); + tokio::spawn(async move { + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + let _res = client + .send_request(request, true) + .unwrap() + .0 + .await + .expect("request"); + }); + let err = h2.await.expect_err("should receive UnexpectedEof"); + assert_eq!( + err.get_io().expect("should be UnexpectedEof").kind(), + io::ErrorKind::UnexpectedEof, + ); + }; + join(srv, h2).await; +} + +const SETTINGS: &[u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; +const SETTINGS_ACK: &[u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; trait MockH2 { fn handshake(&mut self) -> &mut Self; diff --git a/tests/h2-tests/tests/codec_read.rs b/tests/h2-tests/tests/codec_read.rs index 6ebe54d6e..d955e186b 100644 --- a/tests/h2-tests/tests/codec_read.rs +++ b/tests/h2-tests/tests/codec_read.rs @@ -130,7 +130,7 @@ fn read_headers_empty_payload() {} #[tokio::test] async fn read_continuation_frames() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let large = build_large_headers(); @@ -190,8 +190,9 @@ async fn read_continuation_frames() { #[tokio::test] async fn update_max_frame_len_at_rest() { use futures::StreamExt; + use tokio::io::AsyncReadExt; - let _ = env_logger::try_init(); + h2_support::trace_init!(); // TODO: add test for updating max frame length in flight as well? let mut codec = raw_codec! { read => [ @@ -211,6 +212,10 @@ async fn update_max_frame_len_at_rest() { codec.next().await.unwrap().unwrap_err().to_string(), "frame with invalid size" ); + + // drain codec buffer + let mut buf = Vec::new(); + codec.get_mut().read_to_end(&mut buf).await.unwrap(); } #[tokio::test] @@ -231,7 +236,7 @@ async fn read_goaway_with_debug_data() { let data = poll_frame!(GoAway, codec); assert_eq!(data.reason(), Reason::ENHANCE_YOUR_CALM); assert_eq!(data.last_stream_id(), 1); - assert_eq!(data.debug_data(), b"too_many_pings"); + assert_eq!(&**data.debug_data(), b"too_many_pings"); assert_closed!(codec); } diff --git a/tests/h2-tests/tests/codec_write.rs b/tests/h2-tests/tests/codec_write.rs index 2347f63b2..0b85a2238 100644 --- a/tests/h2-tests/tests/codec_write.rs +++ b/tests/h2-tests/tests/codec_write.rs @@ -5,7 +5,7 @@ use h2_support::prelude::*; async fn write_continuation_frames() { // An invalid dependency ID results in a stream level error. The hpack // payload should still be decoded. - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let large = build_large_headers(); @@ -56,7 +56,7 @@ async fn write_continuation_frames() { async fn client_settings_header_table_size() { // A server sets the SETTINGS_HEADER_TABLE_SIZE to 0, test that the // client doesn't send indexed headers. - let _ = env_logger::try_init(); + h2_support::trace_init!(); let io = mock_io::Builder::new() // Read SETTINGS_HEADER_TABLE_SIZE = 0 @@ -99,7 +99,7 @@ async fn client_settings_header_table_size() { async fn server_settings_header_table_size() { // A client sets the SETTINGS_HEADER_TABLE_SIZE to 0, test that the // server doesn't send indexed headers. - let _ = env_logger::try_init(); + h2_support::trace_init!(); let io = mock_io::Builder::new() .read(MAGIC_PREFACE) diff --git a/tests/h2-tests/tests/flow_control.rs b/tests/h2-tests/tests/flow_control.rs index f03404130..dbb933286 100644 --- a/tests/h2-tests/tests/flow_control.rs +++ b/tests/h2-tests/tests/flow_control.rs @@ -7,7 +7,7 @@ use h2_support::util::yield_once; // explicitly requested. #[tokio::test] async fn send_data_without_requesting_capacity() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let payload = vec![0; 1024]; @@ -53,7 +53,7 @@ async fn send_data_without_requesting_capacity() { #[tokio::test] async fn release_capacity_sends_window_update() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let payload = vec![0u8; 16_384]; let payload_len = payload.len(); @@ -120,7 +120,7 @@ async fn release_capacity_sends_window_update() { #[tokio::test] async fn release_capacity_of_small_amount_does_not_send_window_update() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let payload = [0; 16]; @@ -175,7 +175,7 @@ fn expand_window_calls_are_coalesced() {} #[tokio::test] async fn recv_data_overflows_connection_window() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); @@ -217,7 +217,7 @@ async fn recv_data_overflows_connection_window() { let err = res.unwrap_err(); assert_eq!( err.to_string(), - "protocol error: flow-control protocol violated" + "connection error detected: flow-control protocol violated" ); }; @@ -227,7 +227,7 @@ async fn recv_data_overflows_connection_window() { let err = res.unwrap_err(); assert_eq!( err.to_string(), - "protocol error: flow-control protocol violated" + "connection error detected: flow-control protocol violated" ); }; join(conn, req).await; @@ -238,7 +238,7 @@ async fn recv_data_overflows_connection_window() { #[tokio::test] async fn recv_data_overflows_stream_window() { // this tests for when streams have smaller windows than their connection - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); @@ -278,7 +278,7 @@ async fn recv_data_overflows_stream_window() { let err = res.unwrap_err(); assert_eq!( err.to_string(), - "protocol error: flow-control protocol violated" + "stream error detected: flow-control protocol violated" ); }; @@ -295,7 +295,7 @@ fn recv_window_update_causes_overflow() { #[tokio::test] async fn stream_error_release_connection_capacity() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -350,7 +350,7 @@ async fn stream_error_release_connection_capacity() { should_recv_bytes -= bytes.len(); should_recv_frames -= 1; if should_recv_bytes == 0 { - assert_eq!(should_recv_bytes, 0); + assert_eq!(should_recv_frames, 0); } Ok(()) }) @@ -358,7 +358,7 @@ async fn stream_error_release_connection_capacity() { .expect_err("body"); assert_eq!( err.to_string(), - "protocol error: unspecific protocol error detected" + "stream error detected: unspecific protocol error detected" ); cap.release_capacity(to_release).expect("release_capacity"); }; @@ -371,7 +371,7 @@ async fn stream_error_release_connection_capacity() { #[tokio::test] async fn stream_close_by_data_frame_releases_capacity() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let window_size = frame::DEFAULT_INITIAL_WINDOW_SIZE as usize; @@ -443,7 +443,7 @@ async fn stream_close_by_data_frame_releases_capacity() { #[tokio::test] async fn stream_close_by_trailers_frame_releases_capacity() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let window_size = frame::DEFAULT_INITIAL_WINDOW_SIZE as usize; @@ -516,7 +516,7 @@ async fn stream_close_by_trailers_frame_releases_capacity() { #[tokio::test] async fn stream_close_by_send_reset_frame_releases_capacity() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -575,7 +575,7 @@ fn stream_close_by_recv_reset_frame_releases_capacity() {} #[tokio::test] async fn recv_window_update_on_stream_closed_by_data_frame() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let h2 = async move { @@ -620,7 +620,7 @@ async fn recv_window_update_on_stream_closed_by_data_frame() { #[tokio::test] async fn reserved_capacity_assigned_in_multi_window_updates() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let h2 = async move { @@ -685,7 +685,7 @@ async fn reserved_capacity_assigned_in_multi_window_updates() { async fn connection_notified_on_released_capacity() { use tokio::sync::{mpsc, oneshot}; - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); // We're going to run the connection on a thread in order to isolate task @@ -794,7 +794,7 @@ async fn connection_notified_on_released_capacity() { #[tokio::test] async fn recv_settings_removes_available_capacity() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let mut settings = frame::Settings::default(); @@ -841,7 +841,7 @@ async fn recv_settings_removes_available_capacity() { #[tokio::test] async fn recv_settings_keeps_assigned_capacity() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let (sent_settings, sent_settings_rx) = futures::channel::oneshot::channel(); @@ -886,7 +886,7 @@ async fn recv_settings_keeps_assigned_capacity() { #[tokio::test] async fn recv_no_init_window_then_receive_some_init_window() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let mut settings = frame::Settings::default(); @@ -940,9 +940,8 @@ async fn recv_no_init_window_then_receive_some_init_window() { #[tokio::test] async fn settings_lowered_capacity_returns_capacity_to_connection() { use futures::channel::oneshot; - use futures::future::{select, Either}; - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let (tx1, rx1) = oneshot::channel(); let (tx2, rx2) = oneshot::channel(); @@ -972,10 +971,9 @@ async fn settings_lowered_capacity_returns_capacity_to_connection() { // // A timeout is used here to avoid blocking forever if there is a // failure - let result = select(rx2, tokio::time::delay_for(Duration::from_secs(5))).await; - if let Either::Right((_, _)) = result { - panic!("Timed out"); - } + let _ = tokio::time::timeout(Duration::from_secs(5), rx2) + .await + .unwrap(); idle_ms(500).await; @@ -1004,10 +1002,9 @@ async fn settings_lowered_capacity_returns_capacity_to_connection() { }); // Wait for server handshake to complete. - let result = select(rx1, tokio::time::delay_for(Duration::from_secs(5))).await; - if let Either::Right((_, _)) = result { - panic!("Timed out"); - } + let _ = tokio::time::timeout(Duration::from_secs(5), rx1) + .await + .unwrap(); let request = Request::post("https://example.com/one").body(()).unwrap(); @@ -1049,7 +1046,7 @@ async fn settings_lowered_capacity_returns_capacity_to_connection() { #[tokio::test] async fn client_increase_target_window_size() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -1069,7 +1066,7 @@ async fn client_increase_target_window_size() { #[tokio::test] async fn increase_target_window_size_after_using_some() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -1110,7 +1107,7 @@ async fn increase_target_window_size_after_using_some() { #[tokio::test] async fn decrease_target_window_size() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -1155,7 +1152,7 @@ async fn decrease_target_window_size() { #[tokio::test] async fn client_update_initial_window_size() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let window_size = frame::DEFAULT_INITIAL_WINDOW_SIZE * 2; @@ -1230,7 +1227,7 @@ async fn client_update_initial_window_size() { #[tokio::test] async fn client_decrease_initial_window_size() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -1355,7 +1352,7 @@ async fn client_decrease_initial_window_size() { #[tokio::test] async fn server_target_window_size() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -1377,7 +1374,7 @@ async fn server_target_window_size() { #[tokio::test] async fn recv_settings_increase_window_size_after_using_some() { // See https://github.com/hyperium/h2/issues/208 - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let new_win_size = 16_384 * 4; // 1 bigger than default @@ -1419,7 +1416,7 @@ async fn recv_settings_increase_window_size_after_using_some() { #[tokio::test] async fn reserve_capacity_after_peer_closes() { // See https://github.com/hyperium/h2/issues/300 - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -1456,7 +1453,7 @@ async fn reserve_capacity_after_peer_closes() { async fn reset_stream_waiting_for_capacity() { // This tests that receiving a reset on a stream that has some available // connection-level window reassigns that window to another stream. - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); @@ -1517,7 +1514,7 @@ async fn reset_stream_waiting_for_capacity() { #[tokio::test] async fn data_padding() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let mut body = Vec::new(); @@ -1564,3 +1561,436 @@ async fn data_padding() { join(srv, h2).await; } + +#[tokio::test] +async fn poll_capacity_after_send_data_and_reserve() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().initial_window_size(5)) + .await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::data(1, &b"abcde"[..])).await; + srv.send_frame(frames::window_update(1, 5)).await; + srv.recv_frame(frames::data(1, &b""[..]).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + stream.send_data("abcde".into(), false).unwrap(); + + stream.reserve_capacity(5); + + // Initial window size was 5 so current capacity is 0 even if we just reserved. + assert_eq!(stream.capacity(), 0); + + // This will panic if there is a bug causing h2 to return Ok(0) from poll_capacity. + let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; + + stream.send_data("".into(), true).unwrap(); + + // Wait for the connection to close + h2.await.unwrap(); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn poll_capacity_after_send_data_and_reserve_with_max_send_buffer_size() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().initial_window_size(10)) + .await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::data(1, &b"abcde"[..])).await; + srv.send_frame(frames::window_update(1, 10)).await; + srv.recv_frame(frames::data(1, &b""[..]).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::Builder::new() + .max_send_buffer_size(5) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + stream.send_data("abcde".into(), false).unwrap(); + + stream.reserve_capacity(5); + + // Initial window size was 10 but with a max send buffer size of 10 in the client, + // so current capacity is 0 even if we just reserved. + assert_eq!(stream.capacity(), 0); + + // This will panic if there is a bug causing h2 to return Ok(0) from poll_capacity. + let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; + + stream.send_data("".into(), true).unwrap(); + + // Wait for the connection to close + h2.await.unwrap(); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn max_send_buffer_size_overflow() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.recv_frame(frames::data(1, &[0; 10][..])).await; + srv.recv_frame(frames::data(1, &[][..]).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::Builder::new() + .max_send_buffer_size(5) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = conn.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + assert_eq!(stream.capacity(), 0); + stream.reserve_capacity(10); + assert_eq!( + stream.capacity(), + 5, + "polled capacity not over max buffer size" + ); + + stream.send_data([0; 10][..].into(), false).unwrap(); + + stream.reserve_capacity(15); + assert_eq!( + stream.capacity(), + 0, + "now with buffered over the max, don't overflow" + ); + stream.send_data([0; 0][..].into(), true).unwrap(); + + // Wait for the connection to close + conn.await.unwrap(); + }; + + join(srv, client).await; +} + +#[tokio::test] +async fn max_send_buffer_size_poll_capacity_wakes_task() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[][..]).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::Builder::new() + .max_send_buffer_size(5) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = conn.drive(response).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + assert_eq!(stream.capacity(), 0); + const TO_SEND: usize = 20; + stream.reserve_capacity(TO_SEND); + assert_eq!( + stream.capacity(), + 5, + "polled capacity not over max buffer size" + ); + + let t1 = tokio::spawn(async move { + let mut sent = 0; + let buf = [0; TO_SEND]; + loop { + match poll_fn(|cx| stream.poll_capacity(cx)).await { + None => panic!("no cap"), + Some(Err(e)) => panic!("cap error: {:?}", e), + Some(Ok(cap)) => { + stream + .send_data(buf[sent..(sent + cap)].to_vec().into(), false) + .unwrap(); + sent += cap; + if sent >= TO_SEND { + break; + } + } + } + } + stream.send_data(Bytes::new(), true).unwrap(); + }); + + // Wait for the connection to close + conn.await.unwrap(); + t1.await.unwrap(); + }; + + join(srv, client).await; +} + +#[tokio::test] +async fn poll_capacity_wakeup_after_window_update() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().initial_window_size(10)) + .await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::data(1, &b"abcde"[..])).await; + srv.send_frame(frames::window_update(1, 5)).await; + srv.send_frame(frames::window_update(1, 5)).await; + srv.recv_frame(frames::data(1, &b"abcde"[..])).await; + srv.recv_frame(frames::data(1, &b""[..]).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::Builder::new() + .max_send_buffer_size(5) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + stream.send_data("abcde".into(), false).unwrap(); + + stream.reserve_capacity(10); + assert_eq!(stream.capacity(), 0); + + let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; + h2.drive(idle_ms(10)).await; + stream.send_data("abcde".into(), false).unwrap(); + + stream.reserve_capacity(5); + assert_eq!(stream.capacity(), 0); + + // This will panic if there is a bug causing h2 to return Ok(0) from poll_capacity. + let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; + + stream.send_data("".into(), true).unwrap(); + + // Wait for the connection to close + h2.await.unwrap(); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn window_size_decremented_past_zero() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + // let _ = client.assert_server_handshake().await; + + // preface + client.write_preface().await; + + // the following http 2 bytes are fuzzer-generated + client.send_bytes(&[0, 0, 0, 4, 0, 0, 0, 0, 0]).await; + client + .send_bytes(&[ + 0, 0, 23, 1, 1, 0, 249, 255, 191, 131, 1, 1, 1, 70, 1, 1, 1, 1, 65, 1, 1, 65, 1, 1, + 65, 1, 1, 1, 1, 1, 1, 190, + ]) + .await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client + .send_bytes(&[ + 0, 0, 9, 247, 0, 121, 255, 255, 184, 1, 65, 1, 1, 1, 1, 1, 1, 190, + ]) + .await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client + .send_bytes(&[0, 0, 3, 0, 1, 0, 249, 255, 191, 1, 1, 190]) + .await; + client + .send_bytes(&[0, 0, 2, 50, 107, 0, 0, 0, 1, 0, 0]) + .await; + client + .send_bytes(&[0, 0, 5, 2, 0, 0, 0, 0, 1, 128, 0, 55, 0, 0]) + .await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 126, 4, 39, 184, 171, 125, 33, 0, 3, 107, 50, 98, + ]) + .await; + client + .send_bytes(&[0, 0, 6, 4, 0, 0, 0, 0, 0, 3, 4, 76, 255, 71, 131]) + .await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 0, 4, 39, 184, 171, 74, 33, 0, 3, 107, 50, 98, + ]) + .await; + client + .send_bytes(&[ + 0, 0, 30, 4, 0, 0, 0, 0, 0, 0, 4, 56, 184, 171, 125, 65, 0, 35, 65, 65, 65, 61, + 232, 87, 115, 89, 116, 0, 4, 0, 58, 33, 125, 33, 79, 3, 107, 49, 98, + ]) + .await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 0, 4, 39, 184, 171, 125, 33, 0, 3, 107, 50, 98, + ]) + .await; + client.send_bytes(&[0, 0, 0, 4, 0, 0, 0, 0, 0]).await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 126, 4, 39, 184, 171, 125, 33, 0, 3, 107, 50, 98, + ]) + .await; + client + .send_bytes(&[ + 0, 0, 177, 1, 44, 0, 0, 0, 1, 67, 67, 67, 67, 67, 67, 131, 134, 5, 61, 67, 67, 67, + 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, + 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, + 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 115, 102, 1, 3, 48, 43, + 101, 64, 31, 37, 99, 99, 97, 97, 97, 97, 49, 97, 54, 97, 97, 97, 97, 49, 97, 54, + 97, 99, 54, 53, 53, 51, 53, 99, 99, 97, 97, 99, 97, 97, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, + ]) + .await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 0, 4, 0, 58, 171, 125, 33, 79, 3, 107, 49, 98, + ]) + .await; + client + .send_bytes(&[0, 0, 6, 4, 0, 0, 0, 0, 0, 0, 4, 87, 115, 89, 116]) + .await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 126, 4, 39, 184, 171, 125, 33, 0, 3, 107, 50, 98, + ]) + .await; + client + .send_bytes(&[ + 0, 0, 129, 1, 44, 0, 0, 0, 1, 67, 67, 67, 67, 67, 67, 131, 134, 5, 18, 67, 67, 61, + 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 48, 54, 53, 55, 114, 1, 4, 97, 49, 51, 116, + 64, 2, 117, 115, 4, 103, 101, 110, 116, 64, 8, 57, 111, 110, 116, 101, 110, 115, + 102, 7, 43, 43, 49, 48, 48, 43, 101, 192, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]) + .await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 0, 4, 0, 58, 171, 125, 33, 79, 3, 107, 49, 98, + ]) + .await; + + // TODO: is CANCEL the right error code to expect here? + // client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let builder = server::Builder::new(); + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + // just keep it open + let res = poll_fn(move |cx| srv.poll_closed(cx)).await; + tracing::debug!("{:?}", res); + }; + + join(client, srv).await; +} diff --git a/tests/h2-tests/tests/hammer.rs b/tests/h2-tests/tests/hammer.rs index cf7051814..4b5d04341 100644 --- a/tests/h2-tests/tests/hammer.rs +++ b/tests/h2-tests/tests/hammer.rs @@ -8,7 +8,6 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, - thread, }; use tokio::net::{TcpListener, TcpStream}; @@ -26,8 +25,8 @@ impl Server { { let mk_data = Arc::new(mk_data); - let mut rt = tokio::runtime::Runtime::new().unwrap(); - let mut listener = rt + let rt = tokio::runtime::Runtime::new().unwrap(); + let listener = rt .block_on(TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))) .unwrap(); let addr = listener.local_addr().unwrap(); @@ -58,7 +57,7 @@ impl Server { } fn addr(&self) -> SocketAddr { - self.addr.clone() + self.addr } fn request_count(&self) -> usize { @@ -140,7 +139,7 @@ fn hammer_client_concurrency() { }) }); - let mut rt = tokio::runtime::Runtime::new().unwrap(); + let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(tcp); println!("...done"); } diff --git a/tests/h2-tests/tests/ping_pong.rs b/tests/h2-tests/tests/ping_pong.rs index f093b43f6..0f93578cc 100644 --- a/tests/h2-tests/tests/ping_pong.rs +++ b/tests/h2-tests/tests/ping_pong.rs @@ -6,14 +6,13 @@ use h2_support::prelude::*; #[tokio::test] async fn recv_single_ping() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (m, mut mock) = mock::new(); // Create the handshake let h2 = async move { - let (client, conn) = client::handshake(m).await.unwrap(); - let c = conn.await.unwrap(); - (client, c) + let (_client, conn) = client::handshake(m).await.unwrap(); + let _: () = conn.await.unwrap(); }; let mock = async move { @@ -36,7 +35,7 @@ async fn recv_single_ping() { #[tokio::test] async fn recv_multiple_pings() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -58,7 +57,7 @@ async fn recv_multiple_pings() { #[tokio::test] async fn pong_has_highest_priority() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let data = Bytes::from(vec![0; 16_384]); @@ -96,7 +95,7 @@ async fn pong_has_highest_priority() { #[tokio::test] async fn user_ping_pong() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -138,7 +137,7 @@ async fn user_ping_pong() { #[tokio::test] async fn user_notifies_when_connection_closes() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { let settings = srv.assert_client_handshake().await; @@ -146,6 +145,7 @@ async fn user_notifies_when_connection_closes() { srv }; + #[allow(clippy::async_yields_async)] let client = async move { let (_client, mut conn) = client::handshake(io).await.expect("client handshake"); // yield once so we can ack server settings diff --git a/tests/h2-tests/tests/prioritization.rs b/tests/h2-tests/tests/prioritization.rs index 18084d91d..11d2c2ccf 100644 --- a/tests/h2-tests/tests/prioritization.rs +++ b/tests/h2-tests/tests/prioritization.rs @@ -1,12 +1,13 @@ -use futures::future::join; -use futures::{FutureExt, StreamExt}; +use futures::future::{join, select}; +use futures::{pin_mut, FutureExt, StreamExt}; + use h2_support::prelude::*; use h2_support::DEFAULT_WINDOW_SIZE; use std::task::Context; #[tokio::test] async fn single_stream_send_large_body() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let payload = vec![0; 1024]; @@ -66,7 +67,7 @@ async fn single_stream_send_large_body() { #[tokio::test] async fn multiple_streams_with_payload_greater_than_default_window() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let payload = vec![0; 16384 * 5 - 1]; let payload_clone = payload.clone(); @@ -129,7 +130,7 @@ async fn multiple_streams_with_payload_greater_than_default_window() { #[tokio::test] async fn single_stream_send_extra_large_body_multi_frames_one_buffer() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let payload = vec![0; 32_768]; @@ -193,7 +194,7 @@ async fn single_stream_send_extra_large_body_multi_frames_one_buffer() { #[tokio::test] async fn single_stream_send_body_greater_than_default_window() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let payload = vec![0; 16384 * 5 - 1]; @@ -279,7 +280,7 @@ async fn single_stream_send_body_greater_than_default_window() { #[tokio::test] async fn single_stream_send_extra_large_body_multi_frames_multi_buffer() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let payload = vec![0; 32_768]; @@ -341,7 +342,7 @@ async fn single_stream_send_extra_large_body_multi_frames_multi_buffer() { #[tokio::test] async fn send_data_receive_window_update() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (m, mut mock) = mock::new(); let h2 = async move { @@ -408,3 +409,95 @@ async fn send_data_receive_window_update() { join(mock, h2).await; } + +#[tokio::test] +async fn stream_count_over_max_stream_limit_does_not_starve_capacity() { + use tokio::sync::oneshot; + + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let (tx, rx) = oneshot::channel(); + + let srv = async move { + let _ = srv + .assert_client_handshake_with_settings( + frames::settings() + // super tiny server + .max_concurrent_streams(1), + ) + .await; + srv.recv_frame(frames::headers(1).request("POST", "http://example.com/")) + .await; + + srv.recv_frame(frames::data(1, vec![0; 16384])).await; + srv.recv_frame(frames::data(1, vec![0; 16384])).await; + srv.recv_frame(frames::data(1, vec![0; 16384])).await; + srv.recv_frame(frames::data(1, vec![0; 16383]).eos()).await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + + // All of these connection capacities should be assigned to stream 3 + srv.send_frame(frames::window_update(0, 16384)).await; + srv.send_frame(frames::window_update(0, 16384)).await; + srv.send_frame(frames::window_update(0, 16384)).await; + srv.send_frame(frames::window_update(0, 16383)).await; + + // StreamId(3) should be able to send all of its request with the conn capacity + srv.recv_frame(frames::headers(3).request("POST", "http://example.com/")) + .await; + srv.recv_frame(frames::data(3, vec![0; 16384])).await; + srv.recv_frame(frames::data(3, vec![0; 16384])).await; + srv.recv_frame(frames::data(3, vec![0; 16384])).await; + srv.recv_frame(frames::data(3, vec![0; 16383]).eos()).await; + srv.send_frame(frames::headers(3).response(200).eos()).await; + + // Then all the future stream is guaranteed to be send-able by induction + tx.send(()).unwrap(); + }; + + fn request() -> Request<()> { + Request::builder() + .method(Method::POST) + .uri("http://example.com/") + .body(()) + .unwrap() + } + + let client = async move { + let (mut client, mut conn) = client::Builder::new() + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let (req1, mut send1) = client.send_request(request(), false).unwrap(); + let (req2, mut send2) = client.send_request(request(), false).unwrap(); + + // Use up the connection window. + send1.send_data(vec![0; 65535].into(), true).unwrap(); + // Queue up for more connection window. + send2.send_data(vec![0; 65535].into(), true).unwrap(); + + // Queue up more pending open streams + for _ in 0..5 { + let (_, mut send) = client.send_request(request(), false).unwrap(); + send.send_data(vec![0; 65535].into(), true).unwrap(); + } + + let response = conn.drive(req1).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let response = conn.drive(req2).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let _ = rx.await; + }; + + let task = join(srv, client); + pin_mut!(task); + + let t = tokio::time::sleep(Duration::from_secs(5)).map(|_| panic!("time out")); + pin_mut!(t); + + select(task, t).await; +} diff --git a/tests/h2-tests/tests/push_promise.rs b/tests/h2-tests/tests/push_promise.rs index f786a72b7..94c1154ef 100644 --- a/tests/h2-tests/tests/push_promise.rs +++ b/tests/h2-tests/tests/push_promise.rs @@ -4,7 +4,7 @@ use h2_support::prelude::*; #[tokio::test] async fn recv_push_works() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let mock = async move { @@ -62,7 +62,7 @@ async fn recv_push_works() { #[tokio::test] async fn pushed_streams_arent_dropped_too_early() { // tests that by default, received push promises work - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let mock = async move { @@ -128,7 +128,7 @@ async fn pushed_streams_arent_dropped_too_early() { #[tokio::test] async fn recv_push_when_push_disabled_is_conn_error() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let mock = async move { @@ -164,7 +164,7 @@ async fn recv_push_when_push_disabled_is_conn_error() { let err = res.unwrap_err(); assert_eq!( err.to_string(), - "protocol error: unspecific protocol error detected" + "connection error detected: unspecific protocol error detected" ); }; @@ -174,7 +174,7 @@ async fn recv_push_when_push_disabled_is_conn_error() { let err = res.unwrap_err(); assert_eq!( err.to_string(), - "protocol error: unspecific protocol error detected" + "connection error detected: unspecific protocol error detected" ); }; @@ -186,7 +186,7 @@ async fn recv_push_when_push_disabled_is_conn_error() { #[tokio::test] async fn pending_push_promises_reset_when_dropped() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -223,7 +223,7 @@ async fn pending_push_promises_reset_when_dropped() { assert_eq!(resp.status(), StatusCode::OK); }; - let _ = conn.drive(req).await; + conn.drive(req).await; conn.await.expect("client"); drop(client); }; @@ -233,7 +233,7 @@ async fn pending_push_promises_reset_when_dropped() { #[tokio::test] async fn recv_push_promise_over_max_header_list_size() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -284,7 +284,7 @@ async fn recv_push_promise_over_max_header_list_size() { #[tokio::test] async fn recv_invalid_push_promise_headers_is_stream_protocol_error() { // Unsafe method or content length is stream protocol error - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let mock = async move { @@ -348,7 +348,7 @@ fn recv_push_promise_with_wrong_authority_is_stream_error() { #[tokio::test] async fn recv_push_promise_skipped_stream_id() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let mock = async move { @@ -380,8 +380,16 @@ async fn recv_push_promise_skipped_stream_id() { .unwrap(); let req = async move { - let res = client.send_request(request, true).unwrap().0.await; - assert!(res.is_err()); + let err = client + .send_request(request, true) + .unwrap() + .0 + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "connection error detected: unspecific protocol error detected" + ); }; // client should see a protocol error @@ -390,7 +398,7 @@ async fn recv_push_promise_skipped_stream_id() { let err = res.unwrap_err(); assert_eq!( err.to_string(), - "protocol error: unspecific protocol error detected" + "connection error detected: unspecific protocol error detected" ); }; @@ -402,7 +410,7 @@ async fn recv_push_promise_skipped_stream_id() { #[tokio::test] async fn recv_push_promise_dup_stream_id() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let mock = async move { @@ -435,7 +443,11 @@ async fn recv_push_promise_dup_stream_id() { let req = async move { let res = client.send_request(request, true).unwrap().0.await; - assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!( + err.to_string(), + "connection error detected: unspecific protocol error detected" + ); }; // client should see a protocol error @@ -444,7 +456,7 @@ async fn recv_push_promise_dup_stream_id() { let err = res.unwrap_err(); assert_eq!( err.to_string(), - "protocol error: unspecific protocol error detected" + "connection error detected: unspecific protocol error detected" ); }; diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index 1916138b3..a4b983a0a 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -1,16 +1,16 @@ #![deny(warnings)] -use futures::future::{join, poll_fn}; +use futures::future::join; use futures::StreamExt; use h2_support::prelude::*; use tokio::io::AsyncWriteExt; -const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; -const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; +const SETTINGS: &[u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; +const SETTINGS_ACK: &[u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; #[tokio::test] async fn read_preface_in_multiple_frames() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let mock = mock_io::Builder::new() .read(b"PRI * HTTP/2.0") @@ -28,7 +28,7 @@ async fn read_preface_in_multiple_frames() { #[tokio::test] async fn server_builder_set_max_concurrent_streams() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let mut settings = frame::Settings::default(); @@ -72,7 +72,7 @@ async fn server_builder_set_max_concurrent_streams() { #[tokio::test] async fn serve_request() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -107,14 +107,14 @@ async fn serve_request() { #[tokio::test] async fn serve_connect() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { let settings = client.assert_server_handshake().await; assert_default_settings!(settings); client - .send_frame(frames::headers(1).method("CONNECT").eos()) + .send_frame(frames::headers(1).request("CONNECT", "localhost").eos()) .await; client .recv_frame(frames::headers(1).response(200).eos()) @@ -138,7 +138,7 @@ async fn serve_connect() { #[tokio::test] async fn push_request() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -220,9 +220,56 @@ async fn push_request() { join(client, srv).await; } +#[tokio::test] +async fn push_request_disabled() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + client + .assert_server_handshake_with_settings(frames::settings().disable_push()) + .await; + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + + assert_eq!(req.method(), &http::Method::GET); + + // attempt to push - expect failure + let req = http::Request::builder() + .method("GET") + .uri("https://http2.akamai.com/style.css") + .body(()) + .unwrap(); + stream + .push_request(req) + .expect_err("push_request should error"); + + // send normal response + let rsp = http::Response::builder().status(200).body(()).unwrap(); + stream.send_response(rsp, true).unwrap(); + + assert!(srv.next().await.is_none()); + }; + + join(client, srv).await; +} + #[tokio::test] async fn push_request_against_concurrency() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -249,10 +296,10 @@ async fn push_request_against_concurrency() { .await; client.recv_frame(frames::data(2, &b""[..]).eos()).await; client - .recv_frame(frames::headers(1).response(200).eos()) + .recv_frame(frames::headers(4).response(200).eos()) .await; client - .recv_frame(frames::headers(4).response(200).eos()) + .recv_frame(frames::headers(1).response(200).eos()) .await; }; @@ -306,7 +353,7 @@ async fn push_request_against_concurrency() { #[tokio::test] async fn push_request_with_data() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -372,7 +419,7 @@ async fn push_request_with_data() { #[tokio::test] async fn push_request_between_data() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -443,7 +490,7 @@ fn accept_with_pending_connections_after_socket_close() {} #[tokio::test] async fn recv_invalid_authority() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let bad_auth = util::byte_str("not:a/good authority"); @@ -470,7 +517,7 @@ async fn recv_invalid_authority() { #[tokio::test] async fn recv_connection_header() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let req = |id, name, val| { @@ -489,7 +536,7 @@ async fn recv_connection_header() { client .send_frame(req(7, "transfer-encoding", "chunked")) .await; - client.send_frame(req(9, "upgrade", "HTTP/2.0")).await; + client.send_frame(req(9, "upgrade", "HTTP/2")).await; client.recv_frame(frames::reset(1).protocol_error()).await; client.recv_frame(frames::reset(3).protocol_error()).await; client.recv_frame(frames::reset(5).protocol_error()).await; @@ -506,8 +553,8 @@ async fn recv_connection_header() { } #[tokio::test] -async fn sends_reset_cancel_when_req_body_is_dropped() { - let _ = env_logger::try_init(); +async fn sends_reset_no_error_when_req_body_is_dropped() { + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -516,10 +563,15 @@ async fn sends_reset_cancel_when_req_body_is_dropped() { client .send_frame(frames::headers(1).request("POST", "https://example.com/")) .await; + // server responded with data before consuming POST-request's body, resulting in `RST_STREAM(NO_ERROR)`. + client.recv_frame(frames::headers(1).response(200)).await; + client.recv_frame(frames::data(1, vec![0; 16384])).await; client - .recv_frame(frames::headers(1).response(200).eos()) + .recv_frame(frames::data(1, vec![0; 16384]).eos()) + .await; + client + .recv_frame(frames::reset(1).reason(Reason::NO_ERROR)) .await; - client.recv_frame(frames::reset(1).cancel()).await; }; let srv = async move { @@ -529,7 +581,8 @@ async fn sends_reset_cancel_when_req_body_is_dropped() { assert_eq!(req.method(), &http::Method::POST); let rsp = http::Response::builder().status(200).body(()).unwrap(); - stream.send_response(rsp, true).unwrap(); + let mut tx = stream.send_response(rsp, false).unwrap(); + tx.send_data(vec![0; 16384 * 2].into(), true).unwrap(); } assert!(srv.next().await.is_none()); }; @@ -539,7 +592,7 @@ async fn sends_reset_cancel_when_req_body_is_dropped() { #[tokio::test] async fn abrupt_shutdown() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -583,7 +636,7 @@ async fn abrupt_shutdown() { #[tokio::test] async fn graceful_shutdown() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -658,7 +711,7 @@ async fn graceful_shutdown() { #[tokio::test] async fn goaway_even_if_client_sent_goaway() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -707,7 +760,7 @@ async fn goaway_even_if_client_sent_goaway() { #[tokio::test] async fn sends_reset_cancel_when_res_body_is_dropped() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -761,7 +814,7 @@ async fn sends_reset_cancel_when_res_body_is_dropped() { #[tokio::test] async fn too_big_headers_sends_431() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -797,7 +850,7 @@ async fn too_big_headers_sends_431() { #[tokio::test] async fn too_big_headers_sends_reset_after_431_if_not_eos() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -830,9 +883,92 @@ async fn too_big_headers_sends_reset_after_431_if_not_eos() { join(client, srv).await; } +#[tokio::test] +async fn too_many_continuation_frames_sends_goaway() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_frame_eq(settings, frames::settings().max_header_list_size(1024 * 32)); + + // the mock impl automatically splits into CONTINUATION frames if the + // headers are too big for one frame. So without a max header list size + // set, we'll send a bunch of headers that will eventually get nuked. + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .field("a".repeat(10_000), "b".repeat(10_000)) + .field("c".repeat(10_000), "d".repeat(10_000)) + .field("e".repeat(10_000), "f".repeat(10_000)) + .field("g".repeat(10_000), "h".repeat(10_000)) + .field("i".repeat(10_000), "j".repeat(10_000)) + .field("k".repeat(10_000), "l".repeat(10_000)) + .field("m".repeat(10_000), "n".repeat(10_000)) + .field("o".repeat(10_000), "p".repeat(10_000)) + .field("y".repeat(10_000), "z".repeat(10_000)), + ) + .await; + client + .recv_frame(frames::go_away(0).calm().data("too_many_continuations")) + .await; + }; + + let srv = async move { + let mut srv = server::Builder::new() + // should mean ~3 continuation + .max_header_list_size(1024 * 32) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let err = srv.next().await.unwrap().expect_err("server"); + assert!(err.is_go_away()); + assert!(err.is_library()); + assert_eq!(err.reason(), Some(Reason::ENHANCE_YOUR_CALM)); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn pending_accept_recv_illegal_content_length_data() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("POST", "https://a.b") + .field("content-length", "1"), + ) + .await; + client + .send_frame(frames::data(1, &b"hello"[..]).eos()) + .await; + client.recv_frame(frames::reset(1).protocol_error()).await; + idle_ms(10).await; + }; + + let srv = async move { + let mut srv = server::Builder::new() + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let _req = srv.next().await.expect("req").expect("is_ok"); + }; + + join(client, srv).await; +} + #[tokio::test] async fn poll_reset() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -872,7 +1008,7 @@ async fn poll_reset() { #[tokio::test] async fn poll_reset_io_error() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -913,7 +1049,7 @@ async fn poll_reset_io_error() { #[tokio::test] async fn poll_reset_after_send_response_is_user_error() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -967,7 +1103,7 @@ async fn poll_reset_after_send_response_is_user_error() { #[tokio::test] async fn server_error_on_unclean_shutdown() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let srv = server::Builder::new().handshake::<_, Bytes>(io); @@ -978,9 +1114,33 @@ async fn server_error_on_unclean_shutdown() { srv.await.expect_err("should error"); } +#[tokio::test] +async fn server_error_on_status_in_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame(frames::headers(1).status(StatusCode::OK)) + .await; + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + + assert!(srv.next().await.is_none()); + }; + + join(client, srv).await; +} + #[tokio::test] async fn request_without_authority() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let client = async move { @@ -1012,3 +1172,378 @@ async fn request_without_authority() { join(client, srv).await; } + +#[tokio::test] +async fn serve_when_request_in_response_extensions() { + use std::sync::Arc; + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + + let mut rsp = http::Response::new(()); + rsp.extensions_mut().insert(Arc::new(req)); + stream.send_response(rsp, true).unwrap(); + + assert!(srv.next().await.is_none()); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn send_reset_explicitly() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + client + .recv_frame(frames::reset(1).reason(Reason::ENHANCE_YOUR_CALM)) + .await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (_req, mut stream) = srv.next().await.unwrap().unwrap(); + + stream.send_reset(Reason::ENHANCE_YOUR_CALM); + + assert!(srv.next().await.is_none()); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn extended_connect_protocol_disabled_by_default() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), None); + + client + .send_frame( + frames::headers(1) + .request("CONNECT", "http://bread/baguette") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn extended_connect_protocol_enabled_during_handshake() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame( + frames::headers(1) + .request("CONNECT", "http://bread/baguette") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::headers(1).response(200)).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + + assert_eq!( + req.extensions().get::(), + Some(&crate::ext::Protocol::from_static("the-bread-protocol")) + ); + + let rsp = Response::new(()); + stream.send_response(rsp, false).unwrap(); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn reject_pseudo_protocol_on_non_connect_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame( + frames::headers(1) + .request("GET", "http://bread/baguette") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + assert!(srv.next().await.is_none()); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn reject_authority_target_on_extended_connect_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame( + frames::headers(1) + .request("CONNECT", "bread:80") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + assert!(srv.next().await.is_none()); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn reject_non_authority_target_on_connect_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette")) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + assert!(srv.next().await.is_none()); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn reject_informational_status_header_in_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let _ = client.assert_server_handshake().await; + + let status_code = 128; + assert!(StatusCode::from_u16(status_code) + .unwrap() + .is_informational()); + + client + .send_frame(frames::headers(1).response(status_code)) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let builder = server::Builder::new(); + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn client_drop_connection_without_close_notify() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + let client = async move { + let _recv_settings = client.assert_server_handshake().await; + client + .send_frame(frames::headers(1).request("GET", "https://example.com/")) + .await; + client.send_frame(frames::data(1, &b"hello"[..])).await; + client.recv_frame(frames::headers(1).response(200)).await; + + client.close_without_notify(); // Client closed without notify causing UnexpectedEof + }; + + let mut builder = server::Builder::new(); + builder.max_concurrent_streams(1); + + let h2 = async move { + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + + assert_eq!(req.method(), &http::Method::GET); + + let rsp = http::Response::builder().status(200).body(()).unwrap(); + stream.send_response(rsp, false).unwrap(); + + // Step the conn state forward and hitting the EOF + // But we have no outstanding request from client to be satisfied, so we should not return + // an error + let _ = poll_fn(|cx| srv.poll_closed(cx)).await.unwrap(); + }; + + join(client, h2).await; +} + +#[tokio::test] +async fn init_window_size_smaller_than_default_should_use_default_before_ack() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + let client = async move { + // Client can send in some data before ACK; + // Server needs to make sure the Recv stream has default window size + // as per https://datatracker.ietf.org/doc/html/rfc9113#name-initial-flow-control-window + client.write_preface().await; + client + .send(frame::Settings::default().into()) + .await + .unwrap(); + client.next().await.expect("unexpected EOF").unwrap(); + client + .send_frame(frames::headers(1).request("GET", "https://example.com/")) + .await; + client.send_frame(frames::data(1, &b"hello"[..])).await; + client.send(frame::Settings::ack().into()).await.unwrap(); + client.next().await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; + + let mut builder = server::Builder::new(); + builder.max_concurrent_streams(1); + builder.initial_window_size(1); + let h2 = async move { + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + + assert_eq!(req.method(), &http::Method::GET); + + let rsp = http::Response::builder().status(200).body(()).unwrap(); + stream.send_response(rsp, true).unwrap(); + + // Drive the state forward + let _ = poll_fn(|cx| srv.poll_closed(cx)).await.unwrap(); + }; + + join(client, h2).await; +} diff --git a/tests/h2-tests/tests/stream_states.rs b/tests/h2-tests/tests/stream_states.rs index dd0316ca0..05a96a0f5 100644 --- a/tests/h2-tests/tests/stream_states.rs +++ b/tests/h2-tests/tests/stream_states.rs @@ -9,7 +9,7 @@ use tokio::sync::oneshot; #[tokio::test] async fn send_recv_headers_only() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let mock = mock_io::Builder::new() .handshake() @@ -31,7 +31,7 @@ async fn send_recv_headers_only() { .body(()) .unwrap(); - log::info!("sending request"); + tracing::info!("sending request"); let (response, _) = client.send_request(request, true).unwrap(); let resp = h2.run(response).await.unwrap(); @@ -42,7 +42,7 @@ async fn send_recv_headers_only() { #[tokio::test] async fn send_recv_data() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let mock = mock_io::Builder::new() .handshake() @@ -72,7 +72,7 @@ async fn send_recv_data() { .body(()) .unwrap(); - log::info!("sending request"); + tracing::info!("sending request"); let (response, mut stream) = client.send_request(request, false).unwrap(); // Reserve send capacity @@ -104,7 +104,7 @@ async fn send_recv_data() { #[tokio::test] async fn send_headers_recv_data_single_frame() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let mock = mock_io::Builder::new() .handshake() @@ -129,7 +129,7 @@ async fn send_headers_recv_data_single_frame() { .body(()) .unwrap(); - log::info!("sending request"); + tracing::info!("sending request"); let (response, _) = client.send_request(request, true).unwrap(); let resp = h2.run(response).await.unwrap(); @@ -153,7 +153,7 @@ async fn send_headers_recv_data_single_frame() { #[tokio::test] async fn closed_streams_are_released() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let h2 = async move { @@ -194,9 +194,143 @@ async fn closed_streams_are_released() { join(srv, h2).await; } +#[tokio::test] +async fn reset_streams_dont_grow_memory_continuously() { + //h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + const N: u32 = 50; + const MAX: usize = 20; + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + for n in (1..(N * 2)).step_by(2) { + client + .send_frame(frames::headers(n).request("GET", "https://a.b/").eos()) + .await; + client.send_frame(frames::reset(n).protocol_error()).await; + } + + tokio::time::timeout( + std::time::Duration::from_secs(1), + client.recv_frame( + frames::go_away((MAX * 2 + 1) as u32) + .data("too_many_resets") + .calm(), + ), + ) + .await + .expect("client goaway"); + }; + + let srv = async move { + let mut srv = server::Builder::new() + .max_pending_accept_reset_streams(MAX) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + poll_fn(|cx| srv.poll_closed(cx)) + .await + .expect_err("server should error"); + // specifically, not 50; + assert_eq!(21, srv.num_wired_streams()); + }; + join(srv, client).await; +} + +#[tokio::test] +async fn go_away_with_pending_accepting() { + // h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let (sent_go_away_tx, sent_go_away_rx) = oneshot::channel(); + let (recv_go_away_tx, recv_go_away_rx) = oneshot::channel(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + + client + .send_frame(frames::headers(1).request("GET", "https://baguette/").eos()) + .await; + + client + .send_frame(frames::headers(3).request("GET", "https://campagne/").eos()) + .await; + client.send_frame(frames::go_away(1).protocol_error()).await; + + sent_go_away_tx.send(()).unwrap(); + + recv_go_away_rx.await.unwrap(); + }; + + let srv = async move { + let mut srv = server::Builder::new() + .max_pending_accept_reset_streams(1) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let (_req_1, _send_response_1) = srv.accept().await.unwrap().unwrap(); + + poll_fn(|cx| srv.poll_closed(cx)) + .drive(sent_go_away_rx) + .await + .unwrap(); + + let (_req_2, _send_response_2) = srv.accept().await.unwrap().unwrap(); + + recv_go_away_tx.send(()).unwrap(); + }; + join(srv, client).await; +} + +#[tokio::test] +async fn pending_accept_reset_streams_decrement_too() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + // If it didn't decrement internally, this would eventually get + // the count over MAX. + const M: usize = 2; + const N: usize = 5; + const MAX: usize = 6; + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + let mut id = 1; + for _ in 0..M { + for _ in 0..N { + client + .send_frame(frames::headers(id).request("GET", "https://a.b/").eos()) + .await; + client.send_frame(frames::reset(id).protocol_error()).await; + id += 2; + } + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + }; + + let srv = async move { + let mut srv = server::Builder::new() + .max_pending_accept_reset_streams(MAX) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + while let Some(Ok(_)) = srv.accept().await {} + + poll_fn(|cx| srv.poll_closed(cx)).await.expect("server"); + }; + join(srv, client).await; +} + #[tokio::test] async fn errors_if_recv_frame_exceeds_max_frame_size() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let h2 = async move { @@ -207,13 +341,19 @@ async fn errors_if_recv_frame_exceeds_max_frame_size() { let body = resp.into_parts().1; let res = util::concat(body).await; let err = res.unwrap_err(); - assert_eq!(err.to_string(), "protocol error: frame with invalid size"); + assert_eq!( + err.to_string(), + "connection error detected: frame with invalid size" + ); }; // client should see a conn error let conn = async move { let err = h2.await.unwrap_err(); - assert_eq!(err.to_string(), "protocol error: frame with invalid size"); + assert_eq!( + err.to_string(), + "connection error detected: frame with invalid size" + ); }; join(conn, req).await; }; @@ -239,7 +379,7 @@ async fn errors_if_recv_frame_exceeds_max_frame_size() { #[tokio::test] async fn configure_max_frame_size() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let h2 = async move { @@ -278,7 +418,7 @@ async fn configure_max_frame_size() { #[tokio::test] async fn recv_goaway_finishes_processed_streams() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -321,7 +461,10 @@ async fn recv_goaway_finishes_processed_streams() { // this request will trigger a goaway let req2 = async move { let err = client.get("https://example.com/").await.unwrap_err(); - assert_eq!(err.to_string(), "protocol error: not a result of an error"); + assert_eq!( + err.to_string(), + "connection error received: not a result of an error" + ); }; join3(async move { h2.await.expect("client") }, req1, req2).await; @@ -332,7 +475,7 @@ async fn recv_goaway_finishes_processed_streams() { #[tokio::test] async fn recv_goaway_with_higher_last_processed_id() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -366,7 +509,7 @@ async fn recv_goaway_with_higher_last_processed_id() { #[tokio::test] async fn recv_next_stream_id_updated_by_malformed_headers() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut client) = mock::new(); let bad_auth = util::byte_str("not:a/good authority"); @@ -404,7 +547,7 @@ async fn recv_next_stream_id_updated_by_malformed_headers() { #[tokio::test] async fn skipped_stream_ids_are_implicitly_closed() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -445,7 +588,7 @@ async fn skipped_stream_ids_are_implicitly_closed() { #[tokio::test] async fn send_rst_stream_allows_recv_data() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -460,7 +603,7 @@ async fn send_rst_stream_allows_recv_data() { srv.send_frame(frames::headers(1).response(200)).await; srv.recv_frame(frames::reset(1).cancel()).await; // sending frames after canceled! - // note: sending 2 to cosume 50% of connection window + // note: sending 2 to consume 50% of connection window srv.send_frame(frames::data(1, vec![0; 16_384])).await; srv.send_frame(frames::data(1, vec![0; 16_384]).eos()).await; // make sure we automatically free the connection window @@ -490,7 +633,7 @@ async fn send_rst_stream_allows_recv_data() { #[tokio::test] async fn send_rst_stream_allows_recv_trailers() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -531,7 +674,7 @@ async fn send_rst_stream_allows_recv_trailers() { #[tokio::test] async fn rst_stream_expires() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -582,7 +725,7 @@ async fn rst_stream_expires() { #[tokio::test] async fn rst_stream_max() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -607,14 +750,14 @@ async fn rst_stream_max() { srv.recv_frame(frames::reset(1).cancel()).await; srv.recv_frame(frames::reset(3).cancel()).await; // sending frame after canceled! - // newer streams trump older streams - // 3 is still being ignored - srv.send_frame(frames::data(3, vec![0; 16]).eos()).await; + // olders streams trump newer streams + // 1 is still being ignored + srv.send_frame(frames::data(1, vec![0; 16]).eos()).await; // ping pong to be sure of no goaway srv.ping_pong([1; 8]).await; - // 1 has been evicted, will get a reset - srv.send_frame(frames::data(1, vec![0; 16]).eos()).await; - srv.recv_frame(frames::reset(1).stream_closed()).await; + // 3 has been evicted, will get a reset + srv.send_frame(frames::data(3, vec![0; 16]).eos()).await; + srv.recv_frame(frames::reset(3).stream_closed()).await; }; let client = async move { @@ -653,7 +796,7 @@ async fn rst_stream_max() { #[tokio::test] async fn reserved_state_recv_window_update() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -692,7 +835,7 @@ async fn reserved_state_recv_window_update() { /* #[test] fn send_data_after_headers_eos() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let mock = mock_io::Builder::new() .handshake() @@ -733,10 +876,10 @@ async fn rst_while_closing() { // Test to reproduce panic in issue #246 --- receipt of a RST_STREAM frame // on a stream in the Half Closed (remote) state with a queued EOS causes // a panic. - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); - // Rendevous when we've queued a trailers frame + // Rendezvous when we've queued a trailers frame let (tx, rx) = oneshot::channel(); let srv = async move { @@ -777,7 +920,7 @@ async fn rst_while_closing() { // Enqueue trailers frame. let _ = stream.send_trailers(HeaderMap::new()); // Signal the server mock to send RST_FRAME - let _ = tx.send(()).unwrap(); + let _: () = tx.send(()).unwrap(); drop(stream); yield_once().await; // yield once to allow the server mock to be polled @@ -794,7 +937,7 @@ async fn rst_with_buffered_data() { // the data is fully flushed. Given that resetting a stream requires // clearing all associated state for that stream, this test ensures that the // buffered up frame is correctly handled. - let _ = env_logger::try_init(); + h2_support::trace_init!(); // This allows the settings + headers frame through let (io, mut srv) = mock::new_with_write_capacity(73); @@ -846,7 +989,7 @@ async fn err_with_buffered_data() { // the data is fully flushed. Given that resetting a stream requires // clearing all associated state for that stream, this test ensures that the // buffered up frame is correctly handled. - let _ = env_logger::try_init(); + h2_support::trace_init!(); // This allows the settings + headers frame through let (io, mut srv) = mock::new_with_write_capacity(73); @@ -901,7 +1044,7 @@ async fn send_err_with_buffered_data() { // the data is fully flushed. Given that resetting a stream requires // clearing all associated state for that stream, this test ensures that the // buffered up frame is correctly handled. - let _ = env_logger::try_init(); + h2_support::trace_init!(); // This allows the settings + headers frame through let (io, mut srv) = mock::new_with_write_capacity(73); @@ -963,7 +1106,7 @@ async fn send_err_with_buffered_data() { #[tokio::test] async fn srv_window_update_on_lower_stream_id() { // See https://github.com/hyperium/h2/issues/208 - let _ = env_logger::try_init(); + h2_support::trace_init!(); let (io, mut srv) = mock::new(); let srv = async move { @@ -1013,3 +1156,60 @@ async fn srv_window_update_on_lower_stream_id() { }; join(srv, client).await; } + +// See https://github.com/hyperium/h2/issues/570 +#[tokio::test] +async fn reset_new_stream_before_send() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + // Send unexpected headers, that depends on itself, causing a framing error. + srv.send_bytes(&[ + 0, 0, 0x6, // len + 0x1, // type (headers) + 0x25, // flags (eos, eoh, pri) + 0, 0, 0, 0x3, // stream id + 0, 0, 0, 0x3, // dependency + 2, // weight + 0x88, // HPACK :status=200 + ]) + .await; + srv.recv_frame(frames::reset(3).protocol_error()).await; + srv.recv_frame( + frames::headers(5) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(5).response(200).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + let resp = conn + .drive(client.get("https://example.com/")) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + // req number 2 + let resp = conn + .drive(client.get("https://example.com/")) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + conn.await.expect("client"); + }; + + join(srv, client).await; +} diff --git a/tests/h2-tests/tests/trailers.rs b/tests/h2-tests/tests/trailers.rs index 513b65d82..08a463ab7 100644 --- a/tests/h2-tests/tests/trailers.rs +++ b/tests/h2-tests/tests/trailers.rs @@ -3,7 +3,7 @@ use h2_support::prelude::*; #[tokio::test] async fn recv_trailers_only() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let mock = mock_io::Builder::new() .handshake() @@ -28,7 +28,7 @@ async fn recv_trailers_only() { .body(()) .unwrap(); - log::info!("sending request"); + tracing::info!("sending request"); let (response, _) = client.send_request(request, true).unwrap(); let response = h2.run(response).await.unwrap(); @@ -53,7 +53,7 @@ async fn recv_trailers_only() { #[tokio::test] async fn send_trailers_immediately() { - let _ = env_logger::try_init(); + h2_support::trace_init!(); let mock = mock_io::Builder::new() .handshake() @@ -79,7 +79,7 @@ async fn send_trailers_immediately() { .body(()) .unwrap(); - log::info!("sending request"); + tracing::info!("sending request"); let (response, mut stream) = client.send_request(request, false).unwrap(); let mut trailers = HeaderMap::new(); diff --git a/util/genfixture/Cargo.toml b/util/genfixture/Cargo.toml index 694a99496..cce7eb1b1 100644 --- a/util/genfixture/Cargo.toml +++ b/util/genfixture/Cargo.toml @@ -6,4 +6,4 @@ publish = false edition = "2018" [dependencies] -walkdir = "1.0.0" +walkdir = "2.3.2" diff --git a/util/genfixture/src/main.rs b/util/genfixture/src/main.rs index a6d730761..9dc7b00f9 100644 --- a/util/genfixture/src/main.rs +++ b/util/genfixture/src/main.rs @@ -10,7 +10,7 @@ fn main() { let path = args.get(1).expect("usage: genfixture [PATH]"); let path = Path::new(path); - let mut tests = HashMap::new(); + let mut tests: HashMap> = HashMap::new(); for entry in WalkDir::new(path) { let entry = entry.unwrap(); @@ -28,21 +28,21 @@ fn main() { let fixture_path = path.split("fixtures/hpack/").last().unwrap(); // Now, split that into the group and the name - let module = fixture_path.split("/").next().unwrap(); + let module = fixture_path.split('/').next().unwrap(); tests .entry(module.to_string()) - .or_insert(vec![]) + .or_default() .push(fixture_path.to_string()); } let mut one = false; for (module, tests) in tests { - let module = module.replace("-", "_"); + let module = module.replace('-', "_"); if one { - println!(""); + println!(); } one = true; @@ -51,7 +51,7 @@ fn main() { println!(" {} => {{", module); for test in tests { - let ident = test.split("/").nth(1).unwrap().split(".").next().unwrap(); + let ident = test.split('/').nth(1).unwrap().split('.').next().unwrap(); println!(" ({}, {:?});", ident, test); } diff --git a/util/genhuff/src/main.rs b/util/genhuff/src/main.rs index 2d5b0ba75..6418fab8b 100644 --- a/util/genhuff/src/main.rs +++ b/util/genhuff/src/main.rs @@ -112,8 +112,8 @@ impl Node { }; start.transitions.borrow_mut().push(Transition { - target: target, - byte: byte, + target, + byte, maybe_eos: self.maybe_eos, }); @@ -238,7 +238,7 @@ pub fn main() { let (encode, decode) = load_table(); println!("// !!! DO NOT EDIT !!! Generated by util/genhuff/src/main.rs"); - println!(""); + println!(); println!("// (num-bits, bits)"); println!("pub const ENCODE_TABLE: [(usize, u64); 257] = ["); @@ -247,7 +247,7 @@ pub fn main() { } println!("];"); - println!(""); + println!(); println!("// (next-state, byte, flags)"); println!("pub const DECODE_TABLE: [[(usize, u8, u8); 16]; 256] = ["); @@ -256,7 +256,7 @@ pub fn main() { println!("];"); } -const TABLE: &'static str = r##" +const TABLE: &str = r##" ( 0) |11111111|11000 1ff8 [13] ( 1) |11111111|11111111|1011000 7fffd8 [23] ( 2) |11111111|11111111|11111110|0010 fffffe2 [28]