diff --git a/packages/vm/src/cache.rs b/packages/vm/src/cache.rs index 7a78d5a968..bab59b27a8 100644 --- a/packages/vm/src/cache.rs +++ b/packages/vm/src/cache.rs @@ -14,8 +14,9 @@ use crate::errors::{VmError, VmResult}; use crate::filesystem::mkdir_p; use crate::instance::{Instance, InstanceOptions}; use crate::modules::{CachedModule, FileSystemCache, InMemoryCache, PinnedMemoryCache}; +use crate::parsed_wasm::ParsedWasm; use crate::size::Size; -use crate::static_analysis::{deserialize_exports, has_ibc_entry_points}; +use crate::static_analysis::has_ibc_entry_points; use crate::wasm_backend::{compile, make_compiling_engine, make_runtime_engine}; const STATE_DIR: &str = "state"; @@ -254,7 +255,7 @@ where pub fn analyze(&self, checksum: &Checksum) -> VmResult { // Here we could use a streaming deserializer to slightly improve performance. However, this way it is DRYer. let wasm = self.load_wasm(checksum)?; - let module = deserialize_exports(&wasm)?; + let module = ParsedWasm::parse(&wasm)?; Ok(AnalysisReport { has_ibc_entry_points: has_ibc_entry_points(&module), required_capabilities: required_capabilities_from_module(&module), @@ -571,10 +572,7 @@ mod tests { let save_result = cache.save_wasm(&wasm); match save_result.unwrap_err() { VmError::StaticValidationErr { msg, .. } => { - assert_eq!( - msg, - "Wasm contract missing a required marker export: interface_version_*" - ) + assert_eq!(msg, "Wasm contract must contain exactly one memory") } e => panic!("Unexpected error {e:?}"), } diff --git a/packages/vm/src/capabilities.rs b/packages/vm/src/capabilities.rs index af560c007d..8281105011 100644 --- a/packages/vm/src/capabilities.rs +++ b/packages/vm/src/capabilities.rs @@ -32,8 +32,9 @@ pub fn required_capabilities_from_module(module: impl ExportInfo) -> HashSet) -> VmResult<()> { - let mut memory_section: Option> = None; - validate_wasm(wasm_code, |payload| { - match payload { - Payload::TableSection(t) => check_wasm_tables(t)?, - Payload::MemorySection(m) => memory_section = Some(m), - Payload::ExportSection(e) => { - let exports = e.into_iter().collect::, _>>()?; - check_interface_version(&exports)?; - check_wasm_exports(&exports)?; - check_wasm_capabilities(&exports, available_capabilities)?; - } - Payload::ImportSection(i) => check_wasm_imports(i, SUPPORTED_IMPORTS)?, - _ => {} - } - Ok(()) - })?; - // we want to fail if there is no memory section, so this check is delayed until the end - check_wasm_memories(memory_section)?; + let module = ParsedWasm::parse(wasm_code)?; + + check_wasm_tables(&module)?; + check_wasm_memories(&module)?; + check_interface_version(&module)?; + check_wasm_exports(&module)?; + check_wasm_imports(&module, SUPPORTED_IMPORTS)?; + check_wasm_capabilities(&module, available_capabilities)?; Ok(()) } -fn check_wasm_tables(mut tables: TableSectionReader<'_>) -> VmResult<()> { - match tables.get_count() { +fn check_wasm_tables(module: &ParsedWasm) -> VmResult<()> { + match module.tables.len() { 0 => Ok(()), 1 => { - let limits = tables.read()?; + let limits = &module.tables[0]; if let Some(maximum) = limits.maximum { if maximum > TABLE_SIZE_LIMIT { return Err(VmError::static_validation_err( @@ -126,23 +111,13 @@ fn check_wasm_tables(mut tables: TableSectionReader<'_>) -> VmResult<()> { } } -fn check_wasm_memories(memory: Option>) -> VmResult<()> { - let mut section = match memory { - Some(section) => section, - None => { - return Err(VmError::static_validation_err( - "Wasm contract doesn't have a memory section", - )); - } - }; - - if section.get_count() != 1 { +fn check_wasm_memories(module: &ParsedWasm) -> VmResult<()> { + if module.memories.len() != 1 { return Err(VmError::static_validation_err( "Wasm contract must contain exactly one memory", )); } - - let memory = section.read()?; + let memory = &module.memories[0]; if memory.initial > MEMORY_LIMIT as u64 { return Err(VmError::static_validation_err(format!( @@ -158,8 +133,8 @@ fn check_wasm_memories(memory: Option>) -> VmResult<()> Ok(()) } -fn check_interface_version(exports: &[Export<'_>]) -> VmResult<()> { - let mut interface_version_exports = exports +fn check_interface_version(module: &ParsedWasm) -> VmResult<()> { + let mut interface_version_exports = module .exported_function_names(Some(INTERFACE_VERSION_PREFIX)) .into_iter(); if let Some(first_interface_version_export) = interface_version_exports.next() { @@ -188,8 +163,8 @@ fn check_interface_version(exports: &[Export<'_>]) -> VmResult<()> { } } -fn check_wasm_exports(exports: &[Export<'_>]) -> VmResult<()> { - let available_exports: HashSet = exports.exported_function_names(None); +fn check_wasm_exports(module: &ParsedWasm) -> VmResult<()> { + let available_exports: HashSet = module.exported_function_names(None); for required_export in REQUIRED_EXPORTS { if !available_exports.contains(*required_export) { return Err(VmError::static_validation_err(format!( @@ -203,25 +178,20 @@ fn check_wasm_exports(exports: &[Export<'_>]) -> VmResult<()> { /// Checks if the import requirements of the contract are satisfied. /// When this is not the case, we either have an incompatibility between contract and VM /// or a error in the contract. -fn check_wasm_imports(imports: ImportSectionReader, supported_imports: &[&str]) -> VmResult<()> { - if imports.get_count() > MAX_IMPORTS { +fn check_wasm_imports(module: &ParsedWasm, supported_imports: &[&str]) -> VmResult<()> { + if module.imports.len() > MAX_IMPORTS { return Err(VmError::static_validation_err(format!( "Import count exceeds limit. Imports: {}. Limit: {}.", - imports.get_count(), + module.imports.len(), MAX_IMPORTS ))); } - let required_imports = imports - .into_iter() - .map(|i| Ok(i?)) - .collect::>>()?; - - for required_import in &required_imports { + for required_import in &module.imports { let full_name = full_import_name(required_import); if !supported_imports.contains(&full_name.as_str()) { let required_import_names: BTreeSet<_> = - required_imports.iter().map(full_import_name).collect(); + module.imports.iter().map(full_import_name).collect(); return Err(VmError::static_validation_err(format!( "Wasm contract requires unsupported import: \"{}\". Required imports: {}. Available imports: {:?}.", full_name, required_import_names.to_string_limited(200), supported_imports @@ -243,7 +213,7 @@ fn full_import_name(ie: &Import) -> String { } fn check_wasm_capabilities( - exports: &[Export<'_>], + exports: &ParsedWasm, available_capabilities: &HashSet, ) -> VmResult<()> { let required_capabilities = required_capabilities_from_module(exports); @@ -263,10 +233,7 @@ fn check_wasm_capabilities( #[cfg(test)] mod tests { use super::*; - use crate::{ - errors::VmError, - static_analysis::{deserialize_exports, extract_reader}, - }; + use crate::errors::VmError; static CONTRACT_0_7: &[u8] = include_bytes!("../testdata/hackatom_0.7.wasm"); static CONTRACT_0_12: &[u8] = include_bytes!("../testdata/hackatom_0.12.wasm"); @@ -313,7 +280,9 @@ mod tests { match check_wasm(CONTRACT_0_12, &default_capabilities()) { Err(VmError::StaticValidationErr { msg, .. }) => { - assert!(msg.contains("Wasm contract requires unsupported import")) + assert!(msg.contains( + "Wasm contract missing a required marker export: interface_version_*" + )) } Err(e) => panic!("Unexpected error {e:?}"), Ok(_) => panic!("This must not succeeed"), @@ -321,7 +290,9 @@ mod tests { match check_wasm(CONTRACT_0_7, &default_capabilities()) { Err(VmError::StaticValidationErr { msg, .. }) => { - assert!(msg.contains("Wasm contract requires unsupported import")) + assert!(msg.contains( + "Wasm contract missing a required marker export: interface_version_*" + )) } Err(e) => panic!("Unexpected error {e:?}"), Ok(_) => panic!("This must not succeeed"), @@ -332,49 +303,30 @@ mod tests { fn check_wasm_tables_works() { // No tables is fine let wasm = wat::parse_str("(module)").unwrap(); - assert!(extract_reader!(&wasm, TableSection, TableSectionReader<'_>) - .unwrap() - .is_none()); + assert!(ParsedWasm::parse(&wasm).unwrap().memories.is_empty()); // One table (bound) let wasm = wat::parse_str("(module (table $name 123 123 funcref))").unwrap(); - check_wasm_tables( - extract_reader!(&wasm, TableSection, TableSectionReader<'_>) - .unwrap() - .unwrap(), - ) - .unwrap(); + check_wasm_tables(&ParsedWasm::parse(&wasm).unwrap()).unwrap(); // One table (bound, initial > max) let wasm = wat::parse_str("(module (table $name 124 123 funcref))").unwrap(); // this should be caught by the validator - let err = extract_reader!(&wasm, TableSection, TableSectionReader<'_>) - .map(|_| ()) - .unwrap_err(); + let err = &ParsedWasm::parse(&wasm).unwrap_err(); assert!(err .to_string() .contains("size minimum must not be greater than maximum")); // One table (bound, max too large) let wasm = wat::parse_str("(module (table $name 100 9999 funcref))").unwrap(); - let err = check_wasm_tables( - extract_reader!(&wasm, TableSection, TableSectionReader<'_>) - .unwrap() - .unwrap(), - ) - .unwrap_err(); + let err = check_wasm_tables(&ParsedWasm::parse(&wasm).unwrap()).unwrap_err(); assert!(err .to_string() .contains("Wasm contract's first table section has a too large max limit")); // One table (unbound) let wasm = wat::parse_str("(module (table $name 100 funcref))").unwrap(); - let err = check_wasm_tables( - extract_reader!(&wasm, TableSection, TableSectionReader<'_>) - .unwrap() - .unwrap(), - ) - .unwrap_err(); + let err = check_wasm_tables(&ParsedWasm::parse(&wasm).unwrap()).unwrap_err(); assert!(err .to_string() .contains("Wasm contract must not have unbound table section")); @@ -383,18 +335,15 @@ mod tests { #[test] fn check_wasm_memories_ok() { let wasm = wat::parse_str("(module (memory 1))").unwrap(); - check_wasm_memories(extract_reader!(&wasm, MemorySection, MemorySectionReader<'_>).unwrap()) - .unwrap() + check_wasm_memories(&ParsedWasm::parse(&wasm).unwrap()).unwrap() } #[test] fn check_wasm_memories_no_memory() { let wasm = wat::parse_str("(module)").unwrap(); - match check_wasm_memories( - extract_reader!(&wasm, MemorySection, MemorySectionReader<'_>).unwrap(), - ) { + match check_wasm_memories(&ParsedWasm::parse(&wasm).unwrap()) { Err(VmError::StaticValidationErr { msg, .. }) => { - assert!(msg.starts_with("Wasm contract doesn't have a memory section")); + assert!(msg.starts_with("Wasm contract must contain exactly one memory")); } Err(e) => panic!("Unexpected error {e:?}"), Ok(_) => panic!("Didn't reject wasm with invalid api"), @@ -417,7 +366,7 @@ mod tests { .unwrap(); // wrong number of memories should be caught by the validator - match extract_reader!(&wasm, MemorySection, MemorySectionReader<'_>) { + match ParsedWasm::parse(&wasm) { Err(VmError::StaticValidationErr { msg, .. }) => { assert!(msg.contains("multiple memories")); } @@ -438,9 +387,7 @@ mod tests { )) .unwrap(); - match check_wasm_memories( - extract_reader!(&wasm, MemorySection, MemorySectionReader<'_>).unwrap(), - ) { + match check_wasm_memories(&ParsedWasm::parse(&wasm).unwrap()) { Err(VmError::StaticValidationErr { msg, .. }) => { assert!(msg.starts_with("Wasm contract must contain exactly one memory")); } @@ -452,15 +399,10 @@ mod tests { #[test] fn check_wasm_memories_initial_size() { let wasm_ok = wat::parse_str("(module (memory 512))").unwrap(); - check_wasm_memories( - extract_reader!(&wasm_ok, MemorySection, MemorySectionReader<'_>).unwrap(), - ) - .unwrap(); + check_wasm_memories(&ParsedWasm::parse(&wasm_ok).unwrap()).unwrap(); let wasm_too_big = wat::parse_str("(module (memory 513))").unwrap(); - match check_wasm_memories( - extract_reader!(&wasm_too_big, MemorySection, MemorySectionReader<'_>).unwrap(), - ) { + match check_wasm_memories(&ParsedWasm::parse(&wasm_too_big).unwrap()) { Err(VmError::StaticValidationErr { msg, .. }) => { assert!(msg.starts_with("Wasm contract memory's minimum must not exceed 512 pages")); } @@ -472,9 +414,7 @@ mod tests { #[test] fn check_wasm_memories_maximum_size() { let wasm_max = wat::parse_str("(module (memory 1 5))").unwrap(); - match check_wasm_memories( - extract_reader!(&wasm_max, MemorySection, MemorySectionReader<'_>).unwrap(), - ) { + match check_wasm_memories(&ParsedWasm::parse(&wasm_max).unwrap()) { Err(VmError::StaticValidationErr { msg, .. }) => { assert!(msg.starts_with("Wasm contract memory's maximum must be unset")); } @@ -498,7 +438,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); check_interface_version(&module).unwrap(); #[cfg(feature = "allow_interface_version_7")] @@ -516,7 +456,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); check_interface_version(&module).unwrap(); } @@ -532,7 +472,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); match check_interface_version(&module).unwrap_err() { VmError::StaticValidationErr { msg, .. } => { assert_eq!( @@ -557,7 +497,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); match check_interface_version(&module).unwrap_err() { VmError::StaticValidationErr { msg, .. } => { assert_eq!( @@ -581,7 +521,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); match check_interface_version(&module).unwrap_err() { VmError::StaticValidationErr { msg, .. } => { assert_eq!(msg, "Wasm contract has unknown interface_version_* marker export (see https://github.com/CosmWasm/cosmwasm/blob/main/packages/vm/README.md)"); @@ -602,7 +542,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); match check_interface_version(&module).unwrap_err() { VmError::StaticValidationErr { msg, .. } => { assert_eq!(msg, "Wasm contract has unknown interface_version_* marker export (see https://github.com/CosmWasm/cosmwasm/blob/main/packages/vm/README.md)"); @@ -625,7 +565,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); check_wasm_exports(&module).unwrap(); // this is invalid, as it doesn't any required export @@ -637,7 +577,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); match check_wasm_exports(&module) { Err(VmError::StaticValidationErr { msg, .. }) => { assert!(msg.starts_with("Wasm contract doesn't have required export: \"allocate\"")); @@ -656,7 +596,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); match check_wasm_exports(&module) { Err(VmError::StaticValidationErr { msg, .. }) => { assert!( @@ -670,7 +610,7 @@ mod tests { #[test] fn check_wasm_exports_of_old_contract() { - let module = deserialize_exports(CONTRACT_0_7).unwrap(); + let module = ParsedWasm::parse(CONTRACT_0_7).unwrap(); match check_wasm_exports(&module) { Err(VmError::StaticValidationErr { msg, .. }) => { assert!( @@ -699,13 +639,7 @@ mod tests { )"#, ) .unwrap(); - check_wasm_imports( - extract_reader!(&wasm, ImportSection, ImportSectionReader<'_>) - .unwrap() - .unwrap(), - SUPPORTED_IMPORTS, - ) - .unwrap(); + check_wasm_imports(&ParsedWasm::parse(&wasm).unwrap(), SUPPORTED_IMPORTS).unwrap(); } #[test] @@ -816,13 +750,8 @@ mod tests { )"#, ) .unwrap(); - let err = check_wasm_imports( - extract_reader!(&wasm, ImportSection, ImportSectionReader<'_>) - .unwrap() - .unwrap(), - SUPPORTED_IMPORTS, - ) - .unwrap_err(); + let err = + check_wasm_imports(&ParsedWasm::parse(&wasm).unwrap(), SUPPORTED_IMPORTS).unwrap_err(); match err { VmError::StaticValidationErr { msg, .. } => { assert_eq!(msg, "Import count exceeds limit. Imports: 101. Limit: 100."); @@ -859,12 +788,7 @@ mod tests { "env.debug", "env.query_chain", ]; - let result = check_wasm_imports( - extract_reader!(&wasm, ImportSection, ImportSectionReader<'_>) - .unwrap() - .unwrap(), - supported_imports, - ); + let result = check_wasm_imports(&ParsedWasm::parse(&wasm).unwrap(), supported_imports); match result.unwrap_err() { VmError::StaticValidationErr { msg, .. } => { println!("{msg}"); @@ -879,9 +803,7 @@ mod tests { #[test] fn check_wasm_imports_of_old_contract() { - let module = extract_reader!(CONTRACT_0_7, ImportSection, ImportSectionReader<'_>) - .unwrap() - .unwrap(); + let module = &ParsedWasm::parse(CONTRACT_0_7).unwrap(); let result = check_wasm_imports(module, SUPPORTED_IMPORTS); match result.unwrap_err() { VmError::StaticValidationErr { msg, .. } => { @@ -896,12 +818,7 @@ mod tests { #[test] fn check_wasm_imports_wrong_type() { let wasm = wat::parse_str(r#"(module (import "env" "db_read" (memory 1 1)))"#).unwrap(); - let result = check_wasm_imports( - extract_reader!(&wasm, ImportSection, ImportSectionReader<'_>) - .unwrap() - .unwrap(), - SUPPORTED_IMPORTS, - ); + let result = check_wasm_imports(&ParsedWasm::parse(&wasm).unwrap(), SUPPORTED_IMPORTS); match result.unwrap_err() { VmError::StaticValidationErr { msg, .. } => { assert!( @@ -927,7 +844,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); let available = [ "water".to_string(), "nutrients".to_string(), @@ -954,7 +871,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); // Available set 1 let available = [ diff --git a/packages/vm/src/lib.rs b/packages/vm/src/lib.rs index af7ed2f519..05987bac76 100644 --- a/packages/vm/src/lib.rs +++ b/packages/vm/src/lib.rs @@ -16,6 +16,7 @@ mod instance; mod limited; mod memory; mod modules; +mod parsed_wasm; mod sections; mod serde; mod size; diff --git a/packages/vm/src/parsed_wasm.rs b/packages/vm/src/parsed_wasm.rs new file mode 100644 index 0000000000..6de9fbaed4 --- /dev/null +++ b/packages/vm/src/parsed_wasm.rs @@ -0,0 +1,70 @@ +use wasmer::wasmparser::{ + Export, Import, MemoryType, Parser, TableType, ValidPayload, Validator, WasmFeatures, +}; + +use crate::VmResult; + +/// A parsed and validated wasm module. +/// It keeps track of the parts that are important for our static analysis and compatibility checks. +#[derive(Debug)] +pub struct ParsedWasm<'a> { + pub version: u32, + pub exports: Vec>, + pub imports: Vec>, + pub tables: Vec, + pub memories: Vec, +} + +impl<'a> ParsedWasm<'a> { + pub fn parse(wasm: &'a [u8]) -> VmResult { + let mut validator = Validator::new_with_features(WasmFeatures { + deterministic_only: true, + component_model: false, + simd: false, + relaxed_simd: false, + threads: false, + multi_memory: false, + memory64: false, + ..Default::default() + }); + + let mut this = Self { + version: 0, + exports: vec![], + imports: vec![], + tables: vec![], + memories: vec![], + }; + + let mut fun_allocations = Default::default(); + for p in Parser::new(0).parse_all(wasm) { + let p = p?; + // validate the payload + if let ValidPayload::Func(fv, body) = validator.payload(&p)? { + // also validate function bodies + let mut fun_validator = fv.into_validator(fun_allocations); + fun_validator.validate(&body)?; + fun_allocations = fun_validator.into_allocations(); + } + + match p { + wasmer::wasmparser::Payload::Version { num, .. } => this.version = num, + wasmer::wasmparser::Payload::ImportSection(i) => { + this.imports = i.into_iter().collect::, _>>()?; + } + wasmer::wasmparser::Payload::TableSection(t) => { + this.tables = t.into_iter().collect::, _>>()?; + } + wasmer::wasmparser::Payload::MemorySection(m) => { + this.memories = m.into_iter().collect::, _>>()?; + } + wasmer::wasmparser::Payload::ExportSection(e) => { + this.exports = e.into_iter().collect::, _>>()?; + } + _ => {} // ignore everything else + } + } + + Ok(this) + } +} diff --git a/packages/vm/src/static_analysis.rs b/packages/vm/src/static_analysis.rs index f3e5804427..7e20fb95f3 100644 --- a/packages/vm/src/static_analysis.rs +++ b/packages/vm/src/static_analysis.rs @@ -1,11 +1,8 @@ use std::collections::HashSet; -use wasmer::wasmparser::{ - Export, ExportSectionReader, ExternalKind, Parser, Payload, ValidPayload, Validator, - WasmFeatures, -}; +use wasmer::wasmparser::ExternalKind; -use crate::errors::VmResult; +use crate::parsed_wasm::ParsedWasm; pub const REQUIRED_IBC_EXPORTS: &[&str] = &[ "ibc_channel_open", @@ -16,68 +13,6 @@ pub const REQUIRED_IBC_EXPORTS: &[&str] = &[ "ibc_packet_timeout", ]; -/// Validates the given wasm code and calls the callback for each payload. -/// "Validates" in this case refers to general WebAssembly validation, not specific to CosmWasm. -pub fn validate_wasm<'a>( - wasm_code: &'a [u8], - mut handle_payload: impl FnMut(Payload<'a>) -> VmResult<()>, -) -> VmResult<()> { - let mut validator = Validator::new_with_features(WasmFeatures { - deterministic_only: true, - component_model: false, - simd: false, - relaxed_simd: false, - threads: false, - multi_memory: false, - memory64: false, - ..Default::default() - }); - - let mut fun_allocations = Default::default(); - for p in Parser::new(0).parse_all(wasm_code) { - let p = p?; - // validate the payload - if let ValidPayload::Func(fv, body) = validator.payload(&p)? { - // also validate function bodies - let mut fun_validator = fv.into_validator(fun_allocations); - fun_validator.validate(&body)?; - fun_allocations = fun_validator.into_allocations(); - } - // tell caller about the payload - handle_payload(p)?; - } - - Ok(()) -} - -/// A small helper macro to validate the wasm module and extract a reader for a specific section. -macro_rules! extract_reader { - ($wasm_code: expr, $payload: ident, $t: ty) => {{ - fn extract(wasm_code: &[u8]) -> $crate::VmResult> { - let mut value = None; - $crate::static_analysis::validate_wasm(wasm_code, |p| { - if let Payload::$payload(p) = p { - value = Some(p); - } - Ok(()) - })?; - Ok(value) - } - - extract($wasm_code) - }}; -} - -pub(crate) use extract_reader; - -pub fn deserialize_exports(wasm_code: &[u8]) -> VmResult>> { - let exports = extract_reader!(wasm_code, ExportSection, ExportSectionReader<'_>)?; - Ok(exports - .map(|e| e.into_iter().collect::, _>>()) - .transpose()? - .unwrap_or_default()) -} - /// A trait that allows accessing shared functionality of `parity_wasm::elements::Module` /// and `wasmer::Module` in a shared fashion. pub trait ExportInfo { @@ -85,9 +20,10 @@ pub trait ExportInfo { fn exported_function_names(self, prefix: Option<&str>) -> HashSet; } -impl ExportInfo for &[Export<'_>] { +impl ExportInfo for &ParsedWasm<'_> { fn exported_function_names(self, prefix: Option<&str>) -> HashSet { - self.iter() + self.exports + .iter() .filter_map(|export| match export.kind { ExternalKind::Func => Some(export.name), _ => None, @@ -104,12 +40,6 @@ impl ExportInfo for &[Export<'_>] { } } -impl ExportInfo for &Vec> { - fn exported_function_names(self, prefix: Option<&str>) -> HashSet { - self[..].exported_function_names(prefix) - } -} - impl ExportInfo for &wasmer::Module { fn exported_function_names(self, prefix: Option<&str>) -> HashSet { self.exports() @@ -152,15 +82,17 @@ mod tests { #[test] fn deserialize_exports_works() { - let module = deserialize_exports(CONTRACT).unwrap(); - // assert_eq!(module.version(), 1); // TODO: not implemented anymore + let module = ParsedWasm::parse(CONTRACT).unwrap(); + assert_eq!(module.version, 1); let exported_functions = module + .exports .iter() .filter(|entry| matches!(entry.kind, ExternalKind::Func)); assert_eq!(exported_functions.count(), 8); // 4 required exports plus "execute", "migrate", "query" and "sudo" let exported_memories = module + .exports .iter() .filter(|entry| matches!(entry.kind, ExternalKind::Memory)); assert_eq!(exported_memories.count(), 1); @@ -168,7 +100,7 @@ mod tests { #[test] fn deserialize_wasm_corrupted_data() { - match deserialize_exports(CORRUPTED).unwrap_err() { + match ParsedWasm::parse(CORRUPTED).unwrap_err() { VmError::StaticValidationErr { msg, .. } => { assert!(msg.starts_with("Wasm bytecode could not be deserialized.")) } @@ -179,7 +111,7 @@ mod tests { #[test] fn exported_function_names_works_for_parity_with_no_prefix() { let wasm = wat::parse_str(r#"(module)"#).unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); let exports = module.exported_function_names(None); assert_eq!(exports, HashSet::new()); @@ -195,7 +127,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); let exports = module.exported_function_names(None); assert_eq!( exports, @@ -206,7 +138,7 @@ mod tests { #[test] fn exported_function_names_works_for_parity_with_prefix() { let wasm = wat::parse_str(r#"(module)"#).unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); let exports = module.exported_function_names(Some("b")); assert_eq!(exports, HashSet::new()); @@ -223,7 +155,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); let exports = module.exported_function_names(Some("b")); assert_eq!( exports, @@ -311,7 +243,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); assert!(!has_ibc_entry_points(&module)); // IBC contract @@ -336,7 +268,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); assert!(has_ibc_entry_points(&module)); // Missing packet ack @@ -360,7 +292,7 @@ mod tests { )"#, ) .unwrap(); - let module = deserialize_exports(&wasm).unwrap(); + let module = ParsedWasm::parse(&wasm).unwrap(); assert!(!has_ibc_entry_points(&module)); } }