Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add macro to add the contract state version as a custom section #2124

Merged
merged 12 commits into from
Apr 25, 2024
169 changes: 158 additions & 11 deletions packages/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use syn::{
parse::{Parse, ParseStream},
parse_quote,
punctuated::Punctuated,
Token,
ItemFn, Token,
};

macro_rules! maybe {
Expand Down Expand Up @@ -93,6 +93,24 @@ impl Parse for Options {
///
/// where `InstantiateMsg`, `ExecuteMsg`, and `QueryMsg` are contract defined
/// types that implement `DeserializeOwned + JsonSchema`.
///
/// ## Set the version of the state of your contract
///
/// The VM will use this as a hint whether it needs to run the migrate function of your contract or not.
///
/// ```
/// # use cosmwasm_std::{
/// # DepsMut, entry_point, Env,
/// # Response, StdResult,
/// # };
/// #
/// # type MigrateMsg = ();
/// #[entry_point]
/// #[state_version(2)]
/// pub fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> StdResult<Response> {
/// todo!();
/// }
/// ```
#[proc_macro_attribute]
pub fn entry_point(
attr: proc_macro::TokenStream,
Expand All @@ -101,32 +119,76 @@ pub fn entry_point(
entry_point_impl(attr.into(), item.into()).into()
}

fn entry_point_impl(attr: TokenStream, mut item: TokenStream) -> TokenStream {
let cloned = item.clone();
let function: syn::ItemFn = maybe!(syn::parse2(cloned));
fn expand_attributes(func: &mut ItemFn) -> syn::Result<TokenStream> {
let attributes = std::mem::take(&mut func.attrs);
let mut stream = TokenStream::new();
for attribute in attributes {
if !attribute.path().is_ident("state_version") {
func.attrs.push(attribute);
continue;
}

if func.sig.ident != "migrate" {
return Err(syn::Error::new_spanned(
&attribute,
"you only want to add this attribute to your migrate function",
));
aumetra marked this conversation as resolved.
Show resolved Hide resolved
}

let version: syn::LitInt = attribute.parse_args()?;
// Enforce that the version is a valid u64 and non-zero
if version.base10_parse::<u64>()? == 0 {
return Err(syn::Error::new_spanned(
version,
"please start versioning with 1",
));
}

let version = version.base10_digits();
aumetra marked this conversation as resolved.
Show resolved Hide resolved

stream = quote! {
#stream

#[allow(unused)]
#[doc(hidden)]
#[link_section = "cw_state_version"]
/// This is an internal constant exported as a custom section denoting the contract state version.
/// The format and even the existence of this value is an implementation detail, DO NOT RELY ON THIS!
static __CW_STATE_VERSION: &str = #version;
};
}

Ok(stream)
}

fn entry_point_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut function: syn::ItemFn = maybe!(syn::parse2(item));
let Options { crate_path } = maybe!(syn::parse2(attr));

let attribute_code = maybe!(expand_attributes(&mut function));

// The first argument is `deps`, the rest is region pointers
let args = function.sig.inputs.len() - 1;
let fn_name = function.sig.ident;
let args = function.sig.inputs.len().saturating_sub(1);
let fn_name = &function.sig.ident;
let wasm_export = format_ident!("__wasm_export_{fn_name}");
let do_call = format_ident!("do_{fn_name}");

let decl_args = (0..args).map(|item| format_ident!("ptr_{item}"));
let call_args = decl_args.clone();

let new_code = quote! {
quote! {
#attribute_code

#function

#[cfg(target_arch = "wasm32")]
mod #wasm_export { // new module to avoid conflict of function name
#[no_mangle]
extern "C" fn #fn_name(#( #decl_args : u32 ),*) -> u32 {
#crate_path::#do_call(&super::#fn_name, #( #call_args ),*)
}
}
};

item.extend(new_code);
item
}
}

#[cfg(test)]
Expand All @@ -136,6 +198,91 @@ mod test {

use crate::entry_point_impl;

#[test]
fn contract_state_zero_not_allowed() {
let code = quote! {
#[state_version(0)]
fn migrate() -> Response {
// Logic here
}
};

let actual = entry_point_impl(TokenStream::new(), code);
let expected = quote! {
::core::compile_error! { "please start versioning with 1" }
};

assert_eq!(actual.to_string(), expected.to_string());
}

#[test]
fn contract_state_version_on_non_migrate() {
let code = quote! {
#[state_version(42)]
fn anything_else() -> Response {
// Logic here
}
};

let actual = entry_point_impl(TokenStream::new(), code);
let expected = quote! {
::core::compile_error! { "you only want to add this attribute to your migrate function" }
};

assert_eq!(actual.to_string(), expected.to_string());
}

#[test]
fn contract_state_version_in_u64() {
let code = quote! {
#[state_version(0xDEAD_BEEF_FFFF_DEAD_2BAD)]
fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
// Logic here
}
};

let actual = entry_point_impl(TokenStream::new(), code);
let expected = quote! {
::core::compile_error! { "number too large to fit in target type" }
};

assert_eq!(actual.to_string(), expected.to_string());
}

#[test]
fn contract_state_version_expansion() {
let code = quote! {
#[state_version(2)]
fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
// Logic here
}
};

let actual = entry_point_impl(TokenStream::new(), code);
let expected = quote! {
#[allow(unused)]
#[doc(hidden)]
#[link_section = "cw_state_version"]
/// This is an internal constant exported as a custom section denoting the contract state version.
/// The format and even the existence of this value is an implementation detail, DO NOT RELY ON THIS!
static __CW_STATE_VERSION: &str = "2";

fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
// Logic here
}

#[cfg(target_arch = "wasm32")]
mod __wasm_export_migrate {
#[no_mangle]
extern "C" fn migrate(ptr_0: u32, ptr_1: u32) -> u32 {
::cosmwasm_std::do_migrate(&super::migrate, ptr_0, ptr_1)
}
}
};

assert_eq!(actual.to_string(), expected.to_string());
}

#[test]
fn default_expansion() {
let code = quote! {
Expand Down