Skip to content

Commit

Permalink
feat: secure HandshakeProtocol trait and simple SharedKeyHandshake (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
giangndm authored Oct 9, 2024
1 parent 14e963e commit ba1a842
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 43 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ parking_lot = { version = "0.12" }
rand = { version = "0.8" }
futures = { version = "0.3" }
bincode = { version = "1.3" }
blake3 = { version = "1.3" }
tokio-util = { version = "0.7", features = ["codec"] }
lru = { version = "0.12" }

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
test-log = { version = "0.2" }
clap = { version = "4.4", features = ["derive", "env", "color"] }
tracing-subscriber = { version = "0.3", features = ["env-filter", "std"] }
tracing-subscriber = { version = "0.3", features = ["env-filter", "std"] }
13 changes: 9 additions & 4 deletions examples/simple.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{net::SocketAddr, str::FromStr};

use atm0s_small_p2p::{P2pNetwork, P2pNetworkConfig, PeerAddress};
use atm0s_small_p2p::{P2pNetwork, P2pNetworkConfig, PeerAddress, SharedKeyHandshake};
use clap::Parser;
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
Expand Down Expand Up @@ -29,6 +29,10 @@ struct Args {
/// This option is useful with high performance relay node
#[arg(env, long)]
sdn_advertise_address: Option<SocketAddr>,

/// Sdn secure code
#[arg(env, long, default_value = "insecure")]
sdn_secure_code: String,
}

#[tokio::main]
Expand All @@ -44,17 +48,18 @@ async fn main() {
let args: Args = Args::parse();
tracing_subscriber::registry().with(fmt::layer()).with(EnvFilter::from_default_env()).init();

let key = PrivatePkcs8KeyDer::from(DEFAULT_CLUSTER_KEY.to_vec());
let priv_key = PrivatePkcs8KeyDer::from(DEFAULT_CLUSTER_KEY.to_vec());
let cert = CertificateDer::from(DEFAULT_CLUSTER_CERT.to_vec());

let mut p2p = P2pNetwork::new(P2pNetworkConfig {
peer_id: args.sdn_peer_id.into(),
listen_addr: args.sdn_listener,
advertise: args.sdn_advertise_address.map(|a| a.into()),
priv_key: key,
cert: cert,
priv_key,
cert,
tick_ms: 100,
seeds: args.sdn_seeds.into_iter().map(|s| PeerAddress::from_str(s.as_str()).expect("should parse address")).collect::<Vec<_>>(),
secure: SharedKeyHandshake::from(args.sdn_secure_code.as_str()),
})
.await
.expect("should create network");
Expand Down
24 changes: 15 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{
net::SocketAddr,
ops::Deref,
str::FromStr,
sync::Arc,
time::Duration,
};

Expand Down Expand Up @@ -36,6 +37,7 @@ mod peer;
mod quic;
mod requester;
mod router;
mod secure;
mod service;
mod stream;
#[cfg(test)]
Expand All @@ -44,6 +46,7 @@ mod utils;

pub use requester::P2pNetworkRequester;
pub use router::SharedRouterTable;
pub use secure::*;
pub use service::*;
pub use stream::P2pQuicStream;
pub use utils::*;
Expand Down Expand Up @@ -116,14 +119,15 @@ enum ControlCmd {
Connect(PeerAddress, Option<oneshot::Sender<anyhow::Result<()>>>),
}

pub struct P2pNetworkConfig {
pub struct P2pNetworkConfig<SECURE> {
pub peer_id: PeerId,
pub listen_addr: SocketAddr,
pub advertise: Option<NetworkAddress>,
pub priv_key: PrivatePkcs8KeyDer<'static>,
pub cert: CertificateDer<'static>,
pub tick_ms: u64,
pub seeds: Vec<PeerAddress>,
pub secure: SECURE,
}

#[derive(Debug, PartialEq, Eq)]
Expand All @@ -133,7 +137,7 @@ pub enum P2pNetworkEvent {
Continue,
}

pub struct P2pNetwork {
pub struct P2pNetwork<SECURE> {
local_id: PeerId,
endpoint: Endpoint,
control_tx: UnboundedSender<ControlCmd>,
Expand All @@ -145,10 +149,11 @@ pub struct P2pNetwork {
router: SharedRouterTable,
discovery: PeerDiscovery,
ctx: SharedCtx,
secure: Arc<SECURE>,
}

impl P2pNetwork {
pub async fn new(cfg: P2pNetworkConfig) -> anyhow::Result<Self> {
impl<SECURE: HandshakeProtocol> P2pNetwork<SECURE> {
pub async fn new(cfg: P2pNetworkConfig<SECURE>) -> anyhow::Result<Self> {
log::info!("[P2pNetwork] starting node {}@{}", cfg.peer_id, cfg.listen_addr);
let endpoint = make_server_endpoint(cfg.listen_addr, cfg.priv_key, cfg.cert)?;
let (internal_tx, internal_rx) = channel(10);
Expand All @@ -172,6 +177,7 @@ impl P2pNetwork {
ctx: SharedCtx::new(router.clone()),
router,
discovery,
secure: Arc::new(cfg.secure),
})
}

Expand Down Expand Up @@ -229,21 +235,21 @@ impl P2pNetwork {
fn process_incoming(&mut self, incoming: Incoming) -> anyhow::Result<P2pNetworkEvent> {
let remote = incoming.remote_address();
log::info!("[P2pNetwork] incoming connect from {remote} => accept");
let conn = PeerConnection::new_incoming(self.local_id, incoming, self.internal_tx.clone(), self.ctx.clone());
let conn = PeerConnection::new_incoming(self.secure.clone(), self.local_id, incoming, self.internal_tx.clone(), self.ctx.clone());
self.neighbours.insert(conn.conn_id(), conn);
Ok(P2pNetworkEvent::Continue)
}

fn process_internal(&mut self, now_ms: u64, event: InternalEvent) -> anyhow::Result<P2pNetworkEvent> {
match event {
InternalEvent::PeerConnected(conn, peer, ttl_ms) => {
log::info!("[P2pNetwork] connected to {peer}");
log::info!("[P2pNetwork] connection {conn} connected to {peer}");
self.router.set_direct(conn, peer, ttl_ms);
self.neighbours.mark_connected(&conn, peer);
Ok(P2pNetworkEvent::PeerConnected(conn, peer))
}
InternalEvent::PeerData(conn, peer, data) => {
log::debug!("[P2pNetwork] on data {data:?} from {peer}");
log::debug!("[P2pNetwork] connection {conn} on data {data:?} from {peer}");
match data {
PeerMainData::Sync { route, advertise } => {
self.router.apply_sync(conn, route);
Expand All @@ -257,7 +263,7 @@ impl P2pNetwork {
Ok(P2pNetworkEvent::Continue)
}
InternalEvent::PeerDisconnected(conn, peer) => {
log::info!("[P2pNetwork] disconnected from {peer}");
log::info!("[P2pNetwork] connection {conn} disconnected from {peer}");
self.router.del_direct(&conn);
self.neighbours.remove(&conn);
Ok(P2pNetworkEvent::PeerDisconnected(conn, peer))
Expand All @@ -274,7 +280,7 @@ impl P2pNetwork {
log::info!("[P2pNetwork] connecting to {addr}");
match self.endpoint.connect(*addr.network_address().deref(), "cluster") {
Ok(connecting) => {
let conn = PeerConnection::new_connecting(self.local_id, addr.peer_id(), connecting, self.internal_tx.clone(), self.ctx.clone());
let conn = PeerConnection::new_connecting(self.secure.clone(), self.local_id, addr.peer_id(), connecting, self.internal_tx.clone(), self.ctx.clone());
self.neighbours.insert(conn.conn_id(), conn);
Ok(())
}
Expand Down
71 changes: 47 additions & 24 deletions src/peer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::net::SocketAddr;
use std::{net::SocketAddr, sync::Arc, time::Duration};

use anyhow::anyhow;
use peer_internal::PeerConnectionInternal;
Expand All @@ -12,6 +12,7 @@ use tokio::sync::{
use crate::{
ctx::SharedCtx,
msg::P2pServiceId,
secure::HandshakeProtocol,
stream::{wait_object, write_object, P2pQuicStream},
ConnectionId, PeerId,
};
Expand All @@ -34,19 +35,20 @@ pub struct PeerConnection {
}

impl PeerConnection {
pub fn new_incoming(local_id: PeerId, incoming: Incoming, internal_tx: Sender<InternalEvent>, ctx: SharedCtx) -> Self {
pub fn new_incoming<SECURE: HandshakeProtocol>(secure: Arc<SECURE>, local_id: PeerId, incoming: Incoming, internal_tx: Sender<InternalEvent>, ctx: SharedCtx) -> Self {
let remote = incoming.remote_address();
let conn_id = ConnectionId::rand();

tokio::spawn(async move {
log::info!("[PeerConnection] wait incoming from {remote}");
log::info!("[PeerConnection {conn_id}] wait incoming from {remote}");
match incoming.await {
Ok(connection) => {
log::info!("[PeerConnection] got connection from {remote}");
log::info!("[PeerConnection {conn_id}] got connection from {remote}");
match connection.accept_bi().await {
Ok((send, recv)) => {
if let Err(e) = run_connection(ctx, remote, conn_id, local_id, PeerConnectionDirection::Incoming, connection, send, recv, internal_tx).await {
log::error!("[PeerConnection] connection from {remote} error {e}");
if let Err(e) = run_connection(secure, ctx, remote, conn_id, local_id, PeerConnectionDirection::Incoming, &connection, send, recv, internal_tx).await {
log::error!("[PeerConnection {conn_id}] connection from {remote} error {e}");
let _ = tokio::time::timeout(Duration::from_secs(2), connection.closed()).await;
}
}
Err(err) => internal_tx.send(InternalEvent::PeerConnectError(conn_id, None, err.into())).await.expect("should send to main"),
Expand All @@ -58,18 +60,18 @@ impl PeerConnection {
Self { conn_id, peer_id: None }
}

pub fn new_connecting(local_id: PeerId, to_peer: PeerId, connecting: Connecting, internal_tx: Sender<InternalEvent>, ctx: SharedCtx) -> Self {
pub fn new_connecting<SECURE: HandshakeProtocol>(secure: Arc<SECURE>, local_id: PeerId, to_peer: PeerId, connecting: Connecting, internal_tx: Sender<InternalEvent>, ctx: SharedCtx) -> Self {
let remote = connecting.remote_address();
let conn_id = ConnectionId::rand();

tokio::spawn(async move {
match connecting.await {
Ok(connection) => {
log::info!("[PeerConnection] connected to {remote}");
log::info!("[PeerConnection {conn_id}] connected to {remote}");
match connection.open_bi().await {
Ok((send, recv)) => {
if let Err(e) = run_connection(ctx, remote, conn_id, local_id, PeerConnectionDirection::Outgoing(to_peer), connection, send, recv, internal_tx).await {
log::error!("[PeerConnection] connection from {remote} error {e}");
if let Err(e) = run_connection(secure, ctx, remote, conn_id, local_id, PeerConnectionDirection::Outgoing(to_peer), &connection, send, recv, internal_tx).await {
log::error!("[PeerConnection {conn_id}] connection to {remote} error {e}");
}
}
Err(err) => internal_tx
Expand Down Expand Up @@ -113,38 +115,59 @@ enum PeerConnectionDirection {
struct ConnectReq {
from: PeerId,
to: PeerId,
auth: Vec<u8>,
}

#[derive(Debug, Serialize, Deserialize)]
struct ConnectRes {
success: bool,
result: Result<Vec<u8>, String>,
}

async fn run_connection(
async fn run_connection<SECURE: HandshakeProtocol>(
secure: Arc<SECURE>,
ctx: SharedCtx,
remote: SocketAddr,
conn_id: ConnectionId,
local_id: PeerId,
direction: PeerConnectionDirection,
connection: Connection,
connection: &Connection,
mut send: SendStream,
mut recv: RecvStream,
internal_tx: Sender<InternalEvent>,
) -> anyhow::Result<()> {
let to_id = if let PeerConnectionDirection::Outgoing(dest) = direction {
write_object::<_, _, 500>(&mut send, &ConnectReq { from: local_id, to: dest }).await?;
let auth = secure.create_request(local_id, dest);
write_object::<_, _, 500>(&mut send, &ConnectReq { from: local_id, to: dest, auth }).await?;
let res: ConnectRes = wait_object::<_, _, 500>(&mut recv).await?;
if !res.success {
return Err(anyhow!("destination rejected"));
log::info!("{res:?}");
match res.result {
Ok(auth) => {
if let Err(e) = secure.verify_response(auth, dest, local_id) {
return Err(anyhow!("destination auth failure: {e}"));
}
dest
}
Err(err) => {
return Err(anyhow!("destination rejected: {err}"));
}
}
dest
} else {
let req: ConnectReq = wait_object::<_, _, 500>(&mut recv).await?;
if req.to != local_id {
write_object::<_, _, 500>(&mut send, &ConnectRes { success: false }).await?;
if let Err(e) = secure.verify_request(req.auth, req.from, req.to) {
write_object::<_, _, 500>(&mut send, &ConnectRes { result: Err(e.clone()) }).await?;
return Err(anyhow!("destination auth failure: {e}"));
} else if req.to != local_id {
write_object::<_, _, 500>(
&mut send,
&ConnectRes {
result: Err("destination not match".to_owned()),
},
)
.await?;
return Err(anyhow!("destination wrong"));
} else {
write_object::<_, _, 500>(&mut send, &ConnectRes { success: true }).await?;
let auth = secure.create_response(req.to, req.from);
write_object::<_, _, 500>(&mut send, &ConnectRes { result: Ok(auth) }).await?;
req.from
}
};
Expand All @@ -153,18 +176,18 @@ async fn run_connection(
let (control_tx, control_rx) = channel(10);
let alias = PeerConnectionAlias::new(local_id, to_id, conn_id, control_tx);
let mut internal = PeerConnectionInternal::new(ctx.clone(), conn_id, to_id, connection.clone(), send, recv, internal_tx.clone(), control_rx);
log::info!("[PeerConnection] started {remote}, rtt: {rtt_ms}");
log::info!("[PeerConnection {conn_id}] started {remote}, rtt: {rtt_ms}");
ctx.register_conn(conn_id, alias);
internal_tx.send(InternalEvent::PeerConnected(conn_id, to_id, rtt_ms)).await.expect("should send to main");
log::info!("[PeerConnection] run loop for {remote}");
log::info!("[PeerConnection {conn_id}] run loop for {remote}");
loop {
if let Err(e) = internal.recv_complex().await {
log::error!("[PeerConnection] {remote} error {e}");
log::error!("[PeerConnection {conn_id}] {remote} error {e}");
break;
}
}
internal_tx.send(InternalEvent::PeerDisconnected(conn_id, to_id)).await.expect("should send to main");
log::info!("[PeerConnection] end loop for {remote}");
log::info!("[PeerConnection {conn_id}] end loop for {remote}");
ctx.unregister_conn(&conn_id);
Ok(())
}
Loading

0 comments on commit ba1a842

Please sign in to comment.