diff --git a/Cargo.lock b/Cargo.lock index 31cf774982c..e52eab8f022 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -689,7 +689,6 @@ dependencies = [ "num-traits 0.2.19", "pretty_assertions", "rust-analyzer-salsa", - "smol_str", "test-log", ] diff --git a/corelib/src/bytes_31.cairo b/corelib/src/bytes_31.cairo index df73b6332c8..24ac3d0d541 100644 --- a/corelib/src/bytes_31.cairo +++ b/corelib/src/bytes_31.cairo @@ -88,31 +88,31 @@ pub(crate) impl Felt252TryIntoBytes31 of TryInto { impl Bytes31Serde = crate::serde::into_felt252_based::SerdeImpl; pub(crate) impl U8IntoBytes31 of Into { - fn into(self: u8) -> bytes31 { + const fn into(self: u8) -> bytes31 { crate::integer::upcast(self) } } impl U16IntoBytes31 of Into { - fn into(self: u16) -> bytes31 { + const fn into(self: u16) -> bytes31 { crate::integer::upcast(self) } } impl U32IntoBytes31 of Into { - fn into(self: u32) -> bytes31 { + const fn into(self: u32) -> bytes31 { crate::integer::upcast(self) } } impl U64IntoBytes31 of Into { - fn into(self: u64) -> bytes31 { + const fn into(self: u64) -> bytes31 { crate::integer::upcast(self) } } pub(crate) impl U128IntoBytes31 of Into { - fn into(self: u128) -> bytes31 { + const fn into(self: u128) -> bytes31 { crate::integer::upcast(self) } } diff --git a/corelib/src/integer.cairo b/corelib/src/integer.cairo index 5e5937ed04d..423303dad28 100644 --- a/corelib/src/integer.cairo +++ b/corelib/src/integer.cairo @@ -1426,11 +1426,11 @@ pub(crate) impl I128IntoFelt252 of Into { // TODO(lior): Restrict the function (using traits) in the high-level compiler so that wrong types // will not lead to Sierra errors. -pub(crate) extern fn upcast(x: FromType) -> ToType nopanic; +pub(crate) extern const fn upcast(x: FromType) -> ToType nopanic; // TODO(lior): Restrict the function (using traits) in the high-level compiler so that wrong types // will not lead to Sierra errors. -pub(crate) extern fn downcast( +pub(crate) extern const fn downcast( x: FromType, ) -> Option implicits(RangeCheck) nopanic; @@ -1599,7 +1599,7 @@ impl U128Felt252DictValue of Felt252DictValue { } impl UpcastableInto> of Into { - fn into(self: From) -> To { + const fn into(self: From) -> To { upcast(self) } } @@ -1607,13 +1607,13 @@ impl UpcastableInto> of Into { impl DowncastableIntTryInto< From, To, +DowncastableInt, +DowncastableInt, -Into, > of TryInto { - fn try_into(self: From) -> Option { + const fn try_into(self: From) -> Option { downcast(self) } } impl U8IntoU256 of Into { - fn into(self: u8) -> u256 { + const fn into(self: u8) -> u256 { u256 { low: upcast(self), high: 0_u128 } } } @@ -1631,7 +1631,7 @@ impl U256TryIntoU8 of TryInto { } impl U16IntoU256 of Into { - fn into(self: u16) -> u256 { + const fn into(self: u16) -> u256 { u256 { low: upcast(self), high: 0_u128 } } } @@ -1649,7 +1649,7 @@ impl U256TryIntoU16 of TryInto { } impl U32IntoU256 of Into { - fn into(self: u32) -> u256 { + const fn into(self: u32) -> u256 { u256 { low: upcast(self), high: 0_u128 } } } @@ -1667,7 +1667,7 @@ impl U256TryIntoU32 of TryInto { } impl U64IntoU256 of Into { - fn into(self: u64) -> u256 { + const fn into(self: u64) -> u256 { u256 { low: upcast(self), high: 0_u128 } } } diff --git a/corelib/src/internal/bounded_int.cairo b/corelib/src/internal/bounded_int.cairo index 9585ba3efed..ce60cc9c920 100644 --- a/corelib/src/internal/bounded_int.cairo +++ b/corelib/src/internal/bounded_int.cairo @@ -11,7 +11,7 @@ impl NumericLiteralBoundedInt< impl BoundedIntIntoFelt252< const MIN: felt252, const MAX: felt252, > of Into, felt252> { - fn into(self: BoundedInt) -> felt252 { + const fn into(self: BoundedInt) -> felt252 { upcast(self) } } @@ -19,7 +19,7 @@ impl BoundedIntIntoFelt252< impl Felt252TryIntoBoundedInt< const MIN: felt252, const MAX: felt252, > of TryInto> { - fn try_into(self: felt252) -> Option> { + const fn try_into(self: felt252) -> Option> { // Using `downcast` is allowed, since `BoundedInt` itself is not `pub`, and only has a few // specific `pub` instances, such as `u96`, `ConstZero` and `ConstOne`. downcast(self) @@ -113,56 +113,49 @@ extern fn bounded_int_constrain Result implicits(RangeCheck) nopanic; -/// A helper trait for trimming a `BoundedInt` instance. -pub trait TrimHelper { +/// A helper trait for trimming a `BoundedInt` instance min value. +pub trait TrimMinHelper { + type Target; +} +/// A helper trait for trimming a `BoundedInt` instance max value. +pub trait TrimMaxHelper { type Target; } mod trim_impl { - pub impl Impl< - T, const TRIMMED_VALUE: felt252, const MIN: felt252, const MAX: felt252, - > of super::TrimHelper { + pub impl Min of super::TrimMinHelper { + type Target = super::BoundedInt; + } + pub impl Max of super::TrimMaxHelper { type Target = super::BoundedInt; } } -impl U8TrimBelow = trim_impl::Impl; -impl U8TrimAbove = trim_impl::Impl; -impl I8TrimBelow = trim_impl::Impl; -impl I8TrimAbove = trim_impl::Impl; -impl U16TrimBelow = trim_impl::Impl; -impl U16TrimAbove = trim_impl::Impl; -impl I16TrimBelow = trim_impl::Impl; -impl I16TrimAbove = trim_impl::Impl; -impl U32TrimBelow = trim_impl::Impl; -impl U32TrimAbove = trim_impl::Impl; -impl I32TrimBelow = trim_impl::Impl; -impl I32TrimAbove = trim_impl::Impl; -impl U64TrimBelow = trim_impl::Impl; -impl U64TrimAbove = trim_impl::Impl; -impl I64TrimBelow = - trim_impl::Impl; -impl I64TrimAbove = - trim_impl::Impl; -impl U128TrimBelow = trim_impl::Impl; -impl U128TrimAbove = - trim_impl::Impl< - u128, 0xffffffffffffffffffffffffffffffff, 0, 0xfffffffffffffffffffffffffffffffe, - >; +impl U8TrimBelow = trim_impl::Min; +impl U8TrimAbove = trim_impl::Max; +impl I8TrimBelow = trim_impl::Min; +impl I8TrimAbove = trim_impl::Max; +impl U16TrimBelow = trim_impl::Min; +impl U16TrimAbove = trim_impl::Max; +impl I16TrimBelow = trim_impl::Min; +impl I16TrimAbove = trim_impl::Max; +impl U32TrimBelow = trim_impl::Min; +impl U32TrimAbove = trim_impl::Max; +impl I32TrimBelow = trim_impl::Min; +impl I32TrimAbove = trim_impl::Max; +impl U64TrimBelow = trim_impl::Min; +impl U64TrimAbove = trim_impl::Max; +impl I64TrimBelow = trim_impl::Min; +impl I64TrimAbove = trim_impl::Max; +impl U128TrimBelow = trim_impl::Min; +impl U128TrimAbove = trim_impl::Max; impl I128TrimBelow = - trim_impl::Impl< - i128, - -0x80000000000000000000000000000000, - -0x7fffffffffffffffffffffffffffffff, - 0x7fffffffffffffffffffffffffffffff, - >; + trim_impl::Min; impl I128TrimAbove = - trim_impl::Impl< - i128, - 0x7fffffffffffffffffffffffffffffff, - -0x80000000000000000000000000000000, - 0x7ffffffffffffffffffffffffffffffe, - >; + trim_impl::Max; -extern fn bounded_int_trim>( +extern fn bounded_int_trim_min>( + value: T, +) -> core::internal::OptionRev nopanic; +extern fn bounded_int_trim_max>( value: T, ) -> core::internal::OptionRev nopanic; @@ -272,5 +265,5 @@ impl MulMinusOneNegateHelper> of NegateHelper< pub use { bounded_int_add as add, bounded_int_constrain as constrain, bounded_int_div_rem as div_rem, bounded_int_is_zero as is_zero, bounded_int_mul as mul, bounded_int_sub as sub, - bounded_int_trim as trim, + bounded_int_trim_max as trim_max, bounded_int_trim_min as trim_min, }; diff --git a/corelib/src/iter/adapters.cairo b/corelib/src/iter/adapters.cairo index fdd38e85c67..ee647323441 100644 --- a/corelib/src/iter/adapters.cairo +++ b/corelib/src/iter/adapters.cairo @@ -2,3 +2,8 @@ mod map; pub use map::Map; #[allow(unused_imports)] pub(crate) use map::mapped_iterator; + +mod enumerate; +pub use enumerate::Enumerate; +#[allow(unused_imports)] +pub(crate) use enumerate::enumerated_iterator; diff --git a/corelib/src/iter/adapters/enumerate.cairo b/corelib/src/iter/adapters/enumerate.cairo new file mode 100644 index 00000000000..7157a9aaf05 --- /dev/null +++ b/corelib/src/iter/adapters/enumerate.cairo @@ -0,0 +1,41 @@ +/// An iterator that yields the current count and the element during iteration. +/// +/// This `struct` is created by the [`enumerate`] method on [`Iterator`]. See its +/// documentation for more. +/// +/// [`enumerate`]: Iterator::enumerate +/// [`Iterator`]: core::iter::Iterator +#[must_use] +#[derive(Drop, Clone, Debug)] +pub struct Enumerate { + iter: I, + count: usize, +} + +pub fn enumerated_iterator(iter: I) -> Enumerate { + Enumerate { iter, count: 0 } +} + +impl EnumerateIterator< + I, T, +Iterator[Item: T], +Destruct, +Destruct, +> of Iterator> { + type Item = (usize, T); + + /// # Overflow Behavior + /// + /// The method does no guarding against overflows, so enumerating more than + /// `Bounded::::MAX` elements will always panic. + /// + /// [`Bounded`]: core::num::traits::Bounded + /// + /// # Panics + /// + /// Will panic if the index of the element overflows a `usize`. + #[inline] + fn next(ref self: Enumerate) -> Option { + let a = self.iter.next()?; + let i = self.count; + self.count += 1; + Option::Some((i, a)) + } +} diff --git a/corelib/src/iter/traits/iterator.cairo b/corelib/src/iter/traits/iterator.cairo index 7ee26837f4c..daffe4d4e7a 100644 --- a/corelib/src/iter/traits/iterator.cairo +++ b/corelib/src/iter/traits/iterator.cairo @@ -1,4 +1,4 @@ -use crate::iter::adapters::{Map, mapped_iterator}; +use crate::iter::adapters::{Enumerate, Map, enumerated_iterator, mapped_iterator}; /// A trait for dealing with iterators. /// @@ -41,6 +41,47 @@ pub trait Iterator { /// ``` fn next(ref self: T) -> Option; + /// Advances the iterator by `n` elements. + /// + /// This method will eagerly skip `n` elements by calling [`next`] up to `n` + /// times until [`None`] is encountered. + /// + /// `advance_by(n)` will return `Ok(())` if the iterator successfully advances by + /// `n` elements, or a `Err(NonZero)` with value `k` if [`None`] is encountered, + /// where `k` is remaining number of steps that could not be advanced because the iterator ran + /// out. + /// If `self` is empty and `n` is non-zero, then this returns `Err(n)`. + /// Otherwise, `k` is always less than `n`. + /// + /// [`None`]: Option::None + /// [`next`]: Iterator::next + /// + /// # Examples + /// + /// ``` + /// let mut iter = array![1_u8, 2, 3, 4].into_iter(); + /// + /// assert_eq!(iter.advance_by(2), Result::Ok(())); + /// assert_eq!(iter.next(), Option::Some(3)); + /// assert_eq!(iter.advance_by(0), Result::Ok(())); + /// assert_eq!(iter.advance_by(100), Result::Err(99)); + /// ``` + fn advance_by<+Destruct, +Destruct>( + ref self: T, n: usize, + ) -> Result< + (), NonZero, + > { + if let Option::Some(nz_n) = n.try_into() { + if let Option::Some(_) = Self::next(ref self) { + return Self::advance_by(ref self, n - 1); + } else { + Result::Err(nz_n) + } + } else { + Result::Ok(()) + } + } + /// Takes a closure and creates an iterator which calls that closure on each /// element. /// @@ -93,4 +134,133 @@ pub trait Iterator { ) -> Map { mapped_iterator(self, f) } + + /// Creates an iterator which gives the current iteration count as well as + /// the next value. + /// + /// The iterator returned yields pairs `(i, val)`, where `i` is the + /// current index of iteration and `val` is the value returned by the + /// iterator. + /// + /// `enumerate()` keeps its count as a [`usize`]. + /// + /// # Overflow Behavior + /// + /// The method does no guarding against overflows, so enumerating more than + /// `Bounded::::MAX` elements will always panic. + /// + /// [`Bounded`]: core::num::traits::Bounded + /// + /// # Panics + /// + /// Will panic if the to-be-returned index overflows a `usize`. + /// + /// # Examples + /// + /// ``` + /// let mut iter = array!['a', 'b', 'c'].into_iter().enumerate(); + /// + /// assert_eq!(iter.next(), Option::Some((0, 'a'))); + /// assert_eq!(iter.next(), Option::Some((1, 'b'))); + /// assert_eq!(iter.next(), Option::Some((2, 'c'))); + /// assert_eq!(iter.next(), Option::None); + /// ``` + #[inline] + fn enumerate(self: T) -> Enumerate { + enumerated_iterator(self) + } + + /// Folds every element into an accumulator by applying an operation, + /// returning the final result. + /// + /// `fold()` takes two arguments: an initial value, and a closure with two + /// arguments: an 'accumulator', and an element. The closure returns the value that + /// the accumulator should have for the next iteration. + /// + /// The initial value is the value the accumulator will have on the first + /// call. + /// + /// After applying this closure to every element of the iterator, `fold()` + /// returns the accumulator. + /// + /// Folding is useful whenever you have a collection of something, and want + /// to produce a single value from it. + /// + /// Note: `fold()`, and similar methods that traverse the entire iterator, + /// might not terminate for infinite iterators, even on traits for which a + /// result is determinable in finite time. + /// + /// Note: `fold()` combines elements in a *left-associative* fashion. For associative + /// operators like `+`, the order the elements are combined in is not important, but for + /// non-associative operators like `-` the order will affect the final result. + /// + /// # Note to Implementors + /// + /// Several of the other (forward) methods have default implementations in + /// terms of this one, so try to implement this explicitly if it can + /// do something better than the default `for` loop implementation. + /// + /// In particular, try to have this call `fold()` on the internal parts + /// from which this iterator is composed. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ``` + /// let mut iter = array![1, 2, 3].into_iter(); + /// + /// // the sum of all of the elements of the array + /// let sum = iter.fold(0, |acc, x| acc + x); + /// + /// assert_eq!(sum, 6); + /// ``` + /// + /// Let's walk through each step of the iteration here: + /// + /// | element | acc | x | result | + /// |---------|-----|---|--------| + /// | | 0 | | | + /// | 1 | 0 | 1 | 1 | + /// | 2 | 1 | 2 | 3 | + /// | 3 | 3 | 3 | 6 | + /// + /// And so, our final result, `6`. + /// + /// It's common for people who haven't used iterators a lot to + /// use a `for` loop with a list of things to build up a result. Those + /// can be turned into `fold()`s: + /// + /// ``` + /// let mut numbers = array![1, 2, 3, 4, 5].span(); + /// + /// let mut result = 0; + /// + /// // for loop: + /// for i in numbers{ + /// result = result + (*i); + /// }; + /// + /// // fold: + /// let mut numbers_iter = numbers.into_iter(); + /// let result2 = numbers_iter.fold(0, |acc, x| acc + (*x)); + /// + /// // they're the same + /// assert_eq!(result, result2); + /// ``` + fn fold< + B, + F, + +core::ops::Fn[Output: B], + +Destruct, + +Destruct, + +Destruct, + >( + ref self: T, init: B, f: F, + ) -> B { + match Self::next(ref self) { + Option::None => init, + Option::Some(x) => Self::fold(ref self, f(init, x), f), + } + } } diff --git a/corelib/src/starknet/storage.cairo b/corelib/src/starknet/storage.cairo index 34ed12c175f..04e3424d143 100644 --- a/corelib/src/starknet/storage.cairo +++ b/corelib/src/starknet/storage.cairo @@ -139,7 +139,6 @@ //! address (the `sn_keccak` hash of the variable name) combined with the mapping keys or vector //! indices. //! See their respective module documentation for more details. - use core::hash::HashStateTrait; #[allow(unused_imports)] use core::pedersen::HashState; @@ -161,7 +160,10 @@ mod sub_pointers; pub use sub_pointers::{SubPointers, SubPointersForward, SubPointersMut, SubPointersMutForward}; mod vec; -use vec::{MutableVecIndexView, VecIndexView}; +use vec::{ + MutableVecIndexView, MutableVecIntoIterRange, PathableMutableVecIntoIterRange, + PathableVecIntoIterRange, VecIndexView, VecIntoIterRange, +}; pub use vec::{MutableVecTrait, Vec, VecTrait}; /// A pointer to an address in storage, can be used to read and write values, if the generic type @@ -539,3 +541,13 @@ trait MutableTrait { impl MutableImpl of MutableTrait> { type InnerType = T; } + +/// Trait for turning collection of values into an iterator over a specific range. +pub trait IntoIterRange { + type IntoIter; + impl Iterator: Iterator; + /// Creates an iterator over a range from a collection. + fn into_iter_range(self: T, range: core::ops::Range) -> Self::IntoIter; + /// Creates an iterator over the full range of a collection. + fn into_iter_full_range(self: T) -> Self::IntoIter; +} diff --git a/corelib/src/starknet/storage/vec.cairo b/corelib/src/starknet/storage/vec.cairo index 5e03fa885c4..2a4a3bc8b0d 100644 --- a/corelib/src/starknet/storage/vec.cairo +++ b/corelib/src/starknet/storage/vec.cairo @@ -77,11 +77,11 @@ //! arr //! } //! ``` - -use core::Option; +use core::ops::Range; use super::{ - Mutable, StorageAsPath, StorageAsPointer, StoragePath, StoragePathTrait, StoragePathUpdateTrait, - StoragePointer0Offset, StoragePointerReadAccess, StoragePointerWriteAccess, + IntoIterRange, Mutable, StorageAsPath, StorageAsPointer, StoragePath, StoragePathTrait, + StoragePathUpdateTrait, StoragePointer0Offset, StoragePointerReadAccess, + StoragePointerWriteAccess, }; /// Represents a dynamic array in contract storage. @@ -396,3 +396,105 @@ pub impl MutableVecIndexView< (*self).at(index) } } + +/// An iterator struct over a `Vec` in storage. +#[derive(Drop)] +pub struct VecIter> { + vec: T, + current_index: crate::ops::RangeIterator, +} + +impl VecIterator, +Drop, +Copy> of Iterator> { + type Item = StoragePath; + fn next(ref self: VecIter) -> Option { + self.vec.get(self.current_index.next()?) + } +} + +// Implement `IntoIterRange` for `StoragePath>` +pub impl VecIntoIterRange< + T, impl VecTraitImpl: VecTrait>>, +> of IntoIterRange>> { + type IntoIter = VecIter>, VecTraitImpl>; + #[inline] + fn into_iter_range(self: StoragePath>, range: Range) -> Self::IntoIter { + VecIter { current_index: range.into_iter(), vec: self } + } + #[inline] + fn into_iter_full_range(self: StoragePath>) -> Self::IntoIter { + VecIter { current_index: (0..core::num::traits::Bounded::MAX).into_iter(), vec: self } + } +} + +/// Implement `IntoIterRange` for any type that implements StorageAsPath into a storage path +/// that implements VecTrait. +pub impl PathableVecIntoIterRange< + T, + +Destruct, + impl PathImpl: StorageAsPath, + impl VecTraitImpl: VecTrait>, +> of IntoIterRange { + type IntoIter = VecIter, VecTraitImpl>; + #[inline] + fn into_iter_range(self: T, range: Range) -> Self::IntoIter { + VecIter { current_index: range.into_iter(), vec: self.as_path() } + } + #[inline] + fn into_iter_full_range(self: T) -> Self::IntoIter { + let vec = self.as_path(); + VecIter { current_index: (0..core::num::traits::Bounded::MAX).into_iter(), vec } + } +} + +/// An iterator struct over a `Mutable` in storage. +#[derive(Drop)] +struct MutableVecIter> { + vec: T, + current_index: crate::ops::RangeIterator, +} + +impl MutableVecIterator< + T, +Drop, +Copy, impl MutVecTraitImpl: MutableVecTrait, +> of Iterator> { + type Item = StoragePath>; + fn next(ref self: MutableVecIter) -> Option { + self.vec.get(self.current_index.next()?) + } +} + +// Implement `IntoIterRange` for `StoragePath>>` +pub impl MutableVecIntoIterRange< + T, impl MutVecTraitImpl: MutableVecTrait>>>, +> of IntoIterRange>>> { + type IntoIter = MutableVecIter>>, MutVecTraitImpl>; + #[inline] + fn into_iter_range(self: StoragePath>>, range: Range) -> Self::IntoIter { + MutableVecIter { current_index: range.into_iter(), vec: self } + } + #[inline] + fn into_iter_full_range(self: StoragePath>>) -> Self::IntoIter { + MutableVecIter { + current_index: (0..core::num::traits::Bounded::MAX).into_iter(), vec: self, + } + } +} + +/// Implement `IntoIterRange` for any type that implements StorageAsPath into a storage path +/// that implements MutableVecTrait. +pub impl PathableMutableVecIntoIterRange< + T, + +Destruct, + impl PathImpl: StorageAsPath, + impl MutVecTraitImpl: MutableVecTrait>, +> of IntoIterRange { + type IntoIter = MutableVecIter, MutVecTraitImpl>; + #[inline] + fn into_iter_range(self: T, range: Range) -> Self::IntoIter { + MutableVecIter { current_index: range.into_iter(), vec: self.as_path() } + } + #[inline] + fn into_iter_full_range(self: T) -> Self::IntoIter { + let vec = self.as_path(); + MutableVecIter { current_index: (0..core::num::traits::Bounded::MAX).into_iter(), vec } + } +} diff --git a/corelib/src/test/integer_test.cairo b/corelib/src/test/integer_test.cairo index f4bac082882..a60e191a32b 100644 --- a/corelib/src/test/integer_test.cairo +++ b/corelib/src/test/integer_test.cairo @@ -2161,81 +2161,80 @@ mod bounded_int { #[test] fn test_trim() { use core::internal::OptionRev; - assert!(bounded_int::trim::(0) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); - assert!(bounded_int::trim::(0xff) == OptionRev::None); - assert!(bounded_int::trim::(0xfe) == OptionRev::Some(0xfe)); - assert!(bounded_int::trim::(-0x80) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); - assert!(bounded_int::trim::(0x7f) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); - - assert!(bounded_int::trim::(0) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); - assert!(bounded_int::trim::(0xffff) == OptionRev::None); - assert!(bounded_int::trim::(0xfffe) == OptionRev::Some(0xfffe)); - assert!(bounded_int::trim::(-0x8000) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); - assert!(bounded_int::trim::(0x7fff) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); - - assert!(bounded_int::trim::(0) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); - assert!(bounded_int::trim::(0xffffffff) == OptionRev::None); - assert!(bounded_int::trim::(0xfffffffe) == OptionRev::Some(0xfffffffe)); - assert!(bounded_int::trim::(-0x80000000) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); - assert!(bounded_int::trim::(0x7fffffff) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); - - assert!(bounded_int::trim::(0) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_min::(0) == OptionRev::None); + assert!(bounded_int::trim_min::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_max::(0xff) == OptionRev::None); + assert!(bounded_int::trim_max::(0xfe) == OptionRev::Some(0xfe)); + assert!(bounded_int::trim_min::(-0x80) == OptionRev::None); + assert!(bounded_int::trim_min::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_max::(0x7f) == OptionRev::None); + assert!(bounded_int::trim_max::(1) == OptionRev::Some(1)); + + assert!(bounded_int::trim_min::(0) == OptionRev::None); + assert!(bounded_int::trim_min::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_max::(0xffff) == OptionRev::None); + assert!(bounded_int::trim_max::(0xfffe) == OptionRev::Some(0xfffe)); + assert!(bounded_int::trim_min::(-0x8000) == OptionRev::None); + assert!(bounded_int::trim_min::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_max::(0x7fff) == OptionRev::None); + assert!(bounded_int::trim_max::(1) == OptionRev::Some(1)); + + assert!(bounded_int::trim_min::(0) == OptionRev::None); + assert!(bounded_int::trim_min::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_max::(0xffffffff) == OptionRev::None); + assert!(bounded_int::trim_max::(0xfffffffe) == OptionRev::Some(0xfffffffe)); + assert!(bounded_int::trim_min::(-0x80000000) == OptionRev::None); + assert!(bounded_int::trim_min::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_max::(0x7fffffff) == OptionRev::None); + assert!(bounded_int::trim_max::(1) == OptionRev::Some(1)); + + assert!(bounded_int::trim_min::(0) == OptionRev::None); + assert!(bounded_int::trim_min::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_max::(0xffffffffffffffff) == OptionRev::None); assert!( - bounded_int::trim::(0xffffffffffffffff) == OptionRev::None, + bounded_int::trim_max::(0xfffffffffffffffe) == OptionRev::Some(0xfffffffffffffffe), ); - assert!( - bounded_int::trim::< - u64, 0xffffffffffffffff, - >(0xfffffffffffffffe) == OptionRev::Some(0xfffffffffffffffe), - ); - assert!( - bounded_int::trim::(-0x8000000000000000) == OptionRev::None, - ); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); - assert!( - bounded_int::trim::(0x7fffffffffffffff) == OptionRev::None, - ); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_min::(-0x8000000000000000) == OptionRev::None); + assert!(bounded_int::trim_min::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_max::(0x7fffffffffffffff) == OptionRev::None); + assert!(bounded_int::trim_max::(1) == OptionRev::Some(1)); - assert!(bounded_int::trim::(0) == OptionRev::None); - assert!(bounded_int::trim::(1) == OptionRev::Some(1)); + assert!(bounded_int::trim_min::(0) == OptionRev::None); + assert!(bounded_int::trim_min::(1) == OptionRev::Some(1)); assert!( - bounded_int::trim::< - u128, 0xffffffffffffffffffffffffffffffff, - >(0xffffffffffffffffffffffffffffffff) == OptionRev::None, + bounded_int::trim_max::(0xffffffffffffffffffffffffffffffff) == OptionRev::None, ); assert!( - bounded_int::trim::< - u128, 0xffffffffffffffffffffffffffffffff, + bounded_int::trim_max::< + u128, >( 0xfffffffffffffffffffffffffffffffe, ) == OptionRev::Some(0xfffffffffffffffffffffffffffffffe), ); assert!( - bounded_int::trim::< - i128, -0x80000000000000000000000000000000, - >(-0x80000000000000000000000000000000) == OptionRev::None, - ); - assert!( - bounded_int::trim::(1) == OptionRev::Some(1), + bounded_int::trim_min::(-0x80000000000000000000000000000000) == OptionRev::None, ); + assert!(bounded_int::trim_min::(1) == OptionRev::Some(1)); assert!( - bounded_int::trim::< - i128, 0x7fffffffffffffffffffffffffffffff, - >(0x7fffffffffffffffffffffffffffffff) == OptionRev::None, - ); - assert!( - bounded_int::trim::(1) == OptionRev::Some(1), + bounded_int::trim_max::(0x7fffffffffffffffffffffffffffffff) == OptionRev::None, ); + assert!(bounded_int::trim_max::(1) == OptionRev::Some(1)); } } + +#[test] +fn test_upcast_in_const() { + const AS_U8: u8 = 10; + const AS_U16: u16 = AS_U8.into(); + assert_eq!(AS_U16, 10); +} + +#[test] +fn test_downcast_in_const() { + const IN_RANGE: u16 = 10; + const OUT_OF_RANGE: u16 = 300; + const IN_RANGE_AS_U8: Option = IN_RANGE.try_into(); + const OUT_OF_RANGE_AS_U8: Option = OUT_OF_RANGE.try_into(); + assert_eq!(IN_RANGE_AS_U8, Option::Some(10)); + assert_eq!(OUT_OF_RANGE_AS_U8, Option::None); +} diff --git a/corelib/src/test/iter_test.cairo b/corelib/src/test/iter_test.cairo index d5aa58efdd6..6053c440625 100644 --- a/corelib/src/test/iter_test.cairo +++ b/corelib/src/test/iter_test.cairo @@ -1,3 +1,13 @@ +#[test] +fn test_advance_by() { + let mut iter = array![1_u8, 2, 3, 4].into_iter(); + + assert_eq!(iter.advance_by(2), Result::Ok(())); + assert_eq!(iter.next(), Option::Some(3)); + assert_eq!(iter.advance_by(0), Result::Ok(())); + assert_eq!(iter.advance_by(100), Result::Err(99)); +} + #[test] fn test_iter_adapter_map() { let mut iter = array![1, 2, 3].into_iter().map(|x| 2 * x); @@ -7,3 +17,21 @@ fn test_iter_adapter_map() { assert_eq!(iter.next(), Option::Some(6)); assert_eq!(iter.next(), Option::None); } + +#[test] +fn test_iterator_enumerate() { + let mut iter = array!['a', 'b', 'c'].into_iter().enumerate(); + + assert_eq!(iter.next(), Option::Some((0, 'a'))); + assert_eq!(iter.next(), Option::Some((1, 'b'))); + assert_eq!(iter.next(), Option::Some((2, 'c'))); + assert_eq!(iter.next(), Option::None); +} + +#[test] +fn test_iter_adapter_fold() { + let mut iter = array![1, 2, 3].into_iter(); + let sum = iter.fold(0, |acc, x| acc + x); + + assert_eq!(sum, 6); +} diff --git a/corelib/src/test/language_features/const_test.cairo b/corelib/src/test/language_features/const_test.cairo index 702ef3e1ca4..872373a7ee6 100644 --- a/corelib/src/test/language_features/const_test.cairo +++ b/corelib/src/test/language_features/const_test.cairo @@ -158,3 +158,25 @@ fn test_complex_consts() { }; assert_eq!(IF_CONST_FALSE, 7); } + +mod const_starknet_consts { + pub extern fn const_as_box() -> Box< + (starknet::ContractAddress, starknet::ClassHash), + > nopanic; +} + +#[test] +fn test_starknet_consts() { + assert!( + const_starknet_consts::const_as_box::< + struct2::Const< + (starknet::ContractAddress, starknet::ClassHash), + value::Const, + value::Const, + >, + 0, + >() + .unbox() == (1000.try_into().unwrap(), 1001.try_into().unwrap()), + ); +} + diff --git a/crates/cairo-lang-lowering/Cargo.toml b/crates/cairo-lang-lowering/Cargo.toml index 4de6fd482cb..a1a5f3c9105 100644 --- a/crates/cairo-lang-lowering/Cargo.toml +++ b/crates/cairo-lang-lowering/Cargo.toml @@ -23,7 +23,6 @@ num-bigint = { workspace = true, default-features = true } num-integer = { workspace = true, default-features = true } num-traits = { workspace = true, default-features = true } salsa.workspace = true -smol_str.workspace = true [dev-dependencies] cairo-lang-plugins = { path = "../cairo-lang-plugins" } diff --git a/crates/cairo-lang-lowering/src/lower/mod.rs b/crates/cairo-lang-lowering/src/lower/mod.rs index 7b852bb3192..895394ce2a4 100644 --- a/crates/cairo-lang-lowering/src/lower/mod.rs +++ b/crates/cairo-lang-lowering/src/lower/mod.rs @@ -1889,6 +1889,8 @@ fn add_closure_call_function( .add(&mut ctx, &mut builder.statements); for (param_var, param) in param_vars.into_iter().zip(expr.params.iter()) { builder.semantics.introduce((¶meter_as_member_path(param.clone())).into(), param_var); + ctx.semantic_defs + .insert(semantic::VarId::Param(param.id), semantic::Binding::Param(param.clone())); } let lowered_expr = lower_expr(&mut ctx, &mut builder, expr.body); let maybe_sealed_block = lowered_expr_to_block_scope_end(&mut ctx, builder, lowered_expr); diff --git a/crates/cairo-lang-lowering/src/lower/test_data/closure b/crates/cairo-lang-lowering/src/lower/test_data/closure index a121fe34854..4f93940c588 100644 --- a/crates/cairo-lang-lowering/src/lower/test_data/closure +++ b/crates/cairo-lang-lowering/src/lower/test_data/closure @@ -299,3 +299,157 @@ Statements: (v3: core::integer::u32) <- struct_destructure(v2) End: Return(v3) + +//! > ========================================================================== + +//! > Test closure with branching. + +//! > test_runner_name +test_generated_function + +//! > function +fn foo(a: u32) { + let f = |a: felt252| { + let mut b = @0; + if 1 == 2 { + b = @a; + } else { + b = @a; + } + }; + let _ = f(0); +} + +//! > function_name +foo + +//! > module_code + +//! > semantic_diagnostics + +//! > lowering_diagnostics + +//! > lowering +Main: +Parameters: v0: core::integer::u32 +blk0 (root): +Statements: + (v1: {closure@lib.cairo:2:13: 2:25}) <- struct_construct() + (v2: {closure@lib.cairo:2:13: 2:25}, v3: @{closure@lib.cairo:2:13: 2:25}) <- snapshot(v1) + (v4: core::felt252) <- 0 + (v5: (core::felt252,)) <- struct_construct(v4) + (v6: ()) <- Generated core::ops::function::Fn::<{closure@lib.cairo:2:13: 2:25}, (core::felt252,)>::call(v3, v5) + (v7: ()) <- struct_construct() +End: + Return(v7) + + +Final lowering: +Parameters: v0: core::integer::u32 +blk0 (root): +Statements: + (v1: core::felt252) <- 1 + (v2: core::felt252) <- 2 + (v3: core::felt252) <- core::felt252_sub(v1, v2) +End: + Match(match core::felt252_is_zero(v3) { + IsZeroResult::Zero => blk1, + IsZeroResult::NonZero(v4) => blk2, + }) + +blk1: +Statements: +End: + Return() + +blk2: +Statements: +End: + Return() + + +Generated core::traits::Destruct::destruct lowering for source location: + let f = |a: felt252| { + ^^^^^^^^^^^^ + +Parameters: v0: {closure@lib.cairo:2:13: 2:25} +blk0 (root): +Statements: + () <- struct_destructure(v0) + (v1: ()) <- struct_construct() +End: + Return(v1) + + +Final lowering: +Parameters: v0: {closure@lib.cairo:2:13: 2:25} +blk0 (root): +Statements: +End: + Return() + + +Generated core::ops::function::Fn::call lowering for source location: + let f = |a: felt252| { + ^^^^^^^^^^^^ + +Parameters: v0: @{closure@lib.cairo:2:13: 2:25}, v2: (core::felt252,) +blk0 (root): +Statements: + (v1: {closure@lib.cairo:2:13: 2:25}) <- desnap(v0) + () <- struct_destructure(v1) + (v3: core::felt252) <- struct_destructure(v2) + (v4: core::felt252) <- 0 + (v5: core::felt252, v6: @core::felt252) <- snapshot(v4) + (v7: core::felt252) <- 1 + (v8: core::felt252, v9: @core::felt252) <- snapshot(v7) + (v10: core::felt252) <- 2 + (v11: core::felt252, v12: @core::felt252) <- snapshot(v10) + (v13: core::bool) <- core::Felt252PartialEq::eq(v9, v12) +End: + Match(match_enum(v13) { + bool::False(v17) => blk2, + bool::True(v14) => blk1, + }) + +blk1: +Statements: + (v15: core::felt252, v16: @core::felt252) <- snapshot(v3) +End: + Goto(blk3, {v15 -> v20, v16 -> v21}) + +blk2: +Statements: + (v18: core::felt252, v19: @core::felt252) <- snapshot(v3) +End: + Goto(blk3, {v18 -> v20, v19 -> v21}) + +blk3: +Statements: + (v22: ()) <- struct_construct() +End: + Return(v22) + + +Final lowering: +Parameters: v0: @{closure@lib.cairo:2:13: 2:25}, v1: (core::felt252,) +blk0 (root): +Statements: + (v2: core::felt252) <- 1 + (v3: core::felt252) <- 2 + (v4: core::felt252) <- core::felt252_sub(v2, v3) +End: + Match(match core::felt252_is_zero(v4) { + IsZeroResult::Zero => blk1, + IsZeroResult::NonZero(v5) => blk2, + }) + +blk1: +Statements: +End: + Return() + +blk2: +Statements: +End: + Return() diff --git a/crates/cairo-lang-lowering/src/lower/test_data/loop b/crates/cairo-lang-lowering/src/lower/test_data/loop index 03e52d1164a..14998075155 100644 --- a/crates/cairo-lang-lowering/src/lower/test_data/loop +++ b/crates/cairo-lang-lowering/src/lower/test_data/loop @@ -1527,3 +1527,82 @@ Statements: (v27: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- PanicResult::Err(v26) End: Return(v6, v7, v27) + +//! > ========================================================================== + +//! > Test default implementation with loop. + +//! > test_runner_name +test_generated_function + +//! > function +fn foo() { + MyTrait::impl_in_trait(); +} + +//! > function_name +foo + +//! > module_code +trait MyTrait { + fn impl_in_impl(x: u8) -> bool; + fn impl_in_trait() -> u8 { + let mut i = 0; + loop { + if Self::impl_in_impl(i) { + break; + } + i += 1; + }; + i + } +} + +impl MyImpl of MyTrait { + fn impl_in_impl(x: u8) -> bool { + x == 30 + } +} + +//! > expected_diagnostics + +//! > semantic_diagnostics + +//! > lowering +Main: +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u8) <- test::MyImpl::impl_in_trait() + (v1: ()) <- struct_construct() +End: + Return(v1) + + +Final lowering: +Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin +blk0 (root): +Statements: + (v2: core::integer::u8) <- 0 + (v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::MyImpl::impl_in_trait[expr9](v0, v1, v2) +End: + Match(match_enum(v5) { + PanicResult::Ok(v6) => blk1, + PanicResult::Err(v7) => blk2, + }) + +blk1: +Statements: + (v8: ()) <- struct_construct() + (v9: ((),)) <- struct_construct(v8) + (v10: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v9) +End: + Return(v3, v4, v10) + +blk2: +Statements: + (v11: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v7) +End: + Return(v3, v4, v11) + +//! > lowering_diagnostics diff --git a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs index 341b6d3c1ab..0ad93d15b1b 100644 --- a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs +++ b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs @@ -4,7 +4,8 @@ mod test; use std::sync::Arc; -use cairo_lang_defs::ids::{ExternFunctionId, ModuleId, ModuleItemId}; +use cairo_lang_defs::ids::{ExternFunctionId, ModuleId}; +use cairo_lang_semantic::helper::ModuleHelper; use cairo_lang_semantic::items::constant::ConstValue; use cairo_lang_semantic::items::imp::ImplLookupContext; use cairo_lang_semantic::{GenericArgumentId, MatchArmSelector, TypeId, corelib}; @@ -17,10 +18,9 @@ use itertools::{chain, zip_eq}; use num_bigint::BigInt; use num_integer::Integer; use num_traits::Zero; -use smol_str::SmolStr; use crate::db::LoweringGroup; -use crate::ids::{FunctionId, FunctionLongId}; +use crate::ids::{FunctionId, SemanticFunctionIdEx}; use crate::{ BlockId, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct, @@ -276,10 +276,14 @@ impl ConstFoldingContext<'_> { let input_var = stmt.inputs[0].var_id; if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) { stmt.inputs.clear(); - stmt.function = ModuleHelper { db: self.db, id: self.storage_access_module } - .function_id("storage_base_address_const", vec![GenericArgumentId::Constant( - ConstValue::Int(val.clone(), *ty).intern(self.db), - )]); + stmt.function = + ModuleHelper { db: self.db.upcast(), id: self.storage_access_module } + .function_id("storage_base_address_const", vec![ + GenericArgumentId::Constant( + ConstValue::Int(val.clone(), *ty).intern(self.db), + ), + ]) + .lowered(self.db); } None } else if id == self.into_box { @@ -479,8 +483,9 @@ impl ConstFoldingContext<'_> { let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone()); let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone()); info.inputs.truncate(1); - info.function = ModuleHelper { db: self.db, id: self.array_module } - .function_id("array_snapshot_pop_front", generic_args); + info.function = ModuleHelper { db: self.db.upcast(), id: self.array_module } + .function_id("array_snapshot_pop_front", generic_args) + .lowered(self.db); success.var_ids.insert(0, unused_arr_output0); failure.var_ids.insert(0, unused_arr_output1); } @@ -539,52 +544,6 @@ pub fn priv_const_folding_info( Arc::new(ConstFoldingLibfuncInfo::new(db)) } -/// Helper for getting functions in the corelib. -struct ModuleHelper<'a> { - /// The db. - db: &'a dyn LoweringGroup, - /// The current module id. - id: ModuleId, -} -impl<'a> ModuleHelper<'a> { - /// Returns a helper for the core module. - fn core(db: &'a dyn LoweringGroup) -> Self { - Self { db, id: corelib::core_module(db.upcast()) } - } - /// Returns a helper for a submodule named `name` of the current module. - fn submodule(&self, name: &str) -> Self { - let id = corelib::get_submodule(self.db.upcast(), self.id, name).unwrap_or_else(|| { - panic!("`{name}` missing in `{}`.", self.id.full_path(self.db.upcast())) - }); - Self { db: self.db, id } - } - /// Returns the id of an extern function named `name` in the current module. - fn extern_function_id(&self, name: impl Into) -> ExternFunctionId { - let name = name.into(); - let Ok(Some(ModuleItemId::ExternFunction(id))) = - self.db.module_item_by_name(self.id, name.clone()) - else { - panic!("`{}` not found in `{}`.", name, self.id.full_path(self.db.upcast())); - }; - id - } - /// Returns the id of a function named `name` in the current module, with the given - /// `generic_args`. - fn function_id( - &self, - name: impl Into, - generic_args: Vec, - ) -> FunctionId { - FunctionLongId::Semantic(corelib::get_function_id( - self.db.upcast(), - self.id, - name.into(), - generic_args, - )) - .intern(self.db) - } -} - /// Holds static information about libfuncs required for the optimization. #[derive(Debug, PartialEq, Eq)] pub struct ConstFoldingLibfuncInfo { @@ -633,7 +592,7 @@ pub struct ConstFoldingLibfuncInfo { } impl ConstFoldingLibfuncInfo { fn new(db: &dyn LoweringGroup) -> Self { - let core = ModuleHelper::core(db); + let core = ModuleHelper::core(db.upcast()); let felt_sub = core.extern_function_id("felt252_sub"); let box_module = core.submodule("box"); let into_box = box_module.extern_function_id("into_box"); @@ -707,7 +666,9 @@ impl ConstFoldingLibfuncInfo { let info = TypeInfo { min, max, - is_zero: integer_module.function_id(format!("{ty}_is_zero"), vec![]), + is_zero: integer_module + .function_id(format!("{ty}_is_zero"), vec![]) + .lowered(db), }; (corelib::get_core_ty_by_name(db.upcast(), ty.into(), vec![]), info) }), diff --git a/crates/cairo-lang-parser/src/parser_test_data/partial_trees/if_else b/crates/cairo-lang-parser/src/parser_test_data/partial_trees/if_else index b4b460b5426..0409f6edc9c 100644 --- a/crates/cairo-lang-parser/src/parser_test_data/partial_trees/if_else +++ b/crates/cairo-lang-parser/src/parser_test_data/partial_trees/if_else @@ -155,7 +155,7 @@ ExprBlock //! > ========================================================================== -//! > Test if-let boolean opperators +//! > Test if-let boolean operators //! > test_runner_name test_partial_parser_tree(expect_diagnostics: false) diff --git a/crates/cairo-lang-semantic/src/corelib.rs b/crates/cairo-lang-semantic/src/corelib.rs index f1a5c18379c..f161a538073 100644 --- a/crates/cairo-lang-semantic/src/corelib.rs +++ b/crates/cairo-lang-semantic/src/corelib.rs @@ -904,7 +904,7 @@ impl LiteralError { pub fn validate_literal( db: &dyn SemanticGroup, ty: TypeId, - value: BigInt, + value: &BigInt, ) -> Result<(), LiteralError> { if let Some(nz_wrapped_ty) = try_extract_nz_wrapped_type(db, ty) { return if value.is_zero() { @@ -914,7 +914,7 @@ pub fn validate_literal( }; } let is_out_of_range = if let Some((min, max)) = try_extract_bounded_int_type_ranges(db, ty) { - value < min || value > max + *value < min || *value > max } else if ty == db.core_felt252_ty() { value.abs() > BigInt::from_str_radix( diff --git a/crates/cairo-lang-semantic/src/diagnostic.rs b/crates/cairo-lang-semantic/src/diagnostic.rs index 652df9e2798..9490c31765f 100644 --- a/crates/cairo-lang-semantic/src/diagnostic.rs +++ b/crates/cairo-lang-semantic/src/diagnostic.rs @@ -997,6 +997,9 @@ impl DiagnosticEntry for SemanticDiagnostic { SemanticDiagnosticKind::RefClosureArgument => { "Arguments to closure functions cannot be references".into() } + SemanticDiagnosticKind::RefClosureParam => { + "Closure parameters cannot be references".into() + } SemanticDiagnosticKind::MutableCapturedVariable => { "Capture of mutable variables in a closure is not supported".into() } @@ -1422,6 +1425,7 @@ pub enum SemanticDiagnosticKind { shadowed_function_name: SmolStr, }, RefClosureArgument, + RefClosureParam, MutableCapturedVariable, NonTraitTypeConstrained { identifier: SmolStr, diff --git a/crates/cairo-lang-semantic/src/expr/compute.rs b/crates/cairo-lang-semantic/src/expr/compute.rs index 09a0d88d416..b3346a1ca03 100644 --- a/crates/cairo-lang-semantic/src/expr/compute.rs +++ b/crates/cairo-lang-semantic/src/expr/compute.rs @@ -66,7 +66,7 @@ use crate::expr::inference::{ImplVarTraitItemMappings, InferenceId}; use crate::items::constant::{ConstValue, resolve_const_expr_and_evaluate, validate_const_expr}; use crate::items::enm::SemanticEnumEx; use crate::items::feature_kind::extract_item_feature_config; -use crate::items::functions::function_signature_params; +use crate::items::functions::{concrete_function_closure_params, function_signature_params}; use crate::items::imp::{ImplLookupContext, filter_candidate_traits, infer_impl_by_self}; use crate::items::modifiers::compute_mutability; use crate::items::us::get_use_path_segments; @@ -84,8 +84,8 @@ use crate::types::{ }; use crate::usage::Usages; use crate::{ - ConcreteEnumId, GenericArgumentId, GenericParam, LocalItem, Member, Mutability, Parameter, - PatternStringLiteral, PatternStruct, Signature, StatementItemKind, + ConcreteEnumId, ConcreteFunction, GenericArgumentId, GenericParam, LocalItem, Member, + Mutability, Parameter, PatternStringLiteral, PatternStruct, Signature, StatementItemKind, }; /// Expression with its id. @@ -424,7 +424,7 @@ pub fn maybe_compute_expr_semantic( ast::Expr::Indexed(expr) => compute_expr_indexed_semantic(ctx, expr), ast::Expr::FixedSizeArray(expr) => compute_expr_fixed_size_array_semantic(ctx, expr), ast::Expr::For(expr) => compute_expr_for_semantic(ctx, expr), - ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr), + ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr, None), } } @@ -882,7 +882,7 @@ fn compute_expr_function_call_semantic( let mut arg_types = vec![]; for arg_syntax in args_iter { let stable_ptr = arg_syntax.stable_ptr(); - let arg = compute_named_argument_clause(ctx, arg_syntax); + let arg = compute_named_argument_clause(ctx, arg_syntax, None); if arg.2 != Mutability::Immutable { return Err(ctx.diagnostics.report(stable_ptr, RefClosureArgument)); } @@ -930,7 +930,7 @@ fn compute_expr_function_call_semantic( let named_args: Vec<_> = args_syntax .elements(syntax_db) .into_iter() - .map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax)) + .map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax, None)) .collect(); if named_args.len() != 1 { return Err(ctx.diagnostics.report(syntax, WrongNumberOfArguments { @@ -979,16 +979,22 @@ fn compute_expr_function_call_semantic( let mut args_iter = args_syntax.elements(syntax_db).into_iter(); // Normal parameters let mut named_args = vec![]; - for _ in function_parameter_types(ctx, function)? { + let ConcreteFunction { .. } = function.lookup_intern(db).function; + let closure_params = concrete_function_closure_params(db, function)?; + for ty in function_parameter_types(ctx, function)? { let Some(arg_syntax) = args_iter.next() else { continue; }; - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause( + ctx, + arg_syntax, + closure_params.get(&ty).cloned(), + )); } // Maybe coupon if let Some(arg_syntax) = args_iter.next() { - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause(ctx, arg_syntax, None)); } expr_function_call(ctx, function, named_args, syntax, syntax.stable_ptr().into()) @@ -1006,6 +1012,7 @@ fn compute_expr_function_call_semantic( pub fn compute_named_argument_clause( ctx: &mut ComputationContext<'_>, arg_syntax: ast::Arg, + closure_param_types: Option, ) -> NamedArg { let syntax_db = ctx.db.upcast(); @@ -1018,12 +1025,38 @@ pub fn compute_named_argument_clause( let arg_clause = arg_syntax.arg_clause(syntax_db); let (expr, arg_name_identifier) = match arg_clause { ast::ArgClause::Unnamed(arg_unnamed) => { - (compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None) + let arg_expr = arg_unnamed.value(syntax_db); + if let ast::Expr::Closure(expr_closure) = arg_expr { + let expr = compute_expr_closure_semantic(ctx, &expr_closure, closure_param_types); + let expr = wrap_maybe_with_missing( + ctx, + expr, + ast::ExprPtr::from(expr_closure.stable_ptr()), + ); + let id = ctx.arenas.exprs.alloc(expr.clone()); + (ExprAndId { expr, id }, None) + } else { + (compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None) + } + } + ast::ArgClause::Named(arg_named) => { + let arg_expr = arg_named.value(syntax_db); + if let ast::Expr::Closure(expr_closure) = arg_expr { + let expr = compute_expr_closure_semantic(ctx, &expr_closure, closure_param_types); + let expr = wrap_maybe_with_missing( + ctx, + expr, + ast::ExprPtr::from(expr_closure.stable_ptr()), + ); + let id = ctx.arenas.exprs.alloc(expr.clone()); + (ExprAndId { expr, id }, None) + } else { + ( + compute_expr_semantic(ctx, &arg_named.value(syntax_db)), + Some(arg_named.name(syntax_db)), + ) + } } - ast::ArgClause::Named(arg_named) => ( - compute_expr_semantic(ctx, &arg_named.value(syntax_db)), - Some(arg_named.name(syntax_db)), - ), ast::ArgClause::FieldInitShorthand(arg_field_init_shorthand) => { let name_expr = arg_field_init_shorthand.name(syntax_db); let stable_ptr: ast::ExprPtr = name_expr.stable_ptr().into(); @@ -1645,6 +1678,7 @@ fn compute_loop_body_semantic( fn compute_expr_closure_semantic( ctx: &mut ComputationContext<'_>, syntax: &ast::ExprClosure, + param_types: Option, ) -> Maybe { ctx.are_closures_in_context = true; let syntax_db = ctx.db.upcast(); @@ -1663,6 +1697,18 @@ fn compute_expr_closure_semantic( } else { vec![] }; + let closure_type = + TypeLongId::Tuple(params.iter().map(|param| param.ty).collect()).intern(new_ctx.db); + if let Some(param_types) = param_types { + if let Err(err_set) = new_ctx.resolver.inference().conform_ty(closure_type, param_types) + { + new_ctx.resolver.inference().consume_error_without_reporting(err_set); + } + } + + params.iter().filter(|param| param.mutability == Mutability::Reference).for_each(|param| { + new_ctx.diagnostics.report(param.stable_ptr(ctx.db.upcast()), RefClosureParam); + }); new_ctx .semantic_defs @@ -2830,16 +2876,22 @@ fn method_call_expr( // Self argument. let mut named_args = vec![NamedArg(fixed_lexpr, None, mutability)]; // Other arguments. - for _ in function_parameter_types(ctx, function_id)?.skip(1) { + let ConcreteFunction { .. } = function_id.lookup_intern(ctx.db).function; + let closure_params = concrete_function_closure_params(ctx.db, function_id)?; + for ty in function_parameter_types(ctx, function_id)?.skip(1) { let Some(arg_syntax) = args_iter.next() else { break; }; - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause( + ctx, + arg_syntax, + closure_params.get(&ty).cloned(), + )); } // Maybe coupon if let Some(arg_syntax) = args_iter.next() { - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause(ctx, arg_syntax, None)); } expr_function_call(ctx, function_id, named_args, &expr, stable_ptr) @@ -3259,7 +3311,6 @@ fn expr_function_call( // Check argument names and types. check_named_arguments(&named_args, &signature, ctx)?; - let mut args = Vec::new(); for (NamedArg(arg, _name, mutability), param) in named_args.into_iter().zip(signature.params.iter()) diff --git a/crates/cairo-lang-semantic/src/expr/test_data/closure b/crates/cairo-lang-semantic/src/expr/test_data/closure index b47c05c6902..8f907f5ff3e 100644 --- a/crates/cairo-lang-semantic/src/expr/test_data/closure +++ b/crates/cairo-lang-semantic/src/expr/test_data/closure @@ -743,3 +743,48 @@ error: Cannot assign to an immutable variable. --> lib.cairo:3:9 a = a + 2 ^^^^^^^^^ + +//! > ========================================================================== + +//! > Closures with ref arguments. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: true) + +//! > function +fn foo() { + let _ = |ref a| { + a = a + 2 + }; +} + +//! > function_name +foo + +//! > module_code + +//! > expected_diagnostics +error: Closure parameters cannot be references + --> lib.cairo:2:14 + let _ = |ref a| { + ^^^^^ + +//! > ========================================================================== + +//! > Passing closures as args with less explicit typing. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: false) + +//! > function +fn foo() -> Option { + let x: Option> = Option::Some(array![1, 2, 3]); + x.map(|x| x.len()) +} + +//! > function_name +foo + +//! > module_code + +//! > expected_diagnostics diff --git a/crates/cairo-lang-semantic/src/helper.rs b/crates/cairo-lang-semantic/src/helper.rs new file mode 100644 index 00000000000..9ca65443c57 --- /dev/null +++ b/crates/cairo-lang-semantic/src/helper.rs @@ -0,0 +1,45 @@ +use cairo_lang_defs::ids::{ExternFunctionId, ModuleId, ModuleItemId}; +use smol_str::SmolStr; + +use crate::db::SemanticGroup; +use crate::{FunctionId, GenericArgumentId, corelib}; + +/// Helper for getting functions in the corelib. +pub struct ModuleHelper<'a> { + /// The db. + pub db: &'a dyn SemanticGroup, + /// The current module id. + pub id: ModuleId, +} +impl<'a> ModuleHelper<'a> { + /// Returns a helper for the core module. + pub fn core(db: &'a dyn SemanticGroup) -> Self { + Self { db, id: db.core_module() } + } + /// Returns a helper for a submodule named `name` of the current module. + pub fn submodule(&self, name: &str) -> Self { + let id = corelib::get_submodule(self.db, self.id, name).unwrap_or_else(|| { + panic!("`{name}` missing in `{}`.", self.id.full_path(self.db.upcast())) + }); + Self { db: self.db, id } + } + /// Returns the id of an extern function named `name` in the current module. + pub fn extern_function_id(&self, name: impl Into) -> ExternFunctionId { + let name = name.into(); + let Ok(Some(ModuleItemId::ExternFunction(id))) = + self.db.module_item_by_name(self.id, name.clone()) + else { + panic!("`{}` not found in `{}`.", name, self.id.full_path(self.db.upcast())); + }; + id + } + /// Returns the id of a function named `name` in the current module, with the given + /// `generic_args`. + pub fn function_id( + &self, + name: impl Into, + generic_args: Vec, + ) -> FunctionId { + corelib::get_function_id(self.db, self.id, name.into(), generic_args) + } +} diff --git a/crates/cairo-lang-semantic/src/items/constant.rs b/crates/cairo-lang-semantic/src/items/constant.rs index 7dee8015366..6db9d3f4717 100644 --- a/crates/cairo-lang-semantic/src/items/constant.rs +++ b/crates/cairo-lang-semantic/src/items/constant.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use cairo_lang_debug::DebugWithDb; use cairo_lang_defs::ids::{ - ConstantId, GenericParamId, LanguageElementId, LookupItemId, ModuleItemId, + ConstantId, ExternFunctionId, GenericParamId, LanguageElementId, LookupItemId, ModuleItemId, NamedLanguageElementId, TraitConstantId, TraitFunctionId, TraitId, VarId, }; use cairo_lang_diagnostics::{ @@ -26,8 +26,8 @@ use smol_str::SmolStr; use super::functions::{GenericFunctionId, GenericFunctionWithBodyId}; use super::imp::{ImplId, ImplLongId}; use crate::corelib::{ - CoreTraitContext, LiteralError, core_box_ty, core_nonzero_ty, false_variant, - get_core_function_id, get_core_trait, get_core_ty_by_name, true_variant, + CoreTraitContext, LiteralError, core_box_ty, core_nonzero_ty, false_variant, get_core_trait, + get_core_ty_by_name, option_none_variant, option_some_variant, true_variant, try_extract_nz_wrapped_type, unit_ty, validate_literal, }; use crate::db::SemanticGroup; @@ -37,6 +37,7 @@ use crate::expr::compute::{ }; use crate::expr::inference::conform::InferenceConform; use crate::expr::inference::{ConstVar, InferenceId}; +use crate::helper::ModuleHelper; use crate::literals::try_extract_minus_literal; use crate::resolve::{Resolver, ResolverData}; use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter}; @@ -44,8 +45,8 @@ use crate::types::resolve_type; use crate::{ Arenas, ConcreteFunction, ConcreteTypeId, ConcreteVariant, Condition, Expr, ExprBlock, ExprConstant, ExprFunctionCall, ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, - FunctionId, GenericParam, LogicalOperator, Pattern, PatternId, SemanticDiagnostic, Statement, - TypeId, TypeLongId, semantic_object_for_id, + FunctionId, GenericArgumentId, GenericParam, LogicalOperator, Pattern, PatternId, + SemanticDiagnostic, Statement, TypeId, TypeLongId, semantic_object_for_id, }; #[derive(Clone, Debug, PartialEq, Eq, DebugWithDb)] @@ -440,7 +441,7 @@ pub fn value_as_const_value( ty: TypeId, value: &BigInt, ) -> Result { - validate_literal(db.upcast(), ty, value.clone())?; + validate_literal(db.upcast(), ty, value)?; let get_basic_const_value = |ty| { let u256_ty = get_core_ty_by_name(db.upcast(), "u256".into(), vec![]); @@ -501,7 +502,7 @@ impl ConstantEvaluateContext<'_> { } Expr::FunctionCall(expr) => { if let Some(value) = try_extract_minus_literal(self.db, &self.arenas.exprs, expr) { - if let Err(err) = validate_literal(self.db, expr.ty, value) { + if let Err(err) = validate_literal(self.db, expr.ty, &value) { self.diagnostics.report( expr.stable_ptr.untyped(), SemanticDiagnosticKind::LiteralError(err), @@ -531,7 +532,7 @@ impl ConstantEvaluateContext<'_> { } } Expr::Literal(expr) => { - if let Err(err) = validate_literal(self.db, expr.ty, expr.value.clone()) { + if let Err(err) = validate_literal(self.db, expr.ty, &expr.value) { self.diagnostics.report( expr.stable_ptr.untyped(), SemanticDiagnosticKind::LiteralError(err), @@ -918,6 +919,40 @@ impl ConstantEvaluateContext<'_> { stable_ptr: SyntaxStablePtrId, ) -> Option { let db = self.db; + if let GenericFunctionId::Extern(extern_fn) = concrete_function.generic_function { + if extern_fn == self.upcast_fn { + let ( + [ConstValue::Int(value, _)], + [GenericArgumentId::Type(_in_ty), GenericArgumentId::Type(out_ty)], + ) = (args, &concrete_function.generic_args[..]) + else { + return None; + }; + return Some(ConstValue::Int(value.clone(), *out_ty)); + } else if extern_fn == self.downcast_fn { + let ( + [ConstValue::Int(value, _)], + [GenericArgumentId::Type(_in_ty), GenericArgumentId::Type(out_ty)], + ) = (args, &concrete_function.generic_args[..]) + else { + return None; + }; + return Some(match validate_literal(db, *out_ty, value) { + Ok(()) => ConstValue::Enum( + option_some_variant(db, *out_ty), + ConstValue::Int(value.clone(), *out_ty).into(), + ), + Err(LiteralError::OutOfRange(_)) => ConstValue::Enum( + option_none_variant(db, *out_ty), + self.unit_const.clone().into(), + ), + Err(LiteralError::InvalidTypeForLiteral(_)) => unreachable!( + "`downcast` is only allowed into types that can be literals. Got `{}`.", + out_ty.format(db) + ), + }); + } + } let body_id = concrete_function.body(db).ok()??; let concrete_body_id = body_id.function_with_body_id(db); let signature = db.function_with_body_signature(concrete_body_id).ok()?; @@ -1210,6 +1245,10 @@ pub struct ConstCalcInfo { false_const: ConstValue, /// The function for panicking with a felt252. panic_with_felt252: FunctionId, + /// The integer `upcast` function. + upcast_fn: ExternFunctionId, + /// The integer `downcast` function. + downcast_fn: ExternFunctionId, } impl ConstCalcInfo { @@ -1232,6 +1271,8 @@ impl ConstCalcInfo { db.trait_function_by_name(trait_id, name.into()).unwrap().unwrap() }; let unit_const = ConstValue::Struct(vec![], unit_ty(db)); + let core = ModuleHelper::core(db); + let integer = core.submodule("integer"); Self { const_traits: [ neg_trait, @@ -1270,7 +1311,9 @@ impl ConstCalcInfo { true_const: ConstValue::Enum(true_variant(db), unit_const.clone().into()), false_const: ConstValue::Enum(false_variant(db), unit_const.clone().into()), unit_const, - panic_with_felt252: get_core_function_id(db, "panic_with_felt252".into(), vec![]), + panic_with_felt252: core.function_id("panic_with_felt252", vec![]), + upcast_fn: integer.extern_function_id("upcast"), + downcast_fn: integer.extern_function_id("downcast"), } } } diff --git a/crates/cairo-lang-semantic/src/items/functions.rs b/crates/cairo-lang-semantic/src/items/functions.rs index 65dc2ae0d3a..7d98f2f4fd6 100644 --- a/crates/cairo-lang-semantic/src/items/functions.rs +++ b/crates/cairo-lang-semantic/src/items/functions.rs @@ -1,4 +1,3 @@ -use std::fmt::Debug; use std::sync::Arc; use cairo_lang_debug::DebugWithDb; @@ -14,6 +13,7 @@ use cairo_lang_proc_macros::{DebugWithDb, SemanticObject}; use cairo_lang_syntax as syntax; use cairo_lang_syntax::attribute::structured::Attribute; use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast}; +use cairo_lang_utils::ordered_hash_map::OrderedHashMap; use cairo_lang_utils::{ Intern, LookupIntern, OptionFrom, define_short_id, require, try_extract_matches, }; @@ -27,7 +27,7 @@ use super::generics::{fmt_generic_args, generic_params_to_args}; use super::imp::{ImplId, ImplLongId}; use super::modifiers; use super::trt::ConcreteTraitGenericFunctionId; -use crate::corelib::{panic_destruct_trait_fn, unit_ty}; +use crate::corelib::{fn_traits, panic_destruct_trait_fn, unit_ty}; use crate::db::SemanticGroup; use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder}; use crate::expr::compute::Environment; @@ -35,8 +35,8 @@ use crate::resolve::{Resolver, ResolverData}; use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter}; use crate::types::resolve_type; use crate::{ - ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericParam, SemanticDiagnostic, - TypeId, semantic, semantic_object_for_id, + ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericArgumentId, GenericParam, + SemanticDiagnostic, TypeId, semantic, semantic_object_for_id, }; /// A generic function of an impl. @@ -124,6 +124,36 @@ impl GenericFunctionId { } } } + + pub fn get_closure_params( + &self, + db: &dyn SemanticGroup, + ) -> Maybe> { + let mut closure_params_map = OrderedHashMap::default(); + let generic_params = self.generic_params(db)?; + + for param in generic_params { + if let GenericParam::Impl(generic_param_impl) = param { + let trait_id = generic_param_impl.concrete_trait?.trait_id(db); + + if fn_traits(db).contains(&trait_id) { + if let Ok(concrete_trait) = generic_param_impl.concrete_trait { + let [ + GenericArgumentId::Type(closure_type), + GenericArgumentId::Type(params_type), + ] = *concrete_trait.generic_args(db) + else { + unreachable!() + }; + + closure_params_map.insert(closure_type, params_type); + } + } + } + } + Ok(closure_params_map) + } + pub fn generic_signature(&self, db: &dyn SemanticGroup) -> Maybe { match *self { GenericFunctionId::Free(id) => db.free_function_signature(id), @@ -146,8 +176,11 @@ impl GenericFunctionId { GenericFunctionId::Extern(id) => db.extern_function_declaration_generic_params(id), GenericFunctionId::Impl(id) => { let concrete_trait_id = db.impl_concrete_trait(id.impl_id)?; - let id = ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function); - db.concrete_trait_function_generic_params(id) + let concrete_id = + ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function); + let substitution = GenericSubstitution::from_impl(id.impl_id); + let mut rewriter = SubstitutionRewriter { db, substitution: &substitution }; + rewriter.rewrite(db.concrete_trait_function_generic_params(concrete_id)?) } GenericFunctionId::Trait(id) => db.concrete_trait_function_generic_params(id), } @@ -860,6 +893,19 @@ pub fn concrete_function_signature( SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_signature) } +/// Query implementation of [crate::db::SemanticGroup::concrete_function_closure_params]. +pub fn concrete_function_closure_params( + db: &dyn SemanticGroup, + function_id: FunctionId, +) -> Maybe> { + let ConcreteFunction { generic_function, generic_args, .. } = + function_id.lookup_intern(db).function; + let generic_params = generic_function.generic_params(db)?; + let generic_closure_params = generic_function.get_closure_params(db)?; + let substitution = GenericSubstitution::new(&generic_params, &generic_args); + SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_closure_params) +} + /// For a given list of AST parameters, returns the list of semantic parameters along with the /// corresponding environment. fn update_env_with_ast_params( diff --git a/crates/cairo-lang-semantic/src/lib.rs b/crates/cairo-lang-semantic/src/lib.rs index 55bf65d53b8..de15d4574ee 100644 --- a/crates/cairo-lang-semantic/src/lib.rs +++ b/crates/cairo-lang-semantic/src/lib.rs @@ -6,6 +6,7 @@ pub mod corelib; pub mod db; pub mod diagnostic; pub mod expr; +pub mod helper; pub mod inline_macros; pub mod items; pub mod literals; diff --git a/crates/cairo-lang-semantic/src/substitution.rs b/crates/cairo-lang-semantic/src/substitution.rs index 0ae2177e339..92971ee4d2f 100644 --- a/crates/cairo-lang-semantic/src/substitution.rs +++ b/crates/cairo-lang-semantic/src/substitution.rs @@ -464,7 +464,7 @@ add_basic_rewrites!( <'a>, SubstitutionRewriter<'a>, DiagnosticAdded, - @exclude TypeId TypeLongId ImplId ImplLongId ConstValue GenericFunctionId + @exclude TypeId TypeLongId ImplId ImplLongId ConstValue GenericFunctionId GenericFunctionWithBodyId ); impl SemanticRewriter for SubstitutionRewriter<'_> { @@ -633,3 +633,21 @@ impl SemanticRewriter for SubstitutionRewrit value.default_rewrite(self) } } +impl SemanticRewriter for SubstitutionRewriter<'_> { + fn internal_rewrite(&mut self, value: &mut GenericFunctionWithBodyId) -> Maybe { + if let GenericFunctionWithBodyId::Trait(id) = value { + if let Some(self_impl) = &self.substitution.self_impl { + if let ImplLongId::Concrete(concrete_impl_id) = self_impl.lookup_intern(self.db) { + if id.concrete_trait(self.db.upcast()) == self_impl.concrete_trait(self.db)? { + *value = GenericFunctionWithBodyId::Impl(ImplGenericFunctionWithBodyId { + concrete_impl_id, + function_body: ImplFunctionBodyId::Trait(id.trait_function(self.db)), + }); + return Ok(RewriteResult::Modified); + } + } + } + } + value.default_rewrite(self) + } +} diff --git a/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs b/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs index 2d06edc6fbf..2b40ec7a47d 100644 --- a/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs +++ b/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs @@ -394,7 +394,8 @@ pub fn core_libfunc_ap_change( ApChange::Known(1 + if libfunc.boundary.is_zero() { 0 } else { 1 }), ] } - BoundedIntConcreteLibfunc::Trim(libfunc) => { + BoundedIntConcreteLibfunc::TrimMin(libfunc) + | BoundedIntConcreteLibfunc::TrimMax(libfunc) => { let ap_change = if libfunc.trimmed_value.is_zero() { 0 } else { 1 }; vec![ApChange::Known(ap_change), ApChange::Known(ap_change)] } diff --git a/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs b/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs index af1fff62434..5e68f985650 100644 --- a/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs +++ b/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs @@ -511,7 +511,8 @@ pub fn core_libfunc_cost( .into(), ] } - BoundedIntConcreteLibfunc::Trim(libfunc) => { + BoundedIntConcreteLibfunc::TrimMin(libfunc) + | BoundedIntConcreteLibfunc::TrimMax(libfunc) => { let steps: BranchCost = ConstCost::steps(if libfunc.trimmed_value.is_zero() { 1 } else { 2 }).into(); vec![steps.clone(), steps] diff --git a/crates/cairo-lang-sierra-to-casm/src/invocations/int/bounded.rs b/crates/cairo-lang-sierra-to-casm/src/invocations/int/bounded.rs index b39fbeb07cb..972e53c147c 100644 --- a/crates/cairo-lang-sierra-to-casm/src/invocations/int/bounded.rs +++ b/crates/cairo-lang-sierra-to-casm/src/invocations/int/bounded.rs @@ -39,7 +39,10 @@ pub fn build( BoundedIntConcreteLibfunc::Constrain(libfunc) => { build_constrain(builder, &libfunc.boundary) } - BoundedIntConcreteLibfunc::Trim(libfunc) => build_trim(builder, &libfunc.trimmed_value), + BoundedIntConcreteLibfunc::TrimMin(libfunc) + | BoundedIntConcreteLibfunc::TrimMax(libfunc) => { + build_trim(builder, &libfunc.trimmed_value) + } BoundedIntConcreteLibfunc::IsZero(_) => build_is_zero(builder), BoundedIntConcreteLibfunc::WrapNonZero(_) => build_identity(builder), } diff --git a/crates/cairo-lang-sierra-to-casm/src/test_data/errors b/crates/cairo-lang-sierra-to-casm/src/test_data/errors index 52ad030dc12..81467fd49f8 100644 --- a/crates/cairo-lang-sierra-to-casm/src/test_data/errors +++ b/crates/cairo-lang-sierra-to-casm/src/test_data/errors @@ -1079,3 +1079,33 @@ foo@0() -> (); //! > error Code size limit exceeded. + +//! > ========================================================================== + +//! > ContractAddress out of range. + +//! > test_runner_name +compiler_errors + +//! > sierra_code +type ContractAddress = ContractAddress; +type InRange = Const; +type OutOfRange = Const; + +//! > error +Error from program registry: Error during type specialization of `OutOfRange`: Could not specialize type + +//! > ========================================================================== + +//! > ClassHash out of range. + +//! > test_runner_name +compiler_errors + +//! > sierra_code +type ClassHash = ClassHash; +type InRange = Const; +type OutOfRange = Const; + +//! > error +Error from program registry: Error during type specialization of `OutOfRange`: Could not specialize type diff --git a/crates/cairo-lang-sierra/src/extensions/modules/bounded_int.rs b/crates/cairo-lang-sierra/src/extensions/modules/bounded_int.rs index b73fcc53d47..ba8d94b3874 100644 --- a/crates/cairo-lang-sierra/src/extensions/modules/bounded_int.rs +++ b/crates/cairo-lang-sierra/src/extensions/modules/bounded_int.rs @@ -79,7 +79,8 @@ define_libfunc_hierarchy! { Mul(BoundedIntMulLibfunc), DivRem(BoundedIntDivRemLibfunc), Constrain(BoundedIntConstrainLibfunc), - Trim(BoundedIntTrimLibfunc), + TrimMin(BoundedIntTrimLibfunc), + TrimMax(BoundedIntTrimLibfunc), IsZero(BoundedIntIsZeroLibfunc), WrapNonZero(BoundedIntWrapNonZeroLibfunc), }, BoundedIntConcreteLibfunc @@ -388,67 +389,71 @@ impl SignatureBasedConcreteLibfunc for BoundedIntConstrainConcreteLibfunc { /// Libfunc for trimming a BoundedInt by removing `Min` or `Max` from the range. /// The libfunc is also applicable for standard types such as u* and i*. #[derive(Default)] -pub struct BoundedIntTrimLibfunc {} -impl NamedLibfunc for BoundedIntTrimLibfunc { +pub struct BoundedIntTrimLibfunc {} +impl NamedLibfunc for BoundedIntTrimLibfunc { type Concrete = BoundedIntTrimConcreteLibfunc; - const STR_ID: &'static str = "bounded_int_trim"; + const STR_ID: &'static str = + if IS_MAX { "bounded_int_trim_max" } else { "bounded_int_trim_min" }; fn specialize_signature( &self, context: &dyn SignatureSpecializationContext, args: &[GenericArg], ) -> Result { - let (ty, trimmed_value) = match args { - [GenericArg::Type(ty), GenericArg::Value(trimmed_value)] => Ok((ty, trimmed_value)), - [_, _] => Err(SpecializationError::UnsupportedGenericArg), - _ => Err(SpecializationError::WrongNumberOfGenericArgs), - }?; + Ok(Self::Concrete::new::(context, args)?.signature) + } + + fn specialize( + &self, + context: &dyn SpecializationContext, + args: &[GenericArg], + ) -> Result { + Self::Concrete::new::(context.upcast(), args) + } +} + +pub struct BoundedIntTrimConcreteLibfunc { + pub trimmed_value: BigInt, + signature: LibfuncSignature, +} +impl BoundedIntTrimConcreteLibfunc { + fn new( + context: &dyn SignatureSpecializationContext, + args: &[GenericArg], + ) -> Result { + let ty = args_as_single_type(args)?; let ty_info = context.get_type_info(ty.clone())?; - let mut range = Range::from_type_info(&ty_info)?; - if trimmed_value == &range.lower { - range.lower += 1; + let range = Range::from_type_info(&ty_info)?; + let (res_ty, trimmed_value) = if IS_MAX { + ( + bounded_int_ty(context, range.lower.clone(), range.upper.clone() - 2)?, + range.upper - 1, + ) } else { - range.upper -= 1; - require(&range.upper == trimmed_value) - .ok_or(SpecializationError::UnsupportedGenericArg)?; - } + ( + bounded_int_ty(context, range.lower.clone() + 1, range.upper.clone() - 1)?, + range.lower, + ) + }; let ap_change = SierraApChange::Known { new_vars_only: trimmed_value.is_zero() }; - Ok(LibfuncSignature { + let signature = LibfuncSignature { param_signatures: vec![ParamSignature::new(ty.clone())], branch_signatures: vec![ BranchSignature { vars: vec![], ap_change: ap_change.clone() }, BranchSignature { vars: vec![OutputVarInfo { - ty: bounded_int_ty(context, range.lower, range.upper - 1)?, + ty: res_ty, ref_info: OutputVarReferenceInfo::SameAsParam { param_idx: 0 }, }], ap_change, }, ], fallthrough: Some(0), - }) - } - - fn specialize( - &self, - context: &dyn SpecializationContext, - args: &[GenericArg], - ) -> Result { - let trimmed_value = match args { - [GenericArg::Type(_), GenericArg::Value(trimmed_value)] => Ok(trimmed_value.clone()), - [_, _] => Err(SpecializationError::UnsupportedGenericArg), - _ => Err(SpecializationError::WrongNumberOfGenericArgs), - }?; - let context = context.upcast(); - Ok(Self::Concrete { trimmed_value, signature: self.specialize_signature(context, args)? }) + }; + Ok(Self { trimmed_value, signature }) } } - -pub struct BoundedIntTrimConcreteLibfunc { - pub trimmed_value: BigInt, - signature: LibfuncSignature, -} impl SignatureBasedConcreteLibfunc for BoundedIntTrimConcreteLibfunc { fn signature(&self) -> &LibfuncSignature { &self.signature diff --git a/crates/cairo-lang-sierra/src/extensions/modules/const_type.rs b/crates/cairo-lang-sierra/src/extensions/modules/const_type.rs index e7d439273c6..886712b3a4e 100644 --- a/crates/cairo-lang-sierra/src/extensions/modules/const_type.rs +++ b/crates/cairo-lang-sierra/src/extensions/modules/const_type.rs @@ -3,9 +3,14 @@ use itertools::Itertools; use num_traits::{ToPrimitive, Zero}; use super::boxing::box_ty; +use super::consts::ConstGenLibfunc; use super::enm::EnumType; use super::int::unsigned128::Uint128Type; use super::non_zero::NonZeroType; +use super::starknet::interoperability::{ + ClassHashConstLibfuncWrapped, ClassHashType, ContractAddressConstLibfuncWrapped, + ContractAddressType, +}; use super::structure::StructType; use super::utils::Range; use crate::define_libfunc_hierarchy; @@ -67,20 +72,26 @@ fn validate_const_data( inner_data: &[GenericArg], ) -> Result<(), SpecializationError> { let inner_type_info = context.get_type_info(inner_ty.clone())?; - if inner_type_info.long_id.generic_id == StructType::ID { - validate_const_struct_data(context, &inner_type_info, inner_data)?; - } else if inner_type_info.long_id.generic_id == EnumType::ID { - validate_const_enum_data(context, &inner_type_info, inner_data)?; - } else if inner_type_info.long_id.generic_id == NonZeroType::ID { - validate_const_nz_data(context, &inner_type_info, inner_data)?; + let inner_generic_id = &inner_type_info.long_id.generic_id; + if *inner_generic_id == StructType::ID { + return validate_const_struct_data(context, &inner_type_info, inner_data); + } else if *inner_generic_id == EnumType::ID { + return validate_const_enum_data(context, &inner_type_info, inner_data); + } else if *inner_generic_id == NonZeroType::ID { + return validate_const_nz_data(context, &inner_type_info, inner_data); + } + let type_range = if *inner_generic_id == ContractAddressType::id() { + Range::half_open(0, ContractAddressConstLibfuncWrapped::bound()) + } else if *inner_generic_id == ClassHashType::id() { + Range::half_open(0, ClassHashConstLibfuncWrapped::bound()) } else { - let type_range = Range::from_type_info(&inner_type_info)?; - let [GenericArg::Value(value)] = inner_data else { - return Err(SpecializationError::WrongNumberOfGenericArgs); - }; - if !(&type_range.lower <= value && value < &type_range.upper) { - return Err(SpecializationError::UnsupportedGenericArg); - } + Range::from_type_info(&inner_type_info)? + }; + let [GenericArg::Value(value)] = inner_data else { + return Err(SpecializationError::WrongNumberOfGenericArgs); + }; + if !(&type_range.lower <= value && value < &type_range.upper) { + return Err(SpecializationError::UnsupportedGenericArg); } Ok(()) } diff --git a/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json b/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json index a4bd46518b8..a4cee384724 100644 --- a/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json +++ b/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json @@ -25,7 +25,8 @@ "bounded_int_is_zero", "bounded_int_mul", "bounded_int_sub", - "bounded_int_trim", + "bounded_int_trim_max", + "bounded_int_trim_min", "bounded_int_wrap_non_zero", "box_forward_snapshot", "branch_align", diff --git a/crates/cairo-lang-starknet/cairo_level_tests/collections_test.cairo b/crates/cairo-lang-starknet/cairo_level_tests/collections_test.cairo index e0b2427e084..e5f492a7b49 100644 --- a/crates/cairo-lang-starknet/cairo_level_tests/collections_test.cairo +++ b/crates/cairo-lang-starknet/cairo_level_tests/collections_test.cairo @@ -1,8 +1,8 @@ use starknet::storage::{ - MutableVecTrait, StoragePathEntry, StoragePointerReadAccess, StoragePointerWriteAccess, + IntoIterRange, MutableVecTrait, StoragePathEntry, StoragePointerReadAccess, + StoragePointerWriteAccess, }; - #[starknet::contract] mod contract_with_map { use starknet::storage::Map; @@ -32,6 +32,51 @@ fn test_simple_member_write_to_map() { assert_eq!(vec_entry.read(), 1); } +#[test] +fn test_vec_iter() { + let mut mut_state = contract_with_vec::contract_state_for_testing(); + for i in 0..9_usize { + mut_state.simple.append().write(i); + }; + + let state = @contract_with_vec::contract_state_for_testing(); + let mut i = 0; + for entry in state.simple.into_iter_full_range() { + assert_eq!(entry.read(), i); + i += 1; + }; + assert_eq!(i, 9); + + let mut i = 2; + for entry in state.simple.into_iter_range(2..5) { + assert_eq!(entry.read(), i); + i += 1; + }; + assert_eq!(i, 5); +} + +#[test] +fn test_mut_vec_iter() { + let mut mut_state = contract_with_vec::contract_state_for_testing(); + for i in 0..9_usize { + mut_state.simple.append().write(i); + }; + + let mut i = 0; + for entry in mut_state.simple.into_iter_full_range() { + assert_eq!(entry.read(), i); + i += 1; + }; + assert_eq!(i, 9); + + let mut i = 2; + for entry in mut_state.simple.into_iter_range(2..5) { + assert_eq!(entry.read(), i); + i += 1; + }; + assert_eq!(i, 5); +} + #[test] fn test_simple_member_write_to_vec() { let mut map_contract_state = contract_with_map::contract_state_for_testing(); diff --git a/scripts/release_crates.sh b/scripts/release_crates.sh index 58aa67949af..03fee32838d 100755 --- a/scripts/release_crates.sh +++ b/scripts/release_crates.sh @@ -1,5 +1,9 @@ # An optional argument, '-s | --skip-first <#num_to_skip>', can be passed to skip the first #num_to_skip crates, otherwise SKIP_FIRST is set to 0. if [ "$1" == "-s" ] || [ "$1" == "--skip-first" ]; then + if [ -z "$2" ]; then + echo "Error: --skip-first requires a numeric argument." + exit 1 + fi SKIP_FIRST=$2 else SKIP_FIRST=0 diff --git a/tests/bug_samples/issue7031.cairo b/tests/bug_samples/issue7031.cairo index e6df5cdc41d..21d030e5c7c 100644 --- a/tests/bug_samples/issue7031.cairo +++ b/tests/bug_samples/issue7031.cairo @@ -1,5 +1,5 @@ pub trait IteratorEx, +Destruct, +Drop> { - fn advance_by( + fn advance_by_( ref self: T, n: usize, ) -> Result< (), NonZero, @@ -19,7 +19,7 @@ impl ItratorExImpl, +Destruct, +Drop> of Iter #[test] fn test_advance_by() { let mut iter = array![1_u8, 2, 3, 4].into_iter(); - assert_eq!(iter.advance_by(2), Result::Ok(())); + assert_eq!(iter.advance_by_(2), Result::Ok(())); assert_eq!(iter.next(), Option::Some(3)); - assert_eq!(iter.advance_by(0), Result::Ok(())); + assert_eq!(iter.advance_by_(0), Result::Ok(())); } diff --git a/tests/bug_samples/issue7083.cairo b/tests/bug_samples/issue7083.cairo new file mode 100644 index 00000000000..ab4f10fd0a5 --- /dev/null +++ b/tests/bug_samples/issue7083.cairo @@ -0,0 +1,10 @@ +fn main() { + let zero: ByteArray = "0"; + + let format_string = |acc: ByteArray, x: u8| { + format!("({acc} + {x})") + }; + + let result = format_string(zero, 1); + assert_eq!(result, "(0 + 1)"); +} diff --git a/tests/bug_samples/lib.cairo b/tests/bug_samples/lib.cairo index a3d01f3e257..9d035b98434 100644 --- a/tests/bug_samples/lib.cairo +++ b/tests/bug_samples/lib.cairo @@ -57,6 +57,7 @@ mod issue7031; mod issue7038; mod issue7060; mod issue7071; +mod issue7083; mod loop_break_in_match; mod loop_only_change; mod partial_param_local; diff --git a/tests/e2e_test_data/libfuncs/bounded_int b/tests/e2e_test_data/libfuncs/bounded_int index cda5140e51d..ee55a736f84 100644 --- a/tests/e2e_test_data/libfuncs/bounded_int +++ b/tests/e2e_test_data/libfuncs/bounded_int @@ -826,7 +826,7 @@ test::foo@0([0]: BoundedInt<-680564733841876926926749214863536422911, -1>) -> (N //! > ========================================================================== -//! > bounded_int_trim libfunc remove 0 below. +//! > bounded_int_trim_min libfunc remove 0. //! > test_runner_name SmallE2ETestRunner @@ -834,10 +834,10 @@ SmallE2ETestRunner //! > cairo extern type BoundedInt; type Res = core::internal::OptionRev>; -extern fn bounded_int_trim(value: T) -> Res nopanic; +extern fn bounded_int_trim_min(value: T) -> Res nopanic; fn foo(value: u8) -> Res { - bounded_int_trim::<_, 0>(value) + bounded_int_trim_min(value) } //! > casm @@ -855,14 +855,14 @@ type Unit = Struct [storable: true, drop: true, dup: true, zero_sized: type BoundedInt<1, 255> = BoundedInt<1, 255> [storable: true, drop: true, dup: true, zero_sized: false]; type core::internal::OptionRev::> = Enum>, Unit, BoundedInt<1, 255>> [storable: true, drop: true, dup: true, zero_sized: false]; -libfunc bounded_int_trim = bounded_int_trim; +libfunc bounded_int_trim_min = bounded_int_trim_min; libfunc branch_align = branch_align; libfunc struct_construct = struct_construct; libfunc enum_init>, 0> = enum_init>, 0>; libfunc store_temp>> = store_temp>>; libfunc enum_init>, 1> = enum_init>, 1>; -bounded_int_trim([0]) { fallthrough() 6([1]) }; // 0 +bounded_int_trim_min([0]) { fallthrough() 6([1]) }; // 0 branch_align() -> (); // 1 struct_construct() -> ([2]); // 2 enum_init>, 0>([2]) -> ([3]); // 3 @@ -880,7 +880,7 @@ test::foo: OrderedHashMap({Const: 300}) //! > ========================================================================== -//! > bounded_int_trim libfunc remove 0 above. +//! > bounded_int_trim_max libfunc remove 0. //! > test_runner_name SmallE2ETestRunner @@ -888,10 +888,10 @@ SmallE2ETestRunner //! > cairo extern type BoundedInt; type Res = core::internal::OptionRev>; -extern fn bounded_int_trim(value: T) -> Res nopanic; +extern fn bounded_int_trim_max(value: T) -> Res nopanic; fn foo(value: BoundedInt<-0xff, 0>) -> Res { - bounded_int_trim::<_, 0>(value) + bounded_int_trim_max(value) } //! > casm @@ -909,14 +909,14 @@ type Unit = Struct [storable: true, drop: true, dup: true, zero_sized: type BoundedInt<-255, -1> = BoundedInt<-255, -1> [storable: true, drop: true, dup: true, zero_sized: false]; type core::internal::OptionRev::> = Enum>, Unit, BoundedInt<-255, -1>> [storable: true, drop: true, dup: true, zero_sized: false]; -libfunc bounded_int_trim, 0> = bounded_int_trim, 0>; +libfunc bounded_int_trim_max> = bounded_int_trim_max>; libfunc branch_align = branch_align; libfunc struct_construct = struct_construct; libfunc enum_init>, 0> = enum_init>, 0>; libfunc store_temp>> = store_temp>>; libfunc enum_init>, 1> = enum_init>, 1>; -bounded_int_trim, 0>([0]) { fallthrough() 6([1]) }; // 0 +bounded_int_trim_max>([0]) { fallthrough() 6([1]) }; // 0 branch_align() -> (); // 1 struct_construct() -> ([2]); // 2 enum_init>, 0>([2]) -> ([3]); // 3 @@ -934,7 +934,7 @@ test::foo: OrderedHashMap({Const: 300}) //! > ========================================================================== -//! > bounded_int_trim libfunc remove non-0 below. +//! > bounded_int_trim_min libfunc remove non-0. //! > test_runner_name SmallE2ETestRunner @@ -942,10 +942,10 @@ SmallE2ETestRunner //! > cairo extern type BoundedInt; type Res = core::internal::OptionRev>; -extern fn bounded_int_trim(value: T) -> Res nopanic; +extern fn bounded_int_trim_min(value: T) -> Res nopanic; fn foo(value: i8) -> Res { - bounded_int_trim::<_, -0x80>(value) + bounded_int_trim_min(value) } //! > casm @@ -964,14 +964,14 @@ type Unit = Struct [storable: true, drop: true, dup: true, zero_sized: type BoundedInt<-127, 127> = BoundedInt<-127, 127> [storable: true, drop: true, dup: true, zero_sized: false]; type core::internal::OptionRev::> = Enum>, Unit, BoundedInt<-127, 127>> [storable: true, drop: true, dup: true, zero_sized: false]; -libfunc bounded_int_trim = bounded_int_trim; +libfunc bounded_int_trim_min = bounded_int_trim_min; libfunc branch_align = branch_align; libfunc struct_construct = struct_construct; libfunc enum_init>, 0> = enum_init>, 0>; libfunc store_temp>> = store_temp>>; libfunc enum_init>, 1> = enum_init>, 1>; -bounded_int_trim([0]) { fallthrough() 6([1]) }; // 0 +bounded_int_trim_min([0]) { fallthrough() 6([1]) }; // 0 branch_align() -> (); // 1 struct_construct() -> ([2]); // 2 enum_init>, 0>([2]) -> ([3]); // 3 @@ -989,7 +989,7 @@ test::foo: OrderedHashMap({Const: 400}) //! > ========================================================================== -//! > bounded_int_trim libfunc remove non-0 above. +//! > bounded_int_trim_max libfunc remove non-0. //! > test_runner_name SmallE2ETestRunner @@ -997,10 +997,10 @@ SmallE2ETestRunner //! > cairo extern type BoundedInt; type Res = core::internal::OptionRev>; -extern fn bounded_int_trim(value: T) -> Res nopanic; +extern fn bounded_int_trim_max(value: T) -> Res nopanic; fn foo(value: u8) -> Res { - bounded_int_trim::<_, 0xff>(value) + bounded_int_trim_max(value) } //! > casm @@ -1019,14 +1019,14 @@ type Unit = Struct [storable: true, drop: true, dup: true, zero_sized: type BoundedInt<0, 254> = BoundedInt<0, 254> [storable: true, drop: true, dup: true, zero_sized: false]; type core::internal::OptionRev::> = Enum>, Unit, BoundedInt<0, 254>> [storable: true, drop: true, dup: true, zero_sized: false]; -libfunc bounded_int_trim = bounded_int_trim; +libfunc bounded_int_trim_max = bounded_int_trim_max; libfunc branch_align = branch_align; libfunc struct_construct = struct_construct; libfunc enum_init>, 0> = enum_init>, 0>; libfunc store_temp>> = store_temp>>; libfunc enum_init>, 1> = enum_init>, 1>; -bounded_int_trim([0]) { fallthrough() 6([1]) }; // 0 +bounded_int_trim_max([0]) { fallthrough() 6([1]) }; // 0 branch_align() -> (); // 1 struct_construct() -> ([2]); // 2 enum_init>, 0>([2]) -> ([3]); // 3