diff --git a/src/interpreter/builtins/concat.rs b/src/interpreter/builtins/concat.rs index 275b8e4..51aafce 100644 --- a/src/interpreter/builtins/concat.rs +++ b/src/interpreter/builtins/concat.rs @@ -6,6 +6,7 @@ use lazy_static::lazy_static; use crate::interpreter::{ functions::{FunctionDef, FunctionParam}, + types::HashableIndexMap, Env, Value, }; @@ -66,7 +67,12 @@ impl FunctionDef for Concat { false } - fn execute<'a>(&'a self, _env: &'a mut Env, args: &'a [Value]) -> BoxFuture<'a, Result> { + fn execute<'a>( + &'a self, + _env: &'a mut Env, + args: &'a [Value], + _options: &'a HashableIndexMap, + ) -> BoxFuture<'a, Result> { async { concat(args) }.boxed() } } diff --git a/src/interpreter/builtins/iterable.rs b/src/interpreter/builtins/iterable.rs index 3fa9026..94a8128 100644 --- a/src/interpreter/builtins/iterable.rs +++ b/src/interpreter/builtins/iterable.rs @@ -18,7 +18,7 @@ fn map<'a>( let mut values = vec![]; for v in receiver.get_items()? { let value = match args.first() { - Some(Value::Func(func)) => func.execute(&[v.clone()], env).await?, + Some(Value::Func(func)) => func.execute(env, &[v.clone()]).await?, Some(Value::TypeObject(type_)) => type_.cast(&v)?, _ => bail!("map function expects a function or type as an argument"), }; diff --git a/src/interpreter/builtins/mod.rs b/src/interpreter/builtins/mod.rs index 4d5f211..49960c7 100644 --- a/src/interpreter/builtins/mod.rs +++ b/src/interpreter/builtins/mod.rs @@ -17,7 +17,6 @@ mod receipt; mod repl; use crate::interpreter::functions::Function; -use crate::interpreter::functions::FunctionCall; use crate::interpreter::functions::FunctionDef; use crate::interpreter::types::NonParametricType; use crate::interpreter::Type; @@ -42,7 +41,10 @@ lazy_static! { ("type", misc::GET_TYPE.clone()), ]; for (name, func) in funcs { - m.insert(name.to_string(), Value::Func(Function::Call(Box::new(FunctionCall::new(func, None))))); + m.insert( + name.to_string(), + Value::Func(Box::new(Function::new(func, None))), + ); } m @@ -90,12 +92,12 @@ lazy_static! { let mut transaction_methods = HashMap::new(); transaction_methods.insert("format".to_string(), format::NON_NUM_FORMAT.clone()); - // transaction_methods.insert("getReceipt".to_string(), receipt::TX_GET_RECEIPT.clone()); + transaction_methods.insert("getReceipt".to_string(), receipt::TX_GET_RECEIPT.clone()); m.insert(NonParametricType::Transaction, transaction_methods); let mut mapping_methods = HashMap::new(); mapping_methods.insert("format".to_string(), format::NON_NUM_FORMAT.clone()); - // mapping_methods.insert("keys".to_string(), misc::MAPPING_KEYS.clone()); + mapping_methods.insert("keys".to_string(), misc::MAPPING_KEYS.clone()); m.insert(NonParametricType::Mapping, mapping_methods); m diff --git a/src/interpreter/functions/contract.rs b/src/interpreter/functions/contract.rs new file mode 100644 index 0000000..801016c --- /dev/null +++ b/src/interpreter/functions/contract.rs @@ -0,0 +1,230 @@ +use std::sync::Arc; + +use alloy::{ + contract::{CallBuilder, ContractInstance, Interface}, + json_abi::StateMutability, + network::{Network, TransactionBuilder}, + primitives::{keccak256, Address, FixedBytes}, + providers::Provider, + rpc::types::{TransactionInput, TransactionRequest}, + transports::Transport, +}; +use anyhow::{anyhow, bail, Result}; +use futures::{future::BoxFuture, FutureExt}; +use itertools::Itertools; + +use crate::interpreter::{types::HashableIndexMap, ContractInfo, Env, Type, Value}; + +use super::{Function, FunctionDef, FunctionParam}; + +#[derive(Debug, Clone, PartialEq, Hash, Eq)] +pub enum ContractCallMode { + Default, + Encode, + Call, + Send, +} + +impl TryFrom<&str> for ContractCallMode { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s { + "encode" => Ok(ContractCallMode::Encode), + "call" => Ok(ContractCallMode::Call), + "send" => Ok(ContractCallMode::Send), + _ => bail!("{} does not exist for contract call", s), + } + } +} + +#[derive(Debug, Clone, Default, Hash, PartialEq, Eq)] +pub struct CallOptions { + value: Option>, +} + +impl std::fmt::Display for CallOptions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(v) = &self.value { + write!(f, "value: {}", v) + } else { + write!(f, "") + } + } +} + +impl TryFrom<&HashableIndexMap> for CallOptions { + type Error = anyhow::Error; + + fn try_from(value: &HashableIndexMap) -> std::result::Result { + let mut opts = CallOptions::default(); + for (k, v) in value.0.iter() { + match k.as_str() { + "value" => opts.value = Some(Box::new(v.clone())), + _ => bail!("unexpected key {}", k), + } + } + Ok(opts) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct ContractFunction { + func_name: String, + mode: ContractCallMode, +} + +impl ContractFunction { + pub fn arc(name: &str) -> Arc { + Arc::new(Self { + func_name: name.to_string(), + mode: ContractCallMode::Default, + }) + } + + pub fn with_mode(&self, mode: ContractCallMode) -> Self { + let mut new = self.clone(); + new.mode = mode; + new + } + + pub fn get_signature(&self, types_: &[Type]) -> String { + let mut selector = self.func_name.clone(); + selector.push('('); + let args_str = types_ + .iter() + .map(|t| t.canonical_string().expect("canonical string")) + .join(","); + selector.push_str(&args_str); + selector.push(')'); + selector + } + + pub fn get_selector(&self, types_: &[Type]) -> FixedBytes<4> { + let signature_hash = keccak256(self.get_signature(types_)); + FixedBytes::<4>::from_slice(&signature_hash[..4]) + } +} + +impl FunctionDef for ContractFunction { + fn name(&self) -> &str { + &self.func_name + } + + fn get_valid_args(&self, receiver: &Option) -> Vec> { + let (ContractInfo(_, abi), _) = receiver.clone().unwrap().as_contract().unwrap(); + let functions = abi.function(self.name()).cloned().unwrap_or(vec![]); + + functions + .into_iter() + .filter_map(|f| { + f.inputs + .into_iter() + .map(FunctionParam::try_from) + .collect::>>() + .ok() + }) + .collect() + } + + fn is_property(&self) -> bool { + false + } + + fn member_access(&self, receiver: &Option, member: &str) -> Option { + ContractCallMode::try_from(member) + .map(|m| Function::new(Arc::new(self.with_mode(m)), receiver.as_ref()).into()) + .ok() + } + + fn execute<'a>( + &'a self, + env: &'a mut Env, + values: &'a [Value], + options: &'a HashableIndexMap, + ) -> BoxFuture<'a, Result> { + let (ContractInfo(_, abi), addr) = values[0].as_contract().unwrap(); + let types_ = values[1..].iter().map(Value::get_type).collect::>(); + let selector = self.get_selector(&types_); + + async move { + let abi_func = abi + .functions() + .find(|f| f.selector() == selector) + .ok_or_else(|| anyhow!("function {} not found", self.get_signature(&types_)))?; + let interface = Interface::new(abi.clone()); + let contract = + ContractInstance::new(addr, env.get_provider().root().clone(), interface); + let call_options: CallOptions = options.try_into()?; + let tokens = values[1..] + .iter() + .map(|arg| arg.try_into()) + .collect::>>()?; + let func = contract.function_from_selector(&selector, &tokens)?; + let is_view = abi_func.state_mutability == StateMutability::Pure + || abi_func.state_mutability == StateMutability::View; + + if self.mode == ContractCallMode::Encode { + let encoded = func.calldata(); + Ok(Value::Bytes(encoded[..].to_vec())) + } else if self.mode == ContractCallMode::Call + || (self.mode == ContractCallMode::Default && is_view) + { + _execute_contract_call(func).await + } else { + _execute_contract_send(&addr, func, &call_options, env).await + } + } + .boxed() + } +} + +async fn _execute_contract_send( + addr: &Address, + func: CallBuilder, + opts: &CallOptions, + env: &Env, +) -> Result +where + T: Transport + Clone, + P: Provider, + N: Network, +{ + let data = func.calldata(); + let input = TransactionInput::new(data.clone()); + let from_ = env + .get_default_sender() + .ok_or(anyhow!("no wallet connected"))?; + let mut tx_req = TransactionRequest::default() + .with_from(from_) + .with_to(*addr) + .input(input); + if let Some(value) = opts.value.as_ref() { + let value = value.as_u256()?; + tx_req = tx_req.with_value(value); + } + + let provider = env.get_provider(); + let tx = provider.send_transaction(tx_req).await?; + Ok(Value::Transaction(*tx.tx_hash())) +} + +async fn _execute_contract_call( + func: CallBuilder, +) -> Result +where + T: Transport + Clone, + P: Provider, + N: Network, +{ + let result = func.call().await?; + let return_values = result + .into_iter() + .map(Value::try_from) + .collect::>>()?; + if return_values.len() == 1 { + Ok(return_values.into_iter().next().unwrap()) + } else { + Ok(Value::Tuple(return_values)) + } +} diff --git a/src/interpreter/functions/definition.rs b/src/interpreter/functions/definition.rs index ccb2463..ecaede5 100644 --- a/src/interpreter/functions/definition.rs +++ b/src/interpreter/functions/definition.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::interpreter::{functions::FunctionParam, Env, Value}; +use crate::interpreter::{functions::FunctionParam, types::HashableIndexMap, Env, Value}; use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, FutureExt}; @@ -11,8 +11,16 @@ pub trait FunctionDef: std::fmt::Debug + Send + Sync { fn is_property(&self) -> bool; - fn execute<'a>(&'a self, env: &'a mut Env, values: &'a [Value]) - -> BoxFuture<'a, Result>; + fn execute<'a>( + &'a self, + env: &'a mut Env, + values: &'a [Value], + options: &'a HashableIndexMap, + ) -> BoxFuture<'a, Result>; + + fn member_access(&self, _receiver: &Option, _member: &str) -> Option { + None + } } #[derive(Debug)] @@ -47,6 +55,7 @@ impl FunctionDef for SyncProperty { &'a self, env: &'a mut Env, values: &'a [Value], + _options: &'a HashableIndexMap, ) -> BoxFuture<'a, Result> { async move { let receiver = values.first().ok_or(anyhow!("no receiver"))?; @@ -91,6 +100,7 @@ impl FunctionDef for AsyncProperty { &'a self, env: &'a mut Env, values: &'a [Value], + _options: &'a HashableIndexMap, ) -> BoxFuture<'a, Result> { async move { let receiver = values.first().ok_or(anyhow!("no receiver"))?; @@ -138,6 +148,7 @@ impl FunctionDef for SyncMethod { &'a self, env: &'a mut Env, values: &'a [Value], + _options: &'a HashableIndexMap, ) -> BoxFuture<'a, Result> { async move { let receiver = values.first().ok_or(anyhow!("no receiver"))?; @@ -185,6 +196,7 @@ impl FunctionDef for SyncFunction { &'a self, env: &'a mut Env, values: &'a [Value], + _options: &'a HashableIndexMap, ) -> BoxFuture<'a, Result> { async move { (self.f)(env, values) }.boxed() } @@ -228,6 +240,7 @@ impl FunctionDef for AsyncMethod { &'a self, env: &'a mut Env, values: &'a [Value], + _options: &'a HashableIndexMap, ) -> BoxFuture<'a, Result> { async move { let receiver = values.first().ok_or(anyhow!("no receiver"))?; diff --git a/src/interpreter/functions/call.rs b/src/interpreter/functions/function.rs similarity index 79% rename from src/interpreter/functions/call.rs rename to src/interpreter/functions/function.rs index 305592c..04c21d1 100644 --- a/src/interpreter/functions/call.rs +++ b/src/interpreter/functions/function.rs @@ -1,17 +1,18 @@ -use crate::interpreter::{utils::join_with_final, Env, Value}; -use anyhow::{bail, Result}; +use crate::interpreter::{types::HashableIndexMap, utils::join_with_final, Env, Value}; +use anyhow::{anyhow, bail, Result}; use itertools::Itertools; use std::{fmt, sync::Arc}; use super::{definition::FunctionDef, FunctionParam}; #[derive(Debug, Clone)] -pub struct FunctionCall { +pub struct Function { def: Arc, receiver: Option, + options: HashableIndexMap, } -impl std::hash::Hash for FunctionCall { +impl std::hash::Hash for Function { fn hash(&self, state: &mut H) { self.receiver.hash(state); self.def.name().hash(state); @@ -20,7 +21,7 @@ impl std::hash::Hash for FunctionCall { } } -impl std::cmp::PartialEq for FunctionCall { +impl std::cmp::PartialEq for Function { fn eq(&self, other: &Self) -> bool { if self.receiver != other.receiver || self.def.name() != other.def.name() { return false; @@ -31,9 +32,9 @@ impl std::cmp::PartialEq for FunctionCall { } } -impl std::cmp::Eq for FunctionCall {} +impl std::cmp::Eq for Function {} -impl fmt::Display for FunctionCall { +impl fmt::Display for Function { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let variants = self.get_variants(); for (i, variant) in variants.iter().enumerate() { @@ -49,14 +50,27 @@ impl fmt::Display for FunctionCall { } } -impl FunctionCall { +impl Function { pub fn new(def: Arc, receiver: Option<&Value>) -> Self { - FunctionCall { + Function { def: def.clone(), receiver: receiver.cloned(), + options: HashableIndexMap::default(), } } + pub fn member_access(&self, member: &str) -> Result { + self.def + .member_access(&self.receiver, member) + .ok_or(anyhow!("no member {} for {}", member, self)) + } + + pub fn with_opts(self, opts: HashableIndexMap) -> Self { + let mut new = self; + new.options = opts; + new + } + pub fn get_valid_args_lengths(&self) -> Vec { let args = self.def.get_valid_args(&self.receiver); let valid_lengths = args.iter().map(|args| args.len()); @@ -82,20 +96,23 @@ impl FunctionCall { Self::new(def, Some(receiver)) } - pub fn function(def: Arc) -> Self { - Self::new(def, None) - } - pub fn is_property(&self) -> bool { self.def.is_property() } pub async fn execute(&self, env: &mut Env, args: &[Value]) -> Result { + env.push_scope(); + let result = self.execute_in_current_scope(env, args).await; + env.pop_scope(); + result + } + + pub async fn execute_in_current_scope(&self, env: &mut Env, args: &[Value]) -> Result { let mut unified_args = self.get_unified_args(args)?; if let Some(receiver) = &self.receiver { unified_args.insert(0, receiver.clone()); } - self.def.execute(env, &unified_args).await + self.def.execute(env, &unified_args, &self.options).await } fn get_unified_args(&self, args: &[Value]) -> Result> { diff --git a/src/interpreter/functions/mod.rs b/src/interpreter/functions/mod.rs index 98aff0c..b021fd7 100644 --- a/src/interpreter/functions/mod.rs +++ b/src/interpreter/functions/mod.rs @@ -1,339 +1,13 @@ -mod call; +mod contract; mod definition; +mod function; mod param; mod user_defined; -pub use call::FunctionCall; +pub use contract::ContractFunction; pub use definition::{ AsyncMethod, AsyncProperty, FunctionDef, SyncFunction, SyncMethod, SyncProperty, }; +pub use function::Function; pub use param::FunctionParam; pub use user_defined::UserDefinedFunction; - -use std::fmt::Display; - -use alloy::{ - contract::{CallBuilder, ContractInstance, Interface}, - dyn_abi::Specifier, - json_abi::StateMutability, - network::{Network, TransactionBuilder}, - primitives::Address, - providers::Provider, - rpc::types::{TransactionInput, TransactionRequest}, - transports::Transport, -}; -use anyhow::{anyhow, bail, Result}; - -use super::{types::ContractInfo, Env, StatementResult, Type, Value}; - -#[derive(Debug, Clone, PartialEq, Hash, Eq)] -pub enum ContractCallMode { - Default, - Encode, - Call, - Send, -} - -impl TryFrom<&str> for ContractCallMode { - type Error = anyhow::Error; - - fn try_from(s: &str) -> Result { - match s { - "encode" => Ok(ContractCallMode::Encode), - "call" => Ok(ContractCallMode::Call), - "send" => Ok(ContractCallMode::Send), - _ => bail!("{} does not exist for contract call", s), - } - } -} - -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct ContractCall { - info: ContractInfo, - addr: Address, - func_name: String, - mode: ContractCallMode, - options: CallOptions, -} - -impl ContractCall { - pub fn new(info: ContractInfo, addr: Address, func_name: String) -> Self { - ContractCall { - info, - addr, - func_name, - mode: ContractCallMode::Default, - options: CallOptions::default(), - } - } - - pub fn with_options(self, options: CallOptions) -> Self { - ContractCall { options, ..self } - } - - pub fn with_mode(self, mode: ContractCallMode) -> Self { - ContractCall { mode, ..self } - } -} - -#[derive(Debug, Clone, Default, Hash, PartialEq, Eq)] -pub struct CallOptions { - value: Option>, -} - -impl Display for CallOptions { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(v) = &self.value { - write!(f, "value: {}", v) - } else { - write!(f, "") - } - } -} - -impl TryFrom for CallOptions { - type Error = anyhow::Error; - - fn try_from(value: Value) -> std::result::Result { - match value { - Value::NamedTuple(_, m) => { - let mut opts = CallOptions::default(); - for (k, v) in m.0.iter() { - match k.as_str() { - "value" => opts.value = Some(Box::new(v.clone())), - _ => bail!("unexpected key {}", k), - } - } - Ok(opts) - } - _ => bail!("expected indexed map but got {}", value), - } - } -} - -impl TryFrom for CallOptions { - type Error = anyhow::Error; - - fn try_from(value: StatementResult) -> std::result::Result { - match value { - StatementResult::Value(v) => v.try_into(), - _ => bail!("expected indexed map but got {}", value), - } - } -} - -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub enum Function { - ContractCall(ContractCall), - Call(Box), -} - -impl Display for Function { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Function::Call(call) => write!(f, "{}", call), - Function::ContractCall(ContractCall { - info: ContractInfo(name, abi), - addr, - func_name, - mode, - options, - }) => { - let arg_types = abi - .function(func_name) - .map(|f| { - f[0].inputs - .iter() - .map(|t| t.to_string()) - .collect::>() - }) - .unwrap_or_default(); - let suffix = if mode == &ContractCallMode::Encode { - ".encode" - } else { - "" - }; - write!(f, "{}({}).{}", name, addr, func_name)?; - let formatted_options = format!("{}", options); - if !formatted_options.is_empty() { - write!(f, "{{{}}}", formatted_options)?; - } - write!(f, "({}){}", arg_types.join(","), suffix) - } - } - } -} - -impl Function { - pub fn with_opts(self, opts: CallOptions) -> Self { - match self { - Function::ContractCall(call) => Function::ContractCall(call.with_options(opts)), - v => v, - } - } - - pub async fn execute_in_current_scope(&self, args: &[Value], env: &mut Env) -> Result { - match self { - Function::ContractCall(call) => { - self._execute_contract_interaction(call, args, env).await - } - Function::Call(call) => call.execute(env, args).await, - } - } - - pub fn is_property(&self) -> bool { - match self { - Function::ContractCall(_) => false, - Function::Call(c) => c.is_property(), - } - } - - pub async fn execute(&self, args: &[Value], env: &mut Env) -> Result { - env.push_scope(); - let result = self.execute_in_current_scope(args, env).await; - env.pop_scope(); - result - } - - async fn _execute_contract_interaction( - &self, - call: &ContractCall, - args: &[Value], - env: &Env, - ) -> Result { - let ContractInfo(name, abi) = &call.info; - let funcs = abi.function(&call.func_name).ok_or(anyhow!( - "function {} not found in {}", - call.func_name, - name - ))?; - let contract = ContractInstance::new( - call.addr, - env.get_provider().root().clone(), - Interface::new(abi.clone()), - ); - let mut call_result = Ok(Value::Null); - for func_abi in funcs.iter() { - let types = func_abi - .inputs - .iter() - .map(|t| t.resolve().map(Type::from).map_err(|e| anyhow!(e))) - .collect::>>()?; - match self._unify_types(args, &types) { - Ok(values) => { - let tokens = values - .iter() - .map(|arg| arg.try_into()) - .collect::>>()?; - let func = contract.function_from_selector(&func_abi.selector(), &tokens)?; - let is_view = func_abi.state_mutability == StateMutability::Pure - || func_abi.state_mutability == StateMutability::View; - match call.mode { - ContractCallMode::Default => { - if is_view { - call_result = self._execute_contract_call(func).await; - } else { - call_result = self - ._execute_contract_send(&call.addr, func, &call.options, env) - .await - } - } - ContractCallMode::Encode => { - let encoded = func.calldata(); - call_result = Ok(Value::Bytes(encoded[..].to_vec())); - } - ContractCallMode::Call => { - call_result = self._execute_contract_call(func).await - } - ContractCallMode::Send => { - call_result = self - ._execute_contract_send(&call.addr, func, &call.options, env) - .await - } - } - break; - } - Err(e) => call_result = Err(e), - } - } - call_result - } - - async fn _execute_contract_send( - &self, - addr: &Address, - func: CallBuilder, - opts: &CallOptions, - env: &Env, - ) -> Result - where - T: Transport + Clone, - P: Provider, - N: Network, - { - let data = func.calldata(); - let input = TransactionInput::new(data.clone()); - let from_ = env - .get_default_sender() - .ok_or(anyhow!("no wallet connected"))?; - let mut tx_req = TransactionRequest::default() - .with_from(from_) - .with_to(*addr) - .input(input); - if let Some(value) = opts.value.as_ref() { - let value = value.as_u256()?; - tx_req = tx_req.with_value(value); - } - - let provider = env.get_provider(); - let tx = provider.send_transaction(tx_req).await?; - Ok(Value::Transaction(*tx.tx_hash())) - } - - async fn _execute_contract_call( - &self, - func: CallBuilder, - ) -> Result - where - T: Transport + Clone, - P: Provider, - N: Network, - { - let result = func.call().await?; - let return_values = result - .into_iter() - .map(Value::try_from) - .collect::>>()?; - if return_values.len() == 1 { - Ok(return_values.into_iter().next().unwrap()) - } else { - Ok(Value::Tuple(return_values)) - } - } - - fn _unify_types(&self, args: &[Value], types: &[Type]) -> Result> { - if args.len() != types.len() { - bail!( - "function {} expects {} arguments, but got {}", - self, - types.len(), - args.len() - ); - } - let mut result = Vec::new(); - for (i, (arg, type_)) in args.iter().zip(types).enumerate() { - match type_.cast(arg) { - Ok(v) => result.push(v), - Err(e) => bail!( - "expected {} argument {} to be {}, but got {} ({})", - self, - i, - type_, - arg.get_type(), - e - ), - } - } - Ok(result) - } -} diff --git a/src/interpreter/functions/param.rs b/src/interpreter/functions/param.rs index b977081..d8137fa 100644 --- a/src/interpreter/functions/param.rs +++ b/src/interpreter/functions/param.rs @@ -1,3 +1,4 @@ +use alloy::{dyn_abi::Specifier, json_abi::Param}; use anyhow::{bail, Result}; use solang_parser::pt::{Expression, Identifier, Parameter}; use std::fmt; @@ -33,6 +34,16 @@ impl fmt::Display for FunctionParam { } } +impl TryFrom for FunctionParam { + type Error = anyhow::Error; + + fn try_from(param: Param) -> std::result::Result { + let name = param.name.clone(); + let type_ = param.resolve()?.into(); + Ok(FunctionParam { name, type_ }) + } +} + impl TryFrom for FunctionParam { type Error = anyhow::Error; diff --git a/src/interpreter/functions/user_defined.rs b/src/interpreter/functions/user_defined.rs index f3c386c..c24c394 100644 --- a/src/interpreter/functions/user_defined.rs +++ b/src/interpreter/functions/user_defined.rs @@ -1,8 +1,8 @@ use std::sync::Arc; -use crate::interpreter::{evaluate_statement, Env, Value}; +use crate::interpreter::{evaluate_statement, types::HashableIndexMap, Env, Value}; -use super::{Function, FunctionCall, FunctionDef, FunctionParam}; +use super::{Function, FunctionDef, FunctionParam}; use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, FutureExt}; use solang_parser::pt::Statement; @@ -16,10 +16,7 @@ pub struct UserDefinedFunction { impl From for Value { fn from(f: UserDefinedFunction) -> Self { - Value::Func(Function::Call(Box::new(FunctionCall::new( - Arc::new(f), - None, - )))) + Value::Func(Box::new(Function::new(Arc::new(f), None))) } } @@ -40,6 +37,7 @@ impl FunctionDef for UserDefinedFunction { &'a self, env: &'a mut Env, values: &'a [Value], + _options: &'a HashableIndexMap, ) -> BoxFuture<'a, Result> { async move { for (param, arg) in self.params.iter().zip(values.iter()) { diff --git a/src/interpreter/interpreter.rs b/src/interpreter/interpreter.rs index 87ca749..6598eac 100644 --- a/src/interpreter/interpreter.rs +++ b/src/interpreter/interpreter.rs @@ -48,6 +48,10 @@ impl StatementResult { _ => None, } } + + pub fn as_value(&self) -> Result<&Value> { + self.value().ok_or(anyhow!("expected value, got {}", self)) + } } unsafe impl std::marker::Send for StatementResult {} @@ -77,7 +81,7 @@ pub async fn evaluate_setup(env: &mut Env, code: &str) -> Result<()> { evaluate_contract_parts(env, &def.parts).await?; let setup = env.get_var(SETUP_FUNCTION_NAME).cloned(); if let Some(Value::Func(func)) = setup { - func.execute_in_current_scope(&[], env).await?; + func.execute_in_current_scope(env, &[]).await?; env.delete_var(SETUP_FUNCTION_NAME) } @@ -430,7 +434,7 @@ pub fn evaluate_expression(env: &mut Env, expr: Box) -> BoxFuture<'_ Expression::MemberAccess(_, receiver_expr, method) => { let receiver = evaluate_expression(env, receiver_expr).await?; match receiver.member_access(&method.name) { - Result::Ok(Value::Func(f)) if f.is_property() => f.execute(&[], env).await, + Result::Ok(Value::Func(f)) if f.is_property() => f.execute(env, &[]).await, v => v, } } @@ -564,7 +568,7 @@ pub fn evaluate_expression(env: &mut Env, expr: Box) -> BoxFuture<'_ args.push(evaluate_expression(env, Box::new(arg.clone())).await?); } match evaluate_expression(env, func_expr).await? { - Value::Func(f) => f.execute(&args, env).await, + Value::Func(f) => f.execute(env, &args).await, Value::TypeObject(type_) => { if let [arg] = &args[..] { type_.cast(arg) @@ -579,7 +583,10 @@ pub fn evaluate_expression(env: &mut Env, expr: Box) -> BoxFuture<'_ Expression::FunctionCallBlock(_, func_expr, stmt) => { let res = evaluate_statement(env, stmt).await?; match evaluate_expression(env, func_expr).await? { - Value::Func(f) => Ok(Value::Func(f.with_opts(res.try_into()?))), + Value::Func(f) => { + let opts = res.as_value()?.as_record()?.clone(); + Ok(f.with_opts(opts).into()) + } _ => bail!("expected function"), } } diff --git a/src/interpreter/types.rs b/src/interpreter/types.rs index 400eb91..68a75a9 100644 --- a/src/interpreter/types.rs +++ b/src/interpreter/types.rs @@ -13,7 +13,7 @@ use solang_parser::pt as parser; use super::{ builtins::{INSTANCE_METHODS, STATIC_METHODS}, - functions::{ContractCall, Function}, + functions::{ContractFunction, Function}, Value, }; @@ -23,6 +23,16 @@ where K: Eq + std::hash::Hash, V: Eq; +impl std::default::Default for HashableIndexMap +where + K: Eq + std::hash::Hash, + V: Eq, +{ + fn default() -> Self { + Self(IndexMap::default()) + } +} + impl std::hash::Hash for HashableIndexMap where K: std::hash::Hash + Eq, @@ -47,16 +57,15 @@ where pub struct ContractInfo(pub String, pub JsonAbi); impl ContractInfo { - pub fn create_call(&self, name: &str, addr: Address) -> Result { + pub fn make_function(&self, name: &str, addr: Address) -> Result { let _func = self .1 .function(name) .ok_or_else(|| anyhow::anyhow!("function {} not found in contract {}", name, self.0))?; - Ok(Function::ContractCall(ContractCall::new( - self.clone(), - addr, - name.to_string(), - ))) + Ok(Function::new( + ContractFunction::arc(name), + Some(&Value::Contract(self.clone(), addr)), + )) } } @@ -364,6 +373,15 @@ impl TryFrom for DynSolType { } } +fn canonical_string_for_tuple(types: &[Type]) -> Result { + let items = types + .iter() + .map(|t| t.canonical_string()) + .collect::>>()? + .join(","); + Ok(format!("({})", items)) +} + impl Type { pub fn default_value(&self) -> Result { let value = match self { @@ -399,6 +417,27 @@ impl Type { Ok(value) } + pub fn canonical_string(&self) -> Result { + let result = match self { + Type::Address => "address".to_string(), + Type::Bool => "bool".to_string(), + Type::Int(size) => format!("int{}", size), + Type::Uint(size) => format!("uint{}", size), + Type::FixBytes(size) => format!("bytes{}", size), + Type::Bytes => "bytes".to_string(), + Type::String => "string".to_string(), + Type::Array(t) => format!("{}[]", t.canonical_string()?), + Type::FixedArray(t, size) => format!("{}[{}]", t.canonical_string()?, size), + Type::NamedTuple(_, fields) => { + let types = fields.0.values().cloned().collect_vec(); + canonical_string_for_tuple(&types)? + } + Type::Tuple(types) => canonical_string_for_tuple(types.as_slice())?, + _ => bail!("cannot get canonical string for type {}", self), + }; + Ok(result) + } + pub fn is_int(&self) -> bool { matches!(self, Type::Int(_) | Type::Uint(_)) } diff --git a/src/interpreter/value.rs b/src/interpreter/value.rs index 67448e1..e5d6380 100644 --- a/src/interpreter/value.rs +++ b/src/interpreter/value.rs @@ -13,7 +13,7 @@ use std::{ use super::{ builtins::{INSTANCE_METHODS, STATIC_METHODS, TYPE_METHODS}, - functions::{ContractCallMode, Function, FunctionCall}, + functions::Function, types::{ContractInfo, HashableIndexMap, Receipt, Type}, }; @@ -35,7 +35,7 @@ pub enum Value { TypeObject(Type), Transaction(B256), TransactionReceipt(Receipt), - Func(Function), + Func(Box), } fn _values_to_string(values: &[Value]) -> String { @@ -219,9 +219,9 @@ impl From<&str> for Value { } } -impl From for Value { - fn from(f: FunctionCall) -> Self { - Value::Func(Function::Call(Box::new(f))) +impl From for Value { + fn from(f: Function) -> Self { + Value::Func(Box::new(f)) } } @@ -359,6 +359,13 @@ impl Value { } } + pub fn as_contract(&self) -> Result<(ContractInfo, Address)> { + match self { + Value::Contract(info, addr) => Ok((info.clone(), *addr)), + _ => bail!("cannot convert {} to contract", self.get_type()), + } + } + pub fn as_string(&self) -> Result { match self { Value::Str(str) => Ok(str.clone()), @@ -397,6 +404,13 @@ impl Value { } } + pub fn as_record(&self) -> Result<&HashableIndexMap> { + match self { + Value::NamedTuple(_, map) => Ok(map), + _ => bail!("cannot convert {} to map", self.get_type()), + } + } + pub fn get_field(&self, field: &str) -> Result { match self { Value::NamedTuple(_, fields) => fields @@ -437,11 +451,8 @@ impl Value { Ok(kv.0.get(member).unwrap().clone()) } Value::TransactionReceipt(r) if r.contains_key(member) => r.get(member), - Value::Contract(c, addr) => c.create_call(member, *addr).map(Value::Func), - - Value::Func(Function::ContractCall(call)) => Ok(Value::Func(Function::ContractCall( - call.clone().with_mode(ContractCallMode::try_from(member)?), - ))), + Value::Contract(c, addr) => c.make_function(member, *addr).map(Into::into), + Value::Func(f) => f.member_access(member), _ => { let (type_, methods) = match self { Value::TypeObject(Type::Type(type_)) => { @@ -458,7 +469,7 @@ impl Value { type_, member ))?; - Ok(FunctionCall::method(func.clone(), self).into()) + Ok(Function::method(func.clone(), self).into()) } } }