Skip to content

Commit

Permalink
feat: Add shared_fn macro (#55)
Browse files Browse the repository at this point in the history
* feat: add export_fn

* feat: add Pointer type to help with export_fn

* test: add nothing test

* cleanup: rename Pointer -> MemoryPointer

* cleanup: derives

* cleanup: rename export_fn -> shared_fn

* doc: add example to plugin_fn and shared_fn
  • Loading branch information
zshipko authored May 21, 2024
1 parent 91485fd commit b0e8c3b
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 7 deletions.
File renamed without changes.
136 changes: 132 additions & 4 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
use proc_macro2::{Ident, Span};
use quote::quote;
use syn::{parse_macro_input, ItemFn, ItemForeignMod};
use syn::{parse_macro_input, FnArg, ItemFn, ItemForeignMod};

/// `plugin_fn` is used to define a function that will be exported by a plugin
/// `plugin_fn` is used to define an Extism callable function to export
///
/// It should be added to a function you would like to export, the function should
/// accept a parameter that implements `extism_pdk::FromBytes` and return a
/// `extism_pdk::FnResult` that contains a value that implements
/// `extism_pdk::ToBytes`.
/// `extism_pdk::ToBytes`. This maps input and output parameters to Extism input
/// and output instead of using function arguments directly.
///
/// ## Example
///
/// ```rust
/// use extism_pdk::{FnResult, plugin_fn};
/// #[plugin_fn]
/// pub fn greet(name: String) -> FnResult<String> {
/// let s = format!("Hello, {name}");
/// Ok(s)
/// }
/// ```
#[proc_macro_attribute]
pub fn plugin_fn(
_attr: proc_macro::TokenStream,
Expand Down Expand Up @@ -103,7 +116,122 @@ pub fn plugin_fn(
}
}

/// `host_fn` is used to define a host function that will be callable from within a plugin
/// `shared_fn` is used to define a function that will be exported by a plugin but is not directly
/// callable by an Extism runtime. These functions can be used for runtime linking and mocking host
/// functions for tests. If direct access to Wasm native parameters is needed, then a bare
/// `extern "C" fn` should be used instead.
///
/// All arguments should implement `extism_pdk::ToBytes` and the return value should implement
/// `extism_pdk::FromBytes`
/// ## Example
///
/// ```rust
/// use extism_pdk::{FnResult, shared_fn};
/// #[shared_fn]
/// pub fn greet2(greeting: String, name: String) -> FnResult<String> {
/// let s = format!("{greeting}, {name}");
/// Ok(name)
/// }
/// ```
#[proc_macro_attribute]
pub fn shared_fn(
_attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let mut function = parse_macro_input!(item as ItemFn);

if !matches!(function.vis, syn::Visibility::Public(..)) {
panic!("extism_pdk::shared_fn expects a public function");
}

let name = &function.sig.ident;
let constness = &function.sig.constness;
let unsafety = &function.sig.unsafety;
let generics = &function.sig.generics;
let inputs = &mut function.sig.inputs;
let output = &mut function.sig.output;
let block = &function.block;

let (raw_inputs, raw_args): (Vec<_>, Vec<_>) = inputs
.iter()
.enumerate()
.map(|(i, x)| {
let t = match x {
FnArg::Receiver(_) => {
panic!("Receiver argument (self) cannot be used in extism_pdk::shared_fn")
}
FnArg::Typed(t) => &t.ty,
};
let arg = Ident::new(&format!("arg{i}"), Span::call_site());
(
quote! { #arg: extism_pdk::MemoryPointer<#t> },
quote! { #arg.get()? },
)
})
.unzip();

if name == "main" {
panic!(
"export_pdk::shared_fn must not be applied to a `main` function. To fix, rename this to something other than `main`."
)
}

let (no_result, raw_output) = match output {
syn::ReturnType::Default => (true, quote! {}),
syn::ReturnType::Type(_, t) => {
if let syn::Type::Path(p) = t.as_ref() {
if let Some(t) = p.path.segments.last() {
if t.ident != "SharedFnResult" {
panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult");
}
} else {
panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult");
}
};
(false, quote! {-> u64 })
}
};

if no_result {
quote! {
#[no_mangle]
pub #constness #unsafety extern "C" fn #name(#(#raw_inputs,)*) {
#constness #unsafety fn inner #generics(#inputs) -> extism_pdk::SharedFnResult<()> {
#block
}


let r = || inner(#(#raw_args,)*);
if let Err(rc) = r() {
panic!("{}", rc.to_string());
}
}
}
.into()
} else {
quote! {
#[no_mangle]
pub #constness #unsafety extern "C" fn #name(#(#raw_inputs,)*) #raw_output {
#constness #unsafety fn inner #generics(#inputs) #output {
#block
}

let r = || inner(#(#raw_args,)*);
match r().and_then(|x| extism_pdk::Memory::new(&x)) {
Ok(mem) => {
mem.offset()
},
Err(rc) => {
panic!("{}", rc.to_string());
}
}
}
}
.into()
}
}

/// `host_fn` is used to import a host function from an `extern` block
#[proc_macro_attribute]
pub fn host_fn(
attr: proc_macro::TokenStream,
Expand Down
13 changes: 13 additions & 0 deletions examples/reflect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#![no_main]

use extism_pdk::*;

#[shared_fn]
pub fn host_reflect(input: String) -> SharedFnResult<Vec<u8>> {
Ok(input.to_lowercase().into_bytes())
}

#[shared_fn]
pub fn nothing() -> SharedFnResult<()> {
Ok(())
}
2 changes: 1 addition & 1 deletion src/extism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ extern "C" {
}

/// Loads a byte array from Extism's memory. Only use this if you
/// have already considered the plugin_fn macro as well as the [extism_load_input] function.
/// have already considered the plugin_fn macro as well as the `extism_load_input` function.
///
/// # Arguments
///
Expand Down
7 changes: 5 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ pub mod http;
pub use anyhow::Error;
pub use extism_convert::*;
pub use extism_convert::{FromBytes, FromBytesOwned, ToBytes};
pub use extism_pdk_derive::{host_fn, plugin_fn};
pub use memory::Memory;
pub use extism_pdk_derive::{host_fn, plugin_fn, shared_fn};
pub use memory::{Memory, MemoryPointer};
pub use to_memory::ToMemory;

#[cfg(feature = "http")]
Expand All @@ -37,6 +37,9 @@ pub use http::HttpResponse;
/// The return type of a plugin function
pub type FnResult<T> = Result<T, WithReturnCode<Error>>;

/// The return type of a `shared_fn`
pub type SharedFnResult<T> = Result<T, Error>;

/// Logging levels
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogLevel {
Expand Down
20 changes: 20 additions & 0 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,23 @@ impl From<i64> for Memory {
Memory::find(offset as u64).unwrap_or_else(Memory::null)
}
}

#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct MemoryPointer<T>(u64, std::marker::PhantomData<T>);

impl<T> MemoryPointer<T> {
pub unsafe fn new(x: u64) -> Self {
MemoryPointer(x, Default::default())
}
}

impl<T: FromBytesOwned> MemoryPointer<T> {
pub fn get(&self) -> Result<T, Error> {
let mem = Memory::find(self.0);
match mem {
Some(mem) => T::from_bytes_owned(&mem.to_vec()),
None => anyhow::bail!("Invalid pointer offset {}", self.0),
}
}
}

0 comments on commit b0e8c3b

Please sign in to comment.