diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 77f47b4..1c3c099 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -1,6 +1,6 @@ use proc_macro2::{Ident, Span}; use quote::quote; -use syn::{parse_macro_input, FnArg, ItemFn, ItemForeignMod}; +use syn::{parse_macro_input, FnArg, GenericArgument, ItemFn, ItemForeignMod, PathArguments}; /// `plugin_fn` is used to define an Extism callable function to export /// @@ -179,16 +179,36 @@ pub fn shared_fn( let (no_result, raw_output) = match output { syn::ReturnType::Default => (true, quote! {}), syn::ReturnType::Type(_, t) => { + let mut is_unit = false; if let syn::Type::Path(p) = t.as_ref() { if let Some(t) = p.path.segments.last() { if t.ident != "SharedFnResult" { panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult"); } + match &t.arguments { + PathArguments::AngleBracketed(args) => { + if args.args.len() == 1 { + match &args.args[0] { + GenericArgument::Type(syn::Type::Tuple(t)) => { + if t.elems.is_empty() { + is_unit = true; + } + } + _ => (), + } + } + } + _ => (), + } } else { panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult"); } }; - (false, quote! {-> u64 }) + if is_unit { + (true, quote! {}) + } else { + (false, quote! {-> u64 }) + } } };