Skip to content

Commit

Permalink
Merge pull request #1472 from CosmWasm/query-responses-generic-trait-…
Browse files Browse the repository at this point in the history
…bounds

QueryResponses: infer the JsonSchema trait bound
  • Loading branch information
webmaster128 authored Nov 15, 2022
2 parents 4415fd1 + 7613623 commit 6329540
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ and this project adheres to
cannot properly measure different runtimes for differet Wasm opcodes.
- cosmwasm-schema: schema generation is now locked to produce strictly
`draft-07` schemas
- cosmwasm-schema: `QueryResponses` derive now sets the `JsonSchema` trait bound
on the generated `impl` block. This allows the contract dev to not add a
`JsonSchema` trait bound on the type itself.

[#1465]: https://github.com/CosmWasm/cosmwasm/pull/1465

Expand Down
47 changes: 30 additions & 17 deletions packages/schema-derive/src/query_responses.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use syn::{parse_quote, Expr, ExprTuple, Ident, ItemEnum, ItemImpl, Type, Variant};
mod context;

use syn::{parse_quote, Expr, ExprTuple, Generics, ItemEnum, ItemImpl, Type, Variant};

use self::context::Context;

pub fn query_responses_derive_impl(input: ItemEnum) -> ItemImpl {
let is_nested = has_attr(&input, "query_responses", "nested");
let ctx = context::get_context(&input);

if is_nested {
if ctx.is_nested {
let ident = input.ident;
let subquery_calls = input.variants.into_iter().map(parse_subquery);

// Handle generics if the type has any
let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
let (_, type_generics, where_clause) = input.generics.split_for_impl();
let impl_generics = impl_generics(&ctx, &input.generics);

let subquery_len = subquery_calls.len();
parse_quote! {
Expand All @@ -31,7 +36,8 @@ pub fn query_responses_derive_impl(input: ItemEnum) -> ItemImpl {
let mappings = mappings.map(parse_tuple);

// Handle generics if the type has any
let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
let (_, type_generics, where_clause) = input.generics.split_for_impl();
let impl_generics = impl_generics(&ctx, &input.generics);

parse_quote! {
#[automatically_derived]
Expand All @@ -47,6 +53,21 @@ pub fn query_responses_derive_impl(input: ItemEnum) -> ItemImpl {
}
}

/// Takes a list of generics from the type definition and produces a list of generics
/// for the expanded `impl` block, adding trait bounds like `JsonSchema` as appropriate.
fn impl_generics(ctx: &Context, generics: &Generics) -> Generics {
let mut impl_generics = generics.to_owned();
for param in impl_generics.type_params_mut() {
if !ctx.no_bounds_for.contains(&param.ident) {
param
.bounds
.push(parse_quote! {::cosmwasm_schema::schemars::JsonSchema})
}
}

impl_generics
}

/// Extract the query -> response mapping out of an enum variant.
fn parse_query(v: Variant) -> (String, Expr) {
let query = to_snake_case(&v.ident.to_string());
Expand Down Expand Up @@ -80,14 +101,6 @@ fn parse_subquery(v: Variant) -> Expr {
parse_quote!(<#submsg as ::cosmwasm_schema::QueryResponses>::response_schemas_impl())
}

/// Checks whether the input has the given `#[$path($attr))]` attribute
fn has_attr(input: &ItemEnum, path: &str, attr: &str) -> bool {
input.attrs.iter().any(|a| {
a.path.get_ident().unwrap() == path
&& a.parse_args::<Ident>().ok().map_or(false, |i| i == attr)
})
}

fn parse_tuple((q, r): (String, Expr)) -> ExprTuple {
parse_quote! {
(#q.to_string(), #r)
Expand Down Expand Up @@ -202,13 +215,13 @@ mod tests {
};

let result = query_responses_derive_impl(input);
dbg!(&result);

assert_eq!(
result,
parse_quote! {
#[automatically_derived]
#[cfg(not(target_arch = "wasm32"))]
impl<T> ::cosmwasm_schema::QueryResponses for QueryMsg<T> {
impl<T: ::cosmwasm_schema::schemars::JsonSchema> ::cosmwasm_schema::QueryResponses for QueryMsg<T> {
fn response_schemas_impl() -> ::std::collections::BTreeMap<String, ::cosmwasm_schema::schemars::schema::RootSchema> {
::std::collections::BTreeMap::from([
("foo".to_string(), ::cosmwasm_schema::schema_for!(bool)),
Expand All @@ -223,7 +236,7 @@ mod tests {
parse_quote! {
#[automatically_derived]
#[cfg(not(target_arch = "wasm32"))]
impl<T: std::fmt::Debug + SomeTrait> ::cosmwasm_schema::QueryResponses for QueryMsg<T> {
impl<T: std::fmt::Debug + SomeTrait + ::cosmwasm_schema::schemars::JsonSchema> ::cosmwasm_schema::QueryResponses for QueryMsg<T> {
fn response_schemas_impl() -> ::std::collections::BTreeMap<String, ::cosmwasm_schema::schemars::schema::RootSchema> {
::std::collections::BTreeMap::from([
("foo".to_string(), ::cosmwasm_schema::schema_for!(bool)),
Expand All @@ -239,7 +252,7 @@ mod tests {
parse_quote! {
#[automatically_derived]
#[cfg(not(target_arch = "wasm32"))]
impl<T> ::cosmwasm_schema::QueryResponses for QueryMsg<T>
impl<T: ::cosmwasm_schema::schemars::JsonSchema> ::cosmwasm_schema::QueryResponses for QueryMsg<T>
where T: std::fmt::Debug + SomeTrait,
{
fn response_schemas_impl() -> ::std::collections::BTreeMap<String, ::cosmwasm_schema::schemars::schema::RootSchema> {
Expand Down
63 changes: 63 additions & 0 deletions packages/schema-derive/src/query_responses/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::collections::HashSet;

use syn::{Ident, ItemEnum, Meta, NestedMeta};

const ATTR_PATH: &str = "query_responses";

pub struct Context {
/// If the enum we're trying to derive QueryResponses for collects other QueryMsgs,
/// setting this flag will derive the implementation appropriately, collecting all
/// KV pairs from the nested enums rather than expecting `#[return]` annotations.
pub is_nested: bool,
/// Disable infering the `JsonSchema` trait bound for chosen type parameters.
pub no_bounds_for: HashSet<Ident>,
}

pub fn get_context(input: &ItemEnum) -> Context {
let params = input
.attrs
.iter()
.filter(|attr| matches!(attr.path.get_ident(), Some(id) if *id == ATTR_PATH))
.flat_map(|attr| {
if let Meta::List(l) = attr.parse_meta().unwrap() {
l.nested
} else {
panic!("{} attribute must contain a meta list", ATTR_PATH);
}
})
.map(|nested_meta| {
if let NestedMeta::Meta(m) = nested_meta {
m
} else {
panic!("no literals allowed in QueryResponses params")
}
});

let mut ctx = Context {
is_nested: false,
no_bounds_for: HashSet::new(),
};

for param in params {
match param.path().get_ident().unwrap().to_string().as_str() {
"no_bounds_for" => {
if let Meta::List(l) = param {
for item in l.nested {
match item {
NestedMeta::Meta(Meta::Path(p)) => {
ctx.no_bounds_for.insert(p.get_ident().unwrap().clone());
}
_ => panic!("`no_bounds_for` only accepts a list of type params"),
}
}
} else {
panic!("expected a list for `no_bounds_for`")
}
}
"nested" => ctx.is_nested = true,
path => panic!("unrecognized QueryResponses param: {}", path),
}
}

ctx
}
28 changes: 24 additions & 4 deletions packages/schema/src/query_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub use cosmwasm_schema_derive::QueryResponses;
///
/// Using the derive macro is the preferred way of implementing this trait.
///
/// # Example
/// # Examples
/// ```
/// use cosmwasm_schema::QueryResponses;
/// use schemars::JsonSchema;
Expand All @@ -30,20 +30,40 @@ pub use cosmwasm_schema_derive::QueryResponses;
/// #[returns(AccountInfo)]
/// AccountInfo { account: String },
/// }
/// ```
///
/// You can compose multiple queries using `#[query_responses(nested)]`. This might be useful
/// together with `#[serde(untagged)]`. If the `nested` flag is set, no `returns` attributes
/// are necessary on the enum variants. Instead, the response types are collected from the
/// nested enums.
///
/// // You can also compose multiple queries using #[query_responses(nested)]:
/// ```
/// # use cosmwasm_schema::QueryResponses;
/// # use schemars::JsonSchema;
/// #[derive(JsonSchema, QueryResponses)]
/// #[query_responses(nested)]
/// #[serde(untagged)]
/// enum QueryMsg2 {
/// MsgA(QueryMsg),
/// enum QueryMsg {
/// MsgA(QueryA),
/// MsgB(QueryB),
/// }
///
/// #[derive(JsonSchema, QueryResponses)]
/// enum QueryA {
/// #[returns(Vec<String>)]
/// Denoms {},
/// }
///
/// #[derive(JsonSchema, QueryResponses)]
/// enum QueryB {
/// #[returns(AccountInfo)]
/// AccountInfo { account: String },
/// }
///
/// # #[derive(JsonSchema)]
/// # struct AccountInfo {
/// # IcqHandle: String,
/// # }
/// ```
pub trait QueryResponses: JsonSchema {
fn response_schemas() -> Result<BTreeMap<String, RootSchema>, IntegrityError> {
Expand Down
41 changes: 39 additions & 2 deletions packages/schema/tests/idl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,16 @@ fn test_query_responses() {

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, JsonSchema, QueryResponses)]
#[serde(rename_all = "snake_case")]
pub enum QueryMsgWithGenerics<T: std::fmt::Debug>
pub enum QueryMsgWithGenerics<T> {
#[returns(u128)]
QueryData { data: T },
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, JsonSchema, QueryResponses)]
#[serde(rename_all = "snake_case")]
pub enum QueryMsgWithGenericsAndTraitBounds<T: std::fmt::Debug>
where
T: JsonSchema,
T: PartialEq,
{
#[returns(u128)]
QueryData { data: T },
Expand Down Expand Up @@ -147,6 +154,36 @@ fn test_query_responses_generics() {
api.get("responses").unwrap().get("query_data").unwrap();
}

#[test]
fn test_query_responses_generics_and_trait_bounds() {
let api_str = generate_api! {
instantiate: InstantiateMsg,
query: QueryMsgWithGenericsAndTraitBounds<u32>,
}
.render()
.to_string()
.unwrap();

let api: Value = serde_json::from_str(&api_str).unwrap();
let queries = api
.get("query")
.unwrap()
.get("oneOf")
.unwrap()
.as_array()
.unwrap();

// Find the "query_data" query in the queries schema
assert_eq!(queries.len(), 1);
assert_eq!(
queries[0].get("required").unwrap().get(0).unwrap(),
"query_data"
);

// Find the "query_data" query in responses
api.get("responses").unwrap().get("query_data").unwrap();
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, JsonSchema, QueryResponses)]
#[serde(untagged)]
#[query_responses(nested)]
Expand Down

0 comments on commit 6329540

Please sign in to comment.