diff --git a/src/peer.rs b/src/peer.rs index 2bd0f44..b11ccf0 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -12,6 +12,7 @@ use tokio::sync::{ use crate::{ ctx::SharedCtx, msg::P2pServiceId, + now_ms, secure::HandshakeProtocol, stream::{wait_object, write_object, P2pQuicStream}, ConnectionId, PeerId, @@ -136,13 +137,13 @@ async fn run_connection( internal_tx: Sender, ) -> anyhow::Result<()> { let to_id = if let PeerConnectionDirection::Outgoing(dest) = direction { - let auth = secure.create_request(local_id, dest); + let auth = secure.create_request(local_id, dest, now_ms()); write_object::<_, _, 500>(&mut send, &ConnectReq { from: local_id, to: dest, auth }).await?; let res: ConnectRes = wait_object::<_, _, 500>(&mut recv).await?; log::info!("{res:?}"); match res.result { Ok(auth) => { - if let Err(e) = secure.verify_response(auth, dest, local_id) { + if let Err(e) = secure.verify_response(auth, dest, local_id, now_ms()) { return Err(anyhow!("destination auth failure: {e}")); } dest @@ -153,7 +154,7 @@ async fn run_connection( } } else { let req: ConnectReq = wait_object::<_, _, 500>(&mut recv).await?; - if let Err(e) = secure.verify_request(req.auth, req.from, req.to) { + if let Err(e) = secure.verify_request(req.auth, req.from, req.to, now_ms()) { write_object::<_, _, 500>(&mut send, &ConnectRes { result: Err(e.clone()) }).await?; return Err(anyhow!("destination auth failure: {e}")); } else if req.to != local_id { @@ -166,7 +167,7 @@ async fn run_connection( .await?; return Err(anyhow!("destination wrong")); } else { - let auth = secure.create_response(req.to, req.from); + let auth = secure.create_response(req.to, req.from, now_ms()); write_object::<_, _, 500>(&mut send, &ConnectRes { result: Ok(auth) }).await?; req.from } diff --git a/src/secure.rs b/src/secure.rs index ff10f21..9075aee 100644 --- a/src/secure.rs +++ b/src/secure.rs @@ -1,11 +1,11 @@ -use crate::{now_ms, PeerId}; +use crate::PeerId; use serde::{Deserialize, Serialize}; pub trait HandshakeProtocol: Send + Sync + 'static { - fn create_request(&self, from: PeerId, to: PeerId) -> Vec; - fn verify_request(&self, data: Vec, expected_from: PeerId, expected_to: PeerId) -> Result<(), String>; - fn create_response(&self, from: PeerId, to: PeerId) -> Vec; - fn verify_response(&self, data: Vec, expected_from: PeerId, expected_to: PeerId) -> Result<(), String>; + fn create_request(&self, from: PeerId, to: PeerId, now: u64) -> Vec; + fn verify_request(&self, data: Vec, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String>; + fn create_response(&self, from: PeerId, to: PeerId, now: u64) -> Vec; + fn verify_response(&self, data: Vec, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String>; } const HASH_SEED: &str = "atm0s-small-p2p"; @@ -40,11 +40,11 @@ impl From<&str> for SharedKeyHandshake { } impl SharedKeyHandshake { - fn generate_handshake(&self, from: PeerId, to: PeerId, is_client: bool) -> Vec { + fn generate_handshake(&self, from: PeerId, to: PeerId, is_client: bool, now: u64) -> Vec { let handshake_data = HandshakeData { from, to, - timestamp: now_ms(), + timestamp: now, is_initiator: is_client, }; @@ -59,15 +59,14 @@ impl SharedKeyHandshake { bincode::serialize(&handshake).unwrap() } - fn validate_handshake(&self, data: Vec, expected_from: PeerId, expected_to: PeerId, expected_is_client: bool) -> Result<(), String> { + fn validate_handshake(&self, data: Vec, expected_from: PeerId, expected_to: PeerId, expected_is_client: bool, current_ts: u64) -> Result<(), String> { let handshake: HandshakeMessage = bincode::deserialize(&data).map_err(|_| "Invalid handshake format".to_string())?; let handshake_data: HandshakeData = bincode::deserialize(&handshake.payload).map_err(|_| "Invalid handshake data format".to_string())?; // Verify timestamp - let current_ts = now_ms(); - if current_ts - handshake_data.timestamp > HANDSHAKE_TIMEOUT { - return Err("Handshake timeout".to_string()); + if current_ts > handshake_data.timestamp + HANDSHAKE_TIMEOUT { + return Err(format!("Handshake timeout {} vs {}", current_ts, handshake_data.timestamp)); } // Verify peer IDs @@ -95,25 +94,27 @@ impl SharedKeyHandshake { } impl HandshakeProtocol for SharedKeyHandshake { - fn create_request(&self, from: PeerId, to: PeerId) -> Vec { - self.generate_handshake(from, to, true) + fn create_request(&self, from: PeerId, to: PeerId, now: u64) -> Vec { + self.generate_handshake(from, to, true, now) } - fn verify_request(&self, data: Vec, expected_from: PeerId, expected_to: PeerId) -> Result<(), String> { - self.validate_handshake(data, expected_from, expected_to, true) + fn verify_request(&self, data: Vec, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String> { + self.validate_handshake(data, expected_from, expected_to, true, now) } - fn create_response(&self, from: PeerId, to: PeerId) -> Vec { - self.generate_handshake(from, to, false) + fn create_response(&self, from: PeerId, to: PeerId, now: u64) -> Vec { + self.generate_handshake(from, to, false, now) } - fn verify_response(&self, data: Vec, expected_from: PeerId, expected_to: PeerId) -> Result<(), String> { - self.validate_handshake(data, expected_from, expected_to, false) + fn verify_response(&self, data: Vec, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String> { + self.validate_handshake(data, expected_from, expected_to, false, now) } } #[cfg(test)] mod tests { + use crate::now_ms; + use super::*; #[test] @@ -123,12 +124,12 @@ mod tests { let peer2 = PeerId::from(2); // Test request handshake - let request = secure.create_request(peer1, peer2); - assert!(secure.verify_request(request, peer1, peer2).is_ok()); + let request = secure.create_request(peer1, peer2, now_ms()); + assert!(secure.verify_request(request, peer1, peer2, now_ms()).is_ok()); // Test response handshake - let response = secure.create_response(peer2, peer1); - assert!(secure.verify_response(response, peer2, peer1).is_ok()); + let response = secure.create_response(peer2, peer1, now_ms()); + assert!(secure.verify_response(response, peer2, peer1, now_ms()).is_ok()); } #[test] @@ -138,7 +139,22 @@ mod tests { let peer1 = PeerId::from(1); let peer2 = PeerId::from(2); - let request = secure1.create_request(peer1, peer2); - assert!(secure2.verify_request(request, peer1, peer2).is_err()); + let request = secure1.create_request(peer1, peer2, now_ms()); + assert!(secure2.verify_request(request, peer1, peer2, now_ms()).is_err()); + } + + #[test] + fn test_handshake_timeout() { + let secure = SharedKeyHandshake::from("test_key"); + let peer1 = PeerId::from(1); + let peer2 = PeerId::from(2); + + // when date of peer2 is faster than peer1 + let request = secure.create_request(peer2, peer1, 1000); + assert!(secure.verify_request(request, peer2, peer1, 980).is_ok()); + + // when peer2 is too slow + let request = secure.create_request(peer2, peer1, 1000); + assert!(secure.verify_request(request, peer2, peer1, 1000 + HANDSHAKE_TIMEOUT + 1).is_err()); } }