diff --git a/Cargo.toml b/Cargo.toml index 09e2b8e..872c6c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,8 +24,8 @@ incremental = false [dependencies] log = { version = "0.3", features = ["max_level_trace", "release_max_level_info"] } url="1.7.2" -futures = "0.1.23" -mio = "0.6" +#futures = "0.1.23" +mio = "0.6.16" mio-uds = "0.6" serde = "1.0.27" serde_derive = "1.0.27" @@ -34,13 +34,9 @@ crossbeam-channel = "0.3.2" crossbeam = "0.7.1" hashbrown = "0.1.8" parking_lot = "0.7.1" +openssl="0.10.20" -[features] -tls = ["openssl"] -[dependencies.openssl] -optional = true -version= "0.10.16" diff --git a/src/config.rs b/src/config.rs index c4bf6b6..9edbb6e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -9,11 +9,14 @@ use std::fs::File; use std::io::prelude::*; +use std::path::PathBuf; use std::str; -#[cfg(feature = "tls")] +//#[cfg(feature = "tls")] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct SSLConfig { + /// SSL enabled + pub enabled: bool, /// SSL Protocol //pub ssl_protocol : Option<>, /// Certificate File @@ -27,10 +30,12 @@ pub struct SSLConfig { /// Verify depth pub verify_depth: Option, } -#[cfg(feature = "tls")] + +//#[cfg(feature = "tls")] impl Default for SSLConfig { fn default() -> SSLConfig { SSLConfig { + enabled: false, certificate_file: None, private_key_file: None, ca_file: None, @@ -39,6 +44,7 @@ impl Default for SSLConfig { } } } + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct GaduConfig { /// The server to connect to. @@ -49,7 +55,7 @@ pub struct GaduConfig { pub write_timeout: usize, /// keep alive time pub keep_alive_time: u64, - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] pub ssl_config: SSLConfig, } @@ -60,7 +66,7 @@ impl Default for GaduConfig { read_timeout: 60_000, write_timeout: 60_000, keep_alive_time: 60_000, - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] ssl_config: SSLConfig::default(), } } diff --git a/src/conn.rs b/src/conn.rs index ecdc547..086ddf8 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -9,14 +9,15 @@ use hashbrown::HashMap; use mio::net::TcpStream; use mio_uds::UnixStream; -#[cfg(feature = "tls")] +use openssl::ssl::{HandshakeError, MidHandshakeSslStream}; +//#[cfg(feature = "tls")] use openssl::ssl::SslStream; +use std::io::{Error, ErrorKind}; +use std::io::{Read, Write}; +use std::io::Result as IoResult; use std::net::Shutdown; use std::net::SocketAddr; use std::os::unix::net::SocketAddr as UnixSocketAddr; - -use std::io::Result as IoResult; -use std::io::{Read, Write}; use std::time::Duration; use url::Url; @@ -47,6 +48,7 @@ impl Conn { } } + #[inline] pub fn is_remote_connection(&self) -> bool { self.url.is_empty() @@ -112,7 +114,7 @@ impl Conn { if let Err(e) = sock.set_keepalive(Some(Conn::KEEP_ALIVE_TIME)) { error!("Faile to set keep alive : {}.", e.to_string()); } - info!("Tcp client connected with server at {}", addr); + debug!("Tcp client connected with server at {}", addr); ( NetStream::UnsecuredTcpStream(sock), NetAddr::NetSocketAddress(conn_addr), @@ -128,7 +130,7 @@ impl Conn { return Err(e.to_string()); } }; - info!("Unix Socket connected with server at {}", path); + debug!("Unix Socket connected with server at {}", path); let addr = sock.peer_addr().unwrap(); (NetStream::UdsStream(sock), NetAddr::UdsSocketAddress(addr)) } @@ -173,9 +175,9 @@ impl Conn { match self.stream.write(self.output.as_slice()) { Ok(n) => { if n < self.output.len() { - // let mut output = Vec::new(); - // output.extend_from_slice(&self.output[n..self.output.len()]); - // self.output = output + // let mut output = Vec::new(); + // output.extend_from_slice(&self.output[n..self.output.len()]); + // self.output = output self.output.drain(0..n); } else { self.output.clear(); @@ -183,7 +185,7 @@ impl Conn { } Err(ref e) => { if e.kind() == std::io::ErrorKind::WouldBlock { - warn!( + debug!( "Write: ErrorKind::WouldBlock on connection :{}", self.addr.to_string() ); @@ -215,7 +217,7 @@ impl Conn { debug!("Conn bytes read: {}", n); if n == 0 { self.close = true; - //break; + //break; } else { self.input.extend_from_slice(&packet[0..n]); //self.input.extend(&buffer); @@ -223,21 +225,20 @@ impl Conn { "Received Length:{}, data: {}. ", n, String::from_utf8_lossy(&self.input) - ); } } Err(ref e) => { if e.kind() == std::io::ErrorKind::WouldBlock { - warn!( + debug!( "Read: ErrorKind::WouldBlock on connection :{}", self.addr.to_string() ); - //break; + //break; } else if e.kind() == std::io::ErrorKind::ConnectionReset { info!("Read: Connection reset by peer:{}", self.addr.to_string()); self.close = true; - //break; + //break; } else { error!( "Read: Peer Connection:{}, Read Error: {:?}", @@ -260,6 +261,56 @@ impl Conn { pub fn shutdown(&mut self) { let _ = self.stream.shutdown(); } + + pub fn is_ssl_handshake_pending(&self) -> bool { + match self.stream { + NetStream::SslMidHandshakeStream(_) => { + true + } + _ => { false } + } + } + + + pub fn ssl_handshake(&mut self) -> Result<(), Error> { + use std::mem; + let old = mem::replace(&mut self.stream, NetStream::Invalid); + match old { + NetStream::SslMidHandshakeStream(mid_stream) => { + match mid_stream.handshake() { + Ok(s) => { + debug!("ssl_handshake:SSL Handshake successful"); + self.stream = NetStream::SslTcpStream(s); + Ok(()) + } + Err(e) => { + debug!("{:?}", e); + let err_str = e.to_string(); + match e { + HandshakeError::WouldBlock(s) => { + debug!("Failed to handshake on SSL connection. Error:{}", err_str); + self.stream = NetStream::SslMidHandshakeStream(s); + Ok(()) + } + _ => { + error!("Failed to accept SSL connection. Error:{}", err_str); + Err(Error::new( + ErrorKind::Other, + format!("An SSL error occurred.{}", err_str), + )) + } + } + } + } + } + _ => { Ok(()) } + } + } + /* if let Ok(_) = NetStream::ssl_handshake(&mut self.stream) { + //self.stream = stream; + }else { + self.close = true; + }*/ } #[derive(Debug)] @@ -270,8 +321,12 @@ pub enum NetStream { UdsStream(UnixStream), /// An SSL-secured TcpStream. /// This is only available when compiled with SSL support. - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] SslTcpStream(SslStream), + //SSL mid handshake stream + SslMidHandshakeStream(MidHandshakeSslStream), + + Invalid, } impl NetStream { @@ -279,11 +334,17 @@ impl NetStream { match *self { NetStream::UnsecuredTcpStream(ref stream) => stream.shutdown(Shutdown::Both), NetStream::UdsStream(ref stream) => stream.shutdown(Shutdown::Both), - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] NetStream::SslTcpStream(ref mut stream) => { - stream.shutdown(); + if let Err(e) = stream.shutdown() { + warn!("Failed to shutdown SSL stream. Error:{:?}",e); + } Ok(()) } + NetStream::SslMidHandshakeStream(ref mut mid_stream) => { + mid_stream.get_mut().shutdown(Shutdown::Both) + } + NetStream::Invalid => { Ok(()) } } } #[inline] @@ -291,8 +352,12 @@ impl NetStream { match *self { NetStream::UnsecuredTcpStream(ref mut stream) => stream.read(buf), NetStream::UdsStream(ref mut stream) => stream.read(buf), - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] NetStream::SslTcpStream(ref mut stream) => stream.read(buf), + NetStream::SslMidHandshakeStream(ref mut mid_stream) => { + mid_stream.get_mut().read(buf) + } + NetStream::Invalid => { Ok(0) } } } #[inline] @@ -300,11 +365,15 @@ impl NetStream { match *self { NetStream::UnsecuredTcpStream(ref mut stream) => stream.write(buf), NetStream::UdsStream(ref mut stream) => stream.write(buf), - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] NetStream::SslTcpStream(ref mut stream) => { // Arc::get_mut(stream).unwrap().write(buf) stream.write(buf) } + NetStream::SslMidHandshakeStream(ref mut mid_stream) => { + mid_stream.get_mut().write(buf) + } + NetStream::Invalid => { Ok(0) } } } #[inline] @@ -312,18 +381,66 @@ impl NetStream { match *self { NetStream::UnsecuredTcpStream(ref mut stream) => stream.write_all(buf), NetStream::UdsStream(ref mut stream) => stream.write_all(buf), - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] NetStream::SslTcpStream(ref mut stream) => stream.write_all(buf), + NetStream::SslMidHandshakeStream(ref mut mid_stream) => { + mid_stream.get_mut().write_all(buf) + } + NetStream::Invalid => { Ok(()) } } } pub fn flush(&mut self) -> IoResult<()> { match *self { NetStream::UnsecuredTcpStream(ref mut stream) => stream.flush(), NetStream::UdsStream(ref mut stream) => stream.flush(), - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] NetStream::SslTcpStream(ref mut stream) => stream.flush(), + NetStream::SslMidHandshakeStream(ref mut mid_stream) => { + mid_stream.get_mut().flush() + } + NetStream::Invalid => { Ok(()) } } } + + /*pub fn ssl_handshake(stream : &mut NetStream) -> Result<(), Error> { + match *stream { + NetStream::SslMidHandshakeStream(mut mid_stream) => { + match mid_stream.handshake() { + Ok(s) => { + info!("ssl_handshake:SSL Handshake successful"); + *stream = NetStream::SslTcpStream(s); + Ok(()) + }, + Err(e) => { + info!("{:?}", e); + let err_str = e.to_string(); + match e { + HandshakeError::WouldBlock(s) => { + info!("Failed to handshake on SSL connection. Error:{}", err_str); + *stream = NetStream::SslMidHandshakeStream(s); + Ok(()) + }, + _ => { + error!("Failed to accept SSL connection. Error:{}", err_str); + return Err(Error::new( + ErrorKind::Other, + format!("An SSL error occurred.{}", err_str), + )); + } + } + } + } + }, + _ => { + error!("Invalid operation. SSL Handshake is only supported on SslMidHandshakeStream"); + return Err(Error::new( + ErrorKind::Other, + "An SSL error occurred.Invalid operation. SSL Handshake is only supported on SslMidHandshakeStream", + )); + } + + } + }*/ } #[derive(Debug)] diff --git a/src/events.rs b/src/events.rs index 2c9fad5..ce3d6b9 100644 --- a/src/events.rs +++ b/src/events.rs @@ -7,37 +7,33 @@ **************************************************/ -use crate::conn::{Conn, NetAddr, NetStream}; -use crate::server::Server; +use crossbeam_channel as mpsc; +use hashbrown::HashMap; +use mio::{Events, Poll, PollOpt, Ready, Token}; use mio::event::Event; use mio::unix::UnixReady; -use mio::{Events, Poll, PollOpt, Ready, Token}; use parking_lot::Mutex; -use std::sync::atomic::Ordering; use std::sync::Arc; - -use crate::config::GaduConfig; -use crossbeam_channel as mpsc; -use hashbrown::HashMap; use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; use std::thread; use std::time::Duration; -pub trait NetEvents { - fn event_opened(&self, id: usize, conn: &mut Conn) -> (Vec, bool); - fn event_closed(&self, id: usize, conn: &mut Conn) -> Result; - fn event_data( - &self, - id: usize, - conn_tags: &mut HashMap, - input_buffer: &mut Vec, - output_buffer: &mut Vec, - ) -> bool; -} +use crate::config::GaduConfig; +use crate::conn::{Conn, NetAddr, NetStream}; +use crate::net_events::NetEvents; +use crate::server::Server; + +//use futures::future::Future; +//use futures::StreamExt; +//use futures::executor::{self, ThreadPool}; +//use futures::io::AsyncWriteExt; +//use futures::task::{SpawnExt}; + pub struct ServerEventHandler { pub server_id: usize, - pub server: Server, + pub server: Arc, pub conn_handlers: Vec>, pub shutdown: bool, } @@ -48,10 +44,10 @@ impl ServerEventHandler { num_threads: usize, config: &GaduConfig, ) -> Result { - let server = Server::init(server_id, config)?; + let server = Arc::new(Server::init(server_id, config)?); let mut conn_handlers = Vec::with_capacity(num_threads); - for _i in 0..num_threads { - let handler = ConnEventHandler::new()?; + for i in 0..num_threads { + let handler = ConnEventHandler::new(i)?; conn_handlers.push(Arc::new(handler)); } Ok(ServerEventHandler { @@ -61,76 +57,18 @@ impl ServerEventHandler { shutdown: false, }) } - - /* - pub fn run_loop(&mut self, event_handler: Arc) where T: NetEvents + 'static + Sync + Send + Sized { - - let mut id = self.server_id; - - crossbeam::scope(|scope| { - for conn_handler in self.conn_handlers.iter() { - - let ev = event_handler.clone(); - let c =conn_handler.clone(); - scope.spawn(move || c.child_loop(ev)); - } - - while !self.shutdown { - id = id + 1; - if let Ok(mut conn) = self.server.accept_connection() { - let (output, close) = event_handler.event_opened(id, &conn); - if close { - conn.close = close; - }else if !output.is_empty() { - conn.output = output; - conn.reg_write = true; - - } - self.add_connection(id, conn); - - continue; - } - } - - }); - }*/ } pub struct ConnEventHandler { + id: usize, pub conns: Arc>>, pub poll: Poll, - //receiver: mpsc::Receiver, - //sender: mpsc::Sender + } -unsafe impl Send for ConnEventHandler {} -unsafe impl Sync for ConnEventHandler {} -/* -impl NetEvents for ConnEventHandler { - /// - /// event opened - #[inline] - fn event_opened(&self, id: usize, conn: &Conn) -> (Vec, bool) { - // new connection, update write command - debug!( - "event_opened: New connection opend with id:{}, address: {:?}", - id, conn.addr - ); - - (Vec::new(), false) - } - /// - /// event close - #[inline] - fn event_closed(&self, id: usize) { - // FUTURE: Adios connection. - debug!("event_closed: Connection close for id:{}", id); - } +//unsafe impl Send for ConnEventHandler {} +//unsafe impl Sync for ConnEventHandler {} - fn event_data(&self, id: usize, buffer: &mut Vec)-> (Vec, bool){ - return (vec![], false) - } -}*/ #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ConnMsg { @@ -139,13 +77,14 @@ pub struct ConnMsg { } impl ConnEventHandler { - pub fn new() -> Result { + pub fn new(id: usize) -> Result { let res = Poll::new(); if let Err(e) = res { return { Err(e.to_string()) }; } //let (sender, receiver) = mpsc::unbounded::(); let conn_handler = ConnEventHandler { + id, conns: Arc::new(Mutex::new(HashMap::new())), poll: res.unwrap(), // receiver, @@ -153,16 +92,18 @@ impl ConnEventHandler { }; Ok(conn_handler) } - pub fn add_connection(&self, id: usize, conn: Conn) -> Result<(), String> { - debug!("ConnEventHandler::add_connection with id:{}", id); - if let Err(e) = self.register(id, &conn) { + pub fn add_connection(&self, conn_id: usize, conn: Conn) -> Result<(), String> { + info!("ConnEventHandler: {} ::add_connection with conn id:{}", self.id, conn_id); + if let Err(e) = self.register(conn_id, &conn) { // read only return Err(e.to_string()); } - self.conns.lock().insert(id, conn); + + self.conns.lock().insert(conn_id, conn); Ok(()) } + /*pub fn add_message(&self, id: usize, msg: &[u8])-> Result<(), String> { let mut output = Vec::with_capacity(msg.len()); output.extend(msg); @@ -206,11 +147,16 @@ impl ConnEventHandler { self.poll .register(stream, Token(id), flags, poll_opt) } - #[cfg(feature = "tls")] + // #[cfg(feature = "tls")] NetStream::SslTcpStream(ref stream) => { self.poll .register(stream.get_ref(), Token(id), flags, poll_opt) } + NetStream::SslMidHandshakeStream(ref stream) => { + self.poll + .register(stream.get_ref(), Token(id), flags, poll_opt) + } + NetStream::Invalid => { Ok(()) } }; if let Err(e) = res { error!("Failed to register connection with id:{} for flag:{:?}. Error:{:?}", id, flags, e); @@ -218,7 +164,7 @@ impl ConnEventHandler { "Failed to register connection with id:{}. Error:{:?}", id, e ) - .to_owned()); + .to_owned()); } Ok(()) } @@ -227,8 +173,10 @@ impl ConnEventHandler { let _res = match conn.get_stream() { NetStream::UnsecuredTcpStream(ref stream) => self.poll.deregister(stream), NetStream::UdsStream(ref stream) => self.poll.deregister(stream), - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] NetStream::SslTcpStream(ref stream) => self.poll.deregister(stream.get_ref()), + NetStream::SslMidHandshakeStream(ref stream) => self.poll.deregister(stream.get_ref()), + NetStream::Invalid => { Ok(()) } }; } #[inline] @@ -236,7 +184,7 @@ impl ConnEventHandler { debug!("check_error_event:event:{:?}", event.readiness()); let er = UnixReady::from(event.readiness()); if er.is_hup() { - debug!( + info!( "UnixReady:Closing peer connection {} due to HUP signal received ", addr.to_string() ); @@ -259,15 +207,13 @@ impl ConnEventHandler { ) where T: NetEvents + ?Sized, { + info!("Child Loop with receiver started for ConnectionHandler :{}", self.id); let mut streams: HashMap = HashMap::new(); let mut events = Events::with_capacity(1024); let mut read_buffer = [0; 32768]; - let mut timeout = Some(Duration::from_millis(1)); - if receiver.is_some() { - timeout = Some(Duration::from_millis(1)); - } + let timeout = if receiver.is_some() { Some(Duration::from_millis(100)) } else { Some(Duration::from_millis(250)) }; loop { //check if shutdown signal received @@ -324,28 +270,37 @@ impl ConnEventHandler { if conn.close { debug!("check_error_event(): Connection closed status:{}", close); } + found = true; if !conn.close { - loop { - debug!("output len:{}", conn.output.len()); - if !conn.output.is_empty() { - close = conn.write(); - } else if !conn.close { - close = conn.read(&mut read_buffer); - // PROFILER.lock().unwrap().start("/tmp/my-prof.profile").expect("Couldn't start"); - debug!("Invoking event_handler::event_data for connection id:{}", id); - let close_conn = - event_handler.event_data(id, &mut conn.tags, &mut conn.input, &mut conn.output); - // PROFILER.lock().unwrap().stop().expect("Couldn't stop"); - debug!("event_data output:{}", String::from_utf8_lossy(&conn.output)); - // conn.output.extend(&output); - conn.close = close_conn; - } - if !conn.close && !conn.output.is_empty() { - continue; + // check if SSL handshake was still pending + if conn.is_ssl_handshake_pending() { + if let Err(e) = conn.ssl_handshake() { + warn!("SSL Handshake failed. Error:{:?}",e); } + continue; + } else { + loop { + debug!("output len:{}", conn.output.len()); + if !conn.output.is_empty() { + close = conn.write(); + } else if !conn.close { + close = conn.read(&mut read_buffer); + // PROFILER.lock().unwrap().start("/tmp/my-prof.profile").expect("Couldn't start"); + debug!("Invoking event_handler::event_data for connection id:{}", id); + let close_conn = + event_handler.event_data(id, &mut conn.tags, &mut conn.input, &mut conn.output); + // PROFILER.lock().unwrap().stop().expect("Couldn't stop"); + debug!("event_data output:{}", String::from_utf8_lossy(&conn.output)); + // conn.output.extend(&output); + conn.close = close_conn; + } + if !conn.close && !conn.output.is_empty() { + continue; + } - break; + break; + } } } if !conn.output.is_empty() { @@ -374,7 +329,7 @@ impl ConnEventHandler { self.deregister(id, &conn); if let Ok(true) = event_handler.event_closed(id, &mut conn) { if self.register(id, &conn).is_ok() { - debug!("Auto reconnect successful. Reregistering socket"); + info!("Auto reconnect successful. Reregistering socket"); conn.reg_write = false; if let Err(e) = self.reregister(&conn, id, true) { error!("Failed to reregister. Error:{:?}", e); //FIXME: should this be closed @@ -407,14 +362,16 @@ impl ConnEventHandler { } pub fn child_loop(&self, event_handler: Arc, shutdown: Arc) - where - T: NetEvents + 'static + Sync + Send + Sized, + where + T: NetEvents + 'static + Sync + Send + Sized, { + info!("Child Loop started for ConnectionHandler :{}", self.id); let mut streams: HashMap = HashMap::new(); let mut events = Events::with_capacity(1024); - let mut read_buffer = [0; 32768]; + let mut read_buffer = [0; 51200]; + + let timeout = Some(Duration::from_millis(250)); - let timeout = Some(Duration::from_millis(1)); loop { //check if shutdown signal received @@ -426,16 +383,6 @@ impl ConnEventHandler { return; } - // if any data in the receiver channel, assigned to connect to send it out - - /*let mut data: Vec = self.receiver.try_iter().collect(); - for msg in data.iter() { - if let Some(mut conn) = streams.get_mut(&msg.id) { - conn.output.extend(&msg.output); - conn.reg_write = true; - } - }*/ - let total_events = match self.poll.poll(&mut events, timeout) { Ok(total_events) => total_events, Err(e) => { @@ -448,7 +395,7 @@ impl ConnEventHandler { continue; } - debug!("Child Poll: Total events received:{}", total_events); + debug!("Child Poll id: {}: Total events received:{}, Number of connections:{}", self.id, total_events, streams.len()); for event in &events { let token = event.token(); @@ -463,33 +410,41 @@ impl ConnEventHandler { //check error/hup event received if conn.close { debug!("Got connection from stream for id {}. Connection closed status:{}", id, conn.close); + } else { + close = ConnEventHandler::check_error_event(&conn.get_address(), &event); } - close = ConnEventHandler::check_error_event(&conn.get_address(), &event); conn.close = close; if conn.close { debug!("check_error_event:Connection closed status:{}", close); } found = true; + if !conn.close { - loop { - debug!("output len:{}", conn.output.len()); - if !conn.output.is_empty() { - close = conn.write(); - } else if !conn.close { - close = conn.read(&mut read_buffer); - // PROFILER.lock().unwrap().start("/tmp/my-prof.profile").expect("Couldn't start"); - let close_conn = - event_handler.event_data(id, &mut conn.tags, &mut conn.input, &mut conn.output); - // PROFILER.lock().unwrap().stop().expect("Couldn't stop"); - debug!("event_data output:{}", String::from_utf8_lossy(&conn.output)); - //conn.output.extend(&output); - conn.close = close_conn; - } - if !conn.close && !conn.output.is_empty() { - continue; + // check if SSL handshake was still pending + if conn.is_ssl_handshake_pending() { + if let Err(e) = conn.ssl_handshake() { + warn!("SSL Handshake failed. Error:{:?}",e); } + } else { + //perform read/write operations + loop { + debug!("output len:{}", conn.output.len()); + if !conn.output.is_empty() { + close = conn.write(); + } else if !conn.close { + close = conn.read(&mut read_buffer); + let close_conn = + event_handler.event_data(id, &mut conn.tags, &mut conn.input, &mut conn.output); + debug!("event_data output:{}", String::from_utf8_lossy(&conn.output)); + //conn.output.extend(&output); + conn.close = close_conn; + } + if !conn.close && !conn.output.is_empty() { + continue; + } - break; + break; + } } } if !conn.output.is_empty() { @@ -540,8 +495,8 @@ impl ConnEventHandler { ); } else if !found { if let Some(conn) = self.conns.lock().remove(&id) { - // if self.reregister( &conn, id, true).is_ok() { - streams.insert(id, conn); + // if self.reregister( &conn, id, true).is_ok() { + streams.insert(id, conn); //} } } diff --git a/src/lib.rs b/src/lib.rs index b429cab..0177654 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,26 +6,32 @@ License: Apache 2.0 **************************************************/ +//#![feature(await_macro, async_await, futures_api)] +//#[macro_use] +//extern crate tokio; #[macro_use] extern crate log; #[macro_use] extern crate serde_derive; +pub mod net_events; pub mod config; pub mod conn; pub mod events; pub mod network_server; pub mod server; + #[cfg(test)] mod tests { + use std::thread; - use super::*; use crate::conn::Conn; use crate::server::Server; - use std::thread; + + use super::*; + #[test] fn it_works() { assert_eq!(2 + 2, 4); } - } diff --git a/src/net_events.rs b/src/net_events.rs new file mode 100644 index 0000000..f79656a --- /dev/null +++ b/src/net_events.rs @@ -0,0 +1,26 @@ +/************************************************ + + File: rato:net_events.rs:NetEvents + Author: ytr289 + Date: 2019-03-26:09:33 + LICENSE: Apache 2.0 + +**************************************************/ +use hashbrown::HashMap; + +use crate::conn::Conn; + +/// +/// NetEvents +/// +pub trait NetEvents { + fn event_opened(&self, id: usize, conn: &mut Conn) -> (Vec, bool); + fn event_closed(&self, id: usize, conn: &mut Conn) -> Result; + fn event_data( + &self, + id: usize, + conn_tags: &mut HashMap, + input_buffer: &mut Vec, + output_buffer: &mut Vec, + ) -> bool; +} \ No newline at end of file diff --git a/src/network_server.rs b/src/network_server.rs index 808a4f2..836a9db 100644 --- a/src/network_server.rs +++ b/src/network_server.rs @@ -6,17 +6,21 @@ LICENSE: Apache 2.0 **************************************************/ -use crate::config::NetworkServerConfig; -use crate::conn::Conn; -use crate::events::{ConnEventHandler, NetEvents, ServerEventHandler}; use crossbeam::thread::Scope; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + +use crate::config::NetworkServerConfig; +use crate::conn::Conn; +use crate::events::{ConnEventHandler, ServerEventHandler}; +use crate::net_events::NetEvents; + + /// /// NetworkServer /// pub struct NetworkServer { - conf : NetworkServerConfig, + conf: NetworkServerConfig, server_event_handler: ServerEventHandler, conn_handlers: Vec>, pub shutdown: Arc, @@ -38,11 +42,11 @@ impl NetworkServer { let mut conn_handlers = Vec::with_capacity(conf.num_threads); for i in 0..conf.num_threads { debug!("Initializing {} NetworkServer::ConnEventHandler {}", conf.server_name, i); - let handler = ConnEventHandler::new()?; + let handler = ConnEventHandler::new(i)?; conn_handlers.push(Arc::new(handler)); } Ok(NetworkServer { - conf : conf.clone(), + conf: conf.clone(), server_event_handler, conn_handlers, shutdown, @@ -51,44 +55,44 @@ impl NetworkServer { } fn add_connection(&self, id: usize, conn: Conn) -> Result<(), String> { debug!("add_connection with id:{}", id); - let mut index = self.round_robin_counter.fetch_add(1, Ordering::SeqCst); if index >= self.conn_handlers.len() { self.round_robin_counter.store(0, Ordering::SeqCst); - index = 0; + index = self.round_robin_counter.fetch_add(1, Ordering::SeqCst); + self.conn_handlers[index].add_connection(id, conn) + } else { + self.conn_handlers[index].add_connection(id, conn) } - self.conn_handlers[index].add_connection(id, conn) } fn server_loop(&self, net_event_handler: Arc) { let mut id = self.server_event_handler.server_id; info!("Waiting for connection..."); while !self.shutdown.load(Ordering::SeqCst) { - id += 1; - if let Ok(mut conn) = self.server_event_handler.server.accept_connection() { - let (output, close) = net_event_handler.event_opened(id, &mut conn); - if close { - conn.close = close; - } else if !output.is_empty() { - conn.output = output; - conn.reg_write = true; - } + if let Ok(conns) = self.server_event_handler.server.accept_connection() { + for mut conn in conns.into_iter() { + id += 1; - if let Err(e) = self.add_connection(id, conn) { - error!("Failed to add connect with id: {}. Error:{:?}", id, e); - } + let (output, close) = net_event_handler.event_opened(id, &mut conn); + if close { + conn.close = close; + } else if !output.is_empty() { + conn.output = output; + conn.reg_write = true; + } - continue; - } else { - std::thread::sleep(std::time::Duration::from_millis(250)); + if let Err(e) = self.add_connection(id, conn) { + error!("Failed to add connect with id: {}. Error:{:?}", id, e); + } + } } } info!("Shutdown received. Exiting {} NetworkServer server_loop...", self.conf.server_name); } - pub fn run_loop(scope: &Scope, network_server: Arc, net_event_handler: Arc, non_blocking:bool) - where - T: NetEvents + 'static + Sync + Send + Sized, + pub fn run_loop(scope: &Scope, network_server: Arc, net_event_handler: Arc, non_blocking: bool) + where + T: NetEvents + 'static + Sync + Send + Sized, { for conn_handler in network_server.conn_handlers.iter() { let ev = net_event_handler.clone(); //kanudo.network_controller.clone(); @@ -101,9 +105,8 @@ impl NetworkServer { info!("Starting kanudo server loop"); if non_blocking { scope.spawn(move |_| network_server.server_loop(net_event_handler)); - }else { + } else { network_server.server_loop(net_event_handler) - } } } diff --git a/src/server.rs b/src/server.rs index 58e7402..e09c989 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,196 +6,71 @@ License: Apache 2.0 **************************************************/ -#[cfg(feature = "tls")] + +use mio::*; +use mio::net::TcpListener; + +use mio_uds::UnixListener; +//#[cfg(feature = "tls")] +use openssl::error::ErrorStack; +//#[cfg(feature = "tls")] use openssl::ssl::{ - SslAcceptor, SslConnectorBuilder, SslFiletype, SslMethod, SslStream, SslVerifyMode, + HandshakeError, SslAcceptor, SslFiletype, SslMethod, SslMode, SslVerifyMode, SslVersion, }; +//#[cfg(feature = "tls")] -#[cfg(feature = "tls")] -use openssl::error::ErrorStack; -#[cfg(feature = "tls")] -use openssl::x509; //use std::io::{Read, Write}; -use std::io::Error; +use std::io::{Error, ErrorKind}; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; use std::time::Duration; +use url::Url; use crate::config::GaduConfig; use crate::conn::{Conn, NetAddr, NetStream}; -use mio::net::TcpListener; -use mio::*; -use mio_uds::UnixListener; -use url::Url; + pub enum NetListener { UnsecuredTcpListener(TcpListener), - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] SslTcpListener(TcpListener), /// Unix domain socket stream UdsListener(UnixListener), -} - -impl NetListener { - pub fn accept_tcp_connection( - listener: &TcpListener, - config: &GaduConfig, - ) -> Result<(NetStream, NetAddr), Error> { - let s = listener.accept()?; - debug!("New peer connection received from: {:?}", s.1); - if let Err(e) = - s.0.set_keepalive(Some(Duration::from_millis(config.keep_alive_time))) - { - error!("Failed to set keepalive. Error:{:?}", e); - } - - if let Err(e) = - s.0.set_nodelay(true) - { - error!("Failed to set nodelay to true. Error:{:?}", e); - } - - Ok(( - NetStream::UnsecuredTcpStream(s.0), - NetAddr::NetSocketAddress(s.1), - )) - /* - match conn_res { - Ok(s) => { - debug!("New peer connection received from: {:?}", s.1); - s.0.set_keepalive(Some(Duration::from_millis(config.keep_alive_time))); - Ok((NetStream::UnsecuredTcpStream(s.0), NetAddr::NetSocketAddress(s.1))) - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => - { - Err(Error::new( - io::ErrorKind::WouldBlock, - "Failed to accept new connection", - )) - }, - Err(e) => panic!("encountered IO error: {}", e), - }*/ - } - - pub fn accept_uds_connection(listener: &UnixListener) -> Result<(NetStream, NetAddr), Error> { - let s = listener.accept()?; - let (sock, addr) = s.unwrap(); - //let addr = s.0.peer_addr().unwrap(); - debug!("New peer connection received from: {:?}", addr); - - Ok((NetStream::UdsStream(sock), NetAddr::UdsSocketAddress(addr))) - /*match conn_res { - Ok(s) => { - let (sock, addr) = s.unwrap(); - //let addr = s.0.peer_addr().unwrap(); - info!("New peer connection received from: {:?}", addr); - Ok((NetStream::UdsStream(sock), NetAddr::UdsSocketAddress(addr))) - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => - { - Err(Error::new( - io::ErrorKind::WouldBlock, - format!("An UDS error occurred.({:?}",e) - ) - ) - - }, - Err(e) => panic!("encountered IO error: {}", e), - }*/ - } - #[cfg(feature = "tls")] - pub fn accept_ssl_connection( - listener: &TcpListener, - config: &GaduConfig, - acceptor: Arc, - ) -> Result<(NetStream, NetAddr), Error> { - let s = listener.accept()?; - debug!("New peer connection received from: {:?}", s.1); - s.0.set_keepalive(Some(Duration::from_millis(config.keep_alive_time))); - let stream = match acceptor.accept(s.0) { - Ok(s) => s, - Err(e) => { - return Err(Error::new( - io::ErrorKind::Other, - format!("An SSL error occurred.({:?}", e), - )); - } - }; - - Ok(( - NetStream::SslTcpStream(stream), - NetAddr::NetSocketAddress(s.1), - )) - /* - match listener.accept() { - Ok(s) => { - info!("New peer connection received from: {:?}", s.1); - s.0.set_keepalive(Some(Duration::from_millis(config.keep_alive_time))); - let stream = match acceptor.accept(s.0) { - Ok(s) => s, - Err(e) => { - return Err( - Error::new( - io::ErrorKind::Other, - format!("An SSL error occurred.({:?}", e) - ) - ); - } - }; - - Ok((NetStream::SslTcpStream(stream), NetAddr::NetSocketAddress(s.1))) - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => - { - Err(Error::new( - io::ErrorKind::WouldBlock, - "Failed to accept new connection", - )) - }, - Err(e) => panic!("encountered IO error: {}", e), - }*/ - } - /* - pub fn accept_connection(&self, config : &GaduConfig, acceptor: Arc) -> Result<(NetStream,NetAddr), Error>{ - - - match self { - &NetListener::UnsecuredTcpListener(ref listener) => { - NetListener::accept_tcp_connection(&listener, &config) - } - #[cfg(feature = "tls")] - &NetListener::SslTcpListener(ref listener) => { - NetListener::accept_ssl_connection(&listener, &config, acceptor) - }, - &NetListener::UdsListener(ref listener) => { - NetListener::accept_uds_connection(&listener, &config) - } - } - - }*/ } + pub struct Server { pub id: Token, pub config: GaduConfig, pub url: Url, + pub mail_poll: Poll, + poll_timeout: Duration, pub server: NetListener, - #[cfg(feature = "tls")] - acceptor: Arc, + acceptor: Option>, + pub next_token_id: AtomicUsize, } impl Server { pub fn init(token_id: usize, config: &GaduConfig) -> Result { - #[cfg(feature = "tls")] - info!("SSL Server initializing.."); + let res = Poll::new(); + if let Err(e) = res { + return Err(e.to_string()); + } + let poll = res.unwrap(); + + //#[cfg(feature = "tls")] + info!("Server initializing.."); let url = match Url::parse(&config.url) { Ok(url) => url, Err(e) => { return Err(e.to_string()); } }; + let net_server = match url.scheme() { - #[cfg(not(feature = "ssl"))] + // #[cfg(not(feature = "ssl"))] "tcp" => { if !url.has_host() { return Err( @@ -219,37 +94,19 @@ impl Server { } }; - info!("Tcp Server started on {}", addr); - NetListener::UnsecuredTcpListener(server) - } - #[cfg(feature = "tls")] - "tcp" | "ssl" | "tls" => { - info!("SSL enabled"); - if !url.has_host() { - return Err( - "Invalid Url. It must have host defined. e.g. ssl://host:port".to_owned(), - ); - } - if url.port().is_none() { - return Err( - "Invalid Url. It must have port defined. e.g. ssl://host:port".to_owned(), - ); - } - let addr = format!("{}:{}", url.host_str().unwrap(), url.port().unwrap()); - - debug!("Binding Server at {}", &config.url); - let server = match TcpListener::bind(&addr.parse().unwrap()) { - Ok(sock) => sock, - Err(e) => { - error!("EventHandler: Couldn't bind at {}. Error: {:?}", addr, e); - return Err(e.to_string()); + match config.ssl_config.enabled { + true => { + info!("Secure Server started on {}", addr); + poll.register(&server, Token(token_id), Ready::readable(), PollOpt::empty()).unwrap(); + NetListener::SslTcpListener(server) } - }; - - info!("Secure Server started on {}", addr); - - NetListener::SslTcpListener(server) + _ => { + info!("Tcp Server started on {}", addr); + poll.register(&server, Token(token_id), Ready::readable(), PollOpt::empty()).unwrap(); + NetListener::UnsecuredTcpListener(server) + } + } } "unix" => { let path = url.path(); @@ -262,53 +119,107 @@ impl Server { } }; info!("Unix Server started on {}", path); + poll.register(&server, Token(token_id), Ready::readable(), PollOpt::empty()).unwrap(); NetListener::UdsListener(server) } _ => { return Err("Unsupported scheme. Valid schemes are unix and tcp".to_owned()); } }; - - + + + let acceptor = if config.ssl_config.enabled { Some(Arc::new(Server::init_ssl_acceptor(&config)?)) } else { None }; + Ok(Server { id: Token(token_id), config: config.clone(), url, + mail_poll: poll, + poll_timeout: Duration::from_millis(250), server: net_server, - #[cfg(feature = "tls")] - acceptor: Arc::new(Server::init_ssl_acceptor(&config)?), + //#[cfg(feature = "tls")] + acceptor, + next_token_id: AtomicUsize::new(0), }) } - pub fn accept_connection(&self) -> Result { - let (net_stream, net_addr) = match self.server { - NetListener::UnsecuredTcpListener(ref listener) => { - NetListener::accept_tcp_connection(&listener, &self.config)? - } - #[cfg(feature = "tls")] - NetListener::SslTcpListener(ref listener) => { - NetListener::accept_ssl_connection(&listener, &self.config, self.acceptor.clone())? - } - NetListener::UdsListener(ref listener) => { - NetListener::accept_uds_connection(&listener)? - } - }; - Ok(Conn::new(net_stream, net_addr)) + pub fn accept_connection(&self) -> Result, Error> { + let mut events = Events::with_capacity(1024); + let total_events = self.mail_poll.poll(&mut events, Some(self.poll_timeout))?; + + let mut conns = Vec::with_capacity(10); + + if total_events == 0 { + return Ok(conns); + } + + debug!("Main Poll: Total events received:{}", total_events); + + + for event in &events { + debug!("event readiness:{:?}", event.readiness()); + if self.id == event.token() { + let (net_stream, net_addr) = match self.server { + NetListener::UnsecuredTcpListener(ref listener) => { + self.accept_tcp_connection(&listener)? + } + // #[cfg(feature = "tls")] + NetListener::SslTcpListener(ref listener) => { + self.accept_ssl_connection(&listener)? + } + NetListener::UdsListener(ref listener) => { + self.accept_uds_connection(&listener)? + } + }; + conns.push(Conn::new(net_stream, net_addr)); + } +// }else { +// // these are uncompleted SSL handshake or others... +// let entry_res = server.pending_streams.lock().remove(&event.token().0); +// if entry_res.is_none() { +// continue; +// } +// let (ssl_net_stream, ssl_net_addr) = entry_res.unwrap(); +// match ssl_net_stream { +// NetStream::SslMidHandshakeStream(stream) => { +// server.mail_poll.deregister(stream.get_ref()); +// let (net_stream, net_addr) = Server::ssl_handshake(server.clone(),stream, ssl_net_addr)?; +// conns.push(Conn::new(net_stream, net_addr)); +// +// } +// _ => { +// error!("Not supported pending stream"); +// continue; +// } +// } +// } + //} //match + } //for events + Ok(conns) } - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] fn init_ssl_acceptor(config: &GaduConfig) -> Result { match Server::build_acceptor(&config) { Ok(acceptor) => Ok(acceptor), - Err(e) => Err(e.to_string()), + Err(e) => { + error!("Failed to build SSL acceptor"); + Err(e.to_string()) + } } } - #[cfg(feature = "tls")] + //#[cfg(feature = "tls")] fn build_acceptor(config: &GaduConfig) -> Result { let mut ctx = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; { ctx.set_default_verify_paths()?; + //ctx.clear_options(SslOptions::NO_TLSV1_3); + + ctx.set_min_proto_version(Some(SslVersion::TLS1_2))?; + + ctx.set_mode(SslMode::AUTO_RETRY); + // verify peer if config.ssl_config.verify.unwrap_or(false) { ctx.set_verify(SslVerifyMode::PEER); @@ -320,21 +231,166 @@ impl Server { ctx.set_verify_depth(config.ssl_config.verify_depth.unwrap()); } if config.ssl_config.certificate_file.is_some() { - ctx.set_certificate_file( - config.ssl_config.certificate_file.as_ref().unwrap(), - SslFiletype::PEM, + ctx.set_certificate_chain_file( + config.ssl_config.certificate_file.as_ref().unwrap() + // SslFiletype::PEM, )?; + info!("Setting SSL certificate file: {:?}", config.ssl_config.certificate_file.as_ref().unwrap()); } if config.ssl_config.private_key_file.is_some() { ctx.set_private_key_file( config.ssl_config.private_key_file.as_ref().unwrap(), SslFiletype::PEM, )?; + info!("Setting SSL private key file: {:?}", config.ssl_config.private_key_file.as_ref().unwrap()); + ctx.check_private_key()?; + info!("Checking private key file successful"); } if config.ssl_config.ca_file.is_some() { - let _ = ctx.set_ca_file(config.ssl_config.ca_file.as_ref().unwrap())?; + ctx.set_ca_file(config.ssl_config.ca_file.as_ref().unwrap())?; + info!("Setting SSL CA cert file: {:?}", config.ssl_config.ca_file.as_ref().unwrap()); } } Ok(ctx.build()) } + pub fn accept_tcp_connection(&self, + listener: &TcpListener, + ) -> Result<(NetStream, NetAddr), Error> { + let s = listener.accept()?; + + debug!("TCP:New peer connection received from: {:?}", s.1); + //if let Err(e) = s.0.set_nonblocking(true) { + // error!("Failed to set nonblocking to true. Error:{:?}", e); + // + //} + if let Err(e) = s.0.set_nodelay(true) { + error!("Failed to set nodelay to true. Error:{:?}", e); + } + if let Err(e) = + s.0.set_keepalive(Some(Duration::from_millis(self.config.keep_alive_time))) + { + error!("Failed to set keepalive. Error:{:?}", e); + } + + + Ok(( + NetStream::UnsecuredTcpStream(s.0), + NetAddr::NetSocketAddress(s.1), + )) + /* + match conn_res { + Ok(s) => { + debug!("New peer connection received from: {:?}", s.1); + s.0.set_keepalive(Some(Duration::from_millis(config.keep_alive_time))); + Ok((NetStream::UnsecuredTcpStream(s.0), NetAddr::NetSocketAddress(s.1))) + }, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => + { + Err(Error::new( + io::ErrorKind::WouldBlock, + "Failed to accept new connection", + )) + }, + Err(e) => panic!("encountered IO error: {}", e), + }*/ + } + + pub fn accept_uds_connection(&self, listener: &UnixListener) -> Result<(NetStream, NetAddr), Error> { + debug!("accept_uds_connection()"); + let accept_results = listener.accept()?; + + if accept_results.is_none() { + //error!("Failed to get uds connection"); + return Err(Error::new(ErrorKind::Other, "none retuned")); + } + + let (stream, addr) = accept_results.unwrap(); + + //let stream = mio_uds::UnixListener::from_listener(stream)?; + + debug!("UDS: New peer connection received from: {:?}", addr); + //sock.set_nonblocking(true); + + Ok((NetStream::UdsStream(stream), NetAddr::UdsSocketAddress(addr))) + /*match conn_res { + Ok(s) => { + let (sock, addr) = s.unwrap(); + //let addr = s.0.peer_addr().unwrap(); + info!("New peer connection received from: {:?}", addr); + + Ok((NetStream::UdsStream(sock), NetAddr::UdsSocketAddress(addr))) + }, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => + { + Err(Error::new( + io::ErrorKind::WouldBlock, + format!("An UDS error occurred.({:?}",e) + ) + ) + + }, + Err(e) => panic!("encountered IO error: {}", e), + }*/ + } + + + // #[cfg(feature = "tls")] + pub fn accept_ssl_connection(&self, + listener: &TcpListener, + ) -> Result<(NetStream, NetAddr), Error> { +// if acceptor.is_none() { +// return Err(Error::new( +// ErrorKind::Other, +// "An SSL error occurred. SslAcceptor not initialized".to_string(), +// )); +// } + let (sock, addr) = listener.accept_std()?; + debug!("SSL: New peer connection received from: {:?}", addr); + let sock = mio::net::TcpStream::from_stream(sock)?; + + if let Err(e) = sock.set_nodelay(true) { + error!("Failed to set nodelay to true for addr: {:?}. Error:{:?}", addr, e); + } + if let Err(e) = sock.set_keepalive(Some(Duration::from_millis(self.config.keep_alive_time))) { + error!("Failed to set keepalive for addr: {:?}. Error:{:?}", addr, e); + } + + match self.acceptor.as_ref().unwrap().accept(sock) { + Ok(s) => { + debug!("SSL Handshake successful"); + Ok(( + NetStream::SslTcpStream(s), + NetAddr::NetSocketAddress(addr) + )) + } + Err(e) => { + debug!("{:?}", e); + let err_str = e.to_string(); + match e { + HandshakeError::WouldBlock(s) => { + debug!("Failed to handshake on SSL connection. Received error: HandshakeError::WouldBlock"); + Ok(( + NetStream::SslMidHandshakeStream(s), + NetAddr::NetSocketAddress(addr) + )) +// server.mail_poll.register(s.get_ref(), Token(server.next_token_id.load(Ordering::Relaxed)), Ready::readable() , PollOpt::empty() ).unwrap(); +// server.pending_streams.lock().insert(server.next_token_id.load(Ordering::Relaxed), (NetStream::SslMidHandshakeStream(s), NetAddr::NetSocketAddress(addr))); +// server.next_token_id.fetch_add(1, Ordering::Relaxed); +// +// return Err(Error::new( +// ErrorKind::Other, +// format!("An SSL error occurred.{}", err_str), +// )); + } + _ => { + error!("Failed to accept SSL connection. Error:{}", err_str); + Err(Error::new( + ErrorKind::Other, + format!("An SSL error occurred.{}", err_str), + )) + } + } + } + } + } }