Skip to content

Commit

Permalink
io: Simplify handshakes
Browse files Browse the repository at this point in the history
  • Loading branch information
ohsayan committed Apr 5, 2024
1 parent 5e1b401 commit 33b7993
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 59 deletions.
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

All changes in this project will be noted in this file.

### 0.9.0
### 0.9.0 (unreleased)

> **Minimum Supported Skytable Version**: 0.8.2
> - **Minimum Supported Skytable Version**: 0.8.2
> - **Field change warnings**:
> - The `Config` struct now has two additional fields. This is not a breaking change because the functionality of the library remains unchanged
- Added support for pipelines

Expand Down
44 changes: 38 additions & 6 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
//! let mut db = Config::new("subnetx2_db1", 2008, "username", "password").connect().unwrap();
//! ```
pub use crate::protocol::handshake::ProtocolVersion;

/// The default host
///
/// NOTE: If you are using a clustering setup, don't use this!
Expand All @@ -46,21 +48,42 @@ pub struct Config {
port: u16,
username: Box<str>,
password: Box<str>,
protocol: ProtocolVersion,
pub(crate) protocol_changed: bool,
}

impl Config {
fn _new(
host: Box<str>,
port: u16,
username: Box<str>,
password: Box<str>,
protocol: ProtocolVersion,
) -> Self {
Self {
host,
port,
username,
password,
protocol,
protocol_changed: false,
}
}
/// Create a new [`Config`] using the default connection settings and using the provided username and password
pub fn new_default(username: &str, password: &str) -> Self {
Self::new(DEFAULT_HOST, DEFAULT_TCP_PORT, username, password)
}
/// Create a new [`Config`] using the given settings
/// Create a new [`Config`] using the given settings.
///
/// **PROTOCOL VERSION**: Defaults to [`ProtocolVersion::V2_0`]
pub fn new(host: &str, port: u16, username: &str, password: &str) -> Self {
Self {
host: host.into(),
Self::_new(
host.into(),
port,
username: username.into(),
password: password.into(),
}
username.into(),
password.into(),
ProtocolVersion::V2_0,
)
}
/// Returns the host setting for this this configuration
pub fn host(&self) -> &str {
Expand All @@ -78,4 +101,13 @@ impl Config {
pub fn password(&self) -> &str {
self.password.as_ref()
}
/// Set the protocol
pub fn set_protocol(&mut self, protocol: ProtocolVersion) {
self.protocol_changed = true;
self.protocol = protocol;
}
/// Returns the protocol used for connections
pub fn protocol(&self) -> ProtocolVersion {
self.protocol
}
}
48 changes: 23 additions & 25 deletions src/io/aio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,12 @@ impl DerefMut for ConnectionTlsAsync {
impl Config {
/// Establish an async connection to the database using the current configuration
pub async fn connect_async(&self) -> ClientResult<ConnectionAsync> {
let mut tcpstream = TcpStream::connect((self.host(), self.port())).await?;
let handshake = ClientHandshake::new_v1(self);
tcpstream.write_all(handshake.inner()).await?;
let mut resp = [0u8; 4];
tcpstream.read_exact(&mut resp).await?;
match ServerHandshake::parse(resp)? {
ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()),
ServerHandshake::Okay(_suggestion) => {
return Ok(ConnectionAsync(TcpConnection::new(tcpstream)))
}
}
TcpStream::connect((self.host(), self.port()))
.await
.map(TcpConnection::new)?
._handshake(self)
.await
.map(ConnectionAsync)
}
/// Establish an async TLS connection to the database using the current configuration.
/// Pass the certificate in PEM format.
Expand All @@ -115,22 +110,15 @@ impl Config {
let connector = builder.build().map_err(|e| {
ConnectionSetupError::Other(format!("failed to set up TLS acceptor: {e}"))
})?;
// init
let mut stream = TlsConnector::from(connector)
// init and handshake
TlsConnector::from(connector)
.connect(self.host(), stream)
.await
.map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))?;
// handshake
let handshake = ClientHandshake::new_v1(self);
stream.write_all(handshake.inner()).await?;
let mut resp = [0u8; 4];
stream.read_exact(&mut resp).await?;
match ServerHandshake::parse(resp)? {
ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()),
ServerHandshake::Okay(_suggestion) => {
return Ok(ConnectionTlsAsync(TcpConnection::new(stream)))
}
}
.map(TcpConnection::new)
.map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))?
._handshake(self)
.await
.map(ConnectionTlsAsync)
}
}

Expand All @@ -148,6 +136,16 @@ impl<C: AsyncWriteExt + AsyncReadExt + Unpin> TcpConnection<C> {
buf: Vec::with_capacity(crate::BUFSIZE),
}
}
async fn _handshake(mut self, cfg: &Config) -> ClientResult<Self> {
let handshake = ClientHandshake::new(cfg);
self.con.write_all(handshake.inner()).await?;
let mut resp = [0u8; 4];
self.con.read_exact(&mut resp).await?;
match ServerHandshake::parse(resp)? {
ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()),
ServerHandshake::Okay(_suggestion) => return Ok(self),
}
}
/// Execute a pipeline. The server returns the queries in the order they were sent (unless otherwise set).
pub async fn execute_pipeline(&mut self, pipeline: &Pipeline) -> ClientResult<Vec<Response>> {
self.buf.clear();
Expand Down
45 changes: 22 additions & 23 deletions src/io/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,17 @@ impl DerefMut for ConnectionTls {
impl Config {
/// Establish a connection to the database using the current configuration
pub fn connect(&self) -> ClientResult<Connection> {
let mut tcpstream = TcpStream::connect((self.host(), self.port()))?;
let handshake = ClientHandshake::new_v1(self);
tcpstream.write_all(handshake.inner())?;
let mut resp = [0u8; 4];
tcpstream.read_exact(&mut resp)?;
match ServerHandshake::parse(resp)? {
ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()),
ServerHandshake::Okay(_suggestion) => {
return Ok(Connection(TcpConnection::new(tcpstream)))
}
}
TcpStream::connect((self.host(), self.port()))
.map(TcpConnection::new)?
._handshake(self)
.map(Connection)
}
/// Establish a TLS connection to the database using the current configuration.
/// Pass the certificate in PEM format.
pub fn connect_tls(&self, cert: &str) -> ClientResult<ConnectionTls> {
let stream = TcpStream::connect((self.host(), self.port()))?;
let mut stream = TlsConnector::builder()
TlsConnector::builder()
// build TLS connector
.add_root_certificate(Certificate::from_pem(cert.as_bytes()).map_err(|e| {
ConnectionSetupError::Other(format!("failed to parse certificate: {e}"))
})?)
Expand All @@ -111,18 +105,13 @@ impl Config {
.map_err(|e| {
ConnectionSetupError::Other(format!("failed to set up TLS acceptor: {e}"))
})?
// connect
.connect(self.host(), stream)
.map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))?;
let handshake = ClientHandshake::new_v1(self);
stream.write_all(handshake.inner())?;
let mut resp = [0u8; 4];
stream.read_exact(&mut resp)?;
match ServerHandshake::parse(resp)? {
ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()),
ServerHandshake::Okay(_suggestion) => {
return Ok(ConnectionTls(TcpConnection::new(stream)))
}
}
.map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))
.map(TcpConnection::new)?
// handshake
._handshake(self)
.map(ConnectionTls)
}
}

Expand All @@ -142,6 +131,16 @@ impl<C: Write + Read> TcpConnection<C> {
buf: Vec::with_capacity(crate::BUFSIZE),
}
}
fn _handshake(mut self, cfg: &Config) -> ClientResult<Self> {
let handshake = ClientHandshake::new(cfg);
self.con.write_all(handshake.inner())?;
let mut resp = [0u8; 4];
self.con.read_exact(&mut resp)?;
match ServerHandshake::parse(resp)? {
ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()),
ServerHandshake::Okay(_suggestion) => return Ok(self),
}
}
/// Execute a pipeline. The server returns the queries in the order they were sent (unless otherwise set).
pub fn execute_pipeline(&mut self, pipeline: &Pipeline) -> ClientResult<Vec<Response>> {
self.buf.clear();
Expand Down
21 changes: 18 additions & 3 deletions src/protocol/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,26 @@ use crate::{
ClientResult, Config,
};

#[derive(Debug, PartialEq, Clone, Copy)]
#[repr(u8)]
/// The Skyhash protocol version
pub enum ProtocolVersion {
/// Skyhash 2.0
V2_0,
}

impl ProtocolVersion {
pub(crate) const fn hs_block(&self) -> [u8; 6] {
match self {
Self::V2_0 => [b'H', 0, 0, 0, 0, 0],
}
}
}

pub struct ClientHandshake(Box<[u8]>);
impl ClientHandshake {
const HANDSHAKE_PROTO_V1: [u8; 6] = [b'H', 0, 0, 0, 0, 0];
pub(crate) fn new_v1(cfg: &Config) -> Self {
Self::_new(Self::HANDSHAKE_PROTO_V1, cfg)
pub(crate) fn new(cfg: &Config) -> Self {
Self::_new(cfg.protocol().hs_block(), cfg)
}
fn _new(hs: [u8; 6], cfg: &Config) -> Self {
let mut v = Vec::with_capacity(6 + cfg.username().len() + cfg.password().len() + 5);
Expand Down

0 comments on commit 33b7993

Please sign in to comment.