Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rtt to Client #914

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .config/nats.dic
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ rustls
Acker
EndpointSchema
auth
filter_subject
filter_subjects
rollup
IoT
RttError
40 changes: 40 additions & 0 deletions async-nats/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,35 @@ impl Client {
Ok(())
}

/// Calculates the round trip time between this client and the server,
/// if the server is currently connected.
///
/// # Examples
///
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let rtt = client.rtt().await?;
/// println!("server rtt: {:?}", rtt);
/// # Ok(())
/// # }
/// ```
pub async fn rtt(&self) -> Result<Duration, RttError> {
let (tx, rx) = tokio::sync::oneshot::channel();

self.sender.send(Command::Rtt { result: tx }).await?;

let rtt = rx
.await
// first handle rx error
.map_err(|err| RttError(Box::new(err)))?
// second handle the actual rtt error
.map_err(|err| RttError(Box::new(err)))?;

Ok(rtt)
}

/// Returns the current state of the connection.
///
/// # Examples
Expand Down Expand Up @@ -684,3 +713,14 @@ impl From<SubscribeError> for RequestError {
RequestError::with_source(RequestErrorKind::Other, e)
}
}

/// Error returned when doing a round-trip time measurement fails.
#[derive(Debug, Error)]
#[error("failed to measure round-trip time: {0}")]
pub struct RttError(#[source] Box<dyn std::error::Error + Send + Sync>);

impl From<tokio::sync::mpsc::error::SendError<Command>> for RttError {
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
RttError(Box::new(err))
}
}
65 changes: 65 additions & 0 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ use thiserror::Error;
use futures::future::FutureExt;
use futures::select;
use futures::stream::Stream;
use std::time::Instant;
use tracing::{debug, error};

use core::fmt;
Expand Down Expand Up @@ -280,6 +281,9 @@ pub(crate) enum Command {
result: oneshot::Sender<Result<(), io::Error>>,
},
TryFlush,
Rtt {
result: oneshot::Sender<Result<Duration, io::Error>>,
},
}

/// `ClientOp` represents all actions of `Client`.
Expand Down Expand Up @@ -323,6 +327,9 @@ pub(crate) struct ConnectionHandler {
info_sender: tokio::sync::watch::Sender<ServerInfo>,
ping_interval: Interval,
flush_interval: Interval,
last_ping_time: Option<Instant>,
last_pong_time: Option<Instant>,
rtt_senders: Vec<oneshot::Sender<Result<Duration, io::Error>>>,
}

impl ConnectionHandler {
Expand All @@ -347,6 +354,9 @@ impl ConnectionHandler {
info_sender,
ping_interval,
flush_interval,
last_ping_time: None,
last_pong_time: None,
rtt_senders: Vec::new(),
}
}

Expand Down Expand Up @@ -425,6 +435,22 @@ impl ConnectionHandler {
}
ServerOp::Pong => {
debug!("received PONG");
if self.pending_pings == 1 {
self.last_pong_time = Some(Instant::now());

while let Some(sender) = self.rtt_senders.pop() {
if let (Some(ping), Some(pong)) = (self.last_ping_time, self.last_pong_time)
{
let rtt = pong.duration_since(ping);
sender.send(Ok(rtt)).map_err(|_| {
io::Error::new(
io::ErrorKind::Other,
"one shot failed to be received",
)
})?;
}
}
}
self.pending_pings = self.pending_pings.saturating_sub(1);
}
ServerOp::Error(error) => {
Expand Down Expand Up @@ -538,6 +564,14 @@ impl ConnectionHandler {
}
}
}
Command::Rtt { result } => {
self.rtt_senders.push(result);

if self.pending_pings == 0 {
// do a ping and expect a pong - will calculate rtt when handling the pong
self.handle_ping().await?;
}
}
Command::Flush { result } => {
if let Err(_err) = self.handle_flush().await {
if let Err(err) = self.handle_disconnect().await {
Expand Down Expand Up @@ -612,8 +646,39 @@ impl ConnectionHandler {
Ok(())
}

async fn handle_ping(&mut self) -> Result<(), io::Error> {
debug!(
"PING command. Pending pings {}, max pings {}",
self.pending_pings, MAX_PENDING_PINGS
);
self.pending_pings += 1;
self.ping_interval.reset();

if self.pending_pings > MAX_PENDING_PINGS {
debug!(
"pending pings {}, max pings {}. disconnecting",
self.pending_pings, MAX_PENDING_PINGS
);
self.handle_disconnect().await?;
}

if self.pending_pings == 1 {
// start the clock for calculating round trip time
self.last_ping_time = Some(Instant::now());
}

if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await {
self.handle_disconnect().await?;
}

self.handle_flush().await?;
Ok(())
}

async fn handle_disconnect(&mut self) -> io::Result<()> {
self.pending_pings = 0;
self.last_ping_time = None;
self.last_pong_time = None;
self.connector.events_tx.try_send(Event::Disconnected).ok();
self.connector.state_tx.send(State::Disconnected).ok();
self.handle_reconnect().await?;
Expand Down
11 changes: 11 additions & 0 deletions async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,4 +867,15 @@ mod client {
.await
.unwrap();
}

#[tokio::test]
async fn rtt() {
let server = nats_server::run_basic_server();
let client = async_nats::connect(server.client_url()).await.unwrap();

let rtt = client.rtt().await.unwrap();

println!("rtt: {:?}", rtt);
assert!(rtt.as_nanos() > 0);
}
}