Skip to content

Commit

Permalink
async implement
Browse files Browse the repository at this point in the history
  • Loading branch information
ssrlive committed Aug 30, 2024
1 parent 634a711 commit 74593ac
Show file tree
Hide file tree
Showing 9 changed files with 509 additions and 41 deletions.
12 changes: 11 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ targets = [

[features]
default = []
# default = ["verify_binary_signature", "panic_on_unsent_packets"]
# default = ["verify_binary_signature", "panic_on_unsent_packets", "async"]
async = ["async-task", "blocking", "futures"]
panic_on_unsent_packets = []
verify_binary_signature = []

[dependencies]
async-task = { version = "4", optional = true }
blocking = { version = "1", optional = true }
c2rust-bitfields = "0.18"
futures = { version = "0.3", optional = true }
libloading = "0.8"
log = "0.4"
thiserror = "1"
Expand Down Expand Up @@ -59,3 +63,9 @@ packet = "0.1"
pcap-file = "2"
serde_json = "1"
subprocess = "0.2"
tokio = { version = "1", features = ["full"] }

[[example]]
name = "udp-echo-async"
path = "examples/udp-echo-async.rs"
required-features = ["async"]
19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,19 @@ wintun's internal ring buffer.

- `verify_binary_signature`: Verifies the signature of the wintun dll file before loading it.

## TODO:
- Add async support
Requires hooking into a windows specific reactor and registering read interest on wintun's read
handle. Asyncify other slow operations via tokio::spawn_blocking. As always, PR's are welcome!

- `async`: Enables async support for the library.
Just add `async` feature to your `Cargo.toml`:
```toml
[dependencies]
wintun-bindings = { version = "0.1", features = ["async"] }
```
And simply transform your `Session` into an `AsyncSession`:
```rust
// ...
let session = Arc::new(adapter.start_session(MAX_RING_CAPACITY)?);
let mut reader_session: AsyncSession = session.clone().try_into()?;
let mut writer_session: AsyncSession = session.clone().try_into()?;
// ...
```

License: MIT
282 changes: 282 additions & 0 deletions examples/udp-echo-async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
//! This example demonstrates how to use Wintun to create a simple UDP echo server.
//!
//! You can see packets being received by wintun by runnig: `nc -u 10.28.13.100 4321`
//! and sending lines of text.
use futures::{AsyncReadExt, AsyncWriteExt};
use std::{
net::{IpAddr, SocketAddr},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use tokio::sync::mpsc::channel;
use windows_sys::Win32::{
Foundation::FALSE,
Security::Cryptography::{CryptAcquireContextW, CryptGenRandom, CryptReleaseContext, PROV_RSA_FULL},
};
use wintun_bindings::{
get_active_network_interface_gateways, get_running_driver_version, get_wintun_bin_pattern_path, load_from_path,
run_command, Adapter, AsyncSession, BoxError, Error, MAX_RING_CAPACITY,
};

#[derive(Debug)]
struct NaiveUdpPacket {
src_addr: SocketAddr,
dst_addr: SocketAddr,
data: Vec<u8>,
}

impl NaiveUdpPacket {
fn new(src_addr: SocketAddr, dst_addr: SocketAddr, data: &[u8]) -> Self {
Self {
src_addr,
dst_addr,
data: data.to_vec(),
}
}
}

impl std::fmt::Display for NaiveUdpPacket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"src=\"{}\", dst=\"{}\", data length {}",
self.src_addr,
self.dst_addr,
self.data.len()
)
}
}

#[tokio::main]
async fn main() -> Result<(), BoxError> {
dotenvy::dotenv().ok();
env_logger::init();
// Loading wintun
let dll_path = get_wintun_bin_pattern_path()?;
let wintun = unsafe { load_from_path(dll_path)? };

let version = get_running_driver_version(&wintun);
println!("Wintun version: {:?}", version);

let adapter_name = "Demo";
let guid = 2131231231231231231_u128;

// Open or create a new adapter
let adapter = match Adapter::open(&wintun, adapter_name) {
Ok(a) => a,
Err(_) => Adapter::create(&wintun, adapter_name, "MyTunnelType", Some(guid))?,
};

let version = get_running_driver_version(&wintun)?;
println!("Wintun version: {}", version);

// set metric command: `netsh interface ipv4 set interface adapter_name metric=255`
let args = &["interface", "ipv4", "set", "interface", adapter_name, "metric=255"];
run_command("netsh", args)?;
println!("netsh {}", args.join(" "));

// Execute the network card initialization command, setting virtual network card information
// ip = 10.28.13.2 mask = 255.255.255.0 gateway = 10.28.13.1
// command: `netsh interface ipv4 set address adapter_name static 10.28.13.2/24 gateway=10.28.13.1`
let args = &[
"interface",
"ipv4",
"set",
"address",
adapter_name,
"static",
"10.28.13.2/24",
"gateway=10.28.13.1",
];
run_command("netsh", args)?;
println!("netsh {}", args.join(" "));

let dns = "8.8.8.8".parse::<IpAddr>().unwrap();
let dns2 = "8.8.4.4".parse::<IpAddr>().unwrap();
adapter.set_dns_servers(&[dns, dns2])?;

let v = adapter.get_addresses()?;
for addr in &v {
let mask = adapter.get_netmask_of_address(addr)?;
println!("address {} netmask: {}", addr, mask);
}

let gateways = adapter.get_gateways()?;
println!("adapter gateways: {gateways:?}");

// adapter.set_name("MyNewName")?;
// println!("adapter name: {}", adapter.get_name()?);

// adapter.set_address("10.28.13.2".parse()?)?;

println!("adapter mtu: {}", adapter.get_mtu()?);

println!(
"active adapter gateways: {:?}",
get_active_network_interface_gateways()?
);

let session = Arc::new(adapter.start_session(MAX_RING_CAPACITY)?);

let mut reader_session: AsyncSession = session.clone().try_into()?;
let mut writer_session: AsyncSession = session.clone().try_into()?;

let (tx, mut rx) = channel::<NaiveUdpPacket>(1000);

// Global flag to stop the session
static RUNNING: AtomicBool = AtomicBool::new(true);

let reader = tokio::task::spawn(async move {
let block = async {
while RUNNING.load(Ordering::Relaxed) {
let mut bytes = [0u8; 1500];

// recieved IP packet
let len = reader_session.read(&mut bytes).await?;
if len == 0 {
break;
}

let udp_packet = extract_udp_packet(&bytes[..len]);
if let Err(err) = udp_packet {
println!("{}", err);
continue;
}

// swap src and dst
let mut udp_packet = udp_packet?;
let src_addr = udp_packet.src_addr;
let dst_addr = udp_packet.dst_addr;
udp_packet.src_addr = dst_addr;
udp_packet.dst_addr = src_addr;

// send to writer
tx.send(udp_packet).await?;
}
Ok::<(), BoxError>(())
};
if let Err(err) = block.await {
println!("Reader {}", err);
}
});

let writer = tokio::task::spawn(async move {
let block = async {
while RUNNING.load(Ordering::Relaxed) {
let resp = rx.recv().await.ok_or("Channel closed")?;

let src_addr = match resp.src_addr.ip() {
IpAddr::V4(addr) => addr,
IpAddr::V6(_) => return Err("IPv6 addresses are not supported".into()),
};

let dst_addr = match resp.dst_addr.ip() {
IpAddr::V4(addr) => addr,
IpAddr::V6(_) => return Err("IPv6 addresses are not supported".into()),
};

let v = generate_random_bytes(2)?;
let id = u16::from_ne_bytes([v[0], v[1]]);

// build response IP packet
use packet::Builder;
let ip_packet = packet::ip::v4::Builder::default()
.id(id)?
.ttl(64)?
.source(src_addr)?
.destination(dst_addr)?
.udp()?
.source(resp.src_addr.port())?
.destination(resp.dst_addr.port())?
.payload(&resp.data)?
.build()?;

// // The following code will be better than above, the `ipv4_udp_build` function link is
// //
// // https://github.com/pysrc/study-udp/blob/59d7ba210a022d207c60ad5370de37110fefaefb/src/protocol.rs#L157-L252
// //
// let mut ip_packet = vec![0u8; 28 + resp.data.len()];
// protocol::ipv4_udp_build(
// &mut ip_packet,
// &src_addr.octets(),
// resp.src_addr.port(),
// &dst_addr.octets(),
// resp.dst_addr.port(),
// &resp.data,
// );

writer_session.write_all(&ip_packet).await?;
}
Ok::<(), BoxError>(())
};
if let Err(err) = block.await {
println!("Writer {}", err);
}
});

println!("Press enter to stop session");

let mut line = String::new();
let _ = std::io::stdin().read_line(&mut line);
println!("Shutting down session");
RUNNING.store(false, Ordering::Relaxed);
session.shutdown()?;
let _ = reader.await;
let _ = writer.await;
Ok(())
}

fn extract_udp_packet(packet: &[u8]) -> Result<NaiveUdpPacket, Error> {
use packet::{ip, udp, AsPacket, Packet};
let packet: ip::Packet<_> = packet.as_packet().map_err(|err| format!("{}", err))?;
let info: String;
match packet {
ip::Packet::V4(a) => {
let src_addr = a.source();
let dst_addr = a.destination();
let protocol = a.protocol();
let payload = a.payload();
match protocol {
ip::Protocol::Udp => {
let udp = udp::Packet::new(payload).map_err(|err| format!("{}", err))?;
let src_port = udp.source();
let dst_port = udp.destination();
let src_addr = SocketAddr::new(src_addr.into(), src_port);
let dst_addr = SocketAddr::new(dst_addr.into(), dst_port);
let data = udp.payload();
let udp_packet = NaiveUdpPacket::new(src_addr, dst_addr, data);
log::trace!("{protocol:?} {}", udp_packet);
return Ok(udp_packet);
}
_ => {
info = format!("{:?} src={}, dst={}", protocol, src_addr, dst_addr);
}
}
}
ip::Packet::V6(a) => {
info = format!("{:?}", a);
}
}
Err(info.into())
}

fn generate_random_bytes(len: usize) -> std::io::Result<Vec<u8>> {
let mut buf = vec![0u8; len];
unsafe {
let mut h_prov = 0_usize;
let null = std::ptr::null_mut();
if FALSE == CryptAcquireContextW(&mut h_prov, null, null, PROV_RSA_FULL, 0) {
return Err(std::io::Error::last_os_error());
}
if FALSE == CryptGenRandom(h_prov, buf.len() as _, buf.as_mut_ptr()) {
return Err(std::io::Error::last_os_error());
}
if FALSE == CryptReleaseContext(h_prov, 0) {
return Err(std::io::Error::last_os_error());
}
};
Ok(buf)
}
14 changes: 6 additions & 8 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
/// wintun functionality
use crate::{
error::{Error, OutOfRangeData},
handle::{SafeEvent, UnsafeHandle},
session,
util::{self, UnsafeHandle},
util::{self},
wintun_raw, Wintun,
};
use std::{
Expand All @@ -20,11 +21,7 @@ use std::{
};
use windows_sys::{
core::GUID,
Win32::{
Foundation::FALSE,
NetworkManagement::{IpHelper::ConvertLengthToIpv4Mask, Ndis::NET_LUID_LH},
System::Threading::CreateEventA,
},
Win32::NetworkManagement::{IpHelper::ConvertLengthToIpv4Mask, Ndis::NET_LUID_LH},
};

/// Wrapper around a <https://git.zx2c4.com/wintun/about/#wintun_adapter_handle>
Expand Down Expand Up @@ -157,11 +154,12 @@ impl Adapter {
if result.is_null() {
return Err("WintunStartSession failed".into());
}
let shutdown_event = unsafe { CreateEventA(std::ptr::null_mut(), FALSE, FALSE, std::ptr::null_mut()) };
// Manual reset, because we use this event once and it must fire on all threads
let shutdown_event = SafeEvent::new(true, false)?;
Ok(session::Session {
session: UnsafeHandle(result),
read_event: OnceLock::new(),
shutdown_event: UnsafeHandle(shutdown_event),
shutdown_event: Arc::new(shutdown_event),
adapter: self.clone(),
})
}
Expand Down
Loading

0 comments on commit 74593ac

Please sign in to comment.