From 44b94c9c6abffc152cc56333a714f8045979c94b Mon Sep 17 00:00:00 2001 From: Alex Ostrovski Date: Thu, 28 Sep 2023 11:33:34 +0300 Subject: [PATCH] Move TCP metrics to `network` crate --- node/actors/network/src/metrics.rs | 94 ++++++++++++- node/actors/network/src/noise/stream.rs | 154 ++++++++++++---------- node/actors/network/src/noise/testonly.rs | 19 ++- node/actors/network/src/preface.rs | 8 +- node/libs/concurrency/src/metrics.rs | 32 +---- node/libs/concurrency/src/net/tcp/mod.rs | 76 +---------- 6 files changed, 199 insertions(+), 184 deletions(-) diff --git a/node/actors/network/src/metrics.rs b/node/actors/network/src/metrics.rs index ec9ec80ff..441ae41ae 100644 --- a/node/actors/network/src/metrics.rs +++ b/node/actors/network/src/metrics.rs @@ -1,8 +1,98 @@ //! General-purpose network metrics. use crate::state::State; -use std::sync::Weak; -use vise::{Collector, Gauge, Metrics}; +use concurrency::{io, metrics::GaugeGuard, net}; +use std::{ + pin::Pin, + sync::Weak, + task::{ready, Context, Poll}, +}; +use vise::{Collector, Counter, EncodeLabelSet, EncodeLabelValue, Family, Gauge, Metrics, Unit}; + +/// Metered TCP stream. +#[pin_project::pin_project] +pub(crate) struct MeteredStream { + #[pin] + stream: net::tcp::Stream, + _active: GaugeGuard, +} + +impl MeteredStream { + /// Creates a new stream with the specified `direction`. + pub(crate) fn new(stream: net::tcp::Stream, direction: Direction) -> Self { + TCP_METRICS.established[&direction].inc(); + Self { + stream, + _active: GaugeGuard::from(TCP_METRICS.active[&direction].clone()), + } + } +} + +impl io::AsyncRead for MeteredStream { + #[inline(always)] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + let this = self.project(); + let before = buf.remaining(); + let res = this.stream.poll_read(cx, buf); + let after = buf.remaining(); + TCP_METRICS.received.inc_by((before - after) as u64); + res + } +} + +impl io::AsyncWrite for MeteredStream { + #[inline(always)] + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + let this = self.project(); + let res = ready!(this.stream.poll_write(cx, buf))?; + TCP_METRICS.sent.inc_by(res as u64); + Poll::Ready(Ok(res)) + } + + #[inline(always)] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().stream.poll_flush(cx) + } + + #[inline(always)] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().stream.poll_shutdown(cx) + } +} + +/// Direction of a TCP connection. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EncodeLabelSet, EncodeLabelValue)] +#[metrics(label = "direction", rename_all = "snake_case")] +pub(crate) enum Direction { + /// Inbound connection. + Inbound, + /// Outbound connection. + Outbound, +} + +/// Metrics reported for TCP connections. +#[derive(Debug, Metrics)] +#[metrics(prefix = "concurrency_net_tcp")] +struct TcpMetrics { + /// Total bytes sent over all TCP connections. + #[metrics(unit = Unit::Bytes)] + pub(crate) sent: Counter, + /// Total bytes received over all TCP connections. + #[metrics(unit = Unit::Bytes)] + pub(crate) received: Counter, + /// TCP connections established since the process started. + pub(crate) established: Family, + /// Number of currently active TCP connections. + pub(crate) active: Family, +} + +/// TCP metrics instance. +#[vise::register] +static TCP_METRICS: vise::Global = vise::Global::new(); /// General-purpose network metrics exposed via a collector. #[derive(Debug, Metrics)] diff --git a/node/actors/network/src/noise/stream.rs b/node/actors/network/src/noise/stream.rs index adf94b703..d3b975445 100644 --- a/node/actors/network/src/noise/stream.rs +++ b/node/actors/network/src/noise/stream.rs @@ -1,9 +1,9 @@ //! `tokio::io` stream using Noise encryption. use super::bytes; +use crate::metrics::MeteredStream; use concurrency::{ ctx, io, io::{AsyncRead as _, AsyncWrite as _}, - net, }; use crypto::{sha256::Sha256, ByteFmt}; use std::{ @@ -32,65 +32,6 @@ fn params() -> snow::params::NoiseParams { } } -impl Stream { - /// Performs a server-side noise handshake and returns the encrypted stream. - pub(crate) async fn server_handshake( - ctx: &ctx::Ctx, - s: net::tcp::Stream, - ) -> anyhow::Result { - Self::handshake(ctx, s, snow::Builder::new(params()).build_responder()?).await - } - - /// Performs a client-side noise handshake and returns the encrypted stream. - pub(crate) async fn client_handshake( - ctx: &ctx::Ctx, - s: net::tcp::Stream, - ) -> anyhow::Result { - Self::handshake(ctx, s, snow::Builder::new(params()).build_initiator()?).await - } - - /// Performs the noise handshake given the HandshakeState. - async fn handshake( - ctx: &ctx::Ctx, - mut stream: net::tcp::Stream, - mut hs: snow::HandshakeState, - ) -> anyhow::Result { - let mut buf = vec![0; 65536]; - let mut payload = vec![]; - loop { - if hs.is_handshake_finished() { - return Ok(Stream { - id: ByteFmt::decode(hs.get_handshake_hash()).unwrap(), - inner: stream, - noise: hs.into_transport_mode()?, - read_buf: Box::default(), - write_buf: Box::default(), - }); - } - if hs.is_my_turn() { - let n = hs.write_message(&payload, &mut buf)?; - // TODO(gprusak): writing/reading length field and the frame content could be - // done in a single syscall. - io::write_all(ctx, &mut stream, &u16::to_le_bytes(n as u16)).await??; - io::write_all(ctx, &mut stream, &buf[..n]).await??; - io::flush(ctx, &mut stream).await??; - } else { - let mut msg_size = [0u8, 2]; - io::read_exact(ctx, &mut stream, &mut msg_size).await??; - let n = u16::from_le_bytes(msg_size) as usize; - io::read_exact(ctx, &mut stream, &mut buf[..n]).await??; - hs.read_message(&buf[..n], &mut payload)?; - } - } - } - - /// Returns the noise session id. - /// See `Stream::id`. - pub(crate) fn id(&self) -> Sha256 { - self.id - } -} - // Constants from the Noise spec. /// Maximal size of the encrypted frame that Noise may output. @@ -130,16 +71,14 @@ impl Default for Buffer { /// Encrypted stream. /// It implements tokio::io::AsyncRead/AsyncWrite. -#[pin_project::pin_project(project=StreamProject)] -pub(crate) struct Stream { +#[pin_project::pin_project(project = StreamProject)] +pub(crate) struct Stream { /// Hash of the handshake messages. /// Uniquely identifies the noise session. id: Sha256, /// Underlying TCP stream. - /// TODO(gprusak): we can generalize noise::Stream to wrap an arbitrary - /// stream if needed. #[pin] - inner: net::tcp::Stream, + inner: S, /// Noise protocol state, used to encrypt/decrypt frames. noise: snow::TransportState, /// Buffers used for the read half of the stream. @@ -148,12 +87,66 @@ pub(crate) struct Stream { write_buf: Box, } -impl Stream { +impl Stream +where + S: io::AsyncRead + io::AsyncWrite + Unpin, +{ + /// Performs a server-side noise handshake and returns the encrypted stream. + pub(crate) async fn server_handshake(ctx: &ctx::Ctx, stream: S) -> anyhow::Result { + Self::handshake(ctx, stream, snow::Builder::new(params()).build_responder()?).await + } + + /// Performs a client-side noise handshake and returns the encrypted stream. + pub(crate) async fn client_handshake(ctx: &ctx::Ctx, stream: S) -> anyhow::Result { + Self::handshake(ctx, stream, snow::Builder::new(params()).build_initiator()?).await + } + + /// Performs the noise handshake given the HandshakeState. + async fn handshake( + ctx: &ctx::Ctx, + mut stream: S, + mut hs: snow::HandshakeState, + ) -> anyhow::Result { + let mut buf = vec![0; 65536]; + let mut payload = vec![]; + loop { + if hs.is_handshake_finished() { + return Ok(Self { + id: ByteFmt::decode(hs.get_handshake_hash()).unwrap(), + inner: stream, + noise: hs.into_transport_mode()?, + read_buf: Box::default(), + write_buf: Box::default(), + }); + } + if hs.is_my_turn() { + let n = hs.write_message(&payload, &mut buf)?; + // TODO(gprusak): writing/reading length field and the frame content could be + // done in a single syscall. + io::write_all(ctx, &mut stream, &u16::to_le_bytes(n as u16)).await??; + io::write_all(ctx, &mut stream, &buf[..n]).await??; + io::flush(ctx, &mut stream).await??; + } else { + let mut msg_size = [0u8, 2]; + io::read_exact(ctx, &mut stream, &mut msg_size).await??; + let n = u16::from_le_bytes(msg_size) as usize; + io::read_exact(ctx, &mut stream, &mut buf[..n]).await??; + hs.read_message(&buf[..n], &mut payload)?; + } + } + } + + /// Returns the noise session id. + /// See `Stream::id`. + pub(crate) fn id(&self) -> Sha256 { + self.id + } + /// Wait until a frame is fully loaded. /// Returns the size of the frame. /// Returns None in case EOF is reached before the frame is loaded. fn poll_read_frame( - this: &mut StreamProject<'_>, + this: &mut StreamProject<'_, S>, cx: &mut Context<'_>, ) -> Poll>> { // Fetch frame until complete. @@ -179,7 +172,7 @@ impl Stream { /// Wait until payload is nonempty. fn poll_read_payload( - this: &mut StreamProject<'_>, + this: &mut StreamProject<'_, S>, cx: &mut Context<'_>, ) -> Poll> { if this.read_buf.payload.len() > 0 { @@ -203,7 +196,10 @@ impl Stream { } } -impl io::AsyncRead for Stream { +impl io::AsyncRead for Stream +where + S: io::AsyncRead + io::AsyncWrite + Unpin, +{ /// From tokio::io::AsyncRead: /// * The amount of data read can be determined by the increase /// in the length of the slice returned by ReadBuf::filled. @@ -227,9 +223,15 @@ impl io::AsyncRead for Stream { } } -impl Stream { +impl Stream +where + S: io::AsyncRead + io::AsyncWrite + Unpin, +{ /// poll_flush_frame will either flush this.write_buf.frame, or return an error. - fn poll_flush_frame(this: &mut StreamProject, cx: &mut Context<'_>) -> Poll> { + fn poll_flush_frame( + this: &mut StreamProject<'_, S>, + cx: &mut Context<'_>, + ) -> Poll> { while this.write_buf.frame.len() > 0 { let n = ready!(Pin::new(&mut this.inner).poll_write(cx, this.write_buf.frame.as_slice()))?; @@ -242,7 +244,10 @@ impl Stream { } /// poll_flush_payload will either flush this.write_buf.payload, or return an error. - fn poll_flush_payload(this: &mut StreamProject, cx: &mut Context<'_>) -> Poll> { + fn poll_flush_payload( + this: &mut StreamProject<'_, S>, + cx: &mut Context<'_>, + ) -> Poll> { if this.write_buf.payload.len() == 0 { return Poll::Ready(Ok(())); } @@ -266,7 +271,10 @@ impl Stream { } } -impl io::AsyncWrite for Stream { +impl io::AsyncWrite for Stream +where + S: io::AsyncRead + io::AsyncWrite + Unpin, +{ /// from futures::io::AsyncWrite: /// * poll_write must try to make progress by flushing if needed to become writable /// from std::io::Write: diff --git a/node/actors/network/src/noise/testonly.rs b/node/actors/network/src/noise/testonly.rs index 425e2f99a..c5a8529c0 100644 --- a/node/actors/network/src/noise/testonly.rs +++ b/node/actors/network/src/noise/testonly.rs @@ -1,12 +1,21 @@ -use crate::noise; +use crate::{metrics, noise}; use concurrency::{ctx, net, scope}; pub(crate) async fn pipe(ctx: &ctx::Ctx) -> (noise::Stream, noise::Stream) { scope::run!(ctx, |ctx, s| async { - let (s1, s2) = net::tcp::testonly::pipe(ctx).await; - let s1 = s.spawn(async { noise::Stream::client_handshake(ctx, s1).await }); - let s2 = s.spawn(async { noise::Stream::server_handshake(ctx, s2).await }); - Ok((s1.join(ctx).await?, s2.join(ctx).await?)) + let (outbound_stream, inbound_stream) = net::tcp::testonly::pipe(ctx).await; + let outbound_stream = + metrics::MeteredStream::new(outbound_stream, metrics::Direction::Outbound); + let inbound_stream = + metrics::MeteredStream::new(inbound_stream, metrics::Direction::Inbound); + let outbound_task = + s.spawn(async { noise::Stream::client_handshake(ctx, outbound_stream).await }); + let inbound_task = + s.spawn(async { noise::Stream::server_handshake(ctx, inbound_stream).await }); + Ok(( + outbound_task.join(ctx).await?, + inbound_task.join(ctx).await?, + )) }) .await .unwrap() diff --git a/node/actors/network/src/preface.rs b/node/actors/network/src/preface.rs index 7641468ce..b96c89286 100644 --- a/node/actors/network/src/preface.rs +++ b/node/actors/network/src/preface.rs @@ -7,7 +7,7 @@ //! //! Hence, the preface protocol is used to enable encryption //! and multiplex between mutliple endpoints available on the same TCP port. -use crate::{frame, noise}; +use crate::{frame, metrics, noise}; use concurrency::{ctx, net, time}; use schema::{proto::network::preface as proto, required, ProtoFmt}; @@ -79,7 +79,8 @@ pub(crate) async fn connect( endpoint: Endpoint, ) -> anyhow::Result { let ctx = &ctx.with_timeout(TIMEOUT); - let mut stream = net::tcp::connect(ctx, addr).await??; + let stream = net::tcp::connect(ctx, addr).await??; + let mut stream = metrics::MeteredStream::new(stream, metrics::Direction::Outbound); frame::send_proto(ctx, &mut stream, &Encryption::NoiseNN).await?; let mut stream = noise::Stream::client_handshake(ctx, stream).await?; frame::send_proto(ctx, &mut stream, &endpoint).await?; @@ -89,8 +90,9 @@ pub(crate) async fn connect( /// Performs a server-side preface protocol. pub(crate) async fn accept( ctx: &ctx::Ctx, - mut stream: net::tcp::Stream, + stream: net::tcp::Stream, ) -> anyhow::Result<(noise::Stream, Endpoint)> { + let mut stream = metrics::MeteredStream::new(stream, metrics::Direction::Inbound); let ctx = &ctx.with_timeout(TIMEOUT); let _: Encryption = frame::recv_proto(ctx, &mut stream).await?; let mut stream = noise::Stream::server_handshake(ctx, stream).await?; diff --git a/node/libs/concurrency/src/metrics.rs b/node/libs/concurrency/src/metrics.rs index 82cfac27a..236986287 100644 --- a/node/libs/concurrency/src/metrics.rs +++ b/node/libs/concurrency/src/metrics.rs @@ -1,7 +1,7 @@ //! Prometheus metrics utilities. use std::time::Duration; -use vise::{Counter, EncodeLabelSet, EncodeLabelValue, Family, Gauge, Metrics, Unit}; +use vise::Gauge; /// Guard which increments the gauge when constructed /// and decrements it when dropped. @@ -20,36 +20,6 @@ impl Drop for GaugeGuard { } } -/// Direction of a TCP connection. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EncodeLabelSet, EncodeLabelValue)] -#[metrics(label = "direction", rename_all = "snake_case")] -pub(crate) enum Direction { - /// Inbound connection. - Inbound, - /// Outbound connection. - Outbound, -} - -/// Metrics reported for TCP connections. -#[derive(Debug, Metrics)] -#[metrics(prefix = "concurrency_net_tcp")] -pub(crate) struct TcpMetrics { - /// Total bytes sent over all TCP connections. - #[metrics(unit = Unit::Bytes)] - pub(crate) sent: Counter, - /// Total bytes received over all TCP connections. - #[metrics(unit = Unit::Bytes)] - pub(crate) received: Counter, - /// TCP connections established since the process started. - pub(crate) established: Family, - /// Number of currently active TCP connections. - pub(crate) active: Family, -} - -/// TCP metrics instance. -#[vise::register] -pub(crate) static TCP_METRICS: vise::Global = vise::Global::new(); - /// Extension trait for latency histograms. pub trait LatencyHistogramExt { /// Observes latency. diff --git a/node/libs/concurrency/src/net/tcp/mod.rs b/node/libs/concurrency/src/net/tcp/mod.rs index 2cc7d86b4..097b4e3f3 100644 --- a/node/libs/concurrency/src/net/tcp/mod.rs +++ b/node/libs/concurrency/src/net/tcp/mod.rs @@ -3,46 +3,26 @@ //! algorithm (so that the transmission latency is more //! predictable), so the caller is expected to apply //! user space buffering. -use crate::{ - ctx, - metrics::{self, Direction}, -}; +use crate::ctx; pub use listener_addr::*; -use std::{ - pin::Pin, - task::{ready, Context, Poll}, -}; use tokio::io; mod listener_addr; pub mod testonly; /// TCP stream. -#[pin_project::pin_project] -pub struct Stream { - #[pin] - stream: tokio::net::TcpStream, - _active: metrics::GaugeGuard, -} - +pub type Stream = tokio::net::TcpStream; /// TCP listener. pub type Listener = tokio::net::TcpListener; /// Accepts an INBOUND listener connection. pub async fn accept(ctx: &ctx::Ctx, this: &mut Listener) -> ctx::OrCanceled> { - Ok(ctx.wait(this.accept()).await?.map(|stream| { - metrics::TCP_METRICS.established[&Direction::Inbound].inc(); - + Ok(ctx.wait(this.accept()).await?.map(|(stream, _)| { // We are the only owner of the correctly opened // socket at this point so `set_nodelay` should // always succeed. - stream.0.set_nodelay(true).unwrap(); - Stream { - stream: stream.0, - _active: metrics::TCP_METRICS.active[&Direction::Inbound] - .clone() - .into(), - } + stream.set_nodelay(true).unwrap(); + stream })) } @@ -55,54 +35,10 @@ pub async fn connect( .wait(tokio::net::TcpStream::connect(addr)) .await? .map(|stream| { - metrics::TCP_METRICS.established[&Direction::Outbound].inc(); // We are the only owner of the correctly opened // socket at this point so `set_nodelay` should // always succeed. stream.set_nodelay(true).unwrap(); - Stream { - stream, - _active: metrics::TCP_METRICS.active[&Direction::Outbound] - .clone() - .into(), - } + stream })) } - -impl io::AsyncRead for Stream { - #[inline(always)] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut io::ReadBuf<'_>, - ) -> Poll> { - let this = self.project(); - let before = buf.remaining(); - let res = this.stream.poll_read(cx, buf); - let after = buf.remaining(); - metrics::TCP_METRICS - .received - .inc_by((before - after) as u64); - res - } -} - -impl io::AsyncWrite for Stream { - #[inline(always)] - fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { - let this = self.project(); - let res = ready!(this.stream.poll_write(cx, buf))?; - metrics::TCP_METRICS.sent.inc_by(res as u64); - Poll::Ready(Ok(res)) - } - - #[inline(always)] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.project().stream.poll_flush(cx) - } - - #[inline(always)] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.project().stream.poll_shutdown(cx) - } -}