Skip to content

Commit

Permalink
Expose OpaqueIpcMessage as IpcMessage (#362)
Browse files Browse the repository at this point in the history
This makes the API a bit friendlier and future-proof. Instead of
returning a tuple from raw `recv()` return the `IpcMessage` struct that
is used internally in the crate. In addition, this is used to pass data
from the platform layer -- removing clippy warnings about complex data
types. One upside to this is that `Option<IpcSharedMemory>` is turned
into `IpcSharedMemory` in some places where having a `None` value
shouldn't be possible.

This simplification removes about 300 lines of code.
  • Loading branch information
mrobinson authored Oct 13, 2024
1 parent 2bcf4dd commit a29f223
Show file tree
Hide file tree
Showing 9 changed files with 295 additions and 616 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ name = "ipc_receiver_set"
harness = false

[features]
default = []
force-inprocess = []
memfd = ["sc"]
async = ["futures", "futures-test"]
Expand Down
16 changes: 6 additions & 10 deletions src/asynch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::ipc;
use crate::ipc::IpcReceiver;
use crate::ipc::IpcReceiverSet;
use crate::ipc::IpcSelectionResult;
use crate::ipc::IpcSender;
use crate::ipc::OpaqueIpcMessage;
use crate::ipc::OpaqueIpcReceiver;
use crate::ipc::{
self, IpcMessage, IpcReceiver, IpcReceiverSet, IpcSelectionResult, IpcSender, OpaqueIpcReceiver,
};
use futures::channel::mpsc::UnboundedReceiver;
use futures::channel::mpsc::UnboundedSender;
use futures::stream::FusedStream;
Expand All @@ -30,15 +26,15 @@ use std::sync::Mutex;
use std::thread;

/// A stream built from an IPC channel.
pub struct IpcStream<T>(UnboundedReceiver<OpaqueIpcMessage>, PhantomData<T>);
pub struct IpcStream<T>(UnboundedReceiver<IpcMessage>, PhantomData<T>);

impl<T> Unpin for IpcStream<T> {}

// A router which routes from an IPC channel to a stream.
struct Router {
// Send `(ipc_recv, send)` to this router to add a route
// from the IPC receiver to the sender.
add_route: UnboundedSender<(OpaqueIpcReceiver, UnboundedSender<OpaqueIpcMessage>)>,
add_route: UnboundedSender<(OpaqueIpcReceiver, UnboundedSender<IpcMessage>)>,

// Wake up the routing thread.
wakeup: Mutex<IpcSender<()>>,
Expand All @@ -52,7 +48,7 @@ lazy_static! {
let (waker, wakee) = ipc::channel().expect("Failed to create IPC channel");
thread::spawn(move || {
let mut receivers = IpcReceiverSet::new().expect("Failed to create receiver set");
let mut senders = HashMap::<u64, UnboundedSender<OpaqueIpcMessage>>::new();
let mut senders = HashMap::<u64, UnboundedSender<IpcMessage>>::new();
let _ = receivers.add(wakee);
while let Ok(mut selections) = receivers.select() {
for selection in selections.drain(..) {
Expand Down
136 changes: 60 additions & 76 deletions src/ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,16 +251,13 @@ where
{
/// Blocking receive.
pub fn recv(&self) -> Result<T, IpcError> {
let (data, os_ipc_channels, os_ipc_shared_memory_regions) = self.os_receiver.recv()?;
OpaqueIpcMessage::new(data, os_ipc_channels, os_ipc_shared_memory_regions)
.to()
.map_err(IpcError::Bincode)
self.os_receiver.recv()?.to().map_err(IpcError::Bincode)
}

/// Non-blocking receive
pub fn try_recv(&self) -> Result<T, TryRecvError> {
let (data, os_ipc_channels, os_ipc_shared_memory_regions) = self.os_receiver.try_recv()?;
OpaqueIpcMessage::new(data, os_ipc_channels, os_ipc_shared_memory_regions)
self.os_receiver
.try_recv()?
.to()
.map_err(IpcError::Bincode)
.map_err(TryRecvError::IpcError)
Expand All @@ -273,9 +270,8 @@ where
/// block forever. At the time of writing, the smallest duration that may trigger this behavior
/// is over 24 days.
pub fn try_recv_timeout(&self, duration: Duration) -> Result<T, TryRecvError> {
let (data, os_ipc_channels, os_ipc_shared_memory_regions) =
self.os_receiver.try_recv_timeout(duration)?;
OpaqueIpcMessage::new(data, os_ipc_channels, os_ipc_shared_memory_regions)
self.os_receiver
.try_recv_timeout(duration)?
.to()
.map_err(IpcError::Bincode)
.map_err(TryRecvError::IpcError)
Expand Down Expand Up @@ -506,22 +502,9 @@ impl IpcReceiverSet {
Ok(results
.into_iter()
.map(|result| match result {
OsIpcSelectionResult::DataReceived(
os_receiver_id,
data,
os_ipc_channels,
os_ipc_shared_memory_regions,
) => IpcSelectionResult::MessageReceived(
os_receiver_id,
OpaqueIpcMessage {
data,
os_ipc_channels,
os_ipc_shared_memory_regions: os_ipc_shared_memory_regions
.into_iter()
.map(Some)
.collect(),
},
),
OsIpcSelectionResult::DataReceived(os_receiver_id, ipc_message) => {
IpcSelectionResult::MessageReceived(os_receiver_id, ipc_message)
},
OsIpcSelectionResult::ChannelClosed(os_receiver_id) => {
IpcSelectionResult::ChannelClosed(os_receiver_id)
},
Expand Down Expand Up @@ -569,21 +552,21 @@ impl<'de> Deserialize<'de> for IpcSharedMemory {
{
let index: usize = Deserialize::deserialize(deserializer)?;
if index == usize::MAX {
Ok(IpcSharedMemory::empty())
} else {
let os_shared_memory = OS_IPC_SHARED_MEMORY_REGIONS_FOR_DESERIALIZATION.with(
|os_ipc_shared_memory_regions_for_deserialization| {
// FIXME(pcwalton): This could panic if the data was corrupt and the index was out
// of bounds. We should return an `Err` result instead.
os_ipc_shared_memory_regions_for_deserialization.borrow_mut()[index]
.take()
.unwrap()
},
);
Ok(IpcSharedMemory {
os_shared_memory: Some(os_shared_memory),
})
return Ok(IpcSharedMemory::empty());
}

let os_shared_memory = OS_IPC_SHARED_MEMORY_REGIONS_FOR_DESERIALIZATION.with(
|os_ipc_shared_memory_regions_for_deserialization| {
// FIXME(pcwalton): This could panic if the data was corrupt and the index was out
// of bounds. We should return an `Err` result instead.
os_ipc_shared_memory_regions_for_deserialization.borrow_mut()[index]
.take()
.unwrap()
},
);
Ok(IpcSharedMemory {
os_shared_memory: Some(os_shared_memory),
})
}
}

Expand Down Expand Up @@ -646,12 +629,9 @@ impl IpcSharedMemory {
///
/// [IpcReceiverSet::select]: struct.IpcReceiverSet.html#method.select
pub enum IpcSelectionResult {
/// A message received from the [IpcReceiver] in the [opaque] form,
/// A message received from the [`IpcReceiver`] in the [`IpcMessage`] form,
/// identified by the `u64` value.
///
/// [IpcReceiver]: struct.IpcReceiver.html
/// [opaque]: struct.OpaqueIpcMessage.html
MessageReceived(u64, OpaqueIpcMessage),
MessageReceived(u64, IpcMessage),
/// The channel has been closed for the [IpcReceiver] identified by the `u64` value.
/// [IpcReceiver]: struct.IpcReceiver.html
ChannelClosed(u64),
Expand All @@ -668,7 +648,7 @@ impl IpcSelectionResult {
/// [IpcSelectionResult]: enum.IpcSelectionResult.html
/// [MessageReceived]: enum.IpcSelectionResult.html#variant.MessageReceived
/// [ChannelClosed]: enum.IpcSelectionResult.html#variant.ChannelClosed
pub fn unwrap(self) -> (u64, OpaqueIpcMessage) {
pub fn unwrap(self) -> (u64, IpcMessage) {
match self {
IpcSelectionResult::MessageReceived(id, message) => (id, message),
IpcSelectionResult::ChannelClosed(id) => {
Expand All @@ -678,19 +658,31 @@ impl IpcSelectionResult {
}
}

/// Structure used to represent a raw message from an [IpcSender].
/// Structure used to represent a raw message from an [`IpcSender`].
///
/// Use the [to] method to deserialize the raw result into the requested type.
///
/// [IpcSender]: struct.IpcSender.html
/// [to]: #method.to
pub struct OpaqueIpcMessage {
data: Vec<u8>,
os_ipc_channels: Vec<OsOpaqueIpcChannel>,
os_ipc_shared_memory_regions: Vec<Option<OsIpcSharedMemory>>,
#[derive(PartialEq)]
pub struct IpcMessage {
pub data: Vec<u8>,
pub os_ipc_channels: Vec<OsOpaqueIpcChannel>,
pub os_ipc_shared_memory_regions: Vec<OsIpcSharedMemory>,
}

impl IpcMessage {
/// Create a new [`IpcMessage`] with data and without any [`OsOpaqueIpcChannel`]s and
/// [`OsIpcSharedMemory`] regions.
pub fn from_data(data: Vec<u8>) -> Self {
Self {
data,
os_ipc_channels: vec![],
os_ipc_shared_memory_regions: vec![],
}
}
}

impl Debug for OpaqueIpcMessage {
impl Debug for IpcMessage {
fn fmt(&self, formatter: &mut Formatter) -> Result<(), fmt::Error> {
match String::from_utf8(self.data.clone()) {
Ok(string) => string.chars().take(256).collect::<String>().fmt(formatter),
Expand All @@ -699,19 +691,16 @@ impl Debug for OpaqueIpcMessage {
}
}

impl OpaqueIpcMessage {
fn new(
impl IpcMessage {
pub(crate) fn new(
data: Vec<u8>,
os_ipc_channels: Vec<OsOpaqueIpcChannel>,
os_ipc_shared_memory_regions: Vec<OsIpcSharedMemory>,
) -> OpaqueIpcMessage {
OpaqueIpcMessage {
) -> IpcMessage {
IpcMessage {
data,
os_ipc_channels,
os_ipc_shared_memory_regions: os_ipc_shared_memory_regions
.into_iter()
.map(Some)
.collect(),
os_ipc_shared_memory_regions,
}
}

Expand All @@ -727,15 +716,16 @@ impl OpaqueIpcMessage {
&mut *os_ipc_channels_for_deserialization.borrow_mut(),
&mut self.os_ipc_channels,
);
mem::swap(
let old_ipc_shared_memory_regions_for_deserialization = mem::replace(
&mut *os_ipc_shared_memory_regions_for_deserialization.borrow_mut(),
&mut self.os_ipc_shared_memory_regions,
self.os_ipc_shared_memory_regions
.into_iter()
.map(Some)
.collect(),
);
let result = bincode::deserialize(&self.data[..]);
mem::swap(
&mut *os_ipc_shared_memory_regions_for_deserialization.borrow_mut(),
&mut self.os_ipc_shared_memory_regions,
);
*os_ipc_shared_memory_regions_for_deserialization.borrow_mut() =
old_ipc_shared_memory_regions_for_deserialization;
mem::swap(
&mut *os_ipc_channels_for_deserialization.borrow_mut(),
&mut self.os_ipc_channels,
Expand Down Expand Up @@ -875,19 +865,13 @@ where
}

pub fn accept(self) -> Result<(IpcReceiver<T>, T), bincode::Error> {
let (os_receiver, data, os_channels, os_shared_memory_regions) = self.os_server.accept()?;
let value = OpaqueIpcMessage {
data,
os_ipc_channels: os_channels,
os_ipc_shared_memory_regions: os_shared_memory_regions.into_iter().map(Some).collect(),
}
.to()?;
let (os_receiver, ipc_message) = self.os_server.accept()?;
Ok((
IpcReceiver {
os_receiver,
phantom: PhantomData,
},
value,
ipc_message.to()?,
))
}
}
Expand All @@ -903,15 +887,15 @@ impl IpcBytesReceiver {
#[inline]
pub fn recv(&self) -> Result<Vec<u8>, IpcError> {
match self.os_receiver.recv() {
Ok((data, _, _)) => Ok(data),
Ok(ipc_message) => Ok(ipc_message.data),
Err(err) => Err(err.into()),
}
}

/// Non-blocking receive
pub fn try_recv(&self) -> Result<Vec<u8>, TryRecvError> {
match self.os_receiver.try_recv() {
Ok((data, _, _)) => Ok(data),
Ok(ipc_message) => Ok(ipc_message.data),
Err(err) => Err(err.into()),
}
}
Expand Down
Loading

0 comments on commit a29f223

Please sign in to comment.