diff --git a/src/session.rs b/src/session.rs index 7adebfc..16bbaa0 100644 --- a/src/session.rs +++ b/src/session.rs @@ -122,6 +122,7 @@ impl Session { self.wait_read()?; } } + fn wait_read(&self) -> Result<(), Error> { //Wait on both the read handle and the shutdown handle so that we stop when requested let handles = [self.get_read_wait_event()?.0, self.shutdown_event.0 .0]; @@ -154,6 +155,7 @@ impl Session { Ok(()) } } + impl Session { pub fn try_recv(&self, buf: &mut [u8]) -> std::io::Result { let mut size = 0u32; @@ -163,7 +165,7 @@ impl Session { debug_assert!(size <= u16::MAX as u32); if ptr.is_null() { - //Wintun returns ERROR_NO_MORE_ITEMS instead of blocking if packets are not available + // Wintun returns ERROR_NO_MORE_ITEMS instead of blocking if packets are not available return match unsafe { GetLastError() } { ERROR_NO_MORE_ITEMS => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), e => Err(std::io::Error::from_raw_os_error(e as i32)), @@ -171,14 +173,10 @@ impl Session { } let size = size as usize; if size > buf.len() { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "destination buffer too small", - )); - } - unsafe { - ptr::copy_nonoverlapping(ptr, buf.as_mut_ptr(), size); + use std::io::{Error, ErrorKind::InvalidInput}; + return Err(Error::new(InvalidInput, "destination buffer too small")); } + unsafe { ptr::copy_nonoverlapping(ptr, buf.as_mut_ptr(), size) }; Ok(size) } @@ -187,14 +185,14 @@ impl Session { /// will return Err(()) pub fn recv(&self, buf: &mut [u8]) -> std::io::Result { loop { - //Try 5 times to receive without blocking so we don't have to issue a syscall to wait - //for the event if packets are being received at a rapid rate + // Try 5 times to receive without blocking so we don't have to issue a syscall to wait + // for the event if packets are being received at a rapid rate for _ in 0..5 { return match self.try_recv(buf) { Ok(len) => Ok(len), Err(e) => { if e.kind() == std::io::ErrorKind::WouldBlock { - //Try again + // Try again continue; } Err(e) @@ -204,6 +202,7 @@ impl Session { self.wait_read()?; } } + pub fn send(&self, buf: &[u8]) -> std::io::Result { let wintun = &self.adapter.wintun; let size = buf.len(); @@ -211,13 +210,12 @@ impl Session { if ptr.is_null() { util::get_last_error()?; } - unsafe { - ptr::copy_nonoverlapping(buf.as_ptr(), ptr, size); - wintun.WintunSendPacket(self.session.0, ptr); - } + unsafe { ptr::copy_nonoverlapping(buf.as_ptr(), ptr, size) }; + unsafe { wintun.WintunSendPacket(self.session.0, ptr) }; Ok(buf.len()) } } + impl Drop for Session { fn drop(&mut self) { if let Err(e) = self.shutdown() {