Skip to content

Commit

Permalink
fix bug timestamp is negative when other node's is faster than curren…
Browse files Browse the repository at this point in the history
…t node
  • Loading branch information
marverlous811 committed Oct 10, 2024
1 parent 919a74e commit cf2b2d4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 28 deletions.
9 changes: 5 additions & 4 deletions src/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use tokio::sync::{
use crate::{
ctx::SharedCtx,
msg::P2pServiceId,
now_ms,
secure::HandshakeProtocol,
stream::{wait_object, write_object, P2pQuicStream},
ConnectionId, PeerId,
Expand Down Expand Up @@ -136,13 +137,13 @@ async fn run_connection<SECURE: HandshakeProtocol>(
internal_tx: Sender<InternalEvent>,
) -> anyhow::Result<()> {
let to_id = if let PeerConnectionDirection::Outgoing(dest) = direction {
let auth = secure.create_request(local_id, dest);
let auth = secure.create_request(local_id, dest, now_ms());
write_object::<_, _, 500>(&mut send, &ConnectReq { from: local_id, to: dest, auth }).await?;
let res: ConnectRes = wait_object::<_, _, 500>(&mut recv).await?;
log::info!("{res:?}");
match res.result {
Ok(auth) => {
if let Err(e) = secure.verify_response(auth, dest, local_id) {
if let Err(e) = secure.verify_response(auth, dest, local_id, now_ms()) {
return Err(anyhow!("destination auth failure: {e}"));
}
dest
Expand All @@ -153,7 +154,7 @@ async fn run_connection<SECURE: HandshakeProtocol>(
}
} else {
let req: ConnectReq = wait_object::<_, _, 500>(&mut recv).await?;
if let Err(e) = secure.verify_request(req.auth, req.from, req.to) {
if let Err(e) = secure.verify_request(req.auth, req.from, req.to, now_ms()) {
write_object::<_, _, 500>(&mut send, &ConnectRes { result: Err(e.clone()) }).await?;
return Err(anyhow!("destination auth failure: {e}"));
} else if req.to != local_id {
Expand All @@ -166,7 +167,7 @@ async fn run_connection<SECURE: HandshakeProtocol>(
.await?;
return Err(anyhow!("destination wrong"));
} else {
let auth = secure.create_response(req.to, req.from);
let auth = secure.create_response(req.to, req.from, now_ms());
write_object::<_, _, 500>(&mut send, &ConnectRes { result: Ok(auth) }).await?;
req.from
}
Expand Down
64 changes: 40 additions & 24 deletions src/secure.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::{now_ms, PeerId};
use crate::PeerId;
use serde::{Deserialize, Serialize};

pub trait HandshakeProtocol: Send + Sync + 'static {
fn create_request(&self, from: PeerId, to: PeerId) -> Vec<u8>;
fn verify_request(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId) -> Result<(), String>;
fn create_response(&self, from: PeerId, to: PeerId) -> Vec<u8>;
fn verify_response(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId) -> Result<(), String>;
fn create_request(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8>;
fn verify_request(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String>;
fn create_response(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8>;
fn verify_response(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String>;
}

const HASH_SEED: &str = "atm0s-small-p2p";
Expand Down Expand Up @@ -40,11 +40,11 @@ impl From<&str> for SharedKeyHandshake {
}

impl SharedKeyHandshake {
fn generate_handshake(&self, from: PeerId, to: PeerId, is_client: bool) -> Vec<u8> {
fn generate_handshake(&self, from: PeerId, to: PeerId, is_client: bool, now: u64) -> Vec<u8> {
let handshake_data = HandshakeData {
from,
to,
timestamp: now_ms(),
timestamp: now,
is_initiator: is_client,
};

Expand All @@ -59,14 +59,13 @@ impl SharedKeyHandshake {
bincode::serialize(&handshake).unwrap()
}

fn validate_handshake(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, expected_is_client: bool) -> Result<(), String> {
fn validate_handshake(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, expected_is_client: bool, current_ts: u64) -> Result<(), String> {
let handshake: HandshakeMessage = bincode::deserialize(&data).map_err(|_| "Invalid handshake format".to_string())?;

let handshake_data: HandshakeData = bincode::deserialize(&handshake.payload).map_err(|_| "Invalid handshake data format".to_string())?;

// Verify timestamp
let current_ts = now_ms();
if current_ts - handshake_data.timestamp > HANDSHAKE_TIMEOUT {
if current_ts > handshake_data.timestamp + HANDSHAKE_TIMEOUT {
return Err(format!("Handshake timeout {} vs {}", current_ts, handshake_data.timestamp));
}

Expand Down Expand Up @@ -95,25 +94,27 @@ impl SharedKeyHandshake {
}

impl HandshakeProtocol for SharedKeyHandshake {
fn create_request(&self, from: PeerId, to: PeerId) -> Vec<u8> {
self.generate_handshake(from, to, true)
fn create_request(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8> {
self.generate_handshake(from, to, true, now)
}

fn verify_request(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId) -> Result<(), String> {
self.validate_handshake(data, expected_from, expected_to, true)
fn verify_request(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String> {
self.validate_handshake(data, expected_from, expected_to, true, now)
}

fn create_response(&self, from: PeerId, to: PeerId) -> Vec<u8> {
self.generate_handshake(from, to, false)
fn create_response(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8> {
self.generate_handshake(from, to, false, now)
}

fn verify_response(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId) -> Result<(), String> {
self.validate_handshake(data, expected_from, expected_to, false)
fn verify_response(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String> {
self.validate_handshake(data, expected_from, expected_to, false, now)
}
}

#[cfg(test)]
mod tests {
use crate::now_ms;

use super::*;

#[test]
Expand All @@ -123,12 +124,12 @@ mod tests {
let peer2 = PeerId::from(2);

// Test request handshake
let request = secure.create_request(peer1, peer2);
assert!(secure.verify_request(request, peer1, peer2).is_ok());
let request = secure.create_request(peer1, peer2, now_ms());
assert!(secure.verify_request(request, peer1, peer2, now_ms()).is_ok());

// Test response handshake
let response = secure.create_response(peer2, peer1);
assert!(secure.verify_response(response, peer2, peer1).is_ok());
let response = secure.create_response(peer2, peer1, now_ms());
assert!(secure.verify_response(response, peer2, peer1, now_ms()).is_ok());
}

#[test]
Expand All @@ -138,7 +139,22 @@ mod tests {
let peer1 = PeerId::from(1);
let peer2 = PeerId::from(2);

let request = secure1.create_request(peer1, peer2);
assert!(secure2.verify_request(request, peer1, peer2).is_err());
let request = secure1.create_request(peer1, peer2, now_ms());
assert!(secure2.verify_request(request, peer1, peer2, now_ms()).is_err());
}

#[test]
fn test_handshake_timeout() {
let secure = SharedKeyHandshake::from("test_key");
let peer1 = PeerId::from(1);
let peer2 = PeerId::from(2);

// when date of peer2 is faster than peer1
let request = secure.create_request(peer2, peer1, 1000);
assert!(secure.verify_request(request, peer2, peer1, 980).is_ok());

// when peer2 is too slow
let request = secure.create_request(peer2, peer1, 1000);
assert!(secure.verify_request(request, peer2, peer1, 1000 + HANDSHAKE_TIMEOUT + 1).is_err());
}
}

0 comments on commit cf2b2d4

Please sign in to comment.