Skip to content

Commit

Permalink
implement softmaxWide
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Oct 22, 2023
1 parent 53181cf commit fc98e82
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 2 deletions.
30 changes: 29 additions & 1 deletion src/operators/nn/functional/softmax.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use orion::operators::tensor::core::{Tensor, TensorTrait};

use orion::operators::tensor::math::{exp::exp_upcast, arithmetic::div_downcast};
use orion::numbers::fixed_point::core::FixedTrait;

/// Cf: NNTrait::softmax docstring
fn softmax<
Expand All @@ -19,3 +20,30 @@ fn softmax<
return softmax;
}

/// Cf: NNTrait::softmax docstring
fn softmaxWide<
T,
TMAG,
W,
WMAG,
impl TTensor: TensorTrait<T>,
impl WTensor: TensorTrait<W>,
impl TDiv: Div<T>,
impl TIntoW: Into<T, W>,
impl WTryIntoT: TryInto<W, T>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>,
impl WCopy: Copy<W>,
impl WDrop: Drop<W>,
impl TFixed: FixedTrait<T, TMAG>,
impl WFixed: FixedTrait<W, WMAG>,
>(
z: @Tensor<T>, axis: usize
) -> Tensor<T> {
let exp_tensor: Tensor<W> = exp_upcast(*z);
let sum = exp_tensor.reduce_sum(axis, true);
let softmax: Tensor<T> = div_downcast(@exp_tensor, @sum);

return softmax;
}

8 changes: 7 additions & 1 deletion src/operators/nn/implementations/nn_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ use orion::numbers::fixed_point::implementations::fp16x16::core::FP16x16;
use orion::operators::tensor::implementations::tensor_fp16x16::{
FP16x16Tensor, FP16x16TensorDiv, FP16x16TensorAdd
};
use orion::numbers::fixed_point::implementations::fp16x16wide::core::{
FP16x16WImpl, FP16x16WTryIntoFP16x16, FP16x16W, FP16x16IntoFP16x16W
};
use orion::operators::tensor::implementations::tensor_fp16x16wide::{
FP16x16WTensor, FP16x16WTensorDiv, FP16x16WTensorAdd
};

impl FP16x16NN of NNTrait<FP16x16> {
fn relu(tensor: @Tensor<FP16x16>) -> Tensor<FP16x16> {
Expand All @@ -18,7 +24,7 @@ impl FP16x16NN of NNTrait<FP16x16> {
}

fn softmax(tensor: @Tensor<FP16x16>, axis: usize) -> Tensor<FP16x16> {
functional::softmax::softmax(tensor, axis)
functional::softmax::softmaxWide::<FP16x16, u32, FP16x16W, u64>(tensor, axis)
}

fn logsoftmax(tensor: @Tensor<FP16x16>, axis: usize) -> Tensor<FP16x16> {
Expand Down
41 changes: 41 additions & 0 deletions src/operators/tensor/math/arithmetic.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,44 @@ fn saturated_div<

return TensorTrait::<Q>::new(broadcasted_shape, result.span());
}

fn div_downcast<
T,
D,
impl TTensor: TensorTrait<T>,
impl DTensor: TensorTrait<D>,
impl DDiv: Div<D>,
impl TTryIntoD: TryInto<T, D>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>,
impl DCopy: Copy<D>,
impl DDrop: Drop<D>
>(
self: @Tensor<T>, other: @Tensor<T>
) -> Tensor<D> {
let broadcasted_shape = broadcast_shape(*self.shape, *other.shape);
let mut result = ArrayTrait::new();

let num_elements = len_from_shape(broadcasted_shape);

let mut n: usize = 0;
loop {
let indices_broadcasted = unravel_index(n, broadcasted_shape);

let indices_self = broadcast_index_mapping(*self.shape, indices_broadcasted);
let indices_other = broadcast_index_mapping(*other.shape, indices_broadcasted);

result
.append(
(*(*self.data)[indices_self]).try_into().unwrap()
/ (*(*other.data)[indices_other]).try_into().unwrap()
);

n += 1;
if n == num_elements {
break ();
};
};

return TensorTrait::<D>::new(broadcasted_shape, result.span());
}
34 changes: 34 additions & 0 deletions src/operators/tensor/math/exp.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,37 @@ fn exp<

return TensorTrait::new(self.shape, result.span());
}

/// Cf: TensorTrait::exp docstring
fn exp_upcast<
T,
MAG,
W,
WMAG,
impl TFixedTrait: FixedTrait<T, MAG>,
impl TTensor: TensorTrait<T>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>,
impl WFixedTrait: FixedTrait<W, WMAG>,
impl WTensor: TensorTrait<W>,
impl WCopy: Copy<W>,
impl WDrop: Drop<W>,
impl TIntoW: Into<T, W>,
>(
mut self: Tensor<T>
) -> Tensor<W> {
let mut result = ArrayTrait::new();

loop {
match self.data.pop_front() {
Option::Some(item) => {
result.append((TIntoW::into(*item)).exp());
},
Option::None(_) => {
break;
}
};
};

return TensorTrait::new(self.shape, result.span());
}

0 comments on commit fc98e82

Please sign in to comment.