Skip to content

Commit

Permalink
http2: refactoring to store stream headers state locally (envoyproxy#…
Browse files Browse the repository at this point in the history
…32013)

This change tracks the HEADERS frame state of each stream internally to the StreamImpl data structure. This is intended to ease the removal of nghttp2 callbacks, as the newer Http2VisitorInterface API does not communicate the nghttp2 hcat state.

Risk Level: low; refactoring, no functional change intended
Testing: ran unit and integration tests locally
Docs Changes:
Release Notes:
Platform Specific Features:

Signed-off-by: Biren Roy <[email protected]>
  • Loading branch information
birenroy authored Jan 26, 2024
1 parent 462fef1 commit 33f14ec
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 32 deletions.
76 changes: 48 additions & 28 deletions source/common/http/http2/codec_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -675,13 +675,44 @@ void ConnectionImpl::ClientStreamImpl::submitHeaders(const HeaderMap& headers, b
ASSERT(stream_id_ > 0);
}

Status ConnectionImpl::ClientStreamImpl::onBeginHeaders() {
if (headers_state_ == HeadersState::Headers) {
allocTrailers();
}

return okStatus();
}

void ConnectionImpl::ClientStreamImpl::advanceHeadersState() {
RELEASE_ASSERT(
headers_state_ == HeadersState::Response || headers_state_ == HeadersState::Headers, "");
headers_state_ = HeadersState::Headers;
}

void ConnectionImpl::ServerStreamImpl::submitHeaders(const HeaderMap& headers, bool end_stream) {
ASSERT(stream_id_ != -1);
parent_.adapter_->SubmitResponse(stream_id_, buildHeaders(headers),
end_stream ? nullptr
: std::make_unique<StreamDataFrameSource>(*this));
}

Status ConnectionImpl::ServerStreamImpl::onBeginHeaders() {
if (headers_state_ != HeadersState::Request) {
parent_.stats_.trailers_.inc();
ASSERT(headers_state_ == HeadersState::Headers);

allocTrailers();
}

return okStatus();
}

void ConnectionImpl::ServerStreamImpl::advanceHeadersState() {
RELEASE_ASSERT(headers_state_ == HeadersState::Request || headers_state_ == HeadersState::Headers,
"");
headers_state_ = HeadersState::Headers;
}

void ConnectionImpl::StreamImpl::onPendingFlushTimer() {
ENVOY_CONN_LOG(debug, "pending stream flush timeout", parent_.connection_);
MultiplexedStreamImplBase::onPendingFlushTimer();
Expand Down Expand Up @@ -1111,8 +1142,7 @@ Status ConnectionImpl::onGoAway(uint32_t error_code) {
return okStatus();
}

Status ConnectionImpl::onHeaders(int32_t stream_id, size_t length, uint8_t flags,
int headers_category) {
Status ConnectionImpl::onHeaders(int32_t stream_id, size_t length, uint8_t flags) {
StreamImpl* stream = getStreamUnchecked(stream_id);
if (!stream) {
return okStatus();
Expand All @@ -1127,14 +1157,15 @@ Status ConnectionImpl::onHeaders(int32_t stream_id, size_t length, uint8_t flags
stream->headers().addViaMove(std::move(key), std::move(stream->cookies_));
}

switch (headers_category) {
case NGHTTP2_HCAT_RESPONSE:
case NGHTTP2_HCAT_REQUEST: {
StreamImpl::HeadersState headers_state = stream->headersState();
switch (headers_state) {
case StreamImpl::HeadersState::Response:
case StreamImpl::HeadersState::Request: {
stream->decodeHeaders();
break;
}

case NGHTTP2_HCAT_HEADERS: {
case StreamImpl::HeadersState::Headers: {
// It's possible that we are waiting to send a deferred reset, so only raise headers/trailers
// if local is not complete.
if (!stream->deferred_reset_) {
Expand All @@ -1154,6 +1185,7 @@ Status ConnectionImpl::onHeaders(int32_t stream_id, size_t length, uint8_t flags
ENVOY_BUG(false, "push not supported");
}

stream->advanceHeadersState();
return okStatus();
}

Expand Down Expand Up @@ -1209,7 +1241,7 @@ Status ConnectionImpl::onFrameReceived(const nghttp2_frame* frame) {
return okStatus();
}
if (frame->hd.type == NGHTTP2_HEADERS) {
return onHeaders(frame->hd.stream_id, frame->hd.length, frame->hd.flags, frame->headers.cat);
return onHeaders(frame->hd.stream_id, frame->hd.length, frame->hd.flags);
}
if (frame->hd.type == NGHTTP2_RST_STREAM) {
return onRstStream(frame->hd.stream_id, frame->rst_stream.error_code);
Expand Down Expand Up @@ -1693,8 +1725,7 @@ ConnectionImpl::Http2Callbacks::Http2Callbacks() {

nghttp2_session_callbacks_set_on_begin_headers_callback(
callbacks_, [](nghttp2_session*, const nghttp2_frame* frame, void* user_data) -> int {
auto status = static_cast<ConnectionImpl*>(user_data)->onBeginHeaders(frame->hd.stream_id,
frame->headers.cat);
auto status = static_cast<ConnectionImpl*>(user_data)->onBeginHeaders(frame->hd.stream_id);
return static_cast<ConnectionImpl*>(user_data)->setAndCheckCodecCallbackStatus(
std::move(status));
});
Expand Down Expand Up @@ -2028,15 +2059,9 @@ RequestEncoder& ClientConnectionImpl::newStream(ResponseDecoder& decoder) {
return stream_ref;
}

Status ClientConnectionImpl::onBeginHeaders(int32_t stream_id, int headers_category) {
RELEASE_ASSERT(
headers_category == NGHTTP2_HCAT_RESPONSE || headers_category == NGHTTP2_HCAT_HEADERS, "");
if (headers_category == NGHTTP2_HCAT_HEADERS) {
StreamImpl* stream = getStream(stream_id);
stream->allocTrailers();
}

return okStatus();
Status ClientConnectionImpl::onBeginHeaders(int32_t stream_id) {
StreamImpl* stream = getStream(stream_id);
return stream->onBeginHeaders();
}

int ClientConnectionImpl::onHeader(int32_t stream_id, HeaderString&& name, HeaderString&& value) {
Expand Down Expand Up @@ -2111,18 +2136,13 @@ ServerConnectionImpl::ServerConnectionImpl(
allow_metadata_ = http2_options.allow_metadata();
}

Status ServerConnectionImpl::onBeginHeaders(int32_t stream_id, int headers_category) {
Status ServerConnectionImpl::onBeginHeaders(int32_t stream_id) {
ASSERT(connection_.state() == Network::Connection::State::Open);

if (headers_category != NGHTTP2_HCAT_REQUEST) {
stats_.trailers_.inc();
ASSERT(headers_category == NGHTTP2_HCAT_HEADERS);

StreamImpl* stream = getStream(stream_id);
stream->allocTrailers();
return okStatus();
StreamImpl* stream_ptr = getStream(stream_id);
if (stream_ptr != nullptr) {
return stream_ptr->onBeginHeaders();
}

ServerStreamImplPtr stream(new ServerStreamImpl(*this, per_stream_buffer_limit_));
if (connection_.aboveHighWatermark()) {
stream->runHighWatermarkCallbacks();
Expand All @@ -2132,7 +2152,7 @@ Status ServerConnectionImpl::onBeginHeaders(int32_t stream_id, int headers_categ
LinkedList::moveIntoList(std::move(stream), active_streams_);
adapter_->SetStreamUserData(stream_id, active_streams_.front().get());
protocol_constraints_.incrementOpenedStreamCount();
return okStatus();
return active_streams_.front()->onBeginHeaders();
}

int ServerConnectionImpl::onHeader(int32_t stream_id, HeaderString&& name, HeaderString&& value) {
Expand Down
24 changes: 20 additions & 4 deletions source/common/http/http2/codec_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ class ConnectionImpl : public virtual Connection,
public Event::DeferredDeletable,
public Http::MultiplexedStreamImplBase,
public ScopeTrackedObject {
enum class HeadersState {
Request,
Response,
Headers, // Signifies additional headers after the initial request/response set.
};

StreamImpl(ConnectionImpl& parent, uint32_t buffer_limit);

Expand All @@ -221,6 +226,9 @@ class ConnectionImpl : public virtual Connection,
StreamImpl* base() { return this; }
void resetStreamWorker(StreamResetReason reason);
static std::vector<http2::adapter::Header> buildHeaders(const HeaderMap& headers);
virtual Status onBeginHeaders() PURE;
virtual void advanceHeadersState() PURE;
virtual HeadersState headersState() const PURE;
void saveHeader(HeaderString&& name, HeaderString&& value);
void encodeHeadersBase(const HeaderMap& headers, bool end_stream);
virtual void submitHeaders(const HeaderMap& headers, bool end_stream) PURE;
Expand Down Expand Up @@ -452,6 +460,9 @@ class ConnectionImpl : public virtual Connection,
}
// StreamImpl
void submitHeaders(const HeaderMap& headers, bool end_stream) override;
Status onBeginHeaders() override;
void advanceHeadersState() override;
HeadersState headersState() const override { return headers_state_; }
// Do not use deferred reset on upstream connections.
bool useDeferredReset() const override { return false; }
StreamDecoder& decoder() override { return response_decoder_; }
Expand Down Expand Up @@ -492,6 +503,7 @@ class ConnectionImpl : public virtual Connection,
ResponseDecoder& response_decoder_;
absl::variant<ResponseHeaderMapPtr, ResponseTrailerMapPtr> headers_or_trailers_;
std::string upgrade_type_;
HeadersState headers_state_ = HeadersState::Response;
};

using ClientStreamImplPtr = std::unique_ptr<ClientStreamImpl>;
Expand All @@ -508,6 +520,9 @@ class ConnectionImpl : public virtual Connection,
// StreamImpl
void destroy() override;
void submitHeaders(const HeaderMap& headers, bool end_stream) override;
Status onBeginHeaders() override;
void advanceHeadersState() override;
HeadersState headersState() const override { return headers_state_; }
// Enable deferred reset on downstream connections so outbound HTTP internal error replies are
// written out before force resetting the stream, assuming there is enough H2 connection flow
// control window is available.
Expand Down Expand Up @@ -554,6 +569,7 @@ class ConnectionImpl : public virtual Connection,

private:
RequestDecoder* request_decoder_{};
HeadersState headers_state_ = HeadersState::Request;
};

using ServerStreamImplPtr = std::unique_ptr<ServerStreamImpl>;
Expand Down Expand Up @@ -687,13 +703,13 @@ class ConnectionImpl : public virtual Connection,
friend class Http2CodecImplTestFixture;

virtual ConnectionCallbacks& callbacks() PURE;
virtual Status onBeginHeaders(int32_t stream_id, int headers_category) PURE;
virtual Status onBeginHeaders(int32_t stream_id) PURE;
int onData(int32_t stream_id, const uint8_t* data, size_t len);
Status onBeforeFrameReceived(int32_t stream_id, size_t length, uint8_t type, uint8_t flags);
Status onPing(uint64_t opaque_data, bool is_ack);
Status onBeginData(int32_t stream_id, size_t length, uint8_t type, uint8_t flags, size_t padding);
Status onGoAway(uint32_t error_code);
Status onHeaders(int32_t stream_id, size_t length, uint8_t flags, int headers_category);
Status onHeaders(int32_t stream_id, size_t length, uint8_t flags);
Status onRstStream(int32_t stream_id, uint32_t error_code);
Status onFrameReceived(const nghttp2_frame* frame);
int onBeforeFrameSend(int32_t stream_id, size_t length, uint8_t type, uint8_t flags);
Expand Down Expand Up @@ -757,7 +773,7 @@ class ClientConnectionImpl : public ClientConnection, public ConnectionImpl {
private:
// ConnectionImpl
ConnectionCallbacks& callbacks() override { return callbacks_; }
Status onBeginHeaders(int32_t stream_id, int headers_category) override;
Status onBeginHeaders(int32_t stream_id) override;
int onHeader(int32_t stream_id, HeaderString&& name, HeaderString&& value) override;
Status trackInboundFrames(int32_t stream_id, size_t length, uint8_t type, uint8_t flags,
uint32_t) override;
Expand All @@ -784,7 +800,7 @@ class ServerConnectionImpl : public ServerConnection, public ConnectionImpl {
private:
// ConnectionImpl
ConnectionCallbacks& callbacks() override { return callbacks_; }
Status onBeginHeaders(int32_t stream_id, int headers_category) override;
Status onBeginHeaders(int32_t stream_id) override;
int onHeader(int32_t stream_id, HeaderString&& name, HeaderString&& value) override;
Status trackInboundFrames(int32_t stream_id, size_t length, uint8_t type, uint8_t flags,
uint32_t padding_length) override;
Expand Down

0 comments on commit 33f14ec

Please sign in to comment.