Skip to content

Commit

Permalink
fixed SSL habdshake issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Joshi committed Apr 8, 2019
1 parent 0b794f4 commit e3fa4b1
Show file tree
Hide file tree
Showing 8 changed files with 593 additions and 428 deletions.
10 changes: 3 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ incremental = false
[dependencies]
log = { version = "0.3", features = ["max_level_trace", "release_max_level_info"] }
url="1.7.2"
futures = "0.1.23"
mio = "0.6"
#futures = "0.1.23"
mio = "0.6.16"
mio-uds = "0.6"
serde = "1.0.27"
serde_derive = "1.0.27"
Expand All @@ -34,13 +34,9 @@ crossbeam-channel = "0.3.2"
crossbeam = "0.7.1"
hashbrown = "0.1.8"
parking_lot = "0.7.1"
openssl="0.10.20"

[features]
tls = ["openssl"]

[dependencies.openssl]
optional = true
version= "0.10.16"



14 changes: 10 additions & 4 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@

use std::fs::File;
use std::io::prelude::*;
use std::path::PathBuf;
use std::str;

#[cfg(feature = "tls")]
//#[cfg(feature = "tls")]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SSLConfig {
/// SSL enabled
pub enabled: bool,
/// SSL Protocol
//pub ssl_protocol : Option<>,
/// Certificate File
Expand All @@ -27,10 +30,12 @@ pub struct SSLConfig {
/// Verify depth
pub verify_depth: Option<u32>,
}
#[cfg(feature = "tls")]

//#[cfg(feature = "tls")]
impl Default for SSLConfig {
fn default() -> SSLConfig {
SSLConfig {
enabled: false,
certificate_file: None,
private_key_file: None,
ca_file: None,
Expand All @@ -39,6 +44,7 @@ impl Default for SSLConfig {
}
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct GaduConfig {
/// The server to connect to.
Expand All @@ -49,7 +55,7 @@ pub struct GaduConfig {
pub write_timeout: usize,
/// keep alive time
pub keep_alive_time: u64,
#[cfg(feature = "tls")]
//#[cfg(feature = "tls")]
pub ssl_config: SSLConfig,
}

Expand All @@ -60,7 +66,7 @@ impl Default for GaduConfig {
read_timeout: 60_000,
write_timeout: 60_000,
keep_alive_time: 60_000,
#[cfg(feature = "tls")]
//#[cfg(feature = "tls")]
ssl_config: SSLConfig::default(),
}
}
Expand Down
161 changes: 139 additions & 22 deletions src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
use hashbrown::HashMap;
use mio::net::TcpStream;
use mio_uds::UnixStream;
#[cfg(feature = "tls")]
use openssl::ssl::{HandshakeError, MidHandshakeSslStream};
//#[cfg(feature = "tls")]
use openssl::ssl::SslStream;
use std::io::{Error, ErrorKind};
use std::io::{Read, Write};
use std::io::Result as IoResult;
use std::net::Shutdown;
use std::net::SocketAddr;
use std::os::unix::net::SocketAddr as UnixSocketAddr;

use std::io::Result as IoResult;
use std::io::{Read, Write};
use std::time::Duration;
use url::Url;

Expand Down Expand Up @@ -47,6 +48,7 @@ impl Conn {
}
}


#[inline]
pub fn is_remote_connection(&self) -> bool {
self.url.is_empty()
Expand Down Expand Up @@ -112,7 +114,7 @@ impl Conn {
if let Err(e) = sock.set_keepalive(Some(Conn::KEEP_ALIVE_TIME)) {
error!("Faile to set keep alive : {}.", e.to_string());
}
info!("Tcp client connected with server at {}", addr);
debug!("Tcp client connected with server at {}", addr);
(
NetStream::UnsecuredTcpStream(sock),
NetAddr::NetSocketAddress(conn_addr),
Expand All @@ -128,7 +130,7 @@ impl Conn {
return Err(e.to_string());
}
};
info!("Unix Socket connected with server at {}", path);
debug!("Unix Socket connected with server at {}", path);
let addr = sock.peer_addr().unwrap();
(NetStream::UdsStream(sock), NetAddr::UdsSocketAddress(addr))
}
Expand Down Expand Up @@ -173,17 +175,17 @@ impl Conn {
match self.stream.write(self.output.as_slice()) {
Ok(n) => {
if n < self.output.len() {
// let mut output = Vec::new();
// output.extend_from_slice(&self.output[n..self.output.len()]);
// self.output = output
// let mut output = Vec::new();
// output.extend_from_slice(&self.output[n..self.output.len()]);
// self.output = output
self.output.drain(0..n);
} else {
self.output.clear();
}
}
Err(ref e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
warn!(
debug!(
"Write: ErrorKind::WouldBlock on connection :{}",
self.addr.to_string()
);
Expand Down Expand Up @@ -215,29 +217,28 @@ impl Conn {
debug!("Conn bytes read: {}", n);
if n == 0 {
self.close = true;
//break;
//break;
} else {
self.input.extend_from_slice(&packet[0..n]);
//self.input.extend(&buffer);
debug!(
"Received Length:{}, data: {}. ",
n,
String::from_utf8_lossy(&self.input)

);
}
}
Err(ref e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
warn!(
debug!(
"Read: ErrorKind::WouldBlock on connection :{}",
self.addr.to_string()
);
//break;
//break;
} else if e.kind() == std::io::ErrorKind::ConnectionReset {
info!("Read: Connection reset by peer:{}", self.addr.to_string());
self.close = true;
//break;
//break;
} else {
error!(
"Read: Peer Connection:{}, Read Error: {:?}",
Expand All @@ -260,6 +261,56 @@ impl Conn {
pub fn shutdown(&mut self) {
let _ = self.stream.shutdown();
}

pub fn is_ssl_handshake_pending(&self) -> bool {
match self.stream {
NetStream::SslMidHandshakeStream(_) => {
true
}
_ => { false }
}
}


pub fn ssl_handshake(&mut self) -> Result<(), Error> {
use std::mem;
let old = mem::replace(&mut self.stream, NetStream::Invalid);
match old {
NetStream::SslMidHandshakeStream(mid_stream) => {
match mid_stream.handshake() {
Ok(s) => {
debug!("ssl_handshake:SSL Handshake successful");
self.stream = NetStream::SslTcpStream(s);
Ok(())
}
Err(e) => {
debug!("{:?}", e);
let err_str = e.to_string();
match e {
HandshakeError::WouldBlock(s) => {
debug!("Failed to handshake on SSL connection. Error:{}", err_str);
self.stream = NetStream::SslMidHandshakeStream(s);
Ok(())
}
_ => {
error!("Failed to accept SSL connection. Error:{}", err_str);
Err(Error::new(
ErrorKind::Other,
format!("An SSL error occurred.{}", err_str),
))
}
}
}
}
}
_ => { Ok(()) }
}
}
/* if let Ok(_) = NetStream::ssl_handshake(&mut self.stream) {
//self.stream = stream;
}else {
self.close = true;
}*/
}

#[derive(Debug)]
Expand All @@ -270,60 +321,126 @@ pub enum NetStream {
UdsStream(UnixStream),
/// An SSL-secured TcpStream.
/// This is only available when compiled with SSL support.
#[cfg(feature = "tls")]
//#[cfg(feature = "tls")]
SslTcpStream(SslStream<TcpStream>),
//SSL mid handshake stream
SslMidHandshakeStream(MidHandshakeSslStream<TcpStream>),

Invalid,
}

impl NetStream {
pub fn shutdown(&mut self) -> Result<(), std::io::Error> {
match *self {
NetStream::UnsecuredTcpStream(ref stream) => stream.shutdown(Shutdown::Both),
NetStream::UdsStream(ref stream) => stream.shutdown(Shutdown::Both),
#[cfg(feature = "tls")]
//#[cfg(feature = "tls")]
NetStream::SslTcpStream(ref mut stream) => {
stream.shutdown();
if let Err(e) = stream.shutdown() {
warn!("Failed to shutdown SSL stream. Error:{:?}",e);
}
Ok(())
}
NetStream::SslMidHandshakeStream(ref mut mid_stream) => {
mid_stream.get_mut().shutdown(Shutdown::Both)
}
NetStream::Invalid => { Ok(()) }
}
}
#[inline]
pub fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match *self {
NetStream::UnsecuredTcpStream(ref mut stream) => stream.read(buf),
NetStream::UdsStream(ref mut stream) => stream.read(buf),
#[cfg(feature = "tls")]
//#[cfg(feature = "tls")]
NetStream::SslTcpStream(ref mut stream) => stream.read(buf),
NetStream::SslMidHandshakeStream(ref mut mid_stream) => {
mid_stream.get_mut().read(buf)
}
NetStream::Invalid => { Ok(0) }
}
}
#[inline]
pub fn write(&mut self, buf: &[u8]) -> IoResult<(usize)> {
match *self {
NetStream::UnsecuredTcpStream(ref mut stream) => stream.write(buf),
NetStream::UdsStream(ref mut stream) => stream.write(buf),
#[cfg(feature = "tls")]
//#[cfg(feature = "tls")]
NetStream::SslTcpStream(ref mut stream) => {
// Arc::get_mut(stream).unwrap().write(buf)
stream.write(buf)
}
NetStream::SslMidHandshakeStream(ref mut mid_stream) => {
mid_stream.get_mut().write(buf)
}
NetStream::Invalid => { Ok(0) }
}
}
#[inline]
pub fn write_all(&mut self, buf: &[u8]) -> IoResult<()> {
match *self {
NetStream::UnsecuredTcpStream(ref mut stream) => stream.write_all(buf),
NetStream::UdsStream(ref mut stream) => stream.write_all(buf),
#[cfg(feature = "tls")]
//#[cfg(feature = "tls")]
NetStream::SslTcpStream(ref mut stream) => stream.write_all(buf),
NetStream::SslMidHandshakeStream(ref mut mid_stream) => {
mid_stream.get_mut().write_all(buf)
}
NetStream::Invalid => { Ok(()) }
}
}
pub fn flush(&mut self) -> IoResult<()> {
match *self {
NetStream::UnsecuredTcpStream(ref mut stream) => stream.flush(),
NetStream::UdsStream(ref mut stream) => stream.flush(),
#[cfg(feature = "tls")]
//#[cfg(feature = "tls")]
NetStream::SslTcpStream(ref mut stream) => stream.flush(),
NetStream::SslMidHandshakeStream(ref mut mid_stream) => {
mid_stream.get_mut().flush()
}
NetStream::Invalid => { Ok(()) }
}
}

/*pub fn ssl_handshake(stream : &mut NetStream) -> Result<(), Error> {
match *stream {
NetStream::SslMidHandshakeStream(mut mid_stream) => {
match mid_stream.handshake() {
Ok(s) => {
info!("ssl_handshake:SSL Handshake successful");
*stream = NetStream::SslTcpStream(s);
Ok(())
},
Err(e) => {
info!("{:?}", e);
let err_str = e.to_string();
match e {
HandshakeError::WouldBlock(s) => {
info!("Failed to handshake on SSL connection. Error:{}", err_str);
*stream = NetStream::SslMidHandshakeStream(s);
Ok(())
},
_ => {
error!("Failed to accept SSL connection. Error:{}", err_str);
return Err(Error::new(
ErrorKind::Other,
format!("An SSL error occurred.{}", err_str),
));
}
}
}
}
},
_ => {
error!("Invalid operation. SSL Handshake is only supported on SslMidHandshakeStream");
return Err(Error::new(
ErrorKind::Other,
"An SSL error occurred.Invalid operation. SSL Handshake is only supported on SslMidHandshakeStream",
));
}
}
}*/
}

#[derive(Debug)]
Expand Down
Loading

0 comments on commit e3fa4b1

Please sign in to comment.