diff --git a/CHANGELOG.md b/CHANGELOG.md index a2d5db4..37ca6db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ * Add support for bitwise operators * [EXPERIMENTAL] Add support for anonymous functions * [EXPERIMENTAL] Add support for fetching events +* Add support for negative indices in arrays and slices ### Bug fixes diff --git a/src/interpreter/interpreter.rs b/src/interpreter/interpreter.rs index f2ab614..a7cf57d 100644 --- a/src/interpreter/interpreter.rs +++ b/src/interpreter/interpreter.rs @@ -15,7 +15,7 @@ use super::assignment::Lhs; use super::builtins; use super::functions::{AnonymousFunction, FunctionDef, UserDefinedFunction}; use super::parsing::ParsedCode; -use super::types::{HashableIndexMap, Type}; +use super::types::{ArrayIndex, HashableIndexMap, Type}; use super::utils::parse_rational_literal; use super::{env::Env, parsing, value::Value}; @@ -444,10 +444,8 @@ pub fn evaluate_expression(env: &mut Env, expr: Box) -> BoxFuture<'_ Value::Tuple(values) | Value::Array(values, _) => { let subscript = subscript_opt .ok_or(anyhow!("tuples and arrays do not support empty subscript"))?; - let index = evaluate_expression(env, subscript).await?.as_usize()?; - if index >= values.len() { - bail!("index out of bounds"); - } + let value = evaluate_expression(env, subscript).await?; + let index = ArrayIndex::try_from(value)?.get_index(values.len())?; Ok(values[index].clone()) } Value::Mapping(values, kt, _) => { @@ -474,11 +472,11 @@ pub fn evaluate_expression(env: &mut Env, expr: Box) -> BoxFuture<'_ Expression::ArraySlice(_, arr_expr, start_expr, end_expr) => { let value = evaluate_expression(env, arr_expr).await?; let start = match start_expr { - Some(expr) => Some(evaluate_expression(env, expr).await?.as_usize()?), + Some(expr) => Some(evaluate_expression(env, expr).await?.try_into()?), None => None, }; let end = match end_expr { - Some(expr) => Some(evaluate_expression(env, expr).await?.as_usize()?), + Some(expr) => Some(evaluate_expression(env, expr).await?.try_into()?), None => None, }; value.slice(start, end) diff --git a/src/interpreter/types.rs b/src/interpreter/types.rs index 41759b1..53f604e 100644 --- a/src/interpreter/types.rs +++ b/src/interpreter/types.rs @@ -53,6 +53,36 @@ where } } +pub struct ArrayIndex(pub i64); +impl ArrayIndex { + pub fn get_index(&self, array_size: usize) -> Result { + let index = if self.0 < 0 { + array_size as i64 + self.0 + } else { + self.0 + }; + if index < 0 || index as usize >= array_size { + bail!( + "index out of bounds: {} for array of size {}", + self.0, + array_size + ) + } + Ok(index as usize) + } +} +impl TryFrom for ArrayIndex { + type Error = anyhow::Error; + + fn try_from(value: Value) -> Result { + match value { + Value::Int(i, _) => Ok(ArrayIndex(i.try_into()?)), + Value::Uint(i, _) => Ok(ArrayIndex(i.try_into()?)), + _ => bail!("cannot convert {} to array index", value), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ContractInfo(pub String, pub JsonAbi); @@ -663,4 +693,13 @@ mod tests { Value::from_hex(padded_selector).unwrap() ); } + + #[test] + fn array_index() { + let size = 10; + let cases = vec![(0, 0), (1, 1), (-1, 9), (5, 5), (-5, 5)]; + for (index, expected) in cases { + assert_eq!(super::ArrayIndex(index).get_index(size).unwrap(), expected); + } + } } diff --git a/src/interpreter/value.rs b/src/interpreter/value.rs index 1be2ab6..58c098b 100644 --- a/src/interpreter/value.rs +++ b/src/interpreter/value.rs @@ -17,7 +17,7 @@ use std::{ use super::{ builtins::{INSTANCE_METHODS, STATIC_METHODS, TYPE_METHODS}, functions::Function, - types::{ContractInfo, HashableIndexMap, Type, LOG_TYPE}, + types::{ArrayIndex, ContractInfo, HashableIndexMap, Type, LOG_TYPE}, }; #[derive(Debug, Clone, Hash, PartialEq, Eq)] @@ -543,12 +543,14 @@ impl Value { } } - pub fn slice(&self, start: Option, end: Option) -> Result { - let start = start.unwrap_or(0); - let end = end.unwrap_or(self.len()?); - if end > self.len()? { - bail!("index out of bounds") - } + pub fn slice(&self, start: Option, end: Option) -> Result { + let length = self.len()?; + let start = start.unwrap_or(ArrayIndex(0)).get_index(length)?; + let end = match end { + Some(end) => end.get_index(length)?, + None => length, + }; + match self { Value::Array(items, t) => { let items = items[start..end].to_vec(); @@ -758,22 +760,34 @@ mod tests { vec![Value::from(1u64), Value::from(2u64), Value::from(3u64)], Box::new(Type::Int(256)), ); - let slice = array.slice(Some(1), Some(2)).unwrap(); + let slice = array + .slice(Some(ArrayIndex(1)), Some(ArrayIndex(2))) + .unwrap(); + assert_eq!( + slice, + Value::Array(vec![Value::from(2u64)], Box::new(Type::Int(256))) + ); + + let slice = array + .slice(Some(ArrayIndex(1)), Some(ArrayIndex(-1))) + .unwrap(); assert_eq!( slice, Value::Array(vec![Value::from(2u64)], Box::new(Type::Int(256))) ); let bytes = Value::Bytes(vec![1, 2, 3]); - let slice = bytes.slice(Some(1), Some(2)).unwrap(); + let slice = bytes + .slice(Some(ArrayIndex(1)), Some(ArrayIndex(2))) + .unwrap(); assert_eq!(slice, Value::Bytes(vec![2])); let bytes = Value::Bytes(vec![1, 2, 3]); - let slice = bytes.slice(Some(1), None).unwrap(); + let slice = bytes.slice(Some(ArrayIndex(1)), None).unwrap(); assert_eq!(slice, Value::Bytes(vec![2, 3])); let str = Value::Str("hello".to_string()); - let slice = str.slice(Some(1), Some(3)).unwrap(); + let slice = str.slice(Some(ArrayIndex(1)), Some(ArrayIndex(3))).unwrap(); assert_eq!(slice, Value::Str("el".to_string())); } }