From 33f14ec2e8ed1ea38b7db70fdbb7d1c635ae5cd6 Mon Sep 17 00:00:00 2001 From: birenroy Date: Fri, 26 Jan 2024 16:14:40 -0500 Subject: [PATCH] http2: refactoring to store stream headers state locally (#32013) This change tracks the HEADERS frame state of each stream internally to the StreamImpl data structure. This is intended to ease the removal of nghttp2 callbacks, as the newer Http2VisitorInterface API does not communicate the nghttp2 hcat state. Risk Level: low; refactoring, no functional change intended Testing: ran unit and integration tests locally Docs Changes: Release Notes: Platform Specific Features: Signed-off-by: Biren Roy --- source/common/http/http2/codec_impl.cc | 76 ++++++++++++++++---------- source/common/http/http2/codec_impl.h | 24 ++++++-- 2 files changed, 68 insertions(+), 32 deletions(-) diff --git a/source/common/http/http2/codec_impl.cc b/source/common/http/http2/codec_impl.cc index 2a1009e8b282..f766871ae3de 100644 --- a/source/common/http/http2/codec_impl.cc +++ b/source/common/http/http2/codec_impl.cc @@ -675,6 +675,20 @@ void ConnectionImpl::ClientStreamImpl::submitHeaders(const HeaderMap& headers, b ASSERT(stream_id_ > 0); } +Status ConnectionImpl::ClientStreamImpl::onBeginHeaders() { + if (headers_state_ == HeadersState::Headers) { + allocTrailers(); + } + + return okStatus(); +} + +void ConnectionImpl::ClientStreamImpl::advanceHeadersState() { + RELEASE_ASSERT( + headers_state_ == HeadersState::Response || headers_state_ == HeadersState::Headers, ""); + headers_state_ = HeadersState::Headers; +} + void ConnectionImpl::ServerStreamImpl::submitHeaders(const HeaderMap& headers, bool end_stream) { ASSERT(stream_id_ != -1); parent_.adapter_->SubmitResponse(stream_id_, buildHeaders(headers), @@ -682,6 +696,23 @@ void ConnectionImpl::ServerStreamImpl::submitHeaders(const HeaderMap& headers, b : std::make_unique(*this)); } +Status ConnectionImpl::ServerStreamImpl::onBeginHeaders() { + if (headers_state_ != HeadersState::Request) { + parent_.stats_.trailers_.inc(); + ASSERT(headers_state_ == HeadersState::Headers); + + allocTrailers(); + } + + return okStatus(); +} + +void ConnectionImpl::ServerStreamImpl::advanceHeadersState() { + RELEASE_ASSERT(headers_state_ == HeadersState::Request || headers_state_ == HeadersState::Headers, + ""); + headers_state_ = HeadersState::Headers; +} + void ConnectionImpl::StreamImpl::onPendingFlushTimer() { ENVOY_CONN_LOG(debug, "pending stream flush timeout", parent_.connection_); MultiplexedStreamImplBase::onPendingFlushTimer(); @@ -1111,8 +1142,7 @@ Status ConnectionImpl::onGoAway(uint32_t error_code) { return okStatus(); } -Status ConnectionImpl::onHeaders(int32_t stream_id, size_t length, uint8_t flags, - int headers_category) { +Status ConnectionImpl::onHeaders(int32_t stream_id, size_t length, uint8_t flags) { StreamImpl* stream = getStreamUnchecked(stream_id); if (!stream) { return okStatus(); @@ -1127,14 +1157,15 @@ Status ConnectionImpl::onHeaders(int32_t stream_id, size_t length, uint8_t flags stream->headers().addViaMove(std::move(key), std::move(stream->cookies_)); } - switch (headers_category) { - case NGHTTP2_HCAT_RESPONSE: - case NGHTTP2_HCAT_REQUEST: { + StreamImpl::HeadersState headers_state = stream->headersState(); + switch (headers_state) { + case StreamImpl::HeadersState::Response: + case StreamImpl::HeadersState::Request: { stream->decodeHeaders(); break; } - case NGHTTP2_HCAT_HEADERS: { + case StreamImpl::HeadersState::Headers: { // It's possible that we are waiting to send a deferred reset, so only raise headers/trailers // if local is not complete. if (!stream->deferred_reset_) { @@ -1154,6 +1185,7 @@ Status ConnectionImpl::onHeaders(int32_t stream_id, size_t length, uint8_t flags ENVOY_BUG(false, "push not supported"); } + stream->advanceHeadersState(); return okStatus(); } @@ -1209,7 +1241,7 @@ Status ConnectionImpl::onFrameReceived(const nghttp2_frame* frame) { return okStatus(); } if (frame->hd.type == NGHTTP2_HEADERS) { - return onHeaders(frame->hd.stream_id, frame->hd.length, frame->hd.flags, frame->headers.cat); + return onHeaders(frame->hd.stream_id, frame->hd.length, frame->hd.flags); } if (frame->hd.type == NGHTTP2_RST_STREAM) { return onRstStream(frame->hd.stream_id, frame->rst_stream.error_code); @@ -1693,8 +1725,7 @@ ConnectionImpl::Http2Callbacks::Http2Callbacks() { nghttp2_session_callbacks_set_on_begin_headers_callback( callbacks_, [](nghttp2_session*, const nghttp2_frame* frame, void* user_data) -> int { - auto status = static_cast(user_data)->onBeginHeaders(frame->hd.stream_id, - frame->headers.cat); + auto status = static_cast(user_data)->onBeginHeaders(frame->hd.stream_id); return static_cast(user_data)->setAndCheckCodecCallbackStatus( std::move(status)); }); @@ -2028,15 +2059,9 @@ RequestEncoder& ClientConnectionImpl::newStream(ResponseDecoder& decoder) { return stream_ref; } -Status ClientConnectionImpl::onBeginHeaders(int32_t stream_id, int headers_category) { - RELEASE_ASSERT( - headers_category == NGHTTP2_HCAT_RESPONSE || headers_category == NGHTTP2_HCAT_HEADERS, ""); - if (headers_category == NGHTTP2_HCAT_HEADERS) { - StreamImpl* stream = getStream(stream_id); - stream->allocTrailers(); - } - - return okStatus(); +Status ClientConnectionImpl::onBeginHeaders(int32_t stream_id) { + StreamImpl* stream = getStream(stream_id); + return stream->onBeginHeaders(); } int ClientConnectionImpl::onHeader(int32_t stream_id, HeaderString&& name, HeaderString&& value) { @@ -2111,18 +2136,13 @@ ServerConnectionImpl::ServerConnectionImpl( allow_metadata_ = http2_options.allow_metadata(); } -Status ServerConnectionImpl::onBeginHeaders(int32_t stream_id, int headers_category) { +Status ServerConnectionImpl::onBeginHeaders(int32_t stream_id) { ASSERT(connection_.state() == Network::Connection::State::Open); - if (headers_category != NGHTTP2_HCAT_REQUEST) { - stats_.trailers_.inc(); - ASSERT(headers_category == NGHTTP2_HCAT_HEADERS); - - StreamImpl* stream = getStream(stream_id); - stream->allocTrailers(); - return okStatus(); + StreamImpl* stream_ptr = getStream(stream_id); + if (stream_ptr != nullptr) { + return stream_ptr->onBeginHeaders(); } - ServerStreamImplPtr stream(new ServerStreamImpl(*this, per_stream_buffer_limit_)); if (connection_.aboveHighWatermark()) { stream->runHighWatermarkCallbacks(); @@ -2132,7 +2152,7 @@ Status ServerConnectionImpl::onBeginHeaders(int32_t stream_id, int headers_categ LinkedList::moveIntoList(std::move(stream), active_streams_); adapter_->SetStreamUserData(stream_id, active_streams_.front().get()); protocol_constraints_.incrementOpenedStreamCount(); - return okStatus(); + return active_streams_.front()->onBeginHeaders(); } int ServerConnectionImpl::onHeader(int32_t stream_id, HeaderString&& name, HeaderString&& value) { diff --git a/source/common/http/http2/codec_impl.h b/source/common/http/http2/codec_impl.h index 8b59234ea833..23154c1ea50e 100644 --- a/source/common/http/http2/codec_impl.h +++ b/source/common/http/http2/codec_impl.h @@ -206,6 +206,11 @@ class ConnectionImpl : public virtual Connection, public Event::DeferredDeletable, public Http::MultiplexedStreamImplBase, public ScopeTrackedObject { + enum class HeadersState { + Request, + Response, + Headers, // Signifies additional headers after the initial request/response set. + }; StreamImpl(ConnectionImpl& parent, uint32_t buffer_limit); @@ -221,6 +226,9 @@ class ConnectionImpl : public virtual Connection, StreamImpl* base() { return this; } void resetStreamWorker(StreamResetReason reason); static std::vector buildHeaders(const HeaderMap& headers); + virtual Status onBeginHeaders() PURE; + virtual void advanceHeadersState() PURE; + virtual HeadersState headersState() const PURE; void saveHeader(HeaderString&& name, HeaderString&& value); void encodeHeadersBase(const HeaderMap& headers, bool end_stream); virtual void submitHeaders(const HeaderMap& headers, bool end_stream) PURE; @@ -452,6 +460,9 @@ class ConnectionImpl : public virtual Connection, } // StreamImpl void submitHeaders(const HeaderMap& headers, bool end_stream) override; + Status onBeginHeaders() override; + void advanceHeadersState() override; + HeadersState headersState() const override { return headers_state_; } // Do not use deferred reset on upstream connections. bool useDeferredReset() const override { return false; } StreamDecoder& decoder() override { return response_decoder_; } @@ -492,6 +503,7 @@ class ConnectionImpl : public virtual Connection, ResponseDecoder& response_decoder_; absl::variant headers_or_trailers_; std::string upgrade_type_; + HeadersState headers_state_ = HeadersState::Response; }; using ClientStreamImplPtr = std::unique_ptr; @@ -508,6 +520,9 @@ class ConnectionImpl : public virtual Connection, // StreamImpl void destroy() override; void submitHeaders(const HeaderMap& headers, bool end_stream) override; + Status onBeginHeaders() override; + void advanceHeadersState() override; + HeadersState headersState() const override { return headers_state_; } // Enable deferred reset on downstream connections so outbound HTTP internal error replies are // written out before force resetting the stream, assuming there is enough H2 connection flow // control window is available. @@ -554,6 +569,7 @@ class ConnectionImpl : public virtual Connection, private: RequestDecoder* request_decoder_{}; + HeadersState headers_state_ = HeadersState::Request; }; using ServerStreamImplPtr = std::unique_ptr; @@ -687,13 +703,13 @@ class ConnectionImpl : public virtual Connection, friend class Http2CodecImplTestFixture; virtual ConnectionCallbacks& callbacks() PURE; - virtual Status onBeginHeaders(int32_t stream_id, int headers_category) PURE; + virtual Status onBeginHeaders(int32_t stream_id) PURE; int onData(int32_t stream_id, const uint8_t* data, size_t len); Status onBeforeFrameReceived(int32_t stream_id, size_t length, uint8_t type, uint8_t flags); Status onPing(uint64_t opaque_data, bool is_ack); Status onBeginData(int32_t stream_id, size_t length, uint8_t type, uint8_t flags, size_t padding); Status onGoAway(uint32_t error_code); - Status onHeaders(int32_t stream_id, size_t length, uint8_t flags, int headers_category); + Status onHeaders(int32_t stream_id, size_t length, uint8_t flags); Status onRstStream(int32_t stream_id, uint32_t error_code); Status onFrameReceived(const nghttp2_frame* frame); int onBeforeFrameSend(int32_t stream_id, size_t length, uint8_t type, uint8_t flags); @@ -757,7 +773,7 @@ class ClientConnectionImpl : public ClientConnection, public ConnectionImpl { private: // ConnectionImpl ConnectionCallbacks& callbacks() override { return callbacks_; } - Status onBeginHeaders(int32_t stream_id, int headers_category) override; + Status onBeginHeaders(int32_t stream_id) override; int onHeader(int32_t stream_id, HeaderString&& name, HeaderString&& value) override; Status trackInboundFrames(int32_t stream_id, size_t length, uint8_t type, uint8_t flags, uint32_t) override; @@ -784,7 +800,7 @@ class ServerConnectionImpl : public ServerConnection, public ConnectionImpl { private: // ConnectionImpl ConnectionCallbacks& callbacks() override { return callbacks_; } - Status onBeginHeaders(int32_t stream_id, int headers_category) override; + Status onBeginHeaders(int32_t stream_id) override; int onHeader(int32_t stream_id, HeaderString&& name, HeaderString&& value) override; Status trackInboundFrames(int32_t stream_id, size_t length, uint8_t type, uint8_t flags, uint32_t padding_length) override;