diff --git a/crates/muvm/src/guest/bin/muvm-guest.rs b/crates/muvm/src/guest/bin/muvm-guest.rs index fd5de38..ae3d2a6 100644 --- a/crates/muvm/src/guest/bin/muvm-guest.rs +++ b/crates/muvm/src/guest/bin/muvm-guest.rs @@ -6,6 +6,8 @@ use std::process::Command; use std::{cmp, env, fs, thread}; use anyhow::{Context, Result}; +use muvm::guest::bridge::pipewire::start_pwbridge; +use muvm::guest::bridge::x11::start_x11bridge; use muvm::guest::fex::setup_fex; use muvm::guest::hidpipe::start_hidpipe; use muvm::guest::mount::mount_filesystems; @@ -14,7 +16,6 @@ use muvm::guest::server::server_main; use muvm::guest::socket::setup_socket_proxy; use muvm::guest::user::setup_user; use muvm::guest::x11::setup_x11_forwarding; -use muvm::guest::x11bridge::start_x11bridge; use muvm::utils::launch::{GuestConfiguration, PULSE_SOCKET}; use nix::unistd::{Gid, Uid}; use rustix::process::{getrlimit, setrlimit, Resource}; @@ -107,6 +108,12 @@ fn main() -> Result<()> { } }); + thread::spawn(|| { + if catch_unwind(start_pwbridge).is_err() { + eprintln!("pwbridge thread crashed, pipewire passthrough will no longer function"); + } + }); + let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async { server_main(options.command.command, options.command.command_args).await }) } diff --git a/crates/muvm/src/guest/bridge/common.rs b/crates/muvm/src/guest/bridge/common.rs new file mode 100644 index 0000000..51f35b3 --- /dev/null +++ b/crates/muvm/src/guest/bridge/common.rs @@ -0,0 +1,980 @@ +use anyhow::Result; +use nix::errno::Errno; +use nix::libc::{c_int, c_void, off_t, O_RDWR}; +use nix::sys::epoll::{Epoll, EpollCreateFlags, EpollEvent, EpollFlags, EpollTimeout}; +use nix::sys::mman::{mmap, munmap, MapFlags, ProtFlags}; +use nix::sys::socket::{recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, RecvMsg}; +use nix::unistd::read; +use nix::{cmsg_space, ioctl_readwrite, ioctl_write_ptr}; +use std::cell::RefCell; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::fs::File; +use std::io::{IoSlice, IoSliceMut, Read, Write}; +use std::net::{TcpListener, TcpStream}; +use std::num::NonZeroUsize; +use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::ptr::NonNull; +use std::rc::{Rc, Weak}; +use std::{env, fs, mem, slice, thread}; + +pub const PAGE_SIZE: usize = 4096; + +const VIRTGPU_CONTEXT_PARAM_CAPSET_ID: u64 = 0x0001; +const VIRTGPU_CONTEXT_PARAM_NUM_RINGS: u64 = 0x0002; +const VIRTGPU_CONTEXT_PARAM_POLL_RINGS_MASK: u64 = 0x0003; +const CAPSET_CROSS_DOMAIN: u64 = 5; +const CROSS_DOMAIN_CHANNEL_RING: u32 = 1; +const VIRTGPU_BLOB_MEM_GUEST: u32 = 0x0001; +const VIRTGPU_BLOB_MEM_HOST3D: u32 = 0x0002; +const VIRTGPU_BLOB_FLAG_USE_MAPPABLE: u32 = 0x0001; +const VIRTGPU_BLOB_FLAG_USE_SHAREABLE: u32 = 0x0002; +const VIRTGPU_EVENT_FENCE_SIGNALED: u32 = 0x90000000; +const CROSS_DOMAIN_ID_TYPE_VIRTGPU_BLOB: u32 = 1; + +#[repr(C)] +#[derive(Default)] +struct DrmVirtgpuContextInit { + num_params: u32, + pad: u32, + ctx_set_params: u64, +} + +#[repr(C)] +#[derive(Default)] +struct DrmVirtgpuContextSetParam { + param: u64, + value: u64, +} + +#[rustfmt::skip] +ioctl_readwrite!(drm_virtgpu_context_init, 'd', 0x40 + 0xb, DrmVirtgpuContextInit); + +#[repr(C)] +#[derive(Default)] +struct DrmVirtgpuResourceCreateBlob { + blob_mem: u32, + blob_flags: u32, + bo_handle: u32, + res_handle: u32, + size: u64, + pad: u32, + cmd_size: u32, + cmd: u64, + blob_id: u64, +} + +#[rustfmt::skip] +ioctl_readwrite!(drm_virtgpu_resource_create_blob, 'd', 0x40 + 0xa, DrmVirtgpuResourceCreateBlob); + +#[repr(C)] +#[derive(Default)] +struct DrmVirtgpuMap { + offset: u64, + handle: u32, + pad: u32, +} + +#[rustfmt::skip] +ioctl_readwrite!(drm_virtgpu_map, 'd', 0x40 + 0x1, DrmVirtgpuMap); + +#[repr(C)] +#[derive(Default)] +struct DrmGemClose { + handle: u32, + pad: u32, +} + +impl DrmGemClose { + fn new(handle: u32) -> DrmGemClose { + DrmGemClose { + handle, + ..DrmGemClose::default() + } + } +} + +#[rustfmt::skip] +ioctl_write_ptr!(drm_gem_close, 'd', 0x9, DrmGemClose); + +#[repr(C)] +#[derive(Default)] +struct DrmPrimeHandle { + handle: u32, + flags: u32, + fd: i32, +} + +#[rustfmt::skip] +ioctl_readwrite!(drm_prime_handle_to_fd, 'd', 0x2d, DrmPrimeHandle); +#[rustfmt::skip] +ioctl_readwrite!(drm_prime_fd_to_handle, 'd', 0x2e, DrmPrimeHandle); + +#[repr(C)] +#[derive(Default)] +struct DrmEvent { + ty: u32, + length: u32, +} + +const VIRTGPU_EXECBUF_RING_IDX: u32 = 0x04; +#[repr(C)] +#[derive(Default)] +struct DrmVirtgpuExecbuffer { + flags: u32, + size: u32, + command: u64, + bo_handles: u64, + num_bo_handles: u32, + fence_fd: i32, + ring_idx: u32, + pad: u32, +} + +#[rustfmt::skip] +ioctl_readwrite!(drm_virtgpu_execbuffer, 'd', 0x40 + 0x2, DrmVirtgpuExecbuffer); + +#[repr(C)] +#[derive(Default)] +struct DrmVirtgpuResourceInfo { + bo_handle: u32, + res_handle: u32, + size: u32, + blob_mem: u32, +} + +#[rustfmt::skip] +ioctl_readwrite!(drm_virtgpu_resource_info, 'd', 0x40 + 0x5, DrmVirtgpuResourceInfo); + +#[repr(C)] +#[derive(Default)] +pub struct CrossDomainHeader { + pub cmd: u8, + pub fence_ctx_idx: u8, + pub cmd_size: u16, + pub pad: u32, +} + +impl CrossDomainHeader { + pub fn new(cmd: u8, cmd_size: u16) -> CrossDomainHeader { + CrossDomainHeader { + cmd, + cmd_size, + ..CrossDomainHeader::default() + } + } +} + +const CROSS_DOMAIN_CMD_INIT: u8 = 1; +const CROSS_DOMAIN_CMD_POLL: u8 = 3; +const CROSS_DOMAIN_PROTOCOL_VERSION: u32 = 1; +#[repr(C)] +#[derive(Default)] +struct CrossDomainInit { + hdr: CrossDomainHeader, + query_ring_id: u32, + channel_ring_id: u32, + channel_type: u32, + protocol_version: u32, +} + +#[repr(C)] +#[derive(Default)] +struct CrossDomainPoll { + hdr: CrossDomainHeader, + pad: u64, +} + +impl CrossDomainPoll { + fn new() -> CrossDomainPoll { + CrossDomainPoll { + hdr: CrossDomainHeader::new( + CROSS_DOMAIN_CMD_POLL, + mem::size_of::() as u16, + ), + ..CrossDomainPoll::default() + } + } +} + +const CROSS_DOMAIN_MAX_IDENTIFIERS: usize = 28; +const CROSS_DOMAIN_CMD_SEND: u8 = 4; +const CROSS_DOMAIN_CMD_RECEIVE: u8 = 5; + +#[repr(C)] +struct CrossDomainSendReceive { + hdr: CrossDomainHeader, + num_identifiers: u32, + opaque_data_size: u32, + identifiers: [u32; CROSS_DOMAIN_MAX_IDENTIFIERS], + identifier_types: [u32; CROSS_DOMAIN_MAX_IDENTIFIERS], + identifier_sizes: [u32; CROSS_DOMAIN_MAX_IDENTIFIERS], + data: T, +} + +const CROSS_DOMAIN_SR_TAIL_SIZE: usize = PAGE_SIZE - mem::size_of::>(); + +pub struct GpuRing { + handle: u32, + res_id: u32, + pub address: *mut c_void, + fd: OwnedFd, +} + +impl GpuRing { + fn new(fd: &OwnedFd) -> Result { + let fd = fd.try_clone().unwrap(); + let mut create_blob = DrmVirtgpuResourceCreateBlob { + size: PAGE_SIZE as u64, + blob_mem: VIRTGPU_BLOB_MEM_GUEST, + blob_flags: VIRTGPU_BLOB_FLAG_USE_MAPPABLE, + ..DrmVirtgpuResourceCreateBlob::default() + }; + unsafe { + drm_virtgpu_resource_create_blob(fd.as_raw_fd() as c_int, &mut create_blob)?; + } + let mut map = DrmVirtgpuMap { + handle: create_blob.bo_handle, + ..DrmVirtgpuMap::default() + }; + unsafe { + drm_virtgpu_map(fd.as_raw_fd() as c_int, &mut map)?; + } + let ptr = unsafe { + mmap( + None, + NonZeroUsize::new(PAGE_SIZE).unwrap(), + ProtFlags::PROT_READ | ProtFlags::PROT_WRITE, + MapFlags::MAP_SHARED, + &fd, + map.offset as off_t, + )? + .as_ptr() + }; + Ok(GpuRing { + fd, + handle: create_blob.bo_handle, + res_id: create_blob.res_handle, + address: ptr, + }) + } +} + +impl Drop for GpuRing { + fn drop(&mut self) { + unsafe { + munmap(NonNull::new(self.address).unwrap(), PAGE_SIZE).unwrap(); + let close = DrmGemClose::new(self.handle); + drm_gem_close(self.fd.as_raw_fd() as c_int, &close).unwrap(); + } + } +} + +pub struct Context { + pub fd: OwnedFd, + pub channel_ring: GpuRing, + query_ring: GpuRing, +} + +impl Context { + fn new(channel_type: u32) -> Result { + let mut params = [ + DrmVirtgpuContextSetParam { + param: VIRTGPU_CONTEXT_PARAM_CAPSET_ID, + value: CAPSET_CROSS_DOMAIN, + }, + DrmVirtgpuContextSetParam { + param: VIRTGPU_CONTEXT_PARAM_NUM_RINGS, + value: 2, + }, + DrmVirtgpuContextSetParam { + param: VIRTGPU_CONTEXT_PARAM_POLL_RINGS_MASK, + value: 1 << CROSS_DOMAIN_CHANNEL_RING, + }, + ]; + let mut init = DrmVirtgpuContextInit { + num_params: 3, + pad: 0, + ctx_set_params: params.as_mut_ptr() as u64, + }; + let fd: OwnedFd = File::options() + .write(true) + .read(true) + .open("/dev/dri/renderD128")? + .into(); + unsafe { + drm_virtgpu_context_init(fd.as_raw_fd() as c_int, &mut init)?; + } + + let query_ring = GpuRing::new(&fd)?; + let channel_ring = GpuRing::new(&fd)?; + let this = Context { + fd, + query_ring, + channel_ring, + }; + let init_cmd = CrossDomainInit { + hdr: CrossDomainHeader::new( + CROSS_DOMAIN_CMD_INIT, + mem::size_of::() as u16, + ), + query_ring_id: this.query_ring.res_id, + channel_ring_id: this.channel_ring.res_id, + channel_type, + protocol_version: CROSS_DOMAIN_PROTOCOL_VERSION, + }; + this.submit_cmd(&init_cmd, mem::size_of::(), None, None)?; + this.poll_cmd()?; + Ok(this) + } + pub fn submit_cmd( + &self, + cmd: &T, + cmd_size: usize, + ring_idx: Option, + ring_handle: Option, + ) -> Result<()> { + submit_cmd_raw( + self.fd.as_raw_fd() as c_int, + cmd, + cmd_size, + ring_idx, + ring_handle, + ) + } + fn poll_cmd(&self) -> Result<()> { + let cmd = CrossDomainPoll::new(); + self.submit_cmd( + &cmd, + mem::size_of::(), + Some(CROSS_DOMAIN_CHANNEL_RING), + None, + ) + } +} + +pub fn submit_cmd_raw( + fd: c_int, + cmd: &T, + cmd_size: usize, + ring_idx: Option, + ring_handle: Option, +) -> Result<()> { + let cmd_buf = cmd as *const T as *const u8; + let mut exec = DrmVirtgpuExecbuffer { + command: cmd_buf as u64, + size: cmd_size as u32, + ..DrmVirtgpuExecbuffer::default() + }; + if let Some(ring_idx) = ring_idx { + exec.ring_idx = ring_idx; + exec.flags = VIRTGPU_EXECBUF_RING_IDX; + } + let ring_handle = &ring_handle; + if let Some(ring_handle) = ring_handle { + exec.bo_handles = ring_handle as *const u32 as u64; + exec.num_bo_handles = 1; + } + unsafe { + drm_virtgpu_execbuffer(fd, &mut exec)?; + } + if ring_handle.is_some() { + unimplemented!(); + } + Ok(()) +} + +struct DebugLoopInner { + ls_remote: TcpStream, + ls_local: TcpStream, +} + +struct DebugLoop(Option); + +impl DebugLoop { + fn new() -> DebugLoop { + if !env::var("X11VG_DEBUG") + .map(|x| x == "1") + .unwrap_or_default() + { + return DebugLoop(None); + } + let ls_remote_l = TcpListener::bind(("0.0.0.0", 6001)).unwrap(); + let ls_local_jh = thread::spawn(|| TcpStream::connect(("0.0.0.0", 6001)).unwrap()); + let ls_remote = ls_remote_l.accept().unwrap().0; + let ls_local = ls_local_jh.join().unwrap(); + DebugLoop(Some(DebugLoopInner { + ls_remote, + ls_local, + })) + } + fn loop_remote(&mut self, data: &[u8]) { + if let Some(this) = &mut self.0 { + this.ls_remote.write_all(data).unwrap(); + let mut trash = vec![0; data.len()]; + this.ls_local.read_exact(&mut trash).unwrap(); + } + } + fn loop_local(&mut self, data: &[u8]) { + if let Some(this) = &mut self.0 { + this.ls_local.write_all(data).unwrap(); + let mut trash = vec![0; data.len()]; + this.ls_remote.read_exact(&mut trash).unwrap(); + } + } +} + +pub struct SendPacket { + pub data: Vec, + pub fds: Vec, +} + +pub trait MessageResourceFinalizer { + type Handler: ProtocolHandler; + fn finalize(self, client: &mut Client) -> Result<()>; +} + +#[derive(Debug)] +pub struct CrossDomainResource { + pub identifier: u32, + pub identifier_type: u32, + pub identifier_size: u32, +} + +pub enum StreamSendResult { + WantMore, + Processed { + consumed_bytes: usize, + resources: Vec, + finalizers: Vec, + }, +} + +pub enum StreamRecvResult { + WantMore, + Processed { + consumed_bytes: usize, + fds: Vec, + }, +} + +pub trait ProtocolHandler: Sized { + type ResourceFinalizer: MessageResourceFinalizer; + const CHANNEL_TYPE: u32; + fn new() -> Self; + fn process_recv_stream( + this: &mut Client, + data: &[u8], + resources: &mut VecDeque, + ) -> Result; + fn process_send_stream( + this: &mut Client, + data: &mut [u8], + ) -> Result>; + fn process_vgpu_extra(this: &mut Client, cmd: u8) -> Result<()>; +} + +pub struct Client<'a, P: ProtocolHandler> { + // protocol_handler must be dropped before gpu_ctx, so it goes first + pub protocol_handler: P, + pub gpu_ctx: Context, + pub socket: UnixStream, + reply_tail: usize, + reply_head: Vec, + request_tail: usize, + request_head: Vec, + pub request_fds: Vec, + debug_loop: DebugLoop, + pub send_queue: VecDeque, + pub sub_poll: SubPoll<'a, P>, +} + +#[derive(Debug)] +enum ClientEvent { + None, + StartSend, + StopSend, + Close, +} + +pub struct GemHandleFinalizer(u32); + +impl GemHandleFinalizer { + pub fn finalize(self, client: &mut Client) -> Result<()> { + unsafe { + let close = DrmGemClose::new(self.0); + drm_gem_close(client.gpu_ctx.fd.as_raw_fd() as c_int, &close)?; + } + Ok(()) + } +} + +impl<'a, P: ProtocolHandler> Client<'a, P> { + fn new( + socket: UnixStream, + protocol_handler: P, + sub_poll: SubPoll<'a, P>, + ) -> Result>>> { + let this = Rc::new(RefCell::new(Client { + socket, + protocol_handler, + gpu_ctx: Context::new(P::CHANNEL_TYPE)?, + reply_tail: 0, + reply_head: Vec::new(), + request_tail: 0, + request_head: Vec::new(), + request_fds: Vec::new(), + debug_loop: DebugLoop::new(), + send_queue: VecDeque::new(), + sub_poll, + })); + { + let mut borrow = this.borrow_mut(); + let borrow = &mut *borrow; + borrow.sub_poll.my_client = Rc::downgrade(&this); + borrow + .sub_poll + .add(borrow.socket.as_fd(), EpollFlags::EPOLLIN); + borrow + .sub_poll + .add(borrow.gpu_ctx.fd.as_fd(), EpollFlags::EPOLLIN); + } + Ok(this) + } + fn process_socket(&mut self, events: EpollFlags) -> Result { + if events.contains(EpollFlags::EPOLLIN) { + let queue_empty = self.send_queue.is_empty(); + if self.process_socket_recv()? { + return Ok(ClientEvent::Close); + } + if queue_empty && !self.send_queue.is_empty() { + return Ok(ClientEvent::StartSend); + } + } + if events.contains(EpollFlags::EPOLLOUT) { + self.process_socket_send()?; + if self.send_queue.is_empty() { + return Ok(ClientEvent::StopSend); + } + } + Ok(ClientEvent::None) + } + + fn process_socket_send(&mut self) -> Result<()> { + let mut msg = self.send_queue.pop_front().unwrap(); + let fds: Vec = msg.fds.iter().map(|a| a.as_raw_fd()).collect(); + let cmsgs = if fds.is_empty() { + Vec::new() + } else { + vec![ControlMessage::ScmRights(&fds)] + }; + match sendmsg::<()>( + self.socket.as_raw_fd(), + &[IoSlice::new(&msg.data)], + &cmsgs, + MsgFlags::empty(), + None, + ) { + Ok(sent) => { + if sent < msg.data.len() { + msg.data = msg.data.split_off(sent); + self.send_queue.push_front(SendPacket { + data: msg.data.split_off(sent), + fds: Vec::new(), + }); + } + }, + Err(Errno::EAGAIN) => self.send_queue.push_front(msg), + Err(e) => return Err(e.into()), + }; + Ok(()) + } + fn process_socket_recv(&mut self) -> Result { + let mut fdspace = cmsg_space!([RawFd; CROSS_DOMAIN_MAX_IDENTIFIERS]); + let mut ring_msg = CrossDomainSendReceive { + hdr: CrossDomainHeader::new(CROSS_DOMAIN_CMD_SEND, 0), + num_identifiers: 0, + opaque_data_size: 0, + identifiers: [0; CROSS_DOMAIN_MAX_IDENTIFIERS], + identifier_types: [0; CROSS_DOMAIN_MAX_IDENTIFIERS], + identifier_sizes: [0; CROSS_DOMAIN_MAX_IDENTIFIERS], + data: [0u8; CROSS_DOMAIN_SR_TAIL_SIZE], + }; + let recv_buf = if self.request_tail > 0 { + assert!(self.request_head.is_empty()); + assert!(self.request_fds.is_empty()); + let len = self.request_tail.min(ring_msg.data.len()); + &mut ring_msg.data[..len] + } else { + let head_len = self.request_head.len(); + ring_msg.data[..head_len].copy_from_slice(&self.request_head); + self.request_head.clear(); + &mut ring_msg.data[head_len..] + }; + let mut ioslice = [IoSliceMut::new(recv_buf)]; + let msg: RecvMsg<()> = recvmsg( + self.socket.as_raw_fd(), + &mut ioslice, + Some(&mut fdspace), + MsgFlags::empty(), + )?; + for cmsg in msg.cmsgs()? { + match cmsg { + ControlMessageOwned::ScmRights(rf) => { + for fd in rf { + self.request_fds.push(unsafe { OwnedFd::from_raw_fd(fd) }); + } + }, + _ => unimplemented!(), + } + } + let len = if let Some(iov) = msg.iovs().next() { + iov.len() + } else { + return Ok(true); + }; + let buf = &mut ring_msg.data[..len]; + self.debug_loop.loop_local(buf); + let mut resources = Vec::new(); + let mut finalizers = Vec::new(); + if self.request_tail > 0 { + assert!(self.request_fds.is_empty()); + self.request_tail -= buf.len(); + } else { + let mut ptr = 0; + while ptr < buf.len() { + match P::process_send_stream(self, &mut buf[ptr..])? { + StreamSendResult::WantMore => break, + StreamSendResult::Processed { + resources: rs, + finalizers: fns, + consumed_bytes: msg_size, + } => { + ptr += msg_size; + resources.extend(rs); + finalizers.extend(fns); + }, + } + } + if ptr < buf.len() { + self.request_head = buf[ptr..].to_vec(); + } else { + self.request_tail = ptr - buf.len(); + } + } + if !self.request_head.is_empty() { + assert_eq!(self.request_tail, 0); + } + let send_len = buf.len() - self.request_head.len(); + let size = mem::size_of::>() + send_len; + ring_msg.opaque_data_size = send_len as u32; + ring_msg.hdr.cmd_size = size as u16; + ring_msg.num_identifiers = resources.len() as u32; + for (i, res) in resources.into_iter().enumerate() { + ring_msg.identifiers[i] = res.identifier; + ring_msg.identifier_types[i] = res.identifier_type; + ring_msg.identifier_sizes[i] = res.identifier_size; + } + self.gpu_ctx.submit_cmd(&ring_msg, size, None, None)?; + for fin in finalizers { + fin.finalize(self)?; + } + Ok(false) + } + pub fn vgpu_id_from_prime( + &mut self, + fd: OwnedFd, + ) -> Result<(CrossDomainResource, GemHandleFinalizer)> { + let mut to_handle = DrmPrimeHandle { + fd: fd.as_raw_fd(), + ..DrmPrimeHandle::default() + }; + unsafe { + drm_prime_fd_to_handle(self.gpu_ctx.fd.as_raw_fd() as c_int, &mut to_handle)?; + } + let mut res_info = DrmVirtgpuResourceInfo { + bo_handle: to_handle.handle, + ..DrmVirtgpuResourceInfo::default() + }; + unsafe { + drm_virtgpu_resource_info(self.gpu_ctx.fd.as_raw_fd() as c_int, &mut res_info)?; + } + Ok(( + CrossDomainResource { + identifier: res_info.res_handle, + identifier_type: CROSS_DOMAIN_ID_TYPE_VIRTGPU_BLOB, + identifier_size: 0, + }, + GemHandleFinalizer(to_handle.handle), + )) + } + + fn process_vgpu(&mut self) -> Result { + let mut evt = DrmEvent::default(); + read(self.gpu_ctx.fd.as_raw_fd(), unsafe { + slice::from_raw_parts_mut( + &mut evt as *mut DrmEvent as *mut u8, + mem::size_of::(), + ) + })?; + assert_eq!(evt.ty, VIRTGPU_EVENT_FENCE_SIGNALED); + let cmd = unsafe { + (self.gpu_ctx.channel_ring.address as *const CrossDomainHeader) + .as_ref() + .unwrap() + .cmd + }; + match cmd { + CROSS_DOMAIN_CMD_RECEIVE => { + let recv = unsafe { + (self.gpu_ctx.channel_ring.address + as *const CrossDomainSendReceive<[u8; CROSS_DOMAIN_SR_TAIL_SIZE]>) + .as_ref() + .unwrap() + }; + if recv.opaque_data_size == 0 { + return Ok(true); + } + self.process_receive(recv)?; + }, + cmd => P::process_vgpu_extra(self, cmd)?, + }; + self.gpu_ctx.poll_cmd()?; + Ok(false) + } + pub fn virtgpu_id_to_prime(&mut self, rsc: CrossDomainResource) -> Result { + let mut create_blob = DrmVirtgpuResourceCreateBlob { + blob_mem: VIRTGPU_BLOB_MEM_HOST3D, + size: rsc.identifier_size as u64, + blob_id: rsc.identifier as u64, + blob_flags: VIRTGPU_BLOB_FLAG_USE_MAPPABLE | VIRTGPU_BLOB_FLAG_USE_SHAREABLE, + ..DrmVirtgpuResourceCreateBlob::default() + }; + unsafe { + drm_virtgpu_resource_create_blob( + self.gpu_ctx.fd.as_raw_fd() as c_int, + &mut create_blob, + )?; + } + let mut to_fd = DrmPrimeHandle { + handle: create_blob.bo_handle, + flags: O_RDWR as u32, + fd: -1, + }; + unsafe { + drm_prime_handle_to_fd(self.gpu_ctx.fd.as_raw_fd() as c_int, &mut to_fd)?; + let close = DrmGemClose::new(create_blob.bo_handle); + drm_gem_close(self.gpu_ctx.fd.as_raw_fd() as c_int, &close)?; + } + Ok(unsafe { OwnedFd::from_raw_fd(to_fd.fd) }) + } + fn process_receive(&mut self, recv: &CrossDomainSendReceive<[u8]>) -> Result<()> { + let mut identifiers = VecDeque::with_capacity(recv.num_identifiers as usize); + for i in 0..recv.num_identifiers as usize { + identifiers.push_back(CrossDomainResource { + identifier: recv.identifiers[i], + identifier_size: recv.identifier_sizes[i], + identifier_type: recv.identifier_types[i], + }); + } + let data = &recv.data[..(recv.opaque_data_size as usize)]; + self.debug_loop.loop_remote(data); + let data = if self.reply_tail > 0 { + assert!(self.reply_head.is_empty()); + let block = self.reply_tail.min(data.len()); + let (block_data, data) = data.split_at(block); + // If we have a reply tail, we need to send it separately. This is to ensure + // that no fds are attached to it, since libxcb cannot handle fds not + // attached to a packet header. + self.send_queue.push_back(SendPacket { + data: block_data.into(), + fds: Vec::new(), + }); + + self.reply_tail -= block; + data + } else { + data + }; + assert!(self.reply_tail == 0 || data.is_empty()); + if data.is_empty() { + assert_eq!(recv.num_identifiers, 0); + return Ok(()); + } + + let data = if self.reply_head.is_empty() { + data.to_vec() + } else { + let mut new_data = core::mem::take(&mut self.reply_head); + new_data.extend_from_slice(data); + new_data + }; + + let mut ptr = 0; + let mut owned_fds = Vec::new(); + while ptr < data.len() { + match P::process_recv_stream(self, &data[ptr..], &mut identifiers)? { + StreamRecvResult::Processed { + consumed_bytes, + fds, + } => { + ptr += consumed_bytes; + owned_fds.extend(fds); + }, + StreamRecvResult::WantMore => break, + } + } + let block = if ptr < data.len() { + let (block, next_head) = data.split_at(ptr); + self.reply_head = next_head.to_vec(); + block.to_vec() + } else { + self.reply_tail = ptr - data.len(); + data.to_vec() + }; + self.send_queue.push_back(SendPacket { + data: block, + fds: owned_fds, + }); + Ok(()) + } + fn process_epoll(&mut self, fd: u64, events: EpollFlags) { + if fd == self.socket.as_raw_fd() as u64 { + let event = self + .process_socket(events) + .map_err(|e| { + eprintln!("Client {fd} disconnected with error: {e:?}"); + e + }) + .unwrap_or(ClientEvent::Close); + match event { + ClientEvent::None => {}, + ClientEvent::StartSend => { + self.sub_poll.modify( + self.socket.as_fd(), + EpollFlags::EPOLLOUT | EpollFlags::EPOLLIN, + ); + }, + ClientEvent::StopSend => { + self.sub_poll + .modify(self.socket.as_fd(), EpollFlags::EPOLLIN); + }, + ClientEvent::Close => { + self.sub_poll.close(); + }, + } + } else if fd == self.gpu_ctx.fd.as_raw_fd() as u64 { + let queue_empty = self.send_queue.is_empty(); + let close = self + .process_vgpu() + .map_err(|e| { + eprintln!("Server {fd} disconnected with error: {e:?}"); + e + }) + .unwrap_or(true); + if close { + self.sub_poll.close(); + } else if queue_empty && !self.send_queue.is_empty() { + self.sub_poll.modify( + self.socket.as_fd(), + EpollFlags::EPOLLOUT | EpollFlags::EPOLLIN, + ); + } + } else { + unimplemented!() + } + } +} + +type ClientMap<'a, T> = Rc>>>>>; + +pub struct SubPoll<'a, T: ProtocolHandler> { + epoll: &'a Epoll, + all_clients: ClientMap<'a, T>, + my_client: Weak>>, + my_entries: HashSet, +} + +impl<'a, T: ProtocolHandler> SubPoll<'a, T> { + fn new(epoll: &'a Epoll, all_clients: ClientMap<'a, T>) -> SubPoll<'a, T> { + SubPoll { + epoll, + all_clients, + my_client: Weak::new(), + my_entries: HashSet::new(), + } + } + pub fn add(&mut self, fd: BorrowedFd, events: EpollFlags) { + let my_client = self.my_client.upgrade().unwrap(); + let mut clients = self.all_clients.borrow_mut(); + let raw = fd.as_raw_fd() as u64; + self.epoll.add(fd, EpollEvent::new(events, raw)).unwrap(); + clients.insert(raw, my_client.clone()); + self.my_entries.insert(raw); + } + pub fn modify(&mut self, fd: BorrowedFd, events: EpollFlags) { + self.epoll + .modify(fd, &mut EpollEvent::new(events, fd.as_raw_fd() as u64)) + .unwrap(); + } + pub fn remove(&mut self, fd: BorrowedFd) { + let mut clients = self.all_clients.borrow_mut(); + let raw = fd.as_raw_fd() as u64; + self.epoll.delete(fd).unwrap(); + self.my_entries.remove(&raw); + clients.remove(&raw); + } + fn close(&mut self) { + let mut clients = self.all_clients.borrow_mut(); + for entry in self.my_entries.drain() { + clients.remove(&entry); + } + // No need to remove from epoll, fds get automatically removed on close. + } +} + +pub fn bridge_loop(sock_path: &str) { + let epoll = Epoll::new(EpollCreateFlags::empty()).unwrap(); + _ = fs::remove_file(sock_path); + let listen_sock = UnixListener::bind(sock_path).unwrap(); + epoll + .add( + &listen_sock, + EpollEvent::new(EpollFlags::EPOLLIN, listen_sock.as_raw_fd() as u64), + ) + .unwrap(); + let clients = Rc::new(RefCell::new(HashMap::>>>::new())); + loop { + let mut evts = [EpollEvent::empty(); 16]; + let count = match epoll.wait(&mut evts, EpollTimeout::NONE) { + Err(Errno::EINTR) | Ok(0) => continue, + a => a.unwrap(), + }; + for evt in &evts[..count.min(evts.len())] { + let fd = evt.data(); + let events = evt.events(); + if fd == listen_sock.as_raw_fd() as u64 { + let res = listen_sock.accept(); + if res.is_err() { + eprintln!( + "Failed to accept a connection, error: {:?}", + res.unwrap_err() + ); + continue; + } + let stream = res.unwrap().0; + stream.set_nonblocking(true).unwrap(); + let sub_poll = SubPoll::new(&epoll, clients.clone()); + Client::new(stream, T::new(), sub_poll).unwrap(); + continue; + } + let client = { + // Ensure the borrow on `clients` is dropped when we are calling `process_epoll` + clients.borrow().get(&fd).cloned() + }; + if let Some(client) = client { + client.borrow_mut().process_epoll(fd, events); + } + } + } +} diff --git a/crates/muvm/src/guest/bridge/mod.rs b/crates/muvm/src/guest/bridge/mod.rs new file mode 100644 index 0000000..708a7cc --- /dev/null +++ b/crates/muvm/src/guest/bridge/mod.rs @@ -0,0 +1,3 @@ +pub mod common; +pub mod pipewire; +pub mod x11; diff --git a/crates/muvm/src/guest/bridge/pipewire.rs b/crates/muvm/src/guest/bridge/pipewire.rs new file mode 100644 index 0000000..516ad71 --- /dev/null +++ b/crates/muvm/src/guest/bridge/pipewire.rs @@ -0,0 +1,336 @@ +use crate::guest::bridge::common; +use crate::guest::bridge::common::{ + Client, CrossDomainHeader, CrossDomainResource, MessageResourceFinalizer, ProtocolHandler, + StreamRecvResult, StreamSendResult, +}; +use anyhow::Result; +use nix::errno::Errno; +use nix::sys::epoll::EpollFlags; +use nix::sys::eventfd::{EfdFlags, EventFd}; +use std::collections::{HashMap, VecDeque}; +use std::ffi::CStr; +use std::os::fd::{AsFd, AsRawFd, OwnedFd}; +use std::{env, mem}; + +const CROSS_DOMAIN_CHANNEL_TYPE_PW: u32 = 0x10; +const CROSS_DOMAIN_CMD_READ_EVENTFD_NEW: u8 = 11; +const CROSS_DOMAIN_CMD_READ: u8 = 6; + +const SPA_TYPE_STRUCT: u32 = 14; + +const PW_OPC_CORE_CREATE_OBJECT: u8 = 6; +const PW_OPC_CORE_ADD_MEM: u8 = 6; +const PW_OPC_CLIENT_UPDATE_PROPERTIES: u8 = 2; +const PW_OPC_CLIENT_NODE_TRANSPORT: u8 = 0; +const PW_OPC_CLIENT_NODE_SET_ACTIVATION: u8 = 10; + +#[repr(C)] +struct CrossDomainReadWrite { + hdr: CrossDomainHeader, + identifier: u32, + hang_up: u32, + opaque_data_size: u32, + pad: u32, + data: T, +} + +#[repr(C)] +struct CrossDomainReadEventfdNew { + pub hdr: CrossDomainHeader, + pub id: u32, + pub pad: u32, +} + +fn align_up(v: u32, a: u32) -> u32 { + (v + a - 1) & !(a - 1) +} + +fn read_u32(data: &[u8], at: usize) -> u32 { + u32::from_ne_bytes(data[at..(at + 4)].try_into().unwrap()) +} + +#[derive(Debug)] +struct CoreCreateObject<'a> { + obj_type: &'a CStr, + new_id: u32, +} + +impl<'a> CoreCreateObject<'a> { + fn new(data: &'a [u8]) -> Self { + let ty = read_u32(data, 4); + assert_eq!(ty, SPA_TYPE_STRUCT); + let factory_name_ptr = 8; + let factory_name_size = read_u32(data, factory_name_ptr); + let type_ptr = factory_name_ptr + align_up(factory_name_size + 8, 8) as usize; + let type_size = read_u32(data, type_ptr); + let obj_type = + CStr::from_bytes_with_nul(&data[(type_ptr + 8)..(type_ptr + 8 + type_size as usize)]) + .unwrap(); + let version_ptr = type_ptr + align_up(type_size + 8, 8) as usize; + let version_size = read_u32(data, version_ptr); + let props_ptr = version_ptr + align_up(version_size + 8, 8) as usize; + let props_size = read_u32(data, props_ptr); + let new_id_ptr = props_ptr + align_up(props_size + 8, 8) as usize; + let new_id = read_u32(data, new_id_ptr + 8); + CoreCreateObject { obj_type, new_id } + } +} + +#[derive(Debug)] +struct ClientUpdateProperties<'a> { + props: Vec<(&'a mut [u8], &'a mut [u8])>, +} + +impl<'a> ClientUpdateProperties<'a> { + fn new(mut data: &'a mut [u8]) -> Self { + let ty = read_u32(data, 4); + assert_eq!(ty, SPA_TYPE_STRUCT); + let props_ptr = 8; + let n_items_ptr = props_ptr + 8; + let n_items_size = read_u32(data, n_items_ptr); + let n_items = read_u32(data, n_items_ptr + 8) as usize; + let key_ptr = n_items_ptr + align_up(n_items_size + 8, 8) as usize; + let mut props = Vec::with_capacity(n_items); + data = data.split_at_mut(key_ptr).1; + for _ in 0..n_items { + let key_size = read_u32(data, 0); + data = data.split_at_mut(8).1; + let (key, data2) = data.split_at_mut(key_size as usize); + data = data2; + let pad_size = (align_up(key_size, 8) - key_size) as usize; + data = data.split_at_mut(pad_size).1; + let value_size = read_u32(data, 0); + data = data.split_at_mut(8).1; + let (value, data2) = data.split_at_mut(value_size as usize); + data = data2; + let pad_size = (align_up(value_size, 8) - value_size) as usize; + data = data.split_at_mut(pad_size).1; + props.push((key, value)); + } + ClientUpdateProperties { props } + } +} + +struct PipeWireHeader { + id: u32, + opcode: u8, + size: usize, + num_fd: usize, +} + +impl PipeWireHeader { + const SIZE: usize = 16; + fn from_stream(data: &[u8]) -> PipeWireHeader { + let id = read_u32(data, 0); + let opc_len_word = read_u32(data, 4) as usize; + let opcode = (opc_len_word >> 24) as u8; + let size = (opc_len_word & 0xFFFFFF) + 16; + let num_fd = read_u32(data, 12) as usize; + PipeWireHeader { + id, + opcode, + size, + num_fd, + } + } +} + +struct PipeWireResourceFinalizer; + +impl MessageResourceFinalizer for PipeWireResourceFinalizer { + type Handler = PipeWireProtocolHandler; + + fn finalize(self, _: &mut Client) -> Result<()> { + unreachable!() + } +} + +struct CrossDomainEventFd { + event_fd: EventFd, +} + +struct ClientNodeData { + host_to_guest: Vec, + guest_to_host: Vec, +} + +impl ClientNodeData { + fn new() -> Self { + ClientNodeData { + host_to_guest: Vec::new(), + guest_to_host: Vec::new(), + } + } +} + +struct PipeWireProtocolHandler { + client_nodes: HashMap, + guest_to_host_eventfds: HashMap, + host_to_guest_eventfds: HashMap, +} + +impl PipeWireProtocolHandler { + fn create_guest_to_host_eventfd(this: &mut Client, node_id: u32) -> Result { + let efd = EventFd::from_flags(EfdFlags::EFD_NONBLOCK)?; + let ofd = efd.as_fd().try_clone_to_owned()?; + this.sub_poll.add(efd.as_fd(), EpollFlags::EPOLLIN); + let raw = efd.as_raw_fd() as u64; + this.protocol_handler + .guest_to_host_eventfds + .insert(raw, CrossDomainEventFd { event_fd: efd }); + this.protocol_handler + .client_nodes + .get_mut(&node_id) + .unwrap() + .guest_to_host + .push(raw); + Ok(ofd) + } + fn create_host_to_guest_eventfd( + this: &mut Client, + node_id: u32, + resource: CrossDomainResource, + ) -> Result { + let efd = EventFd::from_flags(EfdFlags::EFD_NONBLOCK)?; + let ofd = efd.as_fd().try_clone_to_owned()?; + let msg_size = mem::size_of::(); + let msg = CrossDomainReadEventfdNew { + hdr: CrossDomainHeader::new(CROSS_DOMAIN_CMD_READ_EVENTFD_NEW, msg_size as u16), + id: resource.identifier, + pad: 0, + }; + this.protocol_handler + .client_nodes + .get_mut(&node_id) + .unwrap() + .host_to_guest + .push(resource.identifier); + this.gpu_ctx.submit_cmd(&msg, msg_size, None, None)?; + this.protocol_handler + .host_to_guest_eventfds + .insert(resource.identifier, CrossDomainEventFd { event_fd: efd }); + Ok(ofd) + } +} + +impl ProtocolHandler for PipeWireProtocolHandler { + type ResourceFinalizer = PipeWireResourceFinalizer; + const CHANNEL_TYPE: u32 = CROSS_DOMAIN_CHANNEL_TYPE_PW; + + fn new() -> Self { + PipeWireProtocolHandler { + client_nodes: HashMap::new(), + guest_to_host_eventfds: HashMap::new(), + host_to_guest_eventfds: HashMap::new(), + } + } + + fn process_recv_stream( + this: &mut Client, + data: &[u8], + resources: &mut VecDeque, + ) -> Result { + if data.len() < PipeWireHeader::SIZE { + eprintln!( + "Pipewire message truncated (expected at least 16 bytes, got {})", + data.len(), + ); + return Ok(StreamRecvResult::WantMore); + } + let hdr = PipeWireHeader::from_stream(data); + let mut fds = Vec::with_capacity(hdr.num_fd); + if hdr.num_fd != 0 { + if hdr.id == 0 && hdr.opcode == PW_OPC_CORE_ADD_MEM { + let rsc = resources.pop_front().ok_or(Errno::EIO)?; + fds.push(this.virtgpu_id_to_prime(rsc)?); + } else if this.protocol_handler.client_nodes.contains_key(&hdr.id) { + if hdr.opcode == PW_OPC_CLIENT_NODE_SET_ACTIVATION { + resources.pop_front().ok_or(Errno::EIO)?; + fds.push(Self::create_guest_to_host_eventfd(this, hdr.id)?); + } else if hdr.opcode == PW_OPC_CLIENT_NODE_TRANSPORT { + let rsc1 = resources.pop_front().ok_or(Errno::EIO)?; + fds.push(Self::create_host_to_guest_eventfd(this, hdr.id, rsc1)?); + resources.pop_front().ok_or(Errno::EIO)?; + fds.push(Self::create_guest_to_host_eventfd(this, hdr.id)?); + } else { + unimplemented!() + } + } else { + unimplemented!(); + } + }; + Ok(StreamRecvResult::Processed { + consumed_bytes: hdr.size, + fds, + }) + } + + fn process_send_stream( + this: &mut Client, + data: &mut [u8], + ) -> Result> { + if data.len() < PipeWireHeader::SIZE { + eprintln!( + "Pipewire message truncated (expected at least 16 bytes, got {})", + data.len(), + ); + return Ok(StreamSendResult::WantMore); + } + let hdr = PipeWireHeader::from_stream(data); + if hdr.id == 1 && hdr.opcode == PW_OPC_CLIENT_UPDATE_PROPERTIES { + let msg = ClientUpdateProperties::new(&mut data[PipeWireHeader::SIZE..]); + for (k, _) in msg.props { + if CStr::from_bytes_with_nul(k).unwrap() == c"pipewire.access.portal.app_id" { + k.copy_from_slice(c"pipewire.access.muvm00.app_id".to_bytes_with_nul()); + } + } + } + if hdr.id == 0 && hdr.opcode == PW_OPC_CORE_CREATE_OBJECT { + let msg = CoreCreateObject::new(&data[PipeWireHeader::SIZE..]); + if msg.obj_type == c"PipeWire:Interface:ClientNode" { + this.protocol_handler + .client_nodes + .insert(msg.new_id, ClientNodeData::new()); + } + } + if hdr.num_fd != 0 { + unimplemented!(); + }; + Ok(StreamSendResult::Processed { + consumed_bytes: hdr.size, + resources: Vec::new(), + finalizers: Vec::new(), + }) + } + + fn process_vgpu_extra(this: &mut Client, cmd: u8) -> Result<()> { + if cmd != CROSS_DOMAIN_CMD_READ { + return Err(Errno::EINVAL.into()); + } + let recv = unsafe { + (this.gpu_ctx.channel_ring.address + as *const CrossDomainReadWrite<[u8; mem::size_of::()]>) + .as_ref() + .unwrap() + }; + if (recv.opaque_data_size as usize) < mem::size_of::() { + return Err(Errno::EINVAL.into()); + } + if let Some(efd) = this + .protocol_handler + .host_to_guest_eventfds + .get(&recv.identifier) + { + efd.event_fd.write(u64::from_ne_bytes(recv.data))?; + Ok(()) + } else { + Err(Errno::ENOENT.into()) + } + } +} + +pub fn start_pwbridge() { + let sock_path = format!("{}/pipewire-0", env::var("XDG_RUNTIME_DIR").unwrap()); + + common::bridge_loop::(&sock_path) +} diff --git a/crates/muvm/src/guest/bridge/x11.rs b/crates/muvm/src/guest/bridge/x11.rs new file mode 100644 index 0000000..bfa2761 --- /dev/null +++ b/crates/muvm/src/guest/bridge/x11.rs @@ -0,0 +1,783 @@ +use crate::guest::bridge::common; +use crate::guest::bridge::common::{ + Client, CrossDomainHeader, CrossDomainResource, GemHandleFinalizer, MessageResourceFinalizer, + ProtocolHandler, SendPacket, StreamRecvResult, StreamSendResult, PAGE_SIZE, +}; +use anyhow::Result; +use nix::errno::Errno; +use nix::fcntl::readlink; +use nix::libc::{ + c_int, c_ulonglong, pid_t, user_regs_struct, SYS_close, SYS_dup3, SYS_mmap, SYS_munmap, + SYS_openat, AT_FDCWD, MAP_ANONYMOUS, MAP_FIXED, MAP_PRIVATE, MAP_SHARED, O_CLOEXEC, O_RDWR, + PROT_READ, PROT_WRITE, +}; +use nix::sys::mman::{mmap, munmap, MapFlags, ProtFlags}; +use nix::sys::ptrace; +use nix::sys::signal::Signal; +use nix::sys::socket::getsockopt; +use nix::sys::socket::sockopt::PeerCredentials; +use nix::sys::stat::fstat; +use nix::sys::uio::{process_vm_writev, RemoteIoVec}; +use nix::sys::wait::{waitpid, WaitPidFlag, WaitStatus}; +use nix::unistd::{mkstemp, read, Pid}; +use nix::{ioctl_read, NixPath}; +use std::borrow::Cow; +use std::collections::{HashMap, VecDeque}; +use std::ffi::{c_long, c_void, CString}; +use std::fs::{read_to_string, remove_file, File}; +use std::io::{IoSlice, Write}; +use std::os::fd::{AsFd, AsRawFd, FromRawFd, OwnedFd, RawFd}; +use std::process::exit; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::sync::{Arc, OnceLock}; +use std::thread::JoinHandle; +use std::{fs, mem, ptr, thread}; + +const X11_OPCODE_QUERY_EXTENSION: u8 = 98; +const X11_OPCODE_NOP: u8 = 127; +const X11_REPLY: u8 = 1; +const X11_GENERIC_EVENT: u8 = 35; +const DRI3_OPCODE_VERSION: u8 = 0; +const DRI3_OPCODE_OPEN: u8 = 1; +const DRI3_OPCODE_PIXMAP_FROM_BUFFER: u8 = 2; +const DRI3_OPCODE_FENCE_FROM_FD: u8 = 4; +const SYNC_OPCODE_DESTROY_FENCE: u8 = 17; +const DRI3_OPCODE_PIXMAP_FROM_BUFFERS: u8 = 7; +const PRESENT_OPCODE_PRESENT_PIXMAP: u8 = 1; +pub const SHM_TEMPLATE: &str = "/dev/shm/krshm-XXXXXX"; +pub const SHM_DIR: &str = "/dev/shm/"; +const SYSCALL_INSTR: u32 = 0xd4000001; +static SYSCALL_OFFSET: OnceLock = OnceLock::new(); +const CROSS_DOMAIN_CHANNEL_TYPE_X11: u32 = 0x11; +const CROSS_DOMAIN_ID_TYPE_SHM: u32 = 5; +const CROSS_DOMAIN_CMD_FUTEX_NEW: u8 = 8; +const CROSS_DOMAIN_CMD_FUTEX_SIGNAL: u8 = 9; +const CROSS_DOMAIN_CMD_FUTEX_DESTROY: u8 = 10; + +#[repr(C)] +#[derive(Debug, Default)] +struct ExportedHandle { + fs_id: u64, + handle: u64, +} + +const VIRTIO_IOC_MAGIC: u8 = b'v'; +const VIRTIO_IOC_TYPE_EXPORT_FD: u8 = 1; + +ioctl_read!( + virtio_export_handle, + VIRTIO_IOC_MAGIC, + VIRTIO_IOC_TYPE_EXPORT_FD, + ExportedHandle +); + +#[repr(C)] +struct CrossDomainFutexNew { + hdr: CrossDomainHeader, + fs_id: u64, + handle: u64, + id: u32, + pad: u32, +} + +#[repr(C)] +struct CrossDomainFutexSignal { + hdr: CrossDomainHeader, + id: u32, + pad: u32, +} + +#[repr(C)] +struct CrossDomainFutexDestroy { + hdr: CrossDomainHeader, + id: u32, + pad: u32, +} + +enum X11ResourceFinalizer { + Gem(GemHandleFinalizer), + Futex(u32), +} + +impl MessageResourceFinalizer for X11ResourceFinalizer { + type Handler = X11ProtocolHandler; + fn finalize(self, client: &mut Client) -> Result<()> { + match self { + X11ResourceFinalizer::Gem(fin) => fin.finalize(client)?, + X11ResourceFinalizer::Futex(xid) => { + client.protocol_handler.futex_watchers.remove(&xid).unwrap(); + let ft_destroy_msg_size = mem::size_of::(); + let ft_msg = CrossDomainFutexDestroy { + hdr: CrossDomainHeader::new( + CROSS_DOMAIN_CMD_FUTEX_DESTROY, + ft_destroy_msg_size as u16, + ), + id: xid, + pad: 0, + }; + client + .gpu_ctx + .submit_cmd(&ft_msg, ft_destroy_msg_size, None, None)?; + }, + } + Ok(()) + } +} + +struct X11ProtocolHandler { + // futex_watchers gets dropped first + futex_watchers: HashMap, + got_first_req: bool, + seq_no: u16, + got_first_resp: bool, + dri3_ext_opcode: Option, + dri3_qe_resp_seq: Option, + sync_ext_opcode: Option, + sync_qe_resp_seq: Option, + present_ext_opcode: Option, + present_qe_resp_seq: Option, +} + +impl ProtocolHandler for X11ProtocolHandler { + type ResourceFinalizer = X11ResourceFinalizer; + const CHANNEL_TYPE: u32 = CROSS_DOMAIN_CHANNEL_TYPE_X11; + fn new() -> X11ProtocolHandler { + X11ProtocolHandler { + futex_watchers: HashMap::new(), + got_first_req: false, + seq_no: 1, + dri3_ext_opcode: None, + dri3_qe_resp_seq: None, + sync_ext_opcode: None, + sync_qe_resp_seq: None, + present_qe_resp_seq: None, + present_ext_opcode: None, + got_first_resp: false, + } + } + fn process_recv_stream( + this: &mut Client, + data: &[u8], + resources: &mut VecDeque, + ) -> Result { + if !this.protocol_handler.got_first_resp { + this.protocol_handler.got_first_resp = true; + let size = u16::from_ne_bytes(data[6..8].try_into().unwrap()) as usize * 4 + 8; + return Ok(StreamRecvResult::Processed { + consumed_bytes: size, + fds: Vec::new(), + }); + } + if data.len() < 32 { + eprintln!( + "X11 message truncated (expected at least 32 bytes, got {})", + data.len(), + ); + return Ok(StreamRecvResult::WantMore); + } + let mut fds = Vec::new(); + for ident in resources.drain(..) { + fds.push(this.virtgpu_id_to_prime(ident)?); + } + let seq_no = u16::from_ne_bytes(data[2..4].try_into().unwrap()); + let is_reply = data[0] == X11_REPLY; + let is_generic = data[0] == X11_GENERIC_EVENT; + let len = if is_reply || is_generic { + u32::from_ne_bytes(data[4..8].try_into().unwrap()) as usize * 4 + } else { + 0 + } + 32; + if is_reply { + if Some(seq_no) == this.protocol_handler.dri3_qe_resp_seq { + this.protocol_handler.dri3_qe_resp_seq = None; + this.protocol_handler.dri3_ext_opcode = extract_opcode_from_qe_resp(data); + } else if Some(seq_no) == this.protocol_handler.sync_qe_resp_seq { + this.protocol_handler.sync_qe_resp_seq = None; + this.protocol_handler.sync_ext_opcode = extract_opcode_from_qe_resp(data); + } else if Some(seq_no) == this.protocol_handler.present_qe_resp_seq { + this.protocol_handler.present_qe_resp_seq = None; + this.protocol_handler.present_ext_opcode = extract_opcode_from_qe_resp(data); + } + } + Ok(StreamRecvResult::Processed { + consumed_bytes: len, + fds, + }) + } + fn process_send_stream( + this: &mut Client, + buf: &mut [u8], + ) -> Result> { + let mut resources = Vec::new(); + let mut finalizers = Vec::new(); + if !this.protocol_handler.got_first_req { + this.protocol_handler.got_first_req = true; + return Ok(StreamSendResult::Processed { + consumed_bytes: buf.len(), + resources, + finalizers, + }); + } + if buf.len() < 4 { + eprintln!( + "X11 message truncated (expected at least 4 bytes, got {})", + buf.len(), + ); + return Ok(StreamSendResult::WantMore); + } + let mut req_len = u16::from_ne_bytes(buf[2..4].try_into().unwrap()) as usize * 4; + if req_len == 0 { + if buf.len() < 8 { + eprintln!( + "X11 message truncated (expected at least 8 bytes, got {})", + buf.len(), + ); + return Ok(StreamSendResult::WantMore); + } + req_len = u32::from_ne_bytes(buf[4..8].try_into().unwrap()) as usize * 4; + } + if buf[0] == X11_OPCODE_QUERY_EXTENSION { + let namelen = u16::from_ne_bytes(buf[4..6].try_into().unwrap()) as usize; + let name = String::from_utf8_lossy(&buf[8..(8 + namelen)]); + if name == "DRI3" { + this.protocol_handler.dri3_qe_resp_seq = Some(this.protocol_handler.seq_no); + } else if name == "SYNC" { + this.protocol_handler.sync_qe_resp_seq = Some(this.protocol_handler.seq_no) + } else if name == "Present" { + this.protocol_handler.present_qe_resp_seq = Some(this.protocol_handler.seq_no); + } + } else if Some(buf[0]) == this.protocol_handler.dri3_ext_opcode { + if buf[1] == DRI3_OPCODE_VERSION { + buf[8] = buf[8].min(3); + } else if buf[1] == DRI3_OPCODE_OPEN { + buf[0] = X11_OPCODE_NOP; + let mut reply = vec![ + 1, + 1, + (this.protocol_handler.seq_no & 0xff) as u8, + (this.protocol_handler.seq_no >> 8) as u8, + ]; + reply.extend_from_slice(&[0u8; 28]); + let render = File::options() + .read(true) + .write(true) + .open("/dev/dri/renderD128")?; + this.send_queue.push_back(SendPacket { + data: reply, + fds: vec![render.into()], + }); + } else if buf[1] == DRI3_OPCODE_PIXMAP_FROM_BUFFER { + let fd = this.request_fds.remove(0); + let (res, finalizer) = this.vgpu_id_from_prime(fd)?; + resources.push(res); + finalizers.push(X11ResourceFinalizer::Gem(finalizer)); + } else if buf[1] == DRI3_OPCODE_FENCE_FROM_FD { + let xid = u32::from_ne_bytes(buf[8..12].try_into().unwrap()); + let fd = this.request_fds.remove(0); + let filename = readlink(format!("/proc/self/fd/{}", fd.as_raw_fd()).as_str())?; + let filename = filename.to_string_lossy(); + let creds = getsockopt(&this.socket.as_fd(), PeerCredentials)?; + let res = Self::create_cross_vm_futex(this, fd, xid, creds.pid(), filename)?; + resources.push(res); + } else if buf[1] == DRI3_OPCODE_PIXMAP_FROM_BUFFERS { + let num_bufs = buf[12] as usize; + for fd in this.request_fds.drain(..num_bufs).collect::>() { + let (res, finalizer) = this.vgpu_id_from_prime(fd)?; + resources.push(res); + finalizers.push(X11ResourceFinalizer::Gem(finalizer)); + } + } + } else if Some(buf[0]) == this.protocol_handler.sync_ext_opcode { + if buf[1] == SYNC_OPCODE_DESTROY_FENCE { + let xid = u32::from_ne_bytes(buf[4..8].try_into().unwrap()); + finalizers.push(X11ResourceFinalizer::Futex(xid)); + } + } else if Some(buf[0]) == this.protocol_handler.present_ext_opcode + && buf[1] == PRESENT_OPCODE_PRESENT_PIXMAP + { + /* TODO: Implement GPU fence passing here when we have it. */ + } + this.protocol_handler.seq_no = this.protocol_handler.seq_no.wrapping_add(1); + Ok(StreamSendResult::Processed { + finalizers, + resources, + consumed_bytes: req_len, + }) + } + fn process_vgpu_extra(this: &mut Client, cmd: u8) -> Result<()> { + if cmd != CROSS_DOMAIN_CMD_FUTEX_SIGNAL { + return Err(Errno::EINVAL.into()); + } + let recv = unsafe { + (this.gpu_ctx.channel_ring.address as *const CrossDomainFutexSignal) + .as_ref() + .unwrap() + }; + this.protocol_handler.process_futex_signal(recv) + } +} + +impl X11ProtocolHandler { + fn ptrace_all_threads(pid: Pid) -> Result> { + let mut tids = Vec::new(); + for entry in fs::read_dir(format!("/proc/{pid}/task"))? { + let entry = match entry { + Err(_) => continue, + Ok(a) => a, + }; + let tid = Pid::from_raw( + entry + .file_name() + .into_string() + .or(Err(Errno::EIO))? + .parse()?, + ); + if let Err(e) = ptrace::attach(tid) { + // This could be a race (thread exited), so keep going + // unless this is the top-level PID + if tid == pid { + return Err(e.into()); + } + eprintln!("ptrace::attach({pid}, ...) failed (continuing)"); + continue; + } + let ptid = PtracedPid(tid); + wait_for_stop(ptid.pid())?; + tids.push(ptid); + } + Ok(tids) + } + fn replace_futex_storage( + my_fd: RawFd, + pid: Pid, + shmem_path: &str, + shmem_file: &mut File, + ) -> Result<()> { + let traced = Self::ptrace_all_threads(pid)?; + + let mut data = [0; 4]; + read(my_fd, &mut data)?; + shmem_file.write_all(&data)?; + + // TODO: match st_dev too to avoid false positives + let my_ino = fstat(my_fd)?.st_ino; + let mut fds_to_replace = Vec::new(); + for entry in fs::read_dir(format!("/proc/{pid}/fd"))? { + let entry = entry?; + if let Ok(file) = File::options().open(entry.path()) { + if fstat(file.as_raw_fd())?.st_ino == my_ino { + fds_to_replace.push(entry.file_name().to_string_lossy().parse::()?); + } + } + } + let mut pages_to_replace = Vec::new(); + for line in read_to_string(format!("/proc/{pid}/maps"))?.lines() { + let f: Vec<&str> = line.split_whitespace().collect(); + let ino: u64 = f[4].parse()?; + if ino == my_ino { + let addr = usize::from_str_radix(f[0].split('-').next().unwrap(), 16)?; + pages_to_replace.push(addr); + } + } + RemoteCaller::with(pid, |caller| { + let scratch_page = caller.mmap( + 0, + PAGE_SIZE, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, + 0, + 0, + )?; + let path_cstr = CString::new(shmem_path).unwrap(); + process_vm_writev( + pid, + &[IoSlice::new(path_cstr.as_bytes_with_nul())], + &[RemoteIoVec { + base: scratch_page, + len: path_cstr.len(), + }], + )?; + let remote_shm = caller.open(scratch_page, O_CLOEXEC | O_RDWR, 0o600)?; + for fd in fds_to_replace { + caller.dup2(remote_shm, fd)?; + } + for page in pages_to_replace { + caller.mmap( + page, + PAGE_SIZE, + PROT_READ | PROT_WRITE, + MAP_SHARED | MAP_FIXED, + remote_shm, + 0, + )?; + } + caller.munmap(scratch_page, PAGE_SIZE)?; + caller.close(remote_shm)?; + Ok(()) + })?; + // This detaches all the traced threads + mem::drop(traced); + Ok(()) + } + fn create_cross_vm_futex( + this: &mut Client, + memfd: OwnedFd, + xid: u32, + pid: pid_t, + filename: Cow<'_, str>, + ) -> Result { + // Allow everything in /dev/shm (including paths with trailing '(deleted)') + let shmem_file = if filename.starts_with(SHM_DIR) { + File::from(memfd) + } else if cfg!(not(target_arch = "aarch64")) { + return Err(Errno::EOPNOTSUPP.into()); + } else { + let (fd, shmem_path) = mkstemp(SHM_TEMPLATE)?; + let mut shmem_file = unsafe { File::from_raw_fd(fd) }; + let ret = Self::replace_futex_storage( + memfd.as_raw_fd(), + Pid::from_raw(pid), + shmem_path.as_os_str().to_str().unwrap(), + &mut shmem_file, + ); + remove_file(&shmem_path)?; + ret?; + shmem_file + }; + + let mut handle: ExportedHandle = Default::default(); + unsafe { virtio_export_handle(shmem_file.as_raw_fd(), &mut handle) }?; + + let addr = FutexPtr(unsafe { + mmap( + None, + 4.try_into().unwrap(), + ProtFlags::PROT_WRITE | ProtFlags::PROT_READ, + MapFlags::MAP_SHARED, + shmem_file, + 0, + )? + .as_ptr() + }); + let initial_value = + unsafe { AtomicU32::from_ptr(addr.0 as *mut u32) }.load(Ordering::Relaxed); + + let ft_new_msg_size = mem::size_of::(); + let ft_msg = CrossDomainFutexNew { + hdr: CrossDomainHeader::new(CROSS_DOMAIN_CMD_FUTEX_NEW, ft_new_msg_size as u16), + id: xid, + fs_id: handle.fs_id, + handle: handle.handle, + pad: 0, + }; + this.gpu_ctx + .submit_cmd(&ft_msg, ft_new_msg_size, None, None)?; + let fd = this.gpu_ctx.fd.as_raw_fd() as c_int; + this.protocol_handler + .futex_watchers + .insert(xid, FutexWatcherThread::new(fd, xid, addr, initial_value)); + Ok(CrossDomainResource { + identifier: xid, + identifier_type: CROSS_DOMAIN_ID_TYPE_SHM, + identifier_size: 0, + }) + } + fn process_futex_signal(&mut self, recv: &CrossDomainFutexSignal) -> Result<()> { + let watcher = match self.futex_watchers.get(&recv.id) { + Some(a) => a, + None => { + eprintln!("Unknown futex id {}", recv.id); + return Ok(()); + }, + }; + + watcher.signal(); + + Ok(()) + } +} + +#[derive(Clone)] +struct FutexPtr(*mut c_void); + +unsafe impl Send for FutexPtr {} + +impl Drop for FutexPtr { + fn drop(&mut self) { + unsafe { + munmap(NonNull::new_unchecked(self.0), 4).unwrap(); + } + } +} + +fn extract_opcode_from_qe_resp(data: &[u8]) -> Option { + if data[8] != 0 { + Some(data[9]) + } else { + None + } +} + +struct FutexWatcherThread { + join_handle: Option>, + shutdown: Arc, + futex: FutexPtr, +} + +unsafe fn wake_futex(futex: *mut c_void, val3: u32) { + let op = nix::libc::FUTEX_WAKE_BITSET; + let val = c_int::MAX; + let timeout = ptr::null::<()>(); + let uaddr2 = ptr::null::<()>(); + unsafe { + nix::libc::syscall(nix::libc::SYS_futex, futex, op, val, timeout, uaddr2, val3); + } +} + +impl FutexWatcherThread { + fn new(fd: c_int, xid: u32, futex: FutexPtr, initial_value: u32) -> FutexWatcherThread { + let shutdown = Arc::new(AtomicBool::new(false)); + let shutdown2 = shutdown.clone(); + let futex2 = futex.clone(); + let handle = thread::spawn(move || { + let uaddr = futex2; + let op = nix::libc::FUTEX_WAIT_BITSET; + let timeout = ptr::null::<()>(); + let uaddr2 = ptr::null::<()>(); + let val3 = 1u32; + let mut val = initial_value; + let atomic_val = unsafe { AtomicU32::from_ptr(uaddr.0 as *mut u32) }; + loop { + if shutdown2.load(Ordering::SeqCst) { + break; + } + unsafe { + nix::libc::syscall( + nix::libc::SYS_futex, + uaddr.0, + op, + val, + timeout, + uaddr2, + val3, + ); + } + val = atomic_val.load(Ordering::SeqCst); + let ft_signal_msg_size = mem::size_of::(); + let ft_signal_cmd = CrossDomainFutexSignal { + hdr: CrossDomainHeader::new( + CROSS_DOMAIN_CMD_FUTEX_SIGNAL, + ft_signal_msg_size as u16, + ), + id: xid, + pad: 0, + }; + common::submit_cmd_raw(fd, &ft_signal_cmd, ft_signal_msg_size, None, None).unwrap(); + } + }); + FutexWatcherThread { + futex, + join_handle: Some(handle), + shutdown, + } + } + + fn signal(&self) { + unsafe { + wake_futex(self.futex.0, !1); + } + } +} + +impl Drop for FutexWatcherThread { + fn drop(&mut self) { + self.shutdown.store(true, Ordering::Release); + let atomic_val = unsafe { AtomicU32::from_ptr(self.futex.0 as *mut u32) }; + let v = atomic_val.load(Ordering::SeqCst); + atomic_val.store(!v, Ordering::SeqCst); + unsafe { + wake_futex(self.futex.0, !0); + } + self.join_handle.take().unwrap().join().unwrap(); + } +} + +#[allow(dead_code)] +struct RemoteCaller { + pid: Pid, + regs: user_regs_struct, +} + +impl RemoteCaller { + // This is arch-specific, so gate it off of x86_64 builds done for CI purposes + #[cfg(target_arch = "aarch64")] + fn with(pid: Pid, f: F) -> Result + where + F: FnOnce(&RemoteCaller) -> Result, + { + let old_regs = ptrace::getregs(pid)?; + + // Find the vDSO and the address of a syscall instruction within it + let (vdso_start, _) = find_vdso(Some(pid))?; + let syscall_addr = vdso_start + SYSCALL_OFFSET.get().unwrap(); + + let mut regs = old_regs; + regs.pc = syscall_addr as u64; + ptrace::setregs(pid, regs)?; + let res = f(&RemoteCaller { regs, pid })?; + ptrace::setregs(pid, old_regs)?; + Ok(res) + } + fn dup2(&self, oldfd: i32, newfd: i32) -> Result { + self.syscall(SYS_dup3, [oldfd as u64, newfd as u64, 0, 0, 0, 0]) + .map(|x| x as i32) + } + fn close(&self, fd: i32) -> Result { + self.syscall(SYS_close, [fd as u64, 0, 0, 0, 0, 0]) + .map(|x| x as i32) + } + fn mmap( + &self, + addr: usize, + length: usize, + prot: i32, + flags: i32, + fd: i32, + offset: usize, + ) -> Result { + self.syscall( + SYS_mmap, + [ + addr as u64, + length as u64, + prot as u64, + flags as u64, + fd as u64, + offset as u64, + ], + ) + .map(|x| x as usize) + } + fn munmap(&self, addr: usize, length: usize) -> Result { + self.syscall(SYS_munmap, [addr as u64, length as u64, 0, 0, 0, 0]) + .map(|x| x as i32) + } + fn open(&self, path: usize, flags: i32, mode: i32) -> Result { + self.syscall( + SYS_openat, + [ + AT_FDCWD as u64, + path as u64, + flags as u64, + mode as u64, + 0, + 0, + ], + ) + .map(|x| x as i32) + } + + // This is arch-specific, so gate it off of x86_64 builds done for CI purposes + #[cfg(target_arch = "aarch64")] + fn syscall(&self, syscall_no: c_long, args: [c_ulonglong; 6]) -> Result { + let mut regs = self.regs; + regs.regs[..6].copy_from_slice(&args); + regs.regs[8] = syscall_no as c_ulonglong; + ptrace::setregs(self.pid, regs)?; + ptrace::step(self.pid, None)?; + let evt = waitpid(self.pid, Some(WaitPidFlag::__WALL))?; + if !matches!(evt, WaitStatus::Stopped(_, _)) { + unimplemented!(); + } + regs = ptrace::getregs(self.pid)?; + Ok(regs.regs[0]) + } + + #[cfg(not(target_arch = "aarch64"))] + fn with(_pid: Pid, _f: F) -> Result + where + F: FnOnce(&RemoteCaller) -> Result, + { + Err(Errno::EOPNOTSUPP.into()) + } + #[cfg(not(target_arch = "aarch64"))] + fn syscall(&self, _syscall_no: c_long, _args: [c_ulonglong; 6]) -> Result { + Err(Errno::EOPNOTSUPP.into()) + } +} + +fn wait_for_stop(pid: Pid) -> Result<()> { + loop { + let event = waitpid(pid, Some(WaitPidFlag::__WALL))?; + match event { + WaitStatus::Stopped(_, sig) => { + if sig == Signal::SIGSTOP { + return Ok(()); + } else { + ptrace::cont(pid, sig)?; + } + }, + _ => unimplemented!(), + } + } +} + +struct PtracedPid(Pid); + +impl PtracedPid { + fn pid(&self) -> Pid { + self.0 + } +} + +impl Drop for PtracedPid { + fn drop(&mut self) { + if ptrace::detach(self.0, None).is_err() { + eprintln!("Failed to ptrace::detach({}) (continuing)", self.0); + } + } +} + +fn find_vdso(pid: Option) -> Result<(usize, usize), Errno> { + let path = format!( + "/proc/{}/maps", + pid.map(|a| a.to_string()).unwrap_or("self".into()) + ); + + for line in read_to_string(path).unwrap().lines() { + if line.ends_with("[vdso]") { + let a = line.find('-').ok_or(Errno::EINVAL)?; + let b = line.find(' ').ok_or(Errno::EINVAL)?; + let start = usize::from_str_radix(&line[..a], 16).or(Err(Errno::EINVAL))?; + let end = usize::from_str_radix(&line[a + 1..b], 16).or(Err(Errno::EINVAL))?; + + return Ok((start, end)); + } + } + + Err(Errno::EINVAL) +} + +pub fn start_x11bridge(display: u32) { + let sock_path = format!("/tmp/.X11-unix/X{display}"); + + // Look for a syscall instruction in the vDSO. We assume all processes map + // the same vDSO (which should be true if they are running under the same + // kernel!) + let (vdso_start, vdso_end) = find_vdso(None).unwrap(); + for off in (0..(vdso_end - vdso_start)).step_by(4) { + let addr = vdso_start + off; + let val = unsafe { std::ptr::read(addr as *const u32) }; + if val == SYSCALL_INSTR { + SYSCALL_OFFSET.set(off).unwrap(); + break; + } + } + if SYSCALL_OFFSET.get().is_none() { + eprintln!("Failed to find syscall instruction in vDSO"); + exit(1); + } + + common::bridge_loop::(&sock_path) +} diff --git a/crates/muvm/src/guest/mod.rs b/crates/muvm/src/guest/mod.rs index 5cd32d6..904ceae 100644 --- a/crates/muvm/src/guest/mod.rs +++ b/crates/muvm/src/guest/mod.rs @@ -1,3 +1,4 @@ +pub mod bridge; pub mod fex; pub mod hidpipe; pub mod mount; @@ -7,4 +8,3 @@ pub mod server_worker; pub mod socket; pub mod user; pub mod x11; -pub mod x11bridge; diff --git a/crates/muvm/src/guest/x11.rs b/crates/muvm/src/guest/x11.rs index 44041c7..bb74d45 100644 --- a/crates/muvm/src/guest/x11.rs +++ b/crates/muvm/src/guest/x11.rs @@ -7,7 +7,7 @@ use std::{env, thread}; use anyhow::{anyhow, Context, Result}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use crate::guest::x11bridge::start_x11bridge; +use crate::guest::bridge::x11::start_x11bridge; pub fn setup_x11_forwarding

(run_path: P, host_display: &str) -> Result<()> where diff --git a/crates/muvm/src/guest/x11bridge.rs b/crates/muvm/src/guest/x11bridge.rs deleted file mode 100644 index f3a8ffe..0000000 --- a/crates/muvm/src/guest/x11bridge.rs +++ /dev/null @@ -1,1599 +0,0 @@ -use anyhow::Result; -use nix::errno::Errno; -use nix::fcntl::readlink; -use nix::libc::{ - c_int, c_ulonglong, c_void, off_t, pid_t, user_regs_struct, SYS_close, SYS_dup3, SYS_mmap, - SYS_munmap, SYS_openat, AT_FDCWD, MAP_ANONYMOUS, MAP_FIXED, MAP_PRIVATE, MAP_SHARED, O_CLOEXEC, - O_RDWR, PROT_READ, PROT_WRITE, -}; -use nix::sys::epoll::{Epoll, EpollCreateFlags, EpollEvent, EpollFlags, EpollTimeout}; -use nix::sys::mman::{mmap, munmap, MapFlags, ProtFlags}; -use nix::sys::ptrace; -use nix::sys::signal::Signal; -use nix::sys::socket::sockopt::PeerCredentials; -use nix::sys::socket::{ - getsockopt, recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, RecvMsg, -}; -use nix::sys::stat::fstat; -use nix::sys::uio::{process_vm_writev, RemoteIoVec}; -use nix::sys::wait::{waitpid, WaitPidFlag, WaitStatus}; -use nix::unistd::{mkstemp, read, Pid}; -use nix::{cmsg_space, ioctl_read, ioctl_readwrite, ioctl_write_ptr, NixPath}; -use std::borrow::Cow; -use std::cell::RefCell; -use std::collections::{HashMap, VecDeque}; -use std::ffi::{c_long, CString}; -use std::fs::{read_to_string, remove_file, File}; -use std::io::{IoSlice, IoSliceMut, Read, Write}; -use std::net::{TcpListener, TcpStream}; -use std::num::NonZeroUsize; -use std::os::fd::{AsFd, AsRawFd, FromRawFd, OwnedFd, RawFd}; -use std::os::unix::net::{UnixListener, UnixStream}; -use std::process::exit; -use std::ptr::NonNull; -use std::rc::Rc; -use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; -use std::sync::{Arc, OnceLock}; -use std::thread::JoinHandle; -use std::{env, fs, mem, ptr, slice, thread}; - -const PAGE_SIZE: usize = 4096; - -const VIRTGPU_CONTEXT_PARAM_CAPSET_ID: u64 = 0x0001; -const VIRTGPU_CONTEXT_PARAM_NUM_RINGS: u64 = 0x0002; -const VIRTGPU_CONTEXT_PARAM_POLL_RINGS_MASK: u64 = 0x0003; -const CAPSET_CROSS_DOMAIN: u64 = 5; -const CROSS_DOMAIN_CHANNEL_RING: u32 = 1; -const VIRTGPU_BLOB_MEM_GUEST: u32 = 0x0001; -const VIRTGPU_BLOB_MEM_HOST3D: u32 = 0x0002; -const VIRTGPU_BLOB_FLAG_USE_MAPPABLE: u32 = 0x0001; -const VIRTGPU_BLOB_FLAG_USE_SHAREABLE: u32 = 0x0002; -const VIRTGPU_EVENT_FENCE_SIGNALED: u32 = 0x90000000; -const CROSS_DOMAIN_ID_TYPE_VIRTGPU_BLOB: u32 = 1; -const CROSS_DOMAIN_ID_TYPE_SHM: u32 = 5; - -const X11_OPCODE_CREATE_PIXMAP: u8 = 53; -const X11_OPCODE_FREE_PIXMAP: u8 = 54; -const X11_OPCODE_QUERY_EXTENSION: u8 = 98; -const X11_OPCODE_NOP: u8 = 127; -const X11_REPLY: u8 = 1; -const X11_GENERIC_EVENT: u8 = 35; -const DRI3_OPCODE_VERSION: u8 = 0; -const DRI3_OPCODE_OPEN: u8 = 1; -const DRI3_OPCODE_PIXMAP_FROM_BUFFER: u8 = 2; -const DRI3_OPCODE_FENCE_FROM_FD: u8 = 4; -const SYNC_OPCODE_DESTROY_FENCE: u8 = 17; -const DRI3_OPCODE_PIXMAP_FROM_BUFFERS: u8 = 7; -const PRESENT_OPCODE_PRESENT_PIXMAP: u8 = 1; - -pub const SHM_TEMPLATE: &str = "/dev/shm/krshm-XXXXXX"; -pub const SHM_DIR: &str = "/dev/shm/"; - -#[repr(C)] -#[derive(Debug, Default)] -struct ExportedHandle { - fs_id: u64, - handle: u64, -} - -const SYSCALL_INSTR: u32 = 0xd4000001; -static SYSCALL_OFFSET: OnceLock = OnceLock::new(); - -const VIRTIO_IOC_MAGIC: u8 = b'v'; -const VIRTIO_IOC_TYPE_EXPORT_FD: u8 = 1; - -ioctl_read!( - virtio_export_handle, - VIRTIO_IOC_MAGIC, - VIRTIO_IOC_TYPE_EXPORT_FD, - ExportedHandle -); - -#[repr(C)] -#[derive(Default)] -struct DrmVirtgpuContextInit { - num_params: u32, - pad: u32, - ctx_set_params: u64, -} - -#[repr(C)] -#[derive(Default)] -struct DrmVirtgpuContextSetParam { - param: u64, - value: u64, -} - -#[rustfmt::skip] -ioctl_readwrite!(drm_virtgpu_context_init, 'd', 0x40 + 0xb, DrmVirtgpuContextInit); - -#[repr(C)] -#[derive(Default)] -struct DrmVirtgpuResourceCreateBlob { - blob_mem: u32, - blob_flags: u32, - bo_handle: u32, - res_handle: u32, - size: u64, - pad: u32, - cmd_size: u32, - cmd: u64, - blob_id: u64, -} - -#[rustfmt::skip] -ioctl_readwrite!(drm_virtgpu_resource_create_blob, 'd', 0x40 + 0xa, DrmVirtgpuResourceCreateBlob); - -#[repr(C)] -#[derive(Default)] -struct DrmVirtgpuMap { - offset: u64, - handle: u32, - pad: u32, -} - -#[rustfmt::skip] -ioctl_readwrite!(drm_virtgpu_map, 'd', 0x40 + 0x1, DrmVirtgpuMap); - -#[repr(C)] -#[derive(Default)] -struct DrmGemClose { - handle: u32, - pad: u32, -} - -impl DrmGemClose { - fn new(handle: u32) -> DrmGemClose { - DrmGemClose { - handle, - ..DrmGemClose::default() - } - } -} - -#[rustfmt::skip] -ioctl_write_ptr!(drm_gem_close, 'd', 0x9, DrmGemClose); - -#[repr(C)] -#[derive(Default)] -struct DrmPrimeHandle { - handle: u32, - flags: u32, - fd: i32, -} - -#[rustfmt::skip] -ioctl_readwrite!(drm_prime_handle_to_fd, 'd', 0x2d, DrmPrimeHandle); -#[rustfmt::skip] -ioctl_readwrite!(drm_prime_fd_to_handle, 'd', 0x2e, DrmPrimeHandle); - -#[repr(C)] -#[derive(Default)] -struct DrmEvent { - ty: u32, - length: u32, -} - -const VIRTGPU_EXECBUF_RING_IDX: u32 = 0x04; -#[repr(C)] -#[derive(Default)] -struct DrmVirtgpuExecbuffer { - flags: u32, - size: u32, - command: u64, - bo_handles: u64, - num_bo_handles: u32, - fence_fd: i32, - ring_idx: u32, - pad: u32, -} - -#[rustfmt::skip] -ioctl_readwrite!(drm_virtgpu_execbuffer, 'd', 0x40 + 0x2, DrmVirtgpuExecbuffer); - -#[repr(C)] -#[derive(Default)] -struct DrmVirtgpuResourceInfo { - bo_handle: u32, - res_handle: u32, - size: u32, - blob_mem: u32, -} - -#[rustfmt::skip] -ioctl_readwrite!(drm_virtgpu_resource_info, 'd', 0x40 + 0x5, DrmVirtgpuResourceInfo); - -#[repr(C)] -#[derive(Default)] -struct CrossDomainHeader { - cmd: u8, - fence_ctx_idx: u8, - cmd_size: u16, - pad: u32, -} - -impl CrossDomainHeader { - fn new(cmd: u8, cmd_size: u16) -> CrossDomainHeader { - CrossDomainHeader { - cmd, - cmd_size, - ..CrossDomainHeader::default() - } - } -} - -const CROSS_DOMAIN_CMD_INIT: u8 = 1; -const CROSS_DOMAIN_CMD_POLL: u8 = 3; -const CROSS_DOMAIN_CHANNEL_TYPE_X11: u32 = 0x11; -#[repr(C)] -#[derive(Default)] -struct CrossDomainInit { - hdr: CrossDomainHeader, - query_ring_id: u32, - channel_ring_id: u32, - channel_type: u32, -} - -#[repr(C)] -#[derive(Default)] -struct CrossDomainPoll { - hdr: CrossDomainHeader, - pad: u64, -} - -impl CrossDomainPoll { - fn new() -> CrossDomainPoll { - CrossDomainPoll { - hdr: CrossDomainHeader::new( - CROSS_DOMAIN_CMD_POLL, - mem::size_of::() as u16, - ), - ..CrossDomainPoll::default() - } - } -} - -#[repr(C)] -pub struct CrossDomainFutexNew { - hdr: CrossDomainHeader, - fs_id: u64, - handle: u64, - id: u32, - pad: u32, -} - -#[repr(C)] -struct CrossDomainFutexSignal { - hdr: CrossDomainHeader, - id: u32, - pad: u32, -} - -#[repr(C)] -struct CrossDomainFutexDestroy { - hdr: CrossDomainHeader, - id: u32, - pad: u32, -} - -const CROSS_DOMAIN_MAX_IDENTIFIERS: usize = 4; -const CROSS_DOMAIN_CMD_SEND: u8 = 4; -const CROSS_DOMAIN_CMD_RECEIVE: u8 = 5; -const CROSS_DOMAIN_CMD_FUTEX_NEW: u8 = 8; -const CROSS_DOMAIN_CMD_FUTEX_SIGNAL: u8 = 9; -pub const CROSS_DOMAIN_CMD_FUTEX_DESTROY: u8 = 10; - -#[repr(C)] -struct CrossDomainSendReceive { - hdr: CrossDomainHeader, - num_identifiers: u32, - opaque_data_size: u32, - identifiers: [u32; CROSS_DOMAIN_MAX_IDENTIFIERS], - identifier_types: [u32; CROSS_DOMAIN_MAX_IDENTIFIERS], - identifier_sizes: [u32; CROSS_DOMAIN_MAX_IDENTIFIERS], - data: T, -} - -const CROSS_DOMAIN_SR_TAIL_SIZE: usize = PAGE_SIZE - mem::size_of::>(); - -struct GpuRing { - handle: u32, - res_id: u32, - address: *mut c_void, - fd: OwnedFd, -} - -impl GpuRing { - fn new(fd: &OwnedFd) -> Result { - let fd = fd.try_clone().unwrap(); - let mut create_blob = DrmVirtgpuResourceCreateBlob { - size: PAGE_SIZE as u64, - blob_mem: VIRTGPU_BLOB_MEM_GUEST, - blob_flags: VIRTGPU_BLOB_FLAG_USE_MAPPABLE, - ..DrmVirtgpuResourceCreateBlob::default() - }; - unsafe { - drm_virtgpu_resource_create_blob(fd.as_raw_fd() as c_int, &mut create_blob)?; - } - let mut map = DrmVirtgpuMap { - handle: create_blob.bo_handle, - ..DrmVirtgpuMap::default() - }; - unsafe { - drm_virtgpu_map(fd.as_raw_fd() as c_int, &mut map)?; - } - let ptr = unsafe { - mmap( - None, - NonZeroUsize::new(PAGE_SIZE).unwrap(), - ProtFlags::PROT_READ | ProtFlags::PROT_WRITE, - MapFlags::MAP_SHARED, - &fd, - map.offset as off_t, - )? - .as_ptr() - }; - Ok(GpuRing { - fd, - handle: create_blob.bo_handle, - res_id: create_blob.res_handle, - address: ptr, - }) - } -} - -impl Drop for GpuRing { - fn drop(&mut self) { - unsafe { - munmap(NonNull::new(self.address).unwrap(), PAGE_SIZE).unwrap(); - let close = DrmGemClose::new(self.handle); - drm_gem_close(self.fd.as_raw_fd() as c_int, &close).unwrap(); - } - } -} - -struct Context { - fd: OwnedFd, - channel_ring: GpuRing, - query_ring: GpuRing, -} - -impl Context { - fn new() -> Result { - let mut params = [ - DrmVirtgpuContextSetParam { - param: VIRTGPU_CONTEXT_PARAM_CAPSET_ID, - value: CAPSET_CROSS_DOMAIN, - }, - DrmVirtgpuContextSetParam { - param: VIRTGPU_CONTEXT_PARAM_NUM_RINGS, - value: 2, - }, - DrmVirtgpuContextSetParam { - param: VIRTGPU_CONTEXT_PARAM_POLL_RINGS_MASK, - value: 1 << CROSS_DOMAIN_CHANNEL_RING, - }, - ]; - let mut init = DrmVirtgpuContextInit { - num_params: 3, - pad: 0, - ctx_set_params: params.as_mut_ptr() as u64, - }; - let fd: OwnedFd = File::options() - .write(true) - .read(true) - .open("/dev/dri/renderD128")? - .into(); - unsafe { - drm_virtgpu_context_init(fd.as_raw_fd() as c_int, &mut init)?; - } - - let query_ring = GpuRing::new(&fd)?; - let channel_ring = GpuRing::new(&fd)?; - let this = Context { - fd, - query_ring, - channel_ring, - }; - let init_cmd = CrossDomainInit { - hdr: CrossDomainHeader::new( - CROSS_DOMAIN_CMD_INIT, - mem::size_of::() as u16, - ), - query_ring_id: this.query_ring.res_id, - channel_ring_id: this.channel_ring.res_id, - channel_type: CROSS_DOMAIN_CHANNEL_TYPE_X11, - }; - this.submit_cmd(&init_cmd, mem::size_of::(), None, None)?; - this.poll_cmd()?; - Ok(this) - } - fn submit_cmd( - &self, - cmd: &T, - cmd_size: usize, - ring_idx: Option, - ring_handle: Option, - ) -> Result<()> { - submit_cmd_raw( - self.fd.as_raw_fd() as c_int, - cmd, - cmd_size, - ring_idx, - ring_handle, - ) - } - fn poll_cmd(&self) -> Result<()> { - let cmd = CrossDomainPoll::new(); - self.submit_cmd( - &cmd, - mem::size_of::(), - Some(CROSS_DOMAIN_CHANNEL_RING), - None, - ) - } -} - -fn submit_cmd_raw( - fd: c_int, - cmd: &T, - cmd_size: usize, - ring_idx: Option, - ring_handle: Option, -) -> Result<()> { - let cmd_buf = cmd as *const T as *const u8; - let mut exec = DrmVirtgpuExecbuffer { - command: cmd_buf as u64, - size: cmd_size as u32, - ..DrmVirtgpuExecbuffer::default() - }; - if let Some(ring_idx) = ring_idx { - exec.ring_idx = ring_idx; - exec.flags = VIRTGPU_EXECBUF_RING_IDX; - } - let ring_handle = &ring_handle; - if let Some(ring_handle) = ring_handle { - exec.bo_handles = ring_handle as *const u32 as u64; - exec.num_bo_handles = 1; - } - unsafe { - drm_virtgpu_execbuffer(fd, &mut exec)?; - } - if ring_handle.is_some() { - unimplemented!(); - } - Ok(()) -} - -struct DebugLoopInner { - ls_remote: TcpStream, - ls_local: TcpStream, -} - -struct DebugLoop(Option); - -impl DebugLoop { - fn new() -> DebugLoop { - if !env::var("X11VG_DEBUG") - .map(|x| x == "1") - .unwrap_or_default() - { - return DebugLoop(None); - } - let ls_remote_l = TcpListener::bind(("0.0.0.0", 6001)).unwrap(); - let ls_local_jh = thread::spawn(|| TcpStream::connect(("0.0.0.0", 6001)).unwrap()); - let ls_remote = ls_remote_l.accept().unwrap().0; - let ls_local = ls_local_jh.join().unwrap(); - DebugLoop(Some(DebugLoopInner { - ls_remote, - ls_local, - })) - } - fn loop_remote(&mut self, data: &[u8]) { - if let Some(this) = &mut self.0 { - this.ls_remote.write_all(data).unwrap(); - let mut trash = vec![0; data.len()]; - this.ls_local.read_exact(&mut trash).unwrap(); - } - } - fn loop_local(&mut self, data: &[u8]) { - if let Some(this) = &mut self.0 { - this.ls_local.write_all(data).unwrap(); - let mut trash = vec![0; data.len()]; - this.ls_remote.read_exact(&mut trash).unwrap(); - } - } -} - -struct SendPacket { - data: Vec, - fds: Vec, -} - -struct Client { - // futex_watchers must be dropped before gpu_ctx, so it goes first - futex_watchers: HashMap, - gpu_ctx: Context, - socket: UnixStream, - got_first_req: bool, - got_first_resp: bool, - dri3_ext_opcode: Option, - dri3_qe_resp_seq: Option, - sync_ext_opcode: Option, - sync_qe_resp_seq: Option, - present_ext_opcode: Option, - present_qe_resp_seq: Option, - seq_no: u16, - reply_tail: usize, - reply_head: Vec, - request_tail: usize, - request_head: Vec, - request_fds: Vec, - debug_loop: DebugLoop, - buffers_for_pixmap: HashMap>, - send_queue: VecDeque, -} - -#[derive(Clone)] -struct FutexPtr(*mut c_void); - -unsafe impl Send for FutexPtr {} - -impl Drop for FutexPtr { - fn drop(&mut self) { - unsafe { - munmap(NonNull::new_unchecked(self.0), 4).unwrap(); - } - } -} - -fn extract_opcode_from_qe_resp(data: &[u8], ptr: usize) -> Option { - if data[ptr + 8] != 0 { - Some(data[ptr + 9]) - } else { - None - } -} - -struct FutexWatcherThread { - join_handle: Option>, - shutdown: Arc, - futex: FutexPtr, -} - -unsafe fn wake_futex(futex: *mut c_void, val3: u32) { - let op = nix::libc::FUTEX_WAKE_BITSET; - let val = c_int::MAX; - let timeout = ptr::null::<()>(); - let uaddr2 = ptr::null::<()>(); - unsafe { - nix::libc::syscall(nix::libc::SYS_futex, futex, op, val, timeout, uaddr2, val3); - } -} - -impl FutexWatcherThread { - fn new(fd: c_int, xid: u32, futex: FutexPtr, initial_value: u32) -> FutexWatcherThread { - let shutdown = Arc::new(AtomicBool::new(false)); - let shutdown2 = shutdown.clone(); - let futex2 = futex.clone(); - let handle = thread::spawn(move || { - let uaddr = futex2; - let op = nix::libc::FUTEX_WAIT_BITSET; - let timeout = ptr::null::<()>(); - let uaddr2 = ptr::null::<()>(); - let val3 = 1u32; - let mut val = initial_value; - let atomic_val = unsafe { AtomicU32::from_ptr(uaddr.0 as *mut u32) }; - loop { - if shutdown2.load(Ordering::SeqCst) { - break; - } - unsafe { - nix::libc::syscall( - nix::libc::SYS_futex, - uaddr.0, - op, - val, - timeout, - uaddr2, - val3, - ); - } - val = atomic_val.load(Ordering::SeqCst); - let ft_signal_msg_size = mem::size_of::(); - let ft_signal_cmd = CrossDomainFutexSignal { - hdr: CrossDomainHeader::new( - CROSS_DOMAIN_CMD_FUTEX_SIGNAL, - ft_signal_msg_size as u16, - ), - id: xid, - pad: 0, - }; - submit_cmd_raw(fd, &ft_signal_cmd, ft_signal_msg_size, None, None).unwrap(); - } - }); - FutexWatcherThread { - futex, - join_handle: Some(handle), - shutdown, - } - } - - fn signal(&self) { - unsafe { - wake_futex(self.futex.0, !1); - } - } -} - -impl Drop for FutexWatcherThread { - fn drop(&mut self) { - self.shutdown.store(true, Ordering::Release); - let atomic_val = unsafe { AtomicU32::from_ptr(self.futex.0 as *mut u32) }; - let v = atomic_val.load(Ordering::SeqCst); - atomic_val.store(!v, Ordering::SeqCst); - unsafe { - wake_futex(self.futex.0, !0); - } - self.join_handle.take().unwrap().join().unwrap(); - } -} - -#[allow(dead_code)] -struct RemoteCaller { - pid: Pid, - regs: user_regs_struct, -} - -impl RemoteCaller { - // This is arch-specific, so gate it off of x86_64 builds done for CI purposes - #[cfg(target_arch = "aarch64")] - fn with(pid: Pid, f: F) -> Result - where - F: FnOnce(&RemoteCaller) -> Result, - { - let old_regs = ptrace::getregs(pid)?; - - // Find the vDSO and the address of a syscall instruction within it - let (vdso_start, _) = find_vdso(Some(pid))?; - let syscall_addr = vdso_start + SYSCALL_OFFSET.get().unwrap(); - - let mut regs = old_regs; - regs.pc = syscall_addr as u64; - ptrace::setregs(pid, regs)?; - let res = f(&RemoteCaller { regs, pid })?; - ptrace::setregs(pid, old_regs)?; - Ok(res) - } - fn dup2(&self, oldfd: i32, newfd: i32) -> Result { - self.syscall(SYS_dup3, [oldfd as u64, newfd as u64, 0, 0, 0, 0]) - .map(|x| x as i32) - } - fn close(&self, fd: i32) -> Result { - self.syscall(SYS_close, [fd as u64, 0, 0, 0, 0, 0]) - .map(|x| x as i32) - } - fn mmap( - &self, - addr: usize, - length: usize, - prot: i32, - flags: i32, - fd: i32, - offset: usize, - ) -> Result { - self.syscall( - SYS_mmap, - [ - addr as u64, - length as u64, - prot as u64, - flags as u64, - fd as u64, - offset as u64, - ], - ) - .map(|x| x as usize) - } - fn munmap(&self, addr: usize, length: usize) -> Result { - self.syscall(SYS_munmap, [addr as u64, length as u64, 0, 0, 0, 0]) - .map(|x| x as i32) - } - fn open(&self, path: usize, flags: i32, mode: i32) -> Result { - self.syscall( - SYS_openat, - [ - AT_FDCWD as u64, - path as u64, - flags as u64, - mode as u64, - 0, - 0, - ], - ) - .map(|x| x as i32) - } - - // This is arch-specific, so gate it off of x86_64 builds done for CI purposes - #[cfg(target_arch = "aarch64")] - fn syscall(&self, syscall_no: c_long, args: [c_ulonglong; 6]) -> Result { - let mut regs = self.regs; - regs.regs[..6].copy_from_slice(&args); - regs.regs[8] = syscall_no as c_ulonglong; - ptrace::setregs(self.pid, regs)?; - ptrace::step(self.pid, None)?; - let evt = waitpid(self.pid, Some(WaitPidFlag::__WALL))?; - if !matches!(evt, WaitStatus::Stopped(_, _)) { - unimplemented!(); - } - regs = ptrace::getregs(self.pid)?; - Ok(regs.regs[0]) - } - - #[cfg(not(target_arch = "aarch64"))] - fn with(_pid: Pid, _f: F) -> Result - where - F: FnOnce(&RemoteCaller) -> Result, - { - Err(Errno::EOPNOTSUPP.into()) - } - #[cfg(not(target_arch = "aarch64"))] - fn syscall(&self, _syscall_no: c_long, _args: [c_ulonglong; 6]) -> Result { - Err(Errno::EOPNOTSUPP.into()) - } -} - -fn wait_for_stop(pid: Pid) -> Result<()> { - loop { - let event = waitpid(pid, Some(WaitPidFlag::__WALL))?; - match event { - WaitStatus::Stopped(_, sig) => { - if sig == Signal::SIGSTOP { - return Ok(()); - } else { - ptrace::cont(pid, sig)?; - } - }, - _ => unimplemented!(), - } - } -} - -struct PtracedPid(Pid); - -impl PtracedPid { - fn pid(&self) -> Pid { - self.0 - } -} - -impl Drop for PtracedPid { - fn drop(&mut self) { - if ptrace::detach(self.0, None).is_err() { - eprintln!("Failed to ptrace::detach({}) (continuing)", self.0); - } - } -} - -#[derive(Debug)] -enum ClientEvent { - None, - StartSend, - StopSend, - Close, -} - -impl Client { - fn new(socket: UnixStream) -> Result { - Ok(Client { - socket, - gpu_ctx: Context::new()?, - got_first_req: false, - dri3_ext_opcode: None, - dri3_qe_resp_seq: None, - sync_ext_opcode: None, - sync_qe_resp_seq: None, - present_qe_resp_seq: None, - present_ext_opcode: None, - seq_no: 1, - reply_tail: 0, - reply_head: Vec::new(), - got_first_resp: false, - request_tail: 0, - request_head: Vec::new(), - request_fds: Vec::new(), - futex_watchers: HashMap::new(), - debug_loop: DebugLoop::new(), - buffers_for_pixmap: HashMap::new(), - send_queue: VecDeque::new(), - }) - } - fn process_socket(&mut self, events: EpollFlags) -> Result { - if events.contains(EpollFlags::EPOLLIN) { - let queue_empty = self.send_queue.is_empty(); - if self.process_socket_recv()? { - return Ok(ClientEvent::Close); - } - if queue_empty && !self.send_queue.is_empty() { - return Ok(ClientEvent::StartSend); - } - } - if events.contains(EpollFlags::EPOLLOUT) { - self.process_socket_send()?; - if self.send_queue.is_empty() { - return Ok(ClientEvent::StopSend); - } - } - Ok(ClientEvent::None) - } - - fn process_socket_send(&mut self) -> Result<()> { - let mut msg = self.send_queue.pop_front().unwrap(); - let fds: Vec = msg.fds.iter().map(|a| a.as_raw_fd()).collect(); - let cmsgs = if fds.is_empty() { - Vec::new() - } else { - vec![ControlMessage::ScmRights(&fds)] - }; - match sendmsg::<()>( - self.socket.as_raw_fd(), - &[IoSlice::new(&msg.data)], - &cmsgs, - MsgFlags::empty(), - None, - ) { - Ok(sent) => { - if sent < msg.data.len() { - msg.data = msg.data.split_off(sent); - self.send_queue.push_front(SendPacket { - data: msg.data.split_off(sent), - fds: Vec::new(), - }); - } - }, - Err(Errno::EAGAIN) => self.send_queue.push_front(msg), - Err(e) => return Err(e.into()), - }; - Ok(()) - } - fn process_socket_recv(&mut self) -> Result { - let mut fdspace = cmsg_space!([RawFd; CROSS_DOMAIN_MAX_IDENTIFIERS]); - let mut ring_msg = CrossDomainSendReceive { - hdr: CrossDomainHeader::new(CROSS_DOMAIN_CMD_SEND, 0), - num_identifiers: 0, - opaque_data_size: 0, - identifiers: [0; CROSS_DOMAIN_MAX_IDENTIFIERS], - identifier_types: [0; CROSS_DOMAIN_MAX_IDENTIFIERS], - identifier_sizes: [0; CROSS_DOMAIN_MAX_IDENTIFIERS], - data: [0u8; CROSS_DOMAIN_SR_TAIL_SIZE], - }; - let recv_buf = if self.request_tail > 0 { - assert!(self.request_head.is_empty()); - assert!(self.request_fds.is_empty()); - let len = self.request_tail.min(ring_msg.data.len()); - &mut ring_msg.data[..len] - } else { - let head_len = self.request_head.len(); - ring_msg.data[..head_len].copy_from_slice(&self.request_head); - self.request_head.clear(); - &mut ring_msg.data[head_len..] - }; - let mut ioslice = [IoSliceMut::new(recv_buf)]; - let msg: RecvMsg<()> = recvmsg( - self.socket.as_raw_fd(), - &mut ioslice, - Some(&mut fdspace), - MsgFlags::empty(), - )?; - for cmsg in msg.cmsgs()? { - match cmsg { - ControlMessageOwned::ScmRights(rf) => { - for fd in rf { - self.request_fds.push(unsafe { OwnedFd::from_raw_fd(fd) }); - } - }, - _ => unimplemented!(), - } - } - let len = if let Some(iov) = msg.iovs().next() { - iov.len() - } else { - return Ok(true); - }; - let buf = &mut ring_msg.data[..len]; - self.debug_loop.loop_local(buf); - let mut fd_xids = [None; CROSS_DOMAIN_MAX_IDENTIFIERS]; - let mut cur_fd_for_msg = 0; - let mut fences_to_destroy = Vec::new(); - if !self.got_first_req { - self.got_first_req = true; - } else if self.request_tail > 0 { - assert!(self.request_fds.is_empty()); - self.request_tail -= buf.len(); - } else if self.request_tail == 0 { - let mut ptr = 0; - while ptr < buf.len() { - if buf.len() - ptr < 4 { - eprintln!( - "X11 message truncated (expected at least 4 bytes, got {}:{} = {})", - ptr, - buf.len(), - buf.len() - ptr - ); - break; - } - let mut req_len = - u16::from_ne_bytes(buf[(ptr + 2)..(ptr + 4)].try_into().unwrap()) as usize * 4; - if req_len == 0 { - if buf.len() - ptr < 8 { - eprintln!( - "X11 message truncated (expected at least 8 bytes, got {}:{} = {})", - ptr, - buf.len(), - buf.len() - ptr - ); - break; - } - req_len = u32::from_ne_bytes(buf[(ptr + 4)..(ptr + 8)].try_into().unwrap()) - as usize - * 4; - } - if buf[ptr] == X11_OPCODE_QUERY_EXTENSION { - let namelen = - u16::from_ne_bytes(buf[(ptr + 4)..(ptr + 6)].try_into().unwrap()) as usize; - let name = String::from_utf8_lossy(&buf[(ptr + 8)..(ptr + 8 + namelen)]); - if name == "DRI3" { - self.dri3_qe_resp_seq = Some(self.seq_no); - } else if name == "SYNC" { - self.sync_qe_resp_seq = Some(self.seq_no) - } else if name == "Present" { - self.present_qe_resp_seq = Some(self.seq_no); - } - } else if Some(buf[ptr]) == self.dri3_ext_opcode { - if buf[ptr + 1] == DRI3_OPCODE_VERSION { - buf[ptr + 8] = buf[ptr + 8].min(3); - } else if buf[ptr + 1] == DRI3_OPCODE_OPEN { - buf[ptr] = X11_OPCODE_NOP; - let mut reply = - vec![1, 1, (self.seq_no & 0xff) as u8, (self.seq_no >> 8) as u8]; - reply.extend_from_slice(&[0u8; 28]); - let render = File::options() - .read(true) - .write(true) - .open("/dev/dri/renderD128")?; - self.send_queue.push_back(SendPacket { - data: reply, - fds: vec![render.into()], - }); - } else if buf[ptr + 1] == DRI3_OPCODE_PIXMAP_FROM_BUFFER { - let xid = u32::from_ne_bytes(buf[(ptr + 4)..(ptr + 8)].try_into().unwrap()); - fd_xids[cur_fd_for_msg] = Some(xid); - cur_fd_for_msg += 1; - } else if buf[ptr + 1] == DRI3_OPCODE_FENCE_FROM_FD { - let xid = - u32::from_ne_bytes(buf[(ptr + 8)..(ptr + 12)].try_into().unwrap()); - fd_xids[cur_fd_for_msg] = Some(xid); - cur_fd_for_msg += 1; - } else if buf[ptr + 1] == DRI3_OPCODE_PIXMAP_FROM_BUFFERS { - let xid = u32::from_ne_bytes(buf[(ptr + 4)..(ptr + 8)].try_into().unwrap()); - let num_bufs = buf[ptr + 12] as usize; - for i in 0..num_bufs { - fd_xids[cur_fd_for_msg + i] = Some(xid); - } - cur_fd_for_msg += num_bufs; - } - } else if Some(buf[ptr]) == self.sync_ext_opcode { - if buf[ptr + 1] == SYNC_OPCODE_DESTROY_FENCE { - let xid = u32::from_ne_bytes(buf[(ptr + 4)..(ptr + 8)].try_into().unwrap()); - fences_to_destroy.push(xid); - } - } else if Some(buf[ptr]) == self.present_ext_opcode { - if buf[ptr + 1] == PRESENT_OPCODE_PRESENT_PIXMAP { - /* TODO: Implement GPU fence passing here when we have it. */ - } - } else if buf[ptr] == X11_OPCODE_CREATE_PIXMAP { - let xid = u32::from_ne_bytes(buf[(ptr + 4)..(ptr + 8)].try_into().unwrap()); - self.buffers_for_pixmap.insert(xid, Vec::new()); - } else if buf[ptr] == X11_OPCODE_FREE_PIXMAP { - let xid = u32::from_ne_bytes(buf[(ptr + 4)..(ptr + 8)].try_into().unwrap()); - self.buffers_for_pixmap.remove(&xid); - } - self.seq_no = self.seq_no.wrapping_add(1); - ptr += req_len; - } - if ptr < buf.len() { - self.request_head = buf[ptr..].to_vec(); - } else { - self.request_tail = ptr - buf.len(); - } - } - if self.request_head.is_empty() { - assert_eq!(cur_fd_for_msg, self.request_fds.len()); - } else { - assert_eq!(self.request_tail, 0); - assert!(cur_fd_for_msg <= self.request_fds.len()); - } - let send_len = buf.len() - self.request_head.len(); - let size = mem::size_of::>() + send_len; - ring_msg.opaque_data_size = send_len as u32; - ring_msg.hdr.cmd_size = size as u16; - ring_msg.num_identifiers = cur_fd_for_msg as u32; - let mut gem_handles = Vec::with_capacity(cur_fd_for_msg); - let fds: Vec = self.request_fds.drain(..cur_fd_for_msg).collect(); - for (i, fd) in fds.into_iter().enumerate() { - let filename = readlink(format!("/proc/self/fd/{}", fd.as_raw_fd()).as_str())?; - let filename = filename.to_string_lossy(); - if filename.starts_with("/dmabuf:") { - let gh = self.vgpu_id_from_prime(&mut ring_msg, i, &fd_xids, fd)?; - gem_handles.push(gh); - continue; - } - let creds = getsockopt(&self.socket.as_fd(), PeerCredentials)?; - self.create_cross_vm_futex(&mut ring_msg, i, fd, &fd_xids, creds.pid(), filename)?; - } - self.gpu_ctx.submit_cmd(&ring_msg, size, None, None)?; - for gem_handle in gem_handles { - unsafe { - let close = DrmGemClose::new(gem_handle); - drm_gem_close(self.gpu_ctx.fd.as_raw_fd() as c_int, &close)?; - } - } - for xid in fences_to_destroy { - self.futex_watchers.remove(&xid).unwrap(); - let ft_destroy_msg_size = mem::size_of::(); - let ft_msg = CrossDomainFutexDestroy { - hdr: CrossDomainHeader::new( - CROSS_DOMAIN_CMD_FUTEX_DESTROY, - ft_destroy_msg_size as u16, - ), - id: xid, - pad: 0, - }; - self.gpu_ctx - .submit_cmd(&ft_msg, ft_destroy_msg_size, None, None)?; - } - Ok(false) - } - fn vgpu_id_from_prime( - &mut self, - ring_msg: &mut CrossDomainSendReceive, - i: usize, - fd_xids: &[Option], - fd: OwnedFd, - ) -> Result { - let mut to_handle = DrmPrimeHandle { - fd: fd.as_raw_fd(), - ..DrmPrimeHandle::default() - }; - unsafe { - drm_prime_fd_to_handle(self.gpu_ctx.fd.as_raw_fd() as c_int, &mut to_handle)?; - } - self.buffers_for_pixmap - .entry(fd_xids[i].unwrap()) - .or_default() - .push(fd); - let mut res_info = DrmVirtgpuResourceInfo { - bo_handle: to_handle.handle, - ..DrmVirtgpuResourceInfo::default() - }; - unsafe { - drm_virtgpu_resource_info(self.gpu_ctx.fd.as_raw_fd() as c_int, &mut res_info)?; - } - ring_msg.identifiers[i] = res_info.res_handle; - ring_msg.identifier_types[i] = CROSS_DOMAIN_ID_TYPE_VIRTGPU_BLOB; - Ok(to_handle.handle) - } - - fn ptrace_all_threads(pid: Pid) -> Result> { - let mut tids = Vec::new(); - for entry in fs::read_dir(format!("/proc/{pid}/task"))? { - let entry = match entry { - Err(_) => continue, - Ok(a) => a, - }; - let tid = Pid::from_raw( - entry - .file_name() - .into_string() - .or(Err(Errno::EIO))? - .parse()?, - ); - if let Err(e) = ptrace::attach(tid) { - // This could be a race (thread exited), so keep going - // unless this is the top-level PID - if tid == pid { - return Err(e.into()); - } - eprintln!("ptrace::attach({pid}, ...) failed (continuing)"); - continue; - } - let ptid = PtracedPid(tid); - wait_for_stop(ptid.pid())?; - tids.push(ptid); - } - Ok(tids) - } - - fn replace_futex_storage( - my_fd: RawFd, - pid: Pid, - shmem_path: &str, - shmem_file: &mut File, - ) -> Result<()> { - let traced = Self::ptrace_all_threads(pid)?; - - let mut data = [0; 4]; - read(my_fd, &mut data)?; - shmem_file.write_all(&data)?; - - // TODO: match st_dev too to avoid false positives - let my_ino = fstat(my_fd)?.st_ino; - let mut fds_to_replace = Vec::new(); - for entry in fs::read_dir(format!("/proc/{pid}/fd"))? { - let entry = entry?; - if let Ok(file) = File::options().open(entry.path()) { - if fstat(file.as_raw_fd())?.st_ino == my_ino { - fds_to_replace.push(entry.file_name().to_string_lossy().parse::()?); - } - } - } - let mut pages_to_replace = Vec::new(); - for line in read_to_string(format!("/proc/{pid}/maps"))?.lines() { - let f: Vec<&str> = line.split_whitespace().collect(); - let ino: u64 = f[4].parse()?; - if ino == my_ino { - let addr = usize::from_str_radix(f[0].split('-').next().unwrap(), 16)?; - pages_to_replace.push(addr); - } - } - RemoteCaller::with(pid, |caller| { - let scratch_page = caller.mmap( - 0, - PAGE_SIZE, - PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, - 0, - 0, - )?; - let path_cstr = CString::new(shmem_path).unwrap(); - process_vm_writev( - pid, - &[IoSlice::new(path_cstr.as_bytes_with_nul())], - &[RemoteIoVec { - base: scratch_page, - len: path_cstr.len(), - }], - )?; - let remote_shm = caller.open(scratch_page, O_CLOEXEC | O_RDWR, 0o600)?; - for fd in fds_to_replace { - caller.dup2(remote_shm, fd)?; - } - for page in pages_to_replace { - caller.mmap( - page, - PAGE_SIZE, - PROT_READ | PROT_WRITE, - MAP_SHARED | MAP_FIXED, - remote_shm, - 0, - )?; - } - caller.munmap(scratch_page, PAGE_SIZE)?; - caller.close(remote_shm)?; - Ok(()) - })?; - // This detaches all the traced threads - mem::drop(traced); - Ok(()) - } - fn create_cross_vm_futex( - &mut self, - ring_msg: &mut CrossDomainSendReceive, - i: usize, - memfd: OwnedFd, - fd_xids: &[Option], - pid: pid_t, - filename: Cow<'_, str>, - ) -> Result<()> { - // Allow everything in /dev/shm (including paths with trailing '(deleted)') - let shmem_file = if filename.starts_with(SHM_DIR) { - File::from(memfd) - } else if cfg!(not(target_arch = "aarch64")) { - return Err(Errno::EOPNOTSUPP.into()); - } else { - let (fd, shmem_path) = mkstemp(SHM_TEMPLATE)?; - let mut shmem_file = unsafe { File::from_raw_fd(fd) }; - let ret = Self::replace_futex_storage( - memfd.as_raw_fd(), - Pid::from_raw(pid), - shmem_path.as_os_str().to_str().unwrap(), - &mut shmem_file, - ); - remove_file(&shmem_path)?; - ret?; - shmem_file - }; - - let mut handle: ExportedHandle = Default::default(); - unsafe { virtio_export_handle(shmem_file.as_raw_fd(), &mut handle) }?; - - let addr = FutexPtr(unsafe { - mmap( - None, - 4.try_into().unwrap(), - ProtFlags::PROT_WRITE | ProtFlags::PROT_READ, - MapFlags::MAP_SHARED, - shmem_file, - 0, - )? - .as_ptr() - }); - let initial_value = - unsafe { AtomicU32::from_ptr(addr.0 as *mut u32) }.load(Ordering::Relaxed); - - let ft_new_msg_size = mem::size_of::(); - let ft_msg = CrossDomainFutexNew { - hdr: CrossDomainHeader::new(CROSS_DOMAIN_CMD_FUTEX_NEW, ft_new_msg_size as u16), - id: fd_xids[i].unwrap(), - fs_id: handle.fs_id, - handle: handle.handle, - pad: 0, - }; - self.gpu_ctx - .submit_cmd(&ft_msg, ft_new_msg_size, None, None)?; - let sync_xid = fd_xids[i].unwrap(); - let fd = self.gpu_ctx.fd.as_raw_fd() as c_int; - // TODO: do we need to wait here? - //thread::sleep(Duration::from_millis(33)); - self.futex_watchers.insert( - sync_xid, - FutexWatcherThread::new(fd, sync_xid, addr, initial_value), - ); - ring_msg.identifiers[i] = sync_xid; - ring_msg.identifier_types[i] = CROSS_DOMAIN_ID_TYPE_SHM; - Ok(()) - } - fn process_vgpu(&mut self) -> Result { - let mut evt = DrmEvent::default(); - read(self.gpu_ctx.fd.as_raw_fd(), unsafe { - slice::from_raw_parts_mut( - &mut evt as *mut DrmEvent as *mut u8, - mem::size_of::(), - ) - })?; - assert_eq!(evt.ty, VIRTGPU_EVENT_FENCE_SIGNALED); - let cmd = unsafe { - (self.gpu_ctx.channel_ring.address as *const CrossDomainHeader) - .as_ref() - .unwrap() - .cmd - }; - match cmd { - CROSS_DOMAIN_CMD_RECEIVE => { - let recv = unsafe { - (self.gpu_ctx.channel_ring.address - as *const CrossDomainSendReceive<[u8; CROSS_DOMAIN_SR_TAIL_SIZE]>) - .as_ref() - .unwrap() - }; - if recv.opaque_data_size == 0 { - return Ok(true); - } - self.process_receive(recv)?; - }, - CROSS_DOMAIN_CMD_FUTEX_SIGNAL => { - let recv = unsafe { - (self.gpu_ctx.channel_ring.address as *const CrossDomainFutexSignal) - .as_ref() - .unwrap() - }; - self.process_futex_signal(recv)?; - }, - a => { - eprintln!("Received unknown cross-domain command {a}"); - }, - }; - self.gpu_ctx.poll_cmd()?; - Ok(false) - } - fn process_receive(&mut self, recv: &CrossDomainSendReceive<[u8]>) -> Result<()> { - let mut owned_fds = Vec::with_capacity(recv.num_identifiers as usize); - for i in 0..recv.num_identifiers as usize { - assert_eq!(recv.identifier_types[i], CROSS_DOMAIN_ID_TYPE_VIRTGPU_BLOB); - let mut create_blob = DrmVirtgpuResourceCreateBlob { - blob_mem: VIRTGPU_BLOB_MEM_HOST3D, - size: recv.identifier_sizes[i] as u64, - blob_id: recv.identifiers[i] as u64, - blob_flags: VIRTGPU_BLOB_FLAG_USE_MAPPABLE | VIRTGPU_BLOB_FLAG_USE_SHAREABLE, - ..DrmVirtgpuResourceCreateBlob::default() - }; - unsafe { - drm_virtgpu_resource_create_blob( - self.gpu_ctx.fd.as_raw_fd() as c_int, - &mut create_blob, - )?; - } - let mut to_fd = DrmPrimeHandle { - handle: create_blob.bo_handle, - flags: O_RDWR as u32, - fd: -1, - }; - unsafe { - drm_prime_handle_to_fd(self.gpu_ctx.fd.as_raw_fd() as c_int, &mut to_fd)?; - let close = DrmGemClose::new(create_blob.bo_handle); - drm_gem_close(self.gpu_ctx.fd.as_raw_fd() as c_int, &close)?; - } - unsafe { owned_fds.push(OwnedFd::from_raw_fd(to_fd.fd)) } - } - let data = &recv.data[..(recv.opaque_data_size as usize)]; - self.debug_loop.loop_remote(data); - if !self.got_first_resp { - self.got_first_resp = true; - self.reply_tail = u16::from_ne_bytes(data[6..8].try_into().unwrap()) as usize * 4 + 8; - } - let data = if self.reply_tail > 0 { - assert!(self.reply_head.is_empty()); - let block = self.reply_tail.min(data.len()); - let (block_data, data) = data.split_at(block); - // If we have a reply tail, we need to send it separately. This is to ensure - // that no fds are attached to it, since libxcb cannot handle fds not - // attached to a packet header. - self.send_queue.push_back(SendPacket { - data: block_data.into(), - fds: Vec::new(), - }); - - self.reply_tail -= block; - data - } else { - data - }; - assert!(self.reply_tail == 0 || data.is_empty()); - if data.is_empty() { - assert!(owned_fds.is_empty()); - return Ok(()); - } - - let data = if self.reply_head.is_empty() { - data.to_vec() - } else { - let mut new_data = core::mem::take(&mut self.reply_head); - new_data.extend_from_slice(data); - new_data - }; - - let mut ptr = 0; - while ptr < data.len() { - if data.len() - ptr < 32 { - eprintln!( - "X11 message truncated (expected at least 32 bytes, got {}:{} = {})", - ptr, - data.len(), - data.len() - ptr - ); - break; - } - let seq_no = u16::from_ne_bytes(data[(ptr + 2)..(ptr + 4)].try_into().unwrap()); - let is_reply = data[ptr] == X11_REPLY; - let is_generic = data[ptr] == X11_GENERIC_EVENT; - let len = if is_reply || is_generic { - u32::from_ne_bytes(data[(ptr + 4)..(ptr + 8)].try_into().unwrap()) as usize * 4 - } else { - 0 - } + 32; - if is_reply { - if Some(seq_no) == self.dri3_qe_resp_seq { - self.dri3_qe_resp_seq = None; - self.dri3_ext_opcode = extract_opcode_from_qe_resp(&data, ptr); - } else if Some(seq_no) == self.sync_qe_resp_seq { - self.sync_qe_resp_seq = None; - self.sync_ext_opcode = extract_opcode_from_qe_resp(&data, ptr); - } else if Some(seq_no) == self.present_qe_resp_seq { - self.present_qe_resp_seq = None; - self.present_ext_opcode = extract_opcode_from_qe_resp(&data, ptr); - } - } - ptr += len; - } - let block = if ptr < data.len() { - let (block, next_head) = data.split_at(ptr); - self.reply_head = next_head.to_vec(); - block.to_vec() - } else { - self.reply_tail = ptr - data.len(); - data.to_vec() - }; - self.send_queue.push_back(SendPacket { - data: block, - fds: owned_fds, - }); - Ok(()) - } - fn process_futex_signal(&mut self, recv: &CrossDomainFutexSignal) -> Result<()> { - let watcher = match self.futex_watchers.get(&recv.id) { - Some(a) => a, - None => { - eprintln!("Unknown futex id {}", recv.id); - return Ok(()); - }, - }; - - watcher.signal(); - - Ok(()) - } -} - -fn find_vdso(pid: Option) -> Result<(usize, usize), Errno> { - let path = format!( - "/proc/{}/maps", - pid.map(|a| a.to_string()).unwrap_or("self".into()) - ); - - for line in read_to_string(path).unwrap().lines() { - if line.ends_with("[vdso]") { - let a = line.find('-').ok_or(Errno::EINVAL)?; - let b = line.find(' ').ok_or(Errno::EINVAL)?; - let start = usize::from_str_radix(&line[..a], 16).or(Err(Errno::EINVAL))?; - let end = usize::from_str_radix(&line[a + 1..b], 16).or(Err(Errno::EINVAL))?; - - return Ok((start, end)); - } - } - - Err(Errno::EINVAL) -} - -pub fn start_x11bridge(display: u32) { - let sock_path = format!("/tmp/.X11-unix/X{display}"); - - // Look for a syscall instruction in the vDSO. We assume all processes map - // the same vDSO (which should be true if they are running under the same - // kernel!) - let (vdso_start, vdso_end) = find_vdso(None).unwrap(); - for off in (0..(vdso_end - vdso_start)).step_by(4) { - let addr = vdso_start + off; - let val = unsafe { std::ptr::read(addr as *const u32) }; - if val == SYSCALL_INSTR { - SYSCALL_OFFSET.set(off).unwrap(); - break; - } - } - if SYSCALL_OFFSET.get().is_none() { - eprintln!("Failed to find syscall instruction in vDSO"); - exit(1); - } - - let epoll = Epoll::new(EpollCreateFlags::empty()).unwrap(); - _ = fs::remove_file(&sock_path); - let listen_sock = UnixListener::bind(sock_path).unwrap(); - epoll - .add( - &listen_sock, - EpollEvent::new(EpollFlags::EPOLLIN, listen_sock.as_raw_fd() as u64), - ) - .unwrap(); - let mut client_sock = HashMap::>>::new(); - let mut client_vgpu = HashMap::>>::new(); - loop { - let mut evts = [EpollEvent::empty(); 16]; - let count = match epoll.wait(&mut evts, EpollTimeout::NONE) { - Err(Errno::EINTR) | Ok(0) => continue, - a => a.unwrap(), - }; - for evt in &evts[..count.min(evts.len())] { - let fd = evt.data(); - let events = evt.events(); - if fd == listen_sock.as_raw_fd() as u64 { - let res = listen_sock.accept(); - if res.is_err() { - eprintln!( - "Failed to accept a connection, error: {:?}", - res.unwrap_err() - ); - continue; - } - let stream = res.unwrap().0; - stream.set_nonblocking(true).unwrap(); - let client = Rc::new(RefCell::new(Client::new(stream).unwrap())); - client_sock.insert(client.borrow().socket.as_raw_fd() as u64, client.clone()); - epoll - .add( - &client.borrow().socket, - EpollEvent::new( - EpollFlags::EPOLLIN, - client.borrow().socket.as_raw_fd() as u64, - ), - ) - .unwrap(); - client_vgpu.insert( - client.borrow().gpu_ctx.fd.as_raw_fd() as u64, - client.clone(), - ); - epoll - .add( - &client.borrow().gpu_ctx.fd, - EpollEvent::new( - EpollFlags::EPOLLIN, - client.borrow().gpu_ctx.fd.as_raw_fd() as u64, - ), - ) - .unwrap(); - } else if let Some(client) = client_sock.get_mut(&fd) { - let event = client - .borrow_mut() - .process_socket(events) - .map_err(|e| { - eprintln!("Client {fd} disconnected with error: {e:?}"); - e - }) - .unwrap_or(ClientEvent::Close); - match event { - ClientEvent::None => (), - ClientEvent::StartSend => { - epoll - .modify( - &client.borrow().socket, - &mut EpollEvent::new( - EpollFlags::EPOLLOUT | EpollFlags::EPOLLIN, - client.borrow().socket.as_raw_fd() as u64, - ), - ) - .unwrap(); - }, - ClientEvent::StopSend => { - epoll - .modify( - &client.borrow().socket, - &mut EpollEvent::new( - EpollFlags::EPOLLIN, - client.borrow().socket.as_raw_fd() as u64, - ), - ) - .unwrap(); - }, - ClientEvent::Close => { - let client = client.borrow(); - let gpu_fd = client.gpu_ctx.fd.as_fd(); - epoll.delete(gpu_fd).unwrap(); - epoll.delete(&client.socket).unwrap(); - let gpu_fd = gpu_fd.as_raw_fd() as u64; - drop(client); - client_vgpu.remove(&gpu_fd).unwrap(); - client_sock.remove(&fd).unwrap(); - }, - } - } else if let Some(client) = client_vgpu.get_mut(&fd) { - let queue_empty = client.borrow().send_queue.is_empty(); - let close = client - .borrow_mut() - .process_vgpu() - .map_err(|e| { - eprintln!("Server {fd} disconnected with error: {e:?}"); - e - }) - .unwrap_or(true); - if close { - let client = client.borrow(); - let gpu_fd = client.gpu_ctx.fd.as_fd(); - epoll.delete(gpu_fd).unwrap(); - let client_fd = client.socket.as_raw_fd() as u64; - epoll.delete(&client.socket).unwrap(); - drop(client); - client_vgpu.remove(&fd).unwrap(); - client_sock.remove(&client_fd).unwrap(); - } else if queue_empty && !client.borrow().send_queue.is_empty() { - epoll - .modify( - &client.borrow().socket, - &mut EpollEvent::new( - EpollFlags::EPOLLOUT | EpollFlags::EPOLLIN, - client.borrow().socket.as_raw_fd() as u64, - ), - ) - .unwrap(); - } - } - } - } -} diff --git a/share/wireplumber/scripts/client/access-muvm.lua b/share/wireplumber/scripts/client/access-muvm.lua new file mode 100644 index 0000000..df05fd9 --- /dev/null +++ b/share/wireplumber/scripts/client/access-muvm.lua @@ -0,0 +1,143 @@ +MEDIA_ROLE_NONE = 0 +MEDIA_ROLE_CAMERA = 1 << 0 + +log = Log.open_topic ("s-client") + +function hasPermission (permissions, app_id, lookup) + if permissions then + for key, values in pairs(permissions) do + if key == app_id then + for _, v in pairs(values) do + if v == lookup then + return true + end + end + end + end + end + return false +end + +function parseMediaRoles (media_roles_str) + local media_roles = MEDIA_ROLE_NONE + for role in media_roles_str:gmatch('[^,%s]+') do + if role == "Camera" then + media_roles = media_roles | MEDIA_ROLE_CAMERA + end + end + return media_roles +end + +function setPermissions (client, allow_client, allow_nodes) + local client_id = client["bound-id"] + log:info(client, "Granting ALL access to client " .. client_id) + + -- Update permissions on client + client:update_permissions { [client_id] = allow_client and "all" or "-" } + + -- Update permissions on camera source nodes + for node in nodes_om:iterate() do + local node_id = node["bound-id"] + client:update_permissions { [node_id] = allow_nodes and "all" or "-" } + end +end + +function updateClientPermissions (client, permissions) + local client_id = client["bound-id"] + local str_prop = nil + local app_id = nil + local media_roles = nil + local allowed = false + + -- Make sure the client is not the portal itself + str_prop = client.properties["pipewire.access.portal.is_portal"] + if str_prop == "yes" then + log:info (client, "client is the portal itself") + return + end + + -- Make sure the client has a portal app Id + str_prop = client.properties["pipewire.access.muvm00.app_id"] + if str_prop == nil then + log:info (client, "Portal managed client did not set app_id") + return + end + if str_prop == "" then + log:info (client, "Ignoring portal check for non-sandboxed client") + setPermissions (client, true, true) + return + end + app_id = str_prop + + -- Make sure the client has portal media roles + str_prop = client.properties["pipewire.access.portal.media_roles"] + if str_prop == nil then + log:info (client, "Portal managed client did not set media_roles") + return + end + media_roles = parseMediaRoles (str_prop) + if (media_roles & MEDIA_ROLE_CAMERA) == 0 then + log:info (client, "Ignoring portal check for clients without camera role") + return + end + + -- Update permissions + allowed = hasPermission (permissions, app_id, "yes") + + log:info (client, "setting permissions: " .. tostring(allowed)) + setPermissions (client, allowed, allowed) +end + +-- Create portal clients object manager +clients_om = ObjectManager { + Interest { + type = "client", + Constraint { "pipewire.access.muvm00.app_id", "+", type = "pw" }, + } +} + +-- Set permissions to portal clients from the permission store if loaded +pps_plugin = Plugin.find("portal-permissionstore") +if pps_plugin then + nodes_om = ObjectManager { + Interest { + type = "node", + Constraint { "media.role", "=", "Camera" }, + Constraint { "media.class", "=", "Video/Source" }, + } + } + nodes_om:activate() + + clients_om:connect("object-added", function (om, client) + local new_perms = pps_plugin:call("lookup", "devices", "camera"); + updateClientPermissions (client, new_perms) + end) + + nodes_om:connect("object-added", function (om, node) + local new_perms = pps_plugin:call("lookup", "devices", "camera"); + for client in clients_om:iterate() do + updateClientPermissions (client, new_perms) + end + end) + + pps_plugin:connect("changed", function (p, table, id, deleted, permissions) + if table == "devices" or id == "camera" then + for app_id, _ in pairs(permissions) do + for client in clients_om:iterate { + Constraint { "pipewire.access.muvm00.app_id", "=", app_id } + } do + updateClientPermissions (client, permissions) + end + end + end + end) +else + -- Otherwise, just set all permissions to all portal clients + clients_om:connect("object-added", function (om, client) + local id = client["bound-id"] + log:info(client, "Granting ALL access to client " .. id) + client:update_permissions { ["any"] = "all" } + end) +end + +clients_om:activate() diff --git a/share/wireplumber/wireplumber.conf.d/50-muvm-access.conf b/share/wireplumber/wireplumber.conf.d/50-muvm-access.conf new file mode 100644 index 0000000..c27f204 --- /dev/null +++ b/share/wireplumber/wireplumber.conf.d/50-muvm-access.conf @@ -0,0 +1,13 @@ +wireplumber.components = [ + { + name = client/access-muvm.lua, type = script/lua, + provides = custom.access-muvm + requires = [ support.portal-permissionstore ] + } +] + +wireplumber.profiles = { + main = { + custom.access-muvm = required + } +}