Skip to content

Commit

Permalink
Reduce some code duplication around agent connection and protocol (me…
Browse files Browse the repository at this point in the history
…talbear-co#2094)

* Use generics in Codecs to reduce duplication

* refactor to eliminate code duplication for agent connection

* changelog

* CR: use default instead of new

* CR: Client and Daemon Codec docs.

* bump protocol crate version.

* cargo fmt

* CR: remove confusing comments about phantom data.
  • Loading branch information
t4lz authored Dec 5, 2023
1 parent e5e04ae commit 2e9ba6a
Show file tree
Hide file tree
Showing 13 changed files with 73 additions and 141 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions changelog.d/+refactor-connection.internal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reduce some code duplication around protocol and agent connection.
2 changes: 1 addition & 1 deletion mirrord/agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl State {
cancellation_token: CancellationToken,
protocol_version: semver::Version,
) -> Result<Option<JoinHandle<u32>>> {
let mut stream = Framed::new(stream, DaemonCodec::new());
let mut stream = Framed::new(stream, DaemonCodec::default());

let client_id = match self.new_client().await {
Ok(id) => id,
Expand Down
2 changes: 1 addition & 1 deletion mirrord/agent/tests/blackbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ mod tests {
}
});

let mut codec = Framed::new(stream, ClientCodec::new());
let mut codec = Framed::new(stream, ClientCodec::default());
let subscription_port = 1337;

codec
Expand Down
17 changes: 16 additions & 1 deletion mirrord/cli/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::path::PathBuf;

use miette::Diagnostic;
use mirrord_console::error::ConsoleError;
use mirrord_intproxy::error::IntProxyError;
use mirrord_intproxy::{agent_conn::AgentConnectionError, error::IntProxyError};
use mirrord_kube::error::KubeApiError;
use mirrord_operator::client::{HttpError, OperatorApiError};
use thiserror::Error;
Expand Down Expand Up @@ -279,3 +279,18 @@ impl From<OperatorApiError> for CliError {
}
}
}

impl From<AgentConnectionError> for CliError {
fn from(err: AgentConnectionError) -> Self {
match err {
AgentConnectionError::Io(err) => {
CliError::InternalProxySetupError(InternalProxySetupError::TcpConnectError(err))
}
AgentConnectionError::NoConnectionMethod => {
CliError::InternalProxySetupError(InternalProxySetupError::NoConnectionMethod)
}
AgentConnectionError::Operator(err) => err.into(),
AgentConnectionError::Kube(err) => err.into(),
}
}
}
73 changes: 14 additions & 59 deletions mirrord/cli/src/internal_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,13 @@ use std::{

use mirrord_analytics::{AnalyticsError, AnalyticsReporter, CollectAnalytics};
use mirrord_config::LayerConfig;
use mirrord_intproxy::{agent_conn::AgentConnectInfo, IntProxy};
use mirrord_kube::api::{kubernetes::KubernetesAPI, wrap_raw_connection, AgentManagment};
use mirrord_operator::client::{
OperatorApi, OperatorSessionConnection, OperatorSessionInformation,
use mirrord_intproxy::{
agent_conn::{AgentConnectInfo, AgentConnection},
IntProxy,
};
use mirrord_protocol::{pause::DaemonPauseTarget, ClientMessage, DaemonMessage};
use nix::libc;
use tokio::{
net::{TcpListener, TcpStream},
sync::mpsc,
task::JoinHandle,
};
use tokio::{net::TcpListener, sync::mpsc, task::JoinHandle};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, log::trace};

Expand Down Expand Up @@ -160,10 +155,9 @@ pub(crate) async fn proxy(watch: drain::Watch) -> Result<()> {
// Create a main connection, that will be held until proxy is closed.
// This will guarantee agent staying alive and will enable us to
// make the agent close on last connection close immediately (will help in tests)
let mut main_connection =
connect_and_ping(&config, agent_connect_info.clone(), Some(&mut analytics))
.await
.inspect_err(|_| analytics.set_error(AnalyticsError::AgentConnection))?;
let mut main_connection = connect_and_ping(&config, agent_connect_info.clone(), &mut analytics)
.await
.inspect_err(|_| analytics.set_error(AnalyticsError::AgentConnection))?;

if config.pause {
tokio::time::timeout(
Expand Down Expand Up @@ -215,11 +209,14 @@ pub(crate) async fn proxy(watch: drain::Watch) -> Result<()> {
async fn connect_and_ping(
config: &LayerConfig,
agent_connect_info: Option<AgentConnectInfo>,
analytics: Option<&mut AnalyticsReporter>,
analytics: &mut AnalyticsReporter,
) -> Result<(mpsc::Sender<ClientMessage>, mpsc::Receiver<DaemonMessage>)> {
let ((sender, mut receiver), _) = connect(config, agent_connect_info, analytics).await?;
ping(&sender, &mut receiver).await?;
Ok((sender, receiver))
let AgentConnection {
agent_tx,
mut agent_rx,
} = AgentConnection::new(config, agent_connect_info, Some(analytics)).await?;
ping(&agent_tx, &mut agent_rx).await?;
Ok((agent_tx, agent_rx))
}

/// Sends a ping the connection and expects a pong.
Expand Down Expand Up @@ -270,45 +267,3 @@ fn create_ping_loop(

(cancellation_token, join_handle)
}

/// Connects to an agent pod depending on how [`LayerConfig`] is set-up:
///
/// - `connect_tcp`: connects directly to the `address` specified, and calls [`wrap_raw_connection`]
/// on the [`TcpStream`];
///
/// - `connect_agent_name`: Connects to an agent with `connect_agent_name` on `connect_agent_port`
/// using [`KubernetesAPI];
///
/// - None of the above: uses the [`OperatorApi`] to establish the connection.
/// Returns the tx/rx and whether the operator is used.
async fn connect(
config: &LayerConfig,
agent_connect_info: Option<AgentConnectInfo>,
analytics: Option<&mut AnalyticsReporter>,
) -> Result<(
(mpsc::Sender<ClientMessage>, mpsc::Receiver<DaemonMessage>),
Option<OperatorSessionInformation>,
)> {
match agent_connect_info {
Some(AgentConnectInfo::Operator(operator_session_information)) => {
let OperatorSessionConnection { tx, rx, info } =
OperatorApi::connect(config, operator_session_information, analytics).await?;
Ok(((tx, rx), Some(info)))
}
Some(AgentConnectInfo::DirectKubernetes(connect_info)) => {
let k8s_api = KubernetesAPI::create(config).await?;
let connection = k8s_api.create_connection(connect_info).await?;
Ok((connection, None))
}
None => {
if let Some(address) = &config.connect_tcp {
let stream = TcpStream::connect(address)
.await
.map_err(InternalProxySetupError::TcpConnectError)?;
Ok((wrap_raw_connection(stream), None))
} else {
Err(InternalProxySetupError::NoConnectionMethod.into())
}
}
}
}
1 change: 1 addition & 0 deletions mirrord/intproxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mirrord-kube = { path = "../kube" }
mirrord-operator = { path = "../operator", features = ["client"] }
mirrord-protocol = { path = "../protocol" }
mirrord-intproxy-protocol = { path = "./protocol", features = ["codec-async"] }
mirrord-analytics = { path = "../analytics"}

serde.workspace = true
thiserror.workspace = true
Expand Down
8 changes: 5 additions & 3 deletions mirrord/intproxy/src/agent_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::{io, net::SocketAddr};

use mirrord_analytics::AnalyticsReporter;
use mirrord_config::LayerConfig;
use mirrord_kube::{
api::{
Expand Down Expand Up @@ -60,8 +61,8 @@ pub enum AgentConnectInfo {
/// [`mpsc`](tokio::sync::mpsc) channels returned from other functions and implements the
/// [`BackgroundTask`] trait.
pub struct AgentConnection {
agent_tx: Sender<ClientMessage>,
agent_rx: Receiver<DaemonMessage>,
pub agent_tx: Sender<ClientMessage>,
pub agent_rx: Receiver<DaemonMessage>,
}

impl AgentConnection {
Expand All @@ -70,10 +71,11 @@ impl AgentConnection {
pub async fn new(
config: &LayerConfig,
connect_info: Option<AgentConnectInfo>,
analytics: Option<&mut AnalyticsReporter>,
) -> Result<Self, AgentConnectionError> {
let (agent_tx, agent_rx) = match connect_info {
Some(AgentConnectInfo::Operator(operator_session_information)) => {
let session = OperatorApi::connect(config, operator_session_information, None)
let session = OperatorApi::connect(config, operator_session_information, analytics)
.await
.map_err(AgentConnectionError::Operator)?;

Expand Down
7 changes: 4 additions & 3 deletions mirrord/intproxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

use std::{collections::HashMap, time::Duration};

use agent_conn::AgentConnection;
use background_tasks::{BackgroundTasks, TaskSender, TaskUpdate};
use layer_conn::LayerConnection;
use layer_initializer::LayerInitializer;
Expand All @@ -21,7 +20,9 @@ use proxies::{
use tokio::{net::TcpListener, time};

use crate::{
agent_conn::AgentConnectInfo, background_tasks::TaskError, error::IntProxyError,
agent_conn::{AgentConnectInfo, AgentConnection},
background_tasks::TaskError,
error::IntProxyError,
main_tasks::LayerClosed,
};

Expand Down Expand Up @@ -72,7 +73,7 @@ impl IntProxy {
agent_connect_info: Option<AgentConnectInfo>,
listener: TcpListener,
) -> Result<Self, IntProxyError> {
let agent_conn = AgentConnection::new(config, agent_connect_info).await?;
let agent_conn = AgentConnection::new(config, agent_connect_info, None).await?;
Ok(Self::new_with_connection(agent_conn, listener))
}

Expand Down
2 changes: 1 addition & 1 deletion mirrord/kube/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const CONNECTION_CHANNEL_SIZE: usize = 1000;
pub fn wrap_raw_connection(
stream: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
) -> (mpsc::Sender<ClientMessage>, mpsc::Receiver<DaemonMessage>) {
let mut codec = actix_codec::Framed::new(stream, ClientCodec::new());
let mut codec = actix_codec::Framed::new(stream, ClientCodec::default());

let (in_tx, mut in_rx) = mpsc::channel(CONNECTION_CHANNEL_SIZE);
let (out_tx, out_rx) = mpsc::channel(CONNECTION_CHANNEL_SIZE);
Expand Down
2 changes: 1 addition & 1 deletion mirrord/layer/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl TestIntProxy {
});

let (stream, _buffer_size) = fake_agent_listener.accept().await.unwrap();
let codec = Framed::new(stream, DaemonCodec::new());
let codec = Framed::new(stream, DaemonCodec::default());

let mut res = Self {
codec,
Expand Down
2 changes: 1 addition & 1 deletion mirrord/protocol/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mirrord-protocol"
version = "1.3.2"
version = "1.3.3"
authors.workspace = true
description.workspace = true
documentation.workspace = true
Expand Down
Loading

0 comments on commit 2e9ba6a

Please sign in to comment.