diff --git a/packages/data-structures/src/lib.cairo b/packages/data-structures/src/lib.cairo index a742cedc3..acc3fe699 100644 --- a/packages/data-structures/src/lib.cairo +++ b/packages/data-structures/src/lib.cairo @@ -1,3 +1,4 @@ pub mod span; +pub mod vec; pub use {span::FixedSpanMath, span::SpanMathTrait}; diff --git a/packages/data-structures/src/span.cairo b/packages/data-structures/src/span.cairo index d31ddf755..30666bc02 100644 --- a/packages/data-structures/src/span.cairo +++ b/packages/data-structures/src/span.cairo @@ -1,8 +1,5 @@ -pub mod span_fixed; - use orion_numbers::FixedTrait; - -pub use span_fixed::FixedSpanMath; +use core::ops::AddAssign; pub trait SpanMathTrait { fn arange(n: u32) -> Span; @@ -12,3 +9,207 @@ pub trait SpanMathTrait { fn prod(self: Span) -> T; fn sum(self: Span) -> T; } + +pub impl FixedSpanMath< + T, + S, + +FixedTrait, + +Into, + +Drop, + +Copy, + +Add, + +Mul, + +AddAssign, + +PartialOrd +> of SpanMathTrait { + fn arange(n: u32) -> Span { + let mut i = 0; + let mut arr = array![]; + while i < n { + arr.append(FixedTrait::new_unscaled(i.into())); + i += 1; + }; + + arr.span() + } + + fn dot(self: Span, other: Span) -> T { + let mut i = 0; + let mut acc = FixedTrait::ZERO(); + while i != self.len() { + acc += *self.at(i) * *other.at(i); + i += 1; + }; + + acc + } + + fn max(self: Span) -> T { + assert(self.len() > 0, 'span cannot be empty'); + + let mut max = FixedTrait::MIN(); + let mut self = self; + loop { + match self.pop_front() { + Option::Some(item) => { if *item > max { + max = *item; + } }, + Option::None => { break max; }, + } + } + } + + fn min(self: Span) -> T { + assert(self.len() > 0, 'span cannot be empty'); + + let mut min = FixedTrait::MAX(); + let mut self = self; + loop { + match self.pop_front() { + Option::Some(item) => { if *item < min { + min = *item; + } }, + Option::None => { break min; }, + } + } + } + + fn prod(self: Span) -> T { + let mut prod = FixedTrait::ONE(); + let mut self = self; + loop { + match self.pop_front() { + Option::Some(v) => { prod = prod * *v; }, + Option::None => { break prod; } + }; + } + } + + fn sum(self: Span) -> T { + let mut prod = FixedTrait::ZERO(); + let mut self = self; + loop { + match self.pop_front() { + Option::Some(v) => { prod = prod + *v; }, + Option::None => { break prod; } + }; + } + } +} + + +#[cfg(test)] +mod tests { + use super::{FixedSpanMath, FixedTrait}; + use orion_numbers::{F64, F64Impl, f64::helpers::assert_precise}; + + + #[test] + fn test_arange() { + let n = 6; + let res = FixedSpanMath::arange(n); + + let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480].span(); + + assert_precise(*res.at(0), *x.at(0), 'should be equal', Option::None); + assert_precise(*res.at(1), *x.at(1), 'should be equal', Option::None); + assert_precise(*res.at(2), *x.at(2), 'should be equal', Option::None); + assert_precise(*res.at(3), *x.at(3), 'should be equal', Option::None); + assert_precise(*res.at(4), *x.at(4), 'should be equal', Option::None); + assert_precise(*res.at(5), *x.at(5), 'should be equal', Option::None); + } + + #[test] + fn test_dot() { + let x = array![ + F64Impl::new(0), + F64Impl::new(4294967296), + F64Impl::new(8589934592), + F64Impl::new(12884901888), + F64Impl::new(17179869184), + F64Impl::new(21474836480) + ] + .span(); // 0, 1, 2, 3, 4, 5 + let y = array![ + F64Impl::new(0), + F64Impl::new(8589934592), + F64Impl::new(17179869184), + F64Impl::new(25769803776), + F64Impl::new(34359738368), + F64Impl::new(42949672960) + ] + .span(); // 0, 2, 4, 6, 8, 10 + let result = FixedSpanMath::dot(x, y); + + assert_precise( + result, (F64Impl::new(472446402560)).into(), 'should be equal', Option::None + ); + } + + #[test] + fn test_max() { + let x = array![ + F64Impl::new(0), + F64Impl::new(4294967296), + F64Impl::new(8589934592), + F64Impl::new(12884901888), + F64Impl::new(17179869184), + F64Impl::new(21474836480) + ] + .span(); // 0, 1, 2, 3, 4, 5 + + let result = FixedSpanMath::max(x); + + assert_precise(result, 21474836480, 'should be equal', Option::None); + } + + #[test] + fn test_min() { + let x = array![ + F64Impl::new(0), + F64Impl::new(4294967296), + F64Impl::new(8589934592), + F64Impl::new(12884901888), + F64Impl::new(17179869184), + F64Impl::new(21474836480) + ] + .span(); // 0, 1, 2, 3, 4, 5 + + let result = FixedSpanMath::min(x); + + assert_precise(result, 0, 'should be equal', Option::None); + } + + #[test] + fn test_prod() { + let x = array![ + F64Impl::new(0), + F64Impl::new(4294967296), + F64Impl::new(8589934592), + F64Impl::new(12884901888), + F64Impl::new(17179869184), + F64Impl::new(21474836480) + ] + .span(); // 0, 1, 2, 3, 4, 5 + + let result = FixedSpanMath::prod(x); + + assert_precise(result, 0, 'should be equal', Option::None); + } + #[test] + fn test_sum() { + let x = array![ + F64Impl::new(0), + F64Impl::new(4294967296), + F64Impl::new(8589934592), + F64Impl::new(12884901888), + F64Impl::new(17179869184), + F64Impl::new(21474836480) + ] + .span(); // 0, 1, 2, 3, 4, 5 + + let result = FixedSpanMath::sum(x); + + assert_precise(result, 64424509440, 'should be equal', Option::None); + } +} diff --git a/packages/data-structures/src/span/span_fixed.cairo b/packages/data-structures/src/span/span_fixed.cairo deleted file mode 100644 index c61631c5b..000000000 --- a/packages/data-structures/src/span/span_fixed.cairo +++ /dev/null @@ -1,210 +0,0 @@ -use core::ops::AddAssign; - -use orion_numbers::{FixedTrait}; - -use super::SpanMathTrait; - - -pub impl FixedSpanMath< - T, - S, - +FixedTrait, - +Into, - +Drop, - +Copy, - +Add, - +Mul, - +AddAssign, - +PartialOrd -> of SpanMathTrait { - fn arange(n: u32) -> Span { - let mut i = 0; - let mut arr = array![]; - while i < n { - arr.append(FixedTrait::new_unscaled(i.into())); - i += 1; - }; - - arr.span() - } - - fn dot(self: Span, other: Span) -> T { - let mut i = 0; - let mut acc = FixedTrait::ZERO(); - while i != self.len() { - acc += *self.at(i) * *other.at(i); - i += 1; - }; - - acc - } - - fn max(self: Span) -> T { - assert(self.len() > 0, 'span cannot be empty'); - - let mut max = FixedTrait::MIN(); - let mut self = self; - loop { - match self.pop_front() { - Option::Some(item) => { if *item > max { - max = *item; - } }, - Option::None => { break max; }, - } - } - } - - fn min(self: Span) -> T { - assert(self.len() > 0, 'span cannot be empty'); - - let mut min = FixedTrait::MAX(); - let mut self = self; - loop { - match self.pop_front() { - Option::Some(item) => { if *item < min { - min = *item; - } }, - Option::None => { break min; }, - } - } - } - - fn prod(self: Span) -> T { - let mut prod = FixedTrait::ONE(); - let mut self = self; - loop { - match self.pop_front() { - Option::Some(v) => { prod = prod * *v; }, - Option::None => { break prod; } - }; - } - } - - fn sum(self: Span) -> T { - let mut prod = FixedTrait::ZERO(); - let mut self = self; - loop { - match self.pop_front() { - Option::Some(v) => { prod = prod + *v; }, - Option::None => { break prod; } - }; - } - } -} - - -#[cfg(test)] -mod tests { - use super::{FixedSpanMath, FixedTrait}; - use orion_numbers::{F64, F64Impl, f64::helpers::assert_precise}; - - - #[test] - fn test_arange() { - let n = 6; - let res = FixedSpanMath::arange(n); - - let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480].span(); - - assert_precise(*res.at(0), *x.at(0), 'should be equal', Option::None); - assert_precise(*res.at(1), *x.at(1), 'should be equal', Option::None); - assert_precise(*res.at(2), *x.at(2), 'should be equal', Option::None); - assert_precise(*res.at(3), *x.at(3), 'should be equal', Option::None); - assert_precise(*res.at(4), *x.at(4), 'should be equal', Option::None); - assert_precise(*res.at(5), *x.at(5), 'should be equal', Option::None); - } - - #[test] - fn test_dot() { - let x = array![ - F64Impl::new(0), - F64Impl::new(4294967296), - F64Impl::new(8589934592), - F64Impl::new(12884901888), - F64Impl::new(17179869184), - F64Impl::new(21474836480) - ] - .span(); // 0, 1, 2, 3, 4, 5 - let y = array![ - F64Impl::new(0), - F64Impl::new(8589934592), - F64Impl::new(17179869184), - F64Impl::new(25769803776), - F64Impl::new(34359738368), - F64Impl::new(42949672960) - ] - .span(); // 0, 2, 4, 6, 8, 10 - let result = FixedSpanMath::dot(x, y); - - assert_precise( - result, (F64Impl::new(472446402560)).into(), 'should be equal', Option::None - ); - } - - #[test] - fn test_max() { - let x = array![ - F64Impl::new(0), - F64Impl::new(4294967296), - F64Impl::new(8589934592), - F64Impl::new(12884901888), - F64Impl::new(17179869184), - F64Impl::new(21474836480) - ] - .span(); // 0, 1, 2, 3, 4, 5 - - let result = FixedSpanMath::max(x); - - assert_precise(result, 21474836480, 'should be equal', Option::None); - } - - #[test] - fn test_min() { - let x = array![ - F64Impl::new(0), - F64Impl::new(4294967296), - F64Impl::new(8589934592), - F64Impl::new(12884901888), - F64Impl::new(17179869184), - F64Impl::new(21474836480) - ] - .span(); // 0, 1, 2, 3, 4, 5 - - let result = FixedSpanMath::min(x); - - assert_precise(result, 0, 'should be equal', Option::None); - } - - #[test] - fn test_prod() { - let x = array![ - F64Impl::new(0), - F64Impl::new(4294967296), - F64Impl::new(8589934592), - F64Impl::new(12884901888), - F64Impl::new(17179869184), - F64Impl::new(21474836480) - ] - .span(); // 0, 1, 2, 3, 4, 5 - - let result = FixedSpanMath::prod(x); - - assert_precise(result, 0, 'should be equal', Option::None); - } - #[test] - fn test_sum() { - let x = array![ - F64Impl::new(0), - F64Impl::new(4294967296), - F64Impl::new(8589934592), - F64Impl::new(12884901888), - F64Impl::new(17179869184), - F64Impl::new(21474836480) - ] - .span(); // 0, 1, 2, 3, 4, 5 - - let result = FixedSpanMath::sum(x); - - assert_precise(result, 64424509440, 'should be equal', Option::None); - } -} diff --git a/packages/data-structures/src/vec.cairo b/packages/data-structures/src/vec.cairo new file mode 100644 index 000000000..6e43b2d54 --- /dev/null +++ b/packages/data-structures/src/vec.cairo @@ -0,0 +1,104 @@ +use core::nullable::NullableImpl; +use core::num::traits::WrappingAdd; +use core::ops::index::Index; + +pub trait VecTrait { + /// Creates a new V instance. + /// Parameters + /// * size The size of the vec to initialize. + /// Returns + /// * V The new vec instance. + fn new(size: usize) -> V; + + /// Returns the item at the given index, or None if the index is out of bounds. + /// Parameters + /// * self The vec instance. + /// * index The index of the item to get. + /// Returns + /// * Option The item at the given index, or None if the index is out of bounds. + fn get(ref self: V, index: usize) -> Option; + + /// Returns the item at the given index, or panics if the index is out of bounds. + /// Parameters + /// * self The vec instance. + /// * index The index of the item to get. + /// Returns + /// * T The item at the given index. + fn at(ref self: V, index: usize) -> T; + + /// Pushes a new item to the vec. + /// Parameters + /// * self The vec instance. + /// * value The value to push onto the vec. + fn push(ref self: V, value: T); + + /// Sets the item at the given index to the given value. + /// Panics if the index is out of bounds. + /// Parameters + /// * self The vec instance. + /// * index The index of the item to set. + /// * value The value to set the item to. + fn set(ref self: V, index: usize, value: T); + + /// Returns the length of the vec. + /// Parameters + /// * self The vec instance. + /// Returns + /// * usize The length of the vec. + fn len(self: @V) -> usize; +} + +impl VecIndex> of Index { + type Target = T; + + #[inline(always)] + fn index(ref self: V, index: usize) -> T { + self.at(index) + } +} + +pub struct NullableVec { + items: Felt252Dict>, + len: usize, +} + +impl DestructNullableVec> of Destruct> { + fn destruct(self: NullableVec) nopanic { + self.items.squash(); + } +} + +use core::num::traits::Zero; + +impl NullableVecImpl, +Drop, +Copy> of VecTrait, T> { + fn new(size: usize) -> NullableVec { + NullableVec { items: Default::default(), len: size } + } + + fn get(ref self: NullableVec, index: usize) -> Option { + if index < self.len() { + Option::Some(self.items.get(index.into()).deref_or(Zero::zero())) + } else { + Option::None + } + } + + fn at(ref self: NullableVec, index: usize) -> T { + assert(index < self.len(), 'Index out of bounds'); + self.items.get(index.into()).deref_or(Zero::zero()) + } + + fn push(ref self: NullableVec, value: T) { + self.items.insert(self.len.into(), NullableImpl::new(value)); + self.len = self.len.wrapping_add(1); + } + + fn set(ref self: NullableVec, index: usize, value: T) { + assert(index < self.len(), 'Index out of bounds'); + self.items.insert(index.into(), NullableImpl::new(value)); + } + + fn len(self: @NullableVec) -> usize { + *self.len + } +} diff --git a/packages/deep-learning/Scarb.toml b/packages/deep-learning/Scarb.toml index 47c02f3ee..f41a7021c 100644 --- a/packages/deep-learning/Scarb.toml +++ b/packages/deep-learning/Scarb.toml @@ -10,6 +10,7 @@ cairo_test = "2.7.0" [dependencies] orion_numbers = { path = "../numbers" } +orion_data_structures = { path = "../data-structures" } [cairo] enable-gas = false diff --git a/packages/deep-learning/src/lib.cairo b/packages/deep-learning/src/lib.cairo index f15f51e01..489ba4e62 100644 --- a/packages/deep-learning/src/lib.cairo +++ b/packages/deep-learning/src/lib.cairo @@ -3,9 +3,22 @@ pub(crate) mod utils; pub use ops::binary::{BinaryOpMetadata, tensor_add, tensor_mul, tensor_rem, tensor_lt}; pub use ops::unary::{tensor_log2, tensor_exp2, tensor_sin, tensor_sqrt, tensor_recip}; -pub use ops::reduce::{tensor_reduce_sum, ReduceOpMetadata}; +pub use ops::reduce::{tensor_reduce_sum_1d, tensor_reduce_sum_nd, ReduceOpMetadata}; + #[derive(Drop, Copy)] pub struct Tensor { pub data: Span, } + +use orion_data_structures::vec::NullableVec; + +pub struct MutTensor { + pub data: NullableVec, +} + +pub impl MutTensorDestruct> of Destruct> { + fn destruct(self: MutTensor) nopanic { + self.data.destruct() + } +} diff --git a/packages/deep-learning/src/ops/reduce.cairo b/packages/deep-learning/src/ops/reduce.cairo index f54476f9f..3e8d9aa90 100644 --- a/packages/deep-learning/src/ops/reduce.cairo +++ b/packages/deep-learning/src/ops/reduce.cairo @@ -1,6 +1,7 @@ use core::num::traits::Zero; use core::fmt::Debug; -use orion_dl::Tensor; +use orion_dl::{Tensor, MutTensor}; +use orion_data_structures::vec::{NullableVec, VecTrait}; use core::ops::AddAssign; @@ -10,62 +11,71 @@ pub(crate) struct ReduceOpMetadata { output_size: usize, } -pub(crate) fn tensor_reduce_sum< - T, +Add, +AddAssign, +Zero, +Copy, +Drop, +Debug ->( - mut input: Tensor, ref metadata: ReduceOpMetadata +pub(crate) fn tensor_reduce_sum_1d, +Zero, +Copy, +Drop, +Debug>( + mut input: Tensor ) -> Tensor { + let mut result = Zero::::zero(); + + loop { + match input.data.pop_front() { + Option::Some(input_value) => { result = result + *input_value; }, + Option::None(_) => { break; } + }; + }; + let mut result_data = ArrayTrait::new(); - let mut partial_sums = ArrayTrait::new(); - let mut partial_indices = ArrayTrait::new(); + result_data.append(result); + + Tensor { data: result_data.span() } +} + +pub(crate) fn tensor_reduce_sum_nd, +Copy, +Drop, +Debug, +Zero>( + mut input: Tensor, ref metadata: ReduceOpMetadata +) -> MutTensor { + let mut result_data: NullableVec = VecTrait::new(metadata.output_size); loop { match input.data.pop_front() { Option::Some(input_value) => { match metadata.output_indices.pop_front() { Option::Some(output_index) => { - partial_sums.append(*input_value); - partial_indices.append(*output_index); + let current_sum = result_data.at(*output_index); + result_data.set(*output_index, current_sum + *input_value); }, Option::None(_) => { break; // This should never happen if metadata is correct } } }, - Option::None(_) => { break; } - }; - }; - - // Combine partial sums - let mut i = 0; - loop { - if i == metadata.output_size { - break; + Option::None => { break; }, } - let mut current_sum = Zero::::zero(); - let mut partial_sums_span = partial_sums.span(); - let mut partial_indices_span = partial_indices.span(); - loop { - match partial_indices_span.pop_front() { - Option::Some(index) => { - let sum = partial_sums_span.pop_front().unwrap(); - if *index == i { - current_sum = current_sum + *sum - } - }, - Option::None(_) => { break; } - } - }; - result_data.append(current_sum); - i += 1; }; - Tensor { data: result_data.span() } + MutTensor { data: result_data } } #[cfg(test)] mod tests { - use super::{Tensor, ReduceOpMetadata, tensor_reduce_sum}; + use super::{ + Tensor, MutTensor, VecTrait, NullableVec, ReduceOpMetadata, tensor_reduce_sum_1d, + tensor_reduce_sum_nd + }; + + + #[test] + #[available_gas(20000000)] + fn test_tensor_reduce_sum_1d() { + // Test case: Reduce sum along axis 0 for a 1D tensor (full reduction) + let input_data: Array = array![1, 2, 3, 4, 5]; + + let input = Tensor { data: input_data.span() }; + + let result = tensor_reduce_sum_1d(input); + + let expected = array![15]; // [1+2+3+4+5] + assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); + assert_eq!(*result.data.at(0), *expected[0], "Incorrect sum"); + } #[test] #[available_gas(20000000)] @@ -79,12 +89,12 @@ mod tests { output_indices: output_indices.span(), output_size: 2, }; - let result = tensor_reduce_sum(input, ref metadata); + let mut result = tensor_reduce_sum_nd(input, ref metadata); let expected = array![6, 15]; // [1+2+3, 4+5+6] assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); - assert_eq!(*result.data.at(0), *expected[0], "Incorrect first sum"); - assert_eq!(*result.data.at(1), *expected[1], "Incorrect second sum"); + assert_eq!(result.data.at(0), *expected[0], "Incorrect first sum"); + assert_eq!(result.data.at(1), *expected[1], "Incorrect second sum"); } #[test] @@ -99,7 +109,7 @@ mod tests { output_indices: output_indices.span(), output_size: 4, }; - let result = tensor_reduce_sum(input, ref metadata); + let mut result = tensor_reduce_sum_nd(input, ref metadata); let expected = array![6, 8, 10, 12]; // [1+5, 2+6, 3+7, 4+8] assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); @@ -108,7 +118,7 @@ mod tests { if i == expected.len() { break; } - assert_eq!(*result.data.at(i), *expected[i], "Incorrect sum at index"); + assert_eq!(result.data.at(i), *expected[i], "Incorrect sum at index"); i += 1; }; } @@ -125,7 +135,7 @@ mod tests { output_indices: output_indices.span(), output_size: 4, }; - let result = tensor_reduce_sum(input, ref metadata); + let mut result = tensor_reduce_sum_nd(input, ref metadata); let expected = array![9, 12, 27, 30]; // [1+3+5, 2+4+6, 7+9+11, 8+10+12] assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); @@ -134,30 +144,11 @@ mod tests { if i == expected.len() { break; } - assert_eq!(*result.data.at(i), *expected[i], "Incorrect sum at index"); + assert_eq!(result.data.at(i), *expected[i], "Incorrect sum at index"); i += 1; }; } - #[test] - #[available_gas(20000000)] - fn test_tensor_reduce_sum_1d() { - // Test case: Reduce sum along axis 0 for a 1D tensor (full reduction) - let input_data: Array = array![1, 2, 3, 4, 5]; - let output_indices: Array = array![0, 0, 0, 0, 0]; - - let input = Tensor { data: input_data.span() }; - let mut metadata = ReduceOpMetadata { - output_indices: output_indices.span(), output_size: 1, - }; - - let result = tensor_reduce_sum(input, ref metadata); - - let expected = array![15]; // [1+2+3+4+5] - assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); - assert_eq!(*result.data.at(0), *expected[0], "Incorrect sum"); - } - #[test] #[available_gas(20000000)] fn test_tensor_reduce_sum_4d() { @@ -174,7 +165,7 @@ mod tests { output_indices: output_indices.span(), output_size: 8, }; - let result = tensor_reduce_sum(input, ref metadata); + let mut result = tensor_reduce_sum_nd(input, ref metadata); let expected = array![9, 12, 27, 30, 45, 48, 63, 66]; // [1+3+5, 2+4+6, 7+9+11, 8+10+12, 13+15+17, 14+16+18, 19+21+23, 20+22+24] @@ -184,7 +175,7 @@ mod tests { if i == expected.len() { break; } - assert_eq!(*result.data.at(i), *expected[i], "Incorrect sum at index"); + assert_eq!(result.data.at(i), *expected[i], "Incorrect sum at index"); i += 1; }; }