From d9049ab4c1166146cf2e3d8158e51073fe181d54 Mon Sep 17 00:00:00 2001 From: Adrien Guillo Date: Thu, 14 Mar 2024 05:55:20 -0400 Subject: [PATCH] Add protocol version to message header (#140) --- chitchat/src/message.rs | 66 ++++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 14 deletions(-) diff --git a/chitchat/src/message.rs b/chitchat/src/message.rs index 714d4fb..3dded42 100644 --- a/chitchat/src/message.rs +++ b/chitchat/src/message.rs @@ -1,6 +1,6 @@ use std::io::BufRead; -use anyhow::Context; +use anyhow::{bail, Context}; use crate::delta::Delta; use crate::digest::Digest; @@ -26,7 +26,26 @@ pub enum ChitchatMessage { BadCluster, } -#[derive(Copy, Clone)] +#[derive(Clone, Copy, Eq, PartialEq)] +#[repr(u8)] +enum ProtocolVersion { + V0 = 0, +} + +impl ProtocolVersion { + pub fn from_code(code: u8) -> Option { + match code { + 0 => Some(Self::V0), + _ => None, + } + } + + pub fn to_code(self) -> u8 { + self as u8 + } +} + +#[derive(Clone, Copy)] #[repr(u8)] enum MessageType { Syn = 0, @@ -45,6 +64,7 @@ impl MessageType { _ => None, } } + pub fn to_code(self) -> u8 { self as u8 } @@ -52,6 +72,8 @@ impl MessageType { impl Serializable for ChitchatMessage { fn serialize(&self, buf: &mut Vec) { + ProtocolVersion::V0.to_code().serialize(buf); + match self { ChitchatMessage::Syn { cluster_id, digest } => { buf.push(MessageType::Syn.to_code()); @@ -74,7 +96,7 @@ impl Serializable for ChitchatMessage { } fn serialized_len(&self) -> usize { - match self { + 1 + match self { ChitchatMessage::Syn { cluster_id, digest } => { 1 + cluster_id.serialized_len() + digest.serialized_len() } @@ -89,13 +111,28 @@ impl Serializable for ChitchatMessage { impl Deserializable for ChitchatMessage { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let code = buf + let protocol_version = buf + .first() + .copied() + .and_then(ProtocolVersion::from_code) + .context("invalid protocol version")?; + + if protocol_version != ProtocolVersion::V0 { + bail!( + "unsupported protocol version `{}`", + protocol_version.to_code() + ) + } + buf.consume(1); + + let message_type = buf .first() .copied() .and_then(MessageType::from_code) .context("invalid message type")?; buf.consume(1); - match code { + + match message_type { MessageType::Syn => { let digest = Digest::deserialize(buf)?; let cluster_id = String::deserialize(buf)?; @@ -127,7 +164,7 @@ mod tests { cluster_id: "cluster-a".to_string(), digest: Digest::default(), }; - test_serdeser_aux(&syn, 14); + test_serdeser_aux(&syn, 15); } { let mut digest = Digest::default(); @@ -138,7 +175,7 @@ mod tests { cluster_id: "cluster-a".to_string(), digest, }; - test_serdeser_aux(&syn, 65); + test_serdeser_aux(&syn, 66); } } @@ -149,8 +186,8 @@ mod tests { digest: Digest::default(), delta: Delta::default(), }; - // 1 (message tag) + 2 (digest len) + 1 (delta end op) - test_serdeser_aux(&syn_ack, 4); + // 1 (protocol version) + 1 (message tag) + 2 (digest len) + 1 (delta end op) + test_serdeser_aux(&syn_ack, 5); } { // 2 bytes. @@ -173,8 +210,9 @@ mod tests { delta.set_serialized_len(60); let syn_ack = ChitchatMessage::SynAck { digest, delta }; - // 1 bytes (syn ack message) + 45 bytes (digest) + 69 bytes (delta). - test_serdeser_aux(&syn_ack, 1 + 53 + 60); + // 1 byte (protocol version) + 1 byte (message tag) + 53 bytes (digest) + 60 bytes + // (delta). + test_serdeser_aux(&syn_ack, 1 + 1 + 53 + 60); } } @@ -183,7 +221,7 @@ mod tests { { let delta = Delta::default(); let ack = ChitchatMessage::Ack { delta }; - test_serdeser_aux(&ack, 2); + test_serdeser_aux(&ack, 3); } { // 4 bytes. @@ -195,12 +233,12 @@ mod tests { delta.add_kv(&node, "key", "value", 0, true); delta.set_serialized_len(60); let ack = ChitchatMessage::Ack { delta }; - test_serdeser_aux(&ack, 1 + 60); + test_serdeser_aux(&ack, 1 + 1 + 60); } } #[test] fn test_bad_cluster() { - test_serdeser_aux(&ChitchatMessage::BadCluster, 1); + test_serdeser_aux(&ChitchatMessage::BadCluster, 2); } }