Skip to content

Commit

Permalink
softmax wide compute
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Oct 21, 2023
1 parent 0c89931 commit 99647a2
Show file tree
Hide file tree
Showing 16 changed files with 199 additions and 38 deletions.
8 changes: 6 additions & 2 deletions src/numbers.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ trait NumberTrait<T, MAG> {
fn sign(self: T) -> T;
}

use orion::numbers::fixed_point::implementations::fp8x23::core::{FP8x23Impl, FP8x23};
use orion::numbers::fixed_point::implementations::fp8x23::core::{
FP8x23Impl, FP8x23, FP8x23IntoFP32x32
};
use orion::numbers::fixed_point::implementations::fp8x23::math::core as core_fp8x23;
use orion::numbers::fixed_point::implementations::fp8x23::math::comp as comp_fp8x23;

Expand Down Expand Up @@ -213,7 +215,9 @@ impl FP8x23Number of NumberTrait<FP8x23, u32> {
}
}

use orion::numbers::fixed_point::implementations::fp16x16::core::{FP16x16Impl, FP16x16};
use orion::numbers::fixed_point::implementations::fp16x16::core::{
FP16x16Impl, FP16x16, FP16x16IntoFP32x32
};
use orion::numbers::fixed_point::implementations::fp16x16::math::core as core_fp16x16;
use orion::numbers::fixed_point::implementations::fp16x16::math::comp as comp_fp16x16;

Expand Down
27 changes: 27 additions & 0 deletions src/numbers/fixed_point/implementations/fp16x16/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use result::{ResultTrait, ResultTraitImpl};
use traits::{TryInto, Into};

use orion::numbers::signed_integer::{i32::i32, i8::i8};
use orion::numbers::{FP32x32, FP32x32Impl};
use orion::numbers::fixed_point::core::FixedTrait;
use orion::numbers::fixed_point::implementations::fp16x16::math::{core, trig, hyp};
use orion::numbers::fixed_point::utils;
Expand Down Expand Up @@ -192,6 +193,32 @@ impl FP16x16Print of PrintTrait<FP16x16> {
}
}

impl FP16x16IntoFP32x32 of Into<FP16x16, FP32x32> {
fn into(self: FP16x16) -> FP32x32 {
return FP32x32 { mag: self.mag.into() * 65536_u64, sign: self.sign };
}
}

#[test]
fn test_fp16x16_into_fp32x32() {
let a = FP16x16Impl::new_unscaled(42, true);
let b: FP32x32 = a.into();
assert(b.mag == 180388626432, 'invalid conversion');
}

impl FP32x32TryIntoFP16x16 of TryInto<FP32x32, FP16x16> {
fn try_into(self: FP32x32) -> Option<FP16x16> {
Option::Some(FP16x16 { mag: (self.mag / 65536).try_into().unwrap(), sign: self.sign })
}
}

#[test]
fn test_fp32x32_try_into_fp16x16() {
let a = FP32x32Impl::new_unscaled(42, true);
let b: FP16x16 = a.try_into().unwrap();
assert(b == FP16x16Impl::new_unscaled(42, true), 'invalid conversion');
}

// Into a raw felt without unscaling
impl FP16x16IntoFelt252 of Into<FP16x16, felt252> {
fn into(self: FP16x16) -> felt252 {
Expand Down
27 changes: 27 additions & 0 deletions src/numbers/fixed_point/implementations/fp8x23/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use result::{ResultTrait, ResultTraitImpl};
use traits::{TryInto, Into};

use orion::numbers::signed_integer::{i32::i32, i8::i8};
use orion::numbers::{FP32x32, FP32x32Impl};
use orion::numbers::fixed_point::core::{FixedTrait};
use orion::numbers::fixed_point::implementations::fp8x23::math::{core, trig, hyp};
use orion::numbers::fixed_point::utils;
Expand Down Expand Up @@ -192,6 +193,19 @@ impl FP8x23Print of PrintTrait<FP8x23> {
}
}

impl FP8x23IntoFP32x32 of Into<FP8x23, FP32x32> {
fn into(self: FP8x23) -> FP32x32 {
return FP32x32 { mag: self.mag.into() * 512_u64, sign: self.sign };
}
}

#[test]
fn test_fp8x23_into_fp32x32() {
let a = FP8x23Impl::new_unscaled(42, true);
let b: FP32x32 = a.into();
assert(b.mag == 180388626432, 'invalid conversion');
}

// Into a raw felt without unscaling
impl FP8x23IntoFelt252 of Into<FP8x23, felt252> {
fn into(self: FP8x23) -> felt252 {
Expand All @@ -205,6 +219,19 @@ impl FP8x23IntoFelt252 of Into<FP8x23, felt252> {
}
}

impl FP32x32TryIntoFP8x23 of TryInto<FP32x32, FP8x23> {
fn try_into(self: FP32x32) -> Option<FP8x23> {
Option::Some(FP8x23 { mag: (self.mag / 512).try_into().unwrap(), sign: self.sign })
}
}

#[test]
fn test_fp32x32_try_into_fp8x23() {
let a = FP32x32Impl::new_unscaled(42, true);
let b: FP8x23 = a.try_into().unwrap();
assert(b == FP8x23Impl::new_unscaled(42, true), 'invalid conversion');
}

impl FP8x23TryIntoU128 of TryInto<FP8x23, u128> {
fn try_into(self: FP8x23) -> Option<u128> {
if self.sign {
Expand Down
2 changes: 1 addition & 1 deletion src/operators/nn/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ trait NNTrait<T> {
/// // [[0.2689, 0.7311],[0.2689, 0.7311]]
/// ```
///
fn softmax(tensor: @Tensor<T>, axis: usize) -> Tensor<T>;
fn softmax(tensor: @Tensor<T>, axis: usize, wide: bool) -> Tensor<T>;
/// # NNTrait::logsoftmax
///
/// ```rust
Expand Down
31 changes: 25 additions & 6 deletions src/operators/nn/functional/softmax.cairo
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
use orion::operators::tensor::core::{Tensor, TensorTrait};

use orion::operators::tensor::math::{exp::exp_upcast, arithmetic::div_downcast};
use orion::numbers::fixed_point::core::FixedTrait;

/// Cf: NNTrait::softmax docstring
fn softmax<
T,
TMAG,
U,
UMAG,
impl UFixedTrait: FixedTrait<U, UMAG>,
impl TTensor: TensorTrait<T>,
impl UTensor: TensorTrait<U>,
impl TTensorDiv: Div<Tensor<T>>,
impl TDiv: Div<T>,
impl TIntoU: Into<T, U>,
impl UTtryIntoT: TryInto<U, T>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>,
impl UCopy: Copy<U>,
impl UDrop: Drop<U>,
>(
z: @Tensor<T>, axis: usize
z: @Tensor<T>, axis: usize, wide: bool
) -> Tensor<T> {
let exp_tensor = z.exp();
let sum = exp_tensor.reduce_sum(axis, true);
let softmax = exp_tensor / sum;
if wide {
let exp_tensor: Tensor<U> = exp_upcast::<T, TMAG, U, UMAG,>(*z);
let sum = exp_tensor.reduce_sum(axis, true);
let softmax: Tensor<T> = div_downcast(@exp_tensor, @sum);

return softmax;
} else {
let exp_tensor = z.exp();
let sum = exp_tensor.reduce_sum(axis, true);
let softmax = exp_tensor / sum;

return softmax;
return softmax;
}
}

10 changes: 4 additions & 6 deletions src/operators/nn/implementations/nn_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ use core::option::OptionTrait;
use orion::operators::tensor::core::Tensor;
use orion::operators::nn::core::NNTrait;
use orion::operators::nn::functional;
use orion::numbers::fixed_point::implementations::fp16x16::core::FP16x16;
use orion::operators::tensor::implementations::tensor_fp16x16::{
FP16x16Tensor, FP16x16TensorDiv, FP16x16TensorAdd
};
use orion::numbers::{FP16x16, FP16x16IntoFP32x32, FP32x32, FP32x32Impl};
use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorDiv, FP16x16TensorAdd, FP32x32Tensor};

impl FP16x16NN of NNTrait<FP16x16> {
fn relu(tensor: @Tensor<FP16x16>) -> Tensor<FP16x16> {
Expand All @@ -17,8 +15,8 @@ impl FP16x16NN of NNTrait<FP16x16> {
functional::sigmoid::sigmoid(*tensor)
}

fn softmax(tensor: @Tensor<FP16x16>, axis: usize) -> Tensor<FP16x16> {
functional::softmax::softmax(tensor, axis)
fn softmax(tensor: @Tensor<FP16x16>, axis: usize, wide: bool) -> Tensor<FP16x16> {
functional::softmax::softmax::<FP16x16, u32, FP32x32, u64>(tensor, axis, wide)
}

fn logsoftmax(tensor: @Tensor<FP16x16>, axis: usize) -> Tensor<FP16x16> {
Expand Down
5 changes: 3 additions & 2 deletions src/operators/nn/implementations/nn_fp32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ impl FP32x32NN of NNTrait<FP32x32> {
functional::sigmoid::sigmoid(*tensor)
}

fn softmax(tensor: @Tensor<FP32x32>, axis: usize) -> Tensor<FP32x32> {
functional::softmax::softmax(tensor, axis)
fn softmax(tensor: @Tensor<FP32x32>, axis: usize, wide: bool) -> Tensor<FP32x32> {
// functional::softmax::softmax(tensor, axis)
panic(array!['not supported!'])
}

fn logsoftmax(tensor: @Tensor<FP32x32>, axis: usize) -> Tensor<FP32x32> {
Expand Down
5 changes: 3 additions & 2 deletions src/operators/nn/implementations/nn_fp64x64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ impl FP64x64NN of NNTrait<FP64x64> {
functional::sigmoid::sigmoid(*tensor)
}

fn softmax(tensor: @Tensor<FP64x64>, axis: usize) -> Tensor<FP64x64> {
functional::softmax::softmax(tensor, axis)
fn softmax(tensor: @Tensor<FP64x64>, axis: usize, wide: bool) -> Tensor<FP64x64> {
// functional::softmax::softmax(tensor, axis)
panic(array!['not supported!'])
}

fn logsoftmax(tensor: @Tensor<FP64x64>, axis: usize) -> Tensor<FP64x64> {
Expand Down
12 changes: 7 additions & 5 deletions src/operators/nn/implementations/nn_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use core::option::OptionTrait;
use orion::operators::tensor::core::Tensor;
use orion::operators::nn::core::NNTrait;
use orion::operators::nn::functional;
use orion::numbers::fixed_point::implementations::fp8x23::core::FP8x23;
use orion::operators::tensor::implementations::tensor_fp8x23::{
FP8x23Tensor, FP8x23TensorDiv, FP8x23TensorAdd
use orion::numbers::{FP8x23, FP8x23IntoFP32x32, FP32x32, FP32x32Impl};
use orion::operators::tensor::{
FP8x23Tensor, FP8x23TensorDiv, FP8x23TensorAdd, FP32x32Tensor
};



impl FP8x23NN of NNTrait<FP8x23> {
fn relu(tensor: @Tensor<FP8x23>) -> Tensor<FP8x23> {
functional::relu::relu(*tensor)
Expand All @@ -17,8 +19,8 @@ impl FP8x23NN of NNTrait<FP8x23> {
functional::sigmoid::sigmoid(*tensor)
}

fn softmax(tensor: @Tensor<FP8x23>, axis: usize) -> Tensor<FP8x23> {
functional::softmax::softmax(tensor, axis)
fn softmax(tensor: @Tensor<FP8x23>, axis: usize, wide: bool) -> Tensor<FP8x23> {
functional::softmax::softmax::<FP8x23, u32, FP32x32, u64>(tensor, axis, wide)
}

fn logsoftmax(tensor: @Tensor<FP8x23>, axis: usize) -> Tensor<FP8x23> {
Expand Down
2 changes: 1 addition & 1 deletion src/operators/nn/implementations/nn_i32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ impl I32NN of NNTrait<i32> {
panic(array!['not supported!'])
}

fn softmax(tensor: @Tensor<i32>, axis: usize) -> Tensor<i32> {
fn softmax(tensor: @Tensor<i32>, axis: usize, wide: bool) -> Tensor<i32> {
panic(array!['not supported!'])
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/nn/implementations/nn_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ impl I8NN of NNTrait<i8> {
panic(array!['not supported!'])
}

fn softmax(tensor: @Tensor<i8>, axis: usize) -> Tensor<i8> {
fn softmax(tensor: @Tensor<i8>, axis: usize, wide: bool) -> Tensor<i8> {
panic(array!['not supported!'])
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/nn/implementations/nn_u32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ impl U32NN of NNTrait<u32> {
panic(array!['not supported!'])
}

fn softmax(tensor: @Tensor<u32>, axis: usize) -> Tensor<u32> {
fn softmax(tensor: @Tensor<u32>, axis: usize, wide: bool) -> Tensor<u32> {
panic(array!['not supported!'])
}

Expand Down
10 changes: 10 additions & 0 deletions src/operators/tensor.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ use orion::operators::tensor::implementations::tensor_fp16x16::{
FP16x16TensorPartialEq,
};

use orion::operators::tensor::implementations::tensor_fp32x32::{
FP32x32Tensor, FP32x32TensorAdd, FP32x32TensorSub, FP32x32TensorMul, FP32x32TensorDiv,
FP32x32TensorPartialEq,
};

use orion::operators::tensor::implementations::tensor_fp64x64::{
FP64x64Tensor, FP64x64TensorAdd, FP64x64TensorSub, FP64x64TensorMul, FP64x64TensorDiv,
FP64x64TensorPartialEq,
};

use orion::operators::tensor::implementations::tensor_i8::{
I8Tensor, I8TensorAdd, I8TensorSub, I8TensorMul, I8TensorDiv, I8TensorPartialEq,
};
Expand Down
43 changes: 42 additions & 1 deletion src/operators/tensor/math/arithmetic.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ fn saturated_mul<
}

fn div<
T, impl TTensor: TensorTrait<T>, impl TMul: Div<T>, impl TCopy: Copy<T>, impl TDrop: Drop<T>
T, impl TTensor: TensorTrait<T>, impl TDiv: Div<T>, impl TCopy: Copy<T>, impl TDrop: Drop<T>
>(
self: @Tensor<T>, other: @Tensor<T>
) -> Tensor<T> {
Expand Down Expand Up @@ -304,3 +304,44 @@ fn saturated_div<

return TensorTrait::<Q>::new(broadcasted_shape, result.span());
}

fn div_downcast<
T,
D,
impl DTensor: TensorTrait<D>,
impl TTensor: TensorTrait<T>,
impl DDiv: Div<D>,
impl TTryIntoD: TryInto<T, D>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>,
impl DCopy: Copy<D>,
impl DDrop: Drop<D>
>(
self: @Tensor<T>, other: @Tensor<T>
) -> Tensor<D> {
let broadcasted_shape = broadcast_shape(*self.shape, *other.shape);
let mut result = ArrayTrait::new();

let num_elements = len_from_shape(broadcasted_shape);

let mut n: usize = 0;
loop {
let indices_broadcasted = unravel_index(n, broadcasted_shape);

let indices_self = broadcast_index_mapping(*self.shape, indices_broadcasted);
let indices_other = broadcast_index_mapping(*other.shape, indices_broadcasted);

result
.append(
TTryIntoD::try_into(*(*self.data)[indices_self]).unwrap()
/ TTryIntoD::try_into(*(*other.data)[indices_other]).unwrap()
);

n += 1;
if n == num_elements {
break ();
};
};

return TensorTrait::<D>::new(broadcasted_shape, result.span());
}
39 changes: 35 additions & 4 deletions src/operators/tensor/math/exp.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait};
fn exp<
T,
MAG,
impl FFixedTrait: FixedTrait<T, MAG>,
impl FTensor: TensorTrait<T>,
impl FCopy: Copy<T>,
impl FDrop: Drop<T>,
impl TFixedTrait: FixedTrait<T, MAG>,
impl TTensor: TensorTrait<T>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>,
>(
mut self: Tensor<T>
) -> Tensor<T> {
Expand All @@ -34,3 +34,34 @@ fn exp<

return TensorTrait::new(self.shape, result.span());
}

fn exp_upcast<
T,
TMAG,
U,
UMAG,
impl UFixedTrait: FixedTrait<U, UMAG>,
impl UTensor: TensorTrait<U>,
impl TIntoU: Into<T, U>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>,
impl UCopy: Copy<U>,
impl UDrop: Drop<U>,
>(
mut self: Tensor<T>
) -> Tensor<U> {
let mut result: Array<U> = ArrayTrait::new();

loop {
match self.data.pop_front() {
Option::Some(item) => {
result.append(TIntoU::into(*item).exp());
},
Option::None(_) => {
break;
}
};
};

return TensorTrait::new(self.shape, result.span());
}
Loading

0 comments on commit 99647a2

Please sign in to comment.