Skip to content

Commit

Permalink
Convert Function::call to take GuardedArgs (#715)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstoza authored Oct 11, 2024
1 parent 9c525d7 commit 3726f5f
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 13 deletions.
9 changes: 9 additions & 0 deletions crates/rune-macros/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ where
raw_str,
ref_,
static_type_mod,
to_value,
type_info,
type_of,
unsafe_to_mut,
Expand Down Expand Up @@ -695,6 +696,10 @@ where
let (shared, guard) = #vm_try!(#value::from_ref(self));
#vm_result::Ok((shared, guard))
}

fn try_into_to_value(self) -> Option<impl #to_value> {
Option::<&str>::None
}
}

#[automatically_derived]
Expand All @@ -705,6 +710,10 @@ where
let (shared, guard) = #vm_try!(#value::from_mut(self));
#vm_result::Ok((shared, guard))
}

fn try_into_to_value(self) -> Option<impl #to_value> {
Option::<&str>::None
}
}
})
} else {
Expand Down
30 changes: 18 additions & 12 deletions crates/rune/src/runtime/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ use crate::alloc::{self, Box, Vec};
use crate::function;
use crate::runtime::vm::Isolated;
use crate::runtime::{
Args, Call, ConstValue, FromValue, FunctionHandler, InstAddress, Mutable, Output, OwnedTuple,
Rtti, RuntimeContext, Stack, Unit, Value, ValueRef, VariantRtti, Vm, VmCall, VmErrorKind,
VmHalt, VmResult,
Args, Call, ConstValue, FromValue, FunctionHandler, GuardedArgs, InstAddress, Mutable, Output,
OwnedTuple, Rtti, RuntimeContext, Stack, Unit, Value, ValueRef, VariantRtti, Vm, VmCall,
VmErrorKind, VmHalt, VmResult,
};
use crate::shared::AssertSend;
use crate::Any;
Expand Down Expand Up @@ -146,7 +146,7 @@ impl Function {
/// [Send].
pub async fn async_send_call<A, T>(&self, args: A) -> VmResult<T>
where
A: Send + Args,
A: Send + GuardedArgs,
T: Send + FromValue,
{
self.0.async_send_call(args).await
Expand Down Expand Up @@ -179,7 +179,7 @@ impl Function {
/// assert_eq!(value.call::<u32>((1, 2)).into_result()?, 3);
/// # Ok::<_, rune::support::Error>(())
/// ```
pub fn call<T>(&self, args: impl Args) -> VmResult<T>
pub fn call<T>(&self, args: impl GuardedArgs) -> VmResult<T>
where
T: FromValue,
{
Expand Down Expand Up @@ -403,7 +403,7 @@ impl SyncFunction {
/// # })?;
/// # Ok::<_, rune::support::Error>(())
/// ```
pub async fn async_send_call<T>(&self, args: impl Args + Send) -> VmResult<T>
pub async fn async_send_call<T>(&self, args: impl GuardedArgs + Send) -> VmResult<T>
where
T: Send + FromValue,
{
Expand Down Expand Up @@ -437,7 +437,7 @@ impl SyncFunction {
/// assert_eq!(add.call::<u32>((1, 2)).into_result()?, 3);
/// # Ok::<_, rune::support::Error>(())
/// ```
pub fn call<T>(&self, args: impl Args) -> VmResult<T>
pub fn call<T>(&self, args: impl GuardedArgs) -> VmResult<T>
where
T: FromValue,
{
Expand Down Expand Up @@ -506,7 +506,7 @@ where
OwnedTuple: TryFrom<Box<[V]>>,
VmErrorKind: From<<OwnedTuple as TryFrom<Box<[V]>>>::Error>,
{
fn call<T>(&self, args: impl Args) -> VmResult<T>
fn call<T>(&self, args: impl GuardedArgs) -> VmResult<T>
where
T: FromValue,
{
Expand All @@ -516,7 +516,7 @@ where
let size = count.max(1);
// Ensure we have space for the return value.
let mut stack = vm_try!(Stack::with_capacity(size));
vm_try!(args.into_stack(&mut stack));
let _guard = vm_try!(unsafe { args.unsafe_into_stack(&mut stack) });
vm_try!(stack.resize(size));
vm_try!((handler.handler)(
&mut stack,
Expand All @@ -540,6 +540,9 @@ where
}
Inner::FnTupleStruct(tuple) => {
vm_try!(check_args(args.count(), tuple.args));
let Some(args) = args.try_into_args() else {
return VmResult::err(VmErrorKind::InvalidTupleCall);
};
vm_try!(Value::tuple_struct(
tuple.rtti.clone(),
vm_try!(args.try_into_vec())
Expand All @@ -551,6 +554,9 @@ where
}
Inner::FnTupleVariant(tuple) => {
vm_try!(check_args(args.count(), tuple.args));
let Some(args) = args.try_into_args() else {
return VmResult::err(VmErrorKind::InvalidTupleCall);
};
vm_try!(Value::tuple_variant(
tuple.rtti.clone(),
vm_try!(args.try_into_vec())
Expand All @@ -563,7 +569,7 @@ where

fn async_send_call<'a, A, T>(&'a self, args: A) -> impl Future<Output = VmResult<T>> + Send + 'a
where
A: 'a + Send + Args,
A: 'a + Send + GuardedArgs,
T: 'a + Send + FromValue,
{
let future = async move {
Expand Down Expand Up @@ -902,7 +908,7 @@ struct FnOffset {
impl FnOffset {
/// Perform a call into the specified offset and return the produced value.
#[tracing::instrument(skip_all, fields(args = args.count(), extra = extra.count(), ?self.offset, ?self.call, ?self.args, ?self.hash))]
fn call(&self, args: impl Args, extra: impl Args) -> VmResult<Value> {
fn call(&self, args: impl GuardedArgs, extra: impl Args) -> VmResult<Value> {
vm_try!(check_args(
args.count().wrapping_add(extra.count()),
self.args
Expand All @@ -911,7 +917,7 @@ impl FnOffset {
let mut vm = Vm::new(self.context.clone(), self.unit.clone());

vm.set_ip(self.offset);
vm_try!(args.into_stack(vm.stack_mut()));
let _guard = vm_try!(unsafe { args.unsafe_into_stack(vm.stack_mut()) });
vm_try!(extra.into_stack(vm.stack_mut()));

self.call.call_with_vm(vm)
Expand Down
50 changes: 49 additions & 1 deletion crates/rune/src/runtime/guarded_args.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::runtime::{Stack, UnsafeToValue, VmResult};
use crate::alloc::Vec;
use crate::runtime::Args;
use crate::runtime::{Stack, UnsafeToValue, Value, VmResult};

/// Trait for converting arguments onto the stack.
///
Expand All @@ -18,6 +20,10 @@ pub trait GuardedArgs {
/// invalidated.
unsafe fn unsafe_into_stack(self, stack: &mut Stack) -> VmResult<Self::Guard>;

/// Attempts to convert this type into Args, which will only succeed as long
/// as it doesn't contain any references to Any types.
fn try_into_args(self) -> Option<impl Args>;

/// The number of arguments.
fn count(&self) -> usize;
}
Expand All @@ -38,6 +44,11 @@ macro_rules! impl_into_args {
VmResult::Ok(($($value.1,)*))
}

fn try_into_args(self) -> Option<impl Args> {
let ($($value,)*) = self;
Some(($($value.try_into_to_value()?,)*))
}

fn count(&self) -> usize {
$count
}
Expand All @@ -46,3 +57,40 @@ macro_rules! impl_into_args {
}

repeat_macro!(impl_into_args);

impl GuardedArgs for Vec<Value> {
type Guard = ();

#[inline]
unsafe fn unsafe_into_stack(self, stack: &mut Stack) -> VmResult<Self::Guard> {
self.into_stack(stack)
}

fn try_into_args(self) -> Option<impl Args> {
Some(self)
}

#[inline]
fn count(&self) -> usize {
(self as &dyn Args).count()
}
}

#[cfg(feature = "alloc")]
impl GuardedArgs for ::rust_alloc::vec::Vec<Value> {
type Guard = ();

#[inline]
unsafe fn unsafe_into_stack(self, stack: &mut Stack) -> VmResult<Self::Guard> {
self.into_stack(stack)
}

fn try_into_args(self) -> Option<impl Args> {
Some(self)
}

#[inline]
fn count(&self) -> usize {
(self as &dyn Args).count()
}
}
8 changes: 8 additions & 0 deletions crates/rune/src/runtime/to_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ pub trait UnsafeToValue: Sized {
/// The value returned must not be used after the guard associated with it
/// has been dropped.
unsafe fn unsafe_to_value(self) -> VmResult<(Value, Self::Guard)>;

/// Attempts to convert this UnsafeToValue into a ToValue, which is only
/// possible if it is not a reference to an Any type.
fn try_into_to_value(self) -> Option<impl ToValue>;
}

impl<T> ToValue for T
Expand All @@ -142,6 +146,10 @@ where
unsafe fn unsafe_to_value(self) -> VmResult<(Value, Self::Guard)> {
VmResult::Ok((vm_try!(self.to_value()), ()))
}

fn try_into_to_value(self) -> Option<impl ToValue> {
Some(self)
}
}

impl ToValue for &Value {
Expand Down
7 changes: 7 additions & 0 deletions crates/rune/src/runtime/vm_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ pub(crate) enum VmErrorKind {
},
MissingCallFrame,
IllegalFormat,
InvalidTupleCall,
}

impl fmt::Display for VmErrorKind {
Expand Down Expand Up @@ -949,6 +950,12 @@ impl fmt::Display for VmErrorKind {
VmErrorKind::IllegalFormat => {
write!(f, "Value cannot be formatted")
}
VmErrorKind::InvalidTupleCall => {
write!(
f,
"Tuple struct/variant constructors cannot be called with references"
)
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/rune/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ mod external_constructor;
mod external_generic;
mod external_match;
mod external_ops;
mod function_guardedargs;
mod getter_setter;
mod iterator;
mod macros;
Expand Down
98 changes: 98 additions & 0 deletions crates/rune/src/tests/function_guardedargs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
prelude!();

#[derive(Default)]
struct MyAny {}

crate::__internal_impl_any!(self, MyAny);

fn get_vm() -> crate::support::Result<crate::Vm> {
use std::sync::Arc;

let mut sources = crate::sources! {
entry => {
enum Enum {
Variant(internal)
}

struct Struct(internal);

pub fn function(argument) {}
}
};

let context = crate::Context::with_default_modules()?;
let unit = crate::prepare(&mut sources).build()?;
Ok(crate::Vm::new(Arc::new(context.runtime()?), Arc::new(unit)))
}

#[test]
fn references_allowed_for_function_calls() {
let vm = get_vm().unwrap();
let function = vm.lookup_function(["function"]).unwrap();

let value_result = function.call::<crate::Value>((crate::Value::unit(),));
assert!(value_result.is_ok());

let mut mine = MyAny::default();

let ref_result = function.call::<crate::Value>((&mine,));
assert!(ref_result.is_ok());

let mut_result = function.call::<crate::Value>((&mut mine,));
assert!(mut_result.is_ok());

let any_result = function.call::<crate::Value>((mine,));
assert!(any_result.is_ok());
}

#[test]
fn references_disallowed_for_tuple_variant() {
use crate::runtime::{VmErrorKind, VmResult};

let vm = get_vm().unwrap();
let constructor = vm.lookup_function(["Enum", "Variant"]).unwrap();

let value_result = constructor.call::<crate::Value>((crate::Value::unit(),));
assert!(value_result.is_ok());

let mut mine = MyAny::default();

let VmResult::Err(ref_error) = constructor.call::<crate::Value>((&mine,)) else {
panic!("expected ref call to return an error")
};
assert_eq!(ref_error.into_kind(), VmErrorKind::InvalidTupleCall);

let VmResult::Err(mut_error) = constructor.call::<crate::Value>((&mut mine,)) else {
panic!("expected mut call to return an error")
};
assert_eq!(mut_error.into_kind(), VmErrorKind::InvalidTupleCall);

let any_result = constructor.call::<crate::Value>((mine,));
assert!(any_result.is_ok());
}

#[test]
fn references_disallowed_for_tuple_struct() {
use crate::runtime::{VmErrorKind, VmResult};

let vm = get_vm().unwrap();
let constructor = vm.lookup_function(["Struct"]).unwrap();

let value_result = constructor.call::<crate::Value>((crate::Value::unit(),));
assert!(value_result.is_ok());

let mut mine = MyAny::default();

let VmResult::Err(ref_error) = constructor.call::<crate::Value>((&mine,)) else {
panic!("expected ref call to return an error")
};
assert_eq!(ref_error.into_kind(), VmErrorKind::InvalidTupleCall);

let VmResult::Err(mut_error) = constructor.call::<crate::Value>((&mut mine,)) else {
panic!("expected mut call to return an error")
};
assert_eq!(mut_error.into_kind(), VmErrorKind::InvalidTupleCall);

let any_result = constructor.call::<crate::Value>((mine,));
assert!(any_result.is_ok());
}

0 comments on commit 3726f5f

Please sign in to comment.