Skip to content

Commit

Permalink
feat(struct): add struct as marshallable field type
Browse files Browse the repository at this point in the history
  • Loading branch information
Wodann committed Jan 29, 2020
1 parent 8596546 commit ded289b
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 118 deletions.
16 changes: 16 additions & 0 deletions crates/mun_abi/src/autogen_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,22 @@ impl StructInfo {
unsafe { slice::from_raw_parts(self.field_sizes, self.num_fields as usize) }
}
}

/// Returns the index of the field matching the specified `field_name`.
pub fn find_field_index(struct_info: &StructInfo, field_name: &str) -> Result<usize, String> {
struct_info
.field_names()
.enumerate()
.find(|(_, name)| *name == field_name)
.map(|(idx, _)| idx)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
struct_info.name(),
field_name
)
})
}
}

impl fmt::Display for StructInfo {
Expand Down
54 changes: 20 additions & 34 deletions crates/mun_runtime/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,53 +105,39 @@ macro_rules! invoke_fn_impl {
if arg_types.len() != num_args {
return Err(format!(
"Invalid number of arguments. Expected: {}. Found: {}.",
num_args,
arg_types.len(),
num_args,
));
}

#[allow(unused_mut, unused_variables)]
let mut idx = 0;
$(
if arg_types[idx].guid != $Arg.type_guid() {
return Err(format!(
"Invalid argument type at index {}. Expected: {}. Found: {}.",
idx,
$Arg.type_name(),
arg_types[idx].name(),
));
}
crate::reflection::equals_argument_type(&arg_types[idx], &$Arg)
.map_err(|(expected, found)| {
format!(
"Invalid argument type at index {}. Expected: {}. Found: {}.",
idx,
expected,
found,
)
})?;
idx += 1;
)*

if let Some(return_type) = function_info.signature.return_type() {
match return_type.group {
abi::TypeGroup::FundamentalTypes => {
if return_type.guid != Output::type_guid() {
return Err(format!(
"Invalid return type. Expected: {}. Found: {}",
Output::type_name(),
return_type.name(),
));
}
}
abi::TypeGroup::StructTypes => {
if <Struct as ReturnTypeReflection>::type_guid() != Output::type_guid() {
return Err(format!(
"Invalid return type. Expected: {}. Found: Struct",
Output::type_name(),
));
}
}
}

crate::reflection::equals_return_type::<Output>(return_type)
} else if <() as ReturnTypeReflection>::type_guid() != Output::type_guid() {
return Err(format!(
Err((<() as ReturnTypeReflection>::type_name(), Output::type_name()))
} else {
Ok(())
}.map_err(|(expected, found)| {
format!(
"Invalid return type. Expected: {}. Found: {}",
Output::type_name(),
<() as ReturnTypeReflection>::type_name(),
));
}
expected,
found,
)
})?;

Ok(function_info)
}) {
Expand Down
39 changes: 35 additions & 4 deletions crates/mun_runtime/src/reflection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
use crate::marshal::MarshalInto;
use abi::Guid;
use crate::{marshal::MarshalInto, Struct};
use abi::{Guid, TypeInfo};
use md5;

/// Returns whether the specified argument type matches the `type_info`.
pub fn equals_argument_type<'e, 'f, T: ArgumentReflection>(
type_info: &'e TypeInfo,
arg: &'f T,
) -> Result<(), (&'e str, &'f str)> {
if type_info.guid != arg.type_guid() {
Err((type_info.name(), arg.type_name()))
} else {
Ok(())
}
}

/// Returns whether the specified return type matches the `type_info`.
pub fn equals_return_type<T: ReturnTypeReflection>(
type_info: &TypeInfo,
) -> Result<(), (&str, &str)> {
match type_info.group {
abi::TypeGroup::FundamentalTypes => {
if type_info.guid != T::type_guid() {
return Err((type_info.name(), T::type_name()));
}
}
abi::TypeGroup::StructTypes => {
if <Struct as ReturnTypeReflection>::type_guid() != T::type_guid() {
return Err(("struct", T::type_name()));
}
}
}
Ok(())
}

/// A type to emulate dynamic typing across compilation units for static types.
pub trait ReturnTypeReflection: Sized + 'static {
/// The resulting type after marshaling.
Expand All @@ -19,9 +50,9 @@ pub trait ReturnTypeReflection: Sized + 'static {
}

/// A type to emulate dynamic typing across compilation units for statically typed values.
pub trait ArgumentReflection {
pub trait ArgumentReflection: Sized {
/// The resulting type after dereferencing.
type Marshalled: Sized;
type Marshalled: MarshalInto<Self>;

/// Retrieves the `Guid` of the value's type.
fn type_guid(&self) -> Guid {
Expand Down
112 changes: 42 additions & 70 deletions crates/mun_runtime/src/struct.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::{
marshal::MarshalInto,
reflection::{ArgumentReflection, ReturnTypeReflection},
reflection::{
equals_argument_type, equals_return_type, ArgumentReflection, ReturnTypeReflection,
},
};
use abi::{StructInfo, TypeInfo};
use std::mem;
Expand Down Expand Up @@ -39,114 +41,84 @@ impl Struct {
}

/// Retrieves the value of the field corresponding to the specified `field_name`.
pub fn get<T: ReturnTypeReflection>(&self, field_name: &str) -> Result<&T, String> {
let field_idx = self
.info
.field_names()
.enumerate()
.find(|(_, name)| *name == field_name)
.map(|(idx, _)| idx)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
self.info.name(),
field_name
)
})?;

pub fn get<T: ReturnTypeReflection>(&self, field_name: &str) -> Result<T, String> {
let field_idx = StructInfo::find_field_index(&self.info, field_name)?;
let field_type = unsafe { self.info.field_types().get_unchecked(field_idx) };
if T::type_guid() != field_type.guid {
return Err(format!(
equals_return_type::<T>(&field_type).map_err(|(expected, found)| {
format!(
"Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.",
self.info.name(),
field_name,
field_type.name(),
T::type_name()
));
}
expected,
found,
)
})?;

unsafe {
let field_value = unsafe {
// If we found the `field_idx`, we are guaranteed to also have the `field_offset`
let offset = *self.info.field_offsets().get_unchecked(field_idx);
// self.ptr is never null
Ok(&*self.raw.0.add(offset as usize).cast::<T>())
}
// TODO: The unsafe `read` fn could be avoided by adding the `Clone` bound on
// `T::Marshalled`, but its only available on nightly:
// `ReturnTypeReflection<Marshalled: Clone>`
self.raw
.0
.add(offset as usize)
.cast::<T::Marshalled>()
.read()
};
Ok(field_value.marshal_into(Some(*field_type)))
}

/// Replaces the value of the field corresponding to the specified `field_name` and returns the
/// old value.
pub fn replace<T: ArgumentReflection>(
&mut self,
field_name: &str,
mut value: T,
value: T,
) -> Result<T, String> {
let field_idx = self
.info
.field_names()
.enumerate()
.find(|(_, name)| *name == field_name)
.map(|(idx, _)| idx)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
self.info.name(),
field_name
)
})?;

let field_idx = StructInfo::find_field_index(&self.info, field_name)?;
let field_type = unsafe { self.info.field_types().get_unchecked(field_idx) };
if value.type_guid() != field_type.guid {
return Err(format!(
equals_argument_type(&field_type, &value).map_err(|(expected, found)| {
format!(
"Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.",
self.info.name(),
field_name,
field_type.name(),
value.type_name()
));
}
expected,
found,
)
})?;

let mut marshalled: T::Marshalled = value.marshal();
let ptr = unsafe {
// If we found the `field_idx`, we are guaranteed to also have the `field_offset`
let offset = *self.info.field_offsets().get_unchecked(field_idx);
// self.ptr is never null
&mut *self.raw.0.add(offset as usize).cast::<T>()
&mut *self.raw.0.add(offset as usize).cast::<T::Marshalled>()
};
mem::swap(&mut value, ptr);
Ok(value)
mem::swap(&mut marshalled, ptr);
Ok(marshalled.marshal_into(Some(*field_type)))
}

/// Sets the value of the field corresponding to the specified `field_name`.
pub fn set<T: ArgumentReflection>(&mut self, field_name: &str, value: T) -> Result<(), String> {
let field_idx = self
.info
.field_names()
.enumerate()
.find(|(_, name)| *name == field_name)
.map(|(idx, _)| idx)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
self.info.name(),
field_name
)
})?;

let field_idx = StructInfo::find_field_index(&self.info, field_name)?;
let field_type = unsafe { self.info.field_types().get_unchecked(field_idx) };
if value.type_guid() != field_type.guid {
return Err(format!(
equals_argument_type(&field_type, &value).map_err(|(expected, found)| {
format!(
"Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.",
self.info.name(),
field_name,
field_type.name(),
value.type_name()
));
}
expected,
found,
)
})?;

unsafe {
// If we found the `field_idx`, we are guaranteed to also have the `field_offset`
let offset = *self.info.field_offsets().get_unchecked(field_idx);
// self.ptr is never null
*self.raw.0.add(offset as usize).cast::<T>() = value;
*self.raw.0.add(offset as usize).cast::<T::Marshalled>() = value.marshal();
}
Ok(())
}
Expand Down
51 changes: 41 additions & 10 deletions crates/mun_runtime/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,15 @@ fn marshal_struct() {
let mut driver = TestDriver::new(
r#"
struct(gc) Foo { a: int, b: bool, c: float, };
struct Bar(Foo);
fn foo_new(a: int, b: bool, c: float): Foo {
Foo { a, b, c, }
}
fn bar_new(foo: Foo): Bar {
Bar(foo)
}
fn foo_a(foo: Foo):int { foo.a }
fn foo_b(foo: Foo):bool { foo.b }
fn foo_c(foo: Foo):float { foo.c }
Expand All @@ -401,9 +406,9 @@ fn marshal_struct() {
let b = true;
let c = 1.23f64;
let mut foo: Struct = invoke_fn!(driver.runtime, "foo_new", a, b, c).unwrap();
assert_eq!(Ok(&a), foo.get::<i64>("a"));
assert_eq!(Ok(&b), foo.get::<bool>("b"));
assert_eq!(Ok(&c), foo.get::<f64>("c"));
assert_eq!(Ok(a), foo.get::<i64>("a"));
assert_eq!(Ok(b), foo.get::<bool>("b"));
assert_eq!(Ok(c), foo.get::<f64>("c"));

let d = 6i64;
let e = false;
Expand All @@ -412,19 +417,45 @@ fn marshal_struct() {
foo.set("b", e).unwrap();
foo.set("c", f).unwrap();

assert_eq!(Ok(&d), foo.get::<i64>("a"));
assert_eq!(Ok(&e), foo.get::<bool>("b"));
assert_eq!(Ok(&f), foo.get::<f64>("c"));
assert_eq!(Ok(d), foo.get::<i64>("a"));
assert_eq!(Ok(e), foo.get::<bool>("b"));
assert_eq!(Ok(f), foo.get::<f64>("c"));

assert_eq!(Ok(d), foo.replace("a", a));
assert_eq!(Ok(e), foo.replace("b", b));
assert_eq!(Ok(f), foo.replace("c", c));

assert_eq!(Ok(&a), foo.get::<i64>("a"));
assert_eq!(Ok(&b), foo.get::<bool>("b"));
assert_eq!(Ok(&c), foo.get::<f64>("c"));
assert_eq!(Ok(a), foo.get::<i64>("a"));
assert_eq!(Ok(b), foo.get::<bool>("b"));
assert_eq!(Ok(c), foo.get::<f64>("c"));

assert_invoke_eq!(i64, a, driver, "foo_a", foo.clone());
assert_invoke_eq!(bool, b, driver, "foo_b", foo.clone());
assert_invoke_eq!(f64, c, driver, "foo_c", foo);
assert_invoke_eq!(f64, c, driver, "foo_c", foo.clone());

let mut bar: Struct = invoke_fn!(driver.runtime, "bar_new", foo.clone()).unwrap();
let foo2 = bar.get::<Struct>("0").unwrap();
assert_eq!(Ok(a), foo2.get::<i64>("a"));
assert_eq!(foo2.get::<bool>("b"), foo.get::<bool>("b"));
assert_eq!(foo2.get::<f64>("c"), foo.get::<f64>("c"));

// Specify invalid return type
let bar_err = bar.get::<i64>("0");
assert!(bar_err.is_err());

// Specify invalid argument type
let bar_err = bar.replace("0", 1i64);
assert!(bar_err.is_err());

// Specify invalid argument type
let bar_err = bar.set("0", 1i64);
assert!(bar_err.is_err());

// Specify invalid return type
let bar_err: Result<i64, _> = invoke_fn!(driver.runtime, "bar_new", foo);
assert!(bar_err.is_err());

// Pass invalid struct type
let bar_err: Result<Struct, _> = invoke_fn!(driver.runtime, "bar_new", bar);
assert!(bar_err.is_err());
}

0 comments on commit ded289b

Please sign in to comment.