Skip to content

Commit

Permalink
WIP: linux hotplug
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmehall committed Jun 1, 2024
1 parent 8fa83a4 commit b684798
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 23 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ env_logger = "0.10.0"
futures-lite = "1.13.0"

[target.'cfg(target_os="linux")'.dependencies]
rustix = { version = "0.38.17", features = ["fs", "event"] }
rustix = { version = "0.38.17", features = ["fs", "event", "net"] }

[target.'cfg(target_os="windows")'.dependencies]
windows-sys = { version = "0.48.0", features = ["Win32_Devices_Usb", "Win32_Devices_DeviceAndDriverInstallation", "Win32_Foundation", "Win32_Devices_Properties", "Win32_Storage_FileSystem", "Win32_Security", "Win32_System_IO", "Win32_System_Registry", "Win32_System_Com"] }
Expand All @@ -30,3 +30,6 @@ windows-sys = { version = "0.48.0", features = ["Win32_Devices_Usb", "Win32_Devi
core-foundation = "0.9.3"
core-foundation-sys = "0.8.4"
io-kit-sys = "0.4.0"

[patch.crates-io]
rustix = { git = "https://github.com/kevinmehall/rustix.git", rev = "9b432db1b4ed6cd8ec58fd88815a785a03300ebe" }
13 changes: 10 additions & 3 deletions src/platform/linux_usbfs/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use std::{
};

use log::{debug, error};
use rustix::event::epoll;
use rustix::fd::AsFd;
use rustix::{
fd::{AsRawFd, FromRawFd, OwnedFd},
fs::{Mode, OFlags},
Expand All @@ -22,6 +24,7 @@ use super::{
usbfs::{self, Urb},
SysfsPath,
};
use crate::platform::linux_usbfs::events::Watch;
use crate::{
descriptors::{parse_concatenated_config_descriptors, DESCRIPTOR_LEN_DEVICE},
transfer::{
Expand Down Expand Up @@ -61,7 +64,11 @@ impl LinuxDevice {
// because there's no Arc::try_new_cyclic
let mut events_err = None;
let arc = Arc::new_cyclic(|weak| {
let res = events::register(&fd, weak.clone());
let res = events::register(
fd.as_fd(),
Watch::Device(weak.clone()),
epoll::EventFlags::OUT,
);
let events_id = *res.as_ref().unwrap_or(&usize::MAX);
events_err = res.err();
LinuxDevice {
Expand Down Expand Up @@ -109,7 +116,7 @@ impl LinuxDevice {
// only returns ENODEV after all events are received, so unregister to
// keep the event thread from spinning because we won't receive further events.
// The drop impl will try to unregister again, but that's ok.
events::unregister_fd(&self.fd);
events::unregister_fd(self.fd.as_fd());
}
Err(e) => {
error!("Unexpected error {e} from REAPURBNDELAY");
Expand Down Expand Up @@ -282,7 +289,7 @@ impl LinuxDevice {
impl Drop for LinuxDevice {
fn drop(&mut self) {
debug!("Closing device {}", self.events_id);
events::unregister(&self.fd, self.events_id)
events::unregister(self.fd.as_fd(), self.events_id)
}
}

Expand Down
86 changes: 70 additions & 16 deletions src/platform/linux_usbfs/events.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
use atomic_waker::AtomicWaker;
/// Epoll based event loop for Linux.
///
/// Launches a thread when opening the first device that polls
/// for events on usbfs devices and arbitrary file descriptors
/// (used for udev hotplug).
///
/// ### Why not share an event loop with `tokio` or `async-io`?
///
/// This event loop will call USBFS_REAP_URB on the event thread and
/// dispatch to the transfer's waker directly. Since all USB transfers
/// on a device use the same file descriptor, putting USB-specific
/// dispatch in the event loop avoids additonal synchronization.
use once_cell::sync::OnceCell;
use rustix::{
event::epoll::{self, EventData},
fd::OwnedFd,
event::epoll::{self, EventData, EventFlags},
fd::{AsFd, BorrowedFd, OwnedFd},
io::retry_on_intr,
};
use slab::Slab;
use std::{
sync::{Mutex, Weak},
io,
sync::{Arc, Mutex, Weak},
task::Waker,
thread,
};

Expand All @@ -15,38 +30,43 @@ use crate::Error;
use super::Device;

static EPOLL_FD: OnceCell<OwnedFd> = OnceCell::new();
static DEVICES: Mutex<Slab<Weak<Device>>> = Mutex::new(Slab::new());
static WATCHES: Mutex<Slab<Watch>> = Mutex::new(Slab::new());

pub(super) fn register(usb_fd: &OwnedFd, weak_device: Weak<Device>) -> Result<usize, Error> {
pub(super) enum Watch {
Device(Weak<Device>),
Fd(Arc<AtomicWaker>),
}

pub(super) fn register(fd: BorrowedFd, watch: Watch, flags: EventFlags) -> Result<usize, Error> {
let mut start_thread = false;
let epoll_fd = EPOLL_FD.get_or_try_init(|| {
start_thread = true;
epoll::create(epoll::CreateFlags::CLOEXEC)
})?;

let id = {
let mut devices = DEVICES.lock().unwrap();
devices.insert(weak_device)
let mut watches = WATCHES.lock().unwrap();
watches.insert(watch)
};

if start_thread {
thread::spawn(event_loop);
}

let data = EventData::new_u64(id as u64);
epoll::add(epoll_fd, usb_fd, data, epoll::EventFlags::OUT)?;
epoll::add(epoll_fd, fd, data, flags)?;
Ok(id)
}

pub(super) fn unregister_fd(fd: &OwnedFd) {
pub(super) fn unregister_fd(fd: BorrowedFd) {
let epoll_fd = EPOLL_FD.get().unwrap();
epoll::delete(epoll_fd, fd).ok();
}

pub(super) fn unregister(fd: &OwnedFd, events_id: usize) {
pub(super) fn unregister(fd: BorrowedFd, events_id: usize) {
let epoll_fd = EPOLL_FD.get().unwrap();
epoll::delete(epoll_fd, fd).ok();
DEVICES.lock().unwrap().remove(events_id);
WATCHES.lock().unwrap().remove(events_id);
}

fn event_loop() {
Expand All @@ -56,13 +76,47 @@ fn event_loop() {
retry_on_intr(|| epoll::wait(epoll_fd, &mut event_list, -1)).unwrap();
for event in &event_list {
let key = event.data.u64() as usize;
let device = DEVICES.lock().unwrap().get(key).and_then(|w| w.upgrade());
log::info!("event on {key}");
let lock = WATCHES.lock().unwrap();
let Some(watch) = lock.get(key) else { continue };

if let Some(device) = device {
device.handle_events();
// `device` gets dropped here. if it was the last reference, the LinuxDevice will be dropped.
// That will unregister its fd, so it's important that DEVICES is unlocked here, or we'd deadlock.
match watch {
Watch::Device(w) => {
if let Some(device) = w.upgrade() {
drop(lock);
device.handle_events();
// `device` gets dropped here. if it was the last reference, the LinuxDevice will be dropped.
// That will unregister its fd, so it's important that WATCHES is unlocked here, or we'd deadlock.
}
}
Watch::Fd(waker) => waker.wake(),
}
}
}
}

pub(crate) struct Async<T> {
pub(crate) inner: T,
waker: Arc<AtomicWaker>,
id: usize,
}

impl<T: AsFd> Async<T> {
pub fn new(inner: T) -> Result<Self, io::Error> {
let waker = Arc::new(AtomicWaker::new());
let id = register(inner.as_fd(), Watch::Fd(waker.clone()), EventFlags::empty())?;
Ok(Async { inner, id, waker })
}

pub fn register(&self, waker: &Waker) -> Result<(), io::Error> {
self.waker.register(waker);
let epoll_fd = EPOLL_FD.get().unwrap();
epoll::modify(
epoll_fd,
self.inner.as_fd(),
EventData::new_u64(self.id as u64),
EventFlags::ONESHOT | EventFlags::IN,
)?;
Ok(())
}
}
138 changes: 135 additions & 3 deletions src/platform/linux_usbfs/hotplug.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,147 @@
use std::{io::ErrorKind, task::Poll};
use log::{debug, error, warn};
use rustix::{
fd::{AsFd, OwnedFd},
net::{
bind,
netlink::{self, SocketAddrNetlink},
recvfrom, socket_with, AddressFamily, RecvFlags, SocketAddrAny, SocketFlags, SocketType,
},
};
use std::{io::ErrorKind, os::unix::prelude::BorrowedFd, path::Path, task::Poll};

use crate::{hotplug::HotplugEvent, Error};

pub(crate) struct LinuxHotplugWatch {}
use super::{enumeration::probe_device, events::Async, SysfsPath};

const UDEV_MAGIC: &[u8; 12] = b"libudev\0\xfe\xed\xca\xfe";
const UDEV_MULTICAST_GROUP: u32 = 1 << 1;

pub(crate) struct LinuxHotplugWatch {
fd: Async<OwnedFd>,
}

impl LinuxHotplugWatch {
pub(crate) fn new() -> Result<Self, Error> {
Err(Error::new(ErrorKind::Unsupported, "Not implemented."))
let fd = socket_with(
AddressFamily::NETLINK,
SocketType::RAW,
SocketFlags::CLOEXEC,
Some(netlink::KOBJECT_UEVENT),
)?;
bind(&fd, &SocketAddrNetlink::new(0, UDEV_MULTICAST_GROUP))?;
Ok(LinuxHotplugWatch {
fd: Async::new(fd)?,
})
}

pub(crate) fn poll_next(&mut self, cx: &mut std::task::Context<'_>) -> Poll<HotplugEvent> {
if let Some(event) = try_receive_event(self.fd.inner.as_fd()) {
return Poll::Ready(event);
}

if let Err(e) = self.fd.register(cx.waker()) {
log::error!("failed to register udev socket with epoll: {e}");
}

Poll::Pending
}
}

fn try_receive_event(fd: BorrowedFd) -> Option<HotplugEvent> {
let mut buf = [0; 8192];

match recvfrom(fd, &mut buf, RecvFlags::DONTWAIT) {
// udev messages will normally be sent to a multicast group, which only
// root can send to. Reject unicast messages that may be from anywhere.
Ok((size, Some(SocketAddrAny::Netlink(nl)))) if nl.groups() == UDEV_MULTICAST_GROUP => {
parse_packet(&buf[..size])
}
Ok((_, src)) => {
warn!("udev netlink socket received message from {src:?}");
None
}
Err(e) if e.kind() == ErrorKind::WouldBlock => None,
Err(e) => {
error!("udev netlink socket recvfrom failed with {e}");
None
}
}
}

fn parse_packet(buf: &[u8]) -> Option<HotplugEvent> {
if buf.len() < 24 {
error!("packet too short: {buf:x?}");
return None;
}

if !buf.starts_with(UDEV_MAGIC) {
error!("packet does not start with expected header: {buf:x?}");
return None;
}

let properties_off = u32::from_ne_bytes(buf[16..20].try_into().unwrap()) as usize;
let properties_len = u32::from_ne_bytes(buf[20..24].try_into().unwrap()) as usize;
let Some(properties_buf) = buf.get(properties_off..properties_off + properties_len) else {
error!("properties offset={properties_off} length={properties_len} exceeds buffer length {len}", len = buf.len());
return None;
};

let mut is_add = None;
let mut busnum = None;
let mut devnum = None;
let mut devpath = None;

for (k, v) in parse_properties(properties_buf) {
debug!("uevent property {k} = {v}");
match k {
"SUBSYSTEM" if v != "usb" => return None,
"DEVTYPE" if v != "usb_device" => return None,
"ACTION" => {
is_add = Some(match v {
"add" => true,
"remove" => false,
_ => return None,
});
}
"BUSNUM" => {
busnum = v.parse::<u8>().ok();
}
"DEVNUM" => {
devnum = v.parse::<u8>().ok();
}
"DEVPATH" => {
devpath = Some(v);
}
_ => {}
}
}

let is_add = is_add?;
let busnum = busnum?;
let devnum = devnum?;
let devpath = devpath?;

if is_add {
let path = Path::new("/sys/").join(devpath.trim_start_matches('/'));
match probe_device(SysfsPath(path.clone())) {
Ok(d) => Some(HotplugEvent::Connected(d)),
Err(e) => {
error!("Failed to probe device {path:?}: {e}");
None
}
}
} else {
Some(HotplugEvent::Disconnected(crate::DeviceId(
super::DeviceId {
bus: busnum,
addr: devnum,
},
)))
}
}

/// Split nul-separated key=value pairs
fn parse_properties(buf: &[u8]) -> impl Iterator<Item = (&str, &str)> + '_ {
buf.split(|b| b == &0)
.filter_map(|entry| std::str::from_utf8(entry).ok()?.split_once('='))
}

0 comments on commit b684798

Please sign in to comment.