diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 528856265..fdbb7054f 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -1,5 +1,6 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; - +use orion::operators::tensor::math::{exp::exp_upcast, arithmetic::div_downcast}; +use orion::numbers::fixed_point::core::FixedTrait; /// Cf: NNTrait::softmax docstring fn softmax< @@ -19,3 +20,30 @@ fn softmax< return softmax; } +/// Cf: NNTrait::softmax docstring +fn softmaxWide< + T, + TMAG, + W, + WMAG, + impl TTensor: TensorTrait, + impl WTensor: TensorTrait, + impl TDiv: Div, + impl TIntoW: Into, + impl WTryIntoT: TryInto, + impl TCopy: Copy, + impl TDrop: Drop, + impl WCopy: Copy, + impl WDrop: Drop, + impl TFixed: FixedTrait, + impl WFixed: FixedTrait, +>( + z: @Tensor, axis: usize +) -> Tensor { + let exp_tensor: Tensor = exp_upcast(*z); + let sum = exp_tensor.reduce_sum(axis, true); + let softmax: Tensor = div_downcast(@exp_tensor, @sum); + + return softmax; +} + diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index a0094de29..b940d8742 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -7,6 +7,12 @@ use orion::numbers::fixed_point::implementations::fp16x16::core::FP16x16; use orion::operators::tensor::implementations::tensor_fp16x16::{ FP16x16Tensor, FP16x16TensorDiv, FP16x16TensorAdd }; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + FP16x16WImpl, FP16x16WTryIntoFP16x16, FP16x16W, FP16x16IntoFP16x16W +}; +use orion::operators::tensor::implementations::tensor_fp16x16wide::{ + FP16x16WTensor, FP16x16WTensorDiv, FP16x16WTensorAdd +}; impl FP16x16NN of NNTrait { fn relu(tensor: @Tensor) -> Tensor { @@ -18,7 +24,7 @@ impl FP16x16NN of NNTrait { } fn softmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::softmax::softmax(tensor, axis) + functional::softmax::softmaxWide::(tensor, axis) } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { diff --git a/src/operators/tensor/math/arithmetic.cairo b/src/operators/tensor/math/arithmetic.cairo index 075f565b8..06879a4af 100644 --- a/src/operators/tensor/math/arithmetic.cairo +++ b/src/operators/tensor/math/arithmetic.cairo @@ -304,3 +304,44 @@ fn saturated_div< return TensorTrait::::new(broadcasted_shape, result.span()); } + +fn div_downcast< + T, + D, + impl TTensor: TensorTrait, + impl DTensor: TensorTrait, + impl DDiv: Div, + impl TTryIntoD: TryInto, + impl TCopy: Copy, + impl TDrop: Drop, + impl DCopy: Copy, + impl DDrop: Drop +>( + self: @Tensor, other: @Tensor +) -> Tensor { + let broadcasted_shape = broadcast_shape(*self.shape, *other.shape); + let mut result = ArrayTrait::new(); + + let num_elements = len_from_shape(broadcasted_shape); + + let mut n: usize = 0; + loop { + let indices_broadcasted = unravel_index(n, broadcasted_shape); + + let indices_self = broadcast_index_mapping(*self.shape, indices_broadcasted); + let indices_other = broadcast_index_mapping(*other.shape, indices_broadcasted); + + result + .append( + (*(*self.data)[indices_self]).try_into().unwrap() + / (*(*other.data)[indices_other]).try_into().unwrap() + ); + + n += 1; + if n == num_elements { + break (); + }; + }; + + return TensorTrait::::new(broadcasted_shape, result.span()); +} diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 3ba1e97d7..5ba161030 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -34,3 +34,37 @@ fn exp< return TensorTrait::new(self.shape, result.span()); } + +/// Cf: TensorTrait::exp docstring +fn exp_upcast< + T, + MAG, + W, + WMAG, + impl TFixedTrait: FixedTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, + impl WFixedTrait: FixedTrait, + impl WTensor: TensorTrait, + impl WCopy: Copy, + impl WDrop: Drop, + impl TIntoW: Into, +>( + mut self: Tensor +) -> Tensor { + let mut result = ArrayTrait::new(); + + loop { + match self.data.pop_front() { + Option::Some(item) => { + result.append((TIntoW::into(*item)).exp()); + }, + Option::None(_) => { + break; + } + }; + }; + + return TensorTrait::new(self.shape, result.span()); +}