Skip to content

Commit

Permalink
Merge pull request #4174 from RalfJung/read-write-callback
Browse files Browse the repository at this point in the history
files: make read/write take callback to store result
  • Loading branch information
RalfJung authored Feb 2, 2025
2 parents 70fd1ee + 16d331b commit 5afb453
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 138 deletions.
101 changes: 46 additions & 55 deletions src/shims/files.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::any::Any;
use std::collections::BTreeMap;
use std::io::{IsTerminal, Read, SeekFrom, Write};
use std::io::{IsTerminal, SeekFrom, Write};
use std::marker::CoercePointee;
use std::ops::Deref;
use std::rc::{Rc, Weak};
Expand Down Expand Up @@ -140,8 +140,8 @@ pub trait FileDescription: std::fmt::Debug + FileDescriptionExt {
_communicate_allowed: bool,
_ptr: Pointer,
_len: usize,
_dest: &MPlaceTy<'tcx>,
_ecx: &mut MiriInterpCx<'tcx>,
_finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
throw_unsup_format!("cannot read from {}", self.name());
}
Expand All @@ -154,8 +154,8 @@ pub trait FileDescription: std::fmt::Debug + FileDescriptionExt {
_communicate_allowed: bool,
_ptr: Pointer,
_len: usize,
_dest: &MPlaceTy<'tcx>,
_ecx: &mut MiriInterpCx<'tcx>,
_finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
throw_unsup_format!("cannot write to {}", self.name());
}
Expand Down Expand Up @@ -207,19 +207,16 @@ impl FileDescription for io::Stdin {
communicate_allowed: bool,
ptr: Pointer,
len: usize,
dest: &MPlaceTy<'tcx>,
ecx: &mut MiriInterpCx<'tcx>,
finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
let mut bytes = vec![0; len];
if !communicate_allowed {
// We want isolation mode to be deterministic, so we have to disallow all reads, even stdin.
helpers::isolation_abort_error("`read` from stdin")?;
}
let result = Read::read(&mut &*self, &mut bytes);
match result {
Ok(read_size) => ecx.return_read_success(ptr, &bytes, read_size, dest),
Err(e) => ecx.set_last_error_and_return(e, dest),
}

let result = ecx.read_from_host(&*self, len, ptr)?;
finish.call(ecx, result)
}

fn is_tty(&self, communicate_allowed: bool) -> bool {
Expand All @@ -237,22 +234,19 @@ impl FileDescription for io::Stdout {
_communicate_allowed: bool,
ptr: Pointer,
len: usize,
dest: &MPlaceTy<'tcx>,
ecx: &mut MiriInterpCx<'tcx>,
finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
let bytes = ecx.read_bytes_ptr_strip_provenance(ptr, Size::from_bytes(len))?;
// We allow writing to stderr even with isolation enabled.
let result = Write::write(&mut &*self, bytes);
// We allow writing to stdout even with isolation enabled.
let result = ecx.write_to_host(&*self, len, ptr)?;
// Stdout is buffered, flush to make sure it appears on the
// screen. This is the write() syscall of the interpreted
// program, we want it to correspond to a write() syscall on
// the host -- there is no good in adding extra buffering
// here.
io::stdout().flush().unwrap();
match result {
Ok(write_size) => ecx.return_write_success(write_size, dest),
Err(e) => ecx.set_last_error_and_return(e, dest),
}

finish.call(ecx, result)
}

fn is_tty(&self, communicate_allowed: bool) -> bool {
Expand All @@ -270,17 +264,13 @@ impl FileDescription for io::Stderr {
_communicate_allowed: bool,
ptr: Pointer,
len: usize,
dest: &MPlaceTy<'tcx>,
ecx: &mut MiriInterpCx<'tcx>,
finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
let bytes = ecx.read_bytes_ptr_strip_provenance(ptr, Size::from_bytes(len))?;
// We allow writing to stderr even with isolation enabled.
let result = ecx.write_to_host(&*self, len, ptr)?;
// No need to flush, stderr is not buffered.
let result = Write::write(&mut &*self, bytes);
match result {
Ok(write_size) => ecx.return_write_success(write_size, dest),
Err(e) => ecx.set_last_error_and_return(e, dest),
}
finish.call(ecx, result)
}

fn is_tty(&self, communicate_allowed: bool) -> bool {
Expand All @@ -302,11 +292,11 @@ impl FileDescription for NullOutput {
_communicate_allowed: bool,
_ptr: Pointer,
len: usize,
dest: &MPlaceTy<'tcx>,
ecx: &mut MiriInterpCx<'tcx>,
finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
// We just don't write anything, but report to the user that we did.
ecx.return_write_success(len, dest)
finish.call(ecx, Ok(len))
}
}

Expand Down Expand Up @@ -405,40 +395,41 @@ impl FdTable {

impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
/// Helper to implement `FileDescription::read`:
/// This is only used when `read` is successful.
/// `actual_read_size` should be the return value of some underlying `read` call that used
/// `bytes` as its output buffer.
/// The length of `bytes` must not exceed either the host's or the target's `isize`.
/// `bytes` is written to `buf` and the size is written to `dest`.
fn return_read_success(
/// Read data from a host `Read` type, store the result into machine memory,
/// and return whether that worked.
fn read_from_host(
&mut self,
buf: Pointer,
bytes: &[u8],
actual_read_size: usize,
dest: &MPlaceTy<'tcx>,
) -> InterpResult<'tcx> {
mut file: impl io::Read,
len: usize,
ptr: Pointer,
) -> InterpResult<'tcx, Result<usize, IoError>> {
let this = self.eval_context_mut();
// If reading to `bytes` did not fail, we write those bytes to the buffer.
// Crucially, if fewer than `bytes.len()` bytes were read, only write
// that much into the output buffer!
this.write_bytes_ptr(buf, bytes[..actual_read_size].iter().copied())?;

// The actual read size is always less than what got originally requested so this cannot fail.
this.write_int(u64::try_from(actual_read_size).unwrap(), dest)?;
interp_ok(())
let mut bytes = vec![0; len];
let result = file.read(&mut bytes);
match result {
Ok(read_size) => {
// If reading to `bytes` did not fail, we write those bytes to the buffer.
// Crucially, if fewer than `bytes.len()` bytes were read, only write
// that much into the output buffer!
this.write_bytes_ptr(ptr, bytes[..read_size].iter().copied())?;
interp_ok(Ok(read_size))
}
Err(e) => interp_ok(Err(IoError::HostError(e))),
}
}

/// Helper to implement `FileDescription::write`:
/// This function is only used when `write` is successful, and writes `actual_write_size` to `dest`
fn return_write_success(
/// Write data to a host `Write` type, withthe bytes taken from machine memory.
fn write_to_host(
&mut self,
actual_write_size: usize,
dest: &MPlaceTy<'tcx>,
) -> InterpResult<'tcx> {
mut file: impl io::Write,
len: usize,
ptr: Pointer,
) -> InterpResult<'tcx, Result<usize, IoError>> {
let this = self.eval_context_mut();
// The actual write size is always less than what got originally requested so this cannot fail.
this.write_int(u64::try_from(actual_write_size).unwrap(), dest)?;
interp_ok(())

let bytes = this.read_bytes_ptr_strip_provenance(ptr, Size::from_bytes(len))?;
let result = file.write(bytes);
interp_ok(result.map_err(IoError::HostError))
}
}
54 changes: 47 additions & 7 deletions src/shims/unix/fd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ pub trait UnixFileDescription: FileDescription {
_offset: u64,
_ptr: Pointer,
_len: usize,
_dest: &MPlaceTy<'tcx>,
_ecx: &mut MiriInterpCx<'tcx>,
_finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
throw_unsup_format!("cannot pread from {}", self.name());
}
Expand All @@ -46,8 +46,8 @@ pub trait UnixFileDescription: FileDescription {
_ptr: Pointer,
_len: usize,
_offset: u64,
_dest: &MPlaceTy<'tcx>,
_ecx: &mut MiriInterpCx<'tcx>,
_finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
throw_unsup_format!("cannot pwrite to {}", self.name());
}
Expand Down Expand Up @@ -236,7 +236,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let count = usize::try_from(count).unwrap(); // now it fits in a `usize`
let communicate = this.machine.communicate();

// We temporarily dup the FD to be able to retain mutable access to `this`.
// Get the FD.
let Some(fd) = this.machine.fds.get(fd_num) else {
trace!("read: FD not found");
return this.set_last_error_and_return(LibcError("EBADF"), dest);
Expand All @@ -247,13 +247,33 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// because it was a target's `usize`. Also we are sure that its smaller than
// `usize::MAX` because it is bounded by the host's `isize`.

let finish = {
let dest = dest.clone();
callback!(
@capture<'tcx> {
count: usize,
dest: MPlaceTy<'tcx>,
}
|this, result: Result<usize, IoError>| {
match result {
Ok(read_size) => {
assert!(read_size <= count);
// This must fit since `count` fits.
this.write_int(u64::try_from(read_size).unwrap(), &dest)
}
Err(e) => {
this.set_last_error_and_return(e, &dest)
}
}}
)
};
match offset {
None => fd.read(communicate, buf, count, dest, this)?,
None => fd.read(communicate, buf, count, this, finish)?,
Some(offset) => {
let Ok(offset) = u64::try_from(offset) else {
return this.set_last_error_and_return(LibcError("EINVAL"), dest);
};
fd.as_unix().pread(communicate, offset, buf, count, dest, this)?
fd.as_unix().pread(communicate, offset, buf, count, this, finish)?
}
};
interp_ok(())
Expand Down Expand Up @@ -287,13 +307,33 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
return this.set_last_error_and_return(LibcError("EBADF"), dest);
};

let finish = {
let dest = dest.clone();
callback!(
@capture<'tcx> {
count: usize,
dest: MPlaceTy<'tcx>,
}
|this, result: Result<usize, IoError>| {
match result {
Ok(write_size) => {
assert!(write_size <= count);
// This must fit since `count` fits.
this.write_int(u64::try_from(write_size).unwrap(), &dest)
}
Err(e) => {
this.set_last_error_and_return(e, &dest)
}
}}
)
};
match offset {
None => fd.write(communicate, buf, count, dest, this)?,
None => fd.write(communicate, buf, count, this, finish)?,
Some(offset) => {
let Ok(offset) = u64::try_from(offset) else {
return this.set_last_error_and_return(LibcError("EINVAL"), dest);
};
fd.as_unix().pwrite(communicate, buf, count, offset, dest, this)?
fd.as_unix().pwrite(communicate, buf, count, offset, this, finish)?
}
};
interp_ok(())
Expand Down
47 changes: 22 additions & 25 deletions src/shims/unix/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,27 @@ impl FileDescription for FileHandle {
communicate_allowed: bool,
ptr: Pointer,
len: usize,
dest: &MPlaceTy<'tcx>,
ecx: &mut MiriInterpCx<'tcx>,
finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
assert!(communicate_allowed, "isolation should have prevented even opening a file");
let mut bytes = vec![0; len];
let result = (&mut &self.file).read(&mut bytes);
match result {
Ok(read_size) => ecx.return_read_success(ptr, &bytes, read_size, dest),
Err(e) => ecx.set_last_error_and_return(e, dest),
}

let result = ecx.read_from_host(&self.file, len, ptr)?;
finish.call(ecx, result)
}

fn write<'tcx>(
self: FileDescriptionRef<Self>,
communicate_allowed: bool,
ptr: Pointer,
len: usize,
dest: &MPlaceTy<'tcx>,
ecx: &mut MiriInterpCx<'tcx>,
finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
assert!(communicate_allowed, "isolation should have prevented even opening a file");
let bytes = ecx.read_bytes_ptr_strip_provenance(ptr, Size::from_bytes(len))?;
let result = (&mut &self.file).write(bytes);
match result {
Ok(write_size) => ecx.return_write_success(write_size, dest),
Err(e) => ecx.set_last_error_and_return(e, dest),
}

let result = ecx.write_to_host(&self.file, len, ptr)?;
finish.call(ecx, result)
}

fn seek<'tcx>(
Expand Down Expand Up @@ -119,8 +113,8 @@ impl UnixFileDescription for FileHandle {
offset: u64,
ptr: Pointer,
len: usize,
dest: &MPlaceTy<'tcx>,
ecx: &mut MiriInterpCx<'tcx>,
finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
assert!(communicate_allowed, "isolation should have prevented even opening a file");
let mut bytes = vec![0; len];
Expand All @@ -137,11 +131,17 @@ impl UnixFileDescription for FileHandle {
.expect("failed to restore file position, this shouldn't be possible");
res
};
let result = f();
match result {
Ok(read_size) => ecx.return_read_success(ptr, &bytes, read_size, dest),
Err(e) => ecx.set_last_error_and_return(e, dest),
}
let result = match f() {
Ok(read_size) => {
// If reading to `bytes` did not fail, we write those bytes to the buffer.
// Crucially, if fewer than `bytes.len()` bytes were read, only write
// that much into the output buffer!
ecx.write_bytes_ptr(ptr, bytes[..read_size].iter().copied())?;
Ok(read_size)
}
Err(e) => Err(IoError::HostError(e)),
};
finish.call(ecx, result)
}

fn pwrite<'tcx>(
Expand All @@ -150,8 +150,8 @@ impl UnixFileDescription for FileHandle {
ptr: Pointer,
len: usize,
offset: u64,
dest: &MPlaceTy<'tcx>,
ecx: &mut MiriInterpCx<'tcx>,
finish: DynMachineCallback<'tcx, Result<usize, IoError>>,
) -> InterpResult<'tcx> {
assert!(communicate_allowed, "isolation should have prevented even opening a file");
// Emulates pwrite using seek + write + seek to restore cursor position.
Expand All @@ -169,10 +169,7 @@ impl UnixFileDescription for FileHandle {
res
};
let result = f();
match result {
Ok(write_size) => ecx.return_write_success(write_size, dest),
Err(e) => ecx.set_last_error_and_return(e, dest),
}
finish.call(ecx, result.map_err(IoError::HostError))
}

fn flock<'tcx>(
Expand Down
Loading

0 comments on commit 5afb453

Please sign in to comment.