diff --git a/src/transport/multicast_server.rs b/src/transport/multicast_server.rs index 2699a4e9..7bc92db4 100644 --- a/src/transport/multicast_server.rs +++ b/src/transport/multicast_server.rs @@ -17,8 +17,9 @@ use log::*; use std::io; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, RwLock}; -use std::thread; +use std::thread::{self, JoinHandle}; use crate::protocol::Message; use crate::transport::default::*; @@ -29,6 +30,8 @@ use crate::transport::udp_socket::UdpSocket; pub struct MulticastServer { socket: Arc>, notifier: Notifier, + stop_flag: Arc, + thread_handle: Option>, } impl MulticastServer { @@ -36,6 +39,8 @@ impl MulticastServer { MulticastServer { socket: Arc::new(RwLock::new(UdpSocket::new())), notifier: notifier_new(), + stop_flag: Arc::new(AtomicBool::new(false)), + thread_handle: None, } } @@ -119,17 +124,21 @@ impl MulticastServer { true } - pub fn close(&self) -> bool { - self.socket.read().unwrap().close(); - true + pub fn close(&mut self) -> bool { + let sock = self.socket.try_write(); + if sock.is_err() { + return false; + } + sock.unwrap().close() } pub fn start(&mut self) -> bool { let socket = self.socket.clone(); let notifier = self.notifier.clone(); - thread::spawn(move || { + let stop_flag = self.stop_flag.clone(); + let handle = thread::spawn(move || { let mut buf = [0 as u8; MAX_PACKET_SIZE]; - loop { + while !stop_flag.load(Ordering::Relaxed) { let recv_res = socket.read().unwrap().recv_from(&mut buf); match &recv_res { Ok((n_bytes, remote_addr)) => { @@ -164,10 +173,16 @@ impl MulticastServer { } } }); + + self.thread_handle = Some(handle); true } - pub fn stop(&self) -> bool { + pub fn stop(&mut self) -> bool { + self.stop_flag.store(true, Ordering::Relaxed); + if let Some(handle) = self.thread_handle.take() { + handle.join().unwrap(); + } if !self.close() { return false; } diff --git a/src/transport/udp_socket.rs b/src/transport/udp_socket.rs index df3d0291..ebb3e4f1 100644 --- a/src/transport/udp_socket.rs +++ b/src/transport/udp_socket.rs @@ -82,6 +82,7 @@ impl UdpSocket { "socket is not bound", )) } + pub fn bind(&mut self, ifaddr: SocketAddr) -> Result<()> { if self.sock.is_some() { self.close(); @@ -115,9 +116,9 @@ impl UdpSocket { Ok(()) } - pub fn close(&self) { + pub fn close(&mut self) -> bool { if self.sock.is_none() { - return; + return true; } #[cfg(feature = "unix")] { @@ -129,9 +130,12 @@ impl UdpSocket { let res = close(fd); if res.is_err() { warn!("close {:?}", res.err()); + return false; } + self.sock = None; } thread::sleep(time::Duration::from_millis(UDP_SOCKET_BIND_SLEEP_MSEC)); + true } pub fn send_to(&self, buf: &[u8], to_addr: SocketAddr) -> Result { diff --git a/src/transport/unicast_server.rs b/src/transport/unicast_server.rs index 78372194..0200026d 100644 --- a/src/transport/unicast_server.rs +++ b/src/transport/unicast_server.rs @@ -17,8 +17,9 @@ use log::*; use std::io; use std::net::{IpAddr, SocketAddr}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, RwLock}; -use std::thread; +use std::thread::{self, JoinHandle}; use crate::protocol::Message; use crate::transport::default::*; @@ -29,6 +30,8 @@ use crate::transport::udp_socket::UdpSocket; pub struct UnicastServer { socket: Arc>, notifier: Notifier, + stop_flag: Arc, + thread_handle: Option>, } impl UnicastServer { @@ -36,6 +39,8 @@ impl UnicastServer { UnicastServer { socket: Arc::new(RwLock::new(UdpSocket::new())), notifier: notifier_new(), + stop_flag: Arc::new(AtomicBool::new(false)), + thread_handle: None, } } @@ -85,17 +90,21 @@ impl UnicastServer { true } - pub fn close(&self) -> bool { - self.socket.read().unwrap().close(); - true + pub fn close(&mut self) -> bool { + let sock = self.socket.try_write(); + if sock.is_err() { + return false; + } + sock.unwrap().close() } pub fn start(&mut self) -> bool { let socket = self.socket.clone(); let notifier = self.notifier.clone(); - thread::spawn(move || { + let stop_flag = self.stop_flag.clone(); + let handle = thread::spawn(move || { let mut buf = [0 as u8; MAX_PACKET_SIZE]; - loop { + while !stop_flag.load(Ordering::Relaxed) { let recv_res = socket.read().unwrap().recv_from(&mut buf); match &recv_res { Ok((n_bytes, remote_addr)) => { @@ -130,10 +139,15 @@ impl UnicastServer { } } }); + self.thread_handle = Some(handle); true } - pub fn stop(&self) -> bool { + pub fn stop(&mut self) -> bool { + self.stop_flag.store(true, Ordering::Relaxed); + if let Some(handle) = self.thread_handle.take() { + handle.join().unwrap(); + } if !self.close() { return false; }