From 33b79932eba78c2fa9f8a8af80734dfc1f4b3e43 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Fri, 5 Apr 2024 12:26:48 +0530 Subject: [PATCH] io: Simplify handshakes --- CHANGELOG.md | 6 +++-- src/config.rs | 44 ++++++++++++++++++++++++++++++----- src/io/aio.rs | 48 +++++++++++++++++++-------------------- src/io/sync.rs | 45 ++++++++++++++++++------------------ src/protocol/handshake.rs | 21 ++++++++++++++--- 5 files changed, 105 insertions(+), 59 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e05af97..aabe50e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,11 @@ All changes in this project will be noted in this file. -### 0.9.0 +### 0.9.0 (unreleased) -> **Minimum Supported Skytable Version**: 0.8.2 +> - **Minimum Supported Skytable Version**: 0.8.2 +> - **Field change warnings**: +> - The `Config` struct now has two additional fields. This is not a breaking change because the functionality of the library remains unchanged - Added support for pipelines diff --git a/src/config.rs b/src/config.rs index a0e9e03..b818f68 100644 --- a/src/config.rs +++ b/src/config.rs @@ -30,6 +30,8 @@ //! let mut db = Config::new("subnetx2_db1", 2008, "username", "password").connect().unwrap(); //! ``` +pub use crate::protocol::handshake::ProtocolVersion; + /// The default host /// /// NOTE: If you are using a clustering setup, don't use this! @@ -46,21 +48,42 @@ pub struct Config { port: u16, username: Box, password: Box, + protocol: ProtocolVersion, + pub(crate) protocol_changed: bool, } impl Config { + fn _new( + host: Box, + port: u16, + username: Box, + password: Box, + protocol: ProtocolVersion, + ) -> Self { + Self { + host, + port, + username, + password, + protocol, + protocol_changed: false, + } + } /// Create a new [`Config`] using the default connection settings and using the provided username and password pub fn new_default(username: &str, password: &str) -> Self { Self::new(DEFAULT_HOST, DEFAULT_TCP_PORT, username, password) } - /// Create a new [`Config`] using the given settings + /// Create a new [`Config`] using the given settings. + /// + /// **PROTOCOL VERSION**: Defaults to [`ProtocolVersion::V2_0`] pub fn new(host: &str, port: u16, username: &str, password: &str) -> Self { - Self { - host: host.into(), + Self::_new( + host.into(), port, - username: username.into(), - password: password.into(), - } + username.into(), + password.into(), + ProtocolVersion::V2_0, + ) } /// Returns the host setting for this this configuration pub fn host(&self) -> &str { @@ -78,4 +101,13 @@ impl Config { pub fn password(&self) -> &str { self.password.as_ref() } + /// Set the protocol + pub fn set_protocol(&mut self, protocol: ProtocolVersion) { + self.protocol_changed = true; + self.protocol = protocol; + } + /// Returns the protocol used for connections + pub fn protocol(&self) -> ProtocolVersion { + self.protocol + } } diff --git a/src/io/aio.rs b/src/io/aio.rs index 9491e24..86e1193 100644 --- a/src/io/aio.rs +++ b/src/io/aio.rs @@ -85,17 +85,12 @@ impl DerefMut for ConnectionTlsAsync { impl Config { /// Establish an async connection to the database using the current configuration pub async fn connect_async(&self) -> ClientResult { - let mut tcpstream = TcpStream::connect((self.host(), self.port())).await?; - let handshake = ClientHandshake::new_v1(self); - tcpstream.write_all(handshake.inner()).await?; - let mut resp = [0u8; 4]; - tcpstream.read_exact(&mut resp).await?; - match ServerHandshake::parse(resp)? { - ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), - ServerHandshake::Okay(_suggestion) => { - return Ok(ConnectionAsync(TcpConnection::new(tcpstream))) - } - } + TcpStream::connect((self.host(), self.port())) + .await + .map(TcpConnection::new)? + ._handshake(self) + .await + .map(ConnectionAsync) } /// Establish an async TLS connection to the database using the current configuration. /// Pass the certificate in PEM format. @@ -115,22 +110,15 @@ impl Config { let connector = builder.build().map_err(|e| { ConnectionSetupError::Other(format!("failed to set up TLS acceptor: {e}")) })?; - // init - let mut stream = TlsConnector::from(connector) + // init and handshake + TlsConnector::from(connector) .connect(self.host(), stream) .await - .map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))?; - // handshake - let handshake = ClientHandshake::new_v1(self); - stream.write_all(handshake.inner()).await?; - let mut resp = [0u8; 4]; - stream.read_exact(&mut resp).await?; - match ServerHandshake::parse(resp)? { - ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), - ServerHandshake::Okay(_suggestion) => { - return Ok(ConnectionTlsAsync(TcpConnection::new(stream))) - } - } + .map(TcpConnection::new) + .map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))? + ._handshake(self) + .await + .map(ConnectionTlsAsync) } } @@ -148,6 +136,16 @@ impl TcpConnection { buf: Vec::with_capacity(crate::BUFSIZE), } } + async fn _handshake(mut self, cfg: &Config) -> ClientResult { + let handshake = ClientHandshake::new(cfg); + self.con.write_all(handshake.inner()).await?; + let mut resp = [0u8; 4]; + self.con.read_exact(&mut resp).await?; + match ServerHandshake::parse(resp)? { + ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), + ServerHandshake::Okay(_suggestion) => return Ok(self), + } + } /// Execute a pipeline. The server returns the queries in the order they were sent (unless otherwise set). pub async fn execute_pipeline(&mut self, pipeline: &Pipeline) -> ClientResult> { self.buf.clear(); diff --git a/src/io/sync.rs b/src/io/sync.rs index 68603c2..226084e 100644 --- a/src/io/sync.rs +++ b/src/io/sync.rs @@ -86,23 +86,17 @@ impl DerefMut for ConnectionTls { impl Config { /// Establish a connection to the database using the current configuration pub fn connect(&self) -> ClientResult { - let mut tcpstream = TcpStream::connect((self.host(), self.port()))?; - let handshake = ClientHandshake::new_v1(self); - tcpstream.write_all(handshake.inner())?; - let mut resp = [0u8; 4]; - tcpstream.read_exact(&mut resp)?; - match ServerHandshake::parse(resp)? { - ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), - ServerHandshake::Okay(_suggestion) => { - return Ok(Connection(TcpConnection::new(tcpstream))) - } - } + TcpStream::connect((self.host(), self.port())) + .map(TcpConnection::new)? + ._handshake(self) + .map(Connection) } /// Establish a TLS connection to the database using the current configuration. /// Pass the certificate in PEM format. pub fn connect_tls(&self, cert: &str) -> ClientResult { let stream = TcpStream::connect((self.host(), self.port()))?; - let mut stream = TlsConnector::builder() + TlsConnector::builder() + // build TLS connector .add_root_certificate(Certificate::from_pem(cert.as_bytes()).map_err(|e| { ConnectionSetupError::Other(format!("failed to parse certificate: {e}")) })?) @@ -111,18 +105,13 @@ impl Config { .map_err(|e| { ConnectionSetupError::Other(format!("failed to set up TLS acceptor: {e}")) })? + // connect .connect(self.host(), stream) - .map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))?; - let handshake = ClientHandshake::new_v1(self); - stream.write_all(handshake.inner())?; - let mut resp = [0u8; 4]; - stream.read_exact(&mut resp)?; - match ServerHandshake::parse(resp)? { - ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), - ServerHandshake::Okay(_suggestion) => { - return Ok(ConnectionTls(TcpConnection::new(stream))) - } - } + .map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}"))) + .map(TcpConnection::new)? + // handshake + ._handshake(self) + .map(ConnectionTls) } } @@ -142,6 +131,16 @@ impl TcpConnection { buf: Vec::with_capacity(crate::BUFSIZE), } } + fn _handshake(mut self, cfg: &Config) -> ClientResult { + let handshake = ClientHandshake::new(cfg); + self.con.write_all(handshake.inner())?; + let mut resp = [0u8; 4]; + self.con.read_exact(&mut resp)?; + match ServerHandshake::parse(resp)? { + ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), + ServerHandshake::Okay(_suggestion) => return Ok(self), + } + } /// Execute a pipeline. The server returns the queries in the order they were sent (unless otherwise set). pub fn execute_pipeline(&mut self, pipeline: &Pipeline) -> ClientResult> { self.buf.clear(); diff --git a/src/protocol/handshake.rs b/src/protocol/handshake.rs index 51b8fcc..05f62a7 100644 --- a/src/protocol/handshake.rs +++ b/src/protocol/handshake.rs @@ -19,11 +19,26 @@ use crate::{ ClientResult, Config, }; +#[derive(Debug, PartialEq, Clone, Copy)] +#[repr(u8)] +/// The Skyhash protocol version +pub enum ProtocolVersion { + /// Skyhash 2.0 + V2_0, +} + +impl ProtocolVersion { + pub(crate) const fn hs_block(&self) -> [u8; 6] { + match self { + Self::V2_0 => [b'H', 0, 0, 0, 0, 0], + } + } +} + pub struct ClientHandshake(Box<[u8]>); impl ClientHandshake { - const HANDSHAKE_PROTO_V1: [u8; 6] = [b'H', 0, 0, 0, 0, 0]; - pub(crate) fn new_v1(cfg: &Config) -> Self { - Self::_new(Self::HANDSHAKE_PROTO_V1, cfg) + pub(crate) fn new(cfg: &Config) -> Self { + Self::_new(cfg.protocol().hs_block(), cfg) } fn _new(hs: [u8; 6], cfg: &Config) -> Self { let mut v = Vec::with_capacity(6 + cfg.username().len() + cfg.password().len() + 5);